diff --git a/Cargo.toml b/Cargo.toml index 7b2bd3e9..26cfa1d8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,17 +11,19 @@ readme = "README.md" vendored = ["openssl/vendored"] [target.'cfg(any(target_os = "macos", target_os = "ios"))'.dependencies] -security-framework = "0.4.1" -security-framework-sys = "0.4.1" +security-framework = { version = "0.4.4", features = ["session-tickets"] } +security-framework-sys = "0.4.3" lazy_static = "1.0" libc = "0.2" tempfile = "3.0" [target.'cfg(target_os = "windows")'.dependencies] -schannel = "0.1.16" +schannel = "0.1.18" [target.'cfg(not(any(target_os = "windows", target_os = "macos", target_os = "ios")))'.dependencies] +linked_hash_set = "0.1" log = "0.4.5" +once_cell = "1.0" openssl = "0.10.29" openssl-sys = "0.9.55" openssl-probe = "0.1" diff --git a/appveyor.yml b/appveyor.yml index 473dd0e9..49fe535a 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -1,3 +1,4 @@ +image: Visual Studio 2017 environment: RUST_VERSION: 1.37.0 TARGET: x86_64-pc-windows-msvc diff --git a/build.rs b/build.rs index cbac306a..b6619258 100644 --- a/build.rs +++ b/build.rs @@ -7,6 +7,9 @@ fn main() { if version >= 0x1_01_00_00_0 { println!("cargo:rustc-cfg=have_min_max_version"); } + if version >= 0x1_01_01_00_0 { + println!("cargo:rustc-cfg=ossl111"); + } } if let Ok(version) = env::var("DEP_OPENSSL_LIBRESSL_VERSION_NUMBER") { diff --git a/src/imp/openssl.rs b/src/imp/openssl.rs index 5d835b7a..8b4fdbfb 100644 --- a/src/imp/openssl.rs +++ b/src/imp/openssl.rs @@ -1,20 +1,28 @@ +extern crate linked_hash_set; +extern crate once_cell; extern crate openssl; extern crate openssl_probe; +use self::linked_hash_set::LinkedHashSet; +use self::once_cell::sync::OnceCell; use self::openssl::error::ErrorStack; +use self::openssl::ex_data::Index; use self::openssl::hash::MessageDigest; use self::openssl::nid::Nid; use self::openssl::pkcs12::Pkcs12; use self::openssl::pkey::PKey; use self::openssl::ssl::{ - self, MidHandshakeSslStream, SslAcceptor, SslConnector, SslContextBuilder, SslMethod, - SslVerifyMode, + self, MidHandshakeSslStream, Ssl, SslAcceptor, SslConnector, SslContextBuilder, SslMethod, + SslSession, SslSessionCacheMode, SslSessionRef, SslVerifyMode, }; use self::openssl::x509::{X509, store::X509StoreBuilder, X509VerifyResult}; +use std::borrow::Borrow; +use std::collections::hash_map::{Entry, HashMap}; use std::error; use std::fmt; +use std::hash::{Hash, Hasher}; use std::io; -use std::sync::Once; +use std::sync::{Arc, Mutex, Once}; use {Protocol, TlsAcceptorBuilder, TlsConnectorBuilder}; use self::openssl::pkey::Private; @@ -248,6 +256,8 @@ pub struct TlsConnector { use_sni: bool, accept_invalid_hostnames: bool, accept_invalid_certs: bool, + session_tickets_enabled: bool, + session_cache: Arc>, } impl TlsConnector { @@ -277,11 +287,37 @@ impl TlsConnector { #[cfg(target_os = "android")] load_android_root_certs(&mut connector)?; + let session_cache = Arc::new(Mutex::new(SessionCache::new())); + if builder.session_tickets_enabled { + connector.set_session_cache_mode(SslSessionCacheMode::CLIENT); + + connector.set_new_session_callback({ + let session_cache = session_cache.clone(); + move |ssl, session| { + if let Some(key) = key_index().ok().and_then(|idx| ssl.ex_data(idx)) { + if let Ok(mut session_cache) = session_cache.lock() { + session_cache.insert(key.clone(), session); + } + } + } + }); + connector.set_remove_session_callback({ + let session_cache = session_cache.clone(); + move |_, session| { + if let Ok(mut session_cache) = session_cache.lock() { + session_cache.remove(session); + } + } + }); + } + Ok(TlsConnector { connector: connector.build(), use_sni: builder.use_sni, accept_invalid_hostnames: builder.accept_invalid_hostnames, accept_invalid_certs: builder.accept_invalid_certs, + session_tickets_enabled: builder.session_tickets_enabled, + session_cache, }) } @@ -297,6 +333,23 @@ impl TlsConnector { if self.accept_invalid_certs { ssl.set_verify(SslVerifyMode::NONE); } + if self.session_tickets_enabled { + let key = SessionKey { + host: domain.to_string(), + }; + + if let Ok(mut session_cache) = self.session_cache.lock() { + if let Some(session) = session_cache.get(&key) { + // Note: the `unsafe`-ty here is because the `session` is required to come from the + // same SSL_CTX that the ssl object (`ssl`) is from, since it maintains internal + // pointers and refcounts. Here, we only have one SSL_CTX, so this is safe. + unsafe { ssl.set_session(&session)? }; + } + } + + let idx = key_index()?; + ssl.set_ex_data(idx, key); + } let s = ssl.connect(domain, stream)?; Ok(TlsStream(s)) @@ -412,3 +465,151 @@ impl io::Write for TlsStream { self.0.flush() } } + +fn key_index() -> Result, ErrorStack> { + static IDX: OnceCell> = OnceCell::new(); + IDX.get_or_try_init(|| Ssl::new_ex_index()).map(|v| *v) +} + +#[derive(Hash, PartialEq, Eq, Clone)] +pub struct SessionKey { + pub host: String, +} + +#[derive(Clone)] +struct HashSession(SslSession); + +impl PartialEq for HashSession { + fn eq(&self, other: &HashSession) -> bool { + self.0.id() == other.0.id() + } +} + +impl Eq for HashSession {} + +impl Hash for HashSession { + fn hash(&self, state: &mut H) + where + H: Hasher, + { + self.0.id().hash(state); + } +} + +impl Borrow<[u8]> for HashSession { + fn borrow(&self) -> &[u8] { + self.0.id() + } +} + +pub struct SessionCache { + sessions: HashMap>, + reverse: HashMap, +} + +impl SessionCache { + pub fn new() -> SessionCache { + SessionCache { + sessions: HashMap::new(), + reverse: HashMap::new(), + } + } + + pub fn insert(&mut self, key: SessionKey, session: SslSession) { + let session = HashSession(session); + + self.sessions + .entry(key.clone()) + .or_insert_with(LinkedHashSet::new) + .insert(session.clone()); + self.reverse.insert(session.clone(), key); + } + + pub fn get(&mut self, key: &SessionKey) -> Option { + let session = { + let sessions = self.sessions.get_mut(key)?; + sessions.front().cloned()?.0 + }; + + #[cfg(ossl111)] + { + use self::openssl::ssl::SslVersion; + + // https://tools.ietf.org/html/rfc8446#appendix-C.4 + // OpenSSL will remove the session from its cache after the handshake completes anyway, but this ensures + // that concurrent handshakes don't end up with the same session. + if session.protocol_version() == SslVersion::TLS1_3 { + self.remove(&session); + } + } + + Some(session) + } + + pub fn remove(&mut self, session: &SslSessionRef) { + let key = match self.reverse.remove(session.id()) { + Some(key) => key, + None => return, + }; + + if let Entry::Occupied(mut sessions) = self.sessions.entry(key) { + sessions.get_mut().remove(session.id()); + if sessions.get().is_empty() { + sessions.remove(); + } + } + } +} + +#[cfg(test)] +mod tests { + use std::io::{Read, Write}; + use std::net::TcpStream; + + use crate::TlsConnector; + + fn connect_and_assert(tls: &TlsConnector, domain: &str, port: u16, should_resume: bool) { + let s = TcpStream::connect((domain, port)).unwrap(); + let mut stream = tls.connect(domain, s).unwrap(); + + // Must write to the stream, as OpenSSL doesn't appear to call the + // session callback until we do. + stream.write_all(b"GET / HTTP/1.0\r\n\r\n").unwrap(); + let mut result = vec![]; + stream.read_to_end(&mut result).unwrap(); + + assert_eq!((stream.0).0.ssl().session_reused(), should_resume); + + // Must shut down properly, or OpenSSL will invalidate the session. + stream.shutdown().unwrap(); + } + + #[test] + fn connect_no_session_ticket_resumption() { + let tls = TlsConnector::new().unwrap(); + connect_and_assert(&tls, "google.com", 443, false); + connect_and_assert(&tls, "google.com", 443, false); + } + + #[test] + fn connect_session_ticket_resumption() { + let mut builder = TlsConnector::builder(); + builder.session_tickets_enabled(true); + let tls = builder.build().unwrap(); + + connect_and_assert(&tls, "google.com", 443, false); + connect_and_assert(&tls, "google.com", 443, true); + } + + #[test] + fn connect_session_ticket_resumption_two_sites() { + let mut builder = TlsConnector::builder(); + builder.session_tickets_enabled(true); + let tls = builder.build().unwrap(); + + connect_and_assert(&tls, "google.com", 443, false); + connect_and_assert(&tls, "mozilla.org", 443, false); + connect_and_assert(&tls, "google.com", 443, true); + connect_and_assert(&tls, "mozilla.org", 443, true); + } +} diff --git a/src/imp/schannel.rs b/src/imp/schannel.rs index d80ff436..80b35245 100644 --- a/src/imp/schannel.rs +++ b/src/imp/schannel.rs @@ -4,10 +4,13 @@ use self::schannel::cert_context::{CertContext, HashAlgorithm}; use self::schannel::cert_store::{CertAdd, CertStore, Memory, PfxImportOptions}; use self::schannel::schannel_cred::{Direction, Protocol, SchannelCred}; use self::schannel::tls_stream; +use std::collections::VecDeque; use std::error; use std::fmt; use std::io; use std::str; +use std::sync::{Arc, Mutex}; +use std::time::{Duration, SystemTime}; use {TlsAcceptorBuilder, TlsConnectorBuilder}; @@ -20,6 +23,19 @@ static PROTOCOLS: &'static [Protocol] = &[ Protocol::Tls12, ]; +#[derive(Clone)] +struct CacheEntry { + domain: String, + expiry: SystemTime, + credentials: SchannelCred, +} + +// Number of credentials to cache. +const CREDENTIAL_CACHE_SIZE: usize = 10; + +// Credentials live for 10 minutes. +const CREDENTIAL_TTL: Duration = Duration::from_secs(10 * 60); + fn convert_protocols(min: Option<::Protocol>, max: Option<::Protocol>) -> &'static [Protocol] { let mut protocols = PROTOCOLS; if let Some(p) = max.and_then(|max| protocols.get(..max as usize)) { @@ -183,6 +199,8 @@ pub struct TlsConnector { min_protocol: Option<::Protocol>, max_protocol: Option<::Protocol>, use_sni: bool, + session_tickets_enabled: bool, + credentials_cache: Arc>>, accept_invalid_hostnames: bool, accept_invalid_certs: bool, disable_built_in_roots: bool, @@ -202,6 +220,8 @@ impl TlsConnector { min_protocol: builder.min_protocol, max_protocol: builder.max_protocol, use_sni: builder.use_sni, + session_tickets_enabled: builder.session_tickets_enabled, + credentials_cache: Arc::new(Mutex::new(VecDeque::with_capacity(CREDENTIAL_CACHE_SIZE))), accept_invalid_hostnames: builder.accept_invalid_hostnames, accept_invalid_certs: builder.accept_invalid_certs, disable_built_in_roots: builder.disable_built_in_roots, @@ -212,12 +232,7 @@ impl TlsConnector { where S: io::Read + io::Write, { - let mut builder = SchannelCred::builder(); - builder.enabled_protocols(convert_protocols(self.min_protocol, self.max_protocol)); - if let Some(cert) = self.cert.as_ref() { - builder.cert(cert.clone()); - } - let cred = builder.acquire(Direction::Outbound)?; + let cred = self.get_credentials(domain)?; let mut builder = tls_stream::Builder::new(); builder .cert_store(self.roots.clone()) @@ -249,11 +264,73 @@ impl TlsConnector { )) }); } - match builder.connect(cred, stream) { - Ok(s) => Ok(TlsStream(s)), + match builder.connect(cred.clone(), stream) { + Ok(s) => { + self.store_credentials(domain, cred); + Ok(TlsStream(s)) + } Err(e) => Err(e.into()), } } + + fn get_credentials(&self, domain: &str) -> io::Result { + if self.session_tickets_enabled { + let mut found = None; + let mut cache = self.credentials_cache.lock().unwrap(); + for i in 0..cache.len() { + if &cache[i].domain == domain { + found = Some(i); + break; + } + } + + if let Some(idx) = found { + let now = SystemTime::now(); + let mut entry = cache.remove(idx).unwrap(); + + if entry.expiry > now { + let ret = entry.credentials.clone(); + entry.expiry = now + CREDENTIAL_TTL; + cache.push_back(entry); + return Ok(ret); + } + } + } + + let mut builder = SchannelCred::builder(); + builder.enabled_protocols(convert_protocols(self.min_protocol, self.max_protocol)); + if let Some(cert) = self.cert.as_ref() { + builder.cert(cert.clone()); + } + builder.acquire(Direction::Outbound) + } + + fn store_credentials(&self, domain: &str, cred: SchannelCred) { + if self.session_tickets_enabled { + let mut found = None; + let mut cache = self.credentials_cache.lock().unwrap(); + for i in 0..cache.len() { + if &cache[i].domain == domain { + found = Some(i); + break; + } + } + + if let Some(idx) = found { + cache.remove(idx).unwrap(); + } + + if cache.len() == CREDENTIAL_CACHE_SIZE { + cache.pop_front(); + } + + cache.push_back(CacheEntry { + domain: domain.to_owned(), + expiry: SystemTime::now() + CREDENTIAL_TTL, + credentials: cred, + }); + } + } } #[derive(Clone)] @@ -365,3 +442,52 @@ impl io::Write for TlsStream { self.0.flush() } } + +#[cfg(test)] +mod tests { + use std::net::TcpStream; + + use crate::TlsConnector; + + fn connect_and_assert(tls: &TlsConnector, domain: &str, port: u16, should_resume: bool) { + let s = TcpStream::connect((domain, port)).unwrap(); + let stream = tls.connect(domain, s).unwrap(); + + assert_eq!((stream.0).0.session_resumed().unwrap(), should_resume); + } + + /// Expected to fail on Windows versions where RFC 5077 was not implemented (should just be + /// Windows 7 and below). + #[test] + fn connect_no_session_ticket_resumption() { + let tls = TlsConnector::new().unwrap(); + connect_and_assert(&tls, "google.com", 443, false); + connect_and_assert(&tls, "google.com", 443, false); + } + + /// Expected to fail on Windows versions where RFC 5077 was not implemented (should just be + /// Windows 7 and below). + #[test] + fn connect_session_ticket_resumption() { + let mut builder = TlsConnector::builder(); + builder.session_tickets_enabled(true); + let tls = builder.build().unwrap(); + + connect_and_assert(&tls, "google.com", 443, false); + connect_and_assert(&tls, "google.com", 443, true); + } + + /// Expected to fail on Windows versions where RFC 5077 was not implemented (should just be + /// Windows 7 and below). + #[test] + fn connect_session_ticket_resumption_two_sites() { + let mut builder = TlsConnector::builder(); + builder.session_tickets_enabled(true); + let tls = builder.build().unwrap(); + + connect_and_assert(&tls, "google.com", 443, false); + connect_and_assert(&tls, "mozilla.org", 443, false); + connect_and_assert(&tls, "google.com", 443, true); + connect_and_assert(&tls, "mozilla.org", 443, true); + } +} diff --git a/src/imp/security_framework.rs b/src/imp/security_framework.rs index a3510352..17af09cc 100644 --- a/src/imp/security_framework.rs +++ b/src/imp/security_framework.rs @@ -7,6 +7,7 @@ use self::security_framework::base; use self::security_framework::certificate::SecCertificate; use self::security_framework::identity::SecIdentity; use self::security_framework::import_export::{ImportedIdentity, Pkcs12ImportOptions}; +use self::security_framework::os::macos::import_export::Pkcs12ImportOptionsExt; use self::security_framework::secure_transport::{ self, ClientBuilder, SslConnectionType, SslContext, SslProtocol, SslProtocolSide, }; @@ -128,10 +129,9 @@ impl Identity { keychain } }; - let imports = Pkcs12ImportOptions::new() - .passphrase(pass) - .keychain(keychain) - .import(buf)?; + let mut import_opts = Pkcs12ImportOptions::new(); + ::keychain(&mut import_opts, keychain); + let imports = import_opts.passphrase(pass).import(buf)?; Ok(imports) } @@ -260,6 +260,7 @@ pub struct TlsConnector { max_protocol: Option, roots: Vec, use_sni: bool, + session_tickets_enabled: bool, danger_accept_invalid_hostnames: bool, danger_accept_invalid_certs: bool, disable_built_in_roots: bool, @@ -277,6 +278,7 @@ impl TlsConnector { .map(|c| (c.0).0.clone()) .collect(), use_sni: builder.use_sni, + session_tickets_enabled: builder.session_tickets_enabled, danger_accept_invalid_hostnames: builder.accept_invalid_hostnames, danger_accept_invalid_certs: builder.accept_invalid_certs, disable_built_in_roots: builder.disable_built_in_roots, @@ -299,6 +301,7 @@ impl TlsConnector { } builder.anchor_certificates(&self.roots); builder.use_sni(self.use_sni); + builder.enable_session_tickets(self.session_tickets_enabled); builder.danger_accept_invalid_hostnames(self.danger_accept_invalid_hostnames); builder.danger_accept_invalid_certs(self.danger_accept_invalid_certs); builder.trust_anchor_certificates_only(self.disable_built_in_roots); diff --git a/src/lib.rs b/src/lib.rs index c91a2756..935ce8fc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -328,6 +328,7 @@ pub struct TlsConnectorBuilder { accept_invalid_hostnames: bool, use_sni: bool, disable_built_in_roots: bool, + session_tickets_enabled: bool, } impl TlsConnectorBuilder { @@ -418,6 +419,17 @@ impl TlsConnectorBuilder { self } + /// Controls the use of RFC 5077 TLS session ticket resumption. + /// + /// Defaults to `false`. + pub fn session_tickets_enabled( + &mut self, + session_tickets_enabled: bool, + ) -> &mut TlsConnectorBuilder { + self.session_tickets_enabled = session_tickets_enabled; + self + } + /// Creates a new `TlsConnector`. pub fn build(&self) -> Result { let connector = imp::TlsConnector::new(self)?; @@ -464,6 +476,7 @@ impl TlsConnector { accept_invalid_certs: false, accept_invalid_hostnames: false, disable_built_in_roots: false, + session_tickets_enabled: false, } } @@ -477,8 +490,8 @@ impl TlsConnector { /// which can be used to restart the handshake when the socket is ready /// again. /// - /// The domain is ignored if both SNI and hostname verification are - /// disabled. + /// The domain is ignored if SNI, hostname verification, and TLS session + /// ticket resumption are all disabled. pub fn connect( &self, domain: &str,