Skip to content

Commit

Permalink
Retry logic
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenthebuilder committed Apr 20, 2024
1 parent fa2555d commit c2430fc
Show file tree
Hide file tree
Showing 7 changed files with 132 additions and 49 deletions.
54 changes: 11 additions & 43 deletions replit_river/client.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,9 @@
import asyncio
import logging
from collections.abc import AsyncIterable, AsyncIterator
from typing import Any, Callable, Optional, Union

import nanoid # type: ignore
from websockets.client import WebSocketClientProtocol

from replit_river.client_session import ClientSession
from replit_river.client_transport import ClientTransport
from replit_river.error_schema import RiverException
from replit_river.task_manager import BackgroundTaskManager
from replit_river.transport_options import TransportOptions

from .rpc import (
Expand All @@ -23,49 +17,23 @@
class Client:
def __init__(
self,
# websocket constructor to reconnect
websockets: WebSocketClientProtocol,
websocket_uri: str,
client_id: str,
server_id: str,
transport_options: TransportOptions,
) -> None:
self._ws = websockets
self._client_id = client_id
self._server_id = server_id
self._instance_id = str(nanoid.generate())
self._transport_options = transport_options
self._transport = ClientTransport(
transport_id=self._server_id,
websocket_uri=websocket_uri,
client_id=client_id,
server_id=server_id,
transport_options=transport_options,
is_server=True,
)
self._client_session: Optional[ClientSession] = None
self._task_manager = BackgroundTaskManager()
asyncio.create_task(self._task_manager.create_task(self._create_session()))

async def close(self) -> None:
if self._client_session is not None:
await self._task_manager.cancel_all_tasks()
await self._client_session.close()
await self._transport.close_all_sessions()

async def _create_session(self) -> None:
try:
logging.debug("Client start creating session")
client_session = await self._transport.create_client_session(
self._client_id, self._server_id, self._instance_id, self._ws
)
except RiverException as e:
# TODO: we need some retry logic here.
logging.error(f"Error creating session: {e}")
return
self._client_session = client_session
logging.debug("client start serving messages")
await self._client_session.start_serve_messages()

async def _wait_for_handshake(self) -> ClientSession:
while self._client_session is None:
await asyncio.sleep(0.1)
return self._client_session
async def _get_or_create_session(self) -> ClientSession:
return await self._transport._get_or_create_session()

async def send_rpc(
self,
Expand All @@ -76,7 +44,7 @@ async def send_rpc(
response_deserializer: Callable[[Any], ResponseType],
error_deserializer: Callable[[Any], ErrorType],
) -> ResponseType:
session = await self._wait_for_handshake()
session = await self._get_or_create_session()
return await session.send_rpc(
service_name,
procedure_name,
Expand All @@ -97,7 +65,7 @@ async def send_upload(
response_deserializer: Callable[[Any], ResponseType],
error_deserializer: Callable[[Any], ErrorType],
) -> ResponseType:
session = await self._wait_for_handshake()
session = await self._get_or_create_session()
return await session.send_upload(
service_name,
procedure_name,
Expand All @@ -118,7 +86,7 @@ async def send_subscription(
response_deserializer: Callable[[Any], ResponseType],
error_deserializer: Callable[[Any], ErrorType],
) -> AsyncIterator[Union[ResponseType, ErrorType]]:
session = await self._wait_for_handshake()
session = await self._get_or_create_session()
return session.send_subscription(
service_name,
procedure_name,
Expand All @@ -139,7 +107,7 @@ async def send_stream(
response_deserializer: Callable[[Any], ResponseType],
error_deserializer: Callable[[Any], ErrorType],
) -> AsyncIterator[Union[ResponseType, ErrorType]]:
session = await self._wait_for_handshake()
session = await self._get_or_create_session()
return session.send_stream(
service_name,
procedure_name,
Expand Down
1 change: 1 addition & 0 deletions replit_river/client_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@


class ClientSession(Session):

async def send_rpc(
self,
service_name: str,
Expand Down
82 changes: 79 additions & 3 deletions replit_river/client_transport.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
import asyncio
import logging
from typing import Optional

import nanoid
from pydantic import ValidationError
from websockets import (
WebSocketCommonProtocol,
)
import websockets
from websockets.exceptions import ConnectionClosed

from replit_river.client_session import ClientSession
from replit_river.error_schema import (
ERROR_CODE_STREAM_CLOSED,
ERROR_HANDSHAKE,
ERROR_SESSION,
RiverException,
)
from replit_river.messages import (
Expand All @@ -27,10 +32,78 @@
InvalidTransportMessageException,
)
from replit_river.session import Session
from replit_river.task_manager import BackgroundTaskManager
from replit_river.transport import PROTOCOL_VERSION, Transport
from replit_river.transport_options import TransportOptions


class ClientTransport(Transport):
def __init__(
self,
websocket_uri: str,
client_id: str,
server_id: str,
transport_options: TransportOptions,
):
super().__init__(
transport_id=server_id,
transport_options=transport_options,
is_server=False,
)
self._websocket_uri = websocket_uri
self._client_id = client_id
self._server_id = server_id
self._instance_id = str(nanoid.generate())

async def _get_existing_session(self) -> Optional[ClientSession]:
if not self._sessions:
return None
if len(self._sessions) > 1:
raise RiverException(
"session_error",
"More than one session found in client, " + "should only be one",
)
session = list(self._sessions.values())[0]
if isinstance(session, ClientSession):
return session
else:
raise RiverException(
"session_error", f"Client session type wrong, got {type(session)}"
)

async def _create_session(self) -> ClientSession:
try:
logging.debug("Client start creating session")
self._ws = await websockets.connect(self._websocket_uri)
client_session = await self.create_client_session(
self._client_id, self._server_id, self._instance_id, self._ws
)
except RiverException as e:
logging.error(f"Error creating session: {e}")
raise RiverException(ERROR_SESSION, "Error creating session")
logging.debug("client start serving messages")
await client_session.start_serve_messages()
return client_session

async def _get_or_create_session(self) -> ClientSession:
async with self._session_lock:
existing_session = await self._get_existing_session()
if not existing_session:
return await self._create_session()
if not await existing_session.is_websocket_open():
logging.debug("Client session exists but websocket closed, reconnect one")
self._ws = await websockets.connect(self._websocket_uri)
await existing_session.replace_with_new_websocket(self._ws)
return existing_session

async def _on_websocket_closed(self) -> None:
session = await self._get_existing_session()
if session and session.is_session_open():
# TODO: do the retry correctly here
logging.error("Client session websocket closed, retrying")
self._ws = await websockets.connect(self._websocket_uri)
await session.replace_with_new_websocket(self._ws)

async def _send_handshake_request(
self,
transport_id: str,
Expand Down Expand Up @@ -60,12 +133,13 @@ async def _send_handshake_request(
),
ws=websocket,
prefix_bytes=self._transport_options.get_prefix_bytes(),
websocket_closed_callback=self._on_websocket_closed,
)
except ConnectionClosed as e:
logging.error(f"connection closed error during handshake : {e}")

async def close_session_callback(self, session: Session) -> None:
pass
logging.info(f"Client session {session._instance_id} closed")

async def create_client_session(
self,
Expand All @@ -90,8 +164,10 @@ async def create_client_session(
try:
data = await websocket.recv()
except ConnectionClosed as e:
# TODO: handle this here
pass
logging.error(
"Connection closed during waiting for handshake " f"response : {e}"
)
await self._on_websocket_closed()
try:
first_message = parse_transport_msg(data, self._transport_options)
except IgnoreTransportMessageException as e:
Expand Down
1 change: 1 addition & 0 deletions replit_river/error_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

ERROR_CODE_STREAM_CLOSED = "stream_closed"
ERROR_HANDSHAKE = "handshake_failed"
ERROR_SESSION = "session_error"


class RiverError(BaseModel):
Expand Down
8 changes: 7 additions & 1 deletion replit_river/server_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,13 @@ async def _send_handshake_response(
serviceName=request_message.serviceName,
procedureName=request_message.procedureName,
)
await send_transport_message(response_message, websocket)

async def websocket_closed_callback() -> None:
logging.error("websocket closed before handshake response")

await send_transport_message(
response_message, websocket, websocket_closed_callback
)
return response_message

async def _establish_handshake(
Expand Down
29 changes: 27 additions & 2 deletions replit_river/session.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import enum
import logging
from typing import Any, Callable, Coroutine, Dict, Optional, Tuple

Expand All @@ -7,7 +8,6 @@
from aiochannel import Channel, ChannelClosed
from websockets import WebSocketCommonProtocol
from websockets.exceptions import ConnectionClosedError
from websockets.server import WebSocketServerProtocol

from replit_river.message_buffer import MessageBuffer
from replit_river.messages import (
Expand All @@ -32,6 +32,12 @@
)


class SessionState(enum.Enum):
ACTIVE = 0
CLOSING = 1
CLOSED = 2


class Session(object):
"""A transport object that handles the websocket connection with a client."""

Expand All @@ -45,14 +51,20 @@ def __init__(
close_session_callback: Callable[["Session"], Coroutine[Any, Any, None]],
is_server: bool,
handlers: Dict[Tuple[str, str], Tuple[str, GenericRpcHandler]],
close_websocket_callback: Optional[
Callable[["Session"], Coroutine[Any, Any, None]]
] = None,
) -> None:
self._transport_id = transport_id
self._to_id = to_id
self._instance_id = instance_id
self._handlers = handlers
# ws should only be set while session creation, and replacing

self._state = SessionState.ACTIVE

self._ws_lock = asyncio.Lock()
self._ws = websocket
self._close_websocket_callback = close_websocket_callback

self._close_session_callback = close_session_callback
self._is_server = is_server
Expand All @@ -67,6 +79,13 @@ def __init__(
self._close_session_after_time_secs: Optional[float] = None
asyncio.create_task(self._task_manager.create_task(self._heartbeat(self._ws)))

async def is_session_open(self) -> bool:
return self._state == SessionState.ACTIVE

async def is_websocket_open(self) -> bool:
async with self._ws_lock:
return self._ws.open

async def serve(self) -> None:
"""Serve messages from the websocket."""
try:
Expand Down Expand Up @@ -301,6 +320,8 @@ async def close_websocket(self, websocket: WebSocketCommonProtocol) -> None:
"closing websocket"
)
async with self._ws_lock:
if self._close_websocket_callback:
await self._close_websocket_callback(self)
if websocket:
await websocket.close()

Expand Down Expand Up @@ -380,6 +401,9 @@ async def start_serve_messages(self) -> None:
async def close(self) -> None:
"""Close the session and all associated streams."""
logging.info(f"Closing session from {self._transport_id} to {self._to_id}")
if self._state == SessionState.CLOSING or self._state == SessionState.CLOSED:
return
self._state = SessionState.CLOSING
for previous_input in self._streams.values():
previous_input.close()
async with self._stream_lock:
Expand All @@ -388,3 +412,4 @@ async def close(self) -> None:
await self.close_websocket(self._ws)
# Clear the session in transports
await self._close_session_callback(self)
self._state = SessionState.CLOSED
6 changes: 6 additions & 0 deletions tests/unit/test_handshake.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# tests
# cleanup
# no hanging websocket and session
# websocket/session/transport

# no hanging tasks

0 comments on commit c2430fc

Please sign in to comment.