diff --git a/dc/s2n-quic-dc/src/path/secret/map.rs b/dc/s2n-quic-dc/src/path/secret/map.rs index a6e7d2c3e..c62e426c3 100644 --- a/dc/s2n-quic-dc/src/path/secret/map.rs +++ b/dc/s2n-quic-dc/src/path/secret/map.rs @@ -5,7 +5,7 @@ use crate::{ credentials::{Credentials, Id}, event, packet::{secret_control as control, Packet}, - path::secret::{open, seal, stateless_reset, HandshakeKind}, + path::secret::{open, seal, stateless_reset}, stream::TransportFeatures, }; use s2n_quic_core::{dc, time}; @@ -86,16 +86,34 @@ impl Map { self.store.drop_state(); } - pub fn contains(&self, peer: SocketAddr) -> bool { + pub fn contains(&self, peer: &SocketAddr) -> bool { self.store.contains(peer) } + /// Check whether we would like to (re-)handshake with this peer. + /// + /// Note that this is distinct from `contains`, we may already have *some* credentials for a + /// peer but still be interested in handshaking (e.g., due to periodic refresh of the + /// credentials). + pub fn needs_handshake(&self, peer: &SocketAddr) -> bool { + self.store.needs_handshake(peer) + } + + /// Gets the [`Peer`] entry for the given address + /// + /// NOTE: This function is used to track cache hit ratios so it + /// should only be used for connection attempts. + pub fn get_tracked(&self, peer: SocketAddr) -> Option { + let entry = self.store.get_by_addr_tracked(&peer)?; + Some(Peer::new(&entry, self)) + } + /// Gets the [`Peer`] entry for the given address /// /// NOTE: This function is used to track cache hit ratios so it /// should only be used for connection attempts. - pub fn get_tracked(&self, peer: SocketAddr, handshake: HandshakeKind) -> Option { - let entry = self.store.get_by_addr_tracked(&peer, handshake)?; + pub fn get_untracked(&self, peer: SocketAddr) -> Option { + let entry = self.store.get_by_addr_untracked(&peer)?; Some(Peer::new(&entry, self)) } diff --git a/dc/s2n-quic-dc/src/path/secret/map/state.rs b/dc/s2n-quic-dc/src/path/secret/map/state.rs index c37e20899..68b546f1a 100644 --- a/dc/s2n-quic-dc/src/path/secret/map/state.rs +++ b/dc/s2n-quic-dc/src/path/secret/map/state.rs @@ -8,7 +8,7 @@ use crate::{ event::{self, EndpointPublisher as _, IntoEvent as _}, fixed_map::{self, ReadGuard}, packet::{secret_control as control, Packet}, - path::secret::{receiver, HandshakeKind}, + path::secret::receiver, }; use s2n_quic_core::{ inet::SocketAddress, @@ -355,8 +355,12 @@ where self.peers.clear(); } - fn contains(&self, peer: SocketAddr) -> bool { - self.peers.contains_key(&peer) && !self.requested_handshakes.pin().contains(&peer) + fn contains(&self, peer: &SocketAddr) -> bool { + self.peers.contains_key(peer) + } + + fn needs_handshake(&self, peer: &SocketAddr) -> bool { + self.requested_handshakes.pin().contains(peer) } fn on_new_path_secrets(&self, entry: Arc) { @@ -408,29 +412,21 @@ where }); } - fn get_by_addr_tracked( - &self, - peer: &SocketAddr, - handshake: HandshakeKind, - ) -> Option>> { - let result = self.peers.get_by_key(peer)?; - - // If this is trying to use a cached handshake but we've got a request to do a handshake, then - // force the application to do a new handshake. This is consistent with the `contains` method. - if matches!(handshake, HandshakeKind::Cached) - && self.requested_handshakes.pin().contains(peer) - { - return None; - } + fn get_by_addr_untracked(&self, peer: &SocketAddr) -> Option>> { + self.peers.get_by_key(peer) + } + + fn get_by_addr_tracked(&self, peer: &SocketAddr) -> Option>> { + let result = self.peers.get_by_key(peer); self.subscriber().on_path_secret_map_address_cache_accessed( event::builder::PathSecretMapAddressCacheAccessed { peer_address: SocketAddress::from(*peer).into_event(), - hit: matches!(handshake, HandshakeKind::Cached), + hit: result.is_some(), }, ); - Some(result) + result } fn get_by_id_untracked(&self, id: &Id) -> Option>> { diff --git a/dc/s2n-quic-dc/src/path/secret/map/store.rs b/dc/s2n-quic-dc/src/path/secret/map/store.rs index 52e1d7691..1464bdce1 100644 --- a/dc/s2n-quic-dc/src/path/secret/map/store.rs +++ b/dc/s2n-quic-dc/src/path/secret/map/store.rs @@ -6,7 +6,7 @@ use crate::{ credentials::{Credentials, Id}, fixed_map::ReadGuard, packet::{secret_control as control, Packet, WireVersion}, - path::secret::{receiver, stateless_reset, HandshakeKind}, + path::secret::{receiver, stateless_reset}, }; use core::time::Duration; use s2n_codec::EncoderBuffer; @@ -21,17 +21,17 @@ pub trait Store: 'static + Send + Sync { fn drop_state(&self); - fn contains(&self, peer: SocketAddr) -> bool; - fn on_new_path_secrets(&self, entry: Arc); fn on_handshake_complete(&self, entry: Arc); - fn get_by_addr_tracked( - &self, - peer: &SocketAddr, - handshake: HandshakeKind, - ) -> Option>>; + fn contains(&self, peer: &SocketAddr) -> bool; + + fn needs_handshake(&self, peer: &SocketAddr) -> bool; + + fn get_by_addr_untracked(&self, peer: &SocketAddr) -> Option>>; + + fn get_by_addr_tracked(&self, peer: &SocketAddr) -> Option>>; fn get_by_id_untracked(&self, id: &Id) -> Option>>; diff --git a/dc/s2n-quic-dc/src/stream/testing.rs b/dc/s2n-quic-dc/src/stream/testing.rs index c284f19aa..9100c6598 100644 --- a/dc/s2n-quic-dc/src/stream/testing.rs +++ b/dc/s2n-quic-dc/src/stream/testing.rs @@ -4,7 +4,7 @@ use super::{server::tokio::stats, socket::Protocol}; use crate::{ event, - path::secret::{self, HandshakeKind}, + path::secret, stream::{ application::Stream, client::tokio as client, @@ -40,7 +40,7 @@ impl Client { ) -> io::Result { let server = server.as_ref(); let peer = server.local_addr; - if let Some(peer) = self.map.get_tracked(peer, HandshakeKind::Cached) { + if let Some(peer) = self.map.get_tracked(peer) { return Ok(peer); } @@ -48,11 +48,10 @@ impl Client { self.map .test_insert_pair(local_addr, &server.map, server.local_addr); - self.map - .get_tracked(peer, HandshakeKind::Fresh) - .ok_or_else(|| { - io::Error::new(io::ErrorKind::AddrNotAvailable, "path secret not available") - }) + // cache hit already tracked above + self.map.get_untracked(peer).ok_or_else(|| { + io::Error::new(io::ErrorKind::AddrNotAvailable, "path secret not available") + }) } pub async fn connect_to>(