diff --git a/Cargo.lock b/Cargo.lock index 7cc582e..c56738f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1076,6 +1076,7 @@ dependencies = [ "sqlx", "tokio", "tower", + "tower-http", "tracing", "tracing-subscriber", "url", diff --git a/Cargo.toml b/Cargo.toml index d1f7527..e097eab 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,7 +13,7 @@ members = [ [workspace.lints.rust] # See https://doc.rust-lang.org/stable/rustc/lints/listing/allowed-by-default.html -unsafe_code = "forbid" # forbid cannot be ignored with an annotation +unsafe_code = "forbid" # forbid cannot be ignored with an annotation unstable_features = "forbid" macro_use_extern_crate = "forbid" let_underscore_drop = "deny" @@ -69,7 +69,7 @@ time = { version = "0.3.20", features = [ thiserror = { version = "1.0" } tokio = { version = "1.34.0", features = ["full"] } tower = "0.4.13" -tower-http = { version = "0.5.2", features = ["cors", "trace"] } +tower-http = { version = "0.5.2", features = ["cors", "limit", "trace"] } tracing = "0.1.40" tracing-subscriber = "0.3.18" url = { version = "2.5.0 " } diff --git a/hook-api/Cargo.toml b/hook-api/Cargo.toml index c3528d2..eb82438 100644 --- a/hook-api/Cargo.toml +++ b/hook-api/Cargo.toml @@ -19,6 +19,7 @@ serde_json = { workspace = true } sqlx = { workspace = true } tokio = { workspace = true } tower = { workspace = true } +tower-http = { workspace = true } tracing = { workspace = true } tracing-subscriber = { workspace = true } url = { workspace = true } diff --git a/hook-api/src/config.rs b/hook-api/src/config.rs index 55fa404..fe99de2 100644 --- a/hook-api/src/config.rs +++ b/hook-api/src/config.rs @@ -16,6 +16,9 @@ pub struct Config { #[envconfig(default = "100")] pub max_pg_connections: u32, + + #[envconfig(default = "5_000_000")] + pub max_body_size: usize, } impl Config { diff --git a/hook-api/src/handlers/app.rs b/hook-api/src/handlers/app.rs index fa2bcbc..7cbbc44 100644 --- a/hook-api/src/handlers/app.rs +++ b/hook-api/src/handlers/app.rs @@ -1,10 +1,11 @@ -use axum::{extract::DefaultBodyLimit, routing, Router}; +use axum::{routing, Router}; +use tower_http::limit::RequestBodyLimitLayer; use hook_common::pgqueue::PgQueue; use super::webhook; -pub fn add_routes(router: Router, pg_pool: PgQueue) -> Router { +pub fn add_routes(router: Router, pg_pool: PgQueue, max_body_size: usize) -> Router { router .route("/", routing::get(index)) .route("/_readiness", routing::get(index)) @@ -13,7 +14,7 @@ pub fn add_routes(router: Router, pg_pool: PgQueue) -> Router { "/webhook", routing::post(webhook::post) .with_state(pg_pool) - .layer(DefaultBodyLimit::disable()), + .layer(RequestBodyLimitLayer::new(max_body_size)), ) } @@ -37,7 +38,7 @@ mod tests { async fn index(db: PgPool) { let pg_queue = PgQueue::new_from_pool("test_index", db).await; - let app = add_routes(Router::new(), pg_queue); + let app = add_routes(Router::new(), pg_queue, 1_000_000); let response = app .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap()) diff --git a/hook-api/src/handlers/webhook.rs b/hook-api/src/handlers/webhook.rs index 47f21a6..808c948 100644 --- a/hook-api/src/handlers/webhook.rs +++ b/hook-api/src/handlers/webhook.rs @@ -9,8 +9,6 @@ use hook_common::pgqueue::{NewJob, PgQueue}; use serde::Serialize; use tracing::{debug, error}; -pub const MAX_BODY_SIZE: usize = 5_000_000; - #[derive(Serialize, Deserialize)] pub struct WebhookPostResponse { #[serde(skip_serializing_if = "Option::is_none")] @@ -37,15 +35,6 @@ pub async fn post( ) -> Result, (StatusCode, Json)> { debug!("received payload: {:?}", payload); - if payload.parameters.body.len() > MAX_BODY_SIZE { - return Err(( - StatusCode::BAD_REQUEST, - Json(WebhookPostResponse { - error: Some("body too large".to_owned()), - }), - )); - } - let url_hostname = get_hostname(&payload.parameters.url)?; // We could cast to i32, but this ensures we are not wrapping. let max_attempts = i32::try_from(payload.max_attempts).map_err(|_| { @@ -125,11 +114,13 @@ mod tests { use crate::handlers::app::add_routes; + const MAX_BODY_SIZE: usize = 1_000_000; + #[sqlx::test(migrations = "../migrations")] async fn webhook_success(db: PgPool) { let pg_queue = PgQueue::new_from_pool("test_index", db).await; - let app = add_routes(Router::new(), pg_queue); + let app = add_routes(Router::new(), pg_queue, MAX_BODY_SIZE); let mut headers = collections::HashMap::new(); headers.insert("Content-Type".to_owned(), "application/json".to_owned()); @@ -171,7 +162,7 @@ mod tests { async fn webhook_bad_url(db: PgPool) { let pg_queue = PgQueue::new_from_pool("test_index", db).await; - let app = add_routes(Router::new(), pg_queue); + let app = add_routes(Router::new(), pg_queue, MAX_BODY_SIZE); let response = app .oneshot( @@ -208,7 +199,7 @@ mod tests { async fn webhook_payload_missing_fields(db: PgPool) { let pg_queue = PgQueue::new_from_pool("test_index", db).await; - let app = add_routes(Router::new(), pg_queue); + let app = add_routes(Router::new(), pg_queue, MAX_BODY_SIZE); let response = app .oneshot( @@ -229,7 +220,7 @@ mod tests { async fn webhook_payload_not_json(db: PgPool) { let pg_queue = PgQueue::new_from_pool("test_index", db).await; - let app = add_routes(Router::new(), pg_queue); + let app = add_routes(Router::new(), pg_queue, MAX_BODY_SIZE); let response = app .oneshot( @@ -250,9 +241,9 @@ mod tests { async fn webhook_payload_body_too_large(db: PgPool) { let pg_queue = PgQueue::new_from_pool("test_index", db).await; - let app = add_routes(Router::new(), pg_queue); + let app = add_routes(Router::new(), pg_queue, MAX_BODY_SIZE); - let bytes: Vec = vec![b'a'; 5_000_000 * 2]; + let bytes: Vec = vec![b'a'; MAX_BODY_SIZE + 1]; let long_string = String::from_utf8_lossy(&bytes); let response = app @@ -283,6 +274,6 @@ mod tests { .await .unwrap(); - assert_eq!(response.status(), StatusCode::BAD_REQUEST); + assert_eq!(response.status(), StatusCode::PAYLOAD_TOO_LARGE); } } diff --git a/hook-api/src/main.rs b/hook-api/src/main.rs index 9a9a9fd..ad05ede 100644 --- a/hook-api/src/main.rs +++ b/hook-api/src/main.rs @@ -34,7 +34,7 @@ async fn main() { .await .expect("failed to initialize queue"); - let app = handlers::add_routes(Router::new(), pg_queue); + let app = handlers::add_routes(Router::new(), pg_queue, config.max_body_size); let app = setup_metrics_routes(app); match listen(app, config.bind()).await {