diff --git a/src/app/app.rs b/src/app/app.rs index c27b53b..4a3b42a 100644 --- a/src/app/app.rs +++ b/src/app/app.rs @@ -4,7 +4,7 @@ use encryption::aes::AesKey; use crate::settings::{ConnectionsSettingsModel, SettingsReader}; -use super::{ClientCertificatesCache, SavedClientCert}; +use super::ClientCertificatesCache; pub const APP_VERSION: &'static str = env!("CARGO_PKG_VERSION"); @@ -13,7 +13,7 @@ pub struct AppContext { pub http_connections: AtomicIsize, id: AtomicI64, pub connection_settings: ConnectionsSettingsModel, - pub saved_client_certs: SavedClientCert, + //pub saved_client_certs: SavedClientCert, pub token_secret_key: AesKey, pub client_certificates: ClientCertificatesCache, } @@ -33,7 +33,7 @@ impl AppContext { http_connections: AtomicIsize::new(0), id: AtomicI64::new(0), connection_settings, - saved_client_certs: SavedClientCert::new(), + // saved_client_certs: SavedClientCert::new(), token_secret_key, client_certificates: ClientCertificatesCache::new(), } diff --git a/src/app/mod.rs b/src/app/mod.rs index 8d64113..90a826b 100644 --- a/src/app/mod.rs +++ b/src/app/mod.rs @@ -2,8 +2,6 @@ mod app; pub use app::*; mod ssl_certificate; pub use ssl_certificate::*; -mod saved_client_cert; -pub use saved_client_cert::*; pub mod certificates; mod client_certificates_cache; pub use client_certificates_cache::*; diff --git a/src/app/saved_client_cert.rs b/src/app/saved_client_cert.rs deleted file mode 100644 index 0552a5c..0000000 --- a/src/app/saved_client_cert.rs +++ /dev/null @@ -1,40 +0,0 @@ -use std::{collections::HashMap, sync::Mutex}; - -pub struct SavedClientCert { - items: Mutex>>, -} - -impl SavedClientCert { - pub fn new() -> Self { - Self { - items: Mutex::new(HashMap::new()), - } - } - - pub fn save(&self, port: u16, id: u64, cert: String) { - let mut read_access = self.items.lock().unwrap(); - - if !read_access.contains_key(&port) { - read_access.insert(port, Vec::new()); - } - - let by_port = read_access.get_mut(&port).unwrap(); - - let index = by_port.iter().position(|x| x.0 == id); - if let Some(index) = index { - by_port.remove(index); - } - by_port.push((id, cert)); - } - - pub fn get(&self, port: u16, id: u64) -> Option { - let mut read_access = self.items.lock().unwrap(); - - if let Some(by_port) = read_access.get_mut(&port) { - let index = by_port.iter().position(|x| x.0 == id)?; - return Some(by_port.remove(index).1); - } - - None - } -} diff --git a/src/http_server/client_cert_cell.rs b/src/http_server/client_cert_cell.rs new file mode 100644 index 0000000..cefe8ac --- /dev/null +++ b/src/http_server/client_cert_cell.rs @@ -0,0 +1,23 @@ +use std::sync::Mutex; + +pub struct ClientCertCell { + pub value: Mutex>, +} + +impl ClientCertCell { + pub fn new() -> Self { + Self { + value: Mutex::new(None), + } + } + + pub fn set(&self, value: String) { + let mut write_access = self.value.lock().unwrap(); + *write_access = Some(value); + } + + pub fn get(&self) -> Option { + let mut read_access = self.value.lock().unwrap(); + return read_access.take(); + } +} diff --git a/src/http_server/client_cert_verifier.rs b/src/http_server/client_cert_verifier.rs index 6bd472a..fcb6790 100644 --- a/src/http_server/client_cert_verifier.rs +++ b/src/http_server/client_cert_verifier.rs @@ -2,29 +2,24 @@ use std::{fmt::Debug, sync::Arc}; use tokio_rustls::rustls::{server::danger::ClientCertVerifier, SignatureScheme}; -use crate::app::AppContext; - -use super::ClientCertificateCa; +use super::{client_cert_cell::ClientCertCell, ClientCertificateCa}; pub struct MyClientCertVerifier { - app: Arc, + client_cert_cell: Arc, pub ca: Arc, endpoint_port: u16, - connection_id: u64, } impl MyClientCertVerifier { pub fn new( - app: Arc, + client_cert_cell: Arc, ca: Arc, endpoint_port: u16, - connection_id: u64, ) -> Self { Self { ca, - app, + client_cert_cell, endpoint_port, - connection_id, } } } @@ -85,9 +80,7 @@ impl ClientCertVerifier for MyClientCertVerifier { if let Some(common_name) = self.ca.check_certificate(end_entity) { println!("Accepted certificate with common name: {}", common_name); - self.app - .saved_client_certs - .save(self.endpoint_port, self.connection_id, common_name); + self.client_cert_cell.set(common_name); return Ok(tokio_rustls::rustls::server::danger::ClientCertVerified::assertion()); } diff --git a/src/http_server/https_server.rs b/src/http_server/https_server.rs index e4a7bad..c5e3105 100644 --- a/src/http_server/https_server.rs +++ b/src/http_server/https_server.rs @@ -28,11 +28,7 @@ async fn start_https_server_loop( // Build TLS configuration. - let mut connection_id = 0; - loop { - connection_id += 1; - let (tcp_stream, socket_addr) = listener.accept().await.unwrap(); println!("Accepted connection"); @@ -40,7 +36,6 @@ async fn start_https_server_loop( let result = lazy_accept_tcp_stream( app.clone(), endpoint_port, - connection_id, certified_key.clone(), tcp_stream, ) @@ -76,7 +71,6 @@ async fn start_https_server_loop( async fn lazy_accept_tcp_stream( app: Arc, endpoint_port: u16, - connection_id: u64, certified_key: Arc, tcp_stream: TcpStream, ) -> Result< @@ -103,7 +97,6 @@ async fn lazy_accept_tcp_stream( app.clone(), server_name, endpoint_port, - connection_id, certified_key, ) .await; @@ -112,12 +105,12 @@ async fn lazy_accept_tcp_stream( return Err(format!("failed to create tls config: {err:#}")); } - let (config, endpoint_info) = config_result.unwrap(); + let (config, endpoint_info, client_cert_cell) = config_result.unwrap(); let tls_stream = start.into_stream(config.into()).await.unwrap(); - let cn_user_name = if endpoint_info.client_certificate_id.is_some() { - app.saved_client_certs.get(endpoint_port, connection_id) + let cn_user_name = if let Some(client_cert_cell) = client_cert_cell { + client_cert_cell.get() } else { None }; diff --git a/src/http_server/mod.rs b/src/http_server/mod.rs index f1dff30..c125929 100644 --- a/src/http_server/mod.rs +++ b/src/http_server/mod.rs @@ -16,5 +16,6 @@ pub use client_cert_verifier::*; mod generate_tech_page; mod handle_request; pub use generate_tech_page::*; +mod client_cert_cell; mod server_cert_resolver; mod tls_acceptor; diff --git a/src/http_server/tls_acceptor.rs b/src/http_server/tls_acceptor.rs index 2fd40ca..9e7b973 100644 --- a/src/http_server/tls_acceptor.rs +++ b/src/http_server/tls_acceptor.rs @@ -6,7 +6,10 @@ use tokio_rustls::rustls::{ ServerConfig, }; -use crate::{app::AppContext, http_proxy_pass::HttpServerConnectionInfo}; +use crate::{ + app::AppContext, http_proxy_pass::HttpServerConnectionInfo, + http_server::client_cert_cell::ClientCertCell, +}; use super::{server_cert_resolver::MyCertResolver, MyClientCertVerifier}; @@ -14,9 +17,15 @@ pub async fn create_config( app: Arc, server_name: &str, endpoint_port: u16, - connection_id: u64, certified_key: Arc, -) -> Result<(ServerConfig, HttpServerConnectionInfo), String> { +) -> Result< + ( + ServerConfig, + HttpServerConnectionInfo, + Option>, + ), + String, +> { let endpoint_info = app .settings_reader .get_https_connection_configuration(server_name, endpoint_port) @@ -26,11 +35,12 @@ pub async fn create_config( let client_cert_ca = crate::flows::get_client_certificate(&app, client_cert_ca_id, endpoint_port).await?; + let client_cert_cell = Arc::new(ClientCertCell::new()); + let client_cert_verifier = Arc::new(MyClientCertVerifier::new( - app.clone(), + client_cert_cell.clone(), client_cert_ca, endpoint_port, - connection_id, )); let mut server_config = @@ -43,7 +53,7 @@ pub async fn create_config( !endpoint_info.http_type.is_http1() ); server_config.alpn_protocols = get_alpn_protocol(!endpoint_info.http_type.is_http1()); - return Ok((server_config, endpoint_info)); + return Ok((server_config, endpoint_info, Some(client_cert_cell))); } let mut server_config = @@ -53,7 +63,7 @@ pub async fn create_config( server_config.alpn_protocols = get_alpn_protocol(!endpoint_info.http_type.is_http1()); - Ok((server_config, endpoint_info)) + Ok((server_config, endpoint_info, None)) } fn get_alpn_protocol(https2: bool) -> Vec> {