diff --git a/edb/lib/ext/auth.edgeql b/edb/lib/ext/auth.edgeql index 4d056f49964..64f9843653d 100644 --- a/edb/lib/ext/auth.edgeql +++ b/edb/lib/ext/auth.edgeql @@ -130,6 +130,10 @@ CREATE EXTENSION PACKAGE auth VERSION '1.0' { create annotation std::description := "Identity provider's refresh token."; }; + create property id_token: std::str { + create annotation std::description := + "Identity provider's OpenID Connect id_token."; + }; create link identity: ext::auth::Identity { on target delete delete source; }; diff --git a/edb/server/protocol/auth_ext/base.py b/edb/server/protocol/auth_ext/base.py index 59e957cc5a4..d43a492e4a6 100644 --- a/edb/server/protocol/auth_ext/base.py +++ b/edb/server/protocol/auth_ext/base.py @@ -173,6 +173,7 @@ async def fetch_user_info( name=payload.get("name"), email=payload.get("email"), picture=payload.get("picture"), + source_id_token=id_token, ) async def _get_oidc_config(self) -> data.OpenIDConfig: diff --git a/edb/server/protocol/auth_ext/data.py b/edb/server/protocol/auth_ext/data.py index 6464a70673b..035a0c657af 100644 --- a/edb/server/protocol/auth_ext/data.py +++ b/edb/server/protocol/auth_ext/data.py @@ -51,6 +51,7 @@ class UserInfo: phone_number_verified: Optional[bool] = None address: Optional[dict[str, str]] = None updated_at: Optional[float] = None + source_id_token: Optional[str] = None def __str__(self) -> str: return self.sub diff --git a/edb/server/protocol/auth_ext/http.py b/edb/server/protocol/auth_ext/http.py index 9f9450c0ef9..5bbfd55c4bf 100644 --- a/edb/server/protocol/auth_ext/http.py +++ b/edb/server/protocol/auth_ext/http.py @@ -409,6 +409,7 @@ async def handle_callback( new_identity, auth_token, refresh_token, + id_token, ) = await oauth_client.handle_callback(code, self._get_callback_url()) pkce_code = await pkce.link_identity_challenge( self.db, identity.id, challenge @@ -419,6 +420,7 @@ async def handle_callback( id=pkce_code, auth_token=auth_token, refresh_token=refresh_token, + id_token=id_token, ) new_url = util.join_url_params( ( @@ -486,6 +488,7 @@ async def handle_token( "identity_id": pkce_object.identity_id, "provider_token": pkce_object.auth_token, "provider_refresh_token": pkce_object.refresh_token, + "provider_id_token": pkce_object.id_token, } ).encode() else: diff --git a/edb/server/protocol/auth_ext/oauth.py b/edb/server/protocol/auth_ext/oauth.py index 1ed14ff1ccd..a1333ee2e81 100644 --- a/edb/server/protocol/auth_ext/oauth.py +++ b/edb/server/protocol/auth_ext/oauth.py @@ -104,16 +104,18 @@ async def get_authorize_url(self, state: str, redirect_uri: str) -> str: async def handle_callback( self, code: str, redirect_uri: str - ) -> tuple[data.Identity, bool, str | None, str | None]: + ) -> tuple[data.Identity, bool, str | None, str | None, str | None]: response = await self.provider.exchange_code(code, redirect_uri) user_info = await self.provider.fetch_user_info(response) auth_token = response.access_token refresh_token = response.refresh_token + source_id_token = user_info.source_id_token return ( *(await self._handle_identity(user_info)), auth_token, refresh_token, + source_id_token, ) async def _handle_identity( diff --git a/edb/server/protocol/auth_ext/pkce.py b/edb/server/protocol/auth_ext/pkce.py index b4d1c5ddde8..755310c9df0 100644 --- a/edb/server/protocol/auth_ext/pkce.py +++ b/edb/server/protocol/auth_ext/pkce.py @@ -46,6 +46,7 @@ class PKCEChallenge: challenge: str auth_token: str | None refresh_token: str | None + id_token: str | None identity_id: str | None @@ -95,6 +96,7 @@ async def add_provider_tokens( id: str, auth_token: str | None, refresh_token: str | None, + id_token: str | None, ) -> str: r = await execute.parse_execute_json( db, @@ -104,12 +106,14 @@ async def add_provider_tokens( set { auth_token := $auth_token, refresh_token := $refresh_token, + id_token := $id_token, } """, variables={ "id": id, "auth_token": auth_token, "refresh_token": refresh_token, + "id_token": id_token, }, cached_globally=True, ) @@ -129,6 +133,7 @@ async def get_by_id(db: edbtenant.dbview.Database, id: str) -> PKCEChallenge: challenge, auth_token, refresh_token, + id_token, identity_id := .identity.id } filter .id = $id diff --git a/tests/test_http_ext_auth.py b/tests/test_http_ext_auth.py index eaefccde7e6..d9836236d9c 100644 --- a/tests/test_http_ext_auth.py +++ b/tests/test_http_ext_auth.py @@ -3474,6 +3474,7 @@ async def test_http_auth_ext_token_01(self): challenge := $challenge, auth_token := $auth_token, refresh_token := $refresh_token, + id_token := $id_token, identity := ( insert ext::auth::Identity { issuer := "https://example.com", @@ -3486,12 +3487,14 @@ async def test_http_auth_ext_token_01(self): challenge, auth_token, refresh_token, + id_token, identity_id := .identity.id } """, challenge=challenge.decode(), auth_token="a_provider_token", refresh_token="a_refresh_token", + id_token="an_id_token", ) # Correct code, random verifier @@ -3530,6 +3533,7 @@ async def test_http_auth_ext_token_01(self): "identity_id": str(pkce.identity_id), "provider_token": "a_provider_token", "provider_refresh_token": "a_refresh_token", + "provider_id_token": "an_id_token", }, ) async for tr in self.try_until_succeeds(