Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: simd swar #134

Merged
merged 2 commits into from
Apr 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
234 changes: 18 additions & 216 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,65 +87,10 @@ static URI_MAP: [bool; 256] = byte_map![
];

#[inline]
fn is_uri_token(b: u8) -> bool {
pub(crate) fn is_uri_token(b: u8) -> bool {
URI_MAP[b as usize]
}

// A const alternative to u64::from_ne_bytes to avoid bumping MSRV (1.36 => 1.44)
// creates a u64 whose bytes are each equal to b
const fn uniform_block(b: u8) -> u64 {
b as u64 * 0x01_01_01_01_01_01_01_01 // [1_u8; 8]
}

// A byte-wise range-check on an enire word/block,
// ensuring all bytes in the word satisfy
// `33 <= x <= 126 && x != '>' && x != '<'`
// it false negatives if the block contains '?'
#[inline]
fn validate_uri_block(block: [u8; 8]) -> usize {
// 33 <= x <= 126
const M: u8 = 0x21;
const N: u8 = 0x7E;
const BM: u64 = uniform_block(M);
const BN: u64 = uniform_block(127-N);
const M128: u64 = uniform_block(128);

let x = u64::from_ne_bytes(block); // Really just a transmute
let lt = x.wrapping_sub(BM) & !x; // <= m
let gt = x.wrapping_add(BN) | x; // >= n

// XOR checks to catch '<' & '>' for correctness
//
// XOR can be thought of as a "distance function"
// (somewhat extrapolating from the `xor(x, x) = 0` identity and ∀ x != y: xor(x, y) != 0`
// (each u8 "xor key" providing a unique total ordering of u8)
// '<' and '>' have a "xor distance" of 2 (`xor('<', '>') = 2`)
// xor(x, '>') <= 2 => {'>', '?', '<'}
// xor(x, '<') <= 2 => {'<', '=', '>'}
//
// We assume P('=') > P('?'),
// given well/commonly-formatted URLs with querystrings contain
// a single '?' but possibly many '='
//
// Thus it's preferable/near-optimal to "xor distance" on '>',
// since we'll slowpath at most one block per URL
//
// Some rust code to sanity check this yourself:
// ```rs
// fn xordist(x: u8, n: u8) -> Vec<(char, u8)> {
// (0..=255).into_iter().map(|c| (c as char, c ^ x)).filter(|(_c, y)| *y <= n).collect()
// }
// (xordist(b'<', 2), xordist(b'>', 2))
// ```
const B3: u64 = uniform_block(3); // (dist <= 2) + 1 to wrap
const BGT: u64 = uniform_block(b'>');

let xgt = x ^ BGT;
let ltgtq = xgt.wrapping_sub(B3) & !xgt;

offsetnz((ltgtq | lt | gt) & M128)
}

static HEADER_NAME_MAP: [bool; 256] = byte_map![
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
Expand All @@ -166,7 +111,7 @@ static HEADER_NAME_MAP: [bool; 256] = byte_map![
];

#[inline]
fn is_header_name_token(b: u8) -> bool {
pub(crate) fn is_header_name_token(b: u8) -> bool {
HEADER_NAME_MAP[b as usize]
}

Expand All @@ -191,45 +136,10 @@ static HEADER_VALUE_MAP: [bool; 256] = byte_map![


#[inline]
fn is_header_value_token(b: u8) -> bool {
pub(crate) fn is_header_value_token(b: u8) -> bool {
HEADER_VALUE_MAP[b as usize]
}

// A byte-wise range-check on an entire word/block,
// ensuring all bytes in the word satisfy `32 <= x <= 126`
#[inline]
fn validate_header_value_block(block: [u8; 8]) -> usize {
// 32 <= x <= 126
const M: u8 = 0x20;
const N: u8 = 0x7E;
const BM: u64 = uniform_block(M);
const BN: u64 = uniform_block(127-N);
const M128: u64 = uniform_block(128);

let x = u64::from_ne_bytes(block); // Really just a transmute
let lt = x.wrapping_sub(BM) & !x; // <= m
let gt = x.wrapping_add(BN) | x; // >= n
offsetnz((lt | gt) & M128)
}

#[inline]
/// Check block to find offset of first non-zero byte
// NOTE: Curiously `block.trailing_zeros() >> 3` appears to be slower, maybe revisit
fn offsetnz(block: u64) -> usize {
// fast path optimistic case (common for long valid sequences)
if block == 0 {
return 8;
}

// perf: rust will unroll this loop
for (i, b) in block.to_ne_bytes().iter().copied().enumerate() {
if b != 0 {
return i;
}
}
unreachable!()
}

/// An error in parsing.
#[derive(Copy, Clone, PartialEq, Eq, Debug)]
pub enum Error {
Expand Down Expand Up @@ -966,28 +876,14 @@ fn parse_token<'a>(bytes: &mut Bytes<'a>) -> Result<&'a str> {
#[allow(missing_docs)]
// WARNING: Exported for internal benchmarks, not fit for public consumption
pub fn parse_uri<'a>(bytes: &mut Bytes<'a>) -> Result<&'a str> {
let b = next!(bytes);
if !is_uri_token(b) {
// First char must be a URI char, it can't be a space which would indicate an empty path.
let start = bytes.pos();
simd::match_uri_vectored(bytes);
// URI must have at least one char
if bytes.pos() == start {
return Err(Error::Token);
}

simd::match_uri_vectored(bytes);

let mut b;
loop {
if let Some(bytes8) = bytes.peek_n::<[u8; 8]>(8) {
let n = validate_uri_block(bytes8);
unsafe { bytes.advance(n); }
if n == 8 { continue; }
}
b = next!(bytes);
if !is_uri_token(b) {
break;
}
}

if b == b' ' {
if next!(bytes) == b' ' {
return Ok(Status::Complete(unsafe {
// all bytes up till `i` must have been `is_token`.
str::from_utf8_unchecked(bytes.slice_skip(1))
Expand Down Expand Up @@ -1099,7 +995,8 @@ fn parse_headers_iter_uninit<'a, 'b>(
headers,
num_headers: 0,
};
let mut count: usize = 0;
// Track starting pointer to calculate the number of bytes parsed.
let start = bytes.as_ref().as_ptr() as usize;
let mut result = Err(Error::TooManyHeaders);

let mut iter = autoshrink.headers.iter_mut();
Expand Down Expand Up @@ -1155,7 +1052,6 @@ fn parse_headers_iter_uninit<'a, 'b>(
b = next!($bytes);
}

count += $bytes.pos();
$bytes.slice();

continue 'headers;
Expand All @@ -1166,50 +1062,24 @@ fn parse_headers_iter_uninit<'a, 'b>(
let b = next!(bytes);
if b == b'\r' {
expect!(bytes.next() == b'\n' => Err(Error::NewLine));
result = Ok(Status::Complete(count + bytes.pos()));
let end = bytes.as_ref().as_ptr() as usize;
result = Ok(Status::Complete(end - start));
break;
}
if b == b'\n' {
result = Ok(Status::Complete(count + bytes.pos()));
let end = bytes.as_ref().as_ptr() as usize;
result = Ok(Status::Complete(end - start));
break;
}
if !is_header_name_token(b) {
handle_invalid_char!(bytes, b, HeaderName);
}

// parse header name until colon
let mut b;
let header_name: &str = 'name: loop {
'name_inner: loop {
if let Some(bytes8) = bytes.peek_n::<[u8; 8]>(8) {
macro_rules! check {
($bytes:ident, $i:literal) => ({
b = $bytes[$i];
if !is_header_name_token(b) {
unsafe { bytes.advance($i + 1); }
break 'name_inner;
}
});
}

check!(bytes8, 0);
check!(bytes8, 1);
check!(bytes8, 2);
check!(bytes8, 3);
check!(bytes8, 4);
check!(bytes8, 5);
check!(bytes8, 6);
check!(bytes8, 7);
unsafe { bytes.advance(8); }
} else {
b = next!(bytes);
if !is_header_name_token(b) {
break 'name_inner;
}
}
}

count += bytes.pos();
simd::match_header_name_vectored(bytes);
let mut b = next!(bytes);

let name = unsafe {
str::from_utf8_unchecked(bytes.slice_skip(1))
};
Expand All @@ -1223,7 +1093,6 @@ fn parse_headers_iter_uninit<'a, 'b>(
b = next!(bytes);

if b == b':' {
count += bytes.pos();
bytes.slice();
break 'name name;
}
Expand All @@ -1240,7 +1109,6 @@ fn parse_headers_iter_uninit<'a, 'b>(
'whitespace_after_colon: loop {
b = next!(bytes);
if b == b' ' || b == b'\t' {
count += bytes.pos();
bytes.slice();
continue 'whitespace_after_colon;
}
Expand All @@ -1256,7 +1124,6 @@ fn parse_headers_iter_uninit<'a, 'b>(

maybe_continue_after_obsolete_line_folding!(bytes, 'whitespace_after_colon);

count += bytes.pos();
let whitespace_slice = bytes.slice();

// This produces an empty slice that points to the beginning
Expand All @@ -1268,18 +1135,7 @@ fn parse_headers_iter_uninit<'a, 'b>(
// parse value till EOL

simd::match_header_value_vectored(bytes);

'value_line: loop {
if let Some(bytes8) = bytes.peek_n::<[u8; 8]>(8) {
let n = validate_header_value_block(bytes8);
unsafe { bytes.advance(n); }
if n == 8 { continue 'value_line; }
}
b = next!(bytes);
if !is_header_value_token(b) {
break 'value_line;
}
}
let b = next!(bytes);

//found_ctl
let skip = if b == b'\r' {
Expand All @@ -1293,7 +1149,6 @@ fn parse_headers_iter_uninit<'a, 'b>(

maybe_continue_after_obsolete_line_folding!(bytes, 'value_lines);

count += bytes.pos();
// having just checked that a newline exists, it's safe to skip it.
unsafe {
break 'value bytes.slice_skip(skip);
Expand Down Expand Up @@ -1409,7 +1264,6 @@ pub fn parse_chunk_size(buf: &[u8])
#[cfg(test)]
mod tests {
use super::{Request, Response, Status, EMPTY_HEADER, parse_chunk_size};
use super::{offsetnz, validate_header_value_block, validate_uri_block};

const NUM_OF_HEADERS: usize = 4;

Expand Down Expand Up @@ -2376,58 +2230,6 @@ mod tests {
assert_eq!(response.headers[0].value, &b"baguette"[..]);
}

#[test]
fn test_is_header_value_block() {
let is_header_value_block = |b| validate_header_value_block(b) == 8;

// 0..32 => false
for b in 0..32_u8 {
assert_eq!(is_header_value_block([b; 8]), false, "b={}", b);
}
// 32..127 => true
for b in 32..127_u8 {
assert_eq!(is_header_value_block([b; 8]), true, "b={}", b);
}
// 127..=255 => false
for b in 127..=255_u8 {
assert_eq!(is_header_value_block([b; 8]), false, "b={}", b);
}

// A few sanity checks on non-uniform bytes for safe-measure
assert!(!is_header_value_block(*b"foo.com\n"));
assert!(!is_header_value_block(*b"o.com\r\nU"));
}

#[test]
fn test_is_uri_block() {
let is_uri_block = |b| validate_uri_block(b) == 8;

// 0..33 => false
for b in 0..33_u8 {
assert_eq!(is_uri_block([b; 8]), false, "b={}", b);
}
// 33..127 => true if b not in { '<', '?', '>' }
let falsy = |b| b"<?>".contains(&b);
for b in 33..127_u8 {
assert_eq!(is_uri_block([b; 8]), !falsy(b), "b={}", b);
}
// 127..=255 => false
for b in 127..=255_u8 {
assert_eq!(is_uri_block([b; 8]), false, "b={}", b);
}
}

#[test]
fn test_offsetnz() {
let seq = [0_u8; 8];
for i in 0..8 {
let mut seq = seq.clone();
seq[i] = 1;
let x = u64::from_ne_bytes(seq);
assert_eq!(offsetnz(x), i);
}
}

#[test]
fn test_method_within_buffer() {
const REQUEST: &[u8] = b"GET / HTTP/1.1\r\n\r\n";
Expand Down
8 changes: 0 additions & 8 deletions src/simd/fallback.rs

This file was deleted.

Loading