diff --git a/boring/src/ssl/connector.rs b/boring/src/ssl/connector.rs index 7be740d37..549ebadae 100644 --- a/boring/src/ssl/connector.rs +++ b/boring/src/ssl/connector.rs @@ -189,6 +189,10 @@ impl ConnectConfiguration { self.verify_hostname = verify_hostname; } + pub fn ssl_mut(&mut self) -> &mut SslRef { + &mut self.ssl + } + /// Initiates a client-side TLS session on a stream. /// /// The domain is used for SNI and hostname verification if enabled. @@ -324,8 +328,12 @@ impl SslAcceptor { where S: Read + Write, { - let ssl = Ssl::new(&self.0)?; - ssl.accept(stream) + self.new_session()?.accept(stream) + } + + /// Creates a new TLS session, ready to accept a stream. + pub fn new_session(&self) -> Result { + Ssl::new(&self.0) } /// Consumes the `SslAcceptor`, returning the inner raw `SslContext`. diff --git a/boring/src/ssl/mod.rs b/boring/src/ssl/mod.rs index 242cafeed..7f1de9847 100644 --- a/boring/src/ssl/mod.rs +++ b/boring/src/ssl/mod.rs @@ -478,6 +478,9 @@ pub struct SelectCertError(ffi::ssl_select_cert_result_t); impl SelectCertError { /// A fatal error occured and the handshake should be terminated. pub const ERROR: Self = Self(ffi::ssl_select_cert_result_t::ssl_select_cert_error); + + /// The operation could not be completed and should be retried later. + pub const RETRY: Self = Self(ffi::ssl_select_cert_result_t::ssl_select_cert_retry); } /// Extension types, to be used with `ClientHello::get_extension`. @@ -3132,6 +3135,11 @@ impl MidHandshakeSslStream { self.stream.ssl() } + /// Returns a mutable reference to the `Ssl` of the stream. + pub fn ssl_mut(&mut self) -> &mut SslRef { + self.stream.ssl_mut() + } + /// Returns the underlying error which interrupted this handshake. pub fn error(&self) -> &Error { &self.error @@ -3386,6 +3394,11 @@ impl SslStream { pub fn ssl(&self) -> &SslRef { &self.ssl } + + /// Returns a mutable reference to the `Ssl` object associated with this stream. + pub fn ssl_mut(&mut self) -> &mut SslRef { + &mut self.ssl + } } impl Read for SslStream { diff --git a/tokio-boring/Cargo.toml b/tokio-boring/Cargo.toml index 0da4a0586..71b5c2fca 100644 --- a/tokio-boring/Cargo.toml +++ b/tokio-boring/Cargo.toml @@ -31,6 +31,7 @@ pq-experimental = ["boring/pq-experimental"] [dependencies] boring = { workspace = true } boring-sys = { workspace = true } +once_cell = { workspace = true } tokio = { workspace = true } [dev-dependencies] diff --git a/tokio-boring/src/async_callbacks.rs b/tokio-boring/src/async_callbacks.rs new file mode 100644 index 000000000..83fd69aed --- /dev/null +++ b/tokio-boring/src/async_callbacks.rs @@ -0,0 +1,152 @@ +use boring::ex_data::Index; +use boring::ssl::{self, ClientHello, Ssl, SslContextBuilder}; +use once_cell::sync::Lazy; +use std::future::Future; +use std::pin::Pin; +use std::task::{ready, Context, Poll}; + +type BoxSelectCertFuture = Pin< + Box< + dyn Future> + + Send + + Sync + + 'static, + >, +>; + +type BoxSelectCertFinish = + Box) -> Result<(), AsyncSelectCertError> + 'static>; + +pub(crate) static TASK_CONTEXT_INDEX: Lazy> = + Lazy::new(|| Ssl::new_ex_index().unwrap()); +pub(crate) static SELECT_CERT_FUTURE_INDEX: Lazy> = + Lazy::new(|| Ssl::new_ex_index().unwrap()); + +/// Extensions to [`SslContextBuilder`]. +/// +/// This trait provides additional methods to use async callbacks with boring. +pub trait SslContextBuilderExt: private::Sealed { + /// Sets a callback that is called before most [`ClientHello`] processing + /// and before the decision whether to resume a session is made. The + /// callback may inspect the [`ClientHello`] and configure the connection. + /// + /// This method uses a function that returns a future whose output is + /// itself a closure that will be passed [`ClientHello`] to configure + /// the connection based on the computations done in the future. + /// + /// See [`SslContextBuilder::set_select_certificate_callback`] for the sync + /// setter of this callback. + fn set_async_select_certificate_callback(&mut self, callback: Init) + where + Init: Fn(&mut ClientHello<'_>) -> Result + Send + Sync + 'static, + Fut: Future> + Send + Sync + 'static, + Finish: FnOnce(ClientHello<'_>) -> Result<(), AsyncSelectCertError> + 'static; + + /// Sets a callback that is called before most [`ClientHello`] processing + /// and before the decision whether to resume a session is made. The + /// callback may inspect the [`ClientHello`] and configure the connection. + /// + /// This method uses a polling function. + /// + /// See [`SslContextBuilder::set_select_certificate_callback`] for the sync + /// setter of this callback. + fn set_polling_select_certificate_callback( + &mut self, + callback: impl Fn(ClientHello<'_>, &mut Context<'_>) -> Poll> + + Send + + Sync + + 'static, + ); +} + +impl SslContextBuilderExt for SslContextBuilder { + fn set_async_select_certificate_callback(&mut self, callback: Init) + where + Init: Fn(&mut ClientHello<'_>) -> Result + Send + Sync + 'static, + Fut: Future> + Send + Sync + 'static, + Finish: FnOnce(ClientHello<'_>) -> Result<(), AsyncSelectCertError> + 'static, + { + self.set_select_certificate_callback(async_select_certificate_callback(callback)) + } + + fn set_polling_select_certificate_callback( + &mut self, + callback: impl Fn(ClientHello<'_>, &mut Context<'_>) -> Poll> + + Send + + Sync + + 'static, + ) { + self.set_select_certificate_callback(polling_select_certificate_callback(callback)); + } +} + +/// A fatal error to be returned from select certificate callbacks. +pub struct AsyncSelectCertError; + +fn async_select_certificate_callback( + callback: Init, +) -> impl Fn(ClientHello<'_>) -> Result<(), ssl::SelectCertError> + Send + Sync + 'static +where + Init: Fn(&mut ClientHello<'_>) -> Result + Send + Sync + 'static, + Fut: Future> + Send + Sync + 'static, + Finish: FnOnce(ClientHello<'_>) -> Result<(), AsyncSelectCertError> + 'static, +{ + polling_select_certificate_callback(move |mut client_hello, cx| { + let fut_result = match client_hello + .ssl_mut() + .ex_data_mut(*SELECT_CERT_FUTURE_INDEX) + { + Some(fut) => ready!(fut.as_mut().poll(cx)), + None => { + let fut = callback(&mut client_hello)?; + let mut box_fut = + Box::pin(async move { Ok(Box::new(fut.await?) as BoxSelectCertFinish) }) + as BoxSelectCertFuture; + + match box_fut.as_mut().poll(cx) { + Poll::Ready(fut_result) => fut_result, + Poll::Pending => { + client_hello + .ssl_mut() + .set_ex_data(*SELECT_CERT_FUTURE_INDEX, box_fut); + + return Poll::Pending; + } + } + } + }; + + // NOTE(nox): For memory usage concerns, maybe we should implement + // a way to remove the stored future from the `Ssl` value here? + + Poll::Ready(fut_result?(client_hello)) + }) +} + +fn polling_select_certificate_callback( + callback: impl Fn(ClientHello<'_>, &mut Context<'_>) -> Poll> + + Send + + Sync + + 'static, +) -> impl Fn(ClientHello<'_>) -> Result<(), ssl::SelectCertError> + Send + Sync + 'static { + move |client_hello| { + let cx = unsafe { + &mut *(*client_hello + .ssl() + .ex_data(*TASK_CONTEXT_INDEX) + .expect("task context should be set") as *mut Context<'_>) + }; + + match callback(client_hello, cx) { + Poll::Ready(Ok(())) => Ok(()), + Poll::Ready(Err(AsyncSelectCertError)) => Err(ssl::SelectCertError::ERROR), + Poll::Pending => Err(ssl::SelectCertError::RETRY), + } + } +} + +mod private { + pub trait Sealed {} +} + +impl private::Sealed for SslContextBuilder {} diff --git a/tokio-boring/src/lib.rs b/tokio-boring/src/lib.rs index a0dd58c52..336b14474 100644 --- a/tokio-boring/src/lib.rs +++ b/tokio-boring/src/lib.rs @@ -26,15 +26,22 @@ use std::pin::Pin; use std::task::{Context, Poll}; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +mod async_callbacks; + +use self::async_callbacks::TASK_CONTEXT_INDEX; +pub use self::async_callbacks::{AsyncSelectCertError, SslContextBuilderExt}; + /// Asynchronously performs a client-side TLS handshake over the provided stream. pub async fn connect( - config: ConnectConfiguration, + mut config: ConnectConfiguration, domain: &str, stream: S, ) -> Result, HandshakeError> where S: AsyncRead + AsyncWrite + Unpin, { + config.ssl_mut().set_ex_data(*TASK_CONTEXT_INDEX, 0); + handshake(|s| config.connect(domain, s), stream).await } @@ -43,7 +50,13 @@ pub async fn accept(acceptor: &SslAcceptor, stream: S) -> Result where S: AsyncRead + AsyncWrite + Unpin, { - handshake(|s| acceptor.accept(s), stream).await + let mut ssl = acceptor + .new_session() + .map_err(|e| HandshakeError(e.into()))?; + + ssl.set_ex_data(*TASK_CONTEXT_INDEX, 0); + + handshake(|s| ssl.accept(s), stream).await } async fn handshake(f: F, stream: S) -> Result, HandshakeError> @@ -163,6 +176,11 @@ impl SslStream { self.0.ssl() } + /// Returns a mutable reference to the `Ssl` object associated with this stream. + pub fn ssl_mut(&mut self) -> &mut SslRef { + self.0.ssl_mut() + } + /// Returns a shared reference to the underlying stream. pub fn get_ref(&self) -> &S { &self.0.get_ref().stream @@ -367,13 +385,18 @@ where stream: inner.stream, context: ctx as *mut _ as usize, }; + match (inner.f)(stream) { Ok(mut s) => { s.get_mut().context = 0; + s.ssl_mut().set_ex_data(*TASK_CONTEXT_INDEX, 0); + Poll::Ready(Ok(StartedHandshake::Done(SslStream(s)))) } Err(ssl::HandshakeError::WouldBlock(mut s)) => { s.get_mut().context = 0; + s.ssl_mut().set_ex_data(*TASK_CONTEXT_INDEX, 0); + Poll::Ready(Ok(StartedHandshake::Mid(s))) } Err(e) => Poll::Ready(Err(HandshakeError(e))), @@ -396,13 +419,20 @@ where let mut s = self.0.take().expect("future polled after completion"); s.get_mut().context = ctx as *mut _ as usize; + s.ssl_mut() + .set_ex_data(*TASK_CONTEXT_INDEX, ctx as *mut _ as usize); + match s.handshake() { Ok(mut s) => { s.get_mut().context = 0; + s.ssl_mut().set_ex_data(*TASK_CONTEXT_INDEX, 0); + Poll::Ready(Ok(SslStream(s))) } Err(ssl::HandshakeError::WouldBlock(mut s)) => { s.get_mut().context = 0; + s.ssl_mut().set_ex_data(*TASK_CONTEXT_INDEX, 0); + self.0 = Some(s); Poll::Pending }