diff --git a/governor/src/gcra.rs b/governor/src/gcra.rs index cf2a0b0b..ea5beba9 100644 --- a/governor/src/gcra.rs +++ b/governor/src/gcra.rs @@ -174,6 +174,125 @@ impl Gcra { } })) } + + /// Tests whether all `n` cells could be accommodated. + /// No update would be made. + pub(crate) fn test_n_all_without_update< + K, + P: clock::Reference, + S: StateStore, + MW: RateLimitingMiddleware

, + >( + &self, + start: P, + key: &K, + n: NonZeroU32, + state: &S, + t0: P, + ) -> Result, InsufficientCapacity> { + let t0 = t0.duration_since(start); + let tau = self.tau; + let t = self.t; + let additional_weight = t * (n.get() - 1) as u64; + + // check that we can allow enough cells through. Note that `additional_weight` is the + // value of the cells *in addition* to the first cell - so add that first cell back. + if additional_weight + t > tau { + return Err(InsufficientCapacity((tau.as_u64() / t.as_u64()) as u32)); + } + Ok(state.measure(key, |tat| { + let tat = tat.unwrap_or_else(|| self.starting_state(t0)); + let earliest_time = (tat + additional_weight).saturating_sub(tau); + if t0 < earliest_time { + Err(MW::disallow( + key, + StateSnapshot::new(self.t, self.tau, earliest_time, earliest_time), + start, + )) + } else { + let next = cmp::max(tat, t0) + t + additional_weight; + Ok(MW::allow( + key, + StateSnapshot::new(self.t, self.tau, t0, next), + )) + } + })) + } + + /// Tests a single cell against the rate limiter state. + /// No update would be made. + pub(crate) fn test_without_update< + K, + P: clock::Reference, + S: StateStore, + MW: RateLimitingMiddleware

, + >( + &self, + start: P, + key: &K, + state: &S, + t0: P, + ) -> Result { + let t0 = t0.duration_since(start); + let tau = self.tau; + let t = self.t; + state.measure(key, |tat| { + let tat = tat.unwrap_or_else(|| self.starting_state(t0)); + let earliest_time = tat.saturating_sub(tau); + if t0 < earliest_time { + Err(MW::disallow( + key, + StateSnapshot::new(self.t, self.tau, earliest_time, earliest_time), + start, + )) + } else { + let next = cmp::max(tat, t0) + t; + Ok(MW::allow( + key, + StateSnapshot::new(self.t, self.tau, t0, next), + )) + } + }) + } + + /// Update a single cell against the rate limiter state at the given key. + pub(crate) fn update>( + &self, + start: P, + key: &K, + state: &S, + t0: P, + ) { + let t0 = t0.duration_since(start); + let t = self.t; + let _ = state.measure_and_replace(key, |tat| { + let tat = tat.unwrap_or_else(|| self.starting_state(t0)); + let next = cmp::max(tat, t0) + t; + // always ask state to update + Ok::<((), Nanos), ()>(((), next)) + }); + } + + /// Update `n` cells for the rate limiter state. + pub(crate) fn update_n>( + &self, + start: P, + key: &K, + n: NonZeroU32, + state: &S, + t0: P, + ) { + let t0 = t0.duration_since(start); + let t = self.t; + let additional_weight = t * (n.get() - 1) as u64; + + let _ = state.measure_and_replace(key, |tat| { + let tat = tat.unwrap_or_else(|| self.starting_state(t0)); + + let next = cmp::max(tat, t0) + t + additional_weight; + Ok::<((), Nanos), ()>(((), next)) + }); + } } #[cfg(test)] diff --git a/governor/src/state.rs b/governor/src/state.rs index 329851d2..17d2ffa4 100644 --- a/governor/src/state.rs +++ b/governor/src/state.rs @@ -46,6 +46,11 @@ pub trait StateStore { fn measure_and_replace(&self, key: &Self::Key, f: F) -> Result where F: Fn(Option) -> Result<(T, Nanos), E>; + + /// Same as [`measure_and_replace`](`StateStore::measure_and_replace`), but it would not replace the value at the key + fn measure(&self, key: &Self::Key, f: F) -> Result + where + F: Fn(Option) -> Result; } /// A rate limiter. diff --git a/governor/src/state/direct.rs b/governor/src/state/direct.rs index 9310a6c8..c48f923e 100644 --- a/governor/src/state/direct.rs +++ b/governor/src/state/direct.rs @@ -108,6 +108,57 @@ where self.clock.now(), ) } + + /// same as `check`, but will not update internal state. + /// It would only query if the rate limit is reached. + pub fn check_only(&self) -> Result { + self.gcra + .test_without_update::( + self.start, + &NotKeyed::NonKey, + &self.state, + self.clock.now(), + ) + } + + /// same as `check_n`, but will not update internal state. + /// It would only query if all `n` cells can be accommodated. + pub fn check_n_only( + &self, + n: NonZeroU32, + ) -> Result, InsufficientCapacity> { + self.gcra + .test_n_all_without_update::( + self.start, + &NotKeyed::NonKey, + n, + &self.state, + self.clock.now(), + ) + } + + /// Consume a cell through the rate limiter. + /// If no cell is available, it would "borrow" from future cells. + pub fn consume(&self) { + self.gcra.update::( + self.start, + &NotKeyed::NonKey, + &self.state, + self.clock.now(), + ) + } + + /// Consume n cells through the rate limiter. + /// If no cell is available, it would "borrow" from future cells. + pub fn consume_n(&self, n: NonZeroU32) { + self.gcra.update_n::( + self.start, + &NotKeyed::NonKey, + n, + &self.state, + self.clock.now(), + ) + } } #[cfg(feature = "std")] diff --git a/governor/src/state/in_memory.rs b/governor/src/state/in_memory.rs index 100b0569..060f7110 100644 --- a/governor/src/state/in_memory.rs +++ b/governor/src/state/in_memory.rs @@ -44,6 +44,14 @@ impl InMemoryState { decision.map(|(result, _)| result) } + pub(crate) fn measure(&self, mut f: F) -> Result + where + F: FnMut(Option) -> Result, + { + let prev = self.0.load(Ordering::Acquire); + f(NonZeroU64::new(prev).map(|n| n.get().into())) + } + pub(crate) fn is_older_than(&self, nanos: Nanos) -> bool { self.0.load(Ordering::Relaxed) <= nanos.into() } @@ -59,6 +67,13 @@ impl StateStore for InMemoryState { { self.measure_and_replace_one(f) } + + fn measure(&self, _key: &Self::Key, f: F) -> Result + where + F: Fn(Option) -> Result, + { + self.measure(f) + } } impl Debug for InMemoryState { diff --git a/governor/src/state/keyed.rs b/governor/src/state/keyed.rs index 1ff92e77..b3730650 100644 --- a/governor/src/state/keyed.rs +++ b/governor/src/state/keyed.rs @@ -258,6 +258,13 @@ mod test { { f(None).map(|(res, _)| res) } + + fn measure(&self, _key: &Self::Key, f: F) -> Result + where + F: Fn(Option) -> Result, + { + f(None) + } } impl ShrinkableKeyedStateStore for NaiveKeyedStateStore { diff --git a/governor/src/state/keyed/dashmap.rs b/governor/src/state/keyed/dashmap.rs index fe3e1c6d..97b0d4d8 100755 --- a/governor/src/state/keyed/dashmap.rs +++ b/governor/src/state/keyed/dashmap.rs @@ -27,6 +27,19 @@ impl StateStore for DashMapStateStore { let entry = self.entry(key.clone()).or_default(); (*entry).measure_and_replace_one(f) } + + fn measure(&self, key: &Self::Key, f: F) -> Result + where + F: Fn(Option) -> Result, + { + if let Some(v) = self.get(key) { + // fast path: measure existing entry + return v.measure(f); + } + // make an entry and measure that: + let entry = self.entry(key.clone()).or_default(); + (*entry).measure(f) + } } /// # Keyed rate limiters - [`DashMap`]-backed diff --git a/governor/src/state/keyed/hashmap.rs b/governor/src/state/keyed/hashmap.rs index c95155b3..00d51aa4 100644 --- a/governor/src/state/keyed/hashmap.rs +++ b/governor/src/state/keyed/hashmap.rs @@ -35,6 +35,20 @@ impl StateStore for HashMapStateStore { let entry = (*map).entry(key.clone()).or_default(); entry.measure_and_replace_one(f) } + + fn measure(&self, key: &Self::Key, f: F) -> Result + where + F: Fn(Option) -> Result, + { + let mut map = self.lock(); + if let Some(v) = (*map).get(key) { + // fast path: a rate limiter is already present for the key. + return v.measure(f); + } + // not-so-fast path: make a new entry and measure it. + let entry = (*map).entry(key.clone()).or_default(); + entry.measure(f) + } } impl ShrinkableKeyedStateStore for HashMapStateStore { diff --git a/governor/tests/direct.rs b/governor/tests/direct.rs index ca21f13d..802bcaf4 100644 --- a/governor/tests/direct.rs +++ b/governor/tests/direct.rs @@ -12,6 +12,24 @@ fn accepts_first_cell() { assert_eq!(Ok(()), lb.check()); } +#[test] +fn accepts_first_cell_check_only() { + let clock = FakeRelativeClock::default(); + let lb = RateLimiter::direct_with_clock(Quota::per_second(nonzero!(1u32)), &clock); + assert_eq!(Ok(()), lb.check_only()); + // last call does not consume cells + assert_eq!(Ok(()), lb.check_only()); +} + +#[test] +fn accepts_cells_check_n_only() { + let clock = FakeRelativeClock::default(); + let lb = RateLimiter::direct_with_clock(Quota::per_second(nonzero!(3u32)), &clock); + assert_eq!(Ok(Ok(())), lb.check_n_only(nonzero!(3u32))); + // last call does not consume cells + assert_eq!(Ok(Ok(())), lb.check_n_only(nonzero!(3u32))); +} + #[test] fn rejects_too_many() { let clock = FakeRelativeClock::default(); @@ -36,6 +54,58 @@ fn rejects_too_many() { assert_ne!(Ok(()), lb.check(), "{:?}", lb); } +#[test] +fn rejects_too_many_check_only() { + let clock = FakeRelativeClock::default(); + let lb = RateLimiter::direct_with_clock(Quota::per_second(nonzero!(2u32)), &clock); + let ms = Duration::from_millis(1); + + // use up our burst capacity (2 in the first second): + assert_eq!(Ok(()), lb.check(), "Now: {:?}", clock.now()); + clock.advance(ms); + assert_eq!(Ok(()), lb.check(), "Now: {:?}", clock.now()); + + clock.advance(ms); + assert_ne!(Ok(()), lb.check_only(), "Now: {:?}", clock.now()); + + // should be ok again in 1s: + clock.advance(ms * 1000); + assert_eq!(Ok(()), lb.check(), "Now: {:?}", clock.now()); + clock.advance(ms); + assert_eq!(Ok(()), lb.check()); + + clock.advance(ms); + assert_ne!(Ok(()), lb.check_only(), "{:?}", lb); +} + +#[test] +fn rejects_too_many_with_consume_and_check_only() { + let clock = FakeRelativeClock::default(); + let lb = RateLimiter::direct_with_clock(Quota::per_second(nonzero!(2u32)), &clock); + let ms = Duration::from_millis(1); + + // use up our burst capacity (2 in the first second): + assert_eq!(Ok(()), lb.check_only(), "Now: {:?}", clock.now()); + lb.consume(); + clock.advance(ms); + assert_eq!(Ok(()), lb.check_only(), "Now: {:?}", clock.now()); + lb.consume(); + + clock.advance(ms); + assert_ne!(Ok(()), lb.check_only(), "Now: {:?}", clock.now()); + + // should be ok again in 1s: + clock.advance(ms * 1000); + assert_eq!(Ok(()), lb.check_only(), "Now: {:?}", clock.now()); + lb.consume(); + clock.advance(ms); + assert_eq!(Ok(()), lb.check_only(), "Now: {:?}", clock.now()); + lb.consume(); + + clock.advance(ms); + assert_ne!(Ok(()), lb.check_only(), "{:?}", lb); +} + #[test] fn all_1_identical_to_1() { let clock = FakeRelativeClock::default(); @@ -61,6 +131,135 @@ fn all_1_identical_to_1() { assert_ne!(Ok(Ok(())), lb.check_n(one), "{:?}", lb); } +#[test] +fn all_1_identical_to_1_check_only() { + let clock = FakeRelativeClock::default(); + let lb = RateLimiter::direct_with_clock(Quota::per_second(nonzero!(2u32)), &clock); + let ms = Duration::from_millis(1); + let one = nonzero!(1u32); + + // use up our burst capacity (2 in the first second): + assert_eq!(Ok(Ok(())), lb.check_n(one), "Now: {:?}", clock.now()); + clock.advance(ms); + assert_eq!(Ok(Ok(())), lb.check_n(one), "Now: {:?}", clock.now()); + + clock.advance(ms); + assert_ne!(Ok(Ok(())), lb.check_n_only(one), "Now: {:?}", clock.now()); + + // should be ok again in 1s: + clock.advance(ms * 1000); + assert_eq!(Ok(Ok(())), lb.check_n(one), "Now: {:?}", clock.now()); + clock.advance(ms); + assert_eq!(Ok(Ok(())), lb.check_n(one)); + + clock.advance(ms); + assert_ne!(Ok(Ok(())), lb.check_n_only(one), "{:?}", lb); +} + +#[test] +fn all_1_identical_to_1_consume_and_check_only() { + let clock = FakeRelativeClock::default(); + let lb = RateLimiter::direct_with_clock(Quota::per_second(nonzero!(2u32)), &clock); + let ms = Duration::from_millis(1); + let one = nonzero!(1u32); + + // use up our burst capacity (2 in the first second): + assert_eq!(Ok(Ok(())), lb.check_n_only(one), "Now: {:?}", clock.now()); + lb.consume_n(one); + clock.advance(ms); + assert_eq!(Ok(Ok(())), lb.check_n_only(one), "Now: {:?}", clock.now()); + lb.consume_n(one); + + clock.advance(ms); + assert_ne!(Ok(Ok(())), lb.check_n_only(one), "Now: {:?}", clock.now()); + + // should be ok again in 1s: + clock.advance(ms * 1000); + assert_eq!(Ok(Ok(())), lb.check_n_only(one), "Now: {:?}", clock.now()); + lb.consume_n(one); + clock.advance(ms); + assert_eq!(Ok(Ok(())), lb.check_n_only(one)); + lb.consume_n(one); + + clock.advance(ms); + assert_ne!(Ok(Ok(())), lb.check_n_only(one), "{:?}", lb); +} + +#[test] +fn consume_n_borrow_from_future() { + let clock = FakeRelativeClock::default(); + let lb = RateLimiter::direct_with_clock(Quota::per_second(nonzero!(2u32)), &clock); + let ms = Duration::from_millis(1); + + assert_eq!( + Ok(Ok(())), + lb.check_n_only(nonzero!(2u32)), + "Now: {:?}", + clock.now() + ); + lb.consume_n(nonzero!(2u32)); + // consumed all cells + assert_ne!( + Ok(Ok(())), + lb.check_n_only(nonzero!(2u32)), + "Now: {:?}", + clock.now() + ); + + // borrow from future + lb.consume_n(nonzero!(2u32)); + // consumed all cells + assert_ne!( + Ok(Ok(())), + lb.check_n_only(nonzero!(2u32)), + "Now: {:?}", + clock.now() + ); + + // borrowed cells not paid off + clock.advance(1000 * ms); + assert_ne!( + Ok(Ok(())), + lb.check_n_only(nonzero!(2u32)), + "Now: {:?}", + clock.now() + ); + + // paid off + clock.advance(1000 * ms); + assert_eq!( + Ok(Ok(())), + lb.check_n_only(nonzero!(2u32)), + "Now: {:?}", + clock.now() + ); +} + +#[test] +fn consume_borrow_from_future() { + let clock = FakeRelativeClock::default(); + let lb = RateLimiter::direct_with_clock(Quota::per_second(nonzero!(1u32)), &clock); + let ms = Duration::from_millis(1); + + assert_eq!(Ok(()), lb.check_only(), "Now: {:?}", clock.now()); + lb.consume(); + // consumed all cells + assert_ne!(Ok(()), lb.check_only(), "Now: {:?}", clock.now()); + + // borrow from future + lb.consume(); + // consumed all cells + assert_ne!(Ok(()), lb.check_only(), "Now: {:?}", clock.now()); + + // borrowed cells not paid off + clock.advance(1000 * ms); + assert_ne!(Ok(()), lb.check_only(), "Now: {:?}", clock.now()); + + // paid off + clock.advance(1000 * ms); + assert_eq!(Ok(()), lb.check_only(), "Now: {:?}", clock.now()); +} + #[test] fn never_allows_more_than_capacity_all() { let clock = FakeRelativeClock::default(); @@ -84,6 +283,29 @@ fn never_allows_more_than_capacity_all() { assert_ne!(Ok(Ok(())), lb.check_n(nonzero!(2u32)), "{:?}", lb); } +#[test] +fn never_allows_more_than_capacity_all_check_only() { + let clock = FakeRelativeClock::default(); + let lb = RateLimiter::direct_with_clock(Quota::per_second(nonzero!(4u32)), &clock); + let ms = Duration::from_millis(1); + + // Use up the burst capacity: + assert_eq!(Ok(Ok(())), lb.check_n(nonzero!(2u32))); + assert_eq!(Ok(Ok(())), lb.check_n(nonzero!(2u32))); + + clock.advance(ms); + assert_ne!(Ok(Ok(())), lb.check_n_only(nonzero!(2u32))); + + // should be ok again in 1s: + clock.advance(ms * 1000); + assert_eq!(Ok(Ok(())), lb.check_n(nonzero!(2u32))); + clock.advance(ms); + assert_eq!(Ok(Ok(())), lb.check_n(nonzero!(2u32))); + + clock.advance(ms); + assert_ne!(Ok(Ok(())), lb.check_n_only(nonzero!(2u32)), "{:?}", lb); +} + #[test] fn rejects_too_many_all() { let clock = FakeRelativeClock::default(); @@ -98,6 +320,20 @@ fn rejects_too_many_all() { assert_ne!(Ok(Ok(())), lb.check_n(nonzero!(15u32))); } +#[test] +fn rejects_too_many_all_check_n_only() { + let clock = FakeRelativeClock::default(); + let lb = RateLimiter::direct_with_clock(Quota::per_second(nonzero!(5u32)), &clock); + let ms = Duration::from_millis(1); + + // Should not allow the first 15 cells on a capacity 5 bucket: + assert_ne!(Ok(Ok(())), lb.check_n_only(nonzero!(15u32))); + + // After 3 and 20 seconds, it should not allow 15 on that bucket either: + clock.advance(ms * 3 * 1000); + assert_ne!(Ok(Ok(())), lb.check_n_only(nonzero!(15u32))); +} + #[test] fn all_capacity_check_rejects_excess() { let clock = FakeRelativeClock::default(); @@ -108,6 +344,25 @@ fn all_capacity_check_rejects_excess() { assert_eq!(Err(InsufficientCapacity(5)), lb.check_n(nonzero!(7u32))); } +#[test] +fn all_capacity_check_rejects_excess_check_n_only() { + let clock = FakeRelativeClock::default(); + let lb = RateLimiter::direct_with_clock(Quota::per_second(nonzero!(5u32)), &clock); + + assert_eq!( + Err(InsufficientCapacity(5)), + lb.check_n_only(nonzero!(15u32)) + ); + assert_eq!( + Err(InsufficientCapacity(5)), + lb.check_n_only(nonzero!(6u32)) + ); + assert_eq!( + Err(InsufficientCapacity(5)), + lb.check_n_only(nonzero!(7u32)) + ); +} + #[test] fn correct_wait_time() { let clock = FakeRelativeClock::default();