diff --git a/Cargo.toml b/Cargo.toml index 88f2b202c1..06a6863624 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -100,6 +100,7 @@ stream = [] runtime = [ "tcp", "tokio/rt", + "tokio/time", ] tcp = [ "socket2", diff --git a/src/error.rs b/src/error.rs index 470a23b601..0949e1cb85 100644 --- a/src/error.rs +++ b/src/error.rs @@ -44,6 +44,9 @@ pub(super) enum Kind { #[cfg(any(feature = "http1", feature = "http2"))] #[cfg(feature = "server")] Accept, + /// User took too long to send headers + #[cfg(all(feature = "http1", feature = "server", feature = "runtime"))] + HeaderTimeout, /// Error while reading a body from connection. #[cfg(any(feature = "http1", feature = "http2", feature = "stream"))] Body, @@ -310,6 +313,11 @@ impl Error { Error::new_user(User::UnexpectedHeader) } + #[cfg(all(feature = "http1", feature = "server", feature = "runtime"))] + pub(super) fn new_header_timeout() -> Error { + Error::new(Kind::HeaderTimeout) + } + #[cfg(any(feature = "http1", feature = "http2"))] #[cfg(feature = "client")] pub(super) fn new_user_unsupported_version() -> Error { @@ -419,6 +427,8 @@ impl Error { #[cfg(any(feature = "http1", feature = "http2"))] #[cfg(feature = "server")] Kind::Accept => "error accepting connection", + #[cfg(all(feature = "http1", feature = "server", feature = "runtime"))] + Kind::HeaderTimeout => "read header from client timeout", #[cfg(any(feature = "http1", feature = "http2", feature = "stream"))] Kind::Body => "error reading a body from connection", #[cfg(any(feature = "http1", feature = "http2"))] diff --git a/src/proto/h1/conn.rs b/src/proto/h1/conn.rs index eba765226e..ed694ec02c 100644 --- a/src/proto/h1/conn.rs +++ b/src/proto/h1/conn.rs @@ -1,12 +1,15 @@ use std::fmt; use std::io; use std::marker::PhantomData; +use std::time::Duration; use bytes::{Buf, Bytes}; use http::header::{HeaderValue, CONNECTION}; use http::{HeaderMap, Method, Version}; use httparse::ParserConfig; use tokio::io::{AsyncRead, AsyncWrite}; +#[cfg(all(feature = "server", feature = "runtime"))] +use tokio::time::Sleep; use tracing::{debug, error, trace}; use super::io::Buffered; @@ -47,6 +50,12 @@ where keep_alive: KA::Busy, method: None, h1_parser_config: ParserConfig::default(), + #[cfg(all(feature = "server", feature = "runtime"))] + h1_header_read_timeout: None, + #[cfg(all(feature = "server", feature = "runtime"))] + h1_header_read_timeout_fut: None, + #[cfg(all(feature = "server", feature = "runtime"))] + h1_header_read_timeout_running: false, preserve_header_case: false, title_case_headers: false, h09_responses: false, @@ -106,6 +115,11 @@ where self.state.h09_responses = true; } + #[cfg(all(feature = "server", feature = "runtime"))] + pub(crate) fn set_http1_header_read_timeout(&mut self, val: Duration) { + self.state.h1_header_read_timeout = Some(val); + } + #[cfg(feature = "server")] pub(crate) fn set_allow_half_close(&mut self) { self.state.allow_half_close = true; @@ -178,6 +192,12 @@ where cached_headers: &mut self.state.cached_headers, req_method: &mut self.state.method, h1_parser_config: self.state.h1_parser_config.clone(), + #[cfg(all(feature = "server", feature = "runtime"))] + h1_header_read_timeout: self.state.h1_header_read_timeout, + #[cfg(all(feature = "server", feature = "runtime"))] + h1_header_read_timeout_fut: &mut self.state.h1_header_read_timeout_fut, + #[cfg(all(feature = "server", feature = "runtime"))] + h1_header_read_timeout_running: &mut self.state.h1_header_read_timeout_running, preserve_header_case: self.state.preserve_header_case, h09_responses: self.state.h09_responses, #[cfg(feature = "ffi")] @@ -798,6 +818,12 @@ struct State { /// a body or not. method: Option, h1_parser_config: ParserConfig, + #[cfg(all(feature = "server", feature = "runtime"))] + h1_header_read_timeout: Option, + #[cfg(all(feature = "server", feature = "runtime"))] + h1_header_read_timeout_fut: Option>>, + #[cfg(all(feature = "server", feature = "runtime"))] + h1_header_read_timeout_running: bool, preserve_header_case: bool, title_case_headers: bool, h09_responses: bool, diff --git a/src/proto/h1/io.rs b/src/proto/h1/io.rs index 712aad44d7..69c4997073 100644 --- a/src/proto/h1/io.rs +++ b/src/proto/h1/io.rs @@ -3,10 +3,15 @@ use std::fmt; use std::io::{self, IoSlice}; use std::marker::Unpin; use std::mem::MaybeUninit; +use std::future::Future; +#[cfg(all(feature = "server", feature = "runtime"))] +use std::time::Duration; +#[cfg(all(feature = "server", feature = "runtime"))] +use tokio::time::Instant; use bytes::{Buf, BufMut, Bytes, BytesMut}; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; -use tracing::{debug, trace}; +use tracing::{debug, warn, trace}; use super::{Http1Transaction, ParseContext, ParsedMessage}; use crate::common::buf::BufList; @@ -181,6 +186,12 @@ where cached_headers: parse_ctx.cached_headers, req_method: parse_ctx.req_method, h1_parser_config: parse_ctx.h1_parser_config.clone(), + #[cfg(all(feature = "server", feature = "runtime"))] + h1_header_read_timeout: parse_ctx.h1_header_read_timeout, + #[cfg(all(feature = "server", feature = "runtime"))] + h1_header_read_timeout_fut: parse_ctx.h1_header_read_timeout_fut, + #[cfg(all(feature = "server", feature = "runtime"))] + h1_header_read_timeout_running: parse_ctx.h1_header_read_timeout_running, preserve_header_case: parse_ctx.preserve_header_case, h09_responses: parse_ctx.h09_responses, #[cfg(feature = "ffi")] @@ -191,6 +202,16 @@ where )? { Some(msg) => { debug!("parsed {} headers", msg.head.headers.len()); + + #[cfg(all(feature = "server", feature = "runtime"))] + { + *parse_ctx.h1_header_read_timeout_running = false; + + if let Some(h1_header_read_timeout_fut) = parse_ctx.h1_header_read_timeout_fut { + // Reset the timer in order to avoid woken up when the timeout finishes + h1_header_read_timeout_fut.as_mut().reset(Instant::now() + Duration::from_secs(30 * 24 * 60 * 60)); + } + } return Poll::Ready(Ok(msg)); } None => { @@ -199,6 +220,18 @@ where debug!("max_buf_size ({}) reached, closing", max); return Poll::Ready(Err(crate::Error::new_too_large())); } + + #[cfg(all(feature = "server", feature = "runtime"))] + if *parse_ctx.h1_header_read_timeout_running { + if let Some(h1_header_read_timeout_fut) = parse_ctx.h1_header_read_timeout_fut { + if Pin::new( h1_header_read_timeout_fut).poll(cx).is_ready() { + *parse_ctx.h1_header_read_timeout_running = false; + + warn!("read header from client timeout"); + return Poll::Ready(Err(crate::Error::new_header_timeout())) + } + } + } } } if ready!(self.poll_read_from_io(cx)).map_err(crate::Error::new_io)? == 0 { @@ -693,6 +726,9 @@ mod tests { cached_headers: &mut None, req_method: &mut None, h1_parser_config: Default::default(), + h1_header_read_timeout: None, + h1_header_read_timeout_fut: &mut None, + h1_header_read_timeout_running: &mut false, preserve_header_case: false, h09_responses: false, #[cfg(feature = "ffi")] diff --git a/src/proto/h1/mod.rs b/src/proto/h1/mod.rs index 758ac7b073..a39fabf13b 100644 --- a/src/proto/h1/mod.rs +++ b/src/proto/h1/mod.rs @@ -1,6 +1,11 @@ +use std::pin::Pin; +use std::time::Duration; + use bytes::BytesMut; use http::{HeaderMap, Method}; use httparse::ParserConfig; +#[cfg(all(feature = "server", feature = "runtime"))] +use tokio::time::Sleep; use crate::body::DecodedLength; use crate::proto::{BodyLength, MessageHead}; @@ -72,6 +77,12 @@ pub(crate) struct ParseContext<'a> { cached_headers: &'a mut Option, req_method: &'a mut Option, h1_parser_config: ParserConfig, + #[cfg(all(feature = "server", feature = "runtime"))] + h1_header_read_timeout: Option, + #[cfg(all(feature = "server", feature = "runtime"))] + h1_header_read_timeout_fut: &'a mut Option>>, + #[cfg(all(feature = "server", feature = "runtime"))] + h1_header_read_timeout_running: &'a mut bool, preserve_header_case: bool, h09_responses: bool, #[cfg(feature = "ffi")] diff --git a/src/proto/h1/role.rs b/src/proto/h1/role.rs index 5c701ed429..794ee7945c 100644 --- a/src/proto/h1/role.rs +++ b/src/proto/h1/role.rs @@ -1,6 +1,8 @@ use std::fmt::{self, Write}; use std::mem::MaybeUninit; +#[cfg(all(feature = "server", feature = "runtime"))] +use tokio::time::Instant; #[cfg(any(test, feature = "server", feature = "ffi"))] use bytes::Bytes; use bytes::BytesMut; @@ -69,6 +71,25 @@ where let span = trace_span!("parse_headers"); let _s = span.enter(); + + #[cfg(all(feature = "server", feature = "runtime"))] + if !*ctx.h1_header_read_timeout_running { + if let Some(h1_header_read_timeout) = ctx.h1_header_read_timeout { + let deadline = Instant::now() + h1_header_read_timeout; + + match ctx.h1_header_read_timeout_fut { + Some(h1_header_read_timeout_fut) => { + debug!("resetting h1 header read timeout timer"); + h1_header_read_timeout_fut.as_mut().reset(deadline); + }, + None => { + debug!("setting h1 header read timeout timer"); + *ctx.h1_header_read_timeout_fut = Some(Box::pin(tokio::time::sleep_until(deadline))); + } + } + } +} + T::parse(bytes, ctx) } @@ -1428,6 +1449,9 @@ mod tests { cached_headers: &mut None, req_method: &mut method, h1_parser_config: Default::default(), + h1_header_read_timeout: None, + h1_header_read_timeout_fut: &mut None, + h1_header_read_timeout_running: &mut false, preserve_header_case: false, h09_responses: false, #[cfg(feature = "ffi")] @@ -1455,6 +1479,9 @@ mod tests { cached_headers: &mut None, req_method: &mut Some(crate::Method::GET), h1_parser_config: Default::default(), + h1_header_read_timeout: None, + h1_header_read_timeout_fut: &mut None, + h1_header_read_timeout_running: &mut false, preserve_header_case: false, h09_responses: false, #[cfg(feature = "ffi")] @@ -1477,6 +1504,9 @@ mod tests { cached_headers: &mut None, req_method: &mut None, h1_parser_config: Default::default(), + h1_header_read_timeout: None, + h1_header_read_timeout_fut: &mut None, + h1_header_read_timeout_running: &mut false, preserve_header_case: false, h09_responses: false, #[cfg(feature = "ffi")] @@ -1497,6 +1527,9 @@ mod tests { cached_headers: &mut None, req_method: &mut Some(crate::Method::GET), h1_parser_config: Default::default(), + h1_header_read_timeout: None, + h1_header_read_timeout_fut: &mut None, + h1_header_read_timeout_running: &mut false, preserve_header_case: false, h09_responses: true, #[cfg(feature = "ffi")] @@ -1519,6 +1552,9 @@ mod tests { cached_headers: &mut None, req_method: &mut Some(crate::Method::GET), h1_parser_config: Default::default(), + h1_header_read_timeout: None, + h1_header_read_timeout_fut: &mut None, + h1_header_read_timeout_running: &mut false, preserve_header_case: false, h09_responses: false, #[cfg(feature = "ffi")] @@ -1545,6 +1581,9 @@ mod tests { cached_headers: &mut None, req_method: &mut Some(crate::Method::GET), h1_parser_config, + h1_header_read_timeout: None, + h1_header_read_timeout_fut: &mut None, + h1_header_read_timeout_running: &mut false, preserve_header_case: false, h09_responses: false, #[cfg(feature = "ffi")] @@ -1568,6 +1607,9 @@ mod tests { cached_headers: &mut None, req_method: &mut Some(crate::Method::GET), h1_parser_config: Default::default(), + h1_header_read_timeout: None, + h1_header_read_timeout_fut: &mut None, + h1_header_read_timeout_running: &mut false, preserve_header_case: false, h09_responses: false, #[cfg(feature = "ffi")] @@ -1586,6 +1628,9 @@ mod tests { cached_headers: &mut None, req_method: &mut None, h1_parser_config: Default::default(), + h1_header_read_timeout: None, + h1_header_read_timeout_fut: &mut None, + h1_header_read_timeout_running: &mut false, preserve_header_case: true, h09_responses: false, #[cfg(feature = "ffi")] @@ -1625,6 +1670,9 @@ mod tests { cached_headers: &mut None, req_method: &mut None, h1_parser_config: Default::default(), + h1_header_read_timeout: None, + h1_header_read_timeout_fut: &mut None, + h1_header_read_timeout_running: &mut false, preserve_header_case: false, h09_responses: false, #[cfg(feature = "ffi")] @@ -1645,6 +1693,9 @@ mod tests { cached_headers: &mut None, req_method: &mut None, h1_parser_config: Default::default(), + h1_header_read_timeout: None, + h1_header_read_timeout_fut: &mut None, + h1_header_read_timeout_running: &mut false, preserve_header_case: false, h09_responses: false, #[cfg(feature = "ffi")] @@ -1874,6 +1925,9 @@ mod tests { cached_headers: &mut None, req_method: &mut Some(Method::GET), h1_parser_config: Default::default(), + h1_header_read_timeout: None, + h1_header_read_timeout_fut: &mut None, + h1_header_read_timeout_running: &mut false, preserve_header_case: false, h09_responses: false, #[cfg(feature = "ffi")] @@ -1894,6 +1948,9 @@ mod tests { cached_headers: &mut None, req_method: &mut Some(m), h1_parser_config: Default::default(), + h1_header_read_timeout: None, + h1_header_read_timeout_fut: &mut None, + h1_header_read_timeout_running: &mut false, preserve_header_case: false, h09_responses: false, #[cfg(feature = "ffi")] @@ -1914,6 +1971,9 @@ mod tests { cached_headers: &mut None, req_method: &mut Some(Method::GET), h1_parser_config: Default::default(), + h1_header_read_timeout: None, + h1_header_read_timeout_fut: &mut None, + h1_header_read_timeout_running: &mut false, preserve_header_case: false, h09_responses: false, #[cfg(feature = "ffi")] @@ -2411,6 +2471,9 @@ mod tests { cached_headers: &mut None, req_method: &mut Some(Method::GET), h1_parser_config: Default::default(), + h1_header_read_timeout: None, + h1_header_read_timeout_fut: &mut None, + h1_header_read_timeout_running: &mut false, preserve_header_case: false, h09_responses: false, #[cfg(feature = "ffi")] @@ -2495,6 +2558,9 @@ mod tests { cached_headers: &mut headers, req_method: &mut None, h1_parser_config: Default::default(), + h1_header_read_timeout: None, + h1_header_read_timeout_fut: &mut None, + h1_header_read_timeout_running: &mut false, preserve_header_case: false, h09_responses: false, #[cfg(feature = "ffi")] @@ -2535,6 +2601,9 @@ mod tests { cached_headers: &mut headers, req_method: &mut None, h1_parser_config: Default::default(), + h1_header_read_timeout: None, + h1_header_read_timeout_fut: &mut None, + h1_header_read_timeout_running: &mut false, preserve_header_case: false, h09_responses: false, #[cfg(feature = "ffi")] diff --git a/src/server/conn.rs b/src/server/conn.rs index b71f144768..c49c8ae571 100644 --- a/src/server/conn.rs +++ b/src/server/conn.rs @@ -50,7 +50,6 @@ use std::marker::PhantomData; #[cfg(feature = "tcp")] use std::net::SocketAddr; -#[cfg(all(feature = "runtime", feature = "http2"))] use std::time::Duration; #[cfg(feature = "http2")] @@ -103,6 +102,8 @@ pub struct Http { h1_keep_alive: bool, h1_title_case_headers: bool, h1_preserve_header_case: bool, + #[cfg(all(feature = "http1", feature = "runtime"))] + h1_header_read_timeout: Option, h1_writev: Option, #[cfg(feature = "http2")] h2_builder: proto::h2::server::Config, @@ -285,6 +286,8 @@ impl Http { h1_keep_alive: true, h1_title_case_headers: false, h1_preserve_header_case: false, + #[cfg(all(feature = "http1", feature = "runtime"))] + h1_header_read_timeout: None, h1_writev: None, #[cfg(feature = "http2")] h2_builder: Default::default(), @@ -372,6 +375,17 @@ impl Http { self } + /// Set a timeout for reading client request headers. If a client does not + /// transmit the entire header within this time, the connection is closed. + /// + /// Default is None. + #[cfg(all(feature = "http1", feature = "runtime"))] + #[cfg_attr(docsrs, doc(cfg(all(feature = "http1", feature = "runtime"))))] + pub fn http1_header_read_timeout(&mut self, read_timeout: Duration) -> &mut Self { + self.h1_header_read_timeout = Some(read_timeout); + self + } + /// Set whether HTTP/1 connections should try to use vectored writes, /// or always flatten into a single buffer. /// @@ -567,6 +581,8 @@ impl Http { h1_keep_alive: self.h1_keep_alive, h1_title_case_headers: self.h1_title_case_headers, h1_preserve_header_case: self.h1_preserve_header_case, + #[cfg(all(feature = "http1", feature = "runtime"))] + h1_header_read_timeout: self.h1_header_read_timeout, h1_writev: self.h1_writev, #[cfg(feature = "http2")] h2_builder: self.h2_builder, @@ -629,6 +645,10 @@ impl Http { if self.h1_preserve_header_case { conn.set_preserve_header_case(); } + #[cfg(all(feature = "http1", feature = "runtime"))] + if let Some(header_read_timeout) = self.h1_header_read_timeout { + conn.set_http1_header_read_timeout(header_read_timeout); + } if let Some(writev) = self.h1_writev { if writev { conn.set_write_strategy_queue(); diff --git a/src/server/server.rs b/src/server/server.rs index 3f5261cb42..87027bbefc 100644 --- a/src/server/server.rs +++ b/src/server/server.rs @@ -1,7 +1,7 @@ use std::fmt; #[cfg(feature = "tcp")] use std::net::{SocketAddr, TcpListener as StdTcpListener}; -#[cfg(feature = "tcp")] +#[cfg(any(feature = "tcp", feature = "http1"))] use std::time::Duration; #[cfg(all(feature = "tcp", any(feature = "http1", feature = "http2")))] @@ -309,6 +309,17 @@ impl Builder { self } + /// Set a timeout for reading client request headers. If a client does not + /// transmit the entire header within this time, the connection is closed. + /// + /// Default is None. + #[cfg(all(feature = "http1", feature = "runtime"))] + #[cfg_attr(docsrs, doc(cfg(all(feature = "http1", feature = "runtime"))))] + pub fn http1_header_read_timeout(mut self, read_timeout: Duration) -> Self { + self.protocol.http1_header_read_timeout(read_timeout); + self + } + /// Sets whether HTTP/1 is required. /// /// Default is `false`. diff --git a/tests/server.rs b/tests/server.rs index 624c1eb8e7..16e905884c 100644 --- a/tests/server.rs +++ b/tests/server.rs @@ -1261,6 +1261,127 @@ fn header_name_too_long() { assert!(s(&buf[..n]).starts_with("HTTP/1.1 431 Request Header Fields Too Large\r\n")); } +#[tokio::test] +async fn header_read_timeout_slow_writes() { + let listener = tcp_bind(&"127.0.0.1:0".parse().unwrap()).unwrap(); + let addr = listener.local_addr().unwrap(); + + thread::spawn(move || { + let mut tcp = connect(&addr); + tcp.write_all( + b"\ + GET / HTTP/1.1\r\n\ + ", + ) + .expect("write 1"); + thread::sleep(Duration::from_secs(3)); + tcp.write_all( + b"\ + Something: 1\r\n\ + \r\n\ + ", + ) + .expect("write 2"); + thread::sleep(Duration::from_secs(6)); + tcp.write_all( + b"\ + Works: 0\r\n\ + ", + ) + .expect_err("write 3"); + }); + + let (socket, _) = listener.accept().await.unwrap(); + let conn = Http::new() + .http1_header_read_timeout(Duration::from_secs(5)) + .serve_connection( + socket, + service_fn(|_| { + let res = Response::builder() + .status(200) + .body(hyper::Body::empty()) + .unwrap(); + future::ready(Ok::<_, hyper::Error>(res)) + }), + ); + conn.without_shutdown().await.expect_err("header timeout"); +} + +#[tokio::test] +async fn header_read_timeout_slow_writes_multiple_requests() { + let listener = tcp_bind(&"127.0.0.1:0".parse().unwrap()).unwrap(); + let addr = listener.local_addr().unwrap(); + + thread::spawn(move || { + let mut tcp = connect(&addr); + + tcp.write_all( + b"\ + GET / HTTP/1.1\r\n\ + ", + ) + .expect("write 1"); + thread::sleep(Duration::from_secs(3)); + tcp.write_all( + b"\ + Something: 1\r\n\ + \r\n\ + ", + ) + .expect("write 2"); + + thread::sleep(Duration::from_secs(3)); + + tcp.write_all( + b"\ + GET / HTTP/1.1\r\n\ + ", + ) + .expect("write 3"); + thread::sleep(Duration::from_secs(3)); + tcp.write_all( + b"\ + Something: 1\r\n\ + \r\n\ + ", + ) + .expect("write 4"); + + thread::sleep(Duration::from_secs(6)); + + tcp.write_all( + b"\ + GET / HTTP/1.1\r\n\ + Something: 1\r\n\ + \r\n\ + ", + ) + .expect("write 5"); + thread::sleep(Duration::from_secs(6)); + tcp.write_all( + b"\ + Works: 0\r\n\ + ", + ) + .expect_err("write 6"); + }); + + let (socket, _) = listener.accept().await.unwrap(); + let conn = Http::new() + .http1_header_read_timeout(Duration::from_secs(5)) + .serve_connection( + socket, + service_fn(|_| { + let res = Response::builder() + .status(200) + .body(hyper::Body::empty()) + .unwrap(); + future::ready(Ok::<_, hyper::Error>(res)) + }), + ); + conn.without_shutdown().await.expect_err("header timeout"); +} + #[tokio::test] async fn upgrades() { use tokio::io::{AsyncReadExt, AsyncWriteExt};