Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

wait_for_value_interface_change #582

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
Draft
115 changes: 95 additions & 20 deletions src/ophyd_async/core/_signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,75 @@ async def get_value():
signal.clear_sub(q.put_nowait)


async def observe_signals_values(
ZohebShaikh marked this conversation as resolved.
Show resolved Hide resolved
*signals: SignalR[T],
timeout: float | None = None,
done_status: Status | None = None,
) -> AsyncGenerator[tuple[SignalR[T], T], None]:
"""Subscribe to the value of a signal so it can be iterated from.

Parameters
----------
signals:
Call subscribe_value on this at the start, and clear_sub on it at the
end
timeout:
If given, how long to wait for each updated value in seconds. If an update
is not produced in this time then raise asyncio.TimeoutError
done_status:
If this status is complete, stop observing and make the iterator return.
If it raises an exception then this exception will be raised by the iterator.

Notes
-----
Example usage::

async for value in observe_value(sig):
do_something_with(value)
"""
q: asyncio.Queue[tuple[SignalR[T], T | Status]] = asyncio.Queue()
ZohebShaikh marked this conversation as resolved.
Show resolved Hide resolved
if timeout is None:
get_value = q.get
else:

async def get_value():
return await asyncio.wait_for(q.get(), timeout)

def wrapped_signal_put(signal: SignalR[T]):
def queue_value(value: T):
q.put_nowait((signal, value))

def queue_status(status: Status):
q.put_nowait((signal, status))

def clear_signals():
signal.clear_sub(queue_value)
signal.clear_sub(queue_status)

return queue_value, queue_status, clear_signals

clear_signals = []
for signal in signals:
queue_value, queue_status, clear_signal = wrapped_signal_put(signal)
clear_signals.append(clear_signal)
if done_status is not None:
done_status.add_callback(queue_status)
signal.subscribe_value(queue_value)
ZohebShaikh marked this conversation as resolved.
Show resolved Hide resolved
try:
while True:
item = await get_value()
if done_status and item is done_status:
if exc := done_status.exception():
raise exc
else:
break
else:
yield item # type: ignore
finally:
for clear_signal in clear_signals:
clear_signal()
ZohebShaikh marked this conversation as resolved.
Show resolved Hide resolved


class _ValueChecker(Generic[T]):
def __init__(self, matcher: Callable[[T], bool], matcher_name: str):
self._last_value: T | None = None
Expand All @@ -514,7 +583,7 @@ async def wait_for_value(self, signal: SignalR[T], timeout: float | None):

async def wait_for_value(
signal: SignalR[T],
match: T | Callable[[T], bool],
match_value: T | Callable[[T], bool],
timeout: float | None,
):
"""Wait for a signal to have a matching value.
Expand All @@ -540,35 +609,36 @@ async def wait_for_value(

wait_for_value(device.num_captured, lambda v: v > 45, timeout=1)
"""
if callable(match):
checker = _ValueChecker(match, match.__name__) # type: ignore

if callable(match_value):
checker = _ValueChecker(match_value, match_value.__name__) # type: ignore
else:
checker = _ValueChecker(lambda v: v == match, repr(match))
checker = _ValueChecker(lambda v: v == match_value, repr(match_value))
await checker.wait_for_value(signal, timeout)


async def set_and_wait_for_other_value(
set_signal: SignalW[T],
set_value: T,
read_signal: SignalR[S],
read_value: S,
match_signal: SignalR[S],
match_value: S | Callable[[S], bool],
timeout: float = DEFAULT_TIMEOUT,
set_timeout: float | None = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We decided to add wait_for_set_completion = True (or something like that) to both functions and always return the status, but then pass False from the areaDetector utility function

) -> AsyncStatus:
):
"""Set a signal and monitor another signal until it has the specified value.

This function sets a set_signal to a specified set_value and waits for
a read_signal to have the read_value.
a match_signal to have the match_value.

Parameters
----------
signal:
The signal to set
set_value:
The value to set it to
read_signal:
match_signal:
The signal to monitor
read_value:
match_value:
The value to wait for
timeout:
How long to wait for the signal to have the value
Expand All @@ -582,37 +652,36 @@ async def set_and_wait_for_other_value(
set_and_wait_for_value(device.acquire, 1, device.acquire_rbv, 1)
"""
# Start monitoring before the set to avoid a race condition
values_gen = observe_value(read_signal)
values_gen = observe_value(match_signal)

