diff --git a/quinn-proto/src/tests/mod.rs b/quinn-proto/src/tests/mod.rs index cb39b351f..7dc9dcbb4 100644 --- a/quinn-proto/src/tests/mod.rs +++ b/quinn-proto/src/tests/mod.rs @@ -2,7 +2,7 @@ use std::{ convert::TryInto, mem, net::{Ipv4Addr, Ipv6Addr, SocketAddr}, - sync::Arc, + sync::{Arc, Mutex}, }; use assert_matches::assert_matches; @@ -186,7 +186,7 @@ fn draft_version_compat() { fn stateless_retry() { let _guard = subscribe(); let mut pair = Pair::default(); - pair.server.incoming_connection_behavior = IncomingConnectionBehavior::Validate; + pair.server.handle_incoming = Box::new(validate_incoming); let (client_ch, _server_ch) = pair.connect(); pair.client .connections @@ -200,6 +200,203 @@ fn stateless_retry() { assert_eq!(pair.server.known_cids(), 0); } +#[cfg(feature = "fastbloom")] +#[test] +fn use_token() { + let _guard = subscribe(); + let mut pair = Pair::default(); + let client_config = client_config(); + let (client_ch, _server_ch) = pair.connect_with(client_config.clone()); + pair.client + .connections + .get_mut(&client_ch) + .unwrap() + .close(pair.time, VarInt(42), Bytes::new()); + pair.drive(); + assert_eq!(pair.client.known_connections(), 0); + assert_eq!(pair.client.known_cids(), 0); + assert_eq!(pair.server.known_connections(), 0); + assert_eq!(pair.server.known_cids(), 0); + + pair.server.handle_incoming = Box::new(|incoming| { + assert!(incoming.remote_address_validated()); + assert!(incoming.may_retry()); + IncomingConnectionBehavior::Accept + }); + let (client_ch_2, _server_ch_2) = pair.connect_with(client_config); + pair.client + .connections + .get_mut(&client_ch_2) + .unwrap() + .close(pair.time, VarInt(42), Bytes::new()); + pair.drive(); + assert_eq!(pair.client.known_connections(), 0); + assert_eq!(pair.client.known_cids(), 0); + assert_eq!(pair.server.known_connections(), 0); + assert_eq!(pair.server.known_cids(), 0); +} + +#[cfg(feature = "fastbloom")] +#[test] +fn retry_then_use_token() { + let _guard = subscribe(); + let mut pair = Pair::default(); + let client_config = client_config(); + pair.server.handle_incoming = Box::new(validate_incoming); + let (client_ch, _server_ch) = pair.connect_with(client_config.clone()); + pair.client + .connections + .get_mut(&client_ch) + .unwrap() + .close(pair.time, VarInt(42), Bytes::new()); + pair.drive(); + assert_eq!(pair.client.known_connections(), 0); + assert_eq!(pair.client.known_cids(), 0); + assert_eq!(pair.server.known_connections(), 0); + assert_eq!(pair.server.known_cids(), 0); + + pair.server.handle_incoming = Box::new(|incoming| { + assert!(incoming.remote_address_validated()); + assert!(incoming.may_retry()); + IncomingConnectionBehavior::Accept + }); + let (client_ch_2, _server_ch_2) = pair.connect_with(client_config); + pair.client + .connections + .get_mut(&client_ch_2) + .unwrap() + .close(pair.time, VarInt(42), Bytes::new()); + pair.drive(); + assert_eq!(pair.client.known_connections(), 0); + assert_eq!(pair.client.known_cids(), 0); + assert_eq!(pair.server.known_connections(), 0); + assert_eq!(pair.server.known_cids(), 0); +} + +#[cfg(feature = "fastbloom")] +#[test] +fn use_token_then_retry() { + let _guard = subscribe(); + let mut pair = Pair::default(); + let client_config = client_config(); + let (client_ch, _server_ch) = pair.connect_with(client_config.clone()); + pair.client + .connections + .get_mut(&client_ch) + .unwrap() + .close(pair.time, VarInt(42), Bytes::new()); + pair.drive(); + assert_eq!(pair.client.known_connections(), 0); + assert_eq!(pair.client.known_cids(), 0); + assert_eq!(pair.server.known_connections(), 0); + assert_eq!(pair.server.known_cids(), 0); + + pair.server.handle_incoming = Box::new({ + let mut i = 0; + move |incoming| { + if i == 0 { + assert!(incoming.remote_address_validated()); + assert!(incoming.may_retry()); + i += 1; + IncomingConnectionBehavior::Retry + } else if i == 1 { + assert!(incoming.remote_address_validated()); + assert!(!incoming.may_retry()); + i += 1; + IncomingConnectionBehavior::Accept + } else { + panic!("too many handle_incoming iterations") + } + } + }); + let (client_ch_2, _server_ch_2) = pair.connect_with(client_config); + pair.client + .connections + .get_mut(&client_ch_2) + .unwrap() + .close(pair.time, VarInt(42), Bytes::new()); + pair.drive(); + assert_eq!(pair.client.known_connections(), 0); + assert_eq!(pair.client.known_cids(), 0); + assert_eq!(pair.server.known_connections(), 0); + assert_eq!(pair.server.known_cids(), 0); +} + +#[cfg(feature = "fastbloom")] +#[test] +fn use_same_token_twice() { + #[derive(Default)] + struct EvilTokenStore(Mutex); + + impl TokenStore for EvilTokenStore { + fn insert(&self, _server_name: &str, token: Bytes) { + let mut lock = self.0.lock().unwrap(); + if lock.is_empty() { + *lock = token; + } + } + + fn take(&self, _server_name: &str) -> Option { + let lock = self.0.lock().unwrap(); + if lock.is_empty() { + None + } else { + Some(lock.clone()) + } + } + } + + let _guard = subscribe(); + let mut pair = Pair::default(); + let mut client_config = client_config(); + client_config.token_store(Some(Arc::new(EvilTokenStore::default()))); + let (client_ch, _server_ch) = pair.connect_with(client_config.clone()); + pair.client + .connections + .get_mut(&client_ch) + .unwrap() + .close(pair.time, VarInt(42), Bytes::new()); + pair.drive(); + assert_eq!(pair.client.known_connections(), 0); + assert_eq!(pair.client.known_cids(), 0); + assert_eq!(pair.server.known_connections(), 0); + assert_eq!(pair.server.known_cids(), 0); + + pair.server.handle_incoming = Box::new(|incoming| { + assert!(incoming.remote_address_validated()); + assert!(incoming.may_retry()); + IncomingConnectionBehavior::Accept + }); + let (client_ch_2, _server_ch_2) = pair.connect_with(client_config.clone()); + pair.client + .connections + .get_mut(&client_ch_2) + .unwrap() + .close(pair.time, VarInt(42), Bytes::new()); + pair.drive(); + assert_eq!(pair.client.known_connections(), 0); + assert_eq!(pair.client.known_cids(), 0); + assert_eq!(pair.server.known_connections(), 0); + assert_eq!(pair.server.known_cids(), 0); + + pair.server.handle_incoming = Box::new(|incoming| { + assert!(!incoming.remote_address_validated()); + assert!(incoming.may_retry()); + IncomingConnectionBehavior::Accept + }); + let (client_ch_3, _server_ch_3) = pair.connect_with(client_config); + pair.client + .connections + .get_mut(&client_ch_3) + .unwrap() + .close(pair.time, VarInt(42), Bytes::new()); + pair.drive(); + assert_eq!(pair.client.known_connections(), 0); + assert_eq!(pair.client.known_cids(), 0); + assert_eq!(pair.server.known_connections(), 0); + assert_eq!(pair.server.known_cids(), 0); +} + #[test] fn server_stateless_reset() { let _guard = subscribe(); @@ -554,7 +751,7 @@ fn high_latency_handshake() { fn zero_rtt_happypath() { let _guard = subscribe(); let mut pair = Pair::default(); - pair.server.incoming_connection_behavior = IncomingConnectionBehavior::Validate; + pair.server.handle_incoming = Box::new(validate_incoming); let config = client_config(); // Establish normal connection @@ -723,7 +920,7 @@ fn test_zero_rtt_incoming_limit(configure_server: CLIENT_PORTS.lock().unwrap().next().unwrap(), ); info!("resuming session"); - pair.server.incoming_connection_behavior = IncomingConnectionBehavior::Wait; + pair.server.handle_incoming = Box::new(|_| IncomingConnectionBehavior::Wait); let client_ch = pair.begin_connect(config); assert!(pair.client_conn_mut(client_ch).has_0rtt()); let s = pair.client_streams(client_ch).open(Dir::Uni).unwrap(); @@ -2993,7 +3190,7 @@ fn pure_sender_voluntarily_acks() { fn reject_manually() { let _guard = subscribe(); let mut pair = Pair::default(); - pair.server.incoming_connection_behavior = IncomingConnectionBehavior::RejectAll; + pair.server.handle_incoming = Box::new(|_| IncomingConnectionBehavior::Reject); // The server should now reject incoming connections. let client_ch = pair.begin_connect(client_config()); @@ -3013,7 +3210,20 @@ fn reject_manually() { fn validate_then_reject_manually() { let _guard = subscribe(); let mut pair = Pair::default(); - pair.server.incoming_connection_behavior = IncomingConnectionBehavior::ValidateThenReject; + pair.server.handle_incoming = Box::new({ + let mut i = 0; + move |incoming| { + if incoming.remote_address_validated() { + assert_eq!(i, 1); + i += 1; + IncomingConnectionBehavior::Reject + } else { + assert_eq!(i, 0); + i += 1; + IncomingConnectionBehavior::Retry + } + } + }); // The server should now retry and reject incoming connections. let client_ch = pair.begin_connect(client_config()); diff --git a/quinn-proto/src/tests/util.rs b/quinn-proto/src/tests/util.rs index 7e927e203..f8f9bcb32 100644 --- a/quinn-proto/src/tests/util.rs +++ b/quinn-proto/src/tests/util.rs @@ -297,19 +297,26 @@ pub(super) struct TestEndpoint { conn_events: HashMap>, pub(super) captured_packets: Vec>, pub(super) capture_inbound_packets: bool, - pub(super) incoming_connection_behavior: IncomingConnectionBehavior, + pub(super) handle_incoming: Box IncomingConnectionBehavior>, pub(super) waiting_incoming: Vec, } #[derive(Debug, Copy, Clone)] pub(super) enum IncomingConnectionBehavior { - AcceptAll, - RejectAll, - Validate, - ValidateThenReject, + Accept, + Reject, + Retry, Wait, } +pub(super) fn validate_incoming(incoming: &Incoming) -> IncomingConnectionBehavior { + if incoming.remote_address_validated() { + IncomingConnectionBehavior::Accept + } else { + IncomingConnectionBehavior::Retry + } +} + impl TestEndpoint { fn new(endpoint: Endpoint, addr: SocketAddr) -> Self { let socket = if env::var_os("SSLKEYLOGFILE").is_some() { @@ -334,7 +341,7 @@ impl TestEndpoint { conn_events: HashMap::default(), captured_packets: Vec::new(), capture_inbound_packets: false, - incoming_connection_behavior: IncomingConnectionBehavior::AcceptAll, + handle_incoming: Box::new(|_| IncomingConnectionBehavior::Accept), waiting_incoming: Vec::new(), } } @@ -364,26 +371,15 @@ impl TestEndpoint { { match event { DatagramEvent::NewConnection(incoming) => { - match self.incoming_connection_behavior { - IncomingConnectionBehavior::AcceptAll => { + match (self.handle_incoming)(&incoming) { + IncomingConnectionBehavior::Accept => { let _ = self.try_accept(incoming, now); } - IncomingConnectionBehavior::RejectAll => { + IncomingConnectionBehavior::Reject => { self.reject(incoming); } - IncomingConnectionBehavior::Validate => { - if incoming.remote_address_validated() { - let _ = self.try_accept(incoming, now); - } else { - self.retry(incoming); - } - } - IncomingConnectionBehavior::ValidateThenReject => { - if incoming.remote_address_validated() { - self.reject(incoming); - } else { - self.retry(incoming); - } + IncomingConnectionBehavior::Retry => { + self.retry(incoming); } IncomingConnectionBehavior::Wait => { self.waiting_incoming.push(incoming);