From 32f01672ef000998ebb73ab5ba24b70e1da8794d Mon Sep 17 00:00:00 2001 From: James Mayclin Date: Tue, 18 Feb 2025 01:56:24 +0000 Subject: [PATCH] feat(bindings): expose context on cert chain --- .../rust/extended/s2n-tls/src/cert_chain.rs | 150 +++++++++++++++++- 1 file changed, 148 insertions(+), 2 deletions(-) diff --git a/bindings/rust/extended/s2n-tls/src/cert_chain.rs b/bindings/rust/extended/s2n-tls/src/cert_chain.rs index 4c5790cf654..671f7795e68 100644 --- a/bindings/rust/extended/s2n-tls/src/cert_chain.rs +++ b/bindings/rust/extended/s2n-tls/src/cert_chain.rs @@ -4,6 +4,8 @@ use crate::error::{Error, Fallible}; use s2n_tls_sys::*; use std::{ + any::Any, + ffi::c_void, marker::PhantomData, ptr::{self, NonNull}, sync::Arc, @@ -45,20 +47,50 @@ impl CertificateChainHandle<'_> { _lifetime: PhantomData, } } + + fn internal_context_mut(&mut self) -> Option<&mut InternalContext> { + let context = unsafe { s2n_cert_chain_and_key_get_ctx(self.cert.as_ptr()) }; + if context.is_null() { + None + } else { + Some(unsafe { &mut *(context as *mut InternalContext) }) + } + } + + fn internal_context(&self) -> Option<&InternalContext> { + let context = unsafe { s2n_cert_chain_and_key_get_ctx(self.cert.as_ptr()) }; + if context.is_null() { + None + } else { + Some(unsafe { &*(context as *const InternalContext) }) + } + } } impl Drop for CertificateChainHandle<'_> { /// Corresponds to [s2n_cert_chain_and_key_free]. fn drop(&mut self) { - // ignore failures since there's not much we can do about it if self.is_owned { + if let Some(internal_context) = self.internal_context_mut() { + drop(unsafe { Box::from_raw(internal_context) }); + } unsafe { + // ignore failures since there's not much we can do about it let _ = s2n_cert_chain_and_key_free(self.cert.as_ptr()).into_result(); } } } } +/// An internal container to hold the customer supplied application context. +/// +/// We can't directly store the application context on the `s2n_cert_chain_and_key`, +/// because `*mut dyn Any` is a fat pointer (16 bytes) and can not be stored as +/// a c_void (8 bytes). +struct InternalContext { + context: Box, +} + pub struct Builder { cert_handle: CertificateChainHandle<'static>, } @@ -125,6 +157,36 @@ impl Builder { Ok(self) } + /// Associates an arbitrary application context with the CertificateChain to + /// be later retrieved via [`CertificateChain::application_context()`]. + /// + /// This API will override an existing application context set on the Builder. + /// + /// Corresponds to [s2n_cert_chain_and_key_set_ctx]. + pub fn set_application_context( + &mut self, + app_context: T, + ) -> Result<&mut Self, Error> { + let app_context = Box::new(app_context); + match self.cert_handle.internal_context_mut() { + // set_application_context was previously called, overwrite the existing value + Some(c) => c.context = app_context, + None => { + let internal_context = Box::new(InternalContext { + context: app_context, + }); + unsafe { + s2n_cert_chain_and_key_set_ctx( + self.cert_handle.cert.as_ptr(), + Box::into_raw(internal_context) as *mut c_void, + ) + .into_result() + }?; + } + } + Ok(self) + } + /// Return an immutable, internally-reference counted CertificateChain. pub fn build(self) -> Result, Error> { // This method is currently infallible, but returning a result allows @@ -177,6 +239,23 @@ impl CertificateChain<'_> { } } + /// Retrieves a reference to the application context associated with the + /// CertificateChain. + /// + /// If an application context hasn't been set on the CertificateChain or if + /// the set application context isn't of type `T`, `None` will be returned. + /// + /// To set a context on the connection, use [`Builder::set_application_context()`]. + /// + /// Corresponds to [s2n_connection_get_ctx]. + pub fn application_context(&self) -> Option<&T> { + if let Some(internal_context) = self.cert_handle.internal_context() { + internal_context.context.downcast_ref() + } else { + None + } + } + /// Return the length of this certificate chain. /// /// Note that the underlying API currently traverses a linked list, so this is a relatively @@ -271,11 +350,13 @@ unsafe impl Send for Certificate<'_> {} #[cfg(test)] mod tests { + use crate::error::Error as S2NError; + use crate::testing::config_builder; use crate::{ config, error::{ErrorSource, ErrorType}, security::DEFAULT_TLS13, - testing::{InsecureAcceptAllCertificatesHandler, SniTestCerts, TestPair}, + testing::{CertKeyPair, InsecureAcceptAllCertificatesHandler, SniTestCerts, TestPair}, }; use super::*; @@ -495,4 +576,69 @@ mod tests { fn assert_send_sync() {} assert_send_sync::>(); } + + /// sanity check for basic cert chain context interactions + #[test] + fn application_context_workflow() -> Result<(), S2NError> { + let context: Arc = Arc::new(0xC0FFEE); + let handle = Arc::clone(&context); + assert_eq!(Arc::strong_count(&handle), 2); + + let default = CertKeyPair::default(); + let mut chain = Builder::new()?; + chain.load_pem(default.cert(), default.key())?; + chain.set_application_context(context)?; + let chain = chain.build()?; + + let invalid_type_get = chain.application_context::(); + assert!(invalid_type_get.is_none()); + + let retrieved_context = chain.application_context::>().unwrap(); + assert_eq!(*retrieved_context.as_ref(), 0xC0FFEE); + assert_eq!(Arc::strong_count(&handle), 2); + drop(chain); + assert_eq!(Arc::strong_count(&handle), 1); + Ok(()) + } + + /// When an application context is overridden, it should be properly dropped. + #[test] + fn application_context_override() -> Result<(), S2NError> { + let initial: Arc = Arc::new(0xC0FFEE); + let initial_handle = Arc::clone(&initial); + let overridden: Arc<[u8; 6]> = Arc::new(*b"coffee"); + + let mut builder = Builder::new()?; + builder.set_application_context(initial)?; + assert_eq!(Arc::strong_count(&initial_handle), 2); + builder.set_application_context(overridden)?; + assert_eq!(Arc::strong_count(&initial_handle), 1); + + Ok(()) + } + + /// An application context should be retrievable from a selected cert after + /// the handshake. + #[test] + fn application_context_from_selected_cert() -> Result<(), S2NError> { + let default = CertKeyPair::default(); + let mut chain = Builder::new()?; + chain.load_pem(default.cert(), default.key())?; + chain.set_application_context(0xC0FFEE_u64)?; + + let mut server_config = config::Builder::new(); + server_config.load_chain(chain.build()?)?; + + let client_config = config_builder(&crate::security::DEFAULT).unwrap(); + + let mut test_pair = + TestPair::from_configs(&client_config.build()?, &server_config.build()?); + test_pair.handshake()?; + + let selected_cert = test_pair.server.selected_cert().unwrap(); + let context = selected_cert.application_context::(); + assert_eq!(context, Some(&0xC0FFEE_u64)); + + Ok(()) + } }