diff --git a/src/http/h1/decode.rs b/src/http/h1/decode.rs index 1cc4b586fb..cbb24f9839 100644 --- a/src/http/h1/decode.rs +++ b/src/http/h1/decode.rs @@ -1,4 +1,4 @@ -use std::cmp; +use std::{cmp, usize}; use std::io::{self, Read}; use self::Kind::{Length, Chunked, Eof}; @@ -14,21 +14,15 @@ pub struct Decoder { impl Decoder { pub fn length(x: u64) -> Decoder { - Decoder { - kind: Kind::Length(x) - } + Decoder { kind: Kind::Length(x) } } pub fn chunked() -> Decoder { - Decoder { - kind: Kind::Chunked(None) - } + Decoder { kind: Kind::Chunked(ChunkedState::Size, 0) } } pub fn eof() -> Decoder { - Decoder { - kind: Kind::Eof(false) - } + Decoder { kind: Kind::Eof(false) } } } @@ -37,7 +31,7 @@ enum Kind { /// A Reader used when a Content-Length header is passed with a positive integer. Length(u64), /// A Reader used when Transfer-Encoding is `chunked`. - Chunked(Option), + Chunked(ChunkedState, u64), /// A Reader used for responses that don't indicate a length or chunked. /// /// Note: This should only used for `Response`s. It is illegal for a @@ -55,14 +49,26 @@ enum Kind { Eof(bool), } +#[derive(Debug, PartialEq, Clone)] +enum ChunkedState { + Size, + SizeLws, + Extension, + SizeLf, + Body, + BodyCr, + BodyLf, + End, +} + impl Decoder { pub fn is_eof(&self) -> bool { trace!("is_eof? {:?}", self); match self.kind { Length(0) | - Chunked(Some(0)) | + Chunked(ChunkedState::End, _) | Eof(true) => true, - _ => false + _ => false, } } } @@ -87,183 +93,248 @@ impl Decoder { } Ok(num as usize) } - }, - Chunked(ref mut opt_remaining) => { - let mut rem = match *opt_remaining { - Some(ref rem) => *rem, - // None means we don't know the size of the next chunk - None => try!(read_chunk_size(body)) - }; - trace!("Chunked read, remaining={:?}", rem); - - if rem == 0 { - *opt_remaining = Some(0); - - // chunk of size 0 signals the end of the chunked stream - // if the 0 digit was missing from the stream, it would - // be an InvalidInput error instead. - trace!("end of chunked"); - return Ok(0) - } - - let to_read = cmp::min(rem as usize, buf.len()); - let count = try!(body.read(&mut buf[..to_read])) as u64; - - if count == 0 { - *opt_remaining = Some(0); - return Err(io::Error::new(io::ErrorKind::Other, "early eof")); + } + Chunked(ref mut state, ref mut size) => { + loop { + let mut read = 0; + // advances the chunked state + *state = try!(state.step(body, size, buf, &mut read)); + if *state == ChunkedState::End { + trace!("end of chunked"); + return Ok(0); + } + if read > 0 { + return Ok(read); + } } - - rem -= count; - *opt_remaining = if rem > 0 { - Some(rem) - } else { - try!(eat(body, b"\r\n")); - None - }; - Ok(count as usize) - }, + } Eof(ref mut is_eof) => { match body.read(buf) { Ok(0) => { *is_eof = true; Ok(0) } - other => other + other => other, } - }, + } } } } -fn eat(rdr: &mut R, bytes: &[u8]) -> io::Result<()> { - let mut buf = [0]; - for &b in bytes.iter() { - match try!(rdr.read(&mut buf)) { - 1 if buf[0] == b => (), - _ => return Err(io::Error::new(io::ErrorKind::InvalidInput, - "Invalid characters found")), +macro_rules! byte ( + ($rdr:ident) => ({ + let mut buf = [0]; + match try!($rdr.read(&mut buf)) { + 1 => buf[0], + _ => return Err(io::Error::new(io::ErrorKind::UnexpectedEof, + "Unexpected eof during chunk size line")), } - } - Ok(()) -} + }) +); -/// Chunked chunks start with 1*HEXDIGIT, indicating the size of the chunk. -fn read_chunk_size(rdr: &mut R) -> io::Result { - macro_rules! byte ( - ($rdr:ident) => ({ - let mut buf = [0]; - match try!($rdr.read(&mut buf)) { - 1 => buf[0], - _ => return Err(io::Error::new(io::ErrorKind::InvalidInput, - "Invalid chunk size line")), - - } +impl ChunkedState { + fn step(&self, + body: &mut R, + size: &mut u64, + buf: &mut [u8], + read: &mut usize) + -> io::Result { + use self::ChunkedState::*; + Ok(match *self { + Size => try!(ChunkedState::read_size(body, size)), + SizeLws => try!(ChunkedState::read_size_lws(body)), + Extension => try!(ChunkedState::read_extension(body)), + SizeLf => try!(ChunkedState::read_size_lf(body, size)), + Body => try!(ChunkedState::read_body(body, size, buf, read)), + BodyCr => try!(ChunkedState::read_body_cr(body)), + BodyLf => try!(ChunkedState::read_body_lf(body)), + End => ChunkedState::End, }) - ); - let mut size = 0u64; - let radix = 16; - let mut in_ext = false; - let mut in_chunk_size = true; - loop { + } + fn read_size(rdr: &mut R, size: &mut u64) -> io::Result { + trace!("Read size"); + let radix = 16; match byte!(rdr) { - b@b'0'...b'9' if in_chunk_size => { - size *= radix; - size += (b - b'0') as u64; - }, - b@b'a'...b'f' if in_chunk_size => { - size *= radix; - size += (b + 10 - b'a') as u64; - }, - b@b'A'...b'F' if in_chunk_size => { - size *= radix; - size += (b + 10 - b'A') as u64; - }, - b'\r' => { - match byte!(rdr) { - b'\n' => break, - _ => return Err(io::Error::new(io::ErrorKind::InvalidInput, - "Invalid chunk size line")) - - } - }, - // If we weren't in the extension yet, the ";" signals its start - b';' if !in_ext => { - in_ext = true; - in_chunk_size = false; - }, - // "Linear white space" is ignored between the chunk size and the - // extension separator token (";") due to the "implied *LWS rule". - b'\t' | b' ' if !in_ext & !in_chunk_size => {}, - // LWS can follow the chunk size, but no more digits can come - b'\t' | b' ' if in_chunk_size => in_chunk_size = false, - // We allow any arbitrary octet once we are in the extension, since - // they all get ignored anyway. According to the HTTP spec, valid - // extensions would have a more strict syntax: - // (token ["=" (token | quoted-string)]) - // but we gain nothing by rejecting an otherwise valid chunk size. - _ext if in_ext => { - //TODO: chunk extension byte; - }, - // Finally, if we aren't in the extension and we're reading any - // other octet, the chunk size line is invalid! + b @ b'0'...b'9' => { + *size *= radix; + *size += (b - b'0') as u64; + } + b @ b'a'...b'f' => { + *size *= radix; + *size += (b + 10 - b'a') as u64; + } + b @ b'A'...b'F' => { + *size *= radix; + *size += (b + 10 - b'A') as u64; + } + b'\t' | b' ' => return Ok(ChunkedState::SizeLws), + b';' => return Ok(ChunkedState::Extension), + b'\r' => return Ok(ChunkedState::SizeLf), _ => { return Err(io::Error::new(io::ErrorKind::InvalidInput, - "Invalid chunk size line")); + "Invalid chunk size line: Invalid Size")); } } + Ok(ChunkedState::Size) } - trace!("chunk size={:?}", size); - Ok(size) -} + fn read_size_lws(rdr: &mut R) -> io::Result { + trace!("read_size_lws"); + match byte!(rdr) { + // LWS can follow the chunk size, but no more digits can come + b'\t' | b' ' => Ok(ChunkedState::SizeLws), + b';' => Ok(ChunkedState::Extension), + b'\r' => return Ok(ChunkedState::SizeLf), + _ => { + Err(io::Error::new(io::ErrorKind::InvalidInput, + "Invalid chunk size linear white space")) + } + } + } + fn read_extension(rdr: &mut R) -> io::Result { + trace!("read_extension"); + match byte!(rdr) { + b'\r' => return Ok(ChunkedState::SizeLf), + _ => return Ok(ChunkedState::Extension), // no supported extensions + } + } + fn read_size_lf(rdr: &mut R, size: &mut u64) -> io::Result { + trace!("Chunk size is {:?}", size); + match byte!(rdr) { + b'\n' if *size > 0 => Ok(ChunkedState::Body), + b'\n' if *size == 0 => Ok(ChunkedState::End), + _ => Err(io::Error::new(io::ErrorKind::InvalidInput, "Invalid chunk size LF")), + } + } + fn read_body(rdr: &mut R, + rem: &mut u64, + buf: &mut [u8], + read: &mut usize) + -> io::Result { + trace!("Chunked read, remaining={:?}", rem); + + // cap remaining bytes at the max capacity of usize + let rem_cap = match *rem { + r if r > usize::MAX as u64 => usize::MAX, + r => r as usize, + }; + + let to_read = cmp::min(rem_cap, buf.len()); + let count = try!(rdr.read(&mut buf[..to_read])); + + trace!("to_read = {}", to_read); + trace!("count = {}", count); + + if count == 0 { + *rem = 0; + return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "early eof")); + } + *rem -= count as u64; + *read = count; + + if *rem > 0 { + Ok(ChunkedState::Body) + } else { + Ok(ChunkedState::BodyCr) + } + } + fn read_body_cr(rdr: &mut R) -> io::Result { + match byte!(rdr) { + b'\r' => Ok(ChunkedState::BodyLf), + _ => Err(io::Error::new(io::ErrorKind::InvalidInput, "Invalid chunk body CR")), + } + } + fn read_body_lf(rdr: &mut R) -> io::Result { + match byte!(rdr) { + b'\n' => Ok(ChunkedState::Size), + _ => Err(io::Error::new(io::ErrorKind::InvalidInput, "Invalid chunk body LF")), + } + } +} #[cfg(test)] mod tests { use std::error::Error; use std::io; - use super::{Decoder, read_chunk_size}; + use std::io::Write; + use super::Decoder; + use super::ChunkedState; + use mock::Async; #[test] fn test_read_chunk_size() { - fn read(s: &str, result: u64) { - assert_eq!(read_chunk_size(&mut s.as_bytes()).unwrap(), result); + use std::io::ErrorKind::{UnexpectedEof, InvalidInput}; + + fn read(s: &str) -> u64 { + let mut state = ChunkedState::Size; + let mut rdr = &mut s.as_bytes(); + let mut size = 0; + let mut count = 0; + loop { + let mut buf = [0u8; 10]; + let result = state.step(&mut rdr, &mut size, &mut buf, &mut count); + let desc = format!("read_size failed for {:?}", s); + state = result.expect(desc.as_str()); + trace!("State {:?}", state); + if state == ChunkedState::Body || state == ChunkedState::End { + break; + } + } + size } - fn read_err(s: &str) { - assert_eq!(read_chunk_size(&mut s.as_bytes()).unwrap_err().kind(), - io::ErrorKind::InvalidInput); + fn read_err(s: &str, expected_err: io::ErrorKind) { + let mut state = ChunkedState::Size; + let mut rdr = &mut s.as_bytes(); + let mut size = 0; + let mut count = 0; + loop { + let mut buf = [0u8; 10]; + let result = state.step(&mut rdr, &mut size, &mut buf, &mut count); + state = match result { + Ok(s) => s, + Err(e) => { + assert!(expected_err == e.kind(), "Reading {:?}, expected {:?}, but got {:?}", + s, expected_err, e.kind()); + return; + } + }; + trace!("State {:?}", state); + if state == ChunkedState::Body || state == ChunkedState::End { + panic!(format!("Was Ok. Expected Err for {:?}", s)); + } + } } - read("1\r\n", 1); - read("01\r\n", 1); - read("0\r\n", 0); - read("00\r\n", 0); - read("A\r\n", 10); - read("a\r\n", 10); - read("Ff\r\n", 255); - read("Ff \r\n", 255); + assert_eq!(1, read("1\r\n")); + assert_eq!(1, read("01\r\n")); + assert_eq!(0, read("0\r\n")); + assert_eq!(0, read("00\r\n")); + assert_eq!(10, read("A\r\n")); + assert_eq!(10, read("a\r\n")); + assert_eq!(255, read("Ff\r\n")); + assert_eq!(255, read("Ff \r\n")); // Missing LF or CRLF - read_err("F\rF"); - read_err("F"); + read_err("F\rF", InvalidInput); + read_err("F", UnexpectedEof); // Invalid hex digit - read_err("X\r\n"); - read_err("1X\r\n"); - read_err("-\r\n"); - read_err("-1\r\n"); + read_err("X\r\n", InvalidInput); + read_err("1X\r\n", InvalidInput); + read_err("-\r\n", InvalidInput); + read_err("-1\r\n", InvalidInput); // Acceptable (if not fully valid) extensions do not influence the size - read("1;extension\r\n", 1); - read("a;ext name=value\r\n", 10); - read("1;extension;extension2\r\n", 1); - read("1;;; ;\r\n", 1); - read("2; extension...\r\n", 2); - read("3 ; extension=123\r\n", 3); - read("3 ;\r\n", 3); - read("3 ; \r\n", 3); + assert_eq!(1, read("1;extension\r\n")); + assert_eq!(10, read("a;ext name=value\r\n")); + assert_eq!(1, read("1;extension;extension2\r\n")); + assert_eq!(1, read("1;;; ;\r\n")); + assert_eq!(2, read("2; extension...\r\n")); + assert_eq!(3, read("3 ; extension=123\r\n")); + assert_eq!(3, read("3 ;\r\n")); + assert_eq!(3, read("3 ; \r\n")); // Invalid extensions cause an error - read_err("1 invalid extension\r\n"); - read_err("1 A\r\n"); - read_err("1;no CRLF"); + read_err("1 invalid extension\r\n", InvalidInput); + read_err("1 A\r\n", InvalidInput); + read_err("1;no CRLF", UnexpectedEof); } #[test] @@ -287,7 +358,108 @@ mod tests { let mut buf = [0u8; 10]; assert_eq!(decoder.decode(&mut bytes, &mut buf).unwrap(), 7); let e = decoder.decode(&mut bytes, &mut buf).unwrap_err(); - assert_eq!(e.kind(), io::ErrorKind::Other); + assert_eq!(e.kind(), io::ErrorKind::UnexpectedEof); assert_eq!(e.description(), "early eof"); } + + #[test] + fn test_read_chunked_single_read() { + let content = b"10\r\n1234567890abcdef\r\n0\r\n"; + let mut mock_buf = io::Cursor::new(content); + let mut buf = [0u8; 16]; + let count = Decoder::chunked().decode(&mut mock_buf, &mut buf).expect("decode"); + assert_eq!(16, count); + let result = String::from_utf8(buf.to_vec()).expect("decode String"); + assert_eq!("1234567890abcdef", &result); + } + + #[test] + fn test_read_chunked_after_eof() { + let content = b"10\r\n1234567890abcdef\r\n0\r\n"; + let mut mock_buf = io::Cursor::new(content); + let mut buf = [0u8; 50]; + let mut decoder = Decoder::chunked(); + + // normal read + let count = decoder.decode(&mut mock_buf, &mut buf).expect("decode"); + assert_eq!(16, count); + let result = String::from_utf8(buf[0..count].to_vec()).expect("decode String"); + assert_eq!("1234567890abcdef", &result); + + // eof read + let count = decoder.decode(&mut mock_buf, &mut buf).expect("decode"); + assert_eq!(0, count); + + // ensure read after eof also returns eof + let count = decoder.decode(&mut mock_buf, &mut buf).expect("decode"); + assert_eq!(0, count); + } + + // perform an async read using a custom buffer size and causing a blocking + // read at the specified byte + fn read_async(mut decoder: Decoder, + content: &[u8], + block_at: usize, + read_buffer_size: usize) + -> String { + let content_len = content.len(); + let mock_buf = io::Cursor::new(content.clone()); + let mut ins = Async::new(mock_buf, block_at); + let mut outs = vec![]; + loop { + let mut buf = vec![0; read_buffer_size]; + match decoder.decode(&mut ins, buf.as_mut_slice()) { + Ok(0) => break, + Ok(i) => outs.write(&buf[0..i]).expect("write buffer"), + Err(e) => { + if e.kind() != io::ErrorKind::WouldBlock { + break; + } + ins.block_in(content_len); // we only block once + 0 as usize + } + }; + } + String::from_utf8(outs).expect("decode String") + } + + // iterate over the different ways that this async read could go. + // tests every combination of buffer size that is passed in, with a blocking + // read at each byte along the content - The shotgun approach + fn all_async_cases(content: &str, expected: &str, decoder: Decoder) { + let content_len = content.len(); + for block_at in 0..content_len { + for read_buffer_size in 1..content_len { + let actual = read_async(decoder.clone(), + content.as_bytes(), + block_at, + read_buffer_size); + assert_eq!(expected, + &actual, + "Failed async. Blocking at {} with read buffer size {}", + block_at, + read_buffer_size); + } + } + } + + #[test] + fn test_read_length_async() { + let content = "foobar"; + all_async_cases(content, content, Decoder::length(content.len() as u64)); + } + + #[test] + fn test_read_chunked_async() { + let content = "3\r\nfoo\r\n3\r\nbar\r\n0\r\n"; + let expected = "foobar"; + all_async_cases(content, expected, Decoder::chunked()); + } + + #[test] + fn test_read_eof_async() { + let content = "foobar"; + all_async_cases(content, content, Decoder::eof()); + } + }