Skip to content

Commit

Permalink
Merge pull request #25 from sid0/flush-corrupt
Browse files Browse the repository at this point in the history
encoder: write out buffer before starting flush or finish
  • Loading branch information
gyscos authored Apr 12, 2017
2 parents 8321d86 + d1e7285 commit 3123e41
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 25 deletions.
61 changes: 60 additions & 1 deletion src/stream/encoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,13 @@ impl<W: Write> Encoder<W> {
}

fn do_finish(&mut self) -> io::Result<()> {
if self.state != EncoderState::StreamEnd {
if self.state == EncoderState::Accepting {
// Write any data pending in `self.buffer`.
self.write_from_offset()?;
self.state = EncoderState::Finished;
}

if self.state == EncoderState::Finished {
// First, closes the stream.
let mut buffer = zstd_sys::ZSTD_outBuffer {
dst: self.buffer.as_mut_ptr() as *mut c_void,
Expand Down Expand Up @@ -321,6 +326,8 @@ impl<W: Write> Write for Encoder<W> {

fn flush(&mut self) -> io::Result<()> {
if self.state == EncoderState::Accepting {
self.write_from_offset()?;

let mut buffer = zstd_sys::ZSTD_outBuffer {
dst: self.buffer.as_mut_ptr() as *mut c_void,
size: self.buffer.capacity(),
Expand All @@ -341,3 +348,55 @@ impl<W: Write> Write for Encoder<W> {
Ok(())
}
}

#[cfg(test)]
mod tests {
use stream::decode_all;
use stream::tests::WritePartial;
use super::Encoder;

/// Test that flush after a partial write works successfully without
/// corrupting the frame. This test is in this module because it checks
/// internal implementation details.
#[test]
fn test_partial_write_flush() {
use std::io::Write;

let (input, mut z) = setup_partial_write();

// flush shouldn't corrupt the stream
z.flush().unwrap();

let buf = z.finish().unwrap().into_inner();
assert_eq!(&decode_all(&buf[..]).unwrap(), &input);
}

/// Test that finish after a partial write works successfully without
/// corrupting the frame. This test is in this module because it checks
/// internal implementation details.
#[test]
fn test_partial_write_finish() {
let (input, z) = setup_partial_write();

// finish shouldn't corrupt the stream
let buf = z.finish().unwrap().into_inner();
assert_eq!(&decode_all(&buf[..]).unwrap(), &input);
}

fn setup_partial_write() -> (Vec<u8>, Encoder<WritePartial>) {
use std::io::Write;

let mut buf = WritePartial::new();
buf.accept(Some(1));
let mut z = Encoder::new(buf, 1).unwrap();

// Fill in enough data to make sure the buffer gets written out.
let input = "b".repeat(128 * 1024).into_bytes();
z.write(&input).unwrap();

// At this point, the internal buffer in z should have some data.
assert_ne!(z.offset, z.buffer.len());

(input, z)
}
}
47 changes: 23 additions & 24 deletions src/stream/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ pub fn copy_encode<R, W>(mut source: R, destination: W, level: i32)
mod tests {
use super::{Decoder, Encoder};
use super::{copy_encode, decode_all, encode_all};
use std::cmp;
use std::io;

#[test]
Expand Down Expand Up @@ -109,42 +110,42 @@ mod tests {
}

#[derive(Debug)]
struct WriteWithReject {
pub struct WritePartial {
inner: Vec<u8>,
reject: bool,
accept: Option<usize>,
}

impl WriteWithReject {
fn new() -> Self {
WriteWithReject {
impl WritePartial {
pub fn new() -> Self {
WritePartial {
inner: Vec::new(),
reject: false,
accept: Some(0),
}
}

fn reject(&mut self) {
self.reject = true;
/// Make the writer only accept a certain number of bytes per write call.
/// If `bytes` is Some(0), accept an arbitrary number of bytes.
/// If `bytes` is None, reject with WouldBlock.
pub fn accept(&mut self, bytes: Option<usize>) {
self.accept = bytes;
}

fn accept(&mut self) {
self.reject = false;
}

fn into_inner(self) -> Vec<u8> {
pub fn into_inner(self) -> Vec<u8> {
self.inner
}
}

impl io::Write for WriteWithReject {
impl io::Write for WritePartial {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
if self.reject {
return Err(io::Error::new(io::ErrorKind::WouldBlock, "reject"));
match self.accept {
None => Err(io::Error::new(io::ErrorKind::WouldBlock, "reject")),
Some(0) => self.inner.write(buf),
Some(n) => self.inner.write(&buf[..cmp::min(n, buf.len())]),
}
self.inner.write(buf)
}

fn flush(&mut self) -> io::Result<()> {
if self.reject {
if self.accept.is_none() {
return Err(io::Error::new(io::ErrorKind::WouldBlock, "reject"));
}
self.inner.flush()
Expand All @@ -156,7 +157,7 @@ mod tests {
use std::io::Write;
let mut z = setup_try_finish();

z.get_mut().accept();
z.get_mut().accept(Some(0));

// flush() should continue to work even though write() doesn't.
z.flush().unwrap();
Expand All @@ -166,8 +167,6 @@ mod tests {
Err((_z, e)) => panic!("try_finish failed with {:?}", e),
};

println!("{:?}", buf);

// Make sure the multiple try_finish calls didn't screw up the internal
// buffer and continued to produce valid compressed data.
assert_eq!(&decode_all(&buf[..]).unwrap(), b"hello");
Expand All @@ -181,15 +180,15 @@ mod tests {
z.write_all(b"hello world").unwrap();
}

fn setup_try_finish() -> Encoder<WriteWithReject> {
fn setup_try_finish() -> Encoder<WritePartial> {
use std::io::Write;

let buf = WriteWithReject::new();
let buf = WritePartial::new();
let mut z = Encoder::new(buf, 19).unwrap();

z.write_all(b"hello").unwrap();

z.get_mut().reject();
z.get_mut().accept(None);

let (z, err) = z.try_finish().unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::WouldBlock);
Expand Down

0 comments on commit 3123e41

Please sign in to comment.