diff --git a/Cargo.lock b/Cargo.lock index 2f199724a..df8ef0653 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -61,6 +61,12 @@ dependencies = [ "syn 2.0.18", ] +[[package]] +name = "atomic-waker" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1181e1e0d1fce796a03db1ae795d67167da795f9cf4a39c37589e85ef57f26d3" + [[package]] name = "atty" version = "0.2.14" @@ -284,6 +290,12 @@ version = "2.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c2e66c9d817f1720209181c316d28635c050fa304f9c79e47a520882661b7308" +[[package]] +name = "debug-ignore" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ffe7ed1d93f4553003e20b629abe9085e1e81b1429520f897f8f8860bc6dfc21" + [[package]] name = "digest" version = "0.8.1" @@ -346,6 +358,7 @@ dependencies = [ "bytes", "camino", "chrono", + "debug-ignore", "dropshot_endpoint", "expectorate", "form_urlencoded", @@ -388,6 +401,7 @@ dependencies = [ "usdt", "uuid", "version_check", + "waitgroup", ] [[package]] @@ -2097,6 +2111,15 @@ version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" +[[package]] +name = "waitgroup" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d1f50000a783467e6c0200f9d10642f4bc424e39efc1b770203e88b488f79292" +dependencies = [ + "atomic-waker", +] + [[package]] name = "want" version = "0.3.0" diff --git a/dropshot/Cargo.toml b/dropshot/Cargo.toml index 4d430667b..de9557fa8 100644 --- a/dropshot/Cargo.toml +++ b/dropshot/Cargo.toml @@ -16,6 +16,7 @@ async-trait = "0.1.68" base64 = "0.21.2" bytes = "1" camino = { version = "1.1.4", features = ["serde1"] } +debug-ignore = "1.0.5" form_urlencoded = "1.1.0" futures = "0.3.28" hostname = "0.3.0" @@ -37,6 +38,7 @@ slog-json = "2.6.1" slog-term = "2.9.0" tokio-rustls = "0.24.0" toml = "0.7.4" +waitgroup = "0.1.2" [dependencies.chrono] version = "0.4.26" diff --git a/dropshot/src/server.rs b/dropshot/src/server.rs index c14285400..d76dfe838 100644 --- a/dropshot/src/server.rs +++ b/dropshot/src/server.rs @@ -12,6 +12,7 @@ use super::router::HttpRouter; use super::ProbeRegistration; use async_stream::stream; +use debug_ignore::DebugIgnore; use futures::future::{ BoxFuture, FusedFuture, FutureExt, Shared, TryFutureExt, }; @@ -28,6 +29,7 @@ use hyper::Response; use rustls; use std::convert::TryFrom; use std::future::Future; +use std::mem; use std::net::SocketAddr; use std::num::NonZeroU32; use std::panic; @@ -39,6 +41,7 @@ use tokio::net::{TcpListener, TcpStream}; use tokio::sync::oneshot; use tokio_rustls::{server::TlsStream, TlsAcceptor}; use uuid::Uuid; +use waitgroup::WaitGroup; use crate::config::HandlerTaskMode; use crate::RequestInfo; @@ -69,6 +72,9 @@ pub struct DropshotState { pub local_addr: SocketAddr, /// Identifies how to accept TLS connections pub(crate) tls_acceptor: Option>>, + /// Worker for the handler_waitgroup associated with this server, allowing + /// graceful shutdown to wait for all handlers to complete. + pub(crate) handler_waitgroup_worker: DebugIgnore, } impl DropshotState { @@ -96,6 +102,7 @@ pub struct HttpServerStarter { app_state: Arc>, local_addr: SocketAddr, wrapped: WrappedHttpServerStarter, + handler_waitgroup: WaitGroup, } impl HttpServerStarter { @@ -123,6 +130,7 @@ impl HttpServerStarter { default_handler_task_mode: config.default_handler_task_mode, }; + let handler_waitgroup = WaitGroup::new(); let starter = match &tls { Some(tls) => { let (starter, app_state, local_addr) = @@ -133,11 +141,13 @@ impl HttpServerStarter { private, log, tls, + handler_waitgroup.worker(), )?; HttpServerStarter { app_state, local_addr, wrapped: WrappedHttpServerStarter::Https(starter), + handler_waitgroup, } } None => { @@ -148,11 +158,13 @@ impl HttpServerStarter { api, private, log, + handler_waitgroup.worker(), )?; HttpServerStarter { app_state, local_addr, wrapped: WrappedHttpServerStarter::Http(starter), + handler_waitgroup, } } }; @@ -182,6 +194,15 @@ impl HttpServerStarter { }); info!(self.app_state.log, "listening"); + let handler_waitgroup = self.handler_waitgroup; + let join_handle = async move { + // After the server shuts down, we also want to wait for any + // detached handler futures to complete. + () = join_handle.await?; + () = handler_waitgroup.wait().await; + Ok(()) + }; + #[cfg(feature = "usdt-probes")] let probe_registration = match usdt::register_probes() { Ok(_) => { @@ -258,6 +279,7 @@ impl InnerHttpServerStarter { api: ApiDescription, private: C, log: &Logger, + handler_waitgroup_worker: waitgroup::Worker, ) -> Result, hyper::Error> { let incoming = AddrIncoming::bind(&config.bind_address)?; let local_addr = incoming.local_addr(); @@ -269,6 +291,7 @@ impl InnerHttpServerStarter { log: log.new(o!("local_addr" => local_addr)), local_addr, tls_acceptor: None, + handler_waitgroup_worker: DebugIgnore(handler_waitgroup_worker), }); let make_service = ServerConnectionHandler::new(app_state.clone()); @@ -546,6 +569,7 @@ impl InnerHttpsServerStarter { private: C, log: &Logger, tls: &ConfigTls, + handler_waitgroup_worker: waitgroup::Worker, ) -> Result, GenericError> { let acceptor = Arc::new(Mutex::new(TlsAcceptor::from(Arc::new( rustls::ServerConfig::try_from(tls)?, @@ -572,6 +596,7 @@ impl InnerHttpsServerStarter { log: logger, local_addr, tls_acceptor: Some(acceptor), + handler_waitgroup_worker: DebugIgnore(handler_waitgroup_worker), }); let make_service = ServerConnectionHandler::new(Arc::clone(&app_state)); @@ -689,6 +714,14 @@ impl HttpServer { .expect("cannot close twice") .send(()) .expect("failed to send close signal"); + + // We _must_ explicitly drop our app state before awaiting join_future. + // If we are running handlers in `Detached` mode, our `app_state` has a + // `waitgroup::Worker` that they all clone, and `join_future` will await + // all of them being dropped. That means we must drop our "primary" + // clone of it, too! + mem::drop(self.app_state); + self.join_future.await } } @@ -875,6 +908,7 @@ async fn http_request_handle( // to completion. let (tx, rx) = oneshot::channel(); let request_log = rqctx.log.clone(); + let worker = server.handler_waitgroup_worker.clone(); let handler_task = tokio::spawn(async move { let request_log = rqctx.log.clone(); let result = handler.handle_request(rqctx, request).await; @@ -887,6 +921,10 @@ async fn http_request_handle( "client disconnected before response returned" ); } + + // Drop our waitgroup worker, allowing graceful shutdown to + // complete (if it's waiting on us). + mem::drop(worker); }); // The only way we can fail to receive on `rx` is if `tx` is diff --git a/dropshot/src/test_util.rs b/dropshot/src/test_util.rs index ee11f7347..f69b19a8f 100644 --- a/dropshot/src/test_util.rs +++ b/dropshot/src/test_util.rs @@ -72,6 +72,7 @@ const ALLOWED_HEADERS: [AllowedHeader<'static>; 8] = [ /// ClientTestContext encapsulates several facilities associated with using an /// HTTP client for testing. +#[derive(Clone)] pub struct ClientTestContext { /// actual bind address of the HTTP server under test pub bind_address: SocketAddr, diff --git a/dropshot/src/websocket.rs b/dropshot/src/websocket.rs index 2cf081604..dfcea7530 100644 --- a/dropshot/src/websocket.rs +++ b/dropshot/src/websocket.rs @@ -301,12 +301,14 @@ mod tests { ExclusiveExtractor, HttpError, RequestContext, RequestInfo, WebsocketUpgrade, }; + use debug_ignore::DebugIgnore; use http::Request; use hyper::Body; use std::net::{IpAddr, Ipv6Addr, SocketAddr}; use std::num::NonZeroU32; use std::sync::Arc; use std::time::Duration; + use waitgroup::WaitGroup; async fn ws_upg_from_mock_rqctx() -> Result { let log = slog::Logger::root(slog::Discard, slog::o!()).new(slog::o!()); @@ -336,6 +338,9 @@ mod tests { 8080, ), tls_acceptor: None, + handler_waitgroup_worker: DebugIgnore( + WaitGroup::new().worker(), + ), }), request: RequestInfo::new(&request, remote_addr), path_variables: Default::default(), diff --git a/dropshot/tests/common/mod.rs b/dropshot/tests/common/mod.rs index 865dcc8c5..a7cb320f8 100644 --- a/dropshot/tests/common/mod.rs +++ b/dropshot/tests/common/mod.rs @@ -8,6 +8,8 @@ use dropshot::ConfigDropshot; use dropshot::ConfigLogging; use dropshot::ConfigLoggingIfExists; use dropshot::ConfigLoggingLevel; +use dropshot::HandlerTaskMode; +use dropshot::ServerContext; use slog::o; use std::io::Write; use tempfile::NamedTempFile; @@ -16,6 +18,15 @@ pub fn test_setup( test_name: &str, api: ApiDescription, ) -> TestContext { + test_setup_with_context(test_name, api, 0_usize, HandlerTaskMode::Detached) +} + +pub fn test_setup_with_context( + test_name: &str, + api: ApiDescription, + ctx: Context, + default_handler_task_mode: HandlerTaskMode, +) -> TestContext { // The IP address to which we bind can be any local IP, but we use // 127.0.0.1 because we know it's present, it shouldn't expose this server // on any external network, and we don't have to go looking for some other @@ -23,11 +34,12 @@ pub fn test_setup( // request any available port. This is important because we may run // multiple concurrent tests, so any fixed port could result in spurious // failures due to port conflicts. - let config_dropshot: ConfigDropshot = Default::default(); + let config_dropshot: ConfigDropshot = + ConfigDropshot { default_handler_task_mode, ..Default::default() }; let logctx = create_log_context(test_name); let log = logctx.log.new(o!()); - TestContext::new(api, 0_usize, &config_dropshot, Some(logctx), log) + TestContext::new(api, ctx, &config_dropshot, Some(logctx), log) } pub fn create_log_context(test_name: &str) -> LogContext { diff --git a/dropshot/tests/test_detached_shutdown.rs b/dropshot/tests/test_detached_shutdown.rs new file mode 100644 index 000000000..b5e755ff8 --- /dev/null +++ b/dropshot/tests/test_detached_shutdown.rs @@ -0,0 +1,94 @@ +// Copyright 2023 Oxide Computer Company + +//! Test cases for graceful shutdown of a server running tasks in +//! `HandlerTaskMode::Detached`. + +use dropshot::{ + endpoint, ApiDescription, HandlerTaskMode, HttpError, RequestContext, +}; +use http::{Method, Response, StatusCode}; +use hyper::Body; +use std::time::Duration; +use tokio::sync::mpsc; + +pub mod common; + +struct Context { + endpoint_started_tx: mpsc::UnboundedSender<()>, + release_endpoint_rx: async_channel::Receiver<()>, +} + +fn api() -> ApiDescription { + let mut api = ApiDescription::new(); + api.register(root).unwrap(); + api +} + +#[endpoint { + method = GET, + path = "/", +}] +async fn root( + rqctx: RequestContext, +) -> Result, HttpError> { + let ctx = rqctx.context(); + + // Notify test driver we've started handling a request. + ctx.endpoint_started_tx.send(()).unwrap(); + + // Wait until the test driver tells us to return. + () = ctx.release_endpoint_rx.recv().await.unwrap(); + + Ok(Response::builder().status(StatusCode::OK).body(Body::empty())?) +} + +#[tokio::test] +async fn test_graceful_shutdown_with_detached_handler() { + let (endpoint_started_tx, mut endpoint_started_rx) = + mpsc::unbounded_channel(); + let (release_endpoint_tx, release_endpoint_rx) = async_channel::unbounded(); + + let api = api(); + let testctx = common::test_setup_with_context( + "graceful_shutdown_with_detached_handler", + api, + Context { endpoint_started_tx, release_endpoint_rx }, + HandlerTaskMode::Detached, + ); + let client = testctx.client_testctx.clone(); + + // Spawn a task sending a request to our endpoint. + let client_task = tokio::spawn(async move { + client + .make_request_no_body(Method::GET, "/", StatusCode::OK) + .await + .expect("Expected GET request to succeed") + }); + + // Wait for the handler to start running. + () = endpoint_started_rx.recv().await.unwrap(); + + // Kill the client, which cancels the dropshot server future that spawned + // our detached handler, but does not cancel the endpoint future itself + // (because we're using HandlerTaskMode::Detached). + client_task.abort(); + + // Create a future to tear down the server. + let teardown_fut = testctx.teardown(); + tokio::pin!(teardown_fut); + + // Actually tearing down should time out, because it's waiting for the + // handler to return (which in turn is waiting on us to signal it!). + if tokio::time::timeout(Duration::from_secs(2), &mut teardown_fut) + .await + .is_ok() + { + panic!("server shutdown returned while handler running"); + } + + // Signal the handler to complete. + release_endpoint_tx.send(()).await.unwrap(); + + // Now we can finish waiting for server shutdown. + teardown_fut.await; +}