Skip to content

Commit

Permalink
Fix accidental truncation of RPC timeouts
Browse files Browse the repository at this point in the history
  • Loading branch information
cbrewster committed Dec 10, 2024
1 parent 92c3da0 commit 1ee3dfa
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 1 deletion.
2 changes: 1 addition & 1 deletion replit_river/client_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ async def send_rpc(
# Handle potential errors during communication
try:
try:
async with asyncio.timeout(int(timeout.total_seconds())):
async with asyncio.timeout(timeout.total_seconds()):
response = await output.get()
except asyncio.TimeoutError as e:
# TODO(dstewart) After protocol v2, change this to STREAM_CANCEL_BIT
Expand Down
70 changes: 70 additions & 0 deletions tests/test_timeout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import asyncio
from datetime import timedelta

import grpc
import grpc.aio
import pytest

from replit_river.client import Client
from replit_river.error_schema import ERROR_CODE_CANCEL, RiverException
from tests.common_handlers import (
rpc_method_handler,
)
from tests.conftest import (
HandlerMapping,
deserialize_error,
deserialize_response,
serialize_response,
)


async def rpc_slow_handler(duration: float, context: grpc.aio.ServicerContext) -> str:
await asyncio.sleep(duration)
return "I'm sleepy"


def serialize_request(request: float) -> dict:
return {"data": request}


def deserialize_request(request: dict) -> float:
return request.get("data") or 0


slow_rpc: HandlerMapping = {
("test_service", "rpc_method"): (
"rpc",
rpc_method_handler(rpc_slow_handler, deserialize_request, serialize_response),
),
}


@pytest.mark.asyncio
@pytest.mark.parametrize("handlers", [{**slow_rpc}])
async def test_no_timeout(client: Client) -> None:
res = await client.send_rpc(
"test_service",
"rpc_method",
0.01,
serialize_request,
deserialize_response,
deserialize_error,
timeout=timedelta(seconds=0.1),
)
assert res == "I'm sleepy"


@pytest.mark.asyncio
@pytest.mark.parametrize("handlers", [{**slow_rpc}])
async def test_timeout(client: Client) -> None:
with pytest.raises(RiverException) as exc_info:
await client.send_rpc(
"test_service",
"rpc_method",
2.0,
serialize_request,
deserialize_response,
deserialize_error,
timeout=timedelta(seconds=0.01),
)
assert exc_info.value.code == ERROR_CODE_CANCEL

0 comments on commit 1ee3dfa

Please sign in to comment.