Skip to content

Commit 1ee3dfa

Browse files
committed
Fix accidental truncation of RPC timeouts
1 parent 92c3da0 commit 1ee3dfa

File tree

2 files changed

+71
-1
lines changed

2 files changed

+71
-1
lines changed

replit_river/client_session.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ async def send_rpc(
6262
# Handle potential errors during communication
6363
try:
6464
try:
65-
async with asyncio.timeout(int(timeout.total_seconds())):
65+
async with asyncio.timeout(timeout.total_seconds()):
6666
response = await output.get()
6767
except asyncio.TimeoutError as e:
6868
# TODO(dstewart) After protocol v2, change this to STREAM_CANCEL_BIT

tests/test_timeout.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import asyncio
2+
from datetime import timedelta
3+
4+
import grpc
5+
import grpc.aio
6+
import pytest
7+
8+
from replit_river.client import Client
9+
from replit_river.error_schema import ERROR_CODE_CANCEL, RiverException
10+
from tests.common_handlers import (
11+
rpc_method_handler,
12+
)
13+
from tests.conftest import (
14+
HandlerMapping,
15+
deserialize_error,
16+
deserialize_response,
17+
serialize_response,
18+
)
19+
20+
21+
async def rpc_slow_handler(duration: float, context: grpc.aio.ServicerContext) -> str:
22+
await asyncio.sleep(duration)
23+
return "I'm sleepy"
24+
25+
26+
def serialize_request(request: float) -> dict:
27+
return {"data": request}
28+
29+
30+
def deserialize_request(request: dict) -> float:
31+
return request.get("data") or 0
32+
33+
34+
slow_rpc: HandlerMapping = {
35+
("test_service", "rpc_method"): (
36+
"rpc",
37+
rpc_method_handler(rpc_slow_handler, deserialize_request, serialize_response),
38+
),
39+
}
40+
41+
42+
@pytest.mark.asyncio
43+
@pytest.mark.parametrize("handlers", [{**slow_rpc}])
44+
async def test_no_timeout(client: Client) -> None:
45+
res = await client.send_rpc(
46+
"test_service",
47+
"rpc_method",
48+
0.01,
49+
serialize_request,
50+
deserialize_response,
51+
deserialize_error,
52+
timeout=timedelta(seconds=0.1),
53+
)
54+
assert res == "I'm sleepy"
55+
56+
57+
@pytest.mark.asyncio
58+
@pytest.mark.parametrize("handlers", [{**slow_rpc}])
59+
async def test_timeout(client: Client) -> None:
60+
with pytest.raises(RiverException) as exc_info:
61+
await client.send_rpc(
62+
"test_service",
63+
"rpc_method",
64+
2.0,
65+
serialize_request,
66+
deserialize_response,
67+
deserialize_error,
68+
timeout=timedelta(seconds=0.01),
69+
)
70+
assert exc_info.value.code == ERROR_CODE_CANCEL

0 commit comments

Comments
 (0)