diff --git a/Cargo.lock b/Cargo.lock index 904cc783f5f3..8f7e6f033262 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1141,6 +1141,7 @@ dependencies = [ "serde_json", "sha2", "thiserror 2.0.3", + "tracing", "uuid", "zeroize", ] diff --git a/Cargo.toml b/Cargo.toml index 898a876a0338..b6d4350cc4cb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,7 +18,7 @@ members = [ resolver = "2" [workspace.dependencies] -pyo3 = { version = "0.23", features = ["extension-module", "serde", "macros"] } +pyo3 = { version = "0.23.4", features = ["serde", "macros"] } tokio = { version = "1", features = ["rt", "rt-multi-thread", "macros", "time", "sync", "net", "io-util"] } tracing = "0.1.40" tracing-subscriber = { version = "0.3.18", features = ["registry", "env-filter"] } diff --git a/edb/common/secretkey.py b/edb/common/secretkey.py index 7fbaff9bc92c..9ce1d58b2a46 100644 --- a/edb/common/secretkey.py +++ b/edb/common/secretkey.py @@ -17,83 +17,12 @@ # from __future__ import annotations -from typing import Optional, AbstractSet, Iterable +from typing import Iterable import pathlib import uuid -from datetime import datetime, timedelta, timezone - -from jwcrypto import jwk, jwt - -from . import uuidgen - - -class SecretKeyReadError(Exception): - pass - - -def generate_secret_key( - skey: jwk.JWK, - *, - instances: Optional[list[str] | AbstractSet[str]] = None, - roles: Optional[list[str] | AbstractSet[str]] = None, - databases: Optional[list[str] | AbstractSet[str]] = None, - subject: Optional[str] = None, - key_id: Optional[str] = None, -) -> str: - claims = { - "iat": int(datetime.now(timezone.utc).timestamp()), - "iss": "edgedb-server", - } - - if instances is None: - claims["edb.i.all"] = True - else: - claims["edb.i"] = list(instances) - - if roles is None: - claims["edb.r.all"] = True - else: - claims["edb.r"] = list(roles) - - if databases is None: - claims["edb.d.all"] = True - else: - claims["edb.d"] = list(databases) - - if subject is not None: - claims["sub"] = subject - - if key_id is None: - key_id = str(uuidgen.uuid4()) - - claims["jti"] = key_id - - token = jwt.JWT( - header={"alg": "ES256" if skey["kty"] == "EC" else "RS256"}, - claims=claims, - ) - token.make_signed_token(skey) - return "edbt1_" + token.serialize() - - -def load_secret_key(key_file: pathlib.Path) -> jwk.JWK: - try: - with open(key_file, 'rb') as kf: - jws_key = jwk.JWK.from_pem(kf.read()) - except Exception as e: - raise SecretKeyReadError(f"cannot load JWS key: {e}") from e - - if ( - not jws_key.has_public - or jws_key['kty'] not in {"RSA", "EC"} - ): - raise SecretKeyReadError( - f"the cluster JWS key file does not " - f"contain a valid RSA or EC public key") - - return jws_key +from datetime import datetime, timedelta def generate_tls_cert( @@ -154,11 +83,3 @@ def generate_tls_cert( ) ) tls_key_file.chmod(0o600) - - -def generate_jwk(keys_file: pathlib.Path) -> None: - key = jwk.JWK(generate='EC') - with keys_file.open("wb") as f: - f.write(key.export_to_pem(private_key=True, password=None)) - - keys_file.chmod(0o600) diff --git a/edb/server/args.py b/edb/server/args.py index 8c630d3e3c1d..506357d69536 100644 --- a/edb/server/args.py +++ b/edb/server/args.py @@ -984,8 +984,9 @@ def resolve_envvar_value(self, ctx: click.Context): cls=EnvvarResolver, hidden=True, help='Specifies a path to a file containing a public key in PEM ' - 'format used to verify JWT signatures. The file could also ' - 'contain a private key to sign JWT for local testing.'), + 'or JSON JWK format used to verify JWT signatures. The file may ' + 'also contain a private key to sign JWT tokens for ' + 'SCRAM-over-HTTP.'), click.option( '--jwe-key-file', type=PathPath(), diff --git a/edb/server/auth.py b/edb/server/auth.py index c04992478916..c27b48214260 100644 --- a/edb/server/auth.py +++ b/edb/server/auth.py @@ -1,3 +1,6 @@ +import datetime +import pathlib + from typing import TYPE_CHECKING, Iterable, List, Optional, Any if TYPE_CHECKING: @@ -21,17 +24,19 @@ def set_issuer(self, issuer: str) -> None: ... def set_audience(self, audience: str) -> None: ... def set_expiry(self, expiry: int) -> None: ... def set_not_before(self, not_before: int) -> None: ... - def allow(self, claim: str, values: List[str]) -> None: ... - def deny(self, claim: str, values: List[str]) -> None: ... - def export_pem(self, *, private_keys: bool) -> bytes: ... - def export_json(self, *, private_keys: bool) -> bytes: ... + def allow(self, claim: str, values: Iterable[str]) -> None: ... + def deny(self, claim: str, values: Iterable[str]) -> None: ... + def export_pem(self, *, private_keys: bool=True) -> bytes: ... + def export_json(self, *, private_keys: bool=True) -> bytes: ... def can_sign(self) -> bool: ... + def can_validate(self) -> bool: ... + def has_public_keys(self) -> bool: ... + def has_private_keys(self) -> bool: ... + def has_symmetric_keys(self) -> bool: ... def sign( self, claims: dict[str, Any], *, ctx: Optional[SigningCtx] = None ) -> str: ... def validate(self, token: str) -> dict[str, Any]: ... - def to_json(self, *, private_keys: bool) -> str: ... - def to_pem(self, *, private_keys: bool) -> str: ... class JWKSetCache: def __init__(self, expiry_seconds: int) -> None: ... @@ -45,6 +50,55 @@ def generate_gel_token( instances: Optional[List[str] | Iterable[str]] = None, roles: Optional[List[str] | Iterable[str]] = None, databases: Optional[List[str] | Iterable[str]] = None, + **kwargs: Any, ) -> str: ... + + def validate_gel_token( + registry: JWKSet, + token: str, + user: str, + dbname: str, + instance_name: str, + ) -> str | None: ... else: - from edb.server._rust_native._jwt import JWKSet, JWKSetCache, generate_gel_token, SigningCtx # noqa + from edb.server._rust_native._jwt import ( + JWKSet, JWKSetCache, generate_gel_token, validate_gel_token, SigningCtx # noqa + ) + + +def load_secret_key(key_file: pathlib.Path) -> JWKSet: + try: + with open(key_file, 'rb') as kf: + jws_key = JWKSet() + jws_key.load(kf.read().decode('ascii')) + except Exception as e: + raise SecretKeyReadError(f"cannot load JWS key {key_file}: {e}") from e + if not jws_key.can_validate(): + raise SecretKeyReadError( + f"the cluster JWS key file {key_file} does not " + f"contain a valid key for token validation (RSA, EC or " + f"HMAC-SHA256)") + + return jws_key + + +def generate_jwk(keys_file: pathlib.Path) -> None: + key = JWKSet() + # kid is yyyymmdd + kid = datetime.datetime.now(datetime.timezone.utc).strftime("%Y%m%d") + key.generate(kid=kid, kty='ES256') + if keys_file.name.endswith(".pem"): + with keys_file.open("wb") as f: + f.write(key.export_pem()) + elif keys_file.name.endswith(".json"): + with keys_file.open("wb") as f: + f.write(key.export_json()) + else: + raise ValueError(f"Unsupported key file extension {keys_file.suffix}. " + "Use .pem or .json extension when generating a key.") + + keys_file.chmod(0o600) + + +class SecretKeyReadError(Exception): + pass diff --git a/edb/server/cluster.py b/edb/server/cluster.py index 6f4524bb06f7..8307f91ced34 100644 --- a/edb/server/cluster.py +++ b/edb/server/cluster.py @@ -30,8 +30,6 @@ import tempfile import time -from jwcrypto import jwk - from edb import buildmeta from edb.common import devmode from edb.edgeql import quote @@ -39,6 +37,7 @@ from edb.server import args as edgedb_args from edb.server import defines as edgedb_defines from edb.server import pgconnparams +from edb.server import auth from . import pgcluster @@ -428,7 +427,7 @@ def __init__( self._edgedb_cmd.extend(['-D', str(self._data_dir)]) self._pg_connect_args['user'] = pg_superuser self._pg_connect_args['database'] = 'template1' - self._jws_key: Optional[jwk.JWK] = None + self._jws_key: Optional[auth.JWKSet] = None async def _new_pg_cluster(self) -> pgcluster.Cluster: return await pgcluster.get_local_pg_cluster( diff --git a/edb/server/metrics.py b/edb/server/metrics.py index c76b62378e31..352f04a25a69 100644 --- a/edb/server/metrics.py +++ b/edb/server/metrics.py @@ -223,6 +223,30 @@ labels=("tenant",), ) +auth_provider_jwkset_fetch_success = registry.new_labeled_counter( + "auth_provider_jwkset_fetch_success_total", + "Number of successful Auth extension JWK Set fetches.", + labels=("provider",), +) + +auth_provider_jwkset_fetch_errors = registry.new_labeled_counter( + "auth_provider_jwkset_fetch_errors_total", + "Number of failed Auth extension JWK Set fetches.", + labels=("provider",), +) + +auth_provider_token_validation_success = registry.new_labeled_counter( + "auth_provider_token_validation_success_total", + "Number of successful Auth extension provider token validations.", + labels=("provider",), +) + +auth_provider_token_validation_errors = registry.new_labeled_counter( + "auth_provider_token_validation_errors_total", + "Number of failed Auth extension provider token validations.", + labels=("provider",), +) + mt_tenants_total = registry.new_gauge( 'mt_tenants_current', 'Total number of currently-registered tenants.', diff --git a/edb/server/protocol/auth/scram.py b/edb/server/protocol/auth/scram.py index 436c549d1fe9..39d5ea33a0d9 100644 --- a/edb/server/protocol/auth/scram.py +++ b/edb/server/protocol/auth/scram.py @@ -29,7 +29,7 @@ from edb.common import debug from edb.common import markup -from edb.common import secretkey +from edb.server import auth if TYPE_CHECKING: from edb.server import tenant as edbtenant @@ -89,7 +89,8 @@ def handle_request( response.close_connection = True return - if not server.get_jws_key().has_private: # type: ignore[union-attr] + jws = server.get_jws_key() + if jws is None or not jws.has_private_keys(): response.body = b"Server doesn't support HTTP SCRAM authentication" response.status = http.HTTPStatus.FORBIDDEN response.close_connection = True @@ -268,9 +269,8 @@ def handle_request( ).decode("ascii") try: - response.body = secretkey.generate_secret_key( - server.get_jws_key(), - roles=[username], + response.body = auth.generate_gel_token( + jws, roles=[username], ).encode("ascii") except ValueError as ex: if debug.flags.server: diff --git a/edb/server/protocol/auth_ext/base.py b/edb/server/protocol/auth_ext/base.py index d43a492e4a61..2884b4b6e285 100644 --- a/edb/server/protocol/auth_ext/base.py +++ b/edb/server/protocol/auth_ext/base.py @@ -18,15 +18,19 @@ import uuid import urllib.parse -import json import enum +import logging from typing import Any, Callable -from jwcrypto import jwt, jwk from datetime import datetime from . import data, errors from edb.server.http import HttpClient +from edb.server import auth as jwt_auth +from edb.server.protocol.auth_ext import util as auth_util +from edb.server import metrics + +logger = logging.getLogger("edb.server.ext.auth") class BaseProvider: @@ -145,28 +149,48 @@ async def fetch_user_info( ) id_token = token_response.id_token - # Retrieve JWK Set + # Retrieve JWK Set, potentially from the cache oidc_config = await self._get_oidc_config() - jwks_uri = urllib.parse.urlparse(oidc_config.jwks_uri) - async with self.http_factory( - base_url=f"{jwks_uri.scheme}://{jwks_uri.netloc}" - ) as client: - r = await client.get(jwks_uri.path) - # Load the token as a JWT object and verify it directly try: - jwk_set = jwk.JWKSet.from_json(r.text) - id_token_verified = jwt.JWT(key=jwk_set, jwt=id_token) - payload = json.loads(id_token_verified.claims) + async def fetcher(url: str) -> jwt_auth.JWKSet: + jwks_uri = urllib.parse.urlparse(url) + async with self.http_factory( + base_url=f"{jwks_uri.scheme}://{jwks_uri.netloc}" + ) as client: + r = await client.get(jwks_uri.path) + jwk_set = jwt_auth.JWKSet() + jwk_set.load_json(r.text) + jwk_set.set_audience(self.client_id) + jwk_set.set_expiry(3600) + metrics.auth_provider_jwkset_fetch_success.inc( + 1.0, self.name + ) + return jwk_set + + jwk_set = await auth_util.get_remote_jwtset( + oidc_config.jwks_uri, fetcher + ) + except Exception as e: + metrics.auth_provider_jwkset_fetch_errors.inc(1.0, self.name) + logger.exception( + f"Failed to fetch JWK Set from provider {oidc_config.jwks_uri}" + ) + raise errors.MisconfiguredProvider( + f"Failed to fetch JWK Set from provider {oidc_config.jwks_uri}" + ) from e + + # Load the token as a JWT object and verify it directly. This will + # validate the audience and expiry. + try: + payload = jwk_set.validate(id_token) except Exception as e: + metrics.auth_provider_token_validation_errors.inc(1.0, self.name) raise errors.MisconfiguredProvider( "Failed to parse ID token with provider keyset" ) from e - if payload.get("aud") != self.client_id: - raise errors.InvalidData( - "Invalid value for aud in id_token: " - f"{payload.get('aud')} != {self.client_id}" - ) + + metrics.auth_provider_token_validation_success.inc(1.0, self.name) return data.UserInfo( sub=str(payload["sub"]), diff --git a/edb/server/protocol/auth_ext/http.py b/edb/server/protocol/auth_ext/http.py index 6fe948b7916c..3585ee92f836 100644 --- a/edb/server/protocol/auth_ext/http.py +++ b/edb/server/protocol/auth_ext/http.py @@ -43,13 +43,12 @@ ) import aiosmtplib -from jwcrypto import jwk, jwt from edb import errors as edb_errors from edb.common import debug from edb.common import markup from edb.ir import statypes -from edb.server import tenant as edbtenant, metrics +from edb.server import tenant as edbtenant, metrics, auth as jwt_auth from edb.server.config.types import CompositeConfigType from . import ( @@ -2065,13 +2064,16 @@ async def _maybe_send_webhook(self, event: webhook.Event) -> None: def _get_callback_url(self) -> str: return f"{self.base_path}/callback" - def _get_auth_signing_key(self) -> jwk.JWK: + def _get_auth_signing_key(self, info: str | None = None) -> jwt_auth.JWKSet: auth_signing_key = util.get_config( self.db, "ext::auth::AuthConfig::auth_signing_key" ) - key_bytes = base64.b64encode(auth_signing_key.encode()) - - return jwk.JWK(kty="oct", k=key_bytes.decode()) + if info is None: + return jwt_auth.JWKSet.from_hs256_key(auth_signing_key.encode()) + else: + return jwt_auth.JWKSet.from_hs256_key( + util.derive_key_raw(auth_signing_key, info) + ) def _make_state_claims( self, @@ -2081,25 +2083,19 @@ def _make_state_claims( challenge: str, ) -> str: signing_key = self._get_auth_signing_key() - expires_at = datetime.datetime.now( - datetime.timezone.utc - ) + datetime.timedelta(minutes=5) + signing_ctx = jwt_auth.SigningCtx() + signing_ctx.set_expiry(5 * 60) + signing_ctx.set_not_before(30) + signing_ctx.set_issuer(self.base_path) state_claims = { - "iss": self.base_path, "provider": provider, - "exp": expires_at.timestamp(), "redirect_to": redirect_to, "challenge": challenge, } if redirect_to_on_signup: state_claims['redirect_to_on_signup'] = redirect_to_on_signup - state_token = jwt.JWT( - header={"alg": "HS256"}, - claims=state_claims, - ) - state_token.make_signed_token(signing_key) - return cast(str, state_token.serialize()) + return signing_key.sign(state_claims, ctx=signing_ctx) def _make_session_token(self, identity_id: str) -> str: signing_key = self._get_auth_signing_key() @@ -2108,34 +2104,15 @@ def _make_session_token(self, identity_id: str) -> str: "ext::auth::AuthConfig::token_time_to_live", statypes.Duration, ) - expires_in = auth_expiration_time.to_timedelta() - expires_at = datetime.datetime.now(datetime.timezone.utc) + expires_in - - claims: dict[str, Any] = { - "iss": self.base_path, + signing_ctx = jwt_auth.SigningCtx() + signing_ctx.set_expiry(int(auth_expiration_time.to_timedelta().total_seconds())) + signing_ctx.set_not_before(30) + signing_key.set_issuer(self.base_path) + session_token = signing_key.sign({ "sub": identity_id, - } - if expires_in.total_seconds() != 0: - claims["exp"] = expires_at.timestamp() - session_token = jwt.JWT( - header={"alg": "HS256"}, - claims=claims, - ) - session_token.make_signed_token(signing_key) + }, ctx=signing_ctx) metrics.auth_successful_logins.inc(1.0, self.tenant.get_instance_name()) - return cast(str, session_token.serialize()) - - def _get_from_claims(self, state: str, key: str) -> str: - signing_key = self._get_auth_signing_key() - try: - state_token = jwt.JWT(key=signing_key, jwt=state) - except Exception: - raise errors.InvalidData("Invalid state token") - state_claims: dict[str, str] = json.loads(state_token.claims) - value = state_claims.get(key) - if value is None: - raise errors.InvalidData("Invalid state token") - return value + return session_token def _make_secret_token( self, @@ -2147,8 +2124,7 @@ def _make_secret_token( ) = None, expires_in: datetime.timedelta | None = None, ) -> str: - input_key_material = self._get_auth_signing_key() - signing_key = util.derive_key(input_key_material, derive_for_info) + signing_key = self._get_auth_signing_key(derive_for_info) expires_in = ( datetime.timedelta(minutes=10) if expires_in is None else expires_in ) @@ -2166,15 +2142,8 @@ def _make_secret_token( def _verify_and_extract_claims( self, jwtStr: str, key_info: str | None = None ) -> dict[str, str | int | float | bool]: - input_key_material = self._get_auth_signing_key() - if key_info is None: - signing_key = input_key_material - else: - signing_key = util.derive_key(input_key_material, key_info) - verified = jwt.JWT(key=signing_key, jwt=jwtStr) - return cast( - dict[str, str | int | float | bool], json.loads(verified.claims) - ) + signing_key = self._get_auth_signing_key(key_info) + return signing_key.validate(jwtStr) def _get_data_from_magic_link_token( self, token: str @@ -2250,14 +2219,14 @@ def _get_data_from_verification_token( ): case ( str(id), - float(issued_at), + float(issued_at) | int(issued_at), verify_url, challenge, redirect_to, ): return_value = ( id, - issued_at, + float(issued_at), verify_url, challenge, redirect_to, @@ -2355,7 +2324,8 @@ def _make_verification_token( "Verify URL does not match any allowed URLs.", ) - issued_at = datetime.datetime.now(datetime.timezone.utc).timestamp() + now = datetime.datetime.now(datetime.timezone.utc) + issued_at = int(now.timestamp()) return self._make_secret_token( identity_id=identity_id, secret=str(uuid.uuid4()), diff --git a/edb/server/protocol/auth_ext/local.py b/edb/server/protocol/auth_ext/local.py index 2022eea504a5..286c13b72117 100644 --- a/edb/server/protocol/auth_ext/local.py +++ b/edb/server/protocol/auth_ext/local.py @@ -19,12 +19,10 @@ import datetime import json -import base64 -from jwcrypto import jwk from typing import Any, cast from edb.server.protocol import execute - +from edb.server import auth as jwt_auth from . import util, data @@ -32,13 +30,16 @@ class Client: def __init__(self, db: Any): self.db = db - def _get_signing_key(self) -> jwk.JWK: + def _get_signing_key(self, info: str | None = None) -> jwt_auth.JWKSet: auth_signing_key = util.get_config( self.db, "ext::auth::AuthConfig::auth_signing_key" ) - key_bytes = base64.b64encode(auth_signing_key.encode()) - - return jwk.JWK(kty="oct", k=key_bytes.decode()) + if info is None: + return jwt_auth.JWKSet.from_hs256_key(auth_signing_key.encode()) + else: + return jwt_auth.JWKSet.from_hs256_key( + util.derive_key_raw(auth_signing_key, info) + ) async def verify_email( self, identity_id: str, verified_at: datetime.datetime diff --git a/edb/server/protocol/auth_ext/magic_link.py b/edb/server/protocol/auth_ext/magic_link.py index 8474a5c886ee..ae478093cfe3 100644 --- a/edb/server/protocol/auth_ext/magic_link.py +++ b/edb/server/protocol/auth_ext/magic_link.py @@ -103,8 +103,7 @@ def make_magic_link_token( callback_url: str, challenge: str, ) -> str: - initial_key_material = self._get_signing_key() - signing_key = util.derive_key(initial_key_material, "magic_link") + signing_key = self._get_signing_key("magic_link") return util.make_token( signing_key=signing_key, issuer=self.issuer, diff --git a/edb/server/protocol/auth_ext/util.py b/edb/server/protocol/auth_ext/util.py index 57b6977386bf..98619a3d827a 100644 --- a/edb/server/protocol/auth_ext/util.py +++ b/edb/server/protocol/auth_ext/util.py @@ -19,19 +19,22 @@ from __future__ import annotations -import base64 import urllib.parse import datetime import html +import logging +import asyncio from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.kdf.hkdf import HKDFExpand from cryptography.hazmat.backends import default_backend -from jwcrypto import jwt, jwk -from typing import TypeVar, Type, overload, Any, cast, Optional, TYPE_CHECKING +from typing import ( + TypeVar, Type, overload, Any, cast, Optional, TYPE_CHECKING, Callable, + Awaitable +) -from edb.server import config as edb_config +from edb.server import config as edb_config, auth as jwt_auth from edb.server.config.types import CompositeConfigType from . import errors, config @@ -41,6 +44,11 @@ T = TypeVar("T") +logger = logging.getLogger('edb.server.ext.auth') + +# Cache JWKSets for 10 minutes +jwtset_cache = jwt_auth.JWKSetCache(60 * 10) + def maybe_get_config_unchecked(db: edbtenant.dbview.Database, key: str) -> Any: return edb_config.lookup(key, db.db_config, spec=db.user_config_spec) @@ -154,45 +162,60 @@ def join_url_params(url: str, params: dict[str, str]) -> str: return parsed_url._replace(query=new_query_params).geturl() +async def get_remote_jwtset( + url: str, fetch_lambda: Callable[[str], Awaitable[jwt_auth.JWKSet]] +) -> jwt_auth.JWKSet: + """ + Get a JWKSet from the cache, or fetch it from the given URL if it's not in + the cache. + """ + is_fresh, jwtset = jwtset_cache.get(url) + if is_fresh and jwtset is not None: + return jwtset + + if jwtset is None: + jwtset = await fetch_lambda(url) + jwtset_cache.set(url, jwtset) + return jwtset + + # Run fetch in background to refresh cache + async def refresh_cache(url: str) -> None: + try: + new_jwtset = await fetch_lambda(url) + jwtset_cache.set(url, new_jwtset) + except Exception: + logger.exception("Failed to refresh JWKSet cache for %s", url) + + asyncio.create_task(refresh_cache(url)) + return jwtset + + def make_token( - signing_key: jwk.JWK, + signing_key: jwt_auth.JWKSet, issuer: str, subject: str, additional_claims: dict[str, str | int | float | bool | None] | None = None, include_issued_at: bool = False, expires_in: datetime.timedelta | None = None, ) -> str: - now = datetime.datetime.now(datetime.timezone.utc) - expires_in = ( - datetime.timedelta(seconds=0) if expires_in is None else expires_in - ) - expires_at = now + expires_in + signing_ctx = jwt_auth.SigningCtx() + signing_ctx.set_issuer(issuer) + if expires_in is not None and int(expires_in.total_seconds()) != 0: + signing_ctx.set_expiry(int(expires_in.total_seconds())) + if include_issued_at: + signing_ctx.set_not_before(30) claims: dict[str, Any] = { - "iss": issuer, "sub": subject, **(additional_claims or {}), } - if expires_in.total_seconds() != 0: - claims["exp"] = expires_at.timestamp() - if include_issued_at: - claims["iat"] = now.timestamp() - token = jwt.JWT( - header={"alg": "HS256"}, - claims=claims, - ) - token.make_signed_token(signing_key) + return signing_key.sign(claims, ctx=signing_ctx) - return cast(str, token.serialize()) - -def derive_key(key: jwk.JWK, info: str) -> jwk.JWK: +def derive_key_raw(key: str, info: str) -> bytes: """Derive a new key from the given symmetric key using HKDF.""" - - # n.b. the key is returned as a base64url-encoded string - raw_key_base64url = cast(str, key.get_op_key()) - input_key_material = base64.urlsafe_b64decode(raw_key_base64url) + input_key_material = key.encode() backend = default_backend() hkdf = HKDFExpand( @@ -202,7 +225,4 @@ def derive_key(key: jwk.JWK, info: str) -> jwk.JWK: backend=backend, ) new_key_bytes = hkdf.derive(input_key_material) - return jwk.JWK( - kty="oct", - k=new_key_bytes.hex(), - ) + return new_key_bytes diff --git a/edb/server/protocol/auth_helpers.pxd b/edb/server/protocol/auth_helpers.pxd index 9100d670051f..3fde9e8ff8a0 100644 --- a/edb/server/protocol/auth_helpers.pxd +++ b/edb/server/protocol/auth_helpers.pxd @@ -19,8 +19,6 @@ cdef extract_token_from_auth_data(bytes auth_data) cdef auth_jwt(tenant, prefixed_token, str user, str dbname) -cdef _check_jwt_authz(tenant, claims, token_version, str user, str dbname) -cdef _get_jwt_edb_scope(claims, claim) cdef scram_get_verifier(tenant, str user) cdef parse_basic_auth(str auth_payload) cdef extract_http_user(scheme, auth_payload, params) diff --git a/edb/server/protocol/auth_helpers.pyx b/edb/server/protocol/auth_helpers.pyx index f0b22bc764f6..76edba42e27d 100644 --- a/edb/server/protocol/auth_helpers.pyx +++ b/edb/server/protocol/auth_helpers.pyx @@ -25,10 +25,8 @@ import hashlib import json import logging -from jwcrypto import jwt - from edb import errors - +from edb.server.auth import validate_gel_token cdef object logger = logging.getLogger('edb.server') @@ -44,107 +42,15 @@ cdef auth_jwt(tenant, prefixed_token: str | None, user: str, dbname: str): raise errors.AuthenticationError( 'authentication failed: no authorization data provided') - token_version = 0 - for prefix in ["nbwt1_", "nbwt_", "edbt1_", "edbt_"]: - encoded_token = prefixed_token.removeprefix(prefix) - if encoded_token != prefixed_token: - if prefix == "nbwt1_" or prefix == "edbt1_": - token_version = 1 - break - else: - raise errors.AuthenticationError( - 'authentication failed: malformed JWT') + key = tenant.server.get_jws_key() + if err := validate_gel_token(key, prefixed_token, user, dbname, tenant.get_instance_name()): + raise errors.AuthenticationError(str(err)) + # Ensure it's a valid role, but check after the JWT is validated role = tenant.get_roles().get(user) if role is None: raise errors.AuthenticationError('authentication failed') - skey = tenant.server.get_jws_key() - - try: - token = jwt.JWT( - key=skey, - algs=["RS256", "ES256"], - jwt=encoded_token, - ) - except jwt.JWException as e: - logger.debug('authentication failure', exc_info=True) - raise errors.AuthenticationError( - f'authentication failed: {e.args[0]}' - ) from None - except Exception as e: - logger.debug('authentication failure', exc_info=True) - raise errors.AuthenticationError( - f'authentication failed: cannot decode JWT' - ) from None - - try: - claims = json.loads(token.claims) - except Exception as e: - raise errors.AuthenticationError( - f'authentication failed: malformed claims section in JWT' - ) from None - - _check_jwt_authz( - tenant, claims, token_version, user, dbname) - - -cdef _check_jwt_authz(tenant, claims, token_version, user: str, dbname: str): - # Check general key validity (e.g. whether it's a revoked key) - tenant.check_jwt(claims) - - token_instances = None - token_roles = None - token_databases = None - - if token_version == 1: - token_roles = _get_jwt_edb_scope(claims, "edb.r") - token_instances = _get_jwt_edb_scope(claims, "edb.i") - token_databases = _get_jwt_edb_scope(claims, "edb.d") - else: - namespace = "edgedb.server" - if not claims.get(f"{namespace}.any_role"): - token_roles = claims.get(f"{namespace}.roles") - if not isinstance(token_roles, list): - raise errors.AuthenticationError( - f'authentication failed: malformed claims section in' - f' JWT: expected a list in "{namespace}.roles"' - ) - - if ( - token_instances is not None - and tenant.get_instance_name() not in token_instances - ): - raise errors.AuthenticationError( - 'authentication failed: secret key does not authorize ' - f'access to this instance') - - if ( - token_databases is not None - and dbname not in token_databases - ): - raise errors.AuthenticationError( - 'authentication failed: secret key does not authorize ' - f'access to database "{dbname}"') - - if token_roles is not None and user not in token_roles: - raise errors.AuthenticationError( - 'authentication failed: secret key does not authorize ' - f'access in role "{user}"') - - -cdef _get_jwt_edb_scope(claims, claim): - if not claims.get(f"{claim}.all"): - scope = claims.get(claim, []) - if not isinstance(scope, list): - raise errors.AuthenticationError( - f'authentication failed: malformed claims section in' - f' JWT: expected a list in "{claim}"' - ) - return frozenset(scope) - else: - return None - cdef scram_get_verifier(tenant, user: str): roles = tenant.get_roles() diff --git a/edb/server/server.py b/edb/server/server.py index ddd841803a9b..3f784707b7b6 100644 --- a/edb/server/server.py +++ b/edb/server/server.py @@ -48,7 +48,6 @@ import uuid import immutables -from jwcrypto import jwk from edb import buildmeta from edb import errors @@ -62,6 +61,7 @@ from edb.schema import reflection as s_refl from edb.schema import schema as s_schema +from edb.server import auth from edb.server import args as srvargs from edb.server import cache from edb.server import config @@ -241,7 +241,7 @@ def __init__( self._sslctx: ssl.SSLContext | Any = None self._sslctx_pgext: ssl.SSLContext | Any = None - self._jws_key: jwk.JWK | None = None + self._jws_key: auth.JWKSet | None = None self._jws_keys_newly_generated = False self._default_auth_method_spec = default_auth_method @@ -1018,10 +1018,12 @@ def start_watching_files(self): # TODO(fantix): include the monitor_fs() lines above pass - def load_jwcrypto(self, jws_key_file: pathlib.Path) -> None: + def load_jwcrypto(self, jws_key_file: pathlib.Path) -> auth.JWKSet: try: - self._jws_key = secretkey.load_secret_key(jws_key_file) - except secretkey.SecretKeyReadError as e: + jws_key = auth.load_secret_key(jws_key_file) + self._jws_key = jws_key + return jws_key + except auth.SecretKeyReadError as e: raise StartupError(e.args[0]) from e def init_jwcrypto( @@ -1032,7 +1034,7 @@ def init_jwcrypto( self.load_jwcrypto(jws_key_file) self._jws_keys_newly_generated = jws_keys_newly_generated - def get_jws_key(self) -> jwk.JWK | None: + def get_jws_key(self) -> auth.JWKSet | None: return self._jws_key async def _stop_servers(self, servers): @@ -1264,7 +1266,7 @@ async def maybe_generate_pki( logger.info( f'generating JOSE key pair in "{args.jws_key_file}"' ) - secretkey.generate_jwk(args.jws_key_file) + auth.generate_jwk(args.jws_key_file) jws_keys_newly_generated = True return tls_cert_newly_generated, jws_keys_newly_generated @@ -1765,8 +1767,8 @@ def _get_status(self) -> dict[str, Any]: return status def load_jwcrypto(self, jws_key_file: pathlib.Path) -> None: - super().load_jwcrypto(jws_key_file) - self._tenant.load_jwcrypto() + jws_key = super().load_jwcrypto(jws_key_file) + self._tenant.load_jwcrypto(jws_key) def request_shutdown(self): self._tenant.stop_accepting_connections() diff --git a/edb/server/tenant.py b/edb/server/tenant.py index bbdb0939e695..5d911438037a 100644 --- a/edb/server/tenant.py +++ b/edb/server/tenant.py @@ -58,6 +58,7 @@ from edb.common import verutils from edb.common.log import current_tenant +from . import auth from . import args as srvargs from . import config from . import connpool @@ -1537,7 +1538,8 @@ def schedule_reported_config_if_needed(self, setting_name: str) -> None: if setting and setting.report and self._accept_new_tasks: self.create_task(self._load_reported_config(), interruptable=True) - def load_jwcrypto(self) -> None: + def load_jwcrypto(self, jwk_key: auth.JWKSet) -> None: + self._jws_key = jwk_key self.load_jwt_sub_allowlist() self.load_jwt_revocation_list() @@ -1551,6 +1553,8 @@ def load_jwt_sub_allowlist(self) -> None: self._jwt_sub_allowlist = frozenset( self._jwt_sub_allowlist_file.read_text().splitlines(), ) + if self._jws_key is not None: + self._jws_key.allow("sub", self._jwt_sub_allowlist) except Exception as e: from . import server as edbserver @@ -1568,6 +1572,8 @@ def load_jwt_revocation_list(self) -> None: self._jwt_revocation_list = frozenset( self._jwt_revocation_list_file.read_text().splitlines(), ) + if self._jws_key is not None: + self._jws_key.deny("jti", self._jwt_revocation_list) except Exception as e: from . import server as edbserver @@ -1575,33 +1581,6 @@ def load_jwt_revocation_list(self) -> None: f"cannot load JWT revocation list: {e}" ) from e - def check_jwt(self, claims: dict[str, Any]) -> None: - """Check JWT for validity""" - - if self._jwt_sub_allowlist is not None: - subject = claims.get("sub") - if not subject: - raise errors.AuthenticationError( - "authentication failed: " - "JWT does not contain a valid subject claim" - ) - if subject not in self._jwt_sub_allowlist: - raise errors.AuthenticationError( - "authentication failed: unauthorized subject" - ) - - if self._jwt_revocation_list is not None: - key_id = claims.get("jti") - if not key_id: - raise errors.AuthenticationError( - "authentication failed: " - "JWT does not contain a valid key id" - ) - if key_id in self._jwt_revocation_list: - raise errors.AuthenticationError( - "authentication failed: revoked key" - ) - def reload_readiness_state(self) -> None: if self._readiness_state_file is None: return diff --git a/edb/testbase/server.py b/edb/testbase/server.py index 354ca9abc998..3b1feccbd702 100644 --- a/edb/testbase/server.py +++ b/edb/testbase/server.py @@ -70,6 +70,7 @@ from edb.server import cluster as edgedb_cluster from edb.server import pgcluster from edb.server import defines as edgedb_defines +from edb.server import auth from edb.server.pgconnparams import ConnectionParams from edb.common import assert_data_shape @@ -131,7 +132,7 @@ def get_test_cases(tests): bag = assert_data_shape.bag -generate_jwk = secretkey.generate_jwk +generate_jwk = auth.generate_jwk generate_tls_cert = secretkey.generate_tls_cert diff --git a/edb/tools/test/runner.py b/edb/tools/test/runner.py index b1be01269424..bfca6beafe41 100644 --- a/edb/tools/test/runner.py +++ b/edb/tools/test/runner.py @@ -993,7 +993,7 @@ def run( not os.environ.get("EDGEDB_SERVER_JWS_KEY_FILE") and not os.environ.get("GEL_SERVER_JWS_KEY_FILE") ): - jwk_file = pathlib.Path(tempdir.name) / "jwk.pem" + jwk_file = pathlib.Path(tempdir.name) / "jwk.json" if self.verbosity >= 1: self._echo( 'Generating JSON Web Key...', diff --git a/pyproject.toml b/pyproject.toml index 6d601b285b3f..641ff8c6a695 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,6 @@ dependencies = [ 'click~=8.0', 'cryptography~=42.0', 'graphql-core~=3.1.5', - 'jwcrypto~=1.3.1', 'psutil~=5.8', 'setproctitle~=1.2', diff --git a/rust-toolchain.toml b/rust-toolchain.toml index 20e20c7b43ff..ae43d2ab08ba 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,3 +1,3 @@ [toolchain] -channel = "1.81" +channel = "1.82" components = [ "rustfmt", "clippy" ] diff --git a/rust/gel-jwt/Cargo.toml b/rust/gel-jwt/Cargo.toml index 7ff8188e1e2c..1839ea373ebb 100644 --- a/rust/gel-jwt/Cargo.toml +++ b/rust/gel-jwt/Cargo.toml @@ -9,6 +9,7 @@ python_extension = ["pyo3/extension-module"] [dependencies] pyo3 = { workspace = true, optional = true } pyo3_util.workspace = true +tracing.workspace = true # This is required to be in sync w/jsonwebtoken rand = "0.8.5" diff --git a/rust/gel-jwt/src/bare_key.rs b/rust/gel-jwt/src/bare_key.rs index 1b09330b8da3..4b69d08935b2 100644 --- a/rust/gel-jwt/src/bare_key.rs +++ b/rust/gel-jwt/src/bare_key.rs @@ -1059,11 +1059,10 @@ fn handle_rsa_pubkey(key: &Pem) -> Result { /// Decode a base64 string with optional padding, since jwcrypto also seems to /// accept this. /// -/// :JWKs make use of the base64url encoding as defined in RFC 4648 [RFC4648]. -/// As allowed by Section 3.2 of the RFC, this specification mandates that -/// base64url encoding when used with JWKs MUST NOT use padding. Notes on -/// implementing base64url encoding can be found in the JWS [JWS] -/// specification."" +/// > JWKs make use of the base64url encoding as defined in RFC 4648 As allowed +/// > by Section 3.2 of the RFC, this specification mandates that base64url +/// > encoding when used with JWKs MUST NOT use padding. Notes on implementing +/// > base64url encoding can be found in the JWS specification. fn b64_decode(s: &str) -> Result>, KeyError> { let vec = if s.ends_with('=') { base64ct::Base64Url::decode_vec(s).map_err(|_| KeyError::DecodeError)? diff --git a/rust/gel-jwt/src/lib.rs b/rust/gel-jwt/src/lib.rs index 5cfda4d93b83..a4c53cdb4e7e 100644 --- a/rust/gel-jwt/src/lib.rs +++ b/rust/gel-jwt/src/lib.rs @@ -118,6 +118,29 @@ pub enum Any { Object(HashMap, Any>), } +impl Any { + pub fn as_str(&self) -> Option<&str> { + match self { + Any::String(s) => Some(s.as_ref()), + _ => None, + } + } + + pub fn as_array(&self) -> Option<&[Any]> { + match self { + Any::Array(a) => Some(a), + _ => None, + } + } + + pub fn as_object(&self) -> Option<&HashMap, Any>> { + match self { + Any::Object(o) => Some(o), + _ => None, + } + } +} + impl From for Any { fn from(value: bool) -> Self { Any::Bool(value) @@ -485,7 +508,7 @@ mod tests { ("number".to_owned(), Any::Number(123)), ( "array".to_owned(), - Any::Array(vec![Any::String("1".into()), Any::String("2".into())].into()), + Any::Array(vec![Any::String("1".into()), Any::String("2".into())]), ), ]); let json = serde_json::to_string(&map).unwrap(); diff --git a/rust/gel-jwt/src/python.rs b/rust/gel-jwt/src/python.rs index 341bd2a5903a..5e6d4385f0e8 100644 --- a/rust/gel-jwt/src/python.rs +++ b/rust/gel-jwt/src/python.rs @@ -1,5 +1,5 @@ use std::{ - collections::HashMap, + collections::{HashMap, HashSet}, time::{Duration, Instant}, }; @@ -13,6 +13,7 @@ use pyo3::{ types::{PyBytes, PyDict, PyList}, }; use serde::{Deserialize, Serialize}; +use tracing::warn; use uuid::Uuid; impl From for PyErr { @@ -186,17 +187,23 @@ impl JWKSet { self.context.expiry = Some(Duration::from_secs(expiry as u64)); } - pub fn allow(&mut self, claim: &str, values: Bound) -> PyResult<()> { - self.context - .allow - .insert(claim.to_string(), values.extract()?); + pub fn allow(&mut self, claim: &str, values: Bound) -> PyResult<()> { + let mut entries: HashSet = HashSet::new(); + for value in values.try_iter()? { + let value = value?.extract::()?; + entries.insert(value); + } + self.context.allow.insert(claim.to_string(), entries); Ok(()) } - pub fn deny(&mut self, claim: &str, values: Bound) -> PyResult<()> { - self.context - .deny - .insert(claim.to_string(), values.extract()?); + pub fn deny(&mut self, claim: &str, values: Bound) -> PyResult<()> { + let mut entries: HashSet = HashSet::new(); + for value in values.try_iter()? { + let value = value?.extract::()?; + entries.insert(value); + } + self.context.deny.insert(claim.to_string(), entries); Ok(()) } @@ -219,10 +226,6 @@ impl JWKSet { .into_bytes()) } - pub fn can_sign(&self) -> bool { - self.registry.can_sign() - } - /// Sign a claims object with the default or given signing context. #[pyo3(signature = (claims, *, ctx=None))] pub fn sign(&self, claims: Bound, ctx: Option<&SigningCtx>) -> PyResult { @@ -238,9 +241,33 @@ impl JWKSet { Ok(claims) } + pub fn can_sign(&self) -> bool { + self.registry.can_sign() + } + + pub fn can_validate(&self) -> bool { + self.registry.can_validate() + } + + pub fn has_public_keys(&self) -> bool { + self.registry.has_public_keys() + } + + pub fn has_private_keys(&self) -> bool { + self.registry.has_private_keys() + } + + pub fn has_symmetric_keys(&self) -> bool { + self.registry.has_symmetric_keys() + } + pub fn __repr__(&self) -> String { format!("JWKSet(keys={})", self.registry.len()) } + + pub fn __len__(&self) -> usize { + self.registry.len() + } } #[derive(Debug, Default, Serialize, Deserialize)] @@ -324,13 +351,15 @@ impl JWKSetCache { } } +/// Generate a token with optional additional claims. #[pyfunction] -#[pyo3(signature = (registry, *, instances=None, roles=None, databases=None))] +#[pyo3(signature = (registry, *, instances=None, roles=None, databases=None, **kwargs))] fn generate_gel_token( registry: &JWKSet, instances: Option>, roles: Option>, databases: Option>, + kwargs: Option>, ) -> PyResult { let mut claims = GelClaims::default(); @@ -353,26 +382,176 @@ fn generate_gel_token( } claims.jti = Uuid::new_v4(); + let mut claims_map = HashMap::new(); + claims_map.insert("jti".to_string(), Any::from(claims.jti.to_string())); + + if claims.all_instances { + claims_map.insert("edb.i.all".to_string(), Any::from(true)); + } else if let Some(instances) = claims.instances { + claims_map.insert("edb.i".to_string(), Any::from(instances)); + } + + if claims.all_roles { + claims_map.insert("edb.r.all".to_string(), Any::from(true)); + } else if let Some(roles) = claims.roles { + claims_map.insert("edb.r".to_string(), Any::from(roles)); + } - let claims = HashMap::from([ - ("edb.i".to_string(), Any::from(claims.instances)), - ("edb.i.all".to_string(), Any::from(claims.all_instances)), - ("edb.r".to_string(), Any::from(claims.roles)), - ("edb.r.all".to_string(), Any::from(claims.all_roles)), - ("edb.d".to_string(), Any::from(claims.databases)), - ("edb.d.all".to_string(), Any::from(claims.all_databases)), - ("jti".to_string(), Any::from(claims.jti.to_string())), - ]); - - let token = registry.registry.sign(claims, ®istry.context)?; + if claims.all_databases { + claims_map.insert("edb.d.all".to_string(), Any::from(true)); + } else if let Some(databases) = claims.databases { + claims_map.insert("edb.d".to_string(), Any::from(databases)); + } + + if let Some(kwargs) = kwargs { + for (key, value) in kwargs.iter() { + let key = key.extract::()?; + let value = value.extract::()?; + claims_map.insert(key, value); + } + } + + let token = registry.registry.sign(claims_map, ®istry.context)?; Ok(format!("edbt1_{}", token)) } +#[derive(Debug, Default)] +enum TokenMatch { + #[default] + None, + All, + Some(HashSet), +} + +impl TokenMatch { + fn from_claims( + claims: &HashMap, + all_key: &str, + array_key: &str, + ) -> PyResult { + if claims.contains_key(all_key) { + Ok(TokenMatch::All) + } else { + let Some(array) = claims.get(array_key).and_then(|v| v.as_array()) else { + warn!("Missing claims array key: {array_key}"); + return Err(PyErr::new::( + "authentication failed: malformed JWT", + )); + }; + Ok(TokenMatch::Some( + array + .iter() + .map(|v| v.as_str().unwrap_or_default().to_string()) + .collect::>(), + )) + } + } + + fn matches(&self, value: &str) -> bool { + match self { + TokenMatch::All => true, + TokenMatch::Some(set) => set.contains(value), + TokenMatch::None => false, + } + } +} + +#[derive(Debug, Default)] +struct TokenClaims { + instances: TokenMatch, + roles: TokenMatch, + databases: TokenMatch, +} + +#[pyfunction] +#[pyo3(signature = (registry, token, user, dbname, instance_name))] +fn validate_gel_token( + registry: &JWKSet, + token: &str, + user: &str, + dbname: &str, + instance_name: &str, +) -> PyResult> { + let mut token_version = 0; + let encoded_token = if let Some(stripped) = token.strip_prefix("nbwt1_") { + token_version = 1; + stripped + } else if let Some(stripped) = token.strip_prefix("nbwt_") { + stripped + } else if let Some(stripped) = token.strip_prefix("edbt1_") { + token_version = 1; + stripped + } else if let Some(stripped) = token.strip_prefix("edbt_") { + stripped + } else { + warn!( + "Invalid token prefix: [{}...]", + &token[0..token.len().min(7)] + ); + return Ok(Some("authentication failed: malformed JWT".to_string())); + }; + + // Validate and decode the JWT + let decoded = match registry.registry.validate(encoded_token, ®istry.context) { + Ok(claims) => claims, + Err(e) => { + warn!("Invalid token: {}", e.error_string_not_for_user()); + return Ok(Some( + "authentication failed: Verification failed".to_string(), + )); + } + }; + + let claims = if token_version == 0 { + // Legacy v0 token: "edgedb.server.any_role" is a boolean, "edgedb.server.roles" is an array of strings + let roles = + TokenMatch::from_claims(&decoded, "edgedb.server.any_role", "edgedb.server.roles")?; + TokenClaims { + roles, + instances: TokenMatch::All, + databases: TokenMatch::All, + } + } else { + // New v1 token: "edb.{i,r,d}.all" are booleans, "edb.{i,r,d}" are arrays of strings + let instances = TokenMatch::from_claims(&decoded, "edb.i.all", "edb.i")?; + let roles = TokenMatch::from_claims(&decoded, "edb.r.all", "edb.r")?; + let databases = TokenMatch::from_claims(&decoded, "edb.d.all", "edb.d")?; + TokenClaims { + instances, + roles, + databases, + } + }; + + if !claims.instances.matches(instance_name) { + warn!("Instance not in token: {instance_name}"); + return Ok(Some( + "authentication failed: secret key does not authorize access to this instance" + .to_string(), + )); + } + if !claims.roles.matches(user) { + warn!("Role not in token: {user}"); + return Ok(Some(format!( + "authentication failed: secret key does not authorize access in role {user:?}" + ))); + } + if !claims.databases.matches(dbname) { + warn!("Database not in token: {dbname}"); + return Ok(Some(format!( + "authentication failed: secret key does not authorize access to database {dbname:?}" + ))); + } + + Ok(None) +} + #[pymodule] pub fn _jwt(_py: Python, m: &Bound) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_function(wrap_pyfunction!(generate_gel_token, m)?)?; + m.add_function(wrap_pyfunction!(validate_gel_token, m)?)?; Ok(()) } diff --git a/rust/gel-jwt/src/registry.rs b/rust/gel-jwt/src/registry.rs index 9daddf7d2dbf..466307db2a30 100644 --- a/rust/gel-jwt/src/registry.rs +++ b/rust/gel-jwt/src/registry.rs @@ -257,12 +257,6 @@ impl KeyRegistry { Err(result.unwrap_or(OpaqueValidationFailureReason::NoAppropriateKey.into())) } - pub fn can_sign(&self) -> bool { - self.active_key() - .map(|(_, k)| K::encoding_key(k).is_some()) - .unwrap_or(false) - } - pub fn sign( &self, claims: HashMap, @@ -274,11 +268,97 @@ impl KeyRegistry { } } -impl KeyRegistry {} +impl KeyRegistry { + pub fn can_sign(&self) -> bool { + self.has_private_keys() || self.has_symmetric_keys() + } + + pub fn can_validate(&self) -> bool { + self.has_public_keys() || self.has_symmetric_keys() + } + + pub fn has_private_keys(&self) -> bool { + !self.is_empty() + } + + pub fn has_public_keys(&self) -> bool { + self.key_to_ordinal + .iter() + .any(|(k, _)| k.bare_key.key_type() != KeyType::HS256) + } + + pub fn has_symmetric_keys(&self) -> bool { + self.key_to_ordinal + .iter() + .any(|(k, _)| k.bare_key.key_type() == KeyType::HS256) + } +} + +impl KeyRegistry { + pub fn can_sign(&self) -> bool { + self.has_private_keys() || self.has_symmetric_keys() + } + + pub fn can_validate(&self) -> bool { + self.has_public_keys() || self.has_symmetric_keys() + } + + pub fn has_public_keys(&self) -> bool { + !self.is_empty() + } -impl KeyRegistry {} + pub fn has_private_keys(&self) -> bool { + false + } + + pub fn has_symmetric_keys(&self) -> bool { + false + } +} impl KeyRegistry { + pub fn can_sign(&self) -> bool { + self.has_private_keys() || self.has_symmetric_keys() + } + + pub fn can_validate(&self) -> bool { + self.has_public_keys() || self.has_symmetric_keys() + } + + pub fn has_private_keys(&self) -> bool { + for k in self.key_to_ordinal.keys() { + if let KeyInner::Private(_) = k { + return true; + } + } + false + } + + pub fn has_public_keys(&self) -> bool { + for k in self.key_to_ordinal.keys() { + if let KeyInner::Public(_) = k { + return true; + } + if let KeyInner::Private(k) = k { + if k.bare_key.key_type() != KeyType::HS256 { + return true; + } + } + } + false + } + + pub fn has_symmetric_keys(&self) -> bool { + for k in self.key_to_ordinal.keys() { + if let KeyInner::Private(k) = k { + if k.bare_key.key_type() == KeyType::HS256 { + return true; + } + } + } + false + } + /// Export the registry as a PEM file containing only the public keys. /// This will fail if the registry contains symmetric keys. pub fn to_pem_public(&self) -> Result { diff --git a/tests/test_http_ext_auth.py b/tests/test_http_ext_auth.py index 1fd6323e9352..0c51d402a1c5 100644 --- a/tests/test_http_ext_auth.py +++ b/tests/test_http_ext_auth.py @@ -30,14 +30,14 @@ import hashlib import hmac -from typing import Any, Optional, cast -from jwcrypto import jwt, jwk +from typing import Optional, cast from email.message import EmailMessage from edgedb import QueryAssertionError from edb.testbase import http as tb from edb.common import assert_data_shape from edb.server.protocol.auth_ext import util as auth_util +from edb.server.auth import JWKSet ph = argon2.PasswordHasher() @@ -359,6 +359,7 @@ def setUpClass(cls): mock_oauth_server: tb.MockHttpServer mock_net_server: tb.MockHttpServer + jwkset_cache: dict[str, JWKSet] = {} def setUp(self): self.mock_oauth_server = tb.MockHttpServer( @@ -449,34 +450,28 @@ async def get_auth_config_value(self, key: str): """ ) - async def get_signing_key(self): + async def get_signing_key(self, info: str | None = None) -> JWKSet: auth_signing_key = SIGNING_KEY - key_bytes = base64.b64encode(auth_signing_key.encode()) - signing_key = jwk.JWK(k=key_bytes.decode(), kty="oct") + if info is not None: + signing_key = JWKSet.from_hs256_key( + auth_util.derive_key_raw(auth_signing_key, info) + ) + else: + signing_key = JWKSet.from_hs256_key(auth_signing_key.encode()) + signing_key.set_expiry(3600) + signing_key.set_not_before(30) return signing_key def generate_state_value( self, - state_claims: dict[str, str | float], - auth_signing_key: jwk.JWK, + state_claims: dict[str, str], + auth_signing_key: JWKSet, ) -> str: - state_token = jwt.JWT( - header={"alg": "HS256"}, - claims=state_claims, - ) - state_token.make_signed_token(auth_signing_key) - return state_token.serialize() + return auth_signing_key.sign(state_claims) async def extract_jwt_claims(self, raw_jwt: str, info: str | None = None): - input_key_material = await self.get_signing_key() - if info is not None: - signing_key = auth_util.derive_key(input_key_material, info) - else: - signing_key = input_key_material - - jwt_token = jwt.JWT(key=signing_key, jwt=raw_jwt) - claims = json.loads(jwt_token.claims) - return claims + signing_key = await self.get_signing_key(info) + return signing_key.validate(raw_jwt) def maybe_get_cookie_value( self, headers: dict[str, str], name: str @@ -492,6 +487,76 @@ def maybe_get_cookie_value( def maybe_get_auth_token(self, headers: dict[str, str]) -> Optional[str]: return self.maybe_get_cookie_value(headers, "edgedb-session") + def generate_and_serve_jwk( + self, + client_id: str, + jwk_cert_url: str, + token_url: str, + issuer: str, + access_token_name: str, + sub: str = "1", + ) -> tuple[str, str, str]: + parts = jwk_cert_url.split("/", 3) + host = parts[0] + "//" + parts[2] + path = parts[3] + jwks_request = ( + "GET", + host, + path, + ) + + # Because we have an internal cache, ensure that we only generate one + # set per issuer + jwks = self.jwkset_cache.get(issuer) + if jwks is None: + jwks = JWKSet() + jwks.set_issuer(issuer) + jwks.generate(kid=None, kty="RS256") + self.jwkset_cache[issuer] = jwks + + jwk_json = jwks.export_json(private_keys=False).decode() + + self.mock_oauth_server.register_route_handler(*jwks_request)( + ( + jwk_json, + 200, + ) + ) + + parts = token_url.split("/", 3) + host = parts[0] + "//" + parts[2] + path = parts[3] + token_request = ( + "POST", + host, + path, + ) + + jwks.set_issuer(issuer) + jwks.set_audience(client_id) + jwks.set_expiry(3600) + jwks.set_not_before(30) + + id_token = jwks.sign({ + "sub": sub, + "email": "test@example.com", + }) + + self.mock_oauth_server.register_route_handler(*token_request)( + ( + json.dumps( + { + "access_token": access_token_name, + "id_token": id_token, + "scope": "openid", + "token_type": "bearer", + } + ), + 200, + ) + ) + return token_request + async def test_http_auth_ext_github_authorize_01(self): with self.http_con() as http_con: provider_config = await self.get_builtin_provider_config_by_name( @@ -576,10 +641,8 @@ async def test_http_auth_ext_github_callback_missing_provider_01(self): with self.http_con() as http_con: signing_key = await self.get_signing_key() - expires_at = utcnow() + datetime.timedelta(minutes=5) missing_provider_state_claims = { "iss": self.http_addr, - "exp": expires_at.timestamp(), } state_token = self.generate_state_value( missing_provider_state_claims, signing_key @@ -599,15 +662,12 @@ async def test_http_auth_ext_github_callback_wrong_key_01(self): "oauth_github" ) provider_name = provider_config.name - signing_key = jwk.JWK( - k=base64.b64encode(("abcd" * 8).encode()).decode(), kty="oct" - ) + signing_key = JWKSet() + signing_key.generate(kid=None, kty="ES256") - expires_at = utcnow() + datetime.timedelta(minutes=5) missing_provider_state_claims = { "iss": self.http_addr, "provider": provider_name, - "exp": expires_at.timestamp(), } state_token_value = self.generate_state_value( missing_provider_state_claims, signing_key @@ -625,11 +685,9 @@ async def test_http_auth_ext_github_unknown_provider_01(self): with self.http_con() as http_con: signing_key = await self.get_signing_key() - expires_at = utcnow() + datetime.timedelta(minutes=5) state_claims = { "iss": self.http_addr, "provider": "beepboopbeep", - "exp": expires_at.timestamp(), } state_token = self.generate_state_value(state_claims, signing_key) @@ -719,11 +777,9 @@ async def test_http_auth_ext_github_callback_01(self): signing_key = await self.get_signing_key() - expires_at = now + datetime.timedelta(minutes=5) state_claims = { "iss": self.http_addr, "provider": str(provider_name), - "exp": expires_at.timestamp(), "redirect_to": f"{self.http_addr}/some/path", "challenge": challenge, } @@ -828,7 +884,6 @@ async def test_http_auth_ext_github_callback_failure_01(self): ) provider_name = provider_config.name - now = utcnow() token_request = ( "POST", "https://github.com", @@ -849,11 +904,9 @@ async def test_http_auth_ext_github_callback_failure_01(self): signing_key = await self.get_signing_key() - expires_at = now + datetime.timedelta(minutes=5) state_claims = { "iss": self.http_addr, "provider": str(provider_name), - "exp": expires_at.timestamp(), "redirect_to": f"{self.http_addr}/some/path", } state_token = self.generate_state_value(state_claims, signing_key) @@ -894,7 +947,6 @@ async def test_http_auth_ext_github_callback_failure_02(self): ) provider_name = provider_config.name - now = utcnow() token_request = ( "POST", "https://github.com", @@ -915,11 +967,9 @@ async def test_http_auth_ext_github_callback_failure_02(self): signing_key = await self.get_signing_key() - expires_at = now + datetime.timedelta(minutes=5) state_claims = { "iss": self.http_addr, "provider": str(provider_name), - "exp": expires_at.timestamp(), "redirect_to": f"{self.http_addr}/some/path", } state_token = self.generate_state_value(state_claims, signing_key) @@ -1090,11 +1140,9 @@ async def test_http_auth_ext_discord_callback_01(self): signing_key = await self.get_signing_key() - expires_at = now + datetime.timedelta(minutes=5) state_claims = { "iss": self.http_addr, "provider": str(provider_name), - "exp": expires_at.timestamp(), "redirect_to": f"{self.http_addr}/some/path", "challenge": challenge, } @@ -1200,8 +1248,6 @@ async def test_http_auth_ext_google_callback_01(self) -> None: client_id = provider_config.client_id client_secret = GOOGLE_SECRET - now = utcnow() - discovery_request = ( "GET", "https://accounts.google.com", @@ -1215,56 +1261,13 @@ async def test_http_auth_ext_google_callback_01(self) -> None: ) ) - jwks_request = ( - "GET", - "https://www.googleapis.com", - "oauth2/v3/certs", - ) - # Generate a JWK Set - k = jwk.JWK.generate(kty='RSA', size=4096) - ks = jwk.JWKSet() - ks.add(k) - jwk_set: dict[str, Any] = ks.export( - private_keys=False, as_dict=True - ) - - self.mock_oauth_server.register_route_handler(*jwks_request)( - ( - json.dumps(jwk_set), - 200, - ) - ) - - token_request = ( - "POST", - "https://oauth2.googleapis.com", - "token", - ) - id_token_claims = { - "iss": "https://accounts.google.com", - "sub": "1", - "aud": client_id, - "exp": (now + datetime.timedelta(minutes=5)).timestamp(), - "iat": now.timestamp(), - "email": "test@example.com", - } - id_token = jwt.JWT(header={"alg": "RS256"}, claims=id_token_claims) - id_token.make_signed_token(k) - - self.mock_oauth_server.register_route_handler(*token_request)( - ( - json.dumps( - { - "access_token": "google_access_token", - "id_token": id_token.serialize(), - "scope": "openid", - "token_type": "bearer", - } - ), - 200, - ) + token_request = self.generate_and_serve_jwk( + client_id, + "https://www.googleapis.com/oauth2/v3/certs", + "https://oauth2.googleapis.com/token", + "https://accounts.google.com", + "google_access_token", ) - challenge = ( base64.urlsafe_b64encode( hashlib.sha256( @@ -1285,11 +1288,9 @@ async def test_http_auth_ext_google_callback_01(self) -> None: signing_key = await self.get_signing_key() - expires_at = now + datetime.timedelta(minutes=5) state_claims = { "iss": self.http_addr, "provider": str(provider_name), - "exp": expires_at.timestamp(), "redirect_to": f"{self.http_addr}/some/path", "challenge": challenge, } @@ -1498,8 +1499,6 @@ async def test_http_auth_ext_azure_callback_01(self) -> None: client_id = provider_config.client_id client_secret = AZURE_SECRET - now = utcnow() - discovery_request = ( "GET", "https://login.microsoftonline.com", @@ -1512,56 +1511,13 @@ async def test_http_auth_ext_azure_callback_01(self) -> None: ) ) - jwks_request = ( - "GET", + token_request = self.generate_and_serve_jwk( + client_id, + "https://login.microsoftonline.com/common/discovery/v2.0/keys", + "https://login.microsoftonline.com/common/oauth2/v2.0/token", "https://login.microsoftonline.com", - "common/discovery/v2.0/keys", - ) - # Generate a JWK Set - k = jwk.JWK.generate(kty='RSA', size=4096) - ks = jwk.JWKSet() - ks.add(k) - jwk_set: dict[str, Any] = ks.export( - private_keys=False, as_dict=True + "azure_access_token", ) - - self.mock_oauth_server.register_route_handler(*jwks_request)( - ( - json.dumps(jwk_set), - 200, - ) - ) - - token_request = ( - "POST", - "https://login.microsoftonline.com", - "common/oauth2/v2.0/token", - ) - id_token_claims = { - "iss": "https://login.microsoftonline.com/common/v2.0", - "sub": "1", - "aud": client_id, - "exp": (now + datetime.timedelta(minutes=5)).timestamp(), - "iat": now.timestamp(), - "email": "test@example.com", - } - id_token = jwt.JWT(header={"alg": "RS256"}, claims=id_token_claims) - id_token.make_signed_token(k) - - self.mock_oauth_server.register_route_handler(*token_request)( - ( - json.dumps( - { - "access_token": "azure_access_token", - "id_token": id_token.serialize(), - "scope": "openid", - "token_type": "bearer", - } - ), - 200, - ) - ) - challenge = ( base64.urlsafe_b64encode( hashlib.sha256( @@ -1582,11 +1538,9 @@ async def test_http_auth_ext_azure_callback_01(self) -> None: signing_key = await self.get_signing_key() - expires_at = now + datetime.timedelta(minutes=5) state_claims = { "iss": self.http_addr, "provider": str(provider_name), - "exp": expires_at.timestamp(), "redirect_to": f"{self.http_addr}/some/path", "challenge": challenge, } @@ -1598,7 +1552,7 @@ async def test_http_auth_ext_azure_callback_01(self) -> None: path="callback", ) - self.assertEqual(data, b"") + self.assertEqual(data, b"", data) self.assertEqual(status, 302) location = headers.get("location") @@ -1714,8 +1668,6 @@ async def test_http_auth_ext_apple_callback_01(self) -> None: client_id = provider_config.client_id client_secret = APPLE_SECRET - now = utcnow() - discovery_request = ( "GET", "https://appleid.apple.com", @@ -1728,54 +1680,12 @@ async def test_http_auth_ext_apple_callback_01(self) -> None: ) ) - jwks_request = ( - "GET", - "https://appleid.apple.com", - "auth/keys", - ) - # Generate a JWK Set - k = jwk.JWK.generate(kty='RSA', size=4096) - ks = jwk.JWKSet() - ks.add(k) - jwk_set: dict[str, Any] = ks.export( - private_keys=False, as_dict=True - ) - - self.mock_oauth_server.register_route_handler(*jwks_request)( - ( - json.dumps(jwk_set), - 200, - ) - ) - - token_request = ( - "POST", + token_request = self.generate_and_serve_jwk( + client_id, + "https://appleid.apple.com/auth/keys", + "https://appleid.apple.com/auth/token", "https://appleid.apple.com", - "auth/token", - ) - id_token_claims = { - "iss": "https://appleid.apple.com", - "sub": "1", - "aud": client_id, - "exp": (now + datetime.timedelta(minutes=5)).timestamp(), - "iat": now.timestamp(), - "email": "test@example.com", - } - id_token = jwt.JWT(header={"alg": "RS256"}, claims=id_token_claims) - id_token.make_signed_token(k) - - self.mock_oauth_server.register_route_handler(*token_request)( - ( - json.dumps( - { - "access_token": "apple_access_token", - "id_token": id_token.serialize(), - "scope": "openid", - "token_type": "bearer", - } - ), - 200, - ) + "apple_access_token", ) challenge = ( @@ -1798,11 +1708,9 @@ async def test_http_auth_ext_apple_callback_01(self) -> None: signing_key = await self.get_signing_key() - expires_at = now + datetime.timedelta(minutes=5) state_claims = { "iss": self.http_addr, "provider": str(provider_name), - "exp": expires_at.timestamp(), "redirect_to": f"{self.http_addr}/some/path", "challenge": challenge, } @@ -1858,8 +1766,6 @@ async def test_http_auth_ext_apple_callback_redirect_on_signup_02( provider_name = provider_config.name client_id = provider_config.client_id - now = utcnow() - discovery_request = ( "GET", "https://appleid.apple.com", @@ -1872,54 +1778,13 @@ async def test_http_auth_ext_apple_callback_redirect_on_signup_02( ) ) - jwks_request = ( - "GET", + _token_request = self.generate_and_serve_jwk( + client_id, + "https://appleid.apple.com/auth/keys", + "https://appleid.apple.com/auth/token", "https://appleid.apple.com", - "auth/keys", - ) - # Generate a JWK Set - k = jwk.JWK.generate(kty='RSA', size=4096) - ks = jwk.JWKSet() - ks.add(k) - jwk_set: dict[str, Any] = ks.export( - private_keys=False, as_dict=True - ) - - self.mock_oauth_server.register_route_handler(*jwks_request)( - ( - json.dumps(jwk_set), - 200, - ) - ) - - token_request = ( - "POST", - "https://appleid.apple.com", - "auth/token", - ) - id_token_claims = { - "iss": "https://appleid.apple.com", - "sub": "2", - "aud": client_id, - "exp": (now + datetime.timedelta(minutes=5)).timestamp(), - "iat": now.timestamp(), - "email": "test@example.com", - } - id_token = jwt.JWT(header={"alg": "RS256"}, claims=id_token_claims) - id_token.make_signed_token(k) - - self.mock_oauth_server.register_route_handler(*token_request)( - ( - json.dumps( - { - "access_token": "apple_access_token", - "id_token": id_token.serialize(), - "scope": "openid", - "token_type": "bearer", - } - ), - 200, - ) + "apple_access_token", + sub="2", ) challenge = ( @@ -1942,11 +1807,9 @@ async def test_http_auth_ext_apple_callback_redirect_on_signup_02( signing_key = await self.get_signing_key() - expires_at = now + datetime.timedelta(minutes=5) state_claims = { "iss": self.http_addr, "provider": str(provider_name), - "exp": expires_at.timestamp(), "redirect_to": f"{self.http_addr}/some/path", "redirect_to_on_signup": f"{self.http_addr}/some/other/path", "challenge": challenge, @@ -1964,7 +1827,7 @@ async def test_http_auth_ext_apple_callback_redirect_on_signup_02( headers={"Content-Type": "application/x-www-form-urlencoded"}, ) - self.assertEqual(data, b"") + self.assertEqual(data, b"", data) self.assertEqual(status, 302) location = headers.get("location") @@ -2006,8 +1869,6 @@ async def test_http_auth_ext_slack_callback_01(self) -> None: client_id = provider_config.client_id client_secret = SLACK_SECRET - now = utcnow() - discovery_request = ( "GET", "https://slack.com", @@ -2021,54 +1882,12 @@ async def test_http_auth_ext_slack_callback_01(self) -> None: ) ) - jwks_request = ( - "GET", - "https://slack.com", - "openid/connect/keys", - ) - # Generate a JWK Set - k = jwk.JWK.generate(kty='RSA', size=4096) - ks = jwk.JWKSet() - ks.add(k) - jwk_set: dict[str, Any] = ks.export( - private_keys=False, as_dict=True - ) - - self.mock_oauth_server.register_route_handler(*jwks_request)( - ( - json.dumps(jwk_set), - 200, - ) - ) - - token_request = ( - "POST", + token_request = self.generate_and_serve_jwk( + client_id, + "https://slack.com/openid/connect/keys", + "https://slack.com/api/openid.connect.token", "https://slack.com", - "api/openid.connect.token", - ) - id_token_claims = { - "iss": "https://slack.com", - "sub": "1", - "aud": client_id, - "exp": (now + datetime.timedelta(minutes=5)).timestamp(), - "iat": now.timestamp(), - "email": "test@example.com", - } - id_token = jwt.JWT(header={"alg": "RS256"}, claims=id_token_claims) - id_token.make_signed_token(k) - - self.mock_oauth_server.register_route_handler(*token_request)( - ( - json.dumps( - { - "access_token": "slack_access_token", - "id_token": id_token.serialize(), - "scope": "openid", - "token_type": "bearer", - } - ), - 200, - ) + "slack_access_token", ) challenge = ( @@ -2091,11 +1910,9 @@ async def test_http_auth_ext_slack_callback_01(self) -> None: signing_key = await self.get_signing_key() - expires_at = now + datetime.timedelta(minutes=5) state_claims = { "iss": self.http_addr, "provider": str(provider_name), - "exp": expires_at.timestamp(), "redirect_to": f"{self.http_addr}/some/path", "challenge": challenge, } @@ -2313,8 +2130,6 @@ async def test_http_auth_ext_generic_oidc_callback_01(self): client_id = provider_config.client_id client_secret = GENERIC_OIDC_SECRET - now = utcnow() - discovery_request = ( "GET", "https://example.com", @@ -2328,54 +2143,12 @@ async def test_http_auth_ext_generic_oidc_callback_01(self): ) ) - jwks_request = ( - "GET", + token_request = self.generate_and_serve_jwk( + client_id, + "https://example.com/jwks", + "https://example.com/token", "https://example.com", - "jwks", - ) - # Generate a JWK Set - k = jwk.JWK.generate(kty='RSA', size=4096) - ks = jwk.JWKSet() - ks.add(k) - jwk_set: dict[str, Any] = ks.export( - private_keys=False, as_dict=True - ) - - self.mock_oauth_server.register_route_handler(*jwks_request)( - ( - json.dumps(jwk_set), - 200, - ) - ) - - token_request = ( - "POST", - "https://example.com", - "token", - ) - id_token_claims = { - "iss": "https://example.com", - "sub": "1", - "aud": client_id, - "exp": (now + datetime.timedelta(minutes=5)).timestamp(), - "iat": now.timestamp(), - "email": "test@example.com", - } - id_token = jwt.JWT(header={"alg": "RS256"}, claims=id_token_claims) - id_token.make_signed_token(k) - - self.mock_oauth_server.register_route_handler(*token_request)( - ( - json.dumps( - { - "access_token": "oidc_access_token", - "id_token": id_token.serialize(), - "scope": "openid", - "token_type": "bearer", - } - ), - 200, - ) + "oidc_access_token", ) challenge = ( @@ -2398,11 +2171,9 @@ async def test_http_auth_ext_generic_oidc_callback_01(self): signing_key = await self.get_signing_key() - expires_at = now + datetime.timedelta(minutes=5) state_claims = { "iss": self.http_addr, "provider": str(provider_name), - "exp": expires_at.timestamp(), "redirect_to": f"{self.http_addr}/some/path", "challenge": challenge, } @@ -2880,7 +2651,7 @@ async def test_http_auth_ext_local_password_register_json_02(self): headers={"Content-Type": "application/json"}, ) - self.assertEqual(status, 201) + self.assertEqual(status, 201, body) identity = await self.con.query_single( """ @@ -3281,7 +3052,7 @@ async def test_http_auth_ext_local_emailpassword_resend_verification(self): } resend_data_encoded = urllib.parse.urlencode(resend_data).encode() - _, _, status = self.http_con_request( + body, _, status = self.http_con_request( http_con, None, path="resend-verification-email", @@ -3290,7 +3061,7 @@ async def test_http_auth_ext_local_emailpassword_resend_verification(self): headers={"Content-Type": "application/x-www-form-urlencoded"}, ) - self.assertEqual(status, 200) + self.assertEqual(status, 200, body) # Resend verification email with just the email resend_data = { @@ -3588,7 +3359,7 @@ async def test_http_auth_ext_token_01(self): path="token", ) - self.assertEqual(status, 200) + self.assertEqual(status, 200, body) body_json = json.loads(body) self.assertEqual( body_json, @@ -3771,12 +3542,9 @@ async def test_http_auth_ext_local_password_forgot_form_01(self): claims = await self.extract_jwt_claims( reset_url.split('=', maxsplit=1)[1], "reset" ) + # Expiry checked as part of the validation self.assertEqual(claims.get("sub"), str(identity[0].id)) self.assertEqual(claims.get("iss"), str(self.http_addr)) - now = utcnow() - tenMinutesLater = now + datetime.timedelta(minutes=10) - self.assertTrue(claims.get("exp") > now.timestamp()) - self.assertTrue(claims.get("exp") < tenMinutesLater.timestamp()) password_credential = await self.con.query( """ @@ -3932,7 +3700,7 @@ async def test_http_auth_ext_local_password_reset_form_01(self): headers={"Content-Type": "application/x-www-form-urlencoded"}, ) - self.assertEqual(status, 200) + self.assertEqual(status, 200, body) file_name_hash = hashlib.sha256( f"{SENDER}{email}".encode() @@ -4407,7 +4175,7 @@ async def test_http_auth_ext_magic_link_with_link_url(self): with self.http_con() as http_con: # Register with link_url - _, _, status = self.http_con_request( + body, _, status = self.http_con_request( http_con, method="POST", path="magic-link/register", @@ -4426,7 +4194,7 @@ async def test_http_auth_ext_magic_link_with_link_url(self): "Accept": "application/json", }, ) - self.assertEqual(status, 200) + self.assertEqual(status, 200, body) # Get the token from email file_name_hash = hashlib.sha256( @@ -4559,7 +4327,7 @@ async def test_http_auth_ext_magic_link_without_link_url(self): with self.http_con() as http_con: # Register without link_url - _, _, status = self.http_con_request( + body, _, status = self.http_con_request( http_con, method="POST", path="magic-link/register", @@ -4577,7 +4345,7 @@ async def test_http_auth_ext_magic_link_without_link_url(self): "Accept": "application/json", }, ) - self.assertEqual(status, 200) + self.assertEqual(status, 200, body) # Get the token from email file_name_hash = hashlib.sha256( diff --git a/tests/test_server_auth.py b/tests/test_server_auth.py index eeeeeefdd572..3f7c30d54980 100644 --- a/tests/test_server_auth.py +++ b/tests/test_server_auth.py @@ -27,15 +27,13 @@ import urllib.error import urllib.request -import jwcrypto.jwk - import edgedb from edb import errors from edb import protocol -from edb.common import secretkey from edb.server import args from edb.server import cluster as edbcluster +from edb.server.auth import JWKSet, generate_gel_token, load_secret_key from edb.schema import defines as s_def from edb.testbase import server as tb @@ -374,15 +372,16 @@ def _jwt_gql_request(self, server, *, sk=None, password=None): async def test_server_auth_jwt_1(self): jwk_fd, jwk_file = tempfile.mkstemp() - key = jwcrypto.jwk.JWK(generate='EC') + jws = JWKSet() + jws.generate(kid=None, kty="ES256") with open(jwk_fd, "wb") as f: - f.write(key.export_to_pem(private_key=True, password=None)) - jwk = secretkey.load_secret_key(pathlib.Path(jwk_file)) + f.write(jws.export_pem()) + async with tb.start_edgedb_server( jws_key_file=pathlib.Path(jwk_file), default_auth_method=args.ServerAuthMethod.JWT, ) as sd: - base_sk = secretkey.generate_secret_key(jwk) + base_sk = generate_gel_token(jws) conn = await sd.connect(secret_key=base_sk) await conn.execute(''' CREATE SUPERUSER ROLE foo { @@ -419,7 +418,7 @@ async def test_server_auth_jwt_1(self): ): await sd.connect(secret_key='wrong') - sk = secretkey.generate_secret_key(jwk) + sk = generate_gel_token(jws) corrupt_sk = sk[:50] + "0" + sk[51:] with self.assertRaisesRegex( @@ -467,7 +466,7 @@ async def test_server_auth_jwt_1(self): for params in good_keys: params_dict = dict(params) with self.subTest(**params_dict): - sk = secretkey.generate_secret_key(jwk, **params_dict) + sk = generate_gel_token(jws, **params_dict) conn = await sd.connect(secret_key=sk) await conn.aclose() @@ -491,7 +490,7 @@ async def test_server_auth_jwt_1(self): for params, msg in bad_keys.items(): params_dict = dict(params) with self.subTest(**params_dict): - sk = secretkey.generate_secret_key(jwk, **params_dict) + sk = generate_gel_token(jws, **params_dict) with self.assertRaisesRegex( edgedb.AuthenticationError, "authentication failed: " + msg, @@ -510,9 +509,10 @@ async def test_server_auth_jwt_1(self): async def test_server_auth_jwt_2(self): jwk_fd, jwk_file = tempfile.mkstemp() - key = jwcrypto.jwk.JWK(generate='EC') + jws = JWKSet() + jws.generate(kid=None, kty="ES256") with open(jwk_fd, "wb") as f: - f.write(key.export_to_pem(private_key=True, password=None)) + f.write(jws.export_pem()) allowlist_fd, allowlist_file = tempfile.mkstemp() os.close(allowlist_fd) @@ -539,7 +539,7 @@ async def test_server_auth_jwt_2(self): jwt_revocation_list_file=revokelist_file, ) as sd: - jwk = secretkey.load_secret_key(pathlib.Path(jwk_file)) + jwk = load_secret_key(pathlib.Path(jwk_file)) # enable JWT conn = await sd.connect() @@ -555,14 +555,14 @@ async def test_server_auth_jwt_2(self): await conn.aclose() # Try connecting with "test" not being in the allowlist. - sk = secretkey.generate_secret_key( + sk = generate_gel_token( jwk, subject=subject, key_id=key_id, ) with self.assertRaisesRegex( edgedb.AuthenticationError, - 'authentication failed: unauthorized subject', + 'authentication failed: Verification failed', ): await sd.connect(secret_key=sk) @@ -596,10 +596,11 @@ async def test_server_auth_jwt_2(self): async def test_server_auth_multiple_methods(self): jwk_fd, jwk_file = tempfile.mkstemp() - key = jwcrypto.jwk.JWK(generate='EC') + jws = JWKSet() + jws.generate(kid=None, kty="ES256") with open(jwk_fd, "wb") as f: - f.write(key.export_to_pem(private_key=True, password=None)) - jwk = secretkey.load_secret_key(pathlib.Path(jwk_file)) + f.write(jws.export_pem()) + jwk = load_secret_key(pathlib.Path(jwk_file)) async with tb.start_edgedb_server( jws_key_file=pathlib.Path(jwk_file), default_auth_method=args.ServerAuthMethods({ @@ -613,7 +614,7 @@ async def test_server_auth_multiple_methods(self): ], }), ) as sd: - base_sk = secretkey.generate_secret_key(jwk) + base_sk = generate_gel_token(jwk) conn = await sd.connect(secret_key=base_sk) await conn.execute(''' CREATE EXTENSION edgeql_http; @@ -633,7 +634,7 @@ async def test_server_auth_multiple_methods(self): c1 = await sd.connect(secret_key='wrong') await c1.aclose() - sk = secretkey.generate_secret_key(jwk) + sk = generate_gel_token(jwk) body, _, code = await self._jwt_http_request(sd, sk=sk) self.assertEqual(code, 200, f"Wrong result: {body}")