From 63e6fc41ad57feeb3dd6259e40b69e7cf505789a Mon Sep 17 00:00:00 2001 From: Matt Mastracci Date: Thu, 31 Oct 2024 13:19:13 -0600 Subject: [PATCH] Full http in Rust support, remove httpx (#7927) Makes use of HttpClient throughout the entire server process: - Auth - AI extension (incl. SSE) - std::net::http Notes: - `httpx` is no longer a dep (kept as a dev-dependency) - Uses an `EdgeDB ${version}` User-Agent for all HTTP requests - All HTTP requests are subject to `http_max_connections` config option. This option is refreshed from the std::net worker. - SSE supports backpressure, and queue is fixed at 100 messages for now - Streaming and non-streaming requests use separate APIs to simplify the implementation - One thread per tenant is used if HTTP is required (we boot this thread only when needed). Eventually we'll move to a one-runtime-per-tenant and this thread will disappear. --- Cargo.lock | 80 ++-- edb/server/http.py | 419 +++++++++++++++++--- edb/server/http/Cargo.toml | 3 +- edb/server/http/src/python.rs | 345 +++++++++++++--- edb/server/net_worker.py | 9 +- edb/server/protocol/ai_ext.py | 83 ++-- edb/server/protocol/auth_ext/base.py | 5 +- edb/server/protocol/auth_ext/http.py | 4 +- edb/server/protocol/auth_ext/http_client.py | 64 --- edb/server/protocol/auth_ext/oauth.py | 6 +- edb/server/tenant.py | 13 +- pyproject.toml | 4 +- tests/test_http.py | 258 ++++++++++++ 13 files changed, 1038 insertions(+), 255 deletions(-) delete mode 100644 edb/server/protocol/auth_ext/http_client.py create mode 100644 tests/test_http.py diff --git a/Cargo.lock b/Cargo.lock index 5de089324a0..0cde3ad83d6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -352,7 +352,7 @@ dependencies = [ "pretty_assertions", "pyo3", "rand", - "rstest 0.22.0", + "rstest", "scopeguard", "serde", "serde-pickle", @@ -673,6 +673,17 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "eventsource-stream" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "74fef4569247a5f429d9156b9d0a2599914385dd189c539334c625d8099d90ab" +dependencies = [ + "futures-core", + "nom", + "pin-project-lite", +] + [[package]] name = "factorial" version = "0.2.1" @@ -963,10 +974,11 @@ name = "http" version = "0.1.0" dependencies = [ "derive_more", + "eventsource-stream", "futures", "pyo3", "reqwest", - "rstest 0.23.0", + "rstest", "scopeguard", "tokio", "tracing", @@ -1247,6 +1259,12 @@ version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" +[[package]] +name = "minimal-lexical" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" + [[package]] name = "miniz_oxide" version = "0.7.4" @@ -1323,6 +1341,16 @@ dependencies = [ "tempfile", ] +[[package]] +name = "nom" +version = "7.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a" +dependencies = [ + "memchr", + "minimal-lexical", +] + [[package]] name = "nu-ansi-term" version = "0.46.0" @@ -1524,7 +1552,7 @@ dependencies = [ "pyo3", "rand", "roaring", - "rstest 0.22.0", + "rstest", "scopeguard", "serde", "serde-pickle", @@ -1879,6 +1907,7 @@ dependencies = [ "url", "wasm-bindgen", "wasm-bindgen-futures", + "wasm-streams", "web-sys", "windows-registry", ] @@ -1908,18 +1937,6 @@ dependencies = [ "byteorder", ] -[[package]] -name = "rstest" -version = "0.22.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b423f0e62bdd61734b67cd21ff50871dfaeb9cc74f869dcd6af974fbcb19936" -dependencies = [ - "futures", - "futures-timer", - "rstest_macros 0.22.0", - "rustc_version", -] - [[package]] name = "rstest" version = "0.23.0" @@ -1928,28 +1945,10 @@ checksum = "0a2c585be59b6b5dd66a9d2084aa1d8bd52fbdb806eafdeffb52791147862035" dependencies = [ "futures", "futures-timer", - "rstest_macros 0.23.0", + "rstest_macros", "rustc_version", ] -[[package]] -name = "rstest_macros" -version = "0.22.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c5e1711e7d14f74b12a58411c542185ef7fb7f2e7f8ee6e2940a883628522b42" -dependencies = [ - "cfg-if", - "glob", - "proc-macro-crate", - "proc-macro2", - "quote", - "regex", - "relative-path", - "rustc_version", - "syn", - "unicode-ident", -] - [[package]] name = "rstest_macros" version = "0.23.0" @@ -2822,6 +2821,19 @@ version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "65fc09f10666a9f147042251e0dda9c18f166ff7de300607007e96bdebc1068d" +[[package]] +name = "wasm-streams" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b65dc4c90b63b118468cf747d8bf3566c1913ef60be765b5730ead9e0a3ba129" +dependencies = [ + "futures-util", + "js-sys", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", +] + [[package]] name = "web-sys" version = "0.3.72" diff --git a/edb/server/http.py b/edb/server/http.py index a01c2f82899..b49f7f5d5e5 100644 --- a/edb/server/http.py +++ b/edb/server/http.py @@ -23,6 +23,9 @@ Any, Mapping, Optional, + Union, + Self, + Callable, ) import asyncio @@ -31,59 +34,193 @@ import os import json as json_lib import urllib.parse +import time +from http import HTTPStatus as HTTPStatus from edb.server._http import Http logger = logging.getLogger("edb.server") +HeaderType = Optional[Union[list[tuple[str, str]], dict[str, str]]] + + +@dataclasses.dataclass(frozen=True) +class HttpStat: + response_time_ms: int + error_code: int + response_body_size: int + response_content_type: str + request_body_size: int + request_content_type: str + method: str + streaming: bool + + +StatCallback = Callable[[HttpStat], None] class HttpClient: - def __init__(self, limit: int): - self._client = Http(limit) - self._fd = self._client._fd + def __init__( + self, + limit: int, + user_agent: str = "EdgeDB", + stat_callback: Optional[StatCallback] = None, + ): self._task = None + self._client = None + self._limit = limit self._skip_reads = 0 - self._loop = asyncio.get_running_loop() - self._task = self._loop.create_task(self._boot(self._loop)) + self._loop: Optional[asyncio.AbstractEventLoop] = ( + asyncio.get_running_loop() + ) + self._task = None + self._streaming: dict[int, asyncio.Queue[Any]] = {} self._next_id = 0 self._requests: dict[int, asyncio.Future] = {} + self._user_agent = user_agent + self._stat_callback = stat_callback def __del__(self) -> None: - if self._task: + self.close() + + def __enter__(self) -> HttpClient: + return self + + def __exit__(self, exc_type, exc_value, traceback) -> None: + self.close() + + def close(self) -> None: + if self._task is not None: self._task.cancel() self._task = None + self._loop = None + self._client = None + + def _ensure_task(self): + if self._loop is None: + raise Exception("HttpClient was closed") + if self._task is None: + self._client = Http(self._limit) + self._task = self._loop.create_task( + self._boot(self._loop, self._client._fd) + ) + + def _ensure_client(self): + if self._client is None: + raise Exception("HttpClient was closed") + return self._client + + def _safe_close(self, id): + if self._client is not None: + self._client._close(id) + + def _safe_ack(self, id): + if self._client is not None: + self._client._ack_sse(id) def _update_limit(self, limit: int): - self._client._update_limit(limit) + if self._client is not None and limit != self._limit: + self._limit = limit + self._client._update_limit(limit) + + def _process_headers(self, headers: HeaderType) -> list[tuple[str, str]]: + if headers is None: + return [] + if isinstance(headers, Mapping): + return [(k, v) for k, v in headers.items()] + if isinstance(headers, list): + return headers + print(headers) + raise ValueError(f"Invalid headers type: {type(headers)}") + + def _process_content( + self, + headers: list[tuple[str, str]], + data: bytes | str | dict[str, str] | None = None, + json: Any | None = None, + ) -> bytes: + if json is not None: + data = json_lib.dumps(json).encode('utf-8') + headers.append(('Content-Type', 'application/json')) + elif isinstance(data, str): + data = data.encode('utf-8') + elif isinstance(data, dict): + data = urllib.parse.urlencode(data).encode('utf-8') + headers.append( + ('Content-Type', 'application/x-www-form-urlencoded') + ) + elif data is None: + data = bytes() + elif isinstance(data, bytes): + pass + else: + raise ValueError(f"Invalid content type: {type(data)}") + return data + + def _process_path(self, path: str) -> str: + return path + + def with_context( + self, + *, + base_url: Optional[str] = None, + headers: HeaderType = None, + url_munger: Optional[Callable[[str], str]] = None, + ) -> Self: + """Create an HttpClient with common optional base URL and headers that + will be applied to all requests.""" + return HttpClientContext( + http_client=self, + base_url=base_url, + headers=headers, + url_munger=url_munger, + ) # type: ignore async def request( self, *, method: str, - url: str, - content: bytes | None, - headers: list[tuple[str, str]] | None, - ) -> tuple[int, bytes, dict[str, str]]: - if content is None: - content = bytes() - if headers is None: - headers = [] + path: str, + headers: HeaderType = None, + data: bytes | str | dict[str, str] | None = None, + json: Any | None = None, + ) -> tuple[int, bytearray, dict[str, str]]: + self._ensure_task() + path = self._process_path(path) + headers_list = self._process_headers(headers) + headers_list.append(("User-Agent", self._user_agent)) + data = self._process_content(headers_list, data, json) id = self._next_id self._next_id += 1 self._requests[id] = asyncio.Future() + start_time = time.time() try: - self._client._request(id, url, method, content, headers) + self._ensure_client()._request(id, path, method, data, headers_list) resp = await self._requests[id] + if self._stat_callback: + status_code, body, headers = resp + self._stat_callback( + HttpStat( + response_time_ms=int((time.time() - start_time) * 1000), + error_code=status_code, + response_body_size=len(body), + response_content_type=dict(headers_list).get( + 'content-type', '' + ), + request_body_size=len(data), + request_content_type=dict(headers_list).get( + 'content-type', '' + ), + method=method, + streaming=False, + ) + ) return resp finally: del self._requests[id] - async def get( - self, path: str, *, headers: dict[str, str] | None = None - ) -> Response: - headers_list = [(k, v) for k, v in headers.items()] if headers else None + async def get(self, path: str, *, headers: HeaderType = None) -> Response: result = await self.request( - method="GET", url=path, content=None, headers=headers_list + method="GET", path=path, data=None, headers=headers ) return Response.from_tuple(result) @@ -91,33 +228,80 @@ async def post( self, path: str, *, - headers: dict[str, str] | None = None, + headers: HeaderType = None, data: bytes | str | dict[str, str] | None = None, json: Any | None = None, ) -> Response: - if json is not None: - data = json_lib.dumps(json).encode('utf-8') - headers = headers or {} - headers['Content-Type'] = 'application/json' - elif isinstance(data, str): - data = data.encode('utf-8') - elif isinstance(data, dict): - data = urllib.parse.urlencode(data).encode('utf-8') - headers = headers or {} - headers['Content-Type'] = 'application/x-www-form-urlencoded' - - headers_list = [(k, v) for k, v in headers.items()] if headers else None result = await self.request( - method="POST", url=path, content=data, headers=headers_list + method="POST", path=path, data=data, json=json, headers=headers ) return Response.from_tuple(result) - async def _boot(self, loop: asyncio.AbstractEventLoop) -> None: - logger.info("Python-side HTTP client booted") + async def stream_sse( + self, + path: str, + *, + method: str = "POST", + headers: HeaderType = None, + data: bytes | str | dict[str, str] | None = None, + json: Any | None = None, + ) -> Response | ResponseSSE: + self._ensure_task() + path = self._process_path(path) + headers_list = self._process_headers(headers) + headers_list.append(("User-Agent", self._user_agent)) + data = self._process_content(headers_list, data, json) + + id = self._next_id + self._next_id += 1 + self._requests[id] = asyncio.Future() + start_time = time.time() + try: + self._ensure_client()._request_sse( + id, path, method, data, headers_list + ) + resp = await self._requests[id] + if self._stat_callback: + if id in self._streaming: + status_code, headers = resp + body = b'' + else: + status_code, body, headers = resp + self._stat_callback( + HttpStat( + response_time_ms=int((time.time() - start_time) * 1000), + error_code=status_code, + response_body_size=len(body), + response_content_type=dict(headers_list).get( + 'content-type', '' + ), + request_body_size=len(data), + request_content_type=dict(headers_list).get( + 'content-type', '' + ), + method=method, + streaming=id in self._streaming, + ) + ) + if id in self._streaming: + # Valid to call multiple times + cancel = lambda: self._safe_close(id) + # Acknowledge SSE message (for backpressure) + ack = lambda: self._safe_ack(id) + return ResponseSSE.from_tuple( + resp, self._streaming[id], cancel, ack + ) + return Response.from_tuple(resp) + finally: + del self._requests[id] + + async def _boot(self, loop: asyncio.AbstractEventLoop, fd: int) -> None: + logger.info(f"HTTP client initialized, user_agent={self._user_agent}") reader = asyncio.StreamReader(loop=loop) reader_protocol = asyncio.StreamReaderProtocol(reader) - fd = os.fdopen(self._client._fd, 'rb') - transport, _ = await loop.connect_read_pipe(lambda: reader_protocol, fd) + transport, _ = await loop.connect_read_pipe( + lambda: reader_protocol, os.fdopen(fd, 'rb') + ) try: while len(await reader.read(1)) == 1: if not self._client or not self._task: @@ -133,13 +317,89 @@ async def _boot(self, loop: asyncio.AbstractEventLoop) -> None: transport.close() def _process_message(self, msg): - msg_type, id, data = msg + try: + msg_type, id, data = msg + if msg_type == 0: # Error + if id in self._requests: + self._requests[id].set_exception(Exception(data)) + if id in self._streaming: + self._streaming[id].put_nowait(None) + del self._streaming[id] + elif msg_type == 1: # Response + if id in self._requests: + self._requests[id].set_result(data) + elif msg_type == 2: # SSEStart + if id in self._requests: + self._streaming[id] = asyncio.Queue() + self._requests[id].set_result(data) + elif msg_type == 3: # SSEEvent + if id in self._streaming: + self._streaming[id].put_nowait(data) + elif msg_type == 4: # SSEEnd + if id in self._streaming: + self._streaming[id].put_nowait(None) + del self._streaming[id] + except Exception as e: + logger.error(f"Error processing message: {e}", exc_info=True) + raise + + async def __aenter__(self) -> Self: + return self + + async def __aexit__(self, *args) -> None: # type: ignore + pass - if id in self._requests: - if msg_type == 1: - self._requests[id].set_result(data) - elif msg_type == 0: - self._requests[id].set_exception(Exception(data)) + +class HttpClientContext(HttpClient): + def __init__( + self, + http_client: HttpClient, + url_munger: Callable[[str], str] | None = None, + headers: HeaderType = None, + base_url: str | None = None, + ): + self._task = None + self.url_munger = url_munger + self.http_client = http_client + self.base_url = base_url + self.headers = super()._process_headers(headers) + + def _process_headers(self, headers): + headers = super()._process_headers(headers) + headers += self.headers + return headers + + def _process_path(self, path): + path = super()._process_path(path) + if self.base_url is not None: + path = self.base_url + path + if self.url_munger is not None: + path = self.url_munger(path) + return path + + async def request( + self, *, method, path, headers=None, data=None, json=None + ): + path = self._process_path(path) + headers = self._process_headers(headers) + return await self.http_client.request( + method=method, path=path, headers=headers, data=data, json=json + ) + + async def stream_sse( + self, path, *, method="POST", headers=None, data=None, json=None + ): + path = self._process_path(path) + headers = self._process_headers(headers) + return await self.http_client.stream_sse( + path, method=method, headers=headers, data=data, json=json + ) + + async def __aenter__(self) -> Self: + return self + + async def __aexit__(self, *args) -> None: # type: ignore + pass class CaseInsensitiveDict(dict): @@ -171,14 +431,15 @@ def update(self, *args, **kwargs: str) -> None: self[key] = value -@dataclasses.dataclass +@dataclasses.dataclass(frozen=True) class Response: status_code: int - body: bytes + body: bytearray headers: CaseInsensitiveDict + is_streaming: bool = False @classmethod - def from_tuple(cls, data: Tuple[int, bytes, dict[str, str]]): + def from_tuple(cls, data: Tuple[int, bytearray, dict[str, str]]): status_code, body, headers_list = data headers = CaseInsensitiveDict([(k, v) for k, v in headers_list.items()]) return cls(status_code, body, headers) @@ -186,6 +447,68 @@ def from_tuple(cls, data: Tuple[int, bytes, dict[str, str]]): def json(self): return json_lib.loads(self.body.decode('utf-8')) + def bytes(self): + return bytes(self.body) + @property def text(self) -> str: return self.body.decode('utf-8') + + async def __aenter__(self) -> Self: + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + pass + + +@dataclasses.dataclass(frozen=True) +class ResponseSSE: + status_code: int + headers: CaseInsensitiveDict + _stream: asyncio.Queue = dataclasses.field(repr=False) + _cancel: Callable[[], None] = dataclasses.field(repr=False) + _ack: Callable[[], None] = dataclasses.field(repr=False) + is_streaming: bool = True + + @classmethod + def from_tuple( + cls, + data: Tuple[int, dict[str, str]], + stream: asyncio.Queue, + cancel: Callable[[], None], + ack: Callable[[], None], + ): + status_code, headers = data + headers = CaseInsensitiveDict([(k, v) for k, v in headers.items()]) + return cls(status_code, headers, stream, cancel, ack) + + @dataclasses.dataclass(frozen=True) + class SSEEvent: + event: str + data: str + id: Optional[str] = None + + def close(self): + self._cancel() + + def __del__(self): + self.close() + + def __aiter__(self): + return self + + async def __anext__(self): + next = await self._stream.get() + try: + if next is None: + raise StopAsyncIteration + id, data, event = next + return self.SSEEvent(event, data, id) + finally: + self._ack() + + async def __aenter__(self) -> Self: + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + self.close() diff --git a/edb/server/http/Cargo.toml b/edb/server/http/Cargo.toml index 6c865ac8a47..0557767b227 100644 --- a/edb/server/http/Cargo.toml +++ b/edb/server/http/Cargo.toml @@ -16,8 +16,9 @@ pyo3 = { workspace = true, optional = true } tokio.workspace = true tracing = "0" tracing-subscriber = "0" -reqwest = { version = "0.12", features = ["gzip", "deflate"] } +reqwest = { version = "0.12", features = ["gzip", "deflate", "stream"] } scopeguard = "1" +eventsource-stream = "0.2.3" futures = "0" diff --git a/edb/server/http/src/python.rs b/edb/server/http/src/python.rs index dbd035770a1..e48be467b0d 100644 --- a/edb/server/http/src/python.rs +++ b/edb/server/http/src/python.rs @@ -1,25 +1,38 @@ -use futures::future::poll_fn; +use eventsource_stream::Eventsource; +use futures::{future::poll_fn, TryStreamExt}; use pyo3::{exceptions::PyException, prelude::*, types::PyByteArray}; -use reqwest::Method; -use scopeguard::ScopeGuard; +use reqwest::{header::HeaderValue, Method}; +use scopeguard::{defer, guard, ScopeGuard}; use std::{ - cell::RefCell, collections::HashMap, os::fd::IntoRawFd, pin::Pin, rc::Rc, sync::Mutex, thread, + cell::RefCell, + collections::HashMap, + os::fd::IntoRawFd, + pin::Pin, + rc::Rc, + sync::{Arc, Mutex}, + thread, time::Duration, }; use tokio::{ io::AsyncWrite, sync::{AcquireError, Semaphore, SemaphorePermit}, - task::LocalSet, + task::{JoinHandle, LocalSet}, }; use tracing::{error, info, trace}; pyo3::create_exception!(_http, InternalError, PyException); +/// The backlog for SSE message +const SSE_QUEUE_SIZE: usize = 100; + type PythonConnId = u64; #[derive(Debug)] enum RustToPythonMessage { Response(PythonConnId, (u16, Vec, HashMap)), + SSEStart(PythonConnId, (u16, HashMap)), + SSEEvent(PythonConnId, eventsource_stream::Event), + SSEEnd(PythonConnId), Error(PythonConnId, String), } @@ -34,6 +47,11 @@ impl ToPyObject for RustToPythonMessage { (*status, PyByteArray::new_bound(py, &body), headers), ) .to_object(py), + SSEStart(conn, (status, headers)) => (2, conn, (status, headers)).to_object(py), + SSEEvent(conn, message) => { + (3, conn, (&message.id, &message.data, &message.event)).to_object(py) + } + SSEEnd(conn) => (4, conn, ()).to_object(py), } } } @@ -44,6 +62,12 @@ enum PythonToRustMessage { UpdateLimit(usize), /// Perform a request Request(PythonConnId, String, String, Vec, Vec<(String, String)>), + /// Perform a request with SSE + RequestSse(PythonConnId, String, String, Vec, Vec<(String, String)>), + /// Close an SSE connection + Close(PythonConnId), + /// Acknowledge an SSE message + Ack(PythonConnId), } type PipeSender = tokio::net::unix::pipe::Sender; @@ -99,9 +123,9 @@ async fn request( method: String, body: Vec, headers: Vec<(String, String)>, -) -> Result<(reqwest::StatusCode, Vec, HashMap), String> { +) -> Result { let method = - Method::from_bytes(method.as_bytes()).map_err(|e| format!("Invalid HTTP method: {}", e))?; + Method::from_bytes(method.as_bytes()).map_err(|e| format!("Invalid HTTP method: {e:?}"))?; let mut req = client.request(method, url); @@ -116,25 +140,108 @@ async fn request( let resp = req .send() .await - .map_err(|e| format!("Request failed: {}", e))?; + .map_err(|e| format!("Request failed: {e:?}"))?; - let status = resp.status(); - - let headers = resp - .headers() - .iter() - .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string())) - .collect(); + Ok(resp) +} +async fn request_bytes( + client: reqwest::Client, + url: String, + method: String, + body: Vec, + headers: Vec<(String, String)>, +) -> Result<(reqwest::StatusCode, Vec, HashMap), String> { + let resp = request(client, url, method, body, headers).await?; + let status = resp.status(); + let headers = process_headers(resp.headers()); let body = resp .bytes() .await - .map_err(|e| format!("Failed to read response body: {}", e))? + .map_err(|e| format!("Failed to read response body: {e:?}"))? .to_vec(); Ok((status, body, headers)) } +async fn request_sse( + client: reqwest::Client, + id: PythonConnId, + backpressure: Arc, + url: String, + method: String, + body: Vec, + headers: Vec<(String, String)>, + rpc_pipe: Rc, +) -> Result<(), String> { + trace!("Entering SSE"); + let guard = guard((), |_| trace!("Exiting SSE due to cancellation")); + let response = request(client, url, method, body, headers).await?; + + if response.headers().get("content-type") + != Some(&HeaderValue::from_static("text/event-stream")) + { + let headers = process_headers(response.headers()); + let status = response.status(); + let body = match response.bytes().await { + Ok(bytes) => bytes.to_vec(), + Err(e) => { + return Err(format!("Failed to read response body: {e:?}")); + } + }; + _ = rpc_pipe + .write(RustToPythonMessage::Response( + id, + (status.as_u16(), body, headers), + )) + .await; + + ScopeGuard::into_inner(guard); + return Ok(()); + } + + let headers = process_headers(response.headers()); + let status = response.status(); + _ = rpc_pipe + .write(RustToPythonMessage::SSEStart( + id, + (status.as_u16(), headers.clone()), + )) + .await; + + let mut stream = response.bytes_stream().eventsource(); + loop { + let chunk = match stream.try_next().await { + Ok(None) => break, + Ok(Some(chunk)) => chunk, + Err(e) => { + return Err(format!("Failed to read response body: {e:?}")); + } + }; + let Ok(permit) = backpressure.acquire().await else { + break; + }; + permit.forget(); + if rpc_pipe + .write(RustToPythonMessage::SSEEvent(id, chunk)) + .await + .is_err() + { + break; + } + } + + ScopeGuard::into_inner(guard); + Ok(()) +} + +fn process_headers(headers: &reqwest::header::HeaderMap) -> HashMap { + headers + .iter() + .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string())) + .collect() +} + #[derive(Debug, Clone, Copy)] struct PermitCount { active: usize, @@ -262,6 +369,11 @@ impl PermitManager { } } +struct HttpTask { + task: JoinHandle<()>, + backpressure: Arc, +} + async fn run_and_block(capacity: usize, rpc_pipe: RpcPipe) { let rpc_pipe = Rc::new(rpc_pipe); @@ -274,6 +386,7 @@ async fn run_and_block(capacity: usize, rpc_pipe: RpcPipe) { let client = client.build().unwrap(); let permit_manager = Rc::new(PermitManager::new(capacity)); + let tasks = Arc::new(Mutex::new(HashMap::::new())); loop { let Some(rpc) = poll_fn(|cx| rpc_pipe.python_to_rust.borrow_mut().poll_recv(cx)).await @@ -284,34 +397,111 @@ async fn run_and_block(capacity: usize, rpc_pipe: RpcPipe) { let client = client.clone(); trace!("Received RPC: {rpc:?}"); let rpc_pipe = rpc_pipe.clone(); - let permit_manager = permit_manager.clone(); - tokio::task::spawn_local(async move { - use PythonToRustMessage::*; - match rpc { - UpdateLimit(limit) => { - permit_manager.update_limit(limit); + // Allocate a task ID and backpressure object if we're initiating a + // request. This would be less awkward if we allocated in the Rust side + // of the code rather than the Python side. + let (id, backpressure) = match rpc { + PythonToRustMessage::Request(id, ..) | PythonToRustMessage::RequestSse(id, ..) => { + (Some(id), Some(Semaphore::new(SSE_QUEUE_SIZE).into())) + } + _ => (None, None), + }; + let task = tokio::task::spawn_local(execute( + id.clone(), + backpressure.clone(), + tasks.clone(), + rpc, + permit_manager.clone(), + client, + rpc_pipe, + )); + if let (Some(id), Some(backpressure)) = (id, backpressure) { + tasks + .lock() + .unwrap() + .insert(id, HttpTask { task, backpressure }); + } + } +} + +async fn execute( + id: Option, + backpressure: Option>, + tasks_clone: Arc>>, + rpc: PythonToRustMessage, + permit_manager: Rc, + client: reqwest::Client, + rpc_pipe: Rc, +) { + // If a request task was booted by this request, remove it from the list of + // tasks when we exit. + if let Some(id) = id { + defer!(_ = tasks_clone.lock().unwrap().remove(&id)); + } + + use PythonToRustMessage::*; + match rpc { + UpdateLimit(limit) => { + permit_manager.update_limit(limit); + } + Request(id, url, method, body, headers) => { + let Ok(permit) = permit_manager.acquire().await else { + return; + }; + match request_bytes(client, url, method, body, headers).await { + Ok((status, body, headers)) => { + _ = rpc_pipe + .write(RustToPythonMessage::Response( + id, + (status.as_u16(), body, headers), + )) + .await; } - Request(id, url, method, body, headers) => { - let Ok(permit) = permit_manager.acquire().await else { - return; - }; - match request(client, url, method, body, headers).await { - Ok((status, body, headers)) => { - _ = rpc_pipe - .write(RustToPythonMessage::Response( - id, - (status.as_u16(), body, headers), - )) - .await; - } - Err(err) => { - _ = rpc_pipe.write(RustToPythonMessage::Error(id, err)).await; - } - } - drop(permit); + Err(err) => { + _ = rpc_pipe.write(RustToPythonMessage::Error(id, err)).await; } } - }); + drop(permit); + } + RequestSse(id, url, method, body, headers) => { + // Ensure we send the end message whenever this block exits + defer!(_ = rpc_pipe.write(RustToPythonMessage::SSEEnd(id))); + let Ok(permit) = permit_manager.acquire().await else { + return; + }; + match request_sse( + client, + id, + backpressure.unwrap(), + url, + method, + body, + headers, + rpc_pipe.clone(), + ) + .await + { + Ok(..) => {} + Err(err) => { + _ = rpc_pipe + .write(RustToPythonMessage::Error(id, format!("SSE error: {err}"))) + .await; + } + } + drop(permit); + } + Ack(id) => { + let lock = tasks_clone.lock().unwrap(); + if let Some(task) = lock.get(&id) { + task.backpressure.add_permits(1); + } + } + Close(id) => { + let Some(task) = tasks_clone.lock().unwrap().remove(&id) else { + return; + }; + task.task.abort(); + } } } @@ -326,27 +516,31 @@ impl Http { let (txpr, rxpr) = tokio::sync::mpsc::unbounded_channel(); let (txfd, rxfd) = std::sync::mpsc::channel(); - thread::spawn(move || { - info!("Rust-side Http thread booted"); - let rt = tokio::runtime::Builder::new_current_thread() - .enable_time() - .enable_io() - .build() - .unwrap(); - let _guard = rt.enter(); - let (txn, rxn) = tokio::net::unix::pipe::pipe().unwrap(); - let fd = rxn.into_nonblocking_fd().unwrap().into_raw_fd() as u64; - txfd.send(fd).unwrap(); - let local = LocalSet::new(); - - let rpc_pipe = RpcPipe { - python_to_rust: rxpr.into(), - rust_to_python: txrp, - rust_to_python_notify: txn.into(), - }; - - local.block_on(&rt, run_and_block(max_capacity, rpc_pipe)); - }); + thread::Builder::new() + .name("edgedb-http".to_string()) + .spawn(move || { + defer!(info!("Rust-side Http thread exiting")); + info!("Rust-side Http thread booted"); + let rt = tokio::runtime::Builder::new_current_thread() + .enable_time() + .enable_io() + .build() + .unwrap(); + let _guard = rt.enter(); + let (txn, rxn) = tokio::net::unix::pipe::pipe().unwrap(); + let fd = rxn.into_nonblocking_fd().unwrap().into_raw_fd() as u64; + txfd.send(fd).unwrap(); + let local = LocalSet::new(); + + let rpc_pipe = RpcPipe { + python_to_rust: rxpr.into(), + rust_to_python: txrp, + rust_to_python_notify: txn.into(), + }; + + local.block_on(&rt, run_and_block(max_capacity, rpc_pipe)); + }) + .expect("Failed to create HTTP thread"); let notify_fd = rxfd.recv().unwrap(); Http { @@ -374,6 +568,33 @@ impl Http { .map_err(|_| internal_error("In shutdown")) } + fn _request_sse( + &self, + id: PythonConnId, + url: String, + method: String, + body: Vec, + headers: Vec<(String, String)>, + ) -> PyResult<()> { + self.python_to_rust + .send(PythonToRustMessage::RequestSse( + id, url, method, body, headers, + )) + .map_err(|_| internal_error("In shutdown")) + } + + fn _close(&self, id: PythonConnId) -> PyResult<()> { + self.python_to_rust + .send(PythonToRustMessage::Close(id)) + .map_err(|_| internal_error("In shutdown")) + } + + fn _ack_sse(&self, id: PythonConnId) -> PyResult<()> { + self.python_to_rust + .send(PythonToRustMessage::Ack(id)) + .map_err(|_| internal_error("In shutdown")) + } + fn _update_limit(&self, limit: usize) -> PyResult<()> { self.python_to_rust .send(PythonToRustMessage::UpdateLimit(limit)) diff --git a/edb/server/net_worker.py b/edb/server/net_worker.py index df632092a79..784eea38724 100644 --- a/edb/server/net_worker.py +++ b/edb/server/net_worker.py @@ -99,10 +99,7 @@ async def _http_task(tenant: edbtenant.Tenant, http_client) -> None: def create_http(tenant: edbtenant.Tenant): - http_max_connections = tenant._server.config_lookup( - 'http_max_connections', tenant.get_sys_config() - ) - return HttpClient(http_max_connections) + return tenant.get_http_client(originator="std::net") async def http(server: edbserver.BaseServer) -> None: @@ -166,8 +163,8 @@ async def handle_request( ) response = await client.request( method=request.method, - url=request.url, - content=request.body, + path=request.url, + data=request.body, headers=headers, ) response_status, response_bytes, response_hdict = response diff --git a/edb/server/protocol/ai_ext.py b/edb/server/protocol/ai_ext.py index c5f662c441f..4e35494a848 100644 --- a/edb/server/protocol/ai_ext.py +++ b/edb/server/protocol/ai_ext.py @@ -34,19 +34,14 @@ import asyncio import contextlib import contextvars -import http import itertools import json import logging import uuid -import httpx -import httpx_sse - import tiktoken from mistral_common.tokens.tokenizers import mistral as mistral_tokenizer - from edb import errors from edb.common import asyncutil from edb.common import debug @@ -54,7 +49,7 @@ from edb.common import markup from edb.common import uuidgen -from edb.server import compiler +from edb.server import compiler, http from edb.server.compiler import sertypes from edb.server.protocol import execute from edb.server.protocol import request_scheduler as rs @@ -297,9 +292,9 @@ async def _ext_ai_index_builder_controller_loop( try: while True: + models = [] + sleep_timer: rs.Timer = rs.Timer(None, False) try: - models = [] - sleep_timer: rs.Timer = rs.Timer(None, False) async with tenant.with_pgcon(dbname) as pgconn: models = await _ext_ai_fetch_active_models(pgconn) if models: @@ -309,6 +304,7 @@ async def _ext_ai_index_builder_controller_loop( provider_contexts = _prepare_provider_contexts( db, pgconn, + tenant.get_http_client(originator="ai/index"), models, provider_schedulers, naptime, @@ -388,6 +384,7 @@ async def _ext_ai_unlock( def _prepare_provider_contexts( db: dbview.Database, pgconn: pgcon.PGConnection, + http_client: http.HttpClient, models: list[tuple[int, str, str]], provider_schedulers: dict[str, ProviderScheduler], naptime: float, @@ -433,6 +430,7 @@ def _prepare_provider_contexts( naptime=naptime, db=db, pgconn=pgconn, + http_client=http_client, provider_models=provider_models, ) @@ -473,6 +471,7 @@ class ProviderContext(rs.Context): db: dbview.Database pgconn: pgcon.PGConnection + http_client: http.HttpClient provider_models: list[str] @@ -494,6 +493,7 @@ async def get_params( return await _generate_embeddings_params( context.db, context.pgconn, + context.http_client, self.provider_name, context.provider_models, self.model_excluded_ids, @@ -518,6 +518,7 @@ def finalize(self, execution_report: rs.ExecutionReport) -> None: @dataclass(frozen=True, kw_only=True) class EmbeddingsParams(rs.Params[EmbeddingsData]): pgconn: pgcon.PGConnection + http_client: http.HttpClient provider: ProviderConfig model_name: str inputs: list[tuple[PendingEmbedding, str]] @@ -546,6 +547,7 @@ async def run(self) -> Optional[rs.Result[EmbeddingsData]]: self.params.model_name, [input[1] for input in self.params.inputs], self.params.shortening, + self.params.http_client ) result.pgconn = self.params.pgconn result.pending_entries = [ @@ -600,6 +602,7 @@ async def finalize(self) -> None: async def _generate_embeddings_params( db: dbview.Database, pgconn: pgcon.PGConnection, + http_client: http.HttpClient, provider_name: str, provider_models: list[str], model_excluded_ids: dict[str, list[str]], @@ -724,6 +727,7 @@ async def _generate_embeddings_params( input_entries.append(pending_entry) if model_name in model_tokenizers: + tokenizer = model_tokenizers[model_name] max_batch_tokens = model_max_batch_tokens[model_name] if isinstance(tokens_rate_limit, int): # If the rate limit is lower than the batch limit, use that @@ -753,6 +757,7 @@ async def _generate_embeddings_params( inputs=inputs, token_count=batch_token_count, shortening=shortening, + http_client=http_client, )) else: @@ -769,6 +774,7 @@ async def _generate_embeddings_params( inputs=inputs, token_count=total_token_count, shortening=shortening, + http_client=http_client, )) return embeddings_params @@ -975,6 +981,7 @@ async def _generate_embeddings( model_name: str, inputs: list[str], shortening: Optional[int], + http_client: http.HttpClient, ) -> EmbeddingsResult: task_name = _task_name.get() count = len(inputs) @@ -986,7 +993,7 @@ async def _generate_embeddings( if provider.api_style == ApiStyle.OpenAI: return await _generate_openai_embeddings( - provider, model_name, inputs, shortening, + provider, model_name, inputs, shortening, http_client ) else: raise RuntimeError( @@ -1000,6 +1007,7 @@ async def _generate_openai_embeddings( model_name: str, inputs: list[str], shortening: Optional[int], + http_client: http.HttpClient, ) -> EmbeddingsResult: headers = { @@ -1007,7 +1015,7 @@ async def _generate_openai_embeddings( } if provider.name == "builtin::openai" and provider.client_id: headers["OpenAI-Organization"] = provider.client_id - client = httpx.AsyncClient( + client = http_client.with_context( headers=headers, base_url=provider.api_url, ) @@ -1040,7 +1048,7 @@ async def _generate_openai_embeddings( ) return EmbeddingsResult( - data=(error if error else EmbeddingsData(result.content)), + data=(error if error else EmbeddingsData(result.bytes())), limits=_read_openai_limits(result), ) @@ -1113,6 +1121,7 @@ async def _start_chat( request: protocol.HttpRequest, response: protocol.HttpResponse, provider: ProviderConfig, + http_client: http.HttpClient, model_name: str, messages: list[dict[str, Any]], stream: bool, @@ -1120,11 +1129,11 @@ async def _start_chat( if provider.api_style == "OpenAI": await _start_openai_chat( protocol, request, response, - provider, model_name, messages, stream) + provider, http_client, model_name, messages, stream) elif provider.api_style == "Anthropic": await _start_anthropic_chat( protocol, request, response, - provider, model_name, messages, stream) + provider, http_client, model_name, messages, stream) else: raise RuntimeError( f"unsupported model provider API style: {provider.api_style}, " @@ -1134,32 +1143,41 @@ async def _start_chat( @contextlib.asynccontextmanager async def aconnect_sse( - client: httpx.AsyncClient, + client: http.HttpClient, method: str, url: str, **kwargs: Any, -) -> AsyncIterator[httpx_sse.EventSource]: +) -> AsyncIterator[http.ResponseSSE]: headers = kwargs.pop("headers", {}) headers["Accept"] = "text/event-stream" headers["Cache-Control"] = "no-store" - stream = client.stream(method, url, headers=headers, **kwargs) - async with stream as response: + stm = await client.stream_sse( + method=method, + path=url, + headers=headers, + **kwargs + ) + if isinstance(stm, http.Response): + raise AIProviderError( + f"API call to generate chat completions failed with status " + f"{stm.status_code}: {stm.text}" + ) + async with stm as response: if response.status_code >= 400: - await response.aread() + # Unlikely that we have a streaming response with a non-200 result raise AIProviderError( f"API call to generate chat completions failed with status " - f"{response.status_code}: {response.text}" + f"{response.status_code}" ) - else: - yield httpx_sse.EventSource(response) + yield response async def _start_openai_like_chat( protocol: protocol.HttpProtocol, request: protocol.HttpRequest, response: protocol.HttpResponse, - client: httpx.AsyncClient, + client: http.HttpClient, model_name: str, messages: list[dict[str, Any]], stream: bool, @@ -1175,7 +1193,7 @@ async def _start_openai_like_chat( "stream": True, } ) as event_source: - async for sse in event_source.aiter_sse(): + async for sse in event_source: if not response.sent: response.status = http.HTTPStatus.OK response.content_type = b'text/event-stream' @@ -1265,7 +1283,6 @@ async def _start_openai_like_chat( ) if result.status_code >= 400: - await result.aread() raise AIProviderError( f"API call to generate chat completions failed with status " f"{result.status_code}: {result.text}" @@ -1284,6 +1301,7 @@ async def _start_openai_chat( request: protocol.HttpRequest, response: protocol.HttpResponse, provider: ProviderConfig, + http_client: http.HttpClient, model_name: str, messages: list[dict[str, Any]], stream: bool, @@ -1295,7 +1313,7 @@ async def _start_openai_chat( if provider.name == "builtin::openai" and provider.client_id: headers["OpenAI-Organization"] = provider.client_id - client = httpx.AsyncClient( + client = http_client.with_context( base_url=provider.api_url, headers=headers, ) @@ -1316,6 +1334,7 @@ async def _start_anthropic_chat( request: protocol.HttpRequest, response: protocol.HttpResponse, provider: ProviderConfig, + http_client: http.HttpClient, model_name: str, messages: list[dict[str, Any]], stream: bool, @@ -1328,7 +1347,7 @@ async def _start_anthropic_chat( headers["anthropic-version"] = "2023-06-01" headers["anthropic-beta"] = "messages-2023-12-15" - client = httpx.AsyncClient( + client = http_client.with_context( headers={ "anthropic-version": "2023-06-01", "anthropic-beta": "messages-2023-12-15", @@ -1360,7 +1379,7 @@ async def _start_anthropic_chat( "max_tokens": 4096, } ) as event_source: - async for sse in event_source.aiter_sse(): + async for sse in event_source: if not response.sent: response.status = http.HTTPStatus.OK response.content_type = b'text/event-stream' @@ -1421,7 +1440,6 @@ async def _start_anthropic_chat( ) if result.status_code >= 400: - await result.aread() raise AIProviderError( f"API call to generate chat completions failed with status " f"{result.status_code}: {result.text}" @@ -1497,6 +1515,8 @@ async def _handle_rag_request( tenant: srv_tenant.Tenant, ) -> None: try: + http_client = tenant.get_http_client(originator="ai/rag") + body = json.loads(request.body) if not isinstance(body, dict): raise TypeError( @@ -1600,6 +1620,7 @@ async def _handle_rag_request( vector_query = await _generate_embeddings_for_type( db, + http_client, ctx_query, content=query, ) @@ -1694,6 +1715,7 @@ async def _handle_rag_request( request, response, provider, + http_client, model, messages, stream, @@ -1744,6 +1766,7 @@ async def _handle_embeddings_request( model_name, inputs, shortening=None, + http_client=tenant.get_http_client(originator="ai/embeddings") ) if isinstance(result.data, rs.Error): raise AIProviderError(result.data.message) @@ -1907,6 +1930,7 @@ async def _get_model_annotation_as_int( async def _generate_embeddings_for_type( db: dbview.Database, + http_client: http.HttpClient, type_query: str, content: str, ) -> bytes: @@ -2000,7 +2024,8 @@ async def _generate_embeddings_for_type( else: shortening = None result = await _generate_embeddings( - provider, index["model"], [content], shortening=shortening) + provider, index["model"], [content], shortening=shortening, + http_client=http_client) if isinstance(result.data, rs.Error): raise AIProviderError(result.data.message) return result.data.embeddings diff --git a/edb/server/protocol/auth_ext/base.py b/edb/server/protocol/auth_ext/base.py index d42bac0cbf8..59e957cc5a4 100644 --- a/edb/server/protocol/auth_ext/base.py +++ b/edb/server/protocol/auth_ext/base.py @@ -25,7 +25,8 @@ from jwcrypto import jwt, jwk from datetime import datetime -from . import data, errors, http_client +from . import data, errors +from edb.server.http import HttpClient class BaseProvider: @@ -37,7 +38,7 @@ def __init__( client_secret: str, *, additional_scope: str | None, - http_factory: Callable[..., http_client.AuthHttpClient], + http_factory: Callable[..., HttpClient], ): self.name = name self.issuer_url = issuer_url diff --git a/edb/server/protocol/auth_ext/http.py b/edb/server/protocol/auth_ext/http.py index 2837edfd241..9f9450c0ef9 100644 --- a/edb/server/protocol/auth_ext/http.py +++ b/edb/server/protocol/auth_ext/http.py @@ -302,7 +302,7 @@ async def handle_authorize( db=self.db, provider_name=provider_name, url_munger=self._get_url_munger(request), - http_client=self.tenant.get_http_client(), + http_client=self.tenant.get_http_client(originator="auth"), ) await pkce.create(self.db, challenge) authorize_url = await oauth_client.get_authorize_url( @@ -402,7 +402,7 @@ async def handle_callback( db=self.db, provider_name=provider_name, url_munger=self._get_url_munger(request), - http_client=self.tenant.get_http_client(), + http_client=self.tenant.get_http_client(originator="auth"), ) ( identity, diff --git a/edb/server/protocol/auth_ext/http_client.py b/edb/server/protocol/auth_ext/http_client.py deleted file mode 100644 index 563605b5900..00000000000 --- a/edb/server/protocol/auth_ext/http_client.py +++ /dev/null @@ -1,64 +0,0 @@ -# -# This source file is part of the EdgeDB open source project. -# -# Copyright 2023-present MagicStack Inc. and the EdgeDB authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from typing import Any, Callable, Self - -from edb.server import http - - -class AuthHttpClient: - def __init__( - self, - http_client: http.HttpClient, - url_munger: Callable[[str], str] | None = None, - base_url: str | None = None, - ): - self.url_munger = url_munger - self.http_client = http_client - self.base_url = base_url - - async def post( - self, - path: str, - *, - headers: dict[str, str] | None = None, - data: bytes | str | dict[str, str] | None = None, - json: Any | None = None, - ) -> http.Response: - if self.base_url: - path = self.base_url + path - if self.url_munger: - path = self.url_munger(path) - return await self.http_client.post( - path, headers=headers, data=data, json=json - ) - - async def get( - self, path: str, *, headers: dict[str, str] | None = None - ) -> http.Response: - if self.base_url: - path = self.base_url + path - if self.url_munger: - path = self.url_munger(path) - return await self.http_client.get(path, headers=headers) - - async def __aenter__(self) -> Self: - return self - - async def __aexit__(self, *args) -> None: # type: ignore - pass diff --git a/edb/server/protocol/auth_ext/oauth.py b/edb/server/protocol/auth_ext/oauth.py index 73d1cf03666..1ed14ff1ccd 100644 --- a/edb/server/protocol/auth_ext/oauth.py +++ b/edb/server/protocol/auth_ext/oauth.py @@ -24,7 +24,7 @@ from edb.server.http import HttpClient from . import github, google, azure, apple, discord, slack -from . import config, errors, util, data, base, http_client as _http_client +from . import config, errors, util, data, base class Client: @@ -40,8 +40,8 @@ def __init__( ): self.db = db - http_factory = lambda *args, **kwargs: _http_client.AuthHttpClient( - *args, url_munger=url_munger, http_client=http_client, **kwargs # type: ignore + http_factory = lambda *args, **kwargs: http_client.with_context( + *args, url_munger=url_munger, **kwargs ) provider_config = self._get_provider_config(provider_name) diff --git a/edb/server/tenant.py b/edb/server/tenant.py index 71673081055..7b7b0a9a9a1 100644 --- a/edb/server/tenant.py +++ b/edb/server/tenant.py @@ -246,9 +246,18 @@ def set_server(self, server: edbserver.BaseServer) -> None: self._server = server self.__loop = server.get_loop() - def get_http_client(self) -> HttpClient: + def get_http_client(self, *, originator: str) -> HttpClient: if self._http_client is None: - self._http_client = HttpClient(HTTP_MAX_CONNECTIONS) + http_max_connections = self._server.config_lookup( + 'http_max_connections', self.get_sys_config() + ) + self._http_client = HttpClient( + http_max_connections, + user_agent=f"EdgeDB {buildmeta.get_version_string(short=True)}", + stat_callback=lambda stat: logger.debug( + f"HTTP stat: {originator} {stat}" + ), + ) return self._http_client def on_switch_over(self): diff --git a/pyproject.toml b/pyproject.toml index 3f5cb634bc3..3d47378a600 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,8 +18,6 @@ dependencies = [ 'psutil~=5.8', 'setproctitle~=1.2', - 'httpx~=0.24.1', - 'httpx-sse~=0.4.0', 'hishel==0.0.24', 'webauthn~=2.0.0', 'argon2-cffi~=23.1.0', @@ -83,6 +81,8 @@ test = [ 'sphinxcontrib-serializinghtml<1.1.10', 'sphinxcontrib-qthelp<1.0.7', 'sphinx_code_tabs~=0.5.3', + + 'httpx~=0.24.1', ] docs = [ diff --git a/tests/test_http.py b/tests/test_http.py new file mode 100644 index 00000000000..f1438d48dd6 --- /dev/null +++ b/tests/test_http.py @@ -0,0 +1,258 @@ +# +# This source file is part of the EdgeDB open source project. +# +# Copyright 2024-present MagicStack Inc. and the EdgeDB authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import asyncio +import json +import random + +from edb.server import http +from edb.testbase import http as tb + + +class HttpTest(tb.BaseHttpTest): + def setUp(self): + super().setUp() + self.mock_server = tb.MockHttpServer() + self.mock_server.start() + self.base_url = self.mock_server.get_base_url().rstrip("/") + + def tearDown(self): + if self.mock_server is not None: + self.mock_server.stop() + self.mock_server = None + super().tearDown() + + async def test_get(self): + with http.HttpClient(100) as client: + example_request = ( + 'GET', + self.base_url, + '/test-get-01', + ) + url = f"{example_request[1]}{example_request[2]}" + self.mock_server.register_route_handler(*example_request)( + ( + json.dumps( + { + "message": "Hello, world!", + } + ), + 200, + {"Content-Type": "application/json"}, + ) + ) + + result = await client.get(url) + self.assertEqual(result.status_code, 200) + self.assertEqual(result.json(), {"message": "Hello, world!"}) + + async def test_post(self): + with http.HttpClient(100) as client: + example_request = ( + 'POST', + self.base_url, + '/test-post-01', + ) + url = f"{example_request[1]}{example_request[2]}" + self.mock_server.register_route_handler(*example_request)( + lambda _handler, request: ( + request.body, + 200, + ) + ) + + random_data = [hex(x) for x in random.randbytes(10)] + result = await client.post( + url, json={"message": f"Hello, world! {random_data}"} + ) + self.assertEqual(result.status_code, 200) + self.assertEqual( + result.json(), {"message": f"Hello, world! {random_data}"} + ) + + async def test_post_with_headers(self): + with http.HttpClient(100) as client: + example_request = ( + 'POST', + self.base_url, + '/test-post-with-headers', + ) + url = f"{example_request[1]}{example_request[2]}" + self.mock_server.register_route_handler(*example_request)( + lambda _handler, request: ( + request.body, + 200, + {"X-Test": request.headers["x-test"] + "!"}, + ) + ) + random_data = [hex(x) for x in random.randbytes(10)] + result = await client.post( + url, + json={"message": f"Hello, world! {random_data}"}, + headers={"X-Test": "test"}, + ) + self.assertEqual(result.status_code, 200) + self.assertEqual( + result.json(), {"message": f"Hello, world! {random_data}"} + ) + self.assertEqual(result.headers["X-Test"], "test!") + + async def test_bad_url(self): + with http.HttpClient(100) as client: + with self.assertRaisesRegex(Exception, "Scheme"): + await client.get("httpx://uh-oh") + + async def test_immediate_connection_drop(self): + """Test handling of a connection that is dropped immediately by the + server""" + + async def mock_drop_server( + _reader: asyncio.StreamReader, writer: asyncio.StreamWriter + ): + # Close connection immediately without sending any response + writer.close() + await writer.wait_closed() + + server = await asyncio.start_server(mock_drop_server, '127.0.0.1', 0) + addr = server.sockets[0].getsockname() + url = f'http://{addr[0]}:{addr[1]}/drop' + + try: + with http.HttpClient(100) as client: + with self.assertRaisesRegex( + Exception, "Connection reset by peer" + ): + await client.get(url) + finally: + server.close() + await server.wait_closed() + + async def test_immediate_connection_drop_streaming(self): + """Test handling of a connection that is dropped immediately by the + server""" + + async def mock_drop_server( + _reader: asyncio.StreamReader, writer: asyncio.StreamWriter + ): + # Close connection immediately without sending any response + writer.close() + await writer.wait_closed() + + server = await asyncio.start_server(mock_drop_server, '127.0.0.1', 0) + addr = server.sockets[0].getsockname() + url = f'http://{addr[0]}:{addr[1]}/drop' + + try: + with http.HttpClient(100) as client: + with self.assertRaisesRegex( + Exception, "Connection reset by peer" + ): + await client.stream_sse(url) + finally: + server.close() + await server.wait_closed() + + async def test_streaming_get_with_no_sse(self): + with http.HttpClient(100) as client: + example_request = ( + 'GET', + self.base_url, + '/test-get-with-sse', + ) + url = f"{example_request[1]}{example_request[2]}" + self.mock_server.register_route_handler(*example_request)( + lambda _handler, request: ( + "\"ok\"", + 200, + ) + ) + result = await client.stream_sse(url, method="GET") + self.assertEqual(result.status_code, 200) + self.assertEqual(result.json(), "ok") + + async def test_sse_with_mock_server(self): + """Since the regular mock server doesn't support SSE, we need to test + with a real socket. We handle just enough HTTP to get the job done.""" + + is_closed = False + + async def mock_sse_server( + reader: asyncio.StreamReader, writer: asyncio.StreamWriter + ): + nonlocal is_closed + + await reader.readline() + + headers = ( + b"HTTP/1.1 200 OK\r\n" + b"Content-Type: text/event-stream\r\n" + b"Cache-Control: no-cache\r\n" + b"Connection: keep-alive\r\n\r\n" + ) + writer.write(headers) + await writer.drain() + + for i in range(3): + writer.write(b": test comment that should be ignored\n\n") + await writer.drain() + + writer.write( + f"event: message\ndata: Event {i + 1}\n\n".encode() + ) + await writer.drain() + await asyncio.sleep(0.1) + + # Write enough messages that we get a broken pipe. The response gets + # closed below and will refuse any further messages. + try: + for _ in range(50): + writer.writelines([b"event: message", b"data: XX", b""]) + await writer.drain() + writer.close() + await writer.wait_closed() + except Exception: + is_closed = True + + server = await asyncio.start_server(mock_sse_server, '127.0.0.1', 0) + addr = server.sockets[0].getsockname() + url = f'http://{addr[0]}:{addr[1]}/sse' + + async def client_task(): + with http.HttpClient(100) as client: + response = await client.stream_sse(url, method="GET") + assert response.status_code == 200 + assert response.headers['Content-Type'] == 'text/event-stream' + assert isinstance(response, http.ResponseSSE) + + events = [] + async for event in response: + self.assertEqual(event.event, 'message') + events.append(event) + if len(events) == 3: + break + + assert len(events) == 3 + assert events[0].data == 'Event 1' + assert events[1].data == 'Event 2' + assert events[2].data == 'Event 3' + + async with server: + client_future = asyncio.create_task(client_task()) + await asyncio.wait_for(client_future, timeout=5.0) + + assert is_closed