Skip to content

Commit

Permalink
(#117) (#45) update classes to use WatchableAsyncStatus.
Browse files Browse the repository at this point in the history
(#117) (#45) fix typing issues
  • Loading branch information
dperl-dls committed Apr 15, 2024
1 parent b63de38 commit 9c84cb7
Show file tree
Hide file tree
Showing 8 changed files with 60 additions and 28 deletions.
3 changes: 2 additions & 1 deletion src/ophyd_async/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
ShapeProvider,
StaticDirectoryProvider,
)
from .async_status import AsyncStatus
from .async_status import AsyncStatus, WatchableAsyncStatus
from .detector import (
DetectorControl,
DetectorTrigger,
Expand Down Expand Up @@ -74,6 +74,7 @@
"set_sim_value",
"wait_for_value",
"AsyncStatus",
"WatchableAsyncStatus",
"DirectoryInfo",
"DirectoryProvider",
"NameProvider",
Expand Down
47 changes: 34 additions & 13 deletions src/ophyd_async/core/async_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,23 @@
import functools
import time
from dataclasses import replace
from typing import AsyncIterator, Awaitable, Callable, Generic, Type, TypeVar, cast
from typing import (
AsyncIterator,
Awaitable,
Callable,
Generic,
Sequence,
Type,
TypeVar,
cast,
)

from bluesky.protocols import Status

from .utils import Callback, P, T, Watcher, WatcherUpdate

AS = TypeVar("AS")
AS = TypeVar("AS", bound="AsyncStatus")
WAS = TypeVar("WAS", bound="WatchableAsyncStatus")


class AsyncStatusBase(Status):
Expand Down Expand Up @@ -93,11 +103,20 @@ def wrap_f(*args: P.args, **kwargs: P.kwargs) -> AS:
class WatchableAsyncStatus(AsyncStatusBase, Generic[T]):
"""Convert AsyncIterator of WatcherUpdates to bluesky Status interface."""

def __init__(self, iterator: AsyncIterator[WatcherUpdate[T]]):
self._watchers: list[Watcher]
def __init__(
self,
iterator_or_awaitable: Awaitable | AsyncIterator[WatcherUpdate[T]],
watchers: list[Watcher] = [],
):
self._watchers: list[Watcher] = watchers
self._start = time.monotonic()
self._last_update: WatcherUpdate[T] | None = None
super().__init__(self._notify_watchers_from(iterator))
awaitable = (
iterator_or_awaitable
if isinstance(iterator_or_awaitable, Awaitable)
else self._notify_watchers_from(iterator_or_awaitable)
)
super().__init__(awaitable)

async def _notify_watchers_from(self, iterator: AsyncIterator[WatcherUpdate[T]]):
async for self._last_update in iterator:
Expand All @@ -107,17 +126,19 @@ async def _notify_watchers_from(self, iterator: AsyncIterator[WatcherUpdate[T]])
def _update_watcher(self, watcher: Watcher, update: WatcherUpdate[T]):
watcher(replace(update, time_elapsed_s=time.monotonic() - self._start))

def watch(self, watcher: Watcher):
self._watchers.append(watcher)
if self._last_update:
self._update_watcher(watcher, self._last_update)
def watch(self, watchers: Sequence[Watcher]):
for watcher in watchers:
self._watchers.append(watcher)
if self._last_update:
self._update_watcher(watcher, self._last_update)

@classmethod
def wrap(
cls: Type[AS], f: Callable[P, AsyncIterator[WatcherUpdate[T]]]
) -> Callable[P, AS]:
cls: Type[WAS],
f: Callable[P, Awaitable] | Callable[P, AsyncIterator[WatcherUpdate[T]]],
) -> Callable[P, WAS]:
@functools.wraps(f)
def wrap_f(*args: P.args, **kwargs: P.kwargs) -> AS:
def wrap_f(*args: P.args, **kwargs: P.kwargs) -> WAS:
return cls(f(*args, **kwargs))

return cast(Callable[P, AS], wrap_f)
return cast(Callable[P, WAS], wrap_f)
6 changes: 3 additions & 3 deletions src/ophyd_async/core/detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
WritesStreamAssets,
)

