From a405e32d1e55814ae99dc097534a61445583f245 Mon Sep 17 00:00:00 2001 From: Scott Trinh Date: Tue, 23 Jan 2024 13:36:23 -0500 Subject: [PATCH] Add Discord OAuth provider (#6708) --- edb/lib/ext/auth.edgeql | 11 + .../auth_ext/_static/icon_discord.svg | 3 + edb/server/protocol/auth_ext/discord.py | 95 ++++++ edb/server/protocol/auth_ext/oauth.py | 4 +- edb/server/protocol/auth_ext/ui.py | 1 + tests/test_http_ext_auth.py | 281 +++++++++++++++++- 6 files changed, 387 insertions(+), 8 deletions(-) create mode 100644 edb/server/protocol/auth_ext/_static/icon_discord.svg create mode 100644 edb/server/protocol/auth_ext/discord.py diff --git a/edb/lib/ext/auth.edgeql b/edb/lib/ext/auth.edgeql index cc0b8196427..df41ebda009 100644 --- a/edb/lib/ext/auth.edgeql +++ b/edb/lib/ext/auth.edgeql @@ -145,6 +145,17 @@ CREATE EXTENSION PACKAGE auth VERSION '1.0' { }; }; + create type ext::auth::DiscordOAuthProvider + extending ext::auth::OAuthProviderConfig { + alter property name { + set default := 'builtin::oauth_discord'; + }; + + alter property display_name { + set default := 'Discord'; + }; + }; + create type ext::auth::GitHubOAuthProvider extending ext::auth::OAuthProviderConfig { alter property name { diff --git a/edb/server/protocol/auth_ext/_static/icon_discord.svg b/edb/server/protocol/auth_ext/_static/icon_discord.svg new file mode 100644 index 00000000000..3b0b3b1f664 --- /dev/null +++ b/edb/server/protocol/auth_ext/_static/icon_discord.svg @@ -0,0 +1,3 @@ + + + diff --git a/edb/server/protocol/auth_ext/discord.py b/edb/server/protocol/auth_ext/discord.py new file mode 100644 index 00000000000..76b58a033e8 --- /dev/null +++ b/edb/server/protocol/auth_ext/discord.py @@ -0,0 +1,95 @@ +# +# This source file is part of the EdgeDB open source project. +# +# Copyright 2024-present MagicStack Inc. and the EdgeDB authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +import urllib.parse +import functools + +from . import base, data, errors + + +class DiscordProvider(base.BaseProvider): + def __init__(self, *args, **kwargs): + super().__init__("discord", "https://discord.com", *args, **kwargs) + self.auth_domain = self.issuer_url + self.api_domain = f"{self.issuer_url}/api/v10" + self.auth_client = functools.partial( + self.http_factory, base_url=self.auth_domain + ) + self.api_client = functools.partial( + self.http_factory, base_url=self.api_domain + ) + + async def get_code_url( + self, state: str, redirect_uri: str, additional_scope: str + ) -> str: + params = { + "client_id": self.client_id, + "scope": f"email identify {additional_scope}", + "state": state, + "redirect_uri": redirect_uri, + "response_type": "code", + } + encoded = urllib.parse.urlencode(params) + return f"{self.auth_domain}/oauth2/authorize?{encoded}" + + async def exchange_code( + self, code: str, redirect_uri: str + ) -> data.OAuthAccessTokenResponse: + async with self.auth_client() as client: + resp = await client.post( + "/api/oauth2/token", + data={ + "grant_type": "authorization_code", + "code": code, + "client_id": self.client_id, + "client_secret": self.client_secret, + "redirect_uri": redirect_uri, + }, + headers={ + "accept": "application/json", + }, + ) + if resp.status_code >= 400: + raise errors.OAuthProviderFailure( + f"Failed to exchange code: {resp.text}" + ) + json = resp.json() + + return data.OAuthAccessTokenResponse(**json) + + async def fetch_user_info( + self, token_response: data.OAuthAccessTokenResponse + ) -> data.UserInfo: + async with self.api_client() as client: + resp = await client.get( + "/users/@me", + headers={ + "Authorization": f"Bearer {token_response.access_token}", + "Accept": "application/json", + "Cache-Control": "no-store", + }, + ) + payload = resp.json() + return data.UserInfo( + sub=str(payload["id"]), + preferred_username=payload.get("username"), + name=payload.get("global_name"), + email=payload.get("email"), + picture=payload.get("avatar"), + ) diff --git a/edb/server/protocol/auth_ext/oauth.py b/edb/server/protocol/auth_ext/oauth.py index 70e5099f2ff..6440df8e83c 100644 --- a/edb/server/protocol/auth_ext/oauth.py +++ b/edb/server/protocol/auth_ext/oauth.py @@ -22,7 +22,7 @@ from typing import Any, Type from edb.server.protocol import execute -from . import github, google, azure, apple +from . import github, google, azure, apple, discord from . import errors, util, data, base, http_client @@ -55,6 +55,8 @@ def __init__( provider_class = azure.AzureProvider case "builtin::oauth_apple": provider_class = apple.AppleProvider + case "builtin::oauth_discord": + provider_class = discord.DiscordProvider case _: raise errors.InvalidData(f"Invalid provider: {provider_name}") diff --git a/edb/server/protocol/auth_ext/ui.py b/edb/server/protocol/auth_ext/ui.py index 6ca3dc2c255..72a9e3b0b40 100644 --- a/edb/server/protocol/auth_ext/ui.py +++ b/edb/server/protocol/auth_ext/ui.py @@ -30,6 +30,7 @@ 'builtin::oauth_google', 'builtin::oauth_apple', 'builtin::oauth_azure', + 'builtin::oauth_discord', ] diff --git a/tests/test_http_ext_auth.py b/tests/test_http_ext_auth.py index a3378e13a03..c4be059a706 100644 --- a/tests/test_http_ext_auth.py +++ b/tests/test_http_ext_auth.py @@ -242,14 +242,19 @@ def handle_request( # Parse and save the request details parsed_path = urllib.parse.urlparse(path) - request_details = { - 'headers': {k.lower(): v for k, v in dict(handler.headers).items()}, - 'query_params': urllib.parse.parse_qs(parsed_path.query), - 'body': handler.rfile.read( - int(handler.headers['Content-Length']) + headers = {k.lower(): v for k, v in dict(handler.headers).items()} + query_params = urllib.parse.parse_qs(parsed_path.query) + if 'content-length' in headers: + body = handler.rfile.read( + int(headers['content-length']) ).decode() - if 'Content-Length' in handler.headers - else None, + else: + body = None + + request_details = { + 'headers': headers, + 'query_params': query_params, + 'body': body, } self.requests[key].append(request_details) @@ -344,6 +349,7 @@ def __exit__(self, *exc): GOOGLE_SECRET = 'c' * 32 AZURE_SECRET = 'c' * 32 APPLE_SECRET = 'c' * 32 +DISCORD_SECRET = 'd' * 32 class TestHttpExtAuth(tb.ExtAuthTestCase): @@ -391,6 +397,12 @@ class TestHttpExtAuth(tb.ExtAuthTestCase): client_id := '{uuid.uuid4()}', }}; + CONFIGURE CURRENT DATABASE + INSERT ext::auth::DiscordOAuthProvider {{ + secret := '{DISCORD_SECRET}', + client_id := '{uuid.uuid4()}', + }}; + CONFIGURE CURRENT DATABASE INSERT ext::auth::EmailPasswordProviderConfig {{ require_verification := false, @@ -983,6 +995,261 @@ async def test_http_auth_ext_github_callback_failure_02(self): "error=access_denied", ) + async def test_http_auth_ext_discord_authorize_01(self): + with MockAuthProvider(), self.http_con() as http_con: + provider_config = await self.get_builtin_provider_config_by_name( + "oauth_discord" + ) + provider_name = provider_config.name + client_id = provider_config.client_id + redirect_to = f"{self.http_addr}/some/path" + challenge = ( + base64.urlsafe_b64encode( + hashlib.sha256( + base64.urlsafe_b64encode(os.urandom(43)).rstrip(b'=') + ).digest() + ) + .rstrip(b'=') + .decode() + ) + query = { + "provider": provider_name, + "redirect_to": redirect_to, + "challenge": challenge, + } + + _, headers, status = self.http_con_request( + http_con, + query, + path="authorize", + ) + + self.assertEqual(status, 302) + + location = headers.get("location") + assert location is not None + url = urllib.parse.urlparse(location) + qs = urllib.parse.parse_qs(url.query, keep_blank_values=True) + self.assertEqual(url.scheme, "https") + self.assertEqual(url.hostname, "discord.com") + self.assertEqual(url.path, "/oauth2/authorize") + self.assertEqual(qs.get("scope"), ["email identify "]) + + state = qs.get("state") + assert state is not None + + claims = await self.extract_jwt_claims(state[0]) + self.assertEqual(claims.get("provider"), provider_name) + self.assertEqual(claims.get("iss"), self.http_addr) + self.assertEqual(claims.get("redirect_to"), redirect_to) + + self.assertEqual( + qs.get("redirect_uri"), [f"{self.http_addr}/callback"] + ) + self.assertEqual(qs.get("client_id"), [client_id]) + + pkce = await self.con.query( + """ + select ext::auth::PKCEChallenge + filter .challenge = $challenge + """, + challenge=challenge, + ) + self.assertEqual(len(pkce), 1) + + _, _, repeat_status = self.http_con_request( + http_con, + query, + path="authorize", + ) + self.assertEqual(repeat_status, 302) + + repeat_pkce = await self.con.query_single( + """ + select ext::auth::PKCEChallenge + filter .challenge = $challenge + """, + challenge=challenge, + ) + self.assertEqual(pkce[0].id, repeat_pkce.id) + + async def test_http_auth_ext_discord_callback_01(self): + with MockAuthProvider() as mock_provider, self.http_con() as http_con: + provider_config = await self.get_builtin_provider_config_by_name( + "oauth_discord" + ) + provider_name = provider_config.name + client_id = provider_config.client_id + client_secret = DISCORD_SECRET + + now = utcnow() + token_request = ( + "POST", + "https://discord.com", + "/api/oauth2/token", + ) + mock_provider.register_route_handler(*token_request)( + ( + json.dumps( + { + "access_token": "discord_access_token", + "scope": "read:user", + "token_type": "bearer", + } + ), + 200, + ) + ) + + user_request = ("GET", "https://discord.com/api/v10", "/users/@me") + mock_provider.register_route_handler(*user_request)( + ( + json.dumps( + { + "id": 1, + "username": "dischord", + "global_name": "Ian MacKaye", + "email": "ian@example.com", + "picture": "https://example.com/example.jpg", + } + ), + 200, + ) + ) + + challenge = ( + base64.urlsafe_b64encode( + hashlib.sha256( + base64.urlsafe_b64encode(os.urandom(43)).rstrip(b'=') + ).digest() + ) + .rstrip(b'=') + .decode() + ) + await self.con.query( + """ + insert ext::auth::PKCEChallenge { + challenge := $challenge, + } + """, + challenge=challenge, + ) + + signing_key = await self.get_signing_key() + + expires_at = now + datetime.timedelta(minutes=5) + state_claims = { + "iss": self.http_addr, + "provider": str(provider_name), + "exp": expires_at.timestamp(), + "redirect_to": f"{self.http_addr}/some/path", + "challenge": challenge, + } + state_token = self.generate_state_value(state_claims, signing_key) + + data, headers, status = self.http_con_request( + http_con, + {"state": state_token, "code": "abc123"}, + path="callback", + ) + + self.assertEqual(data, b"") + self.assertEqual(status, 302) + + location = headers.get("location") + assert location is not None + server_url = urllib.parse.urlparse(self.http_addr) + url = urllib.parse.urlparse(location) + self.assertEqual(url.scheme, server_url.scheme) + self.assertEqual(url.hostname, server_url.hostname) + self.assertEqual(url.path, f"{server_url.path}/some/path") + + requests_for_token = mock_provider.requests[token_request] + self.assertEqual(len(requests_for_token), 1) + + self.assertEqual( + urllib.parse.parse_qs(requests_for_token[0]["body"]), + { + "grant_type": ["authorization_code"], + "code": ["abc123"], + "client_id": [client_id], + "client_secret": [client_secret], + "redirect_uri": [f"{self.http_addr}/callback"], + }, + ) + + requests_for_user = mock_provider.requests[user_request] + self.assertEqual(len(requests_for_user), 1) + self.assertEqual( + requests_for_user[0]["headers"]["authorization"], + "Bearer discord_access_token", + ) + + identity = await self.con.query( + """ + SELECT ext::auth::Identity + FILTER .subject = '1' + AND .issuer = 'https://discord.com' + """ + ) + self.assertEqual(len(identity), 1) + + session_claims = await self.extract_session_claims(headers) + self.assertEqual(session_claims.get("sub"), str(identity[0].id)) + self.assertEqual(session_claims.get("iss"), str(self.http_addr)) + tomorrow = now + datetime.timedelta(hours=25) + self.assertTrue(session_claims.get("exp") > now.timestamp()) + self.assertTrue(session_claims.get("exp") < tomorrow.timestamp()) + + pkce_object = await self.con.query( + """ + SELECT ext::auth::PKCEChallenge + { id, auth_token, refresh_token } + filter .identity.id = $identity_id + """, + identity_id=identity[0].id, + ) + + self.assertEqual(len(pkce_object), 1) + self.assertEqual(pkce_object[0].auth_token, "discord_access_token") + self.assertIsNone(pkce_object[0].refresh_token) + + mock_provider.register_route_handler(*user_request)( + ( + json.dumps( + { + "id": 1, + "login": "octocat", + "name": "monalisa octocat", + "email": "octocat+2@example.com", + "avatar_url": "https://example.com/example.jpg", + "updated_at": now.isoformat(), + } + ), + 200, + ) + ) + (_, new_headers, _) = self.http_con_request( + http_con, + {"state": state_token, "code": "abc123"}, + path="callback", + ) + + same_identity = await self.con.query( + """ + SELECT ext::auth::Identity + FILTER .subject = '1' + AND .issuer = 'https://discord.com' + """ + ) + self.assertEqual(len(same_identity), 1) + self.assertEqual(identity[0].id, same_identity[0].id) + + new_session_claims = await self.extract_session_claims(new_headers) + self.assertTrue( + new_session_claims.get("exp") > session_claims.get("exp") + ) + async def test_http_auth_ext_google_callback_01(self) -> None: with MockAuthProvider() as mock_provider, self.http_con() as http_con: provider_config = await self.get_builtin_provider_config_by_name(