Skip to content

Commit

Permalink
(#117) (#45) make Watcher match bluesky spec
Browse files Browse the repository at this point in the history
(#117) (#45) update classes to use WatchableAsyncStatus.
  • Loading branch information
dperl-dls committed Apr 15, 2024
1 parent 9c84cb7 commit 865356a
Show file tree
Hide file tree
Showing 6 changed files with 156 additions and 77 deletions.
20 changes: 11 additions & 9 deletions src/ophyd_async/core/async_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import asyncio
import functools
import time
from dataclasses import replace
from dataclasses import asdict, replace
from typing import (
AsyncIterator,
Awaitable,
Expand Down Expand Up @@ -53,7 +53,7 @@ def _run_callbacks(self, task: asyncio.Task):

def exception(self, timeout: float | None = 0.0) -> BaseException | None:
if timeout != 0.0:
raise Exception(
raise ValueError(
"cannot honour any timeout other than 0 in an asynchronous function"
)
if self.task.done():
Expand Down Expand Up @@ -119,18 +119,20 @@ def __init__(
super().__init__(awaitable)

async def _notify_watchers_from(self, iterator: AsyncIterator[WatcherUpdate[T]]):
async for self._last_update in iterator:
async for update in iterator:
self._last_update = replace(
update, time_elapsed=time.monotonic() - self._start
)
for watcher in self._watchers:
self._update_watcher(watcher, self._last_update)

def _update_watcher(self, watcher: Watcher, update: WatcherUpdate[T]):
watcher(replace(update, time_elapsed_s=time.monotonic() - self._start))
watcher(**asdict(update))

def watch(self, watchers: Sequence[Watcher]):
for watcher in watchers:
self._watchers.append(watcher)
if self._last_update:
self._update_watcher(watcher, self._last_update)
def watch(self, watcher:Watcher):
self._watchers.append(watcher)
if self._last_update:
self._update_watcher(watcher, self._last_update)

@classmethod
def wrap(
Expand Down
32 changes: 15 additions & 17 deletions src/ophyd_async/core/detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from .async_status import AsyncStatus, WatchableAsyncStatus
from .device import Device
from .signal import SignalR
from .utils import DEFAULT_TIMEOUT, merge_gathered_dicts
from .utils import DEFAULT_TIMEOUT, WatcherUpdate, merge_gathered_dicts

T = TypeVar("T")

Expand Down Expand Up @@ -300,28 +300,26 @@ async def _prepare(self, value: T) -> None:
exposure=self._trigger_info.livetime,
)

@AsyncStatus.wrap
async def kickoff(self) -> None:
self._fly_status = WatchableAsyncStatus(self._fly(), self._watchers)
async def kickoff(self):
self._fly_status = WatchableAsyncStatus(
self._observe_writer_indicies(self._last_frame), self._watchers
)
self._fly_start = time.monotonic()

async def _fly(self) -> None:
await self._observe_writer_indicies(self._last_frame)
return self._fly_status

async def _observe_writer_indicies(self, end_observation: int):
async for index in self.writer.observe_indices_written(
self._frame_writing_timeout
):
for watcher in self._watchers:
watcher(
name=self.name,
current=index,
initial=self._initial_frame,
target=end_observation,
unit="",
precision=0,
time_elapsed=time.monotonic() - self._fly_start,
)
yield WatcherUpdate(
name=self.name,
current=index,
initial=self._initial_frame,
target=end_observation,
unit="",
precision=0,
time_elapsed=time.monotonic() - self._fly_start,
)
if index >= end_observation:
break

Expand Down
36 changes: 28 additions & 8 deletions src/ophyd_async/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
List,
Optional,
ParamSpec,
Protocol,
Type,
TypeAlias,
TypeVar,
Union,
)
Expand Down Expand Up @@ -85,16 +85,36 @@ def __str__(self) -> str:

@dataclass(frozen=True)
class WatcherUpdate(Generic[T]):
name: str
"""A dataclass such that, when expanded, it provides the kwargs for a watcher"""

current: T
initial: T
target: T
units: str
precision: float
time_elapsed_s: float


Watcher: TypeAlias = Callable[[WatcherUpdate[T]], Any]
name: str | None = None
unit: str | None = None
precision: float | None = None
fraction: float | None = None
time_elapsed: float | None = None
time_remaining: float | None = None


C = TypeVar("C", contravariant=True)


class Watcher(Protocol, Generic[C]):
@staticmethod
def __call__(
*,
current: C,
initial: C,
target: C,
name: str | None,
unit: str | None,
precision: float | None,
fraction: float | None,
time_elapsed: float | None,
time_remaining: float | None,
) -> Any: ...


async def wait_for_connection(**coros: Awaitable[None]):
Expand Down
48 changes: 21 additions & 27 deletions src/ophyd_async/epics/motion/motor.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import asyncio
import time
from typing import Callable, List, Optional
from dataclasses import replace
from typing import Optional

from bluesky.protocols import Movable, Stoppable

from ophyd_async.core import StandardReadable, WatchableAsyncStatus
from ophyd_async.core.signal import observe_value
from ophyd_async.core.utils import WatcherUpdate

from ..signal.signal import epics_signal_r, epics_signal_rw, epics_signal_x

Expand Down Expand Up @@ -35,34 +38,23 @@ def set_name(self, name: str):
# Readback should be named the same as its parent in read()
self.readback.set_name(name)

async def _move(self, new_position: float, watchers: List[Callable] = []):
async def _move(self, new_position: float) -> WatcherUpdate[float]:
self._set_success = True
start = time.monotonic()
old_position, units, precision = await asyncio.gather(
self.setpoint.get_value(),
self.units.get_value(),
self.precision.get_value(),
)

def update_watchers(current_position: float):
for watcher in watchers:
watcher(
name=self.name,
current=current_position,
initial=old_position,
target=new_position,
unit=units,
precision=precision,
time_elapsed=time.monotonic() - start,
)

self.readback.subscribe_value(update_watchers)
try:
await self.setpoint.set(new_position)
finally:
self.readback.clear_sub(update_watchers)
await self.setpoint.set(new_position)
if not self._set_success:
raise RuntimeError("Motor was stopped")
return WatcherUpdate(
initial=old_position,
current=old_position,
target=new_position,
unit=units,
precision=precision,
)

def move(self, new_position: float, timeout: Optional[float] = None):
"""Commandline only synchronous move of a Motor"""
Expand All @@ -72,12 +64,14 @@ def move(self, new_position: float, timeout: Optional[float] = None):
raise RuntimeError("Will deadlock run engine if run in a plan")
call_in_bluesky_event_loop(self._move(new_position), timeout) # type: ignore

def set(
self, new_position: float, timeout: Optional[float] = None
) -> WatchableAsyncStatus:
watchers: List[Callable] = []
coro = asyncio.wait_for(self._move(new_position, watchers), timeout=timeout)
return WatchableAsyncStatus(coro, watchers)
@WatchableAsyncStatus.wrap
async def set(self, new_position: float, timeout: Optional[float] = None):
start_time = time.monotonic()
update: WatcherUpdate[float] = await self._move(new_position)
async for readback in observe_value(self.readback):
yield replace(
update, current=readback, time_elapsed=start_time - time.monotonic()
)

async def stop(self, success=False):
self._set_success = success
Expand Down
Loading

0 comments on commit 865356a

Please sign in to comment.