Skip to content

Commit

Permalink
Introduce async callbacks for set_select_certificate_callback
Browse files Browse the repository at this point in the history
  • Loading branch information
nox committed Aug 1, 2023
1 parent 9ddf5fe commit 3691580
Show file tree
Hide file tree
Showing 5 changed files with 208 additions and 4 deletions.
12 changes: 10 additions & 2 deletions boring/src/ssl/connector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,10 @@ impl ConnectConfiguration {
self.verify_hostname = verify_hostname;
}

pub fn ssl_mut(&mut self) -> &mut SslRef {
&mut self.ssl
}

/// Initiates a client-side TLS session on a stream.
///
/// The domain is used for SNI and hostname verification if enabled.
Expand Down Expand Up @@ -324,8 +328,12 @@ impl SslAcceptor {
where
S: Read + Write,
{
let ssl = Ssl::new(&self.0)?;
ssl.accept(stream)
self.new_session()?.accept(stream)
}

/// Creates a new TLS session, ready to accept a stream.
pub fn new_session(&self) -> Result<Ssl, ErrorStack> {
Ssl::new(&self.0)
}

/// Consumes the `SslAcceptor`, returning the inner raw `SslContext`.
Expand Down
13 changes: 13 additions & 0 deletions boring/src/ssl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,9 @@ pub struct SelectCertError(ffi::ssl_select_cert_result_t);
impl SelectCertError {
/// A fatal error occured and the handshake should be terminated.
pub const ERROR: Self = Self(ffi::ssl_select_cert_result_t::ssl_select_cert_error);

/// The operation could not be completed and should be retried later.
pub const RETRY: Self = Self(ffi::ssl_select_cert_result_t::ssl_select_cert_retry);
}

/// Extension types, to be used with `ClientHello::get_extension`.
Expand Down Expand Up @@ -3132,6 +3135,11 @@ impl<S> MidHandshakeSslStream<S> {
self.stream.ssl()
}

/// Returns a mutable reference to the `Ssl` of the stream.
pub fn ssl_mut(&mut self) -> &mut SslRef {
self.stream.ssl_mut()
}

/// Returns the underlying error which interrupted this handshake.
pub fn error(&self) -> &Error {
&self.error
Expand Down Expand Up @@ -3386,6 +3394,11 @@ impl<S> SslStream<S> {
pub fn ssl(&self) -> &SslRef {
&self.ssl
}

/// Returns a mutable reference to the `Ssl` object associated with this stream.
pub fn ssl_mut(&mut self) -> &mut SslRef {
&mut self.ssl
}
}

impl<S: Read + Write> Read for SslStream<S> {
Expand Down
1 change: 1 addition & 0 deletions tokio-boring/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ pq-experimental = ["boring/pq-experimental"]
[dependencies]
boring = { workspace = true }
boring-sys = { workspace = true }
once_cell = { workspace = true }
tokio = { workspace = true }

[dev-dependencies]
Expand Down
152 changes: 152 additions & 0 deletions tokio-boring/src/async_callbacks.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
use boring::ex_data::Index;
use boring::ssl::{self, ClientHello, Ssl, SslContextBuilder};
use once_cell::sync::Lazy;
use std::future::Future;
use std::pin::Pin;
use std::task::{ready, Context, Poll};

type BoxSelectCertFuture = Pin<
Box<
dyn Future<Output = Result<BoxSelectCertFinish, AsyncSelectCertError>>
+ Send
+ Sync
+ 'static,
>,
>;

type BoxSelectCertFinish =
Box<dyn FnOnce(ClientHello<'_>) -> Result<(), AsyncSelectCertError> + 'static>;

pub(crate) static TASK_CONTEXT_INDEX: Lazy<Index<Ssl, usize>> =
Lazy::new(|| Ssl::new_ex_index().unwrap());
pub(crate) static SELECT_CERT_FUTURE_INDEX: Lazy<Index<Ssl, BoxSelectCertFuture>> =
Lazy::new(|| Ssl::new_ex_index().unwrap());

/// Extensions to [`SslContextBuilder`].
///
/// This trait provides additional methods to use async callbacks with boring.
pub trait SslContextBuilderExt: private::Sealed {
/// Sets a callback that is called before most [`ClientHello`] processing
/// and before the decision whether to resume a session is made. The
/// callback may inspect the [`ClientHello`] and configure the connection.
///
/// This method uses a function that returns a future whose output is
/// itself a closure that will be passed [`ClientHello`] to configure
/// the connection based on the computations done in the future.
///
/// See [`SslContextBuilder::set_select_certificate_callback`] for the sync
/// setter of this callback.
fn set_async_select_certificate_callback<Init, Fut, Finish>(&mut self, callback: Init)
where
Init: Fn(&mut ClientHello<'_>) -> Result<Fut, AsyncSelectCertError> + Send + Sync + 'static,
Fut: Future<Output = Result<Finish, AsyncSelectCertError>> + Send + Sync + 'static,
Finish: FnOnce(ClientHello<'_>) -> Result<(), AsyncSelectCertError> + 'static;

/// Sets a callback that is called before most [`ClientHello`] processing
/// and before the decision whether to resume a session is made. The
/// callback may inspect the [`ClientHello`] and configure the connection.
///
/// This method uses a polling function.
///
/// See [`SslContextBuilder::set_select_certificate_callback`] for the sync
/// setter of this callback.
fn set_polling_select_certificate_callback<F>(
&mut self,
callback: impl Fn(ClientHello<'_>, &mut Context<'_>) -> Poll<Result<(), AsyncSelectCertError>>
+ Send
+ Sync
+ 'static,
);
}

impl SslContextBuilderExt for SslContextBuilder {
fn set_async_select_certificate_callback<Init, Fut, Finish>(&mut self, callback: Init)
where
Init: Fn(&mut ClientHello<'_>) -> Result<Fut, AsyncSelectCertError> + Send + Sync + 'static,
Fut: Future<Output = Result<Finish, AsyncSelectCertError>> + Send + Sync + 'static,
Finish: FnOnce(ClientHello<'_>) -> Result<(), AsyncSelectCertError> + 'static,
{
self.set_select_certificate_callback(async_select_certificate_callback(callback))
}

fn set_polling_select_certificate_callback<F>(
&mut self,
callback: impl Fn(ClientHello<'_>, &mut Context<'_>) -> Poll<Result<(), AsyncSelectCertError>>
+ Send
+ Sync
+ 'static,
) {
self.set_select_certificate_callback(polling_select_certificate_callback(callback));
}
}

/// A fatal error to be returned from select certificate callbacks.
pub struct AsyncSelectCertError;

fn async_select_certificate_callback<Init, Fut, Finish>(
callback: Init,
) -> impl Fn(ClientHello<'_>) -> Result<(), ssl::SelectCertError> + Send + Sync + 'static
where
Init: Fn(&mut ClientHello<'_>) -> Result<Fut, AsyncSelectCertError> + Send + Sync + 'static,
Fut: Future<Output = Result<Finish, AsyncSelectCertError>> + Send + Sync + 'static,
Finish: FnOnce(ClientHello<'_>) -> Result<(), AsyncSelectCertError> + 'static,
{
polling_select_certificate_callback(move |mut client_hello, cx| {
let fut_result = match client_hello
.ssl_mut()
.ex_data_mut(*SELECT_CERT_FUTURE_INDEX)
{
Some(fut) => ready!(fut.as_mut().poll(cx)),
None => {
let fut = callback(&mut client_hello)?;
let mut box_fut =
Box::pin(async move { Ok(Box::new(fut.await?) as BoxSelectCertFinish) })
as BoxSelectCertFuture;

match box_fut.as_mut().poll(cx) {
Poll::Ready(fut_result) => fut_result,
Poll::Pending => {
client_hello
.ssl_mut()
.set_ex_data(*SELECT_CERT_FUTURE_INDEX, box_fut);

return Poll::Pending;
}
}
}
};

// NOTE(nox): For memory usage concerns, maybe we should implement
// a way to remove the stored future from the `Ssl` value here?

Poll::Ready(fut_result?(client_hello))
})
}

fn polling_select_certificate_callback(
callback: impl Fn(ClientHello<'_>, &mut Context<'_>) -> Poll<Result<(), AsyncSelectCertError>>
+ Send
+ Sync
+ 'static,
) -> impl Fn(ClientHello<'_>) -> Result<(), ssl::SelectCertError> + Send + Sync + 'static {
move |client_hello| {
let cx = unsafe {
&mut *(*client_hello
.ssl()
.ex_data(*TASK_CONTEXT_INDEX)
.expect("task context should be set") as *mut Context<'_>)
};

match callback(client_hello, cx) {
Poll::Ready(Ok(())) => Ok(()),
Poll::Ready(Err(AsyncSelectCertError)) => Err(ssl::SelectCertError::ERROR),
Poll::Pending => Err(ssl::SelectCertError::RETRY),
}
}
}

mod private {
pub trait Sealed {}
}

impl private::Sealed for SslContextBuilder {}
34 changes: 32 additions & 2 deletions tokio-boring/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,22 @@ use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};

mod async_callbacks;

use self::async_callbacks::TASK_CONTEXT_INDEX;
pub use self::async_callbacks::{AsyncSelectCertError, SslContextBuilderExt};

/// Asynchronously performs a client-side TLS handshake over the provided stream.
pub async fn connect<S>(
config: ConnectConfiguration,
mut config: ConnectConfiguration,
domain: &str,
stream: S,
) -> Result<SslStream<S>, HandshakeError<S>>
where
S: AsyncRead + AsyncWrite + Unpin,
{
config.ssl_mut().set_ex_data(*TASK_CONTEXT_INDEX, 0);

handshake(|s| config.connect(domain, s), stream).await
}

Expand All @@ -43,7 +50,13 @@ pub async fn accept<S>(acceptor: &SslAcceptor, stream: S) -> Result<SslStream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
handshake(|s| acceptor.accept(s), stream).await
let mut ssl = acceptor
.new_session()
.map_err(|e| HandshakeError(e.into()))?;

ssl.set_ex_data(*TASK_CONTEXT_INDEX, 0);

handshake(|s| ssl.accept(s), stream).await
}

async fn handshake<F, S>(f: F, stream: S) -> Result<SslStream<S>, HandshakeError<S>>
Expand Down Expand Up @@ -163,6 +176,11 @@ impl<S> SslStream<S> {
self.0.ssl()
}

/// Returns a mutable reference to the `Ssl` object associated with this stream.
pub fn ssl_mut(&mut self) -> &mut SslRef {
self.0.ssl_mut()
}

/// Returns a shared reference to the underlying stream.
pub fn get_ref(&self) -> &S {
&self.0.get_ref().stream
Expand Down Expand Up @@ -367,13 +385,18 @@ where
stream: inner.stream,
context: ctx as *mut _ as usize,
};

match (inner.f)(stream) {
Ok(mut s) => {
s.get_mut().context = 0;
s.ssl_mut().set_ex_data(*TASK_CONTEXT_INDEX, 0);

Poll::Ready(Ok(StartedHandshake::Done(SslStream(s))))
}
Err(ssl::HandshakeError::WouldBlock(mut s)) => {
s.get_mut().context = 0;
s.ssl_mut().set_ex_data(*TASK_CONTEXT_INDEX, 0);

Poll::Ready(Ok(StartedHandshake::Mid(s)))
}
Err(e) => Poll::Ready(Err(HandshakeError(e))),
Expand All @@ -396,13 +419,20 @@ where
let mut s = self.0.take().expect("future polled after completion");

s.get_mut().context = ctx as *mut _ as usize;
s.ssl_mut()
.set_ex_data(*TASK_CONTEXT_INDEX, ctx as *mut _ as usize);

match s.handshake() {
Ok(mut s) => {
s.get_mut().context = 0;
s.ssl_mut().set_ex_data(*TASK_CONTEXT_INDEX, 0);

Poll::Ready(Ok(SslStream(s)))
}
Err(ssl::HandshakeError::WouldBlock(mut s)) => {
s.get_mut().context = 0;
s.ssl_mut().set_ex_data(*TASK_CONTEXT_INDEX, 0);

self.0 = Some(s);
Poll::Pending
}
Expand Down

0 comments on commit 3691580

Please sign in to comment.