Skip to content

Commit

Permalink
fix(bindings): remove mutation behind Arc (#5124)
Browse files Browse the repository at this point in the history
  • Loading branch information
jmayclin authored Feb 19, 2025
1 parent 2c47d43 commit 4ae43ec
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 63 deletions.
106 changes: 49 additions & 57 deletions bindings/rust/extended/s2n-tls/src/cert_chain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,34 +13,41 @@ use std::{
///
/// [CertificateChain] is internally reference counted. The reference counted `T`
/// must have a drop implementation.
struct CertificateChainHandle {
cert: NonNull<s2n_cert_chain_and_key>,
pub(crate) struct CertificateChainHandle<'a> {
pub(crate) cert: NonNull<s2n_cert_chain_and_key>,
is_owned: bool,
_lifetime: PhantomData<&'a s2n_cert_chain_and_key>,
}

// # Safety
//
// s2n_cert_chain_and_key objects can be sent across threads.
unsafe impl Send for CertificateChainHandle {}
unsafe impl Sync for CertificateChainHandle {}
unsafe impl Send for CertificateChainHandle<'_> {}
unsafe impl Sync for CertificateChainHandle<'_> {}

impl CertificateChainHandle {
fn from_owned(cert: NonNull<s2n_cert_chain_and_key>) -> Self {
Self {
cert,
impl CertificateChainHandle<'_> {
/// Allocate an uninitialized CertificateChainHandle.
///
/// Corresponds to [s2n_cert_chain_and_key_new].
pub(crate) fn allocate() -> Result<CertificateChainHandle<'static>, crate::error::Error> {
crate::init::init();
Ok(CertificateChainHandle {
cert: unsafe { s2n_cert_chain_and_key_new().into_result() }?,
is_owned: true,
}
_lifetime: PhantomData,
})
}

fn from_reference(cert: NonNull<s2n_cert_chain_and_key>) -> Self {
Self {
cert,
is_owned: false,
_lifetime: PhantomData,
}
}
}

impl Drop for CertificateChainHandle {
impl Drop for CertificateChainHandle<'_> {
/// Corresponds to [s2n_cert_chain_and_key_free].
fn drop(&mut self) {
// ignore failures since there's not much we can do about it
Expand All @@ -53,13 +60,13 @@ impl Drop for CertificateChainHandle {
}

pub struct Builder {
cert: CertificateChain<'static>,
cert_handle: CertificateChainHandle<'static>,
}

impl Builder {
pub fn new() -> Result<Self, Error> {
Ok(Self {
cert: CertificateChain::allocate_owned()?,
cert_handle: CertificateChainHandle::allocate()?,
})
}

Expand All @@ -73,7 +80,7 @@ impl Builder {
// `private_key_pem` are not modified.
// https://github.com/aws/s2n-tls/issues/4140
s2n_cert_chain_and_key_load_pem_bytes(
self.cert.as_mut_ptr(),
self.cert_handle.cert.as_ptr(),
chain.as_ptr() as *mut _,
chain.len() as u32,
key.as_ptr() as *mut _,
Expand All @@ -95,7 +102,7 @@ impl Builder {
// is not modified
// https://github.com/aws/s2n-tls/issues/4140
s2n_cert_chain_and_key_load_public_pem_bytes(
self.cert.as_mut_ptr(),
self.cert_handle.cert.as_ptr(),
chain.as_ptr() as *mut _,
chain.len() as u32,
)
Expand All @@ -109,7 +116,7 @@ impl Builder {
pub fn set_ocsp_data(&mut self, data: &[u8]) -> Result<&mut Self, Error> {
unsafe {
s2n_cert_chain_and_key_set_ocsp_data(
self.cert.as_mut_ptr(),
self.cert_handle.cert.as_ptr(),
data.as_ptr(),
data.len() as u32,
)
Expand All @@ -122,7 +129,7 @@ impl Builder {
pub fn build(self) -> Result<CertificateChain<'static>, Error> {
// This method is currently infallible, but returning a result allows
// us to add validation in the future.
Ok(self.cert)
Ok(CertificateChain::from_allocated(self.cert_handle))
}
}

Expand All @@ -135,22 +142,16 @@ impl Builder {
// safe to mutate CertificateChains.
#[derive(Clone)]
pub struct CertificateChain<'a> {
ptr: Arc<CertificateChainHandle>,
_lifetime: PhantomData<&'a s2n_cert_chain_and_key>,
cert_handle: Arc<CertificateChainHandle<'a>>,
}

impl CertificateChain<'_> {
/// This allocates a new certificate chain from s2n.
///
/// Corresponds to [s2n_cert_chain_and_key_new].
pub(crate) fn allocate_owned() -> Result<CertificateChain<'static>, Error> {
crate::init::init();
unsafe {
let ptr = s2n_cert_chain_and_key_new().into_result()?;
Ok(CertificateChain {
ptr: Arc::new(CertificateChainHandle::from_owned(ptr)),
_lifetime: PhantomData,
})
/// Construct a CertificateChain from an allocated [CertificateChainHandle].
pub(crate) fn from_allocated(
handle: CertificateChainHandle<'static>,
) -> CertificateChain<'static> {
CertificateChain {
cert_handle: Arc::new(handle),
}
}

Expand All @@ -162,8 +163,7 @@ impl CertificateChain<'_> {
let handle = Arc::new(CertificateChainHandle::from_reference(ptr));

CertificateChain {
ptr: handle,
_lifetime: PhantomData,
cert_handle: handle,
}
}

Expand Down Expand Up @@ -202,16 +202,8 @@ impl CertificateChain<'_> {
self.len() == 0
}

/// SAFETY: Only one instance of `CertificateChain` may exist when this method
/// is called. s2n_cert_chain_and_key is not thread-safe, so it is not safe
/// to mutate the certificate chain if references are held across multiple threads.
pub(crate) unsafe fn as_mut_ptr(&mut self) -> *mut s2n_cert_chain_and_key {
debug_assert_eq!(Arc::strong_count(&self.ptr), 1);
self.ptr.cert.as_ptr()
}

pub(crate) fn as_ptr(&self) -> *const s2n_cert_chain_and_key {
self.ptr.cert.as_ptr() as *const _
self.cert_handle.cert.as_ptr() as *const _
}
}

Expand Down Expand Up @@ -339,28 +331,28 @@ mod tests {
#[test]
fn reference_count_increment() -> Result<(), crate::error::Error> {
let cert = SniTestCerts::AlligatorRsa.get().into_certificate_chain();
assert_eq!(Arc::strong_count(&cert.ptr), 1);
assert_eq!(Arc::strong_count(&cert.cert_handle), 1);

{
let mut server = config::Builder::new();
server.load_chain(cert.clone())?;

// after being added, the reference count should have increased
assert_eq!(Arc::strong_count(&cert.ptr), 2);
assert_eq!(Arc::strong_count(&cert.cert_handle), 2);
}

// after the config goes out of scope and is dropped, the ref count should
// decrement
assert_eq!(Arc::strong_count(&cert.ptr), 1);
assert_eq!(Arc::strong_count(&cert.cert_handle), 1);
Ok(())
}

#[test]
fn cert_is_dropped() {
let weak_ref = {
let cert = SniTestCerts::AlligatorEcdsa.get().into_certificate_chain();
assert_eq!(Arc::strong_count(&cert.ptr), 1);
Arc::downgrade(&cert.ptr)
assert_eq!(Arc::strong_count(&cert.cert_handle), 1);
Arc::downgrade(&cert.cert_handle)
};
assert_eq!(weak_ref.strong_count(), 0);
assert!(weak_ref.upgrade().is_none());
Expand All @@ -377,17 +369,17 @@ mod tests {
let mut test_pair_2 =
sni_test_pair(vec![cert.clone()], None, &[SniTestCerts::AlligatorRsa])?;

assert_eq!(Arc::strong_count(&cert.ptr), 3);
assert_eq!(Arc::strong_count(&cert.cert_handle), 3);

assert!(test_pair_1.handshake().is_ok());
assert!(test_pair_2.handshake().is_ok());

assert_eq!(Arc::strong_count(&cert.ptr), 3);
assert_eq!(Arc::strong_count(&cert.cert_handle), 3);

drop(test_pair_1);
assert_eq!(Arc::strong_count(&cert.ptr), 2);
assert_eq!(Arc::strong_count(&cert.cert_handle), 2);
drop(test_pair_2);
assert_eq!(Arc::strong_count(&cert.ptr), 1);
assert_eq!(Arc::strong_count(&cert.cert_handle), 1);
Ok(())
}

Expand All @@ -396,7 +388,7 @@ mod tests {
// 5 certs in the maximum allowed, 6 should error.
const FAILING_NUMBER: usize = 6;
let certs = vec![SniTestCerts::AlligatorRsa.get().into_certificate_chain(); FAILING_NUMBER];
assert_eq!(Arc::strong_count(&certs[0].ptr), FAILING_NUMBER);
assert_eq!(Arc::strong_count(&certs[0].cert_handle), FAILING_NUMBER);

let mut config = config::Builder::new();
let err = config.set_default_chains(certs.clone()).err().unwrap();
Expand All @@ -405,7 +397,7 @@ mod tests {

// The config should not hold a reference when the error was detected
// in the bindings
assert_eq!(Arc::strong_count(&certs[0].ptr), FAILING_NUMBER);
assert_eq!(Arc::strong_count(&certs[0].cert_handle), FAILING_NUMBER);

Ok(())
}
Expand All @@ -430,8 +422,8 @@ mod tests {
&test_pair.client.peer_cert_chain().unwrap()
));

assert_eq!(Arc::strong_count(&alligator_cert.ptr), 2);
assert_eq!(Arc::strong_count(&beaver_cert.ptr), 2);
assert_eq!(Arc::strong_count(&alligator_cert.cert_handle), 2);
assert_eq!(Arc::strong_count(&beaver_cert.cert_handle), 2);
}

// set an explicit default
Expand All @@ -449,10 +441,10 @@ mod tests {
&test_pair.client.peer_cert_chain().unwrap()
));

assert_eq!(Arc::strong_count(&alligator_cert.ptr), 2);
assert_eq!(Arc::strong_count(&alligator_cert.cert_handle), 2);
// beaver has an additional reference because it was used in multiple
// calls
assert_eq!(Arc::strong_count(&beaver_cert.ptr), 3);
assert_eq!(Arc::strong_count(&beaver_cert.cert_handle), 3);
}

// set a default without adding it to the store
Expand All @@ -470,8 +462,8 @@ mod tests {
&test_pair.client.peer_cert_chain().unwrap()
));

assert_eq!(Arc::strong_count(&alligator_cert.ptr), 2);
assert_eq!(Arc::strong_count(&beaver_cert.ptr), 2);
assert_eq!(Arc::strong_count(&alligator_cert.cert_handle), 2);
assert_eq!(Arc::strong_count(&beaver_cert.cert_handle), 2);
}

Ok(())
Expand Down
15 changes: 9 additions & 6 deletions bindings/rust/extended/s2n-tls/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
use crate::renegotiate::RenegotiateState;
use crate::{
callbacks::*,
cert_chain::CertificateChain,
cert_chain::{CertificateChain, CertificateChainHandle},
config::Config,
enums::*,
error::{Error, Fallible, Pollable},
Expand Down Expand Up @@ -1219,11 +1219,14 @@ impl Connection {
/// Corresponds to [s2n_connection_get_peer_cert_chain].
pub fn peer_cert_chain(&self) -> Result<CertificateChain<'static>, Error> {
unsafe {
let mut chain = CertificateChain::allocate_owned()?;
s2n_connection_get_peer_cert_chain(self.connection.as_ptr(), chain.as_mut_ptr())
.into_result()
.map(|_| ())?;
Ok(chain)
let chain_handle = CertificateChainHandle::allocate()?;
s2n_connection_get_peer_cert_chain(
self.connection.as_ptr(),
chain_handle.cert.as_ptr(),
)
.into_result()
.map(|_| ())?;
Ok(CertificateChain::from_allocated(chain_handle))
}
}

Expand Down

0 comments on commit 4ae43ec

Please sign in to comment.