Skip to content

Commit

Permalink
JWT
Browse files Browse the repository at this point in the history
  • Loading branch information
mmastrac committed Feb 7, 2025
1 parent e6a7994 commit 2345968
Show file tree
Hide file tree
Showing 37 changed files with 4,234 additions and 15 deletions.
416 changes: 401 additions & 15 deletions Cargo.lock

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ members = [
"rust/conn_pool",
"rust/db_proto",
"rust/gel-auth",
"rust/gel-jwt",
"rust/gel-stream",
"rust/pgrust",
"rust/http",
Expand All @@ -23,6 +24,7 @@ tracing = "0.1.40"
tracing-subscriber = { version = "0.3.18", features = ["registry", "env-filter"] }

gel-auth = { path = "rust/gel-auth" }
gel-jwt = { path = "rust/gel-jwt" }
gel-stream = { path = "rust/gel-stream" }
db_proto = { path = "rust/db_proto" }
captive_postgres = { path = "rust/captive_postgres" }
Expand Down
2 changes: 2 additions & 0 deletions edb/server/_rust_native/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ pyo3_util.workspace = true
conn_pool = { workspace = true, features = [ "python_extension" ] }
pgrust = { workspace = true, features = [ "python_extension" ] }
http = { workspace = true, features = [ "python_extension" ] }
gel-auth = { workspace = true, features = [ "python_extension" ] }
gel-jwt = { workspace = true, features = [ "python_extension" ] }

[lib]
crate-type = ["lib", "cdylib"]
Expand Down
1 change: 1 addition & 0 deletions edb/server/_rust_native/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ fn _rust_native(py: Python, m: &Bound<PyModule>) -> PyResult<()> {
add_child_module(py, m, "_conn_pool", conn_pool::python::_conn_pool)?;
add_child_module(py, m, "_pg_rust", pgrust::python::_pg_rust)?;
add_child_module(py, m, "_http", http::python::_http)?;
add_child_module(py, m, "_jwt", gel_jwt::python::_jwt)?;

Ok(())
}
50 changes: 50 additions & 0 deletions edb/server/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from typing import TYPE_CHECKING, Iterable, List, Optional, Any

if TYPE_CHECKING:
class SigningCtx:
def __init__(self) -> None: ...
def set_issuer(self, issuer: str) -> None: ...
def set_audience(self, audience: str) -> None: ...
def set_expiry(self, expiry: int) -> None: ...
def set_not_before(self, not_before: int) -> None: ...
def allow(self, claim: str, values: List[str]) -> None: ...

class JWKSet:
@staticmethod
def from_hs256_key(key: bytes) -> "JWKSet": ...
def __init__(self) -> None: ...
def generate(self, *, kid: Optional[str], kty: str) -> None: ...
def add(self, **kwargs: Any) -> None: ...
def load(self, keys: str) -> int: ...
def load_json(self, keys: str) -> int: ...
def set_issuer(self, issuer: str) -> None: ...
def set_audience(self, audience: str) -> None: ...
def set_expiry(self, expiry: int) -> None: ...
def set_not_before(self, not_before: int) -> None: ...
def allow(self, claim: str, values: List[str]) -> None: ...
def deny(self, claim: str, values: List[str]) -> None: ...
def export_pem(self, *, private_keys: bool) -> bytes: ...
def export_json(self, *, private_keys: bool) -> bytes: ...
def can_sign(self) -> bool: ...
def sign(
self, claims: dict[str, Any], *, ctx: Optional[SigningCtx] = None
) -> str: ...
def validate(self, token: str) -> dict[str, Any]: ...
def to_json(self, *, private_keys: bool) -> str: ...
def to_pem(self, *, private_keys: bool) -> str: ...

class JWKSetCache:
def __init__(self, expiry_seconds: int) -> None: ...
# Returns a tuple of (is_fresh, registry)
def get(self, key: str) -> tuple[bool, Optional[JWKSet]]: ...
def set(self, key: str, registry: JWKSet) -> None: ...

def generate_gel_token(
registry: JWKSet,
*,
instances: Optional[List[str] | Iterable[str]] = None,
roles: Optional[List[str] | Iterable[str]] = None,
databases: Optional[List[str] | Iterable[str]] = None,
) -> str: ...
else:
from edb.server._rust_native._jwt import JWKSet, JWKSetCache, generate_gel_token, SigningCtx # noqa
55 changes: 55 additions & 0 deletions rust/gel-jwt/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
[package]
name = "gel-jwt"
version = "0.1.0"
edition = "2021"

[features]
python_extension = ["pyo3/extension-module"]

[dependencies]
pyo3 = { workspace = true, optional = true }
pyo3_util.workspace = true

# This is required to be in sync w/jsonwebtoken
rand = "0.8.5"

md5 = "0.7.0"
sha2 = "0.10.8"
constant_time_eq = "0.3"
base64 = "0.22"
thiserror = "2"
hmac = "0.12.1"
derive_more = { version = "2", features = ["debug", "from", "display"] }

rustls-pki-types = "1"
serde = "1"
serde_derive = "1"
serde_json = "1"
jsonwebtoken = { version = "9", default-features = false }
ring = { version = "0.17", default-features = false }
rsa = { version = "0.9.7", default-features = false, features = ["std"] }
pkcs1 = "0.7.5"
pkcs8 = "0.10.2"
sec1 = { version = "0.7.3", features = ["der", "pkcs8", "alloc"] }
pem = "3"
const-oid = { version ="0.9.6", features = ["db"] }
p256 = { version = "0.13.2", features = ["jwk"] }
base64ct = { version = "1", features = ["alloc"] }
der = "0.7.9"
libc = "0.2"
elliptic-curve = { version = "0.13.8", features = ["arithmetic"] }
num-bigint-dig = "0.8.4"
zeroize = { version = "1", features = ["derive", "serde"] }
uuid = { version = "1", features = ["v4", "serde"] }

[dev-dependencies]
pretty_assertions = "1"
rstest = "0.24.0"
hex-literal = "0.4.1"
divan = "0.1.17"

[[bench]]
name = "encode"
harness = false

[lib]
115 changes: 115 additions & 0 deletions rust/gel-jwt/benches/bench-jwcrypto.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
from jwcrypto import jwt, jwk
import time
import statistics

def generate_key(key_type):
if key_type == "ES256":
return jwk.JWK.generate(kty='EC', crv='P-256')
elif key_type == "RS256":
return jwk.JWK.generate(kty='RSA', size=2048)
elif key_type == "HS256":
return jwk.JWK.generate(kty='oct', size=256)
raise ValueError(f"Unsupported key type: {key_type}")

def benchmark_encode(key_type, iterations=100):
# Generate key outside the loop
key = generate_key(key_type)

# Benchmark full encoding process including claims creation
times = []
for _ in range(iterations):
start = time.perf_counter_ns()

# Create claims and sign in the timed section
claims = {"sub": "test"}
token = jwt.JWT(
header={"alg": key_type},
claims=claims
)
token.make_signed_token(key)

end = time.perf_counter_ns()
times.append(end - start)

mean = statistics.mean(times) / 1000 # Convert to microseconds
median = statistics.median(times) / 1000
return mean, median

def benchmark_signing(key_type, iterations=100):
# Generate key outside the loop
key = generate_key(key_type)
claims = {"sub": "test"}

# Benchmark signing
times = []
for _ in range(iterations):
start = time.perf_counter_ns()

# Signing
token = jwt.JWT(
header={"alg": key_type},
claims=claims
)
token.make_signed_token(key)

end = time.perf_counter_ns()
times.append(end - start)

mean = statistics.mean(times) / 1000
median = statistics.median(times) / 1000
return mean, median

def benchmark_validation(key_type, iterations=100):
# Generate key and token outside the loop
key = generate_key(key_type)
token = jwt.JWT(
header={"alg": key_type},
claims={"sub": "test"}
)
token.make_signed_token(key)
token_string = token.serialize()

# Benchmark validation
times = []
for _ in range(iterations):
start = time.perf_counter_ns()

# Validation
jwt.JWT(jwt=token_string, key=key)

end = time.perf_counter_ns()
times.append(end - start)

mean = statistics.mean(times) / 1000
median = statistics.median(times) / 1000
return mean, median

def main():
key_types = ["ES256", "RS256", "HS256"]
iterations = 100

print(f"Running {iterations} iterations for each algorithm")

print("\nFull encode benchmarks (including claims creation):")
print(f"{'Algorithm':<10} | {'Mean (µs)':<12} | {'Median (µs)':<12}")
print("-" * 38)
for key_type in key_types:
mean, median = benchmark_encode(key_type, iterations)
print(f"{key_type:<10} | {mean:12.2f} | {median:12.2f}")

print("\nSigning benchmarks (pre-created claims):")
print(f"{'Algorithm':<10} | {'Mean (µs)':<12} | {'Median (µs)':<12}")
print("-" * 38)
for key_type in key_types:
mean, median = benchmark_signing(key_type, iterations)
print(f"{key_type:<10} | {mean:12.2f} | {median:12.2f}")

print("\nValidation benchmarks:")
print(f"{'Algorithm':<10} | {'Mean (µs)':<12} | {'Median (µs)':<12}")
print("-" * 38)
for key_type in key_types:
mean, median = benchmark_validation(key_type, iterations)
print(f"{key_type:<10} | {mean:12.2f} | {median:12.2f}")

if __name__ == "__main__":
main()
38 changes: 38 additions & 0 deletions rust/gel-jwt/benches/encode.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
use std::collections::HashMap;

use gel_jwt::{KeyType, PrivateKey, SigningContext};

#[divan::bench(args = [&KeyType::ES256, &KeyType::RS256, &KeyType::HS256])]
fn bench_jwt_signing(b: divan::Bencher, key_type: &KeyType) {
let key = PrivateKey::generate(None, *key_type).unwrap();
let claims = HashMap::from([("sub".to_string(), "test".into())]);
let ctx = SigningContext::default();

b.bench_local(move || key.sign(claims.clone(), &ctx));
}

#[divan::bench(args = [&KeyType::ES256, &KeyType::RS256, &KeyType::HS256])]
fn bench_jwt_validation(b: divan::Bencher, key_type: &KeyType) {
let key = PrivateKey::generate(None, *key_type).unwrap();
let claims = HashMap::from([("sub".to_string(), "test".into())]);
let ctx = SigningContext::default();
let token = key.sign(claims, &ctx).unwrap();

b.bench_local(move || key.validate(&token, &ctx));
}

#[divan::bench(args = [&KeyType::ES256, &KeyType::RS256, &KeyType::HS256])]
fn bench_jwt_encode(b: divan::Bencher, key_type: &KeyType) {
let key = PrivateKey::generate(None, *key_type).unwrap();

b.bench_local(move || {
let claims = HashMap::from([("sub".to_string(), "test".into())]);
let ctx = SigningContext::default();
key.sign(claims, &ctx).unwrap()
});
}

fn main() {
// Run registered benchmarks.
divan::main();
}
16 changes: 16 additions & 0 deletions rust/gel-jwt/src/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# JWT support

This crate provides support for JWT tokens.

## Key types

HS256: symmetric key
RS256: asymmetric key (RSA 2048+ + SHA256)
ES256: asymmetric key (P-256 + SHA256)

## Supported key formats

HS256: raw data
RS256: PKCS1/PKCS8 PEM
ES256: SEC1/PKCS8 PEM

Loading

0 comments on commit 2345968

Please sign in to comment.