Skip to content

Commit

Permalink
Move MockAuthProvider to testbase.http
Browse files Browse the repository at this point in the history
There's nothing specific to the auth extension in the mock HTTP server
implementation and it will be useful in tests of other HTTP extensions,
so move it to `testbase.http` and rename to `MockHttpServer`.
  • Loading branch information
elprans committed Apr 16, 2024
1 parent d7b66b4 commit 45d9358
Show file tree
Hide file tree
Showing 2 changed files with 200 additions and 186 deletions.
181 changes: 181 additions & 0 deletions edb/testbase/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,15 @@


from __future__ import annotations
from typing import (
Any,
Callable,
Optional,
)

import http.server
import json
import threading
import urllib.parse
import urllib.request

Expand Down Expand Up @@ -302,3 +309,177 @@ def assert_graphql_query_result(
assert_data_shape.assert_data_shape(
res, result, self.fail, message=msg)
return res


class MockHttpServerHandler(http.server.BaseHTTPRequestHandler):
def do_GET(self):
self.close_connection = False
server, path = self.path.lstrip('/').split('/', 1)
server = urllib.parse.unquote(server)
self.server.owner.handle_request('GET', server, path, self)

def do_POST(self):
self.close_connection = False
server, path = self.path.lstrip('/').split('/', 1)
server = urllib.parse.unquote(server)
self.server.owner.handle_request('POST', server, path, self)

def log_message(self, *args):
pass


ResponseType = tuple[str, int] | tuple[str, int, dict[str, str]]


class MockHttpServer:
def __init__(self) -> None:
self.has_started = threading.Event()
self.routes: dict[
tuple[str, str, str],
ResponseType | Callable[[MockHttpServerHandler], ResponseType],
] = {}
self.requests: dict[tuple[str, str, str], list[dict[str, Any]]] = {}
self.url: Optional[str] = None

def get_base_url(self) -> str:
if self.url is None:
raise RuntimeError("mock server is not running")
return self.url

def register_route_handler(
self,
method: str,
server: str,
path: str,
):
def wrapper(
handler: (
ResponseType | Callable[[MockHttpServerHandler], ResponseType]
)
):
self.routes[(method, server, path)] = handler
return handler

return wrapper

def handle_request(
self,
method: str,
server: str,
path: str,
handler: MockHttpServerHandler,
):
# `handler` is documented here:
# https://docs.python.org/3/library/http.server.html#http.server.BaseHTTPRequestHandler
key = (method, server, path)
if key not in self.requests:
self.requests[key] = []

# Parse and save the request details
parsed_path = urllib.parse.urlparse(path)
headers = {k.lower(): v for k, v in dict(handler.headers).items()}
query_params = urllib.parse.parse_qs(parsed_path.query)
if 'content-length' in headers:
body = handler.rfile.read(int(headers['content-length'])).decode()
else:
body = None

request_details = {
'headers': headers,
'query_params': query_params,
'body': body,
}
self.requests[key].append(request_details)

if key not in self.routes:
handler.send_error(404)
return

registered_handler = self.routes[key]

if callable(registered_handler):
try:
handler_result = registered_handler(handler)
if len(handler_result) == 2:
response, status = handler_result
additional_headers = None
elif len(handler_result) == 3:
response, status, additional_headers = handler_result
except Exception:
handler.send_error(500)
raise
else:
if len(registered_handler) == 2:
response, status = registered_handler
additional_headers = None
elif len(registered_handler) == 3:
response, status, additional_headers = registered_handler

if "headers" in request_details and isinstance(
request_details["headers"], dict
):
accept_header = request_details["headers"].get(
"accept", "application/json"
)
else:
accept_header = "application/json"

if (
accept_header.startswith("application/json")
or (
accept_header.startswith("application/")
and "vnd." in accept_header
and "+json" in accept_header
)
or accept_header == "*/*"
):
content_type = 'application/json'
elif accept_header.startswith("application/x-www-form-urlencoded"):
content_type = 'application/x-www-form-urlencoded'
else:
handler.send_error(
415, f"Unsupported accept header: {accept_header}"
)
return

data = response.encode()

handler.send_response(status)
handler.send_header('Content-Type', content_type)
handler.send_header('Content-Length', str(len(data)))
if additional_headers is not None:
for header, value in additional_headers.items():
handler.send_header(header, value)
handler.end_headers()
handler.wfile.write(data)

def start(self):
assert not hasattr(self, '_http_runner')
self._http_runner = threading.Thread(target=self._http_worker)
self._http_runner.start()
self.has_started.wait()
self.url = f'http://{self._address[0]}:{self._address[1]}/'

def __enter__(self):
self.start()
return self

def _http_worker(self):
self._http_server = http.server.HTTPServer(
('localhost', 0), MockHttpServerHandler
)
self._http_server.owner = self
self._address = self._http_server.server_address
self.has_started.set()
self._http_server.serve_forever(poll_interval=0.01)
self._http_server.server_close()

def stop(self):
self._http_server.shutdown()
if self._http_runner is not None:
self._http_runner.join()
self._http_runner = None

def __exit__(self, *exc):
self.stop()
self.url = None
Loading

0 comments on commit 45d9358

Please sign in to comment.