Skip to content

Commit

Permalink
Back out python changes
Browse files Browse the repository at this point in the history
  • Loading branch information
mmastrac committed Feb 7, 2025
1 parent 1437b07 commit b80c484
Show file tree
Hide file tree
Showing 8 changed files with 492 additions and 300 deletions.
24 changes: 0 additions & 24 deletions edb/server/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,30 +223,6 @@
labels=("tenant",),
)

auth_provider_jwkset_fetch_success = registry.new_labeled_counter(
"auth_provider_jwkset_fetch_success_total",
"Number of successful Auth extension JWK Set fetches.",
labels=("provider",),
)

auth_provider_jwkset_fetch_errors = registry.new_labeled_counter(
"auth_provider_jwkset_fetch_errors_total",
"Number of failed Auth extension JWK Set fetches.",
labels=("provider",),
)

auth_provider_token_validation_success = registry.new_labeled_counter(
"auth_provider_token_validation_success_total",
"Number of successful Auth extension provider token validations.",
labels=("provider",),
)

auth_provider_token_validation_errors = registry.new_labeled_counter(
"auth_provider_token_validation_errors_total",
"Number of failed Auth extension provider token validations.",
labels=("provider",),
)

mt_tenants_total = registry.new_gauge(
'mt_tenants_current',
'Total number of currently-registered tenants.',
Expand Down
58 changes: 17 additions & 41 deletions edb/server/protocol/auth_ext/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,15 @@

import uuid
import urllib.parse
import json
import enum
import logging

from typing import Any, Callable
from jwcrypto import jwt, jwk
from datetime import datetime

from . import data, errors
from edb.server.http import HttpClient
from edb.server import auth as jwt_auth
from edb.server.protocol.auth_ext import util as auth_util
from edb.server import metrics

logger = logging.getLogger("edb.server.ext.auth")


class BaseProvider:
Expand Down Expand Up @@ -149,48 +145,28 @@ async def fetch_user_info(
)
id_token = token_response.id_token

# Retrieve JWK Set, potentially from the cache
# Retrieve JWK Set
oidc_config = await self._get_oidc_config()
jwks_uri = urllib.parse.urlparse(oidc_config.jwks_uri)
async with self.http_factory(
base_url=f"{jwks_uri.scheme}://{jwks_uri.netloc}"
) as client:
r = await client.get(jwks_uri.path)

# Load the token as a JWT object and verify it directly
try:
async def fetcher(url: str) -> jwt_auth.JWKSet:
jwks_uri = urllib.parse.urlparse(url)
async with self.http_factory(
base_url=f"{jwks_uri.scheme}://{jwks_uri.netloc}"
) as client:
r = await client.get(jwks_uri.path)
jwk_set = jwt_auth.JWKSet()
jwk_set.load_json(r.text)
jwk_set.set_audience(self.client_id)
jwk_set.set_expiry(3600)
metrics.auth_provider_jwkset_fetch_success.inc(
1.0, self.name
)
return jwk_set

jwk_set = await auth_util.get_remote_jwtset(
oidc_config.jwks_uri, fetcher
)
except Exception as e:
metrics.auth_provider_jwkset_fetch_errors.inc(1.0, self.name)
logger.exception(
f"Failed to fetch JWK Set from provider {oidc_config.jwks_uri}"
)
raise errors.MisconfiguredProvider(
f"Failed to fetch JWK Set from provider {oidc_config.jwks_uri}"
) from e

# Load the token as a JWT object and verify it directly. This will
# validate the audience and expiry.
try:
payload = jwk_set.validate(id_token)
jwk_set = jwk.JWKSet.from_json(r.text)
id_token_verified = jwt.JWT(key=jwk_set, jwt=id_token)
payload = json.loads(id_token_verified.claims)
except Exception as e:
metrics.auth_provider_token_validation_errors.inc(1.0, self.name)
raise errors.MisconfiguredProvider(
"Failed to parse ID token with provider keyset"
) from e

metrics.auth_provider_token_validation_success.inc(1.0, self.name)
if payload.get("aud") != self.client_id:
raise errors.InvalidData(
"Invalid value for aud in id_token: "
f"{payload.get('aud')} != {self.client_id}"
)

return data.UserInfo(
sub=str(payload["sub"]),
Expand Down
84 changes: 57 additions & 27 deletions edb/server/protocol/auth_ext/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,13 @@
)

import aiosmtplib
from jwcrypto import jwk, jwt

from edb import errors as edb_errors
from edb.common import debug
from edb.common import markup
from edb.ir import statypes
from edb.server import tenant as edbtenant, metrics, auth as jwt_auth
from edb.server import tenant as edbtenant, metrics
from edb.server.config.types import CompositeConfigType

from . import (
Expand Down Expand Up @@ -2064,16 +2065,13 @@ async def _maybe_send_webhook(self, event: webhook.Event) -> None:
def _get_callback_url(self) -> str:
return f"{self.base_path}/callback"

def _get_auth_signing_key(self, info: str | None = None) -> jwt_auth.JWKSet:
def _get_auth_signing_key(self) -> jwk.JWK:
auth_signing_key = util.get_config(
self.db, "ext::auth::AuthConfig::auth_signing_key"
)
if info is None:
return jwt_auth.JWKSet.from_hs256_key(auth_signing_key.encode())
else:
return jwt_auth.JWKSet.from_hs256_key(
util.derive_key_raw(auth_signing_key, info)
)
key_bytes = base64.b64encode(auth_signing_key.encode())

return jwk.JWK(kty="oct", k=key_bytes.decode())

def _make_state_claims(
self,
Expand All @@ -2083,19 +2081,25 @@ def _make_state_claims(
challenge: str,
) -> str:
signing_key = self._get_auth_signing_key()
signing_ctx = jwt_auth.SigningCtx()
signing_ctx.set_expiry(5 * 60)
signing_ctx.set_not_before(30)
signing_ctx.set_issuer(self.base_path)
expires_at = datetime.datetime.now(
datetime.timezone.utc
) + datetime.timedelta(minutes=5)

state_claims = {
"iss": self.base_path,
"provider": provider,
"exp": expires_at.timestamp(),
"redirect_to": redirect_to,
"challenge": challenge,
}
if redirect_to_on_signup:
state_claims['redirect_to_on_signup'] = redirect_to_on_signup
return signing_key.sign(state_claims, ctx=signing_ctx)
state_token = jwt.JWT(
header={"alg": "HS256"},
claims=state_claims,
)
state_token.make_signed_token(signing_key)
return cast(str, state_token.serialize())

def _make_session_token(self, identity_id: str) -> str:
signing_key = self._get_auth_signing_key()
Expand All @@ -2104,15 +2108,34 @@ def _make_session_token(self, identity_id: str) -> str:
"ext::auth::AuthConfig::token_time_to_live",
statypes.Duration,
)
signing_ctx = jwt_auth.SigningCtx()
signing_ctx.set_expiry(int(auth_expiration_time.to_timedelta().total_seconds()))
signing_ctx.set_not_before(30)
signing_key.set_issuer(self.base_path)
session_token = signing_key.sign({
expires_in = auth_expiration_time.to_timedelta()
expires_at = datetime.datetime.now(datetime.timezone.utc) + expires_in

claims: dict[str, Any] = {
"iss": self.base_path,
"sub": identity_id,
}, ctx=signing_ctx)
}
if expires_in.total_seconds() != 0:
claims["exp"] = expires_at.timestamp()
session_token = jwt.JWT(
header={"alg": "HS256"},
claims=claims,
)
session_token.make_signed_token(signing_key)
metrics.auth_successful_logins.inc(1.0, self.tenant.get_instance_name())
return session_token
return cast(str, session_token.serialize())

def _get_from_claims(self, state: str, key: str) -> str:
signing_key = self._get_auth_signing_key()
try:
state_token = jwt.JWT(key=signing_key, jwt=state)
except Exception:
raise errors.InvalidData("Invalid state token")
state_claims: dict[str, str] = json.loads(state_token.claims)
value = state_claims.get(key)
if value is None:
raise errors.InvalidData("Invalid state token")
return value

def _make_secret_token(
self,
Expand All @@ -2124,7 +2147,8 @@ def _make_secret_token(
) = None,
expires_in: datetime.timedelta | None = None,
) -> str:
signing_key = self._get_auth_signing_key(derive_for_info)
input_key_material = self._get_auth_signing_key()
signing_key = util.derive_key(input_key_material, derive_for_info)
expires_in = (
datetime.timedelta(minutes=10) if expires_in is None else expires_in
)
Expand All @@ -2142,8 +2166,15 @@ def _make_secret_token(
def _verify_and_extract_claims(
self, jwtStr: str, key_info: str | None = None
) -> dict[str, str | int | float | bool]:
signing_key = self._get_auth_signing_key(key_info)
return signing_key.validate(jwtStr)
input_key_material = self._get_auth_signing_key()
if key_info is None:
signing_key = input_key_material
else:
signing_key = util.derive_key(input_key_material, key_info)
verified = jwt.JWT(key=signing_key, jwt=jwtStr)
return cast(
dict[str, str | int | float | bool], json.loads(verified.claims)
)

def _get_data_from_magic_link_token(
self, token: str
Expand Down Expand Up @@ -2219,14 +2250,14 @@ def _get_data_from_verification_token(
):
case (
str(id),
float(issued_at) | int(issued_at),
float(issued_at),
verify_url,
challenge,
redirect_to,
):
return_value = (
id,
float(issued_at),
issued_at,
verify_url,
challenge,
redirect_to,
Expand Down Expand Up @@ -2324,8 +2355,7 @@ def _make_verification_token(
"Verify URL does not match any allowed URLs.",
)

now = datetime.datetime.now(datetime.timezone.utc)
issued_at = int(now.timestamp())
issued_at = datetime.datetime.now(datetime.timezone.utc).timestamp()
return self._make_secret_token(
identity_id=identity_id,
secret=str(uuid.uuid4()),
Expand Down
15 changes: 7 additions & 8 deletions edb/server/protocol/auth_ext/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,27 +19,26 @@

import datetime
import json
import base64

from jwcrypto import jwk
from typing import Any, cast
from edb.server.protocol import execute
from edb.server import auth as jwt_auth

from . import util, data


class Client:
def __init__(self, db: Any):
self.db = db

def _get_signing_key(self, info: str | None = None) -> jwt_auth.JWKSet:
def _get_signing_key(self) -> jwk.JWK:
auth_signing_key = util.get_config(
self.db, "ext::auth::AuthConfig::auth_signing_key"
)
if info is None:
return jwt_auth.JWKSet.from_hs256_key(auth_signing_key.encode())
else:
return jwt_auth.JWKSet.from_hs256_key(
util.derive_key_raw(auth_signing_key, info)
)
key_bytes = base64.b64encode(auth_signing_key.encode())

return jwk.JWK(kty="oct", k=key_bytes.decode())

async def verify_email(
self, identity_id: str, verified_at: datetime.datetime
Expand Down
3 changes: 2 additions & 1 deletion edb/server/protocol/auth_ext/magic_link.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@ def make_magic_link_token(
callback_url: str,
challenge: str,
) -> str:
signing_key = self._get_signing_key("magic_link")
initial_key_material = self._get_signing_key()
signing_key = util.derive_key(initial_key_material, "magic_link")
return util.make_token(
signing_key=signing_key,
issuer=self.issuer,
Expand Down
Loading

0 comments on commit b80c484

Please sign in to comment.