Skip to content

Commit

Permalink
Allow OAuth callback to be specified directly (#8022)
Browse files Browse the repository at this point in the history
Currently, we expect that the OAuth Identity Provider should always
redirect back to the auth extension's server endpoint, so we build this
URL ourselves. However, there might be times when users want to control
the OAuth flow themselves, so the callback should redirect to some URL
that they've specified.

One confusing thing here is that sometimes the URL provided at
`redirect_to` might be refered to as the "callback" URL, but that URL
specifies where the auth extension server should call the application at
the end of the flow.
  • Loading branch information
scotttrinh authored Nov 22, 2024
1 parent 7ac745e commit 009dd9c
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 11 deletions.
35 changes: 25 additions & 10 deletions edb/server/protocol/auth_ext/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,27 +135,27 @@ async def handle_request(

try:
match args:
# API routes
# PKCE token exchange route
case ("token",):
await self.handle_token(request, response)

# OAuth routes
case ("authorize",):
await self.handle_authorize(request, response)
case ("callback",):
await self.handle_callback(request, response)
case ("token",):
await self.handle_token(request, response)

# Email/password routes
case ("register",):
await self.handle_register(request, response)
case ("authenticate",):
await self.handle_authenticate(request, response)
case ("verify",):
await self.handle_verify(request, response)
case ("resend-verification-email",):
await self.handle_resend_verification_email(
request, response
)
case ('send-reset-email',):
await self.handle_send_reset_email(request, response)
case ('reset-password',):
await self.handle_reset_password(request, response)

# Magic link routes
case ('magic-link', 'register'):
await self.handle_magic_link_register(request, response)
case ('magic-link', 'email'):
Expand All @@ -177,6 +177,14 @@ async def handle_request(
request, response
)

# Email verification routes
case ("verify",):
await self.handle_verify(request, response)
case ("resend-verification-email",):
await self.handle_resend_verification_email(
request, response
)

# UI routes
case ('ui', 'signin'):
await self.handle_ui_signin(request, response)
Expand Down Expand Up @@ -292,6 +300,9 @@ async def handle_authorize(
allowed_redirect_to_on_signup = self._maybe_make_allowed_url(
_maybe_get_search_param(query, "redirect_to_on_signup")
)
allowed_callback_url = self._maybe_make_allowed_url(
_maybe_get_search_param(query, "callback_url")
)
challenge = _get_search_param(
query, "challenge", fallback_keys=["code_challenge"]
)
Expand All @@ -303,7 +314,11 @@ async def handle_authorize(
)
await pkce.create(self.db, challenge)
authorize_url = await oauth_client.get_authorize_url(
redirect_uri=self._get_callback_url(),
redirect_uri=(
allowed_callback_url.url
if allowed_callback_url
else self._get_callback_url()
),
state=self._make_state_claims(
provider_name,
allowed_redirect_to.url,
Expand Down
4 changes: 3 additions & 1 deletion tests/test_http_ext_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,7 @@ async def test_http_auth_ext_github_authorize_01(self):
provider_name = provider_config.name
client_id = provider_config.client_id
redirect_to = f"{self.http_addr}/some/path"
callback_url = f"{self.http_addr}/some/callback/url"
challenge = (
base64.urlsafe_b64encode(
hashlib.sha256(
Expand All @@ -512,6 +513,7 @@ async def test_http_auth_ext_github_authorize_01(self):
"provider": provider_name,
"redirect_to": redirect_to,
"challenge": challenge,
"callback_url": callback_url,
}

_, headers, status = self.http_con_request(
Expand Down Expand Up @@ -540,7 +542,7 @@ async def test_http_auth_ext_github_authorize_01(self):
self.assertEqual(claims.get("redirect_to"), redirect_to)

self.assertEqual(
qs.get("redirect_uri"), [f"{self.http_addr}/callback"]
qs.get("redirect_uri"), [callback_url]
)
self.assertEqual(qs.get("client_id"), [client_id])

Expand Down

0 comments on commit 009dd9c

Please sign in to comment.