diff --git a/twilight-gateway/src/ratelimiter.rs b/twilight-gateway/src/ratelimiter.rs index bb88d21424e..94ce47a7768 100644 --- a/twilight-gateway/src/ratelimiter.rs +++ b/twilight-gateway/src/ratelimiter.rs @@ -18,7 +18,7 @@ use std::{ pin::Pin, task::{Context, Poll}, }; -use tokio::time::{sleep, Duration, Instant, Sleep}; +use tokio::time::{sleep_until, Duration, Instant, Sleep}; /// Number of commands allowed in a [`PERIOD`]. const COMMANDS_PER_PERIOD: u8 = 120; @@ -37,13 +37,14 @@ pub struct CommandRatelimiter { impl CommandRatelimiter { /// Create a new ratelimiter with some capacity reserved for heartbeating. - pub(crate) async fn new(heartbeat_interval: Duration) -> Self { + pub(crate) fn new(heartbeat_interval: Duration) -> Self { let allotted = nonreserved_commands_per_reset(heartbeat_interval); - let mut delay = Box::pin(sleep(Duration::ZERO)); + let now = Instant::now(); + let mut delay = Box::pin(sleep_until(now)); // Hack to register the timer. - (&mut delay).await; + delay.as_mut().reset(now); Self { delay, @@ -180,7 +181,7 @@ mod tests { #[tokio::test(start_paused = true)] async fn full_reset() { - let mut ratelimiter = CommandRatelimiter::new(HEARTBEAT_INTERVAL).await; + let mut ratelimiter = CommandRatelimiter::new(HEARTBEAT_INTERVAL); assert_eq!(ratelimiter.available(), ratelimiter.max()); for _ in 0..ratelimiter.max() { @@ -199,7 +200,7 @@ mod tests { #[tokio::test(start_paused = true)] async fn half_reset() { - let mut ratelimiter = CommandRatelimiter::new(HEARTBEAT_INTERVAL).await; + let mut ratelimiter = CommandRatelimiter::new(HEARTBEAT_INTERVAL); assert_eq!(ratelimiter.available(), ratelimiter.max()); for _ in 0..ratelimiter.max() / 2 { @@ -226,7 +227,7 @@ mod tests { #[tokio::test(start_paused = true)] async fn constant_capacity() { - let mut ratelimiter = CommandRatelimiter::new(HEARTBEAT_INTERVAL).await; + let mut ratelimiter = CommandRatelimiter::new(HEARTBEAT_INTERVAL); let max = ratelimiter.max(); for _ in 0..max { @@ -240,7 +241,7 @@ mod tests { #[tokio::test(start_paused = true)] async fn spurious_poll() { - let mut ratelimiter = CommandRatelimiter::new(HEARTBEAT_INTERVAL).await; + let mut ratelimiter = CommandRatelimiter::new(HEARTBEAT_INTERVAL); for _ in 0..ratelimiter.max() { ratelimiter.acquire().await; diff --git a/twilight-gateway/src/shard.rs b/twilight-gateway/src/shard.rs index 1fde728c93f..84e35c4d08f 100644 --- a/twilight-gateway/src/shard.rs +++ b/twilight-gateway/src/shard.rs @@ -1062,7 +1062,7 @@ impl Shard { tracing::debug!(?heartbeat_interval, ?jitter, "received hello"); if self.config().ratelimit_messages() { - self.ratelimiter = Some(CommandRatelimiter::new(heartbeat_interval).await); + self.ratelimiter = Some(CommandRatelimiter::new(heartbeat_interval)); } let mut interval = time::interval_at(Instant::now() + jitter, heartbeat_interval);