diff --git a/edb/testbase/http.py b/edb/testbase/http.py index 01edb9193227..900acf2df484 100644 --- a/edb/testbase/http.py +++ b/edb/testbase/http.py @@ -18,8 +18,15 @@ from __future__ import annotations +from typing import ( + Any, + Callable, + Optional, +) +import http.server import json +import threading import urllib.parse import urllib.request @@ -302,3 +309,177 @@ def assert_graphql_query_result( assert_data_shape.assert_data_shape( res, result, self.fail, message=msg) return res + + +class MockHttpServerHandler(http.server.BaseHTTPRequestHandler): + def do_GET(self): + self.close_connection = False + server, path = self.path.lstrip('/').split('/', 1) + server = urllib.parse.unquote(server) + self.server.owner.handle_request('GET', server, path, self) + + def do_POST(self): + self.close_connection = False + server, path = self.path.lstrip('/').split('/', 1) + server = urllib.parse.unquote(server) + self.server.owner.handle_request('POST', server, path, self) + + def log_message(self, *args): + pass + + +ResponseType = tuple[str, int] | tuple[str, int, dict[str, str]] + + +class MockHttpServer: + def __init__(self) -> None: + self.has_started = threading.Event() + self.routes: dict[ + tuple[str, str, str], + ResponseType | Callable[[MockHttpServerHandler], ResponseType], + ] = {} + self.requests: dict[tuple[str, str, str], list[dict[str, Any]]] = {} + self.url: Optional[str] = None + + def get_base_url(self) -> str: + if self.url is None: + raise RuntimeError("mock server is not running") + return self.url + + def register_route_handler( + self, + method: str, + server: str, + path: str, + ): + def wrapper( + handler: ( + ResponseType | Callable[[MockHttpServerHandler], ResponseType] + ) + ): + self.routes[(method, server, path)] = handler + return handler + + return wrapper + + def handle_request( + self, + method: str, + server: str, + path: str, + handler: MockHttpServerHandler, + ): + # `handler` is documented here: + # https://docs.python.org/3/library/http.server.html#http.server.BaseHTTPRequestHandler + key = (method, server, path) + if key not in self.requests: + self.requests[key] = [] + + # Parse and save the request details + parsed_path = urllib.parse.urlparse(path) + headers = {k.lower(): v for k, v in dict(handler.headers).items()} + query_params = urllib.parse.parse_qs(parsed_path.query) + if 'content-length' in headers: + body = handler.rfile.read(int(headers['content-length'])).decode() + else: + body = None + + request_details = { + 'headers': headers, + 'query_params': query_params, + 'body': body, + } + self.requests[key].append(request_details) + + if key not in self.routes: + handler.send_error(404) + return + + registered_handler = self.routes[key] + + if callable(registered_handler): + try: + handler_result = registered_handler(handler) + if len(handler_result) == 2: + response, status = handler_result + additional_headers = None + elif len(handler_result) == 3: + response, status, additional_headers = handler_result + except Exception: + handler.send_error(500) + raise + else: + if len(registered_handler) == 2: + response, status = registered_handler + additional_headers = None + elif len(registered_handler) == 3: + response, status, additional_headers = registered_handler + + if "headers" in request_details and isinstance( + request_details["headers"], dict + ): + accept_header = request_details["headers"].get( + "accept", "application/json" + ) + else: + accept_header = "application/json" + + if ( + accept_header.startswith("application/json") + or ( + accept_header.startswith("application/") + and "vnd." in accept_header + and "+json" in accept_header + ) + or accept_header == "*/*" + ): + content_type = 'application/json' + elif accept_header.startswith("application/x-www-form-urlencoded"): + content_type = 'application/x-www-form-urlencoded' + else: + handler.send_error( + 415, f"Unsupported accept header: {accept_header}" + ) + return + + data = response.encode() + + handler.send_response(status) + handler.send_header('Content-Type', content_type) + handler.send_header('Content-Length', str(len(data))) + if additional_headers is not None: + for header, value in additional_headers.items(): + handler.send_header(header, value) + handler.end_headers() + handler.wfile.write(data) + + def start(self): + assert not hasattr(self, '_http_runner') + self._http_runner = threading.Thread(target=self._http_worker) + self._http_runner.start() + self.has_started.wait() + self.url = f'http://{self._address[0]}:{self._address[1]}/' + + def __enter__(self): + self.start() + return self + + def _http_worker(self): + self._http_server = http.server.HTTPServer( + ('localhost', 0), MockHttpServerHandler + ) + self._http_server.owner = self + self._address = self._http_server.server_address + self.has_started.set() + self._http_server.serve_forever(poll_interval=0.01) + self._http_server.server_close() + + def stop(self): + self._http_server.shutdown() + if self._http_runner is not None: + self._http_runner.join() + self._http_runner = None + + def __exit__(self, *exc): + self.stop() + self.url = None diff --git a/tests/test_http_ext_auth.py b/tests/test_http_ext_auth.py index a2015484b745..e58057dd57dd 100644 --- a/tests/test_http_ext_auth.py +++ b/tests/test_http_ext_auth.py @@ -23,15 +23,13 @@ import json import base64 import datetime -import http.server -import threading import argon2 import os import pickle import re import hashlib -from typing import Any, Callable, Optional +from typing import Any, Optional from jwcrypto import jwt, jwk from edgedb import QueryAssertionError @@ -206,179 +204,6 @@ def utcnow(): return datetime.datetime.now(datetime.timezone.utc) -class MockHttpServerHandler(http.server.BaseHTTPRequestHandler): - def do_GET(self): - self.close_connection = False - server, path = self.path.lstrip('/').split('/', 1) - server = urllib.parse.unquote(server) - self.server.owner.handle_request('GET', server, path, self) - - def do_POST(self): - self.close_connection = False - server, path = self.path.lstrip('/').split('/', 1) - server = urllib.parse.unquote(server) - self.server.owner.handle_request('POST', server, path, self) - - def log_message(self, *args): - pass - - -ResponseType = tuple[str, int] | tuple[str, int, dict[str, str]] - - -class MockAuthProvider: - def __init__(self) -> None: - self.has_started = threading.Event() - self.routes: dict[ - tuple[str, str, str], - ResponseType | Callable[[MockHttpServerHandler], ResponseType], - ] = {} - self.requests: dict[tuple[str, str, str], list[dict[str, Any]]] = {} - self.url: Optional[str] = None - - def get_base_url(self) -> str: - if self.url is None: - raise RuntimeError("mock server is not running") - return self.url - - def register_route_handler( - self, - method: str, - server: str, - path: str, - ): - def wrapper( - handler: ( - ResponseType | Callable[[MockHttpServerHandler], ResponseType] - ) - ): - self.routes[(method, server, path)] = handler - return handler - - return wrapper - - def handle_request( - self, - method: str, - server: str, - path: str, - handler: MockHttpServerHandler, - ): - # `handler` is documented here: - # https://docs.python.org/3/library/http.server.html#http.server.BaseHTTPRequestHandler - key = (method, server, path) - if key not in self.requests: - self.requests[key] = [] - - # Parse and save the request details - parsed_path = urllib.parse.urlparse(path) - headers = {k.lower(): v for k, v in dict(handler.headers).items()} - query_params = urllib.parse.parse_qs(parsed_path.query) - if 'content-length' in headers: - body = handler.rfile.read(int(headers['content-length'])).decode() - else: - body = None - - request_details = { - 'headers': headers, - 'query_params': query_params, - 'body': body, - } - self.requests[key].append(request_details) - - if key not in self.routes: - handler.send_error(404) - return - - registered_handler = self.routes[key] - - if callable(registered_handler): - try: - handler_result = registered_handler(handler) - if len(handler_result) == 2: - response, status = handler_result - additional_headers = None - elif len(handler_result) == 3: - response, status, additional_headers = handler_result - except Exception: - handler.send_error(500) - raise - else: - if len(registered_handler) == 2: - response, status = registered_handler - additional_headers = None - elif len(registered_handler) == 3: - response, status, additional_headers = registered_handler - - if "headers" in request_details and isinstance( - request_details["headers"], dict - ): - accept_header = request_details["headers"].get( - "accept", "application/json" - ) - else: - accept_header = "application/json" - - if ( - accept_header.startswith("application/json") - or ( - accept_header.startswith("application/") - and "vnd." in accept_header - and "+json" in accept_header - ) - or accept_header == "*/*" - ): - content_type = 'application/json' - elif accept_header.startswith("application/x-www-form-urlencoded"): - content_type = 'application/x-www-form-urlencoded' - else: - handler.send_error( - 415, f"Unsupported accept header: {accept_header}" - ) - return - - data = response.encode() - - handler.send_response(status) - handler.send_header('Content-Type', content_type) - handler.send_header('Content-Length', str(len(data))) - if additional_headers is not None: - for header, value in additional_headers.items(): - handler.send_header(header, value) - handler.end_headers() - handler.wfile.write(data) - - def start(self): - assert not hasattr(self, '_http_runner') - self._http_runner = threading.Thread(target=self._http_worker) - self._http_runner.start() - self.has_started.wait() - self.url = f'http://{self._address[0]}:{self._address[1]}/' - - def __enter__(self): - self.start() - return self - - def _http_worker(self): - self._http_server = http.server.HTTPServer( - ('localhost', 0), MockHttpServerHandler - ) - self._http_server.owner = self - self._address = self._http_server.server_address - self.has_started.set() - self._http_server.serve_forever(poll_interval=0.01) - self._http_server.server_close() - - def stop(self): - self._http_server.shutdown() - self._http_runner.join() - self._http_runner = None - - def __exit__(self, *exc): - self.stop() - self.url = None - - SIGNING_KEY = 'a' * 32 GITHUB_SECRET = 'b' * 32 GOOGLE_SECRET = 'c' * 32 @@ -491,10 +316,10 @@ def setUpClass(cls): cls._wait_for_db_config('ext::auth::AuthConfig::providers') ) - mock_provider: MockAuthProvider + mock_provider: tb.MockHttpServer def setUp(self): - self.mock_provider = MockAuthProvider() + self.mock_provider = tb.MockHttpServer() self.mock_provider.start() HTTP_TEST_PORT.set(self.mock_provider.get_base_url()) @@ -1452,7 +1277,8 @@ async def test_http_auth_ext_google_callback_01(self) -> None: self.assertEqual(url.hostname, server_url.hostname) self.assertEqual(url.path, f"{server_url.path}/some/path") - requests_for_discovery = self.mock_provider.requests[discovery_request] + requests_for_discovery = ( + self.mock_provider.requests[discovery_request]) self.assertEqual(len(requests_for_discovery), 2) requests_for_token = self.mock_provider.requests[token_request] @@ -1548,7 +1374,8 @@ async def test_http_auth_ext_google_authorize_01(self): ) self.assertEqual(qs.get("client_id"), [client_id]) - requests_for_discovery = self.mock_provider.requests[discovery_request] + requests_for_discovery = ( + self.mock_provider.requests[discovery_request]) self.assertEqual(len(requests_for_discovery), 1) pkce = await self.con.query( @@ -1618,7 +1445,8 @@ async def test_http_auth_ext_azure_authorize_01(self): ) self.assertEqual(qs.get("client_id"), [client_id]) - requests_for_discovery = self.mock_provider.requests[discovery_request] + requests_for_discovery = ( + self.mock_provider.requests[discovery_request]) self.assertEqual(len(requests_for_discovery), 1) pkce = await self.con.query( @@ -1750,7 +1578,8 @@ async def test_http_auth_ext_azure_callback_01(self) -> None: self.assertEqual(url.hostname, server_url.hostname) self.assertEqual(url.path, f"{server_url.path}/some/path") - requests_for_discovery = self.mock_provider.requests[discovery_request] + requests_for_discovery = ( + self.mock_provider.requests[discovery_request]) self.assertEqual(len(requests_for_discovery), 2) requests_for_token = self.mock_provider.requests[token_request] @@ -1830,7 +1659,8 @@ async def test_http_auth_ext_apple_authorize_01(self): ) self.assertEqual(qs.get("client_id"), [client_id]) - requests_for_discovery = self.mock_provider.requests[discovery_request] + requests_for_discovery = ( + self.mock_provider.requests[discovery_request]) self.assertEqual(len(requests_for_discovery), 1) pkce = await self.con.query( @@ -1967,7 +1797,8 @@ async def test_http_auth_ext_apple_callback_01(self) -> None: self.assertEqual(url.hostname, server_url.hostname) self.assertEqual(url.path, f"{server_url.path}/some/path") - requests_for_discovery = self.mock_provider.requests[discovery_request] + requests_for_discovery = ( + self.mock_provider.requests[discovery_request]) self.assertEqual(len(requests_for_discovery), 2) requests_for_token = self.mock_provider.requests[token_request] @@ -2253,7 +2084,8 @@ async def test_http_auth_ext_slack_callback_01(self) -> None: self.assertEqual(url.hostname, server_url.hostname) self.assertEqual(url.path, f"{server_url.path}/some/path") - requests_for_discovery = self.mock_provider.requests[discovery_request] + requests_for_discovery = ( + self.mock_provider.requests[discovery_request]) self.assertEqual(len(requests_for_discovery), 2) requests_for_token = self.mock_provider.requests[token_request] @@ -2349,7 +2181,8 @@ async def test_http_auth_ext_slack_authorize_01(self): ) self.assertEqual(qs.get("client_id"), [client_id]) - requests_for_discovery = self.mock_provider.requests[discovery_request] + requests_for_discovery = ( + self.mock_provider.requests[discovery_request]) self.assertEqual(len(requests_for_discovery), 1) pkce = await self.con.query(