Skip to content

Commit

Permalink
Add Rust bindings for certificate chains (#4398)
Browse files Browse the repository at this point in the history
  • Loading branch information
Mark-Simulacrum authored Feb 20, 2024
1 parent 80a6913 commit 89dba0e
Show file tree
Hide file tree
Showing 5 changed files with 299 additions and 5 deletions.
154 changes: 154 additions & 0 deletions bindings/rust/s2n-tls/src/cert_chain.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

use crate::error::{Error, Fallible};
use s2n_tls_sys::*;
use std::{
marker::PhantomData,
ptr::{self, NonNull},
};

/// A CertificateChain represents a chain of X.509 certificates.
pub struct CertificateChain<'a> {
ptr: NonNull<s2n_cert_chain_and_key>,
is_owned: bool,
_lifetime: PhantomData<&'a s2n_cert_chain_and_key>,
}

impl CertificateChain<'_> {
/// This allocates a new certificate chain from s2n.
pub(crate) fn new() -> Result<CertificateChain<'static>, Error> {
unsafe {
let ptr = s2n_cert_chain_and_key_new().into_result()?;
Ok(CertificateChain {
ptr,
is_owned: true,
_lifetime: PhantomData,
})
}
}

pub(crate) unsafe fn from_ptr_reference<'a>(
ptr: NonNull<s2n_cert_chain_and_key>,
) -> CertificateChain<'a> {
CertificateChain {
ptr,
is_owned: false,
_lifetime: PhantomData,
}
}

pub fn iter(&self) -> CertificateChainIter<'_> {
CertificateChainIter {
idx: 0,
// Cache the length as it's O(n) to compute it, the chain is stored as a linked list.
// It shouldn't change while we have access to the iterator.
len: self.len(),
chain: self,
}
}

/// Return the length of this certificate chain.
///
/// Note that the underyling API currently traverses a linked list, so this is a relatively
/// expensive API to call.
pub fn len(&self) -> usize {
let mut length: u32 = 0;
let res =
unsafe { s2n_cert_chain_get_length(self.ptr.as_ptr(), &mut length).into_result() };
if res.is_err() {
// Errors should only happen on empty chains (we guarantee that `ptr` is a valid chain).
return 0;
}
// u32 should always fit into usize on the platforms we support.
length.try_into().unwrap()
}

/// Check if the certificate chain has any certificates.
///
/// Note that the underyling API currently traverses a linked list, so this is a relatively
/// expensive API to call.
pub fn is_empty(&self) -> bool {
self.len() == 0
}

pub(crate) fn as_mut_ptr(&mut self) -> NonNull<s2n_cert_chain_and_key> {
self.ptr
}
}

// # Safety
//
// s2n_cert_chain_and_key objects can be sent across threads.
unsafe impl Send for CertificateChain<'_> {}

impl Drop for CertificateChain<'_> {
fn drop(&mut self) {
if self.is_owned {
// ignore failures since there's not much we can do about it
unsafe {
let _ = s2n_cert_chain_and_key_free(self.ptr.as_ptr()).into_result();
}
}
}
}

pub struct CertificateChainIter<'a> {
idx: u32,
len: usize,
chain: &'a CertificateChain<'a>,
}

impl<'a> Iterator for CertificateChainIter<'a> {
type Item = Result<Certificate<'a>, Error>;

fn next(&mut self) -> Option<Self::Item> {
let idx = self.idx;
// u32 fits into usize on platforms we support.
if usize::try_from(idx).unwrap() >= self.len {
return None;
}
self.idx += 1;
let mut out = ptr::null_mut();
unsafe {
if let Err(e) =
s2n_cert_chain_get_cert(self.chain.ptr.as_ptr(), &mut out, idx).into_result()
{
return Some(Err(e));
}
}
let out = match NonNull::new(out) {
Some(out) => out,
None => return Some(Err(Error::INVALID_INPUT)),
};
Some(Ok(Certificate {
chain: PhantomData,
certificate: out,
}))
}
}

pub struct Certificate<'a> {
// The chain owns the memory for this certificate.
chain: PhantomData<&'a CertificateChain<'a>>,

certificate: NonNull<s2n_cert>,
}

impl<'a> Certificate<'a> {
pub fn der(&self) -> Result<&[u8], Error> {
unsafe {
let mut buffer = ptr::null();
let mut length = 0;
s2n_cert_get_der(self.certificate.as_ptr(), &mut buffer, &mut length).into_result()?;
let length = usize::try_from(length).map_err(|_| Error::INVALID_INPUT)?;

Ok(std::slice::from_raw_parts(buffer, length))
}
}
}

// # Safety
//
// Certificates just reference data in the chain, so share the Send-ness of the chain.
unsafe impl Send for Certificate<'_> {}
39 changes: 39 additions & 0 deletions bindings/rust/s2n-tls/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

