diff --git a/bindings/rust/s2n-tls/src/cert_chain.rs b/bindings/rust/s2n-tls/src/cert_chain.rs new file mode 100644 index 00000000000..007657d15f5 --- /dev/null +++ b/bindings/rust/s2n-tls/src/cert_chain.rs @@ -0,0 +1,154 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use crate::error::{Error, Fallible}; +use s2n_tls_sys::*; +use std::{ + marker::PhantomData, + ptr::{self, NonNull}, +}; + +/// A CertificateChain represents a chain of X.509 certificates. +pub struct CertificateChain<'a> { + ptr: NonNull, + is_owned: bool, + _lifetime: PhantomData<&'a s2n_cert_chain_and_key>, +} + +impl CertificateChain<'_> { + /// This allocates a new certificate chain from s2n. + pub(crate) fn new() -> Result, Error> { + unsafe { + let ptr = s2n_cert_chain_and_key_new().into_result()?; + Ok(CertificateChain { + ptr, + is_owned: true, + _lifetime: PhantomData, + }) + } + } + + pub(crate) unsafe fn from_ptr_reference<'a>( + ptr: NonNull, + ) -> CertificateChain<'a> { + CertificateChain { + ptr, + is_owned: false, + _lifetime: PhantomData, + } + } + + pub fn iter(&self) -> CertificateChainIter<'_> { + CertificateChainIter { + idx: 0, + // Cache the length as it's O(n) to compute it, the chain is stored as a linked list. + // It shouldn't change while we have access to the iterator. + len: self.len(), + chain: self, + } + } + + /// Return the length of this certificate chain. + /// + /// Note that the underyling API currently traverses a linked list, so this is a relatively + /// expensive API to call. + pub fn len(&self) -> usize { + let mut length: u32 = 0; + let res = + unsafe { s2n_cert_chain_get_length(self.ptr.as_ptr(), &mut length).into_result() }; + if res.is_err() { + // Errors should only happen on empty chains (we guarantee that `ptr` is a valid chain). + return 0; + } + // u32 should always fit into usize on the platforms we support. + length.try_into().unwrap() + } + + /// Check if the certificate chain has any certificates. + /// + /// Note that the underyling API currently traverses a linked list, so this is a relatively + /// expensive API to call. + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + pub(crate) fn as_mut_ptr(&mut self) -> NonNull { + self.ptr + } +} + +// # Safety +// +// s2n_cert_chain_and_key objects can be sent across threads. +unsafe impl Send for CertificateChain<'_> {} + +impl Drop for CertificateChain<'_> { + fn drop(&mut self) { + if self.is_owned { + // ignore failures since there's not much we can do about it + unsafe { + let _ = s2n_cert_chain_and_key_free(self.ptr.as_ptr()).into_result(); + } + } + } +} + +pub struct CertificateChainIter<'a> { + idx: u32, + len: usize, + chain: &'a CertificateChain<'a>, +} + +impl<'a> Iterator for CertificateChainIter<'a> { + type Item = Result, Error>; + + fn next(&mut self) -> Option { + let idx = self.idx; + // u32 fits into usize on platforms we support. + if usize::try_from(idx).unwrap() >= self.len { + return None; + } + self.idx += 1; + let mut out = ptr::null_mut(); + unsafe { + if let Err(e) = + s2n_cert_chain_get_cert(self.chain.ptr.as_ptr(), &mut out, idx).into_result() + { + return Some(Err(e)); + } + } + let out = match NonNull::new(out) { + Some(out) => out, + None => return Some(Err(Error::INVALID_INPUT)), + }; + Some(Ok(Certificate { + chain: PhantomData, + certificate: out, + })) + } +} + +pub struct Certificate<'a> { + // The chain owns the memory for this certificate. + chain: PhantomData<&'a CertificateChain<'a>>, + + certificate: NonNull, +} + +impl<'a> Certificate<'a> { + pub fn der(&self) -> Result<&[u8], Error> { + unsafe { + let mut buffer = ptr::null(); + let mut length = 0; + s2n_cert_get_der(self.certificate.as_ptr(), &mut buffer, &mut length).into_result()?; + let length = usize::try_from(length).map_err(|_| Error::INVALID_INPUT)?; + + Ok(std::slice::from_raw_parts(buffer, length)) + } + } +} + +// # Safety +// +// Certificates just reference data in the chain, so share the Send-ness of the chain. +unsafe impl Send for Certificate<'_> {} diff --git a/bindings/rust/s2n-tls/src/connection.rs b/bindings/rust/s2n-tls/src/connection.rs index 332688589fd..efbf2757afb 100644 --- a/bindings/rust/s2n-tls/src/connection.rs +++ b/bindings/rust/s2n-tls/src/connection.rs @@ -5,6 +5,7 @@ use crate::{ callbacks::*, + cert_chain::CertificateChain, config::Config, enums::*, error::{Error, Fallible, Pollable}, @@ -853,6 +854,44 @@ impl Connection { .map(|_| ()) } } + + /// Returns the validated peer certificate chain. + // 'static lifetime is because this copies the certificate chain from the connection into a new + // chain, so the lifetime is independent of the connection. + pub fn peer_cert_chain(&self) -> Result, Error> { + unsafe { + let mut chain = CertificateChain::new()?; + s2n_connection_get_peer_cert_chain( + self.connection.as_ptr(), + chain.as_mut_ptr().as_ptr(), + ) + .into_result() + .map(|_| ())?; + Ok(chain) + } + } + + /// Get the certificate used during the TLS handshake + /// + /// - If `self` is a server connection, the certificate selected will depend on the + /// ServerName sent by the client and supported ciphers. + /// - If `self` is a client connection, the certificate sent in response to a CertificateRequest + /// message is returned. Currently s2n-tls supports loading only one certificate in client mode. Note that + /// not all TLS endpoints will request a certificate. + pub fn selected_cert(&self) -> Option> { + unsafe { + // The API only returns null, no error is actually set. + // Clippy doesn't realize from_ptr_reference is unsafe. + #[allow(clippy::manual_map)] + if let Some(ptr) = + NonNull::new(s2n_connection_get_selected_cert(self.connection.as_ptr())) + { + Some(CertificateChain::from_ptr_reference(ptr)) + } else { + None + } + } + } } struct Context { diff --git a/bindings/rust/s2n-tls/src/lib.rs b/bindings/rust/s2n-tls/src/lib.rs index b09c8fe5b30..78b4e81c572 100644 --- a/bindings/rust/s2n-tls/src/lib.rs +++ b/bindings/rust/s2n-tls/src/lib.rs @@ -14,6 +14,7 @@ static ALLOCATOR: checkers::Allocator = checkers::Allocator::system(); pub mod error; pub mod callbacks; +pub mod cert_chain; #[cfg(feature = "unstable-fingerprint")] pub mod client_hello; pub mod config; diff --git a/bindings/rust/s2n-tls/src/testing.rs b/bindings/rust/s2n-tls/src/testing.rs index 78e9d14dd8a..015ba9b6152 100644 --- a/bindings/rust/s2n-tls/src/testing.rs +++ b/bindings/rust/s2n-tls/src/testing.rs @@ -258,11 +258,7 @@ pub fn config_builder(cipher_prefs: &security::Policy) -> Result Result<(), Error> { + use crate::enums::ClientAuthType; + + let config = { + let mut config = config_builder(&security::DEFAULT_TLS13)?; + config.set_client_auth_type(ClientAuthType::Optional)?; + config.build()? + }; + + let server = { + let mut server = crate::connection::Connection::new_server(); + server.set_config(config.clone())?; + Harness::new(server) + }; + + let client = { + let mut client = crate::connection::Connection::new_client(); + client.set_config(config)?; + Harness::new(client) + }; + + let pair = Pair::new(server, client); + let pair = poll_tls_pair(pair); + let server = pair.server.0.connection; + let client = pair.client.0.connection; + + for conn in [server, client] { + let chain = conn.peer_cert_chain()?; + assert_eq!(chain.len(), 1); + for cert in chain.iter() { + let cert = cert?; + let cert = cert.der()?; + assert!(!cert.is_empty()); + } + } + + Ok(()) + } + + #[test] + fn selected_cert() -> Result<(), Error> { + use crate::enums::ClientAuthType; + + let config = { + let mut config = config_builder(&security::DEFAULT_TLS13)?; + config.set_client_auth_type(ClientAuthType::Required)?; + config.build()? + }; + + let server = { + let mut server = crate::connection::Connection::new_server(); + server.set_config(config.clone())?; + Harness::new(server) + }; + + let client = { + let mut client = crate::connection::Connection::new_client(); + client.set_config(config)?; + Harness::new(client) + }; + + // None before handshake... + assert!(server.connection.selected_cert().is_none()); + assert!(client.connection.selected_cert().is_none()); + + let pair = Pair::new(server, client); + + let pair = poll_tls_pair(pair); + let server = pair.server.0.connection; + let client = pair.client.0.connection; + + for conn in [&server, &client] { + let chain = conn.selected_cert().unwrap(); + assert_eq!(chain.len(), 1); + for cert in chain.iter() { + let cert = cert?; + let cert = cert.der()?; + assert!(!cert.is_empty()); + } + } + + // Same config is used for both and we are doing mTLS, so both should select the same + // certificate. + assert_eq!( + server + .selected_cert() + .unwrap() + .iter() + .next() + .unwrap()? + .der()?, + client + .selected_cert() + .unwrap() + .iter() + .next() + .unwrap()? + .der()? + ); + + Ok(()) + } }