diff --git a/Cargo.toml b/Cargo.toml index 05b04fe649..cb16c36c05 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -45,6 +45,7 @@ num_cpus = "1.0" pretty_env_logger = "0.2.0" spmc = "0.2" url = "1.0" +tokio-mockstream = "1.1.0" [features] default = [ diff --git a/src/error.rs b/src/error.rs index 3ec68dfaab..23da310238 100644 --- a/src/error.rs +++ b/src/error.rs @@ -69,6 +69,7 @@ pub(crate) enum Kind { pub(crate) enum Parse { Method, Version, + VersionH2, Uri, Header, TooLarge, @@ -164,6 +165,10 @@ impl Error { Error::new(Kind::Parse(Parse::Version), None) } + pub(crate) fn new_version_h2() -> Error { + Error::new(Kind::Parse(Parse::VersionH2), None) + } + pub(crate) fn new_mismatched_response() -> Error { Error::new(Kind::MismatchedResponse, None) } @@ -250,6 +255,7 @@ impl StdError for Error { match self.inner.kind { Kind::Parse(Parse::Method) => "invalid Method specified", Kind::Parse(Parse::Version) => "invalid HTTP version specified", + Kind::Parse(Parse::VersionH2) => "invalid HTTP version specified (Http2)", Kind::Parse(Parse::Uri) => "invalid URI", Kind::Parse(Parse::Header) => "invalid Header provided", Kind::Parse(Parse::TooLarge) => "message head is too large", diff --git a/src/proto/h1/conn.rs b/src/proto/h1/conn.rs index a4d6167c76..055cccc02f 100644 --- a/src/proto/h1/conn.rs +++ b/src/proto/h1/conn.rs @@ -12,6 +12,7 @@ use proto::{BodyLength, Decode, Http1Transaction, MessageHead}; use super::io::{Buffered}; use super::{EncodedBuf, Encoder, Decoder}; +const H2_PREFACE: &'static [u8] = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n"; /// This handles a connection, which will have been established over an /// `AsyncRead + AsyncWrite` (like a socket), and will likely include multiple @@ -107,6 +108,11 @@ where I: AsyncRead + AsyncWrite, T::should_error_on_parse_eof() && !self.state.is_idle() } + fn has_h2_prefix(&self) -> bool { + let read_buf = self.io.read_buf(); + read_buf.len() >= 24 && read_buf[..24] == *H2_PREFACE + } + pub fn read_head(&mut self) -> Poll, bool)>, ::Error> { debug_assert!(self.can_read_head()); trace!("Conn::read_head"); @@ -124,6 +130,7 @@ where I: AsyncRead + AsyncWrite, self.io.consume_leading_lines(); let was_mid_parse = e.is_parse() || !self.io.read_buf().is_empty(); return if was_mid_parse || must_error { + // We check if the buf contains the h2 Preface debug!("parse error ({}) with {} bytes", e, self.io.read_buf().len()); self.on_parse_error(e) .map(|()| Async::NotReady) @@ -529,8 +536,12 @@ where I: AsyncRead + AsyncWrite, // - Client: there is nothing we can do // - Server: if Response hasn't been written yet, we can send a 4xx response fn on_parse_error(&mut self, err: ::Error) -> ::Result<()> { + match self.state.writing { Writing::Init => { + if self.has_h2_prefix() { + return Err(::Error::new_version_h2()) + } if let Some(msg) = T::on_error(&err) { self.write_head(msg, None); self.state.error = Some(err); diff --git a/src/proto/h1/dispatch.rs b/src/proto/h1/dispatch.rs index 1f0ceef721..1965f78748 100644 --- a/src/proto/h1/dispatch.rs +++ b/src/proto/h1/dispatch.rs @@ -332,6 +332,9 @@ impl Server where S: Service { service: service, } } + pub fn into_service(self) -> S { + self.service + } } impl Dispatch for Server diff --git a/src/proto/h1/role.rs b/src/proto/h1/role.rs index a9d066c401..1fae294ede 100644 --- a/src/proto/h1/role.rs +++ b/src/proto/h1/role.rs @@ -186,14 +186,14 @@ where use ::error::{Kind, Parse}; let status = match *err.kind() { Kind::Parse(Parse::Method) | - Kind::Parse(Parse::Version) | Kind::Parse(Parse::Header) | - Kind::Parse(Parse::Uri) => { + Kind::Parse(Parse::Uri) | + Kind::Parse(Parse::Version) => { StatusCode::BAD_REQUEST }, Kind::Parse(Parse::TooLarge) => { StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE - } + }, _ => return None, }; diff --git a/src/server/conn.rs b/src/server/conn.rs index e650ec8342..8d59030f90 100644 --- a/src/server/conn.rs +++ b/src/server/conn.rs @@ -13,6 +13,7 @@ use std::fmt; use std::sync::Arc; #[cfg(feature = "runtime")] use std::time::Duration; +use super::rewind::Rewind; use bytes::Bytes; use futures::{Async, Future, Poll, Stream}; use futures::future::{Either, Executor}; @@ -23,6 +24,7 @@ use common::Exec; use proto; use body::{Body, Payload}; use service::{NewService, Service}; +use error::{Kind, Parse}; #[cfg(feature = "runtime")] pub use super::tcp::AddrIncoming; @@ -74,23 +76,24 @@ pub(super) struct SpawnAll { /// /// Polling this future will drive HTTP forward. #[must_use = "futures do nothing unless polled"] -pub struct Connection +pub struct Connection where S: Service, { - pub(super) conn: Either< + pub(super) conn: Option< + Either< proto::h1::Dispatcher< proto::h1::dispatch::Server, S::ResBody, - I, + T, proto::ServerTransaction, >, proto::h2::Server< - I, + Rewind, S, S::ResBody, >, - >, + >>, } /// Deconstructed parts of a `Connection`. @@ -98,7 +101,7 @@ where /// This allows taking apart a `Connection` at a later time, in order to /// reclaim the IO object, and additional related pieces. #[derive(Debug)] -pub struct Parts { +pub struct Parts { /// The original IO object used in the handshake. pub io: T, /// A buffer of bytes that have been read but not processed as HTTP. @@ -239,12 +242,13 @@ impl Http { let sd = proto::h1::dispatch::Server::new(service); Either::A(proto::h1::Dispatcher::new(sd, conn)) } else { - let h2 = proto::h2::Server::new(io, service, self.exec.clone()); + let rewind_io = Rewind::new(io); + let h2 = proto::h2::Server::new(rewind_io, service, self.exec.clone()); Either::B(h2) }; Connection { - conn: either, + conn: Some(either), } } @@ -322,7 +326,7 @@ where /// This `Connection` should continue to be polled until shutdown /// can finish. pub fn graceful_shutdown(&mut self) { - match self.conn { + match *self.conn.as_mut().unwrap() { Either::A(ref mut h1) => { h1.disable_keep_alive(); }, @@ -334,11 +338,12 @@ where /// Return the inner IO object, and additional information. /// + /// If the IO object has been "rewound" the io will not contain those bytes rewound. /// This should only be called after `poll_without_shutdown` signals /// that the connection is "done". Otherwise, it may not have finished /// flushing all necessary HTTP bytes. pub fn into_parts(self) -> Parts { - let (io, read_buf, dispatch) = match self.conn { + let (io, read_buf, dispatch) = match self.conn.unwrap() { Either::A(h1) => { h1.into_inner() }, @@ -349,7 +354,7 @@ where Parts { io: io, read_buf: read_buf, - service: dispatch.service, + service: dispatch.into_service(), _inner: (), } } @@ -362,7 +367,7 @@ where /// but it is not desired to actally shutdown the IO object. Instead you /// would take it back using `into_parts`. pub fn poll_without_shutdown(&mut self) -> Poll<(), ::Error> { - match self.conn { + match *self.conn.as_mut().unwrap() { Either::A(ref mut h1) => { try_ready!(h1.poll_without_shutdown()); Ok(().into()) @@ -370,6 +375,29 @@ where Either::B(ref mut h2) => h2.poll(), } } + + fn try_h2(&mut self) -> Poll<(), ::Error> { + trace!("Trying to upgrade connection to h2"); + let conn = self.conn.take(); + + let (io, read_buf, dispatch) = match conn.unwrap() { + Either::A(h1) => { + h1.into_inner() + }, + Either::B(_h2) => { + panic!("h2 cannot into_inner"); + } + }; + let mut rewind_io = Rewind::new(io); + rewind_io.rewind(read_buf); + let mut h2 = proto::h2::Server::new(rewind_io, dispatch.into_service(), Exec::Default); + let pr = h2.poll(); + + debug_assert!(self.conn.is_none()); + self.conn = Some(Either::B(h2)); + + pr + } } impl Future for Connection @@ -384,7 +412,16 @@ where type Error = ::Error; fn poll(&mut self) -> Poll { - self.conn.poll() + match self.conn.poll() { + Ok(x) => Ok(x.map(|o| o.unwrap_or_else(|| ()))), + Err(e) => { + debug!("error polling connection protocol: {}", e); + match *e.kind() { + Kind::Parse(Parse::VersionH2) => self.try_h2(), + _ => Err(e), + } + } + } } } diff --git a/src/server/mod.rs b/src/server/mod.rs index 0430ee4566..5c1cb9fdff 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -50,6 +50,7 @@ pub mod conn; #[cfg(feature = "runtime")] mod tcp; +mod rewind; use std::fmt; #[cfg(feature = "runtime")] use std::net::SocketAddr; diff --git a/src/server/rewind.rs b/src/server/rewind.rs new file mode 100644 index 0000000000..6d2bf90eda --- /dev/null +++ b/src/server/rewind.rs @@ -0,0 +1,208 @@ +use bytes::{Buf, BufMut, Bytes, IntoBuf}; +use futures::{Async, Poll}; +use std::io::{self, Read, Write}; +use std::cmp; +use tokio_io::{AsyncRead, AsyncWrite}; + +#[derive(Debug)] +pub struct Rewind { + pre: Option, + inner: T, +} + +impl Rewind { + pub(super) fn new(tcp: T) -> Rewind { + Rewind { + pre: None, + inner: tcp, + } + } + pub fn rewind(&mut self, bs: Bytes) { + debug_assert!(self.pre.is_none()); + self.pre = Some(bs); + } +} + +impl Read for Rewind +where + T: Read, +{ + #[inline] + fn read(&mut self, buf: &mut [u8]) -> io::Result { + if let Some(pre_bs) = self.pre.take() { + // If there are no remaining bytes, let the bytes get dropped. + if pre_bs.len() > 0 { + let mut pre_reader = pre_bs.into_buf().reader(); + let read_cnt = pre_reader.read(buf)?; + + let mut new_pre = pre_reader.into_inner().into_inner(); + new_pre.advance(read_cnt); + + // Put back whats left + if new_pre.len() > 0 { + self.pre = Some(new_pre); + } + + return Ok(read_cnt); + } + } + self.inner.read(buf) + } +} + +impl Write for Rewind +where + T: Write, +{ + #[inline] + fn write(&mut self, buf: &[u8]) -> io::Result { + self.inner.write(buf) + } + + #[inline] + fn flush(&mut self) -> io::Result<()> { + self.inner.flush() + } +} + +impl AsyncRead for Rewind +where + T: AsyncRead, +{ + #[inline] + unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool { + self.inner.prepare_uninitialized_buffer(buf) + } + + #[inline] + fn read_buf(&mut self, buf: &mut B) -> Poll { + if let Some(bs) = self.pre.take() { + let pre_len = bs.len(); + // If there are no remaining bytes, let the bytes get dropped. + if pre_len > 0 { + let cnt = cmp::min(buf.remaining_mut(), pre_len); + let pre_buf = bs.into_buf(); + let mut xfer = Buf::take(pre_buf, cnt); + buf.put(&mut xfer); + + let mut new_pre = xfer.into_inner().into_inner(); + new_pre.advance(cnt); + + // Put back whats left + if new_pre.len() > 0 { + self.pre = Some(new_pre); + } + + return Ok(Async::Ready(cnt)); + } + } + self.inner.read_buf(buf) + } +} + +impl AsyncWrite for Rewind +where + T: AsyncWrite, +{ + #[inline] + fn shutdown(&mut self) -> Poll<(), io::Error> { + AsyncWrite::shutdown(&mut self.inner) + } + + #[inline] + fn write_buf(&mut self, buf: &mut B) -> Poll { + self.inner.write_buf(buf) + } +} + +#[cfg(test)] +mod tests { + use super::*; + extern crate tokio_mockstream; + use self::tokio_mockstream::MockStream; + use std::io::Cursor; + + // Test a partial rewind + #[test] + fn async_partial_rewind() { + let bs = &mut [104, 101, 108, 108, 111]; + let o1 = &mut [0, 0]; + let o2 = &mut [0, 0, 0, 0, 0]; + + let mut stream = Rewind::new(MockStream::new(bs)); + let mut o1_cursor = Cursor::new(o1); + // Read off some bytes, ensure we filled o1 + match stream.read_buf(&mut o1_cursor).unwrap() { + Async::NotReady => panic!("should be ready"), + Async::Ready(cnt) => assert_eq!(2, cnt), + } + + // Rewind the stream so that it is as if we never read in the first place. + let read_buf = Bytes::from(&o1_cursor.into_inner()[..]); + stream.rewind(read_buf); + + // We poll 2x here since the first time we'll only get what is in the + // prefix (the rewinded part) of the Rewind.\ + let mut o2_cursor = Cursor::new(o2); + stream.read_buf(&mut o2_cursor).unwrap(); + stream.read_buf(&mut o2_cursor).unwrap(); + let o2_final = o2_cursor.into_inner(); + + // At this point we should have read everything that was in the MockStream + assert_eq!(&o2_final, &bs); + } + // Test a full rewind + #[test] + fn async_full_rewind() { + let bs = &mut [104, 101, 108, 108, 111]; + let o1 = &mut [0, 0, 0, 0, 0]; + let o2 = &mut [0, 0, 0, 0, 0]; + + let mut stream = Rewind::new(MockStream::new(bs)); + let mut o1_cursor = Cursor::new(o1); + match stream.read_buf(&mut o1_cursor).unwrap() { + Async::NotReady => panic!("should be ready"), + Async::Ready(cnt) => assert_eq!(5, cnt), + } + + let read_buf = Bytes::from(&o1_cursor.into_inner()[..]); + stream.rewind(read_buf); + + let mut o2_cursor = Cursor::new(o2); + stream.read_buf(&mut o2_cursor).unwrap(); + stream.read_buf(&mut o2_cursor).unwrap(); + let o2_final = o2_cursor.into_inner(); + + assert_eq!(&o2_final, &bs); + } + #[test] + fn partial_rewind() { + let bs = &mut [104, 101, 108, 108, 111]; + let o1 = &mut [0, 0]; + let o2 = &mut [0, 0, 0, 0, 0]; + + let mut stream = Rewind::new(MockStream::new(bs)); + stream.read(o1).unwrap(); + + let read_buf = Bytes::from(&o1[..]); + stream.rewind(read_buf); + let cnt = stream.read(o2).unwrap(); + stream.read(&mut o2[cnt..]).unwrap(); + assert_eq!(&o2, &bs); + } + #[test] + fn full_rewind() { + let bs = &mut [104, 101, 108, 108, 111]; + let o1 = &mut [0, 0, 0, 0, 0]; + let o2 = &mut [0, 0, 0, 0, 0]; + + let mut stream = Rewind::new(MockStream::new(bs)); + stream.read(o1).unwrap(); + + let read_buf = Bytes::from(&o1[..]); + stream.rewind(read_buf); + let cnt = stream.read(o2).unwrap(); + stream.read(&mut o2[cnt..]).unwrap(); + assert_eq!(&o2, &bs); + } +} diff --git a/tests/server.rs b/tests/server.rs index 7454b425e5..4b86795854 100644 --- a/tests/server.rs +++ b/tests/server.rs @@ -31,6 +31,7 @@ use tokio_io::{AsyncRead, AsyncWrite}; use hyper::{Body, Request, Response, StatusCode}; +use hyper::client::Client; use hyper::server::conn::Http; use hyper::service::{service_fn, Service}; @@ -39,6 +40,24 @@ fn tcp_bind(addr: &SocketAddr, handle: &Handle) -> ::tokio::io::Result = Client::builder().http2_only(true).build_http(); + let uri = addr_str.parse::().expect("server addr should parse"); + + client.get(uri) + .and_then(|_res| { Ok(()) }) + .map(|_| { () }) + .map_err(|_e| { () }) + })); + + assert_eq!(server.body(), b""); +} + #[test] fn get_should_ignore_body() { let server = serve();