Skip to content

Commit

Permalink
remove hyper::Request from RequestContext (#557)
Browse files Browse the repository at this point in the history
  • Loading branch information
davepacheco authored Jan 19, 2023
1 parent 538e99b commit 2d526b3
Show file tree
Hide file tree
Showing 16 changed files with 326 additions and 102 deletions.
13 changes: 13 additions & 0 deletions CHANGELOG.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,19 @@ There are a number of breaking changes in this release but we expect they will b
1. For any endpoint functions that use a `TypedBody`, `UntypedBody`, or `WebsocketConnection` extractor, this extractor must be the last argument to the function. Otherwise, you will get a compile error about the extractor not impl'ing `SharedExtractor`.
2. If you have your own type that impls `Extractor`, you will need to change that to either `ExclusiveExtractor` (if the impl needs a `mut` reference to the underlying `hyper::Request`, which is usually because it needs to read the request body) or `SharedExtractor`. If your extractor only needs to look at the URL or request headers and not the body, it can probably be a `SharedExtractor`. If it's an exclusive extractor, any function that accepts it must accept it as the last argument to the function.
3. If you have your own type that impls `Extractor`, you will also need to change the type signature of the `from_request` method to accept a `&RequestContext<T>` instead of `Arc<RequestContext<T>>`. (This should not be a problem unless your extractor was hanging on to a reference via the Arc. We don't know a reason this would be useful. If you were doing this, please https://github.com/oxidecomputer/dropshot/discussions[start a discussion] or https://github.com/oxidecomputer/dropshot/issues/new[file an issue]. In the meantime, you likely can copy whatever information you need out of the `RequestContext` rather than cloning the Arc.)
* https://github.com/oxidecomputer/dropshot/pull/557[#557] Simpler, safer access to raw request. Prior to this change, the raw `hyper::Request` (`http::Request`) was accessible to endpoint functions via the `RequestContext`, but behind an `Arc<Mutex<...>>`. This was a little strange because your endpoint function was usually the only one with a reference to this object. (You could get into trouble if you defined your own Extractor that cloned one of the `Arc` objects -- your extractor could deadlock with the handler.) After this change, the raw request is available only through a separate `RawRequest` extractor. This is an exclusive extractor, which means you cannot use it with `TypedBody` or `UntypedBody`. As a result, there is no way to wind up with multiple references to the request. There's no lock and no way to get into this sort of trouble.
+
After this change, the `hyper::Request` is passed as a separate argument to `ExclusiveExtractor::from_request()`.
+
**What you need to do:**
+
1. If you have a request handler that accesses `rqctx.request`, it's typically doing `let request = rqctx.request.lock().await`.
a. If that code is only accessing the HTTP method, URI, headers, or version, then _you can skip this step_. However, it's recommended that you replace that with `let request = &rqctx.request`. (That object has methods compatible with `http::Request` for accessing the method, URI, headers, and version.)
b. If that code is accessing other parts of the request (e.g., reading the body or doing a protocol upgrade), then you must instead add a `raw_request: RawRequest` argument to your endpoint function. Then you can use `let request = raw_request.into_inner()`.
2. If you have an extractor that access `rqctx.request`, then it too is typically doing something like `let request = rqctx.request.lock().await`.
a. If that code is only accessing the HTTP method, URI, headers, or version, then just like above _you can skip this step_, but it's recommended that you replace that with `let request = &rqctx.request`. This can be done from a `SharedExtractor` or an `ExclusiveExtractor`.
b. If that code is accessing other parts of the request (e.g., reading the body or doing a protocol upgrade), then this extractor must impl `ExclusiveExtractor` (not `SharedExtractor`). With `ExclusiveExtractor`, the `hyper::Request` is available as an argument to `from_request()`.
+
* https://github.com/oxidecomputer/dropshot/pull/504[#504] Dropshot now allows TLS configuration to be supplied either by path or as bytes. For compatibility, the `AsFile` variant of `ConfigTls` contains the `cert_file` and `key_file` fields, and may be used similarly to the old variant.
* https://github.com/oxidecomputer/dropshot/pull/502[#502] Dropshot exposes a `refresh_tls` method to update the TLS certificates being used by a running server.
+
Expand Down
3 changes: 1 addition & 2 deletions dropshot/examples/request-headers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,8 @@ async fn main() -> Result<(), String> {
async fn example_api_get_header_generic(
rqctx: Arc<RequestContext<()>>,
) -> Result<HttpResponseOk<String>, HttpError> {
let request = rqctx.request.lock().await;
// Note that clients can provide multiple values for a header. See
// http::HeaderMap for ways to get all of them.
let header = request.headers().get("demo-header");
let header = rqctx.request.headers().get("demo-header");
Ok(HttpResponseOk(format!("value for header: {:?}", header)))
}
42 changes: 42 additions & 0 deletions dropshot/src/dtrace.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// Copyright 2023 Oxide Computer Company
//! DTrace probes and support

#[derive(Debug, Clone, serde::Serialize)]
pub(crate) struct RequestInfo {
pub id: String,
pub local_addr: std::net::SocketAddr,
pub remote_addr: std::net::SocketAddr,
pub method: String,
pub path: String,
pub query: Option<String>,
}

#[derive(Debug, Clone, serde::Serialize)]
pub(crate) struct ResponseInfo {
pub id: String,
pub local_addr: std::net::SocketAddr,
pub remote_addr: std::net::SocketAddr,
pub status_code: u16,
pub message: String,
}

#[cfg(feature = "usdt-probes")]
#[usdt::provider(provider = "dropshot")]
mod probes {
use crate::dtrace::{RequestInfo, ResponseInfo};
fn request__start(_: &RequestInfo) {}
fn request__done(_: &ResponseInfo) {}
}

/// The result of registering a server's DTrace USDT probes.
#[derive(Debug, Clone, PartialEq)]
pub enum ProbeRegistration {
/// The probes are explicitly disabled at compile time.
Disabled,

/// Probes were successfully registered.
Succeeded,

/// Registration failed, with an error message explaining the cause.
Failed(String),
}
15 changes: 11 additions & 4 deletions dropshot/src/extractor/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ pub trait ExclusiveExtractor: Send + Sync + Sized {
/// Construct an instance of this type from a `RequestContext`.
async fn from_request<Context: ServerContext>(
rqctx: &RequestContext<Context>,
request: hyper::Request<hyper::Body>,
) -> Result<Self, HttpError>;

fn metadata(
Expand Down Expand Up @@ -55,6 +56,7 @@ pub trait SharedExtractor: Send + Sync + Sized {
impl<S: SharedExtractor> ExclusiveExtractor for S {
async fn from_request<Context: ServerContext>(
rqctx: &RequestContext<Context>,
_request: hyper::Request<hyper::Body>,
) -> Result<Self, HttpError> {
<S as SharedExtractor>::from_request(rqctx).await
}
Expand Down Expand Up @@ -94,6 +96,7 @@ pub trait RequestExtractor: Send + Sync + Sized {
/// Construct an instance of this type from a `RequestContext`.
async fn from_request<Context: ServerContext>(
rqctx: &RequestContext<Context>,
request: hyper::Request<hyper::Body>,
) -> Result<Self, HttpError>;

fn metadata(
Expand All @@ -106,6 +109,7 @@ pub trait RequestExtractor: Send + Sync + Sized {
impl RequestExtractor for () {
async fn from_request<Context: ServerContext>(
_rqctx: &RequestContext<Context>,
_request: hyper::Request<hyper::Body>,
) -> Result<Self, HttpError> {
Ok(())
}
Expand All @@ -125,8 +129,9 @@ impl RequestExtractor for () {
impl<X: ExclusiveExtractor + 'static> RequestExtractor for (X,) {
async fn from_request<Context: ServerContext>(
rqctx: &RequestContext<Context>,
request: hyper::Request<hyper::Body>,
) -> Result<Self, HttpError> {
Ok((X::from_request(rqctx).await?,))
Ok((X::from_request(rqctx, request).await?,))
}

fn metadata(
Expand Down Expand Up @@ -155,12 +160,14 @@ macro_rules! impl_rqextractor_for_tuple {
RequestExtractor
for ($($S,)+ X)
{
async fn from_request<Context: ServerContext>(rqctx: &RequestContext<Context>)
-> Result<( $($S,)+ X ), HttpError>
async fn from_request<Context: ServerContext>(
rqctx: &RequestContext<Context>,
request: hyper::Request<hyper::Body>
) -> Result<( $($S,)+ X ), HttpError>
{
futures::try_join!(
$($S::from_request(rqctx),)+
X::from_request(rqctx)
X::from_request(rqctx, request)
)
}

Expand Down
49 changes: 41 additions & 8 deletions dropshot/src/extractor/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ use crate::RequestContext;

use async_trait::async_trait;
use bytes::Bytes;
use hyper::Body;
use hyper::Request;
use schemars::schema::InstanceType;
use schemars::schema::SchemaObject;
use schemars::JsonSchema;
Expand All @@ -33,6 +31,7 @@ use std::fmt::Debug;

mod common;

use crate::RequestInfo;
pub use common::ExclusiveExtractor;
pub use common::ExtractorMetadata;
pub use common::RequestExtractor;
Expand All @@ -59,7 +58,7 @@ impl<QueryType: DeserializeOwned + JsonSchema + Send + Sync> Query<QueryType> {
/// Given an HTTP request, pull out the query string and attempt to deserialize
/// it as an instance of `QueryType`.
fn http_request_load_query<QueryType>(
request: &Request<Body>,
request: &RequestInfo,
) -> Result<Query<QueryType>, HttpError>
where
QueryType: DeserializeOwned + JsonSchema + Send + Sync,
Expand Down Expand Up @@ -89,8 +88,7 @@ where
async fn from_request<Context: ServerContext>(
rqctx: &RequestContext<Context>,
) -> Result<Query<QueryType>, HttpError> {
let request = rqctx.request.lock().await;
http_request_load_query(&request)
http_request_load_query(&rqctx.request)
}

fn metadata(
Expand Down Expand Up @@ -224,12 +222,12 @@ impl<BodyType: JsonSchema + DeserializeOwned + Send + Sync>
/// to the content type, and deserialize it to an instance of `BodyType`.
async fn http_request_load_body<Context: ServerContext, BodyType>(
rqctx: &RequestContext<Context>,
mut request: hyper::Request<hyper::Body>,
) -> Result<TypedBody<BodyType>, HttpError>
where
BodyType: JsonSchema + DeserializeOwned + Send + Sync,
{
let server = &rqctx.server;
let mut request = rqctx.request.lock().await;
let body = http_read_body(
request.body_mut(),
server.config.request_body_max_bytes,
Expand Down Expand Up @@ -300,8 +298,9 @@ where
{
async fn from_request<Context: ServerContext>(
rqctx: &RequestContext<Context>,
request: hyper::Request<hyper::Body>,
) -> Result<TypedBody<BodyType>, HttpError> {
http_request_load_body(rqctx).await
http_request_load_body(rqctx, request).await
}

fn metadata(content_type: ApiEndpointBodyContentType) -> ExtractorMetadata {
Expand Down Expand Up @@ -353,9 +352,9 @@ impl UntypedBody {
impl ExclusiveExtractor for UntypedBody {
async fn from_request<Context: ServerContext>(
rqctx: &RequestContext<Context>,
mut request: hyper::Request<hyper::Body>,
) -> Result<UntypedBody, HttpError> {
let server = &rqctx.server;
let mut request = rqctx.request.lock().await;
let body_bytes = http_read_body(
request.body_mut(),
server.config.request_body_max_bytes,
Expand Down Expand Up @@ -389,6 +388,40 @@ impl ExclusiveExtractor for UntypedBody {
}
}

// RawRequest: extractor for the raw underlying hyper::Request

/// `RawRequest` is an extractor providing access to the raw underlying
/// [`hyper::Request`].
#[derive(Debug)]
pub struct RawRequest {
request: hyper::Request<hyper::Body>,
}

impl RawRequest {
pub fn into_inner(self) -> hyper::Request<hyper::Body> {
self.request
}
}

#[async_trait]
impl ExclusiveExtractor for RawRequest {
async fn from_request<Context: ServerContext>(
_rqctx: &RequestContext<Context>,
request: hyper::Request<hyper::Body>,
) -> Result<RawRequest, HttpError> {
Ok(RawRequest { request })
}

fn metadata(
_content_type: ApiEndpointBodyContentType,
) -> ExtractorMetadata {
ExtractorMetadata {
parameters: vec![],
extension_mode: ExtensionMode::None,
}
}
}

#[cfg(test)]
mod test {
use crate::api_description::ExtensionMode;
Expand Down
90 changes: 78 additions & 12 deletions dropshot/src/handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,9 @@ use crate::schema_util::ReferenceVisitor;
use crate::to_map::to_map;

use async_trait::async_trait;
use futures::lock::Mutex;
use http::HeaderMap;
use http::StatusCode;
use hyper::Body;
use hyper::Request;
use hyper::Response;
use schemars::JsonSchema;
use serde::de::DeserializeOwned;
Expand All @@ -73,19 +71,10 @@ use std::sync::Arc;
pub type HttpHandlerResult = Result<Response<Body>, HttpError>;

/// Handle for various interfaces useful during request processing.
// TODO-cleanup What's the right way to package up "request"? The only time we
// need it to be mutable is when we're reading the body (e.g., as part of the
// JSON extractor). In order to support that, we wrap it in something that
// supports interior mutability. It also needs to be thread-safe, since we're
// using async/await. That brings us to Arc<Mutex<...>>, but it seems like
// overkill since it will only really be used by one thread at a time (at all,
// let alone mutably) and there will never be contention on the Mutex.
#[derive(Debug)]
pub struct RequestContext<Context: ServerContext> {
/// shared server state
pub server: Arc<DropshotState<Context>>,
/// HTTP request details
pub request: Arc<Mutex<Request<Body>>>,
/// HTTP request routing variables
pub path_variables: VariableSet,
/// expected request body mime type
Expand All @@ -94,6 +83,80 @@ pub struct RequestContext<Context: ServerContext> {
pub request_id: String,
/// logger for this specific request
pub log: Logger,

/// basic request information (method, URI, etc.)
pub request: RequestInfo,
}

// This is deliberately as close to compatible with `hyper::Request` as
// reasonable.
#[derive(Debug)]
pub struct RequestInfo {
method: http::Method,
uri: http::Uri,
version: http::Version,
headers: http::HeaderMap<http::HeaderValue>,
}

impl<B> From<&hyper::Request<B>> for RequestInfo {
fn from(request: &hyper::Request<B>) -> Self {
RequestInfo {
method: request.method().clone(),
uri: request.uri().clone(),
version: request.version(),
headers: request.headers().clone(),
}
}
}

impl RequestInfo {
pub fn method(&self) -> &http::Method {
&self.method
}

pub fn uri(&self) -> &http::Uri {
&self.uri
}

pub fn version(&self) -> &http::Version {
&self.version
}

pub fn headers(&self) -> &http::HeaderMap<http::HeaderValue> {
&self.headers
}

/// Returns a reference to the `RequestInfo` itself
///
/// This is provided for source compatibility. In previous versions of
/// Dropshot, `RequestContext.request` was an
/// `Arc<Mutex<hyper::Request<hyper::Body>>>`. Now, it's just
/// `RequestInfo`, which provides many of the same functions as
/// `hyper::Request` does. Consumers _should_ just use `rqctx.request`
/// instead of this function.
///
/// For example, in previous versions of Dropshot, you might have:
///
/// ```ignore
/// let request = rqctx.request.lock().await;
/// let headers = request.headers();
/// ```
///
/// Now, you would do this:
///
/// ```ignore
/// let headers = rqctx.request.headers();
/// ```
///
/// This function allows the older code to continue to work.
#[deprecated(
since = "0.9.0",
note = "use `rqctx.request` directly instead of \
`rqctx.request.lock().await`"
)]
pub async fn lock(&self) -> &Self {
self
}
}

impl<Context: ServerContext> RequestContext<Context> {
Expand Down Expand Up @@ -304,6 +367,7 @@ pub trait RouteHandler<Context: ServerContext>: Debug + Send + Sync {
async fn handle_request(
&self,
rqctx: RequestContext<Context>,
request: hyper::Request<hyper::Body>,
) -> HttpHandlerResult;
}

Expand Down Expand Up @@ -366,6 +430,7 @@ where
async fn handle_request(
&self,
rqctx_raw: RequestContext<Context>,
request: hyper::Request<hyper::Body>,
) -> HttpHandlerResult {
// This is where the magic happens: in the code below, `funcparams` has
// type `FuncParams`, which is a tuple type describing the extractor
Expand All @@ -384,7 +449,8 @@ where
// actual handler function. From this point down, all of this is
// resolved statically.
let rqctx = Arc::new(rqctx_raw);
let funcparams = RequestExtractor::from_request(&rqctx).await?;
let funcparams =
RequestExtractor::from_request(&rqctx, request).await?;
let future = self.handler.handle_request(rqctx, funcparams);
future.await
}
Expand Down
Loading

0 comments on commit 2d526b3

Please sign in to comment.