diff --git a/edb/server/protocol/auth_ext/http.py b/edb/server/protocol/auth_ext/http.py index 3ee26c32bd1..fd91dc21d35 100644 --- a/edb/server/protocol/auth_ext/http.py +++ b/edb/server/protocol/auth_ext/http.py @@ -135,27 +135,27 @@ async def handle_request( try: match args: - # API routes + # PKCE token exchange route + case ("token",): + await self.handle_token(request, response) + + # OAuth routes case ("authorize",): await self.handle_authorize(request, response) case ("callback",): await self.handle_callback(request, response) - case ("token",): - await self.handle_token(request, response) + + # Email/password routes case ("register",): await self.handle_register(request, response) case ("authenticate",): await self.handle_authenticate(request, response) - case ("verify",): - await self.handle_verify(request, response) - case ("resend-verification-email",): - await self.handle_resend_verification_email( - request, response - ) case ('send-reset-email',): await self.handle_send_reset_email(request, response) case ('reset-password',): await self.handle_reset_password(request, response) + + # Magic link routes case ('magic-link', 'register'): await self.handle_magic_link_register(request, response) case ('magic-link', 'email'): @@ -177,6 +177,14 @@ async def handle_request( request, response ) + # Email verification routes + case ("verify",): + await self.handle_verify(request, response) + case ("resend-verification-email",): + await self.handle_resend_verification_email( + request, response + ) + # UI routes case ('ui', 'signin'): await self.handle_ui_signin(request, response) @@ -292,6 +300,9 @@ async def handle_authorize( allowed_redirect_to_on_signup = self._maybe_make_allowed_url( _maybe_get_search_param(query, "redirect_to_on_signup") ) + allowed_callback_url = self._maybe_make_allowed_url( + _maybe_get_search_param(query, "callback_url") + ) challenge = _get_search_param( query, "challenge", fallback_keys=["code_challenge"] ) @@ -303,7 +314,11 @@ async def handle_authorize( ) await pkce.create(self.db, challenge) authorize_url = await oauth_client.get_authorize_url( - redirect_uri=self._get_callback_url(), + redirect_uri=( + allowed_callback_url.url + if allowed_callback_url + else self._get_callback_url() + ), state=self._make_state_claims( provider_name, allowed_redirect_to.url, diff --git a/tests/test_http_ext_auth.py b/tests/test_http_ext_auth.py index 4827fdffdc0..4d102879678 100644 --- a/tests/test_http_ext_auth.py +++ b/tests/test_http_ext_auth.py @@ -499,6 +499,7 @@ async def test_http_auth_ext_github_authorize_01(self): provider_name = provider_config.name client_id = provider_config.client_id redirect_to = f"{self.http_addr}/some/path" + callback_url = f"{self.http_addr}/some/callback/url" challenge = ( base64.urlsafe_b64encode( hashlib.sha256( @@ -512,6 +513,7 @@ async def test_http_auth_ext_github_authorize_01(self): "provider": provider_name, "redirect_to": redirect_to, "challenge": challenge, + "callback_url": callback_url, } _, headers, status = self.http_con_request( @@ -540,7 +542,7 @@ async def test_http_auth_ext_github_authorize_01(self): self.assertEqual(claims.get("redirect_to"), redirect_to) self.assertEqual( - qs.get("redirect_uri"), [f"{self.http_addr}/callback"] + qs.get("redirect_uri"), [callback_url] ) self.assertEqual(qs.get("client_id"), [client_id])