Skip to content

Commit

Permalink
(#117) (#45) add timeouts
Browse files Browse the repository at this point in the history
and add a bit more to tests
  • Loading branch information
dperl-dls committed Apr 15, 2024
1 parent 865356a commit 3dc36c9
Show file tree
Hide file tree
Showing 8 changed files with 147 additions and 60 deletions.
40 changes: 26 additions & 14 deletions src/ophyd_async/core/async_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
Awaitable,
Callable,
Generic,
Sequence,
SupportsFloat,
Type,
TypeVar,
cast,
Expand Down Expand Up @@ -105,42 +105,54 @@ class WatchableAsyncStatus(AsyncStatusBase, Generic[T]):

def __init__(
self,
iterator_or_awaitable: Awaitable | AsyncIterator[WatcherUpdate[T]],
watchers: list[Watcher] = [],
iterator: AsyncIterator[WatcherUpdate[T]],
timeout_s: float = 0.0,
):
self._watchers: list[Watcher] = watchers
self._watchers: list[Watcher] = []
self._start = time.monotonic()
self._timeout = self._start + timeout_s if timeout_s else None
self._last_update: WatcherUpdate[T] | None = None
awaitable = (
iterator_or_awaitable
if isinstance(iterator_or_awaitable, Awaitable)
else self._notify_watchers_from(iterator_or_awaitable)
)
super().__init__(awaitable)
super().__init__(self._notify_watchers_from(iterator))

async def _notify_watchers_from(self, iterator: AsyncIterator[WatcherUpdate[T]]):
async for update in iterator:
if self._timeout and time.monotonic() > self._timeout:
raise TimeoutError()
self._last_update = replace(
update, time_elapsed=time.monotonic() - self._start
)
for watcher in self._watchers:
self._update_watcher(watcher, self._last_update)

def _update_watcher(self, watcher: Watcher, update: WatcherUpdate[T]):
watcher(**asdict(update))
vals = asdict(
update, dict_factory=lambda d: {k: v for k, v in d if v is not None}
)
watcher(**vals)

def watch(self, watcher:Watcher):
def watch(self, watcher: Watcher):
self._watchers.append(watcher)
if self._last_update:
self._update_watcher(watcher, self._last_update)

@classmethod
def wrap(
cls: Type[WAS],
f: Callable[P, Awaitable] | Callable[P, AsyncIterator[WatcherUpdate[T]]],
f: Callable[P, AsyncIterator[WatcherUpdate[T]]],
timeout_s: float = 0.0,
) -> Callable[P, WAS]:
"""Wrap an AsyncIterator in a WatchableAsyncStatus. If it takes
'timeout_s' as an argument, this must be a float and it will be propagated
to the status."""

@functools.wraps(f)
def wrap_f(*args: P.args, **kwargs: P.kwargs) -> WAS:
return cls(f(*args, **kwargs))
# We can't type this more properly because Concatenate/ParamSpec doesn't
# yet support keywords
# https://peps.python.org/pep-0612/#concatenating-keyword-parameters
_timeout = kwargs.get("timeout_s")
assert isinstance(_timeout, SupportsFloat) or _timeout is None
timeout = _timeout or 0.0
return cls(f(*args, **kwargs), timeout_s=float(timeout))

return cast(Callable[P, WAS], wrap_f)
8 changes: 4 additions & 4 deletions src/ophyd_async/core/detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,12 +300,12 @@ async def _prepare(self, value: T) -> None:
exposure=self._trigger_info.livetime,
)

async def kickoff(self):
def kickoff(self, timeout_s=0.0):
self._fly_start = time.monotonic()
self._fly_status = WatchableAsyncStatus(
self._observe_writer_indicies(self._last_frame), self._watchers
self._observe_writer_indicies(self._last_frame), timeout_s
)
self._fly_start = time.monotonic()
return self._fly_status
return self._fly_status

async def _observe_writer_indicies(self, end_observation: int):
async for index in self.writer.observe_indices_written(
Expand Down
22 changes: 21 additions & 1 deletion src/ophyd_async/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
ParamSpec,
Protocol,
Type,
TypeAlias,
TypeVar,
Union,
)
Expand Down Expand Up @@ -101,7 +102,7 @@ class WatcherUpdate(Generic[T]):
C = TypeVar("C", contravariant=True)


class Watcher(Protocol, Generic[C]):
class _ClsWatcher(Protocol, Generic[C]):
@staticmethod
def __call__(
*,
Expand All @@ -117,6 +118,25 @@ def __call__(
) -> Any: ...


class _InsWatcher(Protocol, Generic[C]):
def __call__(
self,
*,
current: C,
initial: C,
target: C,
name: str | None,
unit: str | None,
precision: float | None,
fraction: float | None,
time_elapsed: float | None,
time_remaining: float | None,
) -> Any: ...


Watcher: TypeAlias = _ClsWatcher | _InsWatcher


async def wait_for_connection(**coros: Awaitable[None]):
"""Call many underlying signals, accumulating exceptions and returning them
Expand Down
48 changes: 25 additions & 23 deletions src/ophyd_async/epics/demo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
import subprocess
import sys
import time
from dataclasses import replace
from enum import Enum
from pathlib import Path
from typing import Callable, List, Optional
from typing import Optional

import numpy as np
from bluesky.protocols import Movable, Stoppable
Expand All @@ -20,6 +21,7 @@
WatchableAsyncStatus,
observe_value,
)
from ophyd_async.core.utils import WatcherUpdate

from ..signal.signal import epics_signal_r, epics_signal_rw, epics_signal_x

Expand Down Expand Up @@ -74,32 +76,26 @@ def set_name(self, name: str):
# Readback should be named the same as its parent in read()
self.readback.set_name(name)

async def _move(self, new_position: float, watchers: List[Callable] = []):
async def _move(self, new_position: float):
self._set_success = True
# time.monotonic won't go backwards in case of NTP corrections
start = time.monotonic()
old_position, units, precision = await asyncio.gather(
self.setpoint.get_value(),
self.units.get_value(),
self.precision.get_value(),
)
# Wait for the value to set, but don't wait for put completion callback
await self.setpoint.set(new_position, wait=False)
async for current_position in observe_value(self.readback):
for watcher in watchers:
watcher(
name=self.name,
current=current_position,
initial=old_position,
target=new_position,
unit=units,
precision=precision,
time_elapsed=time.monotonic() - start,
)
if np.isclose(current_position, new_position):
break
if not self._set_success:
raise RuntimeError("Motor was stopped")
# return a template to set() which it can use to yield progress updates
return WatcherUpdate(
initial=old_position,
current=old_position,
target=new_position,
unit=units,
precision=precision,
)

def move(self, new_position: float, timeout: Optional[float] = None):
"""Commandline only synchronous move of a Motor"""
Expand All @@ -109,13 +105,19 @@ def move(self, new_position: float, timeout: Optional[float] = None):
raise RuntimeError("Will deadlock run engine if run in a plan")
call_in_bluesky_event_loop(self._move(new_position), timeout) # type: ignore

# TODO: this fails if we call from the cli, but works if we "ipython await" it
def set(
self, new_position: float, timeout: Optional[float] = None
) -> WatchableAsyncStatus:
watchers: List[Callable] = []
coro = asyncio.wait_for(self._move(new_position, watchers), timeout=timeout)
return WatchableAsyncStatus(coro, watchers)
@WatchableAsyncStatus.wrap
async def set(self, new_position: float):
update = await self._move(new_position)
start = time.monotonic()
async for current_position in observe_value(self.readback):
yield replace(
update,
name=self.name,
current=current_position,
time_elapsed=time.monotonic() - start,
)
if np.isclose(current_position, new_position):
return

async def stop(self, success=True):
self._set_success = success
Expand Down
24 changes: 16 additions & 8 deletions src/ophyd_async/epics/motion/motor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def __init__(self, prefix: str, name="") -> None:
self.setpoint = epics_signal_rw(float, prefix + ".VAL")
self.readback = epics_signal_r(float, prefix + ".RBV")
self.velocity = epics_signal_rw(float, prefix + ".VELO")
self.done_moving = epics_signal_r(bool, prefix + ".DMOV")
self.units = epics_signal_r(str, prefix + ".EGU")
self.precision = epics_signal_r(int, prefix + ".PREC")
# Signals that collide with standard methods should have a trailing underscore
Expand All @@ -45,7 +46,7 @@ async def _move(self, new_position: float) -> WatcherUpdate[float]:
self.units.get_value(),
self.precision.get_value(),
)
await self.setpoint.set(new_position)
await self.setpoint.set(new_position, wait=False)
if not self._set_success:
raise RuntimeError("Motor was stopped")
return WatcherUpdate(
Expand All @@ -65,17 +66,24 @@ def move(self, new_position: float, timeout: Optional[float] = None):
call_in_bluesky_event_loop(self._move(new_position), timeout) # type: ignore

@WatchableAsyncStatus.wrap
async def set(self, new_position: float, timeout: Optional[float] = None):
start_time = time.monotonic()
update: WatcherUpdate[float] = await self._move(new_position)
async for readback in observe_value(self.readback):
async def set(self, new_position: float, timeout_s: float = 0.0):
update = await self._move(new_position)
start = time.monotonic()
async for current_position in observe_value(self.readback):
if not self._set_success:
raise RuntimeError("Motor was stopped")
yield replace(
update, current=readback, time_elapsed=start_time - time.monotonic()
update,
name=self.name,
current=current_position,
time_elapsed=time.monotonic() - start,
)
if await self.done_moving.get_value():
return

async def stop(self, success=False):
self._set_success = success
# Put with completion will never complete as we are waiting for completion on
# the move above, so need to pass wait=False
status = self.stop_.trigger(wait=False)
await status
await self.stop_.trigger(wait=False)
await self.readback._backend.put(await self.readback.get_value())
2 changes: 2 additions & 0 deletions tests/core/test_device_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def test_async_device_connector_run_engine_same_event_loop():
async def set_up_device():
async with DeviceCollector(sim=True):
sim_motor = motor.Motor("BLxxI-MO-TABLE-01:X")
sim_motor.set
return sim_motor

loop = asyncio.new_event_loop()
Expand All @@ -95,6 +96,7 @@ async def set_up_device():
RE = RunEngine(call_returns_result=True, loop=loop)

def my_plan():
sim_motor.done_moving._backend._set_value(True) # type: ignore
yield from bps.mov(sim_motor, 3.14)

RE(my_plan())
Expand Down
33 changes: 28 additions & 5 deletions tests/epics/demo/test_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,36 @@ async def sim_sensor():
yield sim_sensor


class Watcher:
class DemoWatcher:
def __init__(self) -> None:
self._event = asyncio.Event()
self._mock = Mock()

def __call__(self, *args, **kwargs):
self._mock(*args, **kwargs)
def __call__(
self,
*args,
current: float,
initial: float,
target: float,
name: str | None = None,
unit: str | None = None,
precision: float | None = None,
fraction: float | None = None,
time_elapsed: float | None = None,
time_remaining: float | None = None,
**kwargs,
):
self._mock(
*args,
current=current,
initial=initial,
target=target,
name=name,
unit=unit,
precision=precision,
time_elapsed=time_elapsed,
**kwargs,
)
self._event.set()

async def wait_for_call(self, *args, **kwargs):
Expand All @@ -60,8 +83,8 @@ async def wait_for_call(self, *args, **kwargs):

async def test_mover_moving_well(sim_mover: demo.Mover) -> None:
s = sim_mover.set(0.55)
watcher = Watcher()
s.watch([watcher])
watcher = DemoWatcher()
s.watch(watcher)
done = Mock()
s.add_callback(done)
await watcher.wait_for_call(
Expand Down
Loading

0 comments on commit 3dc36c9

Please sign in to comment.