diff --git a/api/src/middleware/jwt_auth.rs b/api/src/middleware/jwt_auth.rs index 353edb8c..3b0585c6 100644 --- a/api/src/middleware/jwt_auth.rs +++ b/api/src/middleware/jwt_auth.rs @@ -1,6 +1,6 @@ -use crate::{api_payloads::ErrorResponse, server::AppState, unauthorized}; -use axum::{body::Body, extract::State, middleware::Next, response::Response, Json}; -use http::{Request, StatusCode}; +use crate::{endpoints::ApiError, server::AppState, unauthorized}; +use axum::{body::Body, extract::State, middleware::Next, response::Response}; +use http::Request; use integrationos_domain::Claims; use jsonwebtoken::{DecodingKey, Validation}; use std::sync::Arc; @@ -27,7 +27,7 @@ pub async fn jwt_auth( State(state): State>, mut req: Request, next: Next, -) -> Result)> { +) -> Result { let Some(auth_header) = req.headers().get(http::header::AUTHORIZATION) else { info!("missing authorization header"); return Err(unauthorized!()); diff --git a/api/src/routes/mod.rs b/api/src/routes/mod.rs index e7d6ff8d..0cf7b3fa 100644 --- a/api/src/routes/mod.rs +++ b/api/src/routes/mod.rs @@ -1,3 +1,4 @@ +pub mod private; pub mod protected; pub mod public; @@ -16,6 +17,7 @@ pub async fn get_router(state: &Arc) -> Router> { Router::new() .nest(&path, protected::get_router(state).await) .nest(&public_path, public::get_router(state)) + .nest(&path, private::get_router(state).await) .route("/", get(get_root)) .fallback(not_found_handler) .layer(CorsLayer::permissive()) diff --git a/api/src/routes/private.rs b/api/src/routes/private.rs new file mode 100644 index 00000000..918f5702 --- /dev/null +++ b/api/src/routes/private.rs @@ -0,0 +1,62 @@ +use crate::{ + endpoints::{ + common_model, connection_definition, + connection_model_definition::{self, test_connection_model_definition}, + connection_model_schema, connection_oauth_definition, openapi, + }, + middleware::{ + extractor::OwnershipId, + jwt_auth::{self, JwtState}, + }, + server::AppState, +}; +use axum::{middleware::from_fn_with_state, routing::post, Router}; +use std::sync::Arc; +use tower_governor::{governor::GovernorConfigBuilder, GovernorLayer}; +use tower_http::trace::TraceLayer; + +pub async fn get_router(state: &Arc) -> Router> { + let routes = Router::new() + .route( + "/connection-model-definitions/test/:id", + post(test_connection_model_definition), + ) + .nest( + "/connection-definitions", + connection_definition::get_router(), + ) + .nest( + "/connection-oauth-definitions", + connection_oauth_definition::get_router(), + ) + .nest( + "/connection-model-definitions", + connection_model_definition::get_router(), + ) + .route("/openapi", post(openapi::refresh_openapi)) + .nest( + "/connection-model-schemas", + connection_model_schema::get_router(), + ) + .nest("/common-models", common_model::get_router()); + + let config = Box::new( + GovernorConfigBuilder::default() + .per_second(state.config.burst_rate_limit) + .burst_size(state.config.burst_size) + .key_extractor(OwnershipId) + .use_headers() + .finish() + .expect("Failed to build GovernorConfig"), + ); + + routes + .layer(GovernorLayer { + config: Box::leak(config), + }) + .layer(from_fn_with_state( + Arc::new(JwtState::new(state)), + jwt_auth::jwt_auth, + )) + .layer(TraceLayer::new_for_http()) +} diff --git a/api/src/routes/protected.rs b/api/src/routes/protected.rs index c974e438..49c080cf 100644 --- a/api/src/routes/protected.rs +++ b/api/src/routes/protected.rs @@ -1,9 +1,7 @@ use crate::{ endpoints::{ - common_model, connection, connection_definition, - connection_model_definition::{self, test_connection_model_definition}, - connection_model_schema, connection_oauth_definition, event_access, events, metrics, oauth, - openapi, passthrough, pipeline, transactions, unified, + connection, event_access, events, metrics, oauth, passthrough, pipeline, transactions, + unified, }, middleware::{ auth, @@ -13,7 +11,7 @@ use crate::{ server::AppState, }; use axum::{ - error_handling::HandleErrorLayer, middleware::from_fn_with_state, routing::post, Router, + error_handling::HandleErrorLayer, middleware::from_fn_with_state, Router, }; use http::HeaderName; use std::{iter::once, sync::Arc}; @@ -27,32 +25,10 @@ pub async fn get_router(state: &Arc) -> Router> { .nest("/events", events::get_router()) .nest("/transactions", transactions::get_router()) .nest("/connections", connection::get_router()) - .route( - "/connection-model-definitions/test/:id", - post(test_connection_model_definition), - ) .nest("/event-access", event_access::get_router()) .nest("/passthrough", passthrough::get_router()) .nest("/oauth", oauth::get_router()) .nest("/unified", unified::get_router()) - .nest( - "/connection-definitions", - connection_definition::get_router(), - ) - .nest( - "/connection-oauth-definitions", - connection_oauth_definition::get_router(), - ) - .nest( - "/connection-model-definitions", - connection_model_definition::get_router(), - ) - .route("/openapi", post(openapi::refresh_openapi)) - .nest( - "/connection-model-schemas", - connection_model_schema::get_router(), - ) - .nest("/common-models", common_model::get_router()) .layer(TraceLayer::new_for_http()) .nest("/metrics", metrics::get_router());