Skip to content

Commit

Permalink
Pass through source OIDC id_token (#7952)
Browse files Browse the repository at this point in the history
We currently only use a small amount of data from the identity provider's ID
token, but applications might want to use different bits of this data for their
own user profile information. We currently require them to fetch this on their
own, but this will make this process slightly easier by returning the ID token,
if we have one.
  • Loading branch information
scotttrinh authored Nov 4, 2024
1 parent 8f2b29b commit 9494b22
Show file tree
Hide file tree
Showing 7 changed files with 21 additions and 1 deletion.
4 changes: 4 additions & 0 deletions edb/lib/ext/auth.edgeql
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,10 @@ CREATE EXTENSION PACKAGE auth VERSION '1.0' {
create annotation std::description :=
"Identity provider's refresh token.";
};
create property id_token: std::str {
create annotation std::description :=
"Identity provider's OpenID Connect id_token.";
};
create link identity: ext::auth::Identity {
on target delete delete source;
};
Expand Down
1 change: 1 addition & 0 deletions edb/server/protocol/auth_ext/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ async def fetch_user_info(
name=payload.get("name"),
email=payload.get("email"),
picture=payload.get("picture"),
source_id_token=id_token,
)

async def _get_oidc_config(self) -> data.OpenIDConfig:
Expand Down
1 change: 1 addition & 0 deletions edb/server/protocol/auth_ext/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class UserInfo:
phone_number_verified: Optional[bool] = None
address: Optional[dict[str, str]] = None
updated_at: Optional[float] = None
source_id_token: Optional[str] = None

def __str__(self) -> str:
return self.sub
Expand Down
3 changes: 3 additions & 0 deletions edb/server/protocol/auth_ext/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,7 @@ async def handle_callback(
new_identity,
auth_token,
refresh_token,
id_token,
) = await oauth_client.handle_callback(code, self._get_callback_url())
pkce_code = await pkce.link_identity_challenge(
self.db, identity.id, challenge
Expand All @@ -419,6 +420,7 @@ async def handle_callback(
id=pkce_code,
auth_token=auth_token,
refresh_token=refresh_token,
id_token=id_token,
)
new_url = util.join_url_params(
(
Expand Down Expand Up @@ -486,6 +488,7 @@ async def handle_token(
"identity_id": pkce_object.identity_id,
"provider_token": pkce_object.auth_token,
"provider_refresh_token": pkce_object.refresh_token,
"provider_id_token": pkce_object.id_token,
}
).encode()
else:
Expand Down
4 changes: 3 additions & 1 deletion edb/server/protocol/auth_ext/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,16 +104,18 @@ async def get_authorize_url(self, state: str, redirect_uri: str) -> str:

async def handle_callback(
self, code: str, redirect_uri: str
) -> tuple[data.Identity, bool, str | None, str | None]:
) -> tuple[data.Identity, bool, str | None, str | None, str | None]:
response = await self.provider.exchange_code(code, redirect_uri)
user_info = await self.provider.fetch_user_info(response)
auth_token = response.access_token
refresh_token = response.refresh_token
source_id_token = user_info.source_id_token

return (
*(await self._handle_identity(user_info)),
auth_token,
refresh_token,
source_id_token,
)

async def _handle_identity(
Expand Down
5 changes: 5 additions & 0 deletions edb/server/protocol/auth_ext/pkce.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class PKCEChallenge:
challenge: str
auth_token: str | None
refresh_token: str | None
id_token: str | None
identity_id: str | None


Expand Down Expand Up @@ -95,6 +96,7 @@ async def add_provider_tokens(
id: str,
auth_token: str | None,
refresh_token: str | None,
id_token: str | None,
) -> str:
r = await execute.parse_execute_json(
db,
Expand All @@ -104,12 +106,14 @@ async def add_provider_tokens(
set {
auth_token := <optional str>$auth_token,
refresh_token := <optional str>$refresh_token,
id_token := <optional str>$id_token,
}
""",
variables={
"id": id,
"auth_token": auth_token,
"refresh_token": refresh_token,
"id_token": id_token,
},
cached_globally=True,
)
Expand All @@ -129,6 +133,7 @@ async def get_by_id(db: edbtenant.dbview.Database, id: str) -> PKCEChallenge:
challenge,
auth_token,
refresh_token,
id_token,
identity_id := .identity.id
}
filter .id = <uuid>$id
Expand Down
4 changes: 4 additions & 0 deletions tests/test_http_ext_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -3474,6 +3474,7 @@ async def test_http_auth_ext_token_01(self):
challenge := <str>$challenge,
auth_token := <str>$auth_token,
refresh_token := <str>$refresh_token,
id_token := <str>$id_token,
identity := (
insert ext::auth::Identity {
issuer := "https://example.com",
Expand All @@ -3486,12 +3487,14 @@ async def test_http_auth_ext_token_01(self):
challenge,
auth_token,
refresh_token,
id_token,
identity_id := .identity.id
}
""",
challenge=challenge.decode(),
auth_token="a_provider_token",
refresh_token="a_refresh_token",
id_token="an_id_token",
)

# Correct code, random verifier
Expand Down Expand Up @@ -3530,6 +3533,7 @@ async def test_http_auth_ext_token_01(self):
"identity_id": str(pkce.identity_id),
"provider_token": "a_provider_token",
"provider_refresh_token": "a_refresh_token",
"provider_id_token": "an_id_token",
},
)
async for tr in self.try_until_succeeds(
Expand Down

0 comments on commit 9494b22

Please sign in to comment.