diff --git a/dc/s2n-quic-dc/src/stream/server/tokio/tcp/manager.rs b/dc/s2n-quic-dc/src/stream/server/tokio/tcp/manager.rs index b16caadf8..9581ead82 100644 --- a/dc/s2n-quic-dc/src/stream/server/tokio/tcp/manager.rs +++ b/dc/s2n-quic-dc/src/stream/server/tokio/tcp/manager.rs @@ -16,9 +16,15 @@ use s2n_quic_core::{ recovery::RttEstimator, time::{Clock, Timestamp}, }; -use std::{collections::VecDeque, io}; +use std::io; use tracing::trace; +mod list; +#[cfg(test)] +mod tests; + +use list::List; + pub struct Manager where W: Worker, @@ -34,22 +40,48 @@ where W: Worker, { /// A set of worker entries which process newly-accepted streams - workers: Box<[(W, Waker)]>, + workers: Box<[Entry]>, /// FIFO queue for tracking free [`Worker`] entries /// /// None of the indices in this queue have associated sockets and are waiting to be assigned /// for work. - free: VecDeque, + free: List, /// A list of [`Worker`] entries that are in order of sojourn time, starting with the oldest. /// /// The front will be the first to be reclaimed in the case of overload. - by_sojourn_time: VecDeque, + by_sojourn_time: List, /// Tracks the [sojourn time](https://en.wikipedia.org/wiki/Mean_sojourn_time) of processing /// streams in worker entries. sojourn_time: RttEstimator, - /// The number of `by_sojourn_time` list entries that have completed but haven't yet - /// moved to the `free` list - gc_count: usize, +} + +struct Entry +where + W: Worker, +{ + worker: W, + waker: Waker, + link: list::Link, +} + +impl AsRef for Entry +where + W: Worker, +{ + #[inline] + fn as_ref(&self) -> &list::Link { + &self.link + } +} + +impl AsMut for Entry +where + W: Worker, +{ + #[inline] + fn as_mut(&mut self) -> &mut list::Link { + &mut self.link + } } impl Manager @@ -59,15 +91,29 @@ where #[inline] pub fn new(workers: impl IntoIterator) -> Self { let mut waker_set = waker::Set::default(); - let workers: Box<[_]> = workers + let mut workers: Box<[_]> = workers .into_iter() .enumerate() - .map(|(idx, worker)| (worker, waker_set.waker(idx))) + .map(|(idx, worker)| { + let waker = waker_set.waker(idx); + let link = list::Link::default(); + Entry { + worker, + waker, + link, + } + }) .collect(); let capacity = workers.len(); - let mut free = VecDeque::with_capacity(capacity); - free.extend(0..capacity); - let by_sojourn_time = VecDeque::with_capacity(capacity); + let mut free = List::default(); + for idx in 0..capacity { + unsafe { + // SAFETY: idx is in bounds + free.push(&mut workers, idx); + } + } + + let by_sojourn_time = List::default(); let inner = Inner { workers, @@ -75,7 +121,6 @@ where by_sojourn_time, // set the initial estimate high to avoid backlog churn before we get stable samples sojourn_time: RttEstimator::new(Duration::from_secs(30)), - gc_count: 0, }; Self { @@ -87,13 +132,12 @@ where #[inline] pub fn active_slots(&self) -> usize { - // don't include the pending GC streams - self.inner.by_sojourn_time.len() - self.inner.gc_count + self.inner.by_sojourn_time.len() } #[inline] pub fn free_slots(&self) -> usize { - self.inner.free.len() + self.inner.gc_count + self.inner.free.len() } #[inline] @@ -151,14 +195,20 @@ where return false; }; - self.inner.workers[idx].0.replace( + self.inner.workers[idx].worker.replace( remote_address, stream, connection_context, publisher, clock, ); - self.inner.by_sojourn_time.push_back(idx); + + unsafe { + // SAFETY: the idx is in bounds + self.inner + .by_sojourn_time + .push(&mut self.inner.workers, idx); + } // kick off the initial poll to register wakers with the socket self.inner.poll_worker(idx, cx, publisher, clock); @@ -218,9 +268,10 @@ where { let mut cf = ControlFlow::Continue(()); - let (worker, waker) = &mut self.workers[idx]; - let mut task_cx = task::Context::from_waker(waker); - let Poll::Ready(res) = worker.poll(&mut task_cx, cx, publisher, clock) else { + let entry = &mut self.workers[idx]; + let mut task_cx = task::Context::from_waker(&entry.waker); + let Poll::Ready(res) = entry.worker.poll(&mut task_cx, cx, publisher, clock) else { + debug_assert!(entry.worker.is_active()); return cf; }; @@ -230,7 +281,7 @@ where // update the accept_time estimate self.sojourn_time.update_rtt( Duration::ZERO, - worker.sojourn_time(&now), + entry.worker.sojourn_time(&now), now, true, PacketNumberSpace::ApplicationData, @@ -245,7 +296,11 @@ where } // the worker is all done so indicate we have another free slot - self.gc_count += 1; + unsafe { + // SAFETY: list entries are managed by the list impl; idx is in bounds + self.by_sojourn_time.remove(&mut self.workers, idx); + self.free.push(&mut self.workers, idx); + } cf } @@ -255,39 +310,25 @@ where where C: Clock, { - // if we're out of free workers and GC has been requested, then do a scan - if self.free.is_empty() && self.gc_count > 0 { - self.by_sojourn_time.retain(|idx| { - let idx = *idx; - let (worker, _waker) = &self.workers[idx]; - - // check if the worker is active - let is_active = worker.is_active(); - - // if the worker isn't active it means it's ready to move to the free list - if !is_active { - self.free.push_back(idx); - } - - is_active - }); - // we did a full scan so reset the value - self.gc_count = 0; - } - // if we have a free worker then use that - if let Some(idx) = self.free.pop_front() { + if let Some(idx) = unsafe { + // SAFETY: free list manages `workers` linked status + self.free.pop(&mut self.workers) + } { trace!(op = %"next_worker", free = idx); return Some(idx); } - let idx = *self.by_sojourn_time.front().unwrap(); - let sojourn = self.workers[idx].0.sojourn_time(clock); + let idx = self.by_sojourn_time.front().unwrap(); + let sojourn = self.workers[idx].worker.sojourn_time(clock); // if the worker's sojourn time exceeds the maximum, then reclaim it if sojourn >= self.max_sojourn_time() { trace!(op = %"next_worker", injected = idx, ?sojourn); - return self.by_sojourn_time.pop_front(); + return unsafe { + // SAFETY: by_sojourn_time list manages `workers` linked status + self.by_sojourn_time.pop(&mut self.workers) + }; } trace!(op = %"next_worker", ?sojourn, max_sojourn_time = ?self.max_sojourn_time()); @@ -301,32 +342,35 @@ where #[cfg(debug_assertions)] fn invariants(&self) { for idx in 0..self.workers.len() { - let in_ready = self.free.contains(&idx); - let in_working = self.by_sojourn_time.contains(&idx); assert!( - in_working ^ in_ready, - "worker should either be in ready ({in_ready}) or working ({in_working}) list" + self.free + .iter(&self.workers) + .chain(self.by_sojourn_time.iter(&self.workers)) + .filter(|v| *v == idx) + .count() + == 1, + "worker {idx} should be linked at all times\n{:?}", + self.workers[idx].link, ); } - for idx in self.free.iter().copied() { - let (worker, _waker) = &self.workers[idx]; - assert!(!worker.is_active()); + let mut expected_free_len = 0usize; + for idx in self.free.iter(&self.workers) { + let entry = &self.workers[idx]; + assert!(!entry.worker.is_active()); + expected_free_len += 1; } - - let mut expected_gc_count = 0; + assert_eq!(self.free.len(), expected_free_len, "{:?}", self.free); let mut prev_queue_time = None; - for idx in self.by_sojourn_time.iter().copied() { - let (worker, _waker) = &self.workers[idx]; + let mut active_len = 0usize; + for idx in self.by_sojourn_time.iter(&self.workers) { + let entry = &self.workers[idx]; - // if the worker doesn't have a stream then it should be marked for GC - if !worker.is_active() { - expected_gc_count += 1; - continue; - } + assert!(entry.worker.is_active()); + active_len += 1; - let queue_time = worker.queue_time(); + let queue_time = entry.worker.queue_time(); if let Some(prev) = prev_queue_time { assert!( prev <= queue_time, @@ -336,7 +380,12 @@ where prev_queue_time = Some(queue_time); } - assert_eq!(self.gc_count, expected_gc_count); + assert_eq!( + active_len, + self.by_sojourn_time.len(), + "{:?}", + self.by_sojourn_time + ); } } @@ -379,324 +428,3 @@ pub trait Worker { fn is_active(&self) -> bool; } - -#[cfg(test)] -mod tests { - use super::{Worker as _, *}; - use crate::event::{self, IntoEvent}; - use bolero::{check, TypeGenerator}; - use core::time::Duration; - use std::io; - - const WORKER_COUNT: usize = 4; - - #[derive(Clone, Copy, Debug, TypeGenerator)] - enum Op { - Insert, - Wake { - #[generator(0..WORKER_COUNT)] - idx: usize, - }, - Ready { - #[generator(0..WORKER_COUNT)] - idx: usize, - error: bool, - }, - Advance { - #[generator(1..=10)] - millis: u8, - }, - } - - enum State { - Idle, - Active, - Ready, - Error(io::ErrorKind), - } - - struct Worker { - queue_time: Timestamp, - state: State, - epoch: u64, - poll_count: u64, - } - - impl Worker { - fn new(clock: &C) -> Self - where - C: Clock, - { - Self { - queue_time: clock.get_time(), - state: State::Idle, - epoch: 0, - poll_count: 0, - } - } - } - - impl super::Worker for Worker { - type Context = (); - type ConnectionContext = (); - type Stream = (); - - fn replace( - &mut self, - _remote_address: SocketAddress, - _stream: Self::Stream, - _connection_context: Self::ConnectionContext, - _publisher: &Pub, - clock: &C, - ) where - Pub: EndpointPublisher, - C: Clock, - { - self.queue_time = clock.get_time(); - self.state = State::Active; - self.epoch += 1; - self.poll_count = 0; - } - - fn poll( - &mut self, - _task_cx: &mut task::Context, - _cx: &mut Self::Context, - _publisher: &Pub, - _clock: &C, - ) -> Poll, Option>> - where - Pub: EndpointPublisher, - C: Clock, - { - self.poll_count += 1; - match self.state { - State::Idle => { - unreachable!("shouldn't be polled when idle") - } - State::Active => Poll::Pending, - State::Ready => { - self.state = State::Idle; - Poll::Ready(Ok(ControlFlow::Continue(()))) - } - State::Error(err) => { - self.state = State::Idle; - Poll::Ready(Err(Some(err.into()))) - } - } - } - - fn queue_time(&self) -> Timestamp { - self.queue_time - } - - fn is_active(&self) -> bool { - matches!(self.state, State::Active | State::Ready | State::Error(_)) - } - } - - struct Harness { - manager: Manager, - clock: Timestamp, - subscriber: event::tracing::Subscriber, - } - - impl core::ops::Deref for Harness { - type Target = Manager; - - fn deref(&self) -> &Self::Target { - &self.manager - } - } - - impl core::ops::DerefMut for Harness { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.manager - } - } - - impl Default for Harness { - fn default() -> Self { - let clock = unsafe { Timestamp::from_duration(Duration::from_secs(1)) }; - let manager = Manager::::new((0..WORKER_COUNT).map(|_| Worker::new(&clock))); - let subscriber = event::tracing::Subscriber::default(); - Self { - manager, - clock, - subscriber, - } - } - } - - impl Harness { - pub fn poll(&mut self) { - self.manager.poll( - &mut (), - &publisher(&self.subscriber, &self.clock), - &self.clock, - ); - } - - pub fn insert(&mut self) -> bool { - self.manager.insert( - SocketAddress::default(), - (), - &mut (), - (), - &publisher(&self.subscriber, &self.clock), - &self.clock, - ) - } - - pub fn wake(&mut self, idx: usize) -> bool { - let (worker, waker) = &mut self.manager.inner.workers[idx]; - let is_active = worker.is_active(); - - if is_active { - waker.wake_by_ref(); - } - - is_active - } - - pub fn ready(&mut self, idx: usize) -> bool { - let (worker, waker) = &mut self.manager.inner.workers[idx]; - let is_active = worker.is_active(); - - if is_active { - worker.state = State::Ready; - waker.wake_by_ref(); - } - - is_active - } - - pub fn error(&mut self, idx: usize, error: io::ErrorKind) -> bool { - let (worker, waker) = &mut self.manager.inner.workers[idx]; - let is_active = worker.is_active(); - - if is_active { - worker.state = State::Error(error); - waker.wake_by_ref(); - } - - is_active - } - - pub fn advance(&mut self, time: Duration) { - self.clock += time; - } - - #[track_caller] - pub fn assert_epoch(&self, idx: usize, expected: u64) { - let (worker, _waker) = &self.manager.inner.workers[idx]; - assert_eq!(worker.epoch, expected); - } - - #[track_caller] - pub fn assert_poll_count(&self, idx: usize, expected: u64) { - let (worker, _waker) = &self.manager.inner.workers[idx]; - assert_eq!(worker.poll_count, expected); - } - } - - fn publisher<'a>( - subscriber: &'a event::tracing::Subscriber, - clock: &Timestamp, - ) -> event::EndpointPublisherSubscriber<'a, event::tracing::Subscriber> { - event::EndpointPublisherSubscriber::new( - crate::event::builder::EndpointMeta { - timestamp: clock.into_event(), - }, - None, - subscriber, - ) - } - - #[test] - fn invariants_test() { - check!().with_type::>().for_each(|ops| { - let mut harness = Harness::default(); - - for op in ops { - match op { - Op::Insert => { - harness.insert(); - } - Op::Wake { idx } => { - harness.wake(*idx); - } - Op::Ready { idx, error } => { - if *error { - harness.error(*idx, io::ErrorKind::ConnectionReset); - } else { - harness.ready(*idx); - } - } - Op::Advance { millis } => { - harness.advance(Duration::from_millis(*millis as u64)); - harness.poll(); - } - } - } - - harness.poll(); - }); - } - - #[test] - fn replace_test() { - let mut harness = Harness::default(); - assert_eq!(harness.active_slots(), 0); - assert_eq!(harness.capacity(), WORKER_COUNT); - - for idx in 0..4 { - assert!(harness.insert()); - assert_eq!(harness.active_slots(), 1 + idx); - harness.assert_epoch(idx, 1); - } - - // manager should not replace a slot if sojourn_time hasn't passed - assert!(!harness.insert()); - - // advance the clock by max_sojourn_time - harness.advance(harness.max_sojourn_time()); - harness.poll(); - assert_eq!(harness.active_slots(), WORKER_COUNT); - - for idx in 0..4 { - assert!(harness.insert()); - assert_eq!(harness.active_slots(), WORKER_COUNT); - harness.assert_epoch(idx, 2); - } - } - - #[test] - fn wake_test() { - let mut harness = Harness::default(); - assert!(harness.insert()); - // workers should be polled on insertion - harness.assert_poll_count(0, 1); - // workers should not be polled until woken - harness.poll(); - harness.assert_poll_count(0, 1); - - harness.wake(0); - harness.assert_poll_count(0, 1); - harness.poll(); - harness.assert_poll_count(0, 2); - } - - #[test] - fn ready_test() { - let mut harness = Harness::default(); - - assert_eq!(harness.active_slots(), 0); - assert!(harness.insert()); - assert_eq!(harness.active_slots(), 1); - harness.ready(0); - assert_eq!(harness.active_slots(), 1); - harness.poll(); - assert_eq!(harness.active_slots(), 0); - } -} diff --git a/dc/s2n-quic-dc/src/stream/server/tokio/tcp/manager/list.rs b/dc/s2n-quic-dc/src/stream/server/tokio/tcp/manager/list.rs new file mode 100644 index 000000000..b3e3ee806 --- /dev/null +++ b/dc/s2n-quic-dc/src/stream/server/tokio/tcp/manager/list.rs @@ -0,0 +1,366 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +/// List which manages the status of a slice of entries +/// +/// This implementation avoids allocation or shuffling by storing list links +/// inline with the entries. +/// +/// # Time complexity +/// +/// | [push] | [pop] | [remove] | +/// |---------|---------|----------| +/// | *O*(1) | *O*(1) | *O*(1) | +#[derive(Debug)] +pub struct List { + head: usize, + tail: usize, + len: usize, + /// Tracks if a node is linked or not but only when debug assertions are enabled + #[cfg(debug_assertions)] + linked: Vec, +} + +impl Default for List { + #[inline] + fn default() -> Self { + Self { + head: usize::MAX, + tail: usize::MAX, + len: 0, + #[cfg(debug_assertions)] + linked: vec![], + } + } +} + +impl List { + #[inline] + pub fn len(&self) -> usize { + self.len + } + + #[inline] + pub fn is_empty(&self) -> bool { + self.len == 0 + } + + /// # Safety + /// + /// Callers must ensure: + /// + /// * `entries` is only managed by [`List`] + /// * `idx` is less than `usize::MAX` + #[inline] + pub unsafe fn pop(&mut self, entries: &mut [L]) -> Option + where + L: AsMut, + { + if self.len == 0 { + return None; + } + + let idx = self.head; + let link = entries.get_unchecked_mut(idx).as_mut(); + self.head = link.next; + link.reset(); + + if self.head == usize::MAX { + self.tail = usize::MAX; + } else { + entries.get_unchecked_mut(self.head).as_mut().prev = usize::MAX; + } + + self.set_linked_status(idx, false); + + Some(idx) + } + + #[inline] + pub fn front(&self) -> Option { + if self.head == usize::MAX { + None + } else { + Some(self.head) + } + } + + /// # Safety + /// + /// Callers must ensure: + /// + /// * `entries` is only managed by [`List`] + /// * `idx` is in bounds of `entries` + /// * `idx` is less than `usize::MAX` + #[inline] + pub unsafe fn push(&mut self, entries: &mut [L], idx: usize) + where + L: AsMut, + { + let tail = self.tail; + if tail != usize::MAX { + entries.get_unchecked_mut(tail).as_mut().next = idx; + } else { + debug_assert!(self.is_empty()); + self.head = idx; + } + self.tail = idx; + + let link = entries.get_unchecked_mut(idx).as_mut(); + link.prev = tail; + link.next = usize::MAX; + + self.set_linked_status(idx, true); + } + + /// # Safety + /// + /// Callers must ensure: + /// + /// * `entries` is only managed by [`List`] + /// * `idx` is in bounds of `entries` + /// * `idx` must be less that `usize::MAX` + #[inline] + pub unsafe fn remove(&mut self, entries: &mut [L], idx: usize) + where + L: AsMut, + { + debug_assert!(!self.is_empty()); + + let link = entries.get_unchecked_mut(idx).as_mut(); + let next = link.next; + let prev = link.prev; + link.reset(); + + if prev != usize::MAX { + entries.get_unchecked_mut(prev).as_mut().next = next; + } else { + debug_assert!(self.head == idx); + self.head = next; + } + + if next != usize::MAX { + entries.get_unchecked_mut(next).as_mut().prev = prev; + } else { + debug_assert!(self.tail == idx); + self.tail = prev; + } + + self.set_linked_status(idx, false); + } + + #[inline] + #[cfg_attr(not(debug_assertions), allow(dead_code))] + pub fn iter<'a, L>(&'a self, entries: &'a [L]) -> impl Iterator + '_ + where + L: AsRef, + { + let mut idx = self.head; + core::iter::from_fn(move || { + if idx == usize::MAX { + return None; + } + let res = idx; + idx = entries[idx].as_ref().next; + Some(res) + }) + } + + #[inline(always)] + fn set_linked_status(&mut self, idx: usize, linked: bool) { + if linked { + self.len += 1; + } else { + self.len -= 1; + } + + #[cfg(debug_assertions)] + { + if self.linked.len() <= idx { + self.linked.resize(idx + 1, false); + } + assert_eq!(self.linked[idx], !linked, "{self:?}"); + self.linked[idx] = linked; + let expected_len = self.linked.iter().filter(|&v| *v).count(); + assert_eq!(expected_len, self.len, "{self:?}"); + } + + let _ = idx; + + debug_assert_eq!(self.head == usize::MAX, self.is_empty(), "{self:?}"); + debug_assert_eq!(self.tail == usize::MAX, self.is_empty(), "{self:?}"); + debug_assert_eq!(self.head == usize::MAX, self.tail == usize::MAX, "{self:?}"); + } +} + +#[derive(Debug)] +pub struct Link { + next: usize, + prev: usize, +} + +impl Default for Link { + #[inline] + fn default() -> Self { + Self { + next: usize::MAX, + prev: usize::MAX, + } + } +} + +impl Link { + #[inline] + fn reset(&mut self) { + self.next = usize::MAX; + self.prev = usize::MAX; + } +} + +impl AsRef for Link { + #[inline] + fn as_ref(&self) -> &Link { + self + } +} + +impl AsMut for Link { + #[inline] + fn as_mut(&mut self) -> &mut Link { + self + } +} + +#[cfg(test)] +mod tests { + use bolero::{check, TypeGenerator}; + + use super::*; + use std::collections::VecDeque; + + const LEN: usize = 4; + + enum Location { + A, + B, + } + + #[derive(Default)] + struct CheckedList { + list: List, + oracle: VecDeque, + } + + impl CheckedList { + #[inline] + fn pop(&mut self, entries: &mut [Link]) -> Option { + let v = unsafe { self.list.pop(entries) }; + assert_eq!(v, self.oracle.pop_front()); + self.invariants(entries); + v + } + + #[inline] + fn push(&mut self, entries: &mut [Link], v: usize) { + unsafe { self.list.push(entries, v) }; + self.oracle.push_back(v); + self.invariants(entries); + } + + #[inline] + fn remove(&mut self, entries: &mut [Link], v: usize) { + unsafe { self.list.remove(entries, v) }; + let idx = self.oracle.iter().position(|&x| x == v).unwrap(); + self.oracle.remove(idx); + self.invariants(entries); + } + + #[inline] + fn invariants(&self, entries: &[Link]) { + let actual = self.list.iter(entries); + assert!(actual.eq(self.oracle.iter().copied())); + } + } + + struct Harness { + a: CheckedList, + b: CheckedList, + locations: Vec, + entries: Vec, + } + + impl Default for Harness { + fn default() -> Self { + let mut a = CheckedList::default(); + let mut entries: Vec = (0..LEN).map(|_| Link::default()).collect(); + let locations = (0..LEN).map(|_| Location::A).collect(); + + for idx in 0..LEN { + a.push(&mut entries, idx); + } + + Self { + a, + b: Default::default(), + locations, + entries, + } + } + } + + impl Harness { + #[inline] + fn transfer(&mut self, idx: usize) { + let location = &mut self.locations[idx]; + match location { + Location::A => { + self.a.remove(&mut self.entries, idx); + self.b.push(&mut self.entries, idx); + *location = Location::B; + } + Location::B => { + self.b.remove(&mut self.entries, idx); + self.a.push(&mut self.entries, idx); + *location = Location::A; + } + } + } + + #[inline] + fn pop_a(&mut self) { + if let Some(v) = self.a.pop(&mut self.entries) { + self.b.push(&mut self.entries, v); + self.locations[v] = Location::B; + } + } + + #[inline] + fn pop_b(&mut self) { + if let Some(v) = self.b.pop(&mut self.entries) { + self.a.push(&mut self.entries, v); + self.locations[v] = Location::A; + } + } + } + + #[derive(Clone, Copy, Debug, TypeGenerator)] + enum Op { + Transfer(#[generator(0..LEN)] usize), + PopA, + PopB, + } + + #[test] + fn invariants_test() { + check!().with_type::>().for_each(|ops| { + let mut harness = Harness::default(); + for op in ops { + match op { + Op::Transfer(idx) => harness.transfer(*idx), + Op::PopA => harness.pop_a(), + Op::PopB => harness.pop_b(), + } + } + }) + } +} diff --git a/dc/s2n-quic-dc/src/stream/server/tokio/tcp/manager/tests.rs b/dc/s2n-quic-dc/src/stream/server/tokio/tcp/manager/tests.rs new file mode 100644 index 000000000..04bb79027 --- /dev/null +++ b/dc/s2n-quic-dc/src/stream/server/tokio/tcp/manager/tests.rs @@ -0,0 +1,320 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use super::{Worker as _, *}; +use crate::event::{self, IntoEvent}; +use bolero::{check, TypeGenerator}; +use core::time::Duration; +use std::io; + +const WORKER_COUNT: usize = 4; + +#[derive(Clone, Copy, Debug, TypeGenerator)] +enum Op { + Insert, + Wake { + #[generator(0..WORKER_COUNT)] + idx: usize, + }, + Ready { + #[generator(0..WORKER_COUNT)] + idx: usize, + error: bool, + }, + Advance { + #[generator(1..=10)] + millis: u8, + }, +} + +enum State { + Idle, + Active, + Ready, + Error(io::ErrorKind), +} + +struct Worker { + queue_time: Timestamp, + state: State, + epoch: u64, + poll_count: u64, +} + +impl Worker { + fn new(clock: &C) -> Self + where + C: Clock, + { + Self { + queue_time: clock.get_time(), + state: State::Idle, + epoch: 0, + poll_count: 0, + } + } +} + +impl super::Worker for Worker { + type Context = (); + type ConnectionContext = (); + type Stream = (); + + fn replace( + &mut self, + _remote_address: SocketAddress, + _stream: Self::Stream, + _connection_context: Self::ConnectionContext, + _publisher: &Pub, + clock: &C, + ) where + Pub: EndpointPublisher, + C: Clock, + { + self.queue_time = clock.get_time(); + self.state = State::Active; + self.epoch += 1; + self.poll_count = 0; + } + + fn poll( + &mut self, + _task_cx: &mut task::Context, + _cx: &mut Self::Context, + _publisher: &Pub, + _clock: &C, + ) -> Poll, Option>> + where + Pub: EndpointPublisher, + C: Clock, + { + self.poll_count += 1; + match self.state { + State::Idle => { + unreachable!("shouldn't be polled when idle") + } + State::Active => Poll::Pending, + State::Ready => { + self.state = State::Idle; + Poll::Ready(Ok(ControlFlow::Continue(()))) + } + State::Error(err) => { + self.state = State::Idle; + Poll::Ready(Err(Some(err.into()))) + } + } + } + + fn queue_time(&self) -> Timestamp { + self.queue_time + } + + fn is_active(&self) -> bool { + matches!(self.state, State::Active | State::Ready | State::Error(_)) + } +} + +struct Harness { + manager: Manager, + clock: Timestamp, + subscriber: event::tracing::Subscriber, +} + +impl core::ops::Deref for Harness { + type Target = Manager; + + fn deref(&self) -> &Self::Target { + &self.manager + } +} + +impl core::ops::DerefMut for Harness { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.manager + } +} + +impl Default for Harness { + fn default() -> Self { + let clock = unsafe { Timestamp::from_duration(Duration::from_secs(1)) }; + let manager = Manager::::new((0..WORKER_COUNT).map(|_| Worker::new(&clock))); + let subscriber = event::tracing::Subscriber::default(); + Self { + manager, + clock, + subscriber, + } + } +} + +impl Harness { + pub fn poll(&mut self) { + self.manager.poll( + &mut (), + &publisher(&self.subscriber, &self.clock), + &self.clock, + ); + } + + pub fn insert(&mut self) -> bool { + self.manager.insert( + SocketAddress::default(), + (), + &mut (), + (), + &publisher(&self.subscriber, &self.clock), + &self.clock, + ) + } + + pub fn wake(&mut self, idx: usize) -> bool { + let Entry { worker, waker, .. } = &mut self.manager.inner.workers[idx]; + let is_active = worker.is_active(); + + if is_active { + waker.wake_by_ref(); + } + + is_active + } + + pub fn ready(&mut self, idx: usize) -> bool { + let Entry { worker, waker, .. } = &mut self.manager.inner.workers[idx]; + let is_active = worker.is_active(); + + if is_active { + worker.state = State::Ready; + waker.wake_by_ref(); + } + + is_active + } + + pub fn error(&mut self, idx: usize, error: io::ErrorKind) -> bool { + let Entry { worker, waker, .. } = &mut self.manager.inner.workers[idx]; + let is_active = worker.is_active(); + + if is_active { + worker.state = State::Error(error); + waker.wake_by_ref(); + } + + is_active + } + + pub fn advance(&mut self, time: Duration) { + self.clock += time; + } + + #[track_caller] + pub fn assert_epoch(&self, idx: usize, expected: u64) { + let Entry { worker, .. } = &self.manager.inner.workers[idx]; + assert_eq!(worker.epoch, expected); + } + + #[track_caller] + pub fn assert_poll_count(&self, idx: usize, expected: u64) { + let Entry { worker, .. } = &self.manager.inner.workers[idx]; + assert_eq!(worker.poll_count, expected); + } +} + +fn publisher<'a>( + subscriber: &'a event::tracing::Subscriber, + clock: &Timestamp, +) -> event::EndpointPublisherSubscriber<'a, event::tracing::Subscriber> { + event::EndpointPublisherSubscriber::new( + crate::event::builder::EndpointMeta { + timestamp: clock.into_event(), + }, + None, + subscriber, + ) +} + +#[test] +fn invariants_test() { + check!().with_type::>().for_each(|ops| { + let mut harness = Harness::default(); + + for op in ops { + match op { + Op::Insert => { + harness.insert(); + } + Op::Wake { idx } => { + harness.wake(*idx); + } + Op::Ready { idx, error } => { + if *error { + harness.error(*idx, io::ErrorKind::ConnectionReset); + } else { + harness.ready(*idx); + } + } + Op::Advance { millis } => { + harness.advance(Duration::from_millis(*millis as u64)); + harness.poll(); + } + } + } + + harness.poll(); + }); +} + +#[test] +fn replace_test() { + let mut harness = Harness::default(); + assert_eq!(harness.active_slots(), 0); + assert_eq!(harness.capacity(), WORKER_COUNT); + + for idx in 0..4 { + assert!(harness.insert()); + assert_eq!(harness.active_slots(), 1 + idx); + harness.assert_epoch(idx, 1); + } + + // manager should not replace a slot if sojourn_time hasn't passed + assert!(!harness.insert()); + + // advance the clock by max_sojourn_time + harness.advance(harness.max_sojourn_time()); + harness.poll(); + assert_eq!(harness.active_slots(), WORKER_COUNT); + + for idx in 0..4 { + assert!(harness.insert()); + assert_eq!(harness.active_slots(), WORKER_COUNT); + harness.assert_epoch(idx, 2); + } +} + +#[test] +fn wake_test() { + let mut harness = Harness::default(); + assert!(harness.insert()); + // workers should be polled on insertion + harness.assert_poll_count(0, 1); + // workers should not be polled until woken + harness.poll(); + harness.assert_poll_count(0, 1); + + harness.wake(0); + harness.assert_poll_count(0, 1); + harness.poll(); + harness.assert_poll_count(0, 2); +} + +#[test] +fn ready_test() { + let mut harness = Harness::default(); + + assert_eq!(harness.active_slots(), 0); + assert!(harness.insert()); + assert_eq!(harness.active_slots(), 1); + harness.ready(0); + assert_eq!(harness.active_slots(), 1); + harness.poll(); + assert_eq!(harness.active_slots(), 0); +}