From 1f3619bd3ce2cecc98585e0192fb3bd3965d0e4c Mon Sep 17 00:00:00 2001 From: Aaron O'Mullan Date: Fri, 21 Apr 2023 02:29:06 +0300 Subject: [PATCH 1/2] refactor: simd swar Moves the block-wise validators to a "swar" SIMD backend The core logic of validate => extract => chain is now more evident --- src/lib.rs | 234 +++---------------------------- src/simd/mod.rs | 21 +-- src/simd/{fallback.rs => nop.rs} | 0 src/simd/runtime.rs | 8 +- src/simd/sse42.rs | 6 +- src/simd/swar.rs | 228 ++++++++++++++++++++++++++++++ 6 files changed, 268 insertions(+), 229 deletions(-) rename src/simd/{fallback.rs => nop.rs} (100%) create mode 100644 src/simd/swar.rs diff --git a/src/lib.rs b/src/lib.rs index 988d982..7e10dbf 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -87,65 +87,10 @@ static URI_MAP: [bool; 256] = byte_map![ ]; #[inline] -fn is_uri_token(b: u8) -> bool { +pub(crate) fn is_uri_token(b: u8) -> bool { URI_MAP[b as usize] } -// A const alternative to u64::from_ne_bytes to avoid bumping MSRV (1.36 => 1.44) -// creates a u64 whose bytes are each equal to b -const fn uniform_block(b: u8) -> u64 { - b as u64 * 0x01_01_01_01_01_01_01_01 // [1_u8; 8] -} - -// A byte-wise range-check on an enire word/block, -// ensuring all bytes in the word satisfy -// `33 <= x <= 126 && x != '>' && x != '<'` -// it false negatives if the block contains '?' -#[inline] -fn validate_uri_block(block: [u8; 8]) -> usize { - // 33 <= x <= 126 - const M: u8 = 0x21; - const N: u8 = 0x7E; - const BM: u64 = uniform_block(M); - const BN: u64 = uniform_block(127-N); - const M128: u64 = uniform_block(128); - - let x = u64::from_ne_bytes(block); // Really just a transmute - let lt = x.wrapping_sub(BM) & !x; // <= m - let gt = x.wrapping_add(BN) | x; // >= n - - // XOR checks to catch '<' & '>' for correctness - // - // XOR can be thought of as a "distance function" - // (somewhat extrapolating from the `xor(x, x) = 0` identity and ∀ x != y: xor(x, y) != 0` - // (each u8 "xor key" providing a unique total ordering of u8) - // '<' and '>' have a "xor distance" of 2 (`xor('<', '>') = 2`) - // xor(x, '>') <= 2 => {'>', '?', '<'} - // xor(x, '<') <= 2 => {'<', '=', '>'} - // - // We assume P('=') > P('?'), - // given well/commonly-formatted URLs with querystrings contain - // a single '?' but possibly many '=' - // - // Thus it's preferable/near-optimal to "xor distance" on '>', - // since we'll slowpath at most one block per URL - // - // Some rust code to sanity check this yourself: - // ```rs - // fn xordist(x: u8, n: u8) -> Vec<(char, u8)> { - // (0..=255).into_iter().map(|c| (c as char, c ^ x)).filter(|(_c, y)| *y <= n).collect() - // } - // (xordist(b'<', 2), xordist(b'>', 2)) - // ``` - const B3: u64 = uniform_block(3); // (dist <= 2) + 1 to wrap - const BGT: u64 = uniform_block(b'>'); - - let xgt = x ^ BGT; - let ltgtq = xgt.wrapping_sub(B3) & !xgt; - - offsetnz((ltgtq | lt | gt) & M128) -} - static HEADER_NAME_MAP: [bool; 256] = byte_map![ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, @@ -166,7 +111,7 @@ static HEADER_NAME_MAP: [bool; 256] = byte_map![ ]; #[inline] -fn is_header_name_token(b: u8) -> bool { +pub(crate) fn is_header_name_token(b: u8) -> bool { HEADER_NAME_MAP[b as usize] } @@ -191,45 +136,10 @@ static HEADER_VALUE_MAP: [bool; 256] = byte_map![ #[inline] -fn is_header_value_token(b: u8) -> bool { +pub(crate) fn is_header_value_token(b: u8) -> bool { HEADER_VALUE_MAP[b as usize] } -// A byte-wise range-check on an entire word/block, -// ensuring all bytes in the word satisfy `32 <= x <= 126` -#[inline] -fn validate_header_value_block(block: [u8; 8]) -> usize { - // 32 <= x <= 126 - const M: u8 = 0x20; - const N: u8 = 0x7E; - const BM: u64 = uniform_block(M); - const BN: u64 = uniform_block(127-N); - const M128: u64 = uniform_block(128); - - let x = u64::from_ne_bytes(block); // Really just a transmute - let lt = x.wrapping_sub(BM) & !x; // <= m - let gt = x.wrapping_add(BN) | x; // >= n - offsetnz((lt | gt) & M128) -} - -#[inline] -/// Check block to find offset of first non-zero byte -// NOTE: Curiously `block.trailing_zeros() >> 3` appears to be slower, maybe revisit -fn offsetnz(block: u64) -> usize { - // fast path optimistic case (common for long valid sequences) - if block == 0 { - return 8; - } - - // perf: rust will unroll this loop - for (i, b) in block.to_ne_bytes().iter().copied().enumerate() { - if b != 0 { - return i; - } - } - unreachable!() -} - /// An error in parsing. #[derive(Copy, Clone, PartialEq, Eq, Debug)] pub enum Error { @@ -966,28 +876,14 @@ fn parse_token<'a>(bytes: &mut Bytes<'a>) -> Result<&'a str> { #[allow(missing_docs)] // WARNING: Exported for internal benchmarks, not fit for public consumption pub fn parse_uri<'a>(bytes: &mut Bytes<'a>) -> Result<&'a str> { - let b = next!(bytes); - if !is_uri_token(b) { - // First char must be a URI char, it can't be a space which would indicate an empty path. + let start = bytes.pos(); + simd::match_uri_vectored(bytes); + // URI must have at least one char + if bytes.pos() == start { return Err(Error::Token); } - simd::match_uri_vectored(bytes); - - let mut b; - loop { - if let Some(bytes8) = bytes.peek_n::<[u8; 8]>(8) { - let n = validate_uri_block(bytes8); - unsafe { bytes.advance(n); } - if n == 8 { continue; } - } - b = next!(bytes); - if !is_uri_token(b) { - break; - } - } - - if b == b' ' { + if next!(bytes) == b' ' { return Ok(Status::Complete(unsafe { // all bytes up till `i` must have been `is_token`. str::from_utf8_unchecked(bytes.slice_skip(1)) @@ -1099,7 +995,8 @@ fn parse_headers_iter_uninit<'a, 'b>( headers, num_headers: 0, }; - let mut count: usize = 0; + // Track starting pointer to calculate the number of bytes parsed. + let start = bytes.as_ref().as_ptr() as usize; let mut result = Err(Error::TooManyHeaders); let mut iter = autoshrink.headers.iter_mut(); @@ -1155,7 +1052,6 @@ fn parse_headers_iter_uninit<'a, 'b>( b = next!($bytes); } - count += $bytes.pos(); $bytes.slice(); continue 'headers; @@ -1166,11 +1062,13 @@ fn parse_headers_iter_uninit<'a, 'b>( let b = next!(bytes); if b == b'\r' { expect!(bytes.next() == b'\n' => Err(Error::NewLine)); - result = Ok(Status::Complete(count + bytes.pos())); + let end = bytes.as_ref().as_ptr() as usize; + result = Ok(Status::Complete(end - start)); break; } if b == b'\n' { - result = Ok(Status::Complete(count + bytes.pos())); + let end = bytes.as_ref().as_ptr() as usize; + result = Ok(Status::Complete(end - start)); break; } if !is_header_name_token(b) { @@ -1178,38 +1076,10 @@ fn parse_headers_iter_uninit<'a, 'b>( } // parse header name until colon - let mut b; let header_name: &str = 'name: loop { - 'name_inner: loop { - if let Some(bytes8) = bytes.peek_n::<[u8; 8]>(8) { - macro_rules! check { - ($bytes:ident, $i:literal) => ({ - b = $bytes[$i]; - if !is_header_name_token(b) { - unsafe { bytes.advance($i + 1); } - break 'name_inner; - } - }); - } - - check!(bytes8, 0); - check!(bytes8, 1); - check!(bytes8, 2); - check!(bytes8, 3); - check!(bytes8, 4); - check!(bytes8, 5); - check!(bytes8, 6); - check!(bytes8, 7); - unsafe { bytes.advance(8); } - } else { - b = next!(bytes); - if !is_header_name_token(b) { - break 'name_inner; - } - } - } - - count += bytes.pos(); + simd::match_header_name_vectored(bytes); + let mut b = next!(bytes); + let name = unsafe { str::from_utf8_unchecked(bytes.slice_skip(1)) }; @@ -1223,7 +1093,6 @@ fn parse_headers_iter_uninit<'a, 'b>( b = next!(bytes); if b == b':' { - count += bytes.pos(); bytes.slice(); break 'name name; } @@ -1240,7 +1109,6 @@ fn parse_headers_iter_uninit<'a, 'b>( 'whitespace_after_colon: loop { b = next!(bytes); if b == b' ' || b == b'\t' { - count += bytes.pos(); bytes.slice(); continue 'whitespace_after_colon; } @@ -1256,7 +1124,6 @@ fn parse_headers_iter_uninit<'a, 'b>( maybe_continue_after_obsolete_line_folding!(bytes, 'whitespace_after_colon); - count += bytes.pos(); let whitespace_slice = bytes.slice(); // This produces an empty slice that points to the beginning @@ -1268,18 +1135,7 @@ fn parse_headers_iter_uninit<'a, 'b>( // parse value till EOL simd::match_header_value_vectored(bytes); - - 'value_line: loop { - if let Some(bytes8) = bytes.peek_n::<[u8; 8]>(8) { - let n = validate_header_value_block(bytes8); - unsafe { bytes.advance(n); } - if n == 8 { continue 'value_line; } - } - b = next!(bytes); - if !is_header_value_token(b) { - break 'value_line; - } - } + let b = next!(bytes); //found_ctl let skip = if b == b'\r' { @@ -1293,7 +1149,6 @@ fn parse_headers_iter_uninit<'a, 'b>( maybe_continue_after_obsolete_line_folding!(bytes, 'value_lines); - count += bytes.pos(); // having just checked that a newline exists, it's safe to skip it. unsafe { break 'value bytes.slice_skip(skip); @@ -1409,7 +1264,6 @@ pub fn parse_chunk_size(buf: &[u8]) #[cfg(test)] mod tests { use super::{Request, Response, Status, EMPTY_HEADER, parse_chunk_size}; - use super::{offsetnz, validate_header_value_block, validate_uri_block}; const NUM_OF_HEADERS: usize = 4; @@ -2376,58 +2230,6 @@ mod tests { assert_eq!(response.headers[0].value, &b"baguette"[..]); } - #[test] - fn test_is_header_value_block() { - let is_header_value_block = |b| validate_header_value_block(b) == 8; - - // 0..32 => false - for b in 0..32_u8 { - assert_eq!(is_header_value_block([b; 8]), false, "b={}", b); - } - // 32..127 => true - for b in 32..127_u8 { - assert_eq!(is_header_value_block([b; 8]), true, "b={}", b); - } - // 127..=255 => false - for b in 127..=255_u8 { - assert_eq!(is_header_value_block([b; 8]), false, "b={}", b); - } - - // A few sanity checks on non-uniform bytes for safe-measure - assert!(!is_header_value_block(*b"foo.com\n")); - assert!(!is_header_value_block(*b"o.com\r\nU")); - } - - #[test] - fn test_is_uri_block() { - let is_uri_block = |b| validate_uri_block(b) == 8; - - // 0..33 => false - for b in 0..33_u8 { - assert_eq!(is_uri_block([b; 8]), false, "b={}", b); - } - // 33..127 => true if b not in { '<', '?', '>' } - let falsy = |b| b"".contains(&b); - for b in 33..127_u8 { - assert_eq!(is_uri_block([b; 8]), !falsy(b), "b={}", b); - } - // 127..=255 => false - for b in 127..=255_u8 { - assert_eq!(is_uri_block([b; 8]), false, "b={}", b); - } - } - - #[test] - fn test_offsetnz() { - let seq = [0_u8; 8]; - for i in 0..8 { - let mut seq = seq.clone(); - seq[i] = 1; - let x = u64::from_ne_bytes(seq); - assert_eq!(offsetnz(x), i); - } - } - #[test] fn test_method_within_buffer() { const REQUEST: &[u8] = b"GET / HTTP/1.1\r\n\r\n"; diff --git a/src/simd/mod.rs b/src/simd/mod.rs index 26ba6b6..91f682d 100644 --- a/src/simd/mod.rs +++ b/src/simd/mod.rs @@ -1,11 +1,4 @@ -#[cfg(not(all( - httparse_simd, - any( - target_arch = "x86", - target_arch = "x86_64", - ), -)))] -mod fallback; +mod swar; #[cfg(not(all( httparse_simd, @@ -14,7 +7,7 @@ mod fallback; target_arch = "x86_64", ), )))] -pub use self::fallback::*; +pub use self::swar::*; #[cfg(all( httparse_simd, @@ -74,6 +67,11 @@ pub use self::runtime::*; ), ))] mod sse42_compile_time { + #[inline(always)] + pub fn match_header_name_vectored(b: &mut crate::iter::Bytes<'_>) { + super::swar::match_header_name_vectored(b); + } + #[inline(always)] pub fn match_uri_vectored(b: &mut crate::iter::Bytes<'_>) { // SAFETY: calls are guarded by a compile time feature check @@ -107,6 +105,11 @@ pub use self::sse42_compile_time::*; ), ))] mod avx2_compile_time { + #[inline(always)] + pub fn match_header_name_vectored(b: &mut crate::iter::Bytes<'_>) { + super::swar::match_header_name_vectored(b); + } + #[inline(always)] pub fn match_uri_vectored(b: &mut crate::iter::Bytes<'_>) { // SAFETY: calls are guarded by a compile time feature check diff --git a/src/simd/fallback.rs b/src/simd/nop.rs similarity index 100% rename from src/simd/fallback.rs rename to src/simd/nop.rs diff --git a/src/simd/runtime.rs b/src/simd/runtime.rs index 3bf4d3b..c523a92 100644 --- a/src/simd/runtime.rs +++ b/src/simd/runtime.rs @@ -30,13 +30,17 @@ fn get_runtime_feature() -> u8 { feature } +pub fn match_header_name_vectored(bytes: &mut Bytes) { + super::swar::match_header_name_vectored(bytes); +} + pub fn match_uri_vectored(bytes: &mut Bytes) { // SAFETY: calls are guarded by a feature check unsafe { match get_runtime_feature() { AVX2 => avx2::match_uri_vectored(bytes), SSE42 => sse42::match_uri_vectored(bytes), - _ => {}, + _ /* NOP */ => super::swar::match_uri_vectored(bytes), } } } @@ -47,7 +51,7 @@ pub fn match_header_value_vectored(bytes: &mut Bytes) { match get_runtime_feature() { AVX2 => avx2::match_header_value_vectored(bytes), SSE42 => sse42::match_header_value_vectored(bytes), - _ => {}, + _ /* NOP */ => super::swar::match_header_value_vectored(bytes), } } } diff --git a/src/simd/sse42.rs b/src/simd/sse42.rs index e6b41f9..dcae82b 100644 --- a/src/simd/sse42.rs +++ b/src/simd/sse42.rs @@ -7,9 +7,10 @@ pub unsafe fn match_uri_vectored(bytes: &mut Bytes) { bytes.advance(advance); if advance != 16 { - break; + return; } } + super::swar::match_uri_vectored(bytes); } #[inline(always)] @@ -67,9 +68,10 @@ pub unsafe fn match_header_value_vectored(bytes: &mut Bytes) { bytes.advance(advance); if advance != 16 { - break; + return; } } + super::swar::match_header_value_vectored(bytes); } #[inline(always)] diff --git a/src/simd/swar.rs b/src/simd/swar.rs new file mode 100644 index 0000000..13f58a8 --- /dev/null +++ b/src/simd/swar.rs @@ -0,0 +1,228 @@ +/// SWAR: SIMD Within A Register +/// SIMD validator backend that validates register-sized chunks of data at a time. +// TODO: current impl assumes 64-bit registers, optimize for 32-bit +use crate::{is_header_name_token, is_header_value_token, is_uri_token, Bytes}; + +#[inline] +pub fn match_uri_vectored(bytes: &mut Bytes) { + loop { + if let Some(bytes8) = bytes.peek_n::<[u8; 8]>(8) { + let n = match_uri_char_8_swar(bytes8); + unsafe { + bytes.advance(n); + } + if n == 8 { + continue; + } + } + if let Some(b) = bytes.peek() { + if is_uri_token(b) { + unsafe { bytes.advance(1); } + continue; + } + } + break; + } +} + +#[inline] +pub fn match_header_value_vectored(bytes: &mut Bytes) { + loop { + if let Some(bytes8) = bytes.peek_n::<[u8; 8]>(8) { + let n = match_header_value_char_8_swar(bytes8); + unsafe { + bytes.advance(n); + } + if n == 8 { + continue; + } + } + if let Some(b) = bytes.peek() { + if is_header_value_token(b) { + unsafe { bytes.advance(1); } + continue; + } + } + break; + } +} + +#[inline] +pub fn match_header_name_vectored(bytes: &mut Bytes) { + while let Some(block) = bytes.peek_n::<[u8; 8]>(8) { + let n = match_block(is_header_name_token, block); + unsafe { + bytes.advance(n); + } + if n != 8 { + return; + } + } + unsafe { bytes.advance(match_tail(is_header_name_token, bytes.as_ref())) }; +} + +// Matches "tail", i.e: when we have <8 bytes in the buffer, should be uncommon +#[cold] +#[inline] +fn match_tail(f: impl Fn(u8) -> bool, bytes: &[u8]) -> usize { + for (i, &b) in bytes.iter().enumerate() { + if !f(b) { + return i; + } + } + bytes.len() +} + +// Naive fallback block matcher +#[inline(always)] +fn match_block(f: impl Fn(u8) -> bool, block: [u8; 8]) -> usize { + for (i, &b) in block.iter().enumerate() { + if !f(b) { + return i; + } + } + 8 +} + +/// // A const alternative to u64::from_ne_bytes to avoid bumping MSRV (1.36 => 1.44) +// creates a u64 whose bytes are each equal to b +const fn uniform_block(b: u8) -> u64 { + b as u64 * 0x01_01_01_01_01_01_01_01 // [1_u8; 8] +} + +// A byte-wise range-check on an enire word/block, +// ensuring all bytes in the word satisfy +// `33 <= x <= 126 && x != '>' && x != '<'` +// IMPORTANT: it false negatives if the block contains '?' +#[inline] +fn match_uri_char_8_swar(block: [u8; 8]) -> usize { + // 33 <= x <= 126 + const M: u8 = 0x21; + const N: u8 = 0x7E; + const BM: u64 = uniform_block(M); + const BN: u64 = uniform_block(127 - N); + const M128: u64 = uniform_block(128); + + let x = u64::from_ne_bytes(block); // Really just a transmute + let lt = x.wrapping_sub(BM) & !x; // <= m + let gt = x.wrapping_add(BN) | x; // >= n + + // XOR checks to catch '<' & '>' for correctness + // + // XOR can be thought of as a "distance function" + // (somewhat extrapolating from the `xor(x, x) = 0` identity and ∀ x != y: xor(x, y) != 0` + // (each u8 "xor key" providing a unique total ordering of u8) + // '<' and '>' have a "xor distance" of 2 (`xor('<', '>') = 2`) + // xor(x, '>') <= 2 => {'>', '?', '<'} + // xor(x, '<') <= 2 => {'<', '=', '>'} + // + // We assume P('=') > P('?'), + // given well/commonly-formatted URLs with querystrings contain + // a single '?' but possibly many '=' + // + // Thus it's preferable/near-optimal to "xor distance" on '>', + // since we'll slowpath at most one block per URL + // + // Some rust code to sanity check this yourself: + // ```rs + // fn xordist(x: u8, n: u8) -> Vec<(char, u8)> { + // (0..=255).into_iter().map(|c| (c as char, c ^ x)).filter(|(_c, y)| *y <= n).collect() + // } + // (xordist(b'<', 2), xordist(b'>', 2)) + // ``` + const B3: u64 = uniform_block(3); // (dist <= 2) + 1 to wrap + const BGT: u64 = uniform_block(b'>'); + + let xgt = x ^ BGT; + let ltgtq = xgt.wrapping_sub(B3) & !xgt; + + offsetnz((ltgtq | lt | gt) & M128) +} + +// A byte-wise range-check on an entire word/block, +// ensuring all bytes in the word satisfy `32 <= x <= 126` +// IMPORTANT: false negatives if obs-text is present (0x80..=0xFF) +#[inline] +fn match_header_value_char_8_swar(block: [u8; 8]) -> usize { + // 32 <= x <= 126 + const M: u8 = 0x20; + const N: u8 = 0x7E; + const BM: u64 = uniform_block(M); + const BN: u64 = uniform_block(127 - N); + const M128: u64 = uniform_block(128); + + let x = u64::from_ne_bytes(block); // Really just a transmute + let lt = x.wrapping_sub(BM) & !x; // <= m + let gt = x.wrapping_add(BN) | x; // >= n + offsetnz((lt | gt) & M128) +} + +/// Check block to find offset of first non-zero byte +// NOTE: Curiously `block.trailing_zeros() >> 3` appears to be slower, maybe revisit +#[inline] +fn offsetnz(block: u64) -> usize { + // fast path optimistic case (common for long valid sequences) + if block == 0 { + return 8; + } + + // perf: rust will unroll this loop + for (i, b) in block.to_ne_bytes().iter().copied().enumerate() { + if b != 0 { + return i; + } + } + unreachable!() +} + +#[test] +fn test_is_header_value_block() { + let is_header_value_block = |b| match_header_value_char_8_swar(b) == 8; + + // 0..32 => false + for b in 0..32_u8 { + assert_eq!(is_header_value_block([b; 8]), false, "b={}", b); + } + // 32..127 => true + for b in 32..127_u8 { + assert_eq!(is_header_value_block([b; 8]), true, "b={}", b); + } + // 127..=255 => false + for b in 127..=255_u8 { + assert_eq!(is_header_value_block([b; 8]), false, "b={}", b); + } + + // A few sanity checks on non-uniform bytes for safe-measure + assert!(!is_header_value_block(*b"foo.com\n")); + assert!(!is_header_value_block(*b"o.com\r\nU")); +} + +#[test] +fn test_is_uri_block() { + let is_uri_block = |b| match_uri_char_8_swar(b) == 8; + + // 0..33 => false + for b in 0..33_u8 { + assert_eq!(is_uri_block([b; 8]), false, "b={}", b); + } + // 33..127 => true if b not in { '<', '?', '>' } + let falsy = |b| b"".contains(&b); + for b in 33..127_u8 { + assert_eq!(is_uri_block([b; 8]), !falsy(b), "b={}", b); + } + // 127..=255 => false + for b in 127..=255_u8 { + assert_eq!(is_uri_block([b; 8]), false, "b={}", b); + } +} + +#[test] +fn test_offsetnz() { + let seq = [0_u8; 8]; + for i in 0..8 { + let mut seq = seq.clone(); + seq[i] = 1; + let x = u64::from_ne_bytes(seq); + assert_eq!(offsetnz(x), i); + } +} From 5c353c8e3d3b1b4c94b42836bb6c7de5ae988d70 Mon Sep 17 00:00:00 2001 From: Aaron O'Mullan Date: Tue, 25 Apr 2023 21:14:49 +0300 Subject: [PATCH 2/2] Delete simd/nop.rs --- src/simd/nop.rs | 8 -------- 1 file changed, 8 deletions(-) delete mode 100644 src/simd/nop.rs diff --git a/src/simd/nop.rs b/src/simd/nop.rs deleted file mode 100644 index 871cd01..0000000 --- a/src/simd/nop.rs +++ /dev/null @@ -1,8 +0,0 @@ -use crate::iter::Bytes; - -// Fallbacks that do nothing... - -#[inline(always)] -pub fn match_uri_vectored(_: &mut Bytes<'_>) {} -#[inline(always)] -pub fn match_header_value_vectored(_: &mut Bytes<'_>) {}