Skip to content

Commit

Permalink
Move app details configuration into AuthConfig (#6754)
Browse files Browse the repository at this point in the history
  • Loading branch information
scotttrinh authored Feb 7, 2024
1 parent c02db66 commit bb41d4a
Show file tree
Hide file tree
Showing 6 changed files with 154 additions and 75 deletions.
21 changes: 21 additions & 0 deletions edb/lib/ext/auth.edgeql
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,27 @@ CREATE EXTENSION PACKAGE auth VERSION '1.0' {
UI is disabled.";
};

create property app_name: std::str {
create annotation std::description :=
"The name of your application.";
};

create property logo_url: std::str {
create annotation std::description :=
"A url to an image of your application's logo.";
};

create property dark_logo_url: std::str {
create annotation std::description :=
"A url to an image of your application's logo to be used \
with the dark theme.";
};

create property brand_color: std::str {
create annotation std::description :=
"The brand color of your application as a hex string.";
};

create property auth_signing_key -> std::str {
set secret := true;
create annotation std::description :=
Expand Down
5 changes: 5 additions & 0 deletions edb/server/protocol/auth_ext/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,16 @@


from typing import Optional
from dataclasses import dataclass


class UIConfig:
redirect_to: str
redirect_to_on_signup: Optional[str]


@dataclass
class AppDetailsConfig:
app_name: Optional[str]
logo_url: Optional[str]
dark_logo_url: Optional[str]
Expand Down
33 changes: 14 additions & 19 deletions edb/server/protocol/auth_ext/email.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@
import urllib.parse
import random

from typing import Any, Coroutine, cast
from typing import Any, Coroutine
from edb.server import tenant
from edb.server.config.types import CompositeConfigType

from . import util, ui, smtp, config
from . import util, ui, smtp


async def send_password_reset_email(
Expand All @@ -17,17 +16,15 @@ async def send_password_reset_email(
test_mode: bool,
):
from_addr = util.get_config(db, "ext::auth::SMTPConfig::sender")
ui_config = cast(config.UIConfig, util.maybe_get_config(
db, "ext::auth::AuthConfig::ui", CompositeConfigType
))
if ui_config is None:
app_details_config = util.get_app_details_config(db)
if app_details_config is None:
email_args = {}
else:
email_args = dict(
app_name=ui_config.app_name,
logo_url=ui_config.logo_url,
dark_logo_url=ui_config.dark_logo_url,
brand_color=ui_config.brand_color,
app_name=app_details_config.app_name,
logo_url=app_details_config.logo_url,
dark_logo_url=app_details_config.dark_logo_url,
brand_color=app_details_config.brand_color,
)
msg = ui.render_password_reset_email(
from_addr=from_addr,
Expand Down Expand Up @@ -55,9 +52,7 @@ async def send_verification_email(
test_mode: bool,
):
from_addr = util.get_config(db, "ext::auth::SMTPConfig::sender")
ui_config = cast(config.UIConfig, util.maybe_get_config(
db, "ext::auth::AuthConfig::ui", CompositeConfigType
))
app_details_config = util.get_app_details_config(db)
verification_token_params = urllib.parse.urlencode(
{
"verification_token": verification_token,
Expand All @@ -66,14 +61,14 @@ async def send_verification_email(
}
)
verify_url = f"{verify_url}?{verification_token_params}"
if ui_config is None:
if app_details_config is None:
email_args = {}
else:
email_args = dict(
app_name=ui_config.app_name,
logo_url=ui_config.logo_url,
dark_logo_url=ui_config.dark_logo_url,
brand_color=ui_config.brand_color,
app_name=app_details_config.app_name,
logo_url=app_details_config.logo_url,
dark_logo_url=app_details_config.dark_logo_url,
brand_color=app_details_config.brand_color,
)
msg = ui.render_verification_email(
from_addr=from_addr,
Expand Down
97 changes: 56 additions & 41 deletions edb/server/protocol/auth_ext/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,9 +272,9 @@ async def handle_callback(self, request: Any, response: Any):
}
if error_description is not None:
params["error_description"] = error_description
response.custom_headers[
"Location"
] = _join_url_params(redirect_to, params)
response.custom_headers["Location"] = _join_url_params(
redirect_to, params
)
response.status = http.HTTPStatus.FOUND
return

Expand Down Expand Up @@ -326,8 +326,9 @@ async def handle_callback(self, request: Any, response: Any):
)
new_url = _join_url_params(
(redirect_to_on_signup or redirect_to)
if new_identity else redirect_to,
{"code": pkce_code, "provider": provider_name}
if new_identity
else redirect_to,
{"code": pkce_code, "provider": provider_name},
)
session_token = self._make_session_token(identity.id)
response.status = http.HTTPStatus.FOUND
Expand Down Expand Up @@ -432,7 +433,7 @@ async def handle_register(self, request: Any, response: Any):
if require_verification
else {
"code": cast(str, pkce_code),
"provider": register_provider_name
"provider": register_provider_name,
}
)
response.custom_headers["Location"] = _join_url_params(
Expand All @@ -448,10 +449,9 @@ async def handle_register(self, request: Any, response: Any):
else:
if pkce_code is None:
raise errors.PKCECreationFailed
response.body = json.dumps({
"code": pkce_code,
"provider": register_provider_name
}).encode()
response.body = json.dumps(
{"code": pkce_code, "provider": register_provider_name}
).encode()
except Exception as ex:
redirect_on_failure = data.get(
"redirect_on_failure", maybe_redirect_to
Expand All @@ -460,7 +460,7 @@ async def handle_register(self, request: Any, response: Any):
response.status = http.HTTPStatus.FOUND
redirect_params = {
"error": str(ex),
"email": data.get('email', '')
"email": data.get('email', ''),
}
response.custom_headers["Location"] = _join_url_params(
redirect_on_failure, redirect_params
Expand Down Expand Up @@ -803,6 +803,7 @@ async def handle_ui_signin(self, request: Any, response: Any):
'No providers are configured',
)

app_details_config = self._get_app_details_config()
query = urllib.parse.parse_qs(
request.url.query.decode("ascii") if request.url.query else ""
)
Expand All @@ -827,10 +828,10 @@ async def handle_ui_signin(self, request: Any, response: Any):
error_message=_maybe_get_search_param(query, 'error'),
email=_maybe_get_search_param(query, 'email'),
challenge=maybe_challenge,
app_name=ui_config.app_name,
logo_url=ui_config.logo_url,
dark_logo_url=ui_config.dark_logo_url,
brand_color=ui_config.brand_color,
app_name=app_details_config.app_name,
logo_url=app_details_config.logo_url,
dark_logo_url=app_details_config.dark_logo_url,
brand_color=app_details_config.brand_color,
)

async def handle_ui_signup(self, request: Any, response: Any):
Expand Down Expand Up @@ -863,6 +864,7 @@ async def handle_ui_signup(self, request: Any, response: Any):
raise errors.InvalidData(
'Missing "challenge" in register request'
)
app_details_config = self._get_app_details_config()

response.status = http.HTTPStatus.OK
response.content_type = b'text/html'
Expand All @@ -874,10 +876,10 @@ async def handle_ui_signup(self, request: Any, response: Any):
error_message=_maybe_get_search_param(query, 'error'),
email=_maybe_get_search_param(query, 'email'),
challenge=maybe_challenge,
app_name=ui_config.app_name,
logo_url=ui_config.logo_url,
dark_logo_url=ui_config.dark_logo_url,
brand_color=ui_config.brand_color,
app_name=app_details_config.app_name,
logo_url=app_details_config.logo_url,
dark_logo_url=app_details_config.dark_logo_url,
brand_color=app_details_config.brand_color,
)

async def handle_ui_forgot_password(self, request: Any, response: Any):
Expand All @@ -898,6 +900,7 @@ async def handle_ui_forgot_password(self, request: Any, response: Any):
request.url.query.decode("ascii") if request.url.query else ""
)
challenge = _get_search_param(query, "challenge")
app_details_config = self._get_app_details_config()

response.status = http.HTTPStatus.OK
response.content_type = b'text/html'
Expand All @@ -908,10 +911,10 @@ async def handle_ui_forgot_password(self, request: Any, response: Any):
email=_maybe_get_search_param(query, 'email'),
email_sent=_maybe_get_search_param(query, 'email_sent'),
challenge=challenge,
app_name=ui_config.app_name,
logo_url=ui_config.logo_url,
dark_logo_url=ui_config.dark_logo_url,
brand_color=ui_config.brand_color,
app_name=app_details_config.app_name,
logo_url=app_details_config.logo_url,
dark_logo_url=app_details_config.dark_logo_url,
brand_color=app_details_config.brand_color,
)

async def handle_ui_reset_password(self, request: Any, response: Any):
Expand Down Expand Up @@ -956,6 +959,7 @@ async def handle_ui_reset_password(self, request: Any, response: Any):
else:
is_valid = False

app_details_config = self._get_app_details_config()
response.status = http.HTTPStatus.OK
response.content_type = b'text/html'
response.body = ui.render_reset_password_page(
Expand All @@ -966,10 +970,10 @@ async def handle_ui_reset_password(self, request: Any, response: Any):
reset_token=reset_token,
challenge=challenge,
error_message=_maybe_get_search_param(query, 'error'),
app_name=ui_config.app_name,
logo_url=ui_config.logo_url,
dark_logo_url=ui_config.dark_logo_url,
brand_color=ui_config.brand_color,
app_name=app_details_config.app_name,
logo_url=app_details_config.logo_url,
dark_logo_url=app_details_config.dark_logo_url,
brand_color=app_details_config.brand_color,
)

async def handle_ui_verify(self, request: Any, response: Any):
Expand Down Expand Up @@ -1066,16 +1070,17 @@ async def handle_ui_verify(self, request: Any, response: Any):
response.custom_headers["Location"] = redirect_to
return

app_details_config = self._get_app_details_config()
response.status = http.HTTPStatus.OK
response.content_type = b'text/html'
response.body = ui.render_email_verification_page(
verification_token=maybe_verification_token,
is_valid=is_valid,
error_messages=error_messages,
app_name=ui_config.app_name,
logo_url=ui_config.logo_url,
dark_logo_url=ui_config.dark_logo_url,
brand_color=ui_config.brand_color,
app_name=app_details_config.app_name,
logo_url=app_details_config.logo_url,
dark_logo_url=app_details_config.dark_logo_url,
brand_color=app_details_config.brand_color,
)

async def handle_ui_resend_verification(self, request: Any, response: Any):
Expand Down Expand Up @@ -1122,17 +1127,18 @@ async def handle_ui_resend_verification(self, request: Any, response: Any):
except Exception:
is_valid = False

app_details_config = self._get_app_details_config()
response.status = http.HTTPStatus.OK
response.content_type = b"text/html"
response.body = ui.render_resend_verification_done_page(
is_valid=is_valid,
verification_token=_maybe_get_search_param(
query, "verification_token"
),
app_name=ui_config.app_name,
logo_url=ui_config.logo_url,
dark_logo_url=ui_config.dark_logo_url,
brand_color=ui_config.brand_color,
app_name=app_details_config.app_name,
logo_url=app_details_config.logo_url,
dark_logo_url=app_details_config.dark_logo_url,
brand_color=app_details_config.brand_color,
)

def _get_callback_url(self) -> str:
Expand Down Expand Up @@ -1302,10 +1308,18 @@ def _get_data_from_verification_token(
maybe_redirect_to,
):
case (
str(id), float(issued_at), verify_url, challenge, redirect_to
str(id),
float(issued_at),
verify_url,
challenge,
redirect_to,
):
return_value = (
id, issued_at, verify_url, challenge, redirect_to
id,
issued_at,
verify_url,
challenge,
redirect_to,
)
case (_, _, _, _, _):
raise errors.InvalidData(
Expand Down Expand Up @@ -1343,6 +1357,9 @@ def _get_ui_config(self):
),
)

def _get_app_details_config(self):
return util.get_app_details_config(self.db)

def _get_password_provider(self):
providers = cast(
list[config.ProviderConfig],
Expand Down Expand Up @@ -1428,9 +1445,7 @@ def _is_url_allowed(self, url: str) -> bool:

ui_config = self._get_ui_config()
if ui_config:
allowed_urls = allowed_urls.union(
{ui_config.redirect_to}
)
allowed_urls = allowed_urls.union({ui_config.redirect_to})
if ui_config.redirect_to_on_signup:
allowed_urls = allowed_urls.union(
{ui_config.redirect_to_on_signup}
Expand Down Expand Up @@ -1535,7 +1550,7 @@ def _join_url_params(url: str, params: dict[str, str]):
parsed_url = urllib.parse.urlparse(url)
query_params = {
**urllib.parse.parse_qs(parsed_url.query),
**{key: [val] for key, val in params.items()}
**{key: [val] for key, val in params.items()},
}
new_query_params = urllib.parse.urlencode(query_params, doseq=True)
return parsed_url._replace(query=new_query_params).geturl()
Loading

0 comments on commit bb41d4a

Please sign in to comment.