Skip to content

Commit

Permalink
tls: enable reusable sessions
Browse files Browse the repository at this point in the history
Fixes #1145.
  • Loading branch information
Chilledheart committed Nov 12, 2024
1 parent 545daaa commit 763561d
Show file tree
Hide file tree
Showing 13 changed files with 668 additions and 26 deletions.
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4140,6 +4140,7 @@ set(files
src/net/x509_util.cpp
src/net/ssl_socket.cpp
src/net/ssl_server_socket.cpp
src/net/ssl_client_session_cache.cpp
src/net/openssl_util.cpp
src/net/base64.cpp
src/net/cipher.cpp
Expand Down Expand Up @@ -4203,6 +4204,7 @@ set(hfiles
src/net/x509_util.hpp
src/net/ssl_socket.hpp
src/net/ssl_server_socket.hpp
src/net/ssl_client_session_cache.hpp
src/net/net_errors.hpp
src/net/net_error_list.hpp
src/net/openssl_util.hpp
Expand Down
4 changes: 2 additions & 2 deletions src/cli/cli_connection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2054,8 +2054,8 @@ void CliConnection::OnConnect() {
LOG(INFO) << "Connection (client) " << connection_id() << " connect " << remote_domain();
// create lazy
if (enable_upstream_tls_) {
channel_ = ssl_stream::create(ssl_socket_data_index(), *io_context_, remote_host_ips_, remote_host_sni_,
remote_port_, this, upstream_https_fallback_, upstream_ssl_ctx_);
channel_ = ssl_stream::create(ssl_socket_data_index(), ssl_client_session_cache(), *io_context_, remote_host_ips_,
remote_host_sni_, remote_port_, this, upstream_https_fallback_, upstream_ssl_ctx_);

} else {
channel_ = stream::create(*io_context_, remote_host_ips_, remote_host_sni_, remote_port_, this);
Expand Down
9 changes: 8 additions & 1 deletion src/net/connection.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "net/asio.hpp"
#include "net/network.hpp"
#include "net/protocol.hpp"
#include "net/ssl_client_session_cache.hpp"
#include "net/ssl_server_socket.hpp"

#include <absl/functional/any_invocable.h>
Expand Down Expand Up @@ -220,18 +221,21 @@ class Connection {
/// \param the number of connection id
/// \param the pointer of tlsext ctx
/// \param the ssl client data index
/// \param the ssl client session cache
void on_accept(asio::ip::tcp::socket&& socket,
const asio::ip::tcp::endpoint& endpoint,
const asio::ip::tcp::endpoint& peer_endpoint,
int connection_id,
tlsext_ctx_t* tlsext_ctx,
int ssl_socket_data_index) {
int ssl_socket_data_index,
SSLClientSessionCache* ssl_client_session_cache) {
downlink_->on_accept(std::move(socket));
endpoint_ = endpoint;
peer_endpoint_ = peer_endpoint;
connection_id_ = connection_id;
tlsext_ctx_.reset(tlsext_ctx);
ssl_socket_data_index_ = ssl_socket_data_index;
ssl_client_session_cache_ = ssl_client_session_cache;
}

/// set callback
Expand Down Expand Up @@ -264,6 +268,7 @@ class Connection {
}

int ssl_socket_data_index() const { return ssl_socket_data_index_; }
SSLClientSessionCache* ssl_client_session_cache() const { return ssl_client_session_cache_; }

protected:
/// the peek current io
Expand All @@ -289,6 +294,8 @@ class Connection {
std::unique_ptr<tlsext_ctx_t> tlsext_ctx_;
/// the ssl client data index
int ssl_socket_data_index_ = -1;
/// the ssl client context cache
SSLClientSessionCache* ssl_client_session_cache_ = nullptr;

/// if https fallback
bool upstream_https_fallback_;
Expand Down
9 changes: 7 additions & 2 deletions src/net/content_server.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "net/connection.hpp"
#include "net/network.hpp"
#include "net/protocol.hpp"
#include "net/ssl_client_session_cache.hpp"
#include "net/ssl_socket.hpp"
#include "net/x509_util.hpp"

Expand Down Expand Up @@ -278,7 +279,7 @@ class ContentServer {
}
SetSocketTcpNoDelay(&socket, ec);
conn->on_accept(std::move(socket), ctx.endpoint, ctx.peer_endpoint, connection_id, tlsext_ctx,
ssl_socket_data_index_);
ssl_socket_data_index_, ssl_client_session_cache_.get());
conn->set_disconnect_cb([this, conn]() mutable { on_disconnect(conn); });
connection_map_.insert(std::make_pair(connection_id, conn));
++opened_connections_;
Expand Down Expand Up @@ -555,7 +556,9 @@ class ContentServer {
SSL_CTX_set_verify(ctx, SSL_VERIFY_NONE, ::SSL_CTX_get_verify_callback(ctx));
} else {
SSL_CTX_set_verify(ctx, SSL_VERIFY_PEER, ::SSL_CTX_get_verify_callback(ctx));
SSL_CTX_set_reverify_on_resume(ctx, 1);
// FIXME
// reverify only supported on custom verify callback
// SSL_CTX_set_reverify_on_resume(ctx, 1);
}
if (ec) {
return;
Expand Down Expand Up @@ -599,6 +602,7 @@ class ContentServer {

client_instance_ = this;
ssl_socket_data_index_ = SSL_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr);
ssl_client_session_cache_ = std::make_unique<SSLClientSessionCache>(SSLClientSessionCache::Config{});

// Disable the internal session cache. Session caching is handled
// externally (i.e. by SSLClientSessionCache).
Expand Down Expand Up @@ -646,6 +650,7 @@ class ContentServer {
bool enable_tls_;
std::string upstream_certificate_;
bssl::UniquePtr<SSL_CTX> upstream_ssl_ctx_;
std::unique_ptr<SSLClientSessionCache> ssl_client_session_cache_;

std::string certificate_;
std::string private_key_;
Expand Down
4 changes: 2 additions & 2 deletions src/net/doh_request.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,8 @@ void DoHRequest::OnSocketConnect() {
SetTCPCongestion(socket_.native_handle(), ec);
SetTCPKeepAlive(socket_.native_handle(), ec);
SetSocketTcpNoDelay(&socket_, ec);
ssl_socket_ = SSLSocket::Create(ssl_socket_data_index_, &io_context_, &socket_, ssl_ctx_,
/*https_fallback*/ true, doh_host_);
ssl_socket_ = SSLSocket::Create(ssl_socket_data_index_, nullptr, &io_context_, &socket_, ssl_ctx_,
/*https_fallback*/ true, doh_host_, doh_port_);

ssl_socket_->Connect([this, self](int rv) {
asio::error_code ec;
Expand Down
4 changes: 2 additions & 2 deletions src/net/dot_request.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ void DoTRequest::OnSocketConnect() {
SetTCPCongestion(socket_.native_handle(), ec);
SetTCPKeepAlive(socket_.native_handle(), ec);
SetSocketTcpNoDelay(&socket_, ec);
ssl_socket_ = SSLSocket::Create(ssl_socket_data_index_, &io_context_, &socket_, ssl_ctx_,
/*https_fallback*/ true, dot_host_);
ssl_socket_ = SSLSocket::Create(ssl_socket_data_index_, nullptr, &io_context_, &socket_, ssl_ctx_,
/*https_fallback*/ true, dot_host_, dot_port_);

ssl_socket_->Connect([this, self](int rv) {
asio::error_code ec;
Expand Down
175 changes: 175 additions & 0 deletions src/net/ssl_client_session_cache.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
// SPDX-License-Identifier: GPL-2.0
/* Copyright (c) 2024 Chilledheart */

#ifdef UNSAFE_BUFFERS_BUILD
// TODO(crbug.com/40284755): Remove this and spanify to fix the errors.
#pragma allow_unsafe_buffers
#endif

#include "net/ssl_client_session_cache.hpp"

#include <tuple>
#include <utility>

#include <absl/time/clock.h>
#include "third_party/boringssl/src/include/openssl/ssl.h"

namespace net {

namespace {

// Returns a tuple of references to fields of |key|, for comparison purposes.
auto TieKeyFields(const SSLClientSessionCache::Key& key) {
return std::tie(key.server, key.dest_ip_addr);
}

} // namespace

SSLClientSessionCache::Key::Key() = default;
SSLClientSessionCache::Key::Key(const Key& other) = default;
SSLClientSessionCache::Key::Key(Key&& other) = default;
SSLClientSessionCache::Key::~Key() = default;
SSLClientSessionCache::Key& SSLClientSessionCache::Key::operator=(const Key& other) = default;
SSLClientSessionCache::Key& SSLClientSessionCache::Key::operator=(Key&& other) = default;

bool SSLClientSessionCache::Key::operator==(const Key& other) const {
return TieKeyFields(*this) == TieKeyFields(other);
}

bool SSLClientSessionCache::Key::operator<(const Key& other) const {
return TieKeyFields(*this) < TieKeyFields(other);
}

SSLClientSessionCache::SSLClientSessionCache(const Config& config) : config_(config), cache_(config.max_entries) {}

SSLClientSessionCache::~SSLClientSessionCache() {
Flush();
}

size_t SSLClientSessionCache::size() const {
return cache_.size();
}

bssl::UniquePtr<SSL_SESSION> SSLClientSessionCache::Lookup(const Key& cache_key) {
// Expire stale sessions.
lookups_since_flush_++;
if (lookups_since_flush_ >= config_.expiration_check_count) {
lookups_since_flush_ = 0;
FlushExpiredSessions();
}

auto iter = cache_.Get(cache_key);
if (iter == cache_.end())
return nullptr;

time_t now = absl::ToTimeT(absl::Now());
bssl::UniquePtr<SSL_SESSION> session = iter->second.Pop();
if (iter->second.ExpireSessions(now))
cache_.Erase(iter);

if (IsExpired(session.get(), now))
session = nullptr;

return session;
}

void SSLClientSessionCache::Insert(const Key& cache_key, bssl::UniquePtr<SSL_SESSION> session) {
auto iter = cache_.Get(cache_key);
if (iter == cache_.end())
iter = cache_.Put(cache_key, Entry());
iter->second.Push(std::move(session));
}

void SSLClientSessionCache::ClearEarlyData(const Key& cache_key) {
auto iter = cache_.Get(cache_key);
if (iter != cache_.end()) {
for (auto& session : iter->second.sessions) {
if (session) {
session.reset(SSL_SESSION_copy_without_early_data(session.get()));
}
}
}
}

void SSLClientSessionCache::Flush() {
cache_.Clear();
}

bool SSLClientSessionCache::IsExpired(SSL_SESSION* session, time_t now) {
if (now < 0)
return true;
uint64_t now_u64 = static_cast<uint64_t>(now);

// now_u64 may be slightly behind because of differences in how
// time is calculated at this layer versus BoringSSL.
// Add a second of wiggle room to account for this.
return now_u64 < SSL_SESSION_get_time(session) - 1 ||
now_u64 >= SSL_SESSION_get_time(session) + SSL_SESSION_get_timeout(session);
}

SSLClientSessionCache::Entry::Entry() = default;
SSLClientSessionCache::Entry::Entry(Entry&&) = default;
SSLClientSessionCache::Entry::~Entry() = default;

void SSLClientSessionCache::Entry::Push(bssl::UniquePtr<SSL_SESSION> session) {
if (sessions[0] != nullptr && SSL_SESSION_should_be_single_use(sessions[0].get())) {
sessions[1] = std::move(sessions[0]);
}
sessions[0] = std::move(session);
}

bssl::UniquePtr<SSL_SESSION> SSLClientSessionCache::Entry::Pop() {
if (sessions[0] == nullptr)
return nullptr;
bssl::UniquePtr<SSL_SESSION> session = bssl::UpRef(sessions[0]);
if (SSL_SESSION_should_be_single_use(session.get())) {
sessions[0] = std::move(sessions[1]);
sessions[1] = nullptr;
}
return session;
}

bool SSLClientSessionCache::Entry::ExpireSessions(time_t now) {
if (sessions[0] == nullptr)
return true;

if (SSLClientSessionCache::IsExpired(sessions[0].get(), now)) {
return true;
}

if (sessions[1] != nullptr && SSLClientSessionCache::IsExpired(sessions[1].get(), now)) {
sessions[1] = nullptr;
}

return false;
}

void SSLClientSessionCache::FlushExpiredSessions() {
time_t now = absl::ToTimeT(absl::Now());
auto iter = cache_.begin();
while (iter != cache_.end()) {
if (iter->second.ExpireSessions(now)) {
iter = cache_.Erase(iter);
} else {
++iter;
}
}
}

#if 0
void SSLClientSessionCache::OnMemoryPressure(
base::MemoryPressureListener::MemoryPressureLevel memory_pressure_level) {
switch (memory_pressure_level) {
case base::MemoryPressureListener::MEMORY_PRESSURE_LEVEL_NONE:
break;
case base::MemoryPressureListener::MEMORY_PRESSURE_LEVEL_MODERATE:
FlushExpiredSessions();
break;
case base::MemoryPressureListener::MEMORY_PRESSURE_LEVEL_CRITICAL:
Flush();
break;
}
}
#endif

} // namespace net
Loading

0 comments on commit 763561d

Please sign in to comment.