diff --git a/benches/parse.rs b/benches/parse.rs index 98952e3..515f4d6 100644 --- a/benches/parse.rs +++ b/benches/parse.rs @@ -152,7 +152,11 @@ fn version(c: &mut Criterion) { .bench_function(name, |b| b.iter(|| { black_box({ let mut b = httparse::_benchable::Bytes::new(input); - httparse::_benchable::parse_version(&mut b).unwrap() + match httparse::_benchable::parse_version(&mut b) { + // Somewhat awkward, but this is internal code so it's ok. + Ok(_) | Err(None) => (), + Err(Some(e)) => panic!("parse_version failed: {}", e), + } }); })); } diff --git a/src/lib.rs b/src/lib.rs index 7e10dbf..8d6841e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -27,7 +27,6 @@ use core::mem::{self, MaybeUninit}; use crate::iter::Bytes; mod iter; -#[macro_use] mod macros; mod simd; #[doc(hidden)] @@ -40,6 +39,13 @@ pub mod _benchable { pub use super::iter::Bytes; } +// Macro to generate byte lookup tables +macro_rules! byte_map { + ($($flag:expr,)*) => ([ + $($flag != 0,)* + ]) +} + /// Determines if byte is a token char. /// /// > ```notrust @@ -205,6 +211,27 @@ impl fmt::Display for InvalidChunkSize { /// a `Ok(Status::Partial)`. pub type Result = result::Result, Error>; +// Intermediate result with a more compact representation. +type InnerResult = result::Result>; +const ERR_PARTIAL: Option = None; + +// NOTE: we use this "extension trait" to provide a .into_outer() method for InnerResult +// (since it's an external type and wrapping it would be inelegant) +trait IntoOuter: Sized { + fn into_outer(self) -> T; +} + +impl IntoOuter> for InnerResult { + #[inline] + fn into_outer(self) -> Result { + match self { + Ok(x) => Ok(Status::Complete(x)), + Err(None) => Ok(Status::Partial), + Err(Some(err)) => Err(err), + } + } +} + /// The result of a successful parse pass. /// /// `Complete` is used when the buffer contained the complete value. @@ -367,7 +394,7 @@ impl ParserConfig { buf: &'buf [u8], headers: &'headers mut [MaybeUninit>], ) -> Result { - request.parse_with_config_and_uninit_headers(buf, self, headers) + request.parse_with_config_and_uninit_headers(buf, self, headers).into_outer() } /// Sets whether invalid header lines should be silently ignored in responses. @@ -424,7 +451,7 @@ impl ParserConfig { buf: &'buf [u8], headers: &'headers mut [MaybeUninit>], ) -> Result { - response.parse_with_config_and_uninit_headers(buf, self, headers) + response.parse_with_config_and_uninit_headers(buf, self, headers).into_outer() } } @@ -482,32 +509,38 @@ impl<'h, 'b> Request<'h, 'b> { buf: &'b [u8], config: &ParserConfig, mut headers: &'h mut [MaybeUninit>], - ) -> Result { - let orig_len = buf.len(); + ) -> InnerResult { + let start = buf.as_ptr() as usize; let mut bytes = Bytes::new(buf); - complete!(skip_empty_lines(&mut bytes)); - let method = complete!(parse_method(&mut bytes)); + + skip_empty_lines(&mut bytes)?; + let method = parse_method(&mut bytes)?; self.method = Some(method); if config.allow_multiple_spaces_in_request_line_delimiters { - complete!(skip_spaces(&mut bytes)); + skip_spaces(&mut bytes)?; } - self.path = Some(complete!(parse_uri(&mut bytes))); + self.path = Some(parse_uri(&mut bytes)?); if config.allow_multiple_spaces_in_request_line_delimiters { - complete!(skip_spaces(&mut bytes)); + skip_spaces(&mut bytes)?; + } + self.version = Some(parse_version(&mut bytes)?); + match parse_header_eol(&mut bytes)? { + // Req with no headers + true => { + let end = bytes.as_ptr() as usize; + return Ok(end-start) + }, + false => {}, // CRLF transition to headers } - self.version = Some(complete!(parse_version(&mut bytes))); - newline!(bytes); - let len = orig_len - bytes.len(); - let headers_len = complete!(parse_headers_iter_uninit( + parse_request_headers_iter_uninit( &mut headers, &mut bytes, - &ParserConfig::default(), - )); + )?; /* SAFETY: see `parse_headers_iter_uninit` guarantees */ self.headers = unsafe { assume_init_slice(headers) }; - - Ok(Status::Complete(len + headers_len)) + let end = bytes.as_ptr() as usize; + Ok(end-start) } /// Try to parse a buffer of bytes into the Request, @@ -519,7 +552,7 @@ impl<'h, 'b> Request<'h, 'b> { buf: &'b [u8], headers: &'h mut [MaybeUninit>], ) -> Result { - self.parse_with_config_and_uninit_headers(buf, &Default::default(), headers) + self.parse_with_config_and_uninit_headers(buf, &Default::default(), headers).into_outer() } fn parse_with_config(&mut self, buf: &'b [u8], config: &ParserConfig) -> Result { @@ -529,34 +562,33 @@ impl<'h, 'b> Request<'h, 'b> { unsafe { let headers: *mut [Header<'_>] = headers; let headers = headers as *mut [MaybeUninit>]; - match self.parse_with_config_and_uninit_headers(buf, config, &mut *headers) { - Ok(Status::Complete(idx)) => Ok(Status::Complete(idx)), - other => { - // put the original headers back - self.headers = &mut *(headers as *mut [Header<'_>]); - other - }, + let result = self.parse_with_config_and_uninit_headers(buf, config, &mut *headers); + if result.is_err() { + // put the original headers back + self.headers = &mut *(headers as *mut [Header<'_>]); } + result.into_outer() } } /// Try to parse a buffer of bytes into the Request. /// /// Returns byte offset in `buf` to start of HTTP body. + #[inline] pub fn parse(&mut self, buf: &'b [u8]) -> Result { self.parse_with_config(buf, &Default::default()) } } #[inline] -fn skip_empty_lines(bytes: &mut Bytes<'_>) -> Result<()> { +fn skip_empty_lines(bytes: &mut Bytes<'_>) -> InnerResult<()> { loop { let b = bytes.peek(); match b { Some(b'\r') => { // there's `\r`, so it's safe to bump 1 pos unsafe { bytes.bump() }; - expect!(bytes.next() == b'\n' => Err(Error::NewLine)); + expect_next(bytes, |b| *b == b'\n', Error::NewLine)?; }, Some(b'\n') => { // there's `\n`, so it's safe to bump 1 pos @@ -564,15 +596,15 @@ fn skip_empty_lines(bytes: &mut Bytes<'_>) -> Result<()> { }, Some(..) => { bytes.slice(); - return Ok(Status::Complete(())); + return Ok(()); }, - None => return Ok(Status::Partial) + None => return Err(ERR_PARTIAL), } } } #[inline] -fn skip_spaces(bytes: &mut Bytes<'_>) -> Result<()> { +fn skip_spaces(bytes: &mut Bytes<'_>) -> InnerResult<()> { loop { let b = bytes.peek(); match b { @@ -582,9 +614,9 @@ fn skip_spaces(bytes: &mut Bytes<'_>) -> Result<()> { } Some(..) => { bytes.slice(); - return Ok(Status::Complete(())); + return Ok(()); } - None => return Ok(Status::Partial), + None => return Err(ERR_PARTIAL), } } } @@ -629,14 +661,12 @@ impl<'h, 'b> Response<'h, 'b> { unsafe { let headers: *mut [Header<'_>] = headers; let headers = headers as *mut [MaybeUninit>]; - match self.parse_with_config_and_uninit_headers(buf, config, &mut *headers) { - Ok(Status::Complete(idx)) => Ok(Status::Complete(idx)), - other => { - // put the original headers back - self.headers = &mut *(headers as *mut [Header<'_>]); - other - }, + let result = self.parse_with_config_and_uninit_headers(buf, config, &mut *headers); + if result.is_err() { + // put the original headers back + self.headers = &mut *(headers as *mut [Header<'_>]); } + result.into_outer() } } @@ -645,17 +675,18 @@ impl<'h, 'b> Response<'h, 'b> { buf: &'b [u8], config: &ParserConfig, mut headers: &'h mut [MaybeUninit>], - ) -> Result { + ) -> InnerResult { let orig_len = buf.len(); let mut bytes = Bytes::new(buf); - complete!(skip_empty_lines(&mut bytes)); - self.version = Some(complete!(parse_version(&mut bytes))); - space!(bytes or Error::Version); + skip_empty_lines(&mut bytes)?; + self.version = Some(parse_version(&mut bytes)?); + expect_next(&mut bytes, |b| *b == b' ', Error::Version)?; + bytes.commit(); // TODO: remove ? if config.allow_multiple_spaces_in_response_status_delimiters { - complete!(skip_spaces(&mut bytes)); + skip_spaces(&mut bytes)?; } - self.code = Some(complete!(parse_code(&mut bytes))); + self.code = Some(parse_code(&mut bytes)?); // RFC7230 says there must be 'SP' and then reason-phrase, but admits // its only for legacy reasons. With the reason-phrase completely @@ -666,16 +697,16 @@ impl<'h, 'b> Response<'h, 'b> { // So, a SP means parse a reason-phrase. // A newline means go to headers. // Anything else we'll say is a malformed status. - match next!(bytes) { + match next(&mut bytes)? { b' ' => { if config.allow_multiple_spaces_in_response_status_delimiters { - complete!(skip_spaces(&mut bytes)); + skip_spaces(&mut bytes)?; } bytes.slice(); - self.reason = Some(complete!(parse_reason(&mut bytes))); + self.reason = Some(parse_reason(&mut bytes)?); }, b'\r' => { - expect!(bytes.next() == b'\n' => Err(Error::Status)); + expect_next(&mut bytes, |b| *b == b'\n', Error::Status)?; bytes.slice(); self.reason = Some(""); }, @@ -683,19 +714,19 @@ impl<'h, 'b> Response<'h, 'b> { bytes.slice(); self.reason = Some(""); } - _ => return Err(Error::Status), + _ => return Err(Some(Error::Status)), } let len = orig_len - bytes.len(); - let headers_len = complete!(parse_headers_iter_uninit( + let headers_len = parse_response_headers_iter_uninit( &mut headers, &mut bytes, config - )); + )?; /* SAFETY: see `parse_headers_iter_uninit` guarantees */ self.headers = unsafe { assume_init_slice(headers) }; - Ok(Status::Complete(len + headers_len)) + Ok(len + headers_len) } } @@ -740,8 +771,8 @@ pub const EMPTY_HEADER: Header<'static> = Header { name: "", value: b"" }; #[doc(hidden)] #[allow(missing_docs)] // WARNING: Exported for internal benchmarks, not fit for public consumption -pub fn parse_version(bytes: &mut Bytes) -> Result { - if let Some(eight) = bytes.peek_n::<[u8; 8]>(8) { +pub fn parse_version(bytes: &mut Bytes) -> InnerResult { + if let Some(eight) = bytes.peek_n(8) { // NOTE: should be const once MSRV >= 1.44 let h10: u64 = u64::from_ne_bytes(*b"HTTP/1.0"); let h11: u64 = u64::from_ne_bytes(*b"HTTP/1.1"); @@ -749,11 +780,11 @@ pub fn parse_version(bytes: &mut Bytes) -> Result { let block = u64::from_ne_bytes(eight); // NOTE: should be match once h10 & h11 are consts return if block == h10 { - Ok(Status::Complete(0)) + Ok(0) } else if block == h11 { - Ok(Status::Complete(1)) + Ok(1) } else { - Err(Error::Version) + Err(Some(Error::Version)) } } @@ -761,24 +792,21 @@ pub fn parse_version(bytes: &mut Bytes) -> Result { // If there aren't at least 8 bytes, we still want to detect early // if this is a valid version or not. If it is, we'll return Partial. - expect!(bytes.next() == b'H' => Err(Error::Version)); - expect!(bytes.next() == b'T' => Err(Error::Version)); - expect!(bytes.next() == b'T' => Err(Error::Version)); - expect!(bytes.next() == b'P' => Err(Error::Version)); - expect!(bytes.next() == b'/' => Err(Error::Version)); - expect!(bytes.next() == b'1' => Err(Error::Version)); - expect!(bytes.next() == b'.' => Err(Error::Version)); - Ok(Status::Partial) + if b"HTTP/1.".starts_with(bytes.as_ref()) { + return Err(ERR_PARTIAL); + } else { + return Err(Some(Error::Version)); + } } #[inline] #[doc(hidden)] #[allow(missing_docs)] // WARNING: Exported for internal benchmarks, not fit for public consumption -pub fn parse_method<'a>(bytes: &mut Bytes<'a>) -> Result<&'a str> { +pub fn parse_method<'a>(bytes: &mut Bytes<'a>) -> InnerResult<&'a str> { const GET: [u8; 4] = *b"GET "; const POST: [u8; 4] = *b"POST"; - match bytes.peek_n::<[u8; 4]>(4) { + match bytes.peek_n(4) { Some(GET) => { // SAFETY: matched the ASCII string and boundary checked let method = unsafe { @@ -786,7 +814,7 @@ pub fn parse_method<'a>(bytes: &mut Bytes<'a>) -> Result<&'a str> { let buf = bytes.slice_skip(1); str::from_utf8_unchecked(buf) }; - Ok(Status::Complete(method)) + Ok(method) } Some(POST) if bytes.peek_ahead(4) == Some(b' ') => { // SAFETY: matched the ASCII string and boundary checked @@ -795,7 +823,7 @@ pub fn parse_method<'a>(bytes: &mut Bytes<'a>) -> Result<&'a str> { let buf = bytes.slice_skip(1); str::from_utf8_unchecked(buf) }; - Ok(Status::Complete(method)) + Ok(method) } _ => parse_token(bytes), } @@ -815,13 +843,13 @@ pub fn parse_method<'a>(bytes: &mut Bytes<'a>) -> Result<&'a str> { /// > Non-US-ASCII content in header fields and the reason phrase /// > has been obsoleted and made opaque (the TEXT rule was removed). #[inline] -fn parse_reason<'a>(bytes: &mut Bytes<'a>) -> Result<&'a str> { +fn parse_reason<'a>(bytes: &mut Bytes<'a>) -> InnerResult<&'a str> { let mut seen_obs_text = false; loop { - let b = next!(bytes); + let b = next(bytes)?; if b == b'\r' { - expect!(bytes.next() == b'\n' => Err(Error::Status)); - return Ok(Status::Complete(unsafe { + expect_next(bytes, |b| *b == b'\n', Error::Status)?; + return Ok(unsafe { let bytes = bytes.slice_skip(2); if !seen_obs_text { // all bytes up till `i` must have been HTAB / SP / VCHAR @@ -830,9 +858,9 @@ fn parse_reason<'a>(bytes: &mut Bytes<'a>) -> Result<&'a str> { // obs-text characters were found, so return the fallback empty string "" } - })); + }); } else if b == b'\n' { - return Ok(Status::Complete(unsafe { + return Ok(unsafe { let bytes = bytes.slice_skip(1); if !seen_obs_text { // all bytes up till `i` must have been HTAB / SP / VCHAR @@ -841,9 +869,9 @@ fn parse_reason<'a>(bytes: &mut Bytes<'a>) -> Result<&'a str> { // obs-text characters were found, so return the fallback empty string "" } - })); + }); } else if !(b == 0x09 || b == b' ' || (0x21..=0x7E).contains(&b) || b >= 0x80) { - return Err(Error::Status); + return Err(Some(Error::Status)); } else if b >= 0x80 { seen_obs_text = true; } @@ -851,22 +879,22 @@ fn parse_reason<'a>(bytes: &mut Bytes<'a>) -> Result<&'a str> { } #[inline] -fn parse_token<'a>(bytes: &mut Bytes<'a>) -> Result<&'a str> { - let b = next!(bytes); +fn parse_token<'a>(bytes: &mut Bytes<'a>) -> InnerResult<&'a str> { + let b = next(bytes)?; if !is_token(b) { // First char must be a token char, it can't be a space which would indicate an empty token. - return Err(Error::Token); + return Err(Some(Error::Token)); } loop { - let b = next!(bytes); + let b = next(bytes)?; if b == b' ' { - return Ok(Status::Complete(unsafe { + return Ok(unsafe { // all bytes up till `i` must have been `is_token`. str::from_utf8_unchecked(bytes.slice_skip(1)) - })); + }); } else if !is_token(b) { - return Err(Error::Token); + return Err(Some(Error::Token)); } } } @@ -875,33 +903,29 @@ fn parse_token<'a>(bytes: &mut Bytes<'a>) -> Result<&'a str> { #[doc(hidden)] #[allow(missing_docs)] // WARNING: Exported for internal benchmarks, not fit for public consumption -pub fn parse_uri<'a>(bytes: &mut Bytes<'a>) -> Result<&'a str> { - let start = bytes.pos(); +pub fn parse_uri<'a>(bytes: &mut Bytes<'a>) -> InnerResult<&'a str> { simd::match_uri_vectored(bytes); - // URI must have at least one char - if bytes.pos() == start { - return Err(Error::Token); - } + // SAFTEY: the validated bytes are ASCII and thus UTF-8 + let uri = unsafe { + str::from_utf8_unchecked(bytes.slice()) + }; - if next!(bytes) == b' ' { - return Ok(Status::Complete(unsafe { - // all bytes up till `i` must have been `is_token`. - str::from_utf8_unchecked(bytes.slice_skip(1)) - })); + if !uri.is_empty() && next(bytes)? == b' ' { + return Ok(uri); } else { - return Err(Error::Token); + return Err(Some(Error::Token)); } } #[inline] -fn parse_code(bytes: &mut Bytes<'_>) -> Result { - let hundreds = expect!(bytes.next() == b'0'..=b'9' => Err(Error::Status)); - let tens = expect!(bytes.next() == b'0'..=b'9' => Err(Error::Status)); - let ones = expect!(bytes.next() == b'0'..=b'9' => Err(Error::Status)); +fn parse_code(bytes: &mut Bytes<'_>) -> InnerResult { + let hundreds = expect_next(bytes, u8::is_ascii_digit, Error::Status)?; + let tens = expect_next(bytes, u8::is_ascii_digit, Error::Status)?; + let ones = expect_next(bytes, u8::is_ascii_digit, Error::Status)?; - Ok(Status::Complete((hundreds - b'0') as u16 * 100 + + Ok((hundreds - b'0') as u16 * 100 + (tens - b'0') as u16 * 10 + - (ones - b'0') as u16)) + (ones - b'0') as u16) } /// Parse a buffer of bytes as headers. @@ -927,22 +951,18 @@ pub fn parse_headers<'b: 'h, 'h>( mut dst: &'h mut [Header<'b>], ) -> Result<(usize, &'h [Header<'b>])> { let mut iter = Bytes::new(src); - let pos = complete!(parse_headers_iter(&mut dst, &mut iter, &ParserConfig::default())); - Ok(Status::Complete((pos, dst))) -} - -#[inline] -fn parse_headers_iter<'a, 'b>( - headers: &mut &mut [Header<'a>], - bytes: &'b mut Bytes<'a>, - config: &ParserConfig, -) -> Result { - parse_headers_iter_uninit( + let start = iter.as_ptr() as usize; + let result = parse_request_headers_iter_uninit( /* SAFETY: see `parse_headers_iter_uninit` guarantees */ - unsafe { deinit_slice_mut(headers) }, - bytes, - config, - ) + unsafe { deinit_slice_mut(&mut dst) }, + &mut iter, + ); + let result = result.map(move |_| { + let end = iter.as_ptr() as usize; + (end - start, &*dst) + }); + + result.into_outer() } unsafe fn deinit_slice_mut<'a, 'b, T>(s: &'a mut &'b mut [T]) -> &'a mut &'b mut [MaybeUninit] { @@ -965,41 +985,147 @@ unsafe fn assume_init_slice(s: &mut [MaybeUninit]) -> &mut [T] { * Also it promises `headers` get shrunk to number of initialized headers, * so casting the other way around after calling this function is safe */ -fn parse_headers_iter_uninit<'a, 'b>( +#[inline(always)] +fn parse_request_headers_iter_uninit<'a, 'b>( headers: &mut &mut [MaybeUninit>], bytes: &'b mut Bytes<'a>, - config: &ParserConfig, -) -> Result { +) -> InnerResult<()> { + let headers_ptr = headers.as_mut_ptr(); + let max_headers = headers.len(); + + // SAFETY: Size headers at 0, we'll grow it as we parse headers. + *headers = unsafe { core::slice::from_raw_parts_mut(headers_ptr, 0) }; + + for i in 0..max_headers { + bytes.commit(); + + // parse header name until colon + let header_name = { + simd::match_header_name_vectored(bytes); + + // SAFTEY: the validated bytes are ASCII and thus UTF-8 + unsafe { str::from_utf8_unchecked(bytes.slice()) } + }; + + // parse colon and leading whitespace, possibly empty value + parse_header_colspace(bytes)?; + + // parse value until CRLF + let header_value = { + simd::match_header_value_vectored(bytes); + let slice = bytes.slice(); + trim_ascii_end(slice) + }; - /* Flow of this function is pretty complex, especially with macros, - * so this struct makes sure we shrink `headers` to only parsed ones. - * Comparing to previous code, this only may introduce some additional - * instructions in case of early return */ - struct ShrinkOnDrop<'r1, 'r2, 'a> { - headers: &'r1 mut &'r2 mut [MaybeUninit>], - num_headers: usize, - } + // Grow the header slice + *headers = unsafe { core::slice::from_raw_parts_mut(headers_ptr, i + 1) }; + // Write new header + headers[i] = MaybeUninit::new(Header { + name: header_name, + value: header_value, + }); - impl<'r1, 'r2, 'a> Drop for ShrinkOnDrop<'r1, 'r2, 'a> { - fn drop(&mut self) { - let headers = mem::replace(self.headers, &mut []); + // Check for end + if parse_header_eol(bytes)? { + return Ok(()); + } + } + + return Err(Some(Error::TooManyHeaders)); +} - /* SAFETY: num_headers is the number of initialized headers */ - let headers = unsafe { headers.get_unchecked_mut(..self.num_headers) }; +#[inline] +fn parse_header_eol(bytes: &mut Bytes) -> InnerResult { + match bytes.peek_n::<[u8; 4]>(4) { + // End of all headers + Some([b'\r', b'\n', b'\r', b'\n']) => { + unsafe { bytes.advance(4); } + return Ok(true); + }, + // End of header-line + Some([b'\r', b'\n', _, _]) /* if is_header_name_token(b) */ => { + unsafe { bytes.advance(2); } + }, + _ => { + // Slow cases + if let Some([b'\r', b'\n', b'\r']) = bytes.peek_n::<[u8; 3]>(3) { + // 3 bytes but not 4, edge case of incomplete CR LF CR LF + return Err(ERR_PARTIAL) + } + match bytes.peek_n::<[u8; 2]>(2) { + Some([b'\n', b'\n']) => { + unsafe { bytes.advance(2) }; + return Ok(true); + } + Some([b'\r', b'\n']) => { + unsafe { bytes.advance(2) }; + } + _ => { + let b = next(bytes)?; + if b != b'\n' && b != b'\r' { + return Err(Some(Error::HeaderName)); + } + } + } + } + } + return Ok(false) +} - *self.headers = headers; +#[inline(always)] +fn parse_header_colspace(bytes: &mut Bytes) -> InnerResult<()> { + // Fast path + if let Some([b':', b' ', b]) = bytes.peek_n::<[u8; 3]>(3) { + if b > b' ' { + // SAFETY: we know that the next 2-3 bytes are valid + unsafe { bytes.advance_and_commit(2); } + return Ok(()); + } + } + // Validate colon + if next(bytes)? != b':' { + return Err(Some(Error::HeaderName)); + } + // eat white space between colon and value + while let Some(b) = bytes.peek() { + if b == b' ' || b == b'\t' { + unsafe { bytes.advance_and_commit(1); } + } else { + return Ok(()); } } + return Err(ERR_PARTIAL) +} - let mut autoshrink = ShrinkOnDrop { - headers, - num_headers: 0, - }; - // Track starting pointer to calculate the number of bytes parsed. - let start = bytes.as_ref().as_ptr() as usize; - let mut result = Err(Error::TooManyHeaders); +// NOTE: in future use https://doc.rust-lang.org/std/primitive.slice.html#method.trim_ascii_end +#[inline] +fn trim_ascii_end(s: &[u8]) -> &[u8] { + match s.iter().rposition(|b| *b != b' ' && *b != b'\t' && *b != b'\r' && *b != b'\n') { + Some(end) => &s[..=end], + None => &s[..0], + } +} - let mut iter = autoshrink.headers.iter_mut(); +/* Function which parsers headers into uninitialized buffer. + * + * Guarantees that it doesn't write garbage, so casting + * &mut &mut [Header] -> &mut &mut [MaybeUninit
] + * is safe here. + * + * Also it promises `headers` get shrunk to number of initialized headers, + * so casting the other way around after calling this function is safe + */ +fn parse_response_headers_iter_uninit<'a, 'b>( + headers: &mut &mut [MaybeUninit>], + bytes: &'b mut Bytes<'a>, + config: &ParserConfig, +) -> InnerResult { + let headers_ptr = headers.as_mut_ptr(); + let max_headers = headers.len(); + // SAFETY: Size headers at 0, we'll grow it as we parse headers. + *headers = unsafe { core::slice::from_raw_parts_mut(headers_ptr, 0) }; + // Track starting pointer to calculate the number of bytes parsed. + let start = bytes.as_ptr() as usize; macro_rules! maybe_continue_after_obsolete_line_folding { ($bytes:ident, $label:lifetime) => { @@ -1009,7 +1135,7 @@ fn parse_headers_iter_uninit<'a, 'b>( // Next byte may be a space, in which case that header // is using obsolete line folding, so we may have more // whitespace to skip after colon. - return Ok(Status::Partial); + return Err(ERR_PARTIAL); } Some(b' ') | Some(b'\t') => { // The space will be consumed next iteration. @@ -1033,23 +1159,23 @@ fn parse_headers_iter_uninit<'a, 'b>( macro_rules! handle_invalid_char { ($bytes:ident, $b:ident, $err:ident) => { if !config.ignore_invalid_headers_in_responses { - return Err(Error::$err); + return Err(Some(Error::$err)); } let mut b = $b; loop { if b == b'\r' { - expect!(bytes.next() == b'\n' => Err(Error::$err)); + expect_next(bytes, |b| *b == b'\n', Error::$err)?; break; } if b == b'\n' { break; } if b == b'\0' { - return Err(Error::$err); + return Err(Some(Error::$err)); } - b = next!($bytes); + b = next($bytes)?; } $bytes.slice(); @@ -1059,17 +1185,15 @@ fn parse_headers_iter_uninit<'a, 'b>( } // a newline here means the head is over! - let b = next!(bytes); + let b = next(bytes)?; if b == b'\r' { - expect!(bytes.next() == b'\n' => Err(Error::NewLine)); - let end = bytes.as_ref().as_ptr() as usize; - result = Ok(Status::Complete(end - start)); - break; + expect_next(bytes, |b| *b == b'\n', Error::NewLine)?; + let end = bytes.as_ptr() as usize; + return Ok(end - start); } if b == b'\n' { - let end = bytes.as_ref().as_ptr() as usize; - result = Ok(Status::Complete(end - start)); - break; + let end = bytes.as_ptr() as usize; + return Ok(end - start); } if !is_header_name_token(b) { handle_invalid_char!(bytes, b, HeaderName); @@ -1078,7 +1202,7 @@ fn parse_headers_iter_uninit<'a, 'b>( // parse header name until colon let header_name: &str = 'name: loop { simd::match_header_name_vectored(bytes); - let mut b = next!(bytes); + let mut b = next(bytes)?; let name = unsafe { str::from_utf8_unchecked(bytes.slice_skip(1)) @@ -1090,7 +1214,7 @@ fn parse_headers_iter_uninit<'a, 'b>( if config.allow_spaces_after_header_name_in_responses { while b == b' ' || b == b'\t' { - b = next!(bytes); + b = next(bytes)?; if b == b':' { bytes.slice(); @@ -1107,7 +1231,7 @@ fn parse_headers_iter_uninit<'a, 'b>( let value_slice = 'value: loop { // eat white space between colon and value 'whitespace_after_colon: loop { - b = next!(bytes); + b = next(bytes)?; if b == b' ' || b == b'\t' { bytes.slice(); continue 'whitespace_after_colon; @@ -1117,7 +1241,7 @@ fn parse_headers_iter_uninit<'a, 'b>( } if b == b'\r' { - expect!(bytes.next() == b'\n' => Err(Error::HeaderValue)); + expect_next(bytes, |b| *b == b'\n', Error::HeaderValue)?; } else if b != b'\n' { handle_invalid_char!(bytes, b, HeaderValue); } @@ -1135,11 +1259,11 @@ fn parse_headers_iter_uninit<'a, 'b>( // parse value till EOL simd::match_header_value_vectored(bytes); - let b = next!(bytes); + let b = next(bytes)?; //found_ctl let skip = if b == b'\r' { - expect!(bytes.next() == b'\n' => Err(Error::HeaderValue)); + expect_next(bytes, |b| *b == b'\n', Error::HeaderValue)?; 2 } else if b == b'\n' { 1 @@ -1156,32 +1280,18 @@ fn parse_headers_iter_uninit<'a, 'b>( } }; - let uninit_header = match iter.next() { - Some(header) => header, - None => break 'headers - }; - - // trim trailing whitespace in the header - let header_value = if let Some(last_visible) = value_slice - .iter() - .rposition(|b| *b != b' ' && *b != b'\t' && *b != b'\r' && *b != b'\n') - { - // There is at least one non-whitespace character. - &value_slice[0..last_visible+1] - } else { - // There is no non-whitespace character. This can only happen when value_slice is - // empty. - value_slice - }; - - *uninit_header = MaybeUninit::new(Header { + // TODO: can move to the top of the loop + if headers.len() >= max_headers { + return Err(Some(Error::TooManyHeaders)); + } + // Grow the header slice + *headers = unsafe { core::slice::from_raw_parts_mut(headers_ptr, headers.len() + 1) }; + // Write new header + headers[headers.len() - 1] = MaybeUninit::new(Header { name: header_name, - value: header_value, + value: trim_ascii_end(value_slice), }); - autoshrink.num_headers += 1; } - - result } /// Parse a buffer of bytes as a chunk size. @@ -1205,7 +1315,10 @@ pub fn parse_chunk_size(buf: &[u8]) let mut in_ext = false; let mut count = 0; loop { - let b = next!(bytes); + let b = match bytes.next() { + Some(b) => b, + None => return Ok(Status::Partial), + }; match b { b'0' ..= b'9' if in_chunk_size => { if count > 15 { @@ -1232,9 +1345,10 @@ pub fn parse_chunk_size(buf: &[u8]) size += (b + 10 - b'A') as u64; } b'\r' => { - match next!(bytes) { - b'\n' => break, - _ => return Err(InvalidChunkSize), + match bytes.next() { + Some(b'\n') => break, + Some(_) => return Err(InvalidChunkSize), + None => return Ok(Status::Partial), } } // If we weren't in the extension yet, the ";" signals its start @@ -1261,6 +1375,22 @@ pub fn parse_chunk_size(buf: &[u8]) Ok(Status::Complete((bytes.pos(), size))) } +// Parser helpers +#[inline] +fn next(bytes: &mut Bytes) -> InnerResult { + bytes.next().ok_or(ERR_PARTIAL) +} + +#[inline] +fn expect_next(bytes: &mut Bytes, predicate: impl Fn(&u8) -> bool, err: Error) -> InnerResult { + let b = next(bytes)?; + if predicate(&b) { + Ok(b) + } else { + Err(Some(err)) + } +} + #[cfg(test)] mod tests { use super::{Request, Response, Status, EMPTY_HEADER, parse_chunk_size}; @@ -2230,6 +2360,26 @@ mod tests { assert_eq!(response.headers[0].value, &b"baguette"[..]); } + #[test] + fn test_inner_result_assumptions() { + assert_eq!( + core::mem::size_of::>(), + 2, + ); + assert_eq!( + core::mem::size_of::>>(), + 1, + ); + assert_eq!( + core::mem::size_of::>(), + 2, + ); + assert_eq!( + core::mem::size_of::>>(), + 2, + ); + } + #[test] fn test_method_within_buffer() { const REQUEST: &[u8] = b"GET / HTTP/1.1\r\n\r\n"; diff --git a/src/macros.rs b/src/macros.rs deleted file mode 100644 index fa4cf03..0000000 --- a/src/macros.rs +++ /dev/null @@ -1,59 +0,0 @@ -///! Utility macros - -macro_rules! next { - ($bytes:ident) => ({ - match $bytes.next() { - Some(b) => b, - None => return Ok(Status::Partial) - } - }) -} - -macro_rules! expect { - ($bytes:ident.next() == $pat:pat => $ret:expr) => { - expect!(next!($bytes) => $pat |? $ret) - }; - ($e:expr => $pat:pat |? $ret:expr) => { - match $e { - v@$pat => v, - _ => return $ret - } - }; -} - -macro_rules! complete { - ($e:expr) => { - match $e? { - Status::Complete(v) => v, - Status::Partial => return Ok(Status::Partial) - } - } -} - -macro_rules! byte_map { - ($($flag:expr,)*) => ([ - $($flag != 0,)* - ]) -} - -macro_rules! space { - ($bytes:ident or $err:expr) => ({ - expect!($bytes.next() == b' ' => Err($err)); - $bytes.slice(); - }) -} - -macro_rules! newline { - ($bytes:ident) => ({ - match next!($bytes) { - b'\r' => { - expect!($bytes.next() == b'\n' => Err(Error::NewLine)); - $bytes.slice(); - }, - b'\n' => { - $bytes.slice(); - }, - _ => return Err(Error::NewLine) - } - }) -} diff --git a/src/simd/nop.rs b/src/simd/nop.rs new file mode 100644 index 0000000..871cd01 --- /dev/null +++ b/src/simd/nop.rs @@ -0,0 +1,8 @@ +use crate::iter::Bytes; + +// Fallbacks that do nothing... + +#[inline(always)] +pub fn match_uri_vectored(_: &mut Bytes<'_>) {} +#[inline(always)] +pub fn match_header_value_vectored(_: &mut Bytes<'_>) {}