From c2430fce15318cbfbe745110965367ba3c205bb4 Mon Sep 17 00:00:00 2001 From: zhenthebuilder Date: Fri, 19 Apr 2024 19:34:29 -0700 Subject: [PATCH] Retry logic --- replit_river/client.py | 54 +++++---------------- replit_river/client_session.py | 1 + replit_river/client_transport.py | 82 ++++++++++++++++++++++++++++++-- replit_river/error_schema.py | 1 + replit_river/server_transport.py | 8 +++- replit_river/session.py | 29 ++++++++++- tests/unit/test_handshake.py | 6 +++ 7 files changed, 132 insertions(+), 49 deletions(-) diff --git a/replit_river/client.py b/replit_river/client.py index 1f9fd95..532d06e 100644 --- a/replit_river/client.py +++ b/replit_river/client.py @@ -1,15 +1,9 @@ -import asyncio -import logging from collections.abc import AsyncIterable, AsyncIterator from typing import Any, Callable, Optional, Union -import nanoid # type: ignore -from websockets.client import WebSocketClientProtocol from replit_river.client_session import ClientSession from replit_river.client_transport import ClientTransport -from replit_river.error_schema import RiverException -from replit_river.task_manager import BackgroundTaskManager from replit_river.transport_options import TransportOptions from .rpc import ( @@ -23,49 +17,23 @@ class Client: def __init__( self, - # websocket constructor to reconnect - websockets: WebSocketClientProtocol, + websocket_uri: str, client_id: str, server_id: str, transport_options: TransportOptions, ) -> None: - self._ws = websockets - self._client_id = client_id - self._server_id = server_id - self._instance_id = str(nanoid.generate()) - self._transport_options = transport_options self._transport = ClientTransport( - transport_id=self._server_id, + websocket_uri=websocket_uri, + client_id=client_id, + server_id=server_id, transport_options=transport_options, - is_server=True, ) - self._client_session: Optional[ClientSession] = None - self._task_manager = BackgroundTaskManager() - asyncio.create_task(self._task_manager.create_task(self._create_session())) async def close(self) -> None: - if self._client_session is not None: - await self._task_manager.cancel_all_tasks() - await self._client_session.close() + await self._transport.close_all_sessions() - async def _create_session(self) -> None: - try: - logging.debug("Client start creating session") - client_session = await self._transport.create_client_session( - self._client_id, self._server_id, self._instance_id, self._ws - ) - except RiverException as e: - # TODO: we need some retry logic here. - logging.error(f"Error creating session: {e}") - return - self._client_session = client_session - logging.debug("client start serving messages") - await self._client_session.start_serve_messages() - - async def _wait_for_handshake(self) -> ClientSession: - while self._client_session is None: - await asyncio.sleep(0.1) - return self._client_session + async def _get_or_create_session(self) -> ClientSession: + return await self._transport._get_or_create_session() async def send_rpc( self, @@ -76,7 +44,7 @@ async def send_rpc( response_deserializer: Callable[[Any], ResponseType], error_deserializer: Callable[[Any], ErrorType], ) -> ResponseType: - session = await self._wait_for_handshake() + session = await self._get_or_create_session() return await session.send_rpc( service_name, procedure_name, @@ -97,7 +65,7 @@ async def send_upload( response_deserializer: Callable[[Any], ResponseType], error_deserializer: Callable[[Any], ErrorType], ) -> ResponseType: - session = await self._wait_for_handshake() + session = await self._get_or_create_session() return await session.send_upload( service_name, procedure_name, @@ -118,7 +86,7 @@ async def send_subscription( response_deserializer: Callable[[Any], ResponseType], error_deserializer: Callable[[Any], ErrorType], ) -> AsyncIterator[Union[ResponseType, ErrorType]]: - session = await self._wait_for_handshake() + session = await self._get_or_create_session() return session.send_subscription( service_name, procedure_name, @@ -139,7 +107,7 @@ async def send_stream( response_deserializer: Callable[[Any], ResponseType], error_deserializer: Callable[[Any], ErrorType], ) -> AsyncIterator[Union[ResponseType, ErrorType]]: - session = await self._wait_for_handshake() + session = await self._get_or_create_session() return session.send_stream( service_name, procedure_name, diff --git a/replit_river/client_session.py b/replit_river/client_session.py index 93b0758..3fc0da9 100644 --- a/replit_river/client_session.py +++ b/replit_river/client_session.py @@ -19,6 +19,7 @@ class ClientSession(Session): + async def send_rpc( self, service_name: str, diff --git a/replit_river/client_transport.py b/replit_river/client_transport.py index de52c9a..0bea714 100644 --- a/replit_river/client_transport.py +++ b/replit_river/client_transport.py @@ -1,15 +1,20 @@ +import asyncio import logging +from typing import Optional +import nanoid from pydantic import ValidationError from websockets import ( WebSocketCommonProtocol, ) +import websockets from websockets.exceptions import ConnectionClosed from replit_river.client_session import ClientSession from replit_river.error_schema import ( ERROR_CODE_STREAM_CLOSED, ERROR_HANDSHAKE, + ERROR_SESSION, RiverException, ) from replit_river.messages import ( @@ -27,10 +32,78 @@ InvalidTransportMessageException, ) from replit_river.session import Session +from replit_river.task_manager import BackgroundTaskManager from replit_river.transport import PROTOCOL_VERSION, Transport +from replit_river.transport_options import TransportOptions class ClientTransport(Transport): + def __init__( + self, + websocket_uri: str, + client_id: str, + server_id: str, + transport_options: TransportOptions, + ): + super().__init__( + transport_id=server_id, + transport_options=transport_options, + is_server=False, + ) + self._websocket_uri = websocket_uri + self._client_id = client_id + self._server_id = server_id + self._instance_id = str(nanoid.generate()) + + async def _get_existing_session(self) -> Optional[ClientSession]: + if not self._sessions: + return None + if len(self._sessions) > 1: + raise RiverException( + "session_error", + "More than one session found in client, " + "should only be one", + ) + session = list(self._sessions.values())[0] + if isinstance(session, ClientSession): + return session + else: + raise RiverException( + "session_error", f"Client session type wrong, got {type(session)}" + ) + + async def _create_session(self) -> ClientSession: + try: + logging.debug("Client start creating session") + self._ws = await websockets.connect(self._websocket_uri) + client_session = await self.create_client_session( + self._client_id, self._server_id, self._instance_id, self._ws + ) + except RiverException as e: + logging.error(f"Error creating session: {e}") + raise RiverException(ERROR_SESSION, "Error creating session") + logging.debug("client start serving messages") + await client_session.start_serve_messages() + return client_session + + async def _get_or_create_session(self) -> ClientSession: + async with self._session_lock: + existing_session = await self._get_existing_session() + if not existing_session: + return await self._create_session() + if not await existing_session.is_websocket_open(): + logging.debug("Client session exists but websocket closed, reconnect one") + self._ws = await websockets.connect(self._websocket_uri) + await existing_session.replace_with_new_websocket(self._ws) + return existing_session + + async def _on_websocket_closed(self) -> None: + session = await self._get_existing_session() + if session and session.is_session_open(): + # TODO: do the retry correctly here + logging.error("Client session websocket closed, retrying") + self._ws = await websockets.connect(self._websocket_uri) + await session.replace_with_new_websocket(self._ws) + async def _send_handshake_request( self, transport_id: str, @@ -60,12 +133,13 @@ async def _send_handshake_request( ), ws=websocket, prefix_bytes=self._transport_options.get_prefix_bytes(), + websocket_closed_callback=self._on_websocket_closed, ) except ConnectionClosed as e: logging.error(f"connection closed error during handshake : {e}") async def close_session_callback(self, session: Session) -> None: - pass + logging.info(f"Client session {session._instance_id} closed") async def create_client_session( self, @@ -90,8 +164,10 @@ async def create_client_session( try: data = await websocket.recv() except ConnectionClosed as e: - # TODO: handle this here - pass + logging.error( + "Connection closed during waiting for handshake " f"response : {e}" + ) + await self._on_websocket_closed() try: first_message = parse_transport_msg(data, self._transport_options) except IgnoreTransportMessageException as e: diff --git a/replit_river/error_schema.py b/replit_river/error_schema.py index 1ab7f7c..c8d87a0 100644 --- a/replit_river/error_schema.py +++ b/replit_river/error_schema.py @@ -4,6 +4,7 @@ ERROR_CODE_STREAM_CLOSED = "stream_closed" ERROR_HANDSHAKE = "handshake_failed" +ERROR_SESSION = "session_error" class RiverError(BaseModel): diff --git a/replit_river/server_transport.py b/replit_river/server_transport.py index 2bcc1d1..48faefe 100644 --- a/replit_river/server_transport.py +++ b/replit_river/server_transport.py @@ -128,7 +128,13 @@ async def _send_handshake_response( serviceName=request_message.serviceName, procedureName=request_message.procedureName, ) - await send_transport_message(response_message, websocket) + + async def websocket_closed_callback() -> None: + logging.error("websocket closed before handshake response") + + await send_transport_message( + response_message, websocket, websocket_closed_callback + ) return response_message async def _establish_handshake( diff --git a/replit_river/session.py b/replit_river/session.py index b2fc549..da753a5 100644 --- a/replit_river/session.py +++ b/replit_river/session.py @@ -1,4 +1,5 @@ import asyncio +import enum import logging from typing import Any, Callable, Coroutine, Dict, Optional, Tuple @@ -7,7 +8,6 @@ from aiochannel import Channel, ChannelClosed from websockets import WebSocketCommonProtocol from websockets.exceptions import ConnectionClosedError -from websockets.server import WebSocketServerProtocol from replit_river.message_buffer import MessageBuffer from replit_river.messages import ( @@ -32,6 +32,12 @@ ) +class SessionState(enum.Enum): + ACTIVE = 0 + CLOSING = 1 + CLOSED = 2 + + class Session(object): """A transport object that handles the websocket connection with a client.""" @@ -45,14 +51,20 @@ def __init__( close_session_callback: Callable[["Session"], Coroutine[Any, Any, None]], is_server: bool, handlers: Dict[Tuple[str, str], Tuple[str, GenericRpcHandler]], + close_websocket_callback: Optional[ + Callable[["Session"], Coroutine[Any, Any, None]] + ] = None, ) -> None: self._transport_id = transport_id self._to_id = to_id self._instance_id = instance_id self._handlers = handlers - # ws should only be set while session creation, and replacing + + self._state = SessionState.ACTIVE + self._ws_lock = asyncio.Lock() self._ws = websocket + self._close_websocket_callback = close_websocket_callback self._close_session_callback = close_session_callback self._is_server = is_server @@ -67,6 +79,13 @@ def __init__( self._close_session_after_time_secs: Optional[float] = None asyncio.create_task(self._task_manager.create_task(self._heartbeat(self._ws))) + async def is_session_open(self) -> bool: + return self._state == SessionState.ACTIVE + + async def is_websocket_open(self) -> bool: + async with self._ws_lock: + return self._ws.open + async def serve(self) -> None: """Serve messages from the websocket.""" try: @@ -301,6 +320,8 @@ async def close_websocket(self, websocket: WebSocketCommonProtocol) -> None: "closing websocket" ) async with self._ws_lock: + if self._close_websocket_callback: + await self._close_websocket_callback(self) if websocket: await websocket.close() @@ -380,6 +401,9 @@ async def start_serve_messages(self) -> None: async def close(self) -> None: """Close the session and all associated streams.""" logging.info(f"Closing session from {self._transport_id} to {self._to_id}") + if self._state == SessionState.CLOSING or self._state == SessionState.CLOSED: + return + self._state = SessionState.CLOSING for previous_input in self._streams.values(): previous_input.close() async with self._stream_lock: @@ -388,3 +412,4 @@ async def close(self) -> None: await self.close_websocket(self._ws) # Clear the session in transports await self._close_session_callback(self) + self._state = SessionState.CLOSED diff --git a/tests/unit/test_handshake.py b/tests/unit/test_handshake.py index e69de29..63af66c 100644 --- a/tests/unit/test_handshake.py +++ b/tests/unit/test_handshake.py @@ -0,0 +1,6 @@ +# tests +# cleanup +# no hanging websocket and session +# websocket/session/transport + +# no hanging tasks