diff --git a/edb/server/metrics.py b/edb/server/metrics.py index 352f04a25a69..c76b62378e31 100644 --- a/edb/server/metrics.py +++ b/edb/server/metrics.py @@ -223,30 +223,6 @@ labels=("tenant",), ) -auth_provider_jwkset_fetch_success = registry.new_labeled_counter( - "auth_provider_jwkset_fetch_success_total", - "Number of successful Auth extension JWK Set fetches.", - labels=("provider",), -) - -auth_provider_jwkset_fetch_errors = registry.new_labeled_counter( - "auth_provider_jwkset_fetch_errors_total", - "Number of failed Auth extension JWK Set fetches.", - labels=("provider",), -) - -auth_provider_token_validation_success = registry.new_labeled_counter( - "auth_provider_token_validation_success_total", - "Number of successful Auth extension provider token validations.", - labels=("provider",), -) - -auth_provider_token_validation_errors = registry.new_labeled_counter( - "auth_provider_token_validation_errors_total", - "Number of failed Auth extension provider token validations.", - labels=("provider",), -) - mt_tenants_total = registry.new_gauge( 'mt_tenants_current', 'Total number of currently-registered tenants.', diff --git a/edb/server/protocol/auth_ext/base.py b/edb/server/protocol/auth_ext/base.py index 2884b4b6e285..d43a492e4a61 100644 --- a/edb/server/protocol/auth_ext/base.py +++ b/edb/server/protocol/auth_ext/base.py @@ -18,19 +18,15 @@ import uuid import urllib.parse +import json import enum -import logging from typing import Any, Callable +from jwcrypto import jwt, jwk from datetime import datetime from . import data, errors from edb.server.http import HttpClient -from edb.server import auth as jwt_auth -from edb.server.protocol.auth_ext import util as auth_util -from edb.server import metrics - -logger = logging.getLogger("edb.server.ext.auth") class BaseProvider: @@ -149,48 +145,28 @@ async def fetch_user_info( ) id_token = token_response.id_token - # Retrieve JWK Set, potentially from the cache + # Retrieve JWK Set oidc_config = await self._get_oidc_config() + jwks_uri = urllib.parse.urlparse(oidc_config.jwks_uri) + async with self.http_factory( + base_url=f"{jwks_uri.scheme}://{jwks_uri.netloc}" + ) as client: + r = await client.get(jwks_uri.path) + # Load the token as a JWT object and verify it directly try: - async def fetcher(url: str) -> jwt_auth.JWKSet: - jwks_uri = urllib.parse.urlparse(url) - async with self.http_factory( - base_url=f"{jwks_uri.scheme}://{jwks_uri.netloc}" - ) as client: - r = await client.get(jwks_uri.path) - jwk_set = jwt_auth.JWKSet() - jwk_set.load_json(r.text) - jwk_set.set_audience(self.client_id) - jwk_set.set_expiry(3600) - metrics.auth_provider_jwkset_fetch_success.inc( - 1.0, self.name - ) - return jwk_set - - jwk_set = await auth_util.get_remote_jwtset( - oidc_config.jwks_uri, fetcher - ) - except Exception as e: - metrics.auth_provider_jwkset_fetch_errors.inc(1.0, self.name) - logger.exception( - f"Failed to fetch JWK Set from provider {oidc_config.jwks_uri}" - ) - raise errors.MisconfiguredProvider( - f"Failed to fetch JWK Set from provider {oidc_config.jwks_uri}" - ) from e - - # Load the token as a JWT object and verify it directly. This will - # validate the audience and expiry. - try: - payload = jwk_set.validate(id_token) + jwk_set = jwk.JWKSet.from_json(r.text) + id_token_verified = jwt.JWT(key=jwk_set, jwt=id_token) + payload = json.loads(id_token_verified.claims) except Exception as e: - metrics.auth_provider_token_validation_errors.inc(1.0, self.name) raise errors.MisconfiguredProvider( "Failed to parse ID token with provider keyset" ) from e - - metrics.auth_provider_token_validation_success.inc(1.0, self.name) + if payload.get("aud") != self.client_id: + raise errors.InvalidData( + "Invalid value for aud in id_token: " + f"{payload.get('aud')} != {self.client_id}" + ) return data.UserInfo( sub=str(payload["sub"]), diff --git a/edb/server/protocol/auth_ext/http.py b/edb/server/protocol/auth_ext/http.py index 3585ee92f836..6fe948b7916c 100644 --- a/edb/server/protocol/auth_ext/http.py +++ b/edb/server/protocol/auth_ext/http.py @@ -43,12 +43,13 @@ ) import aiosmtplib +from jwcrypto import jwk, jwt from edb import errors as edb_errors from edb.common import debug from edb.common import markup from edb.ir import statypes -from edb.server import tenant as edbtenant, metrics, auth as jwt_auth +from edb.server import tenant as edbtenant, metrics from edb.server.config.types import CompositeConfigType from . import ( @@ -2064,16 +2065,13 @@ async def _maybe_send_webhook(self, event: webhook.Event) -> None: def _get_callback_url(self) -> str: return f"{self.base_path}/callback" - def _get_auth_signing_key(self, info: str | None = None) -> jwt_auth.JWKSet: + def _get_auth_signing_key(self) -> jwk.JWK: auth_signing_key = util.get_config( self.db, "ext::auth::AuthConfig::auth_signing_key" ) - if info is None: - return jwt_auth.JWKSet.from_hs256_key(auth_signing_key.encode()) - else: - return jwt_auth.JWKSet.from_hs256_key( - util.derive_key_raw(auth_signing_key, info) - ) + key_bytes = base64.b64encode(auth_signing_key.encode()) + + return jwk.JWK(kty="oct", k=key_bytes.decode()) def _make_state_claims( self, @@ -2083,19 +2081,25 @@ def _make_state_claims( challenge: str, ) -> str: signing_key = self._get_auth_signing_key() - signing_ctx = jwt_auth.SigningCtx() - signing_ctx.set_expiry(5 * 60) - signing_ctx.set_not_before(30) - signing_ctx.set_issuer(self.base_path) + expires_at = datetime.datetime.now( + datetime.timezone.utc + ) + datetime.timedelta(minutes=5) state_claims = { + "iss": self.base_path, "provider": provider, + "exp": expires_at.timestamp(), "redirect_to": redirect_to, "challenge": challenge, } if redirect_to_on_signup: state_claims['redirect_to_on_signup'] = redirect_to_on_signup - return signing_key.sign(state_claims, ctx=signing_ctx) + state_token = jwt.JWT( + header={"alg": "HS256"}, + claims=state_claims, + ) + state_token.make_signed_token(signing_key) + return cast(str, state_token.serialize()) def _make_session_token(self, identity_id: str) -> str: signing_key = self._get_auth_signing_key() @@ -2104,15 +2108,34 @@ def _make_session_token(self, identity_id: str) -> str: "ext::auth::AuthConfig::token_time_to_live", statypes.Duration, ) - signing_ctx = jwt_auth.SigningCtx() - signing_ctx.set_expiry(int(auth_expiration_time.to_timedelta().total_seconds())) - signing_ctx.set_not_before(30) - signing_key.set_issuer(self.base_path) - session_token = signing_key.sign({ + expires_in = auth_expiration_time.to_timedelta() + expires_at = datetime.datetime.now(datetime.timezone.utc) + expires_in + + claims: dict[str, Any] = { + "iss": self.base_path, "sub": identity_id, - }, ctx=signing_ctx) + } + if expires_in.total_seconds() != 0: + claims["exp"] = expires_at.timestamp() + session_token = jwt.JWT( + header={"alg": "HS256"}, + claims=claims, + ) + session_token.make_signed_token(signing_key) metrics.auth_successful_logins.inc(1.0, self.tenant.get_instance_name()) - return session_token + return cast(str, session_token.serialize()) + + def _get_from_claims(self, state: str, key: str) -> str: + signing_key = self._get_auth_signing_key() + try: + state_token = jwt.JWT(key=signing_key, jwt=state) + except Exception: + raise errors.InvalidData("Invalid state token") + state_claims: dict[str, str] = json.loads(state_token.claims) + value = state_claims.get(key) + if value is None: + raise errors.InvalidData("Invalid state token") + return value def _make_secret_token( self, @@ -2124,7 +2147,8 @@ def _make_secret_token( ) = None, expires_in: datetime.timedelta | None = None, ) -> str: - signing_key = self._get_auth_signing_key(derive_for_info) + input_key_material = self._get_auth_signing_key() + signing_key = util.derive_key(input_key_material, derive_for_info) expires_in = ( datetime.timedelta(minutes=10) if expires_in is None else expires_in ) @@ -2142,8 +2166,15 @@ def _make_secret_token( def _verify_and_extract_claims( self, jwtStr: str, key_info: str | None = None ) -> dict[str, str | int | float | bool]: - signing_key = self._get_auth_signing_key(key_info) - return signing_key.validate(jwtStr) + input_key_material = self._get_auth_signing_key() + if key_info is None: + signing_key = input_key_material + else: + signing_key = util.derive_key(input_key_material, key_info) + verified = jwt.JWT(key=signing_key, jwt=jwtStr) + return cast( + dict[str, str | int | float | bool], json.loads(verified.claims) + ) def _get_data_from_magic_link_token( self, token: str @@ -2219,14 +2250,14 @@ def _get_data_from_verification_token( ): case ( str(id), - float(issued_at) | int(issued_at), + float(issued_at), verify_url, challenge, redirect_to, ): return_value = ( id, - float(issued_at), + issued_at, verify_url, challenge, redirect_to, @@ -2324,8 +2355,7 @@ def _make_verification_token( "Verify URL does not match any allowed URLs.", ) - now = datetime.datetime.now(datetime.timezone.utc) - issued_at = int(now.timestamp()) + issued_at = datetime.datetime.now(datetime.timezone.utc).timestamp() return self._make_secret_token( identity_id=identity_id, secret=str(uuid.uuid4()), diff --git a/edb/server/protocol/auth_ext/local.py b/edb/server/protocol/auth_ext/local.py index 286c13b72117..2022eea504a5 100644 --- a/edb/server/protocol/auth_ext/local.py +++ b/edb/server/protocol/auth_ext/local.py @@ -19,10 +19,12 @@ import datetime import json +import base64 +from jwcrypto import jwk from typing import Any, cast from edb.server.protocol import execute -from edb.server import auth as jwt_auth + from . import util, data @@ -30,16 +32,13 @@ class Client: def __init__(self, db: Any): self.db = db - def _get_signing_key(self, info: str | None = None) -> jwt_auth.JWKSet: + def _get_signing_key(self) -> jwk.JWK: auth_signing_key = util.get_config( self.db, "ext::auth::AuthConfig::auth_signing_key" ) - if info is None: - return jwt_auth.JWKSet.from_hs256_key(auth_signing_key.encode()) - else: - return jwt_auth.JWKSet.from_hs256_key( - util.derive_key_raw(auth_signing_key, info) - ) + key_bytes = base64.b64encode(auth_signing_key.encode()) + + return jwk.JWK(kty="oct", k=key_bytes.decode()) async def verify_email( self, identity_id: str, verified_at: datetime.datetime diff --git a/edb/server/protocol/auth_ext/magic_link.py b/edb/server/protocol/auth_ext/magic_link.py index ae478093cfe3..8474a5c886ee 100644 --- a/edb/server/protocol/auth_ext/magic_link.py +++ b/edb/server/protocol/auth_ext/magic_link.py @@ -103,7 +103,8 @@ def make_magic_link_token( callback_url: str, challenge: str, ) -> str: - signing_key = self._get_signing_key("magic_link") + initial_key_material = self._get_signing_key() + signing_key = util.derive_key(initial_key_material, "magic_link") return util.make_token( signing_key=signing_key, issuer=self.issuer, diff --git a/edb/server/protocol/auth_ext/util.py b/edb/server/protocol/auth_ext/util.py index 98619a3d827a..57b6977386bf 100644 --- a/edb/server/protocol/auth_ext/util.py +++ b/edb/server/protocol/auth_ext/util.py @@ -19,22 +19,19 @@ from __future__ import annotations +import base64 import urllib.parse import datetime import html -import logging -import asyncio from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.kdf.hkdf import HKDFExpand from cryptography.hazmat.backends import default_backend -from typing import ( - TypeVar, Type, overload, Any, cast, Optional, TYPE_CHECKING, Callable, - Awaitable -) +from jwcrypto import jwt, jwk +from typing import TypeVar, Type, overload, Any, cast, Optional, TYPE_CHECKING -from edb.server import config as edb_config, auth as jwt_auth +from edb.server import config as edb_config from edb.server.config.types import CompositeConfigType from . import errors, config @@ -44,11 +41,6 @@ T = TypeVar("T") -logger = logging.getLogger('edb.server.ext.auth') - -# Cache JWKSets for 10 minutes -jwtset_cache = jwt_auth.JWKSetCache(60 * 10) - def maybe_get_config_unchecked(db: edbtenant.dbview.Database, key: str) -> Any: return edb_config.lookup(key, db.db_config, spec=db.user_config_spec) @@ -162,60 +154,45 @@ def join_url_params(url: str, params: dict[str, str]) -> str: return parsed_url._replace(query=new_query_params).geturl() -async def get_remote_jwtset( - url: str, fetch_lambda: Callable[[str], Awaitable[jwt_auth.JWKSet]] -) -> jwt_auth.JWKSet: - """ - Get a JWKSet from the cache, or fetch it from the given URL if it's not in - the cache. - """ - is_fresh, jwtset = jwtset_cache.get(url) - if is_fresh and jwtset is not None: - return jwtset - - if jwtset is None: - jwtset = await fetch_lambda(url) - jwtset_cache.set(url, jwtset) - return jwtset - - # Run fetch in background to refresh cache - async def refresh_cache(url: str) -> None: - try: - new_jwtset = await fetch_lambda(url) - jwtset_cache.set(url, new_jwtset) - except Exception: - logger.exception("Failed to refresh JWKSet cache for %s", url) - - asyncio.create_task(refresh_cache(url)) - return jwtset - - def make_token( - signing_key: jwt_auth.JWKSet, + signing_key: jwk.JWK, issuer: str, subject: str, additional_claims: dict[str, str | int | float | bool | None] | None = None, include_issued_at: bool = False, expires_in: datetime.timedelta | None = None, ) -> str: - signing_ctx = jwt_auth.SigningCtx() - signing_ctx.set_issuer(issuer) - if expires_in is not None and int(expires_in.total_seconds()) != 0: - signing_ctx.set_expiry(int(expires_in.total_seconds())) - if include_issued_at: - signing_ctx.set_not_before(30) + now = datetime.datetime.now(datetime.timezone.utc) + expires_in = ( + datetime.timedelta(seconds=0) if expires_in is None else expires_in + ) + expires_at = now + expires_in claims: dict[str, Any] = { + "iss": issuer, "sub": subject, **(additional_claims or {}), } + if expires_in.total_seconds() != 0: + claims["exp"] = expires_at.timestamp() + if include_issued_at: + claims["iat"] = now.timestamp() - return signing_key.sign(claims, ctx=signing_ctx) + token = jwt.JWT( + header={"alg": "HS256"}, + claims=claims, + ) + token.make_signed_token(signing_key) + return cast(str, token.serialize()) -def derive_key_raw(key: str, info: str) -> bytes: + +def derive_key(key: jwk.JWK, info: str) -> jwk.JWK: """Derive a new key from the given symmetric key using HKDF.""" - input_key_material = key.encode() + + # n.b. the key is returned as a base64url-encoded string + raw_key_base64url = cast(str, key.get_op_key()) + input_key_material = base64.urlsafe_b64decode(raw_key_base64url) backend = default_backend() hkdf = HKDFExpand( @@ -225,4 +202,7 @@ def derive_key_raw(key: str, info: str) -> bytes: backend=backend, ) new_key_bytes = hkdf.derive(input_key_material) - return new_key_bytes + return jwk.JWK( + kty="oct", + k=new_key_bytes.hex(), + ) diff --git a/tests/test_http_ext_auth.py b/tests/test_http_ext_auth.py index 0c51d402a1c5..1fd6323e9352 100644 --- a/tests/test_http_ext_auth.py +++ b/tests/test_http_ext_auth.py @@ -30,14 +30,14 @@ import hashlib import hmac -from typing import Optional, cast +from typing import Any, Optional, cast +from jwcrypto import jwt, jwk from email.message import EmailMessage from edgedb import QueryAssertionError from edb.testbase import http as tb from edb.common import assert_data_shape from edb.server.protocol.auth_ext import util as auth_util -from edb.server.auth import JWKSet ph = argon2.PasswordHasher() @@ -359,7 +359,6 @@ def setUpClass(cls): mock_oauth_server: tb.MockHttpServer mock_net_server: tb.MockHttpServer - jwkset_cache: dict[str, JWKSet] = {} def setUp(self): self.mock_oauth_server = tb.MockHttpServer( @@ -450,28 +449,34 @@ async def get_auth_config_value(self, key: str): """ ) - async def get_signing_key(self, info: str | None = None) -> JWKSet: + async def get_signing_key(self): auth_signing_key = SIGNING_KEY - if info is not None: - signing_key = JWKSet.from_hs256_key( - auth_util.derive_key_raw(auth_signing_key, info) - ) - else: - signing_key = JWKSet.from_hs256_key(auth_signing_key.encode()) - signing_key.set_expiry(3600) - signing_key.set_not_before(30) + key_bytes = base64.b64encode(auth_signing_key.encode()) + signing_key = jwk.JWK(k=key_bytes.decode(), kty="oct") return signing_key def generate_state_value( self, - state_claims: dict[str, str], - auth_signing_key: JWKSet, + state_claims: dict[str, str | float], + auth_signing_key: jwk.JWK, ) -> str: - return auth_signing_key.sign(state_claims) + state_token = jwt.JWT( + header={"alg": "HS256"}, + claims=state_claims, + ) + state_token.make_signed_token(auth_signing_key) + return state_token.serialize() async def extract_jwt_claims(self, raw_jwt: str, info: str | None = None): - signing_key = await self.get_signing_key(info) - return signing_key.validate(raw_jwt) + input_key_material = await self.get_signing_key() + if info is not None: + signing_key = auth_util.derive_key(input_key_material, info) + else: + signing_key = input_key_material + + jwt_token = jwt.JWT(key=signing_key, jwt=raw_jwt) + claims = json.loads(jwt_token.claims) + return claims def maybe_get_cookie_value( self, headers: dict[str, str], name: str @@ -487,76 +492,6 @@ def maybe_get_cookie_value( def maybe_get_auth_token(self, headers: dict[str, str]) -> Optional[str]: return self.maybe_get_cookie_value(headers, "edgedb-session") - def generate_and_serve_jwk( - self, - client_id: str, - jwk_cert_url: str, - token_url: str, - issuer: str, - access_token_name: str, - sub: str = "1", - ) -> tuple[str, str, str]: - parts = jwk_cert_url.split("/", 3) - host = parts[0] + "//" + parts[2] - path = parts[3] - jwks_request = ( - "GET", - host, - path, - ) - - # Because we have an internal cache, ensure that we only generate one - # set per issuer - jwks = self.jwkset_cache.get(issuer) - if jwks is None: - jwks = JWKSet() - jwks.set_issuer(issuer) - jwks.generate(kid=None, kty="RS256") - self.jwkset_cache[issuer] = jwks - - jwk_json = jwks.export_json(private_keys=False).decode() - - self.mock_oauth_server.register_route_handler(*jwks_request)( - ( - jwk_json, - 200, - ) - ) - - parts = token_url.split("/", 3) - host = parts[0] + "//" + parts[2] - path = parts[3] - token_request = ( - "POST", - host, - path, - ) - - jwks.set_issuer(issuer) - jwks.set_audience(client_id) - jwks.set_expiry(3600) - jwks.set_not_before(30) - - id_token = jwks.sign({ - "sub": sub, - "email": "test@example.com", - }) - - self.mock_oauth_server.register_route_handler(*token_request)( - ( - json.dumps( - { - "access_token": access_token_name, - "id_token": id_token, - "scope": "openid", - "token_type": "bearer", - } - ), - 200, - ) - ) - return token_request - async def test_http_auth_ext_github_authorize_01(self): with self.http_con() as http_con: provider_config = await self.get_builtin_provider_config_by_name( @@ -641,8 +576,10 @@ async def test_http_auth_ext_github_callback_missing_provider_01(self): with self.http_con() as http_con: signing_key = await self.get_signing_key() + expires_at = utcnow() + datetime.timedelta(minutes=5) missing_provider_state_claims = { "iss": self.http_addr, + "exp": expires_at.timestamp(), } state_token = self.generate_state_value( missing_provider_state_claims, signing_key @@ -662,12 +599,15 @@ async def test_http_auth_ext_github_callback_wrong_key_01(self): "oauth_github" ) provider_name = provider_config.name - signing_key = JWKSet() - signing_key.generate(kid=None, kty="ES256") + signing_key = jwk.JWK( + k=base64.b64encode(("abcd" * 8).encode()).decode(), kty="oct" + ) + expires_at = utcnow() + datetime.timedelta(minutes=5) missing_provider_state_claims = { "iss": self.http_addr, "provider": provider_name, + "exp": expires_at.timestamp(), } state_token_value = self.generate_state_value( missing_provider_state_claims, signing_key @@ -685,9 +625,11 @@ async def test_http_auth_ext_github_unknown_provider_01(self): with self.http_con() as http_con: signing_key = await self.get_signing_key() + expires_at = utcnow() + datetime.timedelta(minutes=5) state_claims = { "iss": self.http_addr, "provider": "beepboopbeep", + "exp": expires_at.timestamp(), } state_token = self.generate_state_value(state_claims, signing_key) @@ -777,9 +719,11 @@ async def test_http_auth_ext_github_callback_01(self): signing_key = await self.get_signing_key() + expires_at = now + datetime.timedelta(minutes=5) state_claims = { "iss": self.http_addr, "provider": str(provider_name), + "exp": expires_at.timestamp(), "redirect_to": f"{self.http_addr}/some/path", "challenge": challenge, } @@ -884,6 +828,7 @@ async def test_http_auth_ext_github_callback_failure_01(self): ) provider_name = provider_config.name + now = utcnow() token_request = ( "POST", "https://github.com", @@ -904,9 +849,11 @@ async def test_http_auth_ext_github_callback_failure_01(self): signing_key = await self.get_signing_key() + expires_at = now + datetime.timedelta(minutes=5) state_claims = { "iss": self.http_addr, "provider": str(provider_name), + "exp": expires_at.timestamp(), "redirect_to": f"{self.http_addr}/some/path", } state_token = self.generate_state_value(state_claims, signing_key) @@ -947,6 +894,7 @@ async def test_http_auth_ext_github_callback_failure_02(self): ) provider_name = provider_config.name + now = utcnow() token_request = ( "POST", "https://github.com", @@ -967,9 +915,11 @@ async def test_http_auth_ext_github_callback_failure_02(self): signing_key = await self.get_signing_key() + expires_at = now + datetime.timedelta(minutes=5) state_claims = { "iss": self.http_addr, "provider": str(provider_name), + "exp": expires_at.timestamp(), "redirect_to": f"{self.http_addr}/some/path", } state_token = self.generate_state_value(state_claims, signing_key) @@ -1140,9 +1090,11 @@ async def test_http_auth_ext_discord_callback_01(self): signing_key = await self.get_signing_key() + expires_at = now + datetime.timedelta(minutes=5) state_claims = { "iss": self.http_addr, "provider": str(provider_name), + "exp": expires_at.timestamp(), "redirect_to": f"{self.http_addr}/some/path", "challenge": challenge, } @@ -1248,6 +1200,8 @@ async def test_http_auth_ext_google_callback_01(self) -> None: client_id = provider_config.client_id client_secret = GOOGLE_SECRET + now = utcnow() + discovery_request = ( "GET", "https://accounts.google.com", @@ -1261,13 +1215,56 @@ async def test_http_auth_ext_google_callback_01(self) -> None: ) ) - token_request = self.generate_and_serve_jwk( - client_id, - "https://www.googleapis.com/oauth2/v3/certs", - "https://oauth2.googleapis.com/token", - "https://accounts.google.com", - "google_access_token", + jwks_request = ( + "GET", + "https://www.googleapis.com", + "oauth2/v3/certs", + ) + # Generate a JWK Set + k = jwk.JWK.generate(kty='RSA', size=4096) + ks = jwk.JWKSet() + ks.add(k) + jwk_set: dict[str, Any] = ks.export( + private_keys=False, as_dict=True + ) + + self.mock_oauth_server.register_route_handler(*jwks_request)( + ( + json.dumps(jwk_set), + 200, + ) + ) + + token_request = ( + "POST", + "https://oauth2.googleapis.com", + "token", + ) + id_token_claims = { + "iss": "https://accounts.google.com", + "sub": "1", + "aud": client_id, + "exp": (now + datetime.timedelta(minutes=5)).timestamp(), + "iat": now.timestamp(), + "email": "test@example.com", + } + id_token = jwt.JWT(header={"alg": "RS256"}, claims=id_token_claims) + id_token.make_signed_token(k) + + self.mock_oauth_server.register_route_handler(*token_request)( + ( + json.dumps( + { + "access_token": "google_access_token", + "id_token": id_token.serialize(), + "scope": "openid", + "token_type": "bearer", + } + ), + 200, + ) ) + challenge = ( base64.urlsafe_b64encode( hashlib.sha256( @@ -1288,9 +1285,11 @@ async def test_http_auth_ext_google_callback_01(self) -> None: signing_key = await self.get_signing_key() + expires_at = now + datetime.timedelta(minutes=5) state_claims = { "iss": self.http_addr, "provider": str(provider_name), + "exp": expires_at.timestamp(), "redirect_to": f"{self.http_addr}/some/path", "challenge": challenge, } @@ -1499,6 +1498,8 @@ async def test_http_auth_ext_azure_callback_01(self) -> None: client_id = provider_config.client_id client_secret = AZURE_SECRET + now = utcnow() + discovery_request = ( "GET", "https://login.microsoftonline.com", @@ -1511,13 +1512,56 @@ async def test_http_auth_ext_azure_callback_01(self) -> None: ) ) - token_request = self.generate_and_serve_jwk( - client_id, - "https://login.microsoftonline.com/common/discovery/v2.0/keys", - "https://login.microsoftonline.com/common/oauth2/v2.0/token", + jwks_request = ( + "GET", "https://login.microsoftonline.com", - "azure_access_token", + "common/discovery/v2.0/keys", + ) + # Generate a JWK Set + k = jwk.JWK.generate(kty='RSA', size=4096) + ks = jwk.JWKSet() + ks.add(k) + jwk_set: dict[str, Any] = ks.export( + private_keys=False, as_dict=True ) + + self.mock_oauth_server.register_route_handler(*jwks_request)( + ( + json.dumps(jwk_set), + 200, + ) + ) + + token_request = ( + "POST", + "https://login.microsoftonline.com", + "common/oauth2/v2.0/token", + ) + id_token_claims = { + "iss": "https://login.microsoftonline.com/common/v2.0", + "sub": "1", + "aud": client_id, + "exp": (now + datetime.timedelta(minutes=5)).timestamp(), + "iat": now.timestamp(), + "email": "test@example.com", + } + id_token = jwt.JWT(header={"alg": "RS256"}, claims=id_token_claims) + id_token.make_signed_token(k) + + self.mock_oauth_server.register_route_handler(*token_request)( + ( + json.dumps( + { + "access_token": "azure_access_token", + "id_token": id_token.serialize(), + "scope": "openid", + "token_type": "bearer", + } + ), + 200, + ) + ) + challenge = ( base64.urlsafe_b64encode( hashlib.sha256( @@ -1538,9 +1582,11 @@ async def test_http_auth_ext_azure_callback_01(self) -> None: signing_key = await self.get_signing_key() + expires_at = now + datetime.timedelta(minutes=5) state_claims = { "iss": self.http_addr, "provider": str(provider_name), + "exp": expires_at.timestamp(), "redirect_to": f"{self.http_addr}/some/path", "challenge": challenge, } @@ -1552,7 +1598,7 @@ async def test_http_auth_ext_azure_callback_01(self) -> None: path="callback", ) - self.assertEqual(data, b"", data) + self.assertEqual(data, b"") self.assertEqual(status, 302) location = headers.get("location") @@ -1668,6 +1714,8 @@ async def test_http_auth_ext_apple_callback_01(self) -> None: client_id = provider_config.client_id client_secret = APPLE_SECRET + now = utcnow() + discovery_request = ( "GET", "https://appleid.apple.com", @@ -1680,12 +1728,54 @@ async def test_http_auth_ext_apple_callback_01(self) -> None: ) ) - token_request = self.generate_and_serve_jwk( - client_id, - "https://appleid.apple.com/auth/keys", - "https://appleid.apple.com/auth/token", + jwks_request = ( + "GET", + "https://appleid.apple.com", + "auth/keys", + ) + # Generate a JWK Set + k = jwk.JWK.generate(kty='RSA', size=4096) + ks = jwk.JWKSet() + ks.add(k) + jwk_set: dict[str, Any] = ks.export( + private_keys=False, as_dict=True + ) + + self.mock_oauth_server.register_route_handler(*jwks_request)( + ( + json.dumps(jwk_set), + 200, + ) + ) + + token_request = ( + "POST", "https://appleid.apple.com", - "apple_access_token", + "auth/token", + ) + id_token_claims = { + "iss": "https://appleid.apple.com", + "sub": "1", + "aud": client_id, + "exp": (now + datetime.timedelta(minutes=5)).timestamp(), + "iat": now.timestamp(), + "email": "test@example.com", + } + id_token = jwt.JWT(header={"alg": "RS256"}, claims=id_token_claims) + id_token.make_signed_token(k) + + self.mock_oauth_server.register_route_handler(*token_request)( + ( + json.dumps( + { + "access_token": "apple_access_token", + "id_token": id_token.serialize(), + "scope": "openid", + "token_type": "bearer", + } + ), + 200, + ) ) challenge = ( @@ -1708,9 +1798,11 @@ async def test_http_auth_ext_apple_callback_01(self) -> None: signing_key = await self.get_signing_key() + expires_at = now + datetime.timedelta(minutes=5) state_claims = { "iss": self.http_addr, "provider": str(provider_name), + "exp": expires_at.timestamp(), "redirect_to": f"{self.http_addr}/some/path", "challenge": challenge, } @@ -1766,6 +1858,8 @@ async def test_http_auth_ext_apple_callback_redirect_on_signup_02( provider_name = provider_config.name client_id = provider_config.client_id + now = utcnow() + discovery_request = ( "GET", "https://appleid.apple.com", @@ -1778,13 +1872,54 @@ async def test_http_auth_ext_apple_callback_redirect_on_signup_02( ) ) - _token_request = self.generate_and_serve_jwk( - client_id, - "https://appleid.apple.com/auth/keys", - "https://appleid.apple.com/auth/token", + jwks_request = ( + "GET", "https://appleid.apple.com", - "apple_access_token", - sub="2", + "auth/keys", + ) + # Generate a JWK Set + k = jwk.JWK.generate(kty='RSA', size=4096) + ks = jwk.JWKSet() + ks.add(k) + jwk_set: dict[str, Any] = ks.export( + private_keys=False, as_dict=True + ) + + self.mock_oauth_server.register_route_handler(*jwks_request)( + ( + json.dumps(jwk_set), + 200, + ) + ) + + token_request = ( + "POST", + "https://appleid.apple.com", + "auth/token", + ) + id_token_claims = { + "iss": "https://appleid.apple.com", + "sub": "2", + "aud": client_id, + "exp": (now + datetime.timedelta(minutes=5)).timestamp(), + "iat": now.timestamp(), + "email": "test@example.com", + } + id_token = jwt.JWT(header={"alg": "RS256"}, claims=id_token_claims) + id_token.make_signed_token(k) + + self.mock_oauth_server.register_route_handler(*token_request)( + ( + json.dumps( + { + "access_token": "apple_access_token", + "id_token": id_token.serialize(), + "scope": "openid", + "token_type": "bearer", + } + ), + 200, + ) ) challenge = ( @@ -1807,9 +1942,11 @@ async def test_http_auth_ext_apple_callback_redirect_on_signup_02( signing_key = await self.get_signing_key() + expires_at = now + datetime.timedelta(minutes=5) state_claims = { "iss": self.http_addr, "provider": str(provider_name), + "exp": expires_at.timestamp(), "redirect_to": f"{self.http_addr}/some/path", "redirect_to_on_signup": f"{self.http_addr}/some/other/path", "challenge": challenge, @@ -1827,7 +1964,7 @@ async def test_http_auth_ext_apple_callback_redirect_on_signup_02( headers={"Content-Type": "application/x-www-form-urlencoded"}, ) - self.assertEqual(data, b"", data) + self.assertEqual(data, b"") self.assertEqual(status, 302) location = headers.get("location") @@ -1869,6 +2006,8 @@ async def test_http_auth_ext_slack_callback_01(self) -> None: client_id = provider_config.client_id client_secret = SLACK_SECRET + now = utcnow() + discovery_request = ( "GET", "https://slack.com", @@ -1882,12 +2021,54 @@ async def test_http_auth_ext_slack_callback_01(self) -> None: ) ) - token_request = self.generate_and_serve_jwk( - client_id, - "https://slack.com/openid/connect/keys", - "https://slack.com/api/openid.connect.token", + jwks_request = ( + "GET", + "https://slack.com", + "openid/connect/keys", + ) + # Generate a JWK Set + k = jwk.JWK.generate(kty='RSA', size=4096) + ks = jwk.JWKSet() + ks.add(k) + jwk_set: dict[str, Any] = ks.export( + private_keys=False, as_dict=True + ) + + self.mock_oauth_server.register_route_handler(*jwks_request)( + ( + json.dumps(jwk_set), + 200, + ) + ) + + token_request = ( + "POST", "https://slack.com", - "slack_access_token", + "api/openid.connect.token", + ) + id_token_claims = { + "iss": "https://slack.com", + "sub": "1", + "aud": client_id, + "exp": (now + datetime.timedelta(minutes=5)).timestamp(), + "iat": now.timestamp(), + "email": "test@example.com", + } + id_token = jwt.JWT(header={"alg": "RS256"}, claims=id_token_claims) + id_token.make_signed_token(k) + + self.mock_oauth_server.register_route_handler(*token_request)( + ( + json.dumps( + { + "access_token": "slack_access_token", + "id_token": id_token.serialize(), + "scope": "openid", + "token_type": "bearer", + } + ), + 200, + ) ) challenge = ( @@ -1910,9 +2091,11 @@ async def test_http_auth_ext_slack_callback_01(self) -> None: signing_key = await self.get_signing_key() + expires_at = now + datetime.timedelta(minutes=5) state_claims = { "iss": self.http_addr, "provider": str(provider_name), + "exp": expires_at.timestamp(), "redirect_to": f"{self.http_addr}/some/path", "challenge": challenge, } @@ -2130,6 +2313,8 @@ async def test_http_auth_ext_generic_oidc_callback_01(self): client_id = provider_config.client_id client_secret = GENERIC_OIDC_SECRET + now = utcnow() + discovery_request = ( "GET", "https://example.com", @@ -2143,12 +2328,54 @@ async def test_http_auth_ext_generic_oidc_callback_01(self): ) ) - token_request = self.generate_and_serve_jwk( - client_id, - "https://example.com/jwks", - "https://example.com/token", + jwks_request = ( + "GET", "https://example.com", - "oidc_access_token", + "jwks", + ) + # Generate a JWK Set + k = jwk.JWK.generate(kty='RSA', size=4096) + ks = jwk.JWKSet() + ks.add(k) + jwk_set: dict[str, Any] = ks.export( + private_keys=False, as_dict=True + ) + + self.mock_oauth_server.register_route_handler(*jwks_request)( + ( + json.dumps(jwk_set), + 200, + ) + ) + + token_request = ( + "POST", + "https://example.com", + "token", + ) + id_token_claims = { + "iss": "https://example.com", + "sub": "1", + "aud": client_id, + "exp": (now + datetime.timedelta(minutes=5)).timestamp(), + "iat": now.timestamp(), + "email": "test@example.com", + } + id_token = jwt.JWT(header={"alg": "RS256"}, claims=id_token_claims) + id_token.make_signed_token(k) + + self.mock_oauth_server.register_route_handler(*token_request)( + ( + json.dumps( + { + "access_token": "oidc_access_token", + "id_token": id_token.serialize(), + "scope": "openid", + "token_type": "bearer", + } + ), + 200, + ) ) challenge = ( @@ -2171,9 +2398,11 @@ async def test_http_auth_ext_generic_oidc_callback_01(self): signing_key = await self.get_signing_key() + expires_at = now + datetime.timedelta(minutes=5) state_claims = { "iss": self.http_addr, "provider": str(provider_name), + "exp": expires_at.timestamp(), "redirect_to": f"{self.http_addr}/some/path", "challenge": challenge, } @@ -2651,7 +2880,7 @@ async def test_http_auth_ext_local_password_register_json_02(self): headers={"Content-Type": "application/json"}, ) - self.assertEqual(status, 201, body) + self.assertEqual(status, 201) identity = await self.con.query_single( """ @@ -3052,7 +3281,7 @@ async def test_http_auth_ext_local_emailpassword_resend_verification(self): } resend_data_encoded = urllib.parse.urlencode(resend_data).encode() - body, _, status = self.http_con_request( + _, _, status = self.http_con_request( http_con, None, path="resend-verification-email", @@ -3061,7 +3290,7 @@ async def test_http_auth_ext_local_emailpassword_resend_verification(self): headers={"Content-Type": "application/x-www-form-urlencoded"}, ) - self.assertEqual(status, 200, body) + self.assertEqual(status, 200) # Resend verification email with just the email resend_data = { @@ -3359,7 +3588,7 @@ async def test_http_auth_ext_token_01(self): path="token", ) - self.assertEqual(status, 200, body) + self.assertEqual(status, 200) body_json = json.loads(body) self.assertEqual( body_json, @@ -3542,9 +3771,12 @@ async def test_http_auth_ext_local_password_forgot_form_01(self): claims = await self.extract_jwt_claims( reset_url.split('=', maxsplit=1)[1], "reset" ) - # Expiry checked as part of the validation self.assertEqual(claims.get("sub"), str(identity[0].id)) self.assertEqual(claims.get("iss"), str(self.http_addr)) + now = utcnow() + tenMinutesLater = now + datetime.timedelta(minutes=10) + self.assertTrue(claims.get("exp") > now.timestamp()) + self.assertTrue(claims.get("exp") < tenMinutesLater.timestamp()) password_credential = await self.con.query( """ @@ -3700,7 +3932,7 @@ async def test_http_auth_ext_local_password_reset_form_01(self): headers={"Content-Type": "application/x-www-form-urlencoded"}, ) - self.assertEqual(status, 200, body) + self.assertEqual(status, 200) file_name_hash = hashlib.sha256( f"{SENDER}{email}".encode() @@ -4175,7 +4407,7 @@ async def test_http_auth_ext_magic_link_with_link_url(self): with self.http_con() as http_con: # Register with link_url - body, _, status = self.http_con_request( + _, _, status = self.http_con_request( http_con, method="POST", path="magic-link/register", @@ -4194,7 +4426,7 @@ async def test_http_auth_ext_magic_link_with_link_url(self): "Accept": "application/json", }, ) - self.assertEqual(status, 200, body) + self.assertEqual(status, 200) # Get the token from email file_name_hash = hashlib.sha256( @@ -4327,7 +4559,7 @@ async def test_http_auth_ext_magic_link_without_link_url(self): with self.http_con() as http_con: # Register without link_url - body, _, status = self.http_con_request( + _, _, status = self.http_con_request( http_con, method="POST", path="magic-link/register", @@ -4345,7 +4577,7 @@ async def test_http_auth_ext_magic_link_without_link_url(self): "Accept": "application/json", }, ) - self.assertEqual(status, 200, body) + self.assertEqual(status, 200) # Get the token from email file_name_hash = hashlib.sha256( diff --git a/tests/test_server_auth.py b/tests/test_server_auth.py index bb871209dff9..eeeeeefdd572 100644 --- a/tests/test_server_auth.py +++ b/tests/test_server_auth.py @@ -27,6 +27,8 @@ import urllib.error import urllib.request +import jwcrypto.jwk + import edgedb from edb import errors @@ -34,7 +36,6 @@ from edb.common import secretkey from edb.server import args from edb.server import cluster as edbcluster -from edb.server.auth import JWKSet, generate_gel_token from edb.schema import defines as s_def from edb.testbase import server as tb @@ -373,16 +374,15 @@ def _jwt_gql_request(self, server, *, sk=None, password=None): async def test_server_auth_jwt_1(self): jwk_fd, jwk_file = tempfile.mkstemp() - jws = JWKSet() - jws.generate(kid=None, kty="ES256") + key = jwcrypto.jwk.JWK(generate='EC') with open(jwk_fd, "wb") as f: - f.write(jws.export_pem()) - + f.write(key.export_to_pem(private_key=True, password=None)) + jwk = secretkey.load_secret_key(pathlib.Path(jwk_file)) async with tb.start_edgedb_server( jws_key_file=pathlib.Path(jwk_file), default_auth_method=args.ServerAuthMethod.JWT, ) as sd: - base_sk = generate_gel_token(jws) + base_sk = secretkey.generate_secret_key(jwk) conn = await sd.connect(secret_key=base_sk) await conn.execute(''' CREATE SUPERUSER ROLE foo { @@ -419,7 +419,7 @@ async def test_server_auth_jwt_1(self): ): await sd.connect(secret_key='wrong') - sk = generate_gel_token(jws) + sk = secretkey.generate_secret_key(jwk) corrupt_sk = sk[:50] + "0" + sk[51:] with self.assertRaisesRegex( @@ -467,7 +467,7 @@ async def test_server_auth_jwt_1(self): for params in good_keys: params_dict = dict(params) with self.subTest(**params_dict): - sk = generate_gel_token(jws, **params_dict) + sk = secretkey.generate_secret_key(jwk, **params_dict) conn = await sd.connect(secret_key=sk) await conn.aclose() @@ -491,7 +491,7 @@ async def test_server_auth_jwt_1(self): for params, msg in bad_keys.items(): params_dict = dict(params) with self.subTest(**params_dict): - sk = generate_gel_token(jws, **params_dict) + sk = secretkey.generate_secret_key(jwk, **params_dict) with self.assertRaisesRegex( edgedb.AuthenticationError, "authentication failed: " + msg, @@ -510,10 +510,9 @@ async def test_server_auth_jwt_1(self): async def test_server_auth_jwt_2(self): jwk_fd, jwk_file = tempfile.mkstemp() - jws = JWKSet() - jws.generate(kid=None, kty="ES256") + key = jwcrypto.jwk.JWK(generate='EC') with open(jwk_fd, "wb") as f: - f.write(jws.export_pem()) + f.write(key.export_to_pem(private_key=True, password=None)) allowlist_fd, allowlist_file = tempfile.mkstemp() os.close(allowlist_fd) @@ -597,10 +596,9 @@ async def test_server_auth_jwt_2(self): async def test_server_auth_multiple_methods(self): jwk_fd, jwk_file = tempfile.mkstemp() - jws = JWKSet() - jws.generate(kid=None, kty="ES256") + key = jwcrypto.jwk.JWK(generate='EC') with open(jwk_fd, "wb") as f: - f.write(jws.export_pem()) + f.write(key.export_to_pem(private_key=True, password=None)) jwk = secretkey.load_secret_key(pathlib.Path(jwk_file)) async with tb.start_edgedb_server( jws_key_file=pathlib.Path(jwk_file),