Skip to content

Commit

Permalink
feature: enable rate limit for connections api (#45)
Browse files Browse the repository at this point in the history
  • Loading branch information
sagojez authored May 24, 2024
1 parent e71ad30 commit 310cf7f
Show file tree
Hide file tree
Showing 5 changed files with 139 additions and 123 deletions.
93 changes: 0 additions & 93 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion integrationos-api/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ strum.workspace = true
tokio.workspace = true
tower = { version = "0.4.13", features = ["filter"] }
tower-http.workspace = true
tower_governor = "0.3.2"
tracing-subscriber.workspace = true
tracing.workspace = true
validator.workspace = true
Expand Down
5 changes: 4 additions & 1 deletion integrationos-api/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ pub struct Config {
pub openai_config: OpenAiConfig,
#[envconfig(nested = true)]
pub redis_config: RedisConfig,
#[envconfig(from = "RATE_LIMIT_ENABLED", default = "true")]
pub rate_limit_enabled: bool,
}

impl Display for Config {
Expand Down Expand Up @@ -117,7 +119,8 @@ impl Display for Config {
writeln!(f, "{}", self.headers)?;
writeln!(f, "{}", self.db_config)?;
writeln!(f, "{}", self.openai_config)?;
writeln!(f, "{}", self.redis_config)
writeln!(f, "{}", self.redis_config)?;
writeln!(f, "RATE_LIMIT_ENABLED: {}", self.rate_limit_enabled)
}
}

Expand Down
137 changes: 123 additions & 14 deletions integrationos-api/src/middleware/extractor.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,130 @@
use anyhow::Result;
use integrationos_domain::event_access::EventAccess;
use serde::{Deserialize, Serialize};
use crate::{metrics::Metric, server::AppState, too_many_requests};
use anyhow::{Context, Result};
use axum::{
body::Body,
extract::State,
middleware::Next,
response::{IntoResponse, Response},
Extension,
};
use http::{HeaderName, Request};
use integrationos_domain::{event_access::EventAccess, RedisCache};
use redis::AsyncCommands;
use std::sync::Arc;
use tower_governor::{errors::GovernorError, key_extractor::KeyExtractor};
use tokio::sync::{
mpsc::{channel, Sender},
oneshot,
};
use tracing::warn;

#[derive(Debug, Serialize, Deserialize, Clone, Eq, PartialEq)]
pub struct OwnershipId;
#[derive(Debug, Clone)]
pub struct RateLimiter {
tx: Sender<(Arc<str>, oneshot::Sender<u64>)>,
key_header_name: HeaderName,
limit_header_name: HeaderName,
remaining_header_name: HeaderName,
reset_header_name: HeaderName,
metric_tx: Sender<Metric>,
}

impl RateLimiter {
pub async fn new(state: Arc<AppState>) -> Result<Self> {
if state.config.rate_limit_enabled {
return Err(anyhow::anyhow!("Rate limiting is disabled"));
};

let mut redis = RedisCache::new(&state.config.redis_config, 0)
.await
.with_context(|| "Could not connect to redis")?;

let (tx, mut rx) = channel::<(Arc<str>, oneshot::Sender<u64>)>(1024);

let throughput_key = state.config.redis_config.api_throughput_key.clone();

tokio::spawn(async move {
while let Some((id, tx)) = rx.recv().await {
let count: u64 = redis
.hincr(&throughput_key, id.as_ref(), 1)
.await
.unwrap_or_default();
let _ = tx.send(count);
}
});

let key_header_name =
HeaderName::from_lowercase(state.config.headers.connection_header.as_bytes()).unwrap();

let limit_header_name =
HeaderName::from_lowercase(state.config.headers.rate_limit_limit.as_bytes()).unwrap();

let remaining_header_name =
HeaderName::from_lowercase(state.config.headers.rate_limit_remaining.as_bytes())
.unwrap();

let reset_header_name =
HeaderName::from_lowercase(state.config.headers.rate_limit_reset.as_bytes()).unwrap();

Ok(RateLimiter {
tx,
metric_tx: state.metric_tx.clone(),
key_header_name,
limit_header_name,
remaining_header_name,
reset_header_name,
})
}

pub async fn get_request_count(&self, id: Arc<str>) -> u64 {
let (tx, rx) = oneshot::channel();
match self.tx.send((id, tx)).await {
Ok(()) => rx.await.unwrap_or_default(),
Err(e) => {
warn!("Could not send to redis task: {e}");
0
}
}
}
}

pub async fn rate_limit(
Extension(event_access): Extension<Arc<EventAccess>>,
State(state): State<Arc<RateLimiter>>,
req: Request<Body>,
next: Next,
) -> Result<Response, Response> {
let throughput = event_access.throughput;

let count = state
.get_request_count(event_access.ownership.id.clone())
.await;

if count >= throughput {
let _ = state
.metric_tx
.send(Metric::rate_limited(
event_access.clone(),
req.headers().get(&state.key_header_name).cloned(),
))
.await;
let mut res = too_many_requests!().into_response();

let headers = res.headers_mut();

impl KeyExtractor for OwnershipId {
type Key = String;
headers.insert(state.limit_header_name.clone(), throughput.into());
headers.insert(state.remaining_header_name.clone(), 0.into());
headers.insert(state.reset_header_name.clone(), 60.into());

fn extract<T>(&self, req: &http::request::Request<T>) -> Result<Self::Key, GovernorError> {
let event_access = req
.extensions()
.get::<Arc<EventAccess>>()
.ok_or_else(|| GovernorError::UnableToExtractKey)?;
Err(res)
} else {
let mut res = next.run(req).await;
let headers = res.headers_mut();

Ok(event_access.ownership.id.to_string())
headers.insert(state.limit_header_name.clone(), throughput.into());
headers.insert(
state.remaining_header_name.clone(),
(throughput - count).into(),
);
headers.insert(state.reset_header_name.clone(), 60.into());
Ok(res)
}
}
26 changes: 12 additions & 14 deletions integrationos-api/src/routes/protected.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::{
middleware::{
auth,
blocker::{handle_blocked_error, BlockInvalidHeaders},
extractor::OwnershipId,
extractor::{rate_limit, RateLimiter},
},
server::AppState,
};
Expand All @@ -24,8 +24,8 @@ use http::HeaderName;
use integrationos_domain::connection_model_schema::PublicConnectionModelSchema;
use std::{iter::once, sync::Arc};
use tower::{filter::FilterLayer, ServiceBuilder};
use tower_governor::{governor::GovernorConfigBuilder, GovernorLayer};
use tower_http::{sensitive_headers::SetSensitiveRequestHeadersLayer, trace::TraceLayer};
use tracing::warn;

pub async fn get_router(state: &Arc<AppState>) -> Router<Arc<AppState>> {
let routes = Router::new()
Expand All @@ -51,20 +51,18 @@ pub async fn get_router(state: &Arc<AppState>) -> Router<Arc<AppState>> {
.layer(TraceLayer::new_for_http())
.nest("/metrics", metrics::get_router());

let config = Box::new(
GovernorConfigBuilder::default()
.per_second(state.config.burst_rate_limit)
.burst_size(state.config.burst_size)
.key_extractor(OwnershipId)
.use_headers()
.finish()
.expect("Failed to build GovernorConfig"),
);
let routes = match RateLimiter::new(state.clone()).await {
Ok(rate_limiter) => routes.layer(axum::middleware::from_fn_with_state(
Arc::new(rate_limiter),
rate_limit,
)),
Err(e) => {
warn!("Could not connect to redis: {e}");
routes
}
};

routes
.layer(GovernorLayer {
config: Box::leak(config),
})
.layer(from_fn_with_state(state.clone(), auth::auth))
.layer(TraceLayer::new_for_http())
.layer(SetSensitiveRequestHeadersLayer::new(once(
Expand Down

0 comments on commit 310cf7f

Please sign in to comment.