diff --git a/edb/lib/ext/auth.edgeql b/edb/lib/ext/auth.edgeql index 563301ba1e8..f171984e201 100644 --- a/edb/lib/ext/auth.edgeql +++ b/edb/lib/ext/auth.edgeql @@ -177,6 +177,29 @@ CREATE EXTENSION PACKAGE auth VERSION '1.0' { }; }; + create type ext::auth::OpenIDConnectProvider + extending ext::auth::OAuthProviderConfig { + alter property name { + set protected := false; + }; + + alter property display_name { + set protected := false; + }; + + create required property issuer_url: std::str { + create annotation std::description := + "The issuer URL of the provider."; + }; + + create property logo_url: std::str { + create annotation std::description := + "A url to an image of the provider's logo."; + }; + + create constraint exclusive on ((.issuer_url, .client_id)); + }; + create type ext::auth::AppleOAuthProvider extending ext::auth::OAuthProviderConfig { alter property name { diff --git a/edb/server/protocol/auth_ext/apple.py b/edb/server/protocol/auth_ext/apple.py index f5efa1b4750..8d34df9a8d1 100644 --- a/edb/server/protocol/auth_ext/apple.py +++ b/edb/server/protocol/auth_ext/apple.py @@ -24,13 +24,12 @@ from . import base -class AppleProvider(base.OpenIDProvider): +class AppleProvider(base.OpenIDConnectProvider): def __init__(self, *args: Any, **kwargs: Any): super().__init__( "apple", "https://appleid.apple.com", *args, - content_type=base.ContentType.FORM_ENCODED, **kwargs, ) diff --git a/edb/server/protocol/auth_ext/azure.py b/edb/server/protocol/auth_ext/azure.py index 4255707c592..caa3fcf2681 100644 --- a/edb/server/protocol/auth_ext/azure.py +++ b/edb/server/protocol/auth_ext/azure.py @@ -22,12 +22,11 @@ from . import base -class AzureProvider(base.OpenIDProvider): +class AzureProvider(base.OpenIDConnectProvider): def __init__(self, *args: Any, **kwargs: Any): super().__init__( "azure", "https://login.microsoftonline.com/common/v2.0", *args, - content_type=base.ContentType.FORM_ENCODED, **kwargs, ) diff --git a/edb/server/protocol/auth_ext/base.py b/edb/server/protocol/auth_ext/base.py index 1c191c77966..e97edbc8c4f 100644 --- a/edb/server/protocol/auth_ext/base.py +++ b/edb/server/protocol/auth_ext/base.py @@ -70,17 +70,15 @@ class ContentType(enum.StrEnum): FORM_ENCODED = "application/x-www-form-urlencoded" -class OpenIDProvider(BaseProvider): +class OpenIDConnectProvider(BaseProvider): def __init__( self, name: str, issuer_url: str, *args: Any, - content_type: ContentType = ContentType.JSON, **kwargs: Any, ): super().__init__(name, issuer_url, *args, **kwargs) - self.content_type = content_type async def get_code_url( self, state: str, redirect_uri: str, additional_scope: str @@ -114,23 +112,16 @@ async def exchange_code( "redirect_uri": redirect_uri, } headers = {"Accept": ContentType.JSON.value} - if self.content_type == ContentType.JSON: - resp = await client.post( - token_endpoint.path, - json=request_body, - headers=headers, - ) - else: - resp = await client.post( - token_endpoint.path, - data=request_body, - headers=headers, - ) + resp = await client.post( + token_endpoint.path, + data=request_body, + headers=headers, + ) if resp.status_code >= 400: raise errors.OAuthProviderFailure( f"Failed to exchange code: {resp.text}" ) - content_type = resp.headers.get('Content-Type', self.content_type) + content_type = resp.headers.get('Content-Type') if content_type.startswith(str(ContentType.JSON)): response_body = resp.json() else: diff --git a/edb/server/protocol/auth_ext/config.py b/edb/server/protocol/auth_ext/config.py index 431bf479033..ae5b28d5d8b 100644 --- a/edb/server/protocol/auth_ext/config.py +++ b/edb/server/protocol/auth_ext/config.py @@ -53,6 +53,8 @@ class OAuthProviderConfig(ProviderConfig): client_id: str secret: str additional_scope: Optional[str] + issuer_url: Optional[str] + logo_url: Optional[str] class WebAuthnProviderConfig(ProviderConfig): diff --git a/edb/server/protocol/auth_ext/google.py b/edb/server/protocol/auth_ext/google.py index 61e3a3d2f72..c2fece83c23 100644 --- a/edb/server/protocol/auth_ext/google.py +++ b/edb/server/protocol/auth_ext/google.py @@ -22,7 +22,7 @@ from . import base -class GoogleProvider(base.OpenIDProvider): +class GoogleProvider(base.OpenIDConnectProvider): def __init__(self, *args: Any, **kwargs: Any): super().__init__( "google", "https://accounts.google.com", *args, **kwargs diff --git a/edb/server/protocol/auth_ext/http_client.py b/edb/server/protocol/auth_ext/http_client.py index 2e158fe2edf..52302b97c31 100644 --- a/edb/server/protocol/auth_ext/http_client.py +++ b/edb/server/protocol/auth_ext/http_client.py @@ -50,7 +50,7 @@ async def post( # type: ignore[override] **kwargs: Any, ) -> httpx.Response: if self.edgedb_orig_base_url: - path = f'{self.edgedb_orig_base_url}/{path}' + path = f'{self.edgedb_orig_base_url}{path}' return await super().post( path, *args, **kwargs ) @@ -62,5 +62,5 @@ async def get( # type: ignore[override] **kwargs: Any, ) -> httpx.Response: if self.edgedb_orig_base_url: - path = f'{self.edgedb_orig_base_url}/{path}' + path = f'{self.edgedb_orig_base_url}{path}' return await super().get(path, *args, **kwargs) diff --git a/edb/server/protocol/auth_ext/oauth.py b/edb/server/protocol/auth_ext/oauth.py index d02d21423f6..a34d49f2fae 100644 --- a/edb/server/protocol/auth_ext/oauth.py +++ b/edb/server/protocol/auth_ext/oauth.py @@ -19,7 +19,7 @@ import json -from typing import cast, Any, Type +from typing import cast, Any from edb.server.protocol import execute from . import github, google, azure, apple, discord, slack @@ -39,34 +39,56 @@ def __init__( ) provider_config = self._get_provider_config(provider_name) - provider_args = (provider_config.client_id, provider_config.secret) + provider_args: tuple[str, str] | tuple[str, str, str, str] = ( + provider_config.client_id, + provider_config.secret, + ) provider_kwargs = { "http_factory": http_factory, "additional_scope": provider_config.additional_scope, } - provider_class: Type[base.BaseProvider] - match provider_name: - case "builtin::oauth_github": - provider_class = github.GitHubProvider - case "builtin::oauth_google": - provider_class = google.GoogleProvider - case "builtin::oauth_azure": - provider_class = azure.AzureProvider - case "builtin::oauth_apple": - provider_class = apple.AppleProvider - case "builtin::oauth_discord": - provider_class = discord.DiscordProvider - case "builtin::oauth_slack": - provider_class = slack.SlackProvider + match (provider_name, provider_config.issuer_url): + case ("builtin::oauth_github", _): + self.provider = github.GitHubProvider( + *provider_args, + **provider_kwargs, + ) + case ("builtin::oauth_google", _): + self.provider = google.GoogleProvider( + *provider_args, + **provider_kwargs, + ) + case ("builtin::oauth_azure", _): + self.provider = azure.AzureProvider( + *provider_args, + **provider_kwargs, + ) + case ("builtin::oauth_apple", _): + self.provider = apple.AppleProvider( + *provider_args, + **provider_kwargs, + ) + case ("builtin::oauth_discord", _): + self.provider = discord.DiscordProvider( + *provider_args, + **provider_kwargs, + ) + case ("builtin::oauth_slack", _): + self.provider = slack.SlackProvider( + *provider_args, + **provider_kwargs, + ) + case (provider_name, str(issuer_url)): + self.provider = base.OpenIDConnectProvider( + provider_name, + issuer_url, + *provider_args, + **provider_kwargs, + ) case _: raise errors.InvalidData(f"Invalid provider: {provider_name}") - self.provider = provider_class( - *provider_args, - **provider_kwargs, # type: ignore[arg-type] - ) - async def get_authorize_url(self, state: str, redirect_uri: str) -> str: return await self.provider.get_code_url( state=state, @@ -121,7 +143,7 @@ async def _handle_identity( return ( data.Identity(**result_json[0]['identity']), - result_json[0]['new'] + result_json[0]['new'], ) def _get_provider_config( @@ -139,6 +161,8 @@ def _get_provider_config( client_id=cfg.client_id, secret=cfg.secret, additional_scope=cfg.additional_scope, + issuer_url=getattr(cfg, 'issuer_url', None), + logo_url=getattr(cfg, 'logo_url', None), ) raise errors.MissingConfiguration( diff --git a/edb/server/protocol/auth_ext/slack.py b/edb/server/protocol/auth_ext/slack.py index bc6909e4580..b486cfb83f8 100644 --- a/edb/server/protocol/auth_ext/slack.py +++ b/edb/server/protocol/auth_ext/slack.py @@ -22,12 +22,11 @@ from . import base -class SlackProvider(base.OpenIDProvider): +class SlackProvider(base.OpenIDConnectProvider): def __init__(self, *args: Any, **kwargs: Any): super().__init__( "slack", "https://slack.com", *args, - content_type=base.ContentType.FORM_ENCODED, **kwargs, ) diff --git a/edb/server/protocol/auth_ext/ui/__init__.py b/edb/server/protocol/auth_ext/ui/__init__.py index 04ec441c870..d066a0c0984 100644 --- a/edb/server/protocol/auth_ext/ui/__init__.py +++ b/edb/server/protocol/auth_ext/ui/__init__.py @@ -56,10 +56,8 @@ def render_signin_page( webauthn_provider = p elif p.name == 'builtin::local_magic_link': magic_link_provider = p - elif p.name.startswith('builtin::oauth_'): - oauth_providers.append( - cast(auth_config.OAuthProviderConfig, p) - ) + elif p.name.startswith('builtin::oauth_') or hasattr(p, "issuer_url"): + oauth_providers.append(cast(auth_config.OAuthProviderConfig, p)) base_email_factor_form = f""" @@ -68,7 +66,8 @@ def render_signin_page( """ - password_input = f""" + password_input = ( + f"""
- """ if password_provider else '' + """ + if password_provider + else '' + ) email_factor_form = render_email_factor_form( base_email_factor_form=base_email_factor_form, @@ -110,7 +112,8 @@ def render_signin_page( name='callback_url', value=redirect_to ) if magic_link_provider else ''} ''', - password_form=f""" + password_form=( + f""" {render.button("Sign Up", id="password-signup")} - """ if password_provider else None, - webauthn_form=f""" + """ + if password_provider + else None + ), + webauthn_form=( + f"""