Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(http-ratelimiting): add a bucket for global rate limits #2159

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 32 additions & 33 deletions twilight-http-ratelimiting/src/in_memory/bucket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
//! consumed by the [`BucketQueueTask`] that manages the ratelimit for the bucket
//! and respects the global ratelimit.

use super::GlobalLockPair;
use super::GlobalBucket;
use crate::ticket::TicketSender;
use crate::{headers::RatelimitHeaders, request::Path, ticket::TicketNotifier};
use std::{
collections::HashMap,
Expand All @@ -13,6 +14,7 @@ use std::{
},
time::{Duration, Instant},
};
use tokio::sync::oneshot::error::RecvError;
use tokio::{
sync::{
mpsc::{self, UnboundedReceiver, UnboundedSender},
Expand Down Expand Up @@ -55,11 +57,11 @@ impl Bucket {
/// Create a new bucket for the specified [`Path`].
pub fn new(path: Path) -> Self {
Self {
limit: AtomicU64::new(u64::max_value()),
limit: AtomicU64::new(u64::MAX),
path,
queue: BucketQueue::default(),
remaining: AtomicU64::new(u64::max_value()),
reset_after: AtomicU64::new(u64::max_value()),
remaining: AtomicU64::new(u64::MAX),
reset_after: AtomicU64::new(u64::MAX),
started_at: Mutex::new(None),
}
}
Expand Down Expand Up @@ -134,7 +136,7 @@ impl Bucket {
}

if let Some((limit, remaining, reset_after)) = ratelimits {
if bucket_limit != limit && bucket_limit == u64::max_value() {
if bucket_limit != limit && bucket_limit == u64::MAX {
self.reset_after.store(reset_after, Ordering::SeqCst);
self.limit.store(limit, Ordering::SeqCst);
}
Expand Down Expand Up @@ -162,10 +164,10 @@ impl BucketQueue {
}

/// Receive the first incoming ratelimit request.
pub async fn pop(&self, timeout_duration: Duration) -> Option<TicketNotifier> {
pub async fn pop(&self) -> Option<TicketNotifier> {
let mut rx = self.rx.lock().await;

timeout(timeout_duration, rx.recv()).await.ok().flatten()
rx.recv().await
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am somewhat certain that this will never return None (as atleast one sender is tied to each BucketQueue), which is otherwise necessary for the bucket to be finished and freed (see BucketQueueTask::run).

}
}

Expand All @@ -189,7 +191,7 @@ pub(super) struct BucketQueueTask {
/// All buckets managed by the associated [`super::InMemoryRatelimiter`].
buckets: Arc<Mutex<HashMap<Path, Arc<Bucket>>>>,
/// Global ratelimit data.
global: Arc<GlobalLockPair>,
global: GlobalBucket,
/// The [`Path`] this [`Bucket`] belongs to.
path: Path,
}
Expand All @@ -202,7 +204,7 @@ impl BucketQueueTask {
pub fn new(
bucket: Arc<Bucket>,
buckets: Arc<Mutex<HashMap<Path, Arc<Bucket>>>>,
global: Arc<GlobalLockPair>,
global: GlobalBucket,
path: Path,
) -> Self {
Self {
Expand All @@ -218,9 +220,8 @@ impl BucketQueueTask {
#[tracing::instrument(name = "background queue task", skip(self), fields(path = ?self.path))]
pub async fn run(self) {
while let Some(queue_tx) = self.next().await {
if self.global.is_locked() {
drop(self.global.0.lock().await);
}
// Do not lock up if the global rate limiter crashes for any reason
let global_ticket_tx = self.wait_for_global().await.ok();

let ticket_headers = if let Some(ticket_headers) = queue_tx.available() {
ticket_headers
Expand All @@ -231,7 +232,10 @@ impl BucketQueueTask {
tracing::debug!("starting to wait for response headers");

match timeout(Self::WAIT, ticket_headers).await {
Ok(Ok(Some(headers))) => self.handle_headers(&headers).await,
Ok(Ok(Some(headers))) => {
self.handle_headers(&headers);
global_ticket_tx.and_then(|tx| tx.headers(Some(headers)).ok());
}
Ok(Ok(None)) => {
tracing::debug!("request aborted");
}
Expand All @@ -252,43 +256,38 @@ impl BucketQueueTask {
.remove(&self.path);
}

#[tracing::instrument(name = "waiting for global bucket", skip_all)]
async fn wait_for_global(&self) -> Result<TicketSender, RecvError> {
let (tx, rx) = super::ticket::channel();
self.global.queue().push(tx);

tracing::debug!("waiting for global rate limit");
let res = rx.await;
tracing::debug!("done waiting for global rate limit");

res
}

/// Update the bucket's ratelimit state.
async fn handle_headers(&self, headers: &RatelimitHeaders) {
fn handle_headers(&self, headers: &RatelimitHeaders) {
let ratelimits = match headers {
RatelimitHeaders::Global(global) => {
self.lock_global(Duration::from_secs(global.retry_after()))
.await;

None
}
RatelimitHeaders::None => return,
RatelimitHeaders::Present(present) => {
Some((present.limit(), present.remaining(), present.reset_after()))
}
_ => return,
};

tracing::debug!(path=?self.path, "updating bucket");
self.bucket.update(ratelimits);
}

/// Lock the global ratelimit for a specified duration.
async fn lock_global(&self, wait: Duration) {
tracing::debug!(path=?self.path, "request got global ratelimited");
self.global.lock();
let lock = self.global.0.lock().await;
sleep(wait).await;
self.global.unlock();

drop(lock);
}

/// Get the next [`TicketNotifier`] in the queue.
async fn next(&self) -> Option<TicketNotifier> {
tracing::debug!(path=?self.path, "starting to get next in queue");

self.wait_if_needed().await;

self.bucket.queue.pop(Self::WAIT).await
self.bucket.queue.pop().await
}

/// Wait for this bucket to refresh if it isn't ready yet.
Expand Down
153 changes: 153 additions & 0 deletions twilight-http-ratelimiting/src/in_memory/global_bucket.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
//! Bucket implementation for a global ratelimit.

use super::bucket::BucketQueue;
use crate::ticket::TicketNotifier;
use crate::RatelimitHeaders;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::mpsc::{Receiver, Sender};
use tokio::sync::{mpsc, Mutex, Semaphore};
use tokio::time::Instant;

/// seconds per period
const PERIOD: u64 = 1;
/// requests per period
const REQUESTS: u32 = 50;

/// Global bucket. Keeps track of the global rate limit.
#[derive(Debug, Clone)]
pub struct GlobalBucket(Arc<InnerGlobalBucket>);

impl GlobalBucket {
/// Creates a new global bucket using custom ratelimit values.
///
/// `period` is given in seconds.
///
/// `requests` indicates the amount of requests per period.
#[must_use]
pub fn with_ratelimit(period: u64, requests: u32) -> Self {
Self(InnerGlobalBucket::new(period, requests))
}

/// Queue of global ratelimit requests.
pub fn queue(&self) -> &BucketQueue {
&self.0.queue
}

/// Whether the global ratelimit is exhausted.
pub fn is_locked(&self) -> bool {
self.0.is_locked.try_lock().is_err()
}
}

impl Default for GlobalBucket {
fn default() -> Self {
Self(InnerGlobalBucket::new(PERIOD, REQUESTS))
}
}

/// Inner struct to allow [`GlobalBucket`] to return an [`Arc`].
#[derive(Debug)]
struct InnerGlobalBucket {
/// Queue to receive rate limit requests.
pub queue: BucketQueue,
/// currently waiting for capacity.
is_locked: Mutex<()>,
}

impl InnerGlobalBucket {
/// Creates a new bucket and starts a task processing incoming requests.
fn new(period: u64, requests: u32) -> Arc<Self> {
let this = Self {
queue: BucketQueue::default(),
is_locked: Mutex::default(),
};
let this = Arc::new(this);

tokio::spawn(run_global_queue_task(this.clone(), period, requests));

this
}
}

#[tracing::instrument(name = "background global queue task", skip_all)]
async fn run_global_queue_task(bucket: Arc<InnerGlobalBucket>, period: u64, requests: u32) {
let mut time = Instant::now();
let semaphore = Arc::new(Semaphore::new(requests as usize));
let (penalty_tx, mut penalty_rx) = mpsc::channel(requests as usize);

while let Some(queue_tx) = bucket.queue.pop().await {
wait_if_needed(
bucket.as_ref(),
&mut time,
period,
requests,
&mut penalty_rx,
)
.await;

tokio::spawn(process_request(
semaphore.clone(),
queue_tx,
penalty_tx.clone(),
));
}
}

#[tracing::instrument(name = "process request", skip_all)]
async fn process_request(
semaphore: Arc<Semaphore>,
queue_tx: TicketNotifier,
penalties: Sender<Instant>,
) {
// This error should never occur, but if it does, do not lock up
let _permit = semaphore.acquire().await;

let ticket_headers = if let Some(ticket_headers) = queue_tx.available() {
ticket_headers
} else {
return;
};

if let Ok(Some(RatelimitHeaders::Global(headers))) = ticket_headers.await {
tracing::debug!(seconds = headers.retry_after(), "globally ratelimited");

let deadline = Instant::now() + Duration::from_secs(headers.retry_after());
penalties.send(deadline).await.ok();
}
}

/// Checks and sleeps in case a request needs to wait before proceeding.
async fn wait_if_needed(
bucket: &InnerGlobalBucket,
time: &mut Instant,
period: u64,
requests: u32,
penalties: &mut Receiver<Instant>,
) {
let period = Duration::from_secs(period);
let fill_rate = period / requests;

let now = Instant::now();
// maximum requests at once is 1 period worth of requests
let base = now - period;
// if the bucket currently holds more requests than maximum, set to maximum
if base > *time {
*time = base;
}

// deduct one request from current capacity
*time += fill_rate;

// if time > now, then the bucket is exhausted. wait until a request is available again
if *time > now {
let _guard = bucket.is_locked.lock().await;
tokio::time::sleep_until(*time).await;
}

// wait for penalties
while let Ok(deadline) = penalties.try_recv() {
let _guard = bucket.is_locked.lock().await;
tokio::time::sleep_until(deadline).await;
}
}
Loading