From 9c84cb71475b863ef361ca526283b965e6710495 Mon Sep 17 00:00:00 2001 From: David Perl Date: Mon, 8 Apr 2024 08:30:53 +0100 Subject: [PATCH] (#117) (#45) update classes to use WatchableAsyncStatus. (#117) (#45) fix typing issues --- src/ophyd_async/core/__init__.py | 3 +- src/ophyd_async/core/async_status.py | 47 ++++++++++++++++++------- src/ophyd_async/core/detector.py | 6 ++-- src/ophyd_async/epics/demo/__init__.py | 13 +++++-- src/ophyd_async/epics/motion/motor.py | 8 +++-- tests/core/test_async_status_wrapper.py | 7 ++-- tests/epics/demo/test_demo.py | 2 +- tests/epics/motion/test_motor.py | 2 +- 8 files changed, 60 insertions(+), 28 deletions(-) diff --git a/src/ophyd_async/core/__init__.py b/src/ophyd_async/core/__init__.py index 103638019d..a0af742f44 100644 --- a/src/ophyd_async/core/__init__.py +++ b/src/ophyd_async/core/__init__.py @@ -5,7 +5,7 @@ ShapeProvider, StaticDirectoryProvider, ) -from .async_status import AsyncStatus +from .async_status import AsyncStatus, WatchableAsyncStatus from .detector import ( DetectorControl, DetectorTrigger, @@ -74,6 +74,7 @@ "set_sim_value", "wait_for_value", "AsyncStatus", + "WatchableAsyncStatus", "DirectoryInfo", "DirectoryProvider", "NameProvider", diff --git a/src/ophyd_async/core/async_status.py b/src/ophyd_async/core/async_status.py index f449ab6803..1ce58e60a1 100644 --- a/src/ophyd_async/core/async_status.py +++ b/src/ophyd_async/core/async_status.py @@ -4,13 +4,23 @@ import functools import time from dataclasses import replace -from typing import AsyncIterator, Awaitable, Callable, Generic, Type, TypeVar, cast +from typing import ( + AsyncIterator, + Awaitable, + Callable, + Generic, + Sequence, + Type, + TypeVar, + cast, +) from bluesky.protocols import Status from .utils import Callback, P, T, Watcher, WatcherUpdate -AS = TypeVar("AS") +AS = TypeVar("AS", bound="AsyncStatus") +WAS = TypeVar("WAS", bound="WatchableAsyncStatus") class AsyncStatusBase(Status): @@ -93,11 +103,20 @@ def wrap_f(*args: P.args, **kwargs: P.kwargs) -> AS: class WatchableAsyncStatus(AsyncStatusBase, Generic[T]): """Convert AsyncIterator of WatcherUpdates to bluesky Status interface.""" - def __init__(self, iterator: AsyncIterator[WatcherUpdate[T]]): - self._watchers: list[Watcher] + def __init__( + self, + iterator_or_awaitable: Awaitable | AsyncIterator[WatcherUpdate[T]], + watchers: list[Watcher] = [], + ): + self._watchers: list[Watcher] = watchers self._start = time.monotonic() self._last_update: WatcherUpdate[T] | None = None - super().__init__(self._notify_watchers_from(iterator)) + awaitable = ( + iterator_or_awaitable + if isinstance(iterator_or_awaitable, Awaitable) + else self._notify_watchers_from(iterator_or_awaitable) + ) + super().__init__(awaitable) async def _notify_watchers_from(self, iterator: AsyncIterator[WatcherUpdate[T]]): async for self._last_update in iterator: @@ -107,17 +126,19 @@ async def _notify_watchers_from(self, iterator: AsyncIterator[WatcherUpdate[T]]) def _update_watcher(self, watcher: Watcher, update: WatcherUpdate[T]): watcher(replace(update, time_elapsed_s=time.monotonic() - self._start)) - def watch(self, watcher: Watcher): - self._watchers.append(watcher) - if self._last_update: - self._update_watcher(watcher, self._last_update) + def watch(self, watchers: Sequence[Watcher]): + for watcher in watchers: + self._watchers.append(watcher) + if self._last_update: + self._update_watcher(watcher, self._last_update) @classmethod def wrap( - cls: Type[AS], f: Callable[P, AsyncIterator[WatcherUpdate[T]]] - ) -> Callable[P, AS]: + cls: Type[WAS], + f: Callable[P, Awaitable] | Callable[P, AsyncIterator[WatcherUpdate[T]]], + ) -> Callable[P, WAS]: @functools.wraps(f) - def wrap_f(*args: P.args, **kwargs: P.kwargs) -> AS: + def wrap_f(*args: P.args, **kwargs: P.kwargs) -> WAS: return cls(f(*args, **kwargs)) - return cast(Callable[P, AS], wrap_f) + return cast(Callable[P, WAS], wrap_f) diff --git a/src/ophyd_async/core/detector.py b/src/ophyd_async/core/detector.py index 39bd2ac98d..bfe94de35c 100644 --- a/src/ophyd_async/core/detector.py +++ b/src/ophyd_async/core/detector.py @@ -31,7 +31,7 @@ WritesStreamAssets, ) -from .async_status import AsyncStatus +from .async_status import AsyncStatus, WatchableAsyncStatus from .device import Device from .signal import SignalR from .utils import DEFAULT_TIMEOUT, merge_gathered_dicts @@ -189,7 +189,7 @@ def __init__( self._trigger_info: Optional[TriggerInfo] = None # For kickoff self._watchers: List[Callable] = [] - self._fly_status: Optional[AsyncStatus] = None + self._fly_status: Optional[WatchableAsyncStatus] = None self._fly_start: float self._intial_frame: int @@ -302,7 +302,7 @@ async def _prepare(self, value: T) -> None: @AsyncStatus.wrap async def kickoff(self) -> None: - self._fly_status = AsyncStatus(self._fly(), self._watchers) + self._fly_status = WatchableAsyncStatus(self._fly(), self._watchers) self._fly_start = time.monotonic() async def _fly(self) -> None: diff --git a/src/ophyd_async/epics/demo/__init__.py b/src/ophyd_async/epics/demo/__init__.py index 73833c1731..22be71f1eb 100644 --- a/src/ophyd_async/epics/demo/__init__.py +++ b/src/ophyd_async/epics/demo/__init__.py @@ -14,7 +14,12 @@ import numpy as np from bluesky.protocols import Movable, Stoppable -from ophyd_async.core import AsyncStatus, Device, StandardReadable, observe_value +from ophyd_async.core import ( + Device, + StandardReadable, + WatchableAsyncStatus, + observe_value, +) from ..signal.signal import epics_signal_r, epics_signal_rw, epics_signal_x @@ -105,10 +110,12 @@ def move(self, new_position: float, timeout: Optional[float] = None): call_in_bluesky_event_loop(self._move(new_position), timeout) # type: ignore # TODO: this fails if we call from the cli, but works if we "ipython await" it - def set(self, new_position: float, timeout: Optional[float] = None) -> AsyncStatus: + def set( + self, new_position: float, timeout: Optional[float] = None + ) -> WatchableAsyncStatus: watchers: List[Callable] = [] coro = asyncio.wait_for(self._move(new_position, watchers), timeout=timeout) - return AsyncStatus(coro, watchers) + return WatchableAsyncStatus(coro, watchers) async def stop(self, success=True): self._set_success = success diff --git a/src/ophyd_async/epics/motion/motor.py b/src/ophyd_async/epics/motion/motor.py index 49a997e1d0..24078c8775 100644 --- a/src/ophyd_async/epics/motion/motor.py +++ b/src/ophyd_async/epics/motion/motor.py @@ -4,7 +4,7 @@ from bluesky.protocols import Movable, Stoppable -from ophyd_async.core import AsyncStatus, StandardReadable +from ophyd_async.core import StandardReadable, WatchableAsyncStatus from ..signal.signal import epics_signal_r, epics_signal_rw, epics_signal_x @@ -72,10 +72,12 @@ def move(self, new_position: float, timeout: Optional[float] = None): raise RuntimeError("Will deadlock run engine if run in a plan") call_in_bluesky_event_loop(self._move(new_position), timeout) # type: ignore - def set(self, new_position: float, timeout: Optional[float] = None) -> AsyncStatus: + def set( + self, new_position: float, timeout: Optional[float] = None + ) -> WatchableAsyncStatus: watchers: List[Callable] = [] coro = asyncio.wait_for(self._move(new_position, watchers), timeout=timeout) - return AsyncStatus(coro, watchers) + return WatchableAsyncStatus(coro, watchers) async def stop(self, success=False): self._set_success = success diff --git a/tests/core/test_async_status_wrapper.py b/tests/core/test_async_status_wrapper.py index d584061aae..1e88e0c7d0 100644 --- a/tests/core/test_async_status_wrapper.py +++ b/tests/core/test_async_status_wrapper.py @@ -1,4 +1,5 @@ import asyncio +from typing import AsyncIterator import bluesky.plan_stubs as bps import pytest @@ -36,7 +37,7 @@ def __init__( super().__init__(name) @WatchableAsyncStatus.wrap - async def set(self, val): + async def set(self, val) -> AsyncIterator: self._initial = await self.sig.get_value() for point in self.values: await asyncio.sleep(0.01) @@ -123,7 +124,7 @@ async def test_asyncstatus_wraps_set_iterator(RE): def watcher(update): updates.append(update) - st.watch(watcher) + st.watch([watcher]) await st assert st.done assert st.success @@ -139,7 +140,7 @@ async def test_asyncstatus_wraps_failing_set_iterator_(RE): def watcher(update): updates.append(update) - st.watch(watcher) + st.watch([watcher]) try: await st except Exception: diff --git a/tests/epics/demo/test_demo.py b/tests/epics/demo/test_demo.py index 5b8847e842..4ce7512ff5 100644 --- a/tests/epics/demo/test_demo.py +++ b/tests/epics/demo/test_demo.py @@ -61,7 +61,7 @@ async def wait_for_call(self, *args, **kwargs): async def test_mover_moving_well(sim_mover: demo.Mover) -> None: s = sim_mover.set(0.55) watcher = Watcher() - s.watch(watcher) + s.watch([watcher]) done = Mock() s.add_callback(done) await watcher.wait_for_call( diff --git a/tests/epics/motion/test_motor.py b/tests/epics/motion/test_motor.py index 7706099295..c4c327858d 100644 --- a/tests/epics/motion/test_motor.py +++ b/tests/epics/motion/test_motor.py @@ -30,7 +30,7 @@ async def test_motor_moving_well(sim_motor: motor.Motor) -> None: set_sim_put_proceeds(sim_motor.setpoint, False) s = sim_motor.set(0.55) watcher = Mock() - s.watch(watcher) + s.watch([watcher]) done = Mock() s.add_callback(done) await asyncio.sleep(A_BIT)