diff --git a/benches/parse.rs b/benches/parse.rs index 98952e3..515f4d6 100644 --- a/benches/parse.rs +++ b/benches/parse.rs @@ -152,7 +152,11 @@ fn version(c: &mut Criterion) { .bench_function(name, |b| b.iter(|| { black_box({ let mut b = httparse::_benchable::Bytes::new(input); - httparse::_benchable::parse_version(&mut b).unwrap() + match httparse::_benchable::parse_version(&mut b) { + // Somewhat awkward, but this is internal code so it's ok. + Ok(_) | Err(None) => (), + Err(Some(e)) => panic!("parse_version failed: {}", e), + } }); })); } diff --git a/src/iter.rs b/src/iter.rs index 0d86f9e..275c18a 100644 --- a/src/iter.rs +++ b/src/iter.rs @@ -1,108 +1,157 @@ -use core::slice; -use core::convert::TryInto; -use core::convert::TryFrom; - #[allow(missing_docs)] pub struct Bytes<'a> { - slice: &'a [u8], - pos: usize + start: *const u8, + end: *const u8, + cursor: *const u8, + phantom: core::marker::PhantomData<&'a ()>, } #[allow(missing_docs)] impl<'a> Bytes<'a> { #[inline] pub fn new(slice: &'a [u8]) -> Bytes<'a> { + let start = slice.as_ptr(); + let end = unsafe { start.add(slice.len()) }; + let cursor = start; Bytes { - slice, - pos: 0 + start, + end, + cursor, + phantom: core::marker::PhantomData, } } #[inline] pub fn pos(&self) -> usize { - self.pos + self.cursor as usize - self.start as usize } #[inline] pub fn peek(&self) -> Option { - self.peek_ahead(0) + if self.cursor < self.end { + // SAFETY: bounds checked + Some(unsafe { *self.cursor }) + } else { + None + } } #[inline] pub fn peek_ahead(&self, n: usize) -> Option { - self.slice.get(self.pos + n).copied() + let ptr = unsafe { self.cursor.add(n) }; + if ptr < self.end { + // SAFETY: bounds checked + Some(unsafe { *ptr }) + } else { + None + } } - + #[inline] - pub fn peek_n>(&self, n: usize) -> Option { - self.slice.get(self.pos..self.pos + n)?.try_into().ok() + pub fn peek_n(&self) -> Option { + let n = core::mem::size_of::(); + // Boundary check then read array from ptr + if self.len() >= n { + let ptr = self.cursor as *const U; + let x = unsafe { core::ptr::read_unaligned(ptr) }; + Some(x) + } else { + None + } } #[inline] pub unsafe fn bump(&mut self) { - debug_assert!(self.pos < self.slice.len(), "overflow"); - self.pos += 1; + self.advance(1) } - #[allow(unused)] #[inline] pub unsafe fn advance(&mut self, n: usize) { - debug_assert!(self.pos + n <= self.slice.len(), "overflow"); - self.pos += n; + self.cursor = self.cursor.add(n); + debug_assert!(self.cursor <= self.end, "overflow"); } #[inline] pub fn len(&self) -> usize { - self.slice.len() + self.end as usize - self.cursor as usize } #[inline] pub fn slice(&mut self) -> &'a [u8] { // not moving position at all, so it's safe - unsafe { - self.slice_skip(0) - } + let slice = unsafe { slice_from_ptr_range(self.start, self.cursor) }; + self.commit(); + slice } + // TODO: this is an anti-pattern, should be removed #[inline] pub unsafe fn slice_skip(&mut self, skip: usize) -> &'a [u8] { - debug_assert!(self.pos >= skip); - let head_pos = self.pos - skip; - let ptr = self.slice.as_ptr(); - let head = slice::from_raw_parts(ptr, head_pos); - let tail = slice::from_raw_parts(ptr.add(self.pos), self.slice.len() - self.pos); - self.pos = 0; - self.slice = tail; + debug_assert!(self.cursor.sub(skip) >= self.start); + let head = slice_from_ptr_range(self.start, self.cursor.sub(skip)); + self.commit(); head } + + #[inline] + pub fn commit(&mut self) { + self.start = self.cursor + } #[inline] pub unsafe fn advance_and_commit(&mut self, n: usize) { - debug_assert!(self.pos + n <= self.slice.len(), "overflow"); - self.pos += n; - let ptr = self.slice.as_ptr(); - let tail = slice::from_raw_parts(ptr.add(n), self.slice.len() - n); - self.pos = 0; - self.slice = tail; + self.advance(n); + self.commit(); + } + + #[inline] + pub fn as_ptr(&self) -> *const u8 { + self.cursor + } + + #[inline] + pub fn start(&self) -> *const u8 { + self.start + } + + #[inline] + pub fn end(&self) -> *const u8 { + self.end + } + + #[inline] + pub unsafe fn set_cursor(&mut self, ptr: *const u8) { + debug_assert!(ptr >= self.start); + debug_assert!(ptr <= self.end); + self.cursor = ptr; } } impl<'a> AsRef<[u8]> for Bytes<'a> { #[inline] fn as_ref(&self) -> &[u8] { - &self.slice[self.pos..] + unsafe { slice_from_ptr_range(self.cursor, self.end) } } } +#[inline] +unsafe fn slice_from_ptr_range<'a>(start: *const u8, end: *const u8) -> &'a [u8] { + debug_assert!(start <= end); + core::slice::from_raw_parts(start, end as usize - start as usize) +} + impl<'a> Iterator for Bytes<'a> { type Item = u8; #[inline] fn next(&mut self) -> Option { - if self.slice.len() > self.pos { - let b = unsafe { *self.slice.get_unchecked(self.pos) }; - self.pos += 1; - Some(b) + if self.cursor < self.end { + // SAFETY: bounds checked + unsafe { + let b = *self.cursor; + self.bump(); + Some(b) + } } else { None } diff --git a/src/lib.rs b/src/lib.rs index 988d982..b50affc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -27,7 +27,6 @@ use core::mem::{self, MaybeUninit}; use crate::iter::Bytes; mod iter; -#[macro_use] mod macros; mod simd; #[doc(hidden)] @@ -40,6 +39,13 @@ pub mod _benchable { pub use super::iter::Bytes; } +// Macro to generate byte lookup tables +macro_rules! byte_map { + ($($flag:expr,)*) => ([ + $($flag != 0,)* + ]) +} + /// Determines if byte is a token char. /// /// > ```notrust @@ -87,65 +93,10 @@ static URI_MAP: [bool; 256] = byte_map![ ]; #[inline] -fn is_uri_token(b: u8) -> bool { +pub(crate) fn is_uri_token(b: u8) -> bool { URI_MAP[b as usize] } -// A const alternative to u64::from_ne_bytes to avoid bumping MSRV (1.36 => 1.44) -// creates a u64 whose bytes are each equal to b -const fn uniform_block(b: u8) -> u64 { - b as u64 * 0x01_01_01_01_01_01_01_01 // [1_u8; 8] -} - -// A byte-wise range-check on an enire word/block, -// ensuring all bytes in the word satisfy -// `33 <= x <= 126 && x != '>' && x != '<'` -// it false negatives if the block contains '?' -#[inline] -fn validate_uri_block(block: [u8; 8]) -> usize { - // 33 <= x <= 126 - const M: u8 = 0x21; - const N: u8 = 0x7E; - const BM: u64 = uniform_block(M); - const BN: u64 = uniform_block(127-N); - const M128: u64 = uniform_block(128); - - let x = u64::from_ne_bytes(block); // Really just a transmute - let lt = x.wrapping_sub(BM) & !x; // <= m - let gt = x.wrapping_add(BN) | x; // >= n - - // XOR checks to catch '<' & '>' for correctness - // - // XOR can be thought of as a "distance function" - // (somewhat extrapolating from the `xor(x, x) = 0` identity and ∀ x != y: xor(x, y) != 0` - // (each u8 "xor key" providing a unique total ordering of u8) - // '<' and '>' have a "xor distance" of 2 (`xor('<', '>') = 2`) - // xor(x, '>') <= 2 => {'>', '?', '<'} - // xor(x, '<') <= 2 => {'<', '=', '>'} - // - // We assume P('=') > P('?'), - // given well/commonly-formatted URLs with querystrings contain - // a single '?' but possibly many '=' - // - // Thus it's preferable/near-optimal to "xor distance" on '>', - // since we'll slowpath at most one block per URL - // - // Some rust code to sanity check this yourself: - // ```rs - // fn xordist(x: u8, n: u8) -> Vec<(char, u8)> { - // (0..=255).into_iter().map(|c| (c as char, c ^ x)).filter(|(_c, y)| *y <= n).collect() - // } - // (xordist(b'<', 2), xordist(b'>', 2)) - // ``` - const B3: u64 = uniform_block(3); // (dist <= 2) + 1 to wrap - const BGT: u64 = uniform_block(b'>'); - - let xgt = x ^ BGT; - let ltgtq = xgt.wrapping_sub(B3) & !xgt; - - offsetnz((ltgtq | lt | gt) & M128) -} - static HEADER_NAME_MAP: [bool; 256] = byte_map![ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, @@ -166,7 +117,7 @@ static HEADER_NAME_MAP: [bool; 256] = byte_map![ ]; #[inline] -fn is_header_name_token(b: u8) -> bool { +pub(crate) fn is_header_name_token(b: u8) -> bool { HEADER_NAME_MAP[b as usize] } @@ -191,45 +142,10 @@ static HEADER_VALUE_MAP: [bool; 256] = byte_map![ #[inline] -fn is_header_value_token(b: u8) -> bool { +pub(crate) fn is_header_value_token(b: u8) -> bool { HEADER_VALUE_MAP[b as usize] } -// A byte-wise range-check on an entire word/block, -// ensuring all bytes in the word satisfy `32 <= x <= 126` -#[inline] -fn validate_header_value_block(block: [u8; 8]) -> usize { - // 32 <= x <= 126 - const M: u8 = 0x20; - const N: u8 = 0x7E; - const BM: u64 = uniform_block(M); - const BN: u64 = uniform_block(127-N); - const M128: u64 = uniform_block(128); - - let x = u64::from_ne_bytes(block); // Really just a transmute - let lt = x.wrapping_sub(BM) & !x; // <= m - let gt = x.wrapping_add(BN) | x; // >= n - offsetnz((lt | gt) & M128) -} - -#[inline] -/// Check block to find offset of first non-zero byte -// NOTE: Curiously `block.trailing_zeros() >> 3` appears to be slower, maybe revisit -fn offsetnz(block: u64) -> usize { - // fast path optimistic case (common for long valid sequences) - if block == 0 { - return 8; - } - - // perf: rust will unroll this loop - for (i, b) in block.to_ne_bytes().iter().copied().enumerate() { - if b != 0 { - return i; - } - } - unreachable!() -} - /// An error in parsing. #[derive(Copy, Clone, PartialEq, Eq, Debug)] pub enum Error { @@ -295,6 +211,27 @@ impl fmt::Display for InvalidChunkSize { /// a `Ok(Status::Partial)`. pub type Result = result::Result, Error>; +// Intermediate result with a more compact representation. +type InnerResult = result::Result>; +const ERR_PARTIAL: Option = None; + +// NOTE: we use this "extension trait" to provide a .into() method for InnerResult +// (since it's an external type and wrapping it would be inelegant) +trait IntoOuter: Sized { + fn into(self) -> T; +} + +impl IntoOuter> for InnerResult { + #[inline] + fn into(self) -> Result { + match self { + Ok(x) => Ok(Status::Complete(x)), + Err(None) => Ok(Status::Partial), + Err(Some(err)) => Err(err), + } + } +} + /// The result of a successful parse pass. /// /// `Complete` is used when the buffer contained the complete value. @@ -457,7 +394,7 @@ impl ParserConfig { buf: &'buf [u8], headers: &'headers mut [MaybeUninit>], ) -> Result { - request.parse_with_config_and_uninit_headers(buf, self, headers) + request.parse_with_config_and_uninit_headers(buf, self, headers).into() } /// Sets whether invalid header lines should be silently ignored in responses. @@ -514,7 +451,7 @@ impl ParserConfig { buf: &'buf [u8], headers: &'headers mut [MaybeUninit>], ) -> Result { - response.parse_with_config_and_uninit_headers(buf, self, headers) + response.parse_with_config_and_uninit_headers(buf, self, headers).into() } } @@ -572,32 +509,38 @@ impl<'h, 'b> Request<'h, 'b> { buf: &'b [u8], config: &ParserConfig, mut headers: &'h mut [MaybeUninit>], - ) -> Result { - let orig_len = buf.len(); + ) -> InnerResult { + let start = buf.as_ptr() as usize; let mut bytes = Bytes::new(buf); - complete!(skip_empty_lines(&mut bytes)); - let method = complete!(parse_method(&mut bytes)); + + skip_empty_lines(&mut bytes)?; + let method = parse_method(&mut bytes)?; self.method = Some(method); if config.allow_multiple_spaces_in_request_line_delimiters { - complete!(skip_spaces(&mut bytes)); + skip_spaces(&mut bytes)?; } - self.path = Some(complete!(parse_uri(&mut bytes))); + self.path = Some(parse_uri(&mut bytes)?); if config.allow_multiple_spaces_in_request_line_delimiters { - complete!(skip_spaces(&mut bytes)); + skip_spaces(&mut bytes)?; + } + self.version = Some(parse_version(&mut bytes)?); + match parse_header_eol(&mut bytes)? { + // Req with no headers + true => { + let end = bytes.as_ptr() as usize; + return Ok(end-start) + }, + false => {}, // CRLF transition to headers } - self.version = Some(complete!(parse_version(&mut bytes))); - newline!(bytes); - let len = orig_len - bytes.len(); - let headers_len = complete!(parse_headers_iter_uninit( + parse_request_headers_iter_uninit( &mut headers, &mut bytes, - &ParserConfig::default(), - )); + )?; /* SAFETY: see `parse_headers_iter_uninit` guarantees */ self.headers = unsafe { assume_init_slice(headers) }; - - Ok(Status::Complete(len + headers_len)) + let end = bytes.as_ptr() as usize; + Ok(end-start) } /// Try to parse a buffer of bytes into the Request, @@ -609,7 +552,7 @@ impl<'h, 'b> Request<'h, 'b> { buf: &'b [u8], headers: &'h mut [MaybeUninit>], ) -> Result { - self.parse_with_config_and_uninit_headers(buf, &Default::default(), headers) + self.parse_with_config_and_uninit_headers(buf, &Default::default(), headers).into() } fn parse_with_config(&mut self, buf: &'b [u8], config: &ParserConfig) -> Result { @@ -619,34 +562,33 @@ impl<'h, 'b> Request<'h, 'b> { unsafe { let headers: *mut [Header<'_>] = headers; let headers = headers as *mut [MaybeUninit>]; - match self.parse_with_config_and_uninit_headers(buf, config, &mut *headers) { - Ok(Status::Complete(idx)) => Ok(Status::Complete(idx)), - other => { - // put the original headers back - self.headers = &mut *(headers as *mut [Header<'_>]); - other - }, + let result = self.parse_with_config_and_uninit_headers(buf, config, &mut *headers); + if result.is_err() { + // put the original headers back + self.headers = &mut *(headers as *mut [Header<'_>]); } + result.into() } } /// Try to parse a buffer of bytes into the Request. /// /// Returns byte offset in `buf` to start of HTTP body. + #[inline] pub fn parse(&mut self, buf: &'b [u8]) -> Result { self.parse_with_config(buf, &Default::default()) } } #[inline] -fn skip_empty_lines(bytes: &mut Bytes<'_>) -> Result<()> { +fn skip_empty_lines(bytes: &mut Bytes<'_>) -> InnerResult<()> { loop { let b = bytes.peek(); match b { Some(b'\r') => { // there's `\r`, so it's safe to bump 1 pos unsafe { bytes.bump() }; - expect!(bytes.next() == b'\n' => Err(Error::NewLine)); + expect_next(bytes, |b| *b == b'\n', Error::NewLine)?; }, Some(b'\n') => { // there's `\n`, so it's safe to bump 1 pos @@ -654,15 +596,15 @@ fn skip_empty_lines(bytes: &mut Bytes<'_>) -> Result<()> { }, Some(..) => { bytes.slice(); - return Ok(Status::Complete(())); + return Ok(()); }, - None => return Ok(Status::Partial) + None => return Err(ERR_PARTIAL), } } } #[inline] -fn skip_spaces(bytes: &mut Bytes<'_>) -> Result<()> { +fn skip_spaces(bytes: &mut Bytes<'_>) -> InnerResult<()> { loop { let b = bytes.peek(); match b { @@ -672,9 +614,9 @@ fn skip_spaces(bytes: &mut Bytes<'_>) -> Result<()> { } Some(..) => { bytes.slice(); - return Ok(Status::Complete(())); + return Ok(()); } - None => return Ok(Status::Partial), + None => return Err(ERR_PARTIAL), } } } @@ -719,14 +661,12 @@ impl<'h, 'b> Response<'h, 'b> { unsafe { let headers: *mut [Header<'_>] = headers; let headers = headers as *mut [MaybeUninit>]; - match self.parse_with_config_and_uninit_headers(buf, config, &mut *headers) { - Ok(Status::Complete(idx)) => Ok(Status::Complete(idx)), - other => { - // put the original headers back - self.headers = &mut *(headers as *mut [Header<'_>]); - other - }, + let result = self.parse_with_config_and_uninit_headers(buf, config, &mut *headers); + if result.is_err() { + // put the original headers back + self.headers = &mut *(headers as *mut [Header<'_>]); } + result.into() } } @@ -735,17 +675,18 @@ impl<'h, 'b> Response<'h, 'b> { buf: &'b [u8], config: &ParserConfig, mut headers: &'h mut [MaybeUninit>], - ) -> Result { + ) -> InnerResult { let orig_len = buf.len(); let mut bytes = Bytes::new(buf); - complete!(skip_empty_lines(&mut bytes)); - self.version = Some(complete!(parse_version(&mut bytes))); - space!(bytes or Error::Version); + skip_empty_lines(&mut bytes)?; + self.version = Some(parse_version(&mut bytes)?); + expect_next(&mut bytes, |b| *b == b' ', Error::Version)?; + bytes.commit(); // TODO: remove ? if config.allow_multiple_spaces_in_response_status_delimiters { - complete!(skip_spaces(&mut bytes)); + skip_spaces(&mut bytes)?; } - self.code = Some(complete!(parse_code(&mut bytes))); + self.code = Some(parse_code(&mut bytes)?); // RFC7230 says there must be 'SP' and then reason-phrase, but admits // its only for legacy reasons. With the reason-phrase completely @@ -756,16 +697,16 @@ impl<'h, 'b> Response<'h, 'b> { // So, a SP means parse a reason-phrase. // A newline means go to headers. // Anything else we'll say is a malformed status. - match next!(bytes) { + match next(&mut bytes)? { b' ' => { if config.allow_multiple_spaces_in_response_status_delimiters { - complete!(skip_spaces(&mut bytes)); + skip_spaces(&mut bytes)?; } bytes.slice(); - self.reason = Some(complete!(parse_reason(&mut bytes))); + self.reason = Some(parse_reason(&mut bytes)?); }, b'\r' => { - expect!(bytes.next() == b'\n' => Err(Error::Status)); + expect_next(&mut bytes, |b| *b == b'\n', Error::Status)?; bytes.slice(); self.reason = Some(""); }, @@ -773,19 +714,19 @@ impl<'h, 'b> Response<'h, 'b> { bytes.slice(); self.reason = Some(""); } - _ => return Err(Error::Status), + _ => return Err(Some(Error::Status)), } let len = orig_len - bytes.len(); - let headers_len = complete!(parse_headers_iter_uninit( + let headers_len = parse_response_headers_iter_uninit( &mut headers, &mut bytes, config - )); + )?; /* SAFETY: see `parse_headers_iter_uninit` guarantees */ self.headers = unsafe { assume_init_slice(headers) }; - Ok(Status::Complete(len + headers_len)) + Ok(len + headers_len) } } @@ -830,8 +771,8 @@ pub const EMPTY_HEADER: Header<'static> = Header { name: "", value: b"" }; #[doc(hidden)] #[allow(missing_docs)] // WARNING: Exported for internal benchmarks, not fit for public consumption -pub fn parse_version(bytes: &mut Bytes) -> Result { - if let Some(eight) = bytes.peek_n::<[u8; 8]>(8) { +pub fn parse_version(bytes: &mut Bytes) -> InnerResult { + if let Some(eight) = bytes.peek_n() { // NOTE: should be const once MSRV >= 1.44 let h10: u64 = u64::from_ne_bytes(*b"HTTP/1.0"); let h11: u64 = u64::from_ne_bytes(*b"HTTP/1.1"); @@ -839,11 +780,11 @@ pub fn parse_version(bytes: &mut Bytes) -> Result { let block = u64::from_ne_bytes(eight); // NOTE: should be match once h10 & h11 are consts return if block == h10 { - Ok(Status::Complete(0)) + Ok(0) } else if block == h11 { - Ok(Status::Complete(1)) + Ok(1) } else { - Err(Error::Version) + Err(Some(Error::Version)) } } @@ -851,24 +792,21 @@ pub fn parse_version(bytes: &mut Bytes) -> Result { // If there aren't at least 8 bytes, we still want to detect early // if this is a valid version or not. If it is, we'll return Partial. - expect!(bytes.next() == b'H' => Err(Error::Version)); - expect!(bytes.next() == b'T' => Err(Error::Version)); - expect!(bytes.next() == b'T' => Err(Error::Version)); - expect!(bytes.next() == b'P' => Err(Error::Version)); - expect!(bytes.next() == b'/' => Err(Error::Version)); - expect!(bytes.next() == b'1' => Err(Error::Version)); - expect!(bytes.next() == b'.' => Err(Error::Version)); - Ok(Status::Partial) + if b"HTTP/1.".starts_with(bytes.as_ref()) { + return Err(ERR_PARTIAL); + } else { + return Err(Some(Error::Version)); + } } #[inline] #[doc(hidden)] #[allow(missing_docs)] // WARNING: Exported for internal benchmarks, not fit for public consumption -pub fn parse_method<'a>(bytes: &mut Bytes<'a>) -> Result<&'a str> { +pub fn parse_method<'a>(bytes: &mut Bytes<'a>) -> InnerResult<&'a str> { const GET: [u8; 4] = *b"GET "; const POST: [u8; 4] = *b"POST"; - match bytes.peek_n::<[u8; 4]>(4) { + match bytes.peek_n() { Some(GET) => { // SAFETY: matched the ASCII string and boundary checked let method = unsafe { @@ -876,7 +814,7 @@ pub fn parse_method<'a>(bytes: &mut Bytes<'a>) -> Result<&'a str> { let buf = bytes.slice_skip(1); str::from_utf8_unchecked(buf) }; - Ok(Status::Complete(method)) + Ok(method) } Some(POST) if bytes.peek_ahead(4) == Some(b' ') => { // SAFETY: matched the ASCII string and boundary checked @@ -885,7 +823,7 @@ pub fn parse_method<'a>(bytes: &mut Bytes<'a>) -> Result<&'a str> { let buf = bytes.slice_skip(1); str::from_utf8_unchecked(buf) }; - Ok(Status::Complete(method)) + Ok(method) } _ => parse_token(bytes), } @@ -905,13 +843,13 @@ pub fn parse_method<'a>(bytes: &mut Bytes<'a>) -> Result<&'a str> { /// > Non-US-ASCII content in header fields and the reason phrase /// > has been obsoleted and made opaque (the TEXT rule was removed). #[inline] -fn parse_reason<'a>(bytes: &mut Bytes<'a>) -> Result<&'a str> { +fn parse_reason<'a>(bytes: &mut Bytes<'a>) -> InnerResult<&'a str> { let mut seen_obs_text = false; loop { - let b = next!(bytes); + let b = next(bytes)?; if b == b'\r' { - expect!(bytes.next() == b'\n' => Err(Error::Status)); - return Ok(Status::Complete(unsafe { + expect_next(bytes, |b| *b == b'\n', Error::Status)?; + return Ok(unsafe { let bytes = bytes.slice_skip(2); if !seen_obs_text { // all bytes up till `i` must have been HTAB / SP / VCHAR @@ -920,9 +858,9 @@ fn parse_reason<'a>(bytes: &mut Bytes<'a>) -> Result<&'a str> { // obs-text characters were found, so return the fallback empty string "" } - })); + }); } else if b == b'\n' { - return Ok(Status::Complete(unsafe { + return Ok(unsafe { let bytes = bytes.slice_skip(1); if !seen_obs_text { // all bytes up till `i` must have been HTAB / SP / VCHAR @@ -931,9 +869,9 @@ fn parse_reason<'a>(bytes: &mut Bytes<'a>) -> Result<&'a str> { // obs-text characters were found, so return the fallback empty string "" } - })); + }); } else if !(b == 0x09 || b == b' ' || (0x21..=0x7E).contains(&b) || b >= 0x80) { - return Err(Error::Status); + return Err(Some(Error::Status)); } else if b >= 0x80 { seen_obs_text = true; } @@ -941,22 +879,22 @@ fn parse_reason<'a>(bytes: &mut Bytes<'a>) -> Result<&'a str> { } #[inline] -fn parse_token<'a>(bytes: &mut Bytes<'a>) -> Result<&'a str> { - let b = next!(bytes); +fn parse_token<'a>(bytes: &mut Bytes<'a>) -> InnerResult<&'a str> { + let b = next(bytes)?; if !is_token(b) { // First char must be a token char, it can't be a space which would indicate an empty token. - return Err(Error::Token); + return Err(Some(Error::Token)); } loop { - let b = next!(bytes); + let b = next(bytes)?; if b == b' ' { - return Ok(Status::Complete(unsafe { + return Ok(unsafe { // all bytes up till `i` must have been `is_token`. str::from_utf8_unchecked(bytes.slice_skip(1)) - })); + }); } else if !is_token(b) { - return Err(Error::Token); + return Err(Some(Error::Token)); } } } @@ -965,47 +903,29 @@ fn parse_token<'a>(bytes: &mut Bytes<'a>) -> Result<&'a str> { #[doc(hidden)] #[allow(missing_docs)] // WARNING: Exported for internal benchmarks, not fit for public consumption -pub fn parse_uri<'a>(bytes: &mut Bytes<'a>) -> Result<&'a str> { - let b = next!(bytes); - if !is_uri_token(b) { - // First char must be a URI char, it can't be a space which would indicate an empty path. - return Err(Error::Token); - } - +pub fn parse_uri<'a>(bytes: &mut Bytes<'a>) -> InnerResult<&'a str> { simd::match_uri_vectored(bytes); + // SAFTEY: the validated bytes are ASCII and thus UTF-8 + let uri = unsafe { + str::from_utf8_unchecked(bytes.slice()) + }; - let mut b; - loop { - if let Some(bytes8) = bytes.peek_n::<[u8; 8]>(8) { - let n = validate_uri_block(bytes8); - unsafe { bytes.advance(n); } - if n == 8 { continue; } - } - b = next!(bytes); - if !is_uri_token(b) { - break; - } - } - - if b == b' ' { - return Ok(Status::Complete(unsafe { - // all bytes up till `i` must have been `is_token`. - str::from_utf8_unchecked(bytes.slice_skip(1)) - })); + if !uri.is_empty() && next(bytes)? == b' ' { + return Ok(uri); } else { - return Err(Error::Token); + return Err(Some(Error::Token)); } } #[inline] -fn parse_code(bytes: &mut Bytes<'_>) -> Result { - let hundreds = expect!(bytes.next() == b'0'..=b'9' => Err(Error::Status)); - let tens = expect!(bytes.next() == b'0'..=b'9' => Err(Error::Status)); - let ones = expect!(bytes.next() == b'0'..=b'9' => Err(Error::Status)); +fn parse_code(bytes: &mut Bytes<'_>) -> InnerResult { + let hundreds = expect_next(bytes, u8::is_ascii_digit, Error::Status)?; + let tens = expect_next(bytes, u8::is_ascii_digit, Error::Status)?; + let ones = expect_next(bytes, u8::is_ascii_digit, Error::Status)?; - Ok(Status::Complete((hundreds - b'0') as u16 * 100 + + Ok((hundreds - b'0') as u16 * 100 + (tens - b'0') as u16 * 10 + - (ones - b'0') as u16)) + (ones - b'0') as u16) } /// Parse a buffer of bytes as headers. @@ -1031,22 +951,18 @@ pub fn parse_headers<'b: 'h, 'h>( mut dst: &'h mut [Header<'b>], ) -> Result<(usize, &'h [Header<'b>])> { let mut iter = Bytes::new(src); - let pos = complete!(parse_headers_iter(&mut dst, &mut iter, &ParserConfig::default())); - Ok(Status::Complete((pos, dst))) -} - -#[inline] -fn parse_headers_iter<'a, 'b>( - headers: &mut &mut [Header<'a>], - bytes: &'b mut Bytes<'a>, - config: &ParserConfig, -) -> Result { - parse_headers_iter_uninit( + let start = iter.as_ptr() as usize; + let result = parse_request_headers_iter_uninit( /* SAFETY: see `parse_headers_iter_uninit` guarantees */ - unsafe { deinit_slice_mut(headers) }, - bytes, - config, - ) + unsafe { deinit_slice_mut(&mut dst) }, + &mut iter, + ); + let result = result.map(move |_| { + let end = iter.as_ptr() as usize; + (end - start, &*dst) + }); + + result.into() } unsafe fn deinit_slice_mut<'a, 'b, T>(s: &'a mut &'b mut [T]) -> &'a mut &'b mut [MaybeUninit] { @@ -1069,40 +985,147 @@ unsafe fn assume_init_slice(s: &mut [MaybeUninit]) -> &mut [T] { * Also it promises `headers` get shrunk to number of initialized headers, * so casting the other way around after calling this function is safe */ -fn parse_headers_iter_uninit<'a, 'b>( +#[inline(always)] +fn parse_request_headers_iter_uninit<'a, 'b>( headers: &mut &mut [MaybeUninit>], bytes: &'b mut Bytes<'a>, - config: &ParserConfig, -) -> Result { +) -> InnerResult<()> { + let headers_ptr = headers.as_mut_ptr(); + let max_headers = headers.len(); - /* Flow of this function is pretty complex, especially with macros, - * so this struct makes sure we shrink `headers` to only parsed ones. - * Comparing to previous code, this only may introduce some additional - * instructions in case of early return */ - struct ShrinkOnDrop<'r1, 'r2, 'a> { - headers: &'r1 mut &'r2 mut [MaybeUninit>], - num_headers: usize, - } + // SAFETY: Size headers at 0, we'll grow it as we parse headers. + *headers = unsafe { core::slice::from_raw_parts_mut(headers_ptr, 0) }; + + for i in 0..max_headers { + bytes.commit(); + + // parse header name until colon + let header_name = { + simd::match_header_name_vectored(bytes); + + // SAFTEY: the validated bytes are ASCII and thus UTF-8 + unsafe { str::from_utf8_unchecked(bytes.slice()) } + }; + + // parse colon and leading whitespace, possibly empty value + parse_header_colspace(bytes)?; + + // parse value until CRLF + let header_value = { + simd::match_header_value_vectored(bytes); + let slice = bytes.slice(); + trim_ascii_end(slice) + }; - impl<'r1, 'r2, 'a> Drop for ShrinkOnDrop<'r1, 'r2, 'a> { - fn drop(&mut self) { - let headers = mem::replace(self.headers, &mut []); + // Grow the header slice + *headers = unsafe { core::slice::from_raw_parts_mut(headers_ptr, i + 1) }; + // Write new header + headers[i] = MaybeUninit::new(Header { + name: header_name, + value: header_value, + }); + + // Check for end + if parse_header_eol(bytes)? { + return Ok(()); + } + } + + return Err(Some(Error::TooManyHeaders)); +} - /* SAFETY: num_headers is the number of initialized headers */ - let headers = unsafe { headers.get_unchecked_mut(..self.num_headers) }; +#[inline] +fn parse_header_eol(bytes: &mut Bytes) -> InnerResult { + match bytes.peek_n::<[u8; 4]>() { + // End of all headers + Some([b'\r', b'\n', b'\r', b'\n']) => { + unsafe { bytes.advance(4); } + return Ok(true); + }, + // End of header-line + Some([b'\r', b'\n', _, _]) /* if is_header_name_token(b) */ => { + unsafe { bytes.advance(2); } + }, + _ => { + // Slow cases + if let Some([b'\r', b'\n', b'\r']) = bytes.peek_n::<[u8; 3]>() { + // 3 bytes but not 4, edge case of incomplete CR LF CR LF + return Err(ERR_PARTIAL) + } + match bytes.peek_n::<[u8; 2]>() { + Some([b'\n', b'\n']) => { + unsafe { bytes.advance(2) }; + return Ok(true); + } + Some([b'\r', b'\n']) => { + unsafe { bytes.advance(2) }; + } + _ => { + let b = next(bytes)?; + if b != b'\n' && b != b'\r' { + return Err(Some(Error::HeaderName)); + } + } + } + } + } + return Ok(false) +} - *self.headers = headers; +#[inline(always)] +fn parse_header_colspace(bytes: &mut Bytes) -> InnerResult<()> { + // Fast path + if let Some([b':', b' ', b]) = bytes.peek_n::<[u8; 3]>() { + if b > b' ' { + // SAFETY: we know that the next 2-3 bytes are valid + unsafe { bytes.advance_and_commit(2); } + return Ok(()); + } + } + // Validate colon + if next(bytes)? != b':' { + return Err(Some(Error::HeaderName)); + } + // eat white space between colon and value + while let Some(b) = bytes.peek() { + if b == b' ' || b == b'\t' { + unsafe { bytes.advance_and_commit(1); } + } else { + return Ok(()); } } + return Err(ERR_PARTIAL) +} - let mut autoshrink = ShrinkOnDrop { - headers, - num_headers: 0, - }; - let mut count: usize = 0; - let mut result = Err(Error::TooManyHeaders); +// NOTE: in future use https://doc.rust-lang.org/std/primitive.slice.html#method.trim_ascii_end +#[inline] +fn trim_ascii_end(s: &[u8]) -> &[u8] { + match s.iter().rposition(|b| *b != b' ' && *b != b'\t' && *b != b'\r' && *b != b'\n') { + Some(end) => &s[..=end], + None => &s[..0], + } +} - let mut iter = autoshrink.headers.iter_mut(); +/* Function which parsers headers into uninitialized buffer. + * + * Guarantees that it doesn't write garbage, so casting + * &mut &mut [Header] -> &mut &mut [MaybeUninit
] + * is safe here. + * + * Also it promises `headers` get shrunk to number of initialized headers, + * so casting the other way around after calling this function is safe + */ +fn parse_response_headers_iter_uninit<'a, 'b>( + headers: &mut &mut [MaybeUninit>], + bytes: &'b mut Bytes<'a>, + config: &ParserConfig, +) -> InnerResult { + let headers_ptr = headers.as_mut_ptr(); + let max_headers = headers.len(); + // SAFETY: Size headers at 0, we'll grow it as we parse headers. + *headers = unsafe { core::slice::from_raw_parts_mut(headers_ptr, 0) }; + // Track starting pointer to calculate the number of bytes parsed. + let start = bytes.as_ptr() as usize; macro_rules! maybe_continue_after_obsolete_line_folding { ($bytes:ident, $label:lifetime) => { @@ -1112,7 +1135,7 @@ fn parse_headers_iter_uninit<'a, 'b>( // Next byte may be a space, in which case that header // is using obsolete line folding, so we may have more // whitespace to skip after colon. - return Ok(Status::Partial); + return Err(ERR_PARTIAL); } Some(b' ') | Some(b'\t') => { // The space will be consumed next iteration. @@ -1136,26 +1159,25 @@ fn parse_headers_iter_uninit<'a, 'b>( macro_rules! handle_invalid_char { ($bytes:ident, $b:ident, $err:ident) => { if !config.ignore_invalid_headers_in_responses { - return Err(Error::$err); + return Err(Some(Error::$err)); } let mut b = $b; loop { if b == b'\r' { - expect!(bytes.next() == b'\n' => Err(Error::$err)); + expect_next(bytes, |b| *b == b'\n', Error::$err)?; break; } if b == b'\n' { break; } if b == b'\0' { - return Err(Error::$err); + return Err(Some(Error::$err)); } - b = next!($bytes); + b = next($bytes)?; } - count += $bytes.pos(); $bytes.slice(); continue 'headers; @@ -1163,53 +1185,25 @@ fn parse_headers_iter_uninit<'a, 'b>( } // a newline here means the head is over! - let b = next!(bytes); + let b = next(bytes)?; if b == b'\r' { - expect!(bytes.next() == b'\n' => Err(Error::NewLine)); - result = Ok(Status::Complete(count + bytes.pos())); - break; + expect_next(bytes, |b| *b == b'\n', Error::NewLine)?; + let end = bytes.as_ptr() as usize; + return Ok(end - start); } if b == b'\n' { - result = Ok(Status::Complete(count + bytes.pos())); - break; + let end = bytes.as_ptr() as usize; + return Ok(end - start); } if !is_header_name_token(b) { handle_invalid_char!(bytes, b, HeaderName); } // parse header name until colon - let mut b; let header_name: &str = 'name: loop { - 'name_inner: loop { - if let Some(bytes8) = bytes.peek_n::<[u8; 8]>(8) { - macro_rules! check { - ($bytes:ident, $i:literal) => ({ - b = $bytes[$i]; - if !is_header_name_token(b) { - unsafe { bytes.advance($i + 1); } - break 'name_inner; - } - }); - } - - check!(bytes8, 0); - check!(bytes8, 1); - check!(bytes8, 2); - check!(bytes8, 3); - check!(bytes8, 4); - check!(bytes8, 5); - check!(bytes8, 6); - check!(bytes8, 7); - unsafe { bytes.advance(8); } - } else { - b = next!(bytes); - if !is_header_name_token(b) { - break 'name_inner; - } - } - } - - count += bytes.pos(); + simd::match_header_name_vectored(bytes); + let mut b = next(bytes)?; + let name = unsafe { str::from_utf8_unchecked(bytes.slice_skip(1)) }; @@ -1220,10 +1214,9 @@ fn parse_headers_iter_uninit<'a, 'b>( if config.allow_spaces_after_header_name_in_responses { while b == b' ' || b == b'\t' { - b = next!(bytes); + b = next(bytes)?; if b == b':' { - count += bytes.pos(); bytes.slice(); break 'name name; } @@ -1238,9 +1231,8 @@ fn parse_headers_iter_uninit<'a, 'b>( let value_slice = 'value: loop { // eat white space between colon and value 'whitespace_after_colon: loop { - b = next!(bytes); + b = next(bytes)?; if b == b' ' || b == b'\t' { - count += bytes.pos(); bytes.slice(); continue 'whitespace_after_colon; } @@ -1249,14 +1241,13 @@ fn parse_headers_iter_uninit<'a, 'b>( } if b == b'\r' { - expect!(bytes.next() == b'\n' => Err(Error::HeaderValue)); + expect_next(bytes, |b| *b == b'\n', Error::HeaderValue)?; } else if b != b'\n' { handle_invalid_char!(bytes, b, HeaderValue); } maybe_continue_after_obsolete_line_folding!(bytes, 'whitespace_after_colon); - count += bytes.pos(); let whitespace_slice = bytes.slice(); // This produces an empty slice that points to the beginning @@ -1268,22 +1259,11 @@ fn parse_headers_iter_uninit<'a, 'b>( // parse value till EOL simd::match_header_value_vectored(bytes); - - 'value_line: loop { - if let Some(bytes8) = bytes.peek_n::<[u8; 8]>(8) { - let n = validate_header_value_block(bytes8); - unsafe { bytes.advance(n); } - if n == 8 { continue 'value_line; } - } - b = next!(bytes); - if !is_header_value_token(b) { - break 'value_line; - } - } + let b = next(bytes)?; //found_ctl let skip = if b == b'\r' { - expect!(bytes.next() == b'\n' => Err(Error::HeaderValue)); + expect_next(bytes, |b| *b == b'\n', Error::HeaderValue)?; 2 } else if b == b'\n' { 1 @@ -1293,7 +1273,6 @@ fn parse_headers_iter_uninit<'a, 'b>( maybe_continue_after_obsolete_line_folding!(bytes, 'value_lines); - count += bytes.pos(); // having just checked that a newline exists, it's safe to skip it. unsafe { break 'value bytes.slice_skip(skip); @@ -1301,32 +1280,18 @@ fn parse_headers_iter_uninit<'a, 'b>( } }; - let uninit_header = match iter.next() { - Some(header) => header, - None => break 'headers - }; - - // trim trailing whitespace in the header - let header_value = if let Some(last_visible) = value_slice - .iter() - .rposition(|b| *b != b' ' && *b != b'\t' && *b != b'\r' && *b != b'\n') - { - // There is at least one non-whitespace character. - &value_slice[0..last_visible+1] - } else { - // There is no non-whitespace character. This can only happen when value_slice is - // empty. - value_slice - }; - - *uninit_header = MaybeUninit::new(Header { + // TODO: can move to the top of the loop + if headers.len() >= max_headers { + return Err(Some(Error::TooManyHeaders)); + } + // Grow the header slice + *headers = unsafe { core::slice::from_raw_parts_mut(headers_ptr, headers.len() + 1) }; + // Write new header + headers[headers.len() - 1] = MaybeUninit::new(Header { name: header_name, - value: header_value, + value: trim_ascii_end(value_slice), }); - autoshrink.num_headers += 1; } - - result } /// Parse a buffer of bytes as a chunk size. @@ -1350,7 +1315,10 @@ pub fn parse_chunk_size(buf: &[u8]) let mut in_ext = false; let mut count = 0; loop { - let b = next!(bytes); + let b = match bytes.next() { + Some(b) => b, + None => return Ok(Status::Partial), + }; match b { b'0' ..= b'9' if in_chunk_size => { if count > 15 { @@ -1377,9 +1345,10 @@ pub fn parse_chunk_size(buf: &[u8]) size += (b + 10 - b'A') as u64; } b'\r' => { - match next!(bytes) { - b'\n' => break, - _ => return Err(InvalidChunkSize), + match bytes.next() { + Some(b'\n') => break, + Some(_) => return Err(InvalidChunkSize), + None => return Ok(Status::Partial), } } // If we weren't in the extension yet, the ";" signals its start @@ -1406,10 +1375,25 @@ pub fn parse_chunk_size(buf: &[u8]) Ok(Status::Complete((bytes.pos(), size))) } +// Parser helpers +#[inline] +fn next(bytes: &mut Bytes) -> InnerResult { + bytes.next().ok_or(ERR_PARTIAL) +} + +#[inline] +fn expect_next(bytes: &mut Bytes, predicate: impl Fn(&u8) -> bool, err: Error) -> InnerResult { + let b = next(bytes)?; + if predicate(&b) { + Ok(b) + } else { + Err(Some(err)) + } +} + #[cfg(test)] mod tests { use super::{Request, Response, Status, EMPTY_HEADER, parse_chunk_size}; - use super::{offsetnz, validate_header_value_block, validate_uri_block}; const NUM_OF_HEADERS: usize = 4; @@ -2377,57 +2361,25 @@ mod tests { } #[test] - fn test_is_header_value_block() { - let is_header_value_block = |b| validate_header_value_block(b) == 8; - - // 0..32 => false - for b in 0..32_u8 { - assert_eq!(is_header_value_block([b; 8]), false, "b={}", b); - } - // 32..127 => true - for b in 32..127_u8 { - assert_eq!(is_header_value_block([b; 8]), true, "b={}", b); - } - // 127..=255 => false - for b in 127..=255_u8 { - assert_eq!(is_header_value_block([b; 8]), false, "b={}", b); - } - - // A few sanity checks on non-uniform bytes for safe-measure - assert!(!is_header_value_block(*b"foo.com\n")); - assert!(!is_header_value_block(*b"o.com\r\nU")); - } - - #[test] - fn test_is_uri_block() { - let is_uri_block = |b| validate_uri_block(b) == 8; - - // 0..33 => false - for b in 0..33_u8 { - assert_eq!(is_uri_block([b; 8]), false, "b={}", b); - } - // 33..127 => true if b not in { '<', '?', '>' } - let falsy = |b| b"".contains(&b); - for b in 33..127_u8 { - assert_eq!(is_uri_block([b; 8]), !falsy(b), "b={}", b); - } - // 127..=255 => false - for b in 127..=255_u8 { - assert_eq!(is_uri_block([b; 8]), false, "b={}", b); - } + fn test_inner_result_assumptions() { + assert_eq!( + core::mem::size_of::>(), + 2, + ); + assert_eq!( + core::mem::size_of::>>(), + 1, + ); + assert_eq!( + core::mem::size_of::>(), + 2, + ); + assert_eq!( + core::mem::size_of::>>(), + 2, + ); } - #[test] - fn test_offsetnz() { - let seq = [0_u8; 8]; - for i in 0..8 { - let mut seq = seq.clone(); - seq[i] = 1; - let x = u64::from_ne_bytes(seq); - assert_eq!(offsetnz(x), i); - } - } - #[test] fn test_method_within_buffer() { const REQUEST: &[u8] = b"GET / HTTP/1.1\r\n\r\n"; diff --git a/src/macros.rs b/src/macros.rs deleted file mode 100644 index fa4cf03..0000000 --- a/src/macros.rs +++ /dev/null @@ -1,59 +0,0 @@ -///! Utility macros - -macro_rules! next { - ($bytes:ident) => ({ - match $bytes.next() { - Some(b) => b, - None => return Ok(Status::Partial) - } - }) -} - -macro_rules! expect { - ($bytes:ident.next() == $pat:pat => $ret:expr) => { - expect!(next!($bytes) => $pat |? $ret) - }; - ($e:expr => $pat:pat |? $ret:expr) => { - match $e { - v@$pat => v, - _ => return $ret - } - }; -} - -macro_rules! complete { - ($e:expr) => { - match $e? { - Status::Complete(v) => v, - Status::Partial => return Ok(Status::Partial) - } - } -} - -macro_rules! byte_map { - ($($flag:expr,)*) => ([ - $($flag != 0,)* - ]) -} - -macro_rules! space { - ($bytes:ident or $err:expr) => ({ - expect!($bytes.next() == b' ' => Err($err)); - $bytes.slice(); - }) -} - -macro_rules! newline { - ($bytes:ident) => ({ - match next!($bytes) { - b'\r' => { - expect!($bytes.next() == b'\n' => Err(Error::NewLine)); - $bytes.slice(); - }, - b'\n' => { - $bytes.slice(); - }, - _ => return Err(Error::NewLine) - } - }) -} diff --git a/src/simd/mod.rs b/src/simd/mod.rs index 26ba6b6..81bdd87 100644 --- a/src/simd/mod.rs +++ b/src/simd/mod.rs @@ -1,20 +1,14 @@ -#[cfg(not(all( - httparse_simd, - any( - target_arch = "x86", - target_arch = "x86_64", - ), -)))] -mod fallback; +mod swar; #[cfg(not(all( httparse_simd, any( target_arch = "x86", target_arch = "x86_64", + target_arch = "aarch64", ), )))] -pub use self::fallback::*; +pub use self::swar::*; #[cfg(all( httparse_simd, @@ -74,6 +68,11 @@ pub use self::runtime::*; ), ))] mod sse42_compile_time { + #[inline(always)] + pub fn match_header_name_vectored(b: &mut crate::iter::Bytes<'_>) { + super::swar::match_header_name_vectored(b); + } + #[inline(always)] pub fn match_uri_vectored(b: &mut crate::iter::Bytes<'_>) { // SAFETY: calls are guarded by a compile time feature check @@ -107,6 +106,11 @@ pub use self::sse42_compile_time::*; ), ))] mod avx2_compile_time { + #[inline(always)] + pub fn match_header_name_vectored(b: &mut crate::iter::Bytes<'_>) { + super::swar::match_header_name_vectored(b); + } + #[inline(always)] pub fn match_uri_vectored(b: &mut crate::iter::Bytes<'_>) { // SAFETY: calls are guarded by a compile time feature check @@ -129,3 +133,15 @@ mod avx2_compile_time { ), ))] pub use self::avx2_compile_time::*; + +#[cfg(all( + httparse_simd, + target_arch = "aarch64", +))] +mod neon; + +#[cfg(all( + httparse_simd, + target_arch = "aarch64", +))] +pub use self::neon::*; diff --git a/src/simd/neon.rs b/src/simd/neon.rs new file mode 100644 index 0000000..d21cd83 --- /dev/null +++ b/src/simd/neon.rs @@ -0,0 +1,258 @@ +use crate::iter::Bytes; +use core::arch::aarch64::*; + +// NOTE: net-negative, so unused for now +#[allow(dead_code)] // NOTE: will use after https://github.com/seanmonstar/httparse/pull/134 +#[inline] +pub fn match_header_name_vectored(bytes: &mut Bytes) { + while bytes.as_ref().len() >= 16 { + unsafe { + let advance = match_header_name_char_16_neon(bytes.as_ref().as_ptr()); + bytes.advance(advance); + + if advance != 16 { + return; + } + } + } + super::swar::match_header_name_vectored(bytes); +} + +#[inline] +pub fn match_header_value_vectored(bytes: &mut Bytes) { + while bytes.as_ref().len() >= 16 { + unsafe { + let advance = match_header_value_char_16_neon(bytes.as_ref().as_ptr()); + bytes.advance(advance); + + if advance != 16 { + return; + } + } + } + super::swar::match_header_value_vectored(bytes); +} + +#[inline] +pub fn match_uri_vectored(bytes: &mut Bytes) { + while bytes.as_ref().len() >= 16 { + unsafe { + let advance = match_url_char_16_neon(bytes.as_ref().as_ptr()); + bytes.advance(advance); + + if advance != 16 { + return; + } + } + } + super::swar::match_uri_vectored(bytes); +} + +const fn bit_set(x: u8) -> bool { + // Validates if a byte is a valid header name character + // https://tools.ietf.org/html/rfc7230#section-3.2.6 + matches!(x, b'0'..=b'9' | b'a'..=b'z' | b'A'..=b'Z' | b'!' | b'#' | b'$' | b'%' | b'&' | b'\'' | b'*' | b'+' | b'-' | b'.' | b'^' | b'_' | b'`' | b'|' | b'~') +} + +// A 256-bit bitmap, split into two halves +// lower half contains bits whose higher nibble is <= 7 +// higher half contains bits whose higher nibble is >= 8 +const fn build_bitmap() -> ([u8; 16], [u8; 16]) { + let mut bitmap_0_7 = [0u8; 16]; // 0x00..0x7F + let mut bitmap_8_15 = [0u8; 16]; // 0x80..0xFF + let mut i = 0; + while i < 256 { + if bit_set(i as u8) { + // Nibbles + let (lo, hi) = (i & 0x0F, i >> 4); + if i < 128 { + bitmap_0_7[lo] |= 1 << hi; + } else { + bitmap_8_15[lo] |= 1 << hi; + } + } + i += 1; + } + (bitmap_0_7, bitmap_8_15) +} + +const BITMAPS: ([u8; 16], [u8; 16]) = build_bitmap(); + +// NOTE: adapted from 256-bit version, with upper 128-bit ops commented out +#[inline] +unsafe fn match_header_name_char_16_neon(ptr: *const u8) -> usize { + let bitmaps = BITMAPS; + // NOTE: ideally compile-time constants + let (bitmap_0_7, _bitmap_8_15) = bitmaps; + let bitmap_0_7 = vld1q_u8(bitmap_0_7.as_ptr()); + // let bitmap_8_15 = vld1q_u8(bitmap_8_15.as_ptr()); + + // Initialize the bitmask_lookup. + const BITMASK_LOOKUP_DATA: [u8; 16] = + [1, 2, 4, 8, 16, 32, 64, 128, 1, 2, 4, 8, 16, 32, 64, 128]; + let bitmask_lookup = vld1q_u8(BITMASK_LOOKUP_DATA.as_ptr()); + + // Load 16 input bytes. + let input = vld1q_u8(ptr); + + // Extract indices for row_0_7. + let indices_0_7 = vandq_u8(input, vdupq_n_u8(0x8F)); // 0b1000_1111; + + // Extract indices for row_8_15. + // let msb = vandq_u8(input, vdupq_n_u8(0x80)); + // let indices_8_15 = veorq_u8(indices_0_7, msb); + + // Fetch row_0_7 and row_8_15. + let row_0_7 = vqtbl1q_u8(bitmap_0_7, indices_0_7); + // let row_8_15 = vqtbl1q_u8(bitmap_8_15, indices_8_15); + + // Calculate a bitmask, i.e. (1 << hi_nibble % 8). + let bitmask = vqtbl1q_u8(bitmask_lookup, vshrq_n_u8(input, 4)); + + // Choose rows halves depending on higher nibbles. + // let bitsets = vorrq_u8(row_0_7, row_8_15); + let bitsets = row_0_7; + + // Finally check which bytes belong to the set. + let tmp = vandq_u8(bitsets, bitmask); + let result = vceqq_u8(tmp, bitmask); + + offsetz(result) as usize +} + +#[inline] +unsafe fn match_url_char_16_neon(ptr: *const u8) -> usize { + let input = vld1q_u8(ptr); + + // Check that b'!' <= input <= b'~' + let result = vandq_u8( + vcleq_u8(vdupq_n_u8(b'!'), input), + vcleq_u8(input, vdupq_n_u8(b'~')), + ); + // Check that input != b'<' and input != b'>' + let lt = vceqq_u8(input, vdupq_n_u8(b'<')); + let gt = vceqq_u8(input, vdupq_n_u8(b'>')); + let ltgt = vorrq_u8(lt, gt); + // Nand with result + let result = vbicq_u8(result, ltgt); + + offsetz(result) as usize +} + +#[inline] +unsafe fn match_header_value_char_16_neon(ptr: *const u8) -> usize { + let input = vld1q_u8(ptr); + + // Check that b' ' <= and b != 127 or b == 9 + let result = vcleq_u8(vdupq_n_u8(b' '), input); + + // Allow tab + let tab = vceqq_u8(input, vdupq_n_u8(0x09)); + let result = vorrq_u8(result, tab); + + // Disallow del + let del = vceqq_u8(input, vdupq_n_u8(0x7F)); + let result = vbicq_u8(result, del); + + offsetz(result) as usize +} + +#[inline] +unsafe fn offsetz(x: uint8x16_t) -> u32 { + // NOT the vector since it's faster to operate with zeros instead + offsetnz(vmvnq_u8(x)) +} + +#[inline] +unsafe fn offsetnz(x: uint8x16_t) -> u32 { + // Extract two u64 + let x = vreinterpretq_u64_u8(x); + let low: u64 = std::mem::transmute(vget_low_u64(x)); + let high: u64 = std::mem::transmute(vget_high_u64(x)); + + #[inline] + fn clz(x: u64) -> u32 { + // perf: rust will unroll this loop + // and it's much faster than rbit + clz so voila + for (i, b) in x.to_ne_bytes().iter().copied().enumerate() { + if b != 0 { + return i as u32; + } + } + 8 // Technically not reachable since zero-guarded + } + + if low != 0 { + return clz(low); + } else if high != 0 { + return 8 + clz(high); + } else { + return 16; + } +} + +#[test] +fn neon_code_matches_uri_chars_table() { + unsafe { + assert!(byte_is_allowed(b'_', match_uri_vectored)); + + for (b, allowed) in crate::URI_MAP.iter().cloned().enumerate() { + assert_eq!( + byte_is_allowed(b as u8, match_uri_vectored), + allowed, + "byte_is_allowed({:?}) should be {:?}", + b, + allowed, + ); + } + } +} + +#[test] +fn neon_code_matches_header_value_chars_table() { + unsafe { + assert!(byte_is_allowed(b'_', match_header_value_vectored)); + + for (b, allowed) in crate::HEADER_VALUE_MAP.iter().cloned().enumerate() { + assert_eq!( + byte_is_allowed(b as u8, match_header_value_vectored), + allowed, + "byte_is_allowed({:?}) should be {:?}", + b, + allowed, + ); + } + } +} + +#[test] +fn neon_code_matches_header_name_chars_table() { + unsafe { + assert!(byte_is_allowed(b'_', match_header_name_vectored)); + + for (b, allowed) in crate::HEADER_NAME_MAP.iter().cloned().enumerate() { + assert_eq!( + byte_is_allowed(b as u8, match_header_name_vectored), + allowed, + "byte_is_allowed({:?}) should be {:?}", + b, + allowed, + ); + } + } +} + +#[cfg(test)] +unsafe fn byte_is_allowed(byte: u8, f: unsafe fn(bytes: &mut Bytes<'_>)) -> bool { + let mut slice = [b'_'; 16]; + slice[10] = byte; + let mut bytes = Bytes::new(&slice); + + f(&mut bytes); + + match bytes.pos() { + 16 => true, + 10 => false, + x => panic!("unexpected pos: {}", x), + } +} diff --git a/src/simd/nop.rs b/src/simd/nop.rs new file mode 100644 index 0000000..871cd01 --- /dev/null +++ b/src/simd/nop.rs @@ -0,0 +1,8 @@ +use crate::iter::Bytes; + +// Fallbacks that do nothing... + +#[inline(always)] +pub fn match_uri_vectored(_: &mut Bytes<'_>) {} +#[inline(always)] +pub fn match_header_value_vectored(_: &mut Bytes<'_>) {} diff --git a/src/simd/runtime.rs b/src/simd/runtime.rs index 3bf4d3b..c523a92 100644 --- a/src/simd/runtime.rs +++ b/src/simd/runtime.rs @@ -30,13 +30,17 @@ fn get_runtime_feature() -> u8 { feature } +pub fn match_header_name_vectored(bytes: &mut Bytes) { + super::swar::match_header_name_vectored(bytes); +} + pub fn match_uri_vectored(bytes: &mut Bytes) { // SAFETY: calls are guarded by a feature check unsafe { match get_runtime_feature() { AVX2 => avx2::match_uri_vectored(bytes), SSE42 => sse42::match_uri_vectored(bytes), - _ => {}, + _ /* NOP */ => super::swar::match_uri_vectored(bytes), } } } @@ -47,7 +51,7 @@ pub fn match_header_value_vectored(bytes: &mut Bytes) { match get_runtime_feature() { AVX2 => avx2::match_header_value_vectored(bytes), SSE42 => sse42::match_header_value_vectored(bytes), - _ => {}, + _ /* NOP */ => super::swar::match_header_value_vectored(bytes), } } } diff --git a/src/simd/sse42.rs b/src/simd/sse42.rs index e6b41f9..dcae82b 100644 --- a/src/simd/sse42.rs +++ b/src/simd/sse42.rs @@ -7,9 +7,10 @@ pub unsafe fn match_uri_vectored(bytes: &mut Bytes) { bytes.advance(advance); if advance != 16 { - break; + return; } } + super::swar::match_uri_vectored(bytes); } #[inline(always)] @@ -67,9 +68,10 @@ pub unsafe fn match_header_value_vectored(bytes: &mut Bytes) { bytes.advance(advance); if advance != 16 { - break; + return; } } + super::swar::match_header_value_vectored(bytes); } #[inline(always)] diff --git a/src/simd/swar.rs b/src/simd/swar.rs new file mode 100644 index 0000000..ad8f211 --- /dev/null +++ b/src/simd/swar.rs @@ -0,0 +1,228 @@ +/// SWAR: SIMD Within A Register +/// SIMD validator backend that validates register-sized chunks of data at a time. +// TODO: current impl assumes 64-bit registers, optimize for 32-bit +use crate::{is_header_name_token, is_header_value_token, is_uri_token, Bytes}; + +#[inline] +pub fn match_uri_vectored(bytes: &mut Bytes) { + loop { + if let Some(bytes8) = bytes.peek_n() { + let n = match_uri_char_8_swar(bytes8); + unsafe { + bytes.advance(n); + } + if n == 8 { + continue; + } + } + if let Some(b) = bytes.peek() { + if is_uri_token(b) { + unsafe { bytes.advance(1); } + continue; + } + } + break; + } +} + +#[inline] +pub fn match_header_value_vectored(bytes: &mut Bytes) { + loop { + if let Some(bytes8) = bytes.peek_n() { + let n = match_header_value_char_8_swar(bytes8); + unsafe { + bytes.advance(n); + } + if n == 8 { + continue; + } + } + if let Some(b) = bytes.peek() { + if is_header_value_token(b) { + unsafe { bytes.advance(1); } + continue; + } + } + break; + } +} + +#[inline] +pub fn match_header_name_vectored(bytes: &mut Bytes) { + while let Some(block) = bytes.peek_n() { + let n = match_block(is_header_name_token, block); + unsafe { + bytes.advance(n); + } + if n != 8 { + return; + } + } + unsafe { bytes.advance(match_tail(is_header_name_token, bytes.as_ref())) }; +} + +// Matches "tail", i.e: when we have <8 bytes in the buffer, should be uncommon +#[cold] +#[inline] +fn match_tail(f: impl Fn(u8) -> bool, bytes: &[u8]) -> usize { + for (i, &b) in bytes.iter().enumerate() { + if !f(b) { + return i; + } + } + bytes.len() +} + +// Naive fallback block matcher +#[inline(always)] +fn match_block(f: impl Fn(u8) -> bool, block: [u8; 8]) -> usize { + for (i, &b) in block.iter().enumerate() { + if !f(b) { + return i; + } + } + 8 +} + +/// // A const alternative to u64::from_ne_bytes to avoid bumping MSRV (1.36 => 1.44) +// creates a u64 whose bytes are each equal to b +const fn uniform_block(b: u8) -> u64 { + b as u64 * 0x01_01_01_01_01_01_01_01 // [1_u8; 8] +} + +// A byte-wise range-check on an enire word/block, +// ensuring all bytes in the word satisfy +// `33 <= x <= 126 && x != '>' && x != '<'` +// IMPORTANT: it false negatives if the block contains '?' +#[inline] +fn match_uri_char_8_swar(block: [u8; 8]) -> usize { + // 33 <= x <= 126 + const M: u8 = 0x21; + const N: u8 = 0x7E; + const BM: u64 = uniform_block(M); + const BN: u64 = uniform_block(127 - N); + const M128: u64 = uniform_block(128); + + let x = u64::from_ne_bytes(block); // Really just a transmute + let lt = x.wrapping_sub(BM) & !x; // <= m + let gt = x.wrapping_add(BN) | x; // >= n + + // XOR checks to catch '<' & '>' for correctness + // + // XOR can be thought of as a "distance function" + // (somewhat extrapolating from the `xor(x, x) = 0` identity and ∀ x != y: xor(x, y) != 0` + // (each u8 "xor key" providing a unique total ordering of u8) + // '<' and '>' have a "xor distance" of 2 (`xor('<', '>') = 2`) + // xor(x, '>') <= 2 => {'>', '?', '<'} + // xor(x, '<') <= 2 => {'<', '=', '>'} + // + // We assume P('=') > P('?'), + // given well/commonly-formatted URLs with querystrings contain + // a single '?' but possibly many '=' + // + // Thus it's preferable/near-optimal to "xor distance" on '>', + // since we'll slowpath at most one block per URL + // + // Some rust code to sanity check this yourself: + // ```rs + // fn xordist(x: u8, n: u8) -> Vec<(char, u8)> { + // (0..=255).into_iter().map(|c| (c as char, c ^ x)).filter(|(_c, y)| *y <= n).collect() + // } + // (xordist(b'<', 2), xordist(b'>', 2)) + // ``` + const B3: u64 = uniform_block(3); // (dist <= 2) + 1 to wrap + const BGT: u64 = uniform_block(b'>'); + + let xgt = x ^ BGT; + let ltgtq = xgt.wrapping_sub(B3) & !xgt; + + offsetnz((ltgtq | lt | gt) & M128) +} + +// A byte-wise range-check on an entire word/block, +// ensuring all bytes in the word satisfy `32 <= x <= 126` +// IMPORTANT: false negatives if obs-text is present (0x80..=0xFF) +#[inline] +fn match_header_value_char_8_swar(block: [u8; 8]) -> usize { + // 32 <= x <= 126 + const M: u8 = 0x20; + const N: u8 = 0x7E; + const BM: u64 = uniform_block(M); + const BN: u64 = uniform_block(127 - N); + const M128: u64 = uniform_block(128); + + let x = u64::from_ne_bytes(block); // Really just a transmute + let lt = x.wrapping_sub(BM) & !x; // <= m + let gt = x.wrapping_add(BN) | x; // >= n + offsetnz((lt | gt) & M128) +} + +/// Check block to find offset of first non-zero byte +// NOTE: Curiously `block.trailing_zeros() >> 3` appears to be slower, maybe revisit +#[inline] +fn offsetnz(block: u64) -> usize { + // fast path optimistic case (common for long valid sequences) + if block == 0 { + return 8; + } + + // perf: rust will unroll this loop + for (i, b) in block.to_ne_bytes().iter().copied().enumerate() { + if b != 0 { + return i; + } + } + unreachable!() +} + +#[test] +fn test_is_header_value_block() { + let is_header_value_block = |b| match_header_value_char_8_swar(b) == 8; + + // 0..32 => false + for b in 0..32_u8 { + assert_eq!(is_header_value_block([b; 8]), false, "b={}", b); + } + // 32..127 => true + for b in 32..127_u8 { + assert_eq!(is_header_value_block([b; 8]), true, "b={}", b); + } + // 127..=255 => false + for b in 127..=255_u8 { + assert_eq!(is_header_value_block([b; 8]), false, "b={}", b); + } + + // A few sanity checks on non-uniform bytes for safe-measure + assert!(!is_header_value_block(*b"foo.com\n")); + assert!(!is_header_value_block(*b"o.com\r\nU")); +} + +#[test] +fn test_is_uri_block() { + let is_uri_block = |b| match_uri_char_8_swar(b) == 8; + + // 0..33 => false + for b in 0..33_u8 { + assert_eq!(is_uri_block([b; 8]), false, "b={}", b); + } + // 33..127 => true if b not in { '<', '?', '>' } + let falsy = |b| b"".contains(&b); + for b in 33..127_u8 { + assert_eq!(is_uri_block([b; 8]), !falsy(b), "b={}", b); + } + // 127..=255 => false + for b in 127..=255_u8 { + assert_eq!(is_uri_block([b; 8]), false, "b={}", b); + } +} + +#[test] +fn test_offsetnz() { + let seq = [0_u8; 8]; + for i in 0..8 { + let mut seq = seq.clone(); + seq[i] = 1; + let x = u64::from_ne_bytes(seq); + assert_eq!(offsetnz(x), i); + } +}