diff --git a/edb/lib/ext/auth.edgeql b/edb/lib/ext/auth.edgeql index 67848b6ec4a..abc9cdb7921 100644 --- a/edb/lib/ext/auth.edgeql +++ b/edb/lib/ext/auth.edgeql @@ -256,6 +256,27 @@ CREATE EXTENSION PACKAGE auth VERSION '1.0' { UI is disabled."; }; + create property app_name: std::str { + create annotation std::description := + "The name of your application."; + }; + + create property logo_url: std::str { + create annotation std::description := + "A url to an image of your application's logo."; + }; + + create property dark_logo_url: std::str { + create annotation std::description := + "A url to an image of your application's logo to be used \ + with the dark theme."; + }; + + create property brand_color: std::str { + create annotation std::description := + "The brand color of your application as a hex string."; + }; + create property auth_signing_key -> std::str { set secret := true; create annotation std::description := diff --git a/edb/server/protocol/auth_ext/config.py b/edb/server/protocol/auth_ext/config.py index 5fd18712bfb..79b333cbbef 100644 --- a/edb/server/protocol/auth_ext/config.py +++ b/edb/server/protocol/auth_ext/config.py @@ -18,11 +18,16 @@ from typing import Optional +from dataclasses import dataclass class UIConfig: redirect_to: str redirect_to_on_signup: Optional[str] + + +@dataclass +class AppDetailsConfig: app_name: Optional[str] logo_url: Optional[str] dark_logo_url: Optional[str] diff --git a/edb/server/protocol/auth_ext/email.py b/edb/server/protocol/auth_ext/email.py index 5a2733b7c38..b121d36240b 100644 --- a/edb/server/protocol/auth_ext/email.py +++ b/edb/server/protocol/auth_ext/email.py @@ -2,11 +2,10 @@ import urllib.parse import random -from typing import Any, Coroutine, cast +from typing import Any, Coroutine from edb.server import tenant -from edb.server.config.types import CompositeConfigType -from . import util, ui, smtp, config +from . import util, ui, smtp async def send_password_reset_email( @@ -17,17 +16,15 @@ async def send_password_reset_email( test_mode: bool, ): from_addr = util.get_config(db, "ext::auth::SMTPConfig::sender") - ui_config = cast(config.UIConfig, util.maybe_get_config( - db, "ext::auth::AuthConfig::ui", CompositeConfigType - )) - if ui_config is None: + app_details_config = util.get_app_details_config(db) + if app_details_config is None: email_args = {} else: email_args = dict( - app_name=ui_config.app_name, - logo_url=ui_config.logo_url, - dark_logo_url=ui_config.dark_logo_url, - brand_color=ui_config.brand_color, + app_name=app_details_config.app_name, + logo_url=app_details_config.logo_url, + dark_logo_url=app_details_config.dark_logo_url, + brand_color=app_details_config.brand_color, ) msg = ui.render_password_reset_email( from_addr=from_addr, @@ -55,9 +52,7 @@ async def send_verification_email( test_mode: bool, ): from_addr = util.get_config(db, "ext::auth::SMTPConfig::sender") - ui_config = cast(config.UIConfig, util.maybe_get_config( - db, "ext::auth::AuthConfig::ui", CompositeConfigType - )) + app_details_config = util.get_app_details_config(db) verification_token_params = urllib.parse.urlencode( { "verification_token": verification_token, @@ -66,14 +61,14 @@ async def send_verification_email( } ) verify_url = f"{verify_url}?{verification_token_params}" - if ui_config is None: + if app_details_config is None: email_args = {} else: email_args = dict( - app_name=ui_config.app_name, - logo_url=ui_config.logo_url, - dark_logo_url=ui_config.dark_logo_url, - brand_color=ui_config.brand_color, + app_name=app_details_config.app_name, + logo_url=app_details_config.logo_url, + dark_logo_url=app_details_config.dark_logo_url, + brand_color=app_details_config.brand_color, ) msg = ui.render_verification_email( from_addr=from_addr, diff --git a/edb/server/protocol/auth_ext/http.py b/edb/server/protocol/auth_ext/http.py index e7bb42f12f2..7a715afaa1b 100644 --- a/edb/server/protocol/auth_ext/http.py +++ b/edb/server/protocol/auth_ext/http.py @@ -272,9 +272,9 @@ async def handle_callback(self, request: Any, response: Any): } if error_description is not None: params["error_description"] = error_description - response.custom_headers[ - "Location" - ] = _join_url_params(redirect_to, params) + response.custom_headers["Location"] = _join_url_params( + redirect_to, params + ) response.status = http.HTTPStatus.FOUND return @@ -326,8 +326,9 @@ async def handle_callback(self, request: Any, response: Any): ) new_url = _join_url_params( (redirect_to_on_signup or redirect_to) - if new_identity else redirect_to, - {"code": pkce_code, "provider": provider_name} + if new_identity + else redirect_to, + {"code": pkce_code, "provider": provider_name}, ) session_token = self._make_session_token(identity.id) response.status = http.HTTPStatus.FOUND @@ -432,7 +433,7 @@ async def handle_register(self, request: Any, response: Any): if require_verification else { "code": cast(str, pkce_code), - "provider": register_provider_name + "provider": register_provider_name, } ) response.custom_headers["Location"] = _join_url_params( @@ -448,10 +449,9 @@ async def handle_register(self, request: Any, response: Any): else: if pkce_code is None: raise errors.PKCECreationFailed - response.body = json.dumps({ - "code": pkce_code, - "provider": register_provider_name - }).encode() + response.body = json.dumps( + {"code": pkce_code, "provider": register_provider_name} + ).encode() except Exception as ex: redirect_on_failure = data.get( "redirect_on_failure", maybe_redirect_to @@ -460,7 +460,7 @@ async def handle_register(self, request: Any, response: Any): response.status = http.HTTPStatus.FOUND redirect_params = { "error": str(ex), - "email": data.get('email', '') + "email": data.get('email', ''), } response.custom_headers["Location"] = _join_url_params( redirect_on_failure, redirect_params @@ -803,6 +803,7 @@ async def handle_ui_signin(self, request: Any, response: Any): 'No providers are configured', ) + app_details_config = self._get_app_details_config() query = urllib.parse.parse_qs( request.url.query.decode("ascii") if request.url.query else "" ) @@ -827,10 +828,10 @@ async def handle_ui_signin(self, request: Any, response: Any): error_message=_maybe_get_search_param(query, 'error'), email=_maybe_get_search_param(query, 'email'), challenge=maybe_challenge, - app_name=ui_config.app_name, - logo_url=ui_config.logo_url, - dark_logo_url=ui_config.dark_logo_url, - brand_color=ui_config.brand_color, + app_name=app_details_config.app_name, + logo_url=app_details_config.logo_url, + dark_logo_url=app_details_config.dark_logo_url, + brand_color=app_details_config.brand_color, ) async def handle_ui_signup(self, request: Any, response: Any): @@ -863,6 +864,7 @@ async def handle_ui_signup(self, request: Any, response: Any): raise errors.InvalidData( 'Missing "challenge" in register request' ) + app_details_config = self._get_app_details_config() response.status = http.HTTPStatus.OK response.content_type = b'text/html' @@ -874,10 +876,10 @@ async def handle_ui_signup(self, request: Any, response: Any): error_message=_maybe_get_search_param(query, 'error'), email=_maybe_get_search_param(query, 'email'), challenge=maybe_challenge, - app_name=ui_config.app_name, - logo_url=ui_config.logo_url, - dark_logo_url=ui_config.dark_logo_url, - brand_color=ui_config.brand_color, + app_name=app_details_config.app_name, + logo_url=app_details_config.logo_url, + dark_logo_url=app_details_config.dark_logo_url, + brand_color=app_details_config.brand_color, ) async def handle_ui_forgot_password(self, request: Any, response: Any): @@ -898,6 +900,7 @@ async def handle_ui_forgot_password(self, request: Any, response: Any): request.url.query.decode("ascii") if request.url.query else "" ) challenge = _get_search_param(query, "challenge") + app_details_config = self._get_app_details_config() response.status = http.HTTPStatus.OK response.content_type = b'text/html' @@ -908,10 +911,10 @@ async def handle_ui_forgot_password(self, request: Any, response: Any): email=_maybe_get_search_param(query, 'email'), email_sent=_maybe_get_search_param(query, 'email_sent'), challenge=challenge, - app_name=ui_config.app_name, - logo_url=ui_config.logo_url, - dark_logo_url=ui_config.dark_logo_url, - brand_color=ui_config.brand_color, + app_name=app_details_config.app_name, + logo_url=app_details_config.logo_url, + dark_logo_url=app_details_config.dark_logo_url, + brand_color=app_details_config.brand_color, ) async def handle_ui_reset_password(self, request: Any, response: Any): @@ -956,6 +959,7 @@ async def handle_ui_reset_password(self, request: Any, response: Any): else: is_valid = False + app_details_config = self._get_app_details_config() response.status = http.HTTPStatus.OK response.content_type = b'text/html' response.body = ui.render_reset_password_page( @@ -966,10 +970,10 @@ async def handle_ui_reset_password(self, request: Any, response: Any): reset_token=reset_token, challenge=challenge, error_message=_maybe_get_search_param(query, 'error'), - app_name=ui_config.app_name, - logo_url=ui_config.logo_url, - dark_logo_url=ui_config.dark_logo_url, - brand_color=ui_config.brand_color, + app_name=app_details_config.app_name, + logo_url=app_details_config.logo_url, + dark_logo_url=app_details_config.dark_logo_url, + brand_color=app_details_config.brand_color, ) async def handle_ui_verify(self, request: Any, response: Any): @@ -1066,16 +1070,17 @@ async def handle_ui_verify(self, request: Any, response: Any): response.custom_headers["Location"] = redirect_to return + app_details_config = self._get_app_details_config() response.status = http.HTTPStatus.OK response.content_type = b'text/html' response.body = ui.render_email_verification_page( verification_token=maybe_verification_token, is_valid=is_valid, error_messages=error_messages, - app_name=ui_config.app_name, - logo_url=ui_config.logo_url, - dark_logo_url=ui_config.dark_logo_url, - brand_color=ui_config.brand_color, + app_name=app_details_config.app_name, + logo_url=app_details_config.logo_url, + dark_logo_url=app_details_config.dark_logo_url, + brand_color=app_details_config.brand_color, ) async def handle_ui_resend_verification(self, request: Any, response: Any): @@ -1122,6 +1127,7 @@ async def handle_ui_resend_verification(self, request: Any, response: Any): except Exception: is_valid = False + app_details_config = self._get_app_details_config() response.status = http.HTTPStatus.OK response.content_type = b"text/html" response.body = ui.render_resend_verification_done_page( @@ -1129,10 +1135,10 @@ async def handle_ui_resend_verification(self, request: Any, response: Any): verification_token=_maybe_get_search_param( query, "verification_token" ), - app_name=ui_config.app_name, - logo_url=ui_config.logo_url, - dark_logo_url=ui_config.dark_logo_url, - brand_color=ui_config.brand_color, + app_name=app_details_config.app_name, + logo_url=app_details_config.logo_url, + dark_logo_url=app_details_config.dark_logo_url, + brand_color=app_details_config.brand_color, ) def _get_callback_url(self) -> str: @@ -1302,10 +1308,18 @@ def _get_data_from_verification_token( maybe_redirect_to, ): case ( - str(id), float(issued_at), verify_url, challenge, redirect_to + str(id), + float(issued_at), + verify_url, + challenge, + redirect_to, ): return_value = ( - id, issued_at, verify_url, challenge, redirect_to + id, + issued_at, + verify_url, + challenge, + redirect_to, ) case (_, _, _, _, _): raise errors.InvalidData( @@ -1343,6 +1357,9 @@ def _get_ui_config(self): ), ) + def _get_app_details_config(self): + return util.get_app_details_config(self.db) + def _get_password_provider(self): providers = cast( list[config.ProviderConfig], @@ -1428,9 +1445,7 @@ def _is_url_allowed(self, url: str) -> bool: ui_config = self._get_ui_config() if ui_config: - allowed_urls = allowed_urls.union( - {ui_config.redirect_to} - ) + allowed_urls = allowed_urls.union({ui_config.redirect_to}) if ui_config.redirect_to_on_signup: allowed_urls = allowed_urls.union( {ui_config.redirect_to_on_signup} @@ -1535,7 +1550,7 @@ def _join_url_params(url: str, params: dict[str, str]): parsed_url = urllib.parse.urlparse(url) query_params = { **urllib.parse.parse_qs(parsed_url.query), - **{key: [val] for key, val in params.items()} + **{key: [val] for key, val in params.items()}, } new_query_params = urllib.parse.urlencode(query_params, doseq=True) return parsed_url._replace(query=new_query_params).geturl() diff --git a/edb/server/protocol/auth_ext/util.py b/edb/server/protocol/auth_ext/util.py index 6b3c92071f6..7da8d532b13 100644 --- a/edb/server/protocol/auth_ext/util.py +++ b/edb/server/protocol/auth_ext/util.py @@ -22,20 +22,17 @@ from edb.server import config from . import errors +from .config import AppDetailsConfig T = TypeVar("T") -def maybe_get_config_unchecked( - db: Any, key: str -) -> Any: +def maybe_get_config_unchecked(db: Any, key: str) -> Any: return config.lookup(key, db.db_config, spec=db.user_config_spec) @overload -def maybe_get_config( - db: Any, key: str, expected_type: Type[T] -) -> T | None: +def maybe_get_config(db: Any, key: str, expected_type: Type[T]) -> T | None: ... @@ -71,9 +68,7 @@ def get_config(db: Any, key: str) -> str: ... -def get_config( - db: Any, key: str, expected_type: Type[object] = str -) -> object: +def get_config(db: Any, key: str, expected_type: Type[object] = str) -> object: value = maybe_get_config(db, key, expected_type) if value is None: raise errors.MissingConfiguration( @@ -83,9 +78,7 @@ def get_config( return value -def get_config_unchecked( - db: Any, key: str -) -> Any: +def get_config_unchecked(db: Any, key: str) -> Any: value = maybe_get_config_unchecked(db, key) if value is None: raise errors.MissingConfiguration( @@ -97,3 +90,14 @@ def get_config_unchecked( def get_config_typename(config_value: config.SettingValue) -> str: return config_value._tspec.name # type: ignore + + +def get_app_details_config(db: Any) -> AppDetailsConfig: + return AppDetailsConfig( + app_name=maybe_get_config(db, "ext::auth::AuthConfig::app_name"), + logo_url=maybe_get_config(db, "ext::auth::AuthConfig::logo_url"), + dark_logo_url=maybe_get_config( + db, "ext::auth::AuthConfig::dark_logo_url" + ), + brand_color=maybe_get_config(db, "ext::auth::AuthConfig::brand_color"), + ) diff --git a/tests/test_http_ext_auth.py b/tests/test_http_ext_auth.py index 6e33ce93d8c..3e1bba3df39 100644 --- a/tests/test_http_ext_auth.py +++ b/tests/test_http_ext_auth.py @@ -267,9 +267,7 @@ def handle_request( headers = {k.lower(): v for k, v in dict(handler.headers).items()} query_params = urllib.parse.parse_qs(parsed_path.query) if 'content-length' in headers: - body = handler.rfile.read( - int(headers['content-length']) - ).decode() + body = handler.rfile.read(int(headers['content-length'])).decode() else: body = None @@ -373,6 +371,10 @@ def __exit__(self, *exc): APPLE_SECRET = 'c' * 32 DISCORD_SECRET = 'd' * 32 SLACK_SECRET = 'd' * 32 +APP_NAME = "Test App" +LOGO_URL = "http://example.com/logo.png" +DARK_LOGO_URL = "http://example.com/darklogo.png" +BRAND_COLOR = "f0f8ff" class TestHttpExtAuth(tb.ExtAuthTestCase): @@ -387,6 +389,24 @@ class TestHttpExtAuth(tb.ExtAuthTestCase): CONFIGURE CURRENT DATABASE SET ext::auth::AuthConfig::token_time_to_live := '24 hours'; + CONFIGURE CURRENT DATABASE SET + ext::auth::AuthConfig::app_name := '{APP_NAME}'; + + CONFIGURE CURRENT DATABASE SET + ext::auth::AuthConfig::logo_url := '{LOGO_URL}'; + + CONFIGURE CURRENT DATABASE SET + ext::auth::AuthConfig::dark_logo_url := '{DARK_LOGO_URL}'; + + CONFIGURE CURRENT DATABASE SET + ext::auth::AuthConfig::brand_color := '{BRAND_COLOR}'; + + CONFIGURE CURRENT DATABASE + INSERT ext::auth::UIConfig {{ + redirect_to := 'https://example.com', + redirect_to_on_signup := 'https://example.com/signup', + }}; + CONFIGURE CURRENT DATABASE SET ext::auth::SMTPConfig::sender := 'noreply@example.com'; @@ -3480,6 +3500,25 @@ async def test_http_auth_ext_local_password_reset_form_02(self): self.assertEqual(status, 400) + async def test_http_auth_ext_ui_signin(self): + with self.http_con() as http_con: + challenge = ( + base64.urlsafe_b64encode(os.urandom(32)).rstrip(b'=').decode() + ) + query_params = urllib.parse.urlencode({"challenge": challenge}) + + body, _, status = self.http_con_request( + http_con, + path=f"ui/signin?{query_params}", + ) + + body_str = body.decode() + + self.assertIn(APP_NAME, body_str) + self.assertIn(LOGO_URL, body_str) + self.assertIn(BRAND_COLOR, body_str) + self.assertEqual(status, 200) + async def test_client_token_identity_card(self): await self.con.query_single( '''