Auto merge of #79650 - the8472:fix-take, r=dtolnay

Fix incorrect io::Take's limit resulting from io::copy specialization

The specialization introduced in #75272 fails to update `io::Take` wrappers after performing the copy syscalls which bypass those wrappers. The buffer flushing before the copy does update them correctly, but the bytes copied after the initial flush weren't subtracted.

The fix is to subtract the bytes copied from each `Take` in the chain of wrappers, even when an error occurs during the syscall loop. To do so the `CopyResult` enum now has to carry the bytes copied so far in the error case.
This commit is contained in:
bors 2020-12-06 01:15:37 +00:00
commit ddafcc0b66
3 changed files with 54 additions and 16 deletions

View File

@ -1211,7 +1211,8 @@ pub fn copy(from: &Path, to: &Path) -> io::Result<u64> {
use super::kernel_copy::{copy_regular_files, CopyResult};
match copy_regular_files(reader.as_raw_fd(), writer.as_raw_fd(), max_len) {
CopyResult::Ended(result) => result,
CopyResult::Ended(bytes) => Ok(bytes),
CopyResult::Error(e, _) => Err(e),
CopyResult::Fallback(written) => match io::copy::generic_copy(&mut reader, &mut writer) {
Ok(bytes) => Ok(bytes + written),
Err(e) => Err(e),

View File

@ -167,10 +167,11 @@ fn copy(self) -> Result<u64> {
if input_meta.copy_file_range_candidate() && output_meta.copy_file_range_candidate() {
let result = copy_regular_files(readfd, writefd, max_write);
result.update_take(reader);
match result {
CopyResult::Ended(Ok(bytes_copied)) => return Ok(bytes_copied + written),
CopyResult::Ended(err) => return err,
CopyResult::Ended(bytes_copied) => return Ok(bytes_copied + written),
CopyResult::Error(e, _) => return Err(e),
CopyResult::Fallback(bytes) => written += bytes,
}
}
@ -182,20 +183,22 @@ fn copy(self) -> Result<u64> {
// fall back to the generic copy loop.
if input_meta.potential_sendfile_source() {
let result = sendfile_splice(SpliceMode::Sendfile, readfd, writefd, max_write);
result.update_take(reader);
match result {
CopyResult::Ended(Ok(bytes_copied)) => return Ok(bytes_copied + written),
CopyResult::Ended(err) => return err,
CopyResult::Ended(bytes_copied) => return Ok(bytes_copied + written),
CopyResult::Error(e, _) => return Err(e),
CopyResult::Fallback(bytes) => written += bytes,
}
}
if input_meta.maybe_fifo() || output_meta.maybe_fifo() {
let result = sendfile_splice(SpliceMode::Splice, readfd, writefd, max_write);
result.update_take(reader);
match result {
CopyResult::Ended(Ok(bytes_copied)) => return Ok(bytes_copied + written),
CopyResult::Ended(err) => return err,
CopyResult::Ended(bytes_copied) => return Ok(bytes_copied + written),
CopyResult::Error(e, _) => return Err(e),
CopyResult::Fallback(0) => { /* use the fallback below */ }
CopyResult::Fallback(_) => {
unreachable!("splice should not return > 0 bytes on the fallback path")
@ -225,6 +228,9 @@ fn drain_to<W: Write>(&mut self, _writer: &mut W, _limit: u64) -> Result<u64> {
Ok(0)
}
/// Updates `Take` wrappers to remove the number of bytes copied.
fn taken(&mut self, _bytes: u64) {}
/// The minimum of the limit of all `Take<_>` wrappers, `u64::MAX` otherwise.
/// This method does not account for data `BufReader` buffers and would underreport
/// the limit of a `Take<BufReader<Take<_>>>` type. Thus its result is only valid
@ -251,6 +257,10 @@ fn drain_to<W: Write>(&mut self, writer: &mut W, limit: u64) -> Result<u64> {
(**self).drain_to(writer, limit)
}
fn taken(&mut self, bytes: u64) {
(**self).taken(bytes);
}
fn min_limit(&self) -> u64 {
(**self).min_limit()
}
@ -407,6 +417,11 @@ fn drain_to<W: Write>(&mut self, writer: &mut W, outer_limit: u64) -> Result<u64
Ok(bytes_drained)
}
fn taken(&mut self, bytes: u64) {
self.set_limit(self.limit() - bytes);
self.get_mut().taken(bytes);
}
fn min_limit(&self) -> u64 {
min(Take::limit(self), self.get_ref().min_limit())
}
@ -432,6 +447,10 @@ fn drain_to<W: Write>(&mut self, writer: &mut W, outer_limit: u64) -> Result<u64
Ok(bytes as u64 + inner_bytes)
}
fn taken(&mut self, bytes: u64) {
self.get_mut().taken(bytes);
}
fn min_limit(&self) -> u64 {
self.get_ref().min_limit()
}
@ -457,10 +476,21 @@ fn fd_to_meta<T: AsRawFd>(fd: &T) -> FdMeta {
}
pub(super) enum CopyResult {
Ended(Result<u64>),
Ended(u64),
Error(Error, u64),
Fallback(u64),
}
impl CopyResult {
fn update_take(&self, reader: &mut impl CopyRead) {
match *self {
CopyResult::Fallback(bytes)
| CopyResult::Ended(bytes)
| CopyResult::Error(_, bytes) => reader.taken(bytes),
}
}
}
/// linux-specific implementation that will attempt to use copy_file_range for copy offloading
/// as the name says, it only works on regular files
///
@ -527,7 +557,7 @@ fn copy_file_range(
// - copying from an overlay filesystem in docker. reported to occur on fedora 32.
return CopyResult::Fallback(0);
}
Ok(0) => return CopyResult::Ended(Ok(written)), // reached EOF
Ok(0) => return CopyResult::Ended(written), // reached EOF
Ok(ret) => written += ret as u64,
Err(err) => {
return match err.raw_os_error() {
@ -545,12 +575,12 @@ fn copy_file_range(
assert_eq!(written, 0);
CopyResult::Fallback(0)
}
_ => CopyResult::Ended(Err(err)),
_ => CopyResult::Error(err, written),
};
}
}
}
CopyResult::Ended(Ok(written))
CopyResult::Ended(written)
}
#[derive(PartialEq)]
@ -623,10 +653,10 @@ fn splice(
Some(os_err) if mode == SpliceMode::Sendfile && os_err == libc::EOVERFLOW => {
CopyResult::Fallback(written)
}
_ => CopyResult::Ended(Err(err)),
_ => CopyResult::Error(err, written),
};
}
}
}
CopyResult::Ended(Ok(written))
CopyResult::Ended(written)
}

View File

@ -42,8 +42,15 @@ fn copy_specialization() -> Result<()> {
assert_eq!(sink.buffer(), b"wxyz");
let copied = crate::io::copy(&mut source, &mut sink)?;
assert_eq!(copied, 10);
assert_eq!(sink.buffer().len(), 0);
assert_eq!(copied, 10, "copy obeyed limit imposed by Take");
assert_eq!(sink.buffer().len(), 0, "sink buffer was flushed");
assert_eq!(source.limit(), 0, "outer Take was exhausted");
assert_eq!(source.get_ref().buffer().len(), 0, "source buffer should be drained");
assert_eq!(
source.get_ref().get_ref().limit(),
1,
"inner Take allowed reading beyond end of file, some bytes should be left"
);
let mut sink = sink.into_inner()?;
sink.seek(SeekFrom::Start(0))?;
@ -210,7 +217,7 @@ fn bench_socket_pipe_socket_copy(b: &mut test::Bencher) {
);
match probe {
CopyResult::Ended(Ok(1)) => {
CopyResult::Ended(1) => {
// splice works
}
_ => {