diff --git a/src/protocol/connection.rs b/src/protocol/connection.rs index 00277b8a..cfb9cfe0 100644 --- a/src/protocol/connection.rs +++ b/src/protocol/connection.rs @@ -94,6 +94,15 @@ impl ConnectionHandle { } } + /// Try to upgrade the connection to active state. + pub fn try_upgrade(&mut self) { + if let ConnectionType::Inactive(inactive) = &self.connection { + if let Some(active) = inactive.upgrade() { + self.connection = ConnectionType::Active(active); + } + } + } + /// Attempt to acquire permit which will keep the connection open for indefinite time. pub fn try_get_permit(&self) -> Option { match &self.connection { @@ -120,6 +129,7 @@ impl ConnectionHandle { protocol: protocol.clone(), fallback_names, substream_id, + connection_id: self.connection_id.clone(), permit, }) .map_err(|error| match error { @@ -141,10 +151,15 @@ impl ConnectionHandle { TrySendError::Closed(_) => Error::ConnectionClosed, }) } + + /// Check if the connection is active. + pub fn is_active(&self) -> bool { + matches!(self.connection, ConnectionType::Active(_)) + } } /// Type which allows the connection to be kept open. -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct Permit { /// Active connection. _connection: Sender, diff --git a/src/protocol/notification/tests/notification.rs b/src/protocol/notification/tests/notification.rs index fc2c1ee7..37f0a297 100644 --- a/src/protocol/notification/tests/notification.rs +++ b/src/protocol/notification/tests/notification.rs @@ -475,6 +475,7 @@ async fn remote_opens_multiple_inbound_substreams() { SubstreamId::from(0usize), Box::new(DummySubstream::new()), ), + connection_id: ConnectionId::from(0usize), }) .await .unwrap(); @@ -511,6 +512,7 @@ async fn remote_opens_multiple_inbound_substreams() { SubstreamId::from(0usize), Box::new(substream), ), + connection_id: ConnectionId::from(0usize), }) .await .unwrap(); diff --git a/src/protocol/protocol_set.rs b/src/protocol/protocol_set.rs index a03c9e82..c7bec04d 100644 --- a/src/protocol/protocol_set.rs +++ b/src/protocol/protocol_set.rs @@ -122,6 +122,9 @@ pub enum InnerTransportEvent { /// distinguish between different outbound substreams. direction: Direction, + /// Connection ID. + connection_id: ConnectionId, + /// Substream. substream: Substream, }, @@ -149,6 +152,7 @@ impl From for TransportEvent { fallback, direction, substream, + .. } => TransportEvent::SubstreamOpened { peer, protocol, @@ -164,7 +168,7 @@ impl From for TransportEvent { } /// Events emitted by the installed protocols to transport. -#[derive(Debug)] +#[derive(Debug, Clone)] pub enum ProtocolCommand { /// Open substream. OpenSubstream { @@ -192,6 +196,9 @@ pub enum ProtocolCommand { /// and associate incoming substreams with whatever logic it has. substream_id: SubstreamId, + /// Connection ID. + connection_id: ConnectionId, + /// Connection permit. /// /// `Permit` allows the connection to be kept open while the permit is held and it is given @@ -300,6 +307,7 @@ impl ProtocolSet { fallback, direction, substream, + connection_id: self.connection.connection_id().clone(), }; protocol_context diff --git a/src/protocol/transport_service.rs b/src/protocol/transport_service.rs index 7fb6ac37..64e75018 100644 --- a/src/protocol/transport_service.rs +++ b/src/protocol/transport_service.rs @@ -40,8 +40,8 @@ use std::{ atomic::{AtomicUsize, Ordering}, Arc, }, - task::{Context, Poll}, - time::Duration, + task::{Context, Poll, Waker}, + time::{Duration, Instant}, }; /// Logging target for the file. @@ -96,6 +96,168 @@ impl ConnectionContext { "connection doesn't exist, cannot downgrade", ); } + + /// Try to upgrade the connection to active state. + fn try_upgrade(&mut self, connection_id: &ConnectionId) { + if self.primary.connection_id() == connection_id { + self.primary.try_upgrade(); + return; + } + + if let Some(handle) = &mut self.secondary { + if handle.connection_id() == connection_id { + handle.try_upgrade(); + return; + } + } + + tracing::debug!( + target: LOG_TARGET, + primary = ?self.primary.connection_id(), + secondary = ?self.secondary.as_ref().map(|handle| handle.connection_id()), + ?connection_id, + "connection doesn't exist, cannot upgrade", + ); + } +} + +/// Tracks connection keep-alive timeouts. +/// +/// A connection keep-alive timeout is started when a connection is established. +/// If no substreams are opened over the connection within the timeout, +/// the connection is downgraded. However, if a substream is opened over the connection, +/// the timeout is reset. +#[derive(Debug)] +struct KeepAliveTracker { + /// Close the connection if no substreams are open within this time frame. + keep_alive_timeout: Duration, + + /// Track substream last activity. + last_activity: HashMap<(PeerId, ConnectionId), Instant>, + + /// Pending keep-alive timeouts. + pending_keep_alive_timeouts: FuturesUnordered>, + + /// Saved waker. + waker: Option, +} + +impl KeepAliveTracker { + /// Create new [`KeepAliveTracker`]. + pub fn new(keep_alive_timeout: Duration) -> Self { + Self { + keep_alive_timeout, + last_activity: HashMap::new(), + pending_keep_alive_timeouts: FuturesUnordered::new(), + waker: None, + } + } + + /// Called on connection established event to add a new keep-alive timeout. + pub fn on_connection_established(&mut self, peer: PeerId, connection_id: ConnectionId) { + self.substream_activity(peer, connection_id); + } + + /// Called on connection closed event. + pub fn on_connection_closed(&mut self, peer: PeerId, connection_id: ConnectionId) { + self.last_activity.remove(&(peer, connection_id)); + } + + /// Called on substream opened event to track the last activity. + pub fn substream_activity(&mut self, peer: PeerId, connection_id: ConnectionId) { + // Keep track of the connection ID and the time the substream was opened. + if self.last_activity.insert((peer, connection_id), Instant::now()).is_none() { + // Refill futures if there is no pending keep-alive timeout. + let timeout = self.keep_alive_timeout; + self.pending_keep_alive_timeouts.push(Box::pin(async move { + tokio::time::sleep(timeout).await; + (peer, connection_id) + })); + } + + tracing::trace!( + target: LOG_TARGET, + ?peer, + ?connection_id, + ?self.keep_alive_timeout, + last_activity = ?self.last_activity, + pending_keep_alive_timeouts = ?self.pending_keep_alive_timeouts.len(), + "substream activity", + ); + + // Wake any pending poll. + self.waker.take().map(|waker| waker.wake()); + } +} + +impl Stream for KeepAliveTracker { + type Item = (PeerId, ConnectionId); + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if self.pending_keep_alive_timeouts.is_empty() { + // No pending keep-alive timeouts. + self.waker = Some(cx.waker().clone()); + return Poll::Pending; + } + + match self.pending_keep_alive_timeouts.poll_next_unpin(cx) { + Poll::Ready(Some(key)) => { + // Check last-activity time. + let Some(last_activity) = self.last_activity.get(&key) else { + tracing::debug!( + target: LOG_TARGET, + peer = ?key.0, + connection_id = ?key.1, + "Last activity no longer tracks the connection (closed event triggered)", + ); + + // We have effectively ignored this `Poll::Ready` event. To prevent the + // future from getting stuck, we need to tell the executor to poll again + // for more events. + cx.waker().wake_by_ref(); + return Poll::Pending; + }; + + // Keep-alive timeout not reached yet. + let inactive_for = last_activity.elapsed(); + if inactive_for < self.keep_alive_timeout { + let timeout = self.keep_alive_timeout.saturating_sub(inactive_for); + + tracing::trace!( + target: LOG_TARGET, + peer = ?key.0, + connection_id = ?key.1, + ?timeout, + "keep-alive timeout not yet reached", + ); + + // Refill the keep alive timeouts. + self.pending_keep_alive_timeouts.push(Box::pin(async move { + tokio::time::sleep(timeout).await; + key + })); + + // This is similar to the `last_activity` check above, we need to inform + // the executor that this object may produce more events. + cx.waker().wake_by_ref(); + return Poll::Pending; + } + + // Keep-alive timeout reached. + tracing::debug!( + target: LOG_TARGET, + peer = ?key.0, + connection_id = ?key.1, + "keep-alive timeout triggered", + ); + self.last_activity.remove(&key); + return Poll::Ready(Some(key)); + } + Poll::Ready(None) | Poll::Pending => { + return Poll::Pending; + } + } + } } /// Provides an interfaces for [`Litep2p`](crate::Litep2p) protocols to interact @@ -124,10 +286,7 @@ pub struct TransportService { next_substream_id: Arc, /// Close the connection if no substreams are open within this time frame. - keep_alive_timeout: Duration, - - /// Pending keep-alive timeouts. - pending_keep_alive_timeouts: FuturesUnordered>, + keep_alive_tracker: KeepAliveTracker, } impl TransportService { @@ -142,6 +301,8 @@ impl TransportService { ) -> (Self, Sender) { let (tx, rx) = channel(DEFAULT_CHANNEL_SIZE); + let keep_alive_tracker = KeepAliveTracker::new(keep_alive_timeout); + ( Self { rx, @@ -151,8 +312,7 @@ impl TransportService { transport_handle, next_substream_id, connections: HashMap::new(), - keep_alive_timeout, - pending_keep_alive_timeouts: FuturesUnordered::new(), + keep_alive_tracker, }, tx, ) @@ -184,7 +344,6 @@ impl TransportService { ?connection_id, "connection established", ); - let keep_alive_timeout = self.keep_alive_timeout; match self.connections.get_mut(&peer) { Some(context) => match context.secondary { @@ -199,10 +358,8 @@ impl TransportService { None } None => { - self.pending_keep_alive_timeouts.push(Box::pin(async move { - tokio::time::sleep(keep_alive_timeout).await; - (peer, connection_id) - })); + self.keep_alive_tracker.on_connection_established(peer, connection_id); + context.secondary = Some(handle); None @@ -210,10 +367,8 @@ impl TransportService { }, None => { self.connections.insert(peer, ConnectionContext::new(handle)); - self.pending_keep_alive_timeouts.push(Box::pin(async move { - tokio::time::sleep(keep_alive_timeout).await; - (peer, connection_id) - })); + + self.keep_alive_tracker.on_connection_established(peer, connection_id); Some(TransportEvent::ConnectionEstablished { peer, endpoint }) } @@ -226,6 +381,8 @@ impl TransportService { peer: PeerId, connection_id: ConnectionId, ) -> Option { + self.keep_alive_tracker.on_connection_closed(peer, connection_id); + let Some(context) = self.connections.get_mut(&peer) else { tracing::warn!( target: LOG_TARGET, @@ -335,6 +492,8 @@ impl TransportService { .ok_or(SubstreamError::PeerDoesNotExist(peer))? .primary; + let connection_id = connection.connection_id().clone(); + let permit = connection.try_get_permit().ok_or(SubstreamError::ConnectionClosed)?; let substream_id = SubstreamId::from(self.next_substream_id.fetch_add(1usize, Ordering::Relaxed)); @@ -344,9 +503,13 @@ impl TransportService { ?peer, protocol = %self.protocol, ?substream_id, + ?connection_id, "open substream", ); + self.keep_alive_tracker.substream_activity(peer, connection_id); + connection.try_upgrade(); + connection .open_substream( self.protocol.clone(), @@ -362,7 +525,7 @@ impl TransportService { let connection = &mut self.connections.get_mut(&peer).ok_or(Error::PeerDoesntExist(peer))?; - tracing::debug!( + tracing::trace!( target: LOG_TARGET, ?peer, protocol = %self.protocol, @@ -387,6 +550,9 @@ impl Stream for TransportService { type Item = TransportEvent; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let protocol_name = self.protocol.clone(); + let duration = self.keep_alive_tracker.keep_alive_timeout; + while let Poll::Ready(event) = self.rx.poll_recv(cx) { match event { None => { @@ -410,18 +576,43 @@ impl Stream for TransportService { return Poll::Ready(Some(event)); } } + Some(InnerTransportEvent::SubstreamOpened { + peer, + protocol, + fallback, + direction, + substream, + connection_id, + }) => { + if protocol == self.protocol { + self.keep_alive_tracker.substream_activity(peer, connection_id); + if let Some(context) = self.connections.get_mut(&peer) { + context.try_upgrade(&connection_id); + } + } + + return Poll::Ready(Some(TransportEvent::SubstreamOpened { + peer, + protocol, + fallback, + direction, + substream, + })); + } Some(event) => return Poll::Ready(Some(event.into())), } } while let Poll::Ready(Some((peer, connection_id))) = - self.pending_keep_alive_timeouts.poll_next_unpin(cx) + self.keep_alive_tracker.poll_next_unpin(cx) { if let Some(context) = self.connections.get_mut(&peer) { - tracing::trace!( + tracing::debug!( target: LOG_TARGET, ?peer, ?connection_id, + protocol = ?protocol_name, + ?duration, "keep-alive timeout over, downgrade connection", ); @@ -437,7 +628,7 @@ impl Stream for TransportService { mod tests { use super::*; use crate::{ - protocol::TransportService, + protocol::{ProtocolCommand, TransportService}, transport::{ manager::{handle::InnerTransportManagerCommand, TransportManagerHandle}, KEEP_ALIVE_TIMEOUT, @@ -612,7 +803,11 @@ mod tests { } #[tokio::test] - async fn secondary_closing_doesnt_emit_event() { + async fn secondary_closing_does_not_emit_event() { + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + let (mut service, sender, _) = transport_service(); let peer = PeerId::random(); @@ -786,6 +981,10 @@ mod tests { #[tokio::test] async fn keep_alive_timeout_expires_for_a_stale_connection() { + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + let (mut service, sender, _) = transport_service(); let peer = PeerId::random(); @@ -813,7 +1012,7 @@ mod tests { }; // verify the first connection state is correct - assert_eq!(service.pending_keep_alive_timeouts.len(), 1); + assert_eq!(service.keep_alive_tracker.last_activity.len(), 1); match service.connections.get(&peer) { Some(context) => { assert_eq!( @@ -844,15 +1043,12 @@ mod tests { panic!("expected event from `TransportService`"); } - // verify that the keep-alive timeout still exists for the peer but the peer itself - // doesn't exist anymore - // - // the peer is removed because there is no connection to them - assert_eq!(service.pending_keep_alive_timeouts.len(), 1); + // Because the connection was closed, the peer is no longer tracked for keep-alive. + // This leads to better tracking overall since we don't have to track stale connections. + assert!(service.keep_alive_tracker.last_activity.is_empty()); assert!(service.connections.get(&peer).is_none()); - // register new primary connection but verify that there are now two pending keep-alive - // timeouts + // Register new primary connection. let (cmd_tx1, _cmd_rx1) = channel(64); sender .send(InnerTransportEvent::ConnectionEstablished { @@ -875,8 +1071,7 @@ mod tests { panic!("expected event from `TransportService`"); }; - // verify the first connection state is correct - assert_eq!(service.pending_keep_alive_timeouts.len(), 2); + assert_eq!(service.keep_alive_tracker.last_activity.len(), 1); match service.connections.get(&peer) { Some(context) => { assert_eq!( @@ -893,4 +1088,534 @@ mod tests { Err(_) => {} } } + + async fn poll_service(service: &mut TransportService) { + futures::future::poll_fn(|cx| match service.poll_next_unpin(cx) { + std::task::Poll::Ready(_) => panic!("didn't expect event from `TransportService`"), + std::task::Poll::Pending => std::task::Poll::Ready(()), + }) + .await; + } + + #[tokio::test] + async fn keep_alive_timeout_downgrades_connections() { + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (mut service, sender, _) = transport_service(); + let peer = PeerId::random(); + + // register first connection + let (cmd_tx1, _cmd_rx1) = channel(64); + sender + .send(InnerTransportEvent::ConnectionEstablished { + peer, + connection: ConnectionId::from(1337usize), + endpoint: Endpoint::dialer(Multiaddr::empty(), ConnectionId::from(1337usize)), + sender: ConnectionHandle::new(ConnectionId::from(1337usize), cmd_tx1), + }) + .await + .unwrap(); + + if let Some(TransportEvent::ConnectionEstablished { + peer: connected_peer, + endpoint, + }) = service.next().await + { + assert_eq!(connected_peer, peer); + assert_eq!(endpoint.address(), &Multiaddr::empty()); + } else { + panic!("expected event from `TransportService`"); + }; + + // verify the first connection state is correct + assert_eq!(service.keep_alive_tracker.last_activity.len(), 1); + match service.connections.get(&peer) { + Some(context) => { + assert_eq!( + context.primary.connection_id(), + &ConnectionId::from(1337usize) + ); + // Check the connection is still active. + assert!(context.primary.is_active()); + assert!(context.secondary.is_none()); + } + None => panic!("expected {peer} to exist"), + } + + poll_service(&mut service).await; + tokio::time::sleep(KEEP_ALIVE_TIMEOUT + std::time::Duration::from_secs(1)).await; + poll_service(&mut service).await; + + // Verify the connection is downgraded. + match service.connections.get(&peer) { + Some(context) => { + assert_eq!( + context.primary.connection_id(), + &ConnectionId::from(1337usize) + ); + // Check the connection is not active. + assert!(!context.primary.is_active()); + assert!(context.secondary.is_none()); + } + None => panic!("expected {peer} to exist"), + } + + assert_eq!(service.keep_alive_tracker.last_activity.len(), 0); + } + + #[tokio::test] + async fn keep_alive_timeout_reset_when_user_opens_substream() { + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (mut service, sender, _) = transport_service(); + let peer = PeerId::random(); + + // register first connection + let (cmd_tx1, _cmd_rx1) = channel(64); + sender + .send(InnerTransportEvent::ConnectionEstablished { + peer, + connection: ConnectionId::from(1337usize), + endpoint: Endpoint::dialer(Multiaddr::empty(), ConnectionId::from(1337usize)), + sender: ConnectionHandle::new(ConnectionId::from(1337usize), cmd_tx1), + }) + .await + .unwrap(); + + if let Some(TransportEvent::ConnectionEstablished { + peer: connected_peer, + endpoint, + }) = service.next().await + { + assert_eq!(connected_peer, peer); + assert_eq!(endpoint.address(), &Multiaddr::empty()); + } else { + panic!("expected event from `TransportService`"); + }; + + // verify the first connection state is correct + assert_eq!(service.keep_alive_tracker.last_activity.len(), 1); + match service.connections.get(&peer) { + Some(context) => { + assert_eq!( + context.primary.connection_id(), + &ConnectionId::from(1337usize) + ); + // Check the connection is still active. + assert!(context.primary.is_active()); + assert!(context.secondary.is_none()); + } + None => panic!("expected {peer} to exist"), + } + + poll_service(&mut service).await; + // Sleep for almost the entire keep-alive timeout. + tokio::time::sleep(std::time::Duration::from_secs(3)).await; + + // This ensures we reset the keep-alive timer when other protocols + // want to open a substream. + // We are still tracking the same peer. + service.open_substream(peer).unwrap(); + assert_eq!(service.keep_alive_tracker.last_activity.len(), 1); + + poll_service(&mut service).await; + // The keep alive timeout should be advanced. + tokio::time::sleep(std::time::Duration::from_secs(3)).await; + poll_service(&mut service).await; + assert_eq!(service.keep_alive_tracker.last_activity.len(), 1); + // If the `service.open_substream` wasn't called, the connection would have been downgraded. + // Instead the keep-alive was forwarded `KEEP_ALIVE_TIMEOUT` seconds into the future. + // Verify the connection is still active. + match service.connections.get(&peer) { + Some(context) => { + assert_eq!( + context.primary.connection_id(), + &ConnectionId::from(1337usize) + ); + assert!(context.primary.is_active()); + assert!(context.secondary.is_none()); + } + None => panic!("expected {peer} to exist"), + } + + poll_service(&mut service).await; + tokio::time::sleep(KEEP_ALIVE_TIMEOUT).await; + poll_service(&mut service).await; + + assert_eq!(service.keep_alive_tracker.last_activity.len(), 0); + + // The connection had no substream activity for `KEEP_ALIVE_TIMEOUT` seconds. + // Verify the connection is downgraded. + match service.connections.get(&peer) { + Some(context) => { + assert_eq!( + context.primary.connection_id(), + &ConnectionId::from(1337usize) + ); + assert!(!context.primary.is_active()); + assert!(context.secondary.is_none()); + } + None => panic!("expected {peer} to exist"), + } + } + + #[tokio::test] + async fn downgraded_connection_without_substreams_is_closed() { + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (mut service, sender, _) = transport_service(); + let peer = PeerId::random(); + + // register first connection + let (cmd_tx1, mut cmd_rx1) = channel(64); + sender + .send(InnerTransportEvent::ConnectionEstablished { + peer, + connection: ConnectionId::from(1337usize), + endpoint: Endpoint::dialer(Multiaddr::empty(), ConnectionId::from(1337usize)), + sender: ConnectionHandle::new(ConnectionId::from(1337usize), cmd_tx1), + }) + .await + .unwrap(); + + if let Some(TransportEvent::ConnectionEstablished { + peer: connected_peer, + endpoint, + }) = service.next().await + { + assert_eq!(connected_peer, peer); + assert_eq!(endpoint.address(), &Multiaddr::empty()); + } else { + panic!("expected event from `TransportService`"); + }; + + // verify the first connection state is correct + assert_eq!(service.keep_alive_tracker.last_activity.len(), 1); + match service.connections.get(&peer) { + Some(context) => { + assert_eq!( + context.primary.connection_id(), + &ConnectionId::from(1337usize) + ); + // Check the connection is still active. + assert!(context.primary.is_active()); + assert!(context.secondary.is_none()); + } + None => panic!("expected {peer} to exist"), + } + + // Open substreams to the peer. + let substream_id = service.open_substream(peer).unwrap(); + let second_substream_id = service.open_substream(peer).unwrap(); + + // Simulate keep-alive timeout expiration. + poll_service(&mut service).await; + tokio::time::sleep(KEEP_ALIVE_TIMEOUT + std::time::Duration::from_secs(1)).await; + poll_service(&mut service).await; + + let mut permits = Vec::new(); + + // First substream. + let protocol_command = cmd_rx1.recv().await.unwrap(); + match protocol_command { + ProtocolCommand::OpenSubstream { + protocol, + substream_id: opened_substream_id, + permit, + .. + } => { + assert_eq!(protocol, ProtocolName::from("/notif/1")); + assert_eq!(substream_id, opened_substream_id); + + // Save the substream permit for later. + permits.push(permit); + } + _ => panic!("expected `ProtocolCommand::OpenSubstream`"), + } + + // Second substream. + let protocol_command = cmd_rx1.recv().await.unwrap(); + match protocol_command { + ProtocolCommand::OpenSubstream { + protocol, + substream_id: opened_substream_id, + permit, + .. + } => { + assert_eq!(protocol, ProtocolName::from("/notif/1")); + assert_eq!(second_substream_id, opened_substream_id); + + // Save the substream permit for later. + permits.push(permit); + } + _ => panic!("expected `ProtocolCommand::OpenSubstream`"), + } + + // Drop one permit. + let permit = permits.pop(); + // Individual transports like TCP will open a substream + // and then will generate a `SubstreamOpened` event via + // the protocol-set handler. + // + // The substream is used by individual protocols and then + // is closed. This simulates the substream being closed. + drop(permit); + + // Open a new substream to the peer. This will succeed as long as we still have + // one substream open. + let substream_id = service.open_substream(peer).unwrap(); + // Handle the substream. + let protocol_command = cmd_rx1.recv().await.unwrap(); + match protocol_command { + ProtocolCommand::OpenSubstream { + protocol, + substream_id: opened_substream_id, + permit, + .. + } => { + assert_eq!(protocol, ProtocolName::from("/notif/1")); + assert_eq!(substream_id, opened_substream_id); + + // Save the substream permit for later. + permits.push(permit); + } + _ => panic!("expected `ProtocolCommand::OpenSubstream`"), + } + + // Drop all substreams. + drop(permits); + + poll_service(&mut service).await; + tokio::time::sleep(KEEP_ALIVE_TIMEOUT + std::time::Duration::from_secs(1)).await; + poll_service(&mut service).await; + + // Cannot open a new substream because: + // 1. connection was downgraded by keep-alive timeout + // 2. all substreams were dropped. + assert_eq!( + service.open_substream(peer), + Err(SubstreamError::ConnectionClosed) + ); + } + + #[tokio::test] + async fn substream_opening_upgrades_connection_and_resets_keep_alive() { + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (mut service, sender, _) = transport_service(); + let peer = PeerId::random(); + + // register first connection + let (cmd_tx1, mut cmd_rx1) = channel(64); + sender + .send(InnerTransportEvent::ConnectionEstablished { + peer, + connection: ConnectionId::from(1337usize), + endpoint: Endpoint::dialer(Multiaddr::empty(), ConnectionId::from(1337usize)), + sender: ConnectionHandle::new(ConnectionId::from(1337usize), cmd_tx1), + }) + .await + .unwrap(); + + if let Some(TransportEvent::ConnectionEstablished { + peer: connected_peer, + endpoint, + }) = service.next().await + { + assert_eq!(connected_peer, peer); + assert_eq!(endpoint.address(), &Multiaddr::empty()); + } else { + panic!("expected event from `TransportService`"); + }; + + // verify the first connection state is correct + assert_eq!(service.keep_alive_tracker.last_activity.len(), 1); + match service.connections.get(&peer) { + Some(context) => { + assert_eq!( + context.primary.connection_id(), + &ConnectionId::from(1337usize) + ); + // Check the connection is still active. + assert!(context.primary.is_active()); + assert!(context.secondary.is_none()); + } + None => panic!("expected {peer} to exist"), + } + + // Open substreams to the peer. + let substream_id = service.open_substream(peer).unwrap(); + let second_substream_id = service.open_substream(peer).unwrap(); + + let mut permits = Vec::new(); + // First substream. + let protocol_command = cmd_rx1.recv().await.unwrap(); + match protocol_command { + ProtocolCommand::OpenSubstream { + protocol, + substream_id: opened_substream_id, + permit, + .. + } => { + assert_eq!(protocol, ProtocolName::from("/notif/1")); + assert_eq!(substream_id, opened_substream_id); + + // Save the substream permit for later. + permits.push(permit); + } + _ => panic!("expected `ProtocolCommand::OpenSubstream`"), + } + + // Second substream. + let protocol_command = cmd_rx1.recv().await.unwrap(); + match protocol_command { + ProtocolCommand::OpenSubstream { + protocol, + substream_id: opened_substream_id, + permit, + .. + } => { + assert_eq!(protocol, ProtocolName::from("/notif/1")); + assert_eq!(second_substream_id, opened_substream_id); + + // Save the substream permit for later. + permits.push(permit); + } + _ => panic!("expected `ProtocolCommand::OpenSubstream`"), + } + + // Sleep to trigger keep-alive timeout. + poll_service(&mut service).await; + tokio::time::sleep(KEEP_ALIVE_TIMEOUT + std::time::Duration::from_secs(1)).await; + poll_service(&mut service).await; + + // Verify the connection is downgraded. + match service.connections.get(&peer) { + Some(context) => { + assert_eq!( + context.primary.connection_id(), + &ConnectionId::from(1337usize) + ); + // Check the connection is not active. + assert!(!context.primary.is_active()); + assert!(context.secondary.is_none()); + } + None => panic!("expected {peer} to exist"), + } + assert_eq!(service.keep_alive_tracker.last_activity.len(), 0); + + // Open a new substream to the peer. This will succeed as long as we still have + // at least substream permit. + let substream_id = service.open_substream(peer).unwrap(); + let protocol_command = cmd_rx1.recv().await.unwrap(); + match protocol_command { + ProtocolCommand::OpenSubstream { + protocol, + substream_id: opened_substream_id, + permit, + .. + } => { + assert_eq!(protocol, ProtocolName::from("/notif/1")); + assert_eq!(substream_id, opened_substream_id); + + // Save the substream permit for later. + permits.push(permit); + } + _ => panic!("expected `ProtocolCommand::OpenSubstream`"), + } + + poll_service(&mut service).await; + + // Verify the connection is upgraded and keep-alive is tracked. + match service.connections.get(&peer) { + Some(context) => { + assert_eq!( + context.primary.connection_id(), + &ConnectionId::from(1337usize) + ); + // Check the connection is active, because it was upgraded by the last substream. + assert!(context.primary.is_active()); + assert!(context.secondary.is_none()); + } + None => panic!("expected {peer} to exist"), + } + assert_eq!(service.keep_alive_tracker.last_activity.len(), 1); + + // Drop all substreams + drop(permits); + + // The connection is still active, because it was upgraded by the last substream open. + match service.connections.get(&peer) { + Some(context) => { + assert_eq!( + context.primary.connection_id(), + &ConnectionId::from(1337usize) + ); + // Check the connection is active, because it was upgraded by the last substream. + assert!(context.primary.is_active()); + assert!(context.secondary.is_none()); + } + None => panic!("expected {peer} to exist"), + } + assert_eq!(service.keep_alive_tracker.last_activity.len(), 1); + + // Sleep to trigger keep-alive timeout. + poll_service(&mut service).await; + tokio::time::sleep(KEEP_ALIVE_TIMEOUT + std::time::Duration::from_secs(1)).await; + poll_service(&mut service).await; + + match service.connections.get(&peer) { + Some(context) => { + assert_eq!( + context.primary.connection_id(), + &ConnectionId::from(1337usize) + ); + // No longer active because it was downgraded by keep-alive and no + // substream opens were made. + assert!(!context.primary.is_active()); + assert!(context.secondary.is_none()); + } + None => panic!("expected {peer} to exist"), + } + + // Cannot open a new substream because: + // 1. connection was downgraded by keep-alive timeout + // 2. all substreams were dropped. + assert_eq!( + service.open_substream(peer), + Err(SubstreamError::ConnectionClosed) + ); + } + + #[tokio::test] + async fn keep_alive_pop_elements() { + let mut tracker = KeepAliveTracker::new(Duration::from_secs(1)); + + let (peer1, connection1) = (PeerId::random(), ConnectionId::from(1usize)); + let (peer2, connection2) = (PeerId::random(), ConnectionId::from(2usize)); + let added_keys = HashSet::from([(peer1, connection1), (peer2, connection2)]); + + tracker.on_connection_established(peer1, connection1); + tracker.on_connection_established(peer2, connection2); + + tokio::time::sleep(Duration::from_secs(2)).await; + + let key = tracker.next().await.unwrap(); + assert!(added_keys.contains(&key)); + + let key = tracker.next().await.unwrap(); + assert!(added_keys.contains(&key)); + + // No more elements. + assert!(tracker.pending_keep_alive_timeouts.is_empty()); + assert!(tracker.last_activity.is_empty()); + } } diff --git a/src/transport/quic/connection.rs b/src/transport/quic/connection.rs index 52c198e7..486e1dcc 100644 --- a/src/transport/quic/connection.rs +++ b/src/transport/quic/connection.rs @@ -330,7 +330,7 @@ impl QuicConnection { ); return self.protocol_set.report_connection_closed(self.peer, self.endpoint.connection_id()).await; } - Some(ProtocolCommand::OpenSubstream { protocol, fallback_names, substream_id, permit }) => { + Some(ProtocolCommand::OpenSubstream { protocol, fallback_names, substream_id, permit, .. }) => { let connection = self.connection.clone(); let substream_open_timeout = self.substream_open_timeout; diff --git a/src/transport/s2n-quic/connection.rs b/src/transport/s2n-quic/connection.rs index 9d1b0d3e..f26d31ae 100644 --- a/src/transport/s2n-quic/connection.rs +++ b/src/transport/s2n-quic/connection.rs @@ -325,7 +325,7 @@ impl QuicConnection { } } protocol = self.protocol_set.next_event() => match protocol { - Some(ProtocolCommand::OpenSubstream { protocol, fallback_names, substream_id, permit }) => { + Some(ProtocolCommand::OpenSubstream { protocol, fallback_names, substream_id, permit, .. }) => { let handle = self.connection.handle(); tracing::trace!( diff --git a/src/transport/tcp/connection.rs b/src/transport/tcp/connection.rs index b13420f7..d0d21fde 100644 --- a/src/transport/tcp/connection.rs +++ b/src/transport/tcp/connection.rs @@ -654,6 +654,7 @@ impl TcpConnection { protocol, fallback_names, substream_id, + connection_id, permit, }) => { let control = self.control.clone(); @@ -663,6 +664,7 @@ impl TcpConnection { target: LOG_TARGET, ?protocol, ?substream_id, + ?connection_id, "open substream", ); diff --git a/src/transport/webrtc/connection.rs b/src/transport/webrtc/connection.rs index f31a48b2..52ceb048 100644 --- a/src/transport/webrtc/connection.rs +++ b/src/transport/webrtc/connection.rs @@ -816,7 +816,7 @@ impl WebRtcConnection { ); return self.on_connection_closed().await; } - Some(ProtocolCommand::OpenSubstream { protocol, fallback_names, substream_id, permit }) => { + Some(ProtocolCommand::OpenSubstream { protocol, fallback_names, substream_id, permit, .. }) => { self.on_open_substream(protocol, fallback_names, substream_id, permit); } }, diff --git a/src/transport/websocket/connection.rs b/src/transport/websocket/connection.rs index 8c505607..3635e655 100644 --- a/src/transport/websocket/connection.rs +++ b/src/transport/websocket/connection.rs @@ -520,7 +520,7 @@ impl WebSocketConnection { } } protocol = self.protocol_set.next() => match protocol { - Some(ProtocolCommand::OpenSubstream { protocol, fallback_names, substream_id, permit }) => { + Some(ProtocolCommand::OpenSubstream { protocol, fallback_names, substream_id, permit, .. }) => { let control = self.control.clone(); let substream_open_timeout = self.substream_open_timeout;