diff --git a/edb/lib/cfg.edgeql b/edb/lib/cfg.edgeql index 939913da510..56df07eacbb 100644 --- a/edb/lib/cfg.edgeql +++ b/edb/lib/cfg.edgeql @@ -102,6 +102,71 @@ CREATE TYPE cfg::Auth EXTENDING cfg::ConfigObject { }; }; +CREATE SCALAR TYPE cfg::SMTPSecurity EXTENDING enum< + PlainText, + TLS, + STARTTLS, + STARTTLSOrPlainText, +>; + +CREATE ABSTRACT TYPE cfg::EmailProviderConfig EXTENDING cfg::ConfigObject { + CREATE REQUIRED PROPERTY name -> std::str { + CREATE CONSTRAINT std::exclusive; + CREATE ANNOTATION std::description := + "The name of the email provider."; + }; +}; + +CREATE TYPE cfg::SMTPProviderConfig EXTENDING cfg::EmailProviderConfig { + CREATE PROPERTY sender -> std::str { + CREATE ANNOTATION std::description := + "\"From\" address of system emails sent for e.g. \ + password reset, etc."; + }; + CREATE PROPERTY host -> std::str { + CREATE ANNOTATION std::description := + "Host of SMTP server to use for sending emails. \ + If not set, \"localhost\" will be used."; + }; + CREATE PROPERTY port -> std::int32 { + CREATE ANNOTATION std::description := + "Port of SMTP server to use for sending emails. \ + If not set, common defaults will be used depending on security: \ + 465 for TLS, 587 for STARTTLS, 25 otherwise."; + }; + CREATE PROPERTY username -> std::str { + CREATE ANNOTATION std::description := + "Username to login as after connected to SMTP server."; + }; + CREATE PROPERTY password -> std::str { + SET secret := true; + CREATE ANNOTATION std::description := + "Password for login after connected to SMTP server."; + }; + CREATE REQUIRED PROPERTY security -> cfg::SMTPSecurity { + SET default := cfg::SMTPSecurity.STARTTLSOrPlainText; + CREATE ANNOTATION std::description := + "Security mode of the connection to SMTP server. \ + By default, initiate a STARTTLS upgrade if supported by the \ + server, or fallback to PlainText."; + }; + CREATE REQUIRED PROPERTY validate_certs -> std::bool { + SET default := true; + CREATE ANNOTATION std::description := + "Determines if SMTP server certificates are validated."; + }; + CREATE REQUIRED PROPERTY timeout_per_email -> std::duration { + SET default := '60 seconds'; + CREATE ANNOTATION std::description := + "Maximum time to send an email, including retry attempts."; + }; + CREATE REQUIRED PROPERTY timeout_per_attempt -> std::duration { + SET default := '15 seconds'; + CREATE ANNOTATION std::description := + "Maximum time for each SMTP request."; + }; +}; + CREATE ABSTRACT TYPE cfg::AbstractConfig extending cfg::ConfigObject; CREATE ABSTRACT TYPE cfg::ExtensionConfig EXTENDING cfg::ConfigObject { @@ -158,6 +223,16 @@ ALTER TYPE cfg::AbstractConfig { CREATE ANNOTATION cfg::system := 'true'; }; + CREATE MULTI LINK email_providers -> cfg::EmailProviderConfig { + CREATE ANNOTATION std::description := + 'The list of email providers that can be used to send emails.'; + }; + + CREATE PROPERTY current_email_provider_name -> std::str { + CREATE ANNOTATION std::description := + 'The name of the current email provider.'; + }; + CREATE PROPERTY allow_dml_in_functions -> std::bool { SET default := false; CREATE ANNOTATION cfg::affects_compilation := 'true'; diff --git a/edb/lib/ext/auth.edgeql b/edb/lib/ext/auth.edgeql index 64f9843653d..d1d31c08889 100644 --- a/edb/lib/ext/auth.edgeql +++ b/edb/lib/ext/auth.edgeql @@ -469,58 +469,6 @@ CREATE EXTENSION PACKAGE auth VERSION '1.0' { }; }; - create scalar type ext::auth::SMTPSecurity extending enum; - - create type ext::auth::SMTPConfig extending cfg::ExtensionConfig { - create property sender: std::str { - create annotation std::description := - "\"From\" address of system emails sent for e.g. \ - password reset, etc."; - }; - create property host: std::str { - create annotation std::description := - "Host of SMTP server to use for sending emails. \ - If not set, \"localhost\" will be used."; - }; - create property port: std::int32 { - create annotation std::description := - "Port of SMTP server to use for sending emails. \ - If not set, common defaults will be used depending on security: \ - 465 for TLS, 587 for STARTTLS, 25 otherwise."; - }; - create property username: std::str { - create annotation std::description := - "Username to login as after connected to SMTP server."; - }; - create property password: std::str { - set secret := true; - create annotation std::description := - "Password for login after connected to SMTP server."; - }; - create required property security: ext::auth::SMTPSecurity { - set default := ext::auth::SMTPSecurity.STARTTLSOrPlainText; - create annotation std::description := - "Security mode of the connection to SMTP server. \ - By default, initiate a STARTTLS upgrade if supported by the \ - server, or fallback to PlainText."; - }; - create required property validate_certs: std::bool { - set default := true; - create annotation std::description := - "Determines if SMTP server certificates are validated."; - }; - create required property timeout_per_email: std::duration { - set default := '60 seconds'; - create annotation std::description := - "Maximum time to send an email, including retry attempts."; - }; - create required property timeout_per_attempt: std::duration { - set default := '15 seconds'; - create annotation std::description := - "Maximum time for each SMTP request."; - }; - }; - create function ext::auth::signing_key_exists() -> std::bool { using ( select exists cfg::Config.extensions[is ext::auth::AuthConfig] diff --git a/edb/server/compiler/compiler.py b/edb/server/compiler/compiler.py index 5813181485c..d1be9c262ca 100644 --- a/edb/server/compiler/compiler.py +++ b/edb/server/compiler/compiler.py @@ -905,6 +905,62 @@ def describe_database_dump( blocks=descriptors, ) + def _reprocess_restore_config( + self, + stmts: list[qlast.Base], + ) -> list[qlast.Base]: + '''Do any rewrites to the restore script needed. + + This is intended to patch over certain backwards incompatible + changes to config. We try not to do that too much, but when we + do, dumps still need to work. + ''' + + new_stmts = [] + smtp_config = {} + + for stmt in stmts: + # ext::auth::SMTPConfig got removed and moved into a cfg + # object, so intercept those and rewrite them. + if ( + isinstance(stmt, qlast.ConfigSet) + and stmt.name.module == 'ext::auth::SMTPConfig' + ): + smtp_config[stmt.name.name] = stmt.expr + else: + new_stmts.append(stmt) + + if smtp_config: + # Do the rewrite of SMTPConfig + smtp_config['name'] = qlast.Constant.string('_default') + + new_stmts.append( + qlast.ConfigInsert( + scope=qltypes.ConfigScope.DATABASE, + name=qlast.ObjectRef( + module='cfg', name='SMTPProviderConfig' + ), + shape=[ + qlast.ShapeElement( + expr=qlast.Path(steps=[qlast.Ptr(name=name)]), + compexpr=expr, + ) + for name, expr in smtp_config.items() + ], + ) + ) + new_stmts.append( + qlast.ConfigSet( + scope=qltypes.ConfigScope.DATABASE, + name=qlast.ObjectRef( + name='current_email_provider_name' + ), + expr=qlast.Constant.string('_default'), + ) + ) + + return new_stmts + def describe_database_restore( self, user_schema_pickle: bytes, @@ -984,7 +1040,11 @@ def describe_database_restore( # The state serializer generated below is somehow inappropriate, # so it's simply ignored here and the I/O process will do it on its own - units = compile(ctx=ctx, source=ddl_source).units + statements = edgeql.parse_block(ddl_source) + statements = self._reprocess_restore_config(statements) + units = _try_compile_ast( + ctx=ctx, source=ddl_source, statements=statements + ).units _check_force_database_error(ctx, scope='restore') @@ -2363,8 +2423,26 @@ def _try_compile( if text.startswith(sentinel): time.sleep(float(text[len(sentinel):text.index("\n")])) - default_cardinality = enums.Cardinality.NO_RESULT statements = edgeql.parse_block(source) + return _try_compile_ast(statements=statements, source=source, ctx=ctx) + + +def _try_compile_ast( + *, + ctx: CompileContext, + statements: list[qlast.Base], + source: edgeql.Source, +) -> dbstate.QueryUnitGroup: + if _get_config_val(ctx, '__internal_testmode'): + # This is a bad but simple way to emulate a slow compilation for tests. + # Ideally, we should have a testmode function that is hooked to sleep + # as `simple_special_case`, or wait for a notification from the test. + sentinel = "# EDGEDB_TEST_COMPILER_SLEEP = " + text = source.text() + if text.startswith(sentinel): + time.sleep(float(text[len(sentinel):text.index("\n")])) + + default_cardinality = enums.Cardinality.NO_RESULT statements_len = len(statements) if not len(statements): # pragma: no cover diff --git a/edb/server/protocol/auth_ext/email.py b/edb/server/protocol/auth_ext/email.py index 826ae5bb111..c5879d45a8d 100644 --- a/edb/server/protocol/auth_ext/email.py +++ b/edb/server/protocol/auth_ext/email.py @@ -3,9 +3,9 @@ import random from typing import Any, Coroutine -from edb.server import tenant +from edb.server import tenant, smtp -from . import util, ui, smtp +from . import util, ui async def send_password_reset_email( @@ -15,7 +15,6 @@ async def send_password_reset_email( reset_url: str, test_mode: bool, ) -> None: - from_addr = util.get_config(db, "ext::auth::SMTPConfig::sender") app_details_config = util.get_app_details_config(db) if app_details_config is None: email_args = {} @@ -27,16 +26,13 @@ async def send_password_reset_email( brand_color=app_details_config.brand_color, ) msg = ui.render_password_reset_email( - from_addr=from_addr, to_addr=to_addr, reset_url=reset_url, **email_args, ) - coro = smtp.send_email( - db, + smtp_provider = smtp.SMTP(db) + coro = smtp_provider.send( msg, - sender=from_addr, - recipients=to_addr, test_mode=test_mode, ) await _protected_send(coro, tenant) @@ -51,7 +47,6 @@ async def send_verification_email( provider: str, test_mode: bool, ) -> None: - from_addr = util.get_config(db, "ext::auth::SMTPConfig::sender") app_details_config = util.get_app_details_config(db) verification_token_params = urllib.parse.urlencode( { @@ -71,16 +66,13 @@ async def send_verification_email( brand_color=app_details_config.brand_color, ) msg = ui.render_verification_email( - from_addr=from_addr, to_addr=to_addr, verify_url=verify_url, **email_args, ) - coro = smtp.send_email( - db, + smtp_provider = smtp.SMTP(db) + coro = smtp_provider.send( msg, - sender=from_addr, - recipients=to_addr, test_mode=test_mode, ) await _protected_send(coro, tenant) @@ -93,7 +85,6 @@ async def send_magic_link_email( link: str, test_mode: bool, ) -> None: - from_addr = util.get_config(db, "ext::auth::SMTPConfig::sender") app_details_config = util.get_app_details_config(db) if app_details_config is None: email_args = {} @@ -105,16 +96,13 @@ async def send_magic_link_email( brand_color=app_details_config.brand_color, ) msg = ui.render_magic_link_email( - from_addr=from_addr, to_addr=to_addr, link=link, **email_args, ) - coro = smtp.send_email( - db, + smtp_provider = smtp.SMTP(db) + coro = smtp_provider.send( msg, - sender=from_addr, - recipients=to_addr, test_mode=test_mode, ) await _protected_send(coro, tenant) diff --git a/edb/server/protocol/auth_ext/smtp.py b/edb/server/protocol/auth_ext/smtp.py deleted file mode 100644 index e365defcf2f..00000000000 --- a/edb/server/protocol/auth_ext/smtp.py +++ /dev/null @@ -1,196 +0,0 @@ -# -# This source file is part of the EdgeDB open source project. -# -# Copyright 2023-present MagicStack Inc. and the EdgeDB authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from typing import Any, Optional, Union, Sequence - -import asyncio -import email -import email.message -import os -import pickle -import hashlib -import logging - -import aiosmtplib - -from edb.common import retryloop -from edb.ir import statypes - -from . import util - - -_semaphore: asyncio.BoundedSemaphore | None = None - - -logger = logging.getLogger('edb.server.smtp') - - -async def send_email( - db: Any, - message: Union[ - email.message.EmailMessage, - email.message.Message, - str, - bytes, - ], - sender: Optional[str] = None, - recipients: Optional[Union[str, Sequence[str]]] = None, - test_mode: bool = False, -) -> None: - global _semaphore - if _semaphore is None: - _semaphore = asyncio.BoundedSemaphore( - int(os.environ.get("EDGEDB_SERVER_AUTH_SMTP_CONCURRENCY", 5)) - ) - - host = ( - util.maybe_get_config( - db, - "ext::auth::SMTPConfig::host", - ) - or "localhost" - ) - port = util.maybe_get_config( - db, - "ext::auth::SMTPConfig::port", - expected_type=int, - ) - username = util.maybe_get_config( - db, - "ext::auth::SMTPConfig::username", - ) - password = util.maybe_get_config( - db, - "ext::auth::SMTPConfig::password", - ) - timeout_per_attempt = util.get_config( - db, - "ext::auth::SMTPConfig::timeout_per_attempt", - expected_type=statypes.Duration, - ) - req_timeout = timeout_per_attempt.to_microseconds() / 1_000_000.0 - timeout_per_email = util.get_config( - db, - "ext::auth::SMTPConfig::timeout_per_email", - expected_type=statypes.Duration, - ) - validate_certs = util.get_config( - db, - "ext::auth::SMTPConfig::validate_certs", - expected_type=bool, - ) - security = util.get_config( - db, - "ext::auth::SMTPConfig::security", - ) - start_tls: bool | None - match security: - case "PlainText": - use_tls = False - start_tls = False - - case "TLS": - use_tls = True - start_tls = False - - case "STARTTLS": - use_tls = False - start_tls = True - - case "STARTTLSOrPlainText": - use_tls = False - start_tls = None - - case _: - raise NotImplementedError - - rloop = retryloop.RetryLoop( - timeout=timeout_per_email.to_microseconds() / 1_000_000.0, - backoff=retryloop.exp_backoff(), - ignore=( - aiosmtplib.SMTPConnectError, - aiosmtplib.SMTPHeloError, - aiosmtplib.SMTPServerDisconnected, - aiosmtplib.SMTPConnectTimeoutError, - aiosmtplib.SMTPConnectResponseError, - ), - ) - async for iteration in rloop: - async with iteration: - async with _semaphore: - # Currently we are not reusing SMTP connections, but ideally we - # should replace this with a pool of connections, and drop idle - # connections after configured time. - if test_mode: - recipients_list: list[str] - if isinstance(recipients, str): - recipients_list = [recipients] - elif recipients is None: - recipients_list = [] - else: - recipients_list = list(recipients) - - hash_input = f"{sender}{','.join(recipients_list)}" - file_name_hash = hashlib.sha256( - hash_input.encode() - ).hexdigest() - file_name = f"/tmp/edb-test-email-{file_name_hash}.pickle" - test_file = os.environ.get( - "EDGEDB_TEST_EMAIL_FILE", - file_name, - ) - if os.path.exists(test_file): - os.unlink(test_file) - with open(test_file, "wb") as f: - args = dict( - message=message, - sender=sender, - recipients=recipients, - hostname=host, - port=port, - username=username, - password=password, - timeout=req_timeout, - use_tls=use_tls, - start_tls=start_tls, - validate_certs=validate_certs, - ) - pickle.dump(args, f) - else: - logger.info(f"Sending SMTP message to {host}:{port}") - errors, response = await aiosmtplib.send( - message, - sender=sender, - recipients=recipients, - hostname=host, - port=port, - username=username, - password=password, - timeout=req_timeout, - use_tls=use_tls, - start_tls=start_tls, - validate_certs=validate_certs, - ) - if errors: - logger.error( - f"SMTP server returned errors: {errors}" - ) - else: - logger.info( - f"SMTP message sent successfully: {response}" - ) diff --git a/edb/server/protocol/auth_ext/ui/__init__.py b/edb/server/protocol/auth_ext/ui/__init__.py index d066a0c0984..975f3393e45 100644 --- a/edb/server/protocol/auth_ext/ui/__init__.py +++ b/edb/server/protocol/auth_ext/ui/__init__.py @@ -20,9 +20,7 @@ from typing import cast, Optional import html - -from email.mime import multipart -from email.mime import text as mime_text +import email.message from edb.server.protocol.auth_ext import config as auth_config @@ -701,21 +699,17 @@ def render_magic_link_sent_page( def render_password_reset_email( *, - from_addr: str, to_addr: str, reset_url: str, app_name: Optional[str] = None, logo_url: Optional[str] = None, dark_logo_url: Optional[str] = None, brand_color: Optional[str] = "#007bff", -) -> multipart.MIMEMultipart: - msg = multipart.MIMEMultipart() - msg["From"] = from_addr +) -> email.message.EmailMessage: + msg = email.message.EmailMessage() msg["To"] = to_addr msg["Subject"] = "Reset password" - alternative = multipart.MIMEMultipart('alternative') - plain_text_msg = mime_text.MIMEText( - f""" + plain_text_content = f""" Somebody requested a new password for the {app_name or ''} account associated with {to_addr}. @@ -723,13 +717,8 @@ def render_password_reset_email( email address: {reset_url} - """, - "plain", - "utf-8", - ) - alternative.attach(plain_text_msg) - - content = f""" + """ + html_content = f""" multipart.MIMEMultipart: - msg = multipart.MIMEMultipart() - msg["From"] = from_addr +) -> email.message.EmailMessage: + msg = email.message.EmailMessage() msg["To"] = to_addr msg["Subject"] = ( f"Verify your email{f' for {app_name}' if app_name else ''}" ) - alternative = multipart.MIMEMultipart('alternative') - plain_text_msg = mime_text.MIMEText( - f""" + plain_text_content = f""" Congratulations, you're registered{f' at {app_name}' if app_name else ''}! Please paste the following URL into your browser address bar to verify your email address: {verify_url} - """, - "plain", - "utf-8", - ) - alternative.attach(plain_text_msg) - - content = f""" + """ + html_content = f""" - -""" - - html_msg = mime_text.MIMEText( + """ + msg.set_content(plain_text_content, subtype="plain") + msg.set_content( render.base_default_email( + content=html_content, app_name=app_name, logo_url=logo_url, - content=content, ), - "html", - "utf-8", + subtype="html", ) - alternative.attach(html_msg) - msg.attach(alternative) return msg def render_magic_link_email( *, - from_addr: str, to_addr: str, link: str, app_name: Optional[str] = None, logo_url: Optional[str] = None, dark_logo_url: Optional[str] = None, brand_color: Optional[str] = "#007bff", -) -> multipart.MIMEMultipart: - msg = multipart.MIMEMultipart() - msg["From"] = from_addr +) -> email.message.EmailMessage: + msg = email.message.EmailMessage() msg["To"] = to_addr msg["Subject"] = "Sign in link" - alternative = multipart.MIMEMultipart('alternative') - plain_text_msg = mime_text.MIMEText( - f""" + plain_text_content = f""" Please paste the following URL into your browser address bar to be signed into your account: {link} - """, - "plain", - "utf-8", - ) - alternative.attach(plain_text_msg) - content = f""" + """ + html_content = f""" """ - html_msg = mime_text.MIMEText( + msg.set_content(plain_text_content, subtype="plain") + msg.set_content( render.base_default_email( + content=html_content, app_name=app_name, logo_url=logo_url, - content=content, ), - "html", - "utf-8", + subtype="html", ) - alternative.attach(html_msg) - msg.attach(alternative) return msg diff --git a/edb/server/server.py b/edb/server/server.py index 557fa39eaf0..3c40bb84926 100644 --- a/edb/server/server.py +++ b/edb/server/server.py @@ -1274,6 +1274,7 @@ async def init(self) -> None: await self._load_instance_data() await self._maybe_patch() await self._tenant.init() + self._load_sidechannel_configs() await super().init() def get_default_tenant(self) -> edbtenant.Tenant: @@ -1282,6 +1283,20 @@ def get_default_tenant(self) -> edbtenant.Tenant: def iter_tenants(self) -> Iterator[edbtenant.Tenant]: yield self._tenant + def _load_sidechannel_configs(self) -> None: + # TODO(fantix): Do something like this for multitenant + magic_smtp = os.getenv('EDGEDB_MAGIC_SMTP_CONFIG') + if magic_smtp: + email_type = self._config_settings['email_providers'].type + assert not isinstance(email_type, type) + configs = [ + config.CompositeConfigType.from_json_value( + entry, tspec=email_type, spec=self._config_settings + ) + for entry in json.loads(magic_smtp) + ] + self._tenant.set_sidechannel_configs(configs) + async def _get_patch_log( self, conn: pgcon.PGConnection, idx: int ) -> Optional[bootstrap.PatchEntry]: diff --git a/edb/server/smtp.py b/edb/server/smtp.py new file mode 100644 index 00000000000..28bf59b86ca --- /dev/null +++ b/edb/server/smtp.py @@ -0,0 +1,231 @@ +# +# This source file is part of the EdgeDB open source project. +# +# Copyright 2024-present MagicStack Inc. and the EdgeDB authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import annotations + +import dataclasses +import email.message +import asyncio +import logging +import os +import hashlib +import pickle +import aiosmtplib + +from typing import Optional + +from edb.common import retryloop +from edb.ir import statypes +from edb import errors +from . import dbview + + +_semaphore: asyncio.BoundedSemaphore | None = None + +logger = logging.getLogger('edb.server.smtp') + + +@dataclasses.dataclass +class SMTPProviderConfig: + name: str + sender: Optional[str] + host: Optional[str] + port: Optional[int] + username: Optional[str] + password: Optional[str] + security: str + validate_certs: bool + timeout_per_email: statypes.Duration + timeout_per_attempt: statypes.Duration + + +class SMTP: + def __init__(self, db: dbview.Database): + current_provider = _get_current_email_provider(db) + self.sender = current_provider.sender or "noreply@example.com" + default_port = ( + 465 + if current_provider.security == "TLS" + else 587 if current_provider.security == "STARTTLS" else 25 + ) + use_tls: bool + start_tls: bool | None + match current_provider.security: + case "PlainText": + use_tls = False + start_tls = False + + case "TLS": + use_tls = True + start_tls = False + + case "STARTTLS": + use_tls = False + start_tls = True + + case "STARTTLSOrPlainText": + use_tls = False + start_tls = None + + case _: + raise NotImplementedError + + host = current_provider.host or "localhost" + port = current_provider.port or default_port + username = current_provider.username + password = current_provider.password + validate_certs = current_provider.validate_certs + timeout_per_attempt = current_provider.timeout_per_attempt + + req_timeout = timeout_per_attempt.to_microseconds() / 1_000_000.0 + self.timeout_per_email = ( + current_provider.timeout_per_email.to_microseconds() / 1_000_000.0 + ) + self.client = aiosmtplib.SMTP( + hostname=host, + port=port, + username=username, + password=password, + timeout=req_timeout, + use_tls=use_tls, + start_tls=start_tls, + validate_certs=validate_certs, + ) + + async def send( + self, + message: email.message.Message, + *, + test_mode: bool = False, + ) -> None: + global _semaphore + if _semaphore is None: + _semaphore = asyncio.BoundedSemaphore( + int( + os.environ.get( + "EDGEDB_SERVER_AUTH_SMTP_CONCURRENCY", + os.environ.get("EDGEDB_SERVER_SMTP_CONCURRENCY", 5), + ) + ) + ) + + # n.b. When constructing EmailMessage objects, we don't set the "From" + # header since that is configured in the SmtpProviderConfig. However, + # the EmailMessage will have the correct "To" header. + message["From"] = self.sender + rloop = retryloop.RetryLoop( + timeout=self.timeout_per_email, + backoff=retryloop.exp_backoff(), + ignore=( + aiosmtplib.SMTPConnectError, + aiosmtplib.SMTPHeloError, + aiosmtplib.SMTPServerDisconnected, + aiosmtplib.SMTPConnectTimeoutError, + aiosmtplib.SMTPConnectResponseError, + ), + ) + async for iteration in rloop: + async with iteration: + async with _semaphore: + # Currently we are not reusing SMTP connections, but + # ideally we should replace this with a pool of + # connections, and drop idle connections after configured + # time. + if test_mode: + self._send_test_mode_email(message) + else: + logger.info( + "Sending SMTP message to " + f"{self.client.hostname}:{self.client.port}" + ) + + async with self.client: + errors, response = await self.client.send_message( + message + ) + if errors: + logger.error( + f"SMTP server returned errors: {errors}" + ) + else: + logger.info( + f"SMTP message sent successfully: {response}" + ) + + def _send_test_mode_email(self, message: email.message.Message): + sender = message["From"] + recipients = message["To"] + recipients_list: list[str] + if isinstance(recipients, str): + recipients_list = [recipients] + elif recipients is None: + recipients_list = [] + else: + recipients_list = list(recipients) + + hash_input = f"{sender}{','.join(recipients_list)}" + file_name_hash = hashlib.sha256(hash_input.encode()).hexdigest() + file_name = f"/tmp/edb-test-email-{file_name_hash}.pickle" + test_file = os.environ.get( + "EDGEDB_TEST_EMAIL_FILE", + file_name, + ) + if os.path.exists(test_file): + os.unlink(test_file) + with open(test_file, "wb") as f: + logger.info(f"Dumping SMTP message to {test_file}") + args = dict( + message=message, + sender=sender, + recipients=recipients, + hostname=self.client.hostname, + port=self.client.port, + username=self.client._login_username, + password=self.client._login_password, + timeout=self.client.timeout, + use_tls=self.client.use_tls, + start_tls=self.client._start_tls_on_connect, + validate_certs=self.client.validate_certs, + ) + pickle.dump(args, f) + + +def _get_current_email_provider( + db: dbview.Database, +) -> SMTPProviderConfig: + current_provider_name = db.lookup_config("current_email_provider_name") + if current_provider_name is None: + raise errors.ConfigurationError("No email provider configured") + + found = None + objs = ( + list(db.lookup_config("email_providers")) + + db.tenant._sidechannel_email_configs + ) + for obj in objs: + if obj.name == current_provider_name: + as_json = obj.to_json_value() + as_json.pop('_tname', None) + found = SMTPProviderConfig(**as_json) + break + + if found is None: + raise errors.ConfigurationError( + f"No email provider named {current_provider_name!r}" + ) + return found diff --git a/edb/server/tenant.py b/edb/server/tenant.py index 187f3d10b00..3242acb3767 100644 --- a/edb/server/tenant.py +++ b/edb/server/tenant.py @@ -138,6 +138,8 @@ class Tenant(ha_base.ClusterProtocol): _http_client: HttpClient | None + _sidechannel_email_configs: list[Any] + def __init__( self, cluster: pgcluster.BaseCluster, @@ -161,6 +163,7 @@ def __init__( self._accept_new_tasks = False self._file_watch_finalizers = [] self._introspection_locks = weakref.WeakValueDictionary() + self._sidechannel_email_configs = [] self._extensions_dirs = extensions_dir @@ -246,6 +249,9 @@ def set_server(self, server: edbserver.BaseServer) -> None: self._server = server self.__loop = server.get_loop() + def set_sidechannel_configs(self, configs: list[Any]) -> None: + self._sidechannel_email_configs = configs + def get_http_client(self, *, originator: str) -> HttpClient: if self._http_client is None: http_max_connections = self._server.config_lookup( diff --git a/tests/schemas/dump_v4_setup.edgeql b/tests/schemas/dump_v4_setup.edgeql index 496d71dd0e3..7334687a4e1 100644 --- a/tests/schemas/dump_v4_setup.edgeql +++ b/tests/schemas/dump_v4_setup.edgeql @@ -55,8 +55,22 @@ ext::auth::AuthConfig::auth_signing_key := 'aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa'; CONFIGURE CURRENT DATABASE SET ext::auth::AuthConfig::token_time_to_live := '24 hours'; +# N.B: This CONFIGURE command was the original one, but then we +# removed that flag. We kept it working in dumps, though, so old +# dumps still work and behave as if they had the next two statements +# instead. +# +# CONFIGURE CURRENT DATABASE SET +# ext::auth::SMTPConfig::sender := 'noreply@example.com'; + +CONFIGURE CURRENT DATABASE INSERT cfg::SMTPProviderConfig { + name := "_default", + sender := 'noreply@example.com', +}; + CONFIGURE CURRENT DATABASE SET -ext::auth::SMTPConfig::sender := 'noreply@example.com'; +cfg::current_email_provider_name := "_default"; + CONFIGURE CURRENT DATABASE SET ext::auth::AuthConfig::allowed_redirect_urls := { diff --git a/tests/test_dump_v4.py b/tests/test_dump_v4.py index 638925596b7..6722e37fcf8 100644 --- a/tests/test_dump_v4.py +++ b/tests/test_dump_v4.py @@ -131,6 +131,39 @@ async def _ensure_schema_data_integrity(self, include_secrets): }] ) + # We didn't specify include_secrets in the dumps we made for + # 4.0, but the way that smtp config was done then, it got + # dumped anyway. (The secret wasn't specified.) + has_smtp = ( + include_secrets + or self._testMethodName == 'test_dumpv4_restore_compatibility_4_0' + ) + + # N.B: This is not what it looked like in the original + # dumps. We patched it up during restore starting with 6.0. + if has_smtp: + await self.assert_query_result( + ''' + select cfg::Config { + email_providers[is cfg::SMTPProviderConfig]: { + name, sender + }, + current_email_provider_name, + }; + ''', + [ + { + "email_providers": [ + { + "name": "_default", + "sender": "noreply@example.com", + } + ], + "current_email_provider_name": "_default" + } + ], + ) + class TestDumpV4(tb.StableDumpTestCase, DumpTestCaseMixin): EXTENSIONS = ["pgvector", "_conf", "pgcrypto", "auth"] diff --git a/tests/test_http_ext_auth.py b/tests/test_http_ext_auth.py index 349386f4efd..4827fdffdc0 100644 --- a/tests/test_http_ext_auth.py +++ b/tests/test_http_ext_auth.py @@ -31,6 +31,7 @@ from typing import Any, Optional, cast from jwcrypto import jwt, jwk +from email.message import EmailMessage from edgedb import QueryAssertionError from edb.testbase import http as tb @@ -248,6 +249,14 @@ class TestHttpExtAuth(tb.ExtAuthTestCase): SETUP = [ f""" + CONFIGURE CURRENT DATABASE INSERT cfg::SMTPProviderConfig {{ + name := "email_hosting_is_easy", + sender := "{SENDER}", + }}; + + CONFIGURE CURRENT DATABASE SET + cfg::current_email_provider_name := "email_hosting_is_easy"; + CONFIGURE CURRENT DATABASE SET ext::auth::AuthConfig::auth_signing_key := '{SIGNING_KEY}'; @@ -272,9 +281,6 @@ class TestHttpExtAuth(tb.ExtAuthTestCase): redirect_to_on_signup := 'https://example.com/signup/app', }}; - CONFIGURE CURRENT DATABASE SET - ext::auth::SMTPConfig::sender := '{SENDER}'; - CONFIGURE CURRENT DATABASE SET ext::auth::AuthConfig::allowed_redirect_urls := {{ 'https://example.com/app' @@ -3203,10 +3209,18 @@ async def test_http_auth_ext_local_emailpassword_resend_verification(self): email_args = pickle.load(f) self.assertEqual(email_args["sender"], SENDER) self.assertEqual(email_args["recipients"], form_data["email"]) - html_msg = email_args["message"].get_payload(0).get_payload(1) - html_email = html_msg.get_payload(decode=True).decode("utf-8") + msg = cast(EmailMessage, email_args["message"]).get_body( + ("html",) + ) + assert msg is not None + msg = cast(EmailMessage, email_args["message"]).get_body( + ("html",) + ) + assert msg is not None + html_email = msg.get_payload(decode=True).decode("utf-8") match = re.search( - r'

([^<]+)', html_email + r'

([^<]+)', + html_email, ) assert match is not None verify_url = urllib.parse.urlparse(match.group(1)) @@ -3382,8 +3396,11 @@ async def test_http_auth_ext_local_webauthn_resend_verification(self): email_args = pickle.load(f) self.assertEqual(email_args["sender"], SENDER) self.assertEqual(email_args["recipients"], email) - html_msg = email_args["message"].get_payload(0).get_payload(1) - html_email = html_msg.get_payload(decode=True).decode("utf-8") + msg = cast(EmailMessage, email_args["message"]).get_body( + ("html",) + ) + assert msg is not None + html_email = msg.get_payload(decode=True).decode("utf-8") match = re.search( r'

([^<]+)', html_email ) @@ -3679,8 +3696,11 @@ async def test_http_auth_ext_local_password_forgot_form_01(self): email_args = pickle.load(f) self.assertEqual(email_args["sender"], SENDER) self.assertEqual(email_args["recipients"], email) - html_msg = email_args["message"].get_payload(0).get_payload(1) - html_email = html_msg.get_payload(decode=True).decode("utf-8") + msg = cast(EmailMessage, email_args["message"]).get_body( + ("html",) + ) + assert msg is not None + html_email = msg.get_payload(decode=True).decode("utf-8") match = re.search( r'

([^<]+)', html_email ) @@ -3866,8 +3886,11 @@ async def test_http_auth_ext_local_password_reset_form_01(self): email_args = pickle.load(f) self.assertEqual(email_args["sender"], SENDER) self.assertEqual(email_args["recipients"], email) - html_msg = email_args["message"].get_payload(0).get_payload(1) - html_email = html_msg.get_payload(decode=True).decode("utf-8") + msg = cast(EmailMessage, email_args["message"]).get_body( + ("html",) + ) + assert msg is not None + html_email = msg.get_payload(decode=True).decode("utf-8") match = re.search( r'

([^<]+)', html_email ) @@ -4353,8 +4376,11 @@ async def test_http_auth_ext_magic_link_01(self): email_args = pickle.load(f) self.assertEqual(email_args["sender"], SENDER) self.assertEqual(email_args["recipients"], email) - html_msg = email_args["message"].get_payload(0).get_payload(1) - html_email = html_msg.get_payload(decode=True).decode("utf-8") + msg = cast(EmailMessage, email_args["message"]).get_body( + ("html",) + ) + assert msg is not None + html_email = msg.get_payload(decode=True).decode("utf-8") match = re.search( r'

([^<]+)', html_email ) diff --git a/tests/test_server_ops.py b/tests/test_server_ops.py index 0a904be368a..dd79a831f22 100644 --- a/tests/test_server_ops.py +++ b/tests/test_server_ops.py @@ -1645,9 +1645,6 @@ async def _test_server_ops_global_compile_cache( insert ext::auth::EmailPasswordProviderConfig {{ require_verification := false, }}; - - configure current database set - ext::auth::SMTPConfig::sender := 'noreply@example.com'; ''') finally: await conn.aclose()