from .async_status import AsyncStatus
from .async_status import AsyncStatus, WatchableAsyncStatus
from .device import Device
from .signal import SignalR
from .utils import DEFAULT_TIMEOUT, merge_gathered_dicts
Expand Down Expand Up @@ -189,7 +189,7 @@ def __init__(
self._trigger_info: Optional[TriggerInfo] = None
# For kickoff
self._watchers: List[Callable] = []
self._fly_status: Optional[AsyncStatus] = None
self._fly_status: Optional[WatchableAsyncStatus] = None
self._fly_start: float

self._intial_frame: int
Expand Down Expand Up @@ -302,7 +302,7 @@ async def _prepare(self, value: T) -> None:

@AsyncStatus.wrap
async def kickoff(self) -> None:
self._fly_status = AsyncStatus(self._fly(), self._watchers)
self._fly_status = WatchableAsyncStatus(self._fly(), self._watchers)
self._fly_start = time.monotonic()

async def _fly(self) -> None:
Expand Down
13 changes: 10 additions & 3 deletions src/ophyd_async/epics/demo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@
import numpy as np
from bluesky.protocols import Movable, Stoppable

from ophyd_async.core import AsyncStatus, Device, StandardReadable, observe_value
from ophyd_async.core import (
Device,
StandardReadable,
WatchableAsyncStatus,
observe_value,
)

from ..signal.signal import epics_signal_r, epics_signal_rw, epics_signal_x

Expand Down Expand Up @@ -105,10 +110,12 @@ def move(self, new_position: float, timeout: Optional[float] = None):
call_in_bluesky_event_loop(self._move(new_position), timeout) # type: ignore

# TODO: this fails if we call from the cli, but works if we "ipython await" it
def set(self, new_position: float, timeout: Optional[float] = None) -> AsyncStatus:
def set(
self, new_position: float, timeout: Optional[float] = None
) -> WatchableAsyncStatus:
watchers: List[Callable] = []
coro = asyncio.wait_for(self._move(new_position, watchers), timeout=timeout)
return AsyncStatus(coro, watchers)
return WatchableAsyncStatus(coro, watchers)

async def stop(self, success=True):
self._set_success = success
Expand Down
8 changes: 5 additions & 3 deletions src/ophyd_async/epics/motion/motor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from bluesky.protocols import Movable, Stoppable

from ophyd_async.core import AsyncStatus, StandardReadable
from ophyd_async.core import StandardReadable, WatchableAsyncStatus

from ..signal.signal import epics_signal_r, epics_signal_rw, epics_signal_x

Expand Down Expand Up @@ -72,10 +72,12 @@ def move(self, new_position: float, timeout: Optional[float] = None):
raise RuntimeError("Will deadlock run engine if run in a plan")
call_in_bluesky_event_loop(self._move(new_position), timeout) # type: ignore

def set(self, new_position: float, timeout: Optional[float] = None) -> AsyncStatus:
def set(
self, new_position: float, timeout: Optional[float] = None
) -> WatchableAsyncStatus:
watchers: List[Callable] = []
coro = asyncio.wait_for(self._move(new_position, watchers), timeout=timeout)
return AsyncStatus(coro, watchers)
return WatchableAsyncStatus(coro, watchers)

async def stop(self, success=False):
self._set_success = success
Expand Down
7 changes: 4 additions & 3 deletions tests/core/test_async_status_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
from typing import AsyncIterator

import bluesky.plan_stubs as bps
import pytest
Expand Down Expand Up @@ -36,7 +37,7 @@ def __init__(
super().__init__(name)

@WatchableAsyncStatus.wrap
async def set(self, val):
async def set(self, val) -> AsyncIterator:
self._initial = await self.sig.get_value()
for point in self.values:
await asyncio.sleep(0.01)
Expand Down Expand Up @@ -123,7 +124,7 @@ async def test_asyncstatus_wraps_set_iterator(RE):
def watcher(update):
updates.append(update)

st.watch(watcher)
st.watch([watcher])
await st
assert st.done
assert st.success
Expand All @@ -139,7 +140,7 @@ async def test_asyncstatus_wraps_failing_set_iterator_(RE):
def watcher(update):
updates.append(update)

st.watch(watcher)
st.watch([watcher])
try:
await st
except Exception:
Expand Down
2 changes: 1 addition & 1 deletion tests/epics/demo/test_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ async def wait_for_call(self, *args, **kwargs):
async def test_mover_moving_well(sim_mover: demo.Mover) -> None:
s = sim_mover.set(0.55)
watcher = Watcher()
s.watch(watcher)
s.watch([watcher])
done = Mock()
s.add_callback(done)
await watcher.wait_for_call(
Expand Down
2 changes: 1 addition & 1 deletion tests/epics/motion/test_motor.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ async def test_motor_moving_well(sim_motor: motor.Motor) -> None:
set_sim_put_proceeds(sim_motor.setpoint, False)
s = sim_motor.set(0.55)
watcher = Mock()
s.watch(watcher)
s.watch([watcher])
done = Mock()
s.add_callback(done)
await asyncio.sleep(A_BIT)
Expand Down

0 comments on commit 9c84cb7

Please sign in to comment.