Skip to content

Commit

Permalink
Convert header::name to use MaybeUninit (#555)
Browse files Browse the repository at this point in the history
Co-authored-by: Steven Bosnick <[email protected]>
  • Loading branch information
seanmonstar and sbosnick authored Jun 6, 2022
1 parent fecfdfb commit 5b2719a
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 18 deletions.
6 changes: 6 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ seahash = "3.0.5"
serde = "1.0"
serde_json = "1.0"
doc-comment = "0.3"
criterion = "0.3.2"

[[bench]]
name = "header_map"
Expand All @@ -45,6 +46,11 @@ path = "benches/header_map/mod.rs"
name = "header_name"
path = "benches/header_name.rs"

[[bench]]
name = "header_name2"
path = "benches/header_name2.rs"
harness = false

[[bench]]
name = "header_value"
path = "benches/header_value.rs"
Expand Down
52 changes: 52 additions & 0 deletions benches/header_name2.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
use criterion::{criterion_group, criterion_main, BenchmarkId,Criterion, Throughput};
use http::header::HeaderName;

// This is a list of some of the standard headers ordered by increasing size.
// It has exactly one standard header per size (some sizes don't have a standard
// header).
const STANDARD_HEADERS_BY_SIZE: &[&str] = &[
"te",
"age",
"date",
"allow",
"accept",
"alt-svc",
"if-match",
"forwarded",
"connection",
"retry-after",
"content-type",
"accept-ranges",
"accept-charset",
"accept-encoding",
"content-encoding",
"if-modified-since",
"proxy-authenticate",
"content-disposition",
"sec-websocket-accept",
"sec-websocket-version",
"access-control-max-age",
"content-security-policy",
"sec-websocket-extensions",
"strict-transport-security",
"access-control-allow-origin",
"access-control-allow-headers",
"access-control-expose-headers",
"access-control-request-headers",
"access-control-allow-credentials",
"content-security-policy-report-only",
];

fn header_name_by_size(c: &mut Criterion) {
let mut group = c.benchmark_group("std_hdr");
for name in STANDARD_HEADERS_BY_SIZE {
group.throughput(Throughput::Bytes(name.len() as u64));
group.bench_with_input(BenchmarkId::from_parameter(name), name, |b, name| {
b.iter(|| HeaderName::from_static(name) );
});
}
group.finish();
}

criterion_group!(benches, header_name_by_size);
criterion_main!(benches);
80 changes: 62 additions & 18 deletions src/header/name.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@ use std::borrow::Borrow;
use std::error::Error;
use std::convert::{TryFrom};
use std::hash::{Hash, Hasher};
use std::mem::MaybeUninit;
use std::str::FromStr;
use std::{fmt, mem};
use std::fmt;

/// Represents an HTTP header field name
///
Expand Down Expand Up @@ -50,6 +51,7 @@ enum Repr<T> {
struct Custom(ByteStr);

#[derive(Debug, Clone)]
// Invariant: If lower then buf is valid UTF-8.
struct MaybeLower<'a> {
buf: &'a [u8],
lower: bool,
Expand Down Expand Up @@ -986,6 +988,8 @@ standard_headers! {
/// / DIGIT / ALPHA
/// ; any VCHAR, except delimiters
/// ```
// HEADER_CHARS maps every byte that is 128 or larger to 0 so everything that is
// mapped by HEADER_CHARS, maps to a valid single-byte UTF-8 codepoint.
const HEADER_CHARS: [u8; 256] = [
// 0 1 2 3 4 5 6 7 8 9
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // x
Expand Down Expand Up @@ -1017,6 +1021,8 @@ const HEADER_CHARS: [u8; 256] = [
];

/// Valid header name characters for HTTP/2.0 and HTTP/3.0
// HEADER_CHARS_H2 maps every byte that is 128 or larger to 0 so everything that is
// mapped by HEADER_CHARS_H2, maps to a valid single-byte UTF-8 codepoint.
const HEADER_CHARS_H2: [u8; 256] = [
// 0 1 2 3 4 5 6 7 8 9
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // x
Expand Down Expand Up @@ -1049,15 +1055,18 @@ const HEADER_CHARS_H2: [u8; 256] = [

fn parse_hdr<'a>(
data: &'a [u8],
b: &'a mut [u8; 64],
b: &'a mut [MaybeUninit<u8>; SCRATCH_BUF_SIZE],
table: &[u8; 256],
) -> Result<HdrName<'a>, InvalidHeaderName> {
match data.len() {
0 => Err(InvalidHeaderName::new()),
len @ 1..=64 => {
len @ 1..=SCRATCH_BUF_SIZE => {
// Read from data into the buffer - transforming using `table` as we go
data.iter().zip(b.iter_mut()).for_each(|(index, out)| *out = table[*index as usize]);
let name = &b[0..len];
data.iter()
.zip(b.iter_mut())
.for_each(|(index, out)| *out = MaybeUninit::new(table[*index as usize]));
// Safety: len bytes of b were just initialized.
let name: &'a [u8] = unsafe { slice_assume_init(&b[0..len]) };
match StandardHeader::from_bytes(name) {
Some(sh) => Ok(sh.into()),
None => {
Expand All @@ -1069,7 +1078,7 @@ fn parse_hdr<'a>(
}
}
}
65..=super::MAX_HEADER_NAME_LEN => Ok(HdrName::custom(data, false)),
SCRATCH_BUF_OVERFLOW..=super::MAX_HEADER_NAME_LEN => Ok(HdrName::custom(data, false)),
_ => Err(InvalidHeaderName::new()),
}
}
Expand All @@ -1086,14 +1095,14 @@ impl HeaderName {
/// Converts a slice of bytes to an HTTP header name.
///
/// This function normalizes the input.
#[allow(deprecated)]
pub fn from_bytes(src: &[u8]) -> Result<HeaderName, InvalidHeaderName> {
#[allow(deprecated)]
let mut buf = unsafe { mem::uninitialized() };
let mut buf = uninit_u8_array();
// Precondition: HEADER_CHARS is a valid table for parse_hdr().
match parse_hdr(src, &mut buf, &HEADER_CHARS)?.inner {
Repr::Standard(std) => Ok(std.into()),
Repr::Custom(MaybeLower { buf, lower: true }) => {
let buf = Bytes::copy_from_slice(buf);
// Safety: the invariant on MaybeLower ensures buf is valid UTF-8.
let val = unsafe { ByteStr::from_utf8_unchecked(buf) };
Ok(Custom(val).into())
}
Expand All @@ -1102,6 +1111,7 @@ impl HeaderName {
let mut dst = BytesMut::with_capacity(buf.len());

for b in buf.iter() {
// HEADER_CHARS maps all bytes to valid single-byte UTF-8
let b = HEADER_CHARS[*b as usize];

if b == 0 {
Expand All @@ -1111,6 +1121,9 @@ impl HeaderName {
dst.put_u8(b);
}

// Safety: the loop above maps all bytes in buf to valid single byte
// UTF-8 before copying them into dst. This means that dst (and hence
// dst.freeze()) is valid UTF-8.
let val = unsafe { ByteStr::from_utf8_unchecked(dst.freeze()) };

Ok(Custom(val).into())
Expand All @@ -1136,25 +1149,29 @@ impl HeaderName {
/// // Parsing a header that contains uppercase characters
/// assert!(HeaderName::from_lowercase(b"Content-Length").is_err());
/// ```
#[allow(deprecated)]
pub fn from_lowercase(src: &[u8]) -> Result<HeaderName, InvalidHeaderName> {
#[allow(deprecated)]
let mut buf = unsafe { mem::uninitialized() };
let mut buf = uninit_u8_array();
// Precondition: HEADER_CHARS_H2 is a valid table for parse_hdr()
match parse_hdr(src, &mut buf, &HEADER_CHARS_H2)?.inner {
Repr::Standard(std) => Ok(std.into()),
Repr::Custom(MaybeLower { buf, lower: true }) => {
let buf = Bytes::copy_from_slice(buf);
// Safety: the invariant on MaybeLower ensures buf is valid UTF-8.
let val = unsafe { ByteStr::from_utf8_unchecked(buf) };
Ok(Custom(val).into())
}
Repr::Custom(MaybeLower { buf, lower: false }) => {
for &b in buf.iter() {
// HEADER_CHARS maps all bytes that are not valid single-byte
// UTF-8 to 0 so this check returns an error for invalid UTF-8.
if b != HEADER_CHARS[b as usize] {
return Err(InvalidHeaderName::new());
}
}

let buf = Bytes::copy_from_slice(buf);
// Safety: the loop above checks that each byte of buf (either
// version) is valid UTF-8.
let val = unsafe { ByteStr::from_utf8_unchecked(buf) };
Ok(Custom(val).into())
}
Expand Down Expand Up @@ -1481,33 +1498,33 @@ impl Error for InvalidHeaderName {}
// ===== HdrName =====

impl<'a> HdrName<'a> {
// Precondition: if lower then buf is valid UTF-8
fn custom(buf: &'a [u8], lower: bool) -> HdrName<'a> {
HdrName {
// Invariant (on MaybeLower): follows from the precondition
inner: Repr::Custom(MaybeLower {
buf: buf,
lower: lower,
}),
}
}

#[allow(deprecated)]
pub fn from_bytes<F, U>(hdr: &[u8], f: F) -> Result<U, InvalidHeaderName>
where F: FnOnce(HdrName<'_>) -> U,
{
#[allow(deprecated)]
let mut buf = unsafe { mem::uninitialized() };
let mut buf = uninit_u8_array();
// Precondition: HEADER_CHARS is a valid table for parse_hdr().
let hdr = parse_hdr(hdr, &mut buf, &HEADER_CHARS)?;
Ok(f(hdr))
}

#[allow(deprecated)]
pub fn from_static<F, U>(hdr: &'static str, f: F) -> U
where
F: FnOnce(HdrName<'_>) -> U,
{
#[allow(deprecated)]
let mut buf = unsafe { mem::uninitialized() };
let mut buf = uninit_u8_array();
let hdr =
// Precondition: HEADER_CHARS is a valid table for parse_hdr().
parse_hdr(hdr.as_bytes(), &mut buf, &HEADER_CHARS).expect("static str is invalid name");
f(hdr)
}
Expand All @@ -1523,6 +1540,7 @@ impl<'a> From<HdrName<'a>> for HeaderName {
Repr::Custom(maybe_lower) => {
if maybe_lower.lower {
let buf = Bytes::copy_from_slice(&maybe_lower.buf[..]);
// Safety: the invariant on MaybeLower ensures buf is valid UTF-8.
let byte_str = unsafe { ByteStr::from_utf8_unchecked(buf) };

HeaderName {
Expand All @@ -1533,9 +1551,14 @@ impl<'a> From<HdrName<'a>> for HeaderName {
let mut dst = BytesMut::with_capacity(maybe_lower.buf.len());

for b in maybe_lower.buf.iter() {
// HEADER_CHARS maps each byte to a valid single-byte UTF-8
// codepoint.
dst.put_u8(HEADER_CHARS[*b as usize]);
}

// Safety: the loop above maps each byte of maybe_lower.buf to a
// valid single-byte UTF-8 codepoint before copying it into dst.
// dst (and hence dst.freeze()) is thus valid UTF-8.
let buf = unsafe { ByteStr::from_utf8_unchecked(dst.freeze()) };

HeaderName {
Expand Down Expand Up @@ -1606,6 +1629,26 @@ fn eq_ignore_ascii_case(lower: &[u8], s: &[u8]) -> bool {
})
}

// Utility functions for MaybeUninit<>. These are drawn from unstable API's on
// MaybeUninit<> itself.
const SCRATCH_BUF_SIZE: usize = 64;
const SCRATCH_BUF_OVERFLOW: usize = SCRATCH_BUF_SIZE + 1;

fn uninit_u8_array() -> [MaybeUninit<u8>; SCRATCH_BUF_SIZE] {
let arr = MaybeUninit::<[MaybeUninit<u8>; SCRATCH_BUF_SIZE]>::uninit();
// Safety: assume_init() is claiming that an array of MaybeUninit<>
// has been initilized, but MaybeUninit<>'s do not require initilizaton.
unsafe { arr.assume_init() }
}

// Assuming all the elements are initilized, get a slice of them.
//
// Safety: All elements of `slice` must be initilized to prevent
// undefined behavior.
unsafe fn slice_assume_init<T>(slice: &[MaybeUninit<T>]) -> &[T] {
&*(slice as *const [MaybeUninit<T>] as *const [T])
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -1652,6 +1695,7 @@ mod tests {
#[test]
#[should_panic]
fn test_static_invalid_name_lengths() {
// Safety: ONE_TOO_LONG contains only the UTF-8 safe, single-byte codepoint b'a'.
let _ = HeaderName::from_static(unsafe { std::str::from_utf8_unchecked(ONE_TOO_LONG) });
}

Expand Down

0 comments on commit 5b2719a

Please sign in to comment.