diff --git a/edb/lib/ext/auth.edgeql b/edb/lib/ext/auth.edgeql
index cc0b8196427..df41ebda009 100644
--- a/edb/lib/ext/auth.edgeql
+++ b/edb/lib/ext/auth.edgeql
@@ -145,6 +145,17 @@ CREATE EXTENSION PACKAGE auth VERSION '1.0' {
};
};
+ create type ext::auth::DiscordOAuthProvider
+ extending ext::auth::OAuthProviderConfig {
+ alter property name {
+ set default := 'builtin::oauth_discord';
+ };
+
+ alter property display_name {
+ set default := 'Discord';
+ };
+ };
+
create type ext::auth::GitHubOAuthProvider
extending ext::auth::OAuthProviderConfig {
alter property name {
diff --git a/edb/server/protocol/auth_ext/_static/icon_discord.svg b/edb/server/protocol/auth_ext/_static/icon_discord.svg
new file mode 100644
index 00000000000..3b0b3b1f664
--- /dev/null
+++ b/edb/server/protocol/auth_ext/_static/icon_discord.svg
@@ -0,0 +1,3 @@
+
diff --git a/edb/server/protocol/auth_ext/discord.py b/edb/server/protocol/auth_ext/discord.py
new file mode 100644
index 00000000000..76b58a033e8
--- /dev/null
+++ b/edb/server/protocol/auth_ext/discord.py
@@ -0,0 +1,95 @@
+#
+# This source file is part of the EdgeDB open source project.
+#
+# Copyright 2024-present MagicStack Inc. and the EdgeDB authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+
+import urllib.parse
+import functools
+
+from . import base, data, errors
+
+
+class DiscordProvider(base.BaseProvider):
+ def __init__(self, *args, **kwargs):
+ super().__init__("discord", "https://discord.com", *args, **kwargs)
+ self.auth_domain = self.issuer_url
+ self.api_domain = f"{self.issuer_url}/api/v10"
+ self.auth_client = functools.partial(
+ self.http_factory, base_url=self.auth_domain
+ )
+ self.api_client = functools.partial(
+ self.http_factory, base_url=self.api_domain
+ )
+
+ async def get_code_url(
+ self, state: str, redirect_uri: str, additional_scope: str
+ ) -> str:
+ params = {
+ "client_id": self.client_id,
+ "scope": f"email identify {additional_scope}",
+ "state": state,
+ "redirect_uri": redirect_uri,
+ "response_type": "code",
+ }
+ encoded = urllib.parse.urlencode(params)
+ return f"{self.auth_domain}/oauth2/authorize?{encoded}"
+
+ async def exchange_code(
+ self, code: str, redirect_uri: str
+ ) -> data.OAuthAccessTokenResponse:
+ async with self.auth_client() as client:
+ resp = await client.post(
+ "/api/oauth2/token",
+ data={
+ "grant_type": "authorization_code",
+ "code": code,
+ "client_id": self.client_id,
+ "client_secret": self.client_secret,
+ "redirect_uri": redirect_uri,
+ },
+ headers={
+ "accept": "application/json",
+ },
+ )
+ if resp.status_code >= 400:
+ raise errors.OAuthProviderFailure(
+ f"Failed to exchange code: {resp.text}"
+ )
+ json = resp.json()
+
+ return data.OAuthAccessTokenResponse(**json)
+
+ async def fetch_user_info(
+ self, token_response: data.OAuthAccessTokenResponse
+ ) -> data.UserInfo:
+ async with self.api_client() as client:
+ resp = await client.get(
+ "/users/@me",
+ headers={
+ "Authorization": f"Bearer {token_response.access_token}",
+ "Accept": "application/json",
+ "Cache-Control": "no-store",
+ },
+ )
+ payload = resp.json()
+ return data.UserInfo(
+ sub=str(payload["id"]),
+ preferred_username=payload.get("username"),
+ name=payload.get("global_name"),
+ email=payload.get("email"),
+ picture=payload.get("avatar"),
+ )
diff --git a/edb/server/protocol/auth_ext/oauth.py b/edb/server/protocol/auth_ext/oauth.py
index 70e5099f2ff..6440df8e83c 100644
--- a/edb/server/protocol/auth_ext/oauth.py
+++ b/edb/server/protocol/auth_ext/oauth.py
@@ -22,7 +22,7 @@
from typing import Any, Type
from edb.server.protocol import execute
-from . import github, google, azure, apple
+from . import github, google, azure, apple, discord
from . import errors, util, data, base, http_client
@@ -55,6 +55,8 @@ def __init__(
provider_class = azure.AzureProvider
case "builtin::oauth_apple":
provider_class = apple.AppleProvider
+ case "builtin::oauth_discord":
+ provider_class = discord.DiscordProvider
case _:
raise errors.InvalidData(f"Invalid provider: {provider_name}")
diff --git a/edb/server/protocol/auth_ext/ui.py b/edb/server/protocol/auth_ext/ui.py
index 6ca3dc2c255..72a9e3b0b40 100644
--- a/edb/server/protocol/auth_ext/ui.py
+++ b/edb/server/protocol/auth_ext/ui.py
@@ -30,6 +30,7 @@
'builtin::oauth_google',
'builtin::oauth_apple',
'builtin::oauth_azure',
+ 'builtin::oauth_discord',
]
diff --git a/tests/test_http_ext_auth.py b/tests/test_http_ext_auth.py
index a3378e13a03..c4be059a706 100644
--- a/tests/test_http_ext_auth.py
+++ b/tests/test_http_ext_auth.py
@@ -242,14 +242,19 @@ def handle_request(
# Parse and save the request details
parsed_path = urllib.parse.urlparse(path)
- request_details = {
- 'headers': {k.lower(): v for k, v in dict(handler.headers).items()},
- 'query_params': urllib.parse.parse_qs(parsed_path.query),
- 'body': handler.rfile.read(
- int(handler.headers['Content-Length'])
+ headers = {k.lower(): v for k, v in dict(handler.headers).items()}
+ query_params = urllib.parse.parse_qs(parsed_path.query)
+ if 'content-length' in headers:
+ body = handler.rfile.read(
+ int(headers['content-length'])
).decode()
- if 'Content-Length' in handler.headers
- else None,
+ else:
+ body = None
+
+ request_details = {
+ 'headers': headers,
+ 'query_params': query_params,
+ 'body': body,
}
self.requests[key].append(request_details)
@@ -344,6 +349,7 @@ def __exit__(self, *exc):
GOOGLE_SECRET = 'c' * 32
AZURE_SECRET = 'c' * 32
APPLE_SECRET = 'c' * 32
+DISCORD_SECRET = 'd' * 32
class TestHttpExtAuth(tb.ExtAuthTestCase):
@@ -391,6 +397,12 @@ class TestHttpExtAuth(tb.ExtAuthTestCase):
client_id := '{uuid.uuid4()}',
}};
+ CONFIGURE CURRENT DATABASE
+ INSERT ext::auth::DiscordOAuthProvider {{
+ secret := '{DISCORD_SECRET}',
+ client_id := '{uuid.uuid4()}',
+ }};
+
CONFIGURE CURRENT DATABASE
INSERT ext::auth::EmailPasswordProviderConfig {{
require_verification := false,
@@ -983,6 +995,261 @@ async def test_http_auth_ext_github_callback_failure_02(self):
"error=access_denied",
)
+ async def test_http_auth_ext_discord_authorize_01(self):
+ with MockAuthProvider(), self.http_con() as http_con:
+ provider_config = await self.get_builtin_provider_config_by_name(
+ "oauth_discord"
+ )
+ provider_name = provider_config.name
+ client_id = provider_config.client_id
+ redirect_to = f"{self.http_addr}/some/path"
+ challenge = (
+ base64.urlsafe_b64encode(
+ hashlib.sha256(
+ base64.urlsafe_b64encode(os.urandom(43)).rstrip(b'=')
+ ).digest()
+ )
+ .rstrip(b'=')
+ .decode()
+ )
+ query = {
+ "provider": provider_name,
+ "redirect_to": redirect_to,
+ "challenge": challenge,
+ }
+
+ _, headers, status = self.http_con_request(
+ http_con,
+ query,
+ path="authorize",
+ )
+
+ self.assertEqual(status, 302)
+
+ location = headers.get("location")
+ assert location is not None
+ url = urllib.parse.urlparse(location)
+ qs = urllib.parse.parse_qs(url.query, keep_blank_values=True)
+ self.assertEqual(url.scheme, "https")
+ self.assertEqual(url.hostname, "discord.com")
+ self.assertEqual(url.path, "/oauth2/authorize")
+ self.assertEqual(qs.get("scope"), ["email identify "])
+
+ state = qs.get("state")
+ assert state is not None
+
+ claims = await self.extract_jwt_claims(state[0])
+ self.assertEqual(claims.get("provider"), provider_name)
+ self.assertEqual(claims.get("iss"), self.http_addr)
+ self.assertEqual(claims.get("redirect_to"), redirect_to)
+
+ self.assertEqual(
+ qs.get("redirect_uri"), [f"{self.http_addr}/callback"]
+ )
+ self.assertEqual(qs.get("client_id"), [client_id])
+
+ pkce = await self.con.query(
+ """
+ select ext::auth::PKCEChallenge
+ filter .challenge = $challenge
+ """,
+ challenge=challenge,
+ )
+ self.assertEqual(len(pkce), 1)
+
+ _, _, repeat_status = self.http_con_request(
+ http_con,
+ query,
+ path="authorize",
+ )
+ self.assertEqual(repeat_status, 302)
+
+ repeat_pkce = await self.con.query_single(
+ """
+ select ext::auth::PKCEChallenge
+ filter .challenge = $challenge
+ """,
+ challenge=challenge,
+ )
+ self.assertEqual(pkce[0].id, repeat_pkce.id)
+
+ async def test_http_auth_ext_discord_callback_01(self):
+ with MockAuthProvider() as mock_provider, self.http_con() as http_con:
+ provider_config = await self.get_builtin_provider_config_by_name(
+ "oauth_discord"
+ )
+ provider_name = provider_config.name
+ client_id = provider_config.client_id
+ client_secret = DISCORD_SECRET
+
+ now = utcnow()
+ token_request = (
+ "POST",
+ "https://discord.com",
+ "/api/oauth2/token",
+ )
+ mock_provider.register_route_handler(*token_request)(
+ (
+ json.dumps(
+ {
+ "access_token": "discord_access_token",
+ "scope": "read:user",
+ "token_type": "bearer",
+ }
+ ),
+ 200,
+ )
+ )
+
+ user_request = ("GET", "https://discord.com/api/v10", "/users/@me")
+ mock_provider.register_route_handler(*user_request)(
+ (
+ json.dumps(
+ {
+ "id": 1,
+ "username": "dischord",
+ "global_name": "Ian MacKaye",
+ "email": "ian@example.com",
+ "picture": "https://example.com/example.jpg",
+ }
+ ),
+ 200,
+ )
+ )
+
+ challenge = (
+ base64.urlsafe_b64encode(
+ hashlib.sha256(
+ base64.urlsafe_b64encode(os.urandom(43)).rstrip(b'=')
+ ).digest()
+ )
+ .rstrip(b'=')
+ .decode()
+ )
+ await self.con.query(
+ """
+ insert ext::auth::PKCEChallenge {
+ challenge := $challenge,
+ }
+ """,
+ challenge=challenge,
+ )
+
+ signing_key = await self.get_signing_key()
+
+ expires_at = now + datetime.timedelta(minutes=5)
+ state_claims = {
+ "iss": self.http_addr,
+ "provider": str(provider_name),
+ "exp": expires_at.timestamp(),
+ "redirect_to": f"{self.http_addr}/some/path",
+ "challenge": challenge,
+ }
+ state_token = self.generate_state_value(state_claims, signing_key)
+
+ data, headers, status = self.http_con_request(
+ http_con,
+ {"state": state_token, "code": "abc123"},
+ path="callback",
+ )
+
+ self.assertEqual(data, b"")
+ self.assertEqual(status, 302)
+
+ location = headers.get("location")
+ assert location is not None
+ server_url = urllib.parse.urlparse(self.http_addr)
+ url = urllib.parse.urlparse(location)
+ self.assertEqual(url.scheme, server_url.scheme)
+ self.assertEqual(url.hostname, server_url.hostname)
+ self.assertEqual(url.path, f"{server_url.path}/some/path")
+
+ requests_for_token = mock_provider.requests[token_request]
+ self.assertEqual(len(requests_for_token), 1)
+
+ self.assertEqual(
+ urllib.parse.parse_qs(requests_for_token[0]["body"]),
+ {
+ "grant_type": ["authorization_code"],
+ "code": ["abc123"],
+ "client_id": [client_id],
+ "client_secret": [client_secret],
+ "redirect_uri": [f"{self.http_addr}/callback"],
+ },
+ )
+
+ requests_for_user = mock_provider.requests[user_request]
+ self.assertEqual(len(requests_for_user), 1)
+ self.assertEqual(
+ requests_for_user[0]["headers"]["authorization"],
+ "Bearer discord_access_token",
+ )
+
+ identity = await self.con.query(
+ """
+ SELECT ext::auth::Identity
+ FILTER .subject = '1'
+ AND .issuer = 'https://discord.com'
+ """
+ )
+ self.assertEqual(len(identity), 1)
+
+ session_claims = await self.extract_session_claims(headers)
+ self.assertEqual(session_claims.get("sub"), str(identity[0].id))
+ self.assertEqual(session_claims.get("iss"), str(self.http_addr))
+ tomorrow = now + datetime.timedelta(hours=25)
+ self.assertTrue(session_claims.get("exp") > now.timestamp())
+ self.assertTrue(session_claims.get("exp") < tomorrow.timestamp())
+
+ pkce_object = await self.con.query(
+ """
+ SELECT ext::auth::PKCEChallenge
+ { id, auth_token, refresh_token }
+ filter .identity.id = $identity_id
+ """,
+ identity_id=identity[0].id,
+ )
+
+ self.assertEqual(len(pkce_object), 1)
+ self.assertEqual(pkce_object[0].auth_token, "discord_access_token")
+ self.assertIsNone(pkce_object[0].refresh_token)
+
+ mock_provider.register_route_handler(*user_request)(
+ (
+ json.dumps(
+ {
+ "id": 1,
+ "login": "octocat",
+ "name": "monalisa octocat",
+ "email": "octocat+2@example.com",
+ "avatar_url": "https://example.com/example.jpg",
+ "updated_at": now.isoformat(),
+ }
+ ),
+ 200,
+ )
+ )
+ (_, new_headers, _) = self.http_con_request(
+ http_con,
+ {"state": state_token, "code": "abc123"},
+ path="callback",
+ )
+
+ same_identity = await self.con.query(
+ """
+ SELECT ext::auth::Identity
+ FILTER .subject = '1'
+ AND .issuer = 'https://discord.com'
+ """
+ )
+ self.assertEqual(len(same_identity), 1)
+ self.assertEqual(identity[0].id, same_identity[0].id)
+
+ new_session_claims = await self.extract_session_claims(new_headers)
+ self.assertTrue(
+ new_session_claims.get("exp") > session_claims.get("exp")
+ )
+
async def test_http_auth_ext_google_callback_01(self) -> None:
with MockAuthProvider() as mock_provider, self.http_con() as http_con:
provider_config = await self.get_builtin_provider_config_by_name(