Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement a default OPTIONS handler and complement the handler for HTTP 405 in router #743

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 169 additions & 8 deletions src/router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use route_recognizer::{Match, Params, Router as MethodRouter};
use std::collections::HashMap;

use crate::endpoint::DynEndpoint;
use crate::{Request, Response, StatusCode};
use crate::{http::headers, http::Method, Request, Response, StatusCode};

/// The routing table used by `Server`
///
Expand Down Expand Up @@ -71,11 +71,20 @@ impl<State: Clone + Send + Sync + 'static> Router<State> {
.filter(|(k, _)| **k != method)
.any(|(_, r)| r.recognize(path).is_ok())
{
// If this `path` can be handled by a callback registered with a different HTTP method
// should return 405 Method Not Allowed
// If this `path` can be handled by a callback registered with a different HTTP method,
// the server should return 405 Method Not Allowed.
// Or for an OPTIONS request, it should response with a success and supported methods.
let supported_methods = self.get_supported_methods(path).join(", ");
let mut params = Params::new();
params.insert(String::from(SUPPORTED_METHODS_PARAM_KEY), supported_methods);
// TODO: How to pass a closure as the endpoint here?
Selection {
endpoint: &method_not_allowed,
params: Params::new(),
endpoint: if method == Method::Options {
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to pass a closure here as the endpoint. But it seems impossible as Selection.endpoint takes a reference?

To circumvent the limitation, I am trying to pass it in the params from which it is then extracted in the endpoint function. I am not sure it is good or not as it is a little tricky and incurring unnecessary overhead. But if the way is acceptable, we can let the CorsMiddleware also extract the supported methods to fill Access-Control-Allow-Methods in CORS preflight requests without refactoring a lot.

&http_options_endpoint
} else {
&method_not_allowed_endpoint
},
params: params,
}
} else {
Selection {
Expand All @@ -84,6 +93,37 @@ impl<State: Clone + Send + Sync + 'static> Router<State> {
}
}
}

/// Get supported methods for a target resource path
fn get_supported_methods<'a>(&'a self, path: &'a str) -> Vec<&str> {
let basic_methods: &[&str]; // implicitly supported methods not registered in the map
if !self
.method_map
.get(&Method::Head)
.and_then(|r| r.recognize(path).ok())
.is_some()
&& self
.method_map
.get(&Method::Get)
.and_then(|r| r.recognize(path).ok())
.is_some()
{
// If the endpoint has no handler for HEAD, but a handler for GET.
basic_methods = &["OPTIONS", "HEAD"];
} else {
basic_methods = &["OPTIONS"];
}
let registered_methods = self
.method_map
.iter()
.filter(|(_, r)| r.recognize(path).is_ok())
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is inevitable to query every MethodRouter in the method_map to get all registered/supported methods which are used by the Allow header.
As is discussed in #51 (comment), it is somewhat inefficient. But now that L71 already do this to determine whether it falls to the branch handling 405 or not, I suppose it is fine for now (at least, it is not a new problem).

.map(|(m, _)| m.as_ref());
basic_methods
.iter()
.map(|&s| s)
.chain(registered_methods)
.collect::<Vec<&str>>()
}
}

async fn not_found_endpoint<State: Clone + Send + Sync + 'static>(
Expand All @@ -92,8 +132,129 @@ async fn not_found_endpoint<State: Clone + Send + Sync + 'static>(
Ok(Response::new(StatusCode::NotFound))
}

async fn method_not_allowed<State: Clone + Send + Sync + 'static>(
_req: Request<State>,
pub(crate) const SUPPORTED_METHODS_PARAM_KEY: &'static str = "_SUPPORTED_METHODS";

/// The endpoint that responses with HTTP status `405 Method Not Allowed`
///
/// The comma-seperated list of supported methods to be set in the HTTP header `Allow` will be
/// extracted from the request param named [`SUPPORTED_METHODS_PARAM_KEY`].
/// Ref: [Section 6.5.5 of IETC RFC 7231](https://tools.ietf.org/html/rfc7231#section-6.5.5).
async fn method_not_allowed_endpoint<State: Clone + Send + Sync + 'static>(
req: Request<State>,
) -> crate::Result {
let mut resp = Response::new(StatusCode::MethodNotAllowed);
if let Some(supported_methods) = req.param(SUPPORTED_METHODS_PARAM_KEY).ok() {
resp.insert_header(headers::ALLOW, supported_methods);
}
Ok(resp)
}

