Skip to content

Commit

Permalink
Add Discord OAuth provider (#6708)
Browse files Browse the repository at this point in the history
  • Loading branch information
scotttrinh authored Jan 23, 2024
1 parent cfa3729 commit a405e32
Show file tree
Hide file tree
Showing 6 changed files with 387 additions and 8 deletions.
11 changes: 11 additions & 0 deletions edb/lib/ext/auth.edgeql
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,17 @@ CREATE EXTENSION PACKAGE auth VERSION '1.0' {
};
};

create type ext::auth::DiscordOAuthProvider
extending ext::auth::OAuthProviderConfig {
alter property name {
set default := 'builtin::oauth_discord';
};

alter property display_name {
set default := 'Discord';
};
};

create type ext::auth::GitHubOAuthProvider
extending ext::auth::OAuthProviderConfig {
alter property name {
Expand Down
3 changes: 3 additions & 0 deletions edb/server/protocol/auth_ext/_static/icon_discord.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
95 changes: 95 additions & 0 deletions edb/server/protocol/auth_ext/discord.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2024-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


import urllib.parse
import functools

from . import base, data, errors


class DiscordProvider(base.BaseProvider):
def __init__(self, *args, **kwargs):
super().__init__("discord", "https://discord.com", *args, **kwargs)
self.auth_domain = self.issuer_url
self.api_domain = f"{self.issuer_url}/api/v10"
self.auth_client = functools.partial(
self.http_factory, base_url=self.auth_domain
)
self.api_client = functools.partial(
self.http_factory, base_url=self.api_domain
)

async def get_code_url(
self, state: str, redirect_uri: str, additional_scope: str
) -> str:
params = {
"client_id": self.client_id,
"scope": f"email identify {additional_scope}",
"state": state,
"redirect_uri": redirect_uri,
"response_type": "code",
}
encoded = urllib.parse.urlencode(params)
return f"{self.auth_domain}/oauth2/authorize?{encoded}"

async def exchange_code(
self, code: str, redirect_uri: str
) -> data.OAuthAccessTokenResponse:
async with self.auth_client() as client:
resp = await client.post(
"/api/oauth2/token",
data={
"grant_type": "authorization_code",
"code": code,
"client_id": self.client_id,
"client_secret": self.client_secret,
"redirect_uri": redirect_uri,
},
headers={
"accept": "application/json",
},
)
if resp.status_code >= 400:
raise errors.OAuthProviderFailure(
f"Failed to exchange code: {resp.text}"
)
json = resp.json()

return data.OAuthAccessTokenResponse(**json)

async def fetch_user_info(
self, token_response: data.OAuthAccessTokenResponse
) -> data.UserInfo:
async with self.api_client() as client:
resp = await client.get(
"/users/@me",
headers={
"Authorization": f"Bearer {token_response.access_token}",
"Accept": "application/json",
"Cache-Control": "no-store",
},
)
payload = resp.json()
return data.UserInfo(
sub=str(payload["id"]),
preferred_username=payload.get("username"),
name=payload.get("global_name"),
email=payload.get("email"),
picture=payload.get("avatar"),
)
4 changes: 3 additions & 1 deletion edb/server/protocol/auth_ext/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from typing import Any, Type
from edb.server.protocol import execute

from . import github, google, azure, apple
from . import github, google, azure, apple, discord
from . import errors, util, data, base, http_client


Expand Down Expand Up @@ -55,6 +55,8 @@ def __init__(
provider_class = azure.AzureProvider
case "builtin::oauth_apple":
provider_class = apple.AppleProvider
case "builtin::oauth_discord":
provider_class = discord.DiscordProvider
case _:
raise errors.InvalidData(f"Invalid provider: {provider_name}")

Expand Down
1 change: 1 addition & 0 deletions edb/server/protocol/auth_ext/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
'builtin::oauth_google',
'builtin::oauth_apple',
'builtin::oauth_azure',
'builtin::oauth_discord',
]


Expand Down
Loading

0 comments on commit a405e32

Please sign in to comment.