diff --git a/edb/lib/ext/auth.edgeql b/edb/lib/ext/auth.edgeql index 563301ba1e8..f171984e201 100644 --- a/edb/lib/ext/auth.edgeql +++ b/edb/lib/ext/auth.edgeql @@ -177,6 +177,29 @@ CREATE EXTENSION PACKAGE auth VERSION '1.0' { }; }; + create type ext::auth::OpenIDConnectProvider + extending ext::auth::OAuthProviderConfig { + alter property name { + set protected := false; + }; + + alter property display_name { + set protected := false; + }; + + create required property issuer_url: std::str { + create annotation std::description := + "The issuer URL of the provider."; + }; + + create property logo_url: std::str { + create annotation std::description := + "A url to an image of the provider's logo."; + }; + + create constraint exclusive on ((.issuer_url, .client_id)); + }; + create type ext::auth::AppleOAuthProvider extending ext::auth::OAuthProviderConfig { alter property name { diff --git a/edb/server/protocol/auth_ext/apple.py b/edb/server/protocol/auth_ext/apple.py index f5efa1b4750..8d34df9a8d1 100644 --- a/edb/server/protocol/auth_ext/apple.py +++ b/edb/server/protocol/auth_ext/apple.py @@ -24,13 +24,12 @@ from . import base -class AppleProvider(base.OpenIDProvider): +class AppleProvider(base.OpenIDConnectProvider): def __init__(self, *args: Any, **kwargs: Any): super().__init__( "apple", "https://appleid.apple.com", *args, - content_type=base.ContentType.FORM_ENCODED, **kwargs, ) diff --git a/edb/server/protocol/auth_ext/azure.py b/edb/server/protocol/auth_ext/azure.py index 4255707c592..caa3fcf2681 100644 --- a/edb/server/protocol/auth_ext/azure.py +++ b/edb/server/protocol/auth_ext/azure.py @@ -22,12 +22,11 @@ from . import base -class AzureProvider(base.OpenIDProvider): +class AzureProvider(base.OpenIDConnectProvider): def __init__(self, *args: Any, **kwargs: Any): super().__init__( "azure", "https://login.microsoftonline.com/common/v2.0", *args, - content_type=base.ContentType.FORM_ENCODED, **kwargs, ) diff --git a/edb/server/protocol/auth_ext/base.py b/edb/server/protocol/auth_ext/base.py index 1c191c77966..e97edbc8c4f 100644 --- a/edb/server/protocol/auth_ext/base.py +++ b/edb/server/protocol/auth_ext/base.py @@ -70,17 +70,15 @@ class ContentType(enum.StrEnum): FORM_ENCODED = "application/x-www-form-urlencoded" -class OpenIDProvider(BaseProvider): +class OpenIDConnectProvider(BaseProvider): def __init__( self, name: str, issuer_url: str, *args: Any, - content_type: ContentType = ContentType.JSON, **kwargs: Any, ): super().__init__(name, issuer_url, *args, **kwargs) - self.content_type = content_type async def get_code_url( self, state: str, redirect_uri: str, additional_scope: str @@ -114,23 +112,16 @@ async def exchange_code( "redirect_uri": redirect_uri, } headers = {"Accept": ContentType.JSON.value} - if self.content_type == ContentType.JSON: - resp = await client.post( - token_endpoint.path, - json=request_body, - headers=headers, - ) - else: - resp = await client.post( - token_endpoint.path, - data=request_body, - headers=headers, - ) + resp = await client.post( + token_endpoint.path, + data=request_body, + headers=headers, + ) if resp.status_code >= 400: raise errors.OAuthProviderFailure( f"Failed to exchange code: {resp.text}" ) - content_type = resp.headers.get('Content-Type', self.content_type) + content_type = resp.headers.get('Content-Type') if content_type.startswith(str(ContentType.JSON)): response_body = resp.json() else: diff --git a/edb/server/protocol/auth_ext/config.py b/edb/server/protocol/auth_ext/config.py index 431bf479033..ae5b28d5d8b 100644 --- a/edb/server/protocol/auth_ext/config.py +++ b/edb/server/protocol/auth_ext/config.py @@ -53,6 +53,8 @@ class OAuthProviderConfig(ProviderConfig): client_id: str secret: str additional_scope: Optional[str] + issuer_url: Optional[str] + logo_url: Optional[str] class WebAuthnProviderConfig(ProviderConfig): diff --git a/edb/server/protocol/auth_ext/google.py b/edb/server/protocol/auth_ext/google.py index 61e3a3d2f72..c2fece83c23 100644 --- a/edb/server/protocol/auth_ext/google.py +++ b/edb/server/protocol/auth_ext/google.py @@ -22,7 +22,7 @@ from . import base -class GoogleProvider(base.OpenIDProvider): +class GoogleProvider(base.OpenIDConnectProvider): def __init__(self, *args: Any, **kwargs: Any): super().__init__( "google", "https://accounts.google.com", *args, **kwargs diff --git a/edb/server/protocol/auth_ext/http_client.py b/edb/server/protocol/auth_ext/http_client.py index 2e158fe2edf..52302b97c31 100644 --- a/edb/server/protocol/auth_ext/http_client.py +++ b/edb/server/protocol/auth_ext/http_client.py @@ -50,7 +50,7 @@ async def post( # type: ignore[override] **kwargs: Any, ) -> httpx.Response: if self.edgedb_orig_base_url: - path = f'{self.edgedb_orig_base_url}/{path}' + path = f'{self.edgedb_orig_base_url}{path}' return await super().post( path, *args, **kwargs ) @@ -62,5 +62,5 @@ async def get( # type: ignore[override] **kwargs: Any, ) -> httpx.Response: if self.edgedb_orig_base_url: - path = f'{self.edgedb_orig_base_url}/{path}' + path = f'{self.edgedb_orig_base_url}{path}' return await super().get(path, *args, **kwargs) diff --git a/edb/server/protocol/auth_ext/oauth.py b/edb/server/protocol/auth_ext/oauth.py index d02d21423f6..a34d49f2fae 100644 --- a/edb/server/protocol/auth_ext/oauth.py +++ b/edb/server/protocol/auth_ext/oauth.py @@ -19,7 +19,7 @@ import json -from typing import cast, Any, Type +from typing import cast, Any from edb.server.protocol import execute from . import github, google, azure, apple, discord, slack @@ -39,34 +39,56 @@ def __init__( ) provider_config = self._get_provider_config(provider_name) - provider_args = (provider_config.client_id, provider_config.secret) + provider_args: tuple[str, str] | tuple[str, str, str, str] = ( + provider_config.client_id, + provider_config.secret, + ) provider_kwargs = { "http_factory": http_factory, "additional_scope": provider_config.additional_scope, } - provider_class: Type[base.BaseProvider] - match provider_name: - case "builtin::oauth_github": - provider_class = github.GitHubProvider - case "builtin::oauth_google": - provider_class = google.GoogleProvider - case "builtin::oauth_azure": - provider_class = azure.AzureProvider - case "builtin::oauth_apple": - provider_class = apple.AppleProvider - case "builtin::oauth_discord": - provider_class = discord.DiscordProvider - case "builtin::oauth_slack": - provider_class = slack.SlackProvider + match (provider_name, provider_config.issuer_url): + case ("builtin::oauth_github", _): + self.provider = github.GitHubProvider( + *provider_args, + **provider_kwargs, + ) + case ("builtin::oauth_google", _): + self.provider = google.GoogleProvider( + *provider_args, + **provider_kwargs, + ) + case ("builtin::oauth_azure", _): + self.provider = azure.AzureProvider( + *provider_args, + **provider_kwargs, + ) + case ("builtin::oauth_apple", _): + self.provider = apple.AppleProvider( + *provider_args, + **provider_kwargs, + ) + case ("builtin::oauth_discord", _): + self.provider = discord.DiscordProvider( + *provider_args, + **provider_kwargs, + ) + case ("builtin::oauth_slack", _): + self.provider = slack.SlackProvider( + *provider_args, + **provider_kwargs, + ) + case (provider_name, str(issuer_url)): + self.provider = base.OpenIDConnectProvider( + provider_name, + issuer_url, + *provider_args, + **provider_kwargs, + ) case _: raise errors.InvalidData(f"Invalid provider: {provider_name}") - self.provider = provider_class( - *provider_args, - **provider_kwargs, # type: ignore[arg-type] - ) - async def get_authorize_url(self, state: str, redirect_uri: str) -> str: return await self.provider.get_code_url( state=state, @@ -121,7 +143,7 @@ async def _handle_identity( return ( data.Identity(**result_json[0]['identity']), - result_json[0]['new'] + result_json[0]['new'], ) def _get_provider_config( @@ -139,6 +161,8 @@ def _get_provider_config( client_id=cfg.client_id, secret=cfg.secret, additional_scope=cfg.additional_scope, + issuer_url=getattr(cfg, 'issuer_url', None), + logo_url=getattr(cfg, 'logo_url', None), ) raise errors.MissingConfiguration( diff --git a/edb/server/protocol/auth_ext/slack.py b/edb/server/protocol/auth_ext/slack.py index bc6909e4580..b486cfb83f8 100644 --- a/edb/server/protocol/auth_ext/slack.py +++ b/edb/server/protocol/auth_ext/slack.py @@ -22,12 +22,11 @@ from . import base -class SlackProvider(base.OpenIDProvider): +class SlackProvider(base.OpenIDConnectProvider): def __init__(self, *args: Any, **kwargs: Any): super().__init__( "slack", "https://slack.com", *args, - content_type=base.ContentType.FORM_ENCODED, **kwargs, ) diff --git a/edb/server/protocol/auth_ext/ui/__init__.py b/edb/server/protocol/auth_ext/ui/__init__.py index 04ec441c870..d066a0c0984 100644 --- a/edb/server/protocol/auth_ext/ui/__init__.py +++ b/edb/server/protocol/auth_ext/ui/__init__.py @@ -56,10 +56,8 @@ def render_signin_page( webauthn_provider = p elif p.name == 'builtin::local_magic_link': magic_link_provider = p - elif p.name.startswith('builtin::oauth_'): - oauth_providers.append( - cast(auth_config.OAuthProviderConfig, p) - ) + elif p.name.startswith('builtin::oauth_') or hasattr(p, "issuer_url"): + oauth_providers.append(cast(auth_config.OAuthProviderConfig, p)) base_email_factor_form = f""" @@ -68,7 +66,8 @@ def render_signin_page( """ - password_input = f""" + password_input = ( + f"""
- """ if password_provider else '' + """ + if password_provider + else '' + ) email_factor_form = render_email_factor_form( base_email_factor_form=base_email_factor_form, @@ -110,7 +112,8 @@ def render_signin_page( name='callback_url', value=redirect_to ) if magic_link_provider else ''} ''', - password_form=f""" + password_form=( + f"""
- """ if password_provider else None, - webauthn_form=f""" + """ + if password_provider + else None + ), + webauthn_form=( + f""" - """ if webauthn_provider else None, - magic_link_form=f""" + """ + if webauthn_provider + else None + ), + magic_link_form=( + f""" - """ if magic_link_provider else None + """ + if magic_link_provider + else None + ), ) if email_factor_form: email_factor_form += render.bottom_note( - "Don't have an account?", link='Sign up', href='signup') + "Don't have an account?", link='Sign up', href='signup' + ) oauth_buttons = render.oauth_buttons( oauth_providers=oauth_providers, @@ -203,9 +218,9 @@ def render_email_factor_form( magic_link_form: Optional[str], ) -> Optional[str]: if ( - password_form is None and - webauthn_form is None and - magic_link_form is None + password_form is None + and webauthn_form is None + and magic_link_form is None ): return None @@ -217,28 +232,36 @@ def render_email_factor_form( case (None, None, _): return magic_link_form - if ( - base_email_factor_form is None or - (webauthn_form is not None and magic_link_form is not None) + if base_email_factor_form is None or ( + webauthn_form is not None and magic_link_form is not None ): tabs = [ - ('Passkey', webauthn_form, selected_tab == 'webauthn') - if webauthn_form else None, - ('Password', password_form, selected_tab == 'password') - if password_form else None, - ('Email Link', magic_link_form, selected_tab == 'magic_link') - if magic_link_form else None + ( + ('Passkey', webauthn_form, selected_tab == 'webauthn') + if webauthn_form + else None + ), + ( + ('Password', password_form, selected_tab == 'password') + if password_form + else None + ), + ( + ('Email Link', magic_link_form, selected_tab == 'magic_link') + if magic_link_form + else None + ), ] selected_tabs = [t[2] for t in tabs if t is not None] selected_index = ( - selected_tabs.index(True) if True in selected_tabs else 0) + selected_tabs.index(True) if True in selected_tabs else 0 + ) - return ( - render.tabs_buttons( - [t[0] for t in tabs if t is not None], selected_index) + - render.tabs_content( - [t[1] for t in tabs if t is not None], selected_index) + return render.tabs_buttons( + [t[0] for t in tabs if t is not None], selected_index + ) + render.tabs_content( + [t[1] for t in tabs if t is not None], selected_index ) slider_content = [ @@ -256,7 +279,7 @@ def render_email_factor_form( secondary=True, type="button")} {render.button("Sign in with password", id="password-signin")} - ''' + ''', ] return f""" @@ -301,10 +324,8 @@ def render_signup_page( webauthn_provider = p elif p.name == 'builtin::local_magic_link': magic_link_provider = p - elif p.name.startswith('builtin::oauth_'): - oauth_providers.append( - cast(auth_config.OAuthProviderConfig, p) - ) + elif p.name.startswith('builtin::oauth_') or hasattr(p, "issuer_url"): + oauth_providers.append(cast(auth_config.OAuthProviderConfig, p)) base_email_factor_form = f""" @@ -315,7 +336,8 @@ def render_signup_page( email_factor_form = render_email_factor_form( selected_tab=selected_tab, - password_form=f""" + password_form=( + f""" {render.button("Sign Up", id="password-signup")}
- """ if password_provider else None, - webauthn_form=f""" + """ + if password_provider + else None + ), + webauthn_form=( + f"""
- """ if webauthn_provider else None, - magic_link_form=f""" + """ + if webauthn_provider + else None + ), + magic_link_form=( + f""" - """ if magic_link_provider else None + """ + if magic_link_provider + else None + ), ) if email_factor_form: email_factor_form += render.bottom_note( - 'Already have an account?', link='Sign in', href='signin') + 'Already have an account?', link='Sign in', href='signin' + ) oauth_buttons = render.oauth_buttons( oauth_providers=oauth_providers, - label_prefix=('Sign up with' if email_factor_form - else 'Continue with'), + label_prefix=('Sign up with' if email_factor_form else 'Continue with'), challenge=challenge, redirect_to=redirect_to, redirect_to_on_signup=redirect_to_on_signup, - collapsed=email_factor_form is not None and len(oauth_providers) >= 3 + collapsed=email_factor_form is not None and len(oauth_providers) >= 3, ) return render.base_page( @@ -662,6 +695,7 @@ def render_magic_link_sent_page( ''', ) + # emails diff --git a/edb/server/protocol/auth_ext/ui/components.py b/edb/server/protocol/auth_ext/ui/components.py index ac2355266d3..3520960d45c 100644 --- a/edb/server/protocol/auth_ext/ui/components.py +++ b/edb/server/protocol/auth_ext/ui/components.py @@ -148,16 +148,21 @@ def _oauth_button( *, label_prefix: str, ) -> str: - href = '../authorize?' + urllib.parse.urlencode({ - 'provider': provider.name, - **params - }) - img = ( - f'''{provider.display_name} Icon''' - if provider.name in known_oauth_provider_names - else '' + href = '../authorize?' + urllib.parse.urlencode( + {'provider': provider.name, **params} ) + if ( + provider.name.startswith('builtin::') + and provider.name in known_oauth_provider_names + ): + img = f'''{provider.display_name} Icon''' + elif provider.logo_url is not None: + img = f'''{provider.display_name} Icon''' + else: + img = '' + label = f'{label_prefix} {provider.display_name}' return f''' diff --git a/tests/test_http_ext_auth.py b/tests/test_http_ext_auth.py index 044a9bb1224..52ba3d6f2fe 100644 --- a/tests/test_http_ext_auth.py +++ b/tests/test_http_ext_auth.py @@ -200,6 +200,28 @@ ], } +GENERIC_OIDC_DISCOVERY_DOCUMENT = { + "issuer": "https://example.com", + "authorization_endpoint": "https://example.com/auth", + "token_endpoint": "https://example.com/token", + "userinfo_endpoint": "https://example.com/userinfo", + "jwks_uri": "https://example.com/jwks", + "scopes_supported": ["openid", "profile", "email"], + "response_types_supported": ["code"], + "response_modes_supported": ["query"], + "grant_types_supported": ["authorization_code"], + "subject_types_supported": ["public"], + "id_token_signing_alg_values_supported": ["RS256"], + "claims_supported": ["sub", "auth_time", "iss"], + "claims_parameter_supported": False, + "request_parameter_supported": False, + "request_uri_parameter_supported": True, + "token_endpoint_auth_methods_supported": [ + "client_secret_post", + "client_secret_basic", + ], +} + def utcnow(): return datetime.datetime.now(datetime.timezone.utc) @@ -212,6 +234,7 @@ def utcnow(): APPLE_SECRET = 'c' * 32 DISCORD_SECRET = 'd' * 32 SLACK_SECRET = 'd' * 32 +GENERIC_OIDC_SECRET = 'e' * 32 APP_NAME = "Test App" LOGO_URL = "http://example.com/logo.png" DARK_LOGO_URL = "http://example.com/darklogo.png" @@ -294,6 +317,16 @@ class TestHttpExtAuth(tb.ExtAuthTestCase): client_id := '{uuid.uuid4()}', }}; + CONFIGURE CURRENT DATABASE + INSERT ext::auth::OpenIDConnectProvider {{ + secret := '{GENERIC_OIDC_SECRET}', + client_id := '{uuid.uuid4()}', + name := 'generic_oidc', + display_name := 'My Generic OIDC Provider', + issuer_url := 'https://example.com', + additional_scope := 'custom_provider_scope_string', + }}; + CONFIGURE CURRENT DATABASE INSERT ext::auth::EmailPasswordProviderConfig {{ require_verification := false, @@ -321,7 +354,8 @@ def setUpClass(cls): def setUp(self): self.mock_provider = tb.MockHttpServer( - handler_type=tb.MultiHostMockHttpServerHandler) + handler_type=tb.MultiHostMockHttpServerHandler + ) self.mock_provider.start() HTTP_TEST_PORT.set(self.mock_provider.get_base_url()) @@ -371,18 +405,23 @@ def http_con_send_request(self, *args, headers=None, **kwargs): headers['x-edgedb-oauth-test-server'] = test_port return super().http_con_send_request(*args, headers=headers, **kwargs) - async def get_builtin_provider_config_by_name(self, provider_name: str): + async def get_provider_config_by_name(self, fqn: str): return await self.con.query_single( """ - SELECT assert_exists(assert_single( + SELECT assert_exists( cfg::Config.extensions[is ext::auth::AuthConfig].providers { *, [is ext::auth::OAuthProviderConfig].client_id, [is ext::auth::OAuthProviderConfig].additional_scope, - } filter .name = 'builtin::' ++ $0 - )); + } filter .name = $0 + ); """, - provider_name, + fqn, + ) + + async def get_builtin_provider_config_by_name(self, provider_name: str): + return await self.get_provider_config_by_name( + f"builtin::{provider_name}" ) async def get_auth_config_value(self, key: str): @@ -617,7 +656,7 @@ async def test_http_auth_ext_github_callback_01(self): token_request = ( "POST", "https://github.com", - "/login/oauth/access_token", + "login/oauth/access_token", ) self.mock_provider.register_route_handler(*token_request)( ( @@ -632,7 +671,7 @@ async def test_http_auth_ext_github_callback_01(self): ) ) - user_request = ("GET", "https://api.github.com", "/user") + user_request = ("GET", "https://api.github.com", "user") self.mock_provider.register_route_handler(*user_request)( ( json.dumps( @@ -792,7 +831,7 @@ async def test_http_auth_ext_github_callback_failure_01(self): token_request = ( "POST", "https://github.com", - "/login/oauth/access_token", + "login/oauth/access_token", ) self.mock_provider.register_route_handler(*token_request)( ( @@ -858,7 +897,7 @@ async def test_http_auth_ext_github_callback_failure_02(self): token_request = ( "POST", "https://github.com", - "/login/oauth/access_token", + "login/oauth/access_token", ) self.mock_provider.register_route_handler(*token_request)( ( @@ -999,7 +1038,7 @@ async def test_http_auth_ext_discord_callback_01(self): token_request = ( "POST", "https://discord.com", - "/api/oauth2/token", + "api/oauth2/token", ) self.mock_provider.register_route_handler(*token_request)( ( @@ -1014,7 +1053,7 @@ async def test_http_auth_ext_discord_callback_01(self): ) ) - user_request = ("GET", "https://discord.com/api/v10", "/users/@me") + user_request = ("GET", "https://discord.com/api/v10", "users/@me") self.mock_provider.register_route_handler(*user_request)( ( json.dumps( @@ -1177,7 +1216,7 @@ async def test_http_auth_ext_google_callback_01(self) -> None: discovery_request = ( "GET", "https://accounts.google.com", - "/.well-known/openid-configuration", + ".well-known/openid-configuration", ) self.mock_provider.register_route_handler(*discovery_request)( ( @@ -1190,7 +1229,7 @@ async def test_http_auth_ext_google_callback_01(self) -> None: jwks_request = ( "GET", "https://www.googleapis.com", - "/oauth2/v3/certs", + "oauth2/v3/certs", ) # Generate a JWK Set k = jwk.JWK.generate(kty='RSA', size=4096) @@ -1210,7 +1249,7 @@ async def test_http_auth_ext_google_callback_01(self) -> None: token_request = ( "POST", "https://oauth2.googleapis.com", - "/token", + "token", ) id_token_claims = { "iss": "https://accounts.google.com", @@ -1284,20 +1323,21 @@ async def test_http_auth_ext_google_callback_01(self) -> None: self.assertEqual(url.hostname, server_url.hostname) self.assertEqual(url.path, f"{server_url.path}/some/path") - requests_for_discovery = ( - self.mock_provider.requests[discovery_request]) + requests_for_discovery = self.mock_provider.requests[ + discovery_request + ] self.assertEqual(len(requests_for_discovery), 2) requests_for_token = self.mock_provider.requests[token_request] self.assertEqual(len(requests_for_token), 1) self.assertEqual( - json.loads(requests_for_token[0]["body"]), + urllib.parse.parse_qs(requests_for_token[0]["body"]), { - "grant_type": "authorization_code", - "code": "abc123", - "client_id": client_id, - "client_secret": client_secret, - "redirect_uri": f"{self.http_addr}/callback", + "grant_type": ["authorization_code"], + "code": ["abc123"], + "client_id": [client_id], + "client_secret": [client_secret], + "redirect_uri": [f"{self.http_addr}/callback"], }, ) @@ -1337,7 +1377,7 @@ async def test_http_auth_ext_google_authorize_01(self): discovery_request = ( "GET", "https://accounts.google.com", - "/.well-known/openid-configuration", + ".well-known/openid-configuration", ) self.mock_provider.register_route_handler(*discovery_request)( ( @@ -1381,8 +1421,9 @@ async def test_http_auth_ext_google_authorize_01(self): ) self.assertEqual(qs.get("client_id"), [client_id]) - requests_for_discovery = ( - self.mock_provider.requests[discovery_request]) + requests_for_discovery = self.mock_provider.requests[ + discovery_request + ] self.assertEqual(len(requests_for_discovery), 1) pkce = await self.con.query( @@ -1406,7 +1447,7 @@ async def test_http_auth_ext_azure_authorize_01(self): discovery_request = ( "GET", "https://login.microsoftonline.com/common/v2.0", - "/.well-known/openid-configuration", + ".well-known/openid-configuration", ) self.mock_provider.register_route_handler(*discovery_request)( ( @@ -1452,8 +1493,9 @@ async def test_http_auth_ext_azure_authorize_01(self): ) self.assertEqual(qs.get("client_id"), [client_id]) - requests_for_discovery = ( - self.mock_provider.requests[discovery_request]) + requests_for_discovery = self.mock_provider.requests[ + discovery_request + ] self.assertEqual(len(requests_for_discovery), 1) pkce = await self.con.query( @@ -1479,7 +1521,7 @@ async def test_http_auth_ext_azure_callback_01(self) -> None: discovery_request = ( "GET", "https://login.microsoftonline.com/common/v2.0", - "/.well-known/openid-configuration", + ".well-known/openid-configuration", ) self.mock_provider.register_route_handler(*discovery_request)( ( @@ -1491,7 +1533,7 @@ async def test_http_auth_ext_azure_callback_01(self) -> None: jwks_request = ( "GET", "https://login.microsoftonline.com", - "/common/discovery/v2.0/keys", + "common/discovery/v2.0/keys", ) # Generate a JWK Set k = jwk.JWK.generate(kty='RSA', size=4096) @@ -1511,7 +1553,7 @@ async def test_http_auth_ext_azure_callback_01(self) -> None: token_request = ( "POST", "https://login.microsoftonline.com", - "/common/oauth2/v2.0/token", + "common/oauth2/v2.0/token", ) id_token_claims = { "iss": "https://login.microsoftonline.com/common/v2.0", @@ -1585,8 +1627,9 @@ async def test_http_auth_ext_azure_callback_01(self) -> None: self.assertEqual(url.hostname, server_url.hostname) self.assertEqual(url.path, f"{server_url.path}/some/path") - requests_for_discovery = ( - self.mock_provider.requests[discovery_request]) + requests_for_discovery = self.mock_provider.requests[ + discovery_request + ] self.assertEqual(len(requests_for_discovery), 2) requests_for_token = self.mock_provider.requests[token_request] @@ -1622,7 +1665,7 @@ async def test_http_auth_ext_apple_authorize_01(self): discovery_request = ( "GET", "https://appleid.apple.com", - "/.well-known/openid-configuration", + ".well-known/openid-configuration", ) self.mock_provider.register_route_handler(*discovery_request)( ( @@ -1666,8 +1709,9 @@ async def test_http_auth_ext_apple_authorize_01(self): ) self.assertEqual(qs.get("client_id"), [client_id]) - requests_for_discovery = ( - self.mock_provider.requests[discovery_request]) + requests_for_discovery = self.mock_provider.requests[ + discovery_request + ] self.assertEqual(len(requests_for_discovery), 1) pkce = await self.con.query( @@ -1693,7 +1737,7 @@ async def test_http_auth_ext_apple_callback_01(self) -> None: discovery_request = ( "GET", "https://appleid.apple.com", - "/.well-known/openid-configuration", + ".well-known/openid-configuration", ) self.mock_provider.register_route_handler(*discovery_request)( ( @@ -1705,7 +1749,7 @@ async def test_http_auth_ext_apple_callback_01(self) -> None: jwks_request = ( "GET", "https://appleid.apple.com", - "/auth/keys", + "auth/keys", ) # Generate a JWK Set k = jwk.JWK.generate(kty='RSA', size=4096) @@ -1725,7 +1769,7 @@ async def test_http_auth_ext_apple_callback_01(self) -> None: token_request = ( "POST", "https://appleid.apple.com", - "/auth/token", + "auth/token", ) id_token_claims = { "iss": "https://appleid.apple.com", @@ -1804,8 +1848,9 @@ async def test_http_auth_ext_apple_callback_01(self) -> None: self.assertEqual(url.hostname, server_url.hostname) self.assertEqual(url.path, f"{server_url.path}/some/path") - requests_for_discovery = ( - self.mock_provider.requests[discovery_request]) + requests_for_discovery = self.mock_provider.requests[ + discovery_request + ] self.assertEqual(len(requests_for_discovery), 2) requests_for_token = self.mock_provider.requests[token_request] @@ -1836,7 +1881,7 @@ async def test_http_auth_ext_apple_callback_redirect_on_signup_02( discovery_request = ( "GET", "https://appleid.apple.com", - "/.well-known/openid-configuration", + ".well-known/openid-configuration", ) self.mock_provider.register_route_handler(*discovery_request)( ( @@ -1848,7 +1893,7 @@ async def test_http_auth_ext_apple_callback_redirect_on_signup_02( jwks_request = ( "GET", "https://appleid.apple.com", - "/auth/keys", + "auth/keys", ) # Generate a JWK Set k = jwk.JWK.generate(kty='RSA', size=4096) @@ -1868,7 +1913,7 @@ async def test_http_auth_ext_apple_callback_redirect_on_signup_02( token_request = ( "POST", "https://appleid.apple.com", - "/auth/token", + "auth/token", ) id_token_claims = { "iss": "https://appleid.apple.com", @@ -1984,7 +2029,7 @@ async def test_http_auth_ext_slack_callback_01(self) -> None: discovery_request = ( "GET", "https://slack.com", - "/.well-known/openid-configuration", + ".well-known/openid-configuration", ) self.mock_provider.register_route_handler(*discovery_request)( ( @@ -1997,7 +2042,7 @@ async def test_http_auth_ext_slack_callback_01(self) -> None: jwks_request = ( "GET", "https://slack.com", - "/openid/connect/keys", + "openid/connect/keys", ) # Generate a JWK Set k = jwk.JWK.generate(kty='RSA', size=4096) @@ -2017,7 +2062,7 @@ async def test_http_auth_ext_slack_callback_01(self) -> None: token_request = ( "POST", "https://slack.com", - "/api/openid.connect.token", + "api/openid.connect.token", ) id_token_claims = { "iss": "https://slack.com", @@ -2091,8 +2136,9 @@ async def test_http_auth_ext_slack_callback_01(self) -> None: self.assertEqual(url.hostname, server_url.hostname) self.assertEqual(url.path, f"{server_url.path}/some/path") - requests_for_discovery = ( - self.mock_provider.requests[discovery_request]) + requests_for_discovery = self.mock_provider.requests[ + discovery_request + ] self.assertEqual(len(requests_for_discovery), 2) requests_for_token = self.mock_provider.requests[token_request] @@ -2144,7 +2190,7 @@ async def test_http_auth_ext_slack_authorize_01(self): discovery_request = ( "GET", "https://slack.com", - "/.well-known/openid-configuration", + ".well-known/openid-configuration", ) self.mock_provider.register_route_handler(*discovery_request)( ( @@ -2188,8 +2234,90 @@ async def test_http_auth_ext_slack_authorize_01(self): ) self.assertEqual(qs.get("client_id"), [client_id]) - requests_for_discovery = ( - self.mock_provider.requests[discovery_request]) + requests_for_discovery = self.mock_provider.requests[ + discovery_request + ] + self.assertEqual(len(requests_for_discovery), 1) + + pkce = await self.con.query( + """ + select ext::auth::PKCEChallenge + filter .challenge = $challenge + """, + challenge=challenge, + ) + self.assertEqual(len(pkce), 1) + + async def test_http_auth_ext_generic_oidc_authorize_01(self): + with self.http_con() as http_con: + provider_config = await self.get_provider_config_by_name( + "generic_oidc" + ) + provider_name = provider_config.name + client_id = provider_config.client_id + challenge = ( + base64.urlsafe_b64encode( + hashlib.sha256( + base64.urlsafe_b64encode(os.urandom(43)).rstrip(b'=') + ).digest() + ) + .rstrip(b'=') + .decode() + ) + + discovery_request = ( + "GET", + "https://example.com", + ".well-known/openid-configuration", + ) + self.mock_provider.register_route_handler(*discovery_request)( + ( + json.dumps(GENERIC_OIDC_DISCOVERY_DOCUMENT), + 200, + ) + ) + + redirect_to = f"{self.http_addr}/some/path" + _, headers, status = self.http_con_request( + http_con, + { + "provider": provider_name, + "redirect_to": redirect_to, + "challenge": challenge, + }, + path="authorize", + ) + + self.assertEqual(status, 302) + + location = headers.get("location") + assert location is not None + url = urllib.parse.urlparse(location) + qs = urllib.parse.parse_qs(url.query, keep_blank_values=True) + self.assertEqual(url.scheme, "https") + self.assertEqual(url.hostname, "example.com") + self.assertEqual(url.path, "/auth") + self.assertEqual( + qs.get("scope"), + ["openid profile email custom_provider_scope_string"], + ) + + state = qs.get("state") + assert state is not None + + claims = await self.extract_jwt_claims(state[0]) + self.assertEqual(claims.get("provider"), provider_name) + self.assertEqual(claims.get("iss"), self.http_addr) + self.assertEqual(claims.get("redirect_to"), redirect_to) + + self.assertEqual( + qs.get("redirect_uri"), [f"{self.http_addr}/callback"] + ) + self.assertEqual(qs.get("client_id"), [client_id]) + + requests_for_discovery = self.mock_provider.requests[ + discovery_request + ] self.assertEqual(len(requests_for_discovery), 1) pkce = await self.con.query( @@ -2201,6 +2329,161 @@ async def test_http_auth_ext_slack_authorize_01(self): ) self.assertEqual(len(pkce), 1) + async def test_http_auth_ext_generic_oidc_callback_01(self): + with self.http_con() as http_con: + provider_config = await self.get_provider_config_by_name( + "generic_oidc" + ) + provider_name = provider_config.name + client_id = provider_config.client_id + client_secret = GENERIC_OIDC_SECRET + + now = utcnow() + + discovery_request = ( + "GET", + "https://example.com", + ".well-known/openid-configuration", + ) + self.mock_provider.register_route_handler(*discovery_request)( + ( + json.dumps(GENERIC_OIDC_DISCOVERY_DOCUMENT), + 200, + {"cache-control": "max-age=3600"}, + ) + ) + + jwks_request = ( + "GET", + "https://example.com", + "jwks", + ) + # Generate a JWK Set + k = jwk.JWK.generate(kty='RSA', size=4096) + ks = jwk.JWKSet() + ks.add(k) + jwk_set: dict[str, Any] = ks.export( + private_keys=False, as_dict=True + ) + + self.mock_provider.register_route_handler(*jwks_request)( + ( + json.dumps(jwk_set), + 200, + ) + ) + + token_request = ( + "POST", + "https://example.com", + "token", + ) + id_token_claims = { + "iss": "https://example.com", + "sub": "1", + "aud": client_id, + "exp": (now + datetime.timedelta(minutes=5)).timestamp(), + "iat": now.timestamp(), + "email": "test@example.com", + } + id_token = jwt.JWT(header={"alg": "RS256"}, claims=id_token_claims) + id_token.make_signed_token(k) + + self.mock_provider.register_route_handler(*token_request)( + ( + json.dumps( + { + "access_token": "oidc_access_token", + "id_token": id_token.serialize(), + "scope": "openid", + "token_type": "bearer", + } + ), + 200, + ) + ) + + challenge = ( + base64.urlsafe_b64encode( + hashlib.sha256( + base64.urlsafe_b64encode(os.urandom(43)).rstrip(b'=') + ).digest() + ) + .rstrip(b'=') + .decode() + ) + await self.con.query( + """ + insert ext::auth::PKCEChallenge { + challenge := $challenge, + } + """, + challenge=challenge, + ) + + signing_key = await self.get_signing_key() + + expires_at = now + datetime.timedelta(minutes=5) + state_claims = { + "iss": self.http_addr, + "provider": str(provider_name), + "exp": expires_at.timestamp(), + "redirect_to": f"{self.http_addr}/some/path", + "challenge": challenge, + } + state_token = self.generate_state_value(state_claims, signing_key) + + data, headers, status = self.http_con_request( + http_con, + {"state": state_token, "code": "abc123"}, + path="callback", + ) + + self.assertEqual(data, b"") + self.assertEqual(status, 302) + + location = headers.get("location") + assert location is not None + server_url = urllib.parse.urlparse(self.http_addr) + url = urllib.parse.urlparse(location) + self.assertEqual(url.scheme, server_url.scheme) + self.assertEqual(url.hostname, server_url.hostname) + self.assertEqual(url.path, f"{server_url.path}/some/path") + + requests_for_discovery = self.mock_provider.requests[ + discovery_request + ] + self.assertEqual(len(requests_for_discovery), 2) + + requests_for_token = self.mock_provider.requests[token_request] + self.assertEqual(len(requests_for_token), 1) + self.assertEqual( + urllib.parse.parse_qs(requests_for_token[0]["body"]), + { + "grant_type": ["authorization_code"], + "code": ["abc123"], + "client_id": [client_id], + "client_secret": [client_secret], + "redirect_uri": [f"{self.http_addr}/callback"], + }, + ) + + identity = await self.con.query( + """ + SELECT ext::auth::Identity + FILTER .subject = '1' + AND .issuer = 'https://example.com' + """ + ) + self.assertEqual(len(identity), 1) + + session_claims = await self.extract_session_claims(headers) + self.assertEqual(session_claims.get("sub"), str(identity[0].id)) + self.assertEqual(session_claims.get("iss"), str(self.http_addr)) + tomorrow = now + datetime.timedelta(hours=25) + self.assertTrue(session_claims.get("exp") > now.timestamp()) + self.assertTrue(session_claims.get("exp") < tomorrow.timestamp()) + async def test_http_auth_ext_local_password_register_form_01(self): with self.http_con() as http_con: provider_config = await self.get_builtin_provider_config_by_name( @@ -2409,7 +2692,9 @@ async def test_http_auth_ext_local_password_register_form_02(self): self.assertEqual(status, 400) # Non-matching port - form_data["redirect_to"] = "https://oauth.example.com:8080/app/some/path" + form_data["redirect_to"] = ( + "https://oauth.example.com:8080/app/some/path" + ) form_data_encoded = urllib.parse.urlencode(form_data).encode() _, _, status = self.http_con_request( @@ -2424,7 +2709,9 @@ async def test_http_auth_ext_local_password_register_form_02(self): self.assertEqual(status, 400) # Path doesn't match - form_data["redirect_to"] = "https://oauth.example.com/wrong-base/path" + form_data["redirect_to"] = ( + "https://oauth.example.com/wrong-base/path" + ) form_data_encoded = urllib.parse.urlencode(form_data).encode() _, _, status = self.http_con_request( @@ -3476,6 +3763,11 @@ async def test_http_auth_ext_ui_signin(self): self.assertIn(APP_NAME, body_str) self.assertIn(LOGO_URL, body_str) self.assertIn(BRAND_COLOR, body_str) + + # Check for OAuth buttons + self.assertIn("Sign in with Google", body_str) + self.assertIn("Sign in with GitHub", body_str) + self.assertIn("Sign in with My Generic OIDC Provider", body_str) self.assertEqual(status, 200) async def test_http_auth_ext_webauthn_register_options(self):