Skip to content

Commit f08bd39

Browse files
committed
wip: dynamic tls cert resolver
1 parent 8f3061b commit f08bd39

File tree

4 files changed

+152
-51
lines changed

4 files changed

+152
-51
lines changed

core/lib/src/listener/tls.rs

+16-49
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,11 @@ use std::io;
22
use std::sync::Arc;
33

44
use serde::Deserialize;
5-
use rustls::server::{ServerSessionMemoryCache, ServerConfig, WebPkiClientVerifier};
65
use tokio::io::{AsyncRead, AsyncWrite};
7-
use tokio_rustls::TlsAcceptor;
6+
use tokio_rustls::LazyConfigAcceptor;
7+
use rustls::server::{Acceptor, ServerConfig};
88

9-
use crate::tls::{TlsConfig, Error};
10-
use crate::tls::util::{load_cert_chain, load_key, load_ca_certs};
9+
use crate::tls::{Error, Resolver, TlsConfig};
1110
use crate::listener::{Listener, Bindable, Connection, Certificates, Endpoint};
1211

1312
#[doc(inline)]
@@ -16,7 +15,8 @@ pub use tokio_rustls::server::TlsStream;
1615
/// A TLS listener over some listener interface L.
1716
pub struct TlsListener<L> {
1817
listener: L,
19-
acceptor: TlsAcceptor,
18+
resolver: Option<Box<dyn Resolver>>,
19+
default: Arc<ServerConfig>,
2020
config: TlsConfig,
2121
}
2222

@@ -27,48 +27,6 @@ pub struct TlsBindable<I> {
2727
pub tls: TlsConfig,
2828
}
2929

30-
impl TlsConfig {
31-
pub(crate) fn server_config(&self) -> Result<ServerConfig, Error> {
32-
let provider = rustls::crypto::CryptoProvider {
33-
cipher_suites: self.ciphers().map(|c| c.into()).collect(),
34-
..rustls::crypto::ring::default_provider()
35-
};
36-
37-
#[cfg(feature = "mtls")]
38-
let verifier = match self.mutual {
39-
Some(ref mtls) => {
40-
let ca_certs = load_ca_certs(&mut mtls.ca_certs_reader()?)?;
41-
let verifier = WebPkiClientVerifier::builder(Arc::new(ca_certs));
42-
match mtls.mandatory {
43-
true => verifier.build()?,
44-
false => verifier.allow_unauthenticated().build()?,
45-
}
46-
},
47-
None => WebPkiClientVerifier::no_client_auth(),
48-
};
49-
50-
#[cfg(not(feature = "mtls"))]
51-
let verifier = WebPkiClientVerifier::no_client_auth();
52-
53-
let key = load_key(&mut self.key_reader()?)?;
54-
let cert_chain = load_cert_chain(&mut self.certs_reader()?)?;
55-
let mut tls_config = ServerConfig::builder_with_provider(Arc::new(provider))
56-
.with_safe_default_protocol_versions()?
57-
.with_client_cert_verifier(verifier)
58-
.with_single_cert(cert_chain, key)?;
59-
60-
tls_config.ignore_client_order = self.prefer_server_cipher_order;
61-
tls_config.session_storage = ServerSessionMemoryCache::new(1024);
62-
tls_config.ticketer = rustls::crypto::ring::Ticketer::new()?;
63-
tls_config.alpn_protocols = vec![b"http/1.1".to_vec()];
64-
if cfg!(feature = "http2") {
65-
tls_config.alpn_protocols.insert(0, b"h2".to_vec());
66-
}
67-
68-
Ok(tls_config)
69-
}
70-
}
71-
7230
impl<I: Bindable> Bindable for TlsBindable<I>
7331
where I::Listener: Listener<Accept = <I::Listener as Listener>::Connection>,
7432
<I::Listener as Listener>::Connection: AsyncRead + AsyncWrite
@@ -79,7 +37,8 @@ impl<I: Bindable> Bindable for TlsBindable<I>
7937

