Skip to content

Commit

Permalink
Introduce async callbacks
Browse files Browse the repository at this point in the history
We introduce tokio_boring::SslContextBuilderExt, with 2 methods:

* set_async_select_certificate_callback
* set_async_private_key_method
  • Loading branch information
nox committed Aug 24, 2023
1 parent cb27511 commit 7a6d57a
Show file tree
Hide file tree
Showing 7 changed files with 577 additions and 3 deletions.
13 changes: 13 additions & 0 deletions boring/src/ssl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,9 @@ pub struct SelectCertError(ffi::ssl_select_cert_result_t);
impl SelectCertError {
/// A fatal error occured and the handshake should be terminated.
pub const ERROR: Self = Self(ffi::ssl_select_cert_result_t::ssl_select_cert_error);

/// The operation could not be completed and should be retried later.
pub const RETRY: Self = Self(ffi::ssl_select_cert_result_t::ssl_select_cert_retry);
}

/// Extension types, to be used with `ClientHello::get_extension`.
Expand Down Expand Up @@ -3197,6 +3200,11 @@ impl<S> MidHandshakeSslStream<S> {
self.stream.ssl()
}

/// Returns a mutable reference to the `Ssl` of the stream.
pub fn ssl_mut(&mut self) -> &mut SslRef {
self.stream.ssl_mut()
}

