Skip to content

Commit

Permalink
Include remainder of changes
Browse files Browse the repository at this point in the history
  • Loading branch information
mmastrac committed Feb 10, 2025
1 parent 0bff987 commit 4da1209
Show file tree
Hide file tree
Showing 10 changed files with 751 additions and 255 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions rust/gel-jwt/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ python_extension = ["pyo3/extension-module"]
[dependencies]
pyo3 = { workspace = true, optional = true }
pyo3_util.workspace = true
tracing.workspace = true

# This is required to be in sync w/jsonwebtoken
rand = "0.8.5"
Expand Down
3 changes: 2 additions & 1 deletion rust/gel-jwt/benches/encode.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::collections::HashMap;

use gel_jwt::{KeyType, PrivateKey, SigningContext};
use gel_jwt::{KeyType, PrivateKey, SigningContext, ValidationContext};

#[divan::bench(args = [&KeyType::ES256, &KeyType::RS256, &KeyType::HS256])]
fn bench_jwt_signing(b: divan::Bencher, key_type: &KeyType) {
Expand All @@ -17,6 +17,7 @@ fn bench_jwt_validation(b: divan::Bencher, key_type: &KeyType) {
let claims = HashMap::from([("sub".to_string(), "test".into())]);
let ctx = SigningContext::default();
let token = key.sign(claims, &ctx).unwrap();
let ctx = ValidationContext::default();

b.bench_local(move || key.validate(&token, &ctx));
}
Expand Down
4 changes: 3 additions & 1 deletion rust/gel-jwt/src/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# JWT support

This crate provides support for JWT tokens.
This crate provides support for JWT tokens. The JWT signing and verification is done
using the `jsonwebtoken` crate, while the key loading is performed here via the
`rsa`/`p256` crates.

## Key types

Expand Down
9 changes: 4 additions & 5 deletions rust/gel-jwt/src/bare_key.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1059,11 +1059,10 @@ fn handle_rsa_pubkey(key: &Pem) -> Result<BarePublicKeyInner, KeyError> {
/// Decode a base64 string with optional padding, since jwcrypto also seems to
/// accept this.
///
/// :JWKs make use of the base64url encoding as defined in RFC 4648 [RFC4648].
/// As allowed by Section 3.2 of the RFC, this specification mandates that
/// base64url encoding when used with JWKs MUST NOT use padding. Notes on
/// implementing base64url encoding can be found in the JWS [JWS]
/// specification.""
/// > JWKs make use of the base64url encoding as defined in RFC 4648 As allowed
/// > by Section 3.2 of the RFC, this specification mandates that base64url
/// > encoding when used with JWKs MUST NOT use padding. Notes on implementing
/// > base64url encoding can be found in the JWS specification.
fn b64_decode(s: &str) -> Result<zeroize::Zeroizing<Vec<u8>>, KeyError> {
let vec = if s.ends_with('=') {
base64ct::Base64Url::decode_vec(s).map_err(|_| KeyError::DecodeError)?
Expand Down
154 changes: 111 additions & 43 deletions rust/gel-jwt/src/key.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,12 @@
use jsonwebtoken::{Algorithm, Header, Validation};
use serde::{Deserialize, Serialize};
use std::{
collections::{HashMap, HashSet},
fmt::Debug,
sync::Arc,
time::Duration,
};
use std::{collections::HashMap, fmt::Debug, sync::Arc};

use crate::{
bare_key::{BareKeyInner, SerializedKey},
registry::IsKey,
Any, BareKey, BarePrivateKey, BarePublicKey, KeyError, OpaqueValidationFailureReason,
SignatureError, ValidationError,
SignatureError, SigningContext, ValidationContext, ValidationError, ValidationType,
};

#[derive(Clone, Copy, Debug, derive_more::Display, PartialEq, Eq)]
Expand All @@ -21,16 +16,6 @@ pub enum KeyType {
HS256,
}

#[derive(Clone, Serialize, Deserialize, Default)]
pub struct SigningContext {
pub expiry: Option<Duration>,
pub issuer: Option<String>,
pub audience: Option<String>,
pub allow: HashMap<String, HashSet<String>>,
pub deny: HashMap<String, HashSet<String>>,
pub not_before: Option<Duration>,
}

#[derive(Serialize, Deserialize)]
struct Token {
#[serde(rename = "exp", default, skip_serializing_if = "Option::is_none")]
Expand Down Expand Up @@ -106,7 +91,7 @@ impl PrivateKey {
pub fn validate(
&self,
token: &str,
ctx: &SigningContext,
ctx: &ValidationContext,
) -> Result<HashMap<String, Any>, ValidationError> {
validate_token(
self.key_type(),
Expand Down Expand Up @@ -202,8 +187,18 @@ pub(crate) fn sign_token(
(None, None)
};

let expiry = ctx.expiry.map(|d| d.as_secs() as isize);
let expiry = if expiry == Some(0) {
// Ensure that a token that expires now expires with enough notice for
// the leeway option to be ignored. This isn't a great solution, but
// it's challenging to test expiring tokens otherwise.
Some(now.saturating_sub(120))
} else {
expiry.map(|d| now.saturating_add_signed(d))
};

let token = Token {
expiry: ctx.expiry.map(|d| now.saturating_add(d.as_secs() as _)),
expiry,
issuer: ctx.issuer.clone(),
audience: ctx.audience.clone(),
issued_at,
Expand All @@ -222,25 +217,48 @@ pub(crate) fn validate_token(
decoding_key: &jsonwebtoken::DecodingKey,
kid: Option<&str>,
token: &str,
ctx: &SigningContext,
ctx: &ValidationContext,
) -> Result<HashMap<String, Any>, ValidationError> {
let mut validation = Validation::new(match key_type {
KeyType::ES256 => Algorithm::ES256,
KeyType::HS256 => Algorithm::HS256,
KeyType::RS256 => Algorithm::RS256,
});

if ctx.expiry.is_none() {
validation.required_spec_claims.remove("exp");
}
if ctx.not_before.is_none() {
validation.required_spec_claims.remove("nbf");
}
if let Some(aud) = &ctx.audience {
validation.set_audience(&[aud]);
validation.validate_aud = false;

match ctx.expiry {
ValidationType::Ignore => {
validation.required_spec_claims.remove("exp");
validation.validate_exp = false;
}
ValidationType::Allow => {
validation.required_spec_claims.remove("exp");
validation.validate_exp = true;
}
ValidationType::Reject => {
validation.required_spec_claims.remove("exp");
validation.validate_exp = false;
}
ValidationType::Require => {
// The default
}
}
if let Some(iss) = &ctx.issuer {
validation.set_issuer(&[iss]);

match ctx.not_before {
ValidationType::Ignore => {
validation.validate_nbf = false;
}
ValidationType::Allow => {
validation.validate_nbf = true;
}
ValidationType::Reject => {
validation.validate_nbf = false;
}
ValidationType::Require => {
validation.required_spec_claims.insert("nbf".to_string());
validation.validate_nbf = true;
}
}

let token = jsonwebtoken::decode::<HashMap<String, Any>>(token, decoding_key, &validation)
Expand All @@ -262,7 +280,7 @@ pub(crate) fn validate_token(
}
}

for (claim, values) in &ctx.allow {
for (claim, values) in &ctx.allow_list {
let value = token.claims.get(claim);
match value {
Some(Any::String(value)) => {
Expand All @@ -274,6 +292,25 @@ pub(crate) fn validate_token(
.into());
}
}
Some(Any::Array(array_values)) => {
for v in array_values.iter() {
if let Any::String(v) = v {
if !values.contains(v.as_ref()) {
return Err(OpaqueValidationFailureReason::InvalidClaimValue(
claim.to_string(),
Some(v.to_string()),
)
.into());
}
} else {
return Err(OpaqueValidationFailureReason::InvalidClaimValue(
claim.to_string(),
None,
)
.into());
}
}
}
_ => {
return Err(OpaqueValidationFailureReason::InvalidClaimValue(
claim.to_string(),
Expand All @@ -284,7 +321,7 @@ pub(crate) fn validate_token(
}
}

for (claim, values) in &ctx.deny {
for (claim, values) in &ctx.deny_list {
let value = token.claims.get(claim);
match value {
Some(Any::String(value)) => {
Expand All @@ -296,6 +333,25 @@ pub(crate) fn validate_token(
.into());
}
}
Some(Any::Array(array_values)) => {
for v in array_values.iter() {
if let Any::String(v) = v {
if values.contains(v.as_ref()) {
return Err(OpaqueValidationFailureReason::InvalidClaimValue(
claim.to_string(),
Some(v.to_string()),
)
.into());
}
} else {
return Err(OpaqueValidationFailureReason::InvalidClaimValue(
claim.to_string(),
None,
)
.into());
}
}
}
_ => {
return Err(OpaqueValidationFailureReason::InvalidClaimValue(
claim.to_string(),
Expand All @@ -306,19 +362,31 @@ pub(crate) fn validate_token(
}
}

// Remove any claims that were validated automatically
// Remove any claims that were validated automatically and reject any that should not
// be present.
let mut claims = token.claims;
if ctx.audience.is_some() {
claims.remove("aud");
}
if ctx.issuer.is_some() {
claims.remove("iss");
claims.remove("exp");
for claim in ctx.claims.iter() {
claims.remove(claim.0);
}
if ctx.expiry.is_some() {
claims.remove("exp");

if ctx.expiry == ValidationType::Reject {
if let Some(exp) = claims.remove("exp") {
return Err(OpaqueValidationFailureReason::InvalidClaimValue(
"exp".to_string(),
Some(format!("{exp:?}")),
)
.into());
}
}
if ctx.not_before.is_some() {
claims.remove("nbf");
if ctx.not_before == ValidationType::Reject {
if let Some(nbf) = claims.remove("nbf") {
return Err(OpaqueValidationFailureReason::InvalidClaimValue(
"nbf".to_string(),
Some(format!("{nbf:?}")),
)
.into());
}
}

Ok(claims)
Expand Down Expand Up @@ -379,7 +447,7 @@ impl PublicKey {
pub fn validate(
&self,
token: &str,
ctx: &SigningContext,
ctx: &ValidationContext,
) -> Result<HashMap<String, Any>, ValidationError> {
validate_token(
self.key_type(),
Expand Down
Loading

0 comments on commit 4da1209

Please sign in to comment.