From 23ca605eabc3cdde14de7daca0d403000132230a Mon Sep 17 00:00:00 2001 From: Tim Vilgot Mikael Fredenberg Date: Tue, 27 Jun 2023 21:38:19 +0200 Subject: [PATCH] refactor(gateway-queue)!: rewrite crate --- book/src/chapter_2_multi-serviced_approach.md | 15 +- twilight-gateway-queue/Cargo.toml | 14 +- twilight-gateway-queue/README.md | 59 ++-- twilight-gateway-queue/src/day_limiter.rs | 124 ------- twilight-gateway-queue/src/large_bot_queue.rs | 113 ------- twilight-gateway-queue/src/lib.rs | 310 ++++++++++++------ twilight-gateway-queue/tests/queue.rs | 113 +++++++ twilight-gateway/src/config.rs | 16 +- twilight-gateway/src/shard.rs | 50 ++- twilight-gateway/src/stream.rs | 3 +- 10 files changed, 372 insertions(+), 445 deletions(-) delete mode 100644 twilight-gateway-queue/src/day_limiter.rs delete mode 100644 twilight-gateway-queue/src/large_bot_queue.rs create mode 100644 twilight-gateway-queue/tests/queue.rs diff --git a/book/src/chapter_2_multi-serviced_approach.md b/book/src/chapter_2_multi-serviced_approach.md index dcd1ff39a67..0a305bdecd1 100644 --- a/book/src/chapter_2_multi-serviced_approach.md +++ b/book/src/chapter_2_multi-serviced_approach.md @@ -14,19 +14,18 @@ One of the popular design choices when creating a multi-serviced application is to have a service that only connects shards to the gateway and sends the events to a broker to be processed. As bots grow into hundreds or thousands of shards, multiple instances of the application can be created and groups of shards can be -managed by each. Twilight is an excellent choice for this use case: you can -receive either events that come in in a loop and send the payloads to the -appropriate broker stream, or you can loop over received payloads' bytes to send -off. +managed by each. Twilight is an excellent choice for this use case: just receive +and send the payloads to the appropriate broker stream. Twilight shards need +only partially deserialize payloads to function. ## Gateway session ratelimiting If multiple shard groups are used, then they need to be queued and their session -initialization ratelimited. The Gateway includes a Queue trait which can be +initialization ratelimited. The gateway includes a Queue trait which can be implemented; the gateway will submit a request to the queue before starting a -session. Twilight comes with a queue that supports sharding and Large Bot -sharding, but when multiple shard groups are in use then a custom queue will -need to be implemented. Refer to [gateway-queue] for an example of this. +session. Twilight comes with a queue that supports Large Bot sharding, but when +multiple shard groups are in use then a custom queue will need to be implemented. +Refer to [gateway-queue] for an example of this. ## HTTP proxy ratelimiting diff --git a/twilight-gateway-queue/Cargo.toml b/twilight-gateway-queue/Cargo.toml index 513d60f7023..891394ef613 100644 --- a/twilight-gateway-queue/Cargo.toml +++ b/twilight-gateway-queue/Cargo.toml @@ -13,18 +13,8 @@ rust-version.workspace = true version = "0.15.2" [dependencies] -tokio = { default-features = false, features = ["rt", "sync", "time"], version = "1.0" } -tracing = { default-features = false, features = ["std", "attributes"], version = "0.1" } - -# Optional dependencies. -twilight-http = { default-features = false, optional = true, path = "../twilight-http", version = "0.15.2" } +tokio = { default-features = false, features = ["macros", "rt", "sync", "time"], version = "1.0" } [dev-dependencies] static_assertions = { default-features = false, version = "1" } - -[features] -default = ["twilight-http"] - -[package.metadata.docs.rs] -all-features = true -rustdoc-args = ["--cfg", "docsrs"] +tokio = { default-features = false, features = ["test-util"], version = "1.0" } diff --git a/twilight-gateway-queue/README.md b/twilight-gateway-queue/README.md index edd8fc2bb6c..3c0c3f2149a 100644 --- a/twilight-gateway-queue/README.md +++ b/twilight-gateway-queue/README.md @@ -1,42 +1,21 @@ # twilight-gateway-queue -Ratelimiting functionality for queueing new gateway sessions. - -The gateway ratelimits how often clients can initialize new sessions. -Instances of a queue are given to shards so that they can request to -initialize a session. - -Queue implementations must point to the same broker so that all shards -across all shard groups, processes, and other forms of multi-serviced -applications, can work together and use the same ratelimiting source. That is, -if two shard groups are in two different processes, then the the two processes -must use some unified form of ratelimiting: this can either mean using IPC to -communicate ratelimiting or a broker. - -## Provided queues - -Most users only need the [`LocalQueue`]: it's a single-process queue for -smaller bots. Larger bots need the [`LargeBotQueue`], which supports -single-process [Sharding for Large Bots] through the use of bucket -releasing. - -By default, the gateway's `stream` module and `Shard`s use the [`LocalQueue`]. -This can be overridden via the `ShardBuilder::queue` configuration method. - -## Advanced use cases - -Large bots, and smaller bots out of design, may need to implement their own -queue. The most common reason to need this is if you have shard groups in -multiple processes. A broker to manage ratelimiting across shard groups is -required, so a [`Queue`] trait is provided that shards can use to make requests -to create sessions. - -## Features - -### Twilight-HTTP - -The `twilight-http` feature brings in support for [`LargeBotQueue`]. - -This is enabled by default. - -[Sharding for Large Bots]: https://discord.com/developers/docs/topics/gateway#sharding-for-large-bots +Rate limiting functionality for gateway `IDENTIFY` commands. + +Discord allows bot's shards to send a limited amount of `IDENTIFY` commands +every 5 seconds, with a daily limit from 1000 to 2000 commands, and invalidates +*all* shard sessions upon exceeding it. Each identify window is called a bucket +and is filled by shards' IDs modulo `max_concurrency`. Buckets must furthermore +be started in order. See [Discord Docs/Sharding]. + +To coordinate this, a [`Queue`] should process each identify request and shards +should wait for its signal to proceed before continuing and otherwise retry. The +provided [`InMemoryQueue`] never fails or cancels requests and is therefore a +good starting point for custom implementations. For most cases, simply wrapping +[`InMemoryQueue`] is be enough to add new capabilities such as multi-process +support, see [`gateway-queue`] and [`gateway-queue-http`]. Integration tests can +be found [here](https://github.com/twilight-rs/twilight/blob/main/twilight-gateway-queue/src/lib.rs). + +[Discord Docs/Sharding]: https://discord.com/developers/docs/topics/gateway#sharding +[`gateway-queue`]: https://github.com/twilight-rs/gateway-queue +[`gateway-queue-http`]: https://github.com/twilight-rs/twilight/blob/main/examples/gateway-queue-http.rs diff --git a/twilight-gateway-queue/src/day_limiter.rs b/twilight-gateway-queue/src/day_limiter.rs deleted file mode 100644 index f9fd8f11e02..00000000000 --- a/twilight-gateway-queue/src/day_limiter.rs +++ /dev/null @@ -1,124 +0,0 @@ -use std::{ - error::Error, - fmt::{Display, Formatter, Result as FmtResult}, - sync::Arc, - time::Duration, -}; -use tokio::{ - sync::Mutex, - time::{self, Instant}, -}; -use twilight_http::Client; - -/// Creating a day limiter queue failed. -#[derive(Debug)] -pub struct DayLimiterError { - kind: DayLimiterErrorType, - source: Option>, -} - -impl Display for DayLimiterError { - fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { - match &self.kind { - DayLimiterErrorType::RetrievingSessionAvailability { .. } => { - f.write_str("retrieving the bot's gateway session availability failed") - } - } - } -} - -impl Error for DayLimiterError { - fn source(&self) -> Option<&(dyn Error + 'static)> { - self.source - .as_ref() - .map(|source| &**source as &(dyn Error + 'static)) - } -} - -/// Type of [`DayLimiterError`] that occurred. -#[derive(Debug)] -#[non_exhaustive] -pub enum DayLimiterErrorType { - /// Retrieving the bot's available gateway session initiation information - /// via the HTTP API failed. - RetrievingSessionAvailability, -} - -#[derive(Debug)] -pub(crate) struct DayLimiter(pub(crate) Mutex); - -#[derive(Debug)] -pub(crate) struct DayLimiterInner { - pub http: Arc, - pub last_check: Instant, - pub next_reset: Duration, - pub total: u64, - pub current: u64, -} - -impl DayLimiter { - pub async fn new(http: Arc) -> Result { - let info = http - .gateway() - .authed() - .await - .map_err(|source| DayLimiterError { - kind: DayLimiterErrorType::RetrievingSessionAvailability, - source: Some(Box::new(source)), - })? - .model() - .await - .map_err(|source| DayLimiterError { - kind: DayLimiterErrorType::RetrievingSessionAvailability, - source: Some(Box::new(source)), - })?; - - let last_check = Instant::now(); - - let next_reset = Duration::from_millis(info.session_start_limit.reset_after); - let total = info.session_start_limit.total; - let remaining = info.session_start_limit.remaining; - debug_assert!(total >= remaining); - let current = total - remaining; - Ok(DayLimiter(Mutex::new(DayLimiterInner { - http, - last_check, - next_reset, - total: info.session_start_limit.total, - current, - }))) - } - - pub async fn get(&self) { - let mut lock = self.0.lock().await; - if lock.current < lock.total { - lock.current += 1; - } else { - let wait = lock.last_check + lock.next_reset; - time::sleep_until(wait).await; - if let Ok(res) = lock.http.gateway().authed().await { - if let Ok(info) = res.model().await { - let last_check = Instant::now(); - let next_reset = Duration::from_millis(info.session_start_limit.reset_after); - - tracing::info!("next session start limit reset in: {next_reset:.2?}"); - - let total = info.session_start_limit.total; - let remaining = info.session_start_limit.remaining; - assert!(total >= remaining); - let current = total - remaining; - lock.last_check = last_check; - lock.next_reset = next_reset; - lock.total = total; - lock.current = current + 1; - - return; - } - } - - tracing::warn!( - "unable to get new session limits, skipping (this may cause bad things)" - ); - } - } -} diff --git a/twilight-gateway-queue/src/large_bot_queue.rs b/twilight-gateway-queue/src/large_bot_queue.rs deleted file mode 100644 index 1cc66952b05..00000000000 --- a/twilight-gateway-queue/src/large_bot_queue.rs +++ /dev/null @@ -1,113 +0,0 @@ -use super::{day_limiter::DayLimiter, Queue}; -use std::{fmt::Debug, future::Future, pin::Pin, sync::Arc, time::Duration}; -use tokio::{ - sync::{ - mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}, - oneshot::{self, Sender}, - }, - time::sleep, -}; -use twilight_http::Client; - -/// Queue built for single-process groups of shards that require identifying via -/// [Sharding for Large Bots]. -/// -/// Usage with other processes will cause inconsistencies between each process's -/// ratelimit buckets. If using multiple processes for shard groups, then refer -/// to the [module-level] documentation. -/// -/// [Sharding for Large Bots]: https://discord.com/developers/docs/topics/gateway#sharding-for-very-large-bots -/// [module-level]: crate -#[derive(Debug)] -pub struct LargeBotQueue { - buckets: Vec>>, - limiter: DayLimiter, -} - -impl LargeBotQueue { - /// Create a new large bot queue. - /// - /// You must provide the number of buckets Discord requires your bot to - /// connect with. - /// - /// The number of buckets is provided via Discord as `max_concurrency` - /// which can be fetched with [`Client::gateway`]. - pub async fn new(buckets: usize, http: Arc) -> Self { - let mut queues = Vec::with_capacity(buckets); - for _ in 0..buckets { - let (tx, rx) = unbounded_channel(); - - tokio::spawn(waiter(rx)); - - queues.push(tx); - } - - let limiter = DayLimiter::new(http).await.expect( - "Getting the first session limits failed, \ - Is network connection available?", - ); - - // The level_enabled macro does not turn off with the dynamic - // tracing levels. It is made for the static_max_level_xxx features - // And will return false if you do not use those features of if - // You use the feature but then dynamically set a lower feature. - if tracing::level_enabled!(tracing::Level::INFO) { - let lock = limiter.0.lock().await; - - tracing::info!( - "{}/{} identifies used before next reset in {:.2?}", - lock.current, - lock.total, - lock.next_reset - ); - } - - Self { - buckets: queues, - limiter, - } - } -} - -async fn waiter(mut rx: UnboundedReceiver>) { - const DUR: Duration = Duration::from_secs(6); - while let Some(req) = rx.recv().await { - if let Err(source) = req.send(()) { - tracing::warn!("skipping, send failed with: {source:?}"); - } else { - sleep(DUR).await; - } - } -} - -impl Queue for LargeBotQueue { - /// Request to be able to identify with the gateway. This will place this - /// request behind all other requests, and the returned future will resolve - /// once the request has been completed. - fn request(&'_ self, shard_id: [u64; 2]) -> Pin + Send + '_>> { - #[allow(clippy::cast_possible_truncation)] - let bucket = (shard_id[0] % (self.buckets.len() as u64)) as usize; - let (tx, rx) = oneshot::channel(); - - Box::pin(async move { - self.limiter.get().await; - if let Err(source) = self.buckets[bucket].send(tx) { - tracing::warn!("skipping, send failed with: {source:?}"); - return; - } - - tracing::info!("waiting for allowance on shard {}", shard_id[0]); - - _ = rx.await; - }) - } -} - -#[cfg(test)] -mod tests { - use super::{LargeBotQueue, Queue}; - use static_assertions::assert_impl_all; - use std::fmt::Debug; - - assert_impl_all!(LargeBotQueue: Debug, Queue, Send, Sync); -} diff --git a/twilight-gateway-queue/src/lib.rs b/twilight-gateway-queue/src/lib.rs index 65db5f79f46..e97a0d4da3d 100644 --- a/twilight-gateway-queue/src/lib.rs +++ b/twilight-gateway-queue/src/lib.rs @@ -1,153 +1,245 @@ -#![cfg_attr(docsrs, feature(doc_auto_cfg))] #![doc = include_str!("../README.md")] #![warn( clippy::missing_const_for_fn, + clippy::missing_docs_in_private_items, clippy::pedantic, missing_docs, unsafe_code )] -#![allow( - clippy::module_name_repetitions, - clippy::must_use_candidate, - clippy::unnecessary_wraps -)] - -#[cfg(feature = "twilight-http")] -mod day_limiter; -#[cfg(feature = "twilight-http")] -mod large_bot_queue; - -#[cfg(feature = "twilight-http")] -pub use large_bot_queue::LargeBotQueue; +#![allow(clippy::must_use_candidate)] -use std::{ - fmt::Debug, - future::{self, Future}, - pin::Pin, - time::Duration, -}; +use std::{collections::VecDeque, fmt::Debug, iter}; use tokio::{ - sync::{ - mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}, - oneshot::{self, Sender}, - }, - time::sleep, + sync::{mpsc, oneshot}, + time::{sleep_until, Duration, Instant}, }; -/// Queue for shards to request the ability to initialize new sessions with the -/// gateway. -/// -/// This will usually only need to be implemented when you have a multi-process -/// sharding setup. Refer to the [module-level] documentation for more -/// information. +/// Delay between buckets. +pub const IDENTIFY_WINDOW: Duration = Duration::from_secs(5); + +/// Abstraction for types processing gateway identify requests /// -/// [module-level]: crate +/// For convenience in twilight-gateway, implementors must also implement +/// [`Debug`], [`Send`], and [`Sync`]. pub trait Queue: Debug + Send + Sync { - /// A shard has requested the ability to request a session initialization - /// with the gateway. + /// Enqueue a shard with this ID. + /// + /// Send `()` to signal the shard to proceed. Note that shards may drop the + /// receiver prior. /// - /// The returned future must resolve only when the shard can initiate the - /// session. - fn request<'a>(&'a self, shard_id: [u64; 2]) -> Pin + Send + 'a>>; + /// Closing the channel should cause the shard to requeue. + fn enqueue(&self, id: u64) -> oneshot::Receiver<()>; } -/// A local, in-process implementation of a [`Queue`] which manages the -/// connection attempts of one or more shards. -/// -/// The queue will take incoming requests and then queue them, releasing one of -/// the requests every 6 seconds. The queue is necessary because there's a -/// ratelimit on how often shards can initiate sessions. -/// -/// Handling shard queues usually won't need to be manually handled due to the -/// gateway having built-in queueing when managing multiple shards. +impl Queue for &T +where + T: Queue, +{ + fn enqueue(&self, shard: u64) -> oneshot::Receiver<()> { + (**self).enqueue(shard) + } +} + +/// Possible messages from the [`InMemoryQueue`] to the [`runner`]. +#[derive(Debug)] +enum Message { + /// Request a permit. + Request { + /// For this shard. + shard: u64, + /// Indicate readiness through this sender. + tx: oneshot::Sender<()>, + }, + /// Update the runner's settings. + Update(Settings), +} + +/// [`runner`]'s settings. +#[derive(Debug)] +struct Settings { + /// The maximum number of concurrent permits to grant. `0` instantly grants + /// all permits. + max_concurrency: u8, + /// Remaining daily permits. + remaining: u16, + /// Time until the daily permits reset. + reset_at: Instant, + /// The number of permits to reset to. + total: u16, +} + +/// [`InMemoryQueue`]'s background task runner. +async fn runner( + mut rx: mpsc::UnboundedReceiver, + Settings { + mut max_concurrency, + mut remaining, + reset_at, + mut total, + }: Settings, +) { + let reset_at = sleep_until(reset_at); + let identify_window_tracker = sleep_until(Instant::now()); + tokio::pin! { + reset_at, + identify_window_tracker + }; + let create_queues = |max_concurrency: u8| { + iter::repeat_with(VecDeque::new) + .take(max_concurrency.into()) + .collect::>() + }; + let mut queues = create_queues(max_concurrency); + + 'outer: loop { + tokio::select! { + biased; + _ = &mut reset_at => { + remaining = total; + reset_at.as_mut().reset(Instant::now() + Duration::from_secs(60 * 60 * 24)); + } + message = rx.recv() => { + match message { + Some(Message::Request{shard, tx}) => { + if max_concurrency == 0 { + _ = tx.send(()); + } else { + #[allow(clippy::cast_possible_truncation)] + queues[(shard % u64::from(max_concurrency)) as usize].push_back((shard, tx)); + } + } + Some(Message::Update(update)) => { + let deadline; + Settings {max_concurrency, remaining, reset_at: deadline, total} = update; + + if queues.len() != max_concurrency as usize { + let unbalanced = queues.into_iter().flatten(); + queues = create_queues(max_concurrency); + for (shard, tx) in unbalanced { + #[allow(clippy::cast_possible_truncation)] + queues[(shard % u64::from(max_concurrency)) as usize].push_back((shard, tx)); + } + } + reset_at.as_mut().reset(deadline); + } + None => break, + } + } + _ = &mut identify_window_tracker, if queues.iter().any(|b| !b.is_empty()) => { + for queue in &mut queues { + if remaining == 0 { + (&mut reset_at).await; + let now = Instant::now(); + remaining = total; + reset_at.as_mut().reset(now + Duration::from_secs(60 * 60 * 24)); + + // Goto next window. + let left_in_window = IDENTIFY_WINDOW.saturating_sub(now - identify_window_tracker.deadline()); + identify_window_tracker.as_mut().reset(now + left_in_window); + + continue 'outer; + } + while let Some((_, tx)) = queue.pop_front() { + if tx.is_closed() { + continue; + } + _ = tx.send(()); + remaining -= 1; + break; + } + } + identify_window_tracker.as_mut().reset(Instant::now() + IDENTIFY_WINDOW); + } + } + } +} + +/// Memory based [`Queue`] backed by an efficient background task. /// -/// # When not to use this +/// [`InMemoryQueue::update`] allows for dynamically changing the queue's +/// settings. /// -/// This queue implementation is "local", meaning it's intended to be used if -/// you manage shards only in this process. If you run shards in multiple -/// different processes (do you utilize microservices a lot?), then you **must -/// not** use this implementation. Shards across multiple processes may -/// create new sessions at the same time, which is bad. +/// Cloning the queue is cheap and just increments a reference counter. /// -/// It should also not be used for very large sharding, for that the -/// [`LargeBotQueue`] can be used. +/// # Settings /// -/// If you can't use this, look into an alternative implementation of the -/// [`Queue`], such as the [`gateway-queue`] broker. +/// `remaining` is reset to `total` after `reset_after` and then every 24 hours. /// -/// [`gateway-queue`]: https://github.com/twilight-rs/gateway-queue +/// A `max_concurrency` of `0` processes all requests instantly, effectively +/// disabling the queue. #[derive(Clone, Debug)] -pub struct LocalQueue(UnboundedSender>); - -impl Default for LocalQueue { - fn default() -> Self { - Self::new() - } +pub struct InMemoryQueue { + /// Sender to communicate with the background [task runner]. + /// + /// [task runner]: runner + tx: mpsc::UnboundedSender, } -impl LocalQueue { - /// Creates a new local queue. - pub fn new() -> Self { - let (tx, rx) = unbounded_channel(); - - tokio::spawn(waiter(rx)); - - Self(tx) +impl InMemoryQueue { + /// Creates a new `InMemoryQueue` with custom settings. + pub fn new(max_concurrency: u8, remaining: u16, reset_after: Duration, total: u16) -> Self { + let (tx, rx) = mpsc::unbounded_channel(); + + tokio::spawn(runner( + rx, + Settings { + max_concurrency, + remaining, + reset_at: Instant::now() + reset_after, + total, + }, + )); + + Self { tx } } -} -async fn waiter(mut rx: UnboundedReceiver>) { - const DUR: Duration = Duration::from_secs(6); - while let Some(req) = rx.recv().await { - if let Err(source) = req.send(()) { - tracing::warn!("skipping, send failed: {source:?}"); - } else { - sleep(DUR).await; - } + /// Update the queue with new info from the [Get Gateway Bot] endpoint. + /// + /// May be regularly called as the bot joins/leaves guilds. + /// + /// [Get Gateway Bot]: https://discord.com/developers/docs/topics/gateway#get-gateway-bot + pub fn update(&self, max_concurrency: u8, remaining: u16, reset_after: Duration, total: u16) { + self.tx + .send(Message::Update(Settings { + max_concurrency, + remaining, + reset_at: Instant::now() + reset_after, + total, + })) + .expect("receiver dropped after sender"); } } -impl Queue for LocalQueue { - /// Request to be able to identify with the gateway. This will place this - /// request behind all other requests, and the returned future will resolve - /// once the request has been completed. - fn request(&'_ self, [id, total]: [u64; 2]) -> Pin + Send + '_>> { - Box::pin(async move { - let (tx, rx) = oneshot::channel(); - - if let Err(source) = self.0.send(tx) { - tracing::warn!("skipping, send failed: {source:?}"); - return; - } - - tracing::info!("shard {id}/{total} waiting for allowance"); - - _ = rx.await; - }) +impl Default for InMemoryQueue { + /// Creates a new queue with the default settings of: + /// + /// * `max_concurrency`: 1 + /// * `remaining`: 1000 + /// * `reset_after`: 24 hours + /// * `total`: 1000. + fn default() -> Self { + Self::new(1, 1000, Duration::from_secs(60 * 60 * 24), 1000) } } -/// An implementation of [`Queue`] that instantly allows requests. -/// -/// Useful when running behind a proxy gateway. Running without a -/// functional queue **will** get you ratelimited. -#[derive(Debug)] -pub struct NoOpQueue; +impl Queue for InMemoryQueue { + fn enqueue(&self, shard: u64) -> oneshot::Receiver<()> { + let (tx, rx) = oneshot::channel(); + self.tx + .send(Message::Request { shard, tx }) + .expect("receiver dropped after sender"); -impl Queue for NoOpQueue { - fn request(&'_ self, [_id, _total]: [u64; 2]) -> Pin + Send + '_>> { - Box::pin(future::ready(())) + rx } } #[cfg(test)] mod tests { - use super::{LocalQueue, NoOpQueue, Queue}; + use super::{InMemoryQueue, Queue}; use static_assertions::{assert_impl_all, assert_obj_safe}; use std::fmt::Debug; - assert_impl_all!(LocalQueue: Clone, Debug, Queue, Send, Sync); - assert_impl_all!(NoOpQueue: Debug, Queue, Send, Sync); + assert_impl_all!(InMemoryQueue: Clone, Debug, Default, Send, Sync); assert_impl_all!(dyn Queue: Debug, Send, Sync); assert_obj_safe!(Queue); } diff --git a/twilight-gateway-queue/tests/queue.rs b/twilight-gateway-queue/tests/queue.rs new file mode 100644 index 00000000000..da0ddb4e83b --- /dev/null +++ b/twilight-gateway-queue/tests/queue.rs @@ -0,0 +1,113 @@ +use std::{ + future::{poll_fn, Future}, + pin::Pin, +}; +use tokio::time::{advance, timeout, Duration, Instant}; +use twilight_gateway_queue::{InMemoryQueue, Queue, IDENTIFY_WINDOW}; + +async fn same_id_is_serial(queue: impl Queue) { + let now = Instant::now(); + + let mut t1 = queue.enqueue(0); + let mut t2 = queue.enqueue(0); + + _ = poll_fn(|cx| Pin::new(&mut t1).poll(cx)).await; + _ = poll_fn(|cx| Pin::new(&mut t2).poll(cx)).await; + + assert!(now.elapsed() >= IDENTIFY_WINDOW, "ran concurrently"); +} + +/// Requires a queue with `max_concurrency` > 1. +async fn different_id_is_parallel(queue: impl Queue) { + let now = Instant::now(); + + let mut t1 = queue.enqueue(0); + let mut t2 = queue.enqueue(1); + + _ = poll_fn(|cx| Pin::new(&mut t1).poll(cx)).await; + _ = poll_fn(|cx| Pin::new(&mut t2).poll(cx)).await; + + assert!(now.elapsed() < IDENTIFY_WINDOW, "ran serially"); +} + +/// Requires a queue with `total` >= 1. +async fn reset_after_refills(queue: impl Queue, reset_after: Duration) { + let mut t1 = queue.enqueue(0); + + let duration = reset_after + Duration::from_nanos(1); + assert!( + timeout(duration, poll_fn(|cx| Pin::new(&mut t1).poll(cx))) + .await + .is_ok(), + "did not refill in time" + ); +} + +/// Requires a queue with `max_concurrency` >= 4. +async fn multi_bucket(queue: impl Queue) { + let now = Instant::now(); + + let mut t1 = queue.enqueue(0); + let mut t2 = queue.enqueue(1); + let mut t3 = queue.enqueue(3); + let mut t4 = queue.enqueue(3); + + _ = poll_fn(|cx| Pin::new(&mut t1).poll(cx)).await; + _ = poll_fn(|cx| Pin::new(&mut t2).poll(cx)).await; + _ = poll_fn(|cx| Pin::new(&mut t3).poll(cx)).await; + + assert!(now.elapsed() < IDENTIFY_WINDOW, "ran concurrently"); + + _ = poll_fn(|cx| Pin::new(&mut t4).poll(cx)).await; + + assert!(now.elapsed() >= IDENTIFY_WINDOW, "ran serially"); +} + +#[tokio::test] +async fn memory_disabled_is_instant() { + let queue = InMemoryQueue::new(0, 1000, Duration::from_secs(60 * 60 * 24), 1000); + let now = Instant::now(); + + let mut t1 = queue.enqueue(0); + let mut t2 = queue.enqueue(0); + + _ = poll_fn(|cx| Pin::new(&mut t1).poll(cx)).await; + _ = poll_fn(|cx| Pin::new(&mut t2).poll(cx)).await; + + assert!(now.elapsed() < IDENTIFY_WINDOW, "did not run instantly"); +} + +#[tokio::test] +async fn memory_update_fills_bucket() { + let queue = InMemoryQueue::new(1, 1000, Duration::from_secs(60 * 60 * 24), 1000); + let now = Instant::now(); + + // Background task not run due to single-threaded runtime. + let mut t1 = queue.enqueue(0); + let mut t2 = queue.enqueue(1); + queue.update(2, 1000, Duration::from_secs(60 * 60 * 24), 1000); + + _ = poll_fn(|cx| Pin::new(&mut t1).poll(cx)).await; + _ = poll_fn(|cx| Pin::new(&mut t2).poll(cx)).await; + + assert!(now.elapsed() < IDENTIFY_WINDOW, "ran serially"); +} + +#[tokio::test(start_paused = true)] +async fn memory_integration() { + let queue = InMemoryQueue::new(1, 1000, Duration::from_secs(60 * 60 * 24), 1000); + + same_id_is_serial(queue.clone()).await; + + advance(IDENTIFY_WINDOW).await; + queue.update(2, 1000, Duration::from_secs(60 * 60 * 24), 1000); + different_id_is_parallel(queue.clone()).await; + + advance(IDENTIFY_WINDOW).await; + queue.update(1, 0, Duration::from_secs(60), 1); + reset_after_refills(queue.clone(), Duration::from_secs(60)).await; + + advance(IDENTIFY_WINDOW).await; + queue.update(4, 1000, Duration::from_secs(60 * 60 * 24), 1000); + multi_bucket(queue).await; +} diff --git a/twilight-gateway/src/config.rs b/twilight-gateway/src/config.rs index 2fe7e17f60c..e4a9ce65db4 100644 --- a/twilight-gateway/src/config.rs +++ b/twilight-gateway/src/config.rs @@ -1,11 +1,14 @@ //! User configuration for shards. -use crate::{tls::TlsContainer, EventTypeFlags, Session}; +use crate::{ + queue::{InMemoryQueue, Queue}, + tls::TlsContainer, + EventTypeFlags, Session, +}; use std::{ fmt::{Debug, Formatter, Result as FmtResult}, sync::Arc, }; -use twilight_gateway_queue::{LocalQueue, Queue}; use twilight_model::gateway::{ payload::outgoing::{identify::IdentifyProperties, update_presence::UpdatePresencePayload}, Intents, @@ -190,7 +193,7 @@ impl ConfigBuilder { large_threshold: 50, presence: None, proxy_url: None, - queue: Arc::new(LocalQueue::new()), + queue: Arc::new(InMemoryQueue::default()), ratelimit_messages: true, session: None, tls: TlsContainer::new().unwrap(), @@ -347,12 +350,7 @@ impl ConfigBuilder { } /// Set the queue to use for queueing shard sessions. - /// - /// Defaults to a [`LocalQueue`]. - /// - /// Refer to the [`queue`] module for more information. - /// - /// [`queue`]: crate::queue + #[allow(clippy::missing_const_for_fn)] pub fn queue(mut self, queue: Arc) -> Self { self.inner.queue = queue; diff --git a/twilight-gateway/src/shard.rs b/twilight-gateway/src/shard.rs index 04f6a122541..d4cba751da5 100644 --- a/twilight-gateway/src/shard.rs +++ b/twilight-gateway/src/shard.rs @@ -30,8 +30,8 @@ //! which is then forwarded via [`Shard::close`]; or //! b. the interval for the shard to send the next heartbeat occurs, in which //! case [`Shard::heartbeat`] is called; or -//! c. the background identify queue task finishes, in which case -//! [`Shard::send`] is called with the identify payload; or +//! c. the identify token is ready, in which case [`Shard::send`] is called +//! with the identify payload; or //! d. the shard receives a command from the user over the [user channel], //! which is then forwarded via [`Shard::send`]; or //! e. the shard receives a message from Discord via the websocket connection. @@ -86,7 +86,7 @@ use std::{ task::{Context, Poll}, }; use tokio::{ - task::JoinHandle, + sync::oneshot, time::{self, Duration, Instant, Interval, MissedTickBehavior}, }; use tokio_tungstenite::tungstenite::{Error as TungsteniteError, Message as TungsteniteMessage}; @@ -354,8 +354,8 @@ pub struct Shard { heartbeat_interval_event: bool, /// ID of the shard. id: ShardId, - /// Identify queue background task handle. - identify_handle: Option>, + /// Oneshot receiver from the identify queue. + identify_permit: Option>, /// Zlib decompressor. #[cfg(any(feature = "zlib-stock", feature = "zlib-simd"))] inflater: Inflater, @@ -401,7 +401,7 @@ impl Shard { heartbeat_interval: None, heartbeat_interval_event: false, id: shard_id, - identify_handle: None, + identify_permit: None, #[cfg(any(feature = "zlib-stock", feature = "zlib-simd"))] inflater: Inflater::new(), latency: Latency::new(), @@ -556,7 +556,7 @@ impl Shard { /// /// Returns a [`ReceiveMessageErrorType::SendingMessage`] error type if the /// shard failed to send a message to the gateway, such as a heartbeat. - #[tracing::instrument(fields(id = %self.id()), name = "shard", skip(self))] + #[tracing::instrument(fields(id = %self.id), name = "shard", skip(self))] pub async fn next_message(&mut self) -> Result { /// Actions the shard might take. enum Action { @@ -611,13 +611,18 @@ impl Shard { ratelimiter.poll_available(cx).is_pending() }); - if !ratelimited - && self - .identify_handle + if !ratelimited { + if let Some(Poll::Ready(canceled)) = self + .identify_permit .as_mut() - .map_or(false, |handle| Pin::new(handle).poll(cx).is_ready()) - { - return Poll::Ready(Action::Identify); + .map(|rx| Pin::new(rx).poll(cx).map(|r| r.is_err())) + { + if !canceled { + self.identify_permit = None; + return Poll::Ready(Action::Identify); + } + self.identify_permit = Some(self.config.queue().enqueue(self.id.number())); + } } if !ratelimited && self.status.is_identified() { @@ -726,8 +731,6 @@ impl Shard { continue; } Action::Identify => { - self.identify_handle = None; - tracing::debug!("sending identify"); let identify = Identify::new(IdentifyInfo { compress: false, @@ -739,7 +742,7 @@ impl Shard { .identify_properties() .cloned() .unwrap_or_else(default_identify_properties), - shard: Some(self.id()), + shard: Some(self.id), token: self.config.token().to_owned(), }); let json = @@ -1107,7 +1110,7 @@ impl Shard { let jitter = heartbeat_interval.mul_f64(rand::random()); tracing::debug!(?heartbeat_interval, ?jitter, "received hello"); - if self.config().ratelimit_messages() { + if self.config.ratelimit_messages() { self.ratelimiter = Some(CommandRatelimiter::new(heartbeat_interval).await); } @@ -1123,21 +1126,12 @@ impl Shard { Some(session) => { tracing::debug!(sequence = session.sequence(), "sending resume"); let resume = - Resume::new(session.sequence(), session.id(), self.config().token()); + Resume::new(session.sequence(), session.id(), self.config.token()); let json = command::prepare(&resume).map_err(ProcessError::from_send)?; self.send(json).await.map_err(ProcessError::from_send)?; } None => { - // Can not use `MessageSender` since it is only polled - // after the shard is identified. - self.identify_handle = Some(tokio::spawn({ - let shard_id = self.id(); - let queue = self.config().queue().clone(); - - async move { - queue.request([shard_id.number(), shard_id.total()]).await; - } - })); + self.identify_permit = Some(self.config.queue().enqueue(self.id.number())); } } } diff --git a/twilight-gateway/src/stream.rs b/twilight-gateway/src/stream.rs index 6711fd1024b..e77cac9f376 100644 --- a/twilight-gateway/src/stream.rs +++ b/twilight-gateway/src/stream.rs @@ -27,7 +27,6 @@ //! //! See the [gateway-parallel] example for how to implement this. //! -//! [`ConfigBuilder::queue`]: crate::ConfigBuilder::queue //! [gateway-parallel]: https://github.com/twilight-rs/twilight/blob/main/examples/gateway-parallel.rs //! [session queue]: crate::queue @@ -419,7 +418,7 @@ pub fn create_bucket Config>( /// /// ```no_run /// use std::{collections::HashMap, env, sync::Arc}; -/// use twilight_gateway::{queue::LocalQueue, stream, Config, Intents}; +/// use twilight_gateway::{stream, Config, Intents}; /// /// # #[tokio::main] /// # async fn main() -> Result<(), Box> {