Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Graceful shutdown futures lite #766

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
942 changes: 942 additions & 0 deletions .vscode/launch.json

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ async-sse = "4.0.1"
async-std = { version = "1.6.5", features = ["unstable"] }
async-trait = "0.1.41"
femme = { version = "2.1.1", optional = true }
futures-lite = "1.11.2"
futures-util = "0.3.6"
http-client = { version = "6.1.0", default-features = false }
http-types = "2.5.0"
Expand Down
76 changes: 76 additions & 0 deletions src/cancelation_token.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
use std::cell::RefCell;
use std::future::Future;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll, Waker};

#[derive(Debug)]
pub struct CancelationToken {
shared_state: Arc<Mutex<CancelationTokenState>>
}

#[derive(Debug)]
struct CancelationTokenState {
canceled: bool,
waker: Option<Waker>
}

#[derive(Debug)]
pub struct ReturnOnCancel<T> {
result: Mutex<RefCell<Option<T>>>,
shared_state: Arc<Mutex<CancelationTokenState>>
}

/// Future that allows gracefully shutting down the server
impl CancelationToken {
pub fn new() -> CancelationToken {
CancelationToken {
shared_state: Arc::new(Mutex::new(CancelationTokenState {
canceled: false,
waker: None
}))
}
}

/// Call to shut down the server
pub fn complete(&self) {
let mut shared_state = self.shared_state.lock().unwrap();

shared_state.canceled = true;
if let Some(waker) = shared_state.waker.take() {
waker.wake()
}
}

pub fn return_on_cancel<T>(&self, result: T) -> ReturnOnCancel<T> {
ReturnOnCancel {
result: Mutex::new(RefCell::new(Some(result))),
shared_state: self.shared_state.clone()
}
}
}

impl<T> Future for ReturnOnCancel<T> {
type Output = T;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut shared_state = self.shared_state.lock().unwrap();

if shared_state.canceled {
let result_refcell = self.result.lock().unwrap();
let result = result_refcell.replace(None).expect("Result was already returned");
Poll::Ready(result)
} else {
shared_state.waker = Some(cx.waker().clone());
Poll::Pending
}
}
}

impl Clone for CancelationToken {
fn clone(&self) -> Self {
CancelationToken {
shared_state: self.shared_state.clone()
}
}
}
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
#![doc(html_favicon_url = "https://yoshuawuyts.com/assets/http-rs/favicon.ico")]
#![doc(html_logo_url = "https://yoshuawuyts.com/assets/http-rs/logo-rounded.png")]

mod cancelation_token;
#[cfg(feature = "cookies")]
mod cookies;
mod endpoint;
Expand All @@ -85,6 +86,7 @@ pub mod utils;
#[cfg(feature = "sessions")]
pub mod sessions;

pub use cancelation_token::CancelationToken;
pub use endpoint::Endpoint;
pub use middleware::{Middleware, Next};
pub use redirect::Redirect;
Expand Down
19 changes: 15 additions & 4 deletions src/listener/concurrent_listener.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use crate::listener::{ListenInfo, Listener, ToListener};
use crate::Server;
use crate::{CancelationToken, Server};

use std::fmt::{self, Debug, Display, Formatter};

use async_std::io;
use async_std::{io, task};
use futures_util::stream::{futures_unordered::FuturesUnordered, StreamExt};

/// ConcurrentListener allows tide to listen on any number of transports
Expand Down Expand Up @@ -97,13 +97,24 @@ where
Ok(())
}

async fn accept(&mut self) -> io::Result<()> {
async fn accept(&mut self, cancelation_token: CancelationToken) -> io::Result<()> {
let mut futures_unordered = FuturesUnordered::new();

let mut cancelation_tokens = Vec::new();

for listener in self.listeners.iter_mut() {
futures_unordered.push(listener.accept());
let sub_cancelation_token = CancelationToken::new();
futures_unordered.push(listener.accept(sub_cancelation_token.clone()));
cancelation_tokens.push(sub_cancelation_token);
}

task::spawn(async move {
cancelation_token.return_on_cancel::<()>(()).await;
for sub_cancelation_token in cancelation_tokens.iter_mut() {
sub_cancelation_token.complete();
}
});

while let Some(result) = futures_unordered.next().await {
result?;
}
Expand Down
6 changes: 3 additions & 3 deletions src/listener/failover_listener.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::listener::{Listener, ToListener};
use crate::Server;
use crate::{CancelationToken, Server};

use std::fmt::{self, Debug, Display, Formatter};

Expand Down Expand Up @@ -123,11 +123,11 @@ where
))
}