/// The default handler for the HTTP `OPTIONS` method, only meant for listing supported methods
///
/// The comma-separated list of allowed methods to be set in the HTTP header `Allow` will be
/// extracted from the request param named [`SUPPORTED_METHODS_PARAM_KEY`].
/// For CORS preflight requests (i.e. the HTTP header `Origin` is set), it is expected be overrided
/// by CORSMiddleware, if the latter is activated.
async fn http_options_endpoint<State: Clone + Send + Sync + 'static>(
req: Request<State>,
) -> crate::Result {
Ok(Response::new(StatusCode::MethodNotAllowed))
let mut resp = Response::new(StatusCode::NoContent);
if let Some(supported_methods) = req.param(SUPPORTED_METHODS_PARAM_KEY).ok() {
resp.insert_header(headers::ALLOW, supported_methods);
}
Ok(resp)
}

#[cfg(test)]
mod test {
use crate::http::{self, Method, Request, StatusCode, Url};
use crate::security::{CorsMiddleware, Origin};
use crate::Response;
use http_types::headers::HeaderValue;
use std::collections::HashSet;

#[async_std::test]
async fn default_handler_for_http_options() {
let mut app = crate::Server::new();
app.at("/endpoint")
.get(|_| async { Ok("Hello, GET.") })
.post(|_| async { Ok("Hello, POST.") });
app.at("/pendoint").post(|_| async { Ok("Hello, POST.") });

let response: Response = app
.respond(Request::new(
Method::Options,
Url::parse("http://example.com/endpoint").unwrap(),
))
.await
.unwrap();
assert!(response.status().is_success());
ensure_methods_allowed(&response, &["get", "head", "post", "options"], true);

let response: Response = app
.respond(Request::new(
Method::Options,
Url::parse("http://example.com/pendoint").unwrap(),
))
.await
.unwrap();
assert!(response.status().is_success());
ensure_methods_allowed(&response, &["options", "post"], true);
ensure_methods_allowed(&response, &["head"], false);
}

#[async_std::test]
async fn return_status_405_if_method_not_allowed() {
let mut app = crate::Server::new();
app.at("/endpoint")
.get(|_| async { Ok("Hello, GET.") })
.post(|_| async { Ok("Hello, POST.") });

let response: Response = app
.respond(Request::new(
Method::Put,
Url::parse("http://example.com/endpoint").unwrap(),
))
.await
.unwrap();
assert_eq!(response.status(), StatusCode::MethodNotAllowed);
ensure_methods_allowed(&response, &["get", "post", "options"], true);
}

#[async_std::test]
async fn options_overrided_for_cors_preflight() {
let mut app = crate::Server::new();
app.at("/").get(|_| async { Ok("Hello, world.") });
app.with(
CorsMiddleware::new()
.allow_methods("GET, POST, OPTIONS".parse::<HeaderValue>().unwrap())
.allow_origin(Origin::Any),
);

let self_origin = "example.org";
let mut request = Request::new(Method::Options, Url::parse("http://example.com/").unwrap());
request.append_header(http::headers::ORIGIN, self_origin);
let response: Response = app.respond(request).await.unwrap();
let allowed_origin = response
.header(http::headers::ACCESS_CONTROL_ALLOW_ORIGIN)
.map(|origin| Origin::from(origin.as_str()));
assert_eq!(allowed_origin.unwrap(), Origin::from(self_origin));
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current implementation of CorsMiddleware responses directly to CORS preflight requests, without reaching the endpoint.

return Ok(self.build_preflight_response(&origins).into());

So the newly added endpoint handler won't interfere with the existing CORS handling flow.

}

fn ensure_methods_allowed(response: &Response, expected_methods: &[&str], positive: bool) {
let allowed_methods = response.header(http::headers::ALLOW).map(|methods| {
methods
.as_str()
.split(",")
.map(|method| method.trim().to_ascii_lowercase())
.collect::<HashSet<String>>()
});
let allowed_methods = allowed_methods.unwrap();
for method in expected_methods
.iter()
.map(|&method| method.to_ascii_lowercase())
{
assert!(!positive ^ allowed_methods.contains(&method));
}
}
}