Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add: request id + test more effective middleware ordering #696

Merged
merged 5 commits into from
Aug 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ cargo_metadata = "0.18.1"

cfg-if = "1"

uuid = { version = "1.6", features = ["v4"] }
uuid = { version = "1.10.0", features = ["v4", "fast-rng"] }
requestty = "0.5.0"

# A socket.io server implementation
Expand Down
170 changes: 83 additions & 87 deletions src/controller/app_routes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ use tower_http::{
add_extension::AddExtensionLayer,
catch_panic::CatchPanicLayer,
compression::CompressionLayer,
cors,
services::{ServeDir, ServeFile},
set_header::SetResponseHeaderLayer,
timeout::TimeoutLayer,
Expand All @@ -20,9 +19,15 @@ use tower_http::{

#[cfg(feature = "channels")]
use super::channels::AppChannels;
use super::routes::Routes;
use super::{middleware::cors::cors_middleware, routes::Routes};
use crate::{
app::AppContext, config, controller::middleware::etag::EtagLayer, environment::Environment,
app::AppContext,
config,
controller::middleware::{
etag::EtagLayer,
request_id::{request_id_middleware, LocoRequestId},
},
environment::Environment,
errors, Result,
};

Expand Down Expand Up @@ -179,26 +184,86 @@ impl AppRoutes {
/// [`axum::Router`].
#[allow(clippy::cognitive_complexity)]
pub fn to_router(&self, ctx: AppContext, mut app: AXRouter<AppContext>) -> Result<AXRouter> {
//
// IMPORTANT: middleware ordering in this function is opposite to what you
// intuitively may think. when using `app.layer` to add individual middleware,
// the LAST middleware is the FIRST to meet the outside world (a user request
// starting), or "LIFO" order.
// We build the "onion" from the inside (start of this function),
// outwards (end of this function). This is why routes is first in coding order
// here (the core of the onion), and request ID is amongst the last
// (because every request is assigned with a unique ID, which starts its
// "life").
//
// NOTE: when using ServiceBuilder#layer the order is FIRST to LAST (but we
// don't use ServiceBuilder because it requires too complex generic typing for
// this function). ServiceBuilder is recommended to save compile times, but that
// may be a thing of the past as we don't notice any issues with compile times
// using the router directly, and ServiceBuilder has been reported to give
// issues in compile times itself (https://github.com/rust-lang/crates.io/pull/7443).
//
for router in self.collect() {
tracing::info!("{}", router.to_string());

app = app.route(&router.uri, router.method);
}

app = Self::add_powered_by_header(app, &ctx.config.server);
#[cfg(feature = "channels")]
if let Some(channels) = self.channels.as_ref() {
tracing::info!("[Middleware] Adding channels");
let channel_layer_app = tower::ServiceBuilder::new().layer(channels.layer.clone());
if let Some(cors) = &ctx
.config
.server
.middlewares
.cors
.as_ref()
.filter(|c| c.enable)
{
app = app.layer(
tower::ServiceBuilder::new()
.layer(cors_middleware(cors)?)
.layer(channel_layer_app),
);
} else {
app = app.layer(
tower::ServiceBuilder::new()
.layer(tower_http::cors::CorsLayer::permissive())
.layer(channel_layer_app),
);
}
}

if let Some(catch_panic) = &ctx.config.server.middlewares.catch_panic {
if catch_panic.enable {
app = Self::add_catch_panic(app);
}
}

if let Some(etag) = &ctx.config.server.middlewares.etag {
if etag.enable {
app = Self::add_etag_middleware(app);
}
}

if let Some(compression) = &ctx.config.server.middlewares.compression {
if compression.enable {
app = Self::add_compression_middleware(app);
}
}

if let Some(timeout_request) = &ctx.config.server.middlewares.timeout_request {
if timeout_request.enable {
app = Self::add_timeout_middleware(app, timeout_request);
}
}

if let Some(cors) = &ctx.config.server.middlewares.cors {
if cors.enable {
app = app.layer(cors_middleware(cors)?);
}
}

if let Some(limit) = &ctx.config.server.middlewares.limit_payload {
if limit.enable {
app = Self::add_limit_payload_middleware(app, limit)?;
Expand All @@ -211,62 +276,28 @@ impl AppRoutes {
}
}

if let Some(timeout_request) = &ctx.config.server.middlewares.timeout_request {
if timeout_request.enable {
app = Self::add_timeout_middleware(app, timeout_request);
}
}

let cors = ctx
.config
.server
.middlewares
.cors
.as_ref()
.filter(|cors| cors.enable)
.map(Self::get_cors_middleware)
.transpose()?;

if let Some(cors) = &cors {
app = app.layer(cors.clone());
tracing::info!("[Middleware] Adding cors");
}

if let Some(static_assets) = &ctx.config.server.middlewares.static_assets {
if static_assets.enable {
app = Self::add_static_asset_middleware(app, static_assets)?;
}
}

if let Some(etag) = &ctx.config.server.middlewares.etag {
if etag.enable {
app = Self::add_etag_middleware(app);
}
}
// XXX todo: remote IP middleware here

#[cfg(feature = "channels")]
if let Some(channels) = self.channels.as_ref() {
tracing::info!("[Middleware] Adding channels");
let channel_layer_app = tower::ServiceBuilder::new().layer(channels.layer.clone());
if let Some(cors) = cors {
app = app.layer(
tower::ServiceBuilder::new()
.layer(cors)
.layer(channel_layer_app),
);
} else {
app = app.layer(
tower::ServiceBuilder::new()
.layer(tower_http::cors::CorsLayer::permissive())
.layer(channel_layer_app),
);
}
}
app = Self::add_powered_by_header(app, &ctx.config.server);

app = Self::add_request_id_middleware(app);

let router = app.with_state(ctx);
Ok(router)
}

fn add_request_id_middleware(app: AXRouter<AppContext>) -> AXRouter<AppContext> {
let app = app.layer(axum::middleware::from_fn(request_id_middleware));
tracing::info!("[Middleware] Adding request_id middleware");
app
}

fn add_static_asset_middleware(
app: AXRouter<AppContext>,
config: &config::StaticAssetsMiddleware,
Expand Down Expand Up @@ -307,44 +338,6 @@ impl AppRoutes {
app
}

fn get_cors_middleware(config: &config::CorsMiddleware) -> Result<cors::CorsLayer> {
let mut cors: cors::CorsLayer = cors::CorsLayer::permissive();

if let Some(allow_origins) = &config.allow_origins {
// testing CORS, assuming https://example.com in the allow list:
// $ curl -v --request OPTIONS 'localhost:5150/api/_ping' -H 'Origin: https://example.com' -H 'Access-Control-Request-Method: GET'
// look for '< access-control-allow-origin: https://example.com' in response.
// if it doesn't appear (test with a bogus domain), it is not allowed.
let mut list = vec![];
for origins in allow_origins {
list.push(origins.parse()?);
}
cors = cors.allow_origin(list);
}

if let Some(allow_headers) = &config.allow_headers {
let mut headers = vec![];
for header in allow_headers {
headers.push(header.parse()?);
}
cors = cors.allow_headers(headers);
}

if let Some(allow_methods) = &config.allow_methods {
let mut methods = vec![];
for method in allow_methods {
methods.push(method.parse()?);
}
cors = cors.allow_methods(methods);
}

if let Some(max_age) = config.max_age {
cors = cors.max_age(Duration::from_secs(max_age));
}

Ok(cors)
}

fn add_catch_panic(app: AXRouter<AppContext>) -> AXRouter<AppContext> {
app.layer(CatchPanicLayer::custom(handle_panic))
}
Expand Down Expand Up @@ -372,7 +365,10 @@ impl AppRoutes {
let app = app
.layer(
TraceLayer::new_for_http().make_span_with(|request: &http::Request<_>| {
let request_id = uuid::Uuid::new_v4();
let ext = request.extensions();
let request_id = ext
.get::<LocoRequestId>()
.map_or_else(|| "req-id-none".to_string(), |r| r.get().to_string());
let user_agent = request
.headers()
.get(axum::http::header::USER_AGENT)
Expand Down
48 changes: 48 additions & 0 deletions src/controller/middleware/cors.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
use std::time::Duration;

use tower_http::cors;

use crate::{config, Result};

/// Create a CORS layer
///
/// # Errors
///
/// This function will return an error if parsing of header config fail
pub fn cors_middleware(config: &config::CorsMiddleware) -> Result<cors::CorsLayer> {
let mut cors: cors::CorsLayer = cors::CorsLayer::permissive();

if let Some(allow_origins) = &config.allow_origins {
// testing CORS, assuming https://example.com in the allow list:
// $ curl -v --request OPTIONS 'localhost:5150/api/_ping' -H 'Origin: https://example.com' -H 'Access-Control-Request-Method: GET'
// look for '< access-control-allow-origin: https://example.com' in response.
// if it doesn't appear (test with a bogus domain), it is not allowed.
let mut list = vec![];
for origins in allow_origins {
list.push(origins.parse()?);
}
cors = cors.allow_origin(list);
}

if let Some(allow_headers) = &config.allow_headers {
let mut headers = vec![];
for header in allow_headers {
headers.push(header.parse()?);
}
cors = cors.allow_headers(headers);
}

if let Some(allow_methods) = &config.allow_methods {
let mut methods = vec![];
for method in allow_methods {
methods.push(method.parse()?);
}
cors = cors.allow_methods(methods);
}

if let Some(max_age) = config.max_age {
cors = cors.max_age(Duration::from_secs(max_age));
}

Ok(cors)
}
2 changes: 2 additions & 0 deletions src/controller/middleware/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
#[cfg(all(feature = "auth_jwt", feature = "with-db"))]
pub mod auth;
pub mod cors;
pub mod etag;
pub mod format;
pub mod request_id;
77 changes: 77 additions & 0 deletions src/controller/middleware/request_id.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
use axum::{extract::Request, http::HeaderValue, middleware::Next, response::Response};
use lazy_static::lazy_static;
use regex::Regex;
use tracing::warn;
use uuid::Uuid;

#[derive(Debug, Clone)]
pub struct LocoRequestId(String);

impl LocoRequestId {
/// Get the request id
#[must_use]
pub fn get(&self) -> &str {
self.0.as_str()
}
}

const X_REQUEST_ID: &str = "x-request-id";
const MAX_LEN: usize = 255;
lazy_static! {
static ref ID_CLEANUP: Regex = Regex::new(r"[^\w\-@]").unwrap();
}

pub async fn request_id_middleware(mut request: Request, next: Next) -> Response {
let header_request_id = request.headers().get(X_REQUEST_ID).cloned();
let request_id = make_request_id(header_request_id);
request
.extensions_mut()
.insert(LocoRequestId(request_id.clone()));
let mut res = next.run(request).await;

if let Ok(v) = HeaderValue::from_str(request_id.as_str()) {
res.headers_mut().insert(X_REQUEST_ID, v);
} else {
warn!("could not set request ID into response headers: `{request_id}`",);
}
res
}

fn make_request_id(maybe_request_id: Option<HeaderValue>) -> String {
maybe_request_id
.and_then(|hdr| {
// see: https://github.com/rails/rails/blob/main/actionpack/lib/action_dispatch/middleware/request_id.rb#L39
let id: Option<String> = hdr.to_str().ok().map(|s| {
ID_CLEANUP
.replace_all(s, "")
.chars()
.take(MAX_LEN)
.collect()
});
id.filter(|s| !s.is_empty())
})
.unwrap_or_else(|| Uuid::new_v4().to_string())
}

#[cfg(test)]
mod tests {
use axum::http::HeaderValue;
use insta::assert_debug_snapshot;

use super::make_request_id;

#[test]
fn create_or_fetch_request_id() {
let id = make_request_id(Some(HeaderValue::from_static("foo-bar=baz")));
assert_debug_snapshot!(id);
let id = make_request_id(Some(HeaderValue::from_static("")));
assert_debug_snapshot!(id.len());
let id = make_request_id(Some(HeaderValue::from_static("==========")));
assert_debug_snapshot!(id.len());
let long_id = "x".repeat(1000);
let id = make_request_id(Some(HeaderValue::from_str(&long_id).unwrap()));
assert_debug_snapshot!(id.len());
let id = make_request_id(None);
assert_debug_snapshot!(id.len());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
source: src/controller/middleware/request_id.rs
expression: id.len()
---
36
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
source: src/controller/middleware/request_id.rs
expression: id.len()
---
36
Loading
Loading