Skip to content

Commit

Permalink
Support RFC 5077 TLS session ticket reuse
Browse files Browse the repository at this point in the history
  • Loading branch information
cyang1 committed May 9, 2020
1 parent 97b77f4 commit 1fd8a7c
Show file tree
Hide file tree
Showing 7 changed files with 369 additions and 20 deletions.
8 changes: 5 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,19 @@ readme = "README.md"
vendored = ["openssl/vendored"]

[target.'cfg(any(target_os = "macos", target_os = "ios"))'.dependencies]
security-framework = "0.4.1"
security-framework-sys = "0.4.1"
security-framework = { version = "0.4.4", features = ["session-tickets"] }
security-framework-sys = "0.4.3"
lazy_static = "1.0"
libc = "0.2"
tempfile = "3.0"

[target.'cfg(target_os = "windows")'.dependencies]
schannel = "0.1.16"
schannel = "0.1.18"

[target.'cfg(not(any(target_os = "windows", target_os = "macos", target_os = "ios")))'.dependencies]
linked_hash_set = "0.1"
log = "0.4.5"
once_cell = "1.0"
openssl = "0.10.29"
openssl-sys = "0.9.55"
openssl-probe = "0.1"
Expand Down
1 change: 1 addition & 0 deletions appveyor.yml
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
image: Visual Studio 2017
environment:
RUST_VERSION: 1.37.0
TARGET: x86_64-pc-windows-msvc
Expand Down
3 changes: 3 additions & 0 deletions build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ fn main() {
if version >= 0x1_01_00_00_0 {
println!("cargo:rustc-cfg=have_min_max_version");
}
if version >= 0x1_01_01_00_0 {
println!("cargo:rustc-cfg=ossl111");
}
}

