-
Notifications
You must be signed in to change notification settings - Fork 409
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
37 changed files
with
4,234 additions
and
15 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
from typing import TYPE_CHECKING, Iterable, List, Optional, Any | ||
|
||
if TYPE_CHECKING: | ||
class SigningCtx: | ||
def __init__(self) -> None: ... | ||
def set_issuer(self, issuer: str) -> None: ... | ||
def set_audience(self, audience: str) -> None: ... | ||
def set_expiry(self, expiry: int) -> None: ... | ||
def set_not_before(self, not_before: int) -> None: ... | ||
def allow(self, claim: str, values: List[str]) -> None: ... | ||
|
||
class JWKSet: | ||
@staticmethod | ||
def from_hs256_key(key: bytes) -> "JWKSet": ... | ||
def __init__(self) -> None: ... | ||
def generate(self, *, kid: Optional[str], kty: str) -> None: ... | ||
def add(self, **kwargs: Any) -> None: ... | ||
def load(self, keys: str) -> int: ... | ||
def load_json(self, keys: str) -> int: ... | ||
def set_issuer(self, issuer: str) -> None: ... | ||
def set_audience(self, audience: str) -> None: ... | ||
def set_expiry(self, expiry: int) -> None: ... | ||
def set_not_before(self, not_before: int) -> None: ... | ||
def allow(self, claim: str, values: List[str]) -> None: ... | ||
def deny(self, claim: str, values: List[str]) -> None: ... | ||
def export_pem(self, *, private_keys: bool) -> bytes: ... | ||
def export_json(self, *, private_keys: bool) -> bytes: ... | ||
def can_sign(self) -> bool: ... | ||
def sign( | ||
self, claims: dict[str, Any], *, ctx: Optional[SigningCtx] = None | ||
) -> str: ... | ||
def validate(self, token: str) -> dict[str, Any]: ... | ||
def to_json(self, *, private_keys: bool) -> str: ... | ||
def to_pem(self, *, private_keys: bool) -> str: ... | ||
|
||
class JWKSetCache: | ||
def __init__(self, expiry_seconds: int) -> None: ... | ||
# Returns a tuple of (is_fresh, registry) | ||
def get(self, key: str) -> tuple[bool, Optional[JWKSet]]: ... | ||
def set(self, key: str, registry: JWKSet) -> None: ... | ||
|
||
def generate_gel_token( | ||
registry: JWKSet, | ||
*, | ||
instances: Optional[List[str] | Iterable[str]] = None, | ||
roles: Optional[List[str] | Iterable[str]] = None, | ||
databases: Optional[List[str] | Iterable[str]] = None, | ||
) -> str: ... | ||
else: | ||
from edb.server._rust_native._jwt import JWKSet, JWKSetCache, generate_gel_token, SigningCtx # noqa |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
[package] | ||
name = "gel-jwt" | ||
version = "0.1.0" | ||
edition = "2021" | ||
|
||
[features] | ||
python_extension = ["pyo3/extension-module"] | ||
|
||
[dependencies] | ||
pyo3 = { workspace = true, optional = true } | ||
pyo3_util.workspace = true | ||
|
||
# This is required to be in sync w/jsonwebtoken | ||
rand = "0.8.5" | ||
|
||
md5 = "0.7.0" | ||
sha2 = "0.10.8" | ||
constant_time_eq = "0.3" | ||
base64 = "0.22" | ||
thiserror = "2" | ||
hmac = "0.12.1" | ||
derive_more = { version = "2", features = ["debug", "from", "display"] } | ||
|
||
rustls-pki-types = "1" | ||
serde = "1" | ||
serde_derive = "1" | ||
serde_json = "1" | ||
jsonwebtoken = { version = "9", default-features = false } | ||
ring = { version = "0.17", default-features = false } | ||
rsa = { version = "0.9.7", default-features = false, features = ["std"] } | ||
pkcs1 = "0.7.5" | ||
pkcs8 = "0.10.2" | ||
sec1 = { version = "0.7.3", features = ["der", "pkcs8", "alloc"] } | ||
pem = "3" | ||
const-oid = { version ="0.9.6", features = ["db"] } | ||
p256 = { version = "0.13.2", features = ["jwk"] } | ||
base64ct = { version = "1", features = ["alloc"] } | ||
der = "0.7.9" | ||
libc = "0.2" | ||
elliptic-curve = { version = "0.13.8", features = ["arithmetic"] } | ||
num-bigint-dig = "0.8.4" | ||
zeroize = { version = "1", features = ["derive", "serde"] } | ||
uuid = { version = "1", features = ["v4", "serde"] } | ||
|
||
[dev-dependencies] | ||
pretty_assertions = "1" | ||
rstest = "0.24.0" | ||
hex-literal = "0.4.1" | ||
divan = "0.1.17" | ||
|
||
[[bench]] | ||
name = "encode" | ||
harness = false | ||
|
||
[lib] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
from jwcrypto import jwt, jwk | ||
import time | ||
import statistics | ||
|
||
def generate_key(key_type): | ||
if key_type == "ES256": | ||
return jwk.JWK.generate(kty='EC', crv='P-256') | ||
elif key_type == "RS256": | ||
return jwk.JWK.generate(kty='RSA', size=2048) | ||
elif key_type == "HS256": | ||
return jwk.JWK.generate(kty='oct', size=256) | ||
raise ValueError(f"Unsupported key type: {key_type}") | ||
|
||
def benchmark_encode(key_type, iterations=100): | ||
# Generate key outside the loop | ||
key = generate_key(key_type) | ||
|
||
# Benchmark full encoding process including claims creation | ||
times = [] | ||
for _ in range(iterations): | ||
start = time.perf_counter_ns() | ||
|
||
# Create claims and sign in the timed section | ||
claims = {"sub": "test"} | ||
token = jwt.JWT( | ||
header={"alg": key_type}, | ||
claims=claims | ||
) | ||
token.make_signed_token(key) | ||
|
||
end = time.perf_counter_ns() | ||
times.append(end - start) | ||
|
||
mean = statistics.mean(times) / 1000 # Convert to microseconds | ||
median = statistics.median(times) / 1000 | ||
return mean, median | ||
|
||
def benchmark_signing(key_type, iterations=100): | ||
# Generate key outside the loop | ||
key = generate_key(key_type) | ||
claims = {"sub": "test"} | ||
|
||
# Benchmark signing | ||
times = [] | ||
for _ in range(iterations): | ||
start = time.perf_counter_ns() | ||
|
||
# Signing | ||
token = jwt.JWT( | ||
header={"alg": key_type}, | ||
claims=claims | ||
) | ||
token.make_signed_token(key) | ||
|
||
end = time.perf_counter_ns() | ||
times.append(end - start) | ||
|
||
mean = statistics.mean(times) / 1000 | ||
median = statistics.median(times) / 1000 | ||
return mean, median | ||
|
||
def benchmark_validation(key_type, iterations=100): | ||
# Generate key and token outside the loop | ||
key = generate_key(key_type) | ||
token = jwt.JWT( | ||
header={"alg": key_type}, | ||
claims={"sub": "test"} | ||
) | ||
token.make_signed_token(key) | ||
token_string = token.serialize() | ||
|
||
# Benchmark validation | ||
times = [] | ||
for _ in range(iterations): | ||
start = time.perf_counter_ns() | ||
|
||
# Validation | ||
jwt.JWT(jwt=token_string, key=key) | ||
|
||
end = time.perf_counter_ns() | ||
times.append(end - start) | ||
|
||
mean = statistics.mean(times) / 1000 | ||
median = statistics.median(times) / 1000 | ||
return mean, median | ||
|
||
def main(): | ||
key_types = ["ES256", "RS256", "HS256"] | ||
iterations = 100 | ||
|
||
print(f"Running {iterations} iterations for each algorithm") | ||
|
||
print("\nFull encode benchmarks (including claims creation):") | ||
print(f"{'Algorithm':<10} | {'Mean (µs)':<12} | {'Median (µs)':<12}") | ||
print("-" * 38) | ||
for key_type in key_types: | ||
mean, median = benchmark_encode(key_type, iterations) | ||
print(f"{key_type:<10} | {mean:12.2f} | {median:12.2f}") | ||
|
||
print("\nSigning benchmarks (pre-created claims):") | ||
print(f"{'Algorithm':<10} | {'Mean (µs)':<12} | {'Median (µs)':<12}") | ||
print("-" * 38) | ||
for key_type in key_types: | ||
mean, median = benchmark_signing(key_type, iterations) | ||
print(f"{key_type:<10} | {mean:12.2f} | {median:12.2f}") | ||
|
||
print("\nValidation benchmarks:") | ||
print(f"{'Algorithm':<10} | {'Mean (µs)':<12} | {'Median (µs)':<12}") | ||
print("-" * 38) | ||
for key_type in key_types: | ||
mean, median = benchmark_validation(key_type, iterations) | ||
print(f"{key_type:<10} | {mean:12.2f} | {median:12.2f}") | ||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
use std::collections::HashMap; | ||
|
||
use gel_jwt::{KeyType, PrivateKey, SigningContext}; | ||
|
||
#[divan::bench(args = [&KeyType::ES256, &KeyType::RS256, &KeyType::HS256])] | ||
fn bench_jwt_signing(b: divan::Bencher, key_type: &KeyType) { | ||
let key = PrivateKey::generate(None, *key_type).unwrap(); | ||
let claims = HashMap::from([("sub".to_string(), "test".into())]); | ||
let ctx = SigningContext::default(); | ||
|
||
b.bench_local(move || key.sign(claims.clone(), &ctx)); | ||
} | ||
|
||
#[divan::bench(args = [&KeyType::ES256, &KeyType::RS256, &KeyType::HS256])] | ||
fn bench_jwt_validation(b: divan::Bencher, key_type: &KeyType) { | ||
let key = PrivateKey::generate(None, *key_type).unwrap(); | ||
let claims = HashMap::from([("sub".to_string(), "test".into())]); | ||
let ctx = SigningContext::default(); | ||
let token = key.sign(claims, &ctx).unwrap(); | ||
|
||
b.bench_local(move || key.validate(&token, &ctx)); | ||
} | ||
|
||
#[divan::bench(args = [&KeyType::ES256, &KeyType::RS256, &KeyType::HS256])] | ||
fn bench_jwt_encode(b: divan::Bencher, key_type: &KeyType) { | ||
let key = PrivateKey::generate(None, *key_type).unwrap(); | ||
|
||
b.bench_local(move || { | ||
let claims = HashMap::from([("sub".to_string(), "test".into())]); | ||
let ctx = SigningContext::default(); | ||
key.sign(claims, &ctx).unwrap() | ||
}); | ||
} | ||
|
||
fn main() { | ||
// Run registered benchmarks. | ||
divan::main(); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
# JWT support | ||
|
||
This crate provides support for JWT tokens. | ||
|
||
## Key types | ||
|
||
HS256: symmetric key | ||
RS256: asymmetric key (RSA 2048+ + SHA256) | ||
ES256: asymmetric key (P-256 + SHA256) | ||
|
||
## Supported key formats | ||
|
||
HS256: raw data | ||
RS256: PKCS1/PKCS8 PEM | ||
ES256: SEC1/PKCS8 PEM | ||
|
Oops, something went wrong.