Skip to content

Commit

Permalink
Support arbitrary OpenID Connect providers (#7510)
Browse files Browse the repository at this point in the history
Adds support for OpenID Connect providers that use OpenID Connect Discovery and
Authorization Code Flow (without PKCE), like Google and Facebook.
  • Loading branch information
scotttrinh authored Jul 8, 2024
1 parent cb1afc3 commit 37a9f8b
Show file tree
Hide file tree
Showing 12 changed files with 525 additions and 157 deletions.
23 changes: 23 additions & 0 deletions edb/lib/ext/auth.edgeql
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,29 @@ CREATE EXTENSION PACKAGE auth VERSION '1.0' {
};
};

create type ext::auth::OpenIDConnectProvider
extending ext::auth::OAuthProviderConfig {
alter property name {
set protected := false;
};

alter property display_name {
set protected := false;
};

create required property issuer_url: std::str {
create annotation std::description :=
"The issuer URL of the provider.";
};

create property logo_url: std::str {
create annotation std::description :=
"A url to an image of the provider's logo.";
};

create constraint exclusive on ((.issuer_url, .client_id));
};

create type ext::auth::AppleOAuthProvider
extending ext::auth::OAuthProviderConfig {
alter property name {
Expand Down
3 changes: 1 addition & 2 deletions edb/server/protocol/auth_ext/apple.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,12 @@
from . import base


class AppleProvider(base.OpenIDProvider):
class AppleProvider(base.OpenIDConnectProvider):
def __init__(self, *args: Any, **kwargs: Any):
super().__init__(
"apple",
"https://appleid.apple.com",
*args,
content_type=base.ContentType.FORM_ENCODED,
**kwargs,
)

Expand Down
3 changes: 1 addition & 2 deletions edb/server/protocol/auth_ext/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,11 @@
from . import base


class AzureProvider(base.OpenIDProvider):
class AzureProvider(base.OpenIDConnectProvider):
def __init__(self, *args: Any, **kwargs: Any):
super().__init__(
"azure",
"https://login.microsoftonline.com/common/v2.0",
*args,
content_type=base.ContentType.FORM_ENCODED,
**kwargs,
)
23 changes: 7 additions & 16 deletions edb/server/protocol/auth_ext/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,17 +70,15 @@ class ContentType(enum.StrEnum):
FORM_ENCODED = "application/x-www-form-urlencoded"


class OpenIDProvider(BaseProvider):
class OpenIDConnectProvider(BaseProvider):
def __init__(
self,
name: str,
issuer_url: str,
*args: Any,
content_type: ContentType = ContentType.JSON,
**kwargs: Any,
):
super().__init__(name, issuer_url, *args, **kwargs)
self.content_type = content_type

async def get_code_url(
self, state: str, redirect_uri: str, additional_scope: str
Expand Down Expand Up @@ -114,23 +112,16 @@ async def exchange_code(
"redirect_uri": redirect_uri,
}
headers = {"Accept": ContentType.JSON.value}
if self.content_type == ContentType.JSON:
resp = await client.post(
token_endpoint.path,
json=request_body,
headers=headers,
)
else:
resp = await client.post(
token_endpoint.path,
data=request_body,
headers=headers,
)
resp = await client.post(
token_endpoint.path,
data=request_body,
headers=headers,
)
if resp.status_code >= 400:
raise errors.OAuthProviderFailure(
f"Failed to exchange code: {resp.text}"
)
content_type = resp.headers.get('Content-Type', self.content_type)
content_type = resp.headers.get('Content-Type')
if content_type.startswith(str(ContentType.JSON)):
response_body = resp.json()
else:
Expand Down
2 changes: 2 additions & 0 deletions edb/server/protocol/auth_ext/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ class OAuthProviderConfig(ProviderConfig):
client_id: str
secret: str
additional_scope: Optional[str]
issuer_url: Optional[str]
logo_url: Optional[str]


class WebAuthnProviderConfig(ProviderConfig):
Expand Down
2 changes: 1 addition & 1 deletion edb/server/protocol/auth_ext/google.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from . import base


class GoogleProvider(base.OpenIDProvider):
class GoogleProvider(base.OpenIDConnectProvider):
def __init__(self, *args: Any, **kwargs: Any):
super().__init__(
"google", "https://accounts.google.com", *args, **kwargs
Expand Down
4 changes: 2 additions & 2 deletions edb/server/protocol/auth_ext/http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ async def post( # type: ignore[override]
**kwargs: Any,
) -> httpx.Response:
if self.edgedb_orig_base_url:
path = f'{self.edgedb_orig_base_url}/{path}'
path = f'{self.edgedb_orig_base_url}{path}'
return await super().post(
path, *args, **kwargs
)
Expand All @@ -62,5 +62,5 @@ async def get( # type: ignore[override]
**kwargs: Any,
) -> httpx.Response:
if self.edgedb_orig_base_url:
path = f'{self.edgedb_orig_base_url}/{path}'
path = f'{self.edgedb_orig_base_url}{path}'
return await super().get(path, *args, **kwargs)
68 changes: 46 additions & 22 deletions edb/server/protocol/auth_ext/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import json

from typing import cast, Any, Type
from typing import cast, Any
from edb.server.protocol import execute

from . import github, google, azure, apple, discord, slack
Expand All @@ -39,34 +39,56 @@ def __init__(
)

provider_config = self._get_provider_config(provider_name)
provider_args = (provider_config.client_id, provider_config.secret)
provider_args: tuple[str, str] | tuple[str, str, str, str] = (
provider_config.client_id,
provider_config.secret,
)
provider_kwargs = {
"http_factory": http_factory,
"additional_scope": provider_config.additional_scope,
}

provider_class: Type[base.BaseProvider]
match provider_name:
case "builtin::oauth_github":
provider_class = github.GitHubProvider
case "builtin::oauth_google":
provider_class = google.GoogleProvider
case "builtin::oauth_azure":
provider_class = azure.AzureProvider
case "builtin::oauth_apple":
provider_class = apple.AppleProvider
case "builtin::oauth_discord":
provider_class = discord.DiscordProvider
case "builtin::oauth_slack":
provider_class = slack.SlackProvider
match (provider_name, provider_config.issuer_url):
case ("builtin::oauth_github", _):
self.provider = github.GitHubProvider(
*provider_args,
**provider_kwargs,
)
case ("builtin::oauth_google", _):
self.provider = google.GoogleProvider(
*provider_args,
**provider_kwargs,
)
case ("builtin::oauth_azure", _):
self.provider = azure.AzureProvider(
*provider_args,
**provider_kwargs,
)
case ("builtin::oauth_apple", _):
self.provider = apple.AppleProvider(
*provider_args,
**provider_kwargs,
)
case ("builtin::oauth_discord", _):
self.provider = discord.DiscordProvider(
*provider_args,
**provider_kwargs,
)
case ("builtin::oauth_slack", _):
self.provider = slack.SlackProvider(
*provider_args,
**provider_kwargs,
)
case (provider_name, str(issuer_url)):
self.provider = base.OpenIDConnectProvider(
provider_name,
issuer_url,
*provider_args,
**provider_kwargs,
)
case _:
raise errors.InvalidData(f"Invalid provider: {provider_name}")

self.provider = provider_class(
*provider_args,
**provider_kwargs, # type: ignore[arg-type]
)

async def get_authorize_url(self, state: str, redirect_uri: str) -> str:
return await self.provider.get_code_url(
state=state,
Expand Down Expand Up @@ -121,7 +143,7 @@ async def _handle_identity(

return (
data.Identity(**result_json[0]['identity']),
result_json[0]['new']
result_json[0]['new'],
)

def _get_provider_config(
Expand All @@ -139,6 +161,8 @@ def _get_provider_config(
client_id=cfg.client_id,
secret=cfg.secret,
additional_scope=cfg.additional_scope,
issuer_url=getattr(cfg, 'issuer_url', None),
logo_url=getattr(cfg, 'logo_url', None),
)

raise errors.MissingConfiguration(
Expand Down
3 changes: 1 addition & 2 deletions edb/server/protocol/auth_ext/slack.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,11 @@
from . import base


class SlackProvider(base.OpenIDProvider):
class SlackProvider(base.OpenIDConnectProvider):
def __init__(self, *args: Any, **kwargs: Any):
super().__init__(
"slack",
"https://slack.com",
*args,
content_type=base.ContentType.FORM_ENCODED,
**kwargs,
)
Loading

0 comments on commit 37a9f8b

Please sign in to comment.