diff --git a/Cargo.toml b/Cargo.toml index fec06b028..52c84e562 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -109,7 +109,7 @@ cargo_metadata = "0.18.1" cfg-if = "1" -uuid = { version = "1.6", features = ["v4"] } +uuid = { version = "1.10.0", features = ["v4", "fast-rng"] } requestty = "0.5.0" # A socket.io server implementation diff --git a/src/controller/app_routes.rs b/src/controller/app_routes.rs index 157156f09..9fe25638e 100644 --- a/src/controller/app_routes.rs +++ b/src/controller/app_routes.rs @@ -11,7 +11,6 @@ use tower_http::{ add_extension::AddExtensionLayer, catch_panic::CatchPanicLayer, compression::CompressionLayer, - cors, services::{ServeDir, ServeFile}, set_header::SetResponseHeaderLayer, timeout::TimeoutLayer, @@ -20,9 +19,15 @@ use tower_http::{ #[cfg(feature = "channels")] use super::channels::AppChannels; -use super::routes::Routes; +use super::{middleware::cors::cors_middleware, routes::Routes}; use crate::{ - app::AppContext, config, controller::middleware::etag::EtagLayer, environment::Environment, + app::AppContext, + config, + controller::middleware::{ + etag::EtagLayer, + request_id::{request_id_middleware, LocoRequestId}, + }, + environment::Environment, errors, Result, }; @@ -179,13 +184,55 @@ impl AppRoutes { /// [`axum::Router`]. #[allow(clippy::cognitive_complexity)] pub fn to_router(&self, ctx: AppContext, mut app: AXRouter) -> Result { + // + // IMPORTANT: middleware ordering in this function is opposite to what you + // intuitively may think. when using `app.layer` to add individual middleware, + // the LAST middleware is the FIRST to meet the outside world (a user request + // starting), or "LIFO" order. + // We build the "onion" from the inside (start of this function), + // outwards (end of this function). This is why routes is first in coding order + // here (the core of the onion), and request ID is amongst the last + // (because every request is assigned with a unique ID, which starts its + // "life"). + // + // NOTE: when using ServiceBuilder#layer the order is FIRST to LAST (but we + // don't use ServiceBuilder because it requires too complex generic typing for + // this function). ServiceBuilder is recommended to save compile times, but that + // may be a thing of the past as we don't notice any issues with compile times + // using the router directly, and ServiceBuilder has been reported to give + // issues in compile times itself (https://github.com/rust-lang/crates.io/pull/7443). + // for router in self.collect() { tracing::info!("{}", router.to_string()); app = app.route(&router.uri, router.method); } - app = Self::add_powered_by_header(app, &ctx.config.server); + #[cfg(feature = "channels")] + if let Some(channels) = self.channels.as_ref() { + tracing::info!("[Middleware] Adding channels"); + let channel_layer_app = tower::ServiceBuilder::new().layer(channels.layer.clone()); + if let Some(cors) = &ctx + .config + .server + .middlewares + .cors + .as_ref() + .filter(|c| c.enable) + { + app = app.layer( + tower::ServiceBuilder::new() + .layer(cors_middleware(cors)?) + .layer(channel_layer_app), + ); + } else { + app = app.layer( + tower::ServiceBuilder::new() + .layer(tower_http::cors::CorsLayer::permissive()) + .layer(channel_layer_app), + ); + } + } if let Some(catch_panic) = &ctx.config.server.middlewares.catch_panic { if catch_panic.enable { @@ -193,12 +240,30 @@ impl AppRoutes { } } + if let Some(etag) = &ctx.config.server.middlewares.etag { + if etag.enable { + app = Self::add_etag_middleware(app); + } + } + if let Some(compression) = &ctx.config.server.middlewares.compression { if compression.enable { app = Self::add_compression_middleware(app); } } + if let Some(timeout_request) = &ctx.config.server.middlewares.timeout_request { + if timeout_request.enable { + app = Self::add_timeout_middleware(app, timeout_request); + } + } + + if let Some(cors) = &ctx.config.server.middlewares.cors { + if cors.enable { + app = app.layer(cors_middleware(cors)?); + } + } + if let Some(limit) = &ctx.config.server.middlewares.limit_payload { if limit.enable { app = Self::add_limit_payload_middleware(app, limit)?; @@ -211,62 +276,28 @@ impl AppRoutes { } } - if let Some(timeout_request) = &ctx.config.server.middlewares.timeout_request { - if timeout_request.enable { - app = Self::add_timeout_middleware(app, timeout_request); - } - } - - let cors = ctx - .config - .server - .middlewares - .cors - .as_ref() - .filter(|cors| cors.enable) - .map(Self::get_cors_middleware) - .transpose()?; - - if let Some(cors) = &cors { - app = app.layer(cors.clone()); - tracing::info!("[Middleware] Adding cors"); - } - if let Some(static_assets) = &ctx.config.server.middlewares.static_assets { if static_assets.enable { app = Self::add_static_asset_middleware(app, static_assets)?; } } - if let Some(etag) = &ctx.config.server.middlewares.etag { - if etag.enable { - app = Self::add_etag_middleware(app); - } - } + // XXX todo: remote IP middleware here - #[cfg(feature = "channels")] - if let Some(channels) = self.channels.as_ref() { - tracing::info!("[Middleware] Adding channels"); - let channel_layer_app = tower::ServiceBuilder::new().layer(channels.layer.clone()); - if let Some(cors) = cors { - app = app.layer( - tower::ServiceBuilder::new() - .layer(cors) - .layer(channel_layer_app), - ); - } else { - app = app.layer( - tower::ServiceBuilder::new() - .layer(tower_http::cors::CorsLayer::permissive()) - .layer(channel_layer_app), - ); - } - } + app = Self::add_powered_by_header(app, &ctx.config.server); + + app = Self::add_request_id_middleware(app); let router = app.with_state(ctx); Ok(router) } + fn add_request_id_middleware(app: AXRouter) -> AXRouter { + let app = app.layer(axum::middleware::from_fn(request_id_middleware)); + tracing::info!("[Middleware] Adding request_id middleware"); + app + } + fn add_static_asset_middleware( app: AXRouter, config: &config::StaticAssetsMiddleware, @@ -307,44 +338,6 @@ impl AppRoutes { app } - fn get_cors_middleware(config: &config::CorsMiddleware) -> Result { - let mut cors: cors::CorsLayer = cors::CorsLayer::permissive(); - - if let Some(allow_origins) = &config.allow_origins { - // testing CORS, assuming https://example.com in the allow list: - // $ curl -v --request OPTIONS 'localhost:5150/api/_ping' -H 'Origin: https://example.com' -H 'Access-Control-Request-Method: GET' - // look for '< access-control-allow-origin: https://example.com' in response. - // if it doesn't appear (test with a bogus domain), it is not allowed. - let mut list = vec![]; - for origins in allow_origins { - list.push(origins.parse()?); - } - cors = cors.allow_origin(list); - } - - if let Some(allow_headers) = &config.allow_headers { - let mut headers = vec![]; - for header in allow_headers { - headers.push(header.parse()?); - } - cors = cors.allow_headers(headers); - } - - if let Some(allow_methods) = &config.allow_methods { - let mut methods = vec![]; - for method in allow_methods { - methods.push(method.parse()?); - } - cors = cors.allow_methods(methods); - } - - if let Some(max_age) = config.max_age { - cors = cors.max_age(Duration::from_secs(max_age)); - } - - Ok(cors) - } - fn add_catch_panic(app: AXRouter) -> AXRouter { app.layer(CatchPanicLayer::custom(handle_panic)) } @@ -372,7 +365,10 @@ impl AppRoutes { let app = app .layer( TraceLayer::new_for_http().make_span_with(|request: &http::Request<_>| { - let request_id = uuid::Uuid::new_v4(); + let ext = request.extensions(); + let request_id = ext + .get::() + .map_or_else(|| "req-id-none".to_string(), |r| r.get().to_string()); let user_agent = request .headers() .get(axum::http::header::USER_AGENT) diff --git a/src/controller/middleware/cors.rs b/src/controller/middleware/cors.rs new file mode 100644 index 000000000..b53e37a45 --- /dev/null +++ b/src/controller/middleware/cors.rs @@ -0,0 +1,48 @@ +use std::time::Duration; + +use tower_http::cors; + +use crate::{config, Result}; + +/// Create a CORS layer +/// +/// # Errors +/// +/// This function will return an error if parsing of header config fail +pub fn cors_middleware(config: &config::CorsMiddleware) -> Result { + let mut cors: cors::CorsLayer = cors::CorsLayer::permissive(); + + if let Some(allow_origins) = &config.allow_origins { + // testing CORS, assuming https://example.com in the allow list: + // $ curl -v --request OPTIONS 'localhost:5150/api/_ping' -H 'Origin: https://example.com' -H 'Access-Control-Request-Method: GET' + // look for '< access-control-allow-origin: https://example.com' in response. + // if it doesn't appear (test with a bogus domain), it is not allowed. + let mut list = vec![]; + for origins in allow_origins { + list.push(origins.parse()?); + } + cors = cors.allow_origin(list); + } + + if let Some(allow_headers) = &config.allow_headers { + let mut headers = vec![]; + for header in allow_headers { + headers.push(header.parse()?); + } + cors = cors.allow_headers(headers); + } + + if let Some(allow_methods) = &config.allow_methods { + let mut methods = vec![]; + for method in allow_methods { + methods.push(method.parse()?); + } + cors = cors.allow_methods(methods); + } + + if let Some(max_age) = config.max_age { + cors = cors.max_age(Duration::from_secs(max_age)); + } + + Ok(cors) +} diff --git a/src/controller/middleware/mod.rs b/src/controller/middleware/mod.rs index c361d4aaf..f3d9cecdc 100644 --- a/src/controller/middleware/mod.rs +++ b/src/controller/middleware/mod.rs @@ -1,4 +1,6 @@ #[cfg(all(feature = "auth_jwt", feature = "with-db"))] pub mod auth; +pub mod cors; pub mod etag; pub mod format; +pub mod request_id; diff --git a/src/controller/middleware/request_id.rs b/src/controller/middleware/request_id.rs new file mode 100644 index 000000000..56783dba4 --- /dev/null +++ b/src/controller/middleware/request_id.rs @@ -0,0 +1,77 @@ +use axum::{extract::Request, http::HeaderValue, middleware::Next, response::Response}; +use lazy_static::lazy_static; +use regex::Regex; +use tracing::warn; +use uuid::Uuid; + +#[derive(Debug, Clone)] +pub struct LocoRequestId(String); + +impl LocoRequestId { + /// Get the request id + #[must_use] + pub fn get(&self) -> &str { + self.0.as_str() + } +} + +const X_REQUEST_ID: &str = "x-request-id"; +const MAX_LEN: usize = 255; +lazy_static! { + static ref ID_CLEANUP: Regex = Regex::new(r"[^\w\-@]").unwrap(); +} + +pub async fn request_id_middleware(mut request: Request, next: Next) -> Response { + let header_request_id = request.headers().get(X_REQUEST_ID).cloned(); + let request_id = make_request_id(header_request_id); + request + .extensions_mut() + .insert(LocoRequestId(request_id.clone())); + let mut res = next.run(request).await; + + if let Ok(v) = HeaderValue::from_str(request_id.as_str()) { + res.headers_mut().insert(X_REQUEST_ID, v); + } else { + warn!("could not set request ID into response headers: `{request_id}`",); + } + res +} + +fn make_request_id(maybe_request_id: Option) -> String { + maybe_request_id + .and_then(|hdr| { + // see: https://github.com/rails/rails/blob/main/actionpack/lib/action_dispatch/middleware/request_id.rb#L39 + let id: Option = hdr.to_str().ok().map(|s| { + ID_CLEANUP + .replace_all(s, "") + .chars() + .take(MAX_LEN) + .collect() + }); + id.filter(|s| !s.is_empty()) + }) + .unwrap_or_else(|| Uuid::new_v4().to_string()) +} + +#[cfg(test)] +mod tests { + use axum::http::HeaderValue; + use insta::assert_debug_snapshot; + + use super::make_request_id; + + #[test] + fn create_or_fetch_request_id() { + let id = make_request_id(Some(HeaderValue::from_static("foo-bar=baz"))); + assert_debug_snapshot!(id); + let id = make_request_id(Some(HeaderValue::from_static(""))); + assert_debug_snapshot!(id.len()); + let id = make_request_id(Some(HeaderValue::from_static("=========="))); + assert_debug_snapshot!(id.len()); + let long_id = "x".repeat(1000); + let id = make_request_id(Some(HeaderValue::from_str(&long_id).unwrap())); + assert_debug_snapshot!(id.len()); + let id = make_request_id(None); + assert_debug_snapshot!(id.len()); + } +} diff --git a/src/controller/middleware/snapshots/loco_rs__controller__middleware__request_id__tests__create_or_fetch_request_id-2.snap b/src/controller/middleware/snapshots/loco_rs__controller__middleware__request_id__tests__create_or_fetch_request_id-2.snap new file mode 100644 index 000000000..ccd9f88cb --- /dev/null +++ b/src/controller/middleware/snapshots/loco_rs__controller__middleware__request_id__tests__create_or_fetch_request_id-2.snap @@ -0,0 +1,5 @@ +--- +source: src/controller/middleware/request_id.rs +expression: id.len() +--- +36 diff --git a/src/controller/middleware/snapshots/loco_rs__controller__middleware__request_id__tests__create_or_fetch_request_id-3.snap b/src/controller/middleware/snapshots/loco_rs__controller__middleware__request_id__tests__create_or_fetch_request_id-3.snap new file mode 100644 index 000000000..ccd9f88cb --- /dev/null +++ b/src/controller/middleware/snapshots/loco_rs__controller__middleware__request_id__tests__create_or_fetch_request_id-3.snap @@ -0,0 +1,5 @@ +--- +source: src/controller/middleware/request_id.rs +expression: id.len() +--- +36 diff --git a/src/controller/middleware/snapshots/loco_rs__controller__middleware__request_id__tests__create_or_fetch_request_id-4.snap b/src/controller/middleware/snapshots/loco_rs__controller__middleware__request_id__tests__create_or_fetch_request_id-4.snap new file mode 100644 index 000000000..3a3e3f6ca --- /dev/null +++ b/src/controller/middleware/snapshots/loco_rs__controller__middleware__request_id__tests__create_or_fetch_request_id-4.snap @@ -0,0 +1,5 @@ +--- +source: src/controller/middleware/request_id.rs +expression: id.len() +--- +255 diff --git a/src/controller/middleware/snapshots/loco_rs__controller__middleware__request_id__tests__create_or_fetch_request_id-5.snap b/src/controller/middleware/snapshots/loco_rs__controller__middleware__request_id__tests__create_or_fetch_request_id-5.snap new file mode 100644 index 000000000..ccd9f88cb --- /dev/null +++ b/src/controller/middleware/snapshots/loco_rs__controller__middleware__request_id__tests__create_or_fetch_request_id-5.snap @@ -0,0 +1,5 @@ +--- +source: src/controller/middleware/request_id.rs +expression: id.len() +--- +36 diff --git a/src/controller/middleware/snapshots/loco_rs__controller__middleware__request_id__tests__create_or_fetch_request_id.snap b/src/controller/middleware/snapshots/loco_rs__controller__middleware__request_id__tests__create_or_fetch_request_id.snap new file mode 100644 index 000000000..0ce2fa840 --- /dev/null +++ b/src/controller/middleware/snapshots/loco_rs__controller__middleware__request_id__tests__create_or_fetch_request_id.snap @@ -0,0 +1,5 @@ +--- +source: src/controller/middleware/request_id.rs +expression: id +--- +"foo-barbaz"