diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 446f797..e307b15 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -55,7 +55,7 @@ jobs: mypy pycrdt_websocket tests - name: Run tests run: | - pytest -v + pytest -v --color=yes check_release: runs-on: ubuntu-latest diff --git a/pycrdt_websocket/__init__.py b/pycrdt_websocket/__init__.py index 34e6235..51a2a90 100644 --- a/pycrdt_websocket/__init__.py +++ b/pycrdt_websocket/__init__.py @@ -1,7 +1,8 @@ from .asgi_server import ASGIServer as ASGIServer from .websocket_provider import WebsocketProvider as WebsocketProvider from .websocket_server import WebsocketServer as WebsocketServer -from .websocket_server import YRoom as YRoom +from .websocket_server import exception_logger as exception_logger +from .yroom import YRoom as YRoom from .yutils import YMessageType as YMessageType __version__ = "0.13.0" diff --git a/pycrdt_websocket/websocket_server.py b/pycrdt_websocket/websocket_server.py index 0c5f0bd..1846346 100644 --- a/pycrdt_websocket/websocket_server.py +++ b/pycrdt_websocket/websocket_server.py @@ -3,6 +3,7 @@ from contextlib import AsyncExitStack from functools import partial from logging import Logger, getLogger +from typing import Callable from anyio import TASK_STATUS_IGNORED, Event, Lock, create_task_group from anyio.abc import TaskGroup, TaskStatus @@ -17,11 +18,16 @@ class WebsocketServer: auto_clean_rooms: bool rooms: dict[str, YRoom] _started: Event | None = None + _stopped: Event _task_group: TaskGroup | None = None __start_lock: Lock | None = None def __init__( - self, rooms_ready: bool = True, auto_clean_rooms: bool = True, log: Logger | None = None + self, + rooms_ready: bool = True, + auto_clean_rooms: bool = True, + exception_handler: Callable[[Exception, Logger], bool] | None = None, + log: Logger | None = None, ) -> None: """Initialize the object. @@ -41,12 +47,16 @@ def __init__( Arguments: rooms_ready: Whether rooms are ready to be synchronized when opened. auto_clean_rooms: Whether rooms should be deleted when no client is there anymore. + exception_handler: An optional callback to call when an exception is raised, that + returns True if the exception was handled. log: An optional logger. """ self.rooms_ready = rooms_ready self.auto_clean_rooms = auto_clean_rooms + self.exception_handler = exception_handler self.log = log or getLogger(__name__) self.rooms = {} + self._stopped = Event() @property def started(self) -> Event: @@ -146,17 +156,15 @@ async def serve(self, websocket: Websocket) -> None: "`await websocket_server.start()`" ) - async with create_task_group() as tg: - tg.start_soon(self._serve, websocket, tg) - - async def _serve(self, websocket: Websocket, tg: TaskGroup): - room = await self.get_room(websocket.path) - await self.start_room(room) - await room.serve(websocket) - - if self.auto_clean_rooms and not room.clients: - await self.delete_room(room=room) - tg.cancel_scope.cancel() + try: + async with create_task_group(): + room = await self.get_room(websocket.path) + await self.start_room(room) + await room.serve(websocket) + if self.auto_clean_rooms and not room.clients: + await self.delete_room(room=room) + except Exception as exception: + self._handle_exception(exception) async def __aenter__(self) -> WebsocketServer: async with self._start_lock: @@ -164,10 +172,9 @@ async def __aenter__(self) -> WebsocketServer: raise RuntimeError("WebsocketServer already running") async with AsyncExitStack() as exit_stack: - tg = create_task_group() - self._task_group = await exit_stack.enter_async_context(tg) + self._task_group = await exit_stack.enter_async_context(create_task_group()) self._exit_stack = exit_stack.pop_all() - await tg.start(partial(self.start, from_context_manager=True)) + await self._task_group.start(partial(self.start, from_context_manager=True)) return self @@ -175,6 +182,13 @@ async def __aexit__(self, exc_type, exc_value, exc_tb): await self.stop() return await self._exit_stack.__aexit__(exc_type, exc_value, exc_tb) + def _handle_exception(self, exception: Exception) -> None: + exception_handled = False + if self.exception_handler is not None: + exception_handled = self.exception_handler(exception, self.log) + if not exception_handled: + raise exception + async def start( self, *, @@ -190,24 +204,37 @@ async def start( task_status.started() self.started.set() assert self._task_group is not None - # wait forever - self._task_group.start_soon(Event().wait) + # wait until stopped + self._task_group.start_soon(self._stopped.wait) return async with self._start_lock: if self._task_group is not None: raise RuntimeError("WebsocketServer already running") - async with create_task_group() as self._task_group: - task_status.started() - self.started.set() - # wait forever - self._task_group.start_soon(Event().wait) + while True: + try: + async with create_task_group() as self._task_group: + if not self.started.is_set(): + task_status.started() + self.started.set() + # wait until stopped + self._task_group.start_soon(self._stopped.wait) + return + except Exception as exception: + self._handle_exception(exception) async def stop(self) -> None: """Stop the WebSocket server.""" if self._task_group is None: raise RuntimeError("WebsocketServer not running") + self._stopped.set() self._task_group.cancel_scope.cancel() self._task_group = None + + +def exception_logger(exception: Exception, log: Logger) -> bool: + """An exception handler that logs the exception and discards it.""" + log.error("WebsocketServer exception", exc_info=exception) + return True # the exception was handled diff --git a/tests/conftest.py b/tests/conftest.py index caba90e..561a186 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -44,7 +44,7 @@ async def yws_server(request, unused_tcp_port, websocket_server_api): ) await ensure_server_running("localhost", unused_tcp_port) pytest.port = unused_tcp_port - yield unused_tcp_port + yield unused_tcp_port, websocket_server shutdown_event.set() except Exception: pass diff --git a/tests/test_server.py b/tests/test_server.py new file mode 100644 index 0000000..3bb7358 --- /dev/null +++ b/tests/test_server.py @@ -0,0 +1,18 @@ +import pytest +from anyio import sleep + +from pycrdt_websocket import exception_logger + +pytestmark = pytest.mark.anyio + + +@pytest.mark.parametrize("websocket_server_api", ["websocket_server_start_stop"], indirect=True) +@pytest.mark.parametrize("yws_server", [{"exception_handler": exception_logger}], indirect=True) +async def test_server_restart(yws_server): + port, server = yws_server + + async def raise_error(): + raise RuntimeError("foo") + + server._task_group.start_soon(raise_error) + await sleep(0.1)