Skip to content

Commit

Permalink
HTTP streaming fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
mmastrac committed Nov 12, 2024
1 parent 6e0cdaa commit fadfb7c
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 26 deletions.
15 changes: 13 additions & 2 deletions edb/server/http/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -218,10 +218,15 @@ async fn request_sse(
return Err(format!("Failed to read response body: {e:?}"));
}
};

// Note that we use semaphores here in a strange way, but basically we
// want to have per-stream backpressure to avoid buffering messages
// indefinitely.
let Ok(permit) = backpressure.acquire().await else {
break;
};
permit.forget();

if rpc_pipe
.write(RustToPythonMessage::SSEEvent(id, chunk))
.await
Expand Down Expand Up @@ -464,8 +469,14 @@ async fn execute(
drop(permit);
}
RequestSse(id, url, method, body, headers) => {
// Ensure we send the end message whenever this block exits
defer!(_ = rpc_pipe.write(RustToPythonMessage::SSEEnd(id)));
// Ensure we send the end message whenever this block exits (though
// we need to spawn a task to do so)
defer!({
let rpc_pipe = rpc_pipe.clone();
tokio::task::spawn_local(async move {
_ = rpc_pipe.write(RustToPythonMessage::SSEEnd(id)).await;
});
});
let Ok(permit) = permit_manager.acquire().await else {
return;
};
Expand Down
121 changes: 97 additions & 24 deletions tests/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,11 @@ async def test_immediate_connection_drop(self):
server"""

async def mock_drop_server(
_reader: asyncio.StreamReader, writer: asyncio.StreamWriter
reader: asyncio.StreamReader, writer: asyncio.StreamWriter
):
# Close connection immediately without sending any response
# Close connection immediately after reading a byte without sending
# any response
await reader.read(1)
writer.close()
await writer.wait_closed()

Expand All @@ -135,21 +137,43 @@ async def mock_drop_server(
try:
with http.HttpClient(100) as client:
with self.assertRaisesRegex(
Exception, "Connection reset by peer"
Exception, "Connection reset by peer|IncompleteMessage"
):
await client.get(url)
finally:
server.close()
await server.wait_closed()

async def test_streaming_get_with_no_sse(self):
with http.HttpClient(100) as client:
example_request = (
'GET',
self.base_url,
'/test-get-with-sse',
)
url = f"{example_request[1]}{example_request[2]}"
self.mock_server.register_route_handler(*example_request)(
lambda _handler, request: (
"\"ok\"",
200,
)
)
result = await client.stream_sse(url, method="GET")
self.assertEqual(result.status_code, 200)
self.assertEqual(result.json(), "ok")


class HttpSSETest(tb.BaseHttpTest):
async def test_immediate_connection_drop_streaming(self):
"""Test handling of a connection that is dropped immediately by the
server"""

async def mock_drop_server(
_reader: asyncio.StreamReader, writer: asyncio.StreamWriter
reader: asyncio.StreamReader, writer: asyncio.StreamWriter
):
# Close connection immediately without sending any response
# Close connection immediately after reading a byte without sending
# any response
await reader.read(1)
writer.close()
await writer.wait_closed()

Expand All @@ -160,31 +184,13 @@ async def mock_drop_server(
try:
with http.HttpClient(100) as client:
with self.assertRaisesRegex(
Exception, "Connection reset by peer"
Exception, "Connection reset by peer|IncompleteMessage"
):
await client.stream_sse(url)
finally:
server.close()
await server.wait_closed()

async def test_streaming_get_with_no_sse(self):
with http.HttpClient(100) as client:
example_request = (
'GET',
self.base_url,
'/test-get-with-sse',
)
url = f"{example_request[1]}{example_request[2]}"
self.mock_server.register_route_handler(*example_request)(
lambda _handler, request: (
"\"ok\"",
200,
)
)
result = await client.stream_sse(url, method="GET")
self.assertEqual(result.status_code, 200)
self.assertEqual(result.json(), "ok")

async def test_sse_with_mock_server(self):
"""Since the regular mock server doesn't support SSE, we need to test
with a real socket. We handle just enough HTTP to get the job done."""
Expand Down Expand Up @@ -256,3 +262,70 @@ async def client_task():
await asyncio.wait_for(client_future, timeout=5.0)

assert is_closed

async def test_sse_with_mock_server_close(self):
"""Try to close the server-side stream and see if the client detects
an end for the iterator. Note that this is technically not correct SSE:
the client should actually try to reconnect after the specified retry
interval, _but_ we don't handle retries yet."""

is_closed = False

async def mock_sse_server(
reader: asyncio.StreamReader, writer: asyncio.StreamWriter
):
nonlocal is_closed

await reader.readline()

headers = (
b"HTTP/1.1 200 OK\r\n"
b"Content-Type: text/event-stream\r\n"
b"Cache-Control: no-cache\r\n"
b"Connection: keep-alive\r\n\r\n"
)
writer.write(headers)
await writer.drain()

for i in range(3):
writer.write(b": test comment that should be ignored\n\n")
await writer.drain()

writer.write(
f"event: message\ndata: Event {i + 1}\n\n".encode()
)
await writer.drain()
await asyncio.sleep(0.1)

await writer.drain()
writer.close()
await writer.wait_closed()

is_closed = True

server = await asyncio.start_server(mock_sse_server, '127.0.0.1', 0)
addr = server.sockets[0].getsockname()
url = f'http://{addr[0]}:{addr[1]}/sse'

async def client_task():
with http.HttpClient(100) as client:
response = await client.stream_sse(url, method="GET")
assert response.status_code == 200
assert response.headers['Content-Type'] == 'text/event-stream'
assert isinstance(response, http.ResponseSSE)

events = []
async for event in response:
self.assertEqual(event.event, 'message')
events.append(event)

assert len(events) == 3
assert events[0].data == 'Event 1'
assert events[1].data == 'Event 2'
assert events[2].data == 'Event 3'

async with server:
client_future = asyncio.create_task(client_task())
await asyncio.wait_for(client_future, timeout=5.0)

assert is_closed

0 comments on commit fadfb7c

Please sign in to comment.