From 1ee3dfaec033f93d83564fb1bf6fc7bb4702e447 Mon Sep 17 00:00:00 2001 From: Connor Brewster Date: Tue, 10 Dec 2024 11:05:33 -0600 Subject: [PATCH] Fix accidental truncation of RPC timeouts --- replit_river/client_session.py | 2 +- tests/test_timeout.py | 70 ++++++++++++++++++++++++++++++++++ 2 files changed, 71 insertions(+), 1 deletion(-) create mode 100644 tests/test_timeout.py diff --git a/replit_river/client_session.py b/replit_river/client_session.py index 1deb280..ecd4337 100644 --- a/replit_river/client_session.py +++ b/replit_river/client_session.py @@ -62,7 +62,7 @@ async def send_rpc( # Handle potential errors during communication try: try: - async with asyncio.timeout(int(timeout.total_seconds())): + async with asyncio.timeout(timeout.total_seconds()): response = await output.get() except asyncio.TimeoutError as e: # TODO(dstewart) After protocol v2, change this to STREAM_CANCEL_BIT diff --git a/tests/test_timeout.py b/tests/test_timeout.py new file mode 100644 index 0000000..94e033f --- /dev/null +++ b/tests/test_timeout.py @@ -0,0 +1,70 @@ +import asyncio +from datetime import timedelta + +import grpc +import grpc.aio +import pytest + +from replit_river.client import Client +from replit_river.error_schema import ERROR_CODE_CANCEL, RiverException +from tests.common_handlers import ( + rpc_method_handler, +) +from tests.conftest import ( + HandlerMapping, + deserialize_error, + deserialize_response, + serialize_response, +) + + +async def rpc_slow_handler(duration: float, context: grpc.aio.ServicerContext) -> str: + await asyncio.sleep(duration) + return "I'm sleepy" + + +def serialize_request(request: float) -> dict: + return {"data": request} + + +def deserialize_request(request: dict) -> float: + return request.get("data") or 0 + + +slow_rpc: HandlerMapping = { + ("test_service", "rpc_method"): ( + "rpc", + rpc_method_handler(rpc_slow_handler, deserialize_request, serialize_response), + ), +} + + +@pytest.mark.asyncio +@pytest.mark.parametrize("handlers", [{**slow_rpc}]) +async def test_no_timeout(client: Client) -> None: + res = await client.send_rpc( + "test_service", + "rpc_method", + 0.01, + serialize_request, + deserialize_response, + deserialize_error, + timeout=timedelta(seconds=0.1), + ) + assert res == "I'm sleepy" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("handlers", [{**slow_rpc}]) +async def test_timeout(client: Client) -> None: + with pytest.raises(RiverException) as exc_info: + await client.send_rpc( + "test_service", + "rpc_method", + 2.0, + serialize_request, + deserialize_response, + deserialize_error, + timeout=timedelta(seconds=0.01), + ) + assert exc_info.value.code == ERROR_CODE_CANCEL