From 5b2719acccada786eaa643a18a508250aeb98a6f Mon Sep 17 00:00:00 2001 From: Sean McArthur Date: Mon, 6 Jun 2022 14:08:18 -0700 Subject: [PATCH] Convert header::name to use MaybeUninit (#555) Co-authored-by: Steven Bosnick --- Cargo.toml | 6 ++++ benches/header_name2.rs | 52 +++++++++++++++++++++++++++ src/header/name.rs | 80 +++++++++++++++++++++++++++++++---------- 3 files changed, 120 insertions(+), 18 deletions(-) create mode 100644 benches/header_name2.rs diff --git a/Cargo.toml b/Cargo.toml index 3542e1c1..209dd235 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -36,6 +36,7 @@ seahash = "3.0.5" serde = "1.0" serde_json = "1.0" doc-comment = "0.3" +criterion = "0.3.2" [[bench]] name = "header_map" @@ -45,6 +46,11 @@ path = "benches/header_map/mod.rs" name = "header_name" path = "benches/header_name.rs" +[[bench]] +name = "header_name2" +path = "benches/header_name2.rs" +harness = false + [[bench]] name = "header_value" path = "benches/header_value.rs" diff --git a/benches/header_name2.rs b/benches/header_name2.rs new file mode 100644 index 00000000..4562fd66 --- /dev/null +++ b/benches/header_name2.rs @@ -0,0 +1,52 @@ +use criterion::{criterion_group, criterion_main, BenchmarkId,Criterion, Throughput}; +use http::header::HeaderName; + +// This is a list of some of the standard headers ordered by increasing size. +// It has exactly one standard header per size (some sizes don't have a standard +// header). +const STANDARD_HEADERS_BY_SIZE: &[&str] = &[ + "te", + "age", + "date", + "allow", + "accept", + "alt-svc", + "if-match", + "forwarded", + "connection", + "retry-after", + "content-type", + "accept-ranges", + "accept-charset", + "accept-encoding", + "content-encoding", + "if-modified-since", + "proxy-authenticate", + "content-disposition", + "sec-websocket-accept", + "sec-websocket-version", + "access-control-max-age", + "content-security-policy", + "sec-websocket-extensions", + "strict-transport-security", + "access-control-allow-origin", + "access-control-allow-headers", + "access-control-expose-headers", + "access-control-request-headers", + "access-control-allow-credentials", + "content-security-policy-report-only", +]; + +fn header_name_by_size(c: &mut Criterion) { + let mut group = c.benchmark_group("std_hdr"); + for name in STANDARD_HEADERS_BY_SIZE { + group.throughput(Throughput::Bytes(name.len() as u64)); + group.bench_with_input(BenchmarkId::from_parameter(name), name, |b, name| { + b.iter(|| HeaderName::from_static(name) ); + }); + } + group.finish(); +} + +criterion_group!(benches, header_name_by_size); +criterion_main!(benches); diff --git a/src/header/name.rs b/src/header/name.rs index f8872257..f0eaeb77 100644 --- a/src/header/name.rs +++ b/src/header/name.rs @@ -5,8 +5,9 @@ use std::borrow::Borrow; use std::error::Error; use std::convert::{TryFrom}; use std::hash::{Hash, Hasher}; +use std::mem::MaybeUninit; use std::str::FromStr; -use std::{fmt, mem}; +use std::fmt; /// Represents an HTTP header field name /// @@ -50,6 +51,7 @@ enum Repr { struct Custom(ByteStr); #[derive(Debug, Clone)] +// Invariant: If lower then buf is valid UTF-8. struct MaybeLower<'a> { buf: &'a [u8], lower: bool, @@ -986,6 +988,8 @@ standard_headers! { /// / DIGIT / ALPHA /// ; any VCHAR, except delimiters /// ``` +// HEADER_CHARS maps every byte that is 128 or larger to 0 so everything that is +// mapped by HEADER_CHARS, maps to a valid single-byte UTF-8 codepoint. const HEADER_CHARS: [u8; 256] = [ // 0 1 2 3 4 5 6 7 8 9 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // x @@ -1017,6 +1021,8 @@ const HEADER_CHARS: [u8; 256] = [ ]; /// Valid header name characters for HTTP/2.0 and HTTP/3.0 +// HEADER_CHARS_H2 maps every byte that is 128 or larger to 0 so everything that is +// mapped by HEADER_CHARS_H2, maps to a valid single-byte UTF-8 codepoint. const HEADER_CHARS_H2: [u8; 256] = [ // 0 1 2 3 4 5 6 7 8 9 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // x @@ -1049,15 +1055,18 @@ const HEADER_CHARS_H2: [u8; 256] = [ fn parse_hdr<'a>( data: &'a [u8], - b: &'a mut [u8; 64], + b: &'a mut [MaybeUninit; SCRATCH_BUF_SIZE], table: &[u8; 256], ) -> Result, InvalidHeaderName> { match data.len() { 0 => Err(InvalidHeaderName::new()), - len @ 1..=64 => { + len @ 1..=SCRATCH_BUF_SIZE => { // Read from data into the buffer - transforming using `table` as we go - data.iter().zip(b.iter_mut()).for_each(|(index, out)| *out = table[*index as usize]); - let name = &b[0..len]; + data.iter() + .zip(b.iter_mut()) + .for_each(|(index, out)| *out = MaybeUninit::new(table[*index as usize])); + // Safety: len bytes of b were just initialized. + let name: &'a [u8] = unsafe { slice_assume_init(&b[0..len]) }; match StandardHeader::from_bytes(name) { Some(sh) => Ok(sh.into()), None => { @@ -1069,7 +1078,7 @@ fn parse_hdr<'a>( } } } - 65..=super::MAX_HEADER_NAME_LEN => Ok(HdrName::custom(data, false)), + SCRATCH_BUF_OVERFLOW..=super::MAX_HEADER_NAME_LEN => Ok(HdrName::custom(data, false)), _ => Err(InvalidHeaderName::new()), } } @@ -1086,14 +1095,14 @@ impl HeaderName { /// Converts a slice of bytes to an HTTP header name. /// /// This function normalizes the input. - #[allow(deprecated)] pub fn from_bytes(src: &[u8]) -> Result { - #[allow(deprecated)] - let mut buf = unsafe { mem::uninitialized() }; + let mut buf = uninit_u8_array(); + // Precondition: HEADER_CHARS is a valid table for parse_hdr(). match parse_hdr(src, &mut buf, &HEADER_CHARS)?.inner { Repr::Standard(std) => Ok(std.into()), Repr::Custom(MaybeLower { buf, lower: true }) => { let buf = Bytes::copy_from_slice(buf); + // Safety: the invariant on MaybeLower ensures buf is valid UTF-8. let val = unsafe { ByteStr::from_utf8_unchecked(buf) }; Ok(Custom(val).into()) } @@ -1102,6 +1111,7 @@ impl HeaderName { let mut dst = BytesMut::with_capacity(buf.len()); for b in buf.iter() { + // HEADER_CHARS maps all bytes to valid single-byte UTF-8 let b = HEADER_CHARS[*b as usize]; if b == 0 { @@ -1111,6 +1121,9 @@ impl HeaderName { dst.put_u8(b); } + // Safety: the loop above maps all bytes in buf to valid single byte + // UTF-8 before copying them into dst. This means that dst (and hence + // dst.freeze()) is valid UTF-8. let val = unsafe { ByteStr::from_utf8_unchecked(dst.freeze()) }; Ok(Custom(val).into()) @@ -1136,25 +1149,29 @@ impl HeaderName { /// // Parsing a header that contains uppercase characters /// assert!(HeaderName::from_lowercase(b"Content-Length").is_err()); /// ``` - #[allow(deprecated)] pub fn from_lowercase(src: &[u8]) -> Result { - #[allow(deprecated)] - let mut buf = unsafe { mem::uninitialized() }; + let mut buf = uninit_u8_array(); + // Precondition: HEADER_CHARS_H2 is a valid table for parse_hdr() match parse_hdr(src, &mut buf, &HEADER_CHARS_H2)?.inner { Repr::Standard(std) => Ok(std.into()), Repr::Custom(MaybeLower { buf, lower: true }) => { let buf = Bytes::copy_from_slice(buf); + // Safety: the invariant on MaybeLower ensures buf is valid UTF-8. let val = unsafe { ByteStr::from_utf8_unchecked(buf) }; Ok(Custom(val).into()) } Repr::Custom(MaybeLower { buf, lower: false }) => { for &b in buf.iter() { + // HEADER_CHARS maps all bytes that are not valid single-byte + // UTF-8 to 0 so this check returns an error for invalid UTF-8. if b != HEADER_CHARS[b as usize] { return Err(InvalidHeaderName::new()); } } let buf = Bytes::copy_from_slice(buf); + // Safety: the loop above checks that each byte of buf (either + // version) is valid UTF-8. let val = unsafe { ByteStr::from_utf8_unchecked(buf) }; Ok(Custom(val).into()) } @@ -1481,8 +1498,10 @@ impl Error for InvalidHeaderName {} // ===== HdrName ===== impl<'a> HdrName<'a> { + // Precondition: if lower then buf is valid UTF-8 fn custom(buf: &'a [u8], lower: bool) -> HdrName<'a> { HdrName { + // Invariant (on MaybeLower): follows from the precondition inner: Repr::Custom(MaybeLower { buf: buf, lower: lower, @@ -1490,24 +1509,22 @@ impl<'a> HdrName<'a> { } } - #[allow(deprecated)] pub fn from_bytes(hdr: &[u8], f: F) -> Result where F: FnOnce(HdrName<'_>) -> U, { - #[allow(deprecated)] - let mut buf = unsafe { mem::uninitialized() }; + let mut buf = uninit_u8_array(); + // Precondition: HEADER_CHARS is a valid table for parse_hdr(). let hdr = parse_hdr(hdr, &mut buf, &HEADER_CHARS)?; Ok(f(hdr)) } - #[allow(deprecated)] pub fn from_static(hdr: &'static str, f: F) -> U where F: FnOnce(HdrName<'_>) -> U, { - #[allow(deprecated)] - let mut buf = unsafe { mem::uninitialized() }; + let mut buf = uninit_u8_array(); let hdr = + // Precondition: HEADER_CHARS is a valid table for parse_hdr(). parse_hdr(hdr.as_bytes(), &mut buf, &HEADER_CHARS).expect("static str is invalid name"); f(hdr) } @@ -1523,6 +1540,7 @@ impl<'a> From> for HeaderName { Repr::Custom(maybe_lower) => { if maybe_lower.lower { let buf = Bytes::copy_from_slice(&maybe_lower.buf[..]); + // Safety: the invariant on MaybeLower ensures buf is valid UTF-8. let byte_str = unsafe { ByteStr::from_utf8_unchecked(buf) }; HeaderName { @@ -1533,9 +1551,14 @@ impl<'a> From> for HeaderName { let mut dst = BytesMut::with_capacity(maybe_lower.buf.len()); for b in maybe_lower.buf.iter() { + // HEADER_CHARS maps each byte to a valid single-byte UTF-8 + // codepoint. dst.put_u8(HEADER_CHARS[*b as usize]); } + // Safety: the loop above maps each byte of maybe_lower.buf to a + // valid single-byte UTF-8 codepoint before copying it into dst. + // dst (and hence dst.freeze()) is thus valid UTF-8. let buf = unsafe { ByteStr::from_utf8_unchecked(dst.freeze()) }; HeaderName { @@ -1606,6 +1629,26 @@ fn eq_ignore_ascii_case(lower: &[u8], s: &[u8]) -> bool { }) } +// Utility functions for MaybeUninit<>. These are drawn from unstable API's on +// MaybeUninit<> itself. +const SCRATCH_BUF_SIZE: usize = 64; +const SCRATCH_BUF_OVERFLOW: usize = SCRATCH_BUF_SIZE + 1; + +fn uninit_u8_array() -> [MaybeUninit; SCRATCH_BUF_SIZE] { + let arr = MaybeUninit::<[MaybeUninit; SCRATCH_BUF_SIZE]>::uninit(); + // Safety: assume_init() is claiming that an array of MaybeUninit<> + // has been initilized, but MaybeUninit<>'s do not require initilizaton. + unsafe { arr.assume_init() } +} + +// Assuming all the elements are initilized, get a slice of them. +// +// Safety: All elements of `slice` must be initilized to prevent +// undefined behavior. +unsafe fn slice_assume_init(slice: &[MaybeUninit]) -> &[T] { + &*(slice as *const [MaybeUninit] as *const [T]) +} + #[cfg(test)] mod tests { use super::*; @@ -1652,6 +1695,7 @@ mod tests { #[test] #[should_panic] fn test_static_invalid_name_lengths() { + // Safety: ONE_TOO_LONG contains only the UTF-8 safe, single-byte codepoint b'a'. let _ = HeaderName::from_static(unsafe { std::str::from_utf8_unchecked(ONE_TOO_LONG) }); }