From 02628d18627805a0fbe4b9ccccc3f6014ccb4a76 Mon Sep 17 00:00:00 2001 From: Scott Trinh Date: Tue, 5 Nov 2024 16:37:21 -0500 Subject: [PATCH] Add server `smtp` module and switch auth --- edb/lib/cfg.edgeql | 10 +- edb/server/protocol/auth_ext/email.py | 28 +-- edb/server/protocol/auth_ext/ui/__init__.py | 95 +++------ edb/server/smtp.py | 225 ++++++++++++++++++++ tests/test_http_ext_auth.py | 61 ++++-- 5 files changed, 309 insertions(+), 110 deletions(-) create mode 100644 edb/server/smtp.py diff --git a/edb/lib/cfg.edgeql b/edb/lib/cfg.edgeql index 0c8a7e29d672..56df07eacbbb 100644 --- a/edb/lib/cfg.edgeql +++ b/edb/lib/cfg.edgeql @@ -143,8 +143,8 @@ CREATE TYPE cfg::SMTPProviderConfig EXTENDING cfg::EmailProviderConfig { CREATE ANNOTATION std::description := "Password for login after connected to SMTP server."; }; - CREATE REQUIRED PROPERTY security -> ext::auth::SMTPSecurity { - SET default := ext::auth::SMTPSecurity.STARTTLSOrPlainText; + CREATE REQUIRED PROPERTY security -> cfg::SMTPSecurity { + SET default := cfg::SMTPSecurity.STARTTLSOrPlainText; CREATE ANNOTATION std::description := "Security mode of the connection to SMTP server. \ By default, initiate a STARTTLS upgrade if supported by the \ @@ -224,11 +224,13 @@ ALTER TYPE cfg::AbstractConfig { }; CREATE MULTI LINK email_providers -> cfg::EmailProviderConfig { - CREATE ANNOTATION cfg::system := 'true'; + CREATE ANNOTATION std::description := + 'The list of email providers that can be used to send emails.'; }; CREATE PROPERTY current_email_provider_name -> std::str { - CREATE ANNOTATION cfg::system := 'true'; + CREATE ANNOTATION std::description := + 'The name of the current email provider.'; }; CREATE PROPERTY allow_dml_in_functions -> std::bool { diff --git a/edb/server/protocol/auth_ext/email.py b/edb/server/protocol/auth_ext/email.py index 826ae5bb1113..c5879d45a8d0 100644 --- a/edb/server/protocol/auth_ext/email.py +++ b/edb/server/protocol/auth_ext/email.py @@ -3,9 +3,9 @@ import random from typing import Any, Coroutine -from edb.server import tenant +from edb.server import tenant, smtp -from . import util, ui, smtp +from . import util, ui async def send_password_reset_email( @@ -15,7 +15,6 @@ async def send_password_reset_email( reset_url: str, test_mode: bool, ) -> None: - from_addr = util.get_config(db, "ext::auth::SMTPConfig::sender") app_details_config = util.get_app_details_config(db) if app_details_config is None: email_args = {} @@ -27,16 +26,13 @@ async def send_password_reset_email( brand_color=app_details_config.brand_color, ) msg = ui.render_password_reset_email( - from_addr=from_addr, to_addr=to_addr, reset_url=reset_url, **email_args, ) - coro = smtp.send_email( - db, + smtp_provider = smtp.SMTP(db) + coro = smtp_provider.send( msg, - sender=from_addr, - recipients=to_addr, test_mode=test_mode, ) await _protected_send(coro, tenant) @@ -51,7 +47,6 @@ async def send_verification_email( provider: str, test_mode: bool, ) -> None: - from_addr = util.get_config(db, "ext::auth::SMTPConfig::sender") app_details_config = util.get_app_details_config(db) verification_token_params = urllib.parse.urlencode( { @@ -71,16 +66,13 @@ async def send_verification_email( brand_color=app_details_config.brand_color, ) msg = ui.render_verification_email( - from_addr=from_addr, to_addr=to_addr, verify_url=verify_url, **email_args, ) - coro = smtp.send_email( - db, + smtp_provider = smtp.SMTP(db) + coro = smtp_provider.send( msg, - sender=from_addr, - recipients=to_addr, test_mode=test_mode, ) await _protected_send(coro, tenant) @@ -93,7 +85,6 @@ async def send_magic_link_email( link: str, test_mode: bool, ) -> None: - from_addr = util.get_config(db, "ext::auth::SMTPConfig::sender") app_details_config = util.get_app_details_config(db) if app_details_config is None: email_args = {} @@ -105,16 +96,13 @@ async def send_magic_link_email( brand_color=app_details_config.brand_color, ) msg = ui.render_magic_link_email( - from_addr=from_addr, to_addr=to_addr, link=link, **email_args, ) - coro = smtp.send_email( - db, + smtp_provider = smtp.SMTP(db) + coro = smtp_provider.send( msg, - sender=from_addr, - recipients=to_addr, test_mode=test_mode, ) await _protected_send(coro, tenant) diff --git a/edb/server/protocol/auth_ext/ui/__init__.py b/edb/server/protocol/auth_ext/ui/__init__.py index d066a0c09841..975f3393e45d 100644 --- a/edb/server/protocol/auth_ext/ui/__init__.py +++ b/edb/server/protocol/auth_ext/ui/__init__.py @@ -20,9 +20,7 @@ from typing import cast, Optional import html - -from email.mime import multipart -from email.mime import text as mime_text +import email.message from edb.server.protocol.auth_ext import config as auth_config @@ -701,21 +699,17 @@ def render_magic_link_sent_page( def render_password_reset_email( *, - from_addr: str, to_addr: str, reset_url: str, app_name: Optional[str] = None, logo_url: Optional[str] = None, dark_logo_url: Optional[str] = None, brand_color: Optional[str] = "#007bff", -) -> multipart.MIMEMultipart: - msg = multipart.MIMEMultipart() - msg["From"] = from_addr +) -> email.message.EmailMessage: + msg = email.message.EmailMessage() msg["To"] = to_addr msg["Subject"] = "Reset password" - alternative = multipart.MIMEMultipart('alternative') - plain_text_msg = mime_text.MIMEText( - f""" + plain_text_content = f""" Somebody requested a new password for the {app_name or ''} account associated with {to_addr}. @@ -723,13 +717,8 @@ def render_password_reset_email( email address: {reset_url} - """, - "plain", - "utf-8", - ) - alternative.attach(plain_text_msg) - - content = f""" + """ + html_content = f""" multipart.MIMEMultipart: - msg = multipart.MIMEMultipart() - msg["From"] = from_addr +) -> email.message.EmailMessage: + msg = email.message.EmailMessage() msg["To"] = to_addr msg["Subject"] = ( f"Verify your email{f' for {app_name}' if app_name else ''}" ) - alternative = multipart.MIMEMultipart('alternative') - plain_text_msg = mime_text.MIMEText( - f""" + plain_text_content = f""" Congratulations, you're registered{f' at {app_name}' if app_name else ''}! Please paste the following URL into your browser address bar to verify your email address: {verify_url} - """, - "plain", - "utf-8", - ) - alternative.attach(plain_text_msg) - - content = f""" + """ + html_content = f""" - -""" - - html_msg = mime_text.MIMEText( + """ + msg.set_content(plain_text_content, subtype="plain") + msg.set_content( render.base_default_email( + content=html_content, app_name=app_name, logo_url=logo_url, - content=content, ), - "html", - "utf-8", + subtype="html", ) - alternative.attach(html_msg) - msg.attach(alternative) return msg def render_magic_link_email( *, - from_addr: str, to_addr: str, link: str, app_name: Optional[str] = None, logo_url: Optional[str] = None, dark_logo_url: Optional[str] = None, brand_color: Optional[str] = "#007bff", -) -> multipart.MIMEMultipart: - msg = multipart.MIMEMultipart() - msg["From"] = from_addr +) -> email.message.EmailMessage: + msg = email.message.EmailMessage() msg["To"] = to_addr msg["Subject"] = "Sign in link" - alternative = multipart.MIMEMultipart('alternative') - plain_text_msg = mime_text.MIMEText( - f""" + plain_text_content = f""" Please paste the following URL into your browser address bar to be signed into your account: {link} - """, - "plain", - "utf-8", - ) - alternative.attach(plain_text_msg) - content = f""" + """ + html_content = f""" """ - html_msg = mime_text.MIMEText( + msg.set_content(plain_text_content, subtype="plain") + msg.set_content( render.base_default_email( + content=html_content, app_name=app_name, logo_url=logo_url, - content=content, ), - "html", - "utf-8", + subtype="html", ) - alternative.attach(html_msg) - msg.attach(alternative) return msg diff --git a/edb/server/smtp.py b/edb/server/smtp.py new file mode 100644 index 000000000000..799e930713b0 --- /dev/null +++ b/edb/server/smtp.py @@ -0,0 +1,225 @@ +# +# This source file is part of the EdgeDB open source project. +# +# Copyright 2024-present MagicStack Inc. and the EdgeDB authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import annotations + +import dataclasses +import email.message +import asyncio +import logging +import os +import hashlib +import pickle +import aiosmtplib + +from typing import Optional + +from edb.common import retryloop +from edb.ir import statypes +from edb import errors +from . import dbview + + +_semaphore: asyncio.BoundedSemaphore | None = None + +logger = logging.getLogger('edb.server.smtp') + + +@dataclasses.dataclass +class SMTPProviderConfig: + name: str + sender: Optional[str] + host: Optional[str] + port: Optional[int] + username: Optional[str] + password: Optional[str] + security: str + validate_certs: bool + timeout_per_email: statypes.Duration + timeout_per_attempt: statypes.Duration + + +class SMTP: + def __init__(self, db: dbview.Database): + current_provider = _get_current_email_provider(db) + self.sender = current_provider.sender or "noreply@example.com" + default_port = ( + 465 + if current_provider.security == "TLS" + else 587 if current_provider.security == "STARTTLS" else 25 + ) + use_tls: bool + start_tls: bool | None + match current_provider.security: + case "PlainText": + use_tls = False + start_tls = False + + case "TLS": + use_tls = True + start_tls = False + + case "STARTTLS": + use_tls = False + start_tls = True + + case "STARTTLSOrPlainText": + use_tls = False + start_tls = None + + case _: + raise NotImplementedError + + host = current_provider.host or "localhost" + port = current_provider.port or default_port + username = current_provider.username + password = current_provider.password + validate_certs = current_provider.validate_certs + timeout_per_attempt = current_provider.timeout_per_attempt + + req_timeout = timeout_per_attempt.to_microseconds() / 1_000_000.0 + self.timeout_per_email = ( + current_provider.timeout_per_email.to_microseconds() / 1_000_000.0 + ) + self.client = aiosmtplib.SMTP( + hostname=host, + port=port, + username=username, + password=password, + timeout=req_timeout, + use_tls=use_tls, + start_tls=start_tls, + validate_certs=validate_certs, + ) + + async def send( + self, + message: email.message.Message, + *, + test_mode: bool = False, + ) -> None: + global _semaphore + if _semaphore is None: + _semaphore = asyncio.BoundedSemaphore( + int( + os.environ.get( + "EDGEDB_SERVER_AUTH_SMTP_CONCURRENCY", + os.environ.get("EDGEDB_SERVER_SMTP_CONCURRENCY", 5), + ) + ) + ) + + # n.b. When constructing EmailMessage objects, we don't set the "From" + # header since that is configured in the SmtpProviderConfig. However, + # the EmailMessage will have the correct "To" header. + message["From"] = self.sender + rloop = retryloop.RetryLoop( + timeout=self.timeout_per_email, + backoff=retryloop.exp_backoff(), + ignore=( + aiosmtplib.SMTPConnectError, + aiosmtplib.SMTPHeloError, + aiosmtplib.SMTPServerDisconnected, + aiosmtplib.SMTPConnectTimeoutError, + aiosmtplib.SMTPConnectResponseError, + ), + ) + async for iteration in rloop: + async with iteration: + async with _semaphore: + # Currently we are not reusing SMTP connections, but + # ideally we should replace this with a pool of + # connections, and drop idle connections after configured + # time. + if test_mode: + self._send_test_mode_email(message) + else: + logger.info( + f"Sending SMTP message to {self.host}:{self.port}" + ) + + errors, response = await self.client.send_message( + message + ) + if errors: + logger.error( + f"SMTP server returned errors: {errors}" + ) + else: + logger.info( + f"SMTP message sent successfully: {response}" + ) + + def _send_test_mode_email(self, message: email.message.Message): + sender = message["From"] + recipients = message["To"] + recipients_list: list[str] + if isinstance(recipients, str): + recipients_list = [recipients] + elif recipients is None: + recipients_list = [] + else: + recipients_list = list(recipients) + + hash_input = f"{sender}{','.join(recipients_list)}" + file_name_hash = hashlib.sha256(hash_input.encode()).hexdigest() + file_name = f"/tmp/edb-test-email-{file_name_hash}.pickle" + test_file = os.environ.get( + "EDGEDB_TEST_EMAIL_FILE", + file_name, + ) + if os.path.exists(test_file): + os.unlink(test_file) + with open(test_file, "wb") as f: + logger.info(f"Dumping SMTP message to {test_file}") + args = dict( + message=message, + sender=sender, + recipients=recipients, + hostname=self.client.hostname, + port=self.client.port, + username=self.client._login_username, + password=self.client._login_password, + timeout=self.client.timeout, + use_tls=self.client.use_tls, + start_tls=self.client._start_tls_on_connect, + validate_certs=self.client.validate_certs, + ) + pickle.dump(args, f) + + +def _get_current_email_provider( + db: dbview.Database, +) -> SMTPProviderConfig: + current_provider_name = db.lookup_config("current_email_provider_name") + if current_provider_name is None: + raise errors.ConfigurationError("No email provider configured") + + found = None + for obj in db.lookup_config("email_providers"): + as_json = obj.to_json_value() + as_json.pop('_tname', None) + if as_json.get("name") == current_provider_name: + found = SMTPProviderConfig(**as_json) + break + + if found is None: + raise errors.ConfigurationError( + f"No email provider named {current_provider_name!r}" + ) + return found diff --git a/tests/test_http_ext_auth.py b/tests/test_http_ext_auth.py index 983d164a9edb..aa5aec3450e9 100644 --- a/tests/test_http_ext_auth.py +++ b/tests/test_http_ext_auth.py @@ -31,6 +31,7 @@ from typing import Any, Optional, cast from jwcrypto import jwt, jwk +from email.message import EmailMessage from edgedb import QueryAssertionError from edb.testbase import http as tb @@ -248,6 +249,14 @@ class TestHttpExtAuth(tb.ExtAuthTestCase): SETUP = [ f""" + CONFIGURE CURRENT DATABASE INSERT cfg::SMTPProviderConfig {{ + name := "email_hosting_is_easy", + sender := "{SENDER}", + }}; + + CONFIGURE CURRENT DATABASE SET + cfg::current_email_provider_name := "email_hosting_is_easy"; + CONFIGURE CURRENT DATABASE SET ext::auth::AuthConfig::auth_signing_key := '{SIGNING_KEY}'; @@ -266,22 +275,12 @@ class TestHttpExtAuth(tb.ExtAuthTestCase): CONFIGURE CURRENT DATABASE SET ext::auth::AuthConfig::brand_color := '{BRAND_COLOR}'; - CONFIGURE CURRENT DATABASE SET - ext::auth::AuthConfig::current_email_provider_name := - "email_hosting_is_easy"; - CONFIGURE CURRENT DATABASE INSERT ext::auth::UIConfig {{ redirect_to := 'https://example.com/app', redirect_to_on_signup := 'https://example.com/signup/app', }}; - CONFIGURE CURRENT DATABASE - INSERT ext::auth::SMTPProviderConfig {{ - name := "email_hosting_is_easy", - sender := '{SENDER}', - }}; - CONFIGURE CURRENT DATABASE SET ext::auth::AuthConfig::allowed_redirect_urls := {{ 'https://example.com/app' @@ -3210,10 +3209,18 @@ async def test_http_auth_ext_local_emailpassword_resend_verification(self): email_args = pickle.load(f) self.assertEqual(email_args["sender"], SENDER) self.assertEqual(email_args["recipients"], form_data["email"]) - html_msg = email_args["message"].get_payload(0).get_payload(1) - html_email = html_msg.get_payload(decode=True).decode("utf-8") + msg = cast(EmailMessage, email_args["message"]).get_body( + ("html",) + ) + assert msg is not None + msg = cast(EmailMessage, email_args["message"]).get_body( + ("html",) + ) + assert msg is not None + html_email = msg.get_payload(decode=True).decode("utf-8") match = re.search( - r'

([^<]+)', html_email + r'

([^<]+)', + html_email, ) assert match is not None verify_url = urllib.parse.urlparse(match.group(1)) @@ -3389,8 +3396,11 @@ async def test_http_auth_ext_local_webauthn_resend_verification(self): email_args = pickle.load(f) self.assertEqual(email_args["sender"], SENDER) self.assertEqual(email_args["recipients"], email) - html_msg = email_args["message"].get_payload(0).get_payload(1) - html_email = html_msg.get_payload(decode=True).decode("utf-8") + msg = cast(EmailMessage, email_args["message"]).get_body( + ("html",) + ) + assert msg is not None + html_email = msg.get_payload(decode=True).decode("utf-8") match = re.search( r'

([^<]+)', html_email ) @@ -3686,8 +3696,11 @@ async def test_http_auth_ext_local_password_forgot_form_01(self): email_args = pickle.load(f) self.assertEqual(email_args["sender"], SENDER) self.assertEqual(email_args["recipients"], email) - html_msg = email_args["message"].get_payload(0).get_payload(1) - html_email = html_msg.get_payload(decode=True).decode("utf-8") + msg = cast(EmailMessage, email_args["message"]).get_body( + ("html",) + ) + assert msg is not None + html_email = msg.get_payload(decode=True).decode("utf-8") match = re.search( r'

([^<]+)', html_email ) @@ -3873,8 +3886,11 @@ async def test_http_auth_ext_local_password_reset_form_01(self): email_args = pickle.load(f) self.assertEqual(email_args["sender"], SENDER) self.assertEqual(email_args["recipients"], email) - html_msg = email_args["message"].get_payload(0).get_payload(1) - html_email = html_msg.get_payload(decode=True).decode("utf-8") + msg = cast(EmailMessage, email_args["message"]).get_body( + ("html",) + ) + assert msg is not None + html_email = msg.get_payload(decode=True).decode("utf-8") match = re.search( r'

([^<]+)', html_email ) @@ -4360,8 +4376,11 @@ async def test_http_auth_ext_magic_link_01(self): email_args = pickle.load(f) self.assertEqual(email_args["sender"], SENDER) self.assertEqual(email_args["recipients"], email) - html_msg = email_args["message"].get_payload(0).get_payload(1) - html_email = html_msg.get_payload(decode=True).decode("utf-8") + msg = cast(EmailMessage, email_args["message"]).get_body( + ("html",) + ) + assert msg is not None + html_email = msg.get_payload(decode=True).decode("utf-8") match = re.search( r'

([^<]+)', html_email )