Skip to content

Commit

Permalink
SSE working
Browse files Browse the repository at this point in the history
  • Loading branch information
mmastrac committed Oct 25, 2024
1 parent cf60975 commit 467263f
Show file tree
Hide file tree
Showing 3 changed files with 318 additions and 96 deletions.
97 changes: 61 additions & 36 deletions edb/server/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,15 @@ def __init__(self, limit: int):
self._requests: dict[int, asyncio.Future] = {}

def __del__(self) -> None:
self.close()

def __enter__(self) -> HttpClient:
return self

def __exit__(self, exc_type, exc_value, traceback) -> None:
self.close()

def close(self) -> None:
if self._task:
self._task.cancel()
self._task = None
Expand All @@ -70,18 +79,22 @@ def _process_headers(self, headers: HeaderType) -> list[tuple[str, str]]:
return headers
raise ValueError(f"Invalid headers type: {type(headers)}")

def _process_content(self,
headers: list[tuple[str, str]],
data: bytes | str | dict[str, str] | None = None,
json: Any | None = None) -> bytes:
def _process_content(
self,
headers: list[tuple[str, str]],
data: bytes | str | dict[str, str] | None = None,
json: Any | None = None,
) -> bytes:
if json is not None:
data = json_lib.dumps(json).encode('utf-8')
headers.append(('Content-Type', 'application/json'))
elif isinstance(data, str):
data = data.encode('utf-8')
elif isinstance(data, dict):
data = urllib.parse.urlencode(data).encode('utf-8')
headers.append(('Content-Type', 'application/x-www-form-urlencoded'))
headers.append(
('Content-Type', 'application/x-www-form-urlencoded')
)
elif data is None:
data = bytes()
elif isinstance(data, bytes):
Expand Down Expand Up @@ -111,9 +124,7 @@ async def request(
finally:
del self._requests[id]

async def get(
self, path: str, *, headers: HeaderType = None
) -> Response:
async def get(self, path: str, *, headers: HeaderType = None) -> Response:
headers_list = self._process_headers(headers)
result = await self.request(
method="GET", path=path, data=None, headers=headers_list
Expand Down Expand Up @@ -151,14 +162,14 @@ async def stream_sse(
self._next_id += 1
self._requests[id] = asyncio.Future()
try:
self._client._request_sse(id, path=path, method=method, content=data, headers=headers_list)
self._client._request_sse(id, path, method, data, headers_list)
resp = await self._requests[id]
return resp
if id in self._streaming:
return ResponseSSE.from_tuple(resp, self._streaming[id])
return Response.from_tuple(resp)
finally:
del self._requests[id]

return Response.from_tuple(result)

async def _boot(self, loop: asyncio.AbstractEventLoop) -> None:
logger.info("Python-side HTTP client booted")
reader = asyncio.StreamReader(loop=loop)
Expand All @@ -180,28 +191,35 @@ async def _boot(self, loop: asyncio.AbstractEventLoop) -> None:
transport.close()

def _process_message(self, msg):
msg_type, id, *data = msg

if id in self._requests:
try:
msg_type, id, data = msg
if msg_type == 0: # Error
self._requests[id].set_exception(Exception(data[0]))
if id in self._requests:
self._requests[id].set_exception(Exception(data[0]))
elif msg_type == 1: # Response
self._requests[id].set_result(data[0])
if id in self._requests:
self._requests[id].set_result(data)
elif msg_type == 2: # SSEStart
self._requests[id].set_result(data[0])
if id in self._requests:
self._streaming[id] = asyncio.Queue()
self._requests[id].set_result(data)
elif msg_type == 3: # SSEEvent
# Set the result and re-arm the future
self._streaming[id][0].set_result(data[0])
self._streaming[id][0] = asyncio.Future()
if id in self._streaming:
self._streaming[id].put_nowait(data)
elif msg_type == 4: # SSEEnd
self._streaming[id][0].set_result(None)
del self._streaming[id]
if id in self._streaming:
self._streaming[id].put_nowait(None)
del self._streaming[id]
except Exception as e:
logger.error(f"Error processing message: {e}", exc_info=True)
raise


class CaseInsensitiveDict(dict):
def __init__(self, data: Optional[list[Tuple[str, str]]] = None):
super().__init__()
if data:
for k, v in data:
for k, v in data.items():
self[k.lower()] = v

def __setitem__(self, key: str, value: str):
Expand All @@ -226,11 +244,12 @@ def update(self, *args, **kwargs: str) -> None:
self[key] = value


@dataclasses.dataclass
@dataclasses.dataclass(frozen=True)
class Response:
status_code: int
body: bytes
headers: CaseInsensitiveDict
is_streaming: bool = False

@classmethod
def from_tuple(cls, data: Tuple[int, bytes, dict[str, str]]):
Expand All @@ -245,28 +264,34 @@ def json(self):
def text(self) -> str:
return self.body.decode('utf-8')


@dataclasses.dataclass(frozen=True)
class ResponseSSE:
status_code: int
headers: CaseInsensitiveDict
_stream: asyncio.Queue = dataclasses.field(repr=False)
is_streaming: bool = True

def __init__(self, status_code: int, headers: dict[str, str], stream: list[asyncio.Future]):
self.status_code = status_code
self.headers = CaseInsensitiveDict(headers)
self._buffer = b''
self._stream = stream

@dataclasses.dataclass
@classmethod
def from_tuple(
cls, data: Tuple[int, dict[str, str]], stream: asyncio.Queue
):
status_code, headers = data
headers = CaseInsensitiveDict(headers)
return cls(status_code, headers, stream)

@dataclasses.dataclass(frozen=True)
class SSEEvent:
event: str
data: str
id: Optional[str] = None
retry: Optional[int] = None

async def __aiter__(self):
def __aiter__(self):
return self

async def __anext__(self):
next = await self._stream[0]
next = await self._stream.get()
if next is None:
raise StopAsyncIteration
return next
id, data, event = next
return self.SSEEvent(event, data, id)
132 changes: 72 additions & 60 deletions edb/server/http/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ impl ToPyObject for RustToPythonMessage {
.to_object(py),
SSEStart(conn, (status, headers)) => (2, conn, (status, headers)).to_object(py),
SSEEvent(conn, message) => {
(3, conn, &message.id, &message.data, &message.event).to_object(py)
(3, conn, (&message.id, &message.data, &message.event)).to_object(py)
}
SSEEnd(conn) => (4, conn).to_object(py),
SSEEnd(conn) => (4, conn, ()).to_object(py),
}
}
}
Expand Down Expand Up @@ -119,8 +119,9 @@ async fn request(
body: Vec<u8>,
headers: Vec<(String, String)>,
) -> Result<reqwest::Response, String> {
eprintln!("request: {method} {url}");
let method =
Method::from_bytes(method.as_bytes()).map_err(|e| format!("Invalid HTTP method: {}", e))?;
Method::from_bytes(method.as_bytes()).map_err(|e| format!("Invalid HTTP method: {e:?}"))?;

let mut req = client.request(method, url);

Expand All @@ -135,7 +136,7 @@ async fn request(
let resp = req
.send()
.await
.map_err(|e| format!("Request failed: {}", e))?;
.map_err(|e| format!("Request failed: {e:?}"))?;

Ok(resp)
}
Expand All @@ -153,12 +154,74 @@ async fn request_bytes(
let body = resp
.bytes()
.await
.map_err(|e| format!("Failed to read response body: {}", e))?
.map_err(|e| format!("Failed to read response body: {e:?}"))?
.to_vec();

Ok((status, body, headers))
}

async fn request_sse(
client: reqwest::Client,
id: PythonConnId,
url: String,
method: String,
body: Vec<u8>,
headers: Vec<(String, String)>,
rpc_pipe: Rc<RpcPipe>,
) -> Result<(), String> {
let response = request(client, url, method, body, headers).await?;

if response.headers().get("content-type")
!= Some(&HeaderValue::from_static("text/event-stream"))
{
let headers = process_headers(response.headers());
let status = response.status();
let body = match response.bytes().await {
Ok(bytes) => bytes.to_vec(),
Err(e) => {
return Err(format!("Failed to read response body: {e:?}"));
}
};
_ = rpc_pipe
.write(RustToPythonMessage::Response(
id,
(status.as_u16(), body, headers),
))
.await;

return Ok(());
}

let headers = process_headers(response.headers());
let status = response.status();
_ = rpc_pipe
.write(RustToPythonMessage::SSEStart(
id,
(status.as_u16(), headers.clone()),
))
.await;

let mut stream = response.bytes_stream().eventsource();
loop {
let chunk = match stream.try_next().await {
Ok(None) => break,
Ok(Some(chunk)) => chunk,
Err(e) => {
return Err(format!("Failed to read response body: {e:?}"));
}
};
if rpc_pipe
.write(RustToPythonMessage::SSEEvent(id, chunk))
.await
.is_err()
{
break;
}
}

Ok(())
}

fn process_headers(headers: &reqwest::header::HeaderMap) -> HashMap<String, String> {
headers
.iter()
Expand Down Expand Up @@ -378,63 +441,12 @@ async fn execute(
let Ok(permit) = permit_manager.acquire().await else {
return;
};
match request(client, url, method, body, headers).await {
Ok(response) => {
// We may still receive a non-SSE response, so we need to check
// the content type.
if response.headers().get("content-type") != Some(&HeaderValue::from_static("text/event-stream")) {
let headers = process_headers(response.headers());
let status = response.status();
let body = match response.bytes().await {
Ok(bytes) => bytes.to_vec(),
Err(e) => {
_ = rpc_pipe.write(RustToPythonMessage::Error(id, format!("Failed to read response body: {}", e))).await;
return;
}
};
_ = rpc_pipe
.write(RustToPythonMessage::Response(
id,
(status.as_u16(), body, headers),
))
.await;
return;
}

let headers = process_headers(response.headers());
let status = response.status();
match request_sse(client, id, url, method, body, headers, rpc_pipe.clone()).await {
Ok(..) => {}
Err(err) => {
_ = rpc_pipe
.write(RustToPythonMessage::SSEStart(
id,
(status.as_u16(), headers),
))
.write(RustToPythonMessage::Error(id, format!("SSE error: {err}")))
.await;
let mut stream = response.bytes_stream().eventsource();
loop {
let chunk = match stream.try_next().await {
Ok(None) => break,
Ok(Some(chunk)) => chunk,
Err(e) => {
_ = rpc_pipe
.write(RustToPythonMessage::Error(
id,
format!("Failed to read response body: {}", e),
))
.await;
return;
}
};
if rpc_pipe
.write(RustToPythonMessage::SSEEvent(id, chunk))
.await
.is_err()
{
return;
}
}
}
Err(err) => {
_ = rpc_pipe.write(RustToPythonMessage::Error(id, err)).await;
}
}
drop(permit);
Expand Down
Loading

0 comments on commit 467263f

Please sign in to comment.