From e25404edd437500ff1e946fa6f4a441d33286737 Mon Sep 17 00:00:00 2001 From: dnwpark Date: Thu, 27 Jun 2024 16:55:39 -0400 Subject: [PATCH] Enable type checking in edb.server.protocol. (#7489) Additionally: - Move/rename `auth_ext.data.Config` to `auth_ext.config.OAuthProviderConfig` - Change to inherit `auth_ext.config.ProviderConfig` - Change to `dataclass` - Add `ai_ext.ProviderConfig` and `ai_ext.ApiStyle` - Add `auth.scram.Session` --- edb/server/dbview/dbview.pyi | 1 + edb/server/protocol/__init__.py | 2 +- edb/server/protocol/ai_ext.py | 57 ++-- edb/server/protocol/auth/__init__.py | 20 +- edb/server/protocol/auth/scram.py | 65 +++-- edb/server/protocol/auth_ext/apple.py | 3 +- edb/server/protocol/auth_ext/azure.py | 4 +- edb/server/protocol/auth_ext/base.py | 8 +- edb/server/protocol/auth_ext/config.py | 9 + edb/server/protocol/auth_ext/data.py | 42 +-- edb/server/protocol/auth_ext/discord.py | 3 +- edb/server/protocol/auth_ext/email.py | 12 +- .../protocol/auth_ext/email_password.py | 16 +- edb/server/protocol/auth_ext/errors.py | 24 +- edb/server/protocol/auth_ext/github.py | 3 +- edb/server/protocol/auth_ext/google.py | 4 +- edb/server/protocol/auth_ext/http.py | 264 ++++++++++++------ edb/server/protocol/auth_ext/http_client.py | 29 +- edb/server/protocol/auth_ext/local.py | 10 +- edb/server/protocol/auth_ext/magic_link.py | 4 +- edb/server/protocol/auth_ext/oauth.py | 20 +- edb/server/protocol/auth_ext/pkce.py | 27 +- edb/server/protocol/auth_ext/slack.py | 4 +- edb/server/protocol/auth_ext/ui/__init__.py | 49 ++-- edb/server/protocol/auth_ext/ui/components.py | 15 +- edb/server/protocol/auth_ext/ui/util.py | 2 +- edb/server/protocol/auth_ext/util.py | 13 +- edb/server/protocol/auth_ext/webauthn.py | 27 +- edb/server/protocol/metrics.py | 21 +- edb/server/protocol/protocol.pyi | 3 +- edb/server/protocol/server_info.py | 23 +- edb/server/protocol/system_api.py | 48 ++-- pyproject.toml | 3 +- 33 files changed, 567 insertions(+), 268 deletions(-) diff --git a/edb/server/dbview/dbview.pyi b/edb/server/dbview/dbview.pyi index 58ce00e4f19..caac4e4171b 100644 --- a/edb/server/dbview/dbview.pyi +++ b/edb/server/dbview/dbview.pyi @@ -50,6 +50,7 @@ class Database: dbver: int db_config: Config extensions: set[str] + user_config_spec: config.Spec @property def server(self) -> server.Server: diff --git a/edb/server/protocol/__init__.py b/edb/server/protocol/__init__.py index 4bc676099f2..151e70e2481 100644 --- a/edb/server/protocol/__init__.py +++ b/edb/server/protocol/__init__.py @@ -19,7 +19,7 @@ from __future__ import annotations -from . import protocol # type: ignore +from . import protocol HttpProtocol = protocol.HttpProtocol diff --git a/edb/server/protocol/ai_ext.py b/edb/server/protocol/ai_ext.py index 8e42da35e55..13a981bab8b 100644 --- a/edb/server/protocol/ai_ext.py +++ b/edb/server/protocol/ai_ext.py @@ -19,6 +19,7 @@ from __future__ import annotations from dataclasses import dataclass from typing import ( + cast, Any, AsyncIterator, ClassVar, @@ -42,6 +43,7 @@ from edb import errors from edb.common import asyncutil from edb.common import debug +from edb.common import enum as s_enum from edb.common import markup from edb.common import uuidgen @@ -101,6 +103,21 @@ class BadRequestError(AIExtError): http_status = http.HTTPStatus.BAD_REQUEST +class ApiStyle(s_enum.StrEnum): + OpenAI = 'OpenAI' + Anthropic = 'Anthropic' + + +@dataclass +class ProviderConfig: + name: str + display_name: str + api_url: str + client_id: str + secret: str + api_style: ApiStyle + + def start_extension( tenant: srv_tenant.Tenant, dbname: str, @@ -565,7 +582,7 @@ async def _update_embeddings_in_db( async def _generate_embeddings( - provider, + provider: ProviderConfig, model_name: str, inputs: list[str], shortening: Optional[int], @@ -578,7 +595,7 @@ async def _generate_embeddings( f"of {provider.name!r} for {len(inputs)} object{suf}" ) - if provider.api_style == "OpenAI": + if provider.api_style == ApiStyle.OpenAI: return await _generate_openai_embeddings( provider, model_name, inputs, shortening, ) @@ -590,7 +607,7 @@ async def _generate_embeddings( async def _generate_openai_embeddings( - provider, + provider: ProviderConfig, model_name: str, inputs: list[str], shortening: Optional[int], @@ -676,9 +693,9 @@ async def _start_chat( protocol: protocol.HttpProtocol, request: protocol.HttpRequest, response: protocol.HttpResponse, - provider, + provider: ProviderConfig, model_name: str, - messages: list[dict], + messages: list[dict[str, Any]], stream: bool, ) -> None: if provider.api_style == "OpenAI": @@ -725,7 +742,7 @@ async def _start_openai_like_chat( response: protocol.HttpResponse, client: httpx.AsyncClient, model_name: str, - messages: list[dict], + messages: list[dict[str, Any]], stream: bool, ) -> None: if stream: @@ -847,9 +864,9 @@ async def _start_openai_chat( protocol: protocol.HttpProtocol, request: protocol.HttpRequest, response: protocol.HttpResponse, - provider, + provider: ProviderConfig, model_name: str, - messages: list[dict], + messages: list[dict[str, Any]], stream: bool, ) -> None: headers = { @@ -879,9 +896,9 @@ async def _start_anthropic_chat( protocol: protocol.HttpProtocol, request: protocol.HttpRequest, response: protocol.HttpResponse, - provider, + provider: ProviderConfig, model_name: str, - messages: list[dict], + messages: list[dict[str, Any]], stream: bool, ) -> None: headers = { @@ -1010,7 +1027,7 @@ async def handle_request( db: dbview.Database, args: list[str], tenant: srv_tenant.Tenant, -): +) -> None: if len(args) != 1 or args[0] not in {"rag", "embeddings"}: response.body = b'Unknown path' response.status = http.HTTPStatus.NOT_FOUND @@ -1239,7 +1256,7 @@ async def _handle_rag_request( "messages": [], } - messages: dict[str, list[dict]] = {} + messages: dict[str, list[dict[str, Any]]] = {} for message in prompt["messages"]: if message["participant_role"] == "User": content = message["content"].format( @@ -1356,7 +1373,7 @@ async def _edgeql_query_json( except Exception as iex: raise iex from None else: - return content + return cast(list[Any], content) async def _db_error( @@ -1394,12 +1411,20 @@ async def _db_error( def _get_provider_config( db: dbview.Database, provider_name: str, -) -> Any: +) -> ProviderConfig: cfg = db.lookup_config("ext::ai::Config::providers") for provider in cfg: if provider.name == provider_name: - return provider + provider = cast(ProviderConfig, provider) + return ProviderConfig( + name=provider.name, + display_name=provider.display_name, + api_url=provider.api_url, + client_id=provider.client_id, + secret=provider.secret, + api_style=provider.api_style, + ) else: raise ConfigurationError( f"provider {provider_name!r} has not been configured" @@ -1447,7 +1472,7 @@ async def _get_model_provider( elif len(models) > 1: raise InternalError("multiple models defined as requested model") - return models[0]["provider"] + return cast(str, models[0]["provider"]) async def _generate_embeddings_for_type( diff --git a/edb/server/protocol/auth/__init__.py b/edb/server/protocol/auth/__init__.py index 683199d4a51..2a9c9592ab1 100644 --- a/edb/server/protocol/auth/__init__.py +++ b/edb/server/protocol/auth/__init__.py @@ -17,6 +17,8 @@ # +from __future__ import annotations +from typing import Type, TYPE_CHECKING import http import json @@ -26,8 +28,17 @@ from . import scram +if TYPE_CHECKING: + from edb.server import tenant as edbtenant + from edb.server.protocol import protocol -async def handle_request(request, response, path_parts, tenant): + +async def handle_request( + request: protocol.HttpRequest, + response: protocol.HttpResponse, + path_parts: list[str], + tenant: edbtenant.Tenant, +) -> None: try: if path_parts == ["token"]: if not request.authorization: @@ -68,7 +79,12 @@ async def handle_request(request, response, path_parts, tenant): ) -def _response_error(response, status, message, ex_type): +def _response_error( + response: protocol.HttpResponse, + status: http.HTTPStatus, + message: str, + ex_type: Type[errors.EdgeDBError], +) -> None: err_dct = { "message": message, "type": str(ex_type.__name__), diff --git a/edb/server/protocol/auth/scram.py b/edb/server/protocol/auth/scram.py index a706d2c5d5c..436c549d1fe 100644 --- a/edb/server/protocol/auth/scram.py +++ b/edb/server/protocol/auth/scram.py @@ -16,6 +16,8 @@ # limitations under the License. # +from __future__ import annotations +from typing import NamedTuple, Optional, TYPE_CHECKING import base64 import collections import hashlib @@ -29,13 +31,36 @@ from edb.common import markup from edb.common import secretkey +if TYPE_CHECKING: + from edb.server import tenant as edbtenant + from edb.server.protocol import protocol -SESSION_TIMEOUT = 30 -SESSION_HIGH_WATER_MARK = SESSION_TIMEOUT * 10 -sessions: collections.OrderedDict[str, tuple] = collections.OrderedDict() +SESSION_TIMEOUT: float = 30 +SESSION_HIGH_WATER_MARK: float = SESSION_TIMEOUT * 10 -def handle_request(scheme, auth_str, response, tenant): + +class Session(NamedTuple): + time: float + client_nonce: str + server_nonce: str + client_first_bare: bytes + cb_flag: bool + server_first: bytes + verifier: scram.SCRAMVerifier + mock_auth: bool + username: str + + +sessions: collections.OrderedDict[str, Session] = collections.OrderedDict() + + +def handle_request( + scheme: str, + auth_str: str, + response: protocol.HttpResponse, + tenant: edbtenant.Tenant, +) -> None: server = tenant.server if scheme != "SCRAM-SHA-256": response.body = ( @@ -64,7 +89,7 @@ def handle_request(scheme, auth_str, response, tenant): response.close_connection = True return - if not server.get_jws_key().has_private: + if not server.get_jws_key().has_private: # type: ignore[union-attr] response.body = b"Server doesn't support HTTP SCRAM authentication" response.status = http.HTTPStatus.FORBIDDEN response.close_connection = True @@ -72,11 +97,16 @@ def handle_request(scheme, auth_str, response, tenant): if sid is None: try: + bare_offset: int + cb_flag: bool + authzid: Optional[bytes] + username_bytes: bytes + client_nonce: str ( bare_offset, cb_flag, authzid, - username, + username_bytes, client_nonce, ) = scram.parse_client_first_message(data) except ValueError as ex: @@ -87,7 +117,7 @@ def handle_request(scheme, auth_str, response, tenant): response.close_connection = True return - username = username.decode("utf-8") + username = username_bytes.decode("utf-8") client_first_bare = data[bare_offset:] if isinstance(cb_flag, str): @@ -120,16 +150,16 @@ def handle_request(scheme, auth_str, response, tenant): response.custom_headers["WWW-Authenticate"] = "SCRAM-SHA-256" return - server_nonce = scram.generate_nonce() - server_first = scram.build_server_first_message( + server_nonce: str = scram.generate_nonce() + server_first: bytes = scram.build_server_first_message( server_nonce, client_nonce, verifier.salt, verifier.iterations ).encode("utf-8") if len(sessions) > SESSION_HIGH_WATER_MARK: while sessions: - key, value = sessions.popitem(last=False) - if value[0] + SESSION_TIMEOUT > time.monotonic(): - sessions[key] = value + key, session = sessions.popitem(last=False) + if session.time + SESSION_TIMEOUT > time.monotonic(): + sessions[key] = session sessions.move_to_end(key, last=False) break @@ -139,7 +169,7 @@ def handle_request(scheme, auth_str, response, tenant): .rstrip("=") ) assert sid not in sessions - sessions[sid] = ( + sessions[sid] = Session( time.monotonic(), client_nonce, server_nonce, @@ -151,11 +181,11 @@ def handle_request(scheme, auth_str, response, tenant): username, ) - server_first = base64.b64encode(server_first).decode("ascii") + server_first_str = base64.b64encode(server_first).decode("ascii") response.status = http.HTTPStatus.UNAUTHORIZED response.custom_headers[ "WWW-Authenticate" - ] = f"SCRAM-SHA-256 sid={sid}, data={server_first}" + ] = f"SCRAM-SHA-256 sid={sid}, data={server_first_str}" else: session = sessions.pop(sid) @@ -255,7 +285,10 @@ def handle_request(scheme, auth_str, response, tenant): ] = f"sid={sid}, data={server_final}" -def get_scram_verifier(user, tenant): +def get_scram_verifier( + user: str, + tenant: edbtenant.Tenant, +) -> tuple[scram.SCRAMVerifier, bool]: roles = tenant.get_roles() rolerec = roles.get(user) diff --git a/edb/server/protocol/auth_ext/apple.py b/edb/server/protocol/auth_ext/apple.py index b253de4cfc5..f5efa1b4750 100644 --- a/edb/server/protocol/auth_ext/apple.py +++ b/edb/server/protocol/auth_ext/apple.py @@ -16,6 +16,7 @@ # limitations under the License. # +from typing import Any import uuid import urllib.parse @@ -24,7 +25,7 @@ class AppleProvider(base.OpenIDProvider): - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any): super().__init__( "apple", "https://appleid.apple.com", diff --git a/edb/server/protocol/auth_ext/azure.py b/edb/server/protocol/auth_ext/azure.py index 396fd4e8828..4255707c592 100644 --- a/edb/server/protocol/auth_ext/azure.py +++ b/edb/server/protocol/auth_ext/azure.py @@ -17,11 +17,13 @@ # +from typing import Any + from . import base class AzureProvider(base.OpenIDProvider): - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any): super().__init__( "azure", "https://login.microsoftonline.com/common/v2.0", diff --git a/edb/server/protocol/auth_ext/base.py b/edb/server/protocol/auth_ext/base.py index 113d2e1f8e2..1c191c77966 100644 --- a/edb/server/protocol/auth_ext/base.py +++ b/edb/server/protocol/auth_ext/base.py @@ -21,7 +21,7 @@ import json import enum -from typing import Callable +from typing import Any, Callable from jwcrypto import jwt, jwk from datetime import datetime @@ -75,9 +75,9 @@ def __init__( self, name: str, issuer_url: str, - *args, + *args: Any, content_type: ContentType = ContentType.JSON, - **kwargs, + **kwargs: Any, ): super().__init__(name, issuer_url, *args, **kwargs) self.content_type = content_type @@ -183,7 +183,7 @@ async def fetch_user_info( picture=payload.get("picture"), ) - async def _get_oidc_config(self): + async def _get_oidc_config(self) -> data.OpenIDConfig: client = self.http_factory(base_url=self.issuer_url) response = await client.get('/.well-known/openid-configuration') config = response.json() diff --git a/edb/server/protocol/auth_ext/config.py b/edb/server/protocol/auth_ext/config.py index dfb60ba691a..431bf479033 100644 --- a/edb/server/protocol/auth_ext/config.py +++ b/edb/server/protocol/auth_ext/config.py @@ -42,10 +42,19 @@ class AppDetailsConfig: brand_color: Optional[str] +@dataclass class ProviderConfig: name: str +@dataclass +class OAuthProviderConfig(ProviderConfig): + display_name: str + client_id: str + secret: str + additional_scope: Optional[str] + + class WebAuthnProviderConfig(ProviderConfig): relying_party_origin: str require_verification: bool diff --git a/edb/server/protocol/auth_ext/data.py b/edb/server/protocol/auth_ext/data.py index a3fed0ec029..d2d2b6c051b 100644 --- a/edb/server/protocol/auth_ext/data.py +++ b/edb/server/protocol/auth_ext/data.py @@ -21,7 +21,7 @@ import datetime import base64 -from typing import Optional, NamedTuple +from typing import Any, Optional @dataclasses.dataclass @@ -96,7 +96,7 @@ class OpenIDConfig: token_endpoint: str jwks_uri: str - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any): for field in dataclasses.fields(self): setattr(self, field.name, kwargs.get(field.name)) @@ -116,7 +116,7 @@ class OAuthAccessTokenResponse: expires_in: int refresh_token: str | None - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any): for field in dataclasses.fields(self): if field.name in kwargs: setattr(self, field.name, kwargs.pop(field.name)) @@ -134,16 +134,10 @@ class OpenIDConnectAccessTokenResponse(OAuthAccessTokenResponse): id_token: str - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any): super().__init__(**kwargs) -class ProviderConfig(NamedTuple): - client_id: str - secret: str - additional_scope: Optional[str] - - @dataclasses.dataclass class WebAuthnFactor: id: str @@ -159,15 +153,15 @@ class WebAuthnFactor: def __init__( self, *, - id, - created_at, - modified_at, - identity, - email, - verified_at, - user_handle, - credential_id, - public_key, + id: str, + created_at: datetime.datetime, + modified_at: datetime.datetime, + identity: LocalIdentity, + email: str, + verified_at: Optional[datetime.datetime], + user_handle: bytes, + credential_id: bytes, + public_key: bytes, ): self.id = id self.created_at = created_at @@ -192,7 +186,15 @@ class WebAuthnAuthenticationChallenge: challenge: bytes factors: list[WebAuthnFactor] - def __init__(self, *, id, created_at, modified_at, challenge, factors): + def __init__( + self, + *, + id: str, + created_at: datetime.datetime, + modified_at: datetime.datetime, + challenge: bytes, + factors: list[WebAuthnFactor], + ): self.id = id self.created_at = created_at self.modified_at = modified_at diff --git a/edb/server/protocol/auth_ext/discord.py b/edb/server/protocol/auth_ext/discord.py index 76b58a033e8..dce3f39aa63 100644 --- a/edb/server/protocol/auth_ext/discord.py +++ b/edb/server/protocol/auth_ext/discord.py @@ -17,6 +17,7 @@ # +from typing import Any import urllib.parse import functools @@ -24,7 +25,7 @@ class DiscordProvider(base.BaseProvider): - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any): super().__init__("discord", "https://discord.com", *args, **kwargs) self.auth_domain = self.issuer_url self.api_domain = f"{self.issuer_url}/api/v10" diff --git a/edb/server/protocol/auth_ext/email.py b/edb/server/protocol/auth_ext/email.py index dfec5075681..bffe9f1f5c4 100644 --- a/edb/server/protocol/auth_ext/email.py +++ b/edb/server/protocol/auth_ext/email.py @@ -14,7 +14,7 @@ async def send_password_reset_email( to_addr: str, reset_url: str, test_mode: bool, -): +) -> None: from_addr = util.get_config(db, "ext::auth::SMTPConfig::sender") app_details_config = util.get_app_details_config(db) if app_details_config is None: @@ -50,7 +50,7 @@ async def send_verification_email( verification_token: str, provider: str, test_mode: bool, -): +) -> None: from_addr = util.get_config(db, "ext::auth::SMTPConfig::sender") app_details_config = util.get_app_details_config(db) verification_token_params = urllib.parse.urlencode( @@ -92,7 +92,7 @@ async def send_magic_link_email( to_addr: str, link: str, test_mode: bool, -): +) -> None: from_addr = util.get_config(db, "ext::auth::SMTPConfig::sender") app_details_config = util.get_app_details_config(db) if app_details_config is None: @@ -120,8 +120,8 @@ async def send_magic_link_email( await _protected_send(coro, tenant) -async def send_fake_email(tenant: tenant.Tenant): - async def noop_coroutine(): +async def send_fake_email(tenant: tenant.Tenant) -> None: + async def noop_coroutine() -> None: pass coro = noop_coroutine() @@ -130,7 +130,7 @@ async def noop_coroutine(): async def _protected_send( coro: Coroutine[Any, Any, None], tenant: tenant.Tenant -): +) -> None: task = tenant.create_task(coro, interruptable=False) # Prevent timing attack await asyncio.sleep(random.random() * 0.5) diff --git a/edb/server/protocol/auth_ext/email_password.py b/edb/server/protocol/auth_ext/email_password.py index 01ae0e87971..5f8fab553cb 100644 --- a/edb/server/protocol/auth_ext/email_password.py +++ b/edb/server/protocol/auth_ext/email_password.py @@ -22,7 +22,7 @@ import base64 import dataclasses -from typing import Any +from typing import Any, Optional from edb.errors import ConstraintViolationError from edb.server.protocol import execute @@ -42,7 +42,7 @@ def __init__(self, db: Any): super().__init__(db) self.config = self._get_provider_config("builtin::local_emailpassword") - async def register(self, input: dict[str, Any]): + async def register(self, input: dict[str, Any]) -> data.LocalIdentity: match (input.get("email"), input.get("password")): case (str(e), str(p)): email = e @@ -88,7 +88,7 @@ async def register(self, input: dict[str, Any]): return data.LocalIdentity(**result_json[0]) - async def authenticate(self, input: dict[str, Any]): + async def authenticate(self, input: dict[str, Any]) -> data.LocalIdentity: if 'email' not in input or 'password' not in input: raise errors.InvalidData("Missing 'email' or 'password' in data") @@ -143,7 +143,9 @@ async def authenticate(self, input: dict[str, Any]): return local_identity - async def get_identity_and_secret(self, input: dict[str, Any]): + async def get_identity_and_secret( + self, input: dict[str, Any], + ) -> tuple[data.LocalIdentity, str]: if 'email' not in input: raise errors.InvalidData("Missing 'email' in data") @@ -175,7 +177,9 @@ async def get_identity_and_secret(self, input: dict[str, Any]): return (local_identity, secret) - async def validate_reset_secret(self, identity_id: str, secret: str): + async def validate_reset_secret( + self, identity_id: str, secret: str, + ) -> Optional[data.LocalIdentity]: r = await execute.parse_execute_json( db=self.db, query="""\ @@ -204,7 +208,7 @@ async def validate_reset_secret(self, identity_id: str, secret: str): async def update_password( self, identity_id: str, secret: str, input: dict[str, Any] - ): + ) -> data.LocalIdentity: if 'password' not in input: raise errors.InvalidData("Missing 'password' in data") diff --git a/edb/server/protocol/auth_ext/errors.py b/edb/server/protocol/auth_ext/errors.py index 653ab304394..8fffcf39209 100644 --- a/edb/server/protocol/auth_ext/errors.py +++ b/edb/server/protocol/auth_ext/errors.py @@ -29,7 +29,7 @@ class NotFound(AuthExtError): def __init__(self, description: str): self.description = description - def __repr__(self): + def __repr__(self) -> str: return ( f"{self.__class__.__name__}(" f"description={self.description!r}" @@ -47,7 +47,7 @@ def __init__(self, key: str, description: str): self.key = key self.description = description - def __repr__(self): + def __repr__(self) -> str: return ( f"{self.__class__.__name__}(" f"key={self.key!r} " @@ -65,7 +65,7 @@ class InvalidData(AuthExtError): def __init__(self, description: str): self.description = description - def __repr__(self): + def __repr__(self) -> str: return ( f"{self.__class__.__name__}(" f"description={self.description!r}" @@ -82,7 +82,7 @@ class MisconfiguredProvider(AuthExtError): def __init__(self, description: str): self.description = description - def __repr__(self): + def __repr__(self) -> str: return ( f"{self.__class__.__name__}(" f"description={self.description!r}" @@ -104,7 +104,7 @@ def __init__( ): self.description = description - def __repr__(self): + def __repr__(self) -> str: return ( f"{self.__class__.__name__}(" f"description={self.description!r}" @@ -124,7 +124,7 @@ def __init__( ): self.description = description - def __repr__(self): + def __repr__(self) -> str: return ( f"{self.__class__.__name__}(" f"description={self.description!r}" @@ -144,7 +144,7 @@ def __init__( ): self.description = description - def __repr__(self): + def __repr__(self) -> str: return ( f"{self.__class__.__name__}(" f"description={self.description!r}" @@ -164,7 +164,7 @@ def __init__( ): self.description = description - def __repr__(self): + def __repr__(self) -> str: return ( f"{self.__class__.__name__}(" f"description={self.description!r}" @@ -184,7 +184,7 @@ def __init__( ): self.description = description - def __repr__(self): + def __repr__(self) -> str: return ( f"{self.__class__.__name__}(" f"description={self.description!r}" @@ -203,7 +203,7 @@ def __init__( ): self.description = description - def __repr__(self): + def __repr__(self) -> str: return ( f"{self.__class__.__name__}(" f"description={self.description!r}" @@ -222,7 +222,7 @@ def __init__( ): self.description = description - def __repr__(self): + def __repr__(self) -> str: return ( f"{self.__class__.__name__}(" f"description={self.description!r}" @@ -239,7 +239,7 @@ class WebAuthnAuthenticationFailed(AuthExtError): def __init__(self, description: str = "WebAuthn authentication failed"): self.description = description - def __repr__(self): + def __repr__(self) -> str: return ( f"{self.__class__.__name__}(" f"description={self.description!r}" diff --git a/edb/server/protocol/auth_ext/github.py b/edb/server/protocol/auth_ext/github.py index c99b42daadb..afa7a676162 100644 --- a/edb/server/protocol/auth_ext/github.py +++ b/edb/server/protocol/auth_ext/github.py @@ -17,6 +17,7 @@ # +from typing import Any import urllib.parse import functools @@ -24,7 +25,7 @@ class GitHubProvider(base.BaseProvider): - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any): super().__init__("github", "https://github.com", *args, **kwargs) self.auth_domain = self.issuer_url self.api_domain = "https://api.github.com" diff --git a/edb/server/protocol/auth_ext/google.py b/edb/server/protocol/auth_ext/google.py index c84dfed79f1..61e3a3d2f72 100644 --- a/edb/server/protocol/auth_ext/google.py +++ b/edb/server/protocol/auth_ext/google.py @@ -17,11 +17,13 @@ # +from typing import Any + from . import base class GoogleProvider(base.OpenIDProvider): - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any): super().__init__( "google", "https://accounts.google.com", *args, **kwargs ) diff --git a/edb/server/protocol/auth_ext/http.py b/edb/server/protocol/auth_ext/http.py index 5f66f6dea87..f1d742fe5f8 100644 --- a/edb/server/protocol/auth_ext/http.py +++ b/edb/server/protocol/auth_ext/http.py @@ -17,6 +17,8 @@ # +from __future__ import annotations + import datetime import http import http.cookies @@ -29,7 +31,7 @@ import mimetypes import uuid -from typing import Any, Optional, Tuple, FrozenSet, cast +from typing import Any, Optional, Tuple, FrozenSet, cast, TYPE_CHECKING import aiosmtplib from jwcrypto import jwk, jwt @@ -54,6 +56,9 @@ magic_link, ) +if TYPE_CHECKING: + from edb.server.protocol import protocol + logger = logging.getLogger('edb.server') @@ -61,15 +66,24 @@ class Router: test_url: Optional[str] - def __init__(self, *, db: Any, base_path: str, tenant: edbtenant.Tenant): + def __init__( + self, + *, + db: edbtenant.dbview.Database, + base_path: str, + tenant: edbtenant.Tenant + ): self.db = db self.base_path = base_path self.tenant = tenant self.test_mode = tenant.server.in_test_mode() async def handle_request( - self, request: Any, response: Any, args: list[str] - ): + self, + request: protocol.HttpRequest, + response: protocol.HttpResponse, + args: list[str], + ) -> None: if self.db.db_config is None: await self.db.introspection() @@ -83,72 +97,65 @@ async def handle_request( else None ) - handler_args = (request, response) try: match args: # API routes case ("authorize",): - return await self.handle_authorize(*handler_args) + await self.handle_authorize(request, response) case ("callback",): - return await self.handle_callback(*handler_args) + await self.handle_callback(request, response) case ("token",): - return await self.handle_token(*handler_args) + await self.handle_token(request, response) case ("register",): - return await self.handle_register(*handler_args) + await self.handle_register(request, response) case ("authenticate",): - return await self.handle_authenticate(*handler_args) + await self.handle_authenticate(request, response) case ("verify",): - return await self.handle_verify(*handler_args) + await self.handle_verify(request, response) case ("resend-verification-email",): - return await self.handle_resend_verification_email( - *handler_args + await self.handle_resend_verification_email( + request, response ) case ('send-reset-email',): - return await self.handle_send_reset_email(*handler_args) + await self.handle_send_reset_email(request, response) case ('reset-password',): - return await self.handle_reset_password(*handler_args) + await self.handle_reset_password(request, response) case ('magic-link', 'register'): - return await self.handle_magic_link_register(*handler_args) + await self.handle_magic_link_register(request, response) case ('magic-link', 'email'): - return await self.handle_magic_link_email(*handler_args) + await self.handle_magic_link_email(request, response) case ('magic-link', 'authenticate'): - return await self.handle_magic_link_authenticate( - *handler_args - ) + await self.handle_magic_link_authenticate(request, response) # WebAuthn routes case ('webauthn', 'register'): - return await self.handle_webauthn_register(*handler_args) + await self.handle_webauthn_register(request, response) case ('webauthn', 'register', 'options'): - return await self.handle_webauthn_register_options( - *handler_args + await self.handle_webauthn_register_options( + request, response ) case ('webauthn', 'authenticate'): - return await self.handle_webauthn_authenticate( - *handler_args - ) + await self.handle_webauthn_authenticate(request, response) case ('webauthn', 'authenticate', 'options'): - return await self.handle_webauthn_authenticate_options( - *handler_args + await self.handle_webauthn_authenticate_options( + request, response ) # UI routes case ('ui', 'signin'): - return await self.handle_ui_signin(*handler_args) + await self.handle_ui_signin(request, response) case ('ui', 'signup'): - return await self.handle_ui_signup(*handler_args) + await self.handle_ui_signup(request, response) case ('ui', 'forgot-password'): - return await self.handle_ui_forgot_password(*handler_args) + await self.handle_ui_forgot_password(request, response) case ('ui', 'reset-password'): - return await self.handle_ui_reset_password(*handler_args) + await self.handle_ui_reset_password(request, response) case ("ui", "verify"): - return await self.handle_ui_verify(*handler_args) + await self.handle_ui_verify(request, response) case ("ui", "resend-verification"): - return await self.handle_ui_resend_verification( - *handler_args - ) + await self.handle_ui_resend_verification(request, response) case ("ui", "magic-link-sent"): - return await self.handle_ui_magic_link_sent(*handler_args) + await self.handle_ui_magic_link_sent(request, response) case ('ui', '_static', filename): filepath = os.path.join( os.path.dirname(__file__), '_static', filename @@ -234,7 +241,11 @@ async def handle_request( ex=edb_errors.InternalServerError(str(ex)), ) - async def handle_authorize(self, request: Any, response: Any): + async def handle_authorize( + self, + request: protocol.HttpRequest, + response: protocol.HttpResponse, + ) -> None: query = urllib.parse.parse_qs( request.url.query.decode("ascii") if request.url.query else "" ) @@ -269,7 +280,11 @@ async def handle_authorize(self, request: Any, response: Any): response.status = http.HTTPStatus.FOUND response.custom_headers["Location"] = authorize_url - async def handle_callback(self, request: Any, response: Any): + async def handle_callback( + self, + request: protocol.HttpRequest, + response: protocol.HttpResponse, + ) -> None: if request.method == b"POST" and ( request.content_type == b"application/x-www-form-urlencoded" ): @@ -383,7 +398,11 @@ async def handle_callback(self, request: Any, response: Any): response.custom_headers["Location"] = new_url _set_cookie(response, "edgedb-session", session_token) - async def handle_token(self, request: Any, response: Any): + async def handle_token( + self, + request: protocol.HttpRequest, + response: protocol.HttpResponse, + ) -> None: query = urllib.parse.parse_qs( request.url.query.decode("ascii") if request.url.query else "" ) @@ -431,7 +450,11 @@ async def handle_token(self, request: Any, response: Any): else: raise errors.PKCEVerificationFailed - async def handle_register(self, request: Any, response: Any): + async def handle_register( + self, + request: protocol.HttpRequest, + response: protocol.HttpResponse, + ) -> None: data = self._get_data_from_request(request) maybe_redirect_to = cast(Optional[str], data.get("redirect_to")) @@ -516,7 +539,11 @@ async def handle_register(self, request: Any, response: Any): else: raise ex - async def handle_authenticate(self, request: Any, response: Any): + async def handle_authenticate( + self, + request: protocol.HttpRequest, + response: protocol.HttpResponse, + ) -> None: data = self._get_data_from_request(request) authenticate_provider_name = data.get("provider") @@ -588,7 +615,11 @@ async def handle_authenticate(self, request: Any, response: Any): else: raise ex - async def handle_verify(self, request: Any, response: Any): + async def handle_verify( + self, + request: protocol.HttpRequest, + response: protocol.HttpResponse, + ) -> None: data = self._get_data_from_request(request) _check_keyset(data, {"verification_token", "provider"}) @@ -645,8 +676,10 @@ async def handle_verify(self, request: Any, response: Any): response.status = http.HTTPStatus.NO_CONTENT async def handle_resend_verification_email( - self, request: Any, response: Any - ): + self, + request: protocol.HttpRequest, + response: protocol.HttpResponse, + ) -> None: data = self._get_data_from_request(request) _check_keyset(data, {"provider"}) @@ -702,7 +735,11 @@ async def handle_resend_verification_email( response.status = http.HTTPStatus.OK - async def handle_send_reset_email(self, request: Any, response: Any): + async def handle_send_reset_email( + self, + request: protocol.HttpRequest, + response: protocol.HttpResponse, + ) -> None: data = self._get_data_from_request(request) _check_keyset(data, {"provider", "email", "reset_url", "challenge"}) @@ -786,7 +823,11 @@ async def handle_send_reset_email(self, request: Any, response: Any): else: raise ex - async def handle_reset_password(self, request: Any, response: Any): + async def handle_reset_password( + self, + request: protocol.HttpRequest, + response: protocol.HttpResponse, + ) -> None: data = self._get_data_from_request(request) _check_keyset(data, {"provider", "reset_token"}) @@ -842,7 +883,11 @@ async def handle_reset_password(self, request: Any, response: Any): else: raise ex - async def handle_magic_link_register(self, request: Any, response: Any): + async def handle_magic_link_register( + self, + request: protocol.HttpRequest, + response: protocol.HttpResponse, + ) -> None: data = self._get_data_from_request(request) _check_keyset( @@ -931,7 +976,11 @@ async def handle_magic_link_register(self, request: Any, response: Any): redirect_on_failure, redirect_params ) - async def handle_magic_link_email(self, request: Any, response: Any): + async def handle_magic_link_email( + self, + request: protocol.HttpRequest, + response: protocol.HttpResponse, + ) -> None: data = self._get_data_from_request(request) maybe_redirect_to = data.get("redirect_to") @@ -1014,7 +1063,11 @@ async def handle_magic_link_email(self, request: Any, response: Any): redirect_on_failure, redirect_params ) - async def handle_magic_link_authenticate(self, request: Any, response: Any): + async def handle_magic_link_authenticate( + self, + request: protocol.HttpRequest, + response: protocol.HttpResponse, + ) -> None: query = urllib.parse.parse_qs( request.url.query.decode("ascii") if request.url.query else "" ) @@ -1062,8 +1115,10 @@ async def handle_magic_link_authenticate(self, request: Any, response: Any): ) async def handle_webauthn_register_options( - self, request: Any, response: Any - ): + self, + request: protocol.HttpRequest, + response: protocol.HttpResponse, + ) -> None: query = urllib.parse.parse_qs( request.url.query.decode("ascii") if request.url.query else "" ) @@ -1086,7 +1141,11 @@ async def handle_webauthn_register_options( ) response.body = registration_options - async def handle_webauthn_register(self, request: Any, response: Any): + async def handle_webauthn_register( + self, + request: protocol.HttpRequest, + response: protocol.HttpResponse, + ) -> None: data = self._get_data_from_request(request) _check_keyset( @@ -1166,8 +1225,10 @@ async def handle_webauthn_register(self, request: Any, response: Any): ).encode() async def handle_webauthn_authenticate_options( - self, request: Any, response: Any - ): + self, + request: protocol.HttpRequest, + response: protocol.HttpResponse, + ) -> None: query = urllib.parse.parse_qs( request.url.query.decode("ascii") if request.url.query else "" ) @@ -1190,7 +1251,11 @@ async def handle_webauthn_authenticate_options( response.content_type = b"application/json" response.body = registration_options - async def handle_webauthn_authenticate(self, request: Any, response: Any): + async def handle_webauthn_authenticate( + self, + request: protocol.HttpRequest, + response: protocol.HttpResponse, + ) -> None: data = self._get_data_from_request(request) _check_keyset( @@ -1229,7 +1294,11 @@ async def handle_webauthn_authenticate(self, request: Any, response: Any): } ).encode() - async def handle_ui_signin(self, request: Any, response: Any): + async def handle_ui_signin( + self, + request: protocol.HttpRequest, + response: protocol.HttpResponse, + ) -> None: ui_config = self._get_ui_config() if ui_config is None: @@ -1279,7 +1348,11 @@ async def handle_ui_signin(self, request: Any, response: Any): brand_color=app_details_config.brand_color, ) - async def handle_ui_signup(self, request: Any, response: Any): + async def handle_ui_signup( + self, + request: protocol.HttpRequest, + response: protocol.HttpResponse, + ) -> None: ui_config = self._get_ui_config() if ui_config is None: response.status = http.HTTPStatus.NOT_FOUND @@ -1328,7 +1401,11 @@ async def handle_ui_signup(self, request: Any, response: Any): brand_color=app_details_config.brand_color, ) - async def handle_ui_forgot_password(self, request: Any, response: Any): + async def handle_ui_forgot_password( + self, + request: protocol.HttpRequest, + response: protocol.HttpResponse, + ) -> None: ui_config = self._get_ui_config() password_provider = ( self._get_password_provider() if ui_config is not None else None @@ -1365,7 +1442,11 @@ async def handle_ui_forgot_password(self, request: Any, response: Any): brand_color=app_details_config.brand_color, ) - async def handle_ui_reset_password(self, request: Any, response: Any): + async def handle_ui_reset_password( + self, + request: protocol.HttpRequest, + response: protocol.HttpResponse, + ) -> None: ui_config = self._get_ui_config() password_provider = ( self._get_password_provider() if ui_config is not None else None @@ -1401,7 +1482,7 @@ async def handle_ui_reset_password(self, request: Any, response: Any): is_valid = ( await email_password_client.validate_reset_secret( identity_id, secret - ) + ) is not None ) except Exception: is_valid = False @@ -1425,7 +1506,11 @@ async def handle_ui_reset_password(self, request: Any, response: Any): brand_color=app_details_config.brand_color, ) - async def handle_ui_verify(self, request: Any, response: Any): + async def handle_ui_verify( + self, + request: protocol.HttpRequest, + response: protocol.HttpResponse, + ) -> None: error_messages: list[str] = [] ui_config = self._get_ui_config() if ui_config is None: @@ -1536,7 +1621,11 @@ async def handle_ui_verify(self, request: Any, response: Any): brand_color=app_details_config.brand_color, ) - async def handle_ui_resend_verification(self, request: Any, response: Any): + async def handle_ui_resend_verification( + self, + request: protocol.HttpRequest, + response: protocol.HttpResponse, + ) -> None: query = urllib.parse.parse_qs( request.url.query.decode("ascii") if request.url.query else "" ) @@ -1594,7 +1683,11 @@ async def handle_ui_resend_verification(self, request: Any, response: Any): brand_color=app_details_config.brand_color, ) - async def handle_ui_magic_link_sent(self, request: Any, response: Any): + async def handle_ui_magic_link_sent( + self, + request: protocol.HttpRequest, + response: protocol.HttpResponse, + ) -> None: """ Success page for when a magic link is sent """ @@ -1646,7 +1739,7 @@ def _make_state_claims( claims=state_claims, ) state_token.make_signed_token(signing_key) - return state_token.serialize() + return cast(str, state_token.serialize()) def _make_session_token(self, identity_id: str) -> str: signing_key = self._get_auth_signing_key() @@ -1670,7 +1763,7 @@ def _make_session_token(self, identity_id: str) -> str: ) session_token.make_signed_token(signing_key) metrics.auth_successful_logins.inc(1.0, self.tenant.get_instance_name()) - return session_token.serialize() + return cast(str, session_token.serialize()) def _get_from_claims(self, state: str, key: str) -> str: signing_key = self._get_auth_signing_key() @@ -1719,9 +1812,14 @@ def _verify_and_extract_claims( else: signing_key = util.derive_key(input_key_material, key_info) verified = jwt.JWT(key=signing_key, jwt=jwtStr) - return json.loads(verified.claims) + return cast( + dict[str, str | int | float | bool], + json.loads(verified.claims) + ) - def _get_data_from_magic_link_token(self, token: str): + def _get_data_from_magic_link_token( + self, token: str + ) -> tuple[str, str, str]: try: claims = self._verify_and_extract_claims(token, "magic_link") except Exception: @@ -1811,7 +1909,9 @@ def _get_data_from_verification_token( ) return return_value - def _get_data_from_request(self, request: Any) -> dict[Any, Any]: + def _get_data_from_request( + self, request: protocol.HttpRequest, + ) -> dict[Any, Any]: content_type = request.content_type match content_type: case b"application/x-www-form-urlencoded": @@ -1830,10 +1930,10 @@ def _get_data_from_request(self, request: Any) -> dict[Any, Any]: return data case _: raise errors.InvalidData( - f"Unsupported Content-Type: {content_type}" + f"Unsupported Content-Type: {content_type!r}" ) - def _get_ui_config(self): + def _get_ui_config(self) -> config.UIConfig: return cast( config.UIConfig, util.maybe_get_config( @@ -1841,10 +1941,10 @@ def _get_ui_config(self): ), ) - def _get_app_details_config(self): + def _get_app_details_config(self) -> config.AppDetailsConfig: return util.get_app_details_config(self.db) - def _get_password_provider(self): + def _get_password_provider(self) -> Optional[config.ProviderConfig]: providers = cast( list[config.ProviderConfig], util.get_config( @@ -1892,7 +1992,7 @@ async def _send_verification_email( to_addr: str, maybe_challenge: str | None, maybe_redirect_to: str | None, - ): + ) -> None: if not self._is_url_allowed(verify_url): raise errors.InvalidData( "Verify URL does not match any allowed URLs.", @@ -1996,10 +2096,10 @@ def _is_url_allowed(self, url: str) -> bool: def _fail_with_error( *, - response: Any, + response: protocol.HttpResponse, status: http.HTTPStatus, ex: Exception, -): +) -> None: err_dct = { "message": str(ex), "type": str(ex.__class__.__name__), @@ -2044,7 +2144,7 @@ def _maybe_get_form_field( def _get_pkce_challenge( *, - response, + response: protocol.HttpResponse, cookies: http.cookies.SimpleCookie, query_dict: dict[str, list[str]], ) -> str | None: @@ -2061,7 +2161,7 @@ def _get_pkce_challenge( def _set_cookie( - response: Any, + response: protocol.HttpResponse, name: str, value: str, *, @@ -2069,8 +2169,10 @@ def _set_cookie( secure: bool = True, same_site: str = "Strict", path: Optional[str] = None, -): - val: http.cookies.Morsel = http.cookies.SimpleCookie({name: value})[name] +) -> None: + val: http.cookies.Morsel[str] = ( + http.cookies.SimpleCookie({name: value})[name] + ) val["httponly"] = http_only val["secure"] = secure val["samesite"] = same_site @@ -2088,7 +2190,7 @@ def _with_appended_qs(url: str, query: dict[str, list[str]]) -> str: return urllib.parse.urlunparse(url_parts) -def _check_keyset(candidate: dict[str, Any], keyset: set[str]): +def _check_keyset(candidate: dict[str, Any], keyset: set[str]) -> None: missing_fields = [field for field in keyset if field not in candidate] if missing_fields: raise errors.InvalidData( diff --git a/edb/server/protocol/auth_ext/http_client.py b/edb/server/protocol/auth_ext/http_client.py index 8becb135cf0..2e158fe2edf 100644 --- a/edb/server/protocol/auth_ext/http_client.py +++ b/edb/server/protocol/auth_ext/http_client.py @@ -16,6 +16,7 @@ # limitations under the License. # +from typing import Any import urllib.parse import hishel @@ -24,7 +25,11 @@ class HttpClient(httpx.AsyncClient): def __init__( - self, *args, edgedb_test_url: str | None, base_url: str, **kwargs + self, + *args: Any, + edgedb_test_url: str | None, + base_url: str, + **kwargs: Any, ): self.edgedb_orig_base_url = None if edgedb_test_url: @@ -34,14 +39,28 @@ def __init__( transport=httpx.AsyncHTTPTransport(), storage=hishel.AsyncInMemoryStorage(capacity=5), ) - super().__init__(*args, base_url=base_url, transport=cache, **kwargs) + super().__init__( + *args, base_url=base_url, transport=cache, **kwargs + ) - async def post(self, path, *args, **kwargs): + async def post( # type: ignore[override] + self, + path: str, + *args: Any, + **kwargs: Any, + ) -> httpx.Response: if self.edgedb_orig_base_url: path = f'{self.edgedb_orig_base_url}/{path}' - return await super().post(path, *args, **kwargs) + return await super().post( + path, *args, **kwargs + ) - async def get(self, path, *args, **kwargs): + async def get( # type: ignore[override] + self, + path: str, + *args: Any, + **kwargs: Any, + ) -> httpx.Response: if self.edgedb_orig_base_url: path = f'{self.edgedb_orig_base_url}/{path}' return await super().get(path, *args, **kwargs) diff --git a/edb/server/protocol/auth_ext/local.py b/edb/server/protocol/auth_ext/local.py index bf1a0cbb6f4..c2bb3a7e39e 100644 --- a/edb/server/protocol/auth_ext/local.py +++ b/edb/server/protocol/auth_ext/local.py @@ -22,7 +22,7 @@ import base64 from jwcrypto import jwk -from typing import Any +from typing import Any, cast from edb.server.protocol import execute from . import util @@ -42,7 +42,7 @@ def _get_signing_key(self) -> jwk.JWK: async def verify_email( self, identity_id: str, verified_at: datetime.datetime - ): + ) -> bytes: r = await execute.parse_execute_json( db=self.db, query="""\ @@ -80,7 +80,7 @@ async def get_email_by_identity_id(self, identity_id: str) -> str | None: assert len(result_json) == 1 - return result_json[0]["email"] + return cast(str, result_json[0]["email"]) async def get_verified_by_identity_id(self, identity_id: str) -> str | None: r = await execute.parse_execute_json( @@ -100,7 +100,7 @@ async def get_verified_by_identity_id(self, identity_id: str) -> str | None: assert len(result_json) == 1 - return result_json[0]["verified_at"] + return cast(str, result_json[0]["verified_at"]) async def get_identity_id_by_email( self, email: str, *, factor_type: str = 'EmailFactor' @@ -125,4 +125,4 @@ async def get_identity_id_by_email( assert len(result_json) == 1 - return result_json[0] + return cast(str, result_json[0]) diff --git a/edb/server/protocol/auth_ext/magic_link.py b/edb/server/protocol/auth_ext/magic_link.py index 5a1b6540c6a..89b77e7bf30 100644 --- a/edb/server/protocol/auth_ext/magic_link.py +++ b/edb/server/protocol/auth_ext/magic_link.py @@ -61,7 +61,7 @@ def _get_provider(self) -> config.MagicLinkProviderConfig: provider_name, f"Provider is not configured" ) - async def register(self, email: str): + async def register(self, email: str) -> data.LocalIdentity: try: result = await execute.parse_execute_json( self.db, @@ -102,7 +102,7 @@ async def send_magic_link( link_url: str, challenge: str, redirect_on_failure: str, - ): + ) -> None: initial_key_material = self._get_signing_key() identity_id = await self.get_identity_id_by_email( email, factor_type='MagicLinkFactor' diff --git a/edb/server/protocol/auth_ext/oauth.py b/edb/server/protocol/auth_ext/oauth.py index ff941d92005..d02d21423f6 100644 --- a/edb/server/protocol/auth_ext/oauth.py +++ b/edb/server/protocol/auth_ext/oauth.py @@ -19,11 +19,11 @@ import json -from typing import Any, Type +from typing import cast, Any, Type from edb.server.protocol import execute from . import github, google, azure, apple, discord, slack -from . import errors, util, data, base, http_client +from . import config, errors, util, data, base, http_client class Client: @@ -63,7 +63,8 @@ def __init__( raise errors.InvalidData(f"Invalid provider: {provider_name}") self.provider = provider_class( - *provider_args, **provider_kwargs # type: ignore + *provider_args, + **provider_kwargs, # type: ignore[arg-type] ) async def get_authorize_url(self, state: str, redirect_uri: str) -> str: @@ -123,14 +124,21 @@ async def _handle_identity( result_json[0]['new'] ) - def _get_provider_config(self, provider_name: str): + def _get_provider_config( + self, provider_name: str + ) -> config.OAuthProviderConfig: provider_client_config = util.get_config( self.db, "ext::auth::AuthConfig::providers", frozenset ) for cfg in provider_client_config: if cfg.name == provider_name: - return data.ProviderConfig( - cfg.client_id, cfg.secret, cfg.additional_scope + cfg = cast(config.OAuthProviderConfig, cfg) + return config.OAuthProviderConfig( + name=cfg.name, + display_name=cfg.display_name, + client_id=cfg.client_id, + secret=cfg.secret, + additional_scope=cfg.additional_scope, ) raise errors.MissingConfiguration( diff --git a/edb/server/protocol/auth_ext/pkce.py b/edb/server/protocol/auth_ext/pkce.py index 57cd33a9fc5..028a6e9e162 100644 --- a/edb/server/protocol/auth_ext/pkce.py +++ b/edb/server/protocol/auth_ext/pkce.py @@ -49,7 +49,7 @@ class PKCEChallenge: identity_id: str | None -async def create(db, challenge: str): +async def create(db: edbtenant.dbview.Database, challenge: str) -> None: await execute.parse_execute_json( db, """ @@ -65,7 +65,11 @@ async def create(db, challenge: str): ) -async def link_identity_challenge(db, identity_id: str, challenge: str) -> str: +async def link_identity_challenge( + db: edbtenant.dbview.Database, + identity_id: str, + challenge: str, +) -> str: r = await execute.parse_execute_json( db, """ @@ -83,12 +87,15 @@ async def link_identity_challenge(db, identity_id: str, challenge: str) -> str: result_json = json.loads(r.decode()) assert len(result_json) == 1 - return result_json[0]["id"] + return typing.cast(str, result_json[0]["id"]) async def add_provider_tokens( - db, id: str, auth_token: str | None, refresh_token: str | None -): + db: edbtenant.dbview.Database, + id: str, + auth_token: str | None, + refresh_token: str | None, +) -> str: r = await execute.parse_execute_json( db, """ @@ -110,10 +117,10 @@ async def add_provider_tokens( result_json = json.loads(r.decode()) assert len(result_json) == 1 - return result_json[0]["id"] + return typing.cast(str, result_json[0]["id"]) -async def get_by_id(db, id: str) -> PKCEChallenge: +async def get_by_id(db: edbtenant.dbview.Database, id: str) -> PKCEChallenge: r = await execute.parse_execute_json( db, """ @@ -137,7 +144,7 @@ async def get_by_id(db, id: str) -> PKCEChallenge: return PKCEChallenge(**result_json[0]) -async def delete(db, id: str) -> None: +async def delete(db: edbtenant.dbview.Database, id: str) -> None: r = await execute.parse_execute_json( db, """ @@ -151,7 +158,7 @@ async def delete(db, id: str) -> None: assert len(result_json) == 1 -async def _gc(tenant: edbtenant.Tenant): +async def _gc(tenant: edbtenant.Tenant) -> None: try: async with asyncio.TaskGroup() as g: for db in tenant.iter_dbs(): @@ -176,7 +183,7 @@ async def _gc(tenant: edbtenant.Tenant): ) -async def gc(server: edbserver.BaseServer): +async def gc(server: edbserver.BaseServer) -> None: while True: try: tasks = [ diff --git a/edb/server/protocol/auth_ext/slack.py b/edb/server/protocol/auth_ext/slack.py index 954fa099549..bc6909e4580 100644 --- a/edb/server/protocol/auth_ext/slack.py +++ b/edb/server/protocol/auth_ext/slack.py @@ -17,11 +17,13 @@ # +from typing import Any + from . import base class SlackProvider(base.OpenIDProvider): - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any): super().__init__( "slack", "https://slack.com", diff --git a/edb/server/protocol/auth_ext/ui/__init__.py b/edb/server/protocol/auth_ext/ui/__init__.py index 1c42083aca4..04ec441c870 100644 --- a/edb/server/protocol/auth_ext/ui/__init__.py +++ b/edb/server/protocol/auth_ext/ui/__init__.py @@ -16,20 +16,23 @@ # limitations under the License. # -from typing import Optional +from __future__ import annotations +from typing import cast, Optional import html from email.mime import multipart from email.mime import text as mime_text +from edb.server.protocol.auth_ext import config as auth_config + from . import components as render def render_signin_page( *, base_path: str, - providers: frozenset, + providers: frozenset[auth_config.ProviderConfig], error_message: Optional[str] = None, email: Optional[str] = None, challenge: str, @@ -41,7 +44,7 @@ def render_signin_page( logo_url: Optional[str] = None, dark_logo_url: Optional[str] = None, brand_color: Optional[str] = None, -): +) -> bytes: password_provider = None webauthn_provider = None magic_link_provider = None @@ -54,7 +57,9 @@ def render_signin_page( elif p.name == 'builtin::local_magic_link': magic_link_provider = p elif p.name.startswith('builtin::oauth_'): - oauth_providers.append(p) + oauth_providers.append( + cast(auth_config.OAuthProviderConfig, p) + ) base_email_factor_form = f""" @@ -189,14 +194,14 @@ def render_signin_page( def render_email_factor_form( *, - base_email_factor_form=None, - password_input='', - selected_tab=None, - single_form_fields='', - password_form, - webauthn_form, - magic_link_form, -): + base_email_factor_form: Optional[str] = None, + password_input: str = '', + selected_tab: Optional[str] = None, + single_form_fields: str = '', + password_form: Optional[str], + webauthn_form: Optional[str], + magic_link_form: Optional[str], +) -> Optional[str]: if ( password_form is None and webauthn_form is None and @@ -272,7 +277,7 @@ def render_email_factor_form( def render_signup_page( *, base_path: str, - providers: frozenset, + providers: frozenset[auth_config.ProviderConfig], error_message: Optional[str] = None, email: Optional[str] = None, challenge: str, @@ -284,7 +289,7 @@ def render_signup_page( logo_url: Optional[str] = None, dark_logo_url: Optional[str] = None, brand_color: Optional[str] = None, -): +) -> bytes: password_provider = None webauthn_provider = None magic_link_provider = None @@ -297,7 +302,9 @@ def render_signup_page( elif p.name == 'builtin::local_magic_link': magic_link_provider = p elif p.name.startswith('builtin::oauth_'): - oauth_providers.append(p) + oauth_providers.append( + cast(auth_config.OAuthProviderConfig, p) + ) base_email_factor_form = f""" @@ -410,7 +417,7 @@ def render_forgot_password_page( logo_url: Optional[str] = None, dark_logo_url: Optional[str] = None, brand_color: Optional[str] = None, -): +) -> bytes: if email_sent is not None: content = render.success_message( f'Password reset email has been sent to {email_sent}' @@ -464,7 +471,7 @@ def render_reset_password_page( logo_url: Optional[str] = None, dark_logo_url: Optional[str] = None, brand_color: Optional[str] = None, -): +) -> bytes: if not is_valid and challenge is None: content = render.error_message( f'''Reset token is invalid, challenge string is missing. Please @@ -519,7 +526,7 @@ def render_email_verification_page( logo_url: Optional[str] = None, dark_logo_url: Optional[str] = None, brand_color: Optional[str] = None, -): +) -> bytes: resend_url = None if verification_token: verification_token = html.escape(verification_token) @@ -562,7 +569,7 @@ def render_email_verification_expired_page( logo_url: Optional[str] = None, dark_logo_url: Optional[str] = None, brand_color: Optional[str] = None, -): +) -> bytes: verification_token = html.escape(verification_token) content = render.error_message( f''' @@ -597,7 +604,7 @@ def render_resend_verification_done_page( logo_url: Optional[str] = None, dark_logo_url: Optional[str] = None, brand_color: Optional[str] = None, -): +) -> bytes: if verification_token is None: content = render.error_message( f""" @@ -639,7 +646,7 @@ def render_magic_link_sent_page( logo_url: Optional[str] = None, dark_logo_url: Optional[str] = None, brand_color: Optional[str] = None, -): +) -> bytes: content = render.success_message( "A sign in link has been sent to your email. Please check your email." ) diff --git a/edb/server/protocol/auth_ext/ui/components.py b/edb/server/protocol/auth_ext/ui/components.py index e39b7f8e54e..ac2355266d3 100644 --- a/edb/server/protocol/auth_ext/ui/components.py +++ b/edb/server/protocol/auth_ext/ui/components.py @@ -16,13 +16,17 @@ # limitations under the License. # -from typing import Optional +from __future__ import annotations +from typing import Optional, TYPE_CHECKING import html import urllib.parse from . import util +if TYPE_CHECKING: + from edb.server.protocol.auth_ext import config as auth_config + known_oauth_provider_names = [ 'builtin::oauth_github', 'builtin::oauth_google', @@ -110,7 +114,7 @@ def oauth_buttons( redirect_to: str, challenge: str, redirect_to_on_signup: Optional[str], - oauth_providers: list, + oauth_providers: list[auth_config.OAuthProviderConfig], label_prefix: str, collapsed: bool ) -> str: @@ -138,7 +142,12 @@ def oauth_buttons( ''' -def _oauth_button(provider, params: dict, *, label_prefix: str) -> str: +def _oauth_button( + provider: auth_config.OAuthProviderConfig, + params: dict[str, str], + *, + label_prefix: str, +) -> str: href = '../authorize?' + urllib.parse.urlencode({ 'provider': provider.name, **params diff --git a/edb/server/protocol/auth_ext/ui/util.py b/edb/server/protocol/auth_ext/ui/util.py index 38b53f246ad..1c7def49878 100644 --- a/edb/server/protocol/auth_ext/ui/util.py +++ b/edb/server/protocol/auth_ext/ui/util.py @@ -23,7 +23,7 @@ hex_color_regexp = re.compile(r'[0-9a-fA-F]{6}') -def get_colour_vars(bg_hex: str): +def get_colour_vars(bg_hex: str) -> str: bg_rgb = hex_to_rgb(bg_hex) bg_hsl = rgb_to_hsl(*bg_rgb) luma = rgb_to_luma(*bg_rgb) diff --git a/edb/server/protocol/auth_ext/util.py b/edb/server/protocol/auth_ext/util.py index cc2723568b1..60440224d8a 100644 --- a/edb/server/protocol/auth_ext/util.py +++ b/edb/server/protocol/auth_ext/util.py @@ -17,6 +17,8 @@ # +from __future__ import annotations + import base64 import urllib.parse import datetime @@ -25,17 +27,20 @@ from cryptography.hazmat.backends import default_backend from jwcrypto import jwt, jwk -from typing import TypeVar, Type, overload, Any, cast, Optional +from typing import TypeVar, Type, overload, Any, cast, Optional, TYPE_CHECKING from edb.server import config as edb_config from edb.server.config.types import CompositeConfigType from . import errors, config +if TYPE_CHECKING: + from edb.server import tenant as edbtenant + T = TypeVar("T") -def maybe_get_config_unchecked(db: Any, key: str) -> Any: +def maybe_get_config_unchecked(db: edbtenant.dbview.Database, key: str) -> Any: return edb_config.lookup(key, db.db_config, spec=db.user_config_spec) @@ -122,7 +127,7 @@ def get_app_details_config(db: Any) -> config.AppDetailsConfig: ) -def join_url_params(url: str, params: dict[str, str]): +def join_url_params(url: str, params: dict[str, str]) -> str: parsed_url = urllib.parse.urlparse(url) query_params = { **urllib.parse.parse_qs(parsed_url.query), @@ -162,7 +167,7 @@ def make_token( ) token.make_signed_token(signing_key) - return token.serialize() + return cast(str, token.serialize()) def derive_key(key: jwk.JWK, info: str) -> jwk.JWK: diff --git a/edb/server/protocol/auth_ext/webauthn.py b/edb/server/protocol/auth_ext/webauthn.py index 78737f4a3ee..9e84045d84d 100644 --- a/edb/server/protocol/auth_ext/webauthn.py +++ b/edb/server/protocol/auth_ext/webauthn.py @@ -16,12 +16,14 @@ # limitations under the License. # +from __future__ import annotations + import dataclasses import base64 import json import webauthn -from typing import Any, Optional, Tuple +from typing import Optional, Tuple, TYPE_CHECKING from webauthn.helpers import ( parse_authentication_credential_json, structs as webauthn_structs, @@ -33,6 +35,9 @@ from . import config, data, errors, util, local +if TYPE_CHECKING: + from edb.server import tenant as edbtenant + @dataclasses.dataclass(repr=False) class WebAuthnRegistrationChallenge: @@ -47,7 +52,7 @@ class WebAuthnRegistrationChallenge: class Client(local.Client): - def __init__(self, db: Any): + def __init__(self, db: edbtenant.dbview.Database): self.db = db self.provider = self._get_provider() self.app_name = self._get_app_name() @@ -72,7 +77,9 @@ def _get_provider(self) -> config.WebAuthnProvider: def _get_app_name(self) -> Optional[str]: return util.maybe_get_config(self.db, "ext::auth::AuthConfig::app_name") - async def create_registration_options_for_email(self, email: str): + async def create_registration_options_for_email( + self, email: str, + ) -> tuple[str, bytes]: maybe_user_handle = await self._maybe_get_existing_user_handle( email=email ) @@ -95,7 +102,9 @@ async def create_registration_options_for_email(self, email: str): webauthn.options_to_json(registration_options).encode(), ) - async def _maybe_get_existing_user_handle(self, email: str): + async def _maybe_get_existing_user_handle( + self, email: str, + ) -> Optional[bytes]: result = await execute.parse_execute_json( self.db, """ @@ -123,7 +132,7 @@ async def _create_registration_challenge( email: str, challenge: bytes, user_handle: bytes, - ): + ) -> None: await execute.parse_execute_json( self.db, """ @@ -149,7 +158,7 @@ async def register( credentials: str, email: str, user_handle: bytes, - ): + ) -> data.LocalIdentity: registration_challenge = await self._get_registration_challenge( email=email, user_handle=user_handle, @@ -248,7 +257,7 @@ async def _delete_registration_challenges( self, email: str, user_handle: bytes, - ): + ) -> None: await execute.parse_execute_json( self.db, """ @@ -377,7 +386,7 @@ async def _get_authentication_challenge( self, email: str, credential_id: bytes, - ): + ) -> data.WebAuthnAuthenticationChallenge: result = await execute.parse_execute_json( self.db, """ @@ -429,7 +438,7 @@ async def _delete_authentication_challenges( self, email: str, credential_id: bytes, - ): + ) -> None: await execute.parse_execute_json( self.db, """ diff --git a/edb/server/protocol/metrics.py b/edb/server/protocol/metrics.py index 2ccf562753f..c88b3206118 100644 --- a/edb/server/protocol/metrics.py +++ b/edb/server/protocol/metrics.py @@ -17,6 +17,8 @@ # +from __future__ import annotations +from typing import Type, TYPE_CHECKING import http from edb import errors @@ -26,12 +28,16 @@ from edb.common import debug from edb.common import markup +if TYPE_CHECKING: + from edb.server import tenant as edbtenant + from edb.server.protocol import protocol + async def handle_request( - request, - response, - tenant, -): + request: protocol.HttpRequest, + response: protocol.HttpResponse, + tenant: edbtenant.Tenant, +) -> None: try: if tenant is None or isinstance(tenant.server, server.Server): output = metrics.registry.generate() @@ -56,7 +62,12 @@ async def handle_request( ) -def _response_error(response, status, message, ex_type): +def _response_error( + response: protocol.HttpResponse, + status: http.HTTPStatus, + message: str, + ex_type: Type[errors.EdgeDBError], +) -> None: response.body = ( f'Unexpected error in /metrics.\n\n' f'{ex_type.__name__}: {message}' diff --git a/edb/server/protocol/protocol.pyi b/edb/server/protocol/protocol.pyi index cb454a9bc35..3a5db29220b 100644 --- a/edb/server/protocol/protocol.pyi +++ b/edb/server/protocol/protocol.pyi @@ -19,13 +19,14 @@ import asyncio import http import http.cookies +import httptools import ssl from edb.server import args as srvargs from edb.server import server class HttpRequest: - url: str + url: httptools.URL version: bytes should_keep_alive: bool content_type: bytes diff --git a/edb/server/protocol/server_info.py b/edb/server/protocol/server_info.py index 346dd87e56b..65d65a7baff 100644 --- a/edb/server/protocol/server_info.py +++ b/edb/server/protocol/server_info.py @@ -17,6 +17,8 @@ # +from __future__ import annotations +from typing import Any, Type, TYPE_CHECKING import dataclasses import http import json @@ -30,10 +32,14 @@ from edb.common import debug from edb.common import markup +if TYPE_CHECKING: + from edb.server import server as edbserver + from edb.server.protocol import protocol + class ImmutableEncoder(json.JSONEncoder): - def default(self, obj): + def default(self, obj: Any) -> Any: if isinstance(obj, (set, frozenset)): return list(obj) if isinstance(obj, immutables.Map): @@ -50,10 +56,10 @@ def default(self, obj): async def handle_request( - request, - response, - server, -): + request: protocol.HttpRequest, + response: protocol.HttpResponse, + server: edbserver.Server, +) -> None: try: output = ImmutableEncoder().encode(server.get_debug_info()) response.status = http.HTTPStatus.OK @@ -73,7 +79,12 @@ async def handle_request( ) -def _response_error(response, status, message, ex_type): +def _response_error( + response: protocol.HttpResponse, + status: http.HTTPStatus, + message: str, + ex_type: Type[errors.EdgeDBError], +) -> None: response.body = ( f'Unexpected error in /server-info.\n\n' f'{ex_type.__name__}: {message}' diff --git a/edb/server/protocol/system_api.py b/edb/server/protocol/system_api.py index 1e5beaadf75..1fb4671b294 100644 --- a/edb/server/protocol/system_api.py +++ b/edb/server/protocol/system_api.py @@ -16,7 +16,8 @@ # limitations under the License. # - +from __future__ import annotations +from typing import Type, TYPE_CHECKING import http import json @@ -28,16 +29,20 @@ from edb.server import compiler from edb.server import defines as edbdef -from . import execute # type: ignore +from . import execute + +if TYPE_CHECKING: + from edb.server import tenant as edbtenant, server as edbserver + from edb.server.protocol import protocol async def handle_request( - request, - response, - path_parts, - server, - tenant, -): + request: protocol.HttpRequest, + response: protocol.HttpResponse, + path_parts: list[str], + server: edbserver.Server, + tenant: edbtenant.Tenant, +) -> None: try: if tenant is None: try: @@ -92,7 +97,12 @@ async def handle_request( ) -def _response_error(response, status, message, ex_type): +def _response_error( + response: protocol.HttpResponse, + status: http.HTTPStatus, + message: str, + ex_type: Type[errors.EdgeDBError], +) -> None: err_dct = { 'message': message, 'type': str(ex_type.__name__), @@ -104,13 +114,13 @@ def _response_error(response, status, message, ex_type): response.close_connection = True -def _response_ok(response, message): +def _response_ok(response: protocol.HttpResponse, message: bytes) -> None: response.status = http.HTTPStatus.OK response.content_type = b'application/json' response.body = message -async def _ping(tenant): +async def _ping(tenant: edbtenant.Tenant) -> bytes: if tenant.get_backend_runtime_params().has_create_database: dbname = edbdef.EDGEDB_SYSTEM_DB else: @@ -129,18 +139,18 @@ async def _ping(tenant): async def handle_liveness_query( - request, - response, - tenant, -): + request: protocol.HttpRequest, + response: protocol.HttpResponse, + tenant: edbtenant.Tenant, +) -> None: _response_ok(response, await _ping(tenant)) async def handle_readiness_query( - request, - response, - tenant, -): + request: protocol.HttpRequest, + response: protocol.HttpResponse, + tenant: edbtenant.Tenant, +) -> None: if not tenant.is_ready(): _response_error( response, diff --git a/pyproject.toml b/pyproject.toml index 5c6823b26ab..d08de217708 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -161,11 +161,12 @@ module = [ "edb.schema.*", "edb.schema.reflection.*", "edb.server.cluster", + "edb.server.compiler.*", "edb.server.config", "edb.server.connpool.*", + "edb.server.protocol.*", "edb.server.pgcluster", "edb.server.pgconnparams", - "edb.server.compiler.*", ] # Equivalent of --strict on the command line, # but without disallow_untyped_calls: