From 2fee28e7138d8753487ed8895ce0f5f2e643ffad Mon Sep 17 00:00:00 2001 From: Alex Crichton Date: Thu, 12 Dec 2019 16:47:42 -0800 Subject: [PATCH] std: Implement `LineWriter::write_vectored` This commit implements the `write_vectored` method of the `LineWriter` type. First discovered in bytecodealliance/wasmtime#629 the `write_vectored` method of `Stdout` bottoms out here but only ends up writing the first buffer due to the default implementation of `write_vectored`. Like `BufWriter`, however, `LineWriter` can have a non-default implementation of `write_vectored` which tries to preserve the vectored-ness as much as possible. Namely we can have a vectored write for everything before the newline and everything after the newline if all the stars align well. Also like `BufWriter`, though, special care is taken to ensure that whenever bytes are written we're sure to signal success since that represents a "commit" of writing bytes. --- src/libstd/io/buffered.rs | 172 +++++++++++++++++++++++++++++++++++++- 1 file changed, 171 insertions(+), 1 deletion(-) diff --git a/src/libstd/io/buffered.rs b/src/libstd/io/buffered.rs index 8e81b292f6f..df259dc2f56 100644 --- a/src/libstd/io/buffered.rs +++ b/src/libstd/io/buffered.rs @@ -989,6 +989,68 @@ fn write(&mut self, buf: &[u8]) -> io::Result { } } + // Vectored writes are very similar to the writes above, but adjusted for + // the list of buffers that we have to write. + fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result { + if self.need_flush { + self.flush()?; + } + + // Find the last newline, and failing that write the whole buffer + let last_newline = bufs + .iter() + .enumerate() + .rev() + .filter_map(|(i, buf)| { + let pos = memchr::memrchr(b'\n', buf)?; + Some((i, pos)) + }) + .next(); + let (i, j) = match last_newline { + Some(pair) => pair, + None => return self.inner.write_vectored(bufs), + }; + let (prefix, suffix) = bufs.split_at(i); + let (buf, suffix) = suffix.split_at(1); + let buf = &buf[0]; + + // Write everything up to the last newline, flushing afterwards. Note + // that only if we finished our entire `write_vectored` do we try the + // subsequent + // `write` + let mut n = 0; + let prefix_amt = prefix.iter().map(|i| i.len()).sum(); + if prefix_amt > 0 { + n += self.inner.write_vectored(prefix)?; + self.need_flush = true; + } + if n == prefix_amt { + match self.inner.write(&buf[..=j]) { + Ok(m) => n += m, + Err(e) if n == 0 => return Err(e), + Err(_) => return Ok(n), + } + self.need_flush = true; + } + if self.flush().is_err() || n != j + 1 + prefix_amt { + return Ok(n); + } + + // ... and now write out everything remaining + match self.inner.write(&buf[j + 1..]) { + Ok(i) => n += i, + Err(_) => return Ok(n), + } + + if suffix.iter().map(|s| s.len()).sum::() == 0 { + return Ok(n) + } + match self.inner.write_vectored(suffix) { + Ok(i) => Ok(n + i), + Err(_) => Ok(n), + } + } + fn flush(&mut self) -> io::Result<()> { self.inner.flush()?; self.need_flush = false; @@ -1015,7 +1077,7 @@ fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { #[cfg(test)] mod tests { use crate::io::prelude::*; - use crate::io::{self, BufReader, BufWriter, LineWriter, SeekFrom}; + use crate::io::{self, BufReader, BufWriter, LineWriter, SeekFrom, IoSlice}; use crate::sync::atomic::{AtomicUsize, Ordering}; use crate::thread; @@ -1483,4 +1545,112 @@ fn erroneous_flush_retried() { assert_eq!(l.write(b"a").unwrap_err().kind(), io::ErrorKind::Other) } + + #[test] + fn line_vectored() { + let mut a = LineWriter::new(Vec::new()); + assert_eq!( + a.write_vectored(&[ + IoSlice::new(&[]), + IoSlice::new(b"\n"), + IoSlice::new(&[]), + IoSlice::new(b"a"), + ]) + .unwrap(), + 2, + ); + assert_eq!(a.get_ref(), b"\n"); + + assert_eq!( + a.write_vectored(&[ + IoSlice::new(&[]), + IoSlice::new(b"b"), + IoSlice::new(&[]), + IoSlice::new(b"a"), + IoSlice::new(&[]), + IoSlice::new(b"c"), + ]) + .unwrap(), + 3, + ); + assert_eq!(a.get_ref(), b"\n"); + a.flush().unwrap(); + assert_eq!(a.get_ref(), b"\nabac"); + assert_eq!(a.write_vectored(&[]).unwrap(), 0); + assert_eq!( + a.write_vectored(&[ + IoSlice::new(&[]), + IoSlice::new(&[]), + IoSlice::new(&[]), + IoSlice::new(&[]), + ]) + .unwrap(), + 0, + ); + assert_eq!(a.write_vectored(&[IoSlice::new(b"a\nb"),]).unwrap(), 3); + assert_eq!(a.get_ref(), b"\nabaca\n"); + } + + #[test] + fn line_vectored_partial_and_errors() { + enum Call { + Write { inputs: Vec<&'static [u8]>, output: io::Result }, + Flush { output: io::Result<()> }, + } + struct Writer { + calls: Vec, + } + + impl Write for Writer { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.write_vectored(&[IoSlice::new(buf)]) + } + + fn write_vectored(&mut self, buf: &[IoSlice<'_>]) -> io::Result { + match self.calls.pop().unwrap() { + Call::Write { inputs, output } => { + assert_eq!(inputs, buf.iter().map(|b| &**b).collect::>()); + output + } + _ => panic!("unexpected call to write"), + } + } + + fn flush(&mut self) -> io::Result<()> { + match self.calls.pop().unwrap() { + Call::Flush { output } => output, + _ => panic!("unexpected call to flush"), + } + } + } + + impl Drop for Writer { + fn drop(&mut self) { + if !thread::panicking() { + assert_eq!(self.calls.len(), 0); + } + } + } + + // partial writes keep going + let mut a = LineWriter::new(Writer { calls: Vec::new() }); + a.write_vectored(&[IoSlice::new(&[]), IoSlice::new(b"abc")]).unwrap(); + a.get_mut().calls.push(Call::Flush { output: Ok(()) }); + a.get_mut().calls.push(Call::Write { inputs: vec![b"bcx\n"], output: Ok(4) }); + a.get_mut().calls.push(Call::Write { inputs: vec![b"abcx\n"], output: Ok(1) }); + a.write_vectored(&[IoSlice::new(b"x"), IoSlice::new(b"\n")]).unwrap(); + a.get_mut().calls.push(Call::Flush { output: Ok(()) }); + a.flush().unwrap(); + + // erroneous writes stop and don't write more + a.get_mut().calls.push(Call::Write { inputs: vec![b"x\n"], output: Err(err()) }); + assert_eq!(a.write_vectored(&[IoSlice::new(b"x"), IoSlice::new(b"\na")]).unwrap(), 2); + a.get_mut().calls.push(Call::Flush { output: Ok(()) }); + a.get_mut().calls.push(Call::Write { inputs: vec![b"x\n"], output: Ok(2) }); + a.flush().unwrap(); + + fn err() -> io::Error { + io::Error::new(io::ErrorKind::Other, "x") + } + } }