diff --git a/network/src/network/mod.rs b/network/src/network/mod.rs index 1e699d4c4..3e7b0fba2 100644 --- a/network/src/network/mod.rs +++ b/network/src/network/mod.rs @@ -345,10 +345,12 @@ struct NetworkShutdownError; #[cfg(test)] mod tests { + use futures_util::stream::FuturesUnordered; + use futures_util::StreamExt; use tracing_test::traced_test; use super::*; - use crate::types::{service_query_fn, BoxCloneService, PeerInfo, Request}; + use crate::types::{service_message_fn, service_query_fn, BoxCloneService, PeerInfo, Request}; use crate::util::NetworkExt; fn echo_service() -> BoxCloneService { @@ -447,4 +449,59 @@ mod tests { Ok(()) } + + #[traced_test] + #[tokio::test(flavor = "multi_thread")] + async fn uni_message_handler() -> Result<()> { + std::panic::set_hook(Box::new(|info| { + use std::io::Write; + + tracing::error!("{}", info); + std::io::stderr().flush().ok(); + std::io::stdout().flush().ok(); + std::process::exit(1); + })); + + fn noop_service() -> BoxCloneService { + let handle = |request: ServiceRequest| async move { + tracing::trace!("received: {} bytes", request.body.len()); + }; + service_message_fn(handle).boxed_clone() + } + + fn make_network() -> Result { + Network::builder() + .with_config(NetworkConfig { + enable_0rtt: true, + ..Default::default() + }) + .with_random_private_key() + .with_service_name("tycho") + .build("127.0.0.1:0", noop_service()) + } + + let left = make_network()?; + let right = make_network()?; + + let _left_to_right = left.known_peers().insert(make_peer_info(&right), false)?; + let _right_to_left = right.known_peers().insert(make_peer_info(&left), false)?; + + let req = Request { + version: Default::default(), + body: vec![0xff; 750 * 1024].into(), + }; + + for _ in 0..10 { + let mut futures = FuturesUnordered::new(); + for _ in 0..100 { + futures.push(left.send(&right.peer_id(), req.clone())); + } + + while let Some(res) = futures.next().await { + res?; + } + } + + Ok(()) + } } diff --git a/network/src/network/request_handler.rs b/network/src/network/request_handler.rs index 3f8c74a8a..1c460c768 100644 --- a/network/src/network/request_handler.rs +++ b/network/src/network/request_handler.rs @@ -37,6 +37,26 @@ impl InboundRequestHandler { pub async fn start(self) { tracing::debug!(peer_id = %self.connection.peer_id(), "request handler started"); + struct ClearOnDrop<'a> { + handler: &'a InboundRequestHandler, + reason: DisconnectReason, + } + + impl Drop for ClearOnDrop<'_> { + fn drop(&mut self) { + self.handler.active_peers.remove_with_stable_id( + self.handler.connection.peer_id(), + self.handler.connection.stable_id(), + self.reason, + ); + } + } + + let mut clear_on_drop = ClearOnDrop { + handler: &self, + reason: DisconnectReason::LocallyClosed, + }; + let mut inflight_requests = JoinSet::<()>::new(); let reason: quinn::ConnectionError = loop { @@ -107,12 +127,7 @@ impl InboundRequestHandler { } } }; - - self.active_peers.remove_with_stable_id( - self.connection.peer_id(), - self.connection.stable_id(), - DisconnectReason::from(reason), - ); + clear_on_drop.reason = reason.into(); inflight_requests.shutdown().await; tracing::debug!(peer_id = %self.connection.peer_id(), "request handler stopped");