# Get the initial value from the monitor to make sure we've created it
current_value = await anext(values_gen)

status = set_signal.set(set_value, timeout=set_timeout)
set_signal.set(set_value, timeout=set_timeout)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DominicOram this is what @ZohebShaikh was alluding to in #402 (comment).

If we choose A, then we need to decide what to do with the Status. Something needs to keep track of it an await it, we can't drop it on the floor as we are doing here or we get teardown errors in the tests, so we have to return it. The calling code then looks like:

status = await set_and_wait_for_other_value(dev.acquire, 1, dev.acquire_rbv, 1)
# do something that you can do as soon as the thing is acquiring
await status
# now the device has finished acquiring

If we choose B, then we keep a track of status here, then await it at the end of the function. The calling code becomes:

await set_and_wait_for_other_value(dev.acquire, 1, dev.acquire_rbv, 1)
# now the device has finished acquiring,

What was your actual use case for this? Will the write_signal ever take significantly longer to caput-callback than the read_signal will take to change to the match_value? I know we have this case for the areaDetector acquire PV above, but we decided to not use this function as it was clearer to write:

arm_status = dev.acquire.set(1)
await wait_for_value(dev.acquire_rbv, 1)
# do something we can do when the device is armed
await arm_status
# now the device is disarmed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My original usecase for this was the actually the areaDetector one. We tried the code as suggested:

arm_status = dev.acquire.set(1)
await wait_for_value(dev.acquire_rbv, 1)
await arm_status

But found that the wait_for_value was failing. This is because for this particular configuration the detector immediately takes data and does so so quickly that it has finished by the time we start monitoring the RBV. See #453 (comment). My main motivation for putting this into ophyd-async was to provide a function that would always protect against the potential race condition so we're less like to see it with people taking the naive approach.

I believe for my use case the callback has already returned on the set before the RBV goes to the expected value but I would have to check. Either way I think what I would expect to happen is that set_and_wait_for_other_value returns a gather of both statuses. If you want to do some in between these then you need to do that yourself. Maybe we should have a chat about it?


# If the value was the same as before no need to wait for it to change
if current_value != read_value:
if current_value != match_value:

async def _wait_for_value():
async for value in values_gen:
if value == read_value:
if value == match_value:
break

try:
await asyncio.wait_for(_wait_for_value(), timeout)
except asyncio.TimeoutError as e:
raise TimeoutError(
f"{read_signal.name} didn't match {read_value} in {timeout}s"
f"{match_signal.name} didn't match {match_value} in {timeout}s"
) from e

return status


async def set_and_wait_for_value(
signal: SignalRW[T],
value: T,
match_value: T | Callable[[T], bool] | None = None,
timeout: float = DEFAULT_TIMEOUT,
status_timeout: float | None = None,
) -> AsyncStatus:
):
"""Set a signal and monitor it until it has that value.

Useful for busy record, or other Signals with pattern:
Expand All @@ -626,6 +695,9 @@ async def set_and_wait_for_value(
The signal to set
value:
The value to set it to
match_value:
The expected value of the signal after the operation.
Used to verify that the set operation was successful.
timeout:
How long to wait for the signal to have the value
status_timeout:
Expand All @@ -637,6 +709,9 @@ async def set_and_wait_for_value(

set_and_wait_for_value(device.acquire, 1)
"""
return await set_and_wait_for_other_value(
signal, value, signal, value, timeout, status_timeout
if match_value is None:
match_value = value

await set_and_wait_for_other_value(
signal, value, signal, match_value, timeout, status_timeout
)
6 changes: 4 additions & 2 deletions src/ophyd_async/epics/adaravis/_aravis_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
DetectorControl,
DetectorTrigger,
TriggerInfo,
set_and_wait_for_value,
)
from ophyd_async.core._status import AsyncStatus
from ophyd_async.epics import adcore
from ophyd_async.epics.adcore._core_logic import (
start_acquiring_driver_and_ensure_status,
)