if let Ok(version) = env::var("DEP_OPENSSL_LIBRESSL_VERSION_NUMBER") {
Expand Down
207 changes: 204 additions & 3 deletions src/imp/openssl.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,28 @@
extern crate linked_hash_set;
extern crate once_cell;
extern crate openssl;
extern crate openssl_probe;

use self::linked_hash_set::LinkedHashSet;
use self::once_cell::sync::OnceCell;
use self::openssl::error::ErrorStack;
use self::openssl::ex_data::Index;
use self::openssl::hash::MessageDigest;
use self::openssl::nid::Nid;
use self::openssl::pkcs12::Pkcs12;
use self::openssl::pkey::PKey;
use self::openssl::ssl::{
self, MidHandshakeSslStream, SslAcceptor, SslConnector, SslContextBuilder, SslMethod,
SslVerifyMode,
self, MidHandshakeSslStream, Ssl, SslAcceptor, SslConnector, SslContextBuilder, SslMethod,
SslSession, SslSessionCacheMode, SslSessionRef, SslVerifyMode,
};
use self::openssl::x509::{X509, store::X509StoreBuilder, X509VerifyResult};
use std::borrow::Borrow;
use std::collections::hash_map::{Entry, HashMap};
use std::error;
use std::fmt;
use std::hash::{Hash, Hasher};
use std::io;
use std::sync::Once;
use std::sync::{Arc, Mutex, Once};

use {Protocol, TlsAcceptorBuilder, TlsConnectorBuilder};
use self::openssl::pkey::Private;
Expand Down Expand Up @@ -248,6 +256,8 @@ pub struct TlsConnector {
use_sni: bool,
accept_invalid_hostnames: bool,
accept_invalid_certs: bool,
session_tickets_enabled: bool,
session_cache: Arc<Mutex<SessionCache>>,
}

impl TlsConnector {
Expand Down Expand Up @@ -277,11 +287,37 @@ impl TlsConnector {
#[cfg(target_os = "android")]
load_android_root_certs(&mut connector)?;

let session_cache = Arc::new(Mutex::new(SessionCache::new()));
if builder.session_tickets_enabled {
connector.set_session_cache_mode(SslSessionCacheMode::CLIENT);

connector.set_new_session_callback({
let session_cache = session_cache.clone();
move |ssl, session| {
if let Some(key) = key_index().ok().and_then(|idx| ssl.ex_data(idx)) {
if let Ok(mut session_cache) = session_cache.lock() {
session_cache.insert(key.clone(), session);
}
}
}
});
connector.set_remove_session_callback({
let session_cache = session_cache.clone();
move |_, session| {
if let Ok(mut session_cache) = session_cache.lock() {
session_cache.remove(session);
}
}
});
}

Ok(TlsConnector {
connector: connector.build(),
use_sni: builder.use_sni,
accept_invalid_hostnames: builder.accept_invalid_hostnames,
accept_invalid_certs: builder.accept_invalid_certs,
session_tickets_enabled: builder.session_tickets_enabled,
session_cache,
})
}

Expand All @@ -297,6 +333,23 @@ impl TlsConnector {
if self.accept_invalid_certs {
ssl.set_verify(SslVerifyMode::NONE);
}
if self.session_tickets_enabled {
let key = SessionKey {
host: domain.to_string(),
};

if let Ok(mut session_cache) = self.session_cache.lock() {
if let Some(session) = session_cache.get(&key) {
// Note: the `unsafe`-ty here is because the `session` is required to come from the
// same SSL_CTX that the ssl object (`ssl`) is from, since it maintains internal
// pointers and refcounts. Here, we only have one SSL_CTX, so this is safe.
unsafe { ssl.set_session(&session)? };
}
}

let idx = key_index()?;
ssl.set_ex_data(idx, key);
}

let s = ssl.connect(domain, stream)?;
Ok(TlsStream(s))
Expand Down Expand Up @@ -412,3 +465,151 @@ impl<S: io::Read + io::Write> io::Write for TlsStream<S> {
self.0.flush()
}
}

fn key_index() -> Result<Index<Ssl, SessionKey>, ErrorStack> {
static IDX: OnceCell<Index<Ssl, SessionKey>> = OnceCell::new();
IDX.get_or_try_init(|| Ssl::new_ex_index()).map(|v| *v)
}

#[derive(Hash, PartialEq, Eq, Clone)]
pub struct SessionKey {
pub host: String,
}

#[derive(Clone)]
struct HashSession(SslSession);

impl PartialEq for HashSession {
fn eq(&self, other: &HashSession) -> bool {
self.0.id() == other.0.id()
}
}

impl Eq for HashSession {}

impl Hash for HashSession {
fn hash<H>(&self, state: &mut H)
where
H: Hasher,
{
self.0.id().hash(state);
}
}

impl Borrow<[u8]> for HashSession {
fn borrow(&self) -> &[u8] {
self.0.id()
}
}

pub struct SessionCache {
sessions: HashMap<SessionKey, LinkedHashSet<HashSession>>,
reverse: HashMap<HashSession, SessionKey>,
}

impl SessionCache {
pub fn new() -> SessionCache {
SessionCache {
sessions: HashMap::new(),
reverse: HashMap::new(),
}
}

pub fn insert(&mut self, key: SessionKey, session: SslSession) {
let session = HashSession(session);

self.sessions
.entry(key.clone())
.or_insert_with(LinkedHashSet::new)
.insert(session.clone());
self.reverse.insert(session.clone(), key);
}

pub fn get(&mut self, key: &SessionKey) -> Option<SslSession> {
let session = {
let sessions = self.sessions.get_mut(key)?;
sessions.front().cloned()?.0
};

#[cfg(ossl111)]
{
use self::openssl::ssl::SslVersion;

// https://tools.ietf.org/html/rfc8446#appendix-C.4
// OpenSSL will remove the session from its cache after the handshake completes anyway, but this ensures
// that concurrent handshakes don't end up with the same session.
if session.protocol_version() == SslVersion::TLS1_3 {
self.remove(&session);
}
}

Some(session)
}

pub fn remove(&mut self, session: &SslSessionRef) {
let key = match self.reverse.remove(session.id()) {
Some(key) => key,
None => return,
};

if let Entry::Occupied(mut sessions) = self.sessions.entry(key) {
sessions.get_mut().remove(session.id());
if sessions.get().is_empty() {
sessions.remove();
}
}
}
}

#[cfg(test)]
mod tests {
use std::io::{Read, Write};
use std::net::TcpStream;

use crate::TlsConnector;

fn connect_and_assert(tls: &TlsConnector, domain: &str, port: u16, should_resume: bool) {
let s = TcpStream::connect((domain, port)).unwrap();
let mut stream = tls.connect(domain, s).unwrap();

// Must write to the stream, as OpenSSL doesn't appear to call the
// session callback until we do.
stream.write_all(b"GET / HTTP/1.0\r\n\r\n").unwrap();
let mut result = vec![];
stream.read_to_end(&mut result).unwrap();

assert_eq!((stream.0).0.ssl().session_reused(), should_resume);

// Must shut down properly, or OpenSSL will invalidate the session.
stream.shutdown().unwrap();
}

#[test]
fn connect_no_session_ticket_resumption() {
let tls = TlsConnector::new().unwrap();
connect_and_assert(&tls, "google.com", 443, false);
connect_and_assert(&tls, "google.com", 443, false);
}

#[test]
fn connect_session_ticket_resumption() {
let mut builder = TlsConnector::builder();
builder.session_tickets_enabled(true);
let tls = builder.build().unwrap();

connect_and_assert(&tls, "google.com", 443, false);
connect_and_assert(&tls, "google.com", 443, true);
}

#[test]
fn connect_session_ticket_resumption_two_sites() {
let mut builder = TlsConnector::builder();
builder.session_tickets_enabled(true);
let tls = builder.build().unwrap();

connect_and_assert(&tls, "google.com", 443, false);
connect_and_assert(&tls, "mozilla.org", 443, false);
connect_and_assert(&tls, "google.com", 443, true);
connect_and_assert(&tls, "mozilla.org", 443, true);
}
}
Loading

0 comments on commit 1fd8a7c

Please sign in to comment.