diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 86f323c8..071e84f8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -32,3 +32,5 @@ repos: - types-attrs - pydantic - typing_extensions + - anyio + - trio diff --git a/pyproject.toml b/pyproject.toml index 4111073e..def54ead 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,6 +51,7 @@ test = [ "pyinstaller>=4.0", "pytest>=6.0", "pytest-cov", + "pytest-asyncio", "wrapt", "msgspec", "toolz", @@ -110,6 +111,7 @@ exclude = [ "src/psygnal/qt.py", "src/psygnal/_pyinstaller_util", "src/psygnal/_throttler.py", + "src/psygnal/_async.py", ] [tool.cibuildwheel] @@ -171,6 +173,7 @@ docstring-code-format = true [tool.pytest.ini_options] minversion = "6.0" testpaths = ["tests"] +asyncio_default_fixture_loop_scope = "function" filterwarnings = [ "error", "ignore:The distutils package is deprecated:DeprecationWarning:", diff --git a/src/psygnal/__init__.py b/src/psygnal/__init__.py index 8ffd9483..68c5a180 100644 --- a/src/psygnal/__init__.py +++ b/src/psygnal/__init__.py @@ -31,8 +31,10 @@ "debounced", "emit_queued", "evented", + "get_async_backend", "get_evented_namespace", "is_evented", + "set_async_backend", "throttled", ] @@ -48,6 +50,7 @@ stacklevel=2, ) +from ._async import get_async_backend, set_async_backend from ._evented_decorator import evented from ._exceptions import EmitLoopError from ._group import EmissionInfo, SignalGroup diff --git a/src/psygnal/_async.py b/src/psygnal/_async.py new file mode 100644 index 00000000..42c7c53e --- /dev/null +++ b/src/psygnal/_async.py @@ -0,0 +1,167 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from math import inf +from typing import TYPE_CHECKING, Any, overload + +if TYPE_CHECKING: + import anyio.streams.memory + import trio + from typing_extensions import Literal, TypeAlias + + from psygnal._weak_callback import WeakCallback + + SupportedBackend: TypeAlias = Literal["asyncio", "anyio", "trio"] + QueueItem: TypeAlias = tuple["WeakCallback", tuple[Any, ...]] + + +_ASYNC_BACKEND: _AsyncBackend | None = None + + +def get_async_backend() -> _AsyncBackend | None: + """Get the current async backend. Returns None if no backend is set.""" + return _ASYNC_BACKEND + + +@overload +def set_async_backend(backend: Literal["asyncio"]) -> AsyncioBackend: ... +@overload +def set_async_backend(backend: Literal["anyio"]) -> AnyioBackend: ... +@overload +def set_async_backend(backend: Literal["trio"]) -> TrioBackend: ... +def set_async_backend(backend: SupportedBackend = "asyncio") -> _AsyncBackend: + """Set the async backend to use. Must be one of: 'asyncio', 'anyio', 'trio'. + + This should be done as early as possible, and *must* be called before calling + `SignalInstance.connect` with a coroutine function. + """ + global _ASYNC_BACKEND + + if _ASYNC_BACKEND is not None and _ASYNC_BACKEND._backend != backend: + # allow setting the same backend multiple times, for tests + raise RuntimeError(f"Async backend already set to: {_ASYNC_BACKEND._backend}") + + if backend == "asyncio": + _ASYNC_BACKEND = AsyncioBackend() + elif backend == "anyio": + _ASYNC_BACKEND = AnyioBackend() + elif backend == "trio": + _ASYNC_BACKEND = TrioBackend() + else: + raise RuntimeError( + f"Async backend not supported: {backend}. " + "Must be one of: 'asyncio', 'anyio', 'trio'" + ) + + return _ASYNC_BACKEND + + +class _AsyncBackend(ABC): + def __init__(self, backend: str): + self._backend = backend + self._running = False + + @property + def running(self) -> bool: + return self._running + + @abstractmethod + def _put(self, item: QueueItem) -> None: ... + + @abstractmethod + async def _get(self) -> QueueItem: ... + + @abstractmethod + async def run(self) -> None: ... + + async def call_back(self, item: QueueItem) -> None: + cb, args = item + if func := cb.dereference(): + await func(*args) + + +class AsyncioBackend(_AsyncBackend): + def __init__(self) -> None: + super().__init__("asyncio") + import asyncio + + self._asyncio = asyncio + self._queue: asyncio.Queue[tuple] = asyncio.Queue() + self._task = asyncio.create_task(self.run()) + self._loop = asyncio.get_running_loop() + + def _put(self, item: QueueItem) -> None: + self._queue.put_nowait(item) + + async def _get(self) -> QueueItem: + return await self._queue.get() + + async def run(self) -> None: + if self.running: + return + + self._running = True + try: + while True: + item = await self._get() + await self.call_back(item) + except self._asyncio.CancelledError: + pass + except RuntimeError as e: + if not self._loop.is_closed(): + raise e + + +class AnyioBackend(_AsyncBackend): + _send_stream: anyio.streams.memory.MemoryObjectSendStream[QueueItem] + _receive_stream: anyio.streams.memory.MemoryObjectReceiveStream[QueueItem] + + def __init__(self) -> None: + super().__init__("anyio") + import anyio + + self._send_stream, self._receive_stream = anyio.create_memory_object_stream( + max_buffer_size=inf + ) + + def _put(self, item: QueueItem) -> None: + self._send_stream.send_nowait(item) + + async def _get(self) -> QueueItem: + return await self._receive_stream.receive() + + async def run(self) -> None: + if self.running: + return + + self._running = True + async with self._receive_stream: + async for item in self._receive_stream: + await self.call_back(item) + + +class TrioBackend(_AsyncBackend): + _send_channel: trio._channel.MemorySendChannel[QueueItem] + _receive_channel: trio.abc.ReceiveChannel[QueueItem] + + def __init__(self) -> None: + super().__init__("trio") + import trio + + self._send_channel, self._receive_channel = trio.open_memory_channel( + max_buffer_size=inf + ) + + def _put(self, item: tuple) -> None: + self._send_channel.send_nowait(item) + + async def _get(self) -> tuple: + return await self._receive_channel.receive() + + async def run(self) -> None: + if self.running: + return + + self._running = True + async for item in self._receive_channel: + await self.call_back(item) diff --git a/src/psygnal/_weak_callback.py b/src/psygnal/_weak_callback.py index aa233300..f5a65e53 100644 --- a/src/psygnal/_weak_callback.py +++ b/src/psygnal/_weak_callback.py @@ -1,5 +1,6 @@ from __future__ import annotations +import inspect import sys import weakref from functools import partial @@ -16,12 +17,15 @@ ) from warnings import warn +from ._async import get_async_backend from ._mypyc import mypyc_attr if TYPE_CHECKING: import toolz from typing_extensions import TypeAlias, TypeGuard # py310 + from ._async import _AsyncBackend + RefErrorChoice: TypeAlias = Literal["raise", "warn", "ignore"] __all__ = ["WeakCallback", "weak_callback"] @@ -124,14 +128,40 @@ def _on_delete(weak_cb): kwargs = cb.keywords cb = cb.func + is_coro = inspect.iscoroutinefunction(cb) + if is_coro: + if (backend := get_async_backend()) is None: + raise RuntimeError("No async backend set: call `set_async_backend()`") + if not backend.running: + raise RuntimeError( + "Async backend not running (launch `get_async_backend().run()` " + "in a background task)" + ) + if isinstance(cb, FunctionType): - return ( - StrongFunction(cb, max_args, args, kwargs, priority=priority) - if strong_func - else WeakFunction( + # NOTE: I know it looks like this should be easy to express in much shorter + # syntax ... but mypyc will likely fail at runtime. + # Make sure to test compiled version if you change this. + if strong_func: + if is_coro: + return StrongCoroutineFunction( + cb, max_args, args, kwargs, priority=priority + ) + return StrongFunction(cb, max_args, args, kwargs, priority=priority) + else: + if is_coro: + return WeakCoroutineFunction( + cb, + max_args, + args, + kwargs, + finalize, + on_ref_error=on_ref_error, + priority=priority, + ) + return WeakFunction( cb, max_args, args, kwargs, finalize, on_ref_error, priority=priority ) - ) if isinstance(cb, MethodType): if getattr(cb, "__name__", None) == "__setitem__": @@ -145,6 +175,11 @@ def _on_delete(weak_cb): return WeakSetitem( obj, key, max_args, finalize, on_ref_error, priority=priority ) + + if is_coro: + return WeakCoroutineMethod( + cb, max_args, args, kwargs, finalize, on_ref_error, priority=priority + ) return WeakMethod( cb, max_args, args, kwargs, finalize, on_ref_error, priority=priority ) @@ -335,6 +370,8 @@ def _cb(_: weakref.ReferenceType) -> None: class StrongFunction(WeakCallback): """Wrapper around a strong function reference.""" + _f: Callable + def __init__( self, obj: Callable, @@ -580,3 +617,37 @@ def cb(self, args: tuple[Any, ...] = ()) -> None: def dereference(self) -> partial | None: obj = self._obj_ref() return None if obj is None else partial(obj.__setitem__, self._itemkey) + + +# --------------------------- Coroutines --------------------------- + + +class WeakCoroutineFunction(WeakFunction): + def cb(self, args: tuple[Any, ...] = ()) -> None: + if self._f() is None: + raise ReferenceError("weakly-referenced object no longer exists") + if self._max_args is not None: + args = args[: self._max_args] + + cast("_AsyncBackend", get_async_backend())._put((self, args)) + + +class StrongCoroutineFunction(StrongFunction): + """Wrapper around a strong coroutine function reference.""" + + def cb(self, args: tuple[Any, ...] = ()) -> None: + if self._max_args is not None: + args = args[: self._max_args] + + cast("_AsyncBackend", get_async_backend())._put((self, args)) + + +class WeakCoroutineMethod(WeakMethod): + def cb(self, args: tuple[Any, ...] = ()) -> None: + if self._obj_ref() is None or self._func_ref() is None: + raise ReferenceError("weakly-referenced object no longer exists") + + if self._max_args is not None: + args = args[: self._max_args] + + cast("_AsyncBackend", get_async_backend())._put((self, args)) diff --git a/tests/test_coroutines.py b/tests/test_coroutines.py new file mode 100644 index 00000000..506cff1e --- /dev/null +++ b/tests/test_coroutines.py @@ -0,0 +1,78 @@ +import asyncio +import gc +from typing import Any +from unittest.mock import Mock + +import pytest + +from psygnal import _async +from psygnal._weak_callback import WeakCallback, weak_callback + + +@pytest.mark.parametrize( + "type_", + [ + "coroutinefunc", + "weak_coroutinefunc", + "coroutinemethod", + ], +) +@pytest.mark.asyncio +async def test_slot_types(type_: str, capsys: Any) -> None: + backend = _async.set_async_backend("asyncio") + assert backend is _async.get_async_backend() is not None + while not backend.running: + await asyncio.sleep(0) + + mock = Mock() + final_mock = Mock() + + obj: Any + if type_ in {"coroutinefunc", "weak_coroutinefunc"}: + + async def obj(x: int) -> int: + mock(x) + return x + + cb = weak_callback( + obj, strong_func=(type_ == "coroutinefunc"), finalize=final_mock + ) + elif type_ == "coroutinemethod": + + class MyObj: + async def coroutine_method(self, x: int) -> int: + mock(x) + return x + + obj = MyObj() + cb = weak_callback(obj.coroutine_method, finalize=final_mock) + + assert isinstance(cb, WeakCallback) + assert isinstance(cb.slot_repr(), str) + assert cb.dereference() is not None + + cb.cb((2,)) + await asyncio.sleep(0.01) + mock.assert_called_once_with(2) + + mock.reset_mock() + assert await cb(4) == 4 + mock.assert_called_once_with(4) + + del obj + gc.collect() + await asyncio.sleep(0.01) + + if type_ == "coroutinefunc": # strong_func + cb.cb((4,)) + await asyncio.sleep(0.01) + mock.assert_called_with(4) + + else: + final_mock.assert_called_once_with(cb) + assert cb.dereference() is None + with pytest.raises(ReferenceError): + cb.cb((2,)) + await asyncio.sleep(0.01) + with pytest.raises(ReferenceError): + await cb(2)