Skip to content

Commit

Permalink
Make server shutdown wait on any Detached handler futures
Browse files Browse the repository at this point in the history
  • Loading branch information
jgallagher committed Jun 15, 2023
1 parent 3a42491 commit 9816cb2
Show file tree
Hide file tree
Showing 7 changed files with 177 additions and 2 deletions.
23 changes: 23 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions dropshot/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ async-trait = "0.1.68"
base64 = "0.21.2"
bytes = "1"
camino = { version = "1.1.4", features = ["serde1"] }
debug-ignore = "1.0.5"
form_urlencoded = "1.1.0"
futures = "0.3.28"
hostname = "0.3.0"
Expand All @@ -37,6 +38,7 @@ slog-json = "2.6.1"
slog-term = "2.9.0"
tokio-rustls = "0.24.0"
toml = "0.7.4"
waitgroup = "0.1.2"

[dependencies.chrono]
version = "0.4.26"
Expand Down
38 changes: 38 additions & 0 deletions dropshot/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use super::router::HttpRouter;
use super::ProbeRegistration;

use async_stream::stream;
use debug_ignore::DebugIgnore;
use futures::future::{
BoxFuture, FusedFuture, FutureExt, Shared, TryFutureExt,
};
Expand All @@ -28,6 +29,7 @@ use hyper::Response;
use rustls;
use std::convert::TryFrom;
use std::future::Future;
use std::mem;
use std::net::SocketAddr;
use std::num::NonZeroU32;
use std::panic;
Expand All @@ -39,6 +41,7 @@ use tokio::net::{TcpListener, TcpStream};
use tokio::sync::oneshot;
use tokio_rustls::{server::TlsStream, TlsAcceptor};
use uuid::Uuid;
use waitgroup::WaitGroup;

use crate::config::HandlerTaskMode;
use crate::RequestInfo;
Expand Down Expand Up @@ -69,6 +72,9 @@ pub struct DropshotState<C: ServerContext> {
pub local_addr: SocketAddr,
/// Identifies how to accept TLS connections
pub(crate) tls_acceptor: Option<Arc<Mutex<TlsAcceptor>>>,
/// Worker for the handler_waitgroup associated with this server, allowing
/// graceful shutdown to wait for all handlers to complete.
pub(crate) handler_waitgroup_worker: DebugIgnore<waitgroup::Worker>,
}

