diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 0e4ab802..3b0555a7 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -119,16 +119,12 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - - uses: actions/setup-python@v5 + - uses: astral-sh/setup-uv@v6 with: - python-version: "3.x" - - name: install - run: | - python -m pip install -U pip - python -m pip install -e . - python -m pip install pytest pytest-mypy-plugins + enable-cache: true + cache-dependency-glob: "**/pyproject.toml" - name: test - run: pytest typesafety --mypy-only-local-stub --color=yes + run: uv run --no-dev --group test-typing pytest typesafety --mypy-only-local-stub --color=yes benchmarks: runs-on: ubuntu-latest diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3f1f14bb..54e9221a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -32,3 +32,5 @@ repos: - types-attrs - pydantic - typing_extensions + - anyio + - trio diff --git a/docs/guides/async.md b/docs/guides/async.md new file mode 100644 index 00000000..0857fd9f --- /dev/null +++ b/docs/guides/async.md @@ -0,0 +1,164 @@ +# Usage with `async/await` + +Psygnal can be made to work with asynchronous functions (those defined with +`async def`) by setting an async backend. The pattern is slightly different +depending on the async framework you are using, but the general idea is the same. + +We currently support: + +- [x] [`asyncio`](https://docs.python.org/3/library/asyncio.html) +- [x] [`anyio`](https://anyio.readthedocs.io/) +- [x] [`trio`](https://trio.readthedocs.io/) + +## Premise + +Assume you have a class that emits signals, and an async function that you want +to use as a callback: + +```python +from psygnal import Signal + + +class MyObj: + value_changed = Signal(str) + + def set_value(self, value: str) -> None: + self.value_changed.emit(value) + + +async def on_value_changed(new_value: str) -> None: + """Callback function that will be called when the value changes.""" + print(f"The new value is {new_value!r}") +``` + +## Connecting Async Callbacks + +To connect the `value_changed` signal to the `on_value_changed` async function, +we need to: + +1. Set up the async backend using [`set_async_backend()`][psygnal.set_async_backend], + *inside* an async context. +2. Wait for the backend to be ready. +3. Connect the async function to the signal. + +Then whenever `set_value()` is called, the `on_value_changed` async function will be +called asynchronously. + +!!! tip "Order matters!" + + Failure to call `set_async_backend()` before connecting an async callback + will result in `RuntimeError`. + + Failure to wait for the backend to be ready before connecting an async + callback will result in a `RuntimeWarning`, and the callback will not + be called. + +=== "asyncio" + + ```python + import asyncio + + from psygnal import set_async_backend + + + async def main() -> None: + backend = set_async_backend("asyncio") # (1)! + + # Set up the async backend and wait for it to be ready + await backend.running.wait() # (2)! + + # Create an instance of MyObj and connect the async callback + obj = MyObj() + obj.value_changed.connect(on_value_changed) # (3)! + + # Set a value to trigger the callback + obj.set_value("hello!") + + # Give the callback time to execute + await asyncio.sleep(0.01) + + + if __name__ == "__main__": + asyncio.run(main()) + ``` + + 1. Call `psygnal.set_async_backend("asyncio")`. This immediately creates + a task to process the queues. + 2. Wait for the backend to be ready before connecting the signal. + 3. Connect the signal to the async callback function. + +=== "AnyIO" + + ```python + import anyio + + from psygnal import set_async_backend + + + async def main() -> None: + backend = set_async_backend("anyio") # (1)! + + async with anyio.create_task_group() as tg: + # Set up the async backend and wait for it to be ready before connecting + tg.start_soon(backend.run) # (2)! + await backend.running.wait() # (3)! + + # Create an instance of MyObj and connect the async callback + obj = MyObj() + obj.value_changed.connect(on_value_changed) # (4)! + + # Set a value to trigger the callback + obj.set_value("hello!") + + # Give the callback time to execute + await anyio.sleep(0.01) + + tg.cancel_scope.cancel() + + + if __name__ == "__main__": + anyio.run(main) + ``` + + 1. Call `psygnal.set_async_backend("anyio")` to create send/receive queues. + 2. Start watching the queues in the background using `backend.run()`. + 3. Wait for the backend to be ready before connecting the signal. + 4. Connect the signal to the async callback function. + +=== "trio" + + ```python + import trio + + from psygnal import set_async_backend + + + async def main() -> None: + backend = set_async_backend("trio") # (1)! + + async with trio.open_nursery() as nursery: + # Set up the async backend and wait for it to be ready before connecting + nursery.start_soon(backend.run) # (2)! + await backend.running.wait() # (3)! + + # Create an instance of MyObj and connect the async callback + obj = MyObj() + obj.value_changed.connect(on_value_changed) # (4)! + + # Set a value to trigger the callback + obj.set_value("hello!") + + # Give the callback time to execute + await trio.sleep(0.01) + + nursery.cancel_scope.cancel() + + + if __name__ == "__main__": + trio.run(main) + ``` + + 1. Call `psygnal.set_async_backend("trio")` to create send/receive channels. + 2. Start watching the channels in the background using `backend.run()`. + 3. Wait for the backend to be ready before connecting the signal. + 4. Connect the signal to the async callback function. diff --git a/mkdocs.yml b/mkdocs.yml index 3522acd8..b97e4bd1 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -24,6 +24,7 @@ nav: - Evented Dataclasses: guides/dataclasses.md - Evented Pydantic Model: guides/model.md - Throttling & Debouncing: guides/throttler.md + - Coroutines (async/await): guides/async.md - Testing: guides/testing.md - Debugging: guides/debugging.md @@ -54,6 +55,8 @@ theme: # - toc.follow # - content.code.annotate - content.tabs.link + - content.code.copy + - content.code.annotate markdown_extensions: - admonition @@ -64,6 +67,8 @@ markdown_extensions: alternate_style: true - toc: permalink: "#" + - pymdownx.tasklist: + custom_checkbox: true plugins: - search diff --git a/pyproject.toml b/pyproject.toml index 7d66bd29..903b93f3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,18 +36,21 @@ proxy = ["wrapt"] pydantic = ["pydantic"] [dependency-groups] +test-min = ["pytest>=6.0", "pytest-cov", "pytest-asyncio"] test = [ + { include-group = "test-min" }, "dask[array]>=2024.0.0", "attrs", "numpy >1.21.6", "pydantic", "pyinstaller>=4.0", - "pytest>=6.0", - "pytest-cov", "wrapt", "msgspec", "toolz", + "anyio", + "trio", ] +test-typing = [{ include-group = "test-min" }, "pytest-mypy-plugins"] testqt = [{ include-group = "test" }, "pytest-qt", "qtpy"] test-codspeed = [{ include-group = "test" }, "pytest-codspeed"] docs = [ @@ -114,6 +117,7 @@ exclude = [ "src/psygnal/qt.py", "src/psygnal/_pyinstaller_util", "src/psygnal/_throttler.py", + "src/psygnal/_async.py", "src/psygnal/testing.py", ] @@ -177,6 +181,7 @@ docstring-code-format = true [tool.pytest.ini_options] minversion = "6.0" testpaths = ["tests"] +asyncio_default_fixture_loop_scope = "function" addopts = ["--color=yes"] filterwarnings = [ "error", @@ -184,6 +189,8 @@ filterwarnings = [ "ignore:.*BackendFinder.find_spec()", # pyinstaller import "ignore:.*not using a cooperative constructor:pytest.PytestDeprecationWarning:", "ignore:Failed to disconnect::pytestqt", + "ignore:.*unclosed.*socket.*:ResourceWarning:", # asyncio internal socket cleanup + "ignore:.*unclosed event loop.*:ResourceWarning:", # asyncio internal event loop cleanup ] # https://mypy.readthedocs.io/en/stable/config_file.html diff --git a/src/psygnal/__init__.py b/src/psygnal/__init__.py index a24fa73b..f6293a36 100644 --- a/src/psygnal/__init__.py +++ b/src/psygnal/__init__.py @@ -32,8 +32,10 @@ "debounced", "emit_queued", "evented", + "get_async_backend", "get_evented_namespace", "is_evented", + "set_async_backend", "throttled", ] @@ -49,6 +51,7 @@ stacklevel=2, ) +from ._async import get_async_backend, set_async_backend from ._evented_decorator import evented from ._exceptions import EmitLoopError from ._group import EmissionInfo, PathStep, SignalGroup diff --git a/src/psygnal/_async.py b/src/psygnal/_async.py new file mode 100644 index 00000000..5f2dfd06 --- /dev/null +++ b/src/psygnal/_async.py @@ -0,0 +1,245 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from math import inf +from typing import TYPE_CHECKING, overload + +if TYPE_CHECKING: + from collections.abc import Coroutine + from typing import Any, Protocol + + import anyio.streams.memory + import trio + from typing_extensions import Literal, TypeAlias + + from psygnal._weak_callback import WeakCallback + + SupportedBackend: TypeAlias = Literal["asyncio", "anyio", "trio"] + QueueItem: TypeAlias = tuple["WeakCallback", tuple[Any, ...]] + + class EventLike(Protocol): + def is_set(self) -> bool: + """Return ``True`` if the flag is set, ``False`` if not.""" + ... + + async def wait(self) -> Coroutine | bool | None: + """Wait until the flag is set.""" + ... + + +_ASYNC_BACKEND: _AsyncBackend | None = None + + +def get_async_backend() -> _AsyncBackend | None: + """Get the current async backend. Returns None if no backend is set.""" + return _ASYNC_BACKEND + + +def clear_async_backend() -> None: + """Clear the current async backend. Primarily for testing purposes.""" + global _ASYNC_BACKEND + if _ASYNC_BACKEND is not None: + # Cancel any running tasks if it's asyncio and loop is not closed + if isinstance(_ASYNC_BACKEND, AsyncioBackend): + _ASYNC_BACKEND.close() + # Close anyio streams + elif isinstance(_ASYNC_BACKEND, AnyioBackend): + _ASYNC_BACKEND.close() + # Close trio channels + elif isinstance(_ASYNC_BACKEND, TrioBackend): + if hasattr(_ASYNC_BACKEND, "_send_channel"): + _ASYNC_BACKEND._send_channel.close() + # Note: trio receive channels don't have a close method + _ASYNC_BACKEND = None + + +@overload +def set_async_backend(backend: Literal["asyncio"]) -> AsyncioBackend: ... +@overload +def set_async_backend(backend: Literal["anyio"]) -> AnyioBackend: ... +@overload +def set_async_backend(backend: Literal["trio"]) -> TrioBackend: ... +def set_async_backend(backend: SupportedBackend = "asyncio") -> _AsyncBackend: + """Set the async backend to use. Must be one of: 'asyncio', 'anyio', 'trio'. + + This should be done as early as possible, and *must* be called before calling + `SignalInstance.connect` with a coroutine function. + """ + global _ASYNC_BACKEND + + if _ASYNC_BACKEND and _ASYNC_BACKEND._backend != backend: # pragma: no cover + # allow setting the same backend multiple times, for tests + raise RuntimeError(f"Async backend already set to: {_ASYNC_BACKEND._backend}") + + if backend == "asyncio": + _ASYNC_BACKEND = AsyncioBackend() + elif backend == "anyio": + _ASYNC_BACKEND = AnyioBackend() + elif backend == "trio": + _ASYNC_BACKEND = TrioBackend() + else: # pragma: no cover + raise RuntimeError( + f"Async backend not supported: {backend}. " + "Must be one of: 'asyncio', 'anyio', 'trio'" + ) + + return _ASYNC_BACKEND + + +class _AsyncBackend(ABC): + def __init__(self, backend: str): + self._backend = backend + + @property + @abstractmethod + def running(self) -> EventLike: ... + + @abstractmethod + def put(self, item: QueueItem) -> None: ... + + @abstractmethod + async def run(self) -> None: ... + + async def call_back(self, item: QueueItem) -> None: + cb, args = item + if func := cb.dereference(): + await func(*args) + + +class AsyncioBackend(_AsyncBackend): + def __init__(self) -> None: + super().__init__("asyncio") + import asyncio + + self._asyncio = asyncio + self._queue: asyncio.Queue[tuple] = asyncio.Queue() + self._task = asyncio.create_task(self.run()) + self._loop = asyncio.get_running_loop() + self._running = asyncio.Event() + + @property + def running(self) -> EventLike: + """Return the event indicating if the backend is running.""" + return self._running + + def put(self, item: QueueItem) -> None: + self._queue.put_nowait(item) + + def close(self) -> None: + """Close the asyncio backend and cancel tasks.""" + if hasattr(self, "_task") and not self._task.done(): + self._task.cancel() + + async def run(self) -> None: + if self._running.is_set(): + return + + self._running.set() + try: + while True: + item = await self._queue.get() + try: + await self.call_back(item) + except Exception: + # Log the exception but continue running + # This prevents one bad callback from crashing the backend + import traceback + + traceback.print_exc() + except self._asyncio.CancelledError: + pass + except RuntimeError as e: # pragma: no cover + if not self._loop.is_closed(): + raise e + finally: + self._running.clear() + + +class AnyioBackend(_AsyncBackend): + _send_stream: anyio.streams.memory.MemoryObjectSendStream[QueueItem] + _receive_stream: anyio.streams.memory.MemoryObjectReceiveStream[QueueItem] + + def __init__(self) -> None: + super().__init__("anyio") + import anyio + + self._anyio = anyio + self._send_stream, self._receive_stream = anyio.create_memory_object_stream( + max_buffer_size=inf + ) + self._running = anyio.Event() + + @property + def running(self) -> EventLike: + """Return the event indicating if the backend is running.""" + return self._running + + def put(self, item: QueueItem) -> None: + self._send_stream.send_nowait(item) + + def close(self) -> None: + """Close the anyio streams.""" + if hasattr(self, "_send_stream"): + self._send_stream.close() + if hasattr(self, "_receive_stream"): + self._receive_stream.close() + + async def run(self) -> None: + if self._running.is_set(): + return # pragma: no cover + + self._running.set() + try: + async with self._receive_stream: + async for item in self._receive_stream: + try: + await self.call_back(item) + except Exception: + # Log the exception but continue running + import traceback + + traceback.print_exc() + finally: + self._running = self._anyio.Event() + # Ensure streams are closed + self.close() + + +class TrioBackend(_AsyncBackend): + _send_channel: trio._channel.MemorySendChannel[QueueItem] + _receive_channel: trio.abc.ReceiveChannel[QueueItem] + + def __init__(self) -> None: + super().__init__("trio") + import trio + + self._trio = trio + self._send_channel, self._receive_channel = trio.open_memory_channel( + max_buffer_size=inf + ) + self._running = self._trio.Event() + + @property + def running(self) -> EventLike: + """Return the event indicating if the backend is running.""" + return self._running + + def put(self, item: tuple) -> None: + self._send_channel.send_nowait(item) + + async def run(self) -> None: + if self._running.is_set(): + return # pragma: no cover + + self._running.set() + try: + async for item in self._receive_channel: + try: + await self.call_back(item) + except Exception: + # Log the exception but continue running + import traceback + + traceback.print_exc() + finally: + self._running = self._trio.Event() diff --git a/src/psygnal/_weak_callback.py b/src/psygnal/_weak_callback.py index fd2c3872..d06f1c55 100644 --- a/src/psygnal/_weak_callback.py +++ b/src/psygnal/_weak_callback.py @@ -1,6 +1,8 @@ from __future__ import annotations +import inspect import sys +import warnings import weakref from functools import partial from types import BuiltinMethodType, FunctionType, MethodType, MethodWrapperType @@ -16,12 +18,15 @@ ) from warnings import warn +from ._async import get_async_backend from ._mypyc import mypyc_attr if TYPE_CHECKING: import toolz from typing_extensions import TypeAlias, TypeGuard # py310 + from ._async import _AsyncBackend + RefErrorChoice: TypeAlias = Literal["raise", "warn", "ignore"] __all__ = ["WeakCallback", "weak_callback"] @@ -124,14 +129,46 @@ def _on_delete(weak_cb): kwargs = cb.keywords cb = cb.func + is_coro = inspect.iscoroutinefunction(cb) + if is_coro: + if (backend := get_async_backend()) is None: + raise RuntimeError( + "Cannot create async callback yet... No async backend set. " + "Please call `psygnal.set_async_backend()` before connecting." + ) + if not backend.running.is_set(): + warnings.warn( + f"\n\nConnection of async {cb.__name__!r} will not do anything!\n" + "Async backend not running. Launch `get_async_backend().run()` " + "in a background task and wait for `backend.running`", + RuntimeWarning, + stacklevel=2, + ) + if isinstance(cb, FunctionType): - return ( - StrongFunction(cb, max_args, args, kwargs, priority=priority) - if strong_func - else WeakFunction( + # NOTE: I know it looks like this should be easy to express in much shorter + # syntax ... but mypyc will likely fail at runtime. + # Make sure to test compiled version if you change this. + if strong_func: + if is_coro: + return StrongCoroutineFunction( + cb, max_args, args, kwargs, priority=priority + ) + return StrongFunction(cb, max_args, args, kwargs, priority=priority) + else: + if is_coro: + return WeakCoroutineFunction( + cb, + max_args, + args, + kwargs, + finalize, + on_ref_error=on_ref_error, + priority=priority, + ) + return WeakFunction( cb, max_args, args, kwargs, finalize, on_ref_error, priority=priority ) - ) if isinstance(cb, MethodType): if getattr(cb, "__name__", None) == "__setitem__": @@ -145,6 +182,11 @@ def _on_delete(weak_cb): return WeakSetitem( obj, key, max_args, finalize, on_ref_error, priority=priority ) + + if is_coro: + return WeakCoroutineMethod( + cb, max_args, args, kwargs, finalize, on_ref_error, priority=priority + ) return WeakMethod( cb, max_args, args, kwargs, finalize, on_ref_error, priority=priority ) @@ -335,6 +377,8 @@ def _cb(_: weakref.ReferenceType) -> None: class StrongFunction(WeakCallback): """Wrapper around a strong function reference.""" + _f: Callable + def __init__( self, obj: Callable, @@ -580,3 +624,37 @@ def cb(self, args: tuple[Any, ...] = ()) -> None: def dereference(self) -> partial | None: obj = self._obj_ref() return None if obj is None else partial(obj.__setitem__, self._itemkey) + + +# --------------------------- Coroutines --------------------------- + + +class WeakCoroutineFunction(WeakFunction): + def cb(self, args: tuple[Any, ...] = ()) -> None: + if self._f() is None: + raise ReferenceError("weakly-referenced object no longer exists") + if self._max_args is not None: + args = args[: self._max_args] + + cast("_AsyncBackend", get_async_backend()).put((self, args)) + + +class StrongCoroutineFunction(StrongFunction): + """Wrapper around a strong coroutine function reference.""" + + def cb(self, args: tuple[Any, ...] = ()) -> None: + if self._max_args is not None: + args = args[: self._max_args] + + cast("_AsyncBackend", get_async_backend()).put((self, args)) + + +class WeakCoroutineMethod(WeakMethod): + def cb(self, args: tuple[Any, ...] = ()) -> None: + if self._obj_ref() is None or self._func_ref() is None: + raise ReferenceError("weakly-referenced object no longer exists") + + if self._max_args is not None: + args = args[: self._max_args] + + cast("_AsyncBackend", get_async_backend()).put((self, args)) diff --git a/tests/test_coroutines.py b/tests/test_coroutines.py new file mode 100644 index 00000000..bd8ecc95 --- /dev/null +++ b/tests/test_coroutines.py @@ -0,0 +1,473 @@ +from __future__ import annotations + +import asyncio +import gc +import importlib.util +import signal +from typing import TYPE_CHECKING, Any, Callable, Literal, Protocol +from unittest.mock import Mock + +import pytest +import pytest_asyncio + +from psygnal import _async +from psygnal._weak_callback import WeakCallback, weak_callback + +if TYPE_CHECKING: + from collections.abc import Iterator + +# Available backends for parametrization +AVAILABLE_BACKENDS = ["asyncio"] +if importlib.util.find_spec("trio") is not None: + AVAILABLE_BACKENDS.append("trio") +if importlib.util.find_spec("anyio") is not None: + AVAILABLE_BACKENDS.append("anyio") + + +class BackendTestRunner(Protocol): + """Protocol for backend-specific test runners.""" + + @property + def backend_name(self) -> Literal["asyncio", "anyio", "trio"]: + """Name of the backend being used.""" + ... + + async def sleep(self, duration: float) -> None: + """Sleep for the given duration using backend-specific sleep.""" + ... + + def run_with_backend(self, test_func: Callable[[], Any]) -> Any: + """Run a test function with proper backend setup and teardown. Synchronous.""" + ... + + +class AsyncioTestRunner: + """Test runner for asyncio backend.""" + + @property + def backend_name(self) -> Literal["asyncio"]: + return "asyncio" + + async def sleep(self, duration: float) -> None: + await asyncio.sleep(duration) + + def run_with_backend(self, test_func: Callable[[], Any]) -> Any: + """Run test with asyncio backend.""" + + async def _run_test() -> Any: + _async.clear_async_backend() + backend = _async.set_async_backend("asyncio") + + # Wait for backend to be running + await self._wait_for_backend_running(backend) + + try: + return await test_func() + finally: + # Cleanup + if hasattr(backend, "_task") and not backend._task.done(): + backend._task.cancel() + try: + await backend._task + except asyncio.CancelledError: + pass + _async.clear_async_backend() + + return asyncio.run(_run_test()) + + async def _wait_for_backend_running( + self, backend: _async._AsyncBackend, timeout: float = 1.0 + ) -> None: + """Wait for backend to be running with a timeout.""" + start_time = asyncio.get_event_loop().time() + while not backend.running.is_set(): + if asyncio.get_event_loop().time() - start_time > timeout: + raise TimeoutError("Backend did not start running within timeout") + await asyncio.sleep(0) + + +class AnyioTestRunner: + """Test runner for anyio backend.""" + + @property + def backend_name(self) -> Literal["anyio"]: + return "anyio" + + async def sleep(self, duration: float) -> None: + import anyio + + await anyio.sleep(duration) + + def run_with_backend(self, test_func: Callable[[], Any]) -> Any: + """Run test with anyio backend using structured concurrency.""" + import anyio + + async def _run_test(): + _async.clear_async_backend() + backend = _async.set_async_backend("anyio") + + result = None + async with anyio.create_task_group() as tg: + tg.start_soon(backend.run) + + # Wait for backend to be running + await backend.running.wait() + + try: + result = await test_func() + finally: + # Cancel task group to shutdown properly + tg.cancel_scope.cancel() + + _async.clear_async_backend() + return result + + return anyio.run(_run_test) + + +class TrioTestRunner: + """Test runner for trio backend.""" + + @property + def backend_name(self) -> Literal["trio"]: + return "trio" + + async def sleep(self, duration: float) -> None: + import trio + + await trio.sleep(duration) + + def run_with_backend(self, test_func: Callable[[], Any]) -> Any: + """Run test with trio backend using structured concurrency.""" + + # On Windows asyncio has probably left its FD installed + try: + signal.set_wakeup_fd(-1) # restore default + except (ValueError, AttributeError): + pass # not the main thread or not supported + + import trio + + result = None + + async def _trio_main(): + nonlocal result + _async.clear_async_backend() + backend = _async.set_async_backend("trio") + + # Use a timeout to prevent hanging + with trio.move_on_after(5.0) as cancel_scope: + async with trio.open_nursery() as nursery: + nursery.start_soon(backend.run) + + # Wait for backend to be running + await backend.running.wait() + + try: + result = await test_func() + finally: + # Cancel nursery to shutdown properly + nursery.cancel_scope.cancel() + + # Check if we timed out + if cancel_scope.cancelled_caught: + raise TimeoutError("Test timed out") + + _async.clear_async_backend() + + # Run in trio context + trio.run(_trio_main) + return result + + +async def mock_call_count( + mock: Mock, runner: BackendTestRunner, max_iterations: int = 100 +) -> None: + """Wait for callback execution with backend-specific sleep.""" + for _ in range(max_iterations): + await runner.sleep(0.01) + if mock.call_count > 0: + break + + +@pytest_asyncio.fixture +async def clean_async_backend(): + """Fixture to ensure clean async backend state.""" + _async.clear_async_backend() + yield + _async.clear_async_backend() + + +@pytest.fixture(params=AVAILABLE_BACKENDS) +def runner( + request: pytest.FixtureRequest, clean_async_backend: None +) -> Iterator[BackendTestRunner]: + """Get the backend runner for the specified backend.""" + mapping: dict[str, type[BackendTestRunner]] = { + "asyncio": AsyncioTestRunner, + "anyio": AnyioTestRunner, + "trio": TrioTestRunner, + } + yield mapping[request.param]() + + +# Parametrized tests for all backends +@pytest.mark.parametrize( + "slot_type", + [ + "coroutinefunc", + "weak_coroutinefunc", + "coroutinemethod", + ], +) +def test_slot_types_all_backends(runner: BackendTestRunner, slot_type: str) -> None: + """Test async slot types with all available backends.""" + + async def _test_slot_type(): + mock = Mock() + final_mock = Mock() + + if slot_type in {"coroutinefunc", "weak_coroutinefunc"}: + + async def test_obj(x: int) -> int: + mock(x) + return x + + cb = weak_callback( + test_obj, + strong_func=(slot_type == "coroutinefunc"), + finalize=final_mock, + ) + elif slot_type == "coroutinemethod": + + class MyObj: + async def coroutine_method(self, x: int) -> int: + mock(x) + return x + + obj = MyObj() + cb = weak_callback(obj.coroutine_method, finalize=final_mock) + + assert isinstance(cb, WeakCallback) + assert isinstance(cb.slot_repr(), str) + assert cb.dereference() is not None + + # Test callback execution + cb.cb((2,)) + await mock_call_count(mock, runner) + mock.assert_called_once_with(2) + + # Test direct await + mock.reset_mock() + result = await cb(4) + assert result == 4 + mock.assert_called_once_with(4) + + # Test weak reference cleanup + if slot_type in {"coroutinefunc", "weak_coroutinefunc"}: + del test_obj + else: + del obj + gc.collect() + + if slot_type == "coroutinefunc": # strong_func + cb.cb((4,)) + await mock_call_count(mock, runner) + mock.assert_called_with(4) + else: + await mock_call_count(final_mock, runner) + final_mock.assert_called_once_with(cb) + assert cb.dereference() is None + with pytest.raises(ReferenceError): + cb.cb((2,)) + with pytest.raises(ReferenceError): + await cb(2) + + # Run the test with the appropriate backend + runner.run_with_backend(_test_slot_type) + + +def test_backend_error_conditions(runner: BackendTestRunner) -> None: + """Test backend error conditions and exception handling.""" + + async def _test_errors(): + mock = Mock() + + async def test_coro(x: int) -> int: + if x == 999: + raise ValueError("Test error") + mock(x) + return x + + cb = weak_callback(test_coro, strong_func=True) + + # Test normal execution + cb.cb((5,)) + await mock_call_count(mock, runner) + mock.assert_called_once_with(5) + + # Test error case - should not crash the backend + cb.cb((999,)) + await runner.sleep(0.1) # Give time for error to be handled + + # Backend should still work after error + mock.reset_mock() + cb.cb((10,)) + await mock_call_count(mock, runner) + mock.assert_called_once_with(10) + + # Run the test with the backend runner + runner.run_with_backend(_test_errors) + + +@pytest.mark.usefixtures("clean_async_backend") +@pytest.mark.asyncio +async def test_run_method_early_return() -> None: + """Test that run() method returns early if backend is already running.""" + backend = _async.set_async_backend("asyncio") + + # Wait for backend to be running + start_time = asyncio.get_event_loop().time() + while not backend.running.is_set(): + if asyncio.get_event_loop().time() - start_time > 1.0: + raise TimeoutError("Backend did not start running within timeout") + await asyncio.sleep(0) + + # Now calling run() again should return early + await backend.run() + + # Backend should still be running + assert backend.running.is_set() + + +@pytest.mark.parametrize("backend_name", AVAILABLE_BACKENDS) +def test_high_level_api(backend_name: Literal["trio", "asyncio", "anyio"]) -> None: + """Test the exact usage pattern shown in the feature summary documentation.""" + + def run_example() -> None: + """The example from the feature summary, adapted for testing.""" + + async def example_main() -> None: + # Step 1: Set Backend Early (Once Per Application) + backend = _async.set_async_backend(backend_name) + + # Step 2: Launch Backend in Your Event Loop (backend-specific) + if backend_name == "asyncio": + import asyncio + + # Start the backend as a background task + async_backend = _async.get_async_backend() + assert async_backend is not None + task = asyncio.create_task(async_backend.run()) + + # Wait for backend to be running + await backend.running.wait() + + elif backend_name == "anyio": + import anyio + + async with anyio.create_task_group() as tg: + # Start the backend in the task group + async_backend = _async.get_async_backend() + assert async_backend is not None + tg.start_soon(async_backend.run) + + # Wait for backend to be running + await backend.running.wait() + + # Run the actual example + await run_signal_example() + + # Cancel to exit cleanly + tg.cancel_scope.cancel() + return + + elif backend_name == "trio": + import trio + + async with trio.open_nursery() as nursery: + # Start the backend in the nursery + async_backend = _async.get_async_backend() + assert async_backend is not None + nursery.start_soon(async_backend.run) + + # Wait for backend to be running + await backend.running.wait() + + # Run the actual example + await run_signal_example() + + # Cancel to exit cleanly + nursery.cancel_scope.cancel() + return + + # For asyncio, run the example after backend is started + try: + await run_signal_example() + finally: + if backend_name == "asyncio": + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + async def run_signal_example() -> None: + """Step 3: Connect Async Callbacks - the exact example from docs.""" + from psygnal import Signal + + class MyObj: + value_changed = Signal(str) + + def set_value(self, value: str) -> None: + self.value_changed.emit(value) + + # Track calls for testing + mock = Mock() + + async def on_value_changed(new_value: str) -> None: + mock(new_value) + + obj = MyObj() + obj.value_changed.connect(on_value_changed) + obj.set_value("hello!") + + # Wait for callback to execute + max_wait = 100 + for _ in range(max_wait): + if mock.call_count > 0: + break + if backend_name == "asyncio": + await asyncio.sleep(0.01) + elif backend_name == "anyio": + import anyio + + await anyio.sleep(0.01) + elif backend_name == "trio": + import trio + + await trio.sleep(0.01) + + # Verify the callback was called with the expected value + assert mock.call_count == 1 + assert mock.call_args[0][0] == "hello!" + + # Run the example with the appropriate backend + if backend_name == "asyncio": + return asyncio.run(example_main()) + elif backend_name == "anyio": + import anyio + + return anyio.run(example_main) + elif backend_name == "trio": + import trio + + return trio.run(example_main) + + # Clear any existing backend before test + _async.clear_async_backend() + try: + run_example() + finally: + # Clean up after test + _async.clear_async_backend()