from ._aravis_io import AravisDriverIO, AravisTriggerMode, AravisTriggerSource

Expand Down Expand Up @@ -48,7 +50,7 @@ async def prepare(self, trigger_info: TriggerInfo):
)

async def arm(self):
self._arm_status = await set_and_wait_for_value(self._drv.acquire, True)
self._arm_status = await start_acquiring_driver_and_ensure_status(self._drv)

async def wait_for_idle(self):
if self._arm_status:
Expand Down
6 changes: 4 additions & 2 deletions src/ophyd_async/epics/adcore/_core_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
DetectorControl,
set_and_wait_for_value,
)
from ophyd_async.core._signal import wait_for_value
from ophyd_async.epics.adcore._utils import convert_ad_dtype_to_np

from ._core_io import ADBaseIO, DetectorState
Expand Down Expand Up @@ -90,8 +91,9 @@ async def start_acquiring_driver_and_ensure_status(
An AsyncStatus that can be awaited to set driver.acquire to True and perform
subsequent raising (if applicable) due to detector state.
"""

status = await set_and_wait_for_value(driver.acquire, True, timeout=timeout)
status = driver.acquire.set(True, timeout=timeout)
await wait_for_value(driver.acquire, True, timeout=None)
await set_and_wait_for_value(driver.acquire, True, timeout=timeout)

async def complete_acquisition() -> None:
"""NOTE: possible race condition here between the callback from
Expand Down
17 changes: 9 additions & 8 deletions tests/core/test_signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,32 +180,33 @@ async def wait_and_set_proceeds():
set_mock_put_proceeds(signal, True)

async def check_set_and_wait():
await (await set_and_wait_for_value(signal, 1, timeout=0.1))
await set_and_wait_for_value(signal, 1, timeout=0.1)
await wait_for_value(signal, 1, timeout=0.1)

await asyncio.gather(wait_and_set_proceeds(), check_set_and_wait())
assert await signal.get_value() == 1


async def test_set_and_wait_for_value_different_set_and_read():
set_signal = epics_signal_rw(int, "pva://set", name="set-signal")
read_signal = epics_signal_r(str, "pva://read", name="read-signal")
match_signal = epics_signal_r(str, "pva://read", name="read-signal")
await set_signal.connect(mock=True)
await read_signal.connect(mock=True)
await match_signal.connect(mock=True)

do_read_set = Event()

callback_on_mock_put(set_signal, lambda *args, **kwargs: do_read_set.set())

async def wait_and_set_read():
await do_read_set.wait()
set_mock_value(read_signal, "test")
set_mock_value(match_signal, "test")

async def check_set_and_wait():
await (
await set_and_wait_for_other_value(
set_signal, 1, read_signal, "test", timeout=100
)
await set_and_wait_for_other_value(
set_signal, 1, match_signal, "test", timeout=100
)
await wait_for_value(set_signal, 1, timeout=100)
await wait_for_value(match_signal, "test", timeout=100)

await asyncio.gather(wait_and_set_read(), check_set_and_wait())
assert await set_signal.get_value() == 1
Expand Down
9 changes: 6 additions & 3 deletions tests/epics/test_motor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
SignalRW,
callback_on_mock_put,
mock_puts_blocked,
observe_value,
set_mock_put_proceeds,
set_mock_value,
)
from ophyd_async.core._signal import observe_signals_values
from ophyd_async.epics import motor

# Long enough for multiple asyncio event loop cycles to run so
Expand Down Expand Up @@ -279,10 +279,13 @@ async def test_prepare(
set_mock_value(sim_motor.high_limit_travel, 20)
set_mock_value(sim_motor.max_velocity, 10)
fake_set_signal = SignalRW(MockSignalBackend(float))
fake_other_set_signal = SignalRW(MockSignalBackend(float))

async def wait_for_set(_):
async for value in observe_value(fake_set_signal, timeout=1):
if value == target_position:
async for signal, value in observe_signals_values(
fake_other_set_signal, fake_set_signal, timeout=1
):
if signal == fake_set_signal and value == target_position:
break

sim_motor.set = AsyncMock(side_effect=wait_for_set)
Expand Down
Loading