use crate::{
callbacks::*,
cert_chain::CertificateChain,
config::Config,
enums::*,
error::{Error, Fallible, Pollable},
Expand Down Expand Up @@ -853,6 +854,44 @@ impl Connection {
.map(|_| ())
}
}

/// Returns the validated peer certificate chain.
// 'static lifetime is because this copies the certificate chain from the connection into a new
// chain, so the lifetime is independent of the connection.
pub fn peer_cert_chain(&self) -> Result<CertificateChain<'static>, Error> {
unsafe {
let mut chain = CertificateChain::new()?;
s2n_connection_get_peer_cert_chain(
self.connection.as_ptr(),
chain.as_mut_ptr().as_ptr(),
)
.into_result()
.map(|_| ())?;
Ok(chain)
}
}

/// Get the certificate used during the TLS handshake
///
/// - If `self` is a server connection, the certificate selected will depend on the
/// ServerName sent by the client and supported ciphers.
/// - If `self` is a client connection, the certificate sent in response to a CertificateRequest
/// message is returned. Currently s2n-tls supports loading only one certificate in client mode. Note that
/// not all TLS endpoints will request a certificate.
pub fn selected_cert(&self) -> Option<CertificateChain<'_>> {
unsafe {
// The API only returns null, no error is actually set.
// Clippy doesn't realize from_ptr_reference is unsafe.
#[allow(clippy::manual_map)]
if let Some(ptr) =
NonNull::new(s2n_connection_get_selected_cert(self.connection.as_ptr()))
{
Some(CertificateChain::from_ptr_reference(ptr))
} else {
None
}
}
}
}

struct Context {
Expand Down
1 change: 1 addition & 0 deletions bindings/rust/s2n-tls/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ static ALLOCATOR: checkers::Allocator = checkers::Allocator::system();
pub mod error;

pub mod callbacks;
pub mod cert_chain;
#[cfg(feature = "unstable-fingerprint")]
pub mod client_hello;
pub mod config;
Expand Down
6 changes: 1 addition & 5 deletions bindings/rust/s2n-tls/src/testing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -258,11 +258,7 @@ pub fn config_builder(cipher_prefs: &security::Policy) -> Result<crate::config::
builder
.set_verify_host_callback(InsecureAcceptAllCertificatesHandler {})
.expect("Unable to set a host verify callback.");
unsafe {
builder
.disable_x509_verification()
.expect("Unable to disable x509 verification");
};
builder.trust_pem(keypair.cert()).expect("load cert pem");
Ok(builder)
}

Expand Down
104 changes: 104 additions & 0 deletions bindings/rust/s2n-tls/src/testing/s2n_tls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -784,4 +784,108 @@ mod tests {
establish_connection(config_with_system_certs);
});
}

#[test]
fn peer_chain() -> Result<(), Error> {
use crate::enums::ClientAuthType;

let config = {
let mut config = config_builder(&security::DEFAULT_TLS13)?;
config.set_client_auth_type(ClientAuthType::Optional)?;
config.build()?
};

let server = {
let mut server = crate::connection::Connection::new_server();
server.set_config(config.clone())?;
Harness::new(server)
};

let client = {
let mut client = crate::connection::Connection::new_client();
client.set_config(config)?;
Harness::new(client)
};

let pair = Pair::new(server, client);
let pair = poll_tls_pair(pair);
let server = pair.server.0.connection;
let client = pair.client.0.connection;

for conn in [server, client] {
let chain = conn.peer_cert_chain()?;
assert_eq!(chain.len(), 1);
for cert in chain.iter() {
let cert = cert?;
let cert = cert.der()?;
assert!(!cert.is_empty());
}
}

Ok(())
}

#[test]
fn selected_cert() -> Result<(), Error> {
use crate::enums::ClientAuthType;

let config = {
let mut config = config_builder(&security::DEFAULT_TLS13)?;
config.set_client_auth_type(ClientAuthType::Required)?;
config.build()?
};

let server = {
let mut server = crate::connection::Connection::new_server();
server.set_config(config.clone())?;
Harness::new(server)
};

let client = {
let mut client = crate::connection::Connection::new_client();
client.set_config(config)?;
Harness::new(client)
};

// None before handshake...
assert!(server.connection.selected_cert().is_none());
assert!(client.connection.selected_cert().is_none());

let pair = Pair::new(server, client);

let pair = poll_tls_pair(pair);
let server = pair.server.0.connection;
let client = pair.client.0.connection;

for conn in [&server, &client] {
let chain = conn.selected_cert().unwrap();
assert_eq!(chain.len(), 1);
for cert in chain.iter() {
let cert = cert?;
let cert = cert.der()?;
assert!(!cert.is_empty());
}
}

// Same config is used for both and we are doing mTLS, so both should select the same
// certificate.
assert_eq!(
server
.selected_cert()
.unwrap()
.iter()
.next()
.unwrap()?
.der()?,
client
.selected_cert()
.unwrap()
.iter()
.next()
.unwrap()?
.der()?
);

Ok(())
}
}

0 comments on commit 89dba0e

Please sign in to comment.