Skip to content

Commit

Permalink
WIP: full http
Browse files Browse the repository at this point in the history
  • Loading branch information
mmastrac committed Oct 25, 2024
1 parent e838e48 commit cf60975
Show file tree
Hide file tree
Showing 6 changed files with 379 additions and 127 deletions.
80 changes: 46 additions & 34 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

143 changes: 112 additions & 31 deletions edb/server/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
Any,
Mapping,
Optional,
Union,
)

import asyncio
Expand All @@ -37,6 +38,9 @@
logger = logging.getLogger("edb.server")


HeaderType = Optional[Union[list[tuple[str, str]], dict[str, str]]]


class HttpClient:
def __init__(self, limit: int):
self._client = Http(limit)
Expand All @@ -45,6 +49,7 @@ def __init__(self, limit: int):
self._skip_reads = 0
self._loop = asyncio.get_running_loop()
self._task = self._loop.create_task(self._boot(self._loop))
self._streaming = {}
self._next_id = 0
self._requests: dict[int, asyncio.Future] = {}

Expand All @@ -56,62 +61,104 @@ def __del__(self) -> None:
def _update_limit(self, limit: int):
self._client._update_limit(limit)

def _process_headers(self, headers: HeaderType) -> list[tuple[str, str]]:
if headers is None:
return []
if isinstance(headers, Mapping):
return [(k, v) for k, v in headers.items()]
if isinstance(headers, list):
return headers
raise ValueError(f"Invalid headers type: {type(headers)}")

def _process_content(self,
headers: list[tuple[str, str]],
data: bytes | str | dict[str, str] | None = None,
json: Any | None = None) -> bytes:
if json is not None:
data = json_lib.dumps(json).encode('utf-8')
headers.append(('Content-Type', 'application/json'))
elif isinstance(data, str):
data = data.encode('utf-8')
elif isinstance(data, dict):
data = urllib.parse.urlencode(data).encode('utf-8')
headers.append(('Content-Type', 'application/x-www-form-urlencoded'))
elif data is None:
data = bytes()
elif isinstance(data, bytes):
pass
else:
raise ValueError(f"Invalid content type: {type(data)}")
return data

async def request(
self,
*,
method: str,
url: str,
content: bytes | None,
headers: list[tuple[str, str]] | None,
path: str,
headers: HeaderType = None,
data: bytes | str | dict[str, str] | None = None,
json: Any | None = None,
) -> tuple[int, bytes, dict[str, str]]:
if content is None:
content = bytes()
if headers is None:
headers = []
headers_list = self._process_headers(headers)
data = self._process_content(headers_list, data, json)
id = self._next_id
self._next_id += 1
self._requests[id] = asyncio.Future()
try:
self._client._request(id, url, method, content, headers)
self._client._request(id, path, method, data, headers_list)
resp = await self._requests[id]
return resp
finally:
del self._requests[id]

async def get(
self, path: str, *, headers: dict[str, str] | None = None
self, path: str, *, headers: HeaderType = None
) -> Response:
headers_list = [(k, v) for k, v in headers.items()] if headers else None
headers_list = self._process_headers(headers)
result = await self.request(
method="GET", url=path, content=None, headers=headers_list
method="GET", path=path, data=None, headers=headers_list
)
return Response.from_tuple(result)

async def post(
self,
path: str,
*,
headers: dict[str, str] | None = None,
headers: HeaderType = None,
data: bytes | str | dict[str, str] | None = None,
json: Any | None = None,
) -> Response:
if json is not None:
data = json_lib.dumps(json).encode('utf-8')
headers = headers or {}
headers['Content-Type'] = 'application/json'
elif isinstance(data, str):
data = data.encode('utf-8')
elif isinstance(data, dict):
data = urllib.parse.urlencode(data).encode('utf-8')
headers = headers or {}
headers['Content-Type'] = 'application/x-www-form-urlencoded'

headers_list = [(k, v) for k, v in headers.items()] if headers else None
headers_list = self._process_headers(headers)
data = self._process_content(headers_list, data, json)
result = await self.request(
method="POST", url=path, content=data, headers=headers_list
method="POST", path=path, data=data, headers=headers_list
)
return Response.from_tuple(result)

async def stream_sse(
self,
path: str,
*,
method: str = "POST",
headers: HeaderType = None,
data: bytes | str | dict[str, str] | None = None,
json: Any | None = None,
) -> Response | ResponseSSE:
headers_list = self._process_headers(headers)
data = self._process_content(headers_list, data, json)

id = self._next_id
self._next_id += 1
self._requests[id] = asyncio.Future()
try:
self._client._request_sse(id, path=path, method=method, content=data, headers=headers_list)
resp = await self._requests[id]
return resp
finally:
del self._requests[id]

return Response.from_tuple(result)

async def _boot(self, loop: asyncio.AbstractEventLoop) -> None:
logger.info("Python-side HTTP client booted")
reader = asyncio.StreamReader(loop=loop)
Expand All @@ -133,14 +180,22 @@ async def _boot(self, loop: asyncio.AbstractEventLoop) -> None:
transport.close()

def _process_message(self, msg):
msg_type, id, data = msg
msg_type, id, *data = msg

if id in self._requests:
if msg_type == 1:
self._requests[id].set_result(data)
elif msg_type == 0:
self._requests[id].set_exception(Exception(data))

if msg_type == 0: # Error
self._requests[id].set_exception(Exception(data[0]))
elif msg_type == 1: # Response
self._requests[id].set_result(data[0])
elif msg_type == 2: # SSEStart
self._requests[id].set_result(data[0])
elif msg_type == 3: # SSEEvent
# Set the result and re-arm the future
self._streaming[id][0].set_result(data[0])
self._streaming[id][0] = asyncio.Future()
elif msg_type == 4: # SSEEnd
self._streaming[id][0].set_result(None)
del self._streaming[id]

class CaseInsensitiveDict(dict):
def __init__(self, data: Optional[list[Tuple[str, str]]] = None):
Expand Down Expand Up @@ -189,3 +244,29 @@ def json(self):
@property
def text(self) -> str:
return self.body.decode('utf-8')

class ResponseSSE:
status_code: int
headers: CaseInsensitiveDict

def __init__(self, status_code: int, headers: dict[str, str], stream: list[asyncio.Future]):
self.status_code = status_code
self.headers = CaseInsensitiveDict(headers)
self._buffer = b''
self._stream = stream

@dataclasses.dataclass
class SSEEvent:
event: str
data: str
id: Optional[str] = None
retry: Optional[int] = None

async def __aiter__(self):
return self

async def __anext__(self):
next = await self._stream[0]
if next is None:
raise StopAsyncIteration
return next
3 changes: 2 additions & 1 deletion edb/server/http/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@ pyo3 = { workspace = true, optional = true }
tokio.workspace = true
tracing = "0"
tracing-subscriber = "0"
reqwest = { version = "0.12", features = ["gzip", "deflate"] }
reqwest = { version = "0.12", features = ["gzip", "deflate", "stream"] }
scopeguard = "1"
eventsource-stream = "0.2.3"

futures = "0"

Expand Down
Loading

0 comments on commit cf60975

Please sign in to comment.