Skip to content

Commit

Permalink
feat(bindings): expose context on cert chain
Browse files Browse the repository at this point in the history
  • Loading branch information
jmayclin committed Feb 20, 2025
1 parent 4ae43ec commit 32f0167
Showing 1 changed file with 148 additions and 2 deletions.
150 changes: 148 additions & 2 deletions bindings/rust/extended/s2n-tls/src/cert_chain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
use crate::error::{Error, Fallible};
use s2n_tls_sys::*;
use std::{
any::Any,
ffi::c_void,
marker::PhantomData,
ptr::{self, NonNull},
sync::Arc,
Expand Down Expand Up @@ -45,20 +47,50 @@ impl CertificateChainHandle<'_> {
_lifetime: PhantomData,
}
}

fn internal_context_mut(&mut self) -> Option<&mut InternalContext> {
let context = unsafe { s2n_cert_chain_and_key_get_ctx(self.cert.as_ptr()) };
if context.is_null() {
None
} else {
Some(unsafe { &mut *(context as *mut InternalContext) })
}
}

fn internal_context(&self) -> Option<&InternalContext> {
let context = unsafe { s2n_cert_chain_and_key_get_ctx(self.cert.as_ptr()) };
if context.is_null() {
None
} else {
Some(unsafe { &*(context as *const InternalContext) })
}
}
}

impl Drop for CertificateChainHandle<'_> {
/// Corresponds to [s2n_cert_chain_and_key_free].
fn drop(&mut self) {
// ignore failures since there's not much we can do about it
if self.is_owned {
if let Some(internal_context) = self.internal_context_mut() {
drop(unsafe { Box::from_raw(internal_context) });
}
unsafe {
// ignore failures since there's not much we can do about it
let _ = s2n_cert_chain_and_key_free(self.cert.as_ptr()).into_result();
}
}
}
}

/// An internal container to hold the customer supplied application context.
///
/// We can't directly store the application context on the `s2n_cert_chain_and_key`,
/// because `*mut dyn Any` is a fat pointer (16 bytes) and can not be stored as
/// a c_void (8 bytes).
struct InternalContext {
context: Box<dyn Any + Send + Sync>,
}

pub struct Builder {
cert_handle: CertificateChainHandle<'static>,
}
Expand Down Expand Up @@ -125,6 +157,36 @@ impl Builder {
Ok(self)
}

/// Associates an arbitrary application context with the CertificateChain to
/// be later retrieved via [`CertificateChain::application_context()`].
///
/// This API will override an existing application context set on the Builder.
///
/// Corresponds to [s2n_cert_chain_and_key_set_ctx].
pub fn set_application_context<T: Send + Sync + 'static>(
&mut self,
app_context: T,
) -> Result<&mut Self, Error> {
let app_context = Box::new(app_context);
match self.cert_handle.internal_context_mut() {
// set_application_context was previously called, overwrite the existing value
Some(c) => c.context = app_context,
None => {
let internal_context = Box::new(InternalContext {
context: app_context,
});
unsafe {
s2n_cert_chain_and_key_set_ctx(
self.cert_handle.cert.as_ptr(),
Box::into_raw(internal_context) as *mut c_void,
)
.into_result()
}?;
}
}
Ok(self)
}

/// Return an immutable, internally-reference counted CertificateChain.
pub fn build(self) -> Result<CertificateChain<'static>, Error> {
// This method is currently infallible, but returning a result allows
Expand Down Expand Up @@ -177,6 +239,23 @@ impl CertificateChain<'_> {
}
}

/// Retrieves a reference to the application context associated with the
/// CertificateChain.
///
/// If an application context hasn't been set on the CertificateChain or if
/// the set application context isn't of type `T`, `None` will be returned.
///
/// To set a context on the connection, use [`Builder::set_application_context()`].
///
/// Corresponds to [s2n_connection_get_ctx].
pub fn application_context<T: Send + Sync + 'static>(&self) -> Option<&T> {
if let Some(internal_context) = self.cert_handle.internal_context() {
internal_context.context.downcast_ref()
} else {
None
}
}

/// Return the length of this certificate chain.
///
/// Note that the underlying API currently traverses a linked list, so this is a relatively
Expand Down Expand Up @@ -271,11 +350,13 @@ unsafe impl Send for Certificate<'_> {}

#[cfg(test)]
mod tests {
use crate::error::Error as S2NError;
use crate::testing::config_builder;
use crate::{
config,
error::{ErrorSource, ErrorType},
security::DEFAULT_TLS13,
testing::{InsecureAcceptAllCertificatesHandler, SniTestCerts, TestPair},
testing::{CertKeyPair, InsecureAcceptAllCertificatesHandler, SniTestCerts, TestPair},
};

use super::*;
Expand Down Expand Up @@ -495,4 +576,69 @@ mod tests {
fn assert_send_sync<T: 'static + Send + Sync>() {}
assert_send_sync::<CertificateChain<'static>>();
}

/// sanity check for basic cert chain context interactions
#[test]
fn application_context_workflow() -> Result<(), S2NError> {
let context: Arc<u64> = Arc::new(0xC0FFEE);
let handle = Arc::clone(&context);
assert_eq!(Arc::strong_count(&handle), 2);

let default = CertKeyPair::default();
let mut chain = Builder::new()?;
chain.load_pem(default.cert(), default.key())?;
chain.set_application_context(context)?;
let chain = chain.build()?;

let invalid_type_get = chain.application_context::<u64>();
assert!(invalid_type_get.is_none());

let retrieved_context = chain.application_context::<Arc<u64>>().unwrap();
assert_eq!(*retrieved_context.as_ref(), 0xC0FFEE);
assert_eq!(Arc::strong_count(&handle), 2);
drop(chain);
assert_eq!(Arc::strong_count(&handle), 1);
Ok(())
}

/// When an application context is overridden, it should be properly dropped.
#[test]
fn application_context_override() -> Result<(), S2NError> {
let initial: Arc<u64> = Arc::new(0xC0FFEE);
let initial_handle = Arc::clone(&initial);
let overridden: Arc<[u8; 6]> = Arc::new(*b"coffee");

let mut builder = Builder::new()?;
builder.set_application_context(initial)?;
assert_eq!(Arc::strong_count(&initial_handle), 2);
builder.set_application_context(overridden)?;
assert_eq!(Arc::strong_count(&initial_handle), 1);

Ok(())
}

/// An application context should be retrievable from a selected cert after
/// the handshake.
#[test]
fn application_context_from_selected_cert() -> Result<(), S2NError> {
let default = CertKeyPair::default();
let mut chain = Builder::new()?;
chain.load_pem(default.cert(), default.key())?;
chain.set_application_context(0xC0FFEE_u64)?;

let mut server_config = config::Builder::new();
server_config.load_chain(chain.build()?)?;

let client_config = config_builder(&crate::security::DEFAULT).unwrap();

let mut test_pair =
TestPair::from_configs(&client_config.build()?, &server_config.build()?);
test_pair.handshake()?;

let selected_cert = test_pair.server.selected_cert().unwrap();
let context = selected_cert.application_context::<u64>();
assert_eq!(context, Some(&0xC0FFEE_u64));

Ok(())
}
}

0 comments on commit 32f0167

Please sign in to comment.