Skip to content
This repository has been archived by the owner on Feb 8, 2024. It is now read-only.

Commit

Permalink
Add billing limiter (#33)
Browse files Browse the repository at this point in the history
* Add Redis lib

* `cargo update`

* fmt

* Add base implementation of billing limiter

Supports

1. A fixed set of limits, with no redis update
2. A fixed set, subsequently updated from redis
3. No fixed set, updates from redis

I still need to figure out how to nicely mock the redis connection that
stll leaves enough not mocked to be worth testing.

I really don't want integration tests on it :(

Also still needs connecting to the API. Reading through the python for
this is like 😵‍💫

* Rework

I've reworked it a bunch. Honestly the background loop worked but it
became really horrible and the locking behaviour a little sketchy. While
this will slow down some requests a bit, unless it becomes measurably
slow let's keep it that way rather than introducing a bit of a horrible
pattern.

* hook it all up

* Add redis read timeout

* Add non-cluster client

* Respond to feedback
  • Loading branch information
Ellie Huxtable authored Oct 19, 2023
1 parent 7e74df8 commit c03638b
Show file tree
Hide file tree
Showing 11 changed files with 527 additions and 86 deletions.
257 changes: 179 additions & 78 deletions Cargo.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions capture-server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ axum = { workspace = true }
tokio = { workspace = true }
tracing-subscriber = { workspace = true }
tracing = { workspace = true }
time = { workspace = true }
27 changes: 24 additions & 3 deletions capture-server/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use std::env;
use std::net::SocketAddr;
use std::sync::Arc;

use capture::{router, sink, time};
use capture::{billing_limits::BillingLimiter, redis::RedisClient, router, sink};
use time::Duration;
use tokio::signal;

async fn shutdown() {
Expand All @@ -23,16 +25,35 @@ async fn shutdown() {
async fn main() {
let use_print_sink = env::var("PRINT_SINK").is_ok();
let address = env::var("ADDRESS").unwrap_or(String::from("127.0.0.1:3000"));
let redis_addr = env::var("REDIS").expect("redis required; please set the REDIS env var");

let redis_client =
Arc::new(RedisClient::new(redis_addr).expect("failed to create redis client"));

let billing = BillingLimiter::new(Duration::seconds(5), redis_client.clone())
.expect("failed to create billing limiter");

let app = if use_print_sink {
router::router(time::SystemTime {}, sink::PrintSink {}, true)
router::router(
capture::time::SystemTime {},
sink::PrintSink {},
redis_client,
billing,
true,
)
} else {
let brokers = env::var("KAFKA_BROKERS").expect("Expected KAFKA_BROKERS");
let topic = env::var("KAFKA_TOPIC").expect("Expected KAFKA_TOPIC");

let sink = sink::KafkaSink::new(topic, brokers).unwrap();

router::router(time::SystemTime {}, sink, true)
router::router(
capture::time::SystemTime {},
sink,
redis_client,
billing,
true,
)
};

// initialize tracing
Expand Down
2 changes: 2 additions & 0 deletions capture/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,10 @@ rdkafka = { workspace = true }
metrics = { workspace = true }
metrics-exporter-prometheus = { workspace = true }
thiserror = { workspace = true }
redis = { version="0.23.3", features=["tokio-comp", "cluster", "cluster-async"] }

[dev-dependencies]
assert-json-diff = "2.0.2"
axum-test-helper = "0.2.0"
mockall = "0.11.2"
redis-test = "0.2.3"
12 changes: 12 additions & 0 deletions capture/src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@ pub enum CaptureError {
EventTooBig,
#[error("invalid event could not be processed")]
NonRetryableSinkError,

#[error("billing limit reached")]
BillingLimit,

#[error("rate limited")]
RateLimited,
}

impl IntoResponse for CaptureError {
Expand All @@ -64,10 +70,16 @@ impl IntoResponse for CaptureError {
| CaptureError::MissingDistinctId
| CaptureError::EventTooBig
| CaptureError::NonRetryableSinkError => (StatusCode::BAD_REQUEST, self.to_string()),

CaptureError::NoTokenError
| CaptureError::MultipleTokensError
| CaptureError::TokenValidationError(_) => (StatusCode::UNAUTHORIZED, self.to_string()),

CaptureError::RetryableSinkError => (StatusCode::SERVICE_UNAVAILABLE, self.to_string()),

CaptureError::BillingLimit | CaptureError::RateLimited => {
(StatusCode::TOO_MANY_REQUESTS, self.to_string())
}
}
.into_response()
}
Expand Down
188 changes: 188 additions & 0 deletions capture/src/billing_limits.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
use std::{collections::HashSet, ops::Sub, sync::Arc};

use crate::redis::Client;

/// Limit accounts by team ID if they hit a billing limit
///
/// We have an async celery worker that regularly checks on accounts + assesses if they are beyond
/// a billing limit. If this is the case, a key is set in redis.
///
/// Requirements
///
/// 1. Updates from the celery worker should be reflected in capture within a short period of time
/// 2. Capture should cope with redis being _totally down_, and fail open
/// 3. We should not hit redis for every single request
///
/// The solution here is to read from the cache until a time interval is hit, and then fetch new
/// data. The write requires taking a lock that stalls all readers, though so long as redis reads
/// stay fast we're ok.
///
/// Some small delay between an account being limited and the limit taking effect is acceptable.
/// However, ideally we should not allow requests from some pods but 429 from others.
use thiserror::Error;
use time::{Duration, OffsetDateTime};
use tokio::sync::RwLock;

// todo: fetch from env
const QUOTA_LIMITER_CACHE_KEY: &str = "@posthog/quota-limits/";

pub enum QuotaResource {
Events,
Recordings,
}

impl QuotaResource {
fn as_str(&self) -> &'static str {
match self {
Self::Events => "events",
Self::Recordings => "recordings",
}
}
}

#[derive(Error, Debug)]
pub enum LimiterError {
#[error("updater already running - there can only be one")]
UpdaterRunning,
}

#[derive(Clone)]
pub struct BillingLimiter {
limited: Arc<RwLock<HashSet<String>>>,
redis: Arc<dyn Client + Send + Sync>,
interval: Duration,
updated: Arc<RwLock<time::OffsetDateTime>>,
}

impl BillingLimiter {
/// Create a new BillingLimiter.
///
/// This connects to a redis cluster - pass in a vec of addresses for the initial nodes.
///
/// You can also initialize the limiter with a set of tokens to limit from the very beginning.
/// This may be overridden by Redis, if the sets differ,
///
/// Pass an empty redis node list to only use this initial set.
pub fn new(
interval: Duration,
redis: Arc<dyn Client + Send + Sync>,
) -> anyhow::Result<BillingLimiter> {
let limited = Arc::new(RwLock::new(HashSet::new()));

// Force an update immediately if we have any reasonable interval
let updated = OffsetDateTime::from_unix_timestamp(0)?;
let updated = Arc::new(RwLock::new(updated));

Ok(BillingLimiter {
interval,
limited,
updated,
redis,
})
}

async fn fetch_limited(
client: &Arc<dyn Client + Send + Sync>,
resource: QuotaResource,
) -> anyhow::Result<Vec<String>> {
let now = time::OffsetDateTime::now_utc().unix_timestamp();

client
.zrangebyscore(
format!("{QUOTA_LIMITER_CACHE_KEY}{}", resource.as_str()),
now.to_string(),
String::from("+Inf"),
)
.await
}

pub async fn is_limited(&self, key: &str, resource: QuotaResource) -> bool {
// hold the read lock to clone it, very briefly. clone is ok because it's very small 🤏
// rwlock can have many readers, but one writer. the writer will wait in a queue with all
// the readers, so we want to hold read locks for the smallest time possible to avoid
// writers waiting for too long. and vice versa.
let updated = {
let updated = self.updated.read().await;
*updated
};

let now = OffsetDateTime::now_utc();
let since_update = now.sub(updated);

// If an update is due, fetch the set from redis + cache it until the next update is due.
// Otherwise, return a value from the cache
//
// This update will block readers! Keep it fast.
if since_update > self.interval {
let span = tracing::debug_span!("updating billing cache from redis");
let _span = span.enter();

// a few requests might end up in here concurrently, but I don't think a few extra will
// be a big problem. If it is, we can rework the concurrency a bit.
// On prod atm we call this around 15 times per second at peak times, and it usually
// completes in <1ms.

let set = Self::fetch_limited(&self.redis, resource).await;

tracing::debug!("fetched set from redis, caching");

if let Ok(set) = set {
let set = HashSet::from_iter(set.iter().cloned());

let mut limited = self.limited.write().await;
*limited = set;

tracing::debug!("updated cache from redis");

limited.contains(key)
} else {
tracing::error!("failed to fetch from redis in time, failing open");
// If we fail to fetch the set, something really wrong is happening. To avoid
// dropping events that we don't mean to drop, fail open and accept data. Better
// than angry customers :)
//
// TODO: Consider backing off our redis checks
false
}
} else {
let l = self.limited.read().await;

l.contains(key)
}
}
}

#[cfg(test)]
mod tests {
use std::sync::Arc;
use time::Duration;

use crate::{
billing_limits::{BillingLimiter, QuotaResource},
redis::MockRedisClient,
};

#[tokio::test]
async fn test_dynamic_limited() {
let client = MockRedisClient::new().zrangebyscore_ret(vec![String::from("banana")]);
let client = Arc::new(client);

let limiter = BillingLimiter::new(Duration::microseconds(1), client)
.expect("Failed to create billing limiter");

assert_eq!(
limiter
.is_limited("idk it doesn't matter", QuotaResource::Events)
.await,
false
);

assert_eq!(
limiter
.is_limited("some_org_hit_limits", QuotaResource::Events)
.await,
false
);
assert!(limiter.is_limited("banana", QuotaResource::Events).await);
}
}
24 changes: 22 additions & 2 deletions capture/src/capture.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use axum_client_ip::InsecureClientIp;
use base64::Engine;
use time::OffsetDateTime;

use crate::billing_limits::QuotaResource;
use crate::event::ProcessingContext;
use crate::token::validate_token;
use crate::{
Expand Down Expand Up @@ -44,7 +45,7 @@ pub async fn event(
_ => RawEvent::from_bytes(&meta, body),
}?;

println!("Got events {:?}", &events);
tracing::debug!("got events {:?}", &events);

if events.is_empty() {
return Err(CaptureError::EmptyBatch);
Expand All @@ -61,6 +62,7 @@ pub async fn event(
}
None
});

let context = ProcessingContext {
lib_version: meta.lib_version.clone(),
sent_at,
Expand All @@ -69,7 +71,25 @@ pub async fn event(
client_ip: ip.to_string(),
};

println!("Got context {:?}", &context);
let limited = state
.billing
.is_limited(context.token.as_str(), QuotaResource::Events)
.await;

if limited {
// for v0 we want to just return ok 🙃
// this is because the clients are pretty dumb and will just retry over and over and
// over...
//
// for v1, we'll return a meaningful error code and error, so that the clients can do
// something meaningful with that error

return Ok(Json(CaptureResponse {
status: CaptureResponseCode::Ok,
}));
}

tracing::debug!("got context {:?}", &context);

process_events(state.sink.clone(), &events, &context).await?;

Expand Down
2 changes: 2 additions & 0 deletions capture/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
pub mod api;
pub mod billing_limits;
pub mod capture;
pub mod event;
pub mod prometheus;
pub mod redis;
pub mod router;
pub mod sink;
pub mod time;
Expand Down
Loading

0 comments on commit c03638b

Please sign in to comment.