Skip to content

Commit 3a8305a

Browse files
committed
Add example of an updating TLS resolver.
1 parent e4e46ef commit 3a8305a

File tree

3 files changed

+47
-2
lines changed

3 files changed

+47
-2
lines changed

core/lib/Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -131,3 +131,4 @@ version_check = "0.9.1"
131131
tokio = { version = "1", features = ["macros", "io-std"] }
132132
figment = { version = "0.10", features = ["test"] }
133133
pretty_assertions = "1"
134+
arc-swap = "1.7"

core/lib/src/listener/tls.rs

-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
use std::io;
22
use std::sync::Arc;
33

4-
use serde::Deserialize;
54
use tokio::io::{AsyncRead, AsyncWrite};
65
use tokio_rustls::LazyConfigAcceptor;
76
use rustls::server::{Acceptor, ServerConfig};

core/lib/src/tls/resolver.rs

+46-1
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,13 @@ impl fairing::Fairing for Fairing {
3737

3838
#[cfg(test)]
3939
mod tests {
40+
use std::sync::atomic::AtomicU64;
41+
use std::sync::atomic::Ordering;
4042
use std::sync::Arc;
4143
use std::collections::HashMap;
44+
use std::time::UNIX_EPOCH;
45+
use arc_swap::ArcSwap;
46+
use either::Either;
4247
use serde::Deserialize;
4348
use crate::http::uri::Host;
4449
use crate::tls::{TlsConfig, ServerConfig, Resolver, ClientHello};
@@ -69,10 +74,49 @@ mod tests {
6974
}
7075
}
7176

77+
struct UpdatingResolver {
78+
timestamp: AtomicU64,
79+
tls_config: TlsConfig,
80+
server_config: ArcSwap<ServerConfig>
81+
}
82+
83+
impl TryFrom<TlsConfig> for UpdatingResolver {
84+
type Error = crate::tls::Error;
85+
86+
fn try_from(tls_config: TlsConfig) -> Result<Self, Self::Error> {
87+
Ok(UpdatingResolver {
88+
timestamp: AtomicU64::new(0),
89+
server_config: ArcSwap::new(Arc::new(tls_config.to_server_config()?)),
90+
tls_config,
91+
})
92+
}
93+
}
94+
95+
#[crate::async_trait]
96+
impl Resolver for UpdatingResolver {
97+
async fn resolve(&self, _: ClientHello<'_>) -> Option<Arc<ServerConfig>> {
98+
if let Either::Left(path) = self.tls_config.certs() {
99+
let metadata = tokio::fs::metadata(&path).await.ok()?;
100+
let modtime = metadata.modified().ok()?;
101+
let timestamp = modtime.duration_since(UNIX_EPOCH).ok()?.as_secs();
102+
let old_timestamp = self.timestamp.load(Ordering::Acquire);
103+
if timestamp > old_timestamp {
104+
let new_config = self.tls_config.to_server_config().ok()?;
105+
self.server_config.store(Arc::new(new_config));
106+
self.timestamp.store(timestamp, Ordering::Release);
107+
}
108+
}
109+
110+
Some(self.server_config.load_full())
111+
}
112+
}
113+
72114
#[test]
73115
fn test_config() {
74116
figment::Jail::expect_with(|jail| {
75117
use crate::fs::relative;
118+
use figment::Figment;
119+
use figment::providers::{Toml, Format};
76120

77121
let cert_path = relative!("../../examples/tls/private/rsa_sha256_cert.pem");
78122
let key_path = relative!("../../examples/tls/private/rsa_sha256_key.pem");
@@ -87,7 +131,8 @@ mod tests {
87131
key = "{key_path}"
88132
"#))?;
89133

90-
let config = crate::Config::figment().extract::<SniConfig>()?;
134+
let toml = Toml::file("Rocket.toml").nested();
135+
let config: SniConfig = Figment::from(toml).extract().unwrap();
91136
assert!(config.sni.contains_key(&Host::parse("api.rocket.rs").unwrap()));
92137
assert!(config.sni.contains_key(&Host::parse("blob.rocket.rs").unwrap()));
93138
Ok(())

0 commit comments

Comments
 (0)