From 3e20106292fc5b754277426a94ceb8bf11f9cf1c Mon Sep 17 00:00:00 2001 From: Zoheb Shaikh Date: Wed, 18 Sep 2024 11:52:15 +0000 Subject: [PATCH] added signal value and signal status --- src/ophyd_async/core/_signal.py | 47 ++++++++++++++++++++++----------- tests/epics/test_motor.py | 7 +++-- 2 files changed, 37 insertions(+), 17 deletions(-) diff --git a/src/ophyd_async/core/_signal.py b/src/ophyd_async/core/_signal.py index 1ed839ea02..567d4ce36c 100644 --- a/src/ophyd_async/core/_signal.py +++ b/src/ophyd_async/core/_signal.py @@ -516,7 +516,7 @@ async def observe_signals_values( do_something_with(value) """ - q: asyncio.Queue[T | Status] = asyncio.Queue() + q: asyncio.Queue[tuple[SignalR[T], T | Status]] = asyncio.Queue() if timeout is None: get_value = q.get else: @@ -524,22 +524,39 @@ async def observe_signals_values( async def get_value(): return await asyncio.wait_for(q.get(), timeout) - if done_status is not None: - done_status.add_callback(q.put_nowait) + def wrapped_signal_put(signal: SignalR[T]): + def queue_value(value: T): + q.put_nowait((signal, value)) + + def queue_status(status: Status): + q.put_nowait((signal, status)) + + def clear_signals(): + signal.clear_sub(queue_value) + signal.clear_sub(queue_status) + + return queue_value, queue_status, clear_signals + + clear_signals = [] for signal in signals: - signal.subscribe_value(q.put_nowait) - try: - while True: - item = await get_value() - if done_status and item is done_status: - if exc := done_status.exception(): - raise exc - else: - break + queue_value, queue_status, clear_signal = wrapped_signal_put(signal) + clear_signals.append(clear_signal) + if done_status is not None: + done_status.add_callback(queue_status) + signal.subscribe_value(queue_value) + try: + while True: + item = await get_value() + if done_status and item is done_status: + if exc := done_status.exception(): + raise exc else: - yield item - finally: - signal.clear_sub(q.put_nowait) + break + else: + yield item + finally: + for clear_signal in clear_signals: + clear_signal() class _ValueChecker(Generic[T]): diff --git a/tests/epics/test_motor.py b/tests/epics/test_motor.py index c965706451..792d8d8aec 100644 --- a/tests/epics/test_motor.py +++ b/tests/epics/test_motor.py @@ -280,10 +280,13 @@ async def test_prepare( set_mock_value(sim_motor.high_limit_travel, 20) set_mock_value(sim_motor.max_velocity, 10) fake_set_signal = SignalRW(MockSignalBackend(float)) + fake_other_set_signal = SignalRW(MockSignalBackend(float)) async def wait_for_set(_): - async for value in observe_signals_values(fake_set_signal, timeout=1): - if value == target_position: + async for signal, value in observe_signals_values( + fake_other_set_signal, fake_set_signal, timeout=1 + ): + if signal == fake_set_signal and value == target_position: break sim_motor.set = AsyncMock(side_effect=wait_for_set)