From 009dd9c4a4881fc864b8f95bf71648253fe0a5ba Mon Sep 17 00:00:00 2001 From: Scott Trinh Date: Fri, 22 Nov 2024 10:02:12 -0500 Subject: [PATCH] Allow OAuth callback to be specified directly (#8022) Currently, we expect that the OAuth Identity Provider should always redirect back to the auth extension's server endpoint, so we build this URL ourselves. However, there might be times when users want to control the OAuth flow themselves, so the callback should redirect to some URL that they've specified. One confusing thing here is that sometimes the URL provided at `redirect_to` might be refered to as the "callback" URL, but that URL specifies where the auth extension server should call the application at the end of the flow. --- edb/server/protocol/auth_ext/http.py | 35 ++++++++++++++++++++-------- tests/test_http_ext_auth.py | 4 +++- 2 files changed, 28 insertions(+), 11 deletions(-) diff --git a/edb/server/protocol/auth_ext/http.py b/edb/server/protocol/auth_ext/http.py index 3ee26c32bd1..fd91dc21d35 100644 --- a/edb/server/protocol/auth_ext/http.py +++ b/edb/server/protocol/auth_ext/http.py @@ -135,27 +135,27 @@ async def handle_request( try: match args: - # API routes + # PKCE token exchange route + case ("token",): + await self.handle_token(request, response) + + # OAuth routes case ("authorize",): await self.handle_authorize(request, response) case ("callback",): await self.handle_callback(request, response) - case ("token",): - await self.handle_token(request, response) + + # Email/password routes case ("register",): await self.handle_register(request, response) case ("authenticate",): await self.handle_authenticate(request, response) - case ("verify",): - await self.handle_verify(request, response) - case ("resend-verification-email",): - await self.handle_resend_verification_email( - request, response - ) case ('send-reset-email',): await self.handle_send_reset_email(request, response) case ('reset-password',): await self.handle_reset_password(request, response) + + # Magic link routes case ('magic-link', 'register'): await self.handle_magic_link_register(request, response) case ('magic-link', 'email'): @@ -177,6 +177,14 @@ async def handle_request( request, response ) + # Email verification routes + case ("verify",): + await self.handle_verify(request, response) + case ("resend-verification-email",): + await self.handle_resend_verification_email( + request, response + ) + # UI routes case ('ui', 'signin'): await self.handle_ui_signin(request, response) @@ -292,6 +300,9 @@ async def handle_authorize( allowed_redirect_to_on_signup = self._maybe_make_allowed_url( _maybe_get_search_param(query, "redirect_to_on_signup") ) + allowed_callback_url = self._maybe_make_allowed_url( + _maybe_get_search_param(query, "callback_url") + ) challenge = _get_search_param( query, "challenge", fallback_keys=["code_challenge"] ) @@ -303,7 +314,11 @@ async def handle_authorize( ) await pkce.create(self.db, challenge) authorize_url = await oauth_client.get_authorize_url( - redirect_uri=self._get_callback_url(), + redirect_uri=( + allowed_callback_url.url + if allowed_callback_url + else self._get_callback_url() + ), state=self._make_state_claims( provider_name, allowed_redirect_to.url, diff --git a/tests/test_http_ext_auth.py b/tests/test_http_ext_auth.py index 4827fdffdc0..4d102879678 100644 --- a/tests/test_http_ext_auth.py +++ b/tests/test_http_ext_auth.py @@ -499,6 +499,7 @@ async def test_http_auth_ext_github_authorize_01(self): provider_name = provider_config.name client_id = provider_config.client_id redirect_to = f"{self.http_addr}/some/path" + callback_url = f"{self.http_addr}/some/callback/url" challenge = ( base64.urlsafe_b64encode( hashlib.sha256( @@ -512,6 +513,7 @@ async def test_http_auth_ext_github_authorize_01(self): "provider": provider_name, "redirect_to": redirect_to, "challenge": challenge, + "callback_url": callback_url, } _, headers, status = self.http_con_request( @@ -540,7 +542,7 @@ async def test_http_auth_ext_github_authorize_01(self): self.assertEqual(claims.get("redirect_to"), redirect_to) self.assertEqual( - qs.get("redirect_uri"), [f"{self.http_addr}/callback"] + qs.get("redirect_uri"), [callback_url] ) self.assertEqual(qs.get("client_id"), [client_id])