From 61c82a480c9c73194bf75a355e569c3910a1c86e Mon Sep 17 00:00:00 2001 From: Hannes Herrmann Date: Mon, 26 Aug 2024 00:18:33 +0200 Subject: [PATCH] feat(introspection_cache): allow to use introspection caches in axum --- Cargo.toml | 1 + src/axum/introspection/state.rs | 4 + src/axum/introspection/state_builder.rs | 20 +++ src/axum/introspection/user.rs | 181 +++++++++++++++++++++--- src/oidc/introspection/cache/mod.rs | 25 +++- 5 files changed, 208 insertions(+), 23 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index af5c289b..4f169c14 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -88,6 +88,7 @@ tonic-types = { version = "0.11", optional = true } chrono = "0.4.38" tokio = { version = "1.37.0", features = ["macros", "rt-multi-thread"] } tower = { version = "0.4.13" } +http-body-util = "0.1.0" [package.metadata.docs.rs] all-features = true diff --git a/src/axum/introspection/state.rs b/src/axum/introspection/state.rs index a804a59a..964ce034 100644 --- a/src/axum/introspection/state.rs +++ b/src/axum/introspection/state.rs @@ -1,6 +1,8 @@ use openidconnect::IntrospectionUrl; use std::sync::Arc; +#[cfg(feature = "introspection_cache")] +use crate::oidc::introspection::cache::IntrospectionCache; use crate::oidc::introspection::AuthorityAuthentication; /// State which must be present for extractor to work, @@ -32,4 +34,6 @@ pub(crate) struct IntrospectionConfig { pub(crate) authority: String, pub(crate) authentication: AuthorityAuthentication, pub(crate) introspection_uri: IntrospectionUrl, + #[cfg(feature = "introspection_cache")] + pub(crate) cache: Option>, } diff --git a/src/axum/introspection/state_builder.rs b/src/axum/introspection/state_builder.rs index f45b7e5a..7d0fcae3 100644 --- a/src/axum/introspection/state_builder.rs +++ b/src/axum/introspection/state_builder.rs @@ -6,6 +6,9 @@ use crate::credentials::Application; use crate::oidc::discovery::{discover, DiscoveryError}; use crate::oidc::introspection::AuthorityAuthentication; +#[cfg(feature = "introspection_cache")] +use crate::oidc::introspection::cache::IntrospectionCache; + use super::state::IntrospectionState; custom_error! { @@ -19,6 +22,8 @@ custom_error! { pub struct IntrospectionStateBuilder { authority: String, authentication: Option, + #[cfg(feature = "introspection_cache")] + cache: Option>, } /// Builder for [IntrospectionConfig] @@ -27,6 +32,8 @@ impl IntrospectionStateBuilder { Self { authority: authority.to_string(), authentication: None, + #[cfg(feature = "introspection_cache")] + cache: None, } } @@ -49,6 +56,17 @@ impl IntrospectionStateBuilder { self } + /// Set the [IntrospectionCache] to use for caching introspection responses. + #[cfg(feature = "introspection_cache")] + pub fn with_introspection_cache( + &mut self, + cache: impl IntrospectionCache + 'static, + ) -> &mut IntrospectionStateBuilder { + self.cache = Some(Box::new(cache)); + + self + } + pub async fn build(&mut self) -> Result { if self.authentication.is_none() { return Err(IntrospectionStateBuilderError::NoAuthSchema); @@ -72,6 +90,8 @@ impl IntrospectionStateBuilder { authority: self.authority.clone(), introspection_uri: introspection_uri.unwrap(), authentication: self.authentication.as_ref().unwrap().clone(), + #[cfg(feature = "introspection_cache")] + cache: self.cache.take(), }), }) } diff --git a/src/axum/introspection/user.rs b/src/axum/introspection/user.rs index f68ad13c..9ad6e29a 100644 --- a/src/axum/introspection/user.rs +++ b/src/axum/introspection/user.rs @@ -89,6 +89,38 @@ where let state = IntrospectionState::from_ref(state); let config = &state.config; + #[cfg(feature = "introspection_cache")] + let res = { + match state.config.cache.as_deref() { + None => { + introspect( + &config.introspection_uri, + &config.authority, + &config.authentication, + bearer.token(), + ) + .await + } + Some(cache) => match cache.get(bearer.token()).await { + Some(cached_response) => Ok(cached_response), + None => { + let res = introspect( + &config.introspection_uri, + &config.authority, + &config.authentication, + bearer.token(), + ) + .await; + if let Ok(res) = &res { + cache.set(bearer.token(), res.clone()).await; + } + res + } + }, + } + }; + + #[cfg(not(feature = "introspection_cache"))] let res = introspect( &config.introspection_uri, &config.authority, @@ -131,14 +163,12 @@ impl From for IntrospectedUser { mod tests { #![allow(clippy::all)] - use std::thread; - use axum::body::Body; use axum::http::Request; use axum::response::IntoResponse; use axum::routing::get; use axum::Router; - use tokio::runtime::Builder; + use tower::ServiceExt; use crate::axum::introspection::{IntrospectionState, IntrospectionStateBuilder}; @@ -169,33 +199,39 @@ mod tests { "Hello unauthorized" } - fn get_config() -> IntrospectionState { - let config = thread::spawn(move || { - let rt = Builder::new_multi_thread().enable_all().build().unwrap(); - rt.block_on(async { - IntrospectionStateBuilder::new(ZITADEL_URL) - .with_jwt_profile(Application::load_from_json(APPLICATION).unwrap()) - .build() - .await - .unwrap() - }) - }); + #[derive(Clone)] + struct SomeUserState { + introspection_state: IntrospectionState, + } - config.join().unwrap() + impl FromRef for IntrospectionState { + fn from_ref(input: &SomeUserState) -> Self { + input.introspection_state.clone() + } } - fn app() -> Router { + async fn app() -> Router { + let introspection_state = IntrospectionStateBuilder::new(ZITADEL_URL) + .with_jwt_profile(Application::load_from_json(APPLICATION).unwrap()) + .build() + .await + .unwrap(); + + let state = SomeUserState { + introspection_state, + }; + let app = Router::new() .route("/unauthed", get(unauthed)) .route("/authed", get(authed)) - .with_state(get_config()); + .with_state(state); return app; } #[tokio::test] async fn can_guard() { - let app = app(); + let app = app().await; let resp = app .oneshot( @@ -212,7 +248,7 @@ mod tests { #[tokio::test] async fn guard_protects_if_non_bearer_present() { - let app = app(); + let app = app().await; let resp = app .oneshot( @@ -230,7 +266,7 @@ mod tests { #[tokio::test] async fn guard_protects_if_multiple_auth_headers_present() { - let app = app(); + let app = app().await; let resp = app .oneshot( @@ -249,7 +285,7 @@ mod tests { #[tokio::test] async fn guard_protects_if_invalid_token() { - let app = app(); + let app = app().await; let resp = app .oneshot( @@ -267,7 +303,7 @@ mod tests { #[tokio::test] async fn guard_allows_valid_token() { - let app = app(); + let app = app().await; let resp = app .oneshot( @@ -282,4 +318,105 @@ mod tests { assert_eq!(resp.status(), StatusCode::OK); } + + #[cfg(feature = "introspection_cache")] + mod introspection_cache { + use super::*; + use crate::oidc::introspection::cache::in_memory::InMemoryIntrospectionCache; + use crate::oidc::introspection::cache::IntrospectionCache; + use crate::oidc::introspection::ZitadelIntrospectionExtraTokenFields; + use chrono::{TimeDelta, Utc}; + use http_body_util::BodyExt; + use std::ops::Add; + use std::sync::Arc; + + async fn app_witch_cache(cache: impl IntrospectionCache + 'static) -> Router { + let introspection_state = IntrospectionStateBuilder::new(ZITADEL_URL) + .with_jwt_profile(Application::load_from_json(APPLICATION).unwrap()) + .with_introspection_cache(cache) + .build() + .await + .unwrap(); + + let state = SomeUserState { + introspection_state, + }; + + let app = Router::new() + .route("/unauthed", get(unauthed)) + .route("/authed", get(authed)) + .with_state(state); + + return app; + } + + #[tokio::test] + async fn guard_uses_cached_response() { + let cache = Arc::new(InMemoryIntrospectionCache::default()); + let app = app_witch_cache(cache.clone()).await; + + let mut res = ZitadelIntrospectionResponse::new( + true, + ZitadelIntrospectionExtraTokenFields::default(), + ); + res.set_sub(Some("cached_sub".to_string())); + res.set_exp(Some(Utc::now().add(TimeDelta::days(1)))); + cache.set(PERSONAL_ACCESS_TOKEN, res).await; + + let response = app + .oneshot( + Request::builder() + .uri("/authed") + .header("authorization", format!("Bearer {PERSONAL_ACCESS_TOKEN}")) + .body(Body::empty()) + .unwrap(), + ) + .await + .unwrap(); + + let text = String::from_utf8( + response + .into_body() + .collect() + .await + .unwrap() + .to_bytes() + .to_vec(), + ) + .unwrap(); + assert!(text.contains("cached_sub")); + } + + #[tokio::test] + async fn guard_caches_response() { + let cache = Arc::new(InMemoryIntrospectionCache::default()); + let app = app_witch_cache(cache.clone()).await; + + let response = app + .oneshot( + Request::builder() + .uri("/authed") + .header("authorization", format!("Bearer {PERSONAL_ACCESS_TOKEN}")) + .body(Body::empty()) + .unwrap(), + ) + .await + .unwrap(); + + let text = String::from_utf8( + response + .into_body() + .collect() + .await + .unwrap() + .to_bytes() + .to_vec(), + ) + .unwrap(); + + let cached_response = cache.get(PERSONAL_ACCESS_TOKEN).await.unwrap(); + + assert!(text.contains(cached_response.sub().unwrap())); + } + } } diff --git a/src/oidc/introspection/cache/mod.rs b/src/oidc/introspection/cache/mod.rs index 41fed335..10139f48 100644 --- a/src/oidc/introspection/cache/mod.rs +++ b/src/oidc/introspection/cache/mod.rs @@ -3,6 +3,10 @@ //! ZITADEL server. Depending on the enabled features, the cache can be persisted //! or only be kept in memory. +use async_trait::async_trait; +use std::fmt::Debug; +use std::ops::Deref; + pub mod in_memory; type Response = super::ZitadelIntrospectionResponse; @@ -17,7 +21,7 @@ type Response = super::ZitadelIntrospectionResponse; /// ZITADEL will always set the `exp` field, if the token is "active". /// /// Non-active tokens SHOULD not be cached. -#[async_trait::async_trait] +#[async_trait] pub trait IntrospectionCache: Send + Sync + std::fmt::Debug { /// Retrieves the cached introspection result for the given token, if it exists. async fn get(&self, token: &str) -> Option; @@ -29,3 +33,22 @@ pub trait IntrospectionCache: Send + Sync + std::fmt::Debug { /// Clears the cache. async fn clear(&self); } + +#[async_trait] +impl IntrospectionCache for T +where + T: Deref + Send + Sync + Debug, + V: IntrospectionCache, +{ + async fn get(&self, token: &str) -> Option { + self.deref().get(token).await + } + + async fn set(&self, token: &str, response: Response) { + self.deref().set(token, response).await + } + + async fn clear(&self) { + self.deref().clear().await + } +}