Skip to content

Commit

Permalink
fix example
Browse files Browse the repository at this point in the history
  • Loading branch information
toidiu committed Mar 5, 2024
1 parent 05efba5 commit e3295c5
Showing 1 changed file with 69 additions and 48 deletions.
117 changes: 69 additions & 48 deletions examples/rustls-mtls/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,21 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

use rustls::{
cipher_suite, ClientConfig, Error, RootCertStore, ServerConfig, SupportedCipherSuite,
use crate::rustls::server::WebPkiClientVerifier;
use rustls::{ClientConfig, Error, RootCertStore, ServerConfig};
use s2n_quic::provider::{
tls,
tls::{
default::{
default_crypto_provider, CertificateDer, PrivateKeyDer, TLS13_PROTOCOL_VERSIONS,
},
rustls::rustls,
},
};
use s2n_quic::provider::{tls, tls::rustls::rustls};
use std::{io::Cursor, path::Path, sync::Arc};
use tokio::{fs::File, io::AsyncReadExt};
use tracing::Level;

static PROTOCOL_VERSIONS: &[&rustls::SupportedProtocolVersion] = &[&rustls::version::TLS13];

pub static DEFAULT_CIPHERSUITES: &[SupportedCipherSuite] = &[
cipher_suite::TLS13_AES_128_GCM_SHA256,
cipher_suite::TLS13_AES_256_GCM_SHA384,
cipher_suite::TLS13_CHACHA20_POLY1305_SHA256,
];

pub fn initialize_logger(endpoint: &str) {
use std::sync::Once;

Expand All @@ -41,8 +40,8 @@ pub fn initialize_logger(endpoint: &str) {

pub struct MtlsProvider {
root_store: rustls::RootCertStore,
my_cert_chain: Vec<rustls::Certificate>,
my_private_key: rustls::PrivateKey,
my_cert_chain: Vec<CertificateDer<'static>>,
my_private_key: PrivateKeyDer<'static>,
}

impl tls::Provider for MtlsProvider {
Expand All @@ -51,12 +50,16 @@ impl tls::Provider for MtlsProvider {
type Error = rustls::Error;

fn start_server(self) -> Result<Self::Server, Self::Error> {
let verifier = rustls::server::AllowAnyAuthenticatedClient::new(self.root_store);
let mut cfg = ServerConfig::builder()
.with_cipher_suites(DEFAULT_CIPHERSUITES)
.with_safe_default_kx_groups()
.with_protocol_versions(PROTOCOL_VERSIONS)?
.with_client_cert_verifier(Arc::new(verifier))
let tls13_cipher_suite_crypto_provider = Arc::new(default_crypto_provider()?);
let verifier = WebPkiClientVerifier::builder_with_provider(
Arc::new(self.root_store),
tls13_cipher_suite_crypto_provider.clone(),
)
.build()
.unwrap();
let mut cfg = ServerConfig::builder_with_provider(tls13_cipher_suite_crypto_provider)
.with_protocol_versions(TLS13_PROTOCOL_VERSIONS)?
.with_client_cert_verifier(verifier)
.with_single_cert(self.my_cert_chain, self.my_private_key)?;

cfg.ignore_client_order = true;
Expand All @@ -66,12 +69,12 @@ impl tls::Provider for MtlsProvider {
}

fn start_client(self) -> Result<Self::Client, Self::Error> {
let mut cfg = ClientConfig::builder()
.with_cipher_suites(DEFAULT_CIPHERSUITES)
.with_safe_default_kx_groups()
.with_protocol_versions(PROTOCOL_VERSIONS)?
.with_root_certificates(self.root_store)
.with_client_auth_cert(self.my_cert_chain, self.my_private_key)?;
let tls13_cipher_suite_crypto_provider = default_crypto_provider()?;
let mut cfg =
ClientConfig::builder_with_provider(tls13_cipher_suite_crypto_provider.into())
.with_protocol_versions(TLS13_PROTOCOL_VERSIONS)?
.with_root_certificates(self.root_store)
.with_client_auth_cert(self.my_cert_chain, self.my_private_key)?;

cfg.max_fragment_size = None;
cfg.alpn_protocols = vec![b"h3".to_vec()];
Expand All @@ -90,8 +93,8 @@ impl MtlsProvider {
let private_key = into_private_key(my_key_pem.as_ref()).await?;
Ok(MtlsProvider {
root_store,
my_cert_chain: cert_chain.into_iter().map(rustls::Certificate).collect(),
my_private_key: rustls::PrivateKey(private_key),
my_cert_chain: cert_chain.into_iter().map(CertificateDer::from).collect(),
my_private_key: private_key,
})
}
}
Expand All @@ -112,13 +115,24 @@ async fn into_certificate(path: &Path) -> Result<Vec<Vec<u8>>, Error> {
}

async fn into_root_store(path: &Path) -> Result<RootCertStore, Error> {
let ca_certs = into_certificate(path).await?;
let ca_certs: Vec<CertificateDer<'static>> = into_certificate(path)
.await
.map(|certs| certs.into_iter().map(CertificateDer::from))?
.collect();
let mut cert_store = RootCertStore::empty();
cert_store.add_parsable_certificates(ca_certs.as_slice());
cert_store.add_parsable_certificates(ca_certs);
Ok(cert_store)
}

async fn into_private_key(path: &Path) -> Result<Vec<u8>, Error> {
fn construct_pkcs1_key(key: Vec<u8>) -> Result<PrivateKeyDer<'static>, Error> {
Ok(PrivateKeyDer::Pkcs1(key.into()))
}

fn construct_pkcs8_key(key: Vec<u8>) -> Result<PrivateKeyDer<'static>, Error> {
Ok(PrivateKeyDer::Pkcs8(key.into()))
}

async fn into_private_key(path: &Path) -> Result<PrivateKeyDer<'static>, Error> {
let mut f = File::open(path)
.await
.map_err(|e| Error::General(format!("Failed to load file: {}", e)))?;
Expand All @@ -128,26 +142,33 @@ async fn into_private_key(path: &Path) -> Result<Vec<u8>, Error> {
.map_err(|e| Error::General(format!("Failed to read file: {}", e)))?;
let mut cursor = Cursor::new(buf);

let parsers = [
rustls_pemfile::rsa_private_keys,
rustls_pemfile::pkcs8_private_keys,
];
for parser in parsers.iter() {
cursor.set_position(0);

match parser(&mut cursor) {
Ok(keys) if keys.is_empty() => continue,
Ok(mut keys) if keys.len() == 1 => return Ok(rustls::PrivateKey(keys.pop().unwrap()).0),
Ok(keys) => {
return Err(Error::General(format!(
"Unexpected number of keys: {} (only 1 supported)",
keys.len()
)));
macro_rules! parse_key {
($parser:ident, $constructor:ident) => {
cursor.set_position(0);

match rustls_pemfile::$parser(&mut cursor) {
// try the next parser
Err(_) => (),
// try the next parser
Ok(keys) if keys.is_empty() => (),
Ok(mut keys) if keys.len() == 1 => {
return $constructor(keys.pop().unwrap());
}
Ok(keys) => {
return Err(Error::General(format!(
"Unexpected number of keys: {} (only 1 supported)",
keys.len()
)));
}
}
// try the next parser
Err(_) => continue,
}
};
}

// attempt to parse PKCS8 encoded key
parse_key!(pkcs8_private_keys, construct_pkcs8_key);
// attempt to parse RSA key
parse_key!(rsa_private_keys, construct_pkcs1_key);

Err(Error::General(
"could not load any valid private keys".to_string(),
))
Expand Down

0 comments on commit e3295c5

Please sign in to comment.