fix copy specialization not updating Take wrappers

This commit is contained in:
The8472 2020-12-02 23:35:40 +01:00
parent 9b390e73db
commit a9b1381b8d
3 changed files with 45 additions and 14 deletions

View File

@ -1211,7 +1211,8 @@ pub fn copy(from: &Path, to: &Path) -> io::Result<u64> {
use super::kernel_copy::{copy_regular_files, CopyResult}; use super::kernel_copy::{copy_regular_files, CopyResult};
match copy_regular_files(reader.as_raw_fd(), writer.as_raw_fd(), max_len) { match copy_regular_files(reader.as_raw_fd(), writer.as_raw_fd(), max_len) {
CopyResult::Ended(result) => result, CopyResult::Ended(bytes) => Ok(bytes),
CopyResult::Error(e, _) => Err(e),
CopyResult::Fallback(written) => match io::copy::generic_copy(&mut reader, &mut writer) { CopyResult::Fallback(written) => match io::copy::generic_copy(&mut reader, &mut writer) {
Ok(bytes) => Ok(bytes + written), Ok(bytes) => Ok(bytes + written),
Err(e) => Err(e), Err(e) => Err(e),

View File

@ -167,10 +167,11 @@ fn copy(self) -> Result<u64> {
if input_meta.copy_file_range_candidate() && output_meta.copy_file_range_candidate() { if input_meta.copy_file_range_candidate() && output_meta.copy_file_range_candidate() {
let result = copy_regular_files(readfd, writefd, max_write); let result = copy_regular_files(readfd, writefd, max_write);
result.update_take(reader);
match result { match result {
CopyResult::Ended(Ok(bytes_copied)) => return Ok(bytes_copied + written), CopyResult::Ended(bytes_copied) => return Ok(bytes_copied + written),
CopyResult::Ended(err) => return err, CopyResult::Error(e, _) => return Err(e),
CopyResult::Fallback(bytes) => written += bytes, CopyResult::Fallback(bytes) => written += bytes,
} }
} }
@ -182,20 +183,22 @@ fn copy(self) -> Result<u64> {
// fall back to the generic copy loop. // fall back to the generic copy loop.
if input_meta.potential_sendfile_source() { if input_meta.potential_sendfile_source() {
let result = sendfile_splice(SpliceMode::Sendfile, readfd, writefd, max_write); let result = sendfile_splice(SpliceMode::Sendfile, readfd, writefd, max_write);
result.update_take(reader);
match result { match result {
CopyResult::Ended(Ok(bytes_copied)) => return Ok(bytes_copied + written), CopyResult::Ended(bytes_copied) => return Ok(bytes_copied + written),
CopyResult::Ended(err) => return err, CopyResult::Error(e, _) => return Err(e),
CopyResult::Fallback(bytes) => written += bytes, CopyResult::Fallback(bytes) => written += bytes,
} }
} }
if input_meta.maybe_fifo() || output_meta.maybe_fifo() { if input_meta.maybe_fifo() || output_meta.maybe_fifo() {
let result = sendfile_splice(SpliceMode::Splice, readfd, writefd, max_write); let result = sendfile_splice(SpliceMode::Splice, readfd, writefd, max_write);
result.update_take(reader);
match result { match result {
CopyResult::Ended(Ok(bytes_copied)) => return Ok(bytes_copied + written), CopyResult::Ended(bytes_copied) => return Ok(bytes_copied + written),
CopyResult::Ended(err) => return err, CopyResult::Error(e, _) => return Err(e),
CopyResult::Fallback(0) => { /* use the fallback below */ } CopyResult::Fallback(0) => { /* use the fallback below */ }
CopyResult::Fallback(_) => { CopyResult::Fallback(_) => {
unreachable!("splice should not return > 0 bytes on the fallback path") unreachable!("splice should not return > 0 bytes on the fallback path")
@ -225,6 +228,9 @@ fn drain_to<W: Write>(&mut self, _writer: &mut W, _limit: u64) -> Result<u64> {
Ok(0) Ok(0)
} }
/// Updates `Take` wrappers to remove the number of bytes copied.
fn taken(&mut self, _bytes: u64) {}
/// The minimum of the limit of all `Take<_>` wrappers, `u64::MAX` otherwise. /// The minimum of the limit of all `Take<_>` wrappers, `u64::MAX` otherwise.
/// This method does not account for data `BufReader` buffers and would underreport /// This method does not account for data `BufReader` buffers and would underreport
/// the limit of a `Take<BufReader<Take<_>>>` type. Thus its result is only valid /// the limit of a `Take<BufReader<Take<_>>>` type. Thus its result is only valid
@ -251,6 +257,10 @@ fn drain_to<W: Write>(&mut self, writer: &mut W, limit: u64) -> Result<u64> {
(**self).drain_to(writer, limit) (**self).drain_to(writer, limit)
} }
fn taken(&mut self, bytes: u64) {
(**self).taken(bytes);
}
fn min_limit(&self) -> u64 { fn min_limit(&self) -> u64 {
(**self).min_limit() (**self).min_limit()
} }
@ -407,6 +417,11 @@ fn drain_to<W: Write>(&mut self, writer: &mut W, outer_limit: u64) -> Result<u64
Ok(bytes_drained) Ok(bytes_drained)
} }
fn taken(&mut self, bytes: u64) {
self.set_limit(self.limit() - bytes);
self.get_mut().taken(bytes);
}
fn min_limit(&self) -> u64 { fn min_limit(&self) -> u64 {
min(Take::limit(self), self.get_ref().min_limit()) min(Take::limit(self), self.get_ref().min_limit())
} }
@ -432,6 +447,10 @@ fn drain_to<W: Write>(&mut self, writer: &mut W, outer_limit: u64) -> Result<u64
Ok(bytes as u64 + inner_bytes) Ok(bytes as u64 + inner_bytes)
} }
fn taken(&mut self, bytes: u64) {
self.get_mut().taken(bytes);
}
fn min_limit(&self) -> u64 { fn min_limit(&self) -> u64 {
self.get_ref().min_limit() self.get_ref().min_limit()
} }
@ -457,10 +476,21 @@ fn fd_to_meta<T: AsRawFd>(fd: &T) -> FdMeta {
} }
pub(super) enum CopyResult { pub(super) enum CopyResult {
Ended(Result<u64>), Ended(u64),
Error(Error, u64),
Fallback(u64), Fallback(u64),
} }
impl CopyResult {
fn update_take(&self, reader: &mut impl CopyRead) {
match *self {
CopyResult::Fallback(bytes)
| CopyResult::Ended(bytes)
| CopyResult::Error(_, bytes) => reader.taken(bytes),
}
}
}
/// linux-specific implementation that will attempt to use copy_file_range for copy offloading /// linux-specific implementation that will attempt to use copy_file_range for copy offloading
/// as the name says, it only works on regular files /// as the name says, it only works on regular files
/// ///
@ -527,7 +557,7 @@ fn copy_file_range(
// - copying from an overlay filesystem in docker. reported to occur on fedora 32. // - copying from an overlay filesystem in docker. reported to occur on fedora 32.
return CopyResult::Fallback(0); return CopyResult::Fallback(0);
} }
Ok(0) => return CopyResult::Ended(Ok(written)), // reached EOF Ok(0) => return CopyResult::Ended(written), // reached EOF
Ok(ret) => written += ret as u64, Ok(ret) => written += ret as u64,
Err(err) => { Err(err) => {
return match err.raw_os_error() { return match err.raw_os_error() {
@ -545,12 +575,12 @@ fn copy_file_range(
assert_eq!(written, 0); assert_eq!(written, 0);
CopyResult::Fallback(0) CopyResult::Fallback(0)
} }
_ => CopyResult::Ended(Err(err)), _ => CopyResult::Error(err, written),
}; };
} }
} }
} }
CopyResult::Ended(Ok(written)) CopyResult::Ended(written)
} }
#[derive(PartialEq)] #[derive(PartialEq)]
@ -623,10 +653,10 @@ fn splice(
Some(os_err) if mode == SpliceMode::Sendfile && os_err == libc::EOVERFLOW => { Some(os_err) if mode == SpliceMode::Sendfile && os_err == libc::EOVERFLOW => {
CopyResult::Fallback(written) CopyResult::Fallback(written)
} }
_ => CopyResult::Ended(Err(err)), _ => CopyResult::Error(err, written),
}; };
} }
} }
} }
CopyResult::Ended(Ok(written)) CopyResult::Ended(written)
} }

View File

@ -217,7 +217,7 @@ fn bench_socket_pipe_socket_copy(b: &mut test::Bencher) {
); );
match probe { match probe {
CopyResult::Ended(Ok(1)) => { CopyResult::Ended(1) => {
// splice works // splice works
} }
_ => { _ => {