Skip to content

Commit

Permalink
feat(introspection_cache): allow to use introspection caches in axum
Browse files Browse the repository at this point in the history
  • Loading branch information
sprudel committed Aug 25, 2024
1 parent d4da97d commit 61c82a4
Show file tree
Hide file tree
Showing 5 changed files with 208 additions and 23 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ tonic-types = { version = "0.11", optional = true }
chrono = "0.4.38"
tokio = { version = "1.37.0", features = ["macros", "rt-multi-thread"] }
tower = { version = "0.4.13" }
http-body-util = "0.1.0"

[package.metadata.docs.rs]
all-features = true
4 changes: 4 additions & 0 deletions src/axum/introspection/state.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use openidconnect::IntrospectionUrl;
use std::sync::Arc;

#[cfg(feature = "introspection_cache")]
use crate::oidc::introspection::cache::IntrospectionCache;
use crate::oidc::introspection::AuthorityAuthentication;

/// State which must be present for extractor to work,
Expand Down Expand Up @@ -32,4 +34,6 @@ pub(crate) struct IntrospectionConfig {
pub(crate) authority: String,
pub(crate) authentication: AuthorityAuthentication,
pub(crate) introspection_uri: IntrospectionUrl,
#[cfg(feature = "introspection_cache")]
pub(crate) cache: Option<Box<dyn IntrospectionCache>>,
}
20 changes: 20 additions & 0 deletions src/axum/introspection/state_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ use crate::credentials::Application;
use crate::oidc::discovery::{discover, DiscoveryError};
use crate::oidc::introspection::AuthorityAuthentication;

#[cfg(feature = "introspection_cache")]
use crate::oidc::introspection::cache::IntrospectionCache;

use super::state::IntrospectionState;

custom_error! {
Expand All @@ -19,6 +22,8 @@ custom_error! {
pub struct IntrospectionStateBuilder {
authority: String,
authentication: Option<AuthorityAuthentication>,
#[cfg(feature = "introspection_cache")]
cache: Option<Box<dyn IntrospectionCache>>,
}

/// Builder for [IntrospectionConfig]
Expand All @@ -27,6 +32,8 @@ impl IntrospectionStateBuilder {
Self {
authority: authority.to_string(),
authentication: None,
#[cfg(feature = "introspection_cache")]
cache: None,
}
}

Expand All @@ -49,6 +56,17 @@ impl IntrospectionStateBuilder {
self
}

/// Set the [IntrospectionCache] to use for caching introspection responses.
#[cfg(feature = "introspection_cache")]
pub fn with_introspection_cache(
&mut self,
cache: impl IntrospectionCache + 'static,
) -> &mut IntrospectionStateBuilder {
self.cache = Some(Box::new(cache));

self
}

pub async fn build(&mut self) -> Result<IntrospectionState, IntrospectionStateBuilderError> {
if self.authentication.is_none() {
return Err(IntrospectionStateBuilderError::NoAuthSchema);
Expand All @@ -72,6 +90,8 @@ impl IntrospectionStateBuilder {
authority: self.authority.clone(),
introspection_uri: introspection_uri.unwrap(),
authentication: self.authentication.as_ref().unwrap().clone(),
#[cfg(feature = "introspection_cache")]
cache: self.cache.take(),
}),
})
}
Expand Down
181 changes: 159 additions & 22 deletions src/axum/introspection/user.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,38 @@ where
let state = IntrospectionState::from_ref(state);
let config = &state.config;

#[cfg(feature = "introspection_cache")]
let res = {
match state.config.cache.as_deref() {
None => {
introspect(
&config.introspection_uri,
&config.authority,
&config.authentication,
bearer.token(),
)
.await
}
Some(cache) => match cache.get(bearer.token()).await {
Some(cached_response) => Ok(cached_response),
None => {
let res = introspect(
&config.introspection_uri,
&config.authority,
&config.authentication,
bearer.token(),
)
.await;
if let Ok(res) = &res {
cache.set(bearer.token(), res.clone()).await;
}
res
}
},
}
};

#[cfg(not(feature = "introspection_cache"))]
let res = introspect(
&config.introspection_uri,
&config.authority,
Expand Down Expand Up @@ -131,14 +163,12 @@ impl From<ZitadelIntrospectionResponse> for IntrospectedUser {
mod tests {
#![allow(clippy::all)]

use std::thread;

use axum::body::Body;
use axum::http::Request;
use axum::response::IntoResponse;
use axum::routing::get;
use axum::Router;
use tokio::runtime::Builder;

use tower::ServiceExt;

use crate::axum::introspection::{IntrospectionState, IntrospectionStateBuilder};
Expand Down Expand Up @@ -169,33 +199,39 @@ mod tests {
"Hello unauthorized"
}

fn get_config() -> IntrospectionState {
let config = thread::spawn(move || {
let rt = Builder::new_multi_thread().enable_all().build().unwrap();
rt.block_on(async {
IntrospectionStateBuilder::new(ZITADEL_URL)
.with_jwt_profile(Application::load_from_json(APPLICATION).unwrap())
.build()
.await
.unwrap()
})
});
#[derive(Clone)]
struct SomeUserState {
introspection_state: IntrospectionState,
}

config.join().unwrap()
impl FromRef<SomeUserState> for IntrospectionState {
fn from_ref(input: &SomeUserState) -> Self {
input.introspection_state.clone()
}
}

fn app() -> Router {
async fn app() -> Router {
let introspection_state = IntrospectionStateBuilder::new(ZITADEL_URL)
.with_jwt_profile(Application::load_from_json(APPLICATION).unwrap())
.build()
.await
.unwrap();

let state = SomeUserState {
introspection_state,
};

let app = Router::new()
.route("/unauthed", get(unauthed))
.route("/authed", get(authed))
.with_state(get_config());
.with_state(state);

return app;
}

#[tokio::test]
async fn can_guard() {
let app = app();
let app = app().await;

let resp = app
.oneshot(
Expand All @@ -212,7 +248,7 @@ mod tests {

#[tokio::test]
async fn guard_protects_if_non_bearer_present() {
let app = app();
let app = app().await;

let resp = app
.oneshot(
Expand All @@ -230,7 +266,7 @@ mod tests {

#[tokio::test]
async fn guard_protects_if_multiple_auth_headers_present() {
let app = app();
let app = app().await;

let resp = app
.oneshot(
Expand All @@ -249,7 +285,7 @@ mod tests {

#[tokio::test]
async fn guard_protects_if_invalid_token() {
let app = app();
let app = app().await;

let resp = app
.oneshot(
Expand All @@ -267,7 +303,7 @@ mod tests {

#[tokio::test]
async fn guard_allows_valid_token() {
let app = app();
let app = app().await;

let resp = app
.oneshot(
Expand All @@ -282,4 +318,105 @@ mod tests {

assert_eq!(resp.status(), StatusCode::OK);
}

#[cfg(feature = "introspection_cache")]
mod introspection_cache {
use super::*;
use crate::oidc::introspection::cache::in_memory::InMemoryIntrospectionCache;
use crate::oidc::introspection::cache::IntrospectionCache;
use crate::oidc::introspection::ZitadelIntrospectionExtraTokenFields;
use chrono::{TimeDelta, Utc};
use http_body_util::BodyExt;
use std::ops::Add;
use std::sync::Arc;

async fn app_witch_cache(cache: impl IntrospectionCache + 'static) -> Router {
let introspection_state = IntrospectionStateBuilder::new(ZITADEL_URL)
.with_jwt_profile(Application::load_from_json(APPLICATION).unwrap())
.with_introspection_cache(cache)
.build()
.await
.unwrap();

let state = SomeUserState {
introspection_state,
};

let app = Router::new()
.route("/unauthed", get(unauthed))
.route("/authed", get(authed))
.with_state(state);

return app;
}

#[tokio::test]
async fn guard_uses_cached_response() {
let cache = Arc::new(InMemoryIntrospectionCache::default());
let app = app_witch_cache(cache.clone()).await;

let mut res = ZitadelIntrospectionResponse::new(
true,
ZitadelIntrospectionExtraTokenFields::default(),
);
res.set_sub(Some("cached_sub".to_string()));
res.set_exp(Some(Utc::now().add(TimeDelta::days(1))));
cache.set(PERSONAL_ACCESS_TOKEN, res).await;

let response = app
.oneshot(
Request::builder()
.uri("/authed")
.header("authorization", format!("Bearer {PERSONAL_ACCESS_TOKEN}"))
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();

let text = String::from_utf8(
response
.into_body()
.collect()
.await
.unwrap()
.to_bytes()
.to_vec(),
)
.unwrap();
assert!(text.contains("cached_sub"));
}

#[tokio::test]
async fn guard_caches_response() {
let cache = Arc::new(InMemoryIntrospectionCache::default());
let app = app_witch_cache(cache.clone()).await;

let response = app
.oneshot(
Request::builder()
.uri("/authed")
.header("authorization", format!("Bearer {PERSONAL_ACCESS_TOKEN}"))
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();

let text = String::from_utf8(
response
.into_body()
.collect()
.await
.unwrap()
.to_bytes()
.to_vec(),
)
.unwrap();

let cached_response = cache.get(PERSONAL_ACCESS_TOKEN).await.unwrap();

assert!(text.contains(cached_response.sub().unwrap()));
}
}
}
25 changes: 24 additions & 1 deletion src/oidc/introspection/cache/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
//! ZITADEL server. Depending on the enabled features, the cache can be persisted
//! or only be kept in memory.
use async_trait::async_trait;
use std::fmt::Debug;
use std::ops::Deref;

pub mod in_memory;

type Response = super::ZitadelIntrospectionResponse;
Expand All @@ -17,7 +21,7 @@ type Response = super::ZitadelIntrospectionResponse;
/// ZITADEL will always set the `exp` field, if the token is "active".
///
/// Non-active tokens SHOULD not be cached.
#[async_trait::async_trait]
#[async_trait]
pub trait IntrospectionCache: Send + Sync + std::fmt::Debug {
/// Retrieves the cached introspection result for the given token, if it exists.
async fn get(&self, token: &str) -> Option<Response>;
Expand All @@ -29,3 +33,22 @@ pub trait IntrospectionCache: Send + Sync + std::fmt::Debug {
/// Clears the cache.
async fn clear(&self);
}

#[async_trait]
impl<T, V> IntrospectionCache for T
where
T: Deref<Target = V> + Send + Sync + Debug,
V: IntrospectionCache,
{
async fn get(&self, token: &str) -> Option<Response> {
self.deref().get(token).await
}

async fn set(&self, token: &str, response: Response) {
self.deref().set(token, response).await
}

async fn clear(&self) {
self.deref().clear().await
}
}

0 comments on commit 61c82a4

Please sign in to comment.