From 4ae43ecc538ca7a2ec0931005ab54c7a5d3168b8 Mon Sep 17 00:00:00 2001 From: James Mayclin Date: Wed, 19 Feb 2025 12:53:38 -0800 Subject: [PATCH] fix(bindings): remove mutation behind Arc (#5124) --- .../rust/extended/s2n-tls/src/cert_chain.rs | 106 ++++++++---------- .../rust/extended/s2n-tls/src/connection.rs | 15 ++- 2 files changed, 58 insertions(+), 63 deletions(-) diff --git a/bindings/rust/extended/s2n-tls/src/cert_chain.rs b/bindings/rust/extended/s2n-tls/src/cert_chain.rs index 48a3f57dfd8..4c5790cf654 100644 --- a/bindings/rust/extended/s2n-tls/src/cert_chain.rs +++ b/bindings/rust/extended/s2n-tls/src/cert_chain.rs @@ -13,34 +13,41 @@ use std::{ /// /// [CertificateChain] is internally reference counted. The reference counted `T` /// must have a drop implementation. -struct CertificateChainHandle { - cert: NonNull, +pub(crate) struct CertificateChainHandle<'a> { + pub(crate) cert: NonNull, is_owned: bool, + _lifetime: PhantomData<&'a s2n_cert_chain_and_key>, } // # Safety // // s2n_cert_chain_and_key objects can be sent across threads. -unsafe impl Send for CertificateChainHandle {} -unsafe impl Sync for CertificateChainHandle {} +unsafe impl Send for CertificateChainHandle<'_> {} +unsafe impl Sync for CertificateChainHandle<'_> {} -impl CertificateChainHandle { - fn from_owned(cert: NonNull) -> Self { - Self { - cert, +impl CertificateChainHandle<'_> { + /// Allocate an uninitialized CertificateChainHandle. + /// + /// Corresponds to [s2n_cert_chain_and_key_new]. + pub(crate) fn allocate() -> Result, crate::error::Error> { + crate::init::init(); + Ok(CertificateChainHandle { + cert: unsafe { s2n_cert_chain_and_key_new().into_result() }?, is_owned: true, - } + _lifetime: PhantomData, + }) } fn from_reference(cert: NonNull) -> Self { Self { cert, is_owned: false, + _lifetime: PhantomData, } } } -impl Drop for CertificateChainHandle { +impl Drop for CertificateChainHandle<'_> { /// Corresponds to [s2n_cert_chain_and_key_free]. fn drop(&mut self) { // ignore failures since there's not much we can do about it @@ -53,13 +60,13 @@ impl Drop for CertificateChainHandle { } pub struct Builder { - cert: CertificateChain<'static>, + cert_handle: CertificateChainHandle<'static>, } impl Builder { pub fn new() -> Result { Ok(Self { - cert: CertificateChain::allocate_owned()?, + cert_handle: CertificateChainHandle::allocate()?, }) } @@ -73,7 +80,7 @@ impl Builder { // `private_key_pem` are not modified. // https://github.com/aws/s2n-tls/issues/4140 s2n_cert_chain_and_key_load_pem_bytes( - self.cert.as_mut_ptr(), + self.cert_handle.cert.as_ptr(), chain.as_ptr() as *mut _, chain.len() as u32, key.as_ptr() as *mut _, @@ -95,7 +102,7 @@ impl Builder { // is not modified // https://github.com/aws/s2n-tls/issues/4140 s2n_cert_chain_and_key_load_public_pem_bytes( - self.cert.as_mut_ptr(), + self.cert_handle.cert.as_ptr(), chain.as_ptr() as *mut _, chain.len() as u32, ) @@ -109,7 +116,7 @@ impl Builder { pub fn set_ocsp_data(&mut self, data: &[u8]) -> Result<&mut Self, Error> { unsafe { s2n_cert_chain_and_key_set_ocsp_data( - self.cert.as_mut_ptr(), + self.cert_handle.cert.as_ptr(), data.as_ptr(), data.len() as u32, ) @@ -122,7 +129,7 @@ impl Builder { pub fn build(self) -> Result, Error> { // This method is currently infallible, but returning a result allows // us to add validation in the future. - Ok(self.cert) + Ok(CertificateChain::from_allocated(self.cert_handle)) } } @@ -135,22 +142,16 @@ impl Builder { // safe to mutate CertificateChains. #[derive(Clone)] pub struct CertificateChain<'a> { - ptr: Arc, - _lifetime: PhantomData<&'a s2n_cert_chain_and_key>, + cert_handle: Arc>, } impl CertificateChain<'_> { - /// This allocates a new certificate chain from s2n. - /// - /// Corresponds to [s2n_cert_chain_and_key_new]. - pub(crate) fn allocate_owned() -> Result, Error> { - crate::init::init(); - unsafe { - let ptr = s2n_cert_chain_and_key_new().into_result()?; - Ok(CertificateChain { - ptr: Arc::new(CertificateChainHandle::from_owned(ptr)), - _lifetime: PhantomData, - }) + /// Construct a CertificateChain from an allocated [CertificateChainHandle]. + pub(crate) fn from_allocated( + handle: CertificateChainHandle<'static>, + ) -> CertificateChain<'static> { + CertificateChain { + cert_handle: Arc::new(handle), } } @@ -162,8 +163,7 @@ impl CertificateChain<'_> { let handle = Arc::new(CertificateChainHandle::from_reference(ptr)); CertificateChain { - ptr: handle, - _lifetime: PhantomData, + cert_handle: handle, } } @@ -202,16 +202,8 @@ impl CertificateChain<'_> { self.len() == 0 } - /// SAFETY: Only one instance of `CertificateChain` may exist when this method - /// is called. s2n_cert_chain_and_key is not thread-safe, so it is not safe - /// to mutate the certificate chain if references are held across multiple threads. - pub(crate) unsafe fn as_mut_ptr(&mut self) -> *mut s2n_cert_chain_and_key { - debug_assert_eq!(Arc::strong_count(&self.ptr), 1); - self.ptr.cert.as_ptr() - } - pub(crate) fn as_ptr(&self) -> *const s2n_cert_chain_and_key { - self.ptr.cert.as_ptr() as *const _ + self.cert_handle.cert.as_ptr() as *const _ } } @@ -339,19 +331,19 @@ mod tests { #[test] fn reference_count_increment() -> Result<(), crate::error::Error> { let cert = SniTestCerts::AlligatorRsa.get().into_certificate_chain(); - assert_eq!(Arc::strong_count(&cert.ptr), 1); + assert_eq!(Arc::strong_count(&cert.cert_handle), 1); { let mut server = config::Builder::new(); server.load_chain(cert.clone())?; // after being added, the reference count should have increased - assert_eq!(Arc::strong_count(&cert.ptr), 2); + assert_eq!(Arc::strong_count(&cert.cert_handle), 2); } // after the config goes out of scope and is dropped, the ref count should // decrement - assert_eq!(Arc::strong_count(&cert.ptr), 1); + assert_eq!(Arc::strong_count(&cert.cert_handle), 1); Ok(()) } @@ -359,8 +351,8 @@ mod tests { fn cert_is_dropped() { let weak_ref = { let cert = SniTestCerts::AlligatorEcdsa.get().into_certificate_chain(); - assert_eq!(Arc::strong_count(&cert.ptr), 1); - Arc::downgrade(&cert.ptr) + assert_eq!(Arc::strong_count(&cert.cert_handle), 1); + Arc::downgrade(&cert.cert_handle) }; assert_eq!(weak_ref.strong_count(), 0); assert!(weak_ref.upgrade().is_none()); @@ -377,17 +369,17 @@ mod tests { let mut test_pair_2 = sni_test_pair(vec![cert.clone()], None, &[SniTestCerts::AlligatorRsa])?; - assert_eq!(Arc::strong_count(&cert.ptr), 3); + assert_eq!(Arc::strong_count(&cert.cert_handle), 3); assert!(test_pair_1.handshake().is_ok()); assert!(test_pair_2.handshake().is_ok()); - assert_eq!(Arc::strong_count(&cert.ptr), 3); + assert_eq!(Arc::strong_count(&cert.cert_handle), 3); drop(test_pair_1); - assert_eq!(Arc::strong_count(&cert.ptr), 2); + assert_eq!(Arc::strong_count(&cert.cert_handle), 2); drop(test_pair_2); - assert_eq!(Arc::strong_count(&cert.ptr), 1); + assert_eq!(Arc::strong_count(&cert.cert_handle), 1); Ok(()) } @@ -396,7 +388,7 @@ mod tests { // 5 certs in the maximum allowed, 6 should error. const FAILING_NUMBER: usize = 6; let certs = vec![SniTestCerts::AlligatorRsa.get().into_certificate_chain(); FAILING_NUMBER]; - assert_eq!(Arc::strong_count(&certs[0].ptr), FAILING_NUMBER); + assert_eq!(Arc::strong_count(&certs[0].cert_handle), FAILING_NUMBER); let mut config = config::Builder::new(); let err = config.set_default_chains(certs.clone()).err().unwrap(); @@ -405,7 +397,7 @@ mod tests { // The config should not hold a reference when the error was detected // in the bindings - assert_eq!(Arc::strong_count(&certs[0].ptr), FAILING_NUMBER); + assert_eq!(Arc::strong_count(&certs[0].cert_handle), FAILING_NUMBER); Ok(()) } @@ -430,8 +422,8 @@ mod tests { &test_pair.client.peer_cert_chain().unwrap() )); - assert_eq!(Arc::strong_count(&alligator_cert.ptr), 2); - assert_eq!(Arc::strong_count(&beaver_cert.ptr), 2); + assert_eq!(Arc::strong_count(&alligator_cert.cert_handle), 2); + assert_eq!(Arc::strong_count(&beaver_cert.cert_handle), 2); } // set an explicit default @@ -449,10 +441,10 @@ mod tests { &test_pair.client.peer_cert_chain().unwrap() )); - assert_eq!(Arc::strong_count(&alligator_cert.ptr), 2); + assert_eq!(Arc::strong_count(&alligator_cert.cert_handle), 2); // beaver has an additional reference because it was used in multiple // calls - assert_eq!(Arc::strong_count(&beaver_cert.ptr), 3); + assert_eq!(Arc::strong_count(&beaver_cert.cert_handle), 3); } // set a default without adding it to the store @@ -470,8 +462,8 @@ mod tests { &test_pair.client.peer_cert_chain().unwrap() )); - assert_eq!(Arc::strong_count(&alligator_cert.ptr), 2); - assert_eq!(Arc::strong_count(&beaver_cert.ptr), 2); + assert_eq!(Arc::strong_count(&alligator_cert.cert_handle), 2); + assert_eq!(Arc::strong_count(&beaver_cert.cert_handle), 2); } Ok(()) diff --git a/bindings/rust/extended/s2n-tls/src/connection.rs b/bindings/rust/extended/s2n-tls/src/connection.rs index aa03d97f747..3bd64b6e233 100644 --- a/bindings/rust/extended/s2n-tls/src/connection.rs +++ b/bindings/rust/extended/s2n-tls/src/connection.rs @@ -7,7 +7,7 @@ use crate::renegotiate::RenegotiateState; use crate::{ callbacks::*, - cert_chain::CertificateChain, + cert_chain::{CertificateChain, CertificateChainHandle}, config::Config, enums::*, error::{Error, Fallible, Pollable}, @@ -1219,11 +1219,14 @@ impl Connection { /// Corresponds to [s2n_connection_get_peer_cert_chain]. pub fn peer_cert_chain(&self) -> Result, Error> { unsafe { - let mut chain = CertificateChain::allocate_owned()?; - s2n_connection_get_peer_cert_chain(self.connection.as_ptr(), chain.as_mut_ptr()) - .into_result() - .map(|_| ())?; - Ok(chain) + let chain_handle = CertificateChainHandle::allocate()?; + s2n_connection_get_peer_cert_chain( + self.connection.as_ptr(), + chain_handle.cert.as_ptr(), + ) + .into_result() + .map(|_| ())?; + Ok(CertificateChain::from_allocated(chain_handle)) } }