From 5bffd8b9dc747d4c134dfbc132dd20596584f2d8 Mon Sep 17 00:00:00 2001 From: Luqman Aden Date: Fri, 12 Aug 2022 14:27:51 -0700 Subject: [PATCH] Expose way for request handlers to determine if the request came over HTTP or HTTPS. --- dropshot/src/server.rs | 4 ++ dropshot/tests/test_tls.rs | 109 ++++++++++++++++++++++++++++++++++++- 2 files changed, 112 insertions(+), 1 deletion(-) diff --git a/dropshot/src/server.rs b/dropshot/src/server.rs index 7d37374e..7e005531 100644 --- a/dropshot/src/server.rs +++ b/dropshot/src/server.rs @@ -68,6 +68,8 @@ pub struct DropshotState { pub log: Logger, /** bound local address for the server. */ pub local_addr: SocketAddr, + /** are requests served over HTTPS */ + pub tls: bool, } /** @@ -253,6 +255,7 @@ impl InnerHttpServerStarter { router: api.into_router(), log: log.new(o!("local_addr" => local_addr)), local_addr, + tls: false, }); let make_service = ServerConnectionHandler::new(app_state.clone()); @@ -503,6 +506,7 @@ impl InnerHttpsServerStarter { router: api.into_router(), log: logger, local_addr, + tls: true, }); let make_service = ServerConnectionHandler::new(Arc::clone(&app_state)); diff --git a/dropshot/tests/test_tls.rs b/dropshot/tests/test_tls.rs index 88107bf9..3d0e4df7 100644 --- a/dropshot/tests/test_tls.rs +++ b/dropshot/tests/test_tls.rs @@ -4,7 +4,7 @@ * including certificate loading and supported modes. */ -use dropshot::{ConfigDropshot, ConfigTls, HttpServerStarter}; +use dropshot::{ConfigDropshot, ConfigTls, HttpResponseOk, HttpServerStarter}; use slog::{o, Logger}; use std::convert::TryFrom; use std::path::Path; @@ -241,3 +241,110 @@ async fn test_tls_aborted_negotiation() { logctx.cleanup_successful(); } + +#[derive(serde::Serialize, serde::Deserialize, schemars::JsonSchema)] +pub struct TlsCheckArgs { + tls: bool, +} + +/* + * The same handler is used for both an HTTP and HTTPS server. + * Make sure that we can distinguish between the two. + * The intended version is determined by a query parameter + * that varies between both tests. + */ +#[dropshot::endpoint { + method = GET, + path = "/", +}] +async fn tls_check_handler( + rqctx: Arc>, + query: dropshot::Query, +) -> Result, dropshot::HttpError> { + if rqctx.server.tls != query.into_inner().tls { + return Err(dropshot::HttpError::for_bad_request( + None, + "mismatch between expected and actual tls state".to_string(), + )); + } + Ok(HttpResponseOk(())) +} + +#[tokio::test] +async fn test_server_is_https() { + let logctx = create_log_context("test_server_is_https"); + let log = logctx.log.new(o!()); + + // Generate key for the server + let (certs, key) = common::generate_tls_key(); + let (cert_file, key_file) = common::tls_key_to_file(&certs, &key); + + let config = ConfigDropshot { + bind_address: "127.0.0.1:0".parse().unwrap(), + request_body_max_bytes: 1024, + tls: Some(ConfigTls { + cert_file: cert_file.path().to_path_buf(), + key_file: key_file.path().to_path_buf(), + }), + }; + let mut api = dropshot::ApiDescription::new(); + api.register(tls_check_handler).unwrap(); + let server = HttpServerStarter::new(&config, api, 0, &log).unwrap().start(); + let port = server.local_addr().port(); + + let https_client = make_https_client(make_pki_verifier(&certs)); + + // Expect request with tls=true to pass with https server + let https_request = hyper::Request::builder() + .method(http::method::Method::GET) + .uri(format!("https://localhost:{}/?tls=true", port)) + .body(hyper::Body::empty()) + .unwrap(); + let res = https_client.request(https_request).await.unwrap(); + assert_eq!(res.status(), hyper::StatusCode::OK); + + // Expect request with tls=false to fail with https server + let https_request = hyper::Request::builder() + .method(http::method::Method::GET) + .uri(format!("https://localhost:{}/?tls=false", port)) + .body(hyper::Body::empty()) + .unwrap(); + let res = https_client.request(https_request).await.unwrap(); + assert_eq!(res.status(), hyper::StatusCode::BAD_REQUEST); + + server.close().await.unwrap(); + + logctx.cleanup_successful(); +} + +#[tokio::test] +async fn test_server_is_http() { + let mut api = dropshot::ApiDescription::new(); + api.register(tls_check_handler).unwrap(); + + let testctx = common::test_setup("test_server_is_http", api); + + // Expect request with tls=false to pass with plain http server + testctx + .client_testctx + .make_request( + hyper::Method::GET, + "/?tls=false", + None as Option<()>, + hyper::StatusCode::OK, + ) + .await + .expect("expected success"); + + // Expect request with tls=true to fail with plain http server + testctx + .client_testctx + .make_request( + hyper::Method::GET, + "/?tls=true", + None as Option<()>, + hyper::StatusCode::BAD_REQUEST, + ) + .await + .expect_err("expected failure"); +}