Skip to content

Commit

Permalink
Python updates
Browse files Browse the repository at this point in the history
  • Loading branch information
mmastrac committed Feb 7, 2025
1 parent cf90c69 commit 3fed1c7
Show file tree
Hide file tree
Showing 28 changed files with 729 additions and 779 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ members = [
resolver = "2"

[workspace.dependencies]
pyo3 = { version = "0.23", features = ["extension-module", "serde", "macros"] }
pyo3 = { version = "0.23.4", features = ["serde", "macros"] }
tokio = { version = "1", features = ["rt", "rt-multi-thread", "macros", "time", "sync", "net", "io-util"] }
tracing = "0.1.40"
tracing-subscriber = { version = "0.3.18", features = ["registry", "env-filter"] }
Expand Down
83 changes: 2 additions & 81 deletions edb/common/secretkey.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,83 +17,12 @@
#

from __future__ import annotations
from typing import Optional, AbstractSet, Iterable
from typing import Iterable

import pathlib
import uuid

from datetime import datetime, timedelta, timezone

from jwcrypto import jwk, jwt

from . import uuidgen


class SecretKeyReadError(Exception):
pass


def generate_secret_key(
skey: jwk.JWK,
*,
instances: Optional[list[str] | AbstractSet[str]] = None,
roles: Optional[list[str] | AbstractSet[str]] = None,
databases: Optional[list[str] | AbstractSet[str]] = None,
subject: Optional[str] = None,
key_id: Optional[str] = None,
) -> str:
claims = {
"iat": int(datetime.now(timezone.utc).timestamp()),
"iss": "edgedb-server",
}

if instances is None:
claims["edb.i.all"] = True
else:
claims["edb.i"] = list(instances)

if roles is None:
claims["edb.r.all"] = True
else:
claims["edb.r"] = list(roles)

if databases is None:
claims["edb.d.all"] = True
else:
claims["edb.d"] = list(databases)

if subject is not None:
claims["sub"] = subject

if key_id is None:
key_id = str(uuidgen.uuid4())

claims["jti"] = key_id

token = jwt.JWT(
header={"alg": "ES256" if skey["kty"] == "EC" else "RS256"},
claims=claims,
)
token.make_signed_token(skey)
return "edbt1_" + token.serialize()


def load_secret_key(key_file: pathlib.Path) -> jwk.JWK:
try:
with open(key_file, 'rb') as kf:
jws_key = jwk.JWK.from_pem(kf.read())
except Exception as e:
raise SecretKeyReadError(f"cannot load JWS key: {e}") from e

if (
not jws_key.has_public
or jws_key['kty'] not in {"RSA", "EC"}
):
raise SecretKeyReadError(
f"the cluster JWS key file does not "
f"contain a valid RSA or EC public key")

return jws_key
from datetime import datetime, timedelta


def generate_tls_cert(
Expand Down Expand Up @@ -154,11 +83,3 @@ def generate_tls_cert(
)
)
tls_key_file.chmod(0o600)


def generate_jwk(keys_file: pathlib.Path) -> None:
key = jwk.JWK(generate='EC')
with keys_file.open("wb") as f:
f.write(key.export_to_pem(private_key=True, password=None))

keys_file.chmod(0o600)
5 changes: 3 additions & 2 deletions edb/server/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -984,8 +984,9 @@ def resolve_envvar_value(self, ctx: click.Context):
cls=EnvvarResolver,
hidden=True,
help='Specifies a path to a file containing a public key in PEM '
'format used to verify JWT signatures. The file could also '
'contain a private key to sign JWT for local testing.'),
'or JSON JWK format used to verify JWT signatures. The file may '
'also contain a private key to sign JWT tokens for '
'SCRAM-over-HTTP.'),
click.option(
'--jwe-key-file',
type=PathPath(),
Expand Down
68 changes: 61 additions & 7 deletions edb/server/auth.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import datetime
import pathlib

from typing import TYPE_CHECKING, Iterable, List, Optional, Any

if TYPE_CHECKING:
Expand All @@ -21,17 +24,19 @@ def set_issuer(self, issuer: str) -> None: ...
def set_audience(self, audience: str) -> None: ...
def set_expiry(self, expiry: int) -> None: ...
def set_not_before(self, not_before: int) -> None: ...
def allow(self, claim: str, values: List[str]) -> None: ...
def deny(self, claim: str, values: List[str]) -> None: ...
def export_pem(self, *, private_keys: bool) -> bytes: ...
def export_json(self, *, private_keys: bool) -> bytes: ...
def allow(self, claim: str, values: Iterable[str]) -> None: ...
def deny(self, claim: str, values: Iterable[str]) -> None: ...
def export_pem(self, *, private_keys: bool=True) -> bytes: ...
def export_json(self, *, private_keys: bool=True) -> bytes: ...
def can_sign(self) -> bool: ...
def can_validate(self) -> bool: ...
def has_public_keys(self) -> bool: ...
def has_private_keys(self) -> bool: ...
def has_symmetric_keys(self) -> bool: ...
def sign(
self, claims: dict[str, Any], *, ctx: Optional[SigningCtx] = None
) -> str: ...
def validate(self, token: str) -> dict[str, Any]: ...
def to_json(self, *, private_keys: bool) -> str: ...
def to_pem(self, *, private_keys: bool) -> str: ...

class JWKSetCache:
def __init__(self, expiry_seconds: int) -> None: ...
Expand All @@ -45,6 +50,55 @@ def generate_gel_token(
instances: Optional[List[str] | Iterable[str]] = None,
roles: Optional[List[str] | Iterable[str]] = None,
databases: Optional[List[str] | Iterable[str]] = None,
**kwargs: Any,
) -> str: ...

def validate_gel_token(
registry: JWKSet,
token: str,
user: str,
dbname: str,
instance_name: str,
) -> str | None: ...
else:
from edb.server._rust_native._jwt import JWKSet, JWKSetCache, generate_gel_token, SigningCtx # noqa
from edb.server._rust_native._jwt import (
JWKSet, JWKSetCache, generate_gel_token, validate_gel_token, SigningCtx # noqa
)


def load_secret_key(key_file: pathlib.Path) -> JWKSet:
try:
with open(key_file, 'rb') as kf:
jws_key = JWKSet()
jws_key.load(kf.read().decode('ascii'))
except Exception as e:
raise SecretKeyReadError(f"cannot load JWS key {key_file}: {e}") from e
if not jws_key.can_validate():
raise SecretKeyReadError(
f"the cluster JWS key file {key_file} does not "
f"contain a valid key for token validation (RSA, EC or "
f"HMAC-SHA256)")

return jws_key


def generate_jwk(keys_file: pathlib.Path) -> None:
key = JWKSet()
# kid is yyyymmdd
kid = datetime.datetime.now(datetime.timezone.utc).strftime("%Y%m%d")
key.generate(kid=kid, kty='ES256')
if keys_file.name.endswith(".pem"):
with keys_file.open("wb") as f:
f.write(key.export_pem())
elif keys_file.name.endswith(".json"):
with keys_file.open("wb") as f:
f.write(key.export_json())
else:
raise ValueError(f"Unsupported key file extension {keys_file.suffix}. "
"Use .pem or .json extension when generating a key.")

keys_file.chmod(0o600)


class SecretKeyReadError(Exception):
pass
5 changes: 2 additions & 3 deletions edb/server/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,14 @@
import tempfile
import time

from jwcrypto import jwk

from edb import buildmeta
from edb.common import devmode
from edb.edgeql import quote

from edb.server import args as edgedb_args
from edb.server import defines as edgedb_defines
from edb.server import pgconnparams
from edb.server import auth

from . import pgcluster

Expand Down Expand Up @@ -428,7 +427,7 @@ def __init__(
self._edgedb_cmd.extend(['-D', str(self._data_dir)])
self._pg_connect_args['user'] = pg_superuser
self._pg_connect_args['database'] = 'template1'
self._jws_key: Optional[jwk.JWK] = None
self._jws_key: Optional[auth.JWKSet] = None

async def _new_pg_cluster(self) -> pgcluster.Cluster:
return await pgcluster.get_local_pg_cluster(
Expand Down
24 changes: 24 additions & 0 deletions edb/server/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,30 @@
labels=("tenant",),
)

auth_provider_jwkset_fetch_success = registry.new_labeled_counter(
"auth_provider_jwkset_fetch_success_total",
"Number of successful Auth extension JWK Set fetches.",
labels=("provider",),
)

auth_provider_jwkset_fetch_errors = registry.new_labeled_counter(
"auth_provider_jwkset_fetch_errors_total",
"Number of failed Auth extension JWK Set fetches.",
labels=("provider",),
)

auth_provider_token_validation_success = registry.new_labeled_counter(
"auth_provider_token_validation_success_total",
"Number of successful Auth extension provider token validations.",
labels=("provider",),
)

auth_provider_token_validation_errors = registry.new_labeled_counter(
"auth_provider_token_validation_errors_total",
"Number of failed Auth extension provider token validations.",
labels=("provider",),
)

mt_tenants_total = registry.new_gauge(
'mt_tenants_current',
'Total number of currently-registered tenants.',
Expand Down
10 changes: 5 additions & 5 deletions edb/server/protocol/auth/scram.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

from edb.common import debug
from edb.common import markup
from edb.common import secretkey
from edb.server import auth

if TYPE_CHECKING:
from edb.server import tenant as edbtenant
Expand Down Expand Up @@ -89,7 +89,8 @@ def handle_request(
response.close_connection = True
return

if not server.get_jws_key().has_private: # type: ignore[union-attr]
jws = server.get_jws_key()
if jws is None or not jws.has_private_keys():
response.body = b"Server doesn't support HTTP SCRAM authentication"
response.status = http.HTTPStatus.FORBIDDEN
response.close_connection = True
Expand Down Expand Up @@ -268,9 +269,8 @@ def handle_request(
).decode("ascii")

try:
response.body = secretkey.generate_secret_key(
server.get_jws_key(),
roles=[username],
response.body = auth.generate_gel_token(
jws, roles=[username],
).encode("ascii")
except ValueError as ex:
if debug.flags.server:
Expand Down
58 changes: 41 additions & 17 deletions edb/server/protocol/auth_ext/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,19 @@

import uuid
import urllib.parse
import json
import enum
import logging

from typing import Any, Callable
from jwcrypto import jwt, jwk
from datetime import datetime

from . import data, errors
from edb.server.http import HttpClient
from edb.server import auth as jwt_auth
from edb.server.protocol.auth_ext import util as auth_util
from edb.server import metrics

logger = logging.getLogger("edb.server.ext.auth")


class BaseProvider:
Expand Down Expand Up @@ -145,28 +149,48 @@ async def fetch_user_info(
)
id_token = token_response.id_token

# Retrieve JWK Set
# Retrieve JWK Set, potentially from the cache
oidc_config = await self._get_oidc_config()
jwks_uri = urllib.parse.urlparse(oidc_config.jwks_uri)
async with self.http_factory(
base_url=f"{jwks_uri.scheme}://{jwks_uri.netloc}"
) as client:
r = await client.get(jwks_uri.path)

# Load the token as a JWT object and verify it directly
try:
jwk_set = jwk.JWKSet.from_json(r.text)
id_token_verified = jwt.JWT(key=jwk_set, jwt=id_token)
payload = json.loads(id_token_verified.claims)
async def fetcher(url: str) -> jwt_auth.JWKSet:
jwks_uri = urllib.parse.urlparse(url)
async with self.http_factory(
base_url=f"{jwks_uri.scheme}://{jwks_uri.netloc}"
) as client:
r = await client.get(jwks_uri.path)
jwk_set = jwt_auth.JWKSet()
jwk_set.load_json(r.text)
jwk_set.set_audience(self.client_id)
jwk_set.set_expiry(3600)
metrics.auth_provider_jwkset_fetch_success.inc(
1.0, self.name
)
return jwk_set

jwk_set = await auth_util.get_remote_jwtset(
oidc_config.jwks_uri, fetcher
)
except Exception as e:
metrics.auth_provider_jwkset_fetch_errors.inc(1.0, self.name)
logger.exception(
f"Failed to fetch JWK Set from provider {oidc_config.jwks_uri}"
)
raise errors.MisconfiguredProvider(
f"Failed to fetch JWK Set from provider {oidc_config.jwks_uri}"
) from e

# Load the token as a JWT object and verify it directly. This will
# validate the audience and expiry.
try:
payload = jwk_set.validate(id_token)
except Exception as e:
metrics.auth_provider_token_validation_errors.inc(1.0, self.name)
raise errors.MisconfiguredProvider(
"Failed to parse ID token with provider keyset"
) from e
if payload.get("aud") != self.client_id:
raise errors.InvalidData(
"Invalid value for aud in id_token: "
f"{payload.get('aud')} != {self.client_id}"
)

metrics.auth_provider_token_validation_success.inc(1.0, self.name)

return data.UserInfo(
sub=str(payload["sub"]),
Expand Down
Loading

0 comments on commit 3fed1c7

Please sign in to comment.