diff --git a/dc/s2n-quic-dc/Cargo.toml b/dc/s2n-quic-dc/Cargo.toml index 83850d298..18f7b447b 100644 --- a/dc/s2n-quic-dc/Cargo.toml +++ b/dc/s2n-quic-dc/Cargo.toml @@ -12,13 +12,19 @@ exclude = ["corpus.tar.gz"] [features] default = ["tokio"] -testing = ["bolero-generator", "s2n-quic-core/testing", "s2n-quic-platform/testing", "tracing-subscriber"] +testing = [ + "bolero-generator", + "s2n-quic-core/testing", + "s2n-quic-platform/testing", + "tracing-subscriber", +] tokio = ["tokio/io-util", "tokio/net", "tokio/rt-multi-thread", "tokio/time"] [dependencies] arrayvec = "0.7" atomic-waker = "1" aws-lc-rs = "1" +bach = "0.0.10" bitflags = "2" bolero-generator = { version = "0.13", default-features = false, optional = true } bytes = "1" @@ -41,7 +47,9 @@ hashbrown = "0.15" thiserror = "2" tokio = { version = "1", default-features = false, features = ["sync"] } tracing = "0.1" -tracing-subscriber = { version = "0.3", features = ["env-filter"], optional = true } +tracing-subscriber = { version = "0.3", features = [ + "env-filter", +], optional = true } zerocopy = { version = "0.7", features = ["derive"] } zeroize = "1" parking_lot = "0.12" @@ -53,13 +61,12 @@ bolero-generator = "0.13" insta = "1" s2n-codec = { path = "../../common/s2n-codec", features = ["testing"] } s2n-quic-core = { path = "../../quic/s2n-quic-core", features = ["testing"] } -s2n-quic-platform = { path = "../../quic/s2n-quic-platform", features = ["testing"] } +s2n-quic-platform = { path = "../../quic/s2n-quic-platform", features = [ + "testing", +] } tokio = { version = "1", features = ["full"] } tracing-subscriber = { version = "0.3", features = ["env-filter"] } [lints.rust.unexpected_cfgs] level = "warn" -check-cfg = [ - 'cfg(kani)', - 'cfg(todo)', -] +check-cfg = ['cfg(kani)', 'cfg(todo)'] diff --git a/dc/s2n-quic-dc/src/stream/server/tokio/accept.rs b/dc/s2n-quic-dc/src/stream/server/tokio/accept.rs index 043bbee96..49945da97 100644 --- a/dc/s2n-quic-dc/src/stream/server/tokio/accept.rs +++ b/dc/s2n-quic-dc/src/stream/server/tokio/accept.rs @@ -8,7 +8,7 @@ use crate::{ application::{Builder as StreamBuilder, Stream}, environment::{tokio::Environment, Environment as _}, }, - sync::channel, + sync::mpmc as channel, }; use core::time::Duration; use s2n_quic_core::time::Clock; diff --git a/dc/s2n-quic-dc/src/stream/server/tokio/stats.rs b/dc/s2n-quic-dc/src/stream/server/tokio/stats.rs index f6cba231c..9c9c64410 100644 --- a/dc/s2n-quic-dc/src/stream/server/tokio/stats.rs +++ b/dc/s2n-quic-dc/src/stream/server/tokio/stats.rs @@ -1,7 +1,7 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -use crate::{event::Subscriber, sync::channel as chan}; +use crate::{event::Subscriber, sync::mpmc as chan}; use core::{ sync::atomic::{AtomicU64, Ordering}, time::Duration, diff --git a/dc/s2n-quic-dc/src/sync.rs b/dc/s2n-quic-dc/src/sync.rs index 893bcc384..5a8ea09ab 100644 --- a/dc/s2n-quic-dc/src/sync.rs +++ b/dc/s2n-quic-dc/src/sync.rs @@ -1,5 +1,6 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -pub mod channel; +pub mod mpmc; +pub mod mpsc; pub mod ring_deque; diff --git a/dc/s2n-quic-dc/src/sync/channel.rs b/dc/s2n-quic-dc/src/sync/mpmc.rs similarity index 69% rename from dc/s2n-quic-dc/src/sync/channel.rs rename to dc/s2n-quic-dc/src/sync/mpmc.rs index c22b81bf7..1c04bfe9c 100644 --- a/dc/s2n-quic-dc/src/sync/channel.rs +++ b/dc/s2n-quic-dc/src/sync/mpmc.rs @@ -59,6 +59,11 @@ impl Channel { } } +/// A message sender +/// +/// Note that this channel implementation does not allow for backpressure on the +/// sending rate. Instead, the queue is rotated to make room for new items and +/// returned to the sender. pub struct Sender { channel: Arc>, } @@ -122,8 +127,6 @@ pin_project! { /// are dropped, the channel becomes closed. /// /// The channel can also be closed manually by calling [`Receiver::close()`]. - /// - /// Receivers implement the [`Stream`] trait. pub struct Receiver { // Inner channel state. channel: Arc>, @@ -148,7 +151,6 @@ pin_project! { } } -#[allow(dead_code)] // TODO remove this once the module is public impl Receiver { /// Attempts to receive a message from the front of the channel. /// @@ -204,6 +206,12 @@ impl Receiver { channel: Arc::downgrade(&self.channel), } } + + /// Closes the channel for receiving + #[inline] + pub fn close(&self) -> Result<(), Closed> { + self.channel.close() + } } impl fmt::Debug for Receiver { @@ -232,7 +240,6 @@ pub struct WeakReceiver { channel: Weak>, } -#[allow(dead_code)] // TODO remove this once the module is public impl WeakReceiver { #[inline] pub fn pop_front_if(&self, priority: Priority, f: F) -> Result, Closed> @@ -316,3 +323,137 @@ impl EventListenerFuture for RecvInner<'_, T> { } } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::testing::{ext::*, sim, task}; + use std::time::Duration; + + #[test] + fn test_unlimited() { + sim(|| { + let (tx, rx) = new(2); + + async move { + for v in 0u64.. { + if tx.send_back(v).is_err() { + return; + }; + // let the receiver read from the task + task::yield_now().await; + } + } + .primary() + .spawn(); + + async move { + for expected in 0u64..10 { + let actual = rx.recv_front().await.unwrap(); + assert_eq!(actual, expected); + } + } + .primary() + .spawn(); + }); + } + + #[test] + fn test_send_limited() { + sim(|| { + let (tx, rx) = new(2); + + async move { + for v in 0u64.. { + if tx.send_back(v).is_err() { + return; + }; + Duration::from_millis(1).sleep().await; + } + } + .primary() + .spawn(); + + async move { + for expected in 0u64..10 { + let actual = rx.recv_front().await.unwrap(); + assert_eq!(actual, expected); + } + } + .primary() + .spawn(); + }); + } + + #[test] + fn test_recv_limited() { + sim(|| { + let (tx, rx) = new(2); + + async move { + for v in 0u64.. { + match tx.send_back(v) { + Ok(Some(_old)) => { + // the channel doesn't provide backpressure so we'll need to sleep + Duration::from_millis(1).sleep().await; + } + Ok(None) => { + continue; + } + Err(_) => { + // the receiver is done + return; + } + } + } + } + .primary() + .spawn(); + + async move { + let mut min = 0; + for _ in 0u64..10 { + let actual = rx.recv_front().await.unwrap(); + assert!(actual > min || actual == 0); + min = actual; + Duration::from_millis(1).sleep().await; + } + } + .primary() + .spawn(); + }); + } + + #[test] + fn test_multi_recv() { + sim(|| { + let (tx, rx) = new(2); + + async move { + for v in 0u64.. { + if tx.send_back(v).is_err() { + return; + }; + // let the receiver read from the task + task::yield_now().await; + } + } + .primary() + .spawn(); + + for _ in 0..2 { + let rx = rx.clone(); + async move { + let mut min = 0; + for _ in 0u64..10 { + let actual = rx.recv_front().await.unwrap(); + assert!(actual > min || actual == 0, "{actual} > {min}"); + min = actual; + } + } + .primary() + .spawn(); + } + }); + } +} diff --git a/dc/s2n-quic-dc/src/sync/mpsc.rs b/dc/s2n-quic-dc/src/sync/mpsc.rs new file mode 100644 index 000000000..1dafdd046 --- /dev/null +++ b/dc/s2n-quic-dc/src/sync/mpsc.rs @@ -0,0 +1,300 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use crate::sync::ring_deque::{self, RingDeque}; +use core::{fmt, task::Poll}; +use std::{ + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, + task::Waker, +}; + +pub use ring_deque::{Closed, Priority}; + +pub fn new(cap: usize) -> (Sender, Receiver) { + assert!(cap >= 1, "capacity must be at least 2"); + + let channel = Arc::new(Channel { + queue: RingDeque::new(cap), + sender_count: AtomicUsize::new(1), + }); + + let s = Sender { + channel: channel.clone(), + }; + let r = Receiver { channel }; + (s, r) +} + +struct Channel { + queue: RingDeque>, + sender_count: AtomicUsize, +} + +impl Channel { + /// Closes the channel and notifies all blocked operations. + /// + /// Returns `Err` if this call has closed the channel and it was not closed already. + fn close(&self) -> Result<(), Closed> { + self.queue.close()?; + + Ok(()) + } +} + +/// A message sender +/// +/// Note that this channel implementation does not allow for backpressure on the +/// sending rate. Instead, the queue is rotated to make room for new items and +/// returned to the sender. +pub struct Sender { + channel: Arc>, +} + +impl Sender { + #[inline] + pub fn send_back(&self, msg: T) -> Result, Closed> { + let res = self.channel.queue.push_back(msg)?; + + Ok(res) + } + + #[inline] + pub fn send_front(&self, msg: T) -> Result, Closed> { + let res = self.channel.queue.push_front(msg)?; + + Ok(res) + } +} + +impl Drop for Sender { + fn drop(&mut self) { + // Decrement the sender count and close the channel if it drops down to zero. + if self.channel.sender_count.fetch_sub(1, Ordering::AcqRel) == 1 { + let _ = self.channel.close(); + } + } +} + +impl fmt::Debug for Sender { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Sender {{ .. }}") + } +} + +impl Clone for Sender { + fn clone(&self) -> Sender { + let count = self.channel.sender_count.fetch_add(1, Ordering::Relaxed); + + // Make sure the count never overflows, even if lots of sender clones are leaked. + assert!(count < usize::MAX / 2, "too many senders"); + + Sender { + channel: self.channel.clone(), + } + } +} + +/// The receiving side of a channel. +/// +/// When the receiver is dropped, the channel will be closed. +/// +/// The channel can also be closed manually by calling [`Receiver::close()`]. +pub struct Receiver { + // Inner channel state. + channel: Arc>, +} + +impl Drop for Receiver { + fn drop(&mut self) { + let _ = self.channel.close(); + } +} + +impl Receiver { + /// Attempts to receive a message from the front of the channel. + /// + /// If the channel is empty, or empty and closed, this method returns an error. + #[inline] + pub fn try_recv_front(&self) -> Result, Closed> { + self.channel.queue.pop_front() + } + + /// Attempts to receive a message from the back of the channel. + /// + /// If the channel is empty, or empty and closed, this method returns an error. + #[inline] + pub fn try_recv_back(&self) -> Result, Closed> { + self.channel.queue.pop_back() + } + + /// Receives a message from the front of the channel. + /// + /// If the channel is empty, this method waits until there is a message. + /// + /// If the channel is closed, this method receives a message or returns an error if there are + /// no more messages. + #[inline] + pub async fn recv_front(&self) -> Result { + core::future::poll_fn(|cx| self.poll_recv_front(cx)).await + } + + /// Receives a message from the front of the channel + #[inline] + pub fn poll_recv_front(&self, cx: &mut core::task::Context<'_>) -> Poll> { + self.channel.queue.poll_pop_front(cx) + } + + /// Receives a message from the back of the channel. + /// + /// If the channel is empty, this method waits until there is a message. + /// + /// If the channel is closed, this method receives a message or returns an error if there are + /// no more messages. + #[inline] + pub async fn recv_back(&self) -> Result { + core::future::poll_fn(|cx| self.poll_recv_back(cx)).await + } + + /// Receives a message from the back of the channel. + #[inline] + pub fn poll_recv_back(&self, cx: &mut core::task::Context<'_>) -> Poll> { + self.channel.queue.poll_pop_back(cx) + } + + /// Swaps the contents of the channel with the given deque. + /// + /// If the channel is closed, this method returns an error. + #[inline] + pub async fn swap(&self, out: &mut std::collections::VecDeque) -> Result<(), Closed> { + core::future::poll_fn(|cx| self.poll_swap(cx, out)).await + } + + /// Swaps the contents of the channel with the given deque. + /// + /// If the channel is closed, this method returns an error. If the channel is currently + /// empty, `Pending` will be returned. + #[inline] + pub fn poll_swap( + &self, + cx: &mut core::task::Context<'_>, + out: &mut std::collections::VecDeque, + ) -> Poll> { + self.channel.queue.poll_swap(cx, out) + } + + /// Closes the channel for receiving + #[inline] + pub fn close(&self) -> Result<(), Closed> { + self.channel.close() + } +} + +impl fmt::Debug for Receiver { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Receiver {{ .. }}") + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::testing::{ext::*, sim, task}; + use std::time::Duration; + + #[test] + fn test_unlimited() { + sim(|| { + let (tx, rx) = new(2); + + async move { + for v in 0u64.. { + if tx.send_back(v).is_err() { + return; + }; + // let the receiver read from the task + task::yield_now().await; + } + } + .primary() + .spawn(); + + async move { + for expected in 0u64..10 { + let actual = rx.recv_front().await.unwrap(); + assert_eq!(actual, expected); + } + } + .primary() + .spawn(); + }); + } + + #[test] + fn test_send_limited() { + sim(|| { + let (tx, rx) = new(2); + + async move { + for v in 0u64.. { + if tx.send_back(v).is_err() { + return; + }; + Duration::from_millis(1).sleep().await; + } + } + .primary() + .spawn(); + + async move { + for expected in 0u64..10 { + let actual = rx.recv_front().await.unwrap(); + assert_eq!(actual, expected); + } + } + .primary() + .spawn(); + }); + } + + #[test] + fn test_recv_limited() { + sim(|| { + let (tx, rx) = new(2); + + async move { + for v in 0u64.. { + match tx.send_back(v) { + Ok(Some(_old)) => { + // the channel doesn't provide backpressure so we'll need to sleep + Duration::from_millis(1).sleep().await; + } + Ok(None) => { + continue; + } + Err(_) => { + // the receiver is done + return; + } + } + } + } + .primary() + .spawn(); + + async move { + let mut min = 0; + for _ in 0u64..10 { + let actual = rx.recv_front().await.unwrap(); + assert!(actual > min); + min = actual; + Duration::from_millis(1).sleep().await; + } + } + .primary() + .spawn(); + }); + } +} diff --git a/dc/s2n-quic-dc/src/sync/ring_deque.rs b/dc/s2n-quic-dc/src/sync/ring_deque.rs index bd952ace4..3724f428c 100644 --- a/dc/s2n-quic-dc/src/sync/ring_deque.rs +++ b/dc/s2n-quic-dc/src/sync/ring_deque.rs @@ -5,6 +5,7 @@ use s2n_quic_core::ensure; use std::{ collections::VecDeque, sync::{Arc, Mutex}, + task::{Context, Poll, Waker}, }; #[cfg(test)] @@ -20,11 +21,11 @@ pub enum Priority { Optional, } -pub struct RingDeque { - inner: Arc>>, +pub struct RingDeque { + inner: Arc>>, } -impl Clone for RingDeque { +impl Clone for RingDeque { #[inline] fn clone(&self) -> Self { Self { @@ -33,12 +34,23 @@ impl Clone for RingDeque { } } -#[allow(dead_code)] // TODO remove this once the module is public -impl RingDeque { +impl RingDeque { #[inline] pub fn new(capacity: usize) -> Self { + let waker = W::default(); + Self::with_waker(capacity, waker) + } +} + +impl RingDeque { + #[inline] + pub fn with_waker(capacity: usize, recv_waker: W) -> Self { let queue = VecDeque::with_capacity(capacity); - let inner = Inner { open: true, queue }; + let inner = Inner { + open: true, + queue, + recv_waker, + }; let inner = Arc::new(Mutex::new(inner)); RingDeque { inner } } @@ -54,6 +66,11 @@ impl RingDeque { }; inner.queue.push_back(value); + let waker = inner.recv_waker.take(); + drop(inner); + if let Some(waker) = waker { + waker.wake(); + } Ok(prev) } @@ -69,10 +86,39 @@ impl RingDeque { }; inner.queue.push_front(value); + let waker = inner.recv_waker.take(); + drop(inner); + if let Some(waker) = waker { + waker.wake(); + } Ok(prev) } + #[inline] + pub fn poll_swap(&self, cx: &mut Context, out: &mut VecDeque) -> Poll> { + debug_assert!(out.is_empty()); + let mut inner = self.lock()?; + if inner.queue.is_empty() { + inner.recv_waker.update(cx); + Poll::Pending + } else { + core::mem::swap(&mut inner.queue, out); + Ok(()).into() + } + } + + #[inline] + pub fn poll_pop_back(&self, cx: &mut Context) -> Poll> { + let mut inner = self.lock()?; + if let Some(item) = inner.queue.pop_back() { + Ok(item).into() + } else { + inner.recv_waker.update(cx); + Poll::Pending + } + } + #[inline] pub fn pop_back(&self) -> Result, Closed> { let mut inner = self.lock()?; @@ -104,6 +150,17 @@ impl RingDeque { } } + #[inline] + pub fn poll_pop_front(&self, cx: &mut Context) -> Poll> { + let mut inner = self.lock()?; + if let Some(item) = inner.queue.pop_front() { + Ok(item).into() + } else { + inner.recv_waker.update(cx); + Poll::Pending + } + } + #[inline] pub fn pop_front(&self) -> Result, Closed> { let mut inner = self.lock()?; @@ -139,18 +196,23 @@ impl RingDeque { pub fn close(&self) -> Result<(), Closed> { let mut inner = self.lock()?; inner.open = false; + let waker = inner.recv_waker.take(); + drop(inner); + if let Some(waker) = waker { + waker.wake(); + } Ok(()) } #[inline] - fn lock(&self) -> Result>, Closed> { + fn lock(&self) -> Result>, Closed> { let inner = self.inner.lock().unwrap(); ensure!(inner.open, Err(Closed)); Ok(inner) } #[inline] - fn try_lock(&self) -> Result>>, Closed> { + fn try_lock(&self) -> Result>>, Closed> { use std::sync::TryLockError; let inner = match self.inner.try_lock() { Ok(inner) => inner, @@ -162,7 +224,53 @@ impl RingDeque { } } -struct Inner { +struct Inner { open: bool, queue: VecDeque, + recv_waker: W, +} + +/// An interface for storing a waker in the synchronized queue +/// +/// This can be used for implementing single consumer queues without +/// additional machinery for storing wakers. +pub trait RecvWaker { + /// Takes the current waker and returns it, if set + /// + /// This is to avoid calling `wake` while holding the lock on the queue + /// to avoid contention. + fn take(&mut self) -> Option; + fn update(&mut self, cx: &mut core::task::Context); +} + +impl RecvWaker for () { + #[inline(always)] + fn take(&mut self) -> Option { + None + } + + #[inline(always)] + fn update(&mut self, _cx: &mut core::task::Context) { + panic!("polling is disabled"); + } +} + +impl RecvWaker for Option { + #[inline(always)] + fn take(&mut self) -> Option { + self.take() + } + + #[inline(always)] + fn update(&mut self, cx: &mut core::task::Context) { + let new_waker = cx.waker(); + match self { + Some(waker) => { + if !waker.will_wake(new_waker) { + *self = Some(new_waker.clone()); + } + } + None => *self = Some(new_waker.clone()), + } + } } diff --git a/dc/s2n-quic-dc/src/testing.rs b/dc/s2n-quic-dc/src/testing.rs index 59de18536..c891ecfd4 100644 --- a/dc/s2n-quic-dc/src/testing.rs +++ b/dc/s2n-quic-dc/src/testing.rs @@ -1,6 +1,13 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 +pub use bach::{ext, rand}; + +pub mod task { + pub use bach::task::*; + pub use tokio::task::yield_now; +} + pub fn assert_debug(_v: &T) {} pub fn assert_send(_v: &T) {} pub fn assert_sync(_v: &T) {} @@ -39,3 +46,10 @@ pub fn init_tracing() { .init(); }); } + +/// Runs a function in a deterministic, discrete event simulation environment +pub fn sim(f: impl FnOnce()) { + init_tracing(); + + bach::environment::default::Runtime::new().run(f); +}