diff --git a/dc/s2n-quic-dc/src/stream/client/tokio.rs b/dc/s2n-quic-dc/src/stream/client/tokio.rs index ceb345594..648341c1d 100644 --- a/dc/s2n-quic-dc/src/stream/client/tokio.rs +++ b/dc/s2n-quic-dc/src/stream/client/tokio.rs @@ -11,7 +11,7 @@ use crate::{ socket::Protocol, }, }; -use std::{io, net::SocketAddr}; +use std::{io, net::SocketAddr, time::Duration}; use tokio::net::TcpStream; /// Connects using the UDP transport layer @@ -54,6 +54,7 @@ pub async fn connect_tcp( acceptor_addr: SocketAddr, env: &Environment, subscriber: Sub, + linger: Option, ) -> io::Result> where H: core::future::Future>, @@ -64,7 +65,10 @@ where // Make sure TCP_NODELAY is set let _ = socket.set_nodelay(true); - let _ = socket.set_linger(Some(core::time::Duration::ZERO)); + + if linger.is_some() { + let _ = socket.set_linger(linger); + } // if the acceptor_ip isn't known, then ask the socket to resolve it for us let peer_addr = if acceptor_addr.ip().is_unspecified() { diff --git a/dc/s2n-quic-dc/src/stream/server/tokio/tcp.rs b/dc/s2n-quic-dc/src/stream/server/tokio/tcp.rs index 72a968281..51f2deaa9 100644 --- a/dc/s2n-quic-dc/src/stream/server/tokio/tcp.rs +++ b/dc/s2n-quic-dc/src/stream/server/tokio/tcp.rs @@ -9,6 +9,7 @@ use crate::{ }; use core::{future::poll_fn, task::Poll}; use s2n_quic_core::{inet::SocketAddress, time::Clock}; +use std::time::Duration; use tokio::net::TcpListener; use tracing::debug; @@ -26,6 +27,7 @@ where secrets: secret::Map, backlog: usize, accept_flavor: accept::Flavor, + linger: Option, subscriber: Sub, } @@ -42,6 +44,7 @@ where secrets: &secret::Map, backlog: usize, accept_flavor: accept::Flavor, + linger: Option, subscriber: Sub, ) -> Self { let acceptor = Self { @@ -51,6 +54,7 @@ where secrets: secrets.clone(), backlog, accept_flavor, + linger, subscriber, }; @@ -98,6 +102,7 @@ where workers.insert( remote_address, socket, + self.linger, &mut context, subscriber_ctx, &publisher, diff --git a/dc/s2n-quic-dc/src/stream/server/tokio/tcp/manager.rs b/dc/s2n-quic-dc/src/stream/server/tokio/tcp/manager.rs index c2381a030..cb5619e92 100644 --- a/dc/s2n-quic-dc/src/stream/server/tokio/tcp/manager.rs +++ b/dc/s2n-quic-dc/src/stream/server/tokio/tcp/manager.rs @@ -153,6 +153,7 @@ where &mut self, remote_address: SocketAddress, stream: W::Stream, + linger: Option, cx: &mut W::Context, connection_context: W::ConnectionContext, publisher: &Pub, @@ -179,6 +180,7 @@ where self.inner.workers[idx].worker.replace( remote_address, stream, + linger, connection_context, publisher, clock, @@ -377,6 +379,7 @@ pub trait Worker { &mut self, remote_address: SocketAddress, stream: Self::Stream, + linger: Option, connection_context: Self::ConnectionContext, publisher: &Pub, clock: &C, diff --git a/dc/s2n-quic-dc/src/stream/server/tokio/tcp/manager/tests.rs b/dc/s2n-quic-dc/src/stream/server/tokio/tcp/manager/tests.rs index b822d0355..3fd69b8ab 100644 --- a/dc/s2n-quic-dc/src/stream/server/tokio/tcp/manager/tests.rs +++ b/dc/s2n-quic-dc/src/stream/server/tokio/tcp/manager/tests.rs @@ -64,6 +64,7 @@ impl super::Worker for Worker { &mut self, _remote_address: SocketAddress, _stream: Self::Stream, + _linger: Option, _connection_context: Self::ConnectionContext, _publisher: &Pub, clock: &C, @@ -160,6 +161,7 @@ impl Harness { self.manager.insert( SocketAddress::default(), (), + None, &mut (), (), &publisher(&self.subscriber, &self.clock), diff --git a/dc/s2n-quic-dc/src/stream/server/tokio/tcp/worker.rs b/dc/s2n-quic-dc/src/stream/server/tokio/tcp/worker.rs index 825ffbae2..d7082d8a3 100644 --- a/dc/s2n-quic-dc/src/stream/server/tokio/tcp/worker.rs +++ b/dc/s2n-quic-dc/src/stream/server/tokio/tcp/worker.rs @@ -98,6 +98,7 @@ where &mut self, remote_address: SocketAddress, stream: TcpStream, + linger: Option, subscriber_ctx: Self::ConnectionContext, publisher: &Pub, clock: &C, @@ -107,7 +108,10 @@ where { // Make sure TCP_NODELAY is set let _ = stream.set_nodelay(true); - let _ = stream.set_linger(Some(Duration::ZERO)); + + if linger.is_some() { + let _ = stream.set_linger(linger); + } let now = clock.get_time(); @@ -116,7 +120,14 @@ where let prev_stream = core::mem::replace(&mut self.stream, Some((stream, remote_address))); let prev_ctx = core::mem::replace(&mut self.subscriber_ctx, Some(subscriber_ctx)); - if let Some(remote_address) = prev_stream.map(|(_socket, remote_address)| remote_address) { + if let Some(remote_address) = prev_stream.map(|(socket, remote_address)| { + // If linger wasn't already set or it was set to a value other than 0, then override it + if linger.is_none() || linger != Some(Duration::ZERO) { + // close the stream immediately and send a reset to the client + let _ = socket.set_linger(Some(Duration::ZERO)); + } + remote_address + }) { let sojourn_time = now.saturating_duration_since(prev_queue_time); let buffer_len = match prev_state { WorkerState::Init => 0, @@ -331,6 +342,10 @@ impl WorkerState { error: error.error, }; continue; + } else { + // close the stream immediately and send a reset to the client + let _ = socket.set_linger(Some(Duration::ZERO)); + drop(socket); } } return Err(Some(error.error)).into(); @@ -381,16 +396,15 @@ impl WorkerState { } #[inline] - fn poll_initial_packet( + fn poll_initial_packet( cx: &mut task::Context, - stream: &mut S, + stream: &mut TcpStream, remote_address: &SocketAddress, recv_buffer: &mut msg::recv::Message, sojourn_time: Duration, publisher: &Pub, ) -> Poll>> where - S: Socket, Pub: EndpointPublisher, { loop { @@ -403,6 +417,10 @@ impl WorkerState { sojourn_time, }, ); + + // close the stream immediately and send a reset to the client + let _ = stream.set_linger(Some(Duration::ZERO)); + return Err(None).into(); } @@ -437,6 +455,9 @@ impl WorkerState { }, ); + // close the stream immediately and send a reset to the client + let _ = stream.set_linger(Some(Duration::ZERO)); + return Err(None).into(); } } diff --git a/dc/s2n-quic-dc/src/stream/testing.rs b/dc/s2n-quic-dc/src/stream/testing.rs index ddccccbb8..67ddbeb57 100644 --- a/dc/s2n-quic-dc/src/stream/testing.rs +++ b/dc/s2n-quic-dc/src/stream/testing.rs @@ -68,8 +68,14 @@ impl Client { match server.protocol { Protocol::Tcp => { - stream_client::connect_tcp(handshake, server.local_addr, &self.env, subscriber) - .await + stream_client::connect_tcp( + handshake, + server.local_addr, + &self.env, + subscriber, + None, + ) + .await } Protocol::Udp => { stream_client::connect_udp(handshake, server.local_addr, &self.env, subscriber) @@ -181,6 +187,8 @@ mod drop_handle { } pub mod server { + use std::time::Duration; + use super::*; #[derive(Clone)] @@ -201,6 +209,7 @@ pub mod server { flavor: accept::Flavor, protocol: Protocol, map_capacity: usize, + linger: Option, subscriber: event::testing::Subscriber, } @@ -211,6 +220,7 @@ pub mod server { flavor: accept::Flavor::default(), protocol: Protocol::Tcp, map_capacity: 16, + linger: None, subscriber: event::testing::Subscriber::no_snapshot(), } } @@ -255,6 +265,11 @@ pub mod server { self } + pub fn linger(mut self, linger: Duration) -> Self { + self.linger = Some(linger); + self + } + pub fn subscriber(mut self, subscriber: event::testing::Subscriber) -> Self { self.subscriber = subscriber; self @@ -266,6 +281,7 @@ pub mod server { flavor, protocol, map_capacity, + linger, subscriber, } = self; @@ -291,7 +307,7 @@ pub mod server { let socket = tokio::net::TcpListener::from_std(socket).unwrap(); let acceptor = stream_server::tcp::Acceptor::new( - 0, socket, &sender, &env, &map, backlog, flavor, subscriber, + 0, socket, &sender, &env, &map, backlog, flavor, linger, subscriber, ); let acceptor = drop_handle_receiver.wrap(acceptor.run()); let acceptor = acceptor.instrument(tracing::info_span!("tcp"));