From d1e7285c4bbe09148658cfa6f2b3b15678052ed7 Mon Sep 17 00:00:00 2001 From: Siddharth Agarwal Date: Wed, 12 Apr 2017 14:28:15 -0700 Subject: [PATCH] encoder: write out buffer before starting flush or finish The problem here was that the buffer could still have some data in it at the time it is overwritten. None of the existing tests tested partial writes at all. To test this, generalize `WriteWithReject` to also handle partial writes. --- src/stream/encoder.rs | 61 ++++++++++++++++++++++++++++++++++++++++++- src/stream/mod.rs | 47 ++++++++++++++++----------------- 2 files changed, 83 insertions(+), 25 deletions(-) diff --git a/src/stream/encoder.rs b/src/stream/encoder.rs index 84d4048a..01419a7d 100644 --- a/src/stream/encoder.rs +++ b/src/stream/encoder.rs @@ -220,8 +220,13 @@ impl Encoder { } fn do_finish(&mut self) -> io::Result<()> { - if self.state != EncoderState::StreamEnd { + if self.state == EncoderState::Accepting { + // Write any data pending in `self.buffer`. + self.write_from_offset()?; self.state = EncoderState::Finished; + } + + if self.state == EncoderState::Finished { // First, closes the stream. let mut buffer = zstd_sys::ZSTD_outBuffer { dst: self.buffer.as_mut_ptr() as *mut c_void, @@ -321,6 +326,8 @@ impl Write for Encoder { fn flush(&mut self) -> io::Result<()> { if self.state == EncoderState::Accepting { + self.write_from_offset()?; + let mut buffer = zstd_sys::ZSTD_outBuffer { dst: self.buffer.as_mut_ptr() as *mut c_void, size: self.buffer.capacity(), @@ -341,3 +348,55 @@ impl Write for Encoder { Ok(()) } } + +#[cfg(test)] +mod tests { + use stream::decode_all; + use stream::tests::WritePartial; + use super::Encoder; + + /// Test that flush after a partial write works successfully without + /// corrupting the frame. This test is in this module because it checks + /// internal implementation details. + #[test] + fn test_partial_write_flush() { + use std::io::Write; + + let (input, mut z) = setup_partial_write(); + + // flush shouldn't corrupt the stream + z.flush().unwrap(); + + let buf = z.finish().unwrap().into_inner(); + assert_eq!(&decode_all(&buf[..]).unwrap(), &input); + } + + /// Test that finish after a partial write works successfully without + /// corrupting the frame. This test is in this module because it checks + /// internal implementation details. + #[test] + fn test_partial_write_finish() { + let (input, z) = setup_partial_write(); + + // finish shouldn't corrupt the stream + let buf = z.finish().unwrap().into_inner(); + assert_eq!(&decode_all(&buf[..]).unwrap(), &input); + } + + fn setup_partial_write() -> (Vec, Encoder) { + use std::io::Write; + + let mut buf = WritePartial::new(); + buf.accept(Some(1)); + let mut z = Encoder::new(buf, 1).unwrap(); + + // Fill in enough data to make sure the buffer gets written out. + let input = "b".repeat(128 * 1024).into_bytes(); + z.write(&input).unwrap(); + + // At this point, the internal buffer in z should have some data. + assert_ne!(z.offset, z.buffer.len()); + + (input, z) + } +} diff --git a/src/stream/mod.rs b/src/stream/mod.rs index d90997ec..d0f2b3ef 100644 --- a/src/stream/mod.rs +++ b/src/stream/mod.rs @@ -60,6 +60,7 @@ pub fn copy_encode(mut source: R, destination: W, level: i32) mod tests { use super::{Decoder, Encoder}; use super::{copy_encode, decode_all, encode_all}; + use std::cmp; use std::io; #[test] @@ -109,42 +110,42 @@ mod tests { } #[derive(Debug)] - struct WriteWithReject { + pub struct WritePartial { inner: Vec, - reject: bool, + accept: Option, } - impl WriteWithReject { - fn new() -> Self { - WriteWithReject { + impl WritePartial { + pub fn new() -> Self { + WritePartial { inner: Vec::new(), - reject: false, + accept: Some(0), } } - fn reject(&mut self) { - self.reject = true; + /// Make the writer only accept a certain number of bytes per write call. + /// If `bytes` is Some(0), accept an arbitrary number of bytes. + /// If `bytes` is None, reject with WouldBlock. + pub fn accept(&mut self, bytes: Option) { + self.accept = bytes; } - fn accept(&mut self) { - self.reject = false; - } - - fn into_inner(self) -> Vec { + pub fn into_inner(self) -> Vec { self.inner } } - impl io::Write for WriteWithReject { + impl io::Write for WritePartial { fn write(&mut self, buf: &[u8]) -> io::Result { - if self.reject { - return Err(io::Error::new(io::ErrorKind::WouldBlock, "reject")); + match self.accept { + None => Err(io::Error::new(io::ErrorKind::WouldBlock, "reject")), + Some(0) => self.inner.write(buf), + Some(n) => self.inner.write(&buf[..cmp::min(n, buf.len())]), } - self.inner.write(buf) } fn flush(&mut self) -> io::Result<()> { - if self.reject { + if self.accept.is_none() { return Err(io::Error::new(io::ErrorKind::WouldBlock, "reject")); } self.inner.flush() @@ -156,7 +157,7 @@ mod tests { use std::io::Write; let mut z = setup_try_finish(); - z.get_mut().accept(); + z.get_mut().accept(Some(0)); // flush() should continue to work even though write() doesn't. z.flush().unwrap(); @@ -166,8 +167,6 @@ mod tests { Err((_z, e)) => panic!("try_finish failed with {:?}", e), }; - println!("{:?}", buf); - // Make sure the multiple try_finish calls didn't screw up the internal // buffer and continued to produce valid compressed data. assert_eq!(&decode_all(&buf[..]).unwrap(), b"hello"); @@ -181,15 +180,15 @@ mod tests { z.write_all(b"hello world").unwrap(); } - fn setup_try_finish() -> Encoder { + fn setup_try_finish() -> Encoder { use std::io::Write; - let buf = WriteWithReject::new(); + let buf = WritePartial::new(); let mut z = Encoder::new(buf, 19).unwrap(); z.write_all(b"hello").unwrap(); - z.get_mut().reject(); + z.get_mut().accept(None); let (z, err) = z.try_finish().unwrap_err(); assert_eq!(err.kind(), io::ErrorKind::WouldBlock);