From 4da1209cc8758375020e962c3ca5222fe8131ad3 Mon Sep 17 00:00:00 2001 From: Matt Mastracci Date: Mon, 10 Feb 2025 14:42:18 -0700 Subject: [PATCH] Include remainder of changes --- Cargo.lock | 1 + rust/gel-jwt/Cargo.toml | 1 + rust/gel-jwt/benches/encode.rs | 3 +- rust/gel-jwt/src/README.md | 4 +- rust/gel-jwt/src/bare_key.rs | 9 +- rust/gel-jwt/src/key.rs | 154 +++++++++++---- rust/gel-jwt/src/lib.rs | 176 +++++------------ rust/gel-jwt/src/python.rs | 351 ++++++++++++++++++++++++++------- rust/gel-jwt/src/registry.rs | 101 +++++++++- rust/gel-jwt/src/sig.rs | 206 +++++++++++++++++++ 10 files changed, 751 insertions(+), 255 deletions(-) create mode 100644 rust/gel-jwt/src/sig.rs diff --git a/Cargo.lock b/Cargo.lock index 904cc783f5f..8f7e6f03326 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1141,6 +1141,7 @@ dependencies = [ "serde_json", "sha2", "thiserror 2.0.3", + "tracing", "uuid", "zeroize", ] diff --git a/rust/gel-jwt/Cargo.toml b/rust/gel-jwt/Cargo.toml index 7ff8188e1e2..1839ea373eb 100644 --- a/rust/gel-jwt/Cargo.toml +++ b/rust/gel-jwt/Cargo.toml @@ -9,6 +9,7 @@ python_extension = ["pyo3/extension-module"] [dependencies] pyo3 = { workspace = true, optional = true } pyo3_util.workspace = true +tracing.workspace = true # This is required to be in sync w/jsonwebtoken rand = "0.8.5" diff --git a/rust/gel-jwt/benches/encode.rs b/rust/gel-jwt/benches/encode.rs index c68d8e30377..f1642de6340 100644 --- a/rust/gel-jwt/benches/encode.rs +++ b/rust/gel-jwt/benches/encode.rs @@ -1,6 +1,6 @@ use std::collections::HashMap; -use gel_jwt::{KeyType, PrivateKey, SigningContext}; +use gel_jwt::{KeyType, PrivateKey, SigningContext, ValidationContext}; #[divan::bench(args = [&KeyType::ES256, &KeyType::RS256, &KeyType::HS256])] fn bench_jwt_signing(b: divan::Bencher, key_type: &KeyType) { @@ -17,6 +17,7 @@ fn bench_jwt_validation(b: divan::Bencher, key_type: &KeyType) { let claims = HashMap::from([("sub".to_string(), "test".into())]); let ctx = SigningContext::default(); let token = key.sign(claims, &ctx).unwrap(); + let ctx = ValidationContext::default(); b.bench_local(move || key.validate(&token, &ctx)); } diff --git a/rust/gel-jwt/src/README.md b/rust/gel-jwt/src/README.md index d1e14427fd0..a0438d55d23 100644 --- a/rust/gel-jwt/src/README.md +++ b/rust/gel-jwt/src/README.md @@ -1,6 +1,8 @@ # JWT support -This crate provides support for JWT tokens. +This crate provides support for JWT tokens. The JWT signing and verification is done +using the `jsonwebtoken` crate, while the key loading is performed here via the +`rsa`/`p256` crates. ## Key types diff --git a/rust/gel-jwt/src/bare_key.rs b/rust/gel-jwt/src/bare_key.rs index 1b09330b8da..4b69d08935b 100644 --- a/rust/gel-jwt/src/bare_key.rs +++ b/rust/gel-jwt/src/bare_key.rs @@ -1059,11 +1059,10 @@ fn handle_rsa_pubkey(key: &Pem) -> Result { /// Decode a base64 string with optional padding, since jwcrypto also seems to /// accept this. /// -/// :JWKs make use of the base64url encoding as defined in RFC 4648 [RFC4648]. -/// As allowed by Section 3.2 of the RFC, this specification mandates that -/// base64url encoding when used with JWKs MUST NOT use padding. Notes on -/// implementing base64url encoding can be found in the JWS [JWS] -/// specification."" +/// > JWKs make use of the base64url encoding as defined in RFC 4648 As allowed +/// > by Section 3.2 of the RFC, this specification mandates that base64url +/// > encoding when used with JWKs MUST NOT use padding. Notes on implementing +/// > base64url encoding can be found in the JWS specification. fn b64_decode(s: &str) -> Result>, KeyError> { let vec = if s.ends_with('=') { base64ct::Base64Url::decode_vec(s).map_err(|_| KeyError::DecodeError)? diff --git a/rust/gel-jwt/src/key.rs b/rust/gel-jwt/src/key.rs index 7b5c8c885f6..9ef5b3ca0c3 100644 --- a/rust/gel-jwt/src/key.rs +++ b/rust/gel-jwt/src/key.rs @@ -1,17 +1,12 @@ use jsonwebtoken::{Algorithm, Header, Validation}; use serde::{Deserialize, Serialize}; -use std::{ - collections::{HashMap, HashSet}, - fmt::Debug, - sync::Arc, - time::Duration, -}; +use std::{collections::HashMap, fmt::Debug, sync::Arc}; use crate::{ bare_key::{BareKeyInner, SerializedKey}, registry::IsKey, Any, BareKey, BarePrivateKey, BarePublicKey, KeyError, OpaqueValidationFailureReason, - SignatureError, ValidationError, + SignatureError, SigningContext, ValidationContext, ValidationError, ValidationType, }; #[derive(Clone, Copy, Debug, derive_more::Display, PartialEq, Eq)] @@ -21,16 +16,6 @@ pub enum KeyType { HS256, } -#[derive(Clone, Serialize, Deserialize, Default)] -pub struct SigningContext { - pub expiry: Option, - pub issuer: Option, - pub audience: Option, - pub allow: HashMap>, - pub deny: HashMap>, - pub not_before: Option, -} - #[derive(Serialize, Deserialize)] struct Token { #[serde(rename = "exp", default, skip_serializing_if = "Option::is_none")] @@ -106,7 +91,7 @@ impl PrivateKey { pub fn validate( &self, token: &str, - ctx: &SigningContext, + ctx: &ValidationContext, ) -> Result, ValidationError> { validate_token( self.key_type(), @@ -202,8 +187,18 @@ pub(crate) fn sign_token( (None, None) }; + let expiry = ctx.expiry.map(|d| d.as_secs() as isize); + let expiry = if expiry == Some(0) { + // Ensure that a token that expires now expires with enough notice for + // the leeway option to be ignored. This isn't a great solution, but + // it's challenging to test expiring tokens otherwise. + Some(now.saturating_sub(120)) + } else { + expiry.map(|d| now.saturating_add_signed(d)) + }; + let token = Token { - expiry: ctx.expiry.map(|d| now.saturating_add(d.as_secs() as _)), + expiry, issuer: ctx.issuer.clone(), audience: ctx.audience.clone(), issued_at, @@ -222,7 +217,7 @@ pub(crate) fn validate_token( decoding_key: &jsonwebtoken::DecodingKey, kid: Option<&str>, token: &str, - ctx: &SigningContext, + ctx: &ValidationContext, ) -> Result, ValidationError> { let mut validation = Validation::new(match key_type { KeyType::ES256 => Algorithm::ES256, @@ -230,17 +225,40 @@ pub(crate) fn validate_token( KeyType::RS256 => Algorithm::RS256, }); - if ctx.expiry.is_none() { - validation.required_spec_claims.remove("exp"); - } - if ctx.not_before.is_none() { - validation.required_spec_claims.remove("nbf"); - } - if let Some(aud) = &ctx.audience { - validation.set_audience(&[aud]); + validation.validate_aud = false; + + match ctx.expiry { + ValidationType::Ignore => { + validation.required_spec_claims.remove("exp"); + validation.validate_exp = false; + } + ValidationType::Allow => { + validation.required_spec_claims.remove("exp"); + validation.validate_exp = true; + } + ValidationType::Reject => { + validation.required_spec_claims.remove("exp"); + validation.validate_exp = false; + } + ValidationType::Require => { + // The default + } } - if let Some(iss) = &ctx.issuer { - validation.set_issuer(&[iss]); + + match ctx.not_before { + ValidationType::Ignore => { + validation.validate_nbf = false; + } + ValidationType::Allow => { + validation.validate_nbf = true; + } + ValidationType::Reject => { + validation.validate_nbf = false; + } + ValidationType::Require => { + validation.required_spec_claims.insert("nbf".to_string()); + validation.validate_nbf = true; + } } let token = jsonwebtoken::decode::>(token, decoding_key, &validation) @@ -262,7 +280,7 @@ pub(crate) fn validate_token( } } - for (claim, values) in &ctx.allow { + for (claim, values) in &ctx.allow_list { let value = token.claims.get(claim); match value { Some(Any::String(value)) => { @@ -274,6 +292,25 @@ pub(crate) fn validate_token( .into()); } } + Some(Any::Array(array_values)) => { + for v in array_values.iter() { + if let Any::String(v) = v { + if !values.contains(v.as_ref()) { + return Err(OpaqueValidationFailureReason::InvalidClaimValue( + claim.to_string(), + Some(v.to_string()), + ) + .into()); + } + } else { + return Err(OpaqueValidationFailureReason::InvalidClaimValue( + claim.to_string(), + None, + ) + .into()); + } + } + } _ => { return Err(OpaqueValidationFailureReason::InvalidClaimValue( claim.to_string(), @@ -284,7 +321,7 @@ pub(crate) fn validate_token( } } - for (claim, values) in &ctx.deny { + for (claim, values) in &ctx.deny_list { let value = token.claims.get(claim); match value { Some(Any::String(value)) => { @@ -296,6 +333,25 @@ pub(crate) fn validate_token( .into()); } } + Some(Any::Array(array_values)) => { + for v in array_values.iter() { + if let Any::String(v) = v { + if values.contains(v.as_ref()) { + return Err(OpaqueValidationFailureReason::InvalidClaimValue( + claim.to_string(), + Some(v.to_string()), + ) + .into()); + } + } else { + return Err(OpaqueValidationFailureReason::InvalidClaimValue( + claim.to_string(), + None, + ) + .into()); + } + } + } _ => { return Err(OpaqueValidationFailureReason::InvalidClaimValue( claim.to_string(), @@ -306,19 +362,31 @@ pub(crate) fn validate_token( } } - // Remove any claims that were validated automatically + // Remove any claims that were validated automatically and reject any that should not + // be present. let mut claims = token.claims; - if ctx.audience.is_some() { - claims.remove("aud"); - } - if ctx.issuer.is_some() { - claims.remove("iss"); + claims.remove("exp"); + for claim in ctx.claims.iter() { + claims.remove(claim.0); } - if ctx.expiry.is_some() { - claims.remove("exp"); + + if ctx.expiry == ValidationType::Reject { + if let Some(exp) = claims.remove("exp") { + return Err(OpaqueValidationFailureReason::InvalidClaimValue( + "exp".to_string(), + Some(format!("{exp:?}")), + ) + .into()); + } } - if ctx.not_before.is_some() { - claims.remove("nbf"); + if ctx.not_before == ValidationType::Reject { + if let Some(nbf) = claims.remove("nbf") { + return Err(OpaqueValidationFailureReason::InvalidClaimValue( + "nbf".to_string(), + Some(format!("{nbf:?}")), + ) + .into()); + } } Ok(claims) @@ -379,7 +447,7 @@ impl PublicKey { pub fn validate( &self, token: &str, - ctx: &SigningContext, + ctx: &ValidationContext, ) -> Result, ValidationError> { validate_token( self.key_type(), diff --git a/rust/gel-jwt/src/lib.rs b/rust/gel-jwt/src/lib.rs index 5cfda4d93b8..8f60e44df58 100644 --- a/rust/gel-jwt/src/lib.rs +++ b/rust/gel-jwt/src/lib.rs @@ -1,16 +1,18 @@ #[cfg(feature = "python_extension")] pub mod python; -use std::{borrow::Cow, collections::HashMap, fmt::Debug}; +use std::fmt::Debug; use thiserror::Error; mod bare_key; mod key; mod registry; +mod sig; pub use bare_key::{BareKey, BarePrivateKey, BarePublicKey}; -pub use key::{Key, KeyType, PrivateKey, PublicKey, SigningContext}; +pub use key::{Key, KeyType, PrivateKey, PublicKey}; pub use registry::KeyRegistry; +pub use sig::{Any, SigningContext, ValidationContext, ValidationType}; #[derive(Error, Debug, Eq, PartialEq)] pub enum ValidationError { @@ -107,117 +109,9 @@ pub enum KeyError { #[derive(Debug, Eq, PartialEq)] pub struct KeyValidationError(String); -#[derive(Clone, serde::Serialize, serde::Deserialize, Debug, PartialEq)] -#[serde(untagged)] -pub enum Any { - None, - String(Cow<'static, str>), - Bool(bool), - Number(isize), - Array(Vec), - Object(HashMap, Any>), -} - -impl From for Any { - fn from(value: bool) -> Self { - Any::Bool(value) - } -} - -impl From<&'static str> for Any { - fn from(value: &'static str) -> Self { - Any::String(Cow::Borrowed(value)) - } -} - -impl From for Any { - fn from(value: String) -> Self { - Any::String(Cow::Owned(value)) - } -} - -impl From> for Any -where - T: Into, -{ - fn from(value: Option) -> Self { - value.map(T::into).unwrap_or(Any::None) - } -} - -impl From> for Any -where - T: Into, -{ - fn from(value: Vec) -> Self { - Any::Array(value.into_iter().map(T::into).collect()) - } -} - -#[cfg(feature = "python_extension")] -impl<'py> pyo3::FromPyObject<'py> for Any { - fn extract_bound(ob: &pyo3::Bound<'py, pyo3::PyAny>) -> pyo3::PyResult { - use pyo3::types::PyAnyMethods; - if ob.is_none() { - return Ok(Any::None); - } - if let Ok(value) = ob.extract::() { - return Ok(Any::Bool(value)); - } - if let Ok(value) = ob.extract::() { - return Ok(Any::Number(value)); - } - if let Ok(value) = ob.extract::() { - return Ok(Any::String(Cow::Owned(value))); - } - let res: Result, pyo3::PyErr> = ob.extract(); - if let Ok(list) = res { - let mut items = Vec::new(); - for item in list { - items.push(Any::extract_bound(&item)?); - } - return Ok(Any::Array(items)); - } - let res: Result, pyo3::PyErr> = ob.extract(); - if let Ok(dict) = res { - let mut items = HashMap::new(); - for (k, v) in dict { - items.insert(Cow::Owned(k.extract::()?), Any::extract_bound(&v)?); - } - return Ok(Any::Object(items)); - } - Err(pyo3::PyErr::new::( - "Invalid Any value", - )) - } -} - -#[cfg(feature = "python_extension")] -impl<'py> pyo3::IntoPyObject<'py> for Any { - type Target = pyo3::PyAny; - type Output = pyo3::Bound<'py, pyo3::PyAny>; - type Error = pyo3::PyErr; - fn into_pyobject(self, py: pyo3::Python<'py>) -> Result { - use pyo3::IntoPyObjectExt; - - Ok(match self { - Any::None => py.None(), - Any::String(s) => s.as_ref().into_py_any(py)?, - Any::Bool(b) => b.into_py_any(py)?, - Any::Number(n) => n.into_py_any(py)?, - Any::Array(a) => a.into_py_any(py)?, - Any::Object(o) => o.into_py_any(py)?, - } - .into_bound(py)) - } -} - #[cfg(test)] mod tests { - use std::{ - collections::{HashMap, HashSet}, - time::Duration, - }; + use std::{collections::HashMap, time::Duration}; use super::*; @@ -278,14 +172,18 @@ mod tests { let key = PrivateKey::generate(Some("1".to_owned()), KeyType::HS256).unwrap(); let claims = HashMap::from([("hello".to_owned(), "world".into())]); let signing_ctx = SigningContext { - expiry: Some(Duration::from_secs(10)), + expiry: Some(Duration::from_secs(600)), issuer: Some("issuer".to_owned()), audience: Some("audience".to_owned()), ..Default::default() }; + let mut validation_ctx = ValidationContext::default(); + validation_ctx.require_claim("aud"); + validation_ctx.require_claim_with_allow_list("iss", &["issuer"]); + let token = key.sign(claims.clone(), &signing_ctx).unwrap(); println!("token: {}", token); - let decoded = key.validate(&token, &signing_ctx).unwrap(); + let decoded = key.validate(&token, &validation_ctx).unwrap(); assert_eq!(decoded, claims); } @@ -299,7 +197,13 @@ mod tests { ..Default::default() }; let token = key.sign(claims.clone(), &signing_ctx).unwrap(); - let decoded = key.validate(&token, &signing_ctx).unwrap(); + let mut validation_ctx = ValidationContext::default(); + validation_ctx.require_claim("aud"); + validation_ctx.require_claim_with_allow_list("iss", &["issuer"]); + let decoded = key + .validate(&token, &validation_ctx) + .map_err(|e| e.error_string_not_for_user()) + .unwrap(); assert_eq!(decoded, claims); } @@ -397,11 +301,14 @@ mod tests { audience: Some("test-audience".to_owned()), ..Default::default() }; + let mut validation_ctx = ValidationContext::default(); + validation_ctx.require_claim_with_allow_list("iss", &["test-issuer"]); + validation_ctx.require_claim_with_allow_list("aud", &["test-audience"]); // Generate and validate a token with each key for key in &keys { let token = key.sign(claims.clone(), &signing_ctx).unwrap(); - let decoded = registry.validate(&token, &signing_ctx).unwrap(); + let decoded = registry.validate(&token, &validation_ctx).unwrap(); assert_eq!(decoded, claims); } @@ -412,7 +319,7 @@ mod tests { .unwrap(); for key in &keys { let token = key.sign(claims.clone(), &signing_ctx).unwrap(); - let decoded = registry.validate(&token, &signing_ctx).unwrap(); + let decoded = registry.validate(&token, &validation_ctx).unwrap(); assert_eq!(decoded, claims); } } @@ -428,6 +335,7 @@ mod tests { audience: Some("test-audience".to_owned()), ..Default::default() }; + let validation_ctx = ValidationContext::default(); let token = key1.sign(claims, &signing_ctx).unwrap(); // Swap the keys so the signature is no longer valid with the specified kid @@ -438,7 +346,7 @@ mod tests { registry.add_key(key1); registry.add_key(key2); - let decoded = registry.validate(&token, &signing_ctx).unwrap_err(); + let decoded = registry.validate(&token, &validation_ctx).unwrap_err(); assert_eq!( decoded, ValidationError::Invalid(OpaqueValidationFailureReason::InvalidSignature), @@ -454,19 +362,35 @@ mod tests { registry.add_key(key); let claims = HashMap::from([("jti".to_owned(), "1234".into())]); - let signing_ctx = SigningContext { - allow: HashMap::from([("jti".to_owned(), HashSet::from(["1234".to_owned()]))]), - ..Default::default() - }; - + let signing_ctx = SigningContext::default(); + let mut validation_ctx = ValidationContext::default(); let token = registry.sign(claims.clone(), &signing_ctx).unwrap(); - let decoded = registry.validate(&token, &signing_ctx).unwrap(); - assert_eq!(decoded, claims); + + // With no claim validation, the token should be valid + let res = registry.validate(&token, &validation_ctx); + assert!( + matches!(res, Ok(_)), + "{}", + res.unwrap_err().error_string_not_for_user() + ); + + validation_ctx.require_claim_with_allow_list("jti", &["1234"]); + let decoded = registry.validate(&token, &validation_ctx).unwrap(); + assert_eq!(decoded, Default::default()); let claims = HashMap::from([("jti".to_owned(), "bad".into())]); let token = registry.sign(claims, &signing_ctx).unwrap(); - let decoded = registry.validate(&token, &signing_ctx).unwrap_err(); + let decoded = registry.validate(&token, &validation_ctx).unwrap_err(); + assert_eq!( + decoded, + ValidationError::Invalid(OpaqueValidationFailureReason::InvalidClaimValue( + "jti".to_string(), + Some("bad".to_string()) + )) + ); + validation_ctx.require_claim_with_deny_list("jti", &["bad"]); + let decoded = registry.validate(&token, &validation_ctx).unwrap_err(); assert_eq!( decoded, ValidationError::Invalid(OpaqueValidationFailureReason::InvalidClaimValue( @@ -485,7 +409,7 @@ mod tests { ("number".to_owned(), Any::Number(123)), ( "array".to_owned(), - Any::Array(vec![Any::String("1".into()), Any::String("2".into())].into()), + Any::Array(vec![Any::String("1".into()), Any::String("2".into())]), ), ]); let json = serde_json::to_string(&map).unwrap(); diff --git a/rust/gel-jwt/src/python.rs b/rust/gel-jwt/src/python.rs index 341bd2a5903..633e1c94d77 100644 --- a/rust/gel-jwt/src/python.rs +++ b/rust/gel-jwt/src/python.rs @@ -1,18 +1,20 @@ use std::{ - collections::HashMap, + collections::{HashMap, HashSet}, time::{Duration, Instant}, }; use crate::{ bare_key::SerializedKey, Any, BarePrivateKey, Key, KeyError, KeyRegistry, KeyType, - SignatureError, ValidationError, + SignatureError, ValidationError, ValidationType, }; +use base64ct::{Base64Unpadded, Encoding}; use pyo3::{ exceptions::PyValueError, prelude::*, - types::{PyBytes, PyDict, PyList}, + types::{PyBytes, PyDict}, }; use serde::{Deserialize, Serialize}; +use tracing::warn; use uuid::Uuid; impl From for PyErr { @@ -29,7 +31,7 @@ impl From for PyErr { impl From for PyErr { fn from(value: ValidationError) -> Self { - PyValueError::new_err(value.to_string() + ":" + value.error_string_not_for_user().as_str()) + PyValueError::new_err(format!("{}: {}", value, value.error_string_not_for_user())) } } @@ -59,50 +61,78 @@ impl SigningCtx { self.context.not_before = Some(Duration::from_secs(not_before as u64)); } - pub fn set_expiry(&mut self, expiry: usize) { - self.context.expiry = Some(Duration::from_secs(expiry as u64)); + pub fn set_expiry(&mut self, expiry: isize) { + self.context.expiry = Some(Duration::from_secs(expiry.max(0) as u64)); } +} + +#[pyclass] +pub struct ValidationCtx { + context: crate::ValidationContext, +} - pub fn allow(&mut self, claim: &str, values: Bound) -> PyResult<()> { +#[pymethods] +impl ValidationCtx { + #[new] + pub fn new() -> PyResult { + Ok(Self { + context: crate::ValidationContext::default(), + }) + } + + pub fn allow(&mut self, claim: &str, values: Bound) -> PyResult<()> { + let values = vec_from_list_or_tuple(values)?; self.context - .allow - .insert(claim.to_string(), values.extract()?); + .allow_list + .insert(claim.to_string(), values.into_iter().collect()); Ok(()) } - pub fn deny(&mut self, claim: &str, values: Bound) -> PyResult<()> { + pub fn deny(&mut self, claim: &str, values: Bound) -> PyResult<()> { + let values = vec_from_list_or_tuple(values)?; self.context - .deny - .insert(claim.to_string(), values.extract()?); + .deny_list + .insert(claim.to_string(), values.into_iter().collect()); Ok(()) } + + pub fn require_expiry(&mut self) { + self.context.expiry = ValidationType::Require; + } + + pub fn ignore_expiry(&mut self) { + self.context.expiry = ValidationType::Ignore; + } } #[pyclass] pub struct JWKSet { registry: KeyRegistry, - context: crate::SigningContext, + default_signing_ctx: Py, + default_validation_ctx: Py, } #[pymethods] impl JWKSet { #[new] - pub fn new() -> PyResult { + pub fn new(py: Python) -> PyResult { let registry = KeyRegistry::::default(); Ok(Self { registry, - context: crate::SigningContext::default(), + default_signing_ctx: Py::new(py, SigningCtx::new()?)?, + default_validation_ctx: Py::new(py, ValidationCtx::new()?)?, }) } #[staticmethod] - pub fn from_hs256_key(key: Bound) -> PyResult { + pub fn from_hs256_key(py: Python, key: Bound) -> PyResult { let key = BarePrivateKey::from_raw_oct(key.as_bytes())?; let mut registry = KeyRegistry::::default(); registry.add_key(Key::from_bare_private_key(None, key)?); Ok(Self { registry, - context: crate::SigningContext::default(), + default_signing_ctx: Py::new(py, SigningCtx::new()?)?, + default_validation_ctx: Py::new(py, ValidationCtx::new()?)?, }) } @@ -119,6 +149,16 @@ impl JWKSet { Ok(()) } + #[getter] + pub fn default_signing_context(&self, py: Python) -> Py { + self.default_signing_ctx.clone_ref(py) + } + + #[getter] + pub fn default_validation_context(&self, py: Python) -> Py { + self.default_validation_ctx.clone_ref(py) + } + #[pyo3(signature = (*, kid, kty, **kwargs))] pub fn add( &mut self, @@ -170,36 +210,6 @@ impl JWKSet { Ok(count) } - pub fn set_issuer(&mut self, issuer: &str) { - self.context.issuer = Some(issuer.to_string()); - } - - pub fn set_audience(&mut self, audience: &str) { - self.context.audience = Some(audience.to_string()); - } - - pub fn set_not_before(&mut self, not_before: usize) { - self.context.not_before = Some(Duration::from_secs(not_before as u64)); - } - - pub fn set_expiry(&mut self, expiry: usize) { - self.context.expiry = Some(Duration::from_secs(expiry as u64)); - } - - pub fn allow(&mut self, claim: &str, values: Bound) -> PyResult<()> { - self.context - .allow - .insert(claim.to_string(), values.extract()?); - Ok(()) - } - - pub fn deny(&mut self, claim: &str, values: Bound) -> PyResult<()> { - self.context - .deny - .insert(claim.to_string(), values.extract()?); - Ok(()) - } - #[pyo3(signature = (*, private_keys=true))] pub fn export_pem(&self, private_keys: bool) -> PyResult> { if private_keys { @@ -219,28 +229,66 @@ impl JWKSet { .into_bytes()) } - pub fn can_sign(&self) -> bool { - self.registry.can_sign() - } - /// Sign a claims object with the default or given signing context. #[pyo3(signature = (claims, *, ctx=None))] - pub fn sign(&self, claims: Bound, ctx: Option<&SigningCtx>) -> PyResult { + pub fn sign( + &self, + py: Python, + claims: Bound, + ctx: Option<&SigningCtx>, + ) -> PyResult { let claims = claims.extract()?; - let token = self - .registry - .sign(claims, ctx.map(|c| &c.context).unwrap_or(&self.context))?; + let token = self.registry.sign( + claims, + ctx.map(|c| &c.context) + .unwrap_or(&self.default_signing_ctx.borrow(py).context), + )?; Ok(token) } - pub fn validate(&self, token: &str) -> PyResult> { - let claims = self.registry.validate(token, &self.context)?; + /// Validate a token with the default or given validation context. + #[pyo3(signature = (token, *, ctx=None))] + pub fn validate( + &self, + py: Python, + token: &str, + ctx: Option<&ValidationCtx>, + ) -> PyResult> { + let claims = self.registry.validate( + token, + ctx.map(|c| &c.context) + .unwrap_or(&self.default_validation_ctx.borrow(py).context), + )?; Ok(claims) } + pub fn can_sign(&self) -> bool { + self.registry.can_sign() + } + + pub fn can_validate(&self) -> bool { + self.registry.can_validate() + } + + pub fn has_public_keys(&self) -> bool { + self.registry.has_public_keys() + } + + pub fn has_private_keys(&self) -> bool { + self.registry.has_private_keys() + } + + pub fn has_symmetric_keys(&self) -> bool { + self.registry.has_symmetric_keys() + } + pub fn __repr__(&self) -> String { format!("JWKSet(keys={})", self.registry.len()) } + + pub fn __len__(&self) -> usize { + self.registry.len() + } } #[derive(Debug, Default, Serialize, Deserialize)] @@ -324,13 +372,16 @@ impl JWKSetCache { } } +/// Generate a token with optional additional claims. #[pyfunction] -#[pyo3(signature = (registry, *, instances=None, roles=None, databases=None))] +#[pyo3(signature = (registry, *, instances=None, roles=None, databases=None, **kwargs))] fn generate_gel_token( + py: Python, registry: &JWKSet, instances: Option>, roles: Option>, databases: Option>, + kwargs: Option>, ) -> PyResult { let mut claims = GelClaims::default(); @@ -352,27 +403,189 @@ fn generate_gel_token( claims.all_databases = true; } - claims.jti = Uuid::new_v4(); + let mut claims_map = HashMap::new(); + if claims.all_instances { + claims_map.insert("edb.i.all".to_string(), Any::from(true)); + } else if let Some(instances) = claims.instances { + claims_map.insert("edb.i".to_string(), Any::from(instances)); + } + + if claims.all_roles { + claims_map.insert("edb.r.all".to_string(), Any::from(true)); + } else if let Some(roles) = claims.roles { + claims_map.insert("edb.r".to_string(), Any::from(roles)); + } + + if claims.all_databases { + claims_map.insert("edb.d.all".to_string(), Any::from(true)); + } else if let Some(databases) = claims.databases { + claims_map.insert("edb.d".to_string(), Any::from(databases)); + } + + if let Some(kwargs) = kwargs { + for (key, value) in kwargs.iter() { + let key = key.extract::()?; + let value = value.extract::()?; + claims_map.insert(key, value); + } + } - let claims = HashMap::from([ - ("edb.i".to_string(), Any::from(claims.instances)), - ("edb.i.all".to_string(), Any::from(claims.all_instances)), - ("edb.r".to_string(), Any::from(claims.roles)), - ("edb.r.all".to_string(), Any::from(claims.all_roles)), - ("edb.d".to_string(), Any::from(claims.databases)), - ("edb.d.all".to_string(), Any::from(claims.all_databases)), - ("jti".to_string(), Any::from(claims.jti.to_string())), - ]); + // Add a JTI if and only if it's not already present. + if !claims_map.contains_key("jti") { + claims.jti = Uuid::new_v4(); + // Encode UUID as base64 to make the token shorter + let jti_base64 = Base64Unpadded::encode_string(claims.jti.as_bytes()); + claims_map.insert("jti".to_string(), Any::from(jti_base64)); + } - let token = registry.registry.sign(claims, ®istry.context)?; + let token = registry + .registry + .sign(claims_map, ®istry.default_signing_ctx.borrow(py).context)?; Ok(format!("edbt1_{}", token)) } +#[derive(Debug, Default)] +enum TokenMatch { + #[default] + None, + All, + Some(HashSet), +} + +impl TokenMatch { + fn from_claims( + claims: &HashMap, + all_key: &str, + array_key: &str, + ) -> PyResult { + if claims.contains_key(all_key) { + Ok(TokenMatch::All) + } else { + let Some(array) = claims.get(array_key).and_then(|v| v.as_array()) else { + warn!("Missing claims array key: {array_key}"); + return Err(PyErr::new::( + "authentication failed: malformed JWT", + )); + }; + Ok(TokenMatch::Some( + array + .iter() + .map(|v| v.as_str().unwrap_or_default().to_string()) + .collect::>(), + )) + } + } + + fn matches(&self, value: &str) -> bool { + match self { + TokenMatch::All => true, + TokenMatch::Some(set) => set.contains(value), + TokenMatch::None => false, + } + } +} + +#[derive(Debug, Default)] +struct TokenClaims { + instances: TokenMatch, + roles: TokenMatch, + databases: TokenMatch, +} + +#[pyfunction] +#[pyo3(signature = (registry, token, user, dbname, instance_name))] +fn validate_gel_token( + py: Python, + registry: &JWKSet, + token: &str, + user: &str, + dbname: &str, + instance_name: &str, +) -> PyResult> { + let mut token_version = 0; + let encoded_token = if let Some(stripped) = token.strip_prefix("nbwt1_") { + token_version = 1; + stripped + } else if let Some(stripped) = token.strip_prefix("nbwt_") { + stripped + } else if let Some(stripped) = token.strip_prefix("edbt1_") { + token_version = 1; + stripped + } else if let Some(stripped) = token.strip_prefix("edbt_") { + stripped + } else { + warn!( + "Invalid token prefix: [{}...]", + &token[0..token.len().min(7)] + ); + return Ok(Some("authentication failed: malformed JWT".to_string())); + }; + + // Validate and decode the JWT + let decoded = match registry.registry.validate( + encoded_token, + ®istry.default_validation_ctx.borrow(py).context, + ) { + Ok(claims) => claims, + Err(e) => { + warn!("Invalid token: {}", e.error_string_not_for_user()); + return Ok(Some( + "authentication failed: Verification failed".to_string(), + )); + } + }; + + let claims = if token_version == 0 { + // Legacy v0 token: "edgedb.server.any_role" is a boolean, "edgedb.server.roles" is an array of strings + let roles = + TokenMatch::from_claims(&decoded, "edgedb.server.any_role", "edgedb.server.roles")?; + TokenClaims { + roles, + instances: TokenMatch::All, + databases: TokenMatch::All, + } + } else { + // New v1 token: "edb.{i,r,d}.all" are booleans, "edb.{i,r,d}" are arrays of strings + let instances = TokenMatch::from_claims(&decoded, "edb.i.all", "edb.i")?; + let roles = TokenMatch::from_claims(&decoded, "edb.r.all", "edb.r")?; + let databases = TokenMatch::from_claims(&decoded, "edb.d.all", "edb.d")?; + TokenClaims { + instances, + roles, + databases, + } + }; + + if !claims.instances.matches(instance_name) { + warn!("Instance not in token: {instance_name}"); + return Ok(Some( + "authentication failed: secret key does not authorize access to this instance" + .to_string(), + )); + } + if !claims.roles.matches(user) { + warn!("Role not in token: {user}"); + return Ok(Some(format!( + "authentication failed: secret key does not authorize access in role {user:?}" + ))); + } + if !claims.databases.matches(dbname) { + warn!("Database not in token: {dbname}"); + return Ok(Some(format!( + "authentication failed: secret key does not authorize access to database {dbname:?}" + ))); + } + + Ok(None) +} + #[pymodule] pub fn _jwt(_py: Python, m: &Bound) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; m.add_function(wrap_pyfunction!(generate_gel_token, m)?)?; + m.add_function(wrap_pyfunction!(validate_gel_token, m)?)?; Ok(()) } diff --git a/rust/gel-jwt/src/registry.rs b/rust/gel-jwt/src/registry.rs index 9daddf7d2db..384f2edd0c0 100644 --- a/rust/gel-jwt/src/registry.rs +++ b/rust/gel-jwt/src/registry.rs @@ -1,7 +1,8 @@ use crate::{ bare_key::{SerializedKey, SerializedKeys}, key::*, - Any, KeyError, OpaqueValidationFailureReason, SignatureError, ValidationError, + Any, KeyError, OpaqueValidationFailureReason, SignatureError, SigningContext, + ValidationContext, ValidationError, }; use std::{ collections::{BTreeSet, HashMap, HashSet}, @@ -224,7 +225,7 @@ impl KeyRegistry { pub fn validate( &self, token: &str, - ctx: &SigningContext, + ctx: &ValidationContext, ) -> Result, ValidationError> { // If we have a named key that matches, use that. if !self.named_keys.is_empty() { @@ -257,12 +258,6 @@ impl KeyRegistry { Err(result.unwrap_or(OpaqueValidationFailureReason::NoAppropriateKey.into())) } - pub fn can_sign(&self) -> bool { - self.active_key() - .map(|(_, k)| K::encoding_key(k).is_some()) - .unwrap_or(false) - } - pub fn sign( &self, claims: HashMap, @@ -274,11 +269,97 @@ impl KeyRegistry { } } -impl KeyRegistry {} +impl KeyRegistry { + pub fn can_sign(&self) -> bool { + self.has_private_keys() || self.has_symmetric_keys() + } + + pub fn can_validate(&self) -> bool { + self.has_public_keys() || self.has_symmetric_keys() + } + + pub fn has_private_keys(&self) -> bool { + !self.is_empty() + } + + pub fn has_public_keys(&self) -> bool { + self.key_to_ordinal + .iter() + .any(|(k, _)| k.bare_key.key_type() != KeyType::HS256) + } + + pub fn has_symmetric_keys(&self) -> bool { + self.key_to_ordinal + .iter() + .any(|(k, _)| k.bare_key.key_type() == KeyType::HS256) + } +} + +impl KeyRegistry { + pub fn can_sign(&self) -> bool { + self.has_private_keys() || self.has_symmetric_keys() + } + + pub fn can_validate(&self) -> bool { + self.has_public_keys() || self.has_symmetric_keys() + } + + pub fn has_public_keys(&self) -> bool { + !self.is_empty() + } + + pub fn has_private_keys(&self) -> bool { + false + } -impl KeyRegistry {} + pub fn has_symmetric_keys(&self) -> bool { + false + } +} impl KeyRegistry { + pub fn can_sign(&self) -> bool { + self.has_private_keys() || self.has_symmetric_keys() + } + + pub fn can_validate(&self) -> bool { + self.has_public_keys() || self.has_symmetric_keys() + } + + pub fn has_private_keys(&self) -> bool { + for k in self.key_to_ordinal.keys() { + if let KeyInner::Private(_) = k { + return true; + } + } + false + } + + pub fn has_public_keys(&self) -> bool { + for k in self.key_to_ordinal.keys() { + if let KeyInner::Public(_) = k { + return true; + } + if let KeyInner::Private(k) = k { + if k.bare_key.key_type() != KeyType::HS256 { + return true; + } + } + } + false + } + + pub fn has_symmetric_keys(&self) -> bool { + for k in self.key_to_ordinal.keys() { + if let KeyInner::Private(k) = k { + if k.bare_key.key_type() == KeyType::HS256 { + return true; + } + } + } + false + } + /// Export the registry as a PEM file containing only the public keys. /// This will fail if the registry contains symmetric keys. pub fn to_pem_public(&self) -> Result { diff --git a/rust/gel-jwt/src/sig.rs b/rust/gel-jwt/src/sig.rs new file mode 100644 index 00000000000..65c81d74e72 --- /dev/null +++ b/rust/gel-jwt/src/sig.rs @@ -0,0 +1,206 @@ +use std::{ + borrow::Cow, + collections::{HashMap, HashSet}, + time::Duration, +}; + +use serde::{Deserialize, Serialize}; + +#[derive(Clone, Serialize, Deserialize, Default)] +pub struct SigningContext { + pub expiry: Option, + pub issuer: Option, + pub audience: Option, + pub not_before: Option, +} + +#[derive(Clone, Serialize, Deserialize, Default, Debug, PartialEq, Eq)] +pub enum ValidationType { + /// Require the claim to be absent and fail if it is present. + Reject, + /// Ignore the claim. + Ignore, + /// If the claim is present, it must be valid. + #[default] + Allow, + /// Require the claim to be present and be valid. + Require, +} + +#[derive(Clone, Serialize, Deserialize, Default)] +pub struct ValidationContext { + pub allow_list: HashMap>, + pub deny_list: HashMap>, + pub claims: HashMap, + pub expiry: ValidationType, + pub not_before: ValidationType, +} + +impl ValidationContext { + pub fn require_claim_with_allow_list(&mut self, claim: &str, values: &[&str]) { + self.claims + .insert(claim.to_string(), ValidationType::Require); + self.allow_list.insert( + claim.to_string(), + values.iter().map(|s| s.to_string()).collect(), + ); + } + + pub fn require_claim_with_deny_list(&mut self, claim: &str, values: &[&str]) { + self.claims + .insert(claim.to_string(), ValidationType::Require); + self.deny_list.insert( + claim.to_string(), + values.iter().map(|s| s.to_string()).collect(), + ); + } + + pub fn require_claim(&mut self, claim: &str) { + self.claims + .insert(claim.to_string(), ValidationType::Require); + } + + pub fn reject_claim(&mut self, claim: &str) { + self.claims + .insert(claim.to_string(), ValidationType::Reject); + } + + pub fn ignore_claim(&mut self, claim: &str) { + self.claims + .insert(claim.to_string(), ValidationType::Ignore); + } + + pub fn allow_claim(&mut self, claim: &str) { + self.claims.insert(claim.to_string(), ValidationType::Allow); + } +} + +/// A type similar to `serde_json::Value` that can be serialized and deserialized +/// from a JWT token. +#[derive(Clone, serde::Serialize, serde::Deserialize, Debug, PartialEq)] +#[serde(untagged)] +pub enum Any { + None, + String(Cow<'static, str>), + Bool(bool), + Number(isize), + Array(Vec), + Object(HashMap, Any>), +} + +impl Any { + pub fn as_str(&self) -> Option<&str> { + match self { + Any::String(s) => Some(s.as_ref()), + _ => None, + } + } + + pub fn as_array(&self) -> Option<&[Any]> { + match self { + Any::Array(a) => Some(a), + _ => None, + } + } + + pub fn as_object(&self) -> Option<&HashMap, Any>> { + match self { + Any::Object(o) => Some(o), + _ => None, + } + } +} + +impl From for Any { + fn from(value: bool) -> Self { + Any::Bool(value) + } +} + +impl From<&'static str> for Any { + fn from(value: &'static str) -> Self { + Any::String(Cow::Borrowed(value)) + } +} + +impl From for Any { + fn from(value: String) -> Self { + Any::String(Cow::Owned(value)) + } +} + +impl From> for Any +where + T: Into, +{ + fn from(value: Option) -> Self { + value.map(T::into).unwrap_or(Any::None) + } +} + +impl From> for Any +where + T: Into, +{ + fn from(value: Vec) -> Self { + Any::Array(value.into_iter().map(T::into).collect()) + } +} + +#[cfg(feature = "python_extension")] +impl<'py> pyo3::FromPyObject<'py> for Any { + fn extract_bound(ob: &pyo3::Bound<'py, pyo3::PyAny>) -> pyo3::PyResult { + use pyo3::types::PyAnyMethods; + if ob.is_none() { + return Ok(Any::None); + } + if let Ok(value) = ob.extract::() { + return Ok(Any::Bool(value)); + } + if let Ok(value) = ob.extract::() { + return Ok(Any::Number(value)); + } + if let Ok(value) = ob.extract::() { + return Ok(Any::String(Cow::Owned(value))); + } + let res: Result, pyo3::PyErr> = ob.extract(); + if let Ok(list) = res { + let mut items = Vec::new(); + for item in list { + items.push(Any::extract_bound(&item)?); + } + return Ok(Any::Array(items)); + } + let res: Result, pyo3::PyErr> = ob.extract(); + if let Ok(dict) = res { + let mut items = HashMap::new(); + for (k, v) in dict { + items.insert(Cow::Owned(k.extract::()?), Any::extract_bound(&v)?); + } + return Ok(Any::Object(items)); + } + Err(pyo3::PyErr::new::( + "Invalid Any value", + )) + } +} + +#[cfg(feature = "python_extension")] +impl<'py> pyo3::IntoPyObject<'py> for Any { + type Target = pyo3::PyAny; + type Output = pyo3::Bound<'py, pyo3::PyAny>; + type Error = pyo3::PyErr; + fn into_pyobject(self, py: pyo3::Python<'py>) -> Result { + use pyo3::IntoPyObjectExt; + + Ok(match self { + Any::None => py.None(), + Any::String(s) => s.as_ref().into_py_any(py)?, + Any::Bool(b) => b.into_py_any(py)?, + Any::Number(n) => n.into_py_any(py)?, + Any::Array(a) => a.into_py_any(py)?, + Any::Object(o) => o.into_py_any(py)?, + } + .into_bound(py)) + } +}