From d5105d6570d91eb2b75ad7e2a0dd054dfbef5652 Mon Sep 17 00:00:00 2001 From: Francis CLAIRICIA-ROSE-CLAIRE-JOSEPHINE Date: Sat, 6 Jul 2024 19:49:12 +0200 Subject: [PATCH 1/5] Servers: Common implementation for both TCP and UDP servers --- docs/source/_extensions/sphinx_easynetwork.py | 26 +- src/easynetwork/lowlevel/_utils.py | 17 ++ src/easynetwork/servers/_base.py | 277 +++++++++++++++++- src/easynetwork/servers/abc.py | 51 ++-- src/easynetwork/servers/async_tcp.py | 253 ++++------------ src/easynetwork/servers/async_udp.py | 234 +++++---------- src/easynetwork/servers/standalone_tcp.py | 15 +- src/easynetwork/servers/standalone_udp.py | 12 +- .../test_async/test_server/base.py | 46 ++- .../test_async/test_server/test_tcp.py | 59 ++-- .../test_async/test_server/test_udp.py | 38 ++- .../test_sync/test_server/test_standalone.py | 26 -- .../test_concurrency/conftest.py | 1 + tests/scripts/async_server_test.py | 6 +- tests/unit_test/test_tools/test_utils.py | 17 ++ 15 files changed, 594 insertions(+), 484 deletions(-) diff --git a/docs/source/_extensions/sphinx_easynetwork.py b/docs/source/_extensions/sphinx_easynetwork.py index ee91addd..c69de638 100644 --- a/docs/source/_extensions/sphinx_easynetwork.py +++ b/docs/source/_extensions/sphinx_easynetwork.py @@ -3,32 +3,40 @@ v0.1.0: Initial v0.1.1: Fix base is not replaced if the class is generic. -v0.2.0 (current): Log when an object does not have a docstring. +v0.2.0: Log when an object does not have a docstring. +v0.2.1 (current): Add base class to replace. """ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Any, get_origin +from collections.abc import Callable +from typing import TYPE_CHECKING, Any, get_args, get_origin if TYPE_CHECKING: from sphinx.application import Sphinx -from easynetwork.servers._base import BaseStandaloneNetworkServerImpl -from easynetwork.servers.abc import AbstractNetworkServer +from easynetwork.servers._base import BaseAsyncNetworkServerImpl, BaseStandaloneNetworkServerImpl +from easynetwork.servers.abc import AbstractAsyncNetworkServer, AbstractNetworkServer logger = logging.getLogger(__name__) -def _replace_base_in_place(klass: type, bases: list[type], base_to_replace: type, base_to_set_instead: type) -> None: +def _replace_base_in_place( + klass: type, + bases: list[type], + base_to_replace: type, + base_to_set_instead: Callable[[tuple[Any, ...]], Any], +) -> None: if issubclass(klass, base_to_replace): for index, base in enumerate(bases): - if get_origin(base) is base_to_replace: - bases[index] = base_to_set_instead + if base is base_to_replace or get_origin(base) is base_to_replace: + bases[index] = base_to_set_instead(get_args(base)) def autodoc_process_bases(app: Sphinx, name: str, obj: type, options: dict[str, Any], bases: list[type]) -> None: - _replace_base_in_place(obj, bases, BaseStandaloneNetworkServerImpl, AbstractNetworkServer) + _replace_base_in_place(obj, bases, BaseAsyncNetworkServerImpl, lambda _: AbstractAsyncNetworkServer) + _replace_base_in_place(obj, bases, BaseStandaloneNetworkServerImpl, lambda _: AbstractNetworkServer) def _is_magic_method(name: str) -> bool: @@ -47,7 +55,7 @@ def setup(app: Sphinx) -> dict[str, Any]: app.connect("autodoc-process-docstring", autodoc_process_docstring) return { - "version": "0.2.0", + "version": "0.2.1", "parallel_read_safe": True, "parallel_write_safe": True, } diff --git a/src/easynetwork/lowlevel/_utils.py b/src/easynetwork/lowlevel/_utils.py index b2914efa..f1cb46a3 100644 --- a/src/easynetwork/lowlevel/_utils.py +++ b/src/easynetwork/lowlevel/_utils.py @@ -16,6 +16,7 @@ __all__ = [ "ElapsedTime", + "Flag", "WarnCallback", "adjust_leftover_buffer", "check_real_socket_state", @@ -395,6 +396,22 @@ def __call__( ) -> None: ... +class Flag: + __slots__ = ("__value", "__weakref__") + + def __init__(self) -> None: + self.__value: bool = False + + def is_set(self) -> bool: + return self.__value + + def set(self) -> None: + self.__value = True + + def clear(self) -> None: + self.__value = False + + class ElapsedTime: __slots__ = ("_current_time_func", "_start_time", "_end_time") diff --git a/src/easynetwork/servers/_base.py b/src/easynetwork/servers/_base.py index 42bcaf8c..4dd52533 100644 --- a/src/easynetwork/servers/_base.py +++ b/src/easynetwork/servers/_base.py @@ -16,27 +16,44 @@ from __future__ import annotations -__all__ = ["BaseStandaloneNetworkServerImpl"] +__all__ = ["BaseAsyncNetworkServerImpl", "BaseStandaloneNetworkServerImpl"] import concurrent.futures import contextlib +import dataclasses +import logging import threading as _threading -from collections.abc import Callable, Mapping, Sequence -from typing import Any, Generic, TypeVar +from abc import abstractmethod +from collections.abc import Awaitable, Callable, Mapping, Sequence +from types import TracebackType +from typing import Any, Generic, NoReturn, Protocol, Self, TypeVar from ..exceptions import ServerAlreadyRunning, ServerClosedError from ..lowlevel import _utils from ..lowlevel._lock import ForkSafeLock -from ..lowlevel.api_async.backend.abc import AsyncBackend, ThreadsPortal +from ..lowlevel.api_async.backend.abc import AsyncBackend, CancelScope, Task, TaskGroup, ThreadsPortal from ..lowlevel.api_async.backend.utils import BuiltinAsyncBackendLiteral, ensure_backend -from ..lowlevel.socket import SocketAddress from .abc import AbstractAsyncNetworkServer, AbstractNetworkServer, SupportsEventSet + +class _SupportsAclose(Protocol): + def is_closing(self) -> bool: ... + def aclose(self) -> Awaitable[object]: ... + + +_T_Address = TypeVar("_T_Address") _T_Return = TypeVar("_T_Return") _T_Default = TypeVar("_T_Default") _T_AsyncServer = TypeVar("_T_AsyncServer", bound=AbstractAsyncNetworkServer) +############################################################################################################## +# +# BLOCKING SERVER +# +############################################################################################################## + + class BaseStandaloneNetworkServerImpl(AbstractNetworkServer, Generic[_T_AsyncServer]): __slots__ = ( "__server_factory", @@ -188,6 +205,250 @@ async def serve_forever() -> None: backend.bootstrap(serve_forever, runner_options=runner_options) - @_utils.inherit_doc(AbstractNetworkServer) - def get_addresses(self) -> Sequence[SocketAddress]: - return self._run_sync_or(lambda portal, server: portal.run_sync(server.get_addresses), ()) + +_T_LowLevelServer = TypeVar("_T_LowLevelServer", bound=_SupportsAclose) + + +############################################################################################################## +# +# ASYNCHRONOUS SERVER +# +############################################################################################################## + + +@dataclasses.dataclass(repr=False, eq=False, frozen=True, slots=True) +class _BindServer(contextlib.AbstractContextManager[None]): + attach: Callable[[], None] + detach: Callable[[], None] + + def __enter__(self) -> None: + self.attach() + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + self.detach() + + +class BaseAsyncNetworkServerImpl(AbstractAsyncNetworkServer, Generic[_T_LowLevelServer, _T_Address]): + """ + An asynchronous network server for TCP connections. + """ + + __slots__ = ( + "__backend", + "__servers", + "__servers_factory_cb", + "__servers_factory_scope", + "__initialize_service_cb", + "__lowlevel_serve_cb", + "__server_activation_lock", + "__server_close_lock", + "__server_close_guard", + "__is_shutdown", + "__server_tasks", + "__server_run_scope", + "__active_tasks", + "__logger", + ) + + def __init__( + self, + *, + backend: AsyncBackend | BuiltinAsyncBackendLiteral | None, + servers_factory: Callable[[Self], Awaitable[Sequence[_T_LowLevelServer]]], + initialize_service: Callable[[Self, contextlib.AsyncExitStack], Awaitable[None]], + lowlevel_serve: Callable[[Self, _T_LowLevelServer, TaskGroup], Awaitable[NoReturn]], + logger: logging.Logger, + ) -> None: + """ + Parameters: + backend: The :term:`asynchronous backend interface` to use. + """ + super().__init__() + + backend = ensure_backend(backend) + + self.__backend: AsyncBackend = backend + self.__servers_factory_cb: Callable[[Self], Awaitable[Sequence[_T_LowLevelServer]]] | None = servers_factory + self.__initialize_service_cb: Callable[[Self, contextlib.AsyncExitStack], Awaitable[None]] = initialize_service + self.__lowlevel_serve_cb: Callable[[Self, _T_LowLevelServer, TaskGroup], Awaitable[NoReturn]] = lowlevel_serve + + self.__servers_factory_scope: CancelScope | None = None + self.__server_run_scope: CancelScope | None = None + self.__server_activation_lock = backend.create_lock() + self.__server_close_lock = backend.create_lock() + self.__server_close_guard = _utils.ResourceGuard("Cannot close server during serve_forever() setup.") + + self.__servers: list[_T_LowLevelServer] = [] + self.__is_shutdown = backend.create_event() + self.__is_shutdown.set() + self.__server_tasks: list[Task[NoReturn]] = [] + self.__logger: logging.Logger = logger + self.__active_tasks: int = 0 + + @_utils.inherit_doc(AbstractAsyncNetworkServer) + def is_serving(self) -> bool: + return bool(self.__server_tasks) and all(not t.done() for t in self.__server_tasks) and self.is_listening() + + @_utils.inherit_doc(AbstractAsyncNetworkServer) + def is_listening(self) -> bool: + return bool(self.__servers) and all(not server.is_closing() for server in self.__servers) + + @_utils.inherit_doc(AbstractAsyncNetworkServer) + async def server_activate(self) -> None: + async with self.__server_activation_lock: + assert self.__servers_factory_scope is None # nosec assert_used + if (servers_factory := self.__servers_factory_cb) is None: + raise ServerClosedError("Closed server") + if self.__servers: + return + listeners: list[_T_LowLevelServer] = [] + try: + with self.__backend.open_cancel_scope() as self.__servers_factory_scope: + await self.__backend.coro_yield() + listeners.extend(await servers_factory(self)) # type: ignore[arg-type] + if self.__servers_factory_scope.cancelled_caught(): + raise ServerClosedError("Server has been closed") + finally: + self.__servers_factory_scope = None + if not listeners: + raise OSError("empty listeners list") + self.__servers[:] = listeners + + @_utils.inherit_doc(AbstractAsyncNetworkServer) + async def server_close(self) -> None: + async with contextlib.AsyncExitStack() as exit_stack: + await exit_stack.enter_async_context(self.__server_close_lock) + exit_stack.enter_context(self.__server_close_guard) + + if self.__servers_factory_scope is not None: + self.__servers_factory_scope.cancel() + self.__servers_factory_cb = None + + exit_stack.callback(self.__servers.clear) + exit_stack.push_async_callback(self.__close_all_servers, self.__backend, self.__servers[:]) + + async with self.__backend.create_task_group() as group: + for task in self.__server_tasks: + task.cancel() + group.start_soon(task.wait) + + @classmethod + async def __close_all_servers(cls, backend: AsyncBackend, servers: Sequence[_T_LowLevelServer]) -> None: + async with backend.create_task_group() as group: + for server in servers: + group.start_soon(cls.__close_server, server) + + @classmethod + async def __close_server(cls, server: _T_LowLevelServer) -> None: + await server.aclose() + + @_utils.inherit_doc(AbstractAsyncNetworkServer) + async def shutdown(self) -> None: + if self.__server_run_scope is not None: + self.__server_run_scope.cancel() + await self.__is_shutdown.wait() + + @_utils.inherit_doc(AbstractAsyncNetworkServer) + async def serve_forever(self, *, is_up_event: SupportsEventSet | None = None) -> None: + async with contextlib.AsyncExitStack() as server_exit_stack: + # Wake up server + if not self.__is_shutdown.is_set(): + raise ServerAlreadyRunning("Server is already running") + self.__is_shutdown = is_shutdown = self.__backend.create_event() + server_exit_stack.callback(is_shutdown.set) + self.__server_run_scope = server_exit_stack.enter_context(self.__backend.open_cancel_scope()) + + def reset_scope() -> None: + self.__server_run_scope = None + + server_exit_stack.callback(reset_scope) + ################ + + # Bind and activate + await self.server_activate() + assert len(self.__servers) > 0 # nosec assert_used + ################### + + with self.__server_close_guard: + + # Final teardown + server_exit_stack.callback(self.__logger.info, "Server stopped") + ################ + + # Initialize service + initialize_service = self.__initialize_service_cb + await initialize_service(self, server_exit_stack) # type: ignore[arg-type] + ############################ + + # Setup task group + self.__active_tasks = 0 + server_exit_stack.callback(self.__server_tasks.clear) + task_group = await server_exit_stack.enter_async_context(self.__backend.create_task_group()) + server_exit_stack.callback(self.__logger.info, "Server loop break, waiting for remaining tasks...") + ################## + + # Enable listener + self.__server_tasks[:] = [await task_group.start(self.__serve, server, task_group) for server in self.__servers] + self.__logger.info("Start serving at %s", ", ".join(map(str, self.get_addresses()))) + ################# + + # Server is up + if is_up_event is not None: + is_up_event.set() + ############## + + # Main loop + try: + await self.__backend.sleep_forever() + finally: + reset_scope() + + @abstractmethod + def get_addresses(self) -> Sequence[_T_Address]: + """ + Returns all interfaces to which the server is bound. + + Returns: + A sequence of network socket address. + If the server is not serving (:meth:`is_serving` returns :data:`False`), an empty sequence is returned. + """ + raise NotImplementedError + + def _bind_server(self) -> _BindServer: + return _BindServer(self.__attach_server, self.__detach_server) + + async def __serve( + self, + server: _T_LowLevelServer, + task_group: TaskGroup, + ) -> NoReturn: + lowlevel_serve = self.__lowlevel_serve_cb + with _BindServer(self.__attach_server, self.__detach_server): + await lowlevel_serve(self, server, task_group) # type: ignore[arg-type] + + def __attach_server(self) -> None: + self.__active_tasks += 1 + + def __detach_server(self) -> None: + self.__active_tasks -= 1 + if self.__active_tasks < 0: + raise AssertionError("self.__active_tasks < 0") + if not self.__active_tasks and self.__server_run_scope is not None: + self.__server_run_scope.cancel() + + def _with_lowlevel_servers(self, f: Callable[[Sequence[_T_LowLevelServer]], _T_Return]) -> _T_Return: + servers = tuple(self.__servers) + return f(servers) + + @_utils.inherit_doc(AbstractAsyncNetworkServer) + def backend(self) -> AsyncBackend: + return self.__backend + + @property + def logger(self) -> logging.Logger: + return self.__logger diff --git a/src/easynetwork/servers/abc.py b/src/easynetwork/servers/abc.py index 669aae71..9ecb8965 100644 --- a/src/easynetwork/servers/abc.py +++ b/src/easynetwork/servers/abc.py @@ -23,12 +23,10 @@ ] from abc import ABCMeta, abstractmethod -from collections.abc import Sequence from types import TracebackType from typing import Protocol, Self from ..lowlevel.api_async.backend.abc import AsyncBackend -from ..lowlevel.socket import SocketAddress class SupportsEventSet(Protocol): @@ -108,17 +106,6 @@ def shutdown(self, timeout: float | None = ...) -> None: """ raise NotImplementedError - @abstractmethod - def get_addresses(self) -> Sequence[SocketAddress]: - """ - Returns all interfaces to which the server is bound. Thread-safe. - - Returns: - A sequence of network socket address. - If the server is not serving (:meth:`is_serving` returns :data:`False`), an empty sequence is returned. - """ - raise NotImplementedError - class AbstractAsyncNetworkServer(metaclass=ABCMeta): """ @@ -128,6 +115,8 @@ class AbstractAsyncNetworkServer(metaclass=ABCMeta): __slots__ = ("__weakref__",) async def __aenter__(self) -> Self: + """Calls :meth:`server_activate`.""" + await self.server_activate() return self async def __aexit__( @@ -142,7 +131,7 @@ async def __aexit__( @abstractmethod def is_serving(self) -> bool: """ - Checks whether the server is up and accepting new clients. + Checks whether the server is up (:meth:`is_listening` returns :data:`True`) and accepting new clients. """ raise NotImplementedError @@ -151,12 +140,36 @@ async def serve_forever(self, *, is_up_event: SupportsEventSet | None = ...) -> """ Starts the server's main loop. + Further calls to :meth:`is_serving` will return :data:`True` until the loop is stopped. + Parameters: is_up_event: If given, will be triggered when the server is ready to accept new clients. Raises: ServerClosedError: The server is closed. ServerAlreadyRunning: Another task already called :meth:`serve_forever`. + ServerNotActivated: :meth:`server_activate` must be used before calling :meth:`serve_forever`. + """ + raise NotImplementedError + + @abstractmethod + def is_listening(self) -> bool: + """ + Checks whether the server is up. + """ + raise NotImplementedError + + @abstractmethod + async def server_activate(self) -> None: + """ + Opens all listeners. + + This method MUST be idempotent. Further calls to :meth:`is_listening` will return :data:`True`. + + To stop and close the listeners, you can use :meth:`stop_listening`. + + Raises: + ServerClosedError: The server is closed. """ raise NotImplementedError @@ -179,16 +192,6 @@ async def shutdown(self) -> None: """ raise NotImplementedError - @abstractmethod - def get_addresses(self) -> Sequence[SocketAddress]: - """ - Returns all interfaces to which the server is bound. - - Returns: - A sequence of network socket address. - If the server is not serving (:meth:`is_serving` returns :data:`False`), an empty sequence is returned. - """ - @abstractmethod def backend(self) -> AsyncBackend: """ diff --git a/src/easynetwork/servers/async_tcp.py b/src/easynetwork/servers/async_tcp.py index c06f0c29..b3435869 100644 --- a/src/easynetwork/servers/async_tcp.py +++ b/src/easynetwork/servers/async_tcp.py @@ -21,16 +21,15 @@ import contextlib import logging import weakref -from collections import deque from collections.abc import AsyncIterator, Callable, Coroutine, Iterator, Mapping, Sequence from typing import TYPE_CHECKING, Any, Generic, NoReturn, final from .._typevars import _T_Request, _T_Response -from ..exceptions import ClientClosedError, ServerAlreadyRunning, ServerClosedError +from ..exceptions import ClientClosedError from ..lowlevel import _utils, constants from ..lowlevel._final import runtime_final_class -from ..lowlevel.api_async.backend.abc import AsyncBackend, CancelScope, IEvent, Task, TaskGroup -from ..lowlevel.api_async.backend.utils import BuiltinAsyncBackendLiteral, ensure_backend +from ..lowlevel.api_async.backend.abc import AsyncBackend, TaskGroup +from ..lowlevel.api_async.backend.utils import BuiltinAsyncBackendLiteral from ..lowlevel.api_async.servers import stream as _stream_server from ..lowlevel.api_async.transports.abc import AsyncListener, AsyncStreamTransport from ..lowlevel.socket import ( @@ -45,7 +44,7 @@ set_tcp_nodelay, ) from ..protocol import AnyStreamProtocolType -from .abc import AbstractAsyncNetworkServer, SupportsEventSet +from . import _base from .handlers import AsyncStreamClient, AsyncStreamRequestHandler, INETClientAttribute from .misc import build_lowlevel_stream_server_handler @@ -53,25 +52,20 @@ from ssl import SSLContext -class AsyncTCPNetworkServer(AbstractAsyncNetworkServer, Generic[_T_Request, _T_Response]): +class AsyncTCPNetworkServer( + _base.BaseAsyncNetworkServerImpl[_stream_server.AsyncStreamServer[_T_Request, _T_Response], SocketAddress], + Generic[_T_Request, _T_Response], +): """ An asynchronous network server for TCP connections. """ __slots__ = ( - "__backend", - "__servers", "__listeners_factory", - "__listeners_factory_scope", "__protocol", "__request_handler", - "__is_shutdown", "__max_recv_size", - "__servers_tasks", - "__server_run_scope", - "__active_tasks", "__client_connection_log_level", - "__logger", ) def __init__( @@ -129,7 +123,13 @@ def __init__( See Also: :ref:`SSL/TLS security considerations ` """ - super().__init__() + super().__init__( + backend=backend, + servers_factory=type(self).__activate_listeners, + initialize_service=type(self).__initialize_service, + lowlevel_serve=type(self).__lowlevel_serve, + logger=logger or logging.getLogger(__name__), + ) from ..lowlevel._stream import _check_any_protocol @@ -138,7 +138,7 @@ def __init__( if not isinstance(request_handler, AsyncStreamRequestHandler): raise TypeError(f"Expected an AsyncStreamRequestHandler object, got {request_handler!r}") - backend = ensure_backend(backend) + backend = self.backend() if backlog is None: backlog = 100 @@ -163,8 +163,7 @@ def __init__( if ssl_standard_compatible is None: ssl_standard_compatible = True - self.__backend: AsyncBackend = backend - self.__listeners_factory: Callable[[], Coroutine[Any, Any, Sequence[AsyncListener[AsyncStreamTransport]]]] | None + self.__listeners_factory: Callable[[], Coroutine[Any, Any, Sequence[AsyncListener[AsyncStreamTransport]]]] if ssl: self.__listeners_factory = _utils.make_callback( self.__create_ssl_over_tcp_listeners, @@ -186,23 +185,15 @@ def __init__( backlog=backlog, reuse_port=reuse_port, ) - self.__listeners_factory_scope: CancelScope | None = None - self.__server_run_scope: CancelScope | None = None - self.__servers: tuple[_stream_server.AsyncStreamServer[_T_Request, _T_Response], ...] | None = None self.__protocol: AnyStreamProtocolType[_T_Response, _T_Request] = protocol self.__request_handler: AsyncStreamRequestHandler[_T_Request, _T_Response] = request_handler - self.__is_shutdown: IEvent = backend.create_event() - self.__is_shutdown.set() self.__max_recv_size: int = max_recv_size - self.__servers_tasks: deque[Task[NoReturn]] = deque() - self.__logger: logging.Logger = logger or logging.getLogger(__name__) self.__client_connection_log_level: int if log_client_connection: self.__client_connection_log_level = logging.INFO else: self.__client_connection_log_level = logging.DEBUG - self.__active_tasks: int = 0 @staticmethod async def __create_ssl_over_tcp_listeners( @@ -236,152 +227,33 @@ async def __create_ssl_over_tcp_listeners( for listener in listeners ] - @_utils.inherit_doc(AbstractAsyncNetworkServer) - def is_serving(self) -> bool: - return self.__servers is not None and all(not server.is_closing() for server in self.__servers) - - def stop_listening(self) -> None: - """ - Schedules the shutdown of all listener sockets. - - After that, all new connections will be refused, but the server will continue to run and handle - previously accepted connections. - - Further calls to :meth:`is_serving` will return :data:`False`. - """ - with contextlib.ExitStack() as exit_stack: - for listener_task in self.__servers_tasks: - exit_stack.callback(listener_task.cancel) - del listener_task - - @_utils.inherit_doc(AbstractAsyncNetworkServer) - async def server_close(self) -> None: - if self.__listeners_factory_scope is not None: - self.__listeners_factory_scope.cancel() - self.__listeners_factory = None - await self.__close_servers() - - async def __close_servers(self) -> None: - async with contextlib.AsyncExitStack() as exit_stack: - server_close_group = await exit_stack.enter_async_context(self.__backend.create_task_group()) - - servers, self.__servers = self.__servers, None - if servers is not None: - exit_stack.push_async_callback(self.__backend.cancel_shielded_coro_yield) - for server in servers: - exit_stack.callback(server_close_group.start_soon, server.aclose) - del server - - for server_task in self.__servers_tasks: - server_task.cancel() - exit_stack.push_async_callback(server_task.wait) - del server_task - - if self.__server_run_scope is not None: - self.__server_run_scope.cancel() - - await self.__backend.cancel_shielded_coro_yield() - - @_utils.inherit_doc(AbstractAsyncNetworkServer) - async def shutdown(self) -> None: - if self.__server_run_scope is not None: - self.__server_run_scope.cancel() - await self.__is_shutdown.wait() - - @_utils.inherit_doc(AbstractAsyncNetworkServer) - async def serve_forever(self, *, is_up_event: SupportsEventSet | None = None) -> None: - async with contextlib.AsyncExitStack() as server_exit_stack: - # Wake up server - if not self.__is_shutdown.is_set(): - raise ServerAlreadyRunning("Server is already running") - self.__is_shutdown = is_shutdown = self.__backend.create_event() - server_exit_stack.callback(is_shutdown.set) - self.__server_run_scope = server_exit_stack.enter_context(self.__backend.open_cancel_scope()) - - def reset_scope() -> None: - self.__server_run_scope = None - - server_exit_stack.callback(reset_scope) - ################ - - # Bind and activate - assert self.__servers is None # nosec assert_used - assert self.__listeners_factory_scope is None # nosec assert_used - if self.__listeners_factory is None: - raise ServerClosedError("Closed server") - listeners: list[AsyncListener[AsyncStreamTransport]] = [] - try: - with self.__backend.open_cancel_scope() as self.__listeners_factory_scope: - await self.__backend.coro_yield() - listeners.extend(await self.__listeners_factory()) - if self.__listeners_factory_scope.cancelled_caught(): - raise ServerClosedError("Server has been closed during task setup") - finally: - self.__listeners_factory_scope = None - if not listeners: - raise OSError("empty listeners list") - self.__servers = tuple( - _stream_server.AsyncStreamServer( - listener, - self.__protocol, - max_recv_size=self.__max_recv_size, - ) - for listener in listeners - ) - del listeners - ################### - - # Final teardown - server_exit_stack.callback(self.__logger.info, "Server stopped") - server_exit_stack.push_async_callback(self.__close_servers) - ################ - - # Initialize request handler - await self.__request_handler.service_init( - await server_exit_stack.enter_async_context(contextlib.AsyncExitStack()), - weakref.proxy(self), + async def __activate_listeners(self) -> list[_stream_server.AsyncStreamServer[_T_Request, _T_Response]]: + return [ + _stream_server.AsyncStreamServer( + listener, + self.__protocol, + max_recv_size=self.__max_recv_size, ) - ############################ - - # Setup task group - self.__active_tasks = 0 - server_exit_stack.callback(self.__servers_tasks.clear) - task_group = await server_exit_stack.enter_async_context(self.__backend.create_task_group()) - server_exit_stack.callback(self.__logger.info, "Server loop break, waiting for remaining tasks...") - ################## - - # Enable listener - self.__servers_tasks.extend([await task_group.start(self.__serve, server, task_group) for server in self.__servers]) - self.__logger.info("Start serving at %s", ", ".join(map(str, self.get_addresses()))) - ################# - - # Server is up - if is_up_event is not None: - is_up_event.set() - ############## - - # Main loop - try: - await self.__backend.sleep_forever() - finally: - reset_scope() + for listener in await self.__listeners_factory() + ] + + async def __initialize_service(self, server_exit_stack: contextlib.AsyncExitStack) -> None: + await self.__request_handler.service_init( + await server_exit_stack.enter_async_context(contextlib.AsyncExitStack()), + weakref.proxy(self), + ) - async def __serve( + async def __lowlevel_serve( self, server: _stream_server.AsyncStreamServer[_T_Request, _T_Response], task_group: TaskGroup, ) -> NoReturn: - self.__attach_server() - try: - async with contextlib.aclosing(server): - handler = build_lowlevel_stream_server_handler( - self.__client_initializer, - self.__request_handler, - logger=self.__logger, - ) - await server.serve(handler, task_group) - finally: - self.__detach_server() + handler = build_lowlevel_stream_server_handler( + self.__client_initializer, + self.__request_handler, + logger=self.logger, + ) + await server.serve(handler, task_group) @contextlib.asynccontextmanager async def __client_initializer( @@ -389,15 +261,14 @@ async def __client_initializer( lowlevel_client: _stream_server.ConnectedStreamClient[_T_Response], ) -> AsyncIterator[AsyncStreamClient[_T_Response] | None]: async with contextlib.AsyncExitStack() as client_exit_stack: - self.__attach_server() - client_exit_stack.callback(self.__detach_server) + client_exit_stack.enter_context(self._bind_server()) client_address = lowlevel_client.extra(INETSocketAttribute.peername, None) if client_address is None: # The remote host closed the connection before starting the task. # See this test for details: # test____serve_forever____accept_client____client_sent_RST_packet_right_after_accept - self.__logger.warning("A client connection was interrupted just after listener.accept()") + self.logger.warning("A client connection was interrupted just after listener.accept()") yield None return @@ -416,7 +287,7 @@ async def __client_initializer( # We expect a TLS close handshake, so we must (try to) properly close the transport before await client_exit_stack.enter_async_context(contextlib.aclosing(lowlevel_client)) - logger: logging.Logger = self.__logger + logger: logging.Logger = self.logger client = _ConnectedClientAPI(client_address, lowlevel_client) del lowlevel_client @@ -431,16 +302,6 @@ async def __client_initializer( _utils.remove_traceback_frames_in_place(exc, 1) raise - def __attach_server(self) -> None: - self.__active_tasks += 1 - - def __detach_server(self) -> None: - self.__active_tasks -= 1 - if self.__active_tasks < 0: - raise AssertionError("self.__active_tasks < 0") - if not self.__active_tasks and self.__server_run_scope is not None: - self.__server_run_scope.cancel() - @staticmethod def __set_socket_linger_if_not_closed(socket: ISocket) -> None: with contextlib.suppress(OSError): @@ -454,7 +315,7 @@ def __suppress_and_log_remaining_exception(self, client_address: SocketAddress) yield except* ClientClosedError as excgrp: _utils.remove_traceback_frames_in_place(excgrp, 1) # Removes the 'yield' frame just above - self.__logger.warning( + self.logger.warning( "There have been attempts to do operation on closed client %s", client_address, exc_info=excgrp, @@ -466,18 +327,18 @@ def __suppress_and_log_remaining_exception(self, client_address: SocketAddress) pass except Exception as exc: _utils.remove_traceback_frames_in_place(exc, 1) # Removes the 'yield' frame just above - self.__logger.error("-" * 40) - self.__logger.error("Exception occurred during processing of request from %s", client_address, exc_info=exc) - self.__logger.error("-" * 40) + self.logger.error("-" * 40) + self.logger.error("Exception occurred during processing of request from %s", client_address, exc_info=exc) + self.logger.error("-" * 40) - @_utils.inherit_doc(AbstractAsyncNetworkServer) + @_utils.inherit_doc(_base.BaseAsyncNetworkServerImpl) def get_addresses(self) -> Sequence[SocketAddress]: - if (servers := self.__servers) is None: - return () - return tuple( - new_socket_address(server.extra(INETSocketAttribute.sockname), server.extra(INETSocketAttribute.family)) - for server in servers - if not server.is_closing() + return self._with_lowlevel_servers( + lambda servers: tuple( + new_socket_address(server.extra(INETSocketAttribute.sockname), server.extra(INETSocketAttribute.family)) + for server in servers + if not server.is_closing() + ) ) def get_sockets(self) -> Sequence[SocketProxy]: @@ -488,13 +349,9 @@ def get_sockets(self) -> Sequence[SocketProxy]: If the server is not running, an empty sequence is returned. """ - if (servers := self.__servers) is None: - return () - return tuple(SocketProxy(server.extra(INETSocketAttribute.socket)) for server in servers) - - @_utils.inherit_doc(AbstractAsyncNetworkServer) - def backend(self) -> AsyncBackend: - return self.__backend + return self._with_lowlevel_servers( + lambda servers: tuple(SocketProxy(server.extra(INETSocketAttribute.socket)) for server in servers) + ) @final diff --git a/src/easynetwork/servers/async_udp.py b/src/easynetwork/servers/async_udp.py index 3afc2109..9283634b 100644 --- a/src/easynetwork/servers/async_udp.py +++ b/src/easynetwork/servers/async_udp.py @@ -22,41 +22,40 @@ import logging import types import weakref -from collections import deque from collections.abc import Callable, Coroutine, Mapping, Sequence from typing import Any, Generic, NoReturn, final from .._typevars import _T_Request, _T_Response -from ..exceptions import ClientClosedError, ServerAlreadyRunning, ServerClosedError +from ..exceptions import ClientClosedError from ..lowlevel import _utils from ..lowlevel._final import runtime_final_class -from ..lowlevel.api_async.backend.abc import AsyncBackend, CancelScope, IEvent, Task, TaskGroup -from ..lowlevel.api_async.backend.utils import BuiltinAsyncBackendLiteral, ensure_backend +from ..lowlevel.api_async.backend.abc import AsyncBackend, TaskGroup +from ..lowlevel.api_async.backend.utils import BuiltinAsyncBackendLiteral from ..lowlevel.api_async.servers import datagram as _datagram_server from ..lowlevel.api_async.transports.abc import AsyncDatagramListener from ..lowlevel.socket import INETSocketAttribute, SocketAddress, SocketProxy, new_socket_address from ..protocol import DatagramProtocol -from .abc import AbstractAsyncNetworkServer, SupportsEventSet +from . import _base from .handlers import AsyncDatagramClient, AsyncDatagramRequestHandler, INETClientAttribute from .misc import build_lowlevel_datagram_server_handler -class AsyncUDPNetworkServer(AbstractAsyncNetworkServer, Generic[_T_Request, _T_Response]): +class AsyncUDPNetworkServer( + _base.BaseAsyncNetworkServerImpl[ + _datagram_server.AsyncDatagramServer[_T_Request, _T_Response, tuple[Any, ...]], + SocketAddress, + ], + Generic[_T_Request, _T_Response], +): """ An asynchronous network server for UDP communication. """ __slots__ = ( - "__backend", - "__servers", "__listeners_factory", - "__listeners_factory_scope", "__protocol", "__request_handler", - "__is_shutdown", - "__servers_tasks", - "__server_run_scope", - "__logger", + "__service_available", ) def __init__( @@ -86,166 +85,73 @@ def __init__( This option is not supported on Windows. logger: If given, the logger instance to use. """ - super().__init__() + super().__init__( + backend=backend, + servers_factory=type(self).__activate_listeners, + initialize_service=type(self).__initialize_service, + lowlevel_serve=type(self).__lowlevel_serve, + logger=logger or logging.getLogger(__name__), + ) if not isinstance(protocol, DatagramProtocol): raise TypeError(f"Expected a DatagramProtocol object, got {protocol!r}") if not isinstance(request_handler, AsyncDatagramRequestHandler): raise TypeError(f"Expected an AsyncDatagramRequestHandler object, got {request_handler!r}") - backend = ensure_backend(backend) + backend = self.backend() - self.__backend: AsyncBackend = backend - self.__listeners_factory: Callable[[], Coroutine[Any, Any, Sequence[AsyncDatagramListener[tuple[Any, ...]]]]] | None + self.__listeners_factory: Callable[[], Coroutine[Any, Any, Sequence[AsyncDatagramListener[tuple[Any, ...]]]]] self.__listeners_factory = _utils.make_callback( backend.create_udp_listeners, host, port, reuse_port=reuse_port, ) - self.__listeners_factory_scope: CancelScope | None = None - self.__server_run_scope: CancelScope | None = None - self.__servers: tuple[_datagram_server.AsyncDatagramServer[_T_Request, _T_Response, tuple[Any, ...]], ...] | None - self.__servers = None self.__protocol: DatagramProtocol[_T_Response, _T_Request] = protocol self.__request_handler: AsyncDatagramRequestHandler[_T_Request, _T_Response] = request_handler - self.__is_shutdown: IEvent = backend.create_event() - self.__is_shutdown.set() - self.__servers_tasks: deque[Task[NoReturn]] = deque() - self.__logger: logging.Logger = logger or logging.getLogger(__name__) - - @_utils.inherit_doc(AbstractAsyncNetworkServer) - def is_serving(self) -> bool: - return self.__servers is not None and all(not server.is_closing() for server in self.__servers) - - @_utils.inherit_doc(AbstractAsyncNetworkServer) - async def server_close(self) -> None: - if self.__listeners_factory_scope is not None: - self.__listeners_factory_scope.cancel() - self.__listeners_factory = None - await self.__close_servers() - - async def __close_servers(self) -> None: - async with contextlib.AsyncExitStack() as exit_stack: - server_close_group = await exit_stack.enter_async_context(self.__backend.create_task_group()) - - servers, self.__servers = self.__servers, None - if servers is not None: - exit_stack.push_async_callback(self.__backend.cancel_shielded_coro_yield) - for server in servers: - exit_stack.callback(server_close_group.start_soon, server.aclose) - del server - - for server_task in self.__servers_tasks: - server_task.cancel() - exit_stack.push_async_callback(server_task.wait) - del server_task - - if self.__server_run_scope is not None: - self.__server_run_scope.cancel() - - await self.__backend.cancel_shielded_coro_yield() - - @_utils.inherit_doc(AbstractAsyncNetworkServer) - async def shutdown(self) -> None: - if self.__server_run_scope is not None: - self.__server_run_scope.cancel() - await self.__is_shutdown.wait() - - @_utils.inherit_doc(AbstractAsyncNetworkServer) - async def serve_forever(self, *, is_up_event: SupportsEventSet | None = None) -> None: - async with contextlib.AsyncExitStack() as server_exit_stack: - # Wake up server - if not self.__is_shutdown.is_set(): - raise ServerAlreadyRunning("Server is already running") - self.__is_shutdown = is_shutdown = self.__backend.create_event() - server_exit_stack.callback(is_shutdown.set) - self.__server_run_scope = server_exit_stack.enter_context(self.__backend.open_cancel_scope()) - - def reset_scope() -> None: - self.__server_run_scope = None - - server_exit_stack.callback(reset_scope) - ################ - - # Bind and activate - assert self.__servers is None # nosec assert_used - assert self.__listeners_factory_scope is None # nosec assert_used - if self.__listeners_factory is None: - raise ServerClosedError("Closed server") - listeners: list[AsyncDatagramListener[tuple[Any, ...]]] = [] - try: - with self.__backend.open_cancel_scope() as self.__listeners_factory_scope: - await self.__backend.coro_yield() - listeners.extend(await self.__listeners_factory()) - if self.__listeners_factory_scope.cancelled_caught(): - raise ServerClosedError("Server has been closed during task setup") - finally: - self.__listeners_factory_scope = None - if not listeners: - raise OSError("empty listeners list") - self.__servers = tuple(_datagram_server.AsyncDatagramServer(listener, self.__protocol) for listener in listeners) - del listeners - ################### - - # Final teardown - server_exit_stack.callback(self.__logger.info, "Server stopped") - ################ - - # Initialize request handler - await self.__request_handler.service_init( - await server_exit_stack.enter_async_context(contextlib.AsyncExitStack()), - weakref.proxy(self), - ) - server_exit_stack.push_async_callback(self.__close_servers) - ############################ - - # Setup task group - server_exit_stack.callback(self.__servers_tasks.clear) - task_group: TaskGroup = await server_exit_stack.enter_async_context(self.__backend.create_task_group()) - server_exit_stack.callback(self.__logger.info, "Server loop break, waiting for remaining tasks...") - ################## - - # Enable listener - self.__servers_tasks.extend( - [ - await task_group.start( - server.serve, - build_lowlevel_datagram_server_handler(self.__client_initializer, self.__request_handler), - task_group, - ) - for server in self.__servers - ] + self.__service_available = _utils.Flag() + + async def __activate_listeners(self) -> list[_datagram_server.AsyncDatagramServer[_T_Request, _T_Response, tuple[Any, ...]]]: + return [ + _datagram_server.AsyncDatagramServer( + listener, + self.__protocol, ) - self.__logger.info("Start serving at %s", ", ".join(map(str, self.get_addresses()))) - ################# + for listener in await self.__listeners_factory() + ] + + async def __initialize_service(self, server_exit_stack: contextlib.AsyncExitStack) -> None: + await self.__request_handler.service_init( + await server_exit_stack.enter_async_context(contextlib.AsyncExitStack()), + weakref.proxy(self), + ) - # Server is up - if is_up_event is not None: - is_up_event.set() - ############## + self.__service_available.set() + server_exit_stack.callback(self.__service_available.clear) - # Main loop - try: - await self.__backend.sleep_forever() - finally: - reset_scope() + async def __lowlevel_serve( + self, + server: _datagram_server.AsyncDatagramServer[_T_Request, _T_Response, tuple[Any, ...]], + task_group: TaskGroup, + ) -> NoReturn: + handler = build_lowlevel_datagram_server_handler(self.__client_initializer, self.__request_handler) + await server.serve(handler, task_group) def __client_initializer( self, lowlevel_client: _datagram_server.DatagramClientContext[_T_Response, tuple[Any, ...]], ) -> _ClientContext: - return _ClientContext(lowlevel_client, self.__logger) + return _ClientContext(lowlevel_client, self.__service_available, self.logger) - @_utils.inherit_doc(AbstractAsyncNetworkServer) + @_utils.inherit_doc(_base.BaseAsyncNetworkServerImpl) def get_addresses(self) -> Sequence[SocketAddress]: - if (servers := self.__servers) is None: - return () - return tuple( - new_socket_address(server.extra(INETSocketAttribute.sockname), server.extra(INETSocketAttribute.family)) - for server in servers - if not server.is_closing() + return self._with_lowlevel_servers( + lambda servers: tuple( + new_socket_address(server.extra(INETSocketAttribute.sockname), server.extra(INETSocketAttribute.family)) + for server in servers + if not server.is_closing() + ) ) def get_sockets(self) -> Sequence[SocketProxy]: @@ -256,13 +162,9 @@ def get_sockets(self) -> Sequence[SocketProxy]: If the server is not running, an empty sequence is returned. """ - if (servers := self.__servers) is None: - return () - return tuple(SocketProxy(server.extra(INETSocketAttribute.socket)) for server in servers) - - @_utils.inherit_doc(AbstractAsyncNetworkServer) - def backend(self) -> AsyncBackend: - return self.__backend + return self._with_lowlevel_servers( + lambda servers: tuple(SocketProxy(server.extra(INETSocketAttribute.socket)) for server in servers) + ) @final @@ -270,15 +172,21 @@ def backend(self) -> AsyncBackend: class _ClientAPI(AsyncDatagramClient[_T_Response]): __slots__ = ( "__context", + "__service_available", "__h", "__extra_attributes_cache", ) - def __init__(self, context: _datagram_server.DatagramClientContext[_T_Response, tuple[Any, ...]]) -> None: + def __init__( + self, + context: _datagram_server.DatagramClientContext[_T_Response, tuple[Any, ...]], + service_available: _utils.Flag, + ) -> None: super().__init__() self.__context: _datagram_server.DatagramClientContext[_T_Response, tuple[Any, ...]] = context self.__h: int | None = None self.__extra_attributes_cache: Mapping[Any, Callable[[], Any]] | None = None + self.__service_available: _utils.Flag = service_available def __repr__(self) -> str: return f"" @@ -294,12 +202,19 @@ def __eq__(self, other: object) -> bool: return self.__context == other.__context def is_closing(self) -> bool: - return self.__context.server.is_closing() + return self.__is_closing(self.__service_available, self.__context.server) + + @staticmethod + def __is_closing( + service_available: _utils.Flag, + server: _datagram_server.AsyncDatagramServer[Any, _T_Response, tuple[Any, ...]], + ) -> bool: + return (not service_available.is_set()) or server.is_closing() async def send_packet(self, packet: _T_Response, /) -> None: server = self.__context.server address = self.__context.address - if server.is_closing(): + if self.__is_closing(self.__service_available, server): raise ClientClosedError("Closed client") await server.send_packet_to(packet, address) @@ -328,19 +243,22 @@ def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: class _ClientContext: __slots__ = ( "__lowlevel_client", + "__service_available", "__logger", ) def __init__( self, lowlevel_client: _datagram_server.DatagramClientContext[Any, tuple[Any, ...]], + service_available: _utils.Flag, logger: logging.Logger, ) -> None: self.__lowlevel_client: _datagram_server.DatagramClientContext[Any, tuple[Any, ...]] = lowlevel_client + self.__service_available: _utils.Flag = service_available self.__logger: logging.Logger = logger async def __aenter__(self) -> AsyncDatagramClient[_T_Response]: - return _ClientAPI(self.__lowlevel_client) + return _ClientAPI(self.__lowlevel_client, self.__service_available) async def __aexit__( self, diff --git a/src/easynetwork/servers/standalone_tcp.py b/src/easynetwork/servers/standalone_tcp.py index d9a37a27..b4d856d2 100644 --- a/src/easynetwork/servers/standalone_tcp.py +++ b/src/easynetwork/servers/standalone_tcp.py @@ -27,7 +27,7 @@ from .._typevars import _T_Request, _T_Response from ..lowlevel.api_async.backend.abc import AsyncBackend from ..lowlevel.api_async.backend.utils import BuiltinAsyncBackendLiteral -from ..lowlevel.socket import SocketProxy +from ..lowlevel.socket import SocketAddress, SocketProxy from ..protocol import AnyStreamProtocolType from . import _base from .async_tcp import AsyncTCPNetworkServer @@ -98,16 +98,15 @@ def __init__( runner_options=runner_options, ) - def stop_listening(self) -> None: + def get_addresses(self) -> Sequence[SocketAddress]: """ - Schedules the shutdown of all listener sockets. Thread-safe. + Returns all interfaces to which the server is bound. Thread-safe. - After that, all new connections will be refused, but the server will continue to run and handle - previously accepted connections. - - Further calls to :meth:`is_serving` will return :data:`False`. + Returns: + A sequence of network socket address. + If the server is not serving (:meth:`is_serving` returns :data:`False`), an empty sequence is returned. """ - self._run_sync_or(lambda portal, server: portal.run_sync(server.stop_listening), None) + return self._run_sync_or(lambda portal, server: portal.run_sync(server.get_addresses), ()) def get_sockets(self) -> Sequence[SocketProxy]: """Gets the listeners sockets. Thread-safe. diff --git a/src/easynetwork/servers/standalone_udp.py b/src/easynetwork/servers/standalone_udp.py index 7f6db9b1..e17ee631 100644 --- a/src/easynetwork/servers/standalone_udp.py +++ b/src/easynetwork/servers/standalone_udp.py @@ -27,7 +27,7 @@ from .._typevars import _T_Request, _T_Response from ..lowlevel.api_async.backend.abc import AsyncBackend from ..lowlevel.api_async.backend.utils import BuiltinAsyncBackendLiteral -from ..lowlevel.socket import SocketProxy +from ..lowlevel.socket import SocketAddress, SocketProxy from ..protocol import DatagramProtocol from . import _base from .async_udp import AsyncUDPNetworkServer @@ -81,6 +81,16 @@ def __init__( runner_options=runner_options, ) + def get_addresses(self) -> Sequence[SocketAddress]: + """ + Returns all interfaces to which the server is bound. Thread-safe. + + Returns: + A sequence of network socket address. + If the server is not serving (:meth:`is_serving` returns :data:`False`), an empty sequence is returned. + """ + return self._run_sync_or(lambda portal, server: portal.run_sync(server.get_addresses), ()) + def get_sockets(self) -> Sequence[SocketProxy]: """Gets the listeners sockets. Thread-safe. diff --git a/tests/functional_test/test_communication/test_async/test_server/base.py b/tests/functional_test/test_communication/test_async/test_server/base.py index 217419c2..a6bfc51a 100644 --- a/tests/functional_test/test_communication/test_async/test_server/base.py +++ b/tests/functional_test/test_communication/test_async/test_server/base.py @@ -89,14 +89,6 @@ async def test____server_close____idempotent(self, server: AbstractAsyncNetworkS await server.server_close() await server.server_close() - async def test____server_close____while_server_is_running( - self, - server: AbstractAsyncNetworkServer, - run_server: asyncio.Event, - ) -> None: - await run_server.wait() - await server.server_close() - @pytest.mark.usefixtures("run_server") async def test____serve_forever____error_already_running(self, server: AbstractAsyncNetworkServer) -> None: with pytest.raises(ServerAlreadyRunning): @@ -122,26 +114,6 @@ async def test____serve_forever____shutdown_during_setup( await server.shutdown() assert not event.is_set() - async def test____serve_forever____server_close_during_setup( - self, - server: AbstractAsyncNetworkServer, - enable_eager_tasks: bool, - ) -> None: - event = asyncio.Event() - - async def serve() -> None: - with pytest.raises(ServerClosedError): - await server.serve_forever(is_up_event=event) - - async with asyncio.TaskGroup() as tg: - _ = tg.create_task(serve()) - if not enable_eager_tasks: - await asyncio.sleep(0) - assert not event.is_set() - async with asyncio.timeout(1): - await server.server_close() - assert not event.is_set() - async def test____serve_forever____without_is_up_event( self, server: AbstractAsyncNetworkServer, @@ -166,3 +138,21 @@ async def test____serve_forever____concurrent_shutdown( await run_server.wait() await asyncio.gather(*[server.shutdown() for _ in range(10)]) + + async def test____server_activate____server_close_during_activation( + self, + server_not_activated: AbstractAsyncNetworkServer, + enable_eager_tasks: bool, + ) -> None: + async def serve() -> None: + with pytest.raises(ServerClosedError): + await server_not_activated.server_activate() + + async with asyncio.TaskGroup() as tg: + _ = tg.create_task(serve()) + if not enable_eager_tasks: + await asyncio.sleep(0) + assert not server_not_activated.is_listening() + async with asyncio.timeout(1): + await server_not_activated.server_close() + assert not server_not_activated.is_listening() diff --git a/tests/functional_test/test_communication/test_async/test_server/test_tcp.py b/tests/functional_test/test_communication/test_async/test_server/test_tcp.py index 48a34bdb..4b678459 100644 --- a/tests/functional_test/test_communication/test_async/test_server/test_tcp.py +++ b/tests/functional_test/test_communication/test_async/test_server/test_tcp.py @@ -7,7 +7,7 @@ import ssl from collections.abc import AsyncGenerator, AsyncIterator, Awaitable, Callable, Sequence from socket import IPPROTO_TCP, TCP_NODELAY -from typing import Any, Literal +from typing import Any from weakref import WeakValueDictionary from easynetwork.exceptions import ( @@ -134,7 +134,7 @@ async def handle(self, client: AsyncStreamClient[str]) -> AsyncGenerator[None, s case "__os_error__": raise OSError("Server issue.") case "__stop_listening__": - self.server.stop_listening() + await self.server.server_close() await client.send_packet("successfully stop listening") case "__wait__": while True: @@ -362,6 +362,35 @@ def ssl_standard_compatible(request: Any) -> bool | None: def asyncio_backend() -> AsyncIOBackend: return AsyncIOBackend() + @pytest_asyncio.fixture + @staticmethod + async def server_not_activated( + asyncio_backend: AsyncIOBackend, + request_handler: MyAsyncTCPRequestHandler, + localhost_ip: str, + stream_protocol: AnyStreamProtocolType[str, str], + caplog: pytest.LogCaptureFixture, + logger_crash_threshold_level: dict[str, int], + ) -> AsyncIterator[MyAsyncTCPServer]: + server = MyAsyncTCPServer( + localhost_ip, + 0, + stream_protocol, + request_handler, + asyncio_backend, + backlog=1, + logger=LOGGER, + ) + try: + assert not server.is_listening() + assert not server.get_sockets() + assert not server.get_addresses() + caplog.set_level(logging.INFO, LOGGER.name) + logger_crash_threshold_level[LOGGER.name] = logging.WARNING + yield server + finally: + await server.server_close() + @pytest_asyncio.fixture @staticmethod async def server( @@ -391,8 +420,9 @@ async def server( ssl_standard_compatible=ssl_standard_compatible, logger=LOGGER, ) as server: - assert not server.get_sockets() - assert not server.get_addresses() + assert server.is_listening() + assert server.get_sockets() + assert server.get_addresses() caplog.set_level(logging.INFO, LOGGER.name) logger_crash_threshold_level[LOGGER.name] = logging.WARNING yield server @@ -544,11 +574,14 @@ async def test____serve_forever____empty_listener_list( request_handler: MyAsyncTCPRequestHandler, stream_protocol: AnyStreamProtocolType[str, str], ) -> None: - async with MyAsyncTCPServer(None, 0, stream_protocol, request_handler, NoListenerErrorBackend()) as s: + s = MyAsyncTCPServer(None, 0, stream_protocol, request_handler, NoListenerErrorBackend()) + try: with pytest.raises(OSError, match=r"^empty listeners list$"): - await s.serve_forever() + await s.server_activate() assert not s.get_sockets() + finally: + await s.server_close() @pytest.mark.usefixtures("run_server_and_wait") async def test____serve_forever____server_assignment( @@ -624,22 +657,14 @@ async def test____serve_forever____disable_nagle_algorithm( # (c.f. https://stackoverflow.com/a/31835137) assert tcp_nodelay_state != 0 - @pytest.mark.parametrize("action", ["shutdown", "server_close"]) - async def test____serve_forever____close_during_loop____kill_client_tasks( + async def test____serve_forever____shutdown_during_loop____kill_client_tasks( self, - action: Literal["shutdown", "server_close"], server: MyAsyncTCPServer, client_factory: Callable[[], Awaitable[tuple[asyncio.StreamReader, asyncio.StreamWriter]]], ) -> None: reader, _ = await client_factory() - match action: - case "shutdown": - await server.shutdown() - case "server_close": - await server.server_close() - case _: - pytest.fail("Invalid argument") + await server.shutdown() await asyncio.sleep(0.3) with contextlib.suppress(ConnectionError): @@ -941,9 +966,9 @@ async def test____serve_forever____request_handler_ask_to_stop_accepting_new_con writer.write(b"__stop_listening__\n") assert await reader.readline() == b"successfully stop listening\n" + await asyncio.sleep(0.1) assert not server.is_serving() - await asyncio.sleep(0.1) with pytest.raises(ExceptionGroup) as exc_info: await client_factory() diff --git a/tests/functional_test/test_communication/test_async/test_server/test_udp.py b/tests/functional_test/test_communication/test_async/test_server/test_udp.py index 339bc93b..d1a0306b 100644 --- a/tests/functional_test/test_communication/test_async/test_server/test_udp.py +++ b/tests/functional_test/test_communication/test_async/test_server/test_udp.py @@ -250,6 +250,34 @@ def request_handler(request: Any) -> AsyncDatagramRequestHandler[str, str]: def asyncio_backend() -> AsyncIOBackend: return AsyncIOBackend() + @pytest_asyncio.fixture + @staticmethod + async def server_not_activated( + asyncio_backend: AsyncIOBackend, + request_handler: AsyncDatagramRequestHandler[str, str], + localhost_ip: str, + datagram_protocol: DatagramProtocol[str, str], + caplog: pytest.LogCaptureFixture, + logger_crash_threshold_level: dict[str, int], + ) -> AsyncIterator[MyAsyncUDPServer]: + server = MyAsyncUDPServer( + localhost_ip, + 0, + datagram_protocol, + request_handler, + asyncio_backend, + logger=LOGGER, + ) + try: + assert not server.is_listening() + assert not server.get_sockets() + assert not server.get_addresses() + caplog.set_level(logging.INFO, LOGGER.name) + logger_crash_threshold_level[LOGGER.name] = logging.WARNING + yield server + finally: + await server.server_close() + @pytest_asyncio.fixture @staticmethod async def server( @@ -268,8 +296,9 @@ async def server( asyncio_backend, logger=LOGGER, ) as server: - assert not server.get_sockets() - assert not server.get_addresses() + assert server.is_listening() + assert server.get_sockets() + assert server.get_addresses() caplog.set_level(logging.INFO, LOGGER.name) logger_crash_threshold_level[LOGGER.name] = logging.WARNING yield server @@ -314,11 +343,14 @@ async def test____serve_forever____empty_listener_list( request_handler: MyAsyncUDPRequestHandler, datagram_protocol: DatagramProtocol[str, str], ) -> None: - async with MyAsyncUDPServer(None, 0, datagram_protocol, request_handler, NoListenerErrorBackend()) as s: + s = MyAsyncUDPServer(None, 0, datagram_protocol, request_handler, NoListenerErrorBackend()) + try: with pytest.raises(OSError, match=r"^empty listeners list$"): await s.serve_forever() assert not s.get_sockets() + finally: + await s.server_close() @pytest.mark.usefixtures("run_server_and_wait") async def test____serve_forever____server_assignment( diff --git a/tests/functional_test/test_communication/test_sync/test_server/test_standalone.py b/tests/functional_test/test_communication/test_sync/test_server/test_standalone.py index 00af4a6d..ff2ed042 100644 --- a/tests/functional_test/test_communication/test_sync/test_server/test_standalone.py +++ b/tests/functional_test/test_communication/test_sync/test_server/test_standalone.py @@ -57,13 +57,6 @@ def test____server_close____idempotent(self, server: AbstractNetworkServer) -> N server.server_close() server.server_close() - @pytest.mark.usefixtures("start_server") - def test____server_close____while_server_is_running(self, server: AbstractNetworkServer) -> None: - server.server_close() - - with pytest.raises(ServerClosedError): - server.serve_forever() - @pytest.mark.usefixtures("start_server") def test____serve_forever____error_server_already_running(self, server: AbstractNetworkServer) -> None: with pytest.raises(ServerAlreadyRunning): @@ -92,13 +85,11 @@ def test____serve_forever____serve_several_times(self, server: AbstractNetworkSe with server: for _ in range(3): assert not server.is_serving() - assert not server.get_addresses() server_thread = NetworkServerThread(server, daemon=True) server_thread.start() try: assert server.is_serving() - assert len(server.get_addresses()) > 0 time.sleep(0.5) finally: server_thread.join() @@ -124,12 +115,6 @@ def client(server: StandaloneTCPNetworkServer[str, str], start_server: None) -> with socket.create_connection(("localhost", port)) as client: yield client - def test____stop_listening____default_to_noop(self, server: StandaloneTCPNetworkServer[str, str]) -> None: - with server: - assert not server.get_sockets() - assert not server.get_addresses() - server.stop_listening() - def test____socket_property____server_is_not_running(self, server: StandaloneTCPNetworkServer[str, str]) -> None: with server: assert len(server.get_sockets()) == 0 @@ -140,17 +125,6 @@ def test____socket_property____server_is_running(self, server: StandaloneTCPNetw assert len(server.get_sockets()) > 0 assert len(server.get_addresses()) > 0 - @pytest.mark.usefixtures("start_server", "client") - def test____stop_listening____stop_accepting_new_connection(self, server: StandaloneTCPNetworkServer[str, str]) -> None: - assert server.is_serving() - assert len(server.get_sockets()) > 0 - assert len(server.get_addresses()) > 0 - - server.stop_listening() - assert not server.is_serving() - assert len(server.get_sockets()) > 0 # Sockets are closed, but always available until server_close() call - assert len(server.get_addresses()) == 0 - class TestStandaloneUDPNetworkServer(BaseTestStandaloneNetworkServer): @pytest.fixture diff --git a/tests/functional_test/test_concurrency/conftest.py b/tests/functional_test/test_concurrency/conftest.py index 3bc69401..1ec50d35 100644 --- a/tests/functional_test/test_concurrency/conftest.py +++ b/tests/functional_test/test_concurrency/conftest.py @@ -49,6 +49,7 @@ def _run_server(server: AbstractNetworkServer) -> None: def _retrieve_server_address(server: AbstractNetworkServer) -> tuple[str, int]: + assert isinstance(server, (StandaloneTCPNetworkServer, StandaloneUDPNetworkServer)) address = server.get_addresses()[0] if isinstance(address, IPv4SocketAddress): return "127.0.0.1", address.port diff --git a/tests/scripts/async_server_test.py b/tests/scripts/async_server_test.py index 5cf0533d..4ae6ffbd 100644 --- a/tests/scripts/async_server_test.py +++ b/tests/scripts/async_server_test.py @@ -22,8 +22,6 @@ async def service_init(self, exit_stack: contextlib.AsyncExitStack, server: Abst self.server = server async def handle(self, client: AsyncBaseClientInterface[str]) -> AsyncGenerator[None, str]: - from easynetwork.servers.async_tcp import AsyncTCPNetworkServer - request: str = yield logger.debug(f"Received {request!r} from {client!r}") match request: @@ -31,8 +29,8 @@ async def handle(self, client: AsyncBaseClientInterface[str]) -> AsyncGenerator[ raise RuntimeError("requested error") case "wait:": request = (yield) + " after wait" - case "self_kill:" if isinstance(self.server, AsyncTCPNetworkServer): - self.server.stop_listening() + case "self_kill:": + await self.server.server_close() await client.send_packet("stop_listening() done") return await client.send_packet(request.upper()) diff --git a/tests/unit_test/test_tools/test_utils.py b/tests/unit_test/test_tools/test_utils.py index ef958e66..bdb04af2 100644 --- a/tests/unit_test/test_tools/test_utils.py +++ b/tests/unit_test/test_tools/test_utils.py @@ -17,6 +17,7 @@ from easynetwork.lowlevel._final import runtime_final_class from easynetwork.lowlevel._utils import ( ElapsedTime, + Flag, ResourceGuard, adjust_leftover_buffer, check_real_socket_state, @@ -808,6 +809,22 @@ def test____iterate_exceptions____recursive_yield_exceptions_in_group() -> None: assert all_exceptions == list(itertools.chain(sub_excgrp1.exceptions, sub_excgrp2.exceptions)) +def test____Flag___set_and_reset() -> None: + # Arrange + flag = Flag() + assert not flag.is_set() + + # Act & Assert + flag.set() + assert flag.is_set() + flag.set() + assert flag.is_set() + flag.clear() + assert not flag.is_set() + flag.clear() + assert not flag.is_set() + + def test____ElapsedTime____catch_elapsed_time(mocker: MockerFixture) -> None: # Arrange now: float = 798546132.0 From afcc36ed4afe26b5179a4d14fad65fad047c2993 Mon Sep 17 00:00:00 2001 From: Francis CLAIRICIA-ROSE-CLAIRE-JOSEPHINE Date: Sun, 7 Jul 2024 11:14:53 +0200 Subject: [PATCH 2/5] Little fix in examples --- .../_include/examples/howto/tcp_servers/background_server.py | 1 + .../_include/examples/howto/tcp_servers/background_server_ssl.py | 1 + .../_include/examples/howto/udp_servers/background_server.py | 1 + 3 files changed, 3 insertions(+) diff --git a/docs/source/_include/examples/howto/tcp_servers/background_server.py b/docs/source/_include/examples/howto/tcp_servers/background_server.py index 8c446874..73d1893e 100644 --- a/docs/source/_include/examples/howto/tcp_servers/background_server.py +++ b/docs/source/_include/examples/howto/tcp_servers/background_server.py @@ -63,6 +63,7 @@ async def main() -> None: await client(host, port, "Hello world 3") await server.shutdown() + await server_task if __name__ == "__main__": diff --git a/docs/source/_include/examples/howto/tcp_servers/background_server_ssl.py b/docs/source/_include/examples/howto/tcp_servers/background_server_ssl.py index 51f68306..16f9ffd4 100644 --- a/docs/source/_include/examples/howto/tcp_servers/background_server_ssl.py +++ b/docs/source/_include/examples/howto/tcp_servers/background_server_ssl.py @@ -71,6 +71,7 @@ async def main() -> None: await client(host, port, "Hello world 3") await server.shutdown() + await server_task if __name__ == "__main__": diff --git a/docs/source/_include/examples/howto/udp_servers/background_server.py b/docs/source/_include/examples/howto/udp_servers/background_server.py index a7d2df80..f39ea9df 100644 --- a/docs/source/_include/examples/howto/udp_servers/background_server.py +++ b/docs/source/_include/examples/howto/udp_servers/background_server.py @@ -60,6 +60,7 @@ async def main() -> None: await client(host, port, "Hello world 3") await server.shutdown() + await server_task if __name__ == "__main__": From 6f74e53c21446a0e0f5a5a839228bb95dbb3e2b4 Mon Sep 17 00:00:00 2001 From: Francis CLAIRICIA-ROSE-CLAIRE-JOSEPHINE Date: Sun, 7 Jul 2024 11:51:36 +0200 Subject: [PATCH 3/5] Fixed wrong fixture used for test --- .../test_communication/test_async/test_server/base.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/functional_test/test_communication/test_async/test_server/base.py b/tests/functional_test/test_communication/test_async/test_server/base.py index a6bfc51a..deebfd74 100644 --- a/tests/functional_test/test_communication/test_async/test_server/base.py +++ b/tests/functional_test/test_communication/test_async/test_server/base.py @@ -101,17 +101,17 @@ async def test____serve_forever____error_closed_server(self, server: AbstractAsy async def test____serve_forever____shutdown_during_setup( self, - server: AbstractAsyncNetworkServer, + server_not_activated: AbstractAsyncNetworkServer, enable_eager_tasks: bool, ) -> None: event = asyncio.Event() async with asyncio.TaskGroup() as tg: - _ = tg.create_task(server.serve_forever(is_up_event=event)) + _ = tg.create_task(server_not_activated.serve_forever(is_up_event=event)) if not enable_eager_tasks: await asyncio.sleep(0) assert not event.is_set() async with asyncio.timeout(1): - await server.shutdown() + await server_not_activated.shutdown() assert not event.is_set() async def test____serve_forever____without_is_up_event( From be2fc00ce275f7bb74b93af99295c47c01b8585a Mon Sep 17 00:00:00 2001 From: Francis CLAIRICIA-ROSE-CLAIRE-JOSEPHINE Date: Sun, 7 Jul 2024 12:25:52 +0200 Subject: [PATCH 4/5] Re-added previously removed tests + Fixed standalone servers' server_close() blocking if serve_forever() is running --- src/easynetwork/servers/_base.py | 10 ++++------ .../test_async/test_server/base.py | 13 +++++++++++++ .../test_sync/test_server/test_standalone.py | 8 ++++++++ 3 files changed, 25 insertions(+), 6 deletions(-) diff --git a/src/easynetwork/servers/_base.py b/src/easynetwork/servers/_base.py index 4dd52533..cddc6099 100644 --- a/src/easynetwork/servers/_base.py +++ b/src/easynetwork/servers/_base.py @@ -37,7 +37,9 @@ class _SupportsAclose(Protocol): + @abstractmethod def is_closing(self) -> bool: ... + @abstractmethod def aclose(self) -> Awaitable[object]: ... @@ -115,10 +117,6 @@ def is_serving(self) -> bool: def server_close(self) -> None: with self.__close_lock.get(), contextlib.ExitStack() as stack: stack.callback(self.__is_closed.set) - - # Ensure we are not in the interval between the server shutdown and the scheduler shutdown - stack.callback(self.__is_shutdown.wait) - self._run_sync_or(lambda portal, server: portal.run_coroutine(server.server_close), None) @_utils.inherit_doc(AbstractNetworkServer) @@ -171,12 +169,12 @@ def serve_forever( # locks_stack is used to acquire locks until # serve_forever() coroutine creates the thread portal locks_stack = server_exit_stack.enter_context(contextlib.ExitStack()) - locks_stack.enter_context(self.__close_lock.get()) - locks_stack.enter_context(self.__bootstrap_lock.get()) + locks_stack.enter_context(self.__close_lock.get()) if self.__is_closed.is_set(): raise ServerClosedError("Closed server") + locks_stack.enter_context(self.__bootstrap_lock.get()) if not self.__is_shutdown.is_set(): raise ServerAlreadyRunning("Server is already running") diff --git a/tests/functional_test/test_communication/test_async/test_server/base.py b/tests/functional_test/test_communication/test_async/test_server/base.py index deebfd74..db925c68 100644 --- a/tests/functional_test/test_communication/test_async/test_server/base.py +++ b/tests/functional_test/test_communication/test_async/test_server/base.py @@ -89,6 +89,19 @@ async def test____server_close____idempotent(self, server: AbstractAsyncNetworkS await server.server_close() await server.server_close() + async def test____server_close____while_server_is_running( + self, + server: AbstractAsyncNetworkServer, + run_server: asyncio.Event, + ) -> None: + await run_server.wait() + await server.server_close() + await asyncio.sleep(0.5) + + # There is no client so the server loop should stop by itself + assert not server.is_serving() + assert not server.is_listening() + @pytest.mark.usefixtures("run_server") async def test____serve_forever____error_already_running(self, server: AbstractAsyncNetworkServer) -> None: with pytest.raises(ServerAlreadyRunning): diff --git a/tests/functional_test/test_communication/test_sync/test_server/test_standalone.py b/tests/functional_test/test_communication/test_sync/test_server/test_standalone.py index ff2ed042..35445a53 100644 --- a/tests/functional_test/test_communication/test_sync/test_server/test_standalone.py +++ b/tests/functional_test/test_communication/test_sync/test_server/test_standalone.py @@ -45,6 +45,14 @@ def test____shutdown____default_to_noop(self, server: AbstractNetworkServer) -> with server: server.shutdown() + @pytest.mark.usefixtures("start_server") + def test____server_close____while_server_is_running(self, server: AbstractNetworkServer) -> None: + server.server_close() + time.sleep(0.5) + + # There is no client so the server loop should stop by itself + assert not server.is_serving() + @pytest.mark.usefixtures("start_server") def test____shutdown____while_server_is_running(self, server: AbstractNetworkServer) -> None: assert server.is_serving() From 92eb9cfcde936c188bf319dc6d45db947b5e9ad2 Mon Sep 17 00:00:00 2001 From: Francis CLAIRICIA-ROSE-CLAIRE-JOSEPHINE Date: Sun, 7 Jul 2024 12:35:30 +0200 Subject: [PATCH 5/5] Fix invalid docstrings --- src/easynetwork/servers/abc.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/easynetwork/servers/abc.py b/src/easynetwork/servers/abc.py index 9ecb8965..a0674098 100644 --- a/src/easynetwork/servers/abc.py +++ b/src/easynetwork/servers/abc.py @@ -148,7 +148,6 @@ async def serve_forever(self, *, is_up_event: SupportsEventSet | None = ...) -> Raises: ServerClosedError: The server is closed. ServerAlreadyRunning: Another task already called :meth:`serve_forever`. - ServerNotActivated: :meth:`server_activate` must be used before calling :meth:`serve_forever`. """ raise NotImplementedError @@ -166,8 +165,6 @@ async def server_activate(self) -> None: This method MUST be idempotent. Further calls to :meth:`is_listening` will return :data:`True`. - To stop and close the listeners, you can use :meth:`stop_listening`. - Raises: ServerClosedError: The server is closed. """