Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make server shutdown wait on any Detached handler futures #702

Merged
merged 1 commit into from
Jun 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions dropshot/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ async-trait = "0.1.68"
base64 = "0.21.2"
bytes = "1"
camino = { version = "1.1.4", features = ["serde1"] }
debug-ignore = "1.0.5"
form_urlencoded = "1.1.0"
futures = "0.3.28"
hostname = "0.3.0"
Expand All @@ -37,6 +38,7 @@ slog-json = "2.6.1"
slog-term = "2.9.0"
tokio-rustls = "0.24.0"
toml = "0.7.4"
waitgroup = "0.1.2"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fine. I see this pulls in two deps. It's a shame if there's nothing in tokio or futures that can already do this. 🤷


[dependencies.chrono]
version = "0.4.26"
Expand Down
38 changes: 38 additions & 0 deletions dropshot/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use super::router::HttpRouter;
use super::ProbeRegistration;

use async_stream::stream;
use debug_ignore::DebugIgnore;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TIL, this crate seems useful!

use futures::future::{
BoxFuture, FusedFuture, FutureExt, Shared, TryFutureExt,
};
Expand All @@ -28,6 +29,7 @@ use hyper::Response;
use rustls;
use std::convert::TryFrom;
use std::future::Future;
use std::mem;
use std::net::SocketAddr;
use std::num::NonZeroU32;
use std::panic;
Expand All @@ -39,6 +41,7 @@ use tokio::net::{TcpListener, TcpStream};
use tokio::sync::oneshot;
use tokio_rustls::{server::TlsStream, TlsAcceptor};
use uuid::Uuid;
use waitgroup::WaitGroup;

use crate::config::HandlerTaskMode;
use crate::RequestInfo;
Expand Down Expand Up @@ -69,6 +72,9 @@ pub struct DropshotState<C: ServerContext> {
pub local_addr: SocketAddr,
/// Identifies how to accept TLS connections
pub(crate) tls_acceptor: Option<Arc<Mutex<TlsAcceptor>>>,
/// Worker for the handler_waitgroup associated with this server, allowing
/// graceful shutdown to wait for all handlers to complete.
pub(crate) handler_waitgroup_worker: DebugIgnore<waitgroup::Worker>,
}

impl<C: ServerContext> DropshotState<C> {
Expand Down Expand Up @@ -96,6 +102,7 @@ pub struct HttpServerStarter<C: ServerContext> {
app_state: Arc<DropshotState<C>>,
local_addr: SocketAddr,
wrapped: WrappedHttpServerStarter<C>,
handler_waitgroup: WaitGroup,
}

impl<C: ServerContext> HttpServerStarter<C> {
Expand Down Expand Up @@ -123,6 +130,7 @@ impl<C: ServerContext> HttpServerStarter<C> {
default_handler_task_mode: config.default_handler_task_mode,
};

let handler_waitgroup = WaitGroup::new();
let starter = match &tls {
Some(tls) => {
let (starter, app_state, local_addr) =
Expand All @@ -133,11 +141,13 @@ impl<C: ServerContext> HttpServerStarter<C> {
private,
log,
tls,
handler_waitgroup.worker(),
)?;
HttpServerStarter {
app_state,
local_addr,
wrapped: WrappedHttpServerStarter::Https(starter),
handler_waitgroup,
}
}
None => {
Expand All @@ -148,11 +158,13 @@ impl<C: ServerContext> HttpServerStarter<C> {
api,
private,
log,
handler_waitgroup.worker(),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is probably not something to worry about right now, but: I wonder how debuggable this is, either in situ or post mortem. Like if you walk up to a server that's shutting down and stuck waiting on one of these, would you have any way to know which request it was waiting on? I imagine eventually we'll want to elevate this to an API but that's probably way down the road.

Just to show what I mean, in the past I built something like this:
https://github.com/TritonDataCenter/node-vasync#barrier-coordinate-multiple-concurrent-operations
In that thing, each outstanding operation has a distinct name. You can take the whole object and expose that over a debug HTTP API to ask the server "what are the operations you're waiting on". This proved incredibly useful. But that was more for stuff like our own RSS, where you've got complicated long-running things that could get stuck somewhere. This particular case seems less likely to be a problem. It would just be neat to have a thing like waitgroup but where it was easy to ask it what it was waiting on.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is probably not something to worry about right now, but: I wonder how debuggable this is, either in situ or post mortem. Like if you walk up to a server that's shutting down and stuck waiting on one of these, would you have any way to know which request it was waiting on?

Probably not! Internally the waitgroup is just a wrapper around an AtomicWaker, which itself is just a glorified AtomicUsize; I think you could pretty quickly find the count, but assuming it's something like 1, I don't know how you'd find the guilty party.

)?;
HttpServerStarter {
app_state,
local_addr,
wrapped: WrappedHttpServerStarter::Http(starter),
handler_waitgroup,
}
}
};
Expand Down Expand Up @@ -182,6 +194,15 @@ impl<C: ServerContext> HttpServerStarter<C> {
});
info!(self.app_state.log, "listening");

let handler_waitgroup = self.handler_waitgroup;
let join_handle = async move {
// After the server shuts down, we also want to wait for any
// detached handler futures to complete.
() = join_handle.await?;
() = handler_waitgroup.wait().await;
Ok(())
};

#[cfg(feature = "usdt-probes")]
let probe_registration = match usdt::register_probes() {
Ok(_) => {
Expand Down Expand Up @@ -258,6 +279,7 @@ impl<C: ServerContext> InnerHttpServerStarter<C> {
api: ApiDescription<C>,
private: C,
log: &Logger,
handler_waitgroup_worker: waitgroup::Worker,
) -> Result<InnerHttpServerStarterNewReturn<C>, hyper::Error> {
let incoming = AddrIncoming::bind(&config.bind_address)?;
let local_addr = incoming.local_addr();
Expand All @@ -269,6 +291,7 @@ impl<C: ServerContext> InnerHttpServerStarter<C> {
log: log.new(o!("local_addr" => local_addr)),
local_addr,
tls_acceptor: None,
handler_waitgroup_worker: DebugIgnore(handler_waitgroup_worker),
});

let make_service = ServerConnectionHandler::new(app_state.clone());
Expand Down Expand Up @@ -546,6 +569,7 @@ impl<C: ServerContext> InnerHttpsServerStarter<C> {
private: C,
log: &Logger,
tls: &ConfigTls,
handler_waitgroup_worker: waitgroup::Worker,
) -> Result<InnerHttpsServerStarterNewReturn<C>, GenericError> {
let acceptor = Arc::new(Mutex::new(TlsAcceptor::from(Arc::new(
rustls::ServerConfig::try_from(tls)?,
Expand All @@ -572,6 +596,7 @@ impl<C: ServerContext> InnerHttpsServerStarter<C> {
log: logger,
local_addr,
tls_acceptor: Some(acceptor),
handler_waitgroup_worker: DebugIgnore(handler_waitgroup_worker),
});

let make_service = ServerConnectionHandler::new(Arc::clone(&app_state));
Expand Down Expand Up @@ -689,6 +714,14 @@ impl<C: ServerContext> HttpServer<C> {
.expect("cannot close twice")
.send(())
.expect("failed to send close signal");

// We _must_ explicitly drop our app state before awaiting join_future.
// If we are running handlers in `Detached` mode, our `app_state` has a
// `waitgroup::Worker` that they all clone, and `join_future` will await
// all of them being dropped. That means we must drop our "primary"
// clone of it, too!
mem::drop(self.app_state);

self.join_future.await
}
}
Expand Down Expand Up @@ -875,6 +908,7 @@ async fn http_request_handle<C: ServerContext>(
// to completion.
let (tx, rx) = oneshot::channel();
let request_log = rqctx.log.clone();
let worker = server.handler_waitgroup_worker.clone();
let handler_task = tokio::spawn(async move {
let request_log = rqctx.log.clone();
let result = handler.handle_request(rqctx, request).await;
Expand All @@ -887,6 +921,10 @@ async fn http_request_handle<C: ServerContext>(
"client disconnected before response returned"
);
}

// Drop our waitgroup worker, allowing graceful shutdown to
// complete (if it's waiting on us).
mem::drop(worker);
});

// The only way we can fail to receive on `rx` is if `tx` is
Expand Down
1 change: 1 addition & 0 deletions dropshot/src/test_util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ const ALLOWED_HEADERS: [AllowedHeader<'static>; 8] = [

/// ClientTestContext encapsulates several facilities associated with using an
/// HTTP client for testing.
#[derive(Clone)]
pub struct ClientTestContext {
/// actual bind address of the HTTP server under test
pub bind_address: SocketAddr,
Expand Down
5 changes: 5 additions & 0 deletions dropshot/src/websocket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -301,12 +301,14 @@ mod tests {
ExclusiveExtractor, HttpError, RequestContext, RequestInfo,
WebsocketUpgrade,
};
use debug_ignore::DebugIgnore;
use http::Request;
use hyper::Body;
use std::net::{IpAddr, Ipv6Addr, SocketAddr};
use std::num::NonZeroU32;
use std::sync::Arc;
use std::time::Duration;
use waitgroup::WaitGroup;

async fn ws_upg_from_mock_rqctx() -> Result<WebsocketUpgrade, HttpError> {
let log = slog::Logger::root(slog::Discard, slog::o!()).new(slog::o!());
Expand Down Expand Up @@ -336,6 +338,9 @@ mod tests {
8080,
),
tls_acceptor: None,
handler_waitgroup_worker: DebugIgnore(
WaitGroup::new().worker(),
),
}),
request: RequestInfo::new(&request, remote_addr),
path_variables: Default::default(),
Expand Down
16 changes: 14 additions & 2 deletions dropshot/tests/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ use dropshot::ConfigDropshot;
use dropshot::ConfigLogging;
use dropshot::ConfigLoggingIfExists;
use dropshot::ConfigLoggingLevel;
use dropshot::HandlerTaskMode;
use dropshot::ServerContext;
use slog::o;
use std::io::Write;
use tempfile::NamedTempFile;
Expand All @@ -16,18 +18,28 @@ pub fn test_setup(
test_name: &str,
api: ApiDescription<usize>,
) -> TestContext<usize> {
test_setup_with_context(test_name, api, 0_usize, HandlerTaskMode::Detached)
}

pub fn test_setup_with_context<Context: ServerContext>(
test_name: &str,
api: ApiDescription<Context>,
ctx: Context,
default_handler_task_mode: HandlerTaskMode,
) -> TestContext<Context> {
// The IP address to which we bind can be any local IP, but we use
// 127.0.0.1 because we know it's present, it shouldn't expose this server
// on any external network, and we don't have to go looking for some other
// local IP (likely in a platform-specific way). We specify port 0 to
// request any available port. This is important because we may run
// multiple concurrent tests, so any fixed port could result in spurious
// failures due to port conflicts.
let config_dropshot: ConfigDropshot = Default::default();
let config_dropshot: ConfigDropshot =
ConfigDropshot { default_handler_task_mode, ..Default::default() };

let logctx = create_log_context(test_name);
let log = logctx.log.new(o!());
TestContext::new(api, 0_usize, &config_dropshot, Some(logctx), log)
TestContext::new(api, ctx, &config_dropshot, Some(logctx), log)
}

pub fn create_log_context(test_name: &str) -> LogContext {
Expand Down
Loading