8038
async fn bind(self) -> Result<Self::Listener, Self::Error> {
8139
Ok(TlsListener {
82-
acceptor: TlsAcceptor::from(Arc::new(self.tls.server_config()?)),
40+
default: Arc::new(self.tls.to_server_config()?),
41+
resolver: None,
8342
listener: self.inner.bind().await.map_err(|e| Error::Bind(Box::new(e)))?,
8443
config: self.tls,
8544
})
@@ -104,7 +63,15 @@ impl<L> Listener for TlsListener<L>
10463
}
10564

10665
async fn connect(&self, conn: L::Connection) -> io::Result<Self::Connection> {
107-
self.acceptor.accept(conn).await
66+
let acceptor = LazyConfigAcceptor::new(Acceptor::default(), conn);
67+
let handshake = acceptor.await?;
68+
let hello = handshake.client_hello();
69+
let config = match &self.resolver {
70+
Some(r) => r.resolve(hello).await.unwrap_or_else(|| self.default.clone()),
71+
None => self.default.clone(),
72+
};
73+
74+
handshake.into_stream(config).await
10875
}
10976

11077
fn endpoint(&self) -> io::Result<Endpoint> {

core/lib/src/tls/config.rs

+51-2
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
use std::io;
2+
use std::sync::Arc;
23

34
use figment::value::magic::{Either, RelativePathBuf};
45
use serde::{Deserialize, Serialize};
56
use indexmap::IndexSet;
67

8+
use crate::tls::Result;
9+
710
/// TLS configuration: certificate chain, key, and ciphersuites.
811
///
912
/// Four parameters control `tls` configuration:
@@ -426,8 +429,54 @@ impl TlsConfig {
426429
self.mutual.as_ref()
427430
}
428431

429-
pub fn validate(&self) -> Result<(), crate::tls::Error> {
430-
self.server_config().map(|_| ())
432+
/// Try to convert `self` into a [rustls] [`ServerConfig`].
433+
///
434+
/// [`ServerConfig`]: rustls::server::ServerConfig
435+
pub fn to_server_config(&self) -> Result<rustls::server::ServerConfig> {
436+
use rustls::server::{ServerSessionMemoryCache, ServerConfig, WebPkiClientVerifier};
437+
use crate::tls::util::{load_cert_chain, load_key, load_ca_certs};
438+
439+
let provider = rustls::crypto::CryptoProvider {
440+
cipher_suites: self.ciphers().map(|c| c.into()).collect(),
441+
..rustls::crypto::ring::default_provider()
442+
};
443+
444+
#[cfg(feature = "mtls")]
445+
let verifier = match self.mutual {
446+
Some(ref mtls) => {
447+
let ca_certs = load_ca_certs(&mut mtls.ca_certs_reader()?)?;
448+
let verifier = WebPkiClientVerifier::builder(Arc::new(ca_certs));
449+
match mtls.mandatory {
450+
true => verifier.build()?,
451+
false => verifier.allow_unauthenticated().build()?,
452+
}
453+
},
454+
None => WebPkiClientVerifier::no_client_auth(),
455+
};
456+
457+
#[cfg(not(feature = "mtls"))]
458+
let verifier = WebPkiClientVerifier::no_client_auth();
459+
460+
let key = load_key(&mut self.key_reader()?)?;
461+
let cert_chain = load_cert_chain(&mut self.certs_reader()?)?;
462+
let mut tls_config = ServerConfig::builder_with_provider(Arc::new(provider))
463+
.with_safe_default_protocol_versions()?
464+
.with_client_cert_verifier(verifier)
465+
.with_single_cert(cert_chain, key)?;
466+
467+
tls_config.ignore_client_order = self.prefer_server_cipher_order;
468+
tls_config.session_storage = ServerSessionMemoryCache::new(1024);
469+
tls_config.ticketer = rustls::crypto::ring::Ticketer::new()?;
470+
tls_config.alpn_protocols = vec![b"http/1.1".to_vec()];
471+
if cfg!(feature = "http2") {
472+
tls_config.alpn_protocols.insert(0, b"h2".to_vec());
473+
}
474+
475+
Ok(tls_config)
476+
}
477+
478+
pub fn validate(&self) -> Result<()> {
479+
self.to_server_config().map(|_| ())
431480
}
432481
}
433482

core/lib/src/tls/mod.rs

+4
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
mod error;
2+
mod resolver;
23
pub(crate) mod config;
34
pub(crate) mod util;
45

6+
pub use rustls;
7+
58
pub use error::Result;
69
pub use config::{TlsConfig, CipherSuite};
710
pub use error::Error;
11+
pub use resolver::{Resolver, ClientHello, ServerConfig};

core/lib/src/tls/resolver.rs

+81
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
use std::sync::Arc;
2+
3+
pub use rustls::server::{ClientHello, ServerConfig};
4+
5+
/// A dynamic TLS configuration resolver.
6+
#[crate::async_trait]
7+
pub trait Resolver: Send + Sync {
8+
async fn resolve(&self, hello: ClientHello<'_>) -> Option<Arc<ServerConfig>>;
9+
}
10+
11+
#[cfg(test)]
12+
mod tests {
13+
use std::sync::Arc;
14+
use std::collections::HashMap;
15+
use serde::Deserialize;
16+
use crate::http::uri::Host;
17+
use crate::tls::{TlsConfig, ServerConfig, Error, Resolver, ClientHello};
18+
19+
/// ```toml
20+
/// [sni."api.rocket.rs"]
21+
/// certs = "private/api_rocket_rs.rsa_sha256_cert.pem"
22+
/// key = "private/api_rocket_rs.rsa_sha256_key.pem"
23+
///
24+
/// [sni."blob.rocket.rs"]
25+
/// certs = "private/blob_rsa_sha256_cert.pem"
26+
/// key = "private/blob_rsa_sha256_key.pem"
27+
/// ```
28+
#[derive(Deserialize)]
29+
struct SniConfig {
30+
sni: HashMap<Host<'static>, TlsConfig>,
31+
}
32+
33+
struct SniResolver {
34+
sni_map: HashMap<Host<'static>, Arc<ServerConfig>>
35+
}
36+
37+
#[crate::async_trait]
38+
impl Resolver for SniResolver {
39+
async fn resolve(&self, hello: ClientHello<'_>) -> Option<Arc<ServerConfig>> {
40+
let host = Host::parse(hello.server_name()?).ok()?;
41+
self.sni_map.get(&host).cloned()
42+
}
43+
}
44+
45+
#[test]
46+
fn test_config() {
47+
figment::Jail::expect_with(|jail| {
48+
use crate::fs::relative;
49+
50+
let cert_path = relative!("../../examples/tls/private/rsa_sha256_cert.pem");
51+
let key_path = relative!("../../examples/tls/private/rsa_sha256_key.pem");
52+
53+
jail.create_file("Rocket.toml", &format!(r#"
54+
[default.sni."api.rocket.rs"]
55+
certs = "{cert_path}"
56+
key = "{key_path}"
57+
58+
[default.sni."blob.rocket.rs"]
59+
certs = "{cert_path}"
60+
key = "{key_path}"
61+
"#))?;
62+
63+
let config = crate::Config::figment().extract::<SniConfig>()?;
64+
assert!(config.sni.contains_key(&Host::parse("api.rocket.rs").unwrap()));
65+
assert!(config.sni.contains_key(&Host::parse("blob.rocket.rs").unwrap()));
66+
Ok(())
67+
});
68+
}
69+
70+
#[test]
71+
fn test() {
72+
let rocket = crate::build();
73+
let config = rocket.figment().extract::<SniConfig>().unwrap();
74+
let sni_map = config.sni.into_iter()
75+
.map(|(k, v)| Ok((k, Arc::new(v.to_server_config()?))))
76+
.collect::<Result<HashMap<_, _>, Error>>()
77+
.unwrap();
78+
79+
let _ = SniResolver { sni_map, };
80+
}
81+
}

0 commit comments

Comments
 (0)