From a91d693b9544cb27689dadae4561c71c1aa1e516 Mon Sep 17 00:00:00 2001 From: Scott Trinh Date: Fri, 17 Jan 2025 13:19:15 -0500 Subject: [PATCH] Do not fail if SMTP provider is not configured (#8228) Now that we have webhooks, it is valid to not have an SMTP provider configured. In this case, we log an error and send a fake email instead. --- edb/server/protocol/auth_ext/email.py | 51 +++++--- edb/testbase/http.py | 175 ++++++++++++++------------ tests/test_http_ext_auth.py | 43 ++++++- 3 files changed, 167 insertions(+), 102 deletions(-) diff --git a/edb/server/protocol/auth_ext/email.py b/edb/server/protocol/auth_ext/email.py index c5879d45a8d..70c8758cc33 100644 --- a/edb/server/protocol/auth_ext/email.py +++ b/edb/server/protocol/auth_ext/email.py @@ -1,13 +1,19 @@ import asyncio import urllib.parse import random +import logging +from email.message import EmailMessage from typing import Any, Coroutine from edb.server import tenant, smtp +from edb import errors from . import util, ui +logger = logging.getLogger("edb.server.ext.auth") + + async def send_password_reset_email( db: Any, tenant: tenant.Tenant, @@ -30,12 +36,7 @@ async def send_password_reset_email( reset_url=reset_url, **email_args, ) - smtp_provider = smtp.SMTP(db) - coro = smtp_provider.send( - msg, - test_mode=test_mode, - ) - await _protected_send(coro, tenant) + await _maybe_send_message(msg, tenant, db, test_mode) async def send_verification_email( @@ -70,12 +71,7 @@ async def send_verification_email( verify_url=verify_url, **email_args, ) - smtp_provider = smtp.SMTP(db) - coro = smtp_provider.send( - msg, - test_mode=test_mode, - ) - await _protected_send(coro, tenant) + await _maybe_send_message(msg, tenant, db, test_mode) async def send_magic_link_email( @@ -100,12 +96,7 @@ async def send_magic_link_email( link=link, **email_args, ) - smtp_provider = smtp.SMTP(db) - coro = smtp_provider.send( - msg, - test_mode=test_mode, - ) - await _protected_send(coro, tenant) + await _maybe_send_message(msg, tenant, db, test_mode) async def send_fake_email(tenant: tenant.Tenant) -> None: @@ -116,6 +107,30 @@ async def noop_coroutine() -> None: await _protected_send(coro, tenant) +async def _maybe_send_message( + msg: EmailMessage, + tenant: tenant.Tenant, + db: Any, + test_mode: bool, +) -> None: + try: + smtp_provider = smtp.SMTP(db) + except errors.ConfigurationError as e: + logger.debug( + "ConfigurationError while instantiating SMTP provider, " + f"sending fake email instead: {e}" + ) + smtp_provider = None + if smtp_provider is None: + coro = send_fake_email(tenant) + else: + coro = smtp_provider.send( + msg, + test_mode=test_mode, + ) + await _protected_send(coro, tenant) + + async def _protected_send( coro: Coroutine[Any, Any, None], tenant: tenant.Tenant ) -> None: diff --git a/edb/testbase/http.py b/edb/testbase/http.py index 4e74eb45cc7..6ab949a92c7 100644 --- a/edb/testbase/http.py +++ b/edb/testbase/http.py @@ -47,7 +47,13 @@ class BaseHttpTest(server.QueryTestCase): @classmethod async def _wait_for_db_config( - cls, config_key, *, server=None, instance_config=False, value=None + cls, + config_key, + *, + server=None, + instance_config=False, + value=None, + is_reset=False, ): dbname = cls.get_database_name() # Wait for the database config changes to propagate to the @@ -68,17 +74,21 @@ async def _wait_for_db_config( path="server-info", ) data = json.loads(rdata) - if 'databases' not in data: + if "databases" not in data: # multi-tenant instance - use the first tenant - data = next(iter(data['tenants'].values())) + data = next(iter(data["tenants"].values())) if instance_config: - config = data['instance_config'] + config = data["instance_config"] + else: + config = data["databases"][dbname]["config"] + if is_reset: + if config_key in config: + raise AssertionError("database config not ready") else: - config = data['databases'][dbname]['config'] - if config_key not in config: - raise AssertionError('database config not ready') - if value and config[config_key] != value: - raise AssertionError(f'database config not ready') + if config_key not in config: + raise AssertionError("database config not ready") + if value and config[config_key] != value: + raise AssertionError("database config not ready") class BaseHttpExtensionTest(BaseHttpTest): @@ -90,25 +100,23 @@ def get_extension_path(cls): def get_api_prefix(cls): extpath = cls.get_extension_path() dbname = cls.get_database_name() - return f'/branch/{dbname}/{extpath}' + return f"/branch/{dbname}/{extpath}" class ExtAuthTestCase(BaseHttpExtensionTest): - - EXTENSIONS = ['pgcrypto', 'auth'] + EXTENSIONS = ["pgcrypto", "auth"] @classmethod def get_extension_path(cls): - return 'ext/auth' + return "ext/auth" class EdgeQLTestCase(BaseHttpExtensionTest): - - EXTENSIONS = ['edgeql_http'] + EXTENSIONS = ["edgeql_http"] @classmethod def get_extension_path(cls): - return 'edgeql' + return "edgeql" def edgeql_query( self, @@ -119,46 +127,44 @@ def edgeql_query( globals=None, origin=None, ): - req_data = { - 'query': query - } + req_data = {"query": query} if use_http_post: if variables is not None: - req_data['variables'] = variables + req_data["variables"] = variables if globals is not None: - req_data['globals'] = globals - req = urllib.request.Request(self.http_addr, method='POST') - req.add_header('Content-Type', 'application/json') - req.add_header('Authorization', self.make_auth_header()) + req_data["globals"] = globals + req = urllib.request.Request(self.http_addr, method="POST") + req.add_header("Content-Type", "application/json") + req.add_header("Authorization", self.make_auth_header()) if origin: - req.add_header('Origin', origin) + req.add_header("Origin", origin) response = urllib.request.urlopen( req, json.dumps(req_data).encode(), context=self.tls_context ) resp_data = json.loads(response.read()) else: if variables is not None: - req_data['variables'] = json.dumps(variables) + req_data["variables"] = json.dumps(variables) if globals is not None: - req_data['globals'] = json.dumps(globals) + req_data["globals"] = json.dumps(globals) req = urllib.request.Request( - f'{self.http_addr}/?{urllib.parse.urlencode(req_data)}', + f"{self.http_addr}/?{urllib.parse.urlencode(req_data)}", ) - req.add_header('Authorization', self.make_auth_header()) + req.add_header("Authorization", self.make_auth_header()) response = urllib.request.urlopen( req, context=self.tls_context, ) resp_data = json.loads(response.read()) - if 'data' in resp_data: - return (resp_data['data'], response) + if "data" in resp_data: + return (resp_data["data"], response) - err = resp_data['error'] + err = resp_data["error"] - ex_msg = err['message'].strip() - ex_code = err['code'] + ex_msg = err["message"].strip() + ex_code = err["code"] raise edgedb.EdgeDBError._from_code(ex_code, ex_msg) @@ -177,7 +183,8 @@ def assert_edgeql_query_result( query, use_http_post=use_http_post, variables=variables, - globals=globals) + globals=globals, + ) if sort is not None: # GQL will always have a single object returned. The data is @@ -185,18 +192,16 @@ def assert_edgeql_query_result( for r in res.values(): assert_data_shape.sort_results(r, sort) - assert_data_shape.assert_data_shape( - res, result, self.fail, message=msg) + assert_data_shape.assert_data_shape(res, result, self.fail, message=msg) return res class GraphQLTestCase(BaseHttpExtensionTest): - - EXTENSIONS = ['graphql'] + EXTENSIONS = ["graphql"] @classmethod def get_extension_path(cls): - return 'graphql' + return "graphql" def graphql_query( self, @@ -208,25 +213,25 @@ def graphql_query( globals=None, deprecated_globals=None, ): - req_data = {'query': query} + req_data = {"query": query} if operation_name is not None: - req_data['operationName'] = operation_name + req_data["operationName"] = operation_name if use_http_post: if variables is not None: - req_data['variables'] = variables + req_data["variables"] = variables if globals is not None: if variables is None: - req_data['variables'] = dict() - req_data['variables']['__globals__'] = globals + req_data["variables"] = dict() + req_data["variables"]["__globals__"] = globals # Support testing the old way of sending globals. if deprecated_globals is not None: - req_data['globals'] = deprecated_globals + req_data["globals"] = deprecated_globals - req = urllib.request.Request(self.http_addr, method='POST') - req.add_header('Content-Type', 'application/json') - req.add_header('Authorization', self.make_auth_header()) + req = urllib.request.Request(self.http_addr, method="POST") + req.add_header("Content-Type", "application/json") + req.add_header("Authorization", self.make_auth_header()) response = urllib.request.urlopen( req, json.dumps(req_data).encode(), context=self.tls_context ) @@ -235,45 +240,48 @@ def graphql_query( if globals is not None: if variables is None: variables = dict() - variables['__globals__'] = globals + variables["__globals__"] = globals # Support testing the old way of sending globals. if deprecated_globals is not None: - req_data['globals'] = json.dumps(deprecated_globals) + req_data["globals"] = json.dumps(deprecated_globals) if variables is not None: - req_data['variables'] = json.dumps(variables) + req_data["variables"] = json.dumps(variables) req = urllib.request.Request( - f'{self.http_addr}/?{urllib.parse.urlencode(req_data)}', + f"{self.http_addr}/?{urllib.parse.urlencode(req_data)}", ) - req.add_header('Authorization', self.make_auth_header()) + req.add_header("Authorization", self.make_auth_header()) response = urllib.request.urlopen( req, context=self.tls_context, ) resp_data = json.loads(response.read()) - if 'data' in resp_data: - return resp_data['data'] + if "data" in resp_data: + return resp_data["data"] - err = resp_data['errors'][0] + err = resp_data["errors"][0] - typename, msg = err['message'].split(':', 1) + typename, msg = err["message"].split(":", 1) msg = msg.strip() try: ex_type = getattr(edgedb, typename) except AttributeError: raise AssertionError( - f'server returned an invalid exception typename: {typename!r}' - f'\n Message: {msg}') + f"server returned an invalid exception typename: {typename!r}" + f"\n Message: {msg}" + ) ex = ex_type(msg) - if 'locations' in err: + if "locations" in err: # XXX Fix this when LSP "location" objects are implemented ex._attrs[base_errors.FIELD_LINE_START] = str( - err['locations'][0]['line']).encode() + err["locations"][0]["line"] + ).encode() ex._attrs[base_errors.FIELD_COLUMN_START] = str( - err['locations'][0]['column']).encode() + err["locations"][0]["column"] + ).encode() raise ex @@ -296,7 +304,8 @@ def assert_graphql_query_result( use_http_post=use_http_post, variables=variables, globals=globals, - deprecated_globals=deprecated_globals) + deprecated_globals=deprecated_globals, + ) if sort is not None: # GQL will always have a single object returned. The data is @@ -304,8 +313,7 @@ def assert_graphql_query_result( for r in res.values(): assert_data_shape.sort_results(r, sort) - assert_data_shape.assert_data_shape( - res, result, self.fail, message=msg) + assert_data_shape.assert_data_shape(res, result, self.fail, message=msg) return res @@ -317,12 +325,12 @@ def get_server_and_path(self) -> tuple[str, str]: def do_GET(self): self.close_connection = False server, path = self.get_server_and_path() - self.server.owner.handle_request('GET', server, path, self) + self.server.owner.handle_request("GET", server, path, self) def do_POST(self): self.close_connection = False server, path = self.get_server_and_path() - self.server.owner.handle_request('POST', server, path, self) + self.server.owner.handle_request("POST", server, path, self) def log_message(self, *args): pass @@ -332,10 +340,9 @@ class MultiHostMockHttpServerHandler(MockHttpServerHandler): def get_server_and_path(self) -> tuple[str, str]: # Path looks like: # http://127.0.0.1:32881/https%3A//slack.com/.well-known/openid-configuration - raw_url = urllib.parse.unquote(self.path.lstrip('/')) + raw_url = urllib.parse.unquote(self.path.lstrip("/")) url = urllib.parse.urlparse(raw_url) - return (f'{url.scheme}://{url.netloc}', - url.path.lstrip('/')) + return (f"{url.scheme}://{url.netloc}", url.path.lstrip("/")) ResponseType = tuple[str, int] | tuple[str, int, dict[str, str]] @@ -384,7 +391,7 @@ def wrapper( | Callable[ [MockHttpServerHandler, RequestDetails], ResponseType ] - ) + ), ): self.routes[(method, server, path)] = handler return handler @@ -408,8 +415,8 @@ def handle_request( parsed_path = urllib.parse.urlparse(path) headers = {k.lower(): v for k, v in dict(handler.headers).items()} query_params = urllib.parse.parse_qs(parsed_path.query) - if 'content-length' in headers: - body = handler.rfile.read(int(headers['content-length'])).decode() + if "content-length" in headers: + body = handler.rfile.read(int(headers["content-length"])).decode() else: body = None @@ -420,8 +427,10 @@ def handle_request( ) self.requests[key].append(request_details) if key not in self.routes: - error_message = (f"No route handler for {key}\n\n" - f"Available routes:\n{self.routes}") + error_message = ( + f"No route handler for {key}\n\n" + f"Available routes:\n{self.routes}" + ) handler.send_error(404, message=error_message) return @@ -458,9 +467,9 @@ def handle_request( ) or accept_header == "*/*" ): - content_type = 'application/json' + content_type = "application/json" elif accept_header.startswith("application/x-www-form-urlencoded"): - content_type = 'application/x-www-form-urlencoded' + content_type = "application/x-www-form-urlencoded" else: handler.send_error( 415, f"Unsupported accept header: {accept_header}" @@ -470,8 +479,8 @@ def handle_request( data = response.encode() handler.send_response(status) - handler.send_header('Content-Type', content_type) - handler.send_header('Content-Length', str(len(data))) + handler.send_header("Content-Type", content_type) + handler.send_header("Content-Length", str(len(data))) if additional_headers is not None: for header, value in additional_headers.items(): handler.send_header(header, value) @@ -479,11 +488,11 @@ def handle_request( handler.wfile.write(data) def start(self): - assert not hasattr(self, '_http_runner') + assert not hasattr(self, "_http_runner") self._http_runner = threading.Thread(target=self._http_worker) self._http_runner.start() self.has_started.wait() - self.url = f'http://{self._address[0]}:{self._address[1]}/' + self.url = f"http://{self._address[0]}:{self._address[1]}/" def __enter__(self): self.start() @@ -491,7 +500,7 @@ def __enter__(self): def _http_worker(self): self._http_server = http.server.HTTPServer( - ('localhost', 0), self.handler_type + ("localhost", 0), self.handler_type ) self._http_server.owner = self self._address = self._http_server.server_address diff --git a/tests/test_http_ext_auth.py b/tests/test_http_ext_auth.py index 149a928e166..1fd6323e935 100644 --- a/tests/test_http_ext_auth.py +++ b/tests/test_http_ext_auth.py @@ -256,7 +256,7 @@ class TestHttpExtAuth(tb.ExtAuthTestCase): }}; CONFIGURE CURRENT DATABASE SET - cfg::current_email_provider_name := "email_hosting_is_easy"; + current_email_provider_name := "email_hosting_is_easy"; CONFIGURE CURRENT DATABASE SET ext::auth::AuthConfig::auth_signing_key := '{SIGNING_KEY}'; @@ -2817,6 +2817,47 @@ async def test_http_auth_ext_local_password_register_form_02(self): self.assertEqual(status, 400) + async def test_http_auth_ext_local_password_register_form_no_smtp(self): + await self.con.query( + """ + CONFIGURE CURRENT DATABASE RESET + current_email_provider_name; + """, + ) + await self._wait_for_db_config( + "cfg::current_email_provider_name", is_reset=True + ) + try: + with self.http_con() as http_con: + email = f"{uuid.uuid4()}@example.com" + form_data = { + "provider": "builtin::local_emailpassword", + "email": email, + "password": "test_password", + "challenge": str(uuid.uuid4()), + } + form_data_encoded = urllib.parse.urlencode(form_data).encode() + + _, _, status = self.http_con_request( + http_con, + None, + path="register", + method="POST", + body=form_data_encoded, + headers={ + "Content-Type": "application/x-www-form-urlencoded" + }, + ) + + self.assertEqual(status, 201) + finally: + await self.con.query( + """ + CONFIGURE CURRENT DATABASE SET + current_email_provider_name := "email_hosting_is_easy"; + """, + ) + async def test_http_auth_ext_local_password_register_json_02(self): with self.http_con() as http_con: provider_name = "builtin::local_emailpassword"