/// Returns the underlying error which interrupted this handshake.
pub fn error(&self) -> &Error {
&self.error
Expand Down Expand Up @@ -3451,6 +3459,11 @@ impl<S> SslStream<S> {
pub fn ssl(&self) -> &SslRef {
&self.ssl
}

/// Returns a mutable reference to the `Ssl` object associated with this stream.
pub fn ssl_mut(&mut self) -> &mut SslRef {
&mut self.ssl
}
}

impl<S: Read + Write> Read for SslStream<S> {
Expand Down
1 change: 1 addition & 0 deletions tokio-boring/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ pq-experimental = ["boring/pq-experimental"]
[dependencies]
boring = { workspace = true }
boring-sys = { workspace = true }
once_cell = { workspace = true }
tokio = { workspace = true }

[dev-dependencies]
Expand Down
262 changes: 262 additions & 0 deletions tokio-boring/src/async_callbacks.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,262 @@
use boring::ex_data::Index;
use boring::ssl::{self, ClientHello, PrivateKeyMethod, Ssl, SslContextBuilder};
use once_cell::sync::Lazy;
use std::future::Future;
use std::pin::Pin;
use std::task::{ready, Context, Poll, Waker};

type BoxSelectCertFuture = ExDataFuture<Result<BoxSelectCertFinish, AsyncSelectCertError>>;

type BoxSelectCertFinish = Box<dyn FnOnce(ClientHello<'_>) -> Result<(), AsyncSelectCertError>>;

/// The type of futures returned by [`AsyncPrivateKeyMethod`] methods.
pub type BoxPrivateKeyMethodFuture =
ExDataFuture<Result<BoxPrivateKeyMethodFinish, AsyncPrivateKeyMethodError>>;

/// The type of callbacks returned by [`BoxPrivateKeyMethodFuture`].
pub type BoxPrivateKeyMethodFinish =
Box<dyn FnOnce(&mut ssl::SslRef, &mut [u8]) -> Result<usize, AsyncPrivateKeyMethodError>>;

type ExDataFuture<T> = Pin<Box<dyn Future<Output = T> + Send + Sync>>;

pub(crate) static TASK_WAKER_INDEX: Lazy<Index<Ssl, Option<Waker>>> =
Lazy::new(|| Ssl::new_ex_index().unwrap());
pub(crate) static SELECT_CERT_FUTURE_INDEX: Lazy<Index<Ssl, BoxSelectCertFuture>> =
Lazy::new(|| Ssl::new_ex_index().unwrap());
pub(crate) static SELECT_PRIVATE_KEY_METHOD_FUTURE_INDEX: Lazy<
Index<Ssl, BoxPrivateKeyMethodFuture>,
> = Lazy::new(|| Ssl::new_ex_index().unwrap());

/// Extensions to [`SslContextBuilder`].
///
/// This trait provides additional methods to use async callbacks with boring.
pub trait SslContextBuilderExt: private::Sealed {
/// Sets a callback that is called before most [`ClientHello`] processing
/// and before the decision whether to resume a session is made. The
/// callback may inspect the [`ClientHello`] and configure the connection.
///
/// This method uses a function that returns a future whose output is
/// itself a closure that will be passed [`ClientHello`] to configure
/// the connection based on the computations done in the future.
///
/// See [`SslContextBuilder::set_select_certificate_callback`] for the sync
/// setter of this callback.
fn set_async_select_certificate_callback<Init, Fut, Finish>(&mut self, callback: Init)
where
Init: Fn(&mut ClientHello<'_>) -> Result<Fut, AsyncSelectCertError> + Send + Sync + 'static,
Fut: Future<Output = Result<Finish, AsyncSelectCertError>> + Send + Sync + 'static,
Finish: FnOnce(ClientHello<'_>) -> Result<(), AsyncSelectCertError> + 'static;

/// Configures a custom private key method on the context.
///
/// See [`AsyncPrivateKeyMethod`] for more details.
fn set_async_private_key_method(&mut self, method: impl AsyncPrivateKeyMethod);
}

impl SslContextBuilderExt for SslContextBuilder {
fn set_async_select_certificate_callback<Init, Fut, Finish>(&mut self, callback: Init)
where
Init: Fn(&mut ClientHello<'_>) -> Result<Fut, AsyncSelectCertError> + Send + Sync + 'static,
Fut: Future<Output = Result<Finish, AsyncSelectCertError>> + Send + Sync + 'static,
Finish: FnOnce(ClientHello<'_>) -> Result<(), AsyncSelectCertError> + 'static,
{
self.set_select_certificate_callback(move |mut client_hello| {
let fut_poll_result = with_ex_data_future(
&mut client_hello,
*SELECT_CERT_FUTURE_INDEX,
ClientHello::ssl_mut,
|client_hello| {
let fut = callback(client_hello)?;

Ok(Box::pin(async move {
Ok(Box::new(fut.await?) as BoxSelectCertFinish)
}))
},
);

let fut_result = match fut_poll_result {
Poll::Ready(fut_result) => fut_result,
Poll::Pending => return Err(ssl::SelectCertError::RETRY),
};

let finish = fut_result.or(Err(ssl::SelectCertError::ERROR))?;

finish(client_hello).or(Err(ssl::SelectCertError::ERROR))
})
}

fn set_async_private_key_method(&mut self, method: impl AsyncPrivateKeyMethod) {
self.set_private_key_method(AsyncPrivateKeyMethodBridge(Box::new(method)));
}
}

/// A fatal error to be returned from async select certificate callbacks.
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub struct AsyncSelectCertError;

/// Describes async private key hooks. This is used to off-load signing
/// operations to a custom, potentially asynchronous, backend. Metadata about the
/// key such as the type and size are parsed out of the certificate.
///
/// See [`PrivateKeyMethod`] for the sync version of those hooks.
///
/// [`ssl_private_key_method_st`]: https://commondatastorage.googleapis.com/chromium-boringssl-docs/ssl.h.html#ssl_private_key_method_st
pub trait AsyncPrivateKeyMethod: Send + Sync + 'static {
/// Signs the message `input` using the specified signature algorithm.
///
/// This method uses a function that returns a future whose output is
/// itself a closure that will be passed `ssl` and `output`
/// to finish writing the signature.
///
/// See [`PrivateKeyMethod::sign`] for the sync version of this method.
fn sign(
&self,
ssl: &mut ssl::SslRef,
input: &[u8],
signature_algorithm: ssl::SslSignatureAlgorithm,
output: &mut [u8],
) -> Result<BoxPrivateKeyMethodFuture, AsyncPrivateKeyMethodError>;

/// Decrypts `input`.
///
/// This method uses a function that returns a future whose output is
/// itself a closure that will be passed `ssl` and `output`
/// to finish decrypting the input.
///
/// See [`PrivateKeyMethod::decrypt`] for the sync version of this method.
fn decrypt(
&self,
ssl: &mut ssl::SslRef,
input: &[u8],
output: &mut [u8],
) -> Result<BoxPrivateKeyMethodFuture, AsyncPrivateKeyMethodError>;
}

/// A fatal error to be returned from async private key methods.
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub struct AsyncPrivateKeyMethodError;

struct AsyncPrivateKeyMethodBridge(Box<dyn AsyncPrivateKeyMethod>);

impl PrivateKeyMethod for AsyncPrivateKeyMethodBridge {
fn sign(
&self,
ssl: &mut ssl::SslRef,
input: &[u8],
signature_algorithm: ssl::SslSignatureAlgorithm,
output: &mut [u8],
) -> Result<usize, ssl::PrivateKeyMethodError> {
with_private_key_method(ssl, output, |ssl, output| {
<dyn AsyncPrivateKeyMethod>::sign(&*self.0, ssl, input, signature_algorithm, output)
})
}

fn decrypt(
&self,
ssl: &mut ssl::SslRef,
input: &[u8],
output: &mut [u8],
) -> Result<usize, ssl::PrivateKeyMethodError> {
with_private_key_method(ssl, output, |ssl, output| {
<dyn AsyncPrivateKeyMethod>::decrypt(&*self.0, ssl, input, output)
})
}

fn complete(
&self,
ssl: &mut ssl::SslRef,
output: &mut [u8],
) -> Result<usize, ssl::PrivateKeyMethodError> {
with_private_key_method(ssl, output, |_, _| {
// This should never be reached, if it does, that's a bug on boring's side,
// which called `complete` without having been returned to with a pending
// future from `sign` or `decrypt`.

if cfg!(debug_assertions) {
panic!("BUG: boring called complete without a pending operation");
}

Err(AsyncPrivateKeyMethodError)
})
}
}

/// Creates and drives a private key method future.
///
/// This is a convenience function for the three methods of impl `PrivateKeyMethod``
/// for `dyn AsyncPrivateKeyMethod`. It relies on [`with_ex_data_future`] to
/// drive the future and then immediately calls the final [`BoxPrivateKeyMethodFinish`]
/// when the future is ready.
fn with_private_key_method(
ssl: &mut ssl::SslRef,
output: &mut [u8],
create_fut: impl FnOnce(
&mut ssl::SslRef,
&mut [u8],
) -> Result<BoxPrivateKeyMethodFuture, AsyncPrivateKeyMethodError>,
) -> Result<usize, ssl::PrivateKeyMethodError> {
let fut_poll_result = with_ex_data_future(
ssl,
*SELECT_PRIVATE_KEY_METHOD_FUTURE_INDEX,
|ssl| ssl,
|ssl| create_fut(ssl, output),
);

let fut_result = match fut_poll_result {
Poll::Ready(fut_result) => fut_result,
Poll::Pending => return Err(ssl::PrivateKeyMethodError::RETRY),
};

let finish = fut_result.or(Err(ssl::PrivateKeyMethodError::FAILURE))?;

finish(ssl, output).or(Err(ssl::PrivateKeyMethodError::FAILURE))
}

/// Creates and drives a future stored in `ssl_handle`'s `Ssl` at ex data index `index`.
///
/// This function won't even bother storing the future in `index` if the future
/// created by `create_fut` returns `Poll::Ready(_)` on the first poll call.
fn with_ex_data_future<H, T, E>(
ssl_handle: &mut H,
index: Index<ssl::Ssl, ExDataFuture<Result<T, E>>>,
get_ssl_mut: impl Fn(&mut H) -> &mut ssl::SslRef,
create_fut: impl FnOnce(&mut H) -> Result<ExDataFuture<Result<T, E>>, E>,
) -> Poll<Result<T, E>> {
let ssl = get_ssl_mut(ssl_handle);
let waker = ssl
.ex_data(*TASK_WAKER_INDEX)
.cloned()
.flatten()
.expect("task waker should be set");

let mut ctx = Context::from_waker(&waker);

match ssl.ex_data_mut(index) {
Some(fut) => {
let fut_result = ready!(fut.as_mut().poll(&mut ctx));

// NOTE(nox): For memory usage concerns, maybe we should implement
// a way to remove the stored future from the `Ssl` value here?

Poll::Ready(fut_result)
}
None => {
let mut fut = create_fut(ssl_handle)?;

match fut.as_mut().poll(&mut ctx) {
Poll::Ready(fut_result) => Poll::Ready(fut_result),
Poll::Pending => {
get_ssl_mut(ssl_handle).set_ex_data(index, fut);

Poll::Pending
}
}
}
}
}

mod private {
pub trait Sealed {}
}

impl private::Sealed for SslContextBuilder {}
5 changes: 2 additions & 3 deletions tokio-boring/src/bridge.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
//! Bridge between sync IO traits and async tokio IO traits.

///! Bridge between sync IO traits and async tokio IO traits.
use std::fmt;
use std::io;
use std::pin::Pin;
Expand Down Expand Up @@ -35,7 +34,7 @@ impl<S> AsyncStreamBridge<S> {
F: FnOnce(&mut Context<'_>, Pin<&mut S>) -> R,
{
let mut ctx =
Context::from_waker(self.waker.as_ref().expect("missing task context pointer"));
Context::from_waker(self.waker.as_ref().expect("BUG: missing waker in bridge"));

f(&mut ctx, Pin::new(&mut self.stream))
}
Expand Down
16 changes: 16 additions & 0 deletions tokio-boring/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,14 @@ use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};

mod async_callbacks;
mod bridge;

use self::async_callbacks::TASK_WAKER_INDEX;
pub use self::async_callbacks::{
AsyncPrivateKeyMethod, AsyncPrivateKeyMethodError, AsyncSelectCertError,
BoxPrivateKeyMethodFinish, BoxPrivateKeyMethodFuture, SslContextBuilderExt,
};
use self::bridge::AsyncStreamBridge;

/// Asynchronously performs a client-side TLS handshake over the provided stream.
Expand Down Expand Up @@ -90,6 +96,11 @@ impl<S> SslStream<S> {
self.0.ssl()
}

/// Returns a mutable reference to the `Ssl` object associated with this stream.
pub fn ssl_mut(&mut self) -> &mut SslRef {
self.0.ssl_mut()
}

/// Returns a shared reference to the underlying stream.
pub fn get_ref(&self) -> &S {
&self.0.get_ref().stream
Expand Down Expand Up @@ -285,15 +296,20 @@ where
let mut mid_handshake = self.0.take().expect("future polled after completion");

mid_handshake.get_mut().set_waker(Some(ctx));
mid_handshake
.ssl_mut()
.set_ex_data(*TASK_WAKER_INDEX, Some(ctx.waker().clone()));

match mid_handshake.handshake() {
Ok(mut stream) => {
stream.get_mut().set_waker(None);
stream.ssl_mut().set_ex_data(*TASK_WAKER_INDEX, None);

Poll::Ready(Ok(SslStream(stream)))
}
Err(ssl::HandshakeError::WouldBlock(mut mid_handshake)) => {
mid_handshake.get_mut().set_waker(None);
mid_handshake.ssl_mut().set_ex_data(*TASK_WAKER_INDEX, None);

self.0 = Some(mid_handshake);

Expand Down
Loading

0 comments on commit 7a6d57a

Please sign in to comment.