diff --git a/dc/s2n-quic-dc/src/packet/stream/id.rs b/dc/s2n-quic-dc/src/packet/stream/id.rs index 3acd7b191..e8308956a 100644 --- a/dc/s2n-quic-dc/src/packet/stream/id.rs +++ b/dc/s2n-quic-dc/src/packet/stream/id.rs @@ -12,7 +12,7 @@ use s2n_quic_core::{probe, varint::VarInt}; )] pub struct Id { #[cfg_attr(any(feature = "testing", test), generator(Self::GENERATOR))] - pub key_id: VarInt, + pub queue_id: VarInt, pub is_reliable: bool, pub is_bidirectional: bool, } @@ -22,7 +22,7 @@ impl fmt::Debug for Id { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { if f.alternate() { f.debug_struct("stream::Id") - .field("key_id", &self.key_id) + .field("queue_id", &self.queue_id) .field("is_reliable", &self.is_reliable) .field("is_bidirectional", &self.is_bidirectional) .finish() @@ -66,7 +66,7 @@ impl Id { #[inline] pub fn next(&self) -> Option { Some(Self { - key_id: self.key_id.checked_add_usize(1)?, + queue_id: self.queue_id.checked_add_usize(1)?, is_reliable: self.is_reliable, is_bidirectional: self.is_bidirectional, }) @@ -84,7 +84,7 @@ impl Id { #[inline] pub fn into_varint(self) -> VarInt { - let key_id = *self.key_id; + let key_id = *self.queue_id; let is_reliable = if self.is_reliable { IS_RELIABLE_MASK } else { @@ -108,7 +108,7 @@ impl Id { let is_reliable = *value & IS_RELIABLE_MASK == IS_RELIABLE_MASK; let is_bidirectional = *value & IS_BIDIRECTIONAL_MASK == IS_BIDIRECTIONAL_MASK; Self { - key_id: VarInt::new(*value >> 2).unwrap(), + queue_id: VarInt::new(*value >> 2).unwrap(), is_reliable, is_bidirectional, } diff --git a/dc/s2n-quic-dc/src/stream/endpoint.rs b/dc/s2n-quic-dc/src/stream/endpoint.rs index 2bbc0b322..9f9483392 100644 --- a/dc/s2n-quic-dc/src/stream/endpoint.rs +++ b/dc/s2n-quic-dc/src/stream/endpoint.rs @@ -51,9 +51,9 @@ where parameters = o(parameters); } - let key_id = crypto.credentials.key_id; let stream_id = packet::stream::Id { - key_id, + // the client starts with routing to 0 until the server updates the value + queue_id: VarInt::ZERO, is_reliable: true, is_bidirectional: true, }; @@ -90,6 +90,7 @@ pub fn accept_stream( env: &Env, mut peer: P, packet: &server::InitialPacket, + queue_id: VarInt, recv_buffer: recv::shared::RecvBuffer, map: &Map, subscriber: Env::Subscriber, @@ -124,11 +125,18 @@ where // inform the value of what the source_control_port is peer.with_source_control_port(packet.source_control_port); + let stream_id = packet::stream::Id { + // select our own route key for this stream + queue_id, + // inherit the rest of the parameters from the client + ..packet.stream_id + }; + let res = build_stream( now, env, peer, - packet.stream_id, + stream_id, packet.source_stream_port, crypto, map, diff --git a/dc/s2n-quic-dc/src/stream/server/handshake.rs b/dc/s2n-quic-dc/src/stream/server/handshake.rs index 0c84a6678..b6f3720b9 100644 --- a/dc/s2n-quic-dc/src/stream/server/handshake.rs +++ b/dc/s2n-quic-dc/src/stream/server/handshake.rs @@ -8,7 +8,7 @@ use tokio::sync::mpsc; type Sender = mpsc::Sender; type ReceiverChan = mpsc::Receiver; -type Key = (credentials::Id, u64); +type Key = credentials::Id; type HashMap = flurry::HashMap; pub enum Outcome { @@ -36,20 +36,19 @@ impl Default for Map { impl Map { #[inline] pub fn handle(&mut self, packet: &super::InitialPacket, msg: &mut recv::Message) -> Outcome { - let stream_id = packet.stream_id.into_varint().as_u64(); let (sender, receiver) = self .next .take() .unwrap_or_else(|| mpsc::channel(self.channel_size)); - let key = (packet.credentials.id, stream_id); + let key = packet.credentials.id; let guard = self.inner.guard(); match self.inner.try_insert(key, sender, &guard) { Ok(_) => { drop(guard); let map = Arc::downgrade(&self.inner); - tracing::trace!(action = "register", credentials = ?&key.0, stream_id = key.1); + tracing::trace!(action = "register", credentials = ?&key); let receiver = ReceiverState { map, key, @@ -61,18 +60,18 @@ impl Map { Err(err) => { self.next = Some((err.not_inserted, receiver)); - tracing::trace!(action = "forward", credentials = ?&key.0, stream_id = key.1); + tracing::trace!(action = "forward", credentials = ?&key); if let Err(err) = err.current.try_send(msg.take()) { match err { mpsc::error::TrySendError::Closed(_) => { // remove the channel from the map since we're closed self.inner.remove(&key, &guard); - tracing::debug!(stream_id, error = "channel_closed"); + tracing::debug!(credentials = ?key, error = "channel_closed"); } mpsc::error::TrySendError::Full(_) => { // drop the packet let _ = msg; - tracing::debug!(stream_id, error = "channel_full"); + tracing::debug!(credentials = ?key, error = "channel_full"); } } } @@ -104,7 +103,7 @@ impl Drop for Receiver { #[inline] fn drop(&mut self) { if let Some(map) = self.0.map.upgrade() { - tracing::trace!(action = "unregister", credentials = ?&self.0.key.0, stream_id = self.0.key.1); + tracing::trace!(action = "unregister", credentials = ?&self.0.key); let _ = map.remove(&self.0.key, &map.guard()); } } diff --git a/dc/s2n-quic-dc/src/stream/server/tokio/tcp/worker.rs b/dc/s2n-quic-dc/src/stream/server/tokio/tcp/worker.rs index c10833153..002fc33a0 100644 --- a/dc/s2n-quic-dc/src/stream/server/tokio/tcp/worker.rs +++ b/dc/s2n-quic-dc/src/stream/server/tokio/tcp/worker.rs @@ -24,6 +24,7 @@ use s2n_quic_core::{ inet::SocketAddress, ready, time::{Clock, Timestamp}, + varint::VarInt, }; use std::io; use tokio::{io::AsyncWrite as _, net::TcpStream}; @@ -313,6 +314,8 @@ impl WorkerState { let subscriber_ctx = subscriber_ctx.take().unwrap(); let (socket, remote_address) = stream.take().unwrap(); + // TCP doesn't use the route key so just pick 0 + let queue_id = VarInt::ZERO; let recv_buffer = recv::buffer::Local::new(recv_buffer.take(), None); let stream_builder = match endpoint::accept_stream( @@ -324,6 +327,7 @@ impl WorkerState { local_port: context.local_port, }, &initial_packet, + queue_id, recv_buffer, &context.secrets, context.subscriber.clone(), diff --git a/dc/s2n-quic-dc/src/stream/server/tokio/udp.rs b/dc/s2n-quic-dc/src/stream/server/tokio/udp.rs index 4f4b67717..da9645b5a 100644 --- a/dc/s2n-quic-dc/src/stream/server/tokio/udp.rs +++ b/dc/s2n-quic-dc/src/stream/server/tokio/udp.rs @@ -17,7 +17,7 @@ use crate::{ }, }; use core::ops::ControlFlow; -use s2n_quic_core::{inet::SocketAddress, time::Clock}; +use s2n_quic_core::{inet::SocketAddress, time::Clock, varint::VarInt}; use std::io; use tracing::debug; @@ -111,6 +111,8 @@ where let subscriber_ctx = self.subscriber.create_connection_context(&meta, &info); + // TODO allocate a queue for this stream + let queue_id = VarInt::ZERO; let recv_buffer = recv::buffer::Local::new(self.recv_buffer.take(), Some(handshake)); let stream = match endpoint::accept_stream( @@ -118,6 +120,7 @@ where &self.env, env::UdpUnbound(remote_addr), &packet, + queue_id, recv_buffer, &self.secrets, self.subscriber.clone(), diff --git a/dc/wireshark/src/dissect.rs b/dc/wireshark/src/dissect.rs index 8028f667f..8f0e0c0dc 100644 --- a/dc/wireshark/src/dissect.rs +++ b/dc/wireshark/src/dissect.rs @@ -371,8 +371,8 @@ fn record_stream_id( stream_id: Parsed, ) -> stream::Id { stream_id - .map(|v| v.key_id) - .record(buffer, tree, fields.stream_id); + .map(|v| v.queue_id) + .record(buffer, tree, fields.queue_id); let id = stream_id.value; tree.add_boolean( diff --git a/dc/wireshark/src/field.rs b/dc/wireshark/src/field.rs index e652bcaa8..a780aa3d4 100644 --- a/dc/wireshark/src/field.rs +++ b/dc/wireshark/src/field.rs @@ -51,7 +51,7 @@ pub struct Registration { pub is_bidirectional: i32, pub is_reliable: i32, - pub stream_id: i32, + pub queue_id: i32, pub relative_packet_number: i32, pub stream_offset: i32, pub final_offset: i32, @@ -416,8 +416,8 @@ fn init() -> Registration { ) .with_mask(0x2) .register(), - stream_id: protocol - .field(c"Stream ID", c"dcquic.stream_id", UINT64, BASE_DEC, c"") + queue_id: protocol + .field(c"Route Key", c"dcquic.queue_id", UINT64, BASE_DEC, c"") .register(), relative_packet_number: protocol .field( diff --git a/dc/wireshark/src/test.rs b/dc/wireshark/src/test.rs index 8fb717b4b..11e07de5c 100644 --- a/dc/wireshark/src/test.rs +++ b/dc/wireshark/src/test.rs @@ -94,8 +94,8 @@ fn check_stream_parse() { .map(|v| Field::Integer(v.get() as u64)) ); assert_eq!( - tracker.remove(fields.stream_id), - Field::Integer(u64::from(packet.stream_id.key_id)) + tracker.remove(fields.queue_id), + Field::Integer(u64::from(packet.stream_id.queue_id)) ); assert_eq!( tracker.remove(fields.is_reliable),