Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix race condition in set_and_wait_for_value #457

Merged
merged 6 commits into from
Jul 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/ophyd_async/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
assert_reading,
assert_value,
observe_value,
set_and_wait_for_other_value,
set_and_wait_for_value,
soft_signal_r_and_setter,
soft_signal_rw,
Expand Down Expand Up @@ -135,6 +136,7 @@
"assert_value",
"observe_value",
"set_and_wait_for_value",
"set_and_wait_for_other_value",
"soft_signal_r_and_setter",
"soft_signal_rw",
"wait_for_value",
Expand Down
86 changes: 73 additions & 13 deletions src/ophyd_async/core/_signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
Optional,
Tuple,
Type,
TypeVar,
Union,
)

Expand All @@ -33,6 +34,8 @@
from ._status import AsyncStatus
from ._utils import DEFAULT_TIMEOUT, CalculatableTimeout, CalculateTimeout, Callback, T

S = TypeVar("S")


def _add_timeout(func):
@functools.wraps(func)
Expand Down Expand Up @@ -524,7 +527,9 @@ async def wait_for_value(self, signal: SignalR[T], timeout: Optional[float]):


async def wait_for_value(
signal: SignalR[T], match: Union[T, Callable[[T], bool]], timeout: Optional[float]
signal: SignalR[T],
match: Union[T, Callable[[T], bool]],
timeout: Optional[float],
):
"""Wait for a signal to have a matching value.

Expand Down Expand Up @@ -556,6 +561,66 @@ async def wait_for_value(
await checker.wait_for_value(signal, timeout)


async def set_and_wait_for_other_value(
set_signal: SignalRW[T],
set_value: T,
read_signal: SignalR[S],
read_value: S,
timeout: float = DEFAULT_TIMEOUT,
set_timeout: Optional[float] = None,
) -> AsyncStatus:
"""Set a signal and monitor another signal until it has the specified value.

This function sets a set_signal to a specified set_value and waits for
a read_signal to have the read_value.

Parameters
----------
signal:
The signal to set
set_value:
The value to set it to
read_signal:
The signal to monitor
read_value:
The value to wait for
timeout:
How long to wait for the signal to have the value
set_timeout:
How long to wait for the set to complete

Notes
-----
Example usage::

set_and_wait_for_value(device.acquire, 1, device.acquire_rbv, 1)
"""
# Start monitoring before the set to avoid a race condition
values_gen = observe_value(read_signal)

# Get the initial value from the monitor to make sure we've created it
current_value = await anext(values_gen)

status = set_signal.set(set_value, timeout=set_timeout)

# If the value was the same as before no need to wait for it to change
if current_value != read_value:

async def _wait_for_value():
async for value in values_gen:
if value == read_value:
break

try:
await asyncio.wait_for(_wait_for_value(), timeout)
except asyncio.TimeoutError as e:
raise TimeoutError(
f"{read_signal.name} didn't match {read_value} in {timeout}s"
) from e

return status


async def set_and_wait_for_value(
signal: SignalRW[T],
value: T,
Expand All @@ -565,19 +630,14 @@ async def set_and_wait_for_value(
"""Set a signal and monitor it until it has that value.

Useful for busy record, or other Signals with pattern:

- Set Signal with wait=True and stash the Status
- Read the same Signal to check the operation has started
- Return the Status so calling code can wait for operation to complete

This function sets a signal to a specified value, optionally with or without a
ca/pv put callback, and waits for the readback value of the signal to match the
value it was set to.
- Set Signal with wait=True and stash the Status
- Read the same Signal to check the operation has started
- Return the Status so calling code can wait for operation to complete

Parameters
----------
signal:
The signal to set and monitor
The signal to set
value:
The value to set it to
timeout:
Expand All @@ -591,6 +651,6 @@ async def set_and_wait_for_value(

set_and_wait_for_value(device.acquire, 1)
"""
status = signal.set(value, timeout=status_timeout)
await wait_for_value(signal, value, timeout=timeout)
return status
return await set_and_wait_for_other_value(
signal, value, signal, value, timeout, status_timeout
)
73 changes: 60 additions & 13 deletions tests/core/test_signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
import re
import time
from asyncio import Event
from unittest.mock import ANY

import numpy
Expand All @@ -23,6 +24,8 @@
assert_configuration,
assert_reading,
assert_value,
callback_on_mock_put,
set_and_wait_for_other_value,
set_and_wait_for_value,
set_mock_put_proceeds,
set_mock_value,
Expand Down Expand Up @@ -182,32 +185,76 @@ async def time_taken_by(coro) -> float:
return time.monotonic() - start


async def test_set_and_wait_for_value():
async def test_set_and_wait_for_value_same_set_as_read():
signal = epics_signal_rw(int, "pva://pv", name="signal")
await signal.connect(mock=True)
assert await signal.get_value() == 0
set_mock_put_proceeds(signal, False)

do_read_set = Event()
callback_on_mock_put(signal, lambda *args, **kwargs: do_read_set.set())

async def wait_and_set_proceeds():
await asyncio.sleep(0.1)
await do_read_set.wait()
set_mock_put_proceeds(signal, True)
await asyncio.sleep(0.01)

async def check_set_and_wait():
st = await set_and_wait_for_value(signal, 1, timeout=100)
await st
await asyncio.sleep(0.01)
await (await set_and_wait_for_value(signal, 1, timeout=0.1))

assert (
0.1
< await time_taken_by(
asyncio.gather(wait_and_set_proceeds(), check_set_and_wait())
)
< 0.15
)
await asyncio.gather(wait_and_set_proceeds(), check_set_and_wait())
assert await signal.get_value() == 1


async def test_set_and_wait_for_value_different_set_and_read():
set_signal = epics_signal_rw(int, "pva://set", name="set-signal")
read_signal = epics_signal_r(str, "pva://read", name="read-signal")
await set_signal.connect(mock=True)
await read_signal.connect(mock=True)

do_read_set = Event()

callback_on_mock_put(set_signal, lambda *args, **kwargs: do_read_set.set())

async def wait_and_set_read():
await do_read_set.wait()
set_mock_value(read_signal, "test")

async def check_set_and_wait():
await (
await set_and_wait_for_other_value(
set_signal, 1, read_signal, "test", timeout=100
)
)

await asyncio.gather(wait_and_set_read(), check_set_and_wait())
assert await set_signal.get_value() == 1


async def test_set_and_wait_for_value_different_set_and_read_times_out():
set_signal = epics_signal_rw(int, "pva://set", name="set-signal")
read_signal = epics_signal_r(str, "pva://read", name="read-signal")
await set_signal.connect(mock=True)
await read_signal.connect(mock=True)

do_read_set = Event()

callback_on_mock_put(set_signal, lambda *args, **kwargs: do_read_set.set())

async def wait_and_set_read():
await do_read_set.wait()
set_mock_value(read_signal, "not_test")

async def check_set_and_wait():
await (
await set_and_wait_for_other_value(
set_signal, 1, read_signal, "test", timeout=0.1
)
)

with pytest.raises(TimeoutError):
await asyncio.gather(wait_and_set_read(), check_set_and_wait())


async def test_wait_for_value_with_value():
signal = epics_signal_rw(str, read_pv="pva://signal", name="signal")
await signal.connect(mock=True)
Expand Down
Loading