From a1e1b255e6d19e4f5dfefd186ac6ed30cc619eba Mon Sep 17 00:00:00 2001 From: Cameron Bytheway Date: Mon, 24 Feb 2025 15:31:48 -0700 Subject: [PATCH] feat(s2n-quic-dc): implement recv path packet pool --- dc/s2n-quic-dc/Cargo.toml | 20 +- dc/s2n-quic-dc/src/socket.rs | 2 + dc/s2n-quic-dc/src/socket/recv.rs | 7 + .../corpus.tar.gz | 3 + dc/s2n-quic-dc/src/socket/recv/descriptor.rs | 446 ++++++++++++++++++ dc/s2n-quic-dc/src/socket/recv/pool.rs | 378 +++++++++++++++ dc/s2n-quic-dc/src/socket/recv/router.rs | 159 +++++++ dc/s2n-quic-dc/src/socket/recv/udp.rs | 75 +++ dc/s2n-quic-dc/src/testing.rs | 4 + 9 files changed, 1087 insertions(+), 7 deletions(-) create mode 100644 dc/s2n-quic-dc/src/socket/recv.rs create mode 100644 dc/s2n-quic-dc/src/socket/recv/__fuzz__/socket__recv__pool__tests__model/corpus.tar.gz create mode 100644 dc/s2n-quic-dc/src/socket/recv/descriptor.rs create mode 100644 dc/s2n-quic-dc/src/socket/recv/pool.rs create mode 100644 dc/s2n-quic-dc/src/socket/recv/router.rs create mode 100644 dc/s2n-quic-dc/src/socket/recv/udp.rs diff --git a/dc/s2n-quic-dc/Cargo.toml b/dc/s2n-quic-dc/Cargo.toml index 920d982ad4..2d8599edc1 100644 --- a/dc/s2n-quic-dc/Cargo.toml +++ b/dc/s2n-quic-dc/Cargo.toml @@ -12,7 +12,12 @@ exclude = ["corpus.tar.gz"] [features] default = ["tokio"] -testing = ["bolero-generator", "s2n-quic-core/testing", "s2n-quic-platform/testing", "tracing-subscriber"] +testing = [ + "bolero-generator", + "s2n-quic-core/testing", + "s2n-quic-platform/testing", + "tracing-subscriber", +] tokio = ["tokio/io-util", "tokio/net", "tokio/rt-multi-thread", "tokio/time"] [dependencies] @@ -41,7 +46,9 @@ hashbrown = "0.15" thiserror = "2" tokio = { version = "1", default-features = false, features = ["sync"] } tracing = "0.1" -tracing-subscriber = { version = "0.3", features = ["env-filter"], optional = true } +tracing-subscriber = { version = "0.3", features = [ + "env-filter", +], optional = true } zerocopy = { version = "0.7", features = ["derive"] } zeroize = "1" parking_lot = "0.12" @@ -53,13 +60,12 @@ bolero-generator = "0.12" insta = "1" s2n-codec = { path = "../../common/s2n-codec", features = ["testing"] } s2n-quic-core = { path = "../../quic/s2n-quic-core", features = ["testing"] } -s2n-quic-platform = { path = "../../quic/s2n-quic-platform", features = ["testing"] } +s2n-quic-platform = { path = "../../quic/s2n-quic-platform", features = [ + "testing", +] } tokio = { version = "1", features = ["full"] } tracing-subscriber = { version = "0.3", features = ["env-filter"] } [lints.rust.unexpected_cfgs] level = "warn" -check-cfg = [ - 'cfg(kani)', - 'cfg(todo)', -] +check-cfg = ['cfg(fuzzing)', 'cfg(kani)', 'cfg(todo)'] diff --git a/dc/s2n-quic-dc/src/socket.rs b/dc/s2n-quic-dc/src/socket.rs index 998aa8a904..368a6f9550 100644 --- a/dc/s2n-quic-dc/src/socket.rs +++ b/dc/s2n-quic-dc/src/socket.rs @@ -13,4 +13,6 @@ pub use bpf::Pair; #[cfg(not(target_os = "linux"))] pub use pair::Pair; +pub mod recv; + pub use s2n_quic_platform::socket::options::{Options, ReusePort}; diff --git a/dc/s2n-quic-dc/src/socket/recv.rs b/dc/s2n-quic-dc/src/socket/recv.rs new file mode 100644 index 0000000000..ae84dc3248 --- /dev/null +++ b/dc/s2n-quic-dc/src/socket/recv.rs @@ -0,0 +1,7 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +pub mod descriptor; +pub mod pool; +pub mod router; +pub mod udp; diff --git a/dc/s2n-quic-dc/src/socket/recv/__fuzz__/socket__recv__pool__tests__model/corpus.tar.gz b/dc/s2n-quic-dc/src/socket/recv/__fuzz__/socket__recv__pool__tests__model/corpus.tar.gz new file mode 100644 index 0000000000..e9c7cf8dc4 --- /dev/null +++ b/dc/s2n-quic-dc/src/socket/recv/__fuzz__/socket__recv__pool__tests__model/corpus.tar.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8cbb74f24a784d9c26aa6d4a973e600ef8545302e8f0e5e24612e0ac20a406ac +size 2232320 diff --git a/dc/s2n-quic-dc/src/socket/recv/descriptor.rs b/dc/s2n-quic-dc/src/socket/recv/descriptor.rs new file mode 100644 index 0000000000..787fb0336d --- /dev/null +++ b/dc/s2n-quic-dc/src/socket/recv/descriptor.rs @@ -0,0 +1,446 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use crate::msg::{addr::Addr, cmsg}; +use core::fmt; +use s2n_quic_core::inet::ExplicitCongestionNotification; +use std::{ + io::IoSliceMut, + marker::PhantomData, + ptr::NonNull, + sync::{ + atomic::{AtomicUsize, Ordering}, + Weak, + }, +}; +use tracing::trace; + +/// Callback which releases a descriptor back into the free list +pub(super) trait FreeList: 'static + Send + Sync { + fn free(&self, descriptor: Descriptor); +} + +/// A handle to various parts for the descriptor group instance +pub(super) struct Memory { + capacity: u16, + references: AtomicUsize, + free_list: Weak, + #[allow(dead_code)] + region: Box, +} + +impl Memory { + pub(super) fn new( + capacity: u16, + free_list: Weak, + region: Box, + ) -> Box { + Box::new(Self { + capacity, + references: AtomicUsize::new(0), + free_list, + region, + }) + } +} + +/// A pointer to a single descriptor in a group +pub(super) struct Descriptor { + ptr: NonNull, + phantom: PhantomData, +} + +impl Descriptor { + #[inline] + pub(super) fn new(ptr: NonNull) -> Self { + Self { + ptr, + phantom: PhantomData, + } + } + + #[inline] + pub(super) fn id(&self) -> u64 { + self.inner().id + } + + #[inline] + fn inner(&self) -> &DescriptorInner { + unsafe { self.ptr.as_ref() } + } + + #[inline] + fn addr(&self) -> &Addr { + unsafe { self.inner().address.as_ref() } + } + + #[inline] + fn data(&self) -> NonNull { + self.inner().data + } + + #[inline] + fn upgrade(&self) { + let inner = self.inner(); + trace!(upgrade = inner.id); + inner.references.fetch_add(1, Ordering::Relaxed); + unsafe { + inner + .memory + .as_ref() + .references + .fetch_add(1, Ordering::Relaxed); + } + } + + #[inline] + fn clone_filled(&self) -> Self { + // https://github.com/rust-lang/rust/blob/28b83ee59698ae069f5355b8e03f976406f410f5/library/alloc/src/sync.rs#L2175 + // > Using a relaxed ordering is alright here, as knowledge of the + // > original reference prevents other threads from erroneously deleting + // > the object. + let inner = self.inner(); + inner.references.fetch_add(1, Ordering::Relaxed); + trace!(clone = inner.id); + Self { + ptr: self.ptr, + phantom: PhantomData, + } + } + + #[inline] + fn drop_filled(&self) { + let inner = self.inner(); + let desc_ref = inner.references.fetch_sub(1, Ordering::Release); + debug_assert_ne!(desc_ref, 0, "reference count underflow"); + + // based on the implementation in: + // https://github.com/rust-lang/rust/blob/28b83ee59698ae069f5355b8e03f976406f410f5/library/alloc/src/sync.rs#L2551 + if desc_ref != 1 { + trace!(drop_desc_ref = inner.id); + return; + } + + core::sync::atomic::fence(Ordering::Acquire); + + let mem = inner.free(self); + + trace!(free_desc = inner.id, state = %"filled"); + + drop(mem); + } + + #[inline] + pub(super) fn drop_unfilled(&self) { + let inner = self.inner(); + inner.references.store(0, Ordering::Release); + let mem = inner.free(self); + + trace!(free_desc = inner.id, state = %"unfilled"); + + drop(mem); + } +} + +unsafe impl Send for Descriptor {} +unsafe impl Sync for Descriptor {} + +pub(super) struct DescriptorInner { + id: u64, + address: NonNull, + data: NonNull, + + references: AtomicUsize, + + memory: NonNull, +} + +impl DescriptorInner { + pub(super) fn new( + id: u64, + address: NonNull, + data: NonNull, + memory: NonNull, + ) -> Self { + Self { + id, + address, + data, + references: AtomicUsize::new(0), + memory, + } + } + + #[inline] + fn capacity(&self) -> u16 { + unsafe { self.memory.as_ref().capacity } + } + + /// Frees the descriptor back into the pool + #[inline] + fn free(&self, desc: &Descriptor) -> Option> { + let memory = unsafe { self.memory.as_ref() }; + let mem_refs = memory.references.fetch_sub(1, Ordering::Release); + debug_assert_ne!(mem_refs, 0, "reference count underflow"); + + // if the free_list is still active (the allocator hasn't dropped) then just push the id + // TODO Weak::upgrade is a bit expensive since it clones the `Arc`, only to drop it again + if let Some(free_list) = memory.free_list.upgrade() { + free_list.free(Descriptor { + ptr: desc.ptr, + phantom: PhantomData, + }); + return None; + } + + // the free_list no longer active and we need to clean up the memory + + // based on the implementation in: + // https://github.com/rust-lang/rust/blob/28b83ee59698ae069f5355b8e03f976406f410f5/library/alloc/src/sync.rs#L2551 + if mem_refs != 1 { + trace!(memory_draining = mem_refs - 1, desc = self.id); + return None; + } + + core::sync::atomic::fence(Ordering::Acquire); + + trace!(memory_free = ?self.memory.as_ptr(), desc = self.id); + + // return the boxed memory rather than free it here - this works around + // any stacked borrowing issues found by Miri + Some(unsafe { Box::from_raw(self.memory.as_ptr()) }) + } +} + +/// An unfilled packet +pub struct Unfilled { + desc: Option, +} + +impl fmt::Debug for Unfilled { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let desc = self.desc.as_ref().expect("invalid state"); + f.debug_struct("Unfilled").field("id", &desc.id()).finish() + } +} + +impl Unfilled { + #[inline] + pub(super) fn from_descriptor(desc: Descriptor) -> Self { + desc.upgrade(); + Self { desc: Some(desc) } + } + + /// Fills the packet with the given callback, if the callback is successful + #[inline] + pub fn recv_with(mut self, f: F) -> Result + where + F: FnOnce(&mut Addr, &mut cmsg::Receiver, IoSliceMut) -> Result, + { + let desc = self.desc.take().expect("invalid state"); + let inner = desc.inner(); + let addr = unsafe { &mut *inner.address.as_ptr() }; + let capacity = inner.capacity() as usize; + let data = unsafe { core::slice::from_raw_parts_mut(inner.data.as_ptr(), capacity) }; + let iov = IoSliceMut::new(data); + let mut cmsg = cmsg::Receiver::default(); + + let len = match f(addr, &mut cmsg, iov) { + Ok(len) => { + debug_assert!(len <= capacity); + len.min(capacity) as u16 + } + Err(err) => { + let unfilled = Self { desc: Some(desc) }; + return Err((unfilled, err)); + } + }; + + let segment_len = cmsg.segment_len(); + let ecn = cmsg.ecn(); + let desc = Filled { + desc, + offset: 0, + len, + ecn, + }; + let segments = Segments { + descriptor: Some(desc), + segment_len, + }; + Ok(segments) + } +} + +impl Drop for Unfilled { + #[inline] + fn drop(&mut self) { + if let Some(desc) = self.desc.take() { + // put the descriptor back in the pool if it wasn't filled + desc.drop_unfilled(); + } + } +} + +pub struct Filled { + desc: Descriptor, + offset: u16, + len: u16, + ecn: ExplicitCongestionNotification, +} + +impl fmt::Debug for Filled { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let alt = f.alternate(); + + let mut s = f.debug_struct("Filled"); + s.field("id", &self.desc.id()) + .field("remote_address", &self.remote_address().get()) + .field("ecn", &self.ecn); + + if alt { + s.field("payload", &self.payload()); + } else { + s.field("payload_len", &self.len); + } + + s.finish() + } +} + +impl Filled { + /// Returns the ECN markings for the packet + #[inline] + pub fn ecn(&self) -> ExplicitCongestionNotification { + self.ecn + } + + /// Returns the length of the payload + #[inline] + pub fn len(&self) -> u16 { + self.len + } + + /// Returns `true` if the payload is empty + #[inline] + pub fn is_empty(&self) -> bool { + self.len == 0 + } + + /// Returns the remote address of the packet + #[inline] + pub fn remote_address(&self) -> &Addr { + // NOTE: addr_mut can't be used since the `inner` is reference counted to allow for GRO + self.desc.addr() + } + + /// Returns the packet payload + #[inline] + pub fn payload(&self) -> &[u8] { + unsafe { + let ptr = self.desc.data().as_ptr().add(self.offset as _); + let len = self.len as usize; + core::slice::from_raw_parts(ptr, len) + } + } + + /// Returns a mutable packet payload + // NOTE: this is safe since we guarantee no `Filled` references overlap + #[inline] + pub fn payload_mut(&mut self) -> &mut [u8] { + unsafe { + let ptr = self.desc.data().as_ptr().add(self.offset as _); + let len = self.len as usize; + core::slice::from_raw_parts_mut(ptr, len) + } + } + + /// Splits the packet into two at the given index. + /// + /// Afterwards `self` contains elements `[at, len)`, and the returned + /// [`Filled`] contains elements `[0, at)`. + /// + /// This is an `O(1)` operation that just increases the reference count and + /// sets a few indices. + /// + /// # Panics + /// + /// Panics if `at > len`. + #[must_use = "consider Filled::advance if you don't need the other half"] + #[inline] + pub fn split_to(&mut self, at: u16) -> Self { + assert!(at <= self.len); + let offset = self.offset; + let ecn = self.ecn; + self.offset += at; + self.len -= at; + Self { + desc: self.desc.clone_filled(), + offset, + len: at, + ecn, + } + } + + /// Shortens the packet, keeping the first `len` bytes and dropping the + /// rest. + /// + /// If `len` is greater than the packet's current length, this has no + /// effect. + #[inline] + pub fn truncate(&mut self, len: u16) { + self.len = len.min(self.len); + } + + /// Advances the start of the packet by `len` + /// + /// # Panics + /// + /// This function panics if `len > self.len()` + #[inline] + pub fn advance(&mut self, len: u16) { + assert!(len <= self.len); + self.offset += len; + self.len -= len; + } +} + +impl Drop for Filled { + #[inline] + fn drop(&mut self) { + self.desc.drop_filled() + } +} + +/// An iterator over all of the filled segments in a packet +/// +/// This is used for when the socket interface allows for receiving multiple packets +/// in a single syscall, e.g. GRO. +pub struct Segments { + descriptor: Option, + segment_len: u16, +} + +impl Iterator for Segments { + type Item = Filled; + + #[inline] + fn next(&mut self) -> Option { + // if the segment length wasn't specified, then just return the entire thing + if self.segment_len == 0 { + return self.descriptor.take(); + } + + let descriptor = self.descriptor.as_mut()?; + + // if the current descriptor exceeds the segment length then we need to split it off in bump + // the reference counts + if descriptor.len() > self.segment_len { + return Some(descriptor.split_to(self.segment_len as _)); + } + + // the segment len was bigger than the overall descriptor so return the whole thing to avoid + // reference count churn + self.descriptor.take() + } +} diff --git a/dc/s2n-quic-dc/src/socket/recv/pool.rs b/dc/s2n-quic-dc/src/socket/recv/pool.rs new file mode 100644 index 0000000000..0d15cded6b --- /dev/null +++ b/dc/s2n-quic-dc/src/socket/recv/pool.rs @@ -0,0 +1,378 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + msg::addr::Addr, + socket::recv::descriptor::{Descriptor, DescriptorInner, FreeList, Memory, Unfilled}, +}; +use std::{ + alloc::Layout, + ptr::NonNull, + sync::{Arc, Mutex}, +}; + +#[derive(Clone)] +pub struct Pool { + free: Arc, +} + +impl Pool { + /// Creates a pool with the given `max_packet_size` and `packet_count`. + /// + /// # Notes + /// + /// `max_packet_size` does not account for GRO capabilities of the underlying socket. If + /// GRO is enabled, the `max_packet_size` should be set to `u16::MAX`. + #[inline] + pub fn new(max_packet_size: u16, packet_count: usize) -> Self { + let free = Arc::new(Free(Mutex::new(Vec::with_capacity(packet_count)))); + + let (region, layout) = Region::alloc(max_packet_size, packet_count); + + let ptr = region.ptr; + let packet = layout.packet; + let addr_offset = layout.addr_offset; + let packet_offset = layout.packet_offset; + let max_packet_size = layout.max_packet_size; + let region = Box::new(region); + + let memory = Memory::new(max_packet_size, Arc::downgrade(&free), region); + // we leak the memory pointer since it frees itself when the final reference is dropped + let memory = Box::leak(memory); + let memory = unsafe { NonNull::new_unchecked(memory) }; + + for idx in 0..packet_count { + let offset = packet.size() * idx; + unsafe { + let descriptor = ptr.as_ptr().add(offset).cast::(); + let addr = ptr.as_ptr().add(offset + addr_offset).cast::(); + let data = ptr.as_ptr().add(offset + packet_offset); + + // `data` pointer is already zeroed out with the initial allocation + // initialize the address + addr.write(Addr::default()); + // initialize the descriptor - note that it is self-referential to `addr`, `data`, and `memory` + descriptor.write(DescriptorInner::new( + idx as _, + NonNull::new_unchecked(addr), + NonNull::new_unchecked(data), + memory, + )); + + // push the descriptor into the free list + let descriptor = Descriptor::new(NonNull::new_unchecked(descriptor)); + let descriptor = Unfilled::from_descriptor(descriptor); + free.0.lock().unwrap().push(descriptor); + } + } + + Self { free: free.clone() } + } + + /// Allocates an [`Unfilled`] packet from the [`Pool`] + #[inline] + pub fn alloc(&self) -> Option { + self.free.alloc() + } +} + +struct Region { + ptr: NonNull, + layout: Layout, +} + +struct RegionLayout { + packet: Layout, + addr_offset: usize, + packet_offset: usize, + max_packet_size: u16, +} + +unsafe impl Send for Region {} +unsafe impl Sync for Region {} + +impl Region { + #[inline] + fn alloc(mut max_packet_size: u16, packet_count: usize) -> (Self, RegionLayout) { + debug_assert!(max_packet_size > 0, "packets need to be at least 1 byte"); + debug_assert!(packet_count > 0, "there needs to be at least 1 packet"); + + // first create the descriptor layout + let descriptor = Layout::new::(); + // extend it with the address value + let (header, addr_offset) = descriptor.extend(Layout::new::()).unwrap(); + // finally place the packet data at the end + let (packet, packet_offset) = header + .extend(Layout::array::(max_packet_size as usize).unwrap()) + .unwrap(); + + // add any extra padding we need + let without_padding_len = packet.size(); + let packet = packet.pad_to_align(); + + // if we needed to add padding then use that for the packet buffer since it will go to waste otherwise + let padding_len = packet.size() - without_padding_len; + max_packet_size = max_packet_size.saturating_add(padding_len as u16); + + let packets = { + // TODO use `packet.repeat(packet_count)` once stable + // https://doc.rust-lang.org/stable/core/alloc/struct.Layout.html#method.repeat + Layout::from_size_align(packet.size() * packet_count, packet.align()).unwrap() + }; + + let ptr = unsafe { + // SAFETY: the layout is non-zero size + debug_assert_ne!(packets.size(), 0); + std::alloc::alloc_zeroed(packets) + }; + let ptr = NonNull::new(ptr).expect("failed to allocate memory"); + + let region = Self { + ptr, + layout: packets, + }; + + let layout = RegionLayout { + packet, + addr_offset, + packet_offset, + max_packet_size, + }; + + (region, layout) + } +} + +impl Drop for Region { + #[inline] + fn drop(&mut self) { + unsafe { + std::alloc::dealloc(self.ptr.as_ptr(), self.layout); + } + } +} + +struct Free(Mutex>); + +impl Free { + #[inline] + fn alloc(&self) -> Option { + self.0.lock().unwrap().pop() + } +} + +impl FreeList for Free { + #[inline] + fn free(&self, descriptor: Descriptor) { + // convert it back to an `Unfilled` descriptor so the reference counting works + let descriptor = Unfilled::from_descriptor(descriptor); + self.0.lock().unwrap().push(descriptor); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{socket::recv::descriptor::Filled, testing::init_tracing}; + use bolero::{check, TypeGenerator}; + use std::{ + collections::{HashMap, VecDeque}, + net::{Ipv4Addr, SocketAddr}, + }; + + #[derive(TypeGenerator, Debug)] + enum Op { + Alloc, + DropUnfilled { + idx: u8, + }, + Fill { + idx: u8, + port: u8, + segment_count: u8, + segment_len: u8, + }, + DropFilled { + idx: u8, + }, + } + + struct Model { + pool: Pool, + epoch: u64, + references: HashMap, + unfilled: VecDeque, + filled: VecDeque<(u64, Filled)>, + expected_free_packets: usize, + } + + impl Model { + fn new(max_packet_size: u16, packet_count: usize) -> Self { + let pool = Pool::new(max_packet_size, packet_count); + Self { + pool, + epoch: 0, + references: HashMap::new(), + unfilled: VecDeque::new(), + filled: VecDeque::new(), + expected_free_packets: packet_count, + } + } + + fn alloc(&mut self) { + if let Some(desc) = self.pool.alloc() { + self.unfilled.push_back(desc); + self.expected_free_packets -= 1; + } else { + assert_eq!(self.expected_free_packets, 0); + } + } + + fn drop_unfilled(&mut self, idx: usize) { + if self.unfilled.is_empty() { + return; + } + + let idx = idx % self.unfilled.len(); + let _ = self.unfilled.remove(idx).unwrap(); + self.expected_free_packets += 1; + } + + fn drop_filled(&mut self, idx: usize) { + if self.filled.is_empty() { + return; + } + let idx = idx % self.filled.len(); + let (epoch, _descriptor) = self.filled.remove(idx).unwrap(); + let count = self.references.entry(epoch).or_default(); + *count -= 1; + if *count == 0 { + self.references.remove(&epoch); + self.expected_free_packets += 1; + } + } + + fn fill(&mut self, idx: usize, port: u16, segment_count: u8, segment_len: u8) { + let Self { + epoch, + references, + unfilled, + filled, + expected_free_packets, + .. + } = self; + + if unfilled.is_empty() { + return; + } + let idx = idx % unfilled.len(); + let unfilled = unfilled.remove(idx).unwrap(); + + let src = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), port); + + let segment_len = segment_len as usize; + let segment_count = segment_count as usize; + let mut actual_segment_count = 0; + + let res = unfilled.recv_with(|addr, cmsg, mut payload| { + if port == 0 { + return Err(()); + } + + addr.set(src.into()); + + if segment_count > 1 { + cmsg.set_segment_len(segment_len as _); + } + let mut offset = 0; + + for segment_idx in 0..segment_count { + let remaining = &mut payload[offset..]; + let len = remaining.len().min(segment_len); + if len == 0 { + break; + } + + actual_segment_count += 1; + remaining[..len].fill(segment_idx as u8); + offset += len; + } + + Ok(offset) + }); + + assert_eq!(res.is_err(), port == 0); + + if let Ok(segments) = res { + if actual_segment_count > 0 { + references.insert(*epoch, actual_segment_count); + } + + for (idx, segment) in segments.enumerate() { + // we allow only one segment to be empty. this makes it easier to log when we get empty packets, which are unexpected + if segment.is_empty() { + assert_eq!(actual_segment_count, 0); + assert_eq!(idx, 0); + *expected_free_packets += 1; + continue; + } + + assert!( + idx < actual_segment_count, + "{idx} < {actual_segment_count}, {:?}", + segment.payload() + ); + + // the final segment is allowed to be undersized + if idx == actual_segment_count - 1 { + assert!(segment.len() as usize <= segment_len); + } else { + assert_eq!(segment.len() as usize, segment_len); + } + + // make sure bytes match the segment pattern + for byte in segment.payload().iter() { + assert_eq!(*byte, idx as u8); + } + + filled.push_back((*epoch, segment)); + } + + *epoch += 1; + } else { + *expected_free_packets += 1; + } + } + + fn apply(&mut self, op: &Op) { + match op { + Op::Alloc => self.alloc(), + Op::DropUnfilled { idx } => self.drop_unfilled(*idx as usize), + Op::Fill { + idx, + port, + segment_count, + segment_len, + } => self.fill(*idx as _, *port as _, *segment_count, *segment_len), + Op::DropFilled { idx } => self.drop_filled(*idx as usize), + } + } + } + + #[test] + fn model_test() { + init_tracing(); + + check!() + .with_type::>() + .with_test_time(core::time::Duration::from_secs(20)) + .for_each(|ops| { + let max_packet_size = 1000; + let expected_free_packets = 16; + let mut model = Model::new(max_packet_size, expected_free_packets); + for op in ops { + model.apply(&op); + } + }); + } +} diff --git a/dc/s2n-quic-dc/src/socket/recv/router.rs b/dc/s2n-quic-dc/src/socket/recv/router.rs new file mode 100644 index 0000000000..2a7701955c --- /dev/null +++ b/dc/s2n-quic-dc/src/socket/recv/router.rs @@ -0,0 +1,159 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + credentials::Credentials, + packet::{self, stream}, + socket::recv::descriptor, +}; +use s2n_codec::DecoderBufferMut; +use s2n_quic_core::inet::SocketAddress; + +/// Routes incoming packet segments to the appropriate destination +pub trait Router { + const TAG_LEN: usize = 16; + + #[inline] + fn on_segment(&self, mut segment: descriptor::Filled) { + let remote_address = segment.remote_address().get(); + let decoder = DecoderBufferMut::new(segment.payload_mut()); + match decoder.decode_parameterized(Self::TAG_LEN) { + // We don't check `remaining` since we currently assume one packet per segment. + // If we ever support multiple packets per segment, we'll need to split the segment up even + // further and correctly dispatch to the right place. + Ok((packet, _remaining)) => match packet { + packet::Packet::Control(c) => { + let tag = c.tag(); + let stream_id = c.stream_id().copied(); + let credentials = *c.credentials(); + self.on_control_packet(tag, stream_id, credentials, segment); + } + packet::Packet::Stream(packet) => { + let tag = packet.tag(); + let stream_id = *packet.stream_id(); + let credentials = *packet.credentials(); + self.on_stream_packet(tag, stream_id, credentials, segment); + } + packet::Packet::Datagram(packet) => { + let tag = packet.tag(); + let credentials = *packet.credentials(); + self.on_datagram_packet(tag, credentials, segment); + } + packet::Packet::StaleKey(packet) => { + self.on_stale_key_packet(packet, remote_address); + } + packet::Packet::ReplayDetected(packet) => { + self.on_replay_detected_packet(packet, remote_address); + } + packet::Packet::UnknownPathSecret(packet) => { + self.on_unknown_path_secret_packet(packet, remote_address); + } + }, + Err(error) => { + self.on_decode_error(error, remote_address, segment); + } + } + } + + #[inline] + fn on_control_packet( + &self, + tag: packet::control::Tag, + id: Option, + credentials: Credentials, + segment: descriptor::Filled, + ) { + tracing::warn!( + unhandled_packet = "control", + ?tag, + ?id, + ?credentials, + remote_address = ?segment.remote_address(), + packet_len = segment.len() + ); + } + + #[inline] + fn on_stream_packet( + &self, + tag: stream::Tag, + id: stream::Id, + credentials: Credentials, + segment: descriptor::Filled, + ) { + tracing::warn!( + unhandled_packet = "stream", + ?tag, + ?id, + ?credentials, + remote_address = ?segment.remote_address(), + packet_len = segment.len() + ); + } + + #[inline] + fn on_datagram_packet( + &self, + tag: packet::datagram::Tag, + credentials: Credentials, + segment: descriptor::Filled, + ) { + tracing::warn!( + unhandled_packet = "datagram", + ?tag, + ?credentials, + remote_address = ?segment.remote_address(), + packet_len = segment.len() + ); + } + + #[inline] + fn on_stale_key_packet( + &self, + packet: packet::secret_control::stale_key::Packet, + remote_address: SocketAddress, + ) { + tracing::warn!(unhandled_packet = "stale_key", ?packet, ?remote_address,); + } + + #[inline] + fn on_replay_detected_packet( + &self, + packet: packet::secret_control::replay_detected::Packet, + remote_address: SocketAddress, + ) { + tracing::warn!( + unhandled_packet = "replay_detected", + ?packet, + ?remote_address, + ); + } + + #[inline] + fn on_unknown_path_secret_packet( + &self, + packet: packet::secret_control::unknown_path_secret::Packet, + remote_address: SocketAddress, + ) { + tracing::warn!( + unhandled_packet = "unknown_path_secret", + ?packet, + ?remote_address, + ); + } + + #[inline] + fn on_decode_error( + &self, + error: s2n_codec::DecoderError, + remote_address: SocketAddress, + segment: descriptor::Filled, + ) { + tracing::warn!( + ?error, + ?remote_address, + packet_len = segment.len(), + "failed to decode packet" + ); + } +} diff --git a/dc/s2n-quic-dc/src/socket/recv/udp.rs b/dc/s2n-quic-dc/src/socket/recv/udp.rs new file mode 100644 index 0000000000..3382b40b7f --- /dev/null +++ b/dc/s2n-quic-dc/src/socket/recv/udp.rs @@ -0,0 +1,75 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + socket::recv::{descriptor, pool, router::Router}, + stream::socket::fd::udp, +}; +use std::{collections::VecDeque, net::UdpSocket}; + +pub struct Allocator { + queue: VecDeque, + max_packet_size: u16, + packet_count: usize, +} + +impl Allocator { + pub fn new(max_packet_size: u16, packet_count: usize) -> Self { + // The Pool struct size is quite small so start off with 16 in case we need the space later + let mut queue = VecDeque::with_capacity(16); + queue.push_back(pool::Pool::new(max_packet_size, packet_count)); + Self { + queue, + max_packet_size, + packet_count, + } + } + + #[inline] + fn alloc(&mut self) -> descriptor::Unfilled { + let mut rotate_count = 0; + + // search through the list for a pool with a free segment + while rotate_count < self.queue.len() { + let front = self.queue.front_mut().unwrap(); + if let Some(message) = front.alloc() { + return message; + } + + self.queue.rotate_left(1); + rotate_count += 1; + } + + // we've exhausted all of the current pools so create a new one + let pool = pool::Pool::new(self.max_packet_size, self.packet_count); + let desc = pool.alloc().unwrap(); + self.queue.push_front(pool); + desc + } +} + +/// Receives packets from a blocking [`UdpSocket`] and dispatches into the provided [`Router`] +pub fn blocking(socket: UdpSocket, mut alloc: Allocator, router: R) { + loop { + let mut unfilled = alloc.alloc(); + loop { + let res = unfilled.recv_with(|addr, cmsg, buffer| { + udp::recv(&socket, addr, cmsg, &mut [buffer], Default::default()) + }); + + match res { + Ok(segments) => { + for segment in segments { + router.on_segment(segment); + } + break; + } + Err((desc, err)) => { + tracing::error!("socket recv error: {err}"); + unfilled = desc; + continue; + } + } + } + } +} diff --git a/dc/s2n-quic-dc/src/testing.rs b/dc/s2n-quic-dc/src/testing.rs index 59de185367..de2ef9b40f 100644 --- a/dc/s2n-quic-dc/src/testing.rs +++ b/dc/s2n-quic-dc/src/testing.rs @@ -9,6 +9,10 @@ pub fn assert_async_read(_v: &T) {} pub fn assert_async_write(_v: &T) {} pub fn init_tracing() { + if cfg!(any(miri, fuzzing)) { + return; + } + use std::sync::Once; static TRACING: Once = Once::new();