diff --git a/edb/server/http/src/python.rs b/edb/server/http/src/python.rs index e48be467b0d..d42f38bf4c5 100644 --- a/edb/server/http/src/python.rs +++ b/edb/server/http/src/python.rs @@ -218,10 +218,15 @@ async fn request_sse( return Err(format!("Failed to read response body: {e:?}")); } }; + + // Note that we use semaphores here in a strange way, but basically we + // want to have per-stream backpressure to avoid buffering messages + // indefinitely. let Ok(permit) = backpressure.acquire().await else { break; }; permit.forget(); + if rpc_pipe .write(RustToPythonMessage::SSEEvent(id, chunk)) .await @@ -464,8 +469,14 @@ async fn execute( drop(permit); } RequestSse(id, url, method, body, headers) => { - // Ensure we send the end message whenever this block exits - defer!(_ = rpc_pipe.write(RustToPythonMessage::SSEEnd(id))); + // Ensure we send the end message whenever this block exits (though + // we need to spawn a task to do so) + defer!({ + let rpc_pipe = rpc_pipe.clone(); + tokio::task::spawn_local(async move { + _ = rpc_pipe.write(RustToPythonMessage::SSEEnd(id)).await; + }); + }); let Ok(permit) = permit_manager.acquire().await else { return; }; diff --git a/tests/test_http.py b/tests/test_http.py index f1438d48dd6..1afe0ffc34c 100644 --- a/tests/test_http.py +++ b/tests/test_http.py @@ -122,9 +122,11 @@ async def test_immediate_connection_drop(self): server""" async def mock_drop_server( - _reader: asyncio.StreamReader, writer: asyncio.StreamWriter + reader: asyncio.StreamReader, writer: asyncio.StreamWriter ): - # Close connection immediately without sending any response + # Close connection immediately after reading a byte without sending + # any response + await reader.read(1) writer.close() await writer.wait_closed() @@ -135,21 +137,43 @@ async def mock_drop_server( try: with http.HttpClient(100) as client: with self.assertRaisesRegex( - Exception, "Connection reset by peer" + Exception, "Connection reset by peer|IncompleteMessage" ): await client.get(url) finally: server.close() await server.wait_closed() + async def test_streaming_get_with_no_sse(self): + with http.HttpClient(100) as client: + example_request = ( + 'GET', + self.base_url, + '/test-get-with-sse', + ) + url = f"{example_request[1]}{example_request[2]}" + self.mock_server.register_route_handler(*example_request)( + lambda _handler, request: ( + "\"ok\"", + 200, + ) + ) + result = await client.stream_sse(url, method="GET") + self.assertEqual(result.status_code, 200) + self.assertEqual(result.json(), "ok") + + +class HttpSSETest(tb.BaseHttpTest): async def test_immediate_connection_drop_streaming(self): """Test handling of a connection that is dropped immediately by the server""" async def mock_drop_server( - _reader: asyncio.StreamReader, writer: asyncio.StreamWriter + reader: asyncio.StreamReader, writer: asyncio.StreamWriter ): - # Close connection immediately without sending any response + # Close connection immediately after reading a byte without sending + # any response + await reader.read(1) writer.close() await writer.wait_closed() @@ -160,31 +184,13 @@ async def mock_drop_server( try: with http.HttpClient(100) as client: with self.assertRaisesRegex( - Exception, "Connection reset by peer" + Exception, "Connection reset by peer|IncompleteMessage" ): await client.stream_sse(url) finally: server.close() await server.wait_closed() - async def test_streaming_get_with_no_sse(self): - with http.HttpClient(100) as client: - example_request = ( - 'GET', - self.base_url, - '/test-get-with-sse', - ) - url = f"{example_request[1]}{example_request[2]}" - self.mock_server.register_route_handler(*example_request)( - lambda _handler, request: ( - "\"ok\"", - 200, - ) - ) - result = await client.stream_sse(url, method="GET") - self.assertEqual(result.status_code, 200) - self.assertEqual(result.json(), "ok") - async def test_sse_with_mock_server(self): """Since the regular mock server doesn't support SSE, we need to test with a real socket. We handle just enough HTTP to get the job done.""" @@ -256,3 +262,70 @@ async def client_task(): await asyncio.wait_for(client_future, timeout=5.0) assert is_closed + + async def test_sse_with_mock_server_close(self): + """Try to close the server-side stream and see if the client detects + an end for the iterator. Note that this is technically not correct SSE: + the client should actually try to reconnect after the specified retry + interval, _but_ we don't handle retries yet.""" + + is_closed = False + + async def mock_sse_server( + reader: asyncio.StreamReader, writer: asyncio.StreamWriter + ): + nonlocal is_closed + + await reader.readline() + + headers = ( + b"HTTP/1.1 200 OK\r\n" + b"Content-Type: text/event-stream\r\n" + b"Cache-Control: no-cache\r\n" + b"Connection: keep-alive\r\n\r\n" + ) + writer.write(headers) + await writer.drain() + + for i in range(3): + writer.write(b": test comment that should be ignored\n\n") + await writer.drain() + + writer.write( + f"event: message\ndata: Event {i + 1}\n\n".encode() + ) + await writer.drain() + await asyncio.sleep(0.1) + + await writer.drain() + writer.close() + await writer.wait_closed() + + is_closed = True + + server = await asyncio.start_server(mock_sse_server, '127.0.0.1', 0) + addr = server.sockets[0].getsockname() + url = f'http://{addr[0]}:{addr[1]}/sse' + + async def client_task(): + with http.HttpClient(100) as client: + response = await client.stream_sse(url, method="GET") + assert response.status_code == 200 + assert response.headers['Content-Type'] == 'text/event-stream' + assert isinstance(response, http.ResponseSSE) + + events = [] + async for event in response: + self.assertEqual(event.event, 'message') + events.append(event) + + assert len(events) == 3 + assert events[0].data == 'Event 1' + assert events[1].data == 'Event 2' + assert events[2].data == 'Event 3' + + async with server: + client_future = asyncio.create_task(client_task()) + await asyncio.wait_for(client_future, timeout=5.0) + + assert is_closed