From 8ad12a72096ca51acca8681034ab980a89728dad Mon Sep 17 00:00:00 2001 From: fakeshadow <24548779@qq.com> Date: Fri, 19 Jan 2024 11:22:23 +0800 Subject: [PATCH] add rate-limit feature and http-rate crate. (#895) * add rate-limit feature and http-rate crate. * fix fmt. * clippy fix. * add package meta. * update change log. --- Cargo.toml | 2 + http-rate/Cargo.toml | 19 ++ http-rate/LICENSE | 21 ++ http-rate/README.md | 1 + http-rate/src/error.rs | 76 +++++ http-rate/src/gcra.rs | 208 ++++++++++++ http-rate/src/lib.rs | 474 +++++++++++++++++++++++++++ http-rate/src/nanos.rs | 137 ++++++++ http-rate/src/quota.rs | 172 ++++++++++ http-rate/src/snapshot.rs | 106 ++++++ http-rate/src/state.rs | 81 +++++ http-rate/src/state/direct.rs | 111 +++++++ http-rate/src/state/in_memory.rs | 133 ++++++++ http-rate/src/state/keyed.rs | 226 +++++++++++++ http-rate/src/state/keyed/hashmap.rs | 70 ++++ http-rate/src/timer.rs | 194 +++++++++++ web/CHANGES.md | 2 + web/Cargo.toml | 6 + web/src/middleware/mod.rs | 2 + web/src/middleware/rate_limit.rs | 101 ++++++ 20 files changed, 2142 insertions(+) create mode 100644 http-rate/Cargo.toml create mode 100644 http-rate/LICENSE create mode 100644 http-rate/README.md create mode 100644 http-rate/src/error.rs create mode 100644 http-rate/src/gcra.rs create mode 100644 http-rate/src/lib.rs create mode 100644 http-rate/src/nanos.rs create mode 100644 http-rate/src/quota.rs create mode 100644 http-rate/src/snapshot.rs create mode 100644 http-rate/src/state.rs create mode 100644 http-rate/src/state/direct.rs create mode 100644 http-rate/src/state/in_memory.rs create mode 100644 http-rate/src/state/keyed.rs create mode 100644 http-rate/src/state/keyed/hashmap.rs create mode 100644 http-rate/src/timer.rs create mode 100644 web/src/middleware/rate_limit.rs diff --git a/Cargo.toml b/Cargo.toml index 0f776826..fb9d266c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,6 +16,7 @@ members = [ "http-encoding", "http-file", "http-multipart", + "http-rate", "http-ws", ] @@ -34,6 +35,7 @@ xitca-web = { path = "./web" } http-encoding = { path = "./http-encoding" } http-file = { path = "http-file" } http-multipart = { path = "./http-multipart" } +http-rate = { path = "./http-rate" } http-ws = { path = "./http-ws" } [profile.release] diff --git a/http-rate/Cargo.toml b/http-rate/Cargo.toml new file mode 100644 index 00000000..11b0c78d --- /dev/null +++ b/http-rate/Cargo.toml @@ -0,0 +1,19 @@ +[package] +name = "http-rate" +version = "0.1.0" +edition = "2021" +authors = ["fakeshadow "] +license = "MIT" +description = "rate limit for http crate types" +repository = "https://github.com/HFQR/xitca-web" +keywords = ["http", "rate-limit"] +readme= "README.md" + +[dependencies] +http = "1" + +[dev-dependencies] +crossbeam = "0.8.0" +libc = "0.2.70" +proptest = "1.0.0" +all_asserts = "2.2.0" diff --git a/http-rate/LICENSE b/http-rate/LICENSE new file mode 100644 index 00000000..51c6a41e --- /dev/null +++ b/http-rate/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 Andreas Fuchs + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/http-rate/README.md b/http-rate/README.md new file mode 100644 index 00000000..c6cf95d9 --- /dev/null +++ b/http-rate/README.md @@ -0,0 +1 @@ +# rate limit for http types \ No newline at end of file diff --git a/http-rate/src/error.rs b/http-rate/src/error.rs new file mode 100644 index 00000000..42ab0165 --- /dev/null +++ b/http-rate/src/error.rs @@ -0,0 +1,76 @@ +use core::fmt; + +use std::{error, time::Instant}; + +use http::{HeaderName, HeaderValue, Response, StatusCode}; + +use crate::{ + gcra::NotUntil, + timer::{DefaultTimer, Timer}, +}; + +/// Error happen when client exceeds rate limit. +#[derive(Debug)] +pub struct TooManyRequests { + after_seconds: u64, +} + +impl fmt::Display for TooManyRequests { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "too many requests. wait for {}", self.after_seconds) + } +} + +impl error::Error for TooManyRequests {} + +impl From> for TooManyRequests { + fn from(e: NotUntil) -> Self { + let after_seconds = e.wait_time_from(DefaultTimer.now()).as_secs(); + + Self { after_seconds } + } +} + +const X_RT_AFTER: HeaderName = HeaderName::from_static("x-ratelimit-after"); + +impl TooManyRequests { + /// extend response headers with status code and headers + /// StatusCode: 429 + /// Header: `x-ratelimit-after: ` + pub fn extend_response(&self, res: &mut Response) { + *res.status_mut() = StatusCode::TOO_MANY_REQUESTS; + res.headers_mut() + .insert(X_RT_AFTER, HeaderValue::from(self.after_seconds)); + } +} + +/// Error indicating that the number of cells tested (the first +/// argument) is larger than the bucket's capacity. +/// +/// This means the decision can never have a conforming result. The +/// argument gives the maximum number of cells that could ever have a +/// conforming result. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct InsufficientCapacity(pub u32); + +impl fmt::Display for InsufficientCapacity { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "required number of cells {} exceeds bucket's capacity", self.0) + } +} + +impl std::error::Error for InsufficientCapacity {} + +#[cfg(test)] +mod test { + use super::InsufficientCapacity; + + #[test] + fn coverage() { + let display_output = format!("{}", InsufficientCapacity(3)); + assert!(display_output.contains('3')); + let debug_output = format!("{:?}", InsufficientCapacity(3)); + assert!(debug_output.contains('3')); + assert_eq!(InsufficientCapacity(3), InsufficientCapacity(3)); + } +} diff --git a/http-rate/src/gcra.rs b/http-rate/src/gcra.rs new file mode 100644 index 00000000..c4b76d91 --- /dev/null +++ b/http-rate/src/gcra.rs @@ -0,0 +1,208 @@ +use core::{cmp, fmt, time::Duration}; + +use crate::{nanos::Nanos, quota::Quota, snapshot::RateSnapshot, state::StateStore, timer}; + +#[cfg(test)] +use core::num::NonZeroU32; + +#[cfg(test)] +use crate::error::InsufficientCapacity; + +/// A negative rate-limiting outcome. +/// +/// `NotUntil`'s methods indicate when a caller can expect the next positive +/// rate-limiting result. +#[derive(Debug, PartialEq, Eq)] +pub struct NotUntil { + state: RateSnapshot, + start: P, +} + +impl NotUntil

{ + /// Create a `NotUntil` as a negative rate-limiting result. + #[inline] + pub(crate) fn new(state: RateSnapshot, start: P) -> Self { + Self { state, start } + } + + /// Returns the earliest time at which a decision could be + /// conforming (excluding conforming decisions made by the Decider + /// that are made in the meantime). + #[inline] + pub fn earliest_possible(&self) -> P { + let tat: Nanos = self.state.tat; + self.start + tat + } + + /// Returns the minimum amount of time from the time that the + /// decision was made that must pass before a + /// decision can be conforming. + /// + /// If the time of the next expected positive result is in the past, + /// `wait_time_from` returns a zero `Duration`. + #[inline] + pub fn wait_time_from(&self, from: P) -> Duration { + let earliest = self.earliest_possible(); + earliest.duration_since(earliest.min(from)).into() + } + + /// Returns the rate limiting [`Quota`] used to reach the decision. + #[inline] + pub fn quota(&self) -> Quota { + self.state.quota() + } +} + +impl fmt::Display for NotUntil

{ + fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { + write!(f, "rate-limited until {:?}", self.start + self.state.tat) + } +} + +#[derive(Debug, PartialEq, Eq)] +pub(crate) struct Gcra { + /// The "weight" of a single packet in units of time. + t: Nanos, + + /// The "burst capacity" of the bucket. + tau: Nanos, +} + +impl Gcra { + pub(crate) fn new(quota: Quota) -> Self { + let tau: Nanos = (cmp::max(quota.replenish_1_per, Duration::from_nanos(1)) * quota.max_burst.get()).into(); + let t: Nanos = quota.replenish_1_per.into(); + Gcra { t, tau } + } + + /// Computes and returns a new ratelimiter state if none exists yet. + fn starting_state(&self, t0: Nanos) -> Nanos { + t0 + self.t + } + + /// Tests a single cell against the rate limiter state and updates it at the given key. + pub(crate) fn test_and_update>( + &self, + start: P, + key: &K, + state: &S, + t0: P, + ) -> Result> { + let t0 = t0.duration_since(start); + let tau = self.tau; + let t = self.t; + state.measure_and_replace(key, |tat| { + let tat = tat.unwrap_or_else(|| self.starting_state(t0)); + let earliest_time = tat.saturating_sub(tau); + if t0 < earliest_time { + let state = RateSnapshot::new(self.t, self.tau, earliest_time, earliest_time); + Err(NotUntil::new(state, start)) + } else { + let next = cmp::max(tat, t0) + t; + Ok((RateSnapshot::new(self.t, self.tau, t0, next), next)) + } + }) + } + + #[cfg(test)] + /// Tests whether all `n` cells could be accommodated and updates the rate limiter state, if so. + pub(crate) fn test_n_all_and_update>( + &self, + start: P, + key: &K, + n: NonZeroU32, + state: &S, + t0: P, + ) -> Result>, InsufficientCapacity> { + let t0 = t0.duration_since(start); + let tau = self.tau; + let t = self.t; + let additional_weight = t * (n.get() - 1) as u64; + + // check that we can allow enough cells through. Note that `additional_weight` is the + // value of the cells *in addition* to the first cell - so add that first cell back. + if additional_weight + t > tau { + return Err(InsufficientCapacity((tau.as_u64() / t.as_u64()) as u32)); + } + Ok(state.measure_and_replace(key, |tat| { + let tat = tat.unwrap_or_else(|| self.starting_state(t0)); + let earliest_time = (tat + additional_weight).saturating_sub(tau); + if t0 < earliest_time { + let state = RateSnapshot::new(self.t, self.tau, earliest_time, earliest_time); + Err(NotUntil::new(state, start)) + } else { + let next = cmp::max(tat, t0) + t + additional_weight; + Ok((RateSnapshot::new(self.t, self.tau, t0, next), next)) + } + })) + } +} + +#[cfg(test)] +mod test { + use proptest::prelude::*; + + use crate::quota::Quota; + + use super::*; + + /// Exercise derives and convenience impls on Gcra to make coverage happy + #[test] + fn gcra_derives() { + use all_asserts::assert_gt; + + let g = Gcra::new(Quota::per_second(1)); + let g2 = Gcra::new(Quota::per_second(2)); + assert_eq!(g, g); + assert_ne!(g, g2); + assert_gt!(format!("{:?}", g).len(), 0); + } + + /// Exercise derives and convenience impls on NotUntil to make coverage happy + #[test] + fn notuntil_impls() { + use crate::state::RateLimiter; + use all_asserts::assert_gt; + use timer::FakeRelativeClock; + + let clock = FakeRelativeClock::default(); + let quota = Quota::per_second(1); + let lb = RateLimiter::direct_with_clock(quota, &clock); + assert!(lb.check().is_ok()); + assert!(lb + .check() + .map_err(|nu| { + assert_eq!(nu, nu); + assert_gt!(format!("{:?}", nu).len(), 0); + assert_eq!(format!("{}", nu), "rate-limited until Nanos(1s)"); + assert_eq!(nu.quota(), quota); + }) + .is_err()); + } + + #[derive(Debug)] + struct Count(NonZeroU32); + impl Arbitrary for Count { + type Parameters = (); + fn arbitrary_with(_args: ()) -> Self::Strategy { + (1..10000u32).prop_map(|x| Count(NonZeroU32::new(x).unwrap())).boxed() + } + + type Strategy = BoxedStrategy; + } + + #[test] + fn cover_count_derives() { + assert_eq!(format!("{:?}", Count(NonZeroU32::new(1).unwrap())), "Count(1)"); + } + + #[test] + fn roundtrips_quota() { + proptest!(ProptestConfig::default(), |(per_second: Count, burst: Count)| { + let quota = Quota::per_second(per_second.0).allow_burst(burst.0); + let gcra = Gcra::new(quota); + let back = Quota::from_gcra_parameters(gcra.t, gcra.tau); + assert_eq!(quota, back); + }) + } +} diff --git a/http-rate/src/lib.rs b/http-rate/src/lib.rs new file mode 100644 index 00000000..3769b341 --- /dev/null +++ b/http-rate/src/lib.rs @@ -0,0 +1,474 @@ +#![allow(clippy::declare_interior_mutable_const)] + +mod error; +mod gcra; +mod nanos; +mod quota; +mod snapshot; +mod state; +mod timer; + +pub use error::TooManyRequests; +pub use quota::Quota; +pub use snapshot::RateSnapshot; + +use core::net::{IpAddr, SocketAddr}; + +use std::sync::Arc; + +use http::header::{HeaderMap, HeaderName, FORWARDED}; + +use crate::state::{keyed::DefaultKeyedStateStore, RateLimiter}; + +#[derive(Clone)] +pub struct RateLimit { + limit: Arc>>, +} + +impl RateLimit { + /// Construct a new RateLimit with given quota. + pub fn new(quota: Quota) -> Self { + Self { + limit: Arc::new(RateLimiter::hashmap(quota)), + } + } + + /// Rate limit [Request] based on it's [HeaderMap] state and given client [SocketAddr] + /// "x-real-ip", "x-forwarded-for" and "forwarded" are checked in order start from left to + /// determine client's socket address. Received [SocketAddr] will be used as fallback when + /// all headers are absent or can't provide valid client address. + /// + /// [Request]: http::Request + pub fn rate_limit(&self, headers: &HeaderMap, addr: &SocketAddr) -> Result { + let addr = maybe_x_forwarded_for(headers) + .or_else(|| maybe_x_real_ip(headers)) + .or_else(|| maybe_forwarded(headers)) + .unwrap_or_else(|| addr.ip()); + self.limit.check_key(&addr).map_err(TooManyRequests::from) + } +} + +const X_REAL_IP: HeaderName = HeaderName::from_static("x-real-ip"); +const X_FORWARDED_FOR: HeaderName = HeaderName::from_static("x-forwarded-for"); + +fn maybe_x_forwarded_for(headers: &HeaderMap) -> Option { + headers + .get(X_FORWARDED_FOR) + .and_then(|hv| hv.to_str().ok()) + .and_then(|s| s.split(',').find_map(|s| s.trim().parse::().ok())) +} + +fn maybe_x_real_ip(headers: &HeaderMap) -> Option { + headers + .get(X_REAL_IP) + .and_then(|hv| hv.to_str().ok()) + .and_then(|s| s.parse::().ok()) +} + +fn maybe_forwarded(headers: &HeaderMap) -> Option { + let mut res = None; + + for mut val in headers + .get_all(FORWARDED) + .iter() + .filter_map(|h| h.to_str().ok()) + .flat_map(|val| val.split(';')) + .flat_map(|p| p.split(',')) + .map(|val| val.trim().splitn(2, '=')) + { + if let (Some(name), Some(val)) = (val.next(), val.next()) { + if name.eq_ignore_ascii_case("for") { + let val = val.trim(); + match val.parse::() { + Ok(addr) => res = Some(addr.ip()), + Err(_) => res = val.parse::().ok(), + } + } + } + } + + res +} + +#[cfg(test)] +type DefaultDirectRateLimiter = RateLimiter; + +#[cfg(test)] +mod test { + use core::{num::NonZeroU32, time::Duration}; + + use std::thread; + + use all_asserts::*; + use http::header::HeaderValue; + + use crate::{ + error::InsufficientCapacity, + quota::Quota, + state::RateLimiter, + timer::{DefaultTimer, FakeRelativeClock, Timer}, + DefaultDirectRateLimiter, + }; + + use super::*; + + #[test] + fn forwarded_header() { + let mut headers = HeaderMap::new(); + headers.insert( + FORWARDED, + HeaderValue::from_static("for=192.0.2.60;proto=http;by=203.0.113.43"), + ); + assert_eq!(maybe_forwarded(&headers).unwrap().to_string(), "192.0.2.60"); + } + + #[test] + fn rejects_too_many() { + let clock = FakeRelativeClock::default(); + let lb = RateLimiter::direct_with_clock(Quota::per_second(2), &clock); + let ms = Duration::from_millis(1); + + // use up our burst capacity (2 in the first second): + assert!(lb.check().is_ok(), "Now: {:?}", clock.now()); + clock.advance(ms); + assert!(lb.check().is_ok(), "Now: {:?}", clock.now()); + + clock.advance(ms); + assert!(lb.check().is_err(), "Now: {:?}", clock.now()); + + // should be ok again in 1s: + clock.advance(ms * 1000); + assert!(lb.check().is_ok(), "Now: {:?}", clock.now()); + clock.advance(ms); + assert!(lb.check().is_ok()); + + clock.advance(ms); + assert!(lb.check().is_err(), "{:?}", lb); + } + + #[test] + fn all_1_identical_to_1() { + let clock = FakeRelativeClock::default(); + let lb = RateLimiter::direct_with_clock(Quota::per_second(2), &clock); + let ms = Duration::from_millis(1); + let one = NonZeroU32::new(1).unwrap(); + + // use up our burst capacity (2 in the first second): + assert!(lb.check_n(one).unwrap().is_ok(), "Now: {:?}", clock.now()); + clock.advance(ms); + assert!(lb.check_n(one).unwrap().is_ok(), "Now: {:?}", clock.now()); + + clock.advance(ms); + assert!(lb.check_n(one).unwrap().is_err(), "Now: {:?}", clock.now()); + + // should be ok again in 1s: + clock.advance(ms * 1000); + assert!(lb.check_n(one).unwrap().is_ok(), "Now: {:?}", clock.now()); + clock.advance(ms); + assert!(lb.check_n(one).unwrap().is_ok()); + + clock.advance(ms); + assert!(lb.check_n(one).unwrap().is_err(), "{:?}", lb); + } + + #[test] + fn never_allows_more_than_capacity_all() { + let clock = FakeRelativeClock::default(); + let lb = RateLimiter::direct_with_clock(Quota::per_second(4), &clock); + let ms = Duration::from_millis(1); + + let num = NonZeroU32::new(2).unwrap(); + + // Use up the burst capacity: + assert!(lb.check_n(num).unwrap().is_ok()); + assert!(lb.check_n(num).unwrap().is_ok()); + + clock.advance(ms); + assert!(lb.check_n(num).unwrap().is_err()); + + // should be ok again in 1s: + clock.advance(ms * 1000); + assert!(lb.check_n(num).unwrap().is_ok()); + clock.advance(ms); + assert!(lb.check_n(num).unwrap().is_ok()); + + clock.advance(ms); + assert!(lb.check_n(num).unwrap().is_err(), "{:?}", lb); + } + + #[test] + fn rejects_too_many_all() { + let clock = FakeRelativeClock::default(); + let lb = RateLimiter::direct_with_clock(Quota::per_second(5), &clock); + let ms = Duration::from_millis(1); + + let num = NonZeroU32::new(15).unwrap(); + + // Should not allow the first 15 cells on a capacity 5 bucket: + assert!(lb.check_n(num).is_err()); + + // After 3 and 20 seconds, it should not allow 15 on that bucket either: + clock.advance(ms * 3 * 1000); + assert!(lb.check_n(num).is_err()); + } + + #[test] + fn all_capacity_check_rejects_excess() { + let clock = FakeRelativeClock::default(); + let lb = RateLimiter::direct_with_clock(Quota::per_second(5), &clock); + + assert_eq!(Err(InsufficientCapacity(5)), lb.check_n(NonZeroU32::new(15).unwrap())); + assert_eq!(Err(InsufficientCapacity(5)), lb.check_n(NonZeroU32::new(6).unwrap())); + assert_eq!(Err(InsufficientCapacity(5)), lb.check_n(NonZeroU32::new(7).unwrap())); + } + + #[test] + fn correct_wait_time() { + let clock = FakeRelativeClock::default(); + // Bucket adding a new element per 200ms: + let lb = RateLimiter::direct_with_clock(Quota::per_second(5), &clock); + let ms = Duration::from_millis(1); + let mut conforming = 0; + for _i in 0..20 { + clock.advance(ms); + let res = lb.check(); + match res { + Ok(_) => { + conforming += 1; + } + Err(wait) => { + clock.advance(wait.wait_time_from(clock.now())); + assert!(lb.check().is_ok()); + conforming += 1; + } + } + } + assert_eq!(20, conforming); + } + + #[test] + fn actual_threadsafety() { + use crossbeam; + + let clock = FakeRelativeClock::default(); + let lim = RateLimiter::direct_with_clock(Quota::per_second(20), &clock); + let ms = Duration::from_millis(1); + + crossbeam::scope(|scope| { + for _i in 0..20 { + scope.spawn(|_| { + assert!(lim.check().is_ok()); + }); + } + }) + .unwrap(); + + clock.advance(ms * 2); + assert!(lim.check().is_err()); + clock.advance(ms * 998); + assert!(lim.check().is_ok()); + } + + #[test] + fn default_direct() { + let limiter = RateLimiter::direct_with_clock(Quota::per_second(20), &DefaultTimer); + assert!(limiter.check().is_ok()); + } + + #[test] + fn stresstest_large_quotas() { + use std::{sync::Arc, thread}; + + let quota = Quota::per_second(1_000_000_001); + let rate_limiter = Arc::new(RateLimiter::direct(quota)); + + fn rlspin(rl: Arc) { + for _ in 0..1_000_000 { + rl.check().map_err(|e| dbg!(e)).unwrap(); + } + } + + let rate_limiter2 = rate_limiter.clone(); + thread::spawn(move || { + rlspin(rate_limiter2); + }); + rlspin(rate_limiter); + } + + const KEYS: &[u32] = &[1u32, 2u32]; + + #[test] + fn accepts_first_cell() { + let clock = FakeRelativeClock::default(); + let lb = RateLimiter::hashmap_with_clock(Quota::per_second(5), &clock); + for key in KEYS { + assert!(lb.check_key(&key).is_ok(), "key {:?}", key); + } + } + + use crate::state::keyed::HashMapStateStore; + use core::hash::Hash; + + fn retained_keys( + limiter: RateLimiter, FakeRelativeClock>, + ) -> Vec { + let state = limiter.into_state_store(); + let map = state.lock().unwrap(); + let mut keys: Vec = map.keys().copied().collect(); + keys.sort(); + keys + } + + #[test] + fn expiration() { + let clock = FakeRelativeClock::default(); + let ms = Duration::from_millis(1); + + let make_bucket = || { + let lim = RateLimiter::hashmap_with_clock(Quota::per_second(1), &clock); + lim.check_key(&"foo").unwrap(); + clock.advance(ms * 200); + lim.check_key(&"bar").unwrap(); + clock.advance(ms * 600); + lim.check_key(&"baz").unwrap(); + lim + }; + let keys = &["bar", "baz", "foo"]; + + // clean up all keys that are indistinguishable from unoccupied keys: + let lim_shrunk = make_bucket(); + lim_shrunk.retain_recent(); + assert_eq!(retained_keys(lim_shrunk), keys); + + let lim_later = make_bucket(); + clock.advance(ms * 1200); + lim_later.retain_recent(); + assert_eq!(retained_keys(lim_later), vec!["bar", "baz"]); + + let lim_later = make_bucket(); + clock.advance(ms * (1200 + 200)); + lim_later.retain_recent(); + assert_eq!(retained_keys(lim_later), vec!["baz"]); + + let lim_later = make_bucket(); + clock.advance(ms * (1200 + 200 + 600)); + lim_later.retain_recent(); + assert_eq!(retained_keys(lim_later), Vec::<&str>::new()); + } + + #[test] + fn hashmap_length() { + let lim = RateLimiter::hashmap(Quota::per_second(1)); + assert_eq!(lim.len(), 0); + assert!(lim.is_empty()); + + lim.check_key(&"foo").unwrap(); + assert_eq!(lim.len(), 1); + assert!(!lim.is_empty(),); + + lim.check_key(&"bar").unwrap(); + assert_eq!(lim.len(), 2); + assert!(!lim.is_empty()); + + lim.check_key(&"baz").unwrap(); + assert_eq!(lim.len(), 3); + assert!(!lim.is_empty()); + } + + #[test] + fn hashmap_shrink_to_fit() { + let clock = FakeRelativeClock::default(); + // a steady rate of 3ms between elements: + let lim = RateLimiter::hashmap_with_clock(Quota::per_second(20), &clock); + let ms = Duration::from_millis(1); + + assert!(lim + .check_key_n(&"long-lived".to_string(), NonZeroU32::new(10).unwrap()) + .unwrap() + .is_ok(),); + assert!(lim.check_key(&"short-lived".to_string()).is_ok()); + + // Move the clock forward far enough that the short-lived key gets dropped: + clock.advance(ms * 300); + lim.retain_recent(); + lim.shrink_to_fit(); + + assert_eq!(lim.len(), 1); + } + + fn resident_memory_size() -> i64 { + let mut out: libc::rusage = unsafe { std::mem::zeroed() }; + assert!(unsafe { libc::getrusage(libc::RUSAGE_SELF, &mut out) } == 0); + out.ru_maxrss + } + + const LEAK_TOLERANCE: i64 = 1024 * 1024 * 10; + + struct LeakCheck { + usage_before: i64, + n_iter: usize, + } + + impl Drop for LeakCheck { + fn drop(&mut self) { + let usage_after = resident_memory_size(); + assert_le!(usage_after, self.usage_before + LEAK_TOLERANCE); + } + } + + impl LeakCheck { + fn new(n_iter: usize) -> Self { + LeakCheck { + n_iter, + usage_before: resident_memory_size(), + } + } + } + + #[test] + fn memleak_gcra() { + let bucket = RateLimiter::direct(Quota::per_second(1_000_000)); + + let leak_check = LeakCheck::new(500_000); + + for _i in 0..leak_check.n_iter { + drop(bucket.check()); + } + } + + #[test] + fn memleak_gcra_multi() { + let bucket = RateLimiter::direct(Quota::per_second(1_000_000)); + let leak_check = LeakCheck::new(500_000); + + for _i in 0..leak_check.n_iter { + drop(bucket.check_n(NonZeroU32::new(2).unwrap())); + } + } + + #[test] + fn memleak_gcra_threaded() { + let bucket = Arc::new(RateLimiter::direct(Quota::per_second(1_000_000))); + let leak_check = LeakCheck::new(5_000); + + for _i in 0..leak_check.n_iter { + let bucket = Arc::clone(&bucket); + thread::spawn(move || { + assert!(bucket.check().is_ok()); + }) + .join() + .unwrap(); + } + } + + #[test] + fn memleak_keyed() { + let bucket = RateLimiter::keyed(Quota::per_second(50)); + + let leak_check = LeakCheck::new(500_000); + + for i in 0..leak_check.n_iter { + drop(bucket.check_key(&(i % 1000))); + } + } +} diff --git a/http-rate/src/nanos.rs b/http-rate/src/nanos.rs new file mode 100644 index 00000000..fe798de7 --- /dev/null +++ b/http-rate/src/nanos.rs @@ -0,0 +1,137 @@ +//! A time-keeping abstraction (nanoseconds) that works for storing in an atomic integer. + +use core::{ + convert::TryInto, + fmt, + ops::{Add, Div, Mul}, + time::Duration, +}; + +use crate::timer::Reference; + +/// A number of nanoseconds from a reference point. +/// +/// Nanos can not represent durations >584 years, but hopefully that +/// should not be a problem in real-world applications. +#[derive(PartialEq, Eq, Default, Clone, Copy, PartialOrd, Ord)] +pub struct Nanos(u64); + +impl Nanos { + pub fn as_u64(self) -> u64 { + self.0 + } +} + +/// Nanos as used by Jitter and other std-only features. +impl Nanos { + pub const fn new(u: u64) -> Self { + Nanos(u) + } +} + +impl From for Nanos { + fn from(d: Duration) -> Self { + // This will panic: + Nanos(d.as_nanos().try_into().expect("Duration is longer than 584 years")) + } +} + +impl fmt::Debug for Nanos { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + let d = Duration::from_nanos(self.0); + write!(f, "Nanos({:?})", d) + } +} + +impl Add for Nanos { + type Output = Nanos; + + fn add(self, rhs: Nanos) -> Self::Output { + Nanos(self.0 + rhs.0) + } +} + +impl Mul for Nanos { + type Output = Nanos; + + fn mul(self, rhs: u64) -> Self::Output { + Nanos(self.0 * rhs) + } +} + +impl Div for Nanos { + type Output = u64; + + fn div(self, rhs: Nanos) -> Self::Output { + self.0 / rhs.0 + } +} + +impl From for Nanos { + fn from(u: u64) -> Self { + Nanos(u) + } +} + +impl From for u64 { + fn from(n: Nanos) -> Self { + n.0 + } +} + +impl From for Duration { + fn from(n: Nanos) -> Self { + Duration::from_nanos(n.0) + } +} + +impl Nanos { + #[inline] + pub fn saturating_sub(self, rhs: Nanos) -> Nanos { + Nanos(self.0.saturating_sub(rhs.0)) + } +} + +impl Reference for Nanos { + #[inline] + fn duration_since(&self, earlier: Self) -> Nanos { + (*self as Nanos).saturating_sub(earlier) + } + + #[inline] + fn saturating_sub(&self, duration: Nanos) -> Self { + (*self as Nanos).saturating_sub(duration) + } +} + +impl Add for Nanos { + type Output = Self; + + fn add(self, other: Duration) -> Self { + let other: Nanos = other.into(); + self + other + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn nanos_impls() { + let n = Nanos::new(20); + assert_eq!("Nanos(20ns)", format!("{:?}", n)); + } + + #[test] + fn nanos_arith_coverage() { + let n = Nanos::new(20); + let n_half = Nanos::new(10); + assert_eq!(n / n_half, 2); + assert_eq!(30, (n + Duration::from_nanos(10)).as_u64()); + + assert_eq!(n_half.saturating_sub(n), Nanos::new(0)); + assert_eq!(n.saturating_sub(n_half), n_half); + assert_eq!(Reference::saturating_sub(&n_half, n), Nanos::new(0)); + } +} diff --git a/http-rate/src/quota.rs b/http-rate/src/quota.rs new file mode 100644 index 00000000..36158706 --- /dev/null +++ b/http-rate/src/quota.rs @@ -0,0 +1,172 @@ +use core::{convert::TryInto, fmt, num::NonZeroU32, time::Duration}; + +use crate::nanos::Nanos; + +/// A rate-limiting quota. +/// +/// Quotas are expressed in a positive number of "cells" (the maximum number of positive decisions / +/// allowed items until the rate limiter needs to replenish) and the amount of time for the rate +/// limiter to replenish a single cell. +/// +/// Neither the number of cells nor the replenishment unit of time may be zero. +/// +/// # Burst sizes +/// There are multiple ways of expressing the same quota: a quota given as `Quota::per_second(1)` +/// allows, on average, the same number of cells through as a quota given as `Quota::per_minute(60)`. +/// However, the quota of `Quota::per_minute(60)` has a burst size of 60 cells, meaning it is +/// possible to accommodate 60 cells in one go, after which the equivalent of a minute of inactivity +/// is required for the burst allowance to be fully restored. +/// +/// Burst size gets really important when you construct a rate limiter that should allow multiple +/// elements through at one time (using [`RateLimiter.check_n`](struct.RateLimiter.html#method.check_n) +/// and its related functions): Only +/// at most as many cells can be let through in one call as are given as the burst size. +/// +/// In other words, the burst size is the maximum number of cells that the rate limiter will ever +/// allow through without replenishing them. +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub struct Quota { + pub(crate) max_burst: NonZeroU32, + pub(crate) replenish_1_per: Duration, +} + +impl Quota { + /// Construct a quota for a number of cells per second. The given number of cells is also + /// assumed to be the maximum burst size. + /// + /// # Panics + /// - When max_burst is zero. + pub fn per_second(max_burst: B) -> Self + where + B: TryInto, + B::Error: fmt::Debug, + { + let max_burst = max_burst.try_into().unwrap(); + let replenish_interval_ns = Duration::from_secs(1).as_nanos() / (max_burst.get() as u128); + Self { + max_burst, + replenish_1_per: Duration::from_nanos(replenish_interval_ns as u64), + } + } + + /// Construct a quota for a number of cells per 60-second period. The given number of cells is + /// also assumed to be the maximum burst size. + /// + /// # Panics + /// - When max_burst is zero. + pub fn per_minute(max_burst: B) -> Self + where + B: TryInto, + B::Error: fmt::Debug, + { + let max_burst = max_burst.try_into().unwrap(); + let replenish_interval_ns = Duration::from_secs(60).as_nanos() / (max_burst.get() as u128); + Quota { + max_burst, + replenish_1_per: Duration::from_nanos(replenish_interval_ns as u64), + } + } + + /// Construct a quota for a number of cells per 60-minute (3600-second) period. The given number + /// of cells is also assumed to be the maximum burst size. + /// + /// # Panics + /// - When max_burst is zero. + pub fn per_hour(max_burst: B) -> Self + where + B: TryInto, + B::Error: fmt::Debug, + { + let max_burst = max_burst.try_into().unwrap(); + let replenish_interval_ns = Duration::from_secs(60 * 60).as_nanos() / (max_burst.get() as u128); + Self { + max_burst, + replenish_1_per: Duration::from_nanos(replenish_interval_ns as u64), + } + } + + /// Construct a quota that replenishes one cell in a given + /// interval. + /// + /// This constructor is meant to replace [`::new`](#method.new), + /// in cases where a longer refresh period than 1 cell/hour is + /// necessary. + /// + /// If the time interval is zero, returns `None`. + pub fn with_period(replenish_1_per: Duration) -> Option { + if replenish_1_per.as_nanos() == 0 { + None + } else { + Some(Quota { + max_burst: NonZeroU32::new(1).unwrap(), + replenish_1_per, + }) + } + } + + /// Adjusts the maximum burst size for a quota to construct a rate limiter with a capacity + /// for at most the given number of cells. + /// + /// # Panics + /// - When max_burst is zero. + pub fn allow_burst(self, max_burst: B) -> Self + where + B: TryInto, + B::Error: fmt::Debug, + { + let max_burst = max_burst.try_into().unwrap(); + Self { max_burst, ..self } + } +} + +impl Quota { + // The maximum number of cells that can be allowed in one burst. + pub(crate) const fn burst_size(&self) -> NonZeroU32 { + self.max_burst + } + + #[cfg(test)] + // The time it takes for a rate limiter with an exhausted burst budget to replenish + // a single element. + const fn replenish_interval(&self) -> Duration { + self.replenish_1_per + } + + // The time it takes to replenish the entire maximum burst size. + // const fn burst_size_replenished_in(&self) -> Duration { + // let fill_in_ns = self.replenish_1_per.as_nanos() * self.max_burst.get() as u128; + // Duration::from_nanos(fill_in_ns as u64) + // } +} + +impl Quota { + // A way to reconstruct a Quota from an in-use Gcra. + pub(crate) fn from_gcra_parameters(t: Nanos, tau: Nanos) -> Quota { + let max_burst = NonZeroU32::new((tau.as_u64() / t.as_u64()) as u32).unwrap(); + let replenish_1_per = t.into(); + Quota { + max_burst, + replenish_1_per, + } + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn time_multiples() { + let hourly = Quota::per_hour(1); + let minutely = Quota::per_minute(1); + let secondly = Quota::per_second(1); + + assert_eq!(hourly.replenish_interval() / 60, minutely.replenish_interval()); + assert_eq!(minutely.replenish_interval() / 60, secondly.replenish_interval()); + } + + #[test] + fn period_error_cases() { + assert!(Quota::with_period(Duration::from_secs(0)).is_none()); + } +} diff --git a/http-rate/src/snapshot.rs b/http-rate/src/snapshot.rs new file mode 100644 index 00000000..c38fa3c8 --- /dev/null +++ b/http-rate/src/snapshot.rs @@ -0,0 +1,106 @@ +use core::cmp; + +use http::{ + header::{HeaderName, HeaderValue}, + Response, +}; + +use crate::{nanos::Nanos, quota::Quota}; + +/// Information about the rate-limiting state used to reach a decision. +#[derive(Clone, PartialEq, Eq, Debug)] +pub struct RateSnapshot { + // The "weight" of a single packet in units of time. + t: Nanos, + // The "burst capacity" of the bucket. + tau: Nanos, + // The time at which the measurement was taken. + pub(crate) time_of_measurement: Nanos, + // The next time a cell is expected to arrive + pub(crate) tat: Nanos, +} + +const X_RT_LIMIT: HeaderName = HeaderName::from_static("x-ratelimit-limit"); +const X_RT_REMAINING: HeaderName = HeaderName::from_static("x-ratelimit-remaining"); + +impl RateSnapshot { + /// extend response headers with headers + /// Header: `x-ratelimit-limit: ` + /// Header: `x-ratelimit-remaining: ` + pub fn extend_response(&self, res: &mut Response) { + let burst_size = self.quota().burst_size().get(); + let remaining_burst_capacity = self.remaining_burst_capacity(); + let headers = res.headers_mut(); + headers.insert(X_RT_LIMIT, HeaderValue::from(burst_size)); + headers.insert(X_RT_REMAINING, HeaderValue::from(remaining_burst_capacity)); + } + + pub(crate) const fn new(t: Nanos, tau: Nanos, time_of_measurement: Nanos, tat: Nanos) -> Self { + Self { + t, + tau, + time_of_measurement, + tat, + } + } + + /// Returns the quota used to make the rate limiting decision. + pub(crate) fn quota(&self) -> Quota { + Quota::from_gcra_parameters(self.t, self.tau) + } + + fn remaining_burst_capacity(&self) -> u32 { + let t0 = self.time_of_measurement + self.t; + (cmp::min((t0 + self.tau).saturating_sub(self.tat).as_u64(), self.tau.as_u64()) / self.t.as_u64()) as u32 + } +} + +#[cfg(test)] +mod test { + use core::time::Duration; + + use crate::{quota::Quota, state::RateLimiter, timer::FakeRelativeClock}; + + #[test] + fn state_information() { + let clock = FakeRelativeClock::default(); + let lim = RateLimiter::direct_with_clock(Quota::per_second(4), &clock); + assert_eq!(Ok(3), lim.check().map(|outcome| outcome.remaining_burst_capacity())); + assert_eq!(Ok(2), lim.check().map(|outcome| outcome.remaining_burst_capacity())); + assert_eq!(Ok(1), lim.check().map(|outcome| outcome.remaining_burst_capacity())); + assert_eq!(Ok(0), lim.check().map(|outcome| outcome.remaining_burst_capacity())); + assert!(lim.check().is_err()); + } + + #[test] + fn state_snapshot_tracks_quota_accurately() { + let period = Duration::from_millis(90); + let quota = Quota::with_period(period).unwrap().allow_burst(2); + + let clock = FakeRelativeClock::default(); + + // First test + let lim = RateLimiter::direct_with_clock(quota, &clock); + + assert_eq!(lim.check().unwrap().remaining_burst_capacity(), 1); + assert_eq!(lim.check().unwrap().remaining_burst_capacity(), 0); + assert_eq!(lim.check().map_err(|_| ()), Err(()), "should rate limit"); + + clock.advance(Duration::from_secs(120)); + assert_eq!(lim.check().map(|s| s.remaining_burst_capacity()), Ok(2)); + assert_eq!(lim.check().map(|s| s.remaining_burst_capacity()), Ok(1)); + assert_eq!(lim.check().map(|s| s.remaining_burst_capacity()), Ok(0)); + assert_eq!(lim.check().map_err(|_| ()), Err(()), "should rate limit"); + } + + #[test] + fn state_snapshot_tracks_quota_accurately_with_real_clock() { + let period = Duration::from_millis(90); + let quota = Quota::with_period(period).unwrap().allow_burst(2); + let lim = RateLimiter::direct(quota); + + assert_eq!(lim.check().unwrap().remaining_burst_capacity(), 1); + assert_eq!(lim.check().unwrap().remaining_burst_capacity(), 0); + assert_eq!(lim.check().map_err(|_| ()), Err(()), "should rate limit"); + } +} diff --git a/http-rate/src/state.rs b/http-rate/src/state.rs new file mode 100644 index 00000000..5d45f133 --- /dev/null +++ b/http-rate/src/state.rs @@ -0,0 +1,81 @@ +//! State stores for rate limiters + +pub mod direct; +pub mod keyed; + +mod in_memory; + +pub(crate) use self::{direct::*, in_memory::InMemoryState}; + +use crate::{ + gcra::Gcra, + nanos::Nanos, + quota::Quota, + timer::{DefaultTimer, Timer}, +}; + +/// A way for rate limiters to keep state. +/// +/// There are two important kinds of state stores: Direct and keyed. The direct kind have only +/// one state, and are useful for "global" rate limit enforcement (e.g. a process should never +/// do more than N tasks a day). The keyed kind allows one rate limit per key (e.g. an API +/// call budget per client API key). +/// +/// A direct state store is expressed as [`StateStore::Key`] = [`NotKeyed`]. +/// Keyed state stores have a +/// type parameter for the key and set their key to that. +pub(crate) trait StateStore { + /// The type of key that the state store can represent. + type Key; + + /// Updates a state store's rate limiting state for a given key, using the given closure. + /// + /// The closure parameter takes the old value (`None` if this is the first measurement) of the + /// state store at the key's location, checks if the request an be accommodated and: + /// + /// * If the request is rate-limited, returns `Err(E)`. + /// * If the request can make it through, returns `Ok(T)` (an arbitrary positive return + /// value) and the updated state. + /// + /// It is `measure_and_replace`'s job then to safely replace the value at the key - it must + /// only update the value if the value hasn't changed. The implementations in this + /// crate use `AtomicU64` operations for this. + fn measure_and_replace(&self, key: &Self::Key, f: F) -> Result + where + F: Fn(Option) -> Result<(T, Nanos), E>; +} + +#[derive(Debug)] +pub(crate) struct RateLimiter +where + S: StateStore, + C: Timer, +{ + state: S, + gcra: Gcra, + clock: C, + start: C::Instant, +} + +impl RateLimiter +where + S: StateStore, + C: Timer, +{ + pub(crate) fn new(quota: Quota, state: S, clock: &C) -> Self { + let gcra = Gcra::new(quota); + let start = clock.now(); + let clock = clock.clone(); + RateLimiter { + state, + clock, + gcra, + start, + } + } + + #[cfg(test)] + pub(crate) fn into_state_store(self) -> S { + self.state + } +} diff --git a/http-rate/src/state/direct.rs b/http-rate/src/state/direct.rs new file mode 100644 index 00000000..605fc02c --- /dev/null +++ b/http-rate/src/state/direct.rs @@ -0,0 +1,111 @@ +#![allow(unused)] + +use core::num::NonZeroU32; + +use crate::{ + error::InsufficientCapacity, gcra::NotUntil, quota::Quota, snapshot::RateSnapshot, state::InMemoryState, timer, + timer::DefaultTimer, +}; + +/// The "this state store does not use keys" key type. +/// +/// It's possible to use this to create a "direct" rate limiter. It explicitly does not implement +/// [`Hash`][std::hash::Hash] so that it is possible to tell apart from "hashable" key types. +#[derive(PartialEq, Debug, Eq)] +pub enum NotKeyed { + /// The value given to state stores' methods. + NonKey, +} + +/// A trait for state stores that only keep one rate limiting state. +/// +/// This is blanket-implemented by all [`StateStore`]s with [`NotKeyed`] key associated types. +pub(crate) trait DirectStateStore: StateStore {} + +impl DirectStateStore for T where T: StateStore {} + +/// # Direct in-memory rate limiters - Constructors +/// +/// Here we construct an in-memory rate limiter that makes direct (un-keyed) +/// rate-limiting decisions. Direct rate limiters can be used to +/// e.g. regulate the transmission of packets on a single connection, +/// or to ensure that an API client stays within a service's rate +/// limit. +impl RateLimiter { + /// Constructs a new in-memory direct rate limiter for a quota with the default real-time clock. + pub(crate) fn direct(quota: Quota) -> RateLimiter { + let clock = DefaultTimer; + Self::direct_with_clock(quota, &clock) + } +} + +impl RateLimiter +where + C: timer::Timer, +{ + /// Constructs a new direct rate limiter for a quota with a custom clock. + pub(crate) fn direct_with_clock(quota: Quota, clock: &C) -> Self { + let state: InMemoryState = Default::default(); + RateLimiter::new(quota, state, clock) + } +} + +/// # Direct rate limiters - Manually checking cells +impl RateLimiter +where + S: DirectStateStore, + C: timer::Timer, +{ + /// Allow a single cell through the rate limiter. + /// + /// If the rate limit is reached, `check` returns information about the earliest + /// time that a cell might be allowed through again. + pub(crate) fn check(&self) -> Result> { + self.gcra.test_and_update::( + self.start, + &NotKeyed::NonKey, + &self.state, + self.clock.now(), + ) + } + + #[cfg(test)] + /// Allow *only all* `n` cells through the rate limiter. + /// + /// This method can succeed in only one way and fail in two ways: + /// * Success: If all `n` cells can be accommodated, it returns `Ok(())`. + /// * Failure (but ok): Not all cells can make it through at the current time. + /// The result is `Err(NegativeMultiDecision::BatchNonConforming(NotUntil))`, which can + /// be interrogated about when the batch might next conform. + /// * Failure (the batch can never go through): The rate limit quota's burst size is too low + /// for the given number of cells to ever be allowed through. + /// + /// ### Performance + /// This method diverges a little from the GCRA algorithm, using + /// multiplication to determine the next theoretical arrival time, and so + /// is not as fast as checking a single cell. + pub(crate) fn check_n( + &self, + n: NonZeroU32, + ) -> Result>, InsufficientCapacity> { + self.gcra.test_n_all_and_update::( + self.start, + &NotKeyed::NonKey, + n, + &self.state, + self.clock.now(), + ) + } +} + +use crate::state::{RateLimiter, StateStore}; + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn not_keyed_impls_coverage() { + assert_eq!(NotKeyed::NonKey, NotKeyed::NonKey); + } +} diff --git a/http-rate/src/state/in_memory.rs b/http-rate/src/state/in_memory.rs new file mode 100644 index 00000000..3ed8b698 --- /dev/null +++ b/http-rate/src/state/in_memory.rs @@ -0,0 +1,133 @@ +use std::prelude::v1::*; + +use crate::nanos::Nanos; +use crate::state::{NotKeyed, StateStore}; +use std::fmt; +use std::fmt::Debug; +use std::num::NonZeroU64; +use std::sync::atomic::AtomicU64; +use std::sync::atomic::Ordering; +use std::time::Duration; + +/// An in-memory representation of a GCRA's rate-limiting state. +/// +/// Implemented using [`AtomicU64`] operations, this state representation can be used to +/// construct rate limiting states for other in-memory states: e.g., this crate uses +/// `InMemoryState` as the states it tracks in the keyed rate limiters it implements. +/// +/// Internally, the number tracked here is the theoretical arrival time (a GCRA term) in number of +/// nanoseconds since the rate limiter was created. +#[derive(Default)] +pub(crate) struct InMemoryState(AtomicU64); + +impl InMemoryState { + pub(crate) fn measure_and_replace_one(&self, mut f: F) -> Result + where + F: FnMut(Option) -> Result<(T, Nanos), E>, + { + let mut prev = self.0.load(Ordering::Acquire); + let mut decision = f(NonZeroU64::new(prev).map(|n| n.get().into())); + while let Ok((result, new_data)) = decision { + match self + .0 + .compare_exchange_weak(prev, new_data.into(), Ordering::Release, Ordering::Relaxed) + { + Ok(_) => return Ok(result), + Err(next_prev) => prev = next_prev, + } + decision = f(NonZeroU64::new(prev).map(|n| n.get().into())); + } + // This map shouldn't be needed, as we only get here in the error case, but the compiler + // can't see it. + decision.map(|(result, _)| result) + } + + pub(crate) fn is_older_than(&self, nanos: Nanos) -> bool { + self.0.load(Ordering::Relaxed) <= nanos.into() + } +} + +/// The InMemoryState is the canonical "direct" state store. +impl StateStore for InMemoryState { + type Key = NotKeyed; + + fn measure_and_replace(&self, _key: &Self::Key, f: F) -> Result + where + F: Fn(Option) -> Result<(T, Nanos), E>, + { + self.measure_and_replace_one(f) + } +} + +impl Debug for InMemoryState { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + let d = Duration::from_nanos(self.0.load(Ordering::Relaxed)); + write!(f, "InMemoryState({:?})", d) + } +} + +#[cfg(test)] +#[allow(clippy::needless_collect)] +mod test { + + use all_asserts::assert_gt; + + use super::*; + + fn try_triggering_collisions(n_threads: u64, tries_per_thread: u64) -> (u64, u64) { + use std::sync::Arc; + use std::thread; + + let mut state = Arc::new(InMemoryState(AtomicU64::new(0))); + let threads: Vec> = (0..n_threads) + .map(|_| { + thread::spawn({ + let state = Arc::clone(&state); + move || { + let mut hits = 0; + for _ in 0..tries_per_thread { + assert!(state + .measure_and_replace_one(|old| { + hits += 1; + Ok::<((), Nanos), ()>(((), Nanos::from(old.map(Nanos::as_u64).unwrap_or(0) + 1))) + }) + .is_ok()); + } + hits + } + }) + }) + .collect(); + let hits: u64 = threads.into_iter().map(|t| t.join().unwrap()).sum(); + let value = Arc::get_mut(&mut state).unwrap().0.get_mut(); + (*value, hits) + } + + #[test] + /// Checks that many threads running simultaneously will collide, + /// but result in the correct number being recorded in the state. + fn stresstest_collisions() { + use all_asserts::assert_gt; + + const THREADS: u64 = 8; + const MAX_TRIES: u64 = 20_000_000; + let (mut value, mut hits) = (0, 0); + for tries in (0..MAX_TRIES).step_by((MAX_TRIES / 100) as usize) { + let attempt = try_triggering_collisions(THREADS, tries); + value = attempt.0; + hits = attempt.1; + assert_eq!(value, tries * THREADS); + if hits > value { + break; + } + println!("Didn't trigger a collision in {} iterations", tries); + } + assert_gt!(hits, value); + } + + #[test] + fn in_memory_state_impls() { + let state = InMemoryState(AtomicU64::new(0)); + assert_gt!(format!("{:?}", state).len(), 0); + } +} diff --git a/http-rate/src/state/keyed.rs b/http-rate/src/state/keyed.rs new file mode 100644 index 00000000..503f92b1 --- /dev/null +++ b/http-rate/src/state/keyed.rs @@ -0,0 +1,226 @@ +use core::hash::Hash; + +use crate::{ + gcra::NotUntil, nanos::Nanos, quota::Quota, snapshot::RateSnapshot, state::RateLimiter, state::StateStore, timer, +}; + +#[cfg(test)] +use core::num::NonZeroU32; + +#[cfg(test)] +use crate::{error::InsufficientCapacity, timer::Reference}; + +// A trait for state stores with one rate limiting state per key. +// +// This is blanket-implemented by all `StateStore`s with hashable `Eq + Hash + Clone` key +// associated types. +pub(crate) trait KeyedStateStore: StateStore {} + +impl KeyedStateStore for T +where + T: StateStore, + K: Eq + Clone + Hash, +{ +} + +/// # Keyed rate limiters - default constructors +impl RateLimiter, timer::DefaultTimer> +where + K: Clone + Hash + Eq, +{ + #[cfg(test)] + // Constructs a new keyed rate limiter backed by + // the [`DefaultKeyedStateStore`]. + pub(crate) fn keyed(quota: Quota) -> Self { + let state = DefaultKeyedStateStore::default(); + let clock = timer::DefaultTimer; + RateLimiter::new(quota, state, &clock) + } + + /// Constructs a new keyed rate limiter explicitly backed by a + /// [`HashMap`][std::collections::HashMap]. + pub(crate) fn hashmap(quota: Quota) -> Self { + let state = HashMapStateStore::default(); + let timer = timer::DefaultTimer; + RateLimiter::new(quota, state, &timer) + } +} + +/// # Keyed rate limiters - Manually checking cells +impl RateLimiter +where + S: KeyedStateStore, + K: Hash, + C: timer::Timer, +{ + /// Allow a single cell through the rate limiter for the given key. + /// + /// If the rate limit is reached, `check_key` returns information about the earliest + /// time that a cell might be allowed through again under that key. + pub fn check_key(&self, key: &K) -> Result> { + self.gcra + .test_and_update::(self.start, key, &self.state, self.clock.now()) + } + + #[cfg(test)] + /// Allow *only all* `n` cells through the rate limiter for the given key. + /// + /// This method can succeed in only one way and fail in two ways: + /// * Success: If all `n` cells can be accommodated, it returns `Ok(Ok(()))`. + /// * Failure (but ok): Not all cells can make it through at the current time. + /// The result is `Ok(Err(NotUntil))`, which can + /// be interrogated about when the batch might next conform. + /// * Failure (the batch can never go through): The rate limit is too low for the given number + /// of cells. The result is `Err(InsufficientCapacity)` + /// + /// ### Performance + /// This method diverges a little from the GCRA algorithm, using + /// multiplication to determine the next theoretical arrival time, and so + /// is not as fast as checking a single cell. + pub(crate) fn check_key_n( + &self, + key: &K, + n: NonZeroU32, + ) -> Result>, InsufficientCapacity> { + self.gcra + .test_n_all_and_update::(self.start, key, n, &self.state, self.clock.now()) + } +} + +/// Keyed rate limiters that can be "cleaned up". +/// +/// Any keyed state store implementing this trait allows users to evict elements that are +/// indistinguishable from fresh rate-limiting states (that is, if a key hasn't been used for +/// rate-limiting decisions for as long as the bucket capacity). +/// +/// As this does not make sense for not all keyed state stores (e.g. stores that auto-expire like +/// memcache), this is an optional trait. All the keyed state stores in this crate implement +/// shrinking. +pub(crate) trait ShrinkableKeyedStateStore: KeyedStateStore { + /// Remove those keys with state older than `drop_below`. + fn retain_recent(&self, drop_below: Nanos); + + /// Shrinks the capacity of the state store, if possible. + /// + /// If the state store does not support shrinking, this method is a no-op. + fn shrink_to_fit(&self) {} + + /// Returns the number of "live" keys stored in the state store. + /// + /// Depending on how the state store is implemented, this may + /// return an estimate or an out-of-date result. + fn len(&self) -> usize; + + /// Returns `true` if `self` has no keys stored in it. + /// + /// As with [`len`](#tymethod.len), this method may return + /// imprecise results (indicating that the state store is empty + /// while a concurrent rate-limiting operation is taking place). + fn is_empty(&self) -> bool; +} + +#[cfg(test)] +/// # Keyed rate limiters - Housekeeping +/// +/// As the inputs to a keyed rate-limiter can be arbitrary keys, the set of retained keys retained +/// grows, while the number of active keys may stay smaller. To save on space, a keyed rate-limiter +/// allows removing those keys that are "stale", i.e., whose values are no different from keys' that +/// aren't present in the rate limiter state store. +impl RateLimiter +where + S: ShrinkableKeyedStateStore, + K: Hash, + C: timer::Timer, +{ + // Retains all keys in the rate limiter that were used recently enough. + // + // Any key whose rate limiting state is indistinguishable from a "fresh" state (i.e., the + // theoretical arrival time lies in the past). + pub(crate) fn retain_recent(&self) { + // calculate the minimum retention parameter: Any key whose state store's theoretical + // arrival time is larger than a starting state for the bucket gets to stay, everything + // else (that's indistinguishable from a starting state) goes. + let now = self.clock.now(); + let drop_below = now.duration_since(self.start); + + self.state.retain_recent(drop_below); + } + + // Shrinks the capacity of the rate limiter's state store, if possible. + pub(crate) fn shrink_to_fit(&self) { + self.state.shrink_to_fit(); + } + + // Returns the number of "live" keys in the rate limiter's state store. + // + // Depending on how the state store is implemented, this may + // return an estimate or an out-of-date result. + pub(crate) fn len(&self) -> usize { + self.state.len() + } + + // Returns `true` if the rate limiter has no keys in it. + // + // As with [`len`](#method.len), this method may return + // imprecise results (indicating that the state store is empty + // while a concurrent rate-limiting operation is taking place). + pub(crate) fn is_empty(&self) -> bool { + self.state.is_empty() + } +} + +mod hashmap; + +pub(crate) use hashmap::HashMapStateStore; + +pub(crate) type DefaultKeyedStateStore = HashMapStateStore; + +#[cfg(test)] +mod test { + use core::{marker::PhantomData, num::NonZeroU32}; + + use crate::timer::FakeRelativeClock; + + use super::*; + + #[test] + fn default_nonshrinkable_state_store_coverage() { + #[derive(Default)] + struct NaiveKeyedStateStore(PhantomData); + + impl StateStore for NaiveKeyedStateStore { + type Key = K; + + fn measure_and_replace(&self, _key: &Self::Key, f: F) -> Result + where + F: Fn(Option) -> Result<(T, Nanos), E>, + { + f(None).map(|(res, _)| res) + } + } + + impl ShrinkableKeyedStateStore for NaiveKeyedStateStore { + fn retain_recent(&self, _drop_below: Nanos) { + // nothing to do + } + + fn len(&self) -> usize { + 0 + } + fn is_empty(&self) -> bool { + true + } + } + + let lim: RateLimiter, FakeRelativeClock> = RateLimiter::new( + Quota::per_second(NonZeroU32::new(1).unwrap()), + NaiveKeyedStateStore::default(), + &FakeRelativeClock::default(), + ); + assert!(lim.check_key(&1u32).is_ok()); + assert!(lim.is_empty()); + assert_eq!(lim.len(), 0); + lim.retain_recent(); + lim.shrink_to_fit(); + } +} diff --git a/http-rate/src/state/keyed/hashmap.rs b/http-rate/src/state/keyed/hashmap.rs new file mode 100644 index 00000000..6965d94a --- /dev/null +++ b/http-rate/src/state/keyed/hashmap.rs @@ -0,0 +1,70 @@ +use core::hash::Hash; + +use std::{collections::HashMap, sync::Mutex}; + +use crate::{ + nanos::Nanos, + state::{keyed::ShrinkableKeyedStateStore, InMemoryState, StateStore}, +}; + +#[cfg(test)] +use crate::{quota::Quota, state::RateLimiter, timer}; + +/// A thread-safe (but not very performant) implementation of a keyed rate limiter state +/// store using [`HashMap`]. +/// +/// The `HashMapStateStore` is the default state store in `std` when no other thread-safe +/// features are enabled. +pub(crate) type HashMapStateStore = Mutex>; + +impl StateStore for HashMapStateStore { + type Key = K; + + fn measure_and_replace(&self, key: &Self::Key, f: F) -> Result + where + F: Fn(Option) -> Result<(T, Nanos), E>, + { + let mut map = self.lock().unwrap(); + if let Some(v) = (*map).get(key) { + // fast path: a rate limiter is already present for the key. + return v.measure_and_replace_one(f); + } + // not-so-fast path: make a new entry and measure it. + let entry = (*map).entry(key.clone()).or_default(); + entry.measure_and_replace_one(f) + } +} + +impl ShrinkableKeyedStateStore for HashMapStateStore { + fn retain_recent(&self, drop_below: Nanos) { + let mut map = self.lock().unwrap(); + map.retain(|_, v| !v.is_older_than(drop_below)); + } + + fn shrink_to_fit(&self) { + let mut map = self.lock().unwrap(); + map.shrink_to_fit(); + } + + fn len(&self) -> usize { + let map = self.lock().unwrap(); + (*map).len() + } + fn is_empty(&self) -> bool { + let map = self.lock().unwrap(); + (*map).is_empty() + } +} + +#[cfg(test)] +impl RateLimiter, C> +where + K: Hash + Eq + Clone, + C: timer::Timer, +{ + /// Constructs a new rate limiter with a custom clock, backed by a [`HashMap`]. + pub(crate) fn hashmap_with_clock(quota: Quota, clock: &C) -> Self { + let state = Mutex::new(HashMap::new()); + RateLimiter::new(quota, state, clock) + } +} diff --git a/http-rate/src/timer.rs b/http-rate/src/timer.rs new file mode 100644 index 00000000..144fa2df --- /dev/null +++ b/http-rate/src/timer.rs @@ -0,0 +1,194 @@ +use core::{ + fmt, + ops::Add, + sync::atomic::{AtomicU64, Ordering}, + time::Duration, +}; + +use std::{sync::Arc, time::Instant}; + +use super::nanos::Nanos; + +/// A measurement from a clock. +pub trait Reference: + Sized + Add + PartialEq + Eq + Ord + Copy + Clone + Send + Sync + fmt::Debug +{ + /// Determines the time that separates two measurements of a + /// clock. Implementations of this must perform a saturating + /// subtraction - if the `earlier` timestamp should be later, + /// `duration_since` must return the zero duration. + fn duration_since(&self, earlier: Self) -> Nanos; + + /// Returns a reference point that lies at most `duration` in the + /// past from the current reference. If an underflow should occur, + /// returns the current reference. + fn saturating_sub(&self, duration: Nanos) -> Self; +} + +/// A time source used by rate limiters. +pub trait Timer: Clone { + /// A measurement of a monotonically increasing clock. + type Instant: Reference; + + /// Returns a measurement of the clock. + fn now(&self) -> Self::Instant; +} + +impl Reference for Duration { + fn duration_since(&self, earlier: Self) -> Nanos { + self.checked_sub(earlier).unwrap_or_else(|| Duration::new(0, 0)).into() + } + + fn saturating_sub(&self, duration: Nanos) -> Self { + self.checked_sub(duration.into()).unwrap_or(*self) + } +} + +impl Add for Duration { + type Output = Self; + + fn add(self, other: Nanos) -> Self { + let other: Duration = other.into(); + self + other + } +} + +/// A mock implementation of a clock. All it does is keep track of +/// what "now" is (relative to some point meaningful to the program), +/// and returns that. +/// +/// # Thread safety +/// The mock time is represented as an atomic u64 count of nanoseconds, behind an [`Arc`]. +/// Clones of this clock will all show the same time, even if the original advances. +#[derive(Debug, Clone, Default)] +pub struct FakeRelativeClock { + now: Arc, +} + +impl FakeRelativeClock { + #[cfg(test)] + // Advances the fake clock by the given amount. + pub(crate) fn advance(&self, by: Duration) { + let by: u64 = by + .as_nanos() + .try_into() + .expect("Can not represent times past ~584 years"); + + let mut prev = self.now.load(Ordering::Acquire); + let mut next = prev + by; + while let Err(next_prev) = self + .now + .compare_exchange_weak(prev, next, Ordering::Release, Ordering::Relaxed) + { + prev = next_prev; + next = prev + by; + } + } +} + +impl PartialEq for FakeRelativeClock { + fn eq(&self, other: &Self) -> bool { + self.now.load(Ordering::Relaxed) == other.now.load(Ordering::Relaxed) + } +} + +impl Timer for FakeRelativeClock { + type Instant = Nanos; + + fn now(&self) -> Self::Instant { + self.now.load(Ordering::Relaxed).into() + } +} + +#[derive(Clone, Debug, Default)] +pub struct DefaultTimer; + +impl Add for Instant { + type Output = Instant; + + fn add(self, other: Nanos) -> Instant { + let other: Duration = other.into(); + self + other + } +} + +impl Reference for Instant { + fn duration_since(&self, earlier: Self) -> Nanos { + if earlier < *self { + (*self - earlier).into() + } else { + Nanos::from(Duration::new(0, 0)) + } + } + + fn saturating_sub(&self, duration: Nanos) -> Self { + self.checked_sub(duration.into()).unwrap_or(*self) + } +} + +impl Timer for DefaultTimer { + type Instant = Instant; + + fn now(&self) -> Self::Instant { + Instant::now() + } +} + +pub trait ReasonablyRealtime: Timer { + fn reference_point(&self) -> Self::Instant { + self.now() + } +} + +impl ReasonablyRealtime for DefaultTimer {} + +#[cfg(test)] +mod test { + use super::*; + use crate::nanos::Nanos; + use std::iter::repeat; + use std::sync::Arc; + use std::thread; + use std::time::Duration; + + #[test] + fn fake_clock_parallel_advances() { + let clock = Arc::new(FakeRelativeClock::default()); + let threads = repeat(()) + .take(10) + .map(move |_| { + let clock = Arc::clone(&clock); + thread::spawn(move || { + for _ in 0..1000000 { + let now = clock.now(); + clock.advance(Duration::from_nanos(1)); + assert!(clock.now() > now); + } + }) + }) + .collect::>(); + for t in threads { + t.join().unwrap(); + } + } + + #[test] + fn duration_addition_coverage() { + let d = Duration::from_secs(1); + let one_ns = Nanos::new(1); + assert!(d + one_ns > d); + } + + #[cfg(not(all(target_arch = "aarch64", target_os = "macos")))] + #[test] + fn instant_impls_coverage() { + let one_ns = Nanos::new(1); + let c = DefaultTimer; + let now = c.now(); + let ns_dur = Duration::from(one_ns); + assert_ne!(now + ns_dur, now, "{:?} + {:?}", ns_dur, now); + assert_eq!(one_ns, Reference::duration_since(&(now + one_ns), now)); + assert_eq!(Nanos::new(0), Reference::duration_since(&now, now + one_ns)); + assert_eq!(Reference::saturating_sub(&(now + Duration::from_nanos(1)), one_ns), now); + } +} diff --git a/web/CHANGES.md b/web/CHANGES.md index 6070466f..423295cc 100644 --- a/web/CHANGES.md +++ b/web/CHANGES.md @@ -1,6 +1,7 @@ # unreleased version 0.2.0 ## Add +- `RateLimit` middleware with optional feature `rate-limit`. - implement `Responder` trait for `serde_json::Value`. - re-export `http_ws::{ResponseSender, ResponseWeakSender}` types in `xitca_web::handler::websocket` module. @@ -23,6 +24,7 @@ ``` - update `xitca-http` to version `0.2.0`. - update `http-encoding` to version `0.2.0`. +- update `http-ws` to version `0.2.0`. ## Fix - fix nested App routing. `App::new().at("/foo", App::new().at("/"))` would be successfully matching against `/foo/` diff --git a/web/Cargo.toml b/web/Cargo.toml index 13bc2250..a67d7a22 100644 --- a/web/Cargo.toml +++ b/web/Cargo.toml @@ -50,6 +50,9 @@ websocket = ["http-ws/stream", "tokio/time"] # static file serving file = ["http-file", "nightly"] +# rate-limit middleware +rate-limit = ["http-rate"] + # nightly rust required feature nightly = [] @@ -107,6 +110,9 @@ http-ws = { version = "0.2", optional = true } # static file http-file = { version = "0.1", optional = true } +# rate limit +http-rate = { version = "0.1", optional = true } + # codegen xitca-codegen = { version = "0.1", optional = true } diff --git a/web/src/middleware/mod.rs b/web/src/middleware/mod.rs index 31ce3c52..877e1c2a 100644 --- a/web/src/middleware/mod.rs +++ b/web/src/middleware/mod.rs @@ -282,6 +282,8 @@ pub mod compress; #[cfg(any(feature = "compress-br", feature = "compress-gz", feature = "compress-de"))] pub mod decompress; +#[cfg(feature = "rate-limit")] +pub mod rate_limit; #[cfg(feature = "tower-http-compat")] pub mod tower_http_compat; diff --git a/web/src/middleware/rate_limit.rs b/web/src/middleware/rate_limit.rs new file mode 100644 index 00000000..d9a60bb8 --- /dev/null +++ b/web/src/middleware/rate_limit.rs @@ -0,0 +1,101 @@ +//! client ip address based rate limiting. + +use core::{convert::Infallible, time::Duration}; + +use crate::{ + body::ResponseBody, + error::Error, + http::WebResponse, + service::{ready::ReadyService, Service}, + WebContext, +}; + +use http_rate::Quota; + +pub struct RateLimit(Quota); + +macro_rules! constructor { + ($method: tt) => { + #[doc = concat!("Construct a RateLimit for a number of cells ",stringify!($method)," period. The given number of cells is")] + /// also assumed to be the maximum burst size. + /// + /// # Panics + /// - When max_burst is zero. + pub fn $method(max_burst: u32) -> Self { + Self(Quota::$method(max_burst)) + } + }; +} + +impl RateLimit { + constructor!(per_second); + constructor!(per_minute); + constructor!(per_hour); + + /// Construct a RateLimit that replenishes one cell in a given + /// interval. + /// + /// # Panics + /// - When the Duration is zero. + pub fn with_period(replenish_1_per: Duration) -> Self { + Self(Quota::with_period(replenish_1_per).unwrap()) + } +} + +impl Service> for RateLimit { + type Response = RateLimitService; + type Error = E; + + async fn call(&self, res: Result) -> Result { + res.map(|service| RateLimitService { + service, + rate_limit: http_rate::RateLimit::new(self.0), + }) + } +} + +pub struct RateLimitService { + service: S, + rate_limit: http_rate::RateLimit, +} + +impl<'r, C, B, S, ResB> Service> for RateLimitService +where + S: for<'r2> Service, Response = WebResponse, Error = Error>, +{ + type Response = WebResponse; + type Error = Error; + + async fn call(&self, ctx: WebContext<'r, C, B>) -> Result { + let headers = ctx.req().headers(); + let addr = ctx.req().body().socket_addr(); + let snap = self.rate_limit.rate_limit(headers, addr).map_err(Error::from_service)?; + self.service.call(ctx).await.map(|mut res| { + snap.extend_response(&mut res); + res + }) + } +} + +impl<'r, C, B> Service> for http_rate::TooManyRequests { + type Response = WebResponse; + type Error = Infallible; + + async fn call(&self, ctx: WebContext<'r, C, B>) -> Result { + let mut res = ctx.into_response(ResponseBody::empty()); + self.extend_response(&mut res); + Ok(res) + } +} + +impl ReadyService for RateLimitService +where + S: ReadyService, +{ + type Ready = S::Ready; + + #[inline] + async fn ready(&self) -> Self::Ready { + self.service.ready().await + } +}