impl<C: ServerContext> DropshotState<C> {
Expand Down Expand Up @@ -96,6 +102,7 @@ pub struct HttpServerStarter<C: ServerContext> {
app_state: Arc<DropshotState<C>>,
local_addr: SocketAddr,
wrapped: WrappedHttpServerStarter<C>,
handler_waitgroup: WaitGroup,
}

impl<C: ServerContext> HttpServerStarter<C> {
Expand Down Expand Up @@ -123,6 +130,7 @@ impl<C: ServerContext> HttpServerStarter<C> {
default_handler_task_mode: config.default_handler_task_mode,
};

let handler_waitgroup = WaitGroup::new();
let starter = match &tls {
Some(tls) => {
let (starter, app_state, local_addr) =
Expand All @@ -133,11 +141,13 @@ impl<C: ServerContext> HttpServerStarter<C> {
private,
log,
tls,
handler_waitgroup.worker(),
)?;
HttpServerStarter {
app_state,
local_addr,
wrapped: WrappedHttpServerStarter::Https(starter),
handler_waitgroup,
}
}
None => {
Expand All @@ -148,11 +158,13 @@ impl<C: ServerContext> HttpServerStarter<C> {
api,
private,
log,
handler_waitgroup.worker(),
)?;
HttpServerStarter {
app_state,
local_addr,
wrapped: WrappedHttpServerStarter::Http(starter),
handler_waitgroup,
}
}
};
Expand Down Expand Up @@ -182,6 +194,15 @@ impl<C: ServerContext> HttpServerStarter<C> {
});
info!(self.app_state.log, "listening");

let handler_waitgroup = self.handler_waitgroup;
let join_handle = async move {
// After the server shuts down, we also want to wait for any
// detached handler futures to complete.
() = join_handle.await?;
() = handler_waitgroup.wait().await;
Ok(())
};

#[cfg(feature = "usdt-probes")]
let probe_registration = match usdt::register_probes() {
Ok(_) => {
Expand Down Expand Up @@ -258,6 +279,7 @@ impl<C: ServerContext> InnerHttpServerStarter<C> {
api: ApiDescription<C>,
private: C,
log: &Logger,
handler_waitgroup_worker: waitgroup::Worker,
) -> Result<InnerHttpServerStarterNewReturn<C>, hyper::Error> {
let incoming = AddrIncoming::bind(&config.bind_address)?;
let local_addr = incoming.local_addr();
Expand All @@ -269,6 +291,7 @@ impl<C: ServerContext> InnerHttpServerStarter<C> {
log: log.new(o!("local_addr" => local_addr)),
local_addr,
tls_acceptor: None,
handler_waitgroup_worker: DebugIgnore(handler_waitgroup_worker),
});

let make_service = ServerConnectionHandler::new(app_state.clone());
Expand Down Expand Up @@ -546,6 +569,7 @@ impl<C: ServerContext> InnerHttpsServerStarter<C> {
private: C,
log: &Logger,
tls: &ConfigTls,
handler_waitgroup_worker: waitgroup::Worker,
) -> Result<InnerHttpsServerStarterNewReturn<C>, GenericError> {
let acceptor = Arc::new(Mutex::new(TlsAcceptor::from(Arc::new(
rustls::ServerConfig::try_from(tls)?,
Expand All @@ -572,6 +596,7 @@ impl<C: ServerContext> InnerHttpsServerStarter<C> {
log: logger,
local_addr,
tls_acceptor: Some(acceptor),
handler_waitgroup_worker: DebugIgnore(handler_waitgroup_worker),
});

let make_service = ServerConnectionHandler::new(Arc::clone(&app_state));
Expand Down Expand Up @@ -689,6 +714,14 @@ impl<C: ServerContext> HttpServer<C> {
.expect("cannot close twice")
.send(())
.expect("failed to send close signal");

// We _must_ explicitly drop our app state before awaiting join_future.
// If we are running handlers in `Detached` mode, our `app_state` has a
// `waitgroup::Worker` that they all clone, and `join_future` will await
// all of them being dropped. That means we must drop our "primary"
// clone of it, too!
mem::drop(self.app_state);

self.join_future.await
}
}
Expand Down Expand Up @@ -875,6 +908,7 @@ async fn http_request_handle<C: ServerContext>(
// to completion.
let (tx, rx) = oneshot::channel();
let request_log = rqctx.log.clone();
let worker = server.handler_waitgroup_worker.clone();
let handler_task = tokio::spawn(async move {
let request_log = rqctx.log.clone();
let result = handler.handle_request(rqctx, request).await;
Expand All @@ -887,6 +921,10 @@ async fn http_request_handle<C: ServerContext>(
"client disconnected before response returned"
);
}

// Drop our waitgroup worker, allowing graceful shutdown to
// complete (if it's waiting on us).
mem::drop(worker);
});

// The only way we can fail to receive on `rx` is if `tx` is
Expand Down
1 change: 1 addition & 0 deletions dropshot/src/test_util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ const ALLOWED_HEADERS: [AllowedHeader<'static>; 8] = [

/// ClientTestContext encapsulates several facilities associated with using an
/// HTTP client for testing.
#[derive(Clone)]
pub struct ClientTestContext {
/// actual bind address of the HTTP server under test
pub bind_address: SocketAddr,
Expand Down
5 changes: 5 additions & 0 deletions dropshot/src/websocket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -301,12 +301,14 @@ mod tests {
ExclusiveExtractor, HttpError, RequestContext, RequestInfo,
WebsocketUpgrade,
};
use debug_ignore::DebugIgnore;
use http::Request;
use hyper::Body;
use std::net::{IpAddr, Ipv6Addr, SocketAddr};
use std::num::NonZeroU32;
use std::sync::Arc;
use std::time::Duration;
use waitgroup::WaitGroup;

async fn ws_upg_from_mock_rqctx() -> Result<WebsocketUpgrade, HttpError> {
let log = slog::Logger::root(slog::Discard, slog::o!()).new(slog::o!());
Expand Down Expand Up @@ -336,6 +338,9 @@ mod tests {
8080,
),
tls_acceptor: None,
handler_waitgroup_worker: DebugIgnore(
WaitGroup::new().worker(),
),
}),
request: RequestInfo::new(&request, remote_addr),
path_variables: Default::default(),
Expand Down
16 changes: 14 additions & 2 deletions dropshot/tests/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ use dropshot::ConfigDropshot;
use dropshot::ConfigLogging;
use dropshot::ConfigLoggingIfExists;
use dropshot::ConfigLoggingLevel;
use dropshot::HandlerTaskMode;
use dropshot::ServerContext;
use slog::o;
use std::io::Write;
use tempfile::NamedTempFile;
Expand All @@ -16,18 +18,28 @@ pub fn test_setup(
test_name: &str,
api: ApiDescription<usize>,
) -> TestContext<usize> {
test_setup_with_context(test_name, api, 0_usize, HandlerTaskMode::Detached)
}

pub fn test_setup_with_context<Context: ServerContext>(
test_name: &str,
api: ApiDescription<Context>,
ctx: Context,
default_handler_task_mode: HandlerTaskMode,
) -> TestContext<Context> {
// The IP address to which we bind can be any local IP, but we use
// 127.0.0.1 because we know it's present, it shouldn't expose this server
// on any external network, and we don't have to go looking for some other
// local IP (likely in a platform-specific way). We specify port 0 to
// request any available port. This is important because we may run
// multiple concurrent tests, so any fixed port could result in spurious
// failures due to port conflicts.
let config_dropshot: ConfigDropshot = Default::default();
let config_dropshot: ConfigDropshot =
ConfigDropshot { default_handler_task_mode, ..Default::default() };

let logctx = create_log_context(test_name);
let log = logctx.log.new(o!());
TestContext::new(api, 0_usize, &config_dropshot, Some(logctx), log)
TestContext::new(api, ctx, &config_dropshot, Some(logctx), log)
}

pub fn create_log_context(test_name: &str) -> LogContext {
Expand Down
Loading

0 comments on commit 9816cb2

Please sign in to comment.