async fn accept(&mut self) -> io::Result<()> {
async fn accept(&mut self, cancelation_token: CancelationToken) -> io::Result<()> {
match self.index {
Some(index) => {
let mut listener = self.listeners[index].take().expect("accept called twice");
listener.accept().await?;
listener.accept(cancelation_token).await?;
Ok(())
}
None => Err(io::Error::new(
Expand Down
8 changes: 4 additions & 4 deletions src/listener/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use std::fmt::{Debug, Display};
use async_std::io;
use async_trait::async_trait;

use crate::Server;
use crate::{CancelationToken, Server};

pub use concurrent_listener::ConcurrentListener;
pub use failover_listener::FailoverListener;
Expand Down Expand Up @@ -46,7 +46,7 @@ where

/// Start accepting incoming connections. This method must be called only
/// after `bind` has succeeded.
async fn accept(&mut self) -> io::Result<()>;
async fn accept(&mut self, cancelation_token: CancelationToken) -> io::Result<()>;

/// Expose information about the connection. This should always return valid
/// data after `bind` has succeeded.
Expand All @@ -63,8 +63,8 @@ where
self.as_mut().bind(app).await
}

async fn accept(&mut self) -> io::Result<()> {
self.as_mut().accept().await
async fn accept(&mut self, cancelation_token: CancelationToken) -> io::Result<()> {
self.as_mut().accept(cancelation_token).await
}

fn info(&self) -> Vec<ListenInfo> {
Expand Down
8 changes: 4 additions & 4 deletions src/listener/parsed_listener.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#[cfg(unix)]
use super::UnixListener;
use super::{ListenInfo, Listener, TcpListener};
use crate::Server;
use crate::{CancelationToken, Server};

use async_std::io;
use std::fmt::{self, Debug, Display, Formatter};
Expand Down Expand Up @@ -52,11 +52,11 @@ where
}
}

async fn accept(&mut self) -> io::Result<()> {
async fn accept(&mut self, cancelation_token: CancelationToken) -> io::Result<()> {
match self {
#[cfg(unix)]
Self::Unix(u) => u.accept().await,
Self::Tcp(t) => t.accept().await,
Self::Unix(u) => u.accept(cancelation_token).await,
Self::Tcp(t) => t.accept(cancelation_token).await,
}
}

Expand Down
8 changes: 5 additions & 3 deletions src/listener/tcp_listener.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
use super::{is_transient_error, ListenInfo};

use crate::listener::Listener;
use crate::{log, Server};
use crate::{CancelationToken, log, Server};

use std::fmt::{self, Display, Formatter};

use async_std::net::{self, SocketAddr, TcpStream};
use async_std::prelude::*;
use async_std::{io, task};

use futures_lite::future;

/// This represents a tide [Listener](crate::listener::Listener) that
/// wraps an [async_std::net::TcpListener]. It is implemented as an
/// enum in order to allow creation of a tide::listener::TcpListener
Expand Down Expand Up @@ -88,7 +90,7 @@ where
Ok(())
}

async fn accept(&mut self) -> io::Result<()> {
async fn accept(&mut self, cancelation_token: CancelationToken) -> io::Result<()> {
let server = self
.server
.take()
Expand All @@ -100,7 +102,7 @@ where

let mut incoming = listener.incoming();

while let Some(stream) = incoming.next().await {
while let Some(stream) = future::race(incoming.next(), cancelation_token.return_on_cancel(None)).await {
match stream {
Err(ref e) if is_transient_error(e) => continue,
Err(error) => {
Expand Down
4 changes: 2 additions & 2 deletions src/listener/unix_listener.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use super::{is_transient_error, ListenInfo};

use crate::listener::Listener;
use crate::{log, Server};
use crate::{CancelationToken, log, Server};

use std::fmt::{self, Display, Formatter};

Expand Down Expand Up @@ -86,7 +86,7 @@ where
Ok(())
}

async fn accept(&mut self) -> io::Result<()> {
async fn accept(&mut self, _cancelation_token: CancelationToken) -> io::Result<()> {
let server = self
.server
.take()
Expand Down
8 changes: 6 additions & 2 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::listener::{Listener, ToListener};
use crate::log;
use crate::middleware::{Middleware, Next};
use crate::router::{Router, Selection};
use crate::{Endpoint, Request, Route};
use crate::{CancelationToken, Endpoint, Request, Route};

/// An HTTP server.
///
Expand Down Expand Up @@ -206,12 +206,16 @@ where
/// # Ok(()) }) }
/// ```
pub async fn listen<L: ToListener<State>>(self, listener: L) -> io::Result<()> {
self.listen_with_cancelation_token(listener, CancelationToken::new()).await
}

pub async fn listen_with_cancelation_token<L: ToListener<State>>(self, listener: L, cancelation_token: CancelationToken) -> io::Result<()> {
let mut listener = listener.to_listener()?;
listener.bind(self).await?;
for info in listener.info().iter() {
log::info!("Server listening on {}", info);
}
listener.accept().await?;
listener.accept(cancelation_token).await?;
Ok(())
}

Expand Down
29 changes: 16 additions & 13 deletions tests/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,16 @@ use async_std::task;
use std::time::Duration;

use serde::{Deserialize, Serialize};
use tide::{Body, Request};
use tide::{Body, CancelationToken, Request};

// Note: Async tests are now supported. Refactor to avoid tast::block_on?
#[test]
fn hello_world() -> tide::Result<()> {

task::block_on(async {
let cancelation_token = CancelationToken::new();
let server_cancelation_token = cancelation_token.clone();

let port = test_utils::find_port().await;
let server = task::spawn(async move {
let mut app = tide::new();
Expand All @@ -18,22 +23,20 @@ fn hello_world() -> tide::Result<()> {
assert!(req.peer_addr().is_some());
Ok("says hello")
});
app.listen(("localhost", port)).await?;
app.listen_with_cancelation_token(("localhost", port), server_cancelation_token).await?;
Result::<(), http_types::Error>::Ok(())
});

let client = task::spawn(async move {
task::sleep(Duration::from_millis(100)).await;
let string = surf::get(format!("http://localhost:{}", port))
.body(Body::from_string("nori".to_string()))
.recv_string()
.await
.unwrap();
assert_eq!(string, "says hello");
Ok(())
});
task::sleep(Duration::from_millis(100)).await;
let string = surf::get(format!("http://localhost:{}", port))
.body(Body::from_string("nori".to_string()))
.recv_string()
.await
.unwrap();
assert_eq!(string, "says hello");

server.race(client).await
cancelation_token.complete();
server.await
})
}

Expand Down