-
Notifications
You must be signed in to change notification settings - Fork 61
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Better JWT claim type dynamic specialization.
- Loading branch information
1 parent
56fb310
commit 1ce1093
Showing
8 changed files
with
418 additions
and
570 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
use std::{borrow::Cow, collections::BTreeMap}; | ||
|
||
use ssi_claims_core::{ClaimsValidity, DateTimeEnvironment, Validate}; | ||
|
||
use crate::{Claim, ClaimSet}; | ||
|
||
/// Any set of JWT claims. | ||
#[derive(Debug, Default, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] | ||
#[serde(transparent)] | ||
pub struct AnyClaims(BTreeMap<String, serde_json::Value>); | ||
|
||
impl AnyClaims { | ||
pub fn get(&self, key: &str) -> Option<&serde_json::Value> { | ||
self.0.get(key) | ||
} | ||
|
||
pub fn set(&mut self, key: String, value: serde_json::Value) -> Option<serde_json::Value> { | ||
self.0.insert(key, value) | ||
} | ||
|
||
pub fn remove(&mut self, key: &str) -> Option<serde_json::Value> { | ||
self.0.remove(key) | ||
} | ||
|
||
pub fn iter(&self) -> std::collections::btree_map::Iter<String, serde_json::Value> { | ||
self.0.iter() | ||
} | ||
} | ||
|
||
impl<'a> IntoIterator for &'a AnyClaims { | ||
type IntoIter = std::collections::btree_map::Iter<'a, String, serde_json::Value>; | ||
type Item = (&'a String, &'a serde_json::Value); | ||
|
||
fn into_iter(self) -> Self::IntoIter { | ||
self.iter() | ||
} | ||
} | ||
|
||
impl IntoIterator for AnyClaims { | ||
type IntoIter = std::collections::btree_map::IntoIter<String, serde_json::Value>; | ||
type Item = (String, serde_json::Value); | ||
|
||
fn into_iter(self) -> Self::IntoIter { | ||
self.0.into_iter() | ||
} | ||
} | ||
|
||
impl ClaimSet for AnyClaims { | ||
type Error = serde_json::Error; | ||
|
||
fn try_get<C: Claim>(&self) -> Result<Option<Cow<C>>, Self::Error> { | ||
self.get(C::JWT_CLAIM_NAME) | ||
.cloned() | ||
.map(serde_json::from_value) | ||
.transpose() | ||
} | ||
|
||
fn try_set<C: Claim>(&mut self, claim: C) -> Result<Result<(), C>, Self::Error> { | ||
self.set(C::JWT_CLAIM_NAME.to_owned(), serde_json::to_value(claim)?); | ||
Ok(Ok(())) | ||
} | ||
|
||
fn try_remove<C: Claim>(&mut self) -> Result<Option<C>, Self::Error> { | ||
self.remove(C::JWT_CLAIM_NAME) | ||
.map(serde_json::from_value) | ||
.transpose() | ||
} | ||
} | ||
|
||
impl<E> Validate<E> for AnyClaims | ||
where | ||
E: DateTimeEnvironment, | ||
{ | ||
fn validate(&self, env: &E) -> ClaimsValidity { | ||
ClaimSet::validate_registered_claims(self, env) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,214 @@ | ||
use std::borrow::Cow; | ||
|
||
use crate::Claim; | ||
|
||
/// Dynamic claim type matching. | ||
/// | ||
/// # Usage | ||
/// | ||
/// There are two ways to use this macro. | ||
/// The first one is to simply match on the value of a claim type parameter: | ||
/// ```ignore | ||
/// match_claim_type! { | ||
/// match MyClaimTypeParameter { | ||
/// TypeA => { ... }, | ||
/// TypeB => { ... }, | ||
/// _ => { ... } | ||
/// } | ||
/// } | ||
/// ``` | ||
/// | ||
/// The second one also allows you to properly cast a claim variable. | ||
/// ```ignore | ||
/// match_claim_type! { | ||
/// match claim: MyClaimTypeParameter { | ||
/// TypeA => { | ||
/// // In this block, `claim` has type `TypeA`. | ||
/// ... | ||
/// }, | ||
/// TypeB => { | ||
/// // In this block, `claim` has type `TypeB`. | ||
/// ... | ||
/// }, | ||
/// _ => { | ||
/// // In this block, `claim` has type `MyClaimTypeParameter`. | ||
/// ... | ||
/// }, | ||
/// } | ||
/// } | ||
/// ``` | ||
#[macro_export] | ||
macro_rules! match_claim_type { | ||
{ | ||
match $id:ident { | ||
$($ty:ident => $e:expr,)* | ||
_ => $default_case:expr | ||
} | ||
} => { | ||
$( | ||
if std::any::TypeId::of::<$id>() == std::any::TypeId::of::<$ty>() { | ||
let result = $e; | ||
return unsafe { | ||
// SAFETY: We just checked that `$ty` is equal to `$id`. | ||
$crate::CastClaim::<$ty, $id>::cast_claim(result) | ||
}; | ||
} | ||
)* | ||
|
||
$default_case | ||
}; | ||
{ | ||
match $x:ident: $id:ident { | ||
$($ty:ident => $e:expr,)* | ||
_ => $default_case:expr | ||
} | ||
} => { | ||
$( | ||
if std::any::TypeId::of::<$id>() == std::any::TypeId::of::<$ty>() { | ||
let $x: $ty = unsafe { | ||
// SAFETY: We just checked that `$ty` is equal to `$id`. | ||
$crate::CastClaim::<$id, $ty>::cast_claim($x) | ||
}; | ||
let result = $e; | ||
return unsafe { | ||
// SAFETY: We just checked that `$ty` is equal to `$id`. | ||
$crate::CastClaim::<$ty, $id>::cast_claim(result) | ||
}; | ||
} | ||
)* | ||
|
||
$default_case | ||
}; | ||
} | ||
|
||
/// Cast claim type `A` into `B`. | ||
pub trait CastClaim<A, B>: Sized { | ||
type Target; | ||
|
||
/// Cast claim type `A` into `B`. | ||
/// | ||
/// # Safety | ||
/// | ||
/// `A` **must** be equal to `B`. | ||
unsafe fn cast_claim(value: Self) -> Self::Target; | ||
} | ||
|
||
impl<A: Claim, B: Claim> CastClaim<A, B> for A { | ||
type Target = B; | ||
|
||
unsafe fn cast_claim(value: Self) -> Self::Target { | ||
// SAFETY: The precondition to this function is that `A` (`Self`) is | ||
// equal to `B`, meaning that the transmute does nothing. | ||
// We are just copying `value`, and forgetting the original. | ||
let result = std::mem::transmute_copy(&value); | ||
std::mem::forget(value); | ||
result | ||
} | ||
} | ||
|
||
impl<'a, A: Claim, B: Claim> CastClaim<A, B> for &'a A { | ||
type Target = &'a B; | ||
|
||
unsafe fn cast_claim(value: Self) -> Self::Target { | ||
std::mem::transmute_copy(&value) | ||
} | ||
} | ||
|
||
impl<A, B> CastClaim<A, B> for () { | ||
type Target = Self; | ||
|
||
unsafe fn cast_claim(value: Self) -> Self::Target { | ||
value | ||
} | ||
} | ||
|
||
impl<A, B, T: CastClaim<A, B>> CastClaim<A, B> for Option<T> { | ||
type Target = Option<T::Target>; | ||
|
||
unsafe fn cast_claim(value: Self) -> Self::Target { | ||
value.map(|t| T::cast_claim(t)) | ||
} | ||
} | ||
|
||
impl<A, B, T: CastClaim<A, B>, E> CastClaim<A, B> for Result<T, E> { | ||
type Target = Result<T::Target, E>; | ||
|
||
unsafe fn cast_claim(value: Self) -> Self::Target { | ||
value.map(|t| T::cast_claim(t)) | ||
} | ||
} | ||
|
||
impl<'a, A: Claim, B: Claim> CastClaim<A, B> for Cow<'a, A> { | ||
type Target = Cow<'a, B>; | ||
|
||
unsafe fn cast_claim(value: Self) -> Self::Target { | ||
match value { | ||
Self::Owned(value) => Cow::Owned(CastClaim::cast_claim(value)), | ||
Self::Borrowed(value) => Cow::Borrowed(CastClaim::cast_claim(value)), | ||
} | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use serde::{Deserialize, Serialize}; | ||
use std::borrow::Cow; | ||
|
||
use crate::{AnyClaims, Claim, ClaimSet}; | ||
|
||
#[derive(Clone, Serialize, Deserialize)] | ||
struct CustomClaim; | ||
|
||
impl Claim for CustomClaim { | ||
const JWT_CLAIM_NAME: &'static str = "custom"; | ||
} | ||
|
||
struct CustomClaimSet { | ||
custom: Option<CustomClaim>, | ||
other_claims: AnyClaims, | ||
} | ||
|
||
impl ClaimSet for CustomClaimSet { | ||
type Error = serde_json::Error; | ||
|
||
fn try_get<C: Claim>(&self) -> Result<Option<Cow<C>>, Self::Error> { | ||
match_claim_type! { | ||
match C { | ||
CustomClaim => { | ||
Ok(self.custom.as_ref().map(Cow::Borrowed)) | ||
}, | ||
_ => { | ||
ClaimSet::try_get::<C>(&self.other_claims) | ||
} | ||
} | ||
} | ||
} | ||
|
||
fn try_set<C: Claim>(&mut self, claim: C) -> Result<Result<(), C>, Self::Error> { | ||
match_claim_type! { | ||
match claim: C { | ||
CustomClaim => { | ||
self.custom = Some(claim); | ||
Ok(Ok(())) | ||
}, | ||
_ => { | ||
ClaimSet::try_set(&mut self.other_claims, claim) | ||
} | ||
} | ||
} | ||
} | ||
|
||
fn try_remove<C: Claim>(&mut self) -> Result<Option<C>, Self::Error> { | ||
match_claim_type! { | ||
match C { | ||
CustomClaim => { | ||
Ok(self.custom.take()) | ||
}, | ||
_ => { | ||
ClaimSet::try_remove::<C>(&mut self.other_claims) | ||
} | ||
} | ||
} | ||
} | ||
} | ||
} |
Oops, something went wrong.