From c3f0528555245391eb6680e8bb6987460d0a4161 Mon Sep 17 00:00:00 2001 From: Zoheb Shaikh Date: Tue, 17 Sep 2024 11:44:07 +0000 Subject: [PATCH 01/30] initial commit --- src/ophyd_async/core/_signal.py | 20 +++++++++---------- .../epics/adaravis/_aravis_controller.py | 3 +-- 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/src/ophyd_async/core/_signal.py b/src/ophyd_async/core/_signal.py index 340298160b..9a54e8fd74 100644 --- a/src/ophyd_async/core/_signal.py +++ b/src/ophyd_async/core/_signal.py @@ -513,7 +513,7 @@ async def wait_for_value(self, signal: SignalR[T], timeout: Optional[float]): async def wait_for_value( signal: SignalR[T], - match: Union[T, Callable[[T], bool]], + match_value: T | Callable[[T], bool], timeout: Optional[float], ): """Wait for a signal to have a matching value. @@ -539,18 +539,18 @@ async def wait_for_value( wait_for_value(device.num_captured, lambda v: v > 45, timeout=1) """ - if callable(match): - checker = _ValueChecker(match, match.__name__) + if callable(match_value): + checker = _ValueChecker(match_value, match_value.__name__) else: - checker = _ValueChecker(lambda v: v == match, repr(match)) + checker = _ValueChecker(lambda v: v == match_value, repr(match_value)) await checker.wait_for_value(signal, timeout) async def set_and_wait_for_other_value( set_signal: SignalW[T], set_value: T, - read_signal: SignalR[S], - read_value: S, + match_signal: SignalR[S], + match_value: S, timeout: float = DEFAULT_TIMEOUT, set_timeout: Optional[float] = None, ) -> AsyncStatus: @@ -581,7 +581,7 @@ async def set_and_wait_for_other_value( set_and_wait_for_value(device.acquire, 1, device.acquire_rbv, 1) """ # Start monitoring before the set to avoid a race condition - values_gen = observe_value(read_signal) + values_gen = observe_value(match_signal) # Get the initial value from the monitor to make sure we've created it current_value = await anext(values_gen) @@ -589,18 +589,18 @@ async def set_and_wait_for_other_value( status = set_signal.set(set_value, timeout=set_timeout) # If the value was the same as before no need to wait for it to change - if current_value != read_value: + if current_value != match_value: async def _wait_for_value(): async for value in values_gen: - if value == read_value: + if value == match_value: break try: await asyncio.wait_for(_wait_for_value(), timeout) except asyncio.TimeoutError as e: raise TimeoutError( - f"{read_signal.name} didn't match {read_value} in {timeout}s" + f"{match_signal.name} didn't match {match_value} in {timeout}s" ) from e return status diff --git a/src/ophyd_async/epics/adaravis/_aravis_controller.py b/src/ophyd_async/epics/adaravis/_aravis_controller.py index 894a46c008..4a112929d3 100644 --- a/src/ophyd_async/epics/adaravis/_aravis_controller.py +++ b/src/ophyd_async/epics/adaravis/_aravis_controller.py @@ -5,7 +5,6 @@ DetectorControl, DetectorTrigger, TriggerInfo, - set_and_wait_for_value, ) from ophyd_async.core._status import AsyncStatus from ophyd_async.epics import adcore @@ -48,7 +47,7 @@ async def prepare(self, trigger_info: TriggerInfo): ) async def arm(self): - self._arm_status = await set_and_wait_for_value(self._drv.acquire, True) + self._arm_status = self._drv.acquire.set(True) async def wait_for_idle(self): if self._arm_status: From 0b34872bc3cf81885033c4ed87af00d66cc61a7c Mon Sep 17 00:00:00 2001 From: Zoheb Shaikh Date: Tue, 17 Sep 2024 12:18:18 +0000 Subject: [PATCH 02/30] added observe signals --- src/ophyd_async/core/_signal.py | 53 +++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/src/ophyd_async/core/_signal.py b/src/ophyd_async/core/_signal.py index 9a54e8fd74..127065a96f 100644 --- a/src/ophyd_async/core/_signal.py +++ b/src/ophyd_async/core/_signal.py @@ -489,6 +489,59 @@ async def get_value(): signal.clear_sub(q.put_nowait) +async def observe_signals_values( + *signals: SignalR[T], + timeout: float | None = None, + done_status: Status | None = None, +) -> AsyncGenerator[T, None]: + """Subscribe to the value of a signal so it can be iterated from. + + Parameters + ---------- + signals: + Call subscribe_value on this at the start, and clear_sub on it at the + end + timeout: + If given, how long to wait for each updated value in seconds. If an update + is not produced in this time then raise asyncio.TimeoutError + done_status: + If this status is complete, stop observing and make the iterator return. + If it raises an exception then this exception will be raised by the iterator. + + Notes + ----- + Example usage:: + + async for value in observe_value(sig): + do_something_with(value) + """ + + q: asyncio.Queue[T | Status] = asyncio.Queue() + if timeout is None: + get_value = q.get + else: + + async def get_value(): + return await asyncio.wait_for(q.get(), timeout) + + if done_status is not None: + done_status.add_callback(q.put_nowait) + for signal in signals: + signal.subscribe_value(q.put_nowait) + try: + while True: + item = await get_value() + if done_status and item is done_status: + if exc := done_status.exception(): + raise exc + else: + break + else: + yield item + finally: + signal.clear_sub(q.put_nowait) + + class _ValueChecker(Generic[T]): def __init__(self, matcher: Callable[[T], bool], matcher_name: str): self._last_value: Optional[T] = None From caadacac1ad5418d12019420e5bab340bb2b3edd Mon Sep 17 00:00:00 2001 From: Zoheb Shaikh Date: Tue, 17 Sep 2024 13:29:18 +0000 Subject: [PATCH 03/30] added observe_signals for coverage --- tests/epics/test_motor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/epics/test_motor.py b/tests/epics/test_motor.py index 51d778007d..c965706451 100644 --- a/tests/epics/test_motor.py +++ b/tests/epics/test_motor.py @@ -14,10 +14,10 @@ SignalRW, callback_on_mock_put, mock_puts_blocked, - observe_value, set_mock_put_proceeds, set_mock_value, ) +from ophyd_async.core._signal import observe_signals_values from ophyd_async.epics import motor # Long enough for multiple asyncio event loop cycles to run so @@ -282,7 +282,7 @@ async def test_prepare( fake_set_signal = SignalRW(MockSignalBackend(float)) async def wait_for_set(_): - async for value in observe_value(fake_set_signal, timeout=1): + async for value in observe_signals_values(fake_set_signal, timeout=1): if value == target_position: break From 822a02419147a1a17ac915ee79983a998153747d Mon Sep 17 00:00:00 2001 From: Zoheb Shaikh Date: Tue, 17 Sep 2024 15:09:31 +0000 Subject: [PATCH 04/30] added type changes to interface --- src/ophyd_async/core/_signal.py | 29 ++++++++++++--------- src/ophyd_async/epics/adcore/_core_logic.py | 5 ++-- tests/core/test_signal.py | 17 ++++++------ 3 files changed, 29 insertions(+), 22 deletions(-) diff --git a/src/ophyd_async/core/_signal.py b/src/ophyd_async/core/_signal.py index 127065a96f..1ed839ea02 100644 --- a/src/ophyd_async/core/_signal.py +++ b/src/ophyd_async/core/_signal.py @@ -493,7 +493,7 @@ async def observe_signals_values( *signals: SignalR[T], timeout: float | None = None, done_status: Status | None = None, -) -> AsyncGenerator[T, None]: +) -> AsyncGenerator[Tuple[SignalR[T], T], None]: """Subscribe to the value of a signal so it can be iterated from. Parameters @@ -603,14 +603,14 @@ async def set_and_wait_for_other_value( set_signal: SignalW[T], set_value: T, match_signal: SignalR[S], - match_value: S, + match_value: S | Callable[[S], bool], timeout: float = DEFAULT_TIMEOUT, set_timeout: Optional[float] = None, -) -> AsyncStatus: +): """Set a signal and monitor another signal until it has the specified value. This function sets a set_signal to a specified set_value and waits for - a read_signal to have the read_value. + a match_signal to have the match_value. Parameters ---------- @@ -618,9 +618,9 @@ async def set_and_wait_for_other_value( The signal to set set_value: The value to set it to - read_signal: + match_signal: The signal to monitor - read_value: + match_value: The value to wait for timeout: How long to wait for the signal to have the value @@ -639,7 +639,7 @@ async def set_and_wait_for_other_value( # Get the initial value from the monitor to make sure we've created it current_value = await anext(values_gen) - status = set_signal.set(set_value, timeout=set_timeout) + set_signal.set(set_value, timeout=set_timeout) # If the value was the same as before no need to wait for it to change if current_value != match_value: @@ -656,15 +656,14 @@ async def _wait_for_value(): f"{match_signal.name} didn't match {match_value} in {timeout}s" ) from e - return status - async def set_and_wait_for_value( signal: SignalRW[T], value: T, + match_value: T | Callable[[T], bool] | None = None, timeout: float = DEFAULT_TIMEOUT, status_timeout: Optional[float] = None, -) -> AsyncStatus: +): """Set a signal and monitor it until it has that value. Useful for busy record, or other Signals with pattern: @@ -678,6 +677,9 @@ async def set_and_wait_for_value( The signal to set value: The value to set it to + match_value: + The expected value of the signal after the operation. + Used to verify that the set operation was successful. timeout: How long to wait for the signal to have the value status_timeout: @@ -689,6 +691,9 @@ async def set_and_wait_for_value( set_and_wait_for_value(device.acquire, 1) """ - return await set_and_wait_for_other_value( - signal, value, signal, value, timeout, status_timeout + if match_value is None: + match_value = value + + await set_and_wait_for_other_value( + signal, value, signal, match_value, timeout, status_timeout ) diff --git a/src/ophyd_async/epics/adcore/_core_logic.py b/src/ophyd_async/epics/adcore/_core_logic.py index 21b07406fb..b545ad27fa 100644 --- a/src/ophyd_async/epics/adcore/_core_logic.py +++ b/src/ophyd_async/epics/adcore/_core_logic.py @@ -8,6 +8,7 @@ DetectorControl, set_and_wait_for_value, ) +from ophyd_async.core._signal import wait_for_value from ophyd_async.epics.adcore._utils import convert_ad_dtype_to_np from ._core_io import ADBaseIO, DetectorState @@ -92,12 +93,12 @@ async def start_acquiring_driver_and_ensure_status( subsequent raising (if applicable) due to detector state. """ - status = await set_and_wait_for_value(driver.acquire, True, timeout=timeout) + await set_and_wait_for_value(driver.acquire, True, timeout=timeout) async def complete_acquisition() -> None: """NOTE: possible race condition here between the callback from set_and_wait_for_value and the detector state updating.""" - await status + await wait_for_value(driver.acquire, True, timeout=None) state = await driver.detector_state.get_value() if state not in good_states: raise ValueError( diff --git a/tests/core/test_signal.py b/tests/core/test_signal.py index d498e67531..acbbde6a24 100644 --- a/tests/core/test_signal.py +++ b/tests/core/test_signal.py @@ -178,7 +178,8 @@ async def wait_and_set_proceeds(): set_mock_put_proceeds(signal, True) async def check_set_and_wait(): - await (await set_and_wait_for_value(signal, 1, timeout=0.1)) + await set_and_wait_for_value(signal, 1, timeout=0.1) + await wait_for_value(signal, 1, timeout=0.1) await asyncio.gather(wait_and_set_proceeds(), check_set_and_wait()) assert await signal.get_value() == 1 @@ -186,9 +187,9 @@ async def check_set_and_wait(): async def test_set_and_wait_for_value_different_set_and_read(): set_signal = epics_signal_rw(int, "pva://set", name="set-signal") - read_signal = epics_signal_r(str, "pva://read", name="read-signal") + match_signal = epics_signal_r(str, "pva://read", name="read-signal") await set_signal.connect(mock=True) - await read_signal.connect(mock=True) + await match_signal.connect(mock=True) do_read_set = Event() @@ -196,14 +197,14 @@ async def test_set_and_wait_for_value_different_set_and_read(): async def wait_and_set_read(): await do_read_set.wait() - set_mock_value(read_signal, "test") + set_mock_value(match_signal, "test") async def check_set_and_wait(): - await ( - await set_and_wait_for_other_value( - set_signal, 1, read_signal, "test", timeout=100 - ) + await set_and_wait_for_other_value( + set_signal, 1, match_signal, "test", timeout=100 ) + await wait_for_value(set_signal, 1, timeout=100) + await wait_for_value(match_signal, "test", timeout=100) await asyncio.gather(wait_and_set_read(), check_set_and_wait()) assert await set_signal.get_value() == 1 From 76b1392b7a8f4615985a15741c6adcc3ae10d92b Mon Sep 17 00:00:00 2001 From: Zoheb Shaikh Date: Tue, 17 Sep 2024 15:57:16 +0000 Subject: [PATCH 05/30] changed start_acquiring_driver_and_ensure_status --- src/ophyd_async/epics/adaravis/_aravis_controller.py | 5 ++++- src/ophyd_async/epics/adcore/_core_logic.py | 5 +++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/ophyd_async/epics/adaravis/_aravis_controller.py b/src/ophyd_async/epics/adaravis/_aravis_controller.py index 4a112929d3..f5404d3644 100644 --- a/src/ophyd_async/epics/adaravis/_aravis_controller.py +++ b/src/ophyd_async/epics/adaravis/_aravis_controller.py @@ -8,6 +8,9 @@ ) from ophyd_async.core._status import AsyncStatus from ophyd_async.epics import adcore +from ophyd_async.epics.adcore._core_logic import ( + start_acquiring_driver_and_ensure_status, +) from ._aravis_io import AravisDriverIO, AravisTriggerMode, AravisTriggerSource @@ -47,7 +50,7 @@ async def prepare(self, trigger_info: TriggerInfo): ) async def arm(self): - self._arm_status = self._drv.acquire.set(True) + self._arm_status = await start_acquiring_driver_and_ensure_status(self._drv) async def wait_for_idle(self): if self._arm_status: diff --git a/src/ophyd_async/epics/adcore/_core_logic.py b/src/ophyd_async/epics/adcore/_core_logic.py index b545ad27fa..098df53b7c 100644 --- a/src/ophyd_async/epics/adcore/_core_logic.py +++ b/src/ophyd_async/epics/adcore/_core_logic.py @@ -92,13 +92,14 @@ async def start_acquiring_driver_and_ensure_status( An AsyncStatus that can be awaited to set driver.acquire to True and perform subsequent raising (if applicable) due to detector state. """ - + status = driver.acquire.set(True, timeout=timeout) + await wait_for_value(driver.acquire, True, timeout=None) await set_and_wait_for_value(driver.acquire, True, timeout=timeout) async def complete_acquisition() -> None: """NOTE: possible race condition here between the callback from set_and_wait_for_value and the detector state updating.""" - await wait_for_value(driver.acquire, True, timeout=None) + await status state = await driver.detector_state.get_value() if state not in good_states: raise ValueError( From 3e20106292fc5b754277426a94ceb8bf11f9cf1c Mon Sep 17 00:00:00 2001 From: Zoheb Shaikh Date: Wed, 18 Sep 2024 11:52:15 +0000 Subject: [PATCH 06/30] added signal value and signal status --- src/ophyd_async/core/_signal.py | 47 ++++++++++++++++++++++----------- tests/epics/test_motor.py | 7 +++-- 2 files changed, 37 insertions(+), 17 deletions(-) diff --git a/src/ophyd_async/core/_signal.py b/src/ophyd_async/core/_signal.py index 1ed839ea02..567d4ce36c 100644 --- a/src/ophyd_async/core/_signal.py +++ b/src/ophyd_async/core/_signal.py @@ -516,7 +516,7 @@ async def observe_signals_values( do_something_with(value) """ - q: asyncio.Queue[T | Status] = asyncio.Queue() + q: asyncio.Queue[tuple[SignalR[T], T | Status]] = asyncio.Queue() if timeout is None: get_value = q.get else: @@ -524,22 +524,39 @@ async def observe_signals_values( async def get_value(): return await asyncio.wait_for(q.get(), timeout) - if done_status is not None: - done_status.add_callback(q.put_nowait) + def wrapped_signal_put(signal: SignalR[T]): + def queue_value(value: T): + q.put_nowait((signal, value)) + + def queue_status(status: Status): + q.put_nowait((signal, status)) + + def clear_signals(): + signal.clear_sub(queue_value) + signal.clear_sub(queue_status) + + return queue_value, queue_status, clear_signals + + clear_signals = [] for signal in signals: - signal.subscribe_value(q.put_nowait) - try: - while True: - item = await get_value() - if done_status and item is done_status: - if exc := done_status.exception(): - raise exc - else: - break + queue_value, queue_status, clear_signal = wrapped_signal_put(signal) + clear_signals.append(clear_signal) + if done_status is not None: + done_status.add_callback(queue_status) + signal.subscribe_value(queue_value) + try: + while True: + item = await get_value() + if done_status and item is done_status: + if exc := done_status.exception(): + raise exc else: - yield item - finally: - signal.clear_sub(q.put_nowait) + break + else: + yield item + finally: + for clear_signal in clear_signals: + clear_signal() class _ValueChecker(Generic[T]): diff --git a/tests/epics/test_motor.py b/tests/epics/test_motor.py index c965706451..792d8d8aec 100644 --- a/tests/epics/test_motor.py +++ b/tests/epics/test_motor.py @@ -280,10 +280,13 @@ async def test_prepare( set_mock_value(sim_motor.high_limit_travel, 20) set_mock_value(sim_motor.max_velocity, 10) fake_set_signal = SignalRW(MockSignalBackend(float)) + fake_other_set_signal = SignalRW(MockSignalBackend(float)) async def wait_for_set(_): - async for value in observe_signals_values(fake_set_signal, timeout=1): - if value == target_position: + async for signal, value in observe_signals_values( + fake_other_set_signal, fake_set_signal, timeout=1 + ): + if signal == fake_set_signal and value == target_position: break sim_motor.set = AsyncMock(side_effect=wait_for_set) From 41da0f4768337472db972692156488a298f3213c Mon Sep 17 00:00:00 2001 From: Zoheb Shaikh Date: Wed, 18 Sep 2024 13:47:37 +0000 Subject: [PATCH 07/30] made lint changes --- src/ophyd_async/core/_signal.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/src/ophyd_async/core/_signal.py b/src/ophyd_async/core/_signal.py index c3991a23d3..3dfd946c83 100644 --- a/src/ophyd_async/core/_signal.py +++ b/src/ophyd_async/core/_signal.py @@ -494,7 +494,7 @@ async def observe_signals_values( *signals: SignalR[T], timeout: float | None = None, done_status: Status | None = None, -) -> AsyncGenerator[Tuple[SignalR[T], T], None]: +) -> AsyncGenerator[tuple[SignalR[T], T], None]: """Subscribe to the value of a signal so it can be iterated from. Parameters @@ -516,7 +516,6 @@ async def observe_signals_values( async for value in observe_value(sig): do_something_with(value) """ - q: asyncio.Queue[tuple[SignalR[T], T | Status]] = asyncio.Queue() if timeout is None: get_value = q.get @@ -554,7 +553,7 @@ def clear_signals(): else: break else: - yield item + yield item # type: ignore finally: for clear_signal in clear_signals: clear_signal() @@ -612,7 +611,7 @@ async def wait_for_value( """ if callable(match_value): - checker = _ValueChecker(match_value, match_value.__name__) # type: ignore + checker = _ValueChecker(match_value, match_value.__name__) # type: ignore else: checker = _ValueChecker(lambda v: v == match_value, repr(match_value)) await checker.wait_for_value(signal, timeout) @@ -624,10 +623,8 @@ async def set_and_wait_for_other_value( match_signal: SignalR[S], match_value: S | Callable[[S], bool], timeout: float = DEFAULT_TIMEOUT, - set_timeout: float | None = None, ): - """Set a signal and monitor another signal until it has the specified value. This function sets a set_signal to a specified set_value and waits for @@ -685,7 +682,6 @@ async def set_and_wait_for_value( timeout: float = DEFAULT_TIMEOUT, status_timeout: float | None = None, ): - """Set a signal and monitor it until it has that value. Useful for busy record, or other Signals with pattern: From 232d921d5dcb838656af592abbb36a19a16de964 Mon Sep 17 00:00:00 2001 From: Zoheb Shaikh Date: Fri, 20 Sep 2024 10:00:53 +0100 Subject: [PATCH 08/30] initial commit --- src/ophyd_async/core/_signal.py | 32 ++++++++++++-------------------- 1 file changed, 12 insertions(+), 20 deletions(-) diff --git a/src/ophyd_async/core/_signal.py b/src/ophyd_async/core/_signal.py index 3dfd946c83..4cc6ae54a9 100644 --- a/src/ophyd_async/core/_signal.py +++ b/src/ophyd_async/core/_signal.py @@ -513,10 +513,10 @@ async def observe_signals_values( ----- Example usage:: - async for value in observe_value(sig): + async for value1,value2,value3 in observe_signals_values(sig1,sig2,..): do_something_with(value) """ - q: asyncio.Queue[tuple[SignalR[T], T | Status]] = asyncio.Queue() + q: asyncio.Queue[tuple[SignalR[T], T] | Status] = asyncio.Queue() if timeout is None: get_value = q.get else: @@ -524,26 +524,18 @@ async def observe_signals_values( async def get_value(): return await asyncio.wait_for(q.get(), timeout) - def wrapped_signal_put(signal: SignalR[T]): - def queue_value(value: T): - q.put_nowait((signal, value)) + cbs: dict[SignalR, Callback] = {} + for signal in signals: - def queue_status(status: Status): - q.put_nowait((signal, status)) + def queue_value(value: T, signal=signal): + q.put_nowait((signal, value)) - def clear_signals(): - signal.clear_sub(queue_value) - signal.clear_sub(queue_status) + cbs[signal] = queue_value + signal.subscribe_value(queue_value) - return queue_value, queue_status, clear_signals + if done_status is not None: + done_status.add_callback(q.put_nowait) - clear_signals = [] - for signal in signals: - queue_value, queue_status, clear_signal = wrapped_signal_put(signal) - clear_signals.append(clear_signal) - if done_status is not None: - done_status.add_callback(queue_status) - signal.subscribe_value(queue_value) try: while True: item = await get_value() @@ -555,8 +547,8 @@ def clear_signals(): else: yield item # type: ignore finally: - for clear_signal in clear_signals: - clear_signal() + for signal, cb in cbs.items(): + signal.clear_sub(cb) class _ValueChecker(Generic[T]): From dd10caf9807efb51f9c9ad6927fb37eb33046966 Mon Sep 17 00:00:00 2001 From: Zoheb Shaikh Date: Thu, 17 Oct 2024 10:44:43 +0100 Subject: [PATCH 09/30] Refactor set_and_wait_for_value function to include an optional parameter for waiting for set completion --- src/ophyd_async/core/_signal.py | 21 ++++++++++++-- src/ophyd_async/epics/adcore/_core_logic.py | 4 ++- tests/core/test_signal.py | 31 +++++++++++++++++++++ 3 files changed, 52 insertions(+), 4 deletions(-) diff --git a/src/ophyd_async/core/_signal.py b/src/ophyd_async/core/_signal.py index 4cc6ae54a9..8bbef89ee9 100644 --- a/src/ophyd_async/core/_signal.py +++ b/src/ophyd_async/core/_signal.py @@ -616,6 +616,7 @@ async def set_and_wait_for_other_value( match_value: S | Callable[[S], bool], timeout: float = DEFAULT_TIMEOUT, set_timeout: float | None = None, + wait_for_set_completion: bool = True, ): """Set a signal and monitor another signal until it has the specified value. @@ -636,6 +637,8 @@ async def set_and_wait_for_other_value( How long to wait for the signal to have the value set_timeout: How long to wait for the set to complete + wait_for_set_completion: + This will wait for set completion #More info in TBD Notes ----- @@ -660,7 +663,11 @@ async def _wait_for_value(): break try: - await asyncio.wait_for(_wait_for_value(), timeout) + status = asyncio.wait_for(_wait_for_value(), timeout) + if wait_for_set_completion: + await status + else: + return status except asyncio.TimeoutError as e: raise TimeoutError( f"{match_signal.name} didn't match {match_value} in {timeout}s" @@ -673,6 +680,7 @@ async def set_and_wait_for_value( match_value: T | Callable[[T], bool] | None = None, timeout: float = DEFAULT_TIMEOUT, status_timeout: float | None = None, + wait_for_set_completion: bool = True, ): """Set a signal and monitor it until it has that value. @@ -694,7 +702,8 @@ async def set_and_wait_for_value( How long to wait for the signal to have the value status_timeout: How long the returned Status will wait for the set to complete - + wait_for_set_completion: + This will wait for set completion #More info in TBD Notes ----- Example usage:: @@ -705,5 +714,11 @@ async def set_and_wait_for_value( match_value = value await set_and_wait_for_other_value( - signal, value, signal, match_value, timeout, status_timeout + signal, + value, + signal, + match_value, + timeout, + status_timeout, + wait_for_set_completion, ) diff --git a/src/ophyd_async/epics/adcore/_core_logic.py b/src/ophyd_async/epics/adcore/_core_logic.py index 7e5537b54b..8472737c9f 100644 --- a/src/ophyd_async/epics/adcore/_core_logic.py +++ b/src/ophyd_async/epics/adcore/_core_logic.py @@ -93,7 +93,9 @@ async def start_acquiring_driver_and_ensure_status( """ status = driver.acquire.set(True, timeout=timeout) await wait_for_value(driver.acquire, True, timeout=None) - await set_and_wait_for_value(driver.acquire, True, timeout=timeout) + await set_and_wait_for_value( + driver.acquire, True, timeout=timeout, wait_for_set_completion=False + ) async def complete_acquisition() -> None: """NOTE: possible race condition here between the callback from diff --git a/tests/core/test_signal.py b/tests/core/test_signal.py index 3502f54978..e6f724e551 100644 --- a/tests/core/test_signal.py +++ b/tests/core/test_signal.py @@ -32,6 +32,7 @@ soft_signal_rw, wait_for_value, ) +from ophyd_async.core._signal import observe_signals_values from ophyd_async.epics.signal import epics_signal_r, epics_signal_rw from ophyd_async.plan_stubs import ensure_connected @@ -280,6 +281,33 @@ def less_than_42(v): assert await time_taken_by(wait_for_value(signal, less_than_42, timeout=2)) < 0.1 +async def test_set_and_wait_for_value(): ... + + +async def set_signal_value(signal_1: SignalRW, signal_2: SignalRW): + set_mock_value(signal_1, 123) + set_mock_value(signal_2, 323) + set_mock_value(signal_1, 1) + set_mock_value(signal_2, 2) + + +async def test_observe_values(): + signal_1 = epics_signal_rw(float, read_pv="pva://signal_1", name="signal_1") + signal_2 = epics_signal_rw(float, read_pv="pva://signal_2", name="signal_2") + await signal_1.connect(mock=True) + await signal_2.connect(mock=True) + output: str = "" + t = asyncio.create_task(set_signal_value(signal_1, signal_2)) + async for signal, value in observe_signals_values(signal_1, signal_2): + if signal is signal_1 and value == 1: + output += "Hello_from_1" + elif signal is signal_2 and value == 2: + output += "Hello_from_2" + break + await t + assert output == "Hello_from_1Hello_from_2" + + @pytest.mark.parametrize( "signal_method,signal_class", [(soft_signal_r_and_setter, SignalR), (soft_signal_rw, SignalRW)], @@ -298,6 +326,9 @@ async def test_create_soft_signal(signal_method, signal_class): assert (await signal.get_value()) == INITIAL_VALUE +# write code to add two numbers + + async def test_soft_signal_numpy(): float_signal = soft_signal_rw(numpy.float64, numpy.float64(1), "float_signal") int_signal = soft_signal_rw(numpy.int32, numpy.int32(1), "int_signal") From d9d2d2522303fcf8358660eaa88e4ab3e33accc9 Mon Sep 17 00:00:00 2001 From: Zoheb Shaikh Date: Tue, 22 Oct 2024 10:42:59 +0100 Subject: [PATCH 10/30] WIP --- src/ophyd_async/core/_signal.py | 14 +++++------ src/ophyd_async/epics/adcore/_core_logic.py | 5 +--- tests/core/test_signal.py | 26 ++++++++++++++------- 3 files changed, 25 insertions(+), 20 deletions(-) diff --git a/src/ophyd_async/core/_signal.py b/src/ophyd_async/core/_signal.py index 8bbef89ee9..3ddc140395 100644 --- a/src/ophyd_async/core/_signal.py +++ b/src/ophyd_async/core/_signal.py @@ -20,7 +20,7 @@ from ._protocol import AsyncConfigurable, AsyncReadable, AsyncStageable from ._signal_backend import SignalBackend from ._soft_signal_backend import SignalMetadata, SoftSignalBackend -from ._status import AsyncStatus +from ._status import AsyncStatus, completed_status from ._utils import CALCULATE_TIMEOUT, DEFAULT_TIMEOUT, CalculatableTimeout, Callback, T S = TypeVar("S") @@ -617,7 +617,7 @@ async def set_and_wait_for_other_value( timeout: float = DEFAULT_TIMEOUT, set_timeout: float | None = None, wait_for_set_completion: bool = True, -): +) -> AsyncStatus: """Set a signal and monitor another signal until it has the specified value. This function sets a set_signal to a specified set_value and waits for @@ -663,15 +663,15 @@ async def _wait_for_value(): break try: - status = asyncio.wait_for(_wait_for_value(), timeout) + status = AsyncStatus(asyncio.wait_for(_wait_for_value(), timeout)) if wait_for_set_completion: await status - else: - return status + return status except asyncio.TimeoutError as e: raise TimeoutError( f"{match_signal.name} didn't match {match_value} in {timeout}s" ) from e + return completed_status() async def set_and_wait_for_value( @@ -681,7 +681,7 @@ async def set_and_wait_for_value( timeout: float = DEFAULT_TIMEOUT, status_timeout: float | None = None, wait_for_set_completion: bool = True, -): +) -> AsyncStatus: """Set a signal and monitor it until it has that value. Useful for busy record, or other Signals with pattern: @@ -713,7 +713,7 @@ async def set_and_wait_for_value( if match_value is None: match_value = value - await set_and_wait_for_other_value( + return await set_and_wait_for_other_value( signal, value, signal, diff --git a/src/ophyd_async/epics/adcore/_core_logic.py b/src/ophyd_async/epics/adcore/_core_logic.py index c35cc7f0c7..8d43ab8572 100644 --- a/src/ophyd_async/epics/adcore/_core_logic.py +++ b/src/ophyd_async/epics/adcore/_core_logic.py @@ -7,7 +7,6 @@ DetectorController, set_and_wait_for_value, ) -from ophyd_async.core._signal import wait_for_value from ophyd_async.epics.adcore._utils import convert_ad_dtype_to_np from ._core_io import ADBaseIO, DetectorState @@ -91,9 +90,7 @@ async def start_acquiring_driver_and_ensure_status( An AsyncStatus that can be awaited to set driver.acquire to True and perform subsequent raising (if applicable) due to detector state. """ - status = driver.acquire.set(True, timeout=timeout) - await wait_for_value(driver.acquire, True, timeout=None) - await set_and_wait_for_value( + status = await set_and_wait_for_value( driver.acquire, True, timeout=timeout, wait_for_set_completion=False ) diff --git a/tests/core/test_signal.py b/tests/core/test_signal.py index e6f724e551..39d8329d15 100644 --- a/tests/core/test_signal.py +++ b/tests/core/test_signal.py @@ -285,8 +285,6 @@ async def test_set_and_wait_for_value(): ... async def set_signal_value(signal_1: SignalRW, signal_2: SignalRW): - set_mock_value(signal_1, 123) - set_mock_value(signal_2, 323) set_mock_value(signal_1, 1) set_mock_value(signal_2, 2) @@ -296,16 +294,26 @@ async def test_observe_values(): signal_2 = epics_signal_rw(float, read_pv="pva://signal_2", name="signal_2") await signal_1.connect(mock=True) await signal_2.connect(mock=True) - output: str = "" + signal_changed: set[Signal] = set() t = asyncio.create_task(set_signal_value(signal_1, signal_2)) - async for signal, value in observe_signals_values(signal_1, signal_2): - if signal is signal_1 and value == 1: - output += "Hello_from_1" - elif signal is signal_2 and value == 2: - output += "Hello_from_2" + async for sig, value in observe_signals_values(signal_1, signal_2): + if sig is signal_1 and value == 1: + signal_changed.add(sig) + elif sig is signal_2 and value == 2: + signal_changed.add(sig) break await t - assert output == "Hello_from_1Hello_from_2" + assert signal_changed == {signal_1, signal_2} + + +async def test_observe_values_with_time_out(): + signal_1 = epics_signal_rw(float, read_pv="pva://signal_1", name="signal_1") + signal_2 = epics_signal_rw(float, read_pv="pva://signal_2", name="signal_2") + await signal_1.connect(mock=True) + await signal_2.connect(mock=True) + with pytest.raises(asyncio.TimeoutError): + async for _, _ in observe_signals_values(signal_1, signal_2, timeout=0.1): + ... @pytest.mark.parametrize( From e8e83bb2d771d1972077cb187ca04984a12503a1 Mon Sep 17 00:00:00 2001 From: Zoheb Shaikh Date: Tue, 22 Oct 2024 13:58:59 +0100 Subject: [PATCH 11/30] added code review changes and tests --- src/ophyd_async/core/_signal.py | 6 ++--- tests/core/test_signal.py | 44 +++++++++++++++++++++++++++------ 2 files changed, 40 insertions(+), 10 deletions(-) diff --git a/src/ophyd_async/core/_signal.py b/src/ophyd_async/core/_signal.py index 3ddc140395..677c1db3da 100644 --- a/src/ophyd_async/core/_signal.py +++ b/src/ophyd_async/core/_signal.py @@ -545,7 +545,7 @@ def queue_value(value: T, signal=signal): else: break else: - yield item # type: ignore + yield cast(tuple[SignalR[T], T], item) finally: for signal, cb in cbs.items(): signal.clear_sub(cb) @@ -652,7 +652,7 @@ async def set_and_wait_for_other_value( # Get the initial value from the monitor to make sure we've created it current_value = await anext(values_gen) - set_signal.set(set_value, timeout=set_timeout) + status = set_signal.set(set_value, timeout=set_timeout) # If the value was the same as before no need to wait for it to change if current_value != match_value: @@ -663,7 +663,7 @@ async def _wait_for_value(): break try: - status = AsyncStatus(asyncio.wait_for(_wait_for_value(), timeout)) + await asyncio.wait_for(_wait_for_value(), timeout) if wait_for_set_completion: await status return status diff --git a/tests/core/test_signal.py b/tests/core/test_signal.py index 39d8329d15..32c49ce2fc 100644 --- a/tests/core/test_signal.py +++ b/tests/core/test_signal.py @@ -33,6 +33,7 @@ wait_for_value, ) from ophyd_async.core._signal import observe_signals_values +from ophyd_async.core._status import AsyncStatus from ophyd_async.epics.signal import epics_signal_r, epics_signal_rw from ophyd_async.plan_stubs import ensure_connected @@ -295,14 +296,43 @@ async def test_observe_values(): await signal_1.connect(mock=True) await signal_2.connect(mock=True) signal_changed: set[Signal] = set() - t = asyncio.create_task(set_signal_value(signal_1, signal_2)) - async for sig, value in observe_signals_values(signal_1, signal_2): - if sig is signal_1 and value == 1: - signal_changed.add(sig) - elif sig is signal_2 and value == 2: + done: bool = False + + async def coroutine(): + while not done: + await asyncio.sleep(0.1) + + done_status = AsyncStatus(coroutine()) + async for sig, _ in observe_signals_values( + signal_1, signal_2, timeout=1, done_status=done_status + ): + signal_changed.add(sig) + if len(signal_changed) == 2: + done = True + assert signal_changed == {signal_1, signal_2} + + +async def test_observe_values_raises_exception(): + signal_1 = epics_signal_rw(float, read_pv="pva://signal_1", name="signal_1") + signal_2 = epics_signal_rw(float, read_pv="pva://signal_2", name="signal_2") + await signal_1.connect(mock=True) + await signal_2.connect(mock=True) + signal_changed: set[Signal] = set() + done: bool = False + + async def coroutine(): + while not done: + await asyncio.sleep(0.1) + raise ValueError("Test exception") + + done_status = AsyncStatus(coroutine()) + with pytest.raises(ValueError, match="Test exception"): + async for sig, _ in observe_signals_values( + signal_1, signal_2, timeout=1, done_status=done_status + ): signal_changed.add(sig) - break - await t + if len(signal_changed) == 2: + done = True assert signal_changed == {signal_1, signal_2} From cda91ec4a863939664ceb9f9c7c9fcfa482e178b Mon Sep 17 00:00:00 2001 From: Zoheb Shaikh Date: Tue, 22 Oct 2024 14:13:47 +0100 Subject: [PATCH 12/30] deleted timeout to increase coverage --- tests/core/test_signal.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/core/test_signal.py b/tests/core/test_signal.py index 32c49ce2fc..30fd42ccfb 100644 --- a/tests/core/test_signal.py +++ b/tests/core/test_signal.py @@ -328,7 +328,7 @@ async def coroutine(): done_status = AsyncStatus(coroutine()) with pytest.raises(ValueError, match="Test exception"): async for sig, _ in observe_signals_values( - signal_1, signal_2, timeout=1, done_status=done_status + signal_1, signal_2, done_status=done_status ): signal_changed.add(sig) if len(signal_changed) == 2: From 492d30395bb6aaa8823c0d8e07b3e199e0e7912f Mon Sep 17 00:00:00 2001 From: Zoheb Shaikh Date: Tue, 22 Oct 2024 14:52:13 +0100 Subject: [PATCH 13/30] used async for observe values in observe_value --- src/ophyd_async/core/_signal.py | 29 ++++------------------------- 1 file changed, 4 insertions(+), 25 deletions(-) diff --git a/src/ophyd_async/core/_signal.py b/src/ophyd_async/core/_signal.py index 677c1db3da..0404577ec1 100644 --- a/src/ophyd_async/core/_signal.py +++ b/src/ophyd_async/core/_signal.py @@ -463,31 +463,10 @@ async def observe_value( async for value in observe_value(sig): do_something_with(value) """ - - q: asyncio.Queue[T | Status] = asyncio.Queue() - if timeout is None: - get_value = q.get - else: - - async def get_value(): - return await asyncio.wait_for(q.get(), timeout) - - if done_status is not None: - done_status.add_callback(q.put_nowait) - - signal.subscribe_value(q.put_nowait) - try: - while True: - item = await get_value() - if done_status and item is done_status: - if exc := done_status.exception(): - raise exc - else: - break - else: - yield cast(T, item) - finally: - signal.clear_sub(q.put_nowait) + async for _, value in observe_signals_values( + signal, timeout=timeout, done_status=done_status + ): + yield value async def observe_signals_values( From 9b63dc617d092357077c1de5940a34430b915470 Mon Sep 17 00:00:00 2001 From: Zoheb Shaikh Date: Thu, 24 Oct 2024 15:29:14 +0100 Subject: [PATCH 14/30] deleted test debug message --- src/ophyd_async/epics/adaravis/_aravis_controller.py | 7 +++---- tests/core/test_signal.py | 9 --------- tests/epics/test_motor.py | 9 +++------ 3 files changed, 6 insertions(+), 19 deletions(-) diff --git a/src/ophyd_async/epics/adaravis/_aravis_controller.py b/src/ophyd_async/epics/adaravis/_aravis_controller.py index e52dd1d1ae..67030cd691 100644 --- a/src/ophyd_async/epics/adaravis/_aravis_controller.py +++ b/src/ophyd_async/epics/adaravis/_aravis_controller.py @@ -8,9 +8,6 @@ ) from ophyd_async.core._status import AsyncStatus from ophyd_async.epics import adcore -from ophyd_async.epics.adcore._core_logic import ( - start_acquiring_driver_and_ensure_status, -) from ._aravis_io import AravisDriverIO, AravisTriggerMode, AravisTriggerSource @@ -50,7 +47,9 @@ async def prepare(self, trigger_info: TriggerInfo): ) async def arm(self): - self._arm_status = await start_acquiring_driver_and_ensure_status(self._drv) + self._arm_status = await adcore.start_acquiring_driver_and_ensure_status( + self._drv + ) async def wait_for_idle(self): if self._arm_status: diff --git a/tests/core/test_signal.py b/tests/core/test_signal.py index 30fd42ccfb..c6f4aa1dc5 100644 --- a/tests/core/test_signal.py +++ b/tests/core/test_signal.py @@ -183,7 +183,6 @@ async def wait_and_set_proceeds(): async def check_set_and_wait(): await set_and_wait_for_value(signal, 1, timeout=0.1) - await wait_for_value(signal, 1, timeout=0.1) await asyncio.gather(wait_and_set_proceeds(), check_set_and_wait()) assert await signal.get_value() == 1 @@ -207,8 +206,6 @@ async def check_set_and_wait(): await set_and_wait_for_other_value( set_signal, 1, match_signal, "test", timeout=100 ) - await wait_for_value(set_signal, 1, timeout=100) - await wait_for_value(match_signal, "test", timeout=100) await asyncio.gather(wait_and_set_read(), check_set_and_wait()) assert await set_signal.get_value() == 1 @@ -282,9 +279,6 @@ def less_than_42(v): assert await time_taken_by(wait_for_value(signal, less_than_42, timeout=2)) < 0.1 -async def test_set_and_wait_for_value(): ... - - async def set_signal_value(signal_1: SignalRW, signal_2: SignalRW): set_mock_value(signal_1, 1) set_mock_value(signal_2, 2) @@ -364,9 +358,6 @@ async def test_create_soft_signal(signal_method, signal_class): assert (await signal.get_value()) == INITIAL_VALUE -# write code to add two numbers - - async def test_soft_signal_numpy(): float_signal = soft_signal_rw(numpy.float64, numpy.float64(1), "float_signal") int_signal = soft_signal_rw(numpy.int32, numpy.int32(1), "int_signal") diff --git a/tests/epics/test_motor.py b/tests/epics/test_motor.py index 8346c4438c..12fdfd8c57 100644 --- a/tests/epics/test_motor.py +++ b/tests/epics/test_motor.py @@ -13,10 +13,10 @@ SignalRW, callback_on_mock_put, mock_puts_blocked, + observe_value, set_mock_put_proceeds, set_mock_value, ) -from ophyd_async.core._signal import observe_signals_values from ophyd_async.epics import motor # Long enough for multiple asyncio event loop cycles to run so @@ -279,13 +279,10 @@ async def test_prepare( set_mock_value(sim_motor.high_limit_travel, 20) set_mock_value(sim_motor.max_velocity, 10) fake_set_signal = SignalRW(MockSignalBackend(float)) - fake_other_set_signal = SignalRW(MockSignalBackend(float)) async def wait_for_set(_): - async for signal, value in observe_signals_values( - fake_other_set_signal, fake_set_signal, timeout=1 - ): - if signal == fake_set_signal and value == target_position: + async for value in observe_value(fake_set_signal, timeout=1): + if value == target_position: break sim_motor.set = AsyncMock(side_effect=wait_for_set) From 5c7d23a4d632091fcf3df8d074024a9c23f118ce Mon Sep 17 00:00:00 2001 From: Dominic Oram Date: Tue, 29 Oct 2024 13:45:17 +0000 Subject: [PATCH 15/30] Add test for context race condition (#600) * Add test for context race condition * Add timeout to wait on connection in test --- tests/epics/signal/test_signals.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/tests/epics/signal/test_signals.py b/tests/epics/signal/test_signals.py index e2e5c20f7d..0ce7fd6b45 100644 --- a/tests/epics/signal/test_signals.py +++ b/tests/epics/signal/test_signals.py @@ -14,12 +14,15 @@ from typing import Any, Literal from unittest.mock import ANY +import bluesky.plan_stubs as bps import numpy as np import numpy.typing as npt import pytest from aioca import CANothing, purge_channel_caches from bluesky.protocols import Reading +from bluesky.run_engine import RunEngine from event_model import DataKey +from ophyd.signal import EpicsSignal from typing_extensions import TypedDict from ophyd_async.core import ( @@ -917,3 +920,21 @@ async def test_signals_created_for_not_prec_0_float_cannot_use_int(ioc: IOC): TypeError, match=f"{ioc.protocol}:float_prec_1 has type float not int" ): await sig.connect() + + +async def test_can_read_using_ophyd_async_then_ophyd(ioc: IOC): + oa_read = f"{ioc.protocol}://{PV_PREFIX}:{ioc.protocol}:float_prec_1" + ophyd_read = f"{PV_PREFIX}:{ioc.protocol}:float_prec_0" + + ophyd_async_sig = epics_signal_rw(float, oa_read) + await ophyd_async_sig.connect() + ophyd_signal = EpicsSignal(ophyd_read) + ophyd_signal.wait_for_connection(timeout=5) + + RE = RunEngine() + + def my_plan(): + yield from bps.rd(ophyd_async_sig) + yield from bps.rd(ophyd_signal) + + RE(my_plan()) From 8a8e9ea769163b8d556b037c3ac532b1e1592dbe Mon Sep 17 00:00:00 2001 From: "Tom C (DLS)" <101418278+coretl@users.noreply.github.com> Date: Wed, 30 Oct 2024 10:57:37 +0000 Subject: [PATCH 16/30] New signal typing (#594) Rewrite the type support, and use a plugin `Connector` architecture to support reconnection for PVI and Tango. Fixes #472, fixes #562, fixes #535, fixes #505, fixes #373, fixes #601 Required for #551 Structure now read from `.value` rather than `.pvi`. Supported in FastCS. Requires at least PandABlocks-ioc 0.10.0 ```python from enum import Enum class MyEnum(str, Enum): ONE = "one" TWO = "two" from ophyd_async.core import StrictEnum class MyEnum(StrictEnum): ONE = "one" TWO = "two" ``` ```python from ophyd_async.core import SubsetEnum MySubsetEnum = SubsetEnum["one", "two"] class MySubsetEnum(SubsetEnum): ONE = "one" TWO = "two" ``` ```python import numpy as np x = epics_signal_rw(np.int32, "PV") x = epics_signal_rw(int, "PV") ``` ```python import numpy as np import numpy.typing as npt x = epics_signal_rw(npt.NDArray[np.int32], "PV") from ophyd_async.core import Array1D x = epics_signal_rw(Array1D[np.int32], "PV") ``` ```python import numpy as np import numpy.typing as npt x = epics_signal_rw(npt.NDArray[np.str_], "PV") from collections.abc import Sequence x = epics_signal_rw(Sequence[str], "PV") ``` ```python fake_set_signal = SignalRW(MockSignalBackend(float)) fake_set_signal = soft_signal_rw(float) await fake_set_signal.connect(mock=True) ``` ```python get_mock_put(driver.capture).assert_called_once_with(Writing.ON, wait=ANY, timeout=ANY) get_mock_put(driver.capture).assert_called_once_with(Writing.ON, wait=ANY) ``` ```python class MyDevice(Device): def __init__(self, name: str = ""): self.signal, self.backend_put = soft_signal_r_and_setter(int) class MyDevice(Device): def __init__(self, name: str = ""): self.signal, self.backend_put = soft_signal_r_and_setter(int) super().__init__(name=name) ``` The `Table` type has been suitable for everything we have seen so far, if you need an arbitrary `BaseModel` subclass then please make an issue ```python class SourceDevice(Device): def __init__(self, name: str = ""): self.signal = soft_signal_rw(int) super().__init__(name=name) class ReferenceDevice(Device): def __init__(self, signal: SignalRW[int], name: str = ""): self.signal = signal super().__init__(name=name) def set(self, value) -> AsyncStatus: return self.signal.set(value + 1) from ophyd_async.core import Reference class ReferenceDevice(Device): def __init__(self, signal: SignalRW[int], name: str = ""): self._signal_ref = Reference(signal) super().__init__(name=name) def set(self, value) -> AsyncStatus: return self._signal_ref().set(value + 1) ``` --- .../decisions/0008-signal-types.md | 159 +++++ docs/how-to/make-a-simple-device.rst | 8 +- pyproject.toml | 30 +- src/ophyd_async/core/__init__.py | 31 +- src/ophyd_async/core/_detector.py | 15 +- src/ophyd_async/core/_device.py | 192 ++++-- src/ophyd_async/core/_device_filler.py | 191 ++++++ src/ophyd_async/core/_device_save_loader.py | 13 +- src/ophyd_async/core/_mock_signal_backend.py | 66 +- src/ophyd_async/core/_mock_signal_utils.py | 21 +- src/ophyd_async/core/_protocol.py | 34 +- src/ophyd_async/core/_readable.py | 10 +- src/ophyd_async/core/_signal.py | 297 ++++---- src/ophyd_async/core/_signal_backend.py | 195 ++++-- src/ophyd_async/core/_soft_signal_backend.py | 319 ++++----- src/ophyd_async/core/_table.py | 197 +++--- src/ophyd_async/core/_utils.py | 89 ++- .../epics/adaravis/_aravis_controller.py | 4 +- src/ophyd_async/epics/adaravis/_aravis_io.py | 12 +- src/ophyd_async/epics/adcore/_core_io.py | 10 +- src/ophyd_async/epics/adcore/_hdf_writer.py | 4 +- src/ophyd_async/epics/adcore/_utils.py | 25 +- src/ophyd_async/epics/adkinetix/__init__.py | 3 +- .../epics/adkinetix/_kinetix_controller.py | 9 +- .../epics/adkinetix/_kinetix_io.py | 7 +- .../epics/adpilatus/_pilatus_controller.py | 4 +- .../epics/adpilatus/_pilatus_io.py | 5 +- .../epics/adsimdetector/_sim_controller.py | 4 +- src/ophyd_async/epics/advimba/__init__.py | 5 +- .../epics/advimba/_vimba_controller.py | 9 +- src/ophyd_async/epics/advimba/_vimba_io.py | 15 +- src/ophyd_async/epics/demo/_sensor.py | 12 +- src/ophyd_async/epics/eiger/_eiger.py | 3 +- .../epics/eiger/_eiger_controller.py | 2 +- src/ophyd_async/epics/eiger/_eiger_io.py | 6 +- src/ophyd_async/epics/eiger/_odin_io.py | 8 +- src/ophyd_async/epics/pvi/__init__.py | 4 +- src/ophyd_async/epics/pvi/_pvi.py | 375 ++-------- src/ophyd_async/epics/signal/__init__.py | 7 +- src/ophyd_async/epics/signal/_aioca.py | 421 +++++------- src/ophyd_async/epics/signal/_common.py | 84 +-- .../epics/signal/_epics_transport.py | 34 - src/ophyd_async/epics/signal/_p4p.py | 643 +++++++----------- src/ophyd_async/epics/signal/_signal.py | 84 ++- src/ophyd_async/fastcs/core.py | 9 + src/ophyd_async/fastcs/panda/__init__.py | 8 +- src/ophyd_async/fastcs/panda/_block.py | 31 +- src/ophyd_async/fastcs/panda/_control.py | 8 +- src/ophyd_async/fastcs/panda/_hdf_panda.py | 24 +- src/ophyd_async/fastcs/panda/_table.py | 80 +-- src/ophyd_async/fastcs/panda/_trigger.py | 16 +- src/ophyd_async/fastcs/panda/_writer.py | 7 +- src/ophyd_async/plan_stubs/_fly.py | 4 +- src/ophyd_async/plan_stubs/_nd_attributes.py | 9 +- src/ophyd_async/py.typed | 0 .../_pattern_detector_controller.py | 3 +- src/ophyd_async/tango/__init__.py | 6 +- .../tango/base_devices/_base_device.py | 218 ++---- src/ophyd_async/tango/demo/_counter.py | 4 +- src/ophyd_async/tango/demo/_mover.py | 4 +- src/ophyd_async/tango/signal/__init__.py | 6 +- src/ophyd_async/tango/signal/_signal.py | 79 +-- .../tango/signal/_tango_transport.py | 78 ++- tests/core/test_device.py | 135 ++-- tests/core/test_device_collector.py | 8 +- tests/core/test_device_save_loader.py | 213 +++--- tests/core/test_flyer.py | 6 +- tests/core/test_log.py | 4 +- tests/core/test_mock_signal_backend.py | 127 ++-- tests/core/test_readable.py | 12 +- tests/core/test_signal.py | 116 +--- tests/core/test_soft_signal_backend.py | 115 ++-- tests/core/test_subset_enum.py | 52 +- tests/core/test_table.py | 53 ++ tests/core/test_utils.py | 11 +- tests/epics/adaravis/test_aravis.py | 4 +- tests/epics/adcore/test_single_trigger.py | 16 +- tests/epics/adcore/test_writers.py | 19 +- tests/epics/adkinetix/test_kinetix.py | 11 +- tests/epics/adsimdetector/test_sim.py | 12 +- tests/epics/advimba/test_vimba.py | 8 +- tests/epics/conftest.py | 9 +- tests/epics/demo/test_demo.py | 10 +- tests/epics/eiger/test_eiger_controller.py | 15 +- tests/epics/eiger/test_eiger_detector.py | 4 +- tests/epics/eiger/test_odin_io.py | 26 +- tests/epics/pvi/test_pvi.py | 138 ++-- tests/epics/signal/test_common.py | 19 +- tests/epics/signal/test_records.db | 26 +- tests/epics/signal/test_signals.py | 374 +++++----- tests/epics/test_motor.py | 10 +- tests/fastcs/panda/db/panda.db | 81 +-- tests/fastcs/panda/test_hdf_panda.py | 31 +- tests/fastcs/panda/test_panda_connect.py | 43 +- tests/fastcs/panda/test_panda_control.py | 16 +- tests/fastcs/panda/test_panda_utils.py | 17 +- .../{test_table.py => test_seq_table.py} | 118 ++-- tests/fastcs/panda/test_trigger.py | 30 +- tests/fastcs/panda/test_writer.py | 60 +- tests/plan_stubs/test_ensure_connected.py | 4 +- tests/plan_stubs/test_fly.py | 13 +- tests/tango/test_base_device.py | 18 +- tests/tango/test_tango_signals.py | 45 +- tests/tango/test_tango_transport.py | 56 +- tests/test_data/test_yaml_save.yml | 41 +- 105 files changed, 2995 insertions(+), 3391 deletions(-) create mode 100644 docs/explanations/decisions/0008-signal-types.md create mode 100644 src/ophyd_async/core/_device_filler.py delete mode 100644 src/ophyd_async/epics/signal/_epics_transport.py create mode 100644 src/ophyd_async/fastcs/core.py create mode 100644 src/ophyd_async/py.typed create mode 100644 tests/core/test_table.py rename tests/fastcs/panda/{test_table.py => test_seq_table.py} (73%) diff --git a/docs/explanations/decisions/0008-signal-types.md b/docs/explanations/decisions/0008-signal-types.md new file mode 100644 index 0000000000..021e69c983 --- /dev/null +++ b/docs/explanations/decisions/0008-signal-types.md @@ -0,0 +1,159 @@ +# 8. Settle on Signal Types +Date: 2024-10-18 + +## Status + +Accepted + +## Context + +At present, soft Signals allow any sort of datatype, while CA, PVA, Tango restrict these to what the control system allows. This means that some soft signals when `describe()` is called on them will give `dtype=object` which is not understood by downstream tools. It also means that load/save will not necessarily understand how to serialize results. Finally we now require `dtype_numpy` for tiled, so arbitrary object types are not suitable even if they are serializable. We should restrict the datatypes allowed in Signals to objects that are serializable and are sensible to add support for in downstream tools. + +## Decision + +We will allow the following: +- Primitives: + - `bool` + - `int` + - `float` + - `str` +- Enums: + - `StrictEnum` subclass which will be checked to have the same members as the CS + - `SubsetEnum` subclass which will be checked to be a subset of the CS members +- 1D arrays: + - `Array1D[np.bool_]` + - `Array1D[np.int8]` + - `Array1D[np.uint8]` + - `Array1D[np.int16]` + - `Array1D[np.uint16]` + - `Array1D[np.int32]` + - `Array1D[np.uint32]` + - `Array1D[np.int64]` + - `Array1D[np.uint64]` + - `Array1D[np.float32]` + - `Array1D[np.float64]` + - `Sequence[str]` + - `Sequence[MyEnum]` where `MyEnum` is a subclass of `StrictEnum` or `SubsetEnum` +- Specific structures: + - `np.ndarray` to represent arrays where dimensionality and dtype can change and must be read from CS + - `Table` subclass (which is a pydantic `BaseModel`) where all members are 1D arrays + +## Consequences + +Clients will be expected to understand: +- Python primitives (with Enums serializing as strings) +- Numpy arrays +- Pydantic BaseModels + +All of the above have sensible `dtype_numpy` fields, but `Table` will give a structured row-wise `dtype_numpy`, while the data will be serialized in a column-wise fashion. + +The following breaking changes will be made to ophyd-async: + +## pvi structure changes +Structure now read from `.value` rather than `.pvi`. Supported in FastCS. Requires at least PandABlocks-ioc 0.10.0 +## `StrictEnum` is now requried for all strictly checked `Enums` +```python +# old +from enum import Enum +class MyEnum(str, Enum): + ONE = "one" + TWO = "two" +# new +from ophyd_async.core import StrictEnum +class MyEnum(StrictEnum): + ONE = "one" + TWO = "two" +``` +## `SubsetEnum` is now an `Enum` subclass: +```python +from ophyd_async.core import SubsetEnum +# old +MySubsetEnum = SubsetEnum["one", "two"] +# new +class MySubsetEnum(SubsetEnum): + ONE = "one" + TWO = "two" +``` +## Use python primitives for scalar types instead of numpy types +```python +# old +import numpy as np +x = epics_signal_rw(np.int32, "PV") +# new +x = epics_signal_rw(int, "PV") +``` +## Use `Array1D` for 1D arrays instead of `npt.NDArray` +```python +import numpy as np +# old +import numpy.typing as npt +x = epics_signal_rw(npt.NDArray[np.int32], "PV") +# new +from ophyd_async.core import Array1D +x = epics_signal_rw(Array1D[np.int32], "PV") +``` +## Use `Sequence[str]` for arrays of strings instead of `npt.NDArray[np.str_]` +```python +import numpy as np +# old +import numpy.typing as npt +x = epics_signal_rw(npt.NDArray[np.str_], "PV") +# new +from collections.abc import Sequence +x = epics_signal_rw(Sequence[str], "PV") +``` +## `MockSignalBackend` requires a real backend +```python +# old +fake_set_signal = SignalRW(MockSignalBackend(float)) +# new +fake_set_signal = soft_signal_rw(float) +await fake_set_signal.connect(mock=True) +``` +## `get_mock_put` is no longer passed timeout as it is handled in `Signal` +```python +# old +get_mock_put(driver.capture).assert_called_once_with(Writing.ON, wait=ANY, timeout=ANY) +# new +get_mock_put(driver.capture).assert_called_once_with(Writing.ON, wait=ANY) +``` +## `super().__init__` required for `Device` subclasses +```python +# old +class MyDevice(Device): + def __init__(self, name: str = ""): + self.signal, self.backend_put = soft_signal_r_and_setter(int) +# new +class MyDevice(Device): + def __init__(self, name: str = ""): + self.signal, self.backend_put = soft_signal_r_and_setter(int) + super().__init__(name=name) +``` +## Arbitrary `BaseModel`s not supported, pending use cases for them +The `Table` type has been suitable for everything we have seen so far, if you need an arbitrary `BaseModel` subclass then please make an issue +## Child `Device`s set parent on attach, and can't be public children of more than one parent +```python +class SourceDevice(Device): + def __init__(self, name: str = ""): + self.signal = soft_signal_rw(int) + super().__init__(name=name) + +# old +class ReferenceDevice(Device): + def __init__(self, signal: SignalRW[int], name: str = ""): + self.signal = signal + super().__init__(name=name) + + def set(self, value) -> AsyncStatus: + return self.signal.set(value + 1) +# new +from ophyd_async.core import Reference + +class ReferenceDevice(Device): + def __init__(self, signal: SignalRW[int], name: str = ""): + self._signal_ref = Reference(signal) + super().__init__(name=name) + + def set(self, value) -> AsyncStatus: + return self._signal_ref().set(value + 1) +``` diff --git a/docs/how-to/make-a-simple-device.rst b/docs/how-to/make-a-simple-device.rst index f51fea2120..40edd36426 100644 --- a/docs/how-to/make-a-simple-device.rst +++ b/docs/how-to/make-a-simple-device.rst @@ -1,6 +1,6 @@ .. note:: - Ophyd async is included on a provisional basis until the v1.0 release and + Ophyd async is included on a provisional basis until the v1.0 release and may change API on minor release numbers before then Make a Simple Device @@ -31,7 +31,7 @@ its Python type, which could be: - A primitive (`str`, `int`, `float`) - An array (`numpy.typing.NDArray` ie. ``numpy.typing.NDArray[numpy.uint16]`` or ``Sequence[str]``) - An enum (`enum.Enum`) which **must** also extend `str` - - `str` and ``EnumClass(str, Enum)`` are the only valid ``datatype`` for an enumerated signal. + - `str` and ``EnumClass(StrictEnum)`` are the only valid ``datatype`` for an enumerated signal. The rest of the arguments are PV connection information, in this case the PV suffix. @@ -45,7 +45,7 @@ Finally `super().__init__() ` is called with: without renaming All signals passed into this init method will be monitored between ``stage()`` -and ``unstage()`` and their cached values returned on ``read()`` and +and ``unstage()`` and their cached values returned on ``read()`` and ``read_configuration()`` for perfomance. Movable @@ -64,7 +64,7 @@ informing watchers of the progress. When it gets to the requested value it completes. This co-routine is wrapped in a timeout handler, and passed to an `AsyncStatus` which will start executing it as soon as the Run Engine adds a callback to it. The ``stop()`` method then pokes a PV if the move needs to be -interrupted. +interrupted. Assembly -------- diff --git a/pyproject.toml b/pyproject.toml index 7dbf22271b..8c24913e0b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,12 +13,12 @@ classifiers = [ description = "Asynchronous Bluesky hardware abstraction code, compatible with control systems like EPICS and Tango" dependencies = [ "networkx>=2.0", - "numpy<2.0.0", + "numpy", "packaging", "pint", - "bluesky>=1.13.0a3", - "event_model", - "p4p", + "bluesky>=1.13", + "event-model>=1.22.1", + "p4p>=4.2.0a3", "pyyaml", "colorlog", "pydantic>=2.0", @@ -39,10 +39,6 @@ dev = [ "ophyd_async[sim]", "ophyd_async[ca]", "ophyd_async[tango]", - "black", - "flake8", - "flake8-isort", - "Flake8-pyproject", "inflection", "ipython", "ipywidgets", @@ -150,15 +146,17 @@ commands = src = ["src", "tests", "system_tests"] line-length = 88 lint.select = [ - "B", # flake8-bugbear - https://docs.astral.sh/ruff/rules/#flake8-bugbear-b - "C4", # flake8-comprehensions - https://docs.astral.sh/ruff/rules/#flake8-comprehensions-c4 - "E", # pycodestyle errors - https://docs.astral.sh/ruff/rules/#error-e - "F", # pyflakes rules - https://docs.astral.sh/ruff/rules/#pyflakes-f - "W", # pycodestyle warnings - https://docs.astral.sh/ruff/rules/#warning-w - "I", # isort - https://docs.astral.sh/ruff/rules/#isort-i - "UP", # pyupgrade - https://docs.astral.sh/ruff/rules/#pyupgrade-up - "SLF", # self - https://docs.astral.sh/ruff/settings/#lintflake8-self + "B", # flake8-bugbear - https://docs.astral.sh/ruff/rules/#flake8-bugbear-b + "C4", # flake8-comprehensions - https://docs.astral.sh/ruff/rules/#flake8-comprehensions-c4 + "E", # pycodestyle errors - https://docs.astral.sh/ruff/rules/#error-e + "F", # pyflakes rules - https://docs.astral.sh/ruff/rules/#pyflakes-f + "W", # pycodestyle warnings - https://docs.astral.sh/ruff/rules/#warning-w + "I", # isort - https://docs.astral.sh/ruff/rules/#isort-i + "UP", # pyupgrade - https://docs.astral.sh/ruff/rules/#pyupgrade-up + "SLF", # self - https://docs.astral.sh/ruff/settings/#lintflake8-self + "PLC2701", # private import - https://docs.astral.sh/ruff/rules/import-private-name/ ] +lint.preview = true # so that preview mode PLC2701 is enabled [tool.ruff.lint.per-file-ignores] # By default, private member access is allowed in tests diff --git a/src/ophyd_async/core/__init__.py b/src/ophyd_async/core/__init__.py index ddd3ab1a80..0a6c0d2f94 100644 --- a/src/ophyd_async/core/__init__.py +++ b/src/ophyd_async/core/__init__.py @@ -5,7 +5,8 @@ StandardDetector, TriggerInfo, ) -from ._device import Device, DeviceCollector, DeviceVector +from ._device import Device, DeviceCollector, DeviceConnector, DeviceVector +from ._device_filler import DeviceFiller from ._device_save_loader import ( all_at_once, get_signal_values, @@ -54,6 +55,7 @@ assert_emitted, assert_reading, assert_value, + observe_signals_values, observe_value, set_and_wait_for_other_value, set_and_wait_for_value, @@ -62,9 +64,11 @@ wait_for_value, ) from ._signal_backend import ( - RuntimeSubsetEnum, + Array1D, SignalBackend, - SubsetEnum, + SignalDatatype, + SignalDatatypeT, + make_datakey, ) from ._soft_signal_backend import SignalMetadata, SoftSignalBackend from ._status import AsyncStatus, WatchableAsyncStatus, completed_status @@ -73,14 +77,17 @@ CALCULATE_TIMEOUT, DEFAULT_TIMEOUT, CalculatableTimeout, + Callback, NotConnected, - ReadingValueCallback, + Reference, + StrictEnum, + SubsetEnum, T, WatcherUpdate, get_dtype, + get_enum_cls, get_unique, in_micros, - is_pydantic_model, wait_for_connection, ) @@ -91,8 +98,10 @@ "StandardDetector", "TriggerInfo", "Device", + "DeviceConnector", "DeviceCollector", "DeviceVector", + "DeviceFiller", "all_at_once", "get_signal_values", "load_device", @@ -141,30 +150,36 @@ "assert_reading", "assert_value", "observe_value", + "observe_signals_values", "set_and_wait_for_value", "set_and_wait_for_other_value", "soft_signal_r_and_setter", "soft_signal_rw", "wait_for_value", - "RuntimeSubsetEnum", + "Array1D", "SignalBackend", + "make_datakey", + "StrictEnum", "SubsetEnum", + "SignalDatatype", + "SignalDatatypeT", "SignalMetadata", "SoftSignalBackend", "AsyncStatus", "WatchableAsyncStatus", "DEFAULT_TIMEOUT", "CalculatableTimeout", + "Callback", "CALCULATE_TIMEOUT", "NotConnected", - "ReadingValueCallback", + "Reference", "Table", "T", "WatcherUpdate", "get_dtype", + "get_enum_cls", "get_unique", "in_micros", - "is_pydantic_model", "wait_for_connection", "completed_status", ] diff --git a/src/ophyd_async/core/_detector.py b/src/ophyd_async/core/_detector.py index 4bfed6b451..d2990c9fc8 100644 --- a/src/ophyd_async/core/_detector.py +++ b/src/ophyd_async/core/_detector.py @@ -4,11 +4,7 @@ import time from abc import ABC, abstractmethod from collections.abc import AsyncGenerator, AsyncIterator, Callable, Iterator, Sequence -from enum import Enum from functools import cached_property -from typing import ( - Generic, -) from bluesky.protocols import ( Collectable, @@ -23,14 +19,14 @@ from event_model import DataKey from pydantic import BaseModel, Field, NonNegativeInt, computed_field -from ._device import Device +from ._device import Device, DeviceConnector from ._protocol import AsyncConfigurable, AsyncReadable from ._signal import SignalR from ._status import AsyncStatus, WatchableAsyncStatus -from ._utils import DEFAULT_TIMEOUT, T, WatcherUpdate, merge_gathered_dicts +from ._utils import DEFAULT_TIMEOUT, StrictEnum, WatcherUpdate, merge_gathered_dicts -class DetectorTrigger(str, Enum): +class DetectorTrigger(StrictEnum): """Type of mechanism for triggering a detector to take frames""" #: Detector generates internal trigger for given rate @@ -172,7 +168,6 @@ class StandardDetector( Flyable, Collectable, WritesStreamAssets, - Generic[T], ): """ Useful detector base class for step and fly scanning detectors. @@ -185,6 +180,7 @@ def __init__( writer: DetectorWriter, config_sigs: Sequence[SignalR] = (), name: str = "", + connector: DeviceConnector | None = None, ) -> None: """ Constructor @@ -213,8 +209,7 @@ def __init__( self._completable_frames: int = 0 self._number_of_triggers_iter: Iterator[int] | None = None self._initial_frame: int = 0 - - super().__init__(name) + super().__init__(name, connector=connector) @property def controller(self) -> DetectorController: diff --git a/src/ophyd_async/core/_device.py b/src/ophyd_async/core/_device.py index bf3e23f5d6..c11aa478f8 100644 --- a/src/ophyd_async/core/_device.py +++ b/src/ophyd_async/core/_device.py @@ -1,39 +1,77 @@ -"""Base device""" +from __future__ import annotations import asyncio import sys -from collections.abc import Coroutine, Generator, Iterator -from functools import cached_property +from collections.abc import Coroutine, Iterator, Mapping, MutableMapping from logging import LoggerAdapter, getLogger -from typing import ( - Any, - Optional, - TypeVar, -) +from typing import Any, TypeVar from bluesky.protocols import HasName from bluesky.run_engine import call_in_bluesky_event_loop, in_bluesky_event_loop +from ._protocol import Connectable from ._utils import DEFAULT_TIMEOUT, NotConnected, wait_for_connection -class Device(HasName): - """Common base class for all Ophyd Async Devices. +class DeviceConnector: + """Defines how a `Device` should be connected and type hints processed.""" + + def create_children_from_annotations(self, device: Device): + """Used when children can be created from introspecting the hardware. + + Some control systems allow introspection of a device to determine what + children it has. To allow this to work nicely with typing we add these + hints to the Device like so:: + + my_signal: SignalRW[int] + my_device: MyDevice + + This method will be run during ``Device.__init__``, and is responsible + for turning all of those type hints into real Signal and Device instances. + + Subsequent runs of this function should do nothing, to allow it to be + called early in Devices that need to pass references to their children + during ``__init__``. + """ + + async def connect( + self, + device: Device, + mock: bool, + timeout: float, + force_reconnect: bool, + ): + """Used during ``Device.connect``. + + This is called when a previous connect has not been done, or has been + done in a different mock more. It should connect the Device and all its + children. + """ + coros = { + name: child_device.connect( + mock=mock, timeout=timeout, force_reconnect=force_reconnect + ) + for name, child_device in device.children() + } + await wait_for_connection(**coros) - By default, names and connects all Device children. - """ + +class Device(HasName, Connectable): + """Common base class for all Ophyd Async Devices.""" _name: str = "" #: The parent Device if it exists - parent: Optional["Device"] = None + parent: Device | None = None # None if connect hasn't started, a Task if it has _connect_task: asyncio.Task | None = None + # If not None, then this is the mock arg of the previous connect + # to let us know if we can reuse an existing connection + _connect_mock_arg: bool | None = None - # Used to check if the previous connect was mocked, - # if the next mock value differs then we fail - _previous_connect_was_mock = None - - def __init__(self, name: str = "") -> None: + def __init__( + self, name: str = "", connector: DeviceConnector | None = None + ) -> None: + self._connector = connector or DeviceConnector() self.set_name(name) @property @@ -41,13 +79,7 @@ def name(self) -> str: """Return the name of the Device""" return self._name - @cached_property - def log(self): - return LoggerAdapter( - getLogger("ophyd_async.devices"), {"ophyd_async_device_name": self.name} - ) - - def children(self) -> Iterator[tuple[str, "Device"]]: + def children(self) -> Iterator[tuple[str, Device]]: for attr_name, attr in self.__dict__.items(): if attr_name != "parent" and isinstance(attr, Device): yield attr_name, attr @@ -60,23 +92,32 @@ def set_name(self, name: str): name: New name to set """ - - # Ensure self.log is recreated after a name change - if hasattr(self, "log"): - del self.log - self._name = name - for attr_name, child in self.children(): - child_name = f"{name}-{attr_name.rstrip('_')}" if name else "" + # Ensure self.log is recreated after a name change + self.log = LoggerAdapter( + getLogger("ophyd_async.devices"), {"ophyd_async_device_name": self.name} + ) + for child_name, child in self.children(): + child_name = f"{self.name}-{child_name.strip('_')}" if self.name else "" child.set_name(child_name) - child.parent = self + + def __setattr__(self, name: str, value: Any) -> None: + if name == "parent": + if self.parent not in (value, None): + raise TypeError( + f"Cannot set the parent of {self} to be {value}: " + f"it is already a child of {self.parent}" + ) + elif isinstance(value, Device): + value.parent = self + return super().__setattr__(name, value) async def connect( self, mock: bool = False, timeout: float = DEFAULT_TIMEOUT, force_reconnect: bool = False, - ): + ) -> None: """Connect self and all child Devices. Contains a timeout that gets propagated to child.connect methods. @@ -88,41 +129,27 @@ async def connect( timeout: Time to wait before failing with a TimeoutError. """ - - if ( - self._previous_connect_was_mock is not None - and self._previous_connect_was_mock != mock - ): - raise RuntimeError( - f"`connect(mock={mock})` called on a `Device` where the previous " - f"connect was `mock={self._previous_connect_was_mock}`. Changing mock " - "value between connects is not permitted." - ) - self._previous_connect_was_mock = mock - - # If previous connect with same args has started and not errored, can use it - can_use_previous_connect = self._connect_task and not ( - self._connect_task.done() and self._connect_task.exception() + can_use_previous_connect = ( + mock is self._connect_mock_arg + and self._connect_task + and not (self._connect_task.done() and self._connect_task.exception()) ) if force_reconnect or not can_use_previous_connect: - # Kick off a connection - coros = { - name: child_device.connect( - mock, timeout=timeout, force_reconnect=force_reconnect - ) - for name, child_device in self.children() - } - self._connect_task = asyncio.create_task(wait_for_connection(**coros)) + self._connect_mock_arg = mock + coro = self._connector.connect( + device=self, mock=mock, timeout=timeout, force_reconnect=force_reconnect + ) + self._connect_task = asyncio.create_task(coro) assert self._connect_task, "Connect task not created, this shouldn't happen" # Wait for it to complete await self._connect_task -VT = TypeVar("VT", bound=Device) +DeviceT = TypeVar("DeviceT", bound=Device) -class DeviceVector(dict[int, VT], Device): +class DeviceVector(MutableMapping[int, DeviceT], Device): """ Defines device components with indices. @@ -131,10 +158,45 @@ class DeviceVector(dict[int, VT], Device): :class:`~ophyd_async.epics.demo.DynamicSensorGroup` """ - def children(self) -> Generator[tuple[str, Device], None, None]: - for attr_name, attr in self.items(): - if isinstance(attr, Device): - yield str(attr_name), attr + def __init__( + self, + children: Mapping[int, DeviceT], + name: str = "", + ) -> None: + self._children = dict(children) + super().__init__(name=name) + + def __setattr__(self, name: str, child: Any) -> None: + if name != "parent" and isinstance(child, Device): + raise AttributeError( + "DeviceVector can only have integer named children, " + "set via device_vector[i] = child" + ) + super().__setattr__(name, child) + + def __getitem__(self, key: int) -> DeviceT: + return self._children[key] + + def __setitem__(self, key: int, value: DeviceT) -> None: + # Check the types on entry to dict to make sure we can't accidentally + # make a non-integer named child + assert isinstance(key, int), f"Expected int, got {key}" + assert isinstance(value, Device), f"Expected Device, got {value}" + self._children[key] = value + value.parent = self + + def __delitem__(self, key: int) -> None: + del self._children[key] + + def __iter__(self) -> Iterator[int]: + yield from self._children + + def __len__(self) -> int: + return len(self._children) + + def children(self) -> Iterator[tuple[str, Device]]: + for key, child in self._children.items(): + yield str(key), child class DeviceCollector: @@ -195,12 +257,12 @@ def _caller_locals(self): ), "No previous frame to the one with self in it, this shouldn't happen" return caller_frame.f_locals - def __enter__(self) -> "DeviceCollector": + def __enter__(self) -> DeviceCollector: # Stash the names that were defined before we were called self._names_on_enter = set(self._caller_locals()) return self - async def __aenter__(self) -> "DeviceCollector": + async def __aenter__(self) -> DeviceCollector: return self.__enter__() async def _on_exit(self) -> None: diff --git a/src/ophyd_async/core/_device_filler.py b/src/ophyd_async/core/_device_filler.py new file mode 100644 index 0000000000..978380161c --- /dev/null +++ b/src/ophyd_async/core/_device_filler.py @@ -0,0 +1,191 @@ +from __future__ import annotations + +import re +from collections.abc import Callable +from typing import ( + Generic, + NoReturn, + TypeVar, + get_args, + get_origin, + get_type_hints, +) + +from ._device import Device, DeviceConnector, DeviceVector +from ._signal import Signal, SignalX +from ._signal_backend import SignalBackend, SignalDatatype +from ._utils import get_origin_class + + +def _strip_number_from_string(string: str) -> tuple[str, int | None]: + match = re.match(r"(.*?)(\d*)$", string) + assert match + + name = match.group(1) + number = match.group(2) or None + if number is None: + return name, None + else: + return name, int(number) + + +SignalBackendT = TypeVar("SignalBackendT", bound=SignalBackend) +DeviceConnectorT = TypeVar("DeviceConnectorT", bound=DeviceConnector) + + +class DeviceFiller(Generic[SignalBackendT, DeviceConnectorT]): + def __init__( + self, + device: Device, + signal_backend_factory: Callable[[type[SignalDatatype] | None], SignalBackendT], + device_connector_factory: Callable[[], DeviceConnectorT], + ): + self._device = device + self._signal_backend_factory = signal_backend_factory + self._device_connector_factory = device_connector_factory + self._vectors: dict[str, DeviceVector] = {} + self._vector_device_type: dict[str, type[Device] | None] = {} + self._signal_backends: dict[str, tuple[SignalBackendT, type[Signal]]] = {} + self._device_connectors: dict[str, DeviceConnectorT] = {} + # Get type hints on the class, not the instance + # https://github.com/python/cpython/issues/124840 + self._annotations = get_type_hints(type(device)) + for name, annotation in self._annotations.items(): + # names have a trailing underscore if the clash with a bluesky verb, + # so strip this off to get what the CS will provide + stripped_name = name.rstrip("_") + origin = get_origin_class(annotation) + if name == "parent" or name.startswith("_") or not origin: + # Ignore + pass + elif issubclass(origin, Signal): + # SignalX doesn't need datatype, all others need one + datatype = self.get_datatype(name) + if origin != SignalX and datatype is None: + self._raise( + name, + f"Expected SignalX or SignalR/W/RW[type], got {annotation}", + ) + self._signal_backends[stripped_name] = ( + self.make_child_signal(name, origin), + origin, + ) + elif origin == DeviceVector: + # DeviceVector needs a type of device + args = get_args(annotation) or [None] + child_origin = get_origin(args[0]) or args[0] + if child_origin is None or not issubclass(child_origin, Device): + self._raise( + name, + f"Expected DeviceVector[SomeDevice], got {annotation}", + ) + self.make_device_vector(name, child_origin) + elif issubclass(origin, Device): + self._device_connectors[stripped_name] = self.make_child_device( + name, origin + ) + + def unfilled(self) -> set[str]: + return set(self._device_connectors).union(self._signal_backends) + + def _raise(self, name: str, error: str) -> NoReturn: + raise TypeError(f"{type(self._device).__name__}.{name}: {error}") + + def make_device_vector(self, name: str, device_type: type[Device] | None): + self._vectors[name] = DeviceVector({}) + self._vector_device_type[name] = device_type + setattr(self._device, name, self._vectors[name]) + + def make_device_vectors(self, names: list[str]): + basenames: dict[str, set[int]] = {} + for name in names: + basename, number = _strip_number_from_string(name) + if number is not None: + basenames.setdefault(basename, set()).add(number) + for basename, numbers in basenames.items(): + # If contiguous numbers starting at 1 then it's a device vector + length = len(numbers) + if length > 1 and numbers == set(range(1, length + 1)): + # DeviceVector needs a type of device + self.make_device_vector(basename, None) + + def get_datatype(self, name: str) -> type[SignalDatatype] | None: + # Get dtype from SignalRW[dtype] or DeviceVector[SignalRW[dtype]] + basename, _ = _strip_number_from_string(name) + if basename in self._vectors: + # We decided to put it in a device vector, so get datatype from that + annotation = self._annotations.get(basename, None) + if annotation: + annotation = get_args(annotation)[0] + else: + # It's not a device vector, so get it from the full name + annotation = self._annotations.get(name, None) + args = get_args(annotation) + if args and get_origin_class(args[0]): + return args[0] + + def make_child_signal(self, name: str, signal_type: type[Signal]) -> SignalBackendT: + if name in self._signal_backends: + # We made it above + backend, expected_signal_type = self._signal_backends.pop(name) + else: + # We need to make a new one + basename, number = _strip_number_from_string(name) + child = getattr(self._device, name, None) + backend = self._signal_backend_factory(self.get_datatype(name)) + signal = signal_type(backend) + if basename in self._vectors and isinstance(number, int): + # We need to add a new entry to an existing DeviceVector + expected_signal_type = self._vector_device_type[basename] or signal_type + self._vectors[basename][number] = signal + elif child is None: + # We need to add a new child to the top level Device + expected_signal_type = signal_type + setattr(self._device, name, signal) + else: + self._raise(name, f"Cannot make child as it would shadow {child}") + if signal_type is not expected_signal_type: + self._raise( + name, + f"is a {signal_type.__name__} not a {expected_signal_type.__name__}", + ) + return backend + + def make_child_device( + self, name: str, device_type: type[Device] = Device + ) -> DeviceConnectorT: + basename, number = _strip_number_from_string(name) + child = getattr(self._device, name, None) + if connector := self._device_connectors.pop(name, None): + # We made it above + return connector + elif basename in self._vectors and isinstance(number, int): + # We need to add a new entry to an existing DeviceVector + vector_device_type = self._vector_device_type[basename] or device_type + assert issubclass( + vector_device_type, Device + ), f"{vector_device_type} is not a Device" + connector = self._device_connector_factory() + device = vector_device_type(connector=connector) + self._vectors[basename][number] = device + elif child is None: + # We need to add a new child to the top level Device + connector = self._device_connector_factory() + device = device_type(connector=connector) + setattr(self._device, name, device) + else: + self._raise(name, f"Cannot make child as it would shadow {child}") + connector.create_children_from_annotations(device) + return connector + + def make_soft_device_vector_entries(self, num: int): + for basename, cls in self._vector_device_type.items(): + assert cls, "Shouldn't happen" + for i in range(num): + name = f"{basename}{i + 1}" + if issubclass(cls, Signal): + self.make_child_signal(name, cls) + elif issubclass(cls, Device): + self.make_child_device(name, cls) + else: + self._raise(name, f"Can't make {cls}") diff --git a/src/ophyd_async/core/_device_save_loader.py b/src/ophyd_async/core/_device_save_loader.py index 95e936752b..86e479d136 100644 --- a/src/ophyd_async/core/_device_save_loader.py +++ b/src/ophyd_async/core/_device_save_loader.py @@ -27,11 +27,8 @@ def pydantic_model_abstraction_representer( return dumper.represent_data(model.model_dump(mode="python")) -class OphydDumper(yaml.Dumper): - def represent_data(self, data: Any) -> Any: - if isinstance(data, Enum): - return self.represent_data(data.value) - return super().represent_data(data) +def enum_representer(dumper: yaml.Dumper, enum: Enum) -> yaml.Node: + return dumper.represent_data(enum.value) def get_signal_values( @@ -74,7 +71,7 @@ def get_signal_values( for key, value in zip(selected_signals, selected_values, strict=False) } # Ignored values place in with value None so we know which ones were ignored - named_values.update({key: None for key in ignore}) + named_values.update(dict.fromkeys(ignore)) return named_values @@ -111,6 +108,7 @@ def walk_rw_signals( path_prefix = "" signals: dict[str, SignalRW[Any]] = {} + for attr_name, attr in device.children(): dot_path = f"{path_prefix}{attr_name}" if type(attr) is SignalRW: @@ -145,9 +143,10 @@ def save_to_yaml(phases: Sequence[dict[str, Any]], save_path: str | Path) -> Non pydantic_model_abstraction_representer, Dumper=yaml.Dumper, ) + yaml.add_multi_representer(Enum, enum_representer, Dumper=yaml.Dumper) with open(save_path, "w") as file: - yaml.dump(phases, file, Dumper=OphydDumper, default_flow_style=False) + yaml.dump(phases, file) def load_from_yaml(save_path: str) -> Sequence[dict[str, Any]]: diff --git a/src/ophyd_async/core/_mock_signal_backend.py b/src/ophyd_async/core/_mock_signal_backend.py index 029881cd96..49c835ace5 100644 --- a/src/ophyd_async/core/_mock_signal_backend.py +++ b/src/ophyd_async/core/_mock_signal_backend.py @@ -5,46 +5,37 @@ from bluesky.protocols import Descriptor, Reading -from ._signal_backend import SignalBackend +from ._signal_backend import SignalBackend, SignalDatatypeT from ._soft_signal_backend import SoftSignalBackend -from ._utils import DEFAULT_TIMEOUT, ReadingValueCallback, T +from ._utils import Callback -class MockSignalBackend(SignalBackend[T]): +class MockSignalBackend(SignalBackend[SignalDatatypeT]): """Signal backend for testing, created by ``Device.connect(mock=True)``.""" - def __init__( - self, - datatype: type[T] | None = None, - initial_backend: SignalBackend[T] | None = None, - ) -> None: + def __init__(self, initial_backend: SignalBackend[SignalDatatypeT]) -> None: if isinstance(initial_backend, MockSignalBackend): - raise ValueError("Cannot make a MockSignalBackend for a MockSignalBackends") + raise ValueError("Cannot make a MockSignalBackend for a MockSignalBackend") self.initial_backend = initial_backend - if datatype is None: - assert ( - self.initial_backend - ), "Must supply either initial_backend or datatype" - datatype = self.initial_backend.datatype - - self.datatype = datatype - - if not isinstance(self.initial_backend, SoftSignalBackend): - # If the backend is a hard signal backend, or not provided, - # then we create a soft signal to mimic it - - self.soft_backend = SoftSignalBackend(datatype=datatype) - else: + if isinstance(self.initial_backend, SoftSignalBackend): + # Backend is already a SoftSignalBackend, so use it self.soft_backend = self.initial_backend + else: + # Backend is not a SoftSignalBackend, so create one to mimic it + self.soft_backend = SoftSignalBackend( + datatype=self.initial_backend.datatype + ) + super().__init__(datatype=self.initial_backend.datatype) - def source(self, name: str) -> str: - if self.initial_backend: - return f"mock+{self.initial_backend.source(name)}" - return f"mock+{name}" + def set_value(self, value: SignalDatatypeT): + self.soft_backend.set_value(value) + + def source(self, name: str, read: bool) -> str: + return f"mock+{self.initial_backend.source(name, read)}" - async def connect(self, timeout: float = DEFAULT_TIMEOUT) -> None: + async def connect(self, timeout: float) -> None: pass @cached_property @@ -57,28 +48,23 @@ def put_proceeds(self) -> asyncio.Event: put_proceeds.set() return put_proceeds - async def put(self, value: T | None, wait=True, timeout=None): - await self.put_mock(value, wait=wait, timeout=timeout) - await self.soft_backend.put(value, wait=wait, timeout=timeout) - + async def put(self, value: SignalDatatypeT | None, wait: bool): + await self.put_mock(value, wait=wait) + await self.soft_backend.put(value, wait=wait) if wait: - await asyncio.wait_for(self.put_proceeds.wait(), timeout=timeout) - - def set_value(self, value: T): - self.soft_backend.set_value(value) + await self.put_proceeds.wait() async def get_reading(self) -> Reading: return await self.soft_backend.get_reading() - async def get_value(self) -> T: + async def get_value(self) -> SignalDatatypeT: return await self.soft_backend.get_value() - async def get_setpoint(self) -> T: - """For a soft signal, the setpoint and readback values are the same.""" + async def get_setpoint(self) -> SignalDatatypeT: return await self.soft_backend.get_setpoint() async def get_datakey(self, source: str) -> Descriptor: return await self.soft_backend.get_datakey(source) - def set_callback(self, callback: ReadingValueCallback[T] | None) -> None: + def set_callback(self, callback: Callback[Reading[SignalDatatypeT]] | None) -> None: self.soft_backend.set_callback(callback) diff --git a/src/ophyd_async/core/_mock_signal_utils.py b/src/ophyd_async/core/_mock_signal_utils.py index 33c0f677ba..70f29a2a4f 100644 --- a/src/ophyd_async/core/_mock_signal_utils.py +++ b/src/ophyd_async/core/_mock_signal_utils.py @@ -1,15 +1,14 @@ from collections.abc import Awaitable, Callable, Iterable from contextlib import asynccontextmanager, contextmanager -from typing import Any from unittest.mock import AsyncMock from ._mock_signal_backend import MockSignalBackend -from ._signal import Signal -from ._utils import T +from ._signal import Signal, SignalR +from ._soft_signal_backend import SignalDatatypeT def _get_mock_signal_backend(signal: Signal) -> MockSignalBackend: - backend = signal._backend # noqa:SLF001 + backend = signal._connector.backend # noqa:SLF001 assert isinstance(backend, MockSignalBackend), ( "Expected to receive a `MockSignalBackend`, instead " f" received {type(backend)}. " @@ -17,7 +16,7 @@ def _get_mock_signal_backend(signal: Signal) -> MockSignalBackend: return backend -def set_mock_value(signal: Signal[T], value: T): +def set_mock_value(signal: Signal[SignalDatatypeT], value: SignalDatatypeT): """Set the value of a signal that is in mock mode.""" backend = _get_mock_signal_backend(signal) backend.set_value(value) @@ -59,8 +58,8 @@ class _SetValuesIterator: def __init__( self, - signal: Signal, - values: Iterable[Any], + signal: SignalR[SignalDatatypeT], + values: Iterable[SignalDatatypeT], require_all_consumed: bool = False, ): self.signal = signal @@ -99,8 +98,8 @@ def __del__(self): def set_mock_values( - signal: Signal, - values: Iterable[Any], + signal: SignalR[SignalDatatypeT], + values: Iterable[SignalDatatypeT], require_all_consumed: bool = False, ) -> _SetValuesIterator: """Iterator to set a signal to a sequence of values, optionally repeating the @@ -143,7 +142,9 @@ def _unset_side_effect_cm(put_mock: AsyncMock): def callback_on_mock_put( - signal: Signal[T], callback: Callable[[T], None] | Callable[[T], Awaitable[None]] + signal: Signal[SignalDatatypeT], + callback: Callable[[SignalDatatypeT, bool], None] + | Callable[[SignalDatatypeT, bool], Awaitable[None]], ): """For setting a callback when a backend is put to. diff --git a/src/ophyd_async/core/_protocol.py b/src/ophyd_async/core/_protocol.py index 3978f39cc8..74b7bf0c23 100644 --- a/src/ophyd_async/core/_protocol.py +++ b/src/ophyd_async/core/_protocol.py @@ -13,10 +13,36 @@ from bluesky.protocols import HasName, Reading from event_model import DataKey +from ._utils import DEFAULT_TIMEOUT + if TYPE_CHECKING: from ._status import AsyncStatus +@runtime_checkable +class Connectable(Protocol): + @abstractmethod + async def connect( + self, + mock: bool = False, + timeout: float = DEFAULT_TIMEOUT, + force_reconnect: bool = False, + ): + """Connect self and all child Devices. + + Contains a timeout that gets propagated to child.connect methods. + + Parameters + ---------- + mock: + If True then use ``MockSignalBackend`` for all Signals + timeout: + Time to wait before failing with a TimeoutError. + force_reconnect: + Reconnect even if previous connect was successful. + """ + + @runtime_checkable class AsyncReadable(HasName, Protocol): @abstractmethod @@ -33,7 +59,6 @@ async def read(self) -> dict[str, Reading]: ('channel2', {'value': 16, 'timestamp': 1472493713.539238})) """ - ... @abstractmethod async def describe(self) -> dict[str, DataKey]: @@ -53,7 +78,6 @@ async def describe(self) -> dict[str, DataKey]: 'dtype': 'number', 'shape': []})) """ - ... @runtime_checkable @@ -63,14 +87,12 @@ async def read_configuration(self) -> dict[str, Reading]: """Same API as ``read`` but for slow-changing fields related to configuration. e.g., exposure time. These will typically be read only once per run. """ - ... @abstractmethod async def describe_configuration(self) -> dict[str, DataKey]: """Same API as ``describe``, but corresponding to the keys in ``read_configuration``. """ - ... @runtime_checkable @@ -78,12 +100,10 @@ class AsyncPausable(Protocol): @abstractmethod async def pause(self) -> None: """Perform device-specific work when the RunEngine pauses.""" - ... @abstractmethod async def resume(self) -> None: """Perform device-specific work when the RunEngine resumes after a pause.""" - ... @runtime_checkable @@ -95,7 +115,6 @@ def stage(self) -> AsyncStatus: It should return a ``Status`` that is marked done when the device is done staging. """ - ... @abstractmethod def unstage(self) -> AsyncStatus: @@ -104,7 +123,6 @@ def unstage(self) -> AsyncStatus: It should return a ``Status`` that is marked done when the device is finished unstaging. """ - ... C = TypeVar("C", contravariant=True) diff --git a/src/ophyd_async/core/_readable.py b/src/ophyd_async/core/_readable.py index 111a26d3b1..f27c061462 100644 --- a/src/ophyd_async/core/_readable.py +++ b/src/ophyd_async/core/_readable.py @@ -150,19 +150,19 @@ def add_children_as_readables( :meth:`HintedSignal.uncached` """ - dict_copy = self.__dict__.copy() + dict_copy = dict(self.children()) yield # Set symmetric difference operator gives all newly added keys - new_keys = dict_copy.keys() ^ self.__dict__.keys() - new_values = [self.__dict__[key] for key in new_keys] + new_dict = dict(self.children()) + new_keys = dict_copy.keys() ^ new_dict.keys() + new_values = [new_dict[key] for key in new_keys] flattened_values = [] for value in new_values: if isinstance(value, DeviceVector): - children = value.children() - flattened_values.extend([x[1] for x in children]) + flattened_values.extend(value.values()) else: flattened_values.append(value) diff --git a/src/ophyd_async/core/_signal.py b/src/ophyd_async/core/_signal.py index 0404577ec1..2895cc2ee8 100644 --- a/src/ophyd_async/core/_signal.py +++ b/src/ophyd_async/core/_signal.py @@ -2,128 +2,104 @@ import asyncio import functools -from collections.abc import AsyncGenerator, Callable, Mapping -from typing import Any, Generic, TypeVar, cast +from collections.abc import AsyncGenerator, Awaitable, Callable, Mapping +from typing import Any, Generic, cast from bluesky.protocols import ( Locatable, Location, Movable, - Reading, Status, Subscribable, ) from event_model import DataKey -from ._device import Device +from ._device import Device, DeviceConnector from ._mock_signal_backend import MockSignalBackend -from ._protocol import AsyncConfigurable, AsyncReadable, AsyncStageable -from ._signal_backend import SignalBackend -from ._soft_signal_backend import SignalMetadata, SoftSignalBackend +from ._protocol import ( + AsyncConfigurable, + AsyncReadable, + AsyncStageable, + Reading, +) +from ._signal_backend import ( + SignalBackend, + SignalDatatypeT, +) +from ._soft_signal_backend import SoftSignalBackend from ._status import AsyncStatus, completed_status from ._utils import CALCULATE_TIMEOUT, DEFAULT_TIMEOUT, CalculatableTimeout, Callback, T -S = TypeVar("S") + +async def _wait_for(coro: Awaitable[T], timeout: float | None, source: str) -> T: + try: + return await asyncio.wait_for(coro, timeout) + except asyncio.TimeoutError as e: + raise asyncio.TimeoutError(source) from e def _add_timeout(func): @functools.wraps(func) async def wrapper(self: Signal, *args, **kwargs): - return await asyncio.wait_for(func(self, *args, **kwargs), self._timeout) + return await _wait_for(func(self, *args, **kwargs), self._timeout, self.source) return wrapper -def _fail(*args, **kwargs): - raise RuntimeError("Signal has not been supplied a backend yet") +class SignalConnector(DeviceConnector): + def __init__(self, backend: SignalBackend): + self.backend = self._init_backend = backend - -class DisconnectedBackend(SignalBackend): - source = connect = put = get_datakey = get_reading = get_value = get_setpoint = ( - set_callback - ) = _fail - - -DISCONNECTED_BACKEND = DisconnectedBackend() + async def connect( + self, + device: Device, + mock: bool, + timeout: float, + force_reconnect: bool, + ): + if mock: + self.backend = MockSignalBackend(self._init_backend) + else: + self.backend = self._init_backend + device.log.debug(f"Connecting to {self.backend.source(device.name, read=True)}") + await self.backend.connect(timeout) -class Signal(Device, Generic[T]): +class Signal(Device, Generic[SignalDatatypeT]): """A Device with the concept of a value, with R, RW, W and X flavours""" + _connector: SignalConnector + def __init__( self, - backend: SignalBackend[T] = DISCONNECTED_BACKEND, + backend: SignalBackend[SignalDatatypeT], timeout: float | None = DEFAULT_TIMEOUT, name: str = "", ) -> None: + super().__init__(name=name, connector=SignalConnector(backend)) self._timeout = timeout - self._backend = backend - super().__init__(name) - - async def connect( - self, - mock=False, - timeout=DEFAULT_TIMEOUT, - force_reconnect: bool = False, - backend: SignalBackend[T] | None = None, - ): - if backend: - if ( - self._backend is not DISCONNECTED_BACKEND - and backend is not self._backend - ): - raise ValueError("Backend at connection different from previous one.") - - self._backend = backend - if ( - self._previous_connect_was_mock is not None - and self._previous_connect_was_mock != mock - ): - raise RuntimeError( - f"`connect(mock={mock})` called on a `Signal` where the previous " - f"connect was `mock={self._previous_connect_was_mock}`. Changing mock " - "value between connects is not permitted." - ) - self._previous_connect_was_mock = mock - - if mock and not issubclass(type(self._backend), MockSignalBackend): - # Using a soft backend, look to the initial value - self._backend = MockSignalBackend(initial_backend=self._backend) - - if self._backend is None: - raise RuntimeError("`connect` called on signal without backend") - - can_use_previous_connection: bool = self._connect_task is not None and not ( - self._connect_task.done() and self._connect_task.exception() - ) - - if force_reconnect or not can_use_previous_connection: - self.log.debug(f"Connecting to {self.source}") - self._connect_task = asyncio.create_task( - self._backend.connect(timeout=timeout) - ) - else: - self.log.debug(f"Reusing previous connection to {self.source}") - assert ( - self._connect_task - ), "this assert is for type analysis and will never fail" - await self._connect_task @property def source(self) -> str: """Like ca://PV_PREFIX:SIGNAL, or "" if not set""" - return self._backend.source(self.name) + return self._connector.backend.source(self.name, read=True) + + def __setattr__(self, name: str, value: Any) -> None: + if name != "parent" and isinstance(value, Device): + raise AttributeError( + f"Cannot add Device or Signal {value} as a child of Signal {self}, " + "make a subclass of Device instead" + ) + return super().__setattr__(name, value) -class _SignalCache(Generic[T]): - def __init__(self, backend: SignalBackend[T], signal: Signal): +class _SignalCache(Generic[SignalDatatypeT]): + def __init__(self, backend: SignalBackend[SignalDatatypeT], signal: Signal): self._signal = signal self._staged = False self._listeners: dict[Callback, bool] = {} self._valid = asyncio.Event() - self._reading: Reading | None = None - self._value: T | None = None - + self._reading: Reading[SignalDatatypeT] | None = None self.backend = backend signal.log.debug(f"Making subscription on source {signal.source}") backend.set_callback(self._callback) @@ -132,30 +108,33 @@ def close(self): self.backend.set_callback(None) self._signal.log.debug(f"Closing subscription on source {self._signal.source}") - async def get_reading(self) -> Reading: + async def get_reading(self) -> Reading[SignalDatatypeT]: await self._valid.wait() assert self._reading is not None, "Monitor not working" return self._reading - async def get_value(self) -> T: - await self._valid.wait() - assert self._value is not None, "Monitor not working" - return self._value + async def get_value(self) -> SignalDatatypeT: + reading = await self.get_reading() + return reading["value"] - def _callback(self, reading: Reading, value: T): + def _callback(self, reading: Reading[SignalDatatypeT]): self._signal.log.debug( f"Updated subscription: reading of source {self._signal.source} changed" f"from {self._reading} to {reading}" ) self._reading = reading - self._value = value self._valid.set() for function, want_value in self._listeners.items(): self._notify(function, want_value) - def _notify(self, function: Callback, want_value: bool): + def _notify( + self, + function: Callback[dict[str, Reading[SignalDatatypeT]] | SignalDatatypeT], + want_value: bool, + ): + assert self._reading, "Monitor not working" if want_value: - function(self._value) + function(self._reading["value"]) else: function({self._signal.name: self._reading}) @@ -173,12 +152,14 @@ def set_staged(self, staged: bool): return self._staged or bool(self._listeners) -class SignalR(Signal[T], AsyncReadable, AsyncStageable, Subscribable): +class SignalR(Signal[SignalDatatypeT], AsyncReadable, AsyncStageable, Subscribable): """Signal that can be read from and monitored""" _cache: _SignalCache | None = None - def _backend_or_cache(self, cached: bool | None) -> _SignalCache | SignalBackend: + def _backend_or_cache( + self, cached: bool | None = None + ) -> _SignalCache | SignalBackend: # If cached is None then calculate it based on whether we already have a cache if cached is None: cached = self._cache is not None @@ -186,11 +167,11 @@ def _backend_or_cache(self, cached: bool | None) -> _SignalCache | SignalBackend assert self._cache, f"{self.source} not being monitored" return self._cache else: - return self._backend + return self._connector.backend def _get_cache(self) -> _SignalCache: if not self._cache: - self._cache = _SignalCache(self._backend, self) + self._cache = _SignalCache(self._connector.backend, self) return self._cache def _del_cache(self, needed: bool): @@ -206,16 +187,16 @@ async def read(self, cached: bool | None = None) -> dict[str, Reading]: @_add_timeout async def describe(self) -> dict[str, DataKey]: """Return a single item dict with the descriptor in it""" - return {self.name: await self._backend.get_datakey(self.source)} + return {self.name: await self._connector.backend.get_datakey(self.source)} @_add_timeout - async def get_value(self, cached: bool | None = None) -> T: + async def get_value(self, cached: bool | None = None) -> SignalDatatypeT: """The current value""" value = await self._backend_or_cache(cached).get_value() self.log.debug(f"get_value() on source {self.source} returned {value}") return value - def subscribe_value(self, function: Callback[T]): + def subscribe_value(self, function: Callback[SignalDatatypeT]): """Subscribe to updates in value of a device""" self._get_cache().subscribe(function, want_value=True) @@ -238,84 +219,82 @@ async def unstage(self) -> None: self._del_cache(self._get_cache().set_staged(False)) -class SignalW(Signal[T], Movable): +class SignalW(Signal[SignalDatatypeT], Movable): """Signal that can be set""" - def set( - self, value: T, wait=True, timeout: CalculatableTimeout = CALCULATE_TIMEOUT - ) -> AsyncStatus: + @AsyncStatus.wrap + async def set( + self, + value: SignalDatatypeT, + wait=True, + timeout: CalculatableTimeout = CALCULATE_TIMEOUT, + ) -> None: """Set the value and return a status saying when it's done""" - if timeout is CALCULATE_TIMEOUT: + if timeout == CALCULATE_TIMEOUT: timeout = self._timeout - - async def do_set(): - self.log.debug(f"Putting value {value} to backend at source {self.source}") - await self._backend.put(value, wait=wait, timeout=timeout) - self.log.debug( - f"Successfully put value {value} to backend at source {self.source}" - ) - - return AsyncStatus(do_set()) + source = self._connector.backend.source(self.name, read=False) + self.log.debug(f"Putting value {value} to backend at source {source}") + await _wait_for(self._connector.backend.put(value, wait=wait), timeout, source) + self.log.debug(f"Successfully put value {value} to backend at source {source}") -class SignalRW(SignalR[T], SignalW[T], Locatable): +class SignalRW(SignalR[SignalDatatypeT], SignalW[SignalDatatypeT], Locatable): """Signal that can be both read and set""" + @_add_timeout async def locate(self) -> Location: - location: Location = { - "setpoint": await self._backend.get_setpoint(), - "readback": await self.get_value(), - } - return location + """Return the setpoint and readback.""" + setpoint, readback = await asyncio.gather( + self._connector.backend.get_setpoint(), self._backend_or_cache().get_value() + ) + return Location(setpoint=setpoint, readback=readback) class SignalX(Signal): """Signal that puts the default value""" - def trigger( + @AsyncStatus.wrap + async def trigger( self, wait=True, timeout: CalculatableTimeout = CALCULATE_TIMEOUT - ) -> AsyncStatus: + ) -> None: """Trigger the action and return a status saying when it's done""" - if timeout is CALCULATE_TIMEOUT: + if timeout == CALCULATE_TIMEOUT: timeout = self._timeout - coro = self._backend.put(None, wait=wait, timeout=timeout) - return AsyncStatus(coro) + source = self._connector.backend.source(self.name, read=False) + self.log.debug(f"Putting default value to backend at source {source}") + await _wait_for(self._connector.backend.put(None, wait=wait), timeout, source) + self.log.debug(f"Successfully put default value to backend at source {source}") def soft_signal_rw( - datatype: type[T] | None = None, - initial_value: T | None = None, + datatype: type[SignalDatatypeT], + initial_value: SignalDatatypeT | None = None, name: str = "", units: str | None = None, precision: int | None = None, -) -> SignalRW[T]: +) -> SignalRW[SignalDatatypeT]: """Creates a read-writable Signal with a SoftSignalBackend. May pass metadata, which are propagated into describe. """ - metadata = SignalMetadata(units=units, precision=precision) - signal = SignalRW( - SoftSignalBackend(datatype, initial_value, metadata=metadata), - name=name, - ) + backend = SoftSignalBackend(datatype, initial_value, units, precision) + signal = SignalRW(backend=backend, name=name) return signal def soft_signal_r_and_setter( - datatype: type[T] | None = None, - initial_value: T | None = None, + datatype: type[SignalDatatypeT], + initial_value: SignalDatatypeT | None = None, name: str = "", units: str | None = None, precision: int | None = None, -) -> tuple[SignalR[T], Callable[[T], None]]: +) -> tuple[SignalR[SignalDatatypeT], Callable[[SignalDatatypeT], None]]: """Returns a tuple of a read-only Signal and a callable through which the signal can be internally modified within the device. May pass metadata, which are propagated into describe. Use soft_signal_rw if you want a device that is externally modifiable """ - metadata = SignalMetadata(units=units, precision=precision) - backend = SoftSignalBackend(datatype, initial_value, metadata=metadata) - signal = SignalR(backend, name=name) - + backend = SoftSignalBackend(datatype, initial_value, units, precision) + signal = SignalR(backend=backend, name=name) return (signal, backend.set_value) @@ -330,7 +309,7 @@ def _generate_assert_error_msg(name: str, expected_result, actual_result) -> str ) -async def assert_value(signal: SignalR[T], value: Any) -> None: +async def assert_value(signal: SignalR[SignalDatatypeT], value: Any) -> None: """Assert a signal's value and compare it an expected signal. Parameters @@ -440,8 +419,10 @@ def assert_emitted(docs: Mapping[str, list[dict]], **numbers: int): async def observe_value( - signal: SignalR[T], timeout: float | None = None, done_status: Status | None = None -) -> AsyncGenerator[T, None]: + signal: SignalR[SignalDatatypeT], + timeout: float | None = None, + done_status: Status | None = None, +) -> AsyncGenerator[SignalDatatypeT, None]: """Subscribe to the value of a signal so it can be iterated from. Parameters @@ -470,10 +451,10 @@ async def observe_value( async def observe_signals_values( - *signals: SignalR[T], + *signals: SignalR[SignalDatatypeT], timeout: float | None = None, done_status: Status | None = None, -) -> AsyncGenerator[tuple[SignalR[T], T], None]: +) -> AsyncGenerator[tuple[SignalR[SignalDatatypeT], SignalDatatypeT], None]: """Subscribe to the value of a signal so it can be iterated from. Parameters @@ -495,7 +476,9 @@ async def observe_signals_values( async for value1,value2,value3 in observe_signals_values(sig1,sig2,..): do_something_with(value) """ - q: asyncio.Queue[tuple[SignalR[T], T] | Status] = asyncio.Queue() + q: asyncio.Queue[tuple[SignalR[SignalDatatypeT], SignalDatatypeT] | Status] = ( + asyncio.Queue() + ) if timeout is None: get_value = q.get else: @@ -506,7 +489,7 @@ async def get_value(): cbs: dict[SignalR, Callback] = {} for signal in signals: - def queue_value(value: T, signal=signal): + def queue_value(value: SignalDatatypeT, signal=signal): q.put_nowait((signal, value)) cbs[signal] = queue_value @@ -524,25 +507,27 @@ def queue_value(value: T, signal=signal): else: break else: - yield cast(tuple[SignalR[T], T], item) + yield cast(tuple[SignalR[SignalDatatypeT], SignalDatatypeT], item) finally: for signal, cb in cbs.items(): signal.clear_sub(cb) -class _ValueChecker(Generic[T]): - def __init__(self, matcher: Callable[[T], bool], matcher_name: str): - self._last_value: T | None = None +class _ValueChecker(Generic[SignalDatatypeT]): + def __init__(self, matcher: Callable[[SignalDatatypeT], bool], matcher_name: str): + self._last_value: SignalDatatypeT | None = None self._matcher = matcher self._matcher_name = matcher_name - async def _wait_for_value(self, signal: SignalR[T]): + async def _wait_for_value(self, signal: SignalR[SignalDatatypeT]): async for value in observe_value(signal): self._last_value = value if self._matcher(value): return - async def wait_for_value(self, signal: SignalR[T], timeout: float | None): + async def wait_for_value( + self, signal: SignalR[SignalDatatypeT], timeout: float | None + ): try: await asyncio.wait_for(self._wait_for_value(signal), timeout) except asyncio.TimeoutError as e: @@ -553,8 +538,8 @@ async def wait_for_value(self, signal: SignalR[T], timeout: float | None): async def wait_for_value( - signal: SignalR[T], - match_value: T | Callable[[T], bool], + signal: SignalR[SignalDatatypeT], + match_value: SignalDatatypeT | Callable[[SignalDatatypeT], bool], timeout: float | None, ): """Wait for a signal to have a matching value. @@ -589,10 +574,10 @@ async def wait_for_value( async def set_and_wait_for_other_value( - set_signal: SignalW[T], - set_value: T, - match_signal: SignalR[S], - match_value: S | Callable[[S], bool], + set_signal: SignalW[SignalDatatypeT], + set_value: SignalDatatypeT, + match_signal: SignalR[SignalDatatypeT], + match_value: SignalDatatypeT | Callable[[SignalDatatypeT], bool], timeout: float = DEFAULT_TIMEOUT, set_timeout: float | None = None, wait_for_set_completion: bool = True, @@ -654,9 +639,9 @@ async def _wait_for_value(): async def set_and_wait_for_value( - signal: SignalRW[T], - value: T, - match_value: T | Callable[[T], bool] | None = None, + signal: SignalRW[SignalDatatypeT], + value: SignalDatatypeT, + match_value: SignalDatatypeT | Callable[[SignalDatatypeT], bool] | None = None, timeout: float = DEFAULT_TIMEOUT, status_timeout: float | None = None, wait_for_set_completion: bool = True, diff --git a/src/ophyd_async/core/_signal_backend.py b/src/ophyd_async/core/_signal_backend.py index 035936f32c..d6b04b3d84 100644 --- a/src/ophyd_async/core/_signal_backend.py +++ b/src/ophyd_async/core/_signal_backend.py @@ -1,97 +1,164 @@ from abc import abstractmethod -from typing import ( - TYPE_CHECKING, - Any, - ClassVar, - Generic, - Literal, -) +from collections.abc import Sequence +from typing import Generic, TypedDict, TypeVar, get_origin +import numpy as np from bluesky.protocols import Reading -from event_model import DataKey - -from ._utils import DEFAULT_TIMEOUT, ReadingValueCallback, T +from event_model import DataKey, Dtype, Limits + +from ._table import Table +from ._utils import Callback, StrictEnum, T + +DTypeScalar_co = TypeVar("DTypeScalar_co", covariant=True, bound=np.generic) +Array1D = np.ndarray[tuple[int], np.dtype[DTypeScalar_co]] +Primitive = bool | int | float | str +# NOTE: if you change this union then update the docs to match +SignalDatatype = ( + Primitive + | Array1D[np.bool_] + | Array1D[np.int8] + | Array1D[np.uint8] + | Array1D[np.int16] + | Array1D[np.uint16] + | Array1D[np.int32] + | Array1D[np.uint32] + | Array1D[np.int64] + | Array1D[np.uint64] + | Array1D[np.float32] + | Array1D[np.float64] + | np.ndarray + | StrictEnum + | Sequence[str] + | Sequence[StrictEnum] + | Table +) +# TODO: These typevars will not be needed when we drop python 3.11 +# as you can do MyConverter[SignalType: SignalTypeUnion]: +# rather than MyConverter(Generic[SignalType]) +PrimitiveT = TypeVar("PrimitiveT", bound=Primitive) +SignalDatatypeT = TypeVar("SignalDatatypeT", bound=SignalDatatype) +SignalDatatypeV = TypeVar("SignalDatatypeV", bound=SignalDatatype) +EnumT = TypeVar("EnumT", bound=StrictEnum) +TableT = TypeVar("TableT", bound=Table) -class SignalBackend(Generic[T]): +class SignalBackend(Generic[SignalDatatypeT]): """A read/write/monitor backend for a Signals""" - #: Datatype of the signal value - datatype: type[T] | None = None + def __init__(self, datatype: type[SignalDatatypeT] | None): + self.datatype = datatype - @classmethod @abstractmethod - def datatype_allowed(cls, dtype: Any) -> bool: - """Check if a given datatype is acceptable for this signal backend.""" + def source(self, name: str, read: bool) -> str: + """Return source of signal. - #: Like ca://PV_PREFIX:SIGNAL - @abstractmethod - def source(self, name: str) -> str: - """Return source of signal. Signals may pass a name to the backend, which can be - used or discarded.""" + Signals may pass a name to the backend, which can be used or discarded. + """ @abstractmethod - async def connect(self, timeout: float = DEFAULT_TIMEOUT): + async def connect(self, timeout: float): """Connect to underlying hardware""" @abstractmethod - async def put(self, value: T | None, wait=True, timeout=None): - """Put a value to the PV, if wait then wait for completion for up to timeout""" + async def put(self, value: SignalDatatypeT | None, wait: bool): + """Put a value to the PV, if wait then wait for completion""" @abstractmethod async def get_datakey(self, source: str) -> DataKey: """Metadata like source, dtype, shape, precision, units""" @abstractmethod - async def get_reading(self) -> Reading: + async def get_reading(self) -> Reading[SignalDatatypeT]: """The current value, timestamp and severity""" @abstractmethod - async def get_value(self) -> T: + async def get_value(self) -> SignalDatatypeT: """The current value""" @abstractmethod - async def get_setpoint(self) -> T: + async def get_setpoint(self) -> SignalDatatypeT: """The point that a signal was requested to move to.""" @abstractmethod - def set_callback(self, callback: ReadingValueCallback[T] | None) -> None: + def set_callback(self, callback: Callback[T] | None) -> None: """Observe changes to the current value, timestamp and severity""" -class _RuntimeSubsetEnumMeta(type): - def __str__(cls): - if hasattr(cls, "choices"): - return f"SubsetEnum{list(cls.choices)}" # type: ignore - return "SubsetEnum" - - def __getitem__(cls, _choices): - if isinstance(_choices, str): - _choices = (_choices,) - else: - if not isinstance(_choices, tuple) or not all( - isinstance(c, str) for c in _choices - ): - raise TypeError( - "Choices must be a str or a tuple of str, " f"not {type(_choices)}." - ) - if len(set(_choices)) != len(_choices): - raise TypeError("Duplicate elements in runtime enum choices.") - - class _RuntimeSubsetEnum(cls): - choices = _choices - - return _RuntimeSubsetEnum - - -class RuntimeSubsetEnum(metaclass=_RuntimeSubsetEnumMeta): - choices: ClassVar[tuple[str, ...]] - - def __init__(self): - raise RuntimeError("SubsetEnum cannot be instantiated") - - -if TYPE_CHECKING: - SubsetEnum = Literal -else: - SubsetEnum = RuntimeSubsetEnum +_primitive_dtype: dict[type[Primitive], Dtype] = { + bool: "boolean", + int: "integer", + float: "number", + str: "string", +} + + +class SignalMetadata(TypedDict, total=False): + limits: Limits + choices: list[str] + precision: int + units: str + + +def _datakey_dtype(datatype: type[SignalDatatype]) -> Dtype: + if ( + datatype is np.ndarray + or get_origin(datatype) in (Sequence, np.ndarray) + or issubclass(datatype, Table) + ): + return "array" + elif issubclass(datatype, StrictEnum): + return "string" + elif issubclass(datatype, Primitive): + return _primitive_dtype[datatype] + else: + raise TypeError(f"Can't make dtype for {datatype}") + + +def _datakey_dtype_numpy( + datatype: type[SignalDatatypeT], value: SignalDatatypeT +) -> np.dtype: + if isinstance(value, np.ndarray): + # The value already has a dtype, use that + return value.dtype + elif ( + get_origin(datatype) is Sequence + or datatype is str + or issubclass(datatype, StrictEnum) + ): + # TODO: use np.dtypes.StringDType when we can use in structured arrays + # https://github.com/numpy/numpy/issues/25693 + return np.dtype("S40") + elif isinstance(value, Table): + return value.numpy_dtype() + elif issubclass(datatype, Primitive): + return np.dtype(datatype) + else: + raise TypeError(f"Can't make dtype_numpy for {datatype}") + + +def _datakey_shape(value: SignalDatatype) -> list[int]: + if type(value) in _primitive_dtype or isinstance(value, StrictEnum): + return [] + elif isinstance(value, np.ndarray): + return list(value.shape) + elif isinstance(value, Sequence | Table): + return [len(value)] + else: + raise TypeError(f"Can't make shape for {value}") + + +def make_datakey( + datatype: type[SignalDatatypeT], + value: SignalDatatypeT, + source: str, + metadata: SignalMetadata, +) -> DataKey: + dtn = _datakey_dtype_numpy(datatype, value) + return DataKey( + dtype=_datakey_dtype(datatype), + shape=_datakey_shape(value), + # Ignore until https://github.com/bluesky/event-model/issues/308 + dtype_numpy=dtn.descr if len(dtn.descr) > 1 else dtn.str, # type: ignore + source=source, + **metadata, + ) diff --git a/src/ophyd_async/core/_soft_signal_backend.py b/src/ophyd_async/core/_soft_signal_backend.py index eb4aa47d71..d0e48c7212 100644 --- a/src/ophyd_async/core/_soft_signal_backend.py +++ b/src/ophyd_async/core/_soft_signal_backend.py @@ -1,244 +1,175 @@ from __future__ import annotations -import inspect import time -from collections import abc -from enum import Enum -from typing import Generic, cast, get_origin +from abc import abstractmethod +from collections.abc import Sequence +from dataclasses import dataclass +from typing import Any, Generic, get_origin import numpy as np from bluesky.protocols import Reading from event_model import DataKey -from event_model.documents.event_descriptor import Dtype -from pydantic import BaseModel -from typing_extensions import TypedDict from ._signal_backend import ( - RuntimeSubsetEnum, + Array1D, + EnumT, + Primitive, + PrimitiveT, SignalBackend, + SignalDatatype, + SignalDatatypeT, + SignalMetadata, + TableT, + make_datakey, ) -from ._utils import ( - DEFAULT_TIMEOUT, - ReadingValueCallback, - T, - get_dtype, - is_pydantic_model, -) - -primitive_dtypes: dict[type, Dtype] = { - str: "string", - int: "integer", - float: "number", - bool: "boolean", -} +from ._table import Table +from ._utils import Callback, get_dtype, get_enum_cls -class SignalMetadata(TypedDict): - units: str | None - precision: int | None +class SoftConverter(Generic[SignalDatatypeT]): + # This is Any -> SignalDatatypeT because we support coercing + # value types to SignalDatatype to allow people to do things like + # SignalRW[Enum].set("enum value") + @abstractmethod + def write_value(self, value: Any) -> SignalDatatypeT: ... -class SoftConverter(Generic[T]): - def value(self, value: T) -> T: - return value +@dataclass +class PrimitiveSoftConverter(SoftConverter[PrimitiveT]): + datatype: type[PrimitiveT] - def write_value(self, value: T) -> T: - return value + def write_value(self, value: Any) -> PrimitiveT: + return self.datatype(value) if value else self.datatype() - def reading(self, value: T, timestamp: float, severity: int) -> Reading: - return Reading( - value=value, - timestamp=timestamp, - alarm_severity=-1 if severity > 2 else severity, - ) - def get_datakey(self, source: str, value, **metadata) -> DataKey: - dk: DataKey = {"source": source, "shape": [], **metadata} # type: ignore - dtype = type(value) - if np.issubdtype(dtype, np.integer): - dtype = int - elif np.issubdtype(dtype, np.floating): - dtype = float - assert ( - dtype in primitive_dtypes - ), f"invalid converter for value of type {type(value)}" - dk["dtype"] = primitive_dtypes[dtype] - # type ignore until https://github.com/bluesky/event-model/issues/308 - try: - dk["dtype_numpy"] = np.dtype(dtype).descr[0][1] # type: ignore - except TypeError: - dk["dtype_numpy"] = "" # type: ignore - return dk - - def make_initial_value(self, datatype: type[T] | None) -> T: - if datatype is None: - return cast(T, None) - - return datatype() - - -class SoftArrayConverter(SoftConverter): - def get_datakey(self, source: str, value, **metadata) -> DataKey: - dtype_numpy = "" - if isinstance(value, list): - if len(value) > 0: - dtype_numpy = np.dtype(type(value[0])).descr[0][1] - else: - dtype_numpy = np.dtype(value.dtype).descr[0][1] +class SequenceStrSoftConverter(SoftConverter[Sequence[str]]): + def write_value(self, value: Any) -> Sequence[str]: + return [str(v) for v in value] if value else [] - return { - "source": source, - "dtype": "array", - "dtype_numpy": dtype_numpy, # type: ignore - "shape": [len(value)], - **metadata, - } - def make_initial_value(self, datatype: type[T] | None) -> T: - if datatype is None: - return cast(T, None) +@dataclass +class SequenceEnumSoftConverter(SoftConverter[Sequence[EnumT]]): + datatype: type[EnumT] - if get_origin(datatype) == abc.Sequence: - return cast(T, []) + def write_value(self, value: Any) -> Sequence[EnumT]: + return [self.datatype(v) for v in value] if value else [] - return cast(T, datatype(shape=0)) # type: ignore +@dataclass +class NDArraySoftConverter(SoftConverter[Array1D]): + datatype: np.dtype -class SoftEnumConverter(SoftConverter): - choices: tuple[str, ...] - - def __init__(self, datatype: RuntimeSubsetEnum | type[Enum]): - if issubclass(datatype, Enum): # type: ignore - self.choices = tuple(v.value for v in datatype) - else: - self.choices = datatype.choices - - def write_value(self, value: Enum | str) -> str: - return value # type: ignore + def write_value(self, value: Any) -> Array1D: + return np.array(() if value is None else value, dtype=self.datatype) - def get_datakey(self, source: str, value, **metadata) -> DataKey: - return { - "source": source, - "dtype": "string", - # type ignore until https://github.com/bluesky/event-model/issues/308 - "dtype_numpy": "|S40", # type: ignore - "shape": [], - "choices": self.choices, - **metadata, - } - def make_initial_value(self, datatype: type[T] | None) -> T: - if datatype is None: - return cast(T, None) +@dataclass +class EnumSoftConverter(SoftConverter[EnumT]): + datatype: type[EnumT] - if issubclass(datatype, Enum): - return cast(T, list(datatype.__members__.values())[0]) # type: ignore - return cast(T, self.choices[0]) + def write_value(self, value: Any) -> EnumT: + return ( + self.datatype(value) + if value + else list(self.datatype.__members__.values())[0] + ) -class SoftPydanticModelConverter(SoftConverter): - def __init__(self, datatype: type[BaseModel]): - self.datatype = datatype +@dataclass +class TableSoftConverter(SoftConverter[TableT]): + datatype: type[TableT] - def write_value(self, value): + def write_value(self, value: Any) -> TableT: if isinstance(value, dict): return self.datatype(**value) - return value - - -def make_converter(datatype): - is_array = get_dtype(datatype) is not None - is_sequence = get_origin(datatype) == abc.Sequence - is_enum = inspect.isclass(datatype) and ( - issubclass(datatype, Enum) or issubclass(datatype, RuntimeSubsetEnum) - ) - - if is_array or is_sequence: - return SoftArrayConverter() - if is_enum: - return SoftEnumConverter(datatype) # type: ignore - if is_pydantic_model(datatype): - return SoftPydanticModelConverter(datatype) # type: ignore - - return SoftConverter() - - -class SoftSignalBackend(SignalBackend[T]): + elif isinstance(value, self.datatype): + return value + elif value is None: + return self.datatype() + else: + raise TypeError(f"Cannot convert {value} to {self.datatype}") + + +def make_converter(datatype: type[SignalDatatype]) -> SoftConverter: + enum_cls = get_enum_cls(datatype) + if datatype == Sequence[str]: + return SequenceStrSoftConverter() + elif get_origin(datatype) == Sequence and enum_cls: + return SequenceEnumSoftConverter(enum_cls) + elif get_origin(datatype) == np.ndarray: + return NDArraySoftConverter(get_dtype(datatype)) + elif enum_cls: + return EnumSoftConverter(enum_cls) + elif issubclass(datatype, Table): + return TableSoftConverter(datatype) + elif issubclass(datatype, Primitive): + return PrimitiveSoftConverter(datatype) + raise TypeError(f"Can't make converter for {datatype}") + + +class SoftSignalBackend(SignalBackend[SignalDatatypeT]): """An backend to a soft Signal, for test signals see ``MockSignalBackend``.""" - _value: T - _initial_value: T | None - _timestamp: float - _severity: int - - @classmethod - def datatype_allowed(cls, dtype: type) -> bool: - return True # Any value allowed in a soft signal - def __init__( self, - datatype: type[T] | None, - initial_value: T | None = None, - metadata: SignalMetadata = None, # type: ignore - ) -> None: - self.datatype = datatype - self._initial_value = initial_value - self._metadata = metadata or {} - self.converter: SoftConverter = make_converter(datatype) - if self._initial_value is None: - self._initial_value = self.converter.make_initial_value(self.datatype) - else: - self._initial_value = self.converter.write_value(self._initial_value) # type: ignore - - self.callback: ReadingValueCallback[T] | None = None - self._severity = 0 - self.set_value(self._initial_value) # type: ignore + datatype: type[SignalDatatypeT] | None, + initial_value: SignalDatatypeT | None = None, + units: str | None = None, + precision: int | None = None, + ): + # Create the right converter for the datatype + self.converter = make_converter(datatype or float) + # Add the extra static metadata to the dictionary + self.metadata: SignalMetadata = {} + if units is not None: + self.metadata["units"] = units + if precision is not None: + self.metadata["precision"] = precision + if enum_cls := get_enum_cls(datatype): + self.metadata["choices"] = [v.value for v in enum_cls] + # Create and set the initial value + self.initial_value = self.converter.write_value(initial_value) + self.reading: Reading[SignalDatatypeT] + self.callback: Callback[Reading[SignalDatatypeT]] | None = None + self.set_value(self.initial_value) + super().__init__(datatype) + + def set_value(self, value: SignalDatatypeT): + self.reading = Reading( + value=self.converter.write_value(value), + timestamp=time.monotonic(), + alarm_severity=0, + ) + if self.callback: + self.callback(self.reading) - def source(self, name: str) -> str: + def source(self, name: str, read: bool) -> str: return f"soft://{name}" - async def connect(self, timeout: float = DEFAULT_TIMEOUT) -> None: - """Connection isn't required for soft signals.""" + async def connect(self, timeout: float): pass - async def put(self, value: T | None, wait=True, timeout=None): - write_value = ( - self.converter.write_value(value) - if value is not None - else self._initial_value - ) - - self.set_value(write_value) # type: ignore - - def set_value(self, value: T): - """Method to bypass asynchronous logic.""" - self._value = value - self._timestamp = time.monotonic() - reading: Reading = self.converter.reading( - self._value, self._timestamp, self._severity - ) - - if self.callback: - self.callback(reading, self._value) + async def put(self, value: SignalDatatypeT | None, wait: bool) -> None: + write_value = self.initial_value if value is None else value + self.set_value(write_value) async def get_datakey(self, source: str) -> DataKey: - return self.converter.get_datakey(source, self._value, **self._metadata) + return make_datakey( + self.datatype or float, self.reading["value"], source, self.metadata + ) - async def get_reading(self) -> Reading: - return self.converter.reading(self._value, self._timestamp, self._severity) + async def get_reading(self) -> Reading[SignalDatatypeT]: + return self.reading - async def get_value(self) -> T: - return self.converter.value(self._value) + async def get_value(self) -> SignalDatatypeT: + return self.reading["value"] - async def get_setpoint(self) -> T: - """For a soft signal, the setpoint and readback values are the same.""" - return await self.get_value() + async def get_setpoint(self) -> SignalDatatypeT: + # For a soft signal, the setpoint and readback values are the same. + return self.reading["value"] - def set_callback(self, callback: ReadingValueCallback[T] | None) -> None: + def set_callback(self, callback: Callback[Reading[SignalDatatypeT]] | None) -> None: if callback: assert not self.callback, "Cannot set a callback when one is already set" - reading: Reading = self.converter.reading( - self._value, self._timestamp, self._severity - ) - callback(reading, self._value) + callback(self.reading) self.callback = callback diff --git a/src/ophyd_async/core/_table.py b/src/ophyd_async/core/_table.py index f36b60dceb..2b58a1af87 100644 --- a/src/ophyd_async/core/_table.py +++ b/src/ophyd_async/core/_table.py @@ -1,8 +1,13 @@ -from enum import Enum -from typing import TypeVar, get_args, get_origin +from __future__ import annotations + +from collections.abc import Sequence +from typing import Annotated, Any, TypeVar, get_origin import numpy as np -from pydantic import BaseModel, ConfigDict, model_validator +from pydantic import BaseModel, ConfigDict, Field, model_validator +from pydantic_numpy.helper.annotation import NpArrayPydanticAnnotation + +from ._utils import get_dtype TableSubclass = TypeVar("TableSubclass", bound="Table") @@ -17,34 +22,38 @@ def _concat(value1, value2): class Table(BaseModel): """An abstraction of a Table of str to numpy array.""" - model_config = ConfigDict(validate_assignment=True, strict=False) - - @staticmethod - def row(cls: type[TableSubclass], **kwargs) -> TableSubclass: # type: ignore - arrayified_kwargs = {} - for field_name, field_value in cls.model_fields.items(): - value = kwargs.pop(field_name) - if field_value.default_factory is None: - raise ValueError( - "`Table` models should have default factories for their " - "mutable empty columns." - ) - default_array = field_value.default_factory() - if isinstance(default_array, np.ndarray): - arrayified_kwargs[field_name] = np.array( - [value], dtype=default_array.dtype - ) - elif issubclass(type(value), Enum) and isinstance(value, str): - arrayified_kwargs[field_name] = [value] + # You can use Table in 2 ways: + # 1. Table(**whatever_pva_gives_us) when pvi adds a Signal to a Device that is not + # type hinted + # 2. MyTable(**whatever_pva_gives_us) where the Signal is type hinted + # + # For 1 we want extra="allow" so it is passed through as is. There are no base class + # fields, only "extra" fields, so they must be allowed. For 2 we want extra="forbid" + # so it is strictly checked against the BaseModel we are supplied. + model_config = ConfigDict(extra="allow") + + @classmethod + def __init_subclass__(cls): + # But forbit extra in subclasses so it gets validated + cls.model_config = ConfigDict(validate_assignment=True, extra="forbid") + # Change fields to have the correct annotations + for k, anno in cls.__annotations__.items(): + if get_origin(anno) is np.ndarray: + dtype = get_dtype(anno) + new_anno = Annotated[ + anno, + NpArrayPydanticAnnotation.factory( + data_type=dtype.type, dimensions=1, strict_data_typing=False + ), + Field( + default_factory=lambda dtype=dtype: np.array([], dtype=dtype) + ), + ] + elif get_origin(anno) is Sequence: + new_anno = Annotated[anno, Field(default_factory=list)] else: - raise TypeError( - "Row column should be numpy arrays or sequence of string `Enum`." - ) - if kwargs: - raise TypeError( - f"Unexpected keyword arguments {kwargs.keys()} for {cls.__name__}." - ) - return cls(**arrayified_kwargs) + raise TypeError(f"Cannot use annotation {anno} in a Table") + cls.__annotations__[k] = new_anno def __add__(self, right: TableSubclass) -> TableSubclass: """Concatenate the arrays in field values.""" @@ -64,83 +73,71 @@ def __add__(self, right: TableSubclass) -> TableSubclass: } ) + def __eq__(self, value: object) -> bool: + return super().__eq__(value) + def numpy_dtype(self) -> np.dtype: dtype = [] - for field_name, field_value in self.model_fields.items(): - if np.ndarray in ( - get_origin(field_value.annotation), - field_value.annotation, - ): - dtype.append((field_name, getattr(self, field_name).dtype)) + for k, v in self: + if isinstance(v, np.ndarray): + dtype.append((k, v.dtype)) else: - enum_type = get_args(field_value.annotation)[0] - assert issubclass(enum_type, Enum) - enum_values = [element.value for element in enum_type] - max_length_in_enum = max(len(value) for value in enum_values) - dtype.append((field_name, np.dtype(f" list[np.ndarray]: - """Columns in the table can be lists of string enums or numpy arrays. - - This method returns the columns, converting the string enums to numpy arrays. - """ - - columns = [] - for field_name, field_value in self.model_fields.items(): - if np.ndarray in ( - get_origin(field_value.annotation), - field_value.annotation, + def numpy_table(self, selection: slice | None = None) -> np.ndarray: + array = None + for k, v in self: + if selection: + v = v[selection] + if array is None: + array = np.empty(v.shape, dtype=self.numpy_dtype()) + array[k] = v + assert array is not None + return array + + @model_validator(mode="before") + @classmethod + def validate_array_dtypes(cls, data: Any) -> Any: + if isinstance(data, dict): + data_dict = data + elif isinstance(data, Table): + data_dict = data.model_dump() + else: + raise AssertionError(f"Cannot construct Table from {data}") + for field_name, field_value in cls.model_fields.items(): + if ( + get_origin(field_value.annotation) is np.ndarray + and field_value.annotation + and field_name in data_dict ): - columns.append(getattr(self, field_name)) - else: - enum_type = get_args(field_value.annotation)[0] - assert issubclass(enum_type, Enum) - enum_values = [element.value for element in enum_type] - max_length_in_enum = max(len(value) for value in enum_values) - dtype = np.dtype(f" "Table": - first_length = len(next(iter(self))[1]) - assert all( - len(field_value) == first_length for _, field_value in self - ), "Rows should all be of equal size." - - if not all( - # Checks if the values are numpy subtypes if the array is a numpy array, - # or if the value is a string enum. - np.issubdtype(getattr(self, field_name).dtype, default_array.dtype) - if isinstance( - default_array := self.model_fields[field_name].default_factory(), # type: ignore - np.ndarray, - ) - else issubclass(get_args(field_value.annotation)[0], Enum) - for field_name, field_value in self.model_fields.items() - ): - raise ValueError( - f"Cannot construct a `{type(self).__name__}`, " - "some rows have incorrect types." - ) - + def validate_lengths(self) -> Table: + lengths: dict[int, set[str]] = {} + for field_name, field_value in self: + lengths.setdefault(len(field_value), set()).add(field_name) + assert len(lengths) <= 1, f"Columns should be same length, got {lengths=}" return self + + def __len__(self) -> int: + return len(next(iter(self))[1]) + + def __getitem__(self, item: int | slice) -> np.ndarray: + if isinstance(item, int): + return self.numpy_table(slice(item, item + 1)) + else: + return self.numpy_table(item) diff --git a/src/ophyd_async/core/_utils.py b/src/ophyd_async/core/_utils.py index 8c90639e21..a8800191d2 100644 --- a/src/ophyd_async/core/_utils.py +++ b/src/ophyd_async/core/_utils.py @@ -2,25 +2,35 @@ import asyncio import logging -from collections.abc import Awaitable, Callable, Iterable +from collections.abc import Awaitable, Callable, Iterable, Sequence from dataclasses import dataclass -from typing import Generic, Literal, ParamSpec, TypeVar, get_origin +from enum import Enum, EnumMeta +from typing import Any, Generic, Literal, ParamSpec, TypeVar, get_args, get_origin import numpy as np -from bluesky.protocols import Reading -from pydantic import BaseModel T = TypeVar("T") P = ParamSpec("P") Callback = Callable[[T], None] - -#: A function that will be called with the Reading and value when the -#: monitor updates -ReadingValueCallback = Callable[[Reading, T], None] DEFAULT_TIMEOUT = 10.0 ErrorText = str | dict[str, Exception] +class StrictEnum(str, Enum): + """All members should exist in the Backend, and there will be no extras""" + + +class SubsetEnumMeta(EnumMeta): + def __call__(self, value, *args, **kwargs): # type: ignore + if isinstance(value, str) and not isinstance(value, self): + return value + return super().__call__(value, *args, **kwargs) + + +class SubsetEnum(StrictEnum, metaclass=SubsetEnumMeta): + """All members should exist in the Backend, but there may be extras""" + + CALCULATE_TIMEOUT = "CALCULATE_TIMEOUT" """Sentinel used to implement ``myfunc(timeout=CalculateTimeout)`` @@ -119,7 +129,22 @@ async def wait_for_connection(**coros: Awaitable[None]): raise NotConnected(exceptions) -def get_dtype(typ: type) -> np.dtype | None: +def get_dtype(datatype: type) -> np.dtype: + """Get the runtime dtype from a numpy ndarray type annotation + + >>> from ophyd_async.core import Array1D + >>> import numpy as np + >>> get_dtype(Array1D[np.int8]) + dtype('int8') + """ + if not get_origin(datatype) == np.ndarray: + raise TypeError(f"Expected Array1D[dtype], got {datatype}") + # datatype = numpy.ndarray[typing.Any, numpy.dtype[numpy.float64]] + # so extract numpy.float64 from it + return np.dtype(get_args(get_args(datatype)[1])[0]) + + +def get_enum_cls(datatype: type | None) -> type[StrictEnum] | None: """Get the runtime dtype from a numpy ndarray type annotation >>> import numpy.typing as npt @@ -127,11 +152,15 @@ def get_dtype(typ: type) -> np.dtype | None: >>> get_dtype(npt.NDArray[np.int8]) dtype('int8') """ - if getattr(typ, "__origin__", None) == np.ndarray: - # datatype = numpy.ndarray[typing.Any, numpy.dtype[numpy.float64]] - # so extract numpy.float64 from it - return np.dtype(typ.__args__[1].__args__[0]) # type: ignore - return None + if get_origin(datatype) is Sequence: + datatype = get_args(datatype)[0] + if datatype and issubclass(datatype, Enum): + if not issubclass(datatype, StrictEnum): + raise TypeError( + f"{datatype} should inherit from .SubsetEnum " + "or ophyd_async.core.StrictEnum" + ) + return datatype def get_unique(values: dict[str, T], types: str) -> T: @@ -187,7 +216,31 @@ def in_micros(t: float) -> int: return int(np.ceil(t * 1e6)) -def is_pydantic_model(datatype) -> bool: - while origin := get_origin(datatype): - datatype = origin - return datatype and issubclass(datatype, BaseModel) +def get_origin_class(annotatation: Any) -> type | None: + origin = get_origin(annotatation) or annotatation + if isinstance(origin, type): + return origin + + +class Reference(Generic[T]): + """Hide an object behind a reference. + + Used to opt out of the naming/parent-child relationship of `Device`. + + For example:: + + class DeviceWithRefToSignal(Device): + def __init__(self, signal: SignalRW[int]): + self.signal_ref = Reference(signal) + super().__init__() + + def set(self, value) -> AsyncStatus: + return self.signal_ref().set(value + 1) + + """ + + def __init__(self, obj: T): + self._obj = obj + + def __call__(self) -> T: + return self._obj diff --git a/src/ophyd_async/epics/adaravis/_aravis_controller.py b/src/ophyd_async/epics/adaravis/_aravis_controller.py index 67030cd691..b8765d2970 100644 --- a/src/ophyd_async/epics/adaravis/_aravis_controller.py +++ b/src/ophyd_async/epics/adaravis/_aravis_controller.py @@ -2,11 +2,11 @@ from typing import Literal from ophyd_async.core import ( + AsyncStatus, DetectorController, DetectorTrigger, TriggerInfo, ) -from ophyd_async.core._status import AsyncStatus from ophyd_async.epics import adcore from ._aravis_io import AravisDriverIO, AravisTriggerMode, AravisTriggerSource @@ -70,7 +70,7 @@ def _get_trigger_info( f"use {trigger}" ) if trigger == DetectorTrigger.internal: - return AravisTriggerMode.off, "Freerun" + return AravisTriggerMode.off, AravisTriggerSource.freerun else: return (AravisTriggerMode.on, f"Line{self.gpio_number}") # type: ignore diff --git a/src/ophyd_async/epics/adaravis/_aravis_io.py b/src/ophyd_async/epics/adaravis/_aravis_io.py index 27c2898513..9707beac2d 100644 --- a/src/ophyd_async/epics/adaravis/_aravis_io.py +++ b/src/ophyd_async/epics/adaravis/_aravis_io.py @@ -1,11 +1,9 @@ -from enum import Enum - -from ophyd_async.core import SubsetEnum +from ophyd_async.core import StrictEnum, SubsetEnum from ophyd_async.epics import adcore from ophyd_async.epics.signal import epics_signal_rw_rbv -class AravisTriggerMode(str, Enum): +class AravisTriggerMode(StrictEnum): """GigEVision GenICAM standard: on=externally triggered""" on = "On" @@ -19,7 +17,11 @@ class AravisTriggerMode(str, Enum): To prevent requiring one Enum class per possible configuration, we set as this Enum but read from the underlying signal as a str. """ -AravisTriggerSource = SubsetEnum["Freerun", "Line1"] + + +class AravisTriggerSource(SubsetEnum): + freerun = "Freerun" + line1 = "Line1" class AravisDriverIO(adcore.ADBaseIO): diff --git a/src/ophyd_async/epics/adcore/_core_io.py b/src/ophyd_async/epics/adcore/_core_io.py index 7968579117..e044b0e5d1 100644 --- a/src/ophyd_async/epics/adcore/_core_io.py +++ b/src/ophyd_async/epics/adcore/_core_io.py @@ -1,6 +1,4 @@ -from enum import Enum - -from ophyd_async.core import Device +from ophyd_async.core import Device, StrictEnum from ophyd_async.epics.signal import ( epics_signal_r, epics_signal_rw, @@ -10,7 +8,7 @@ from ._utils import ADBaseDataType, FileWriteMode, ImageMode -class Callback(str, Enum): +class Callback(StrictEnum): Enable = "Enable" Disable = "Disable" @@ -68,7 +66,7 @@ def __init__(self, prefix: str, name: str = "") -> None: super().__init__(prefix, name) -class DetectorState(str, Enum): +class DetectorState(StrictEnum): """ Default set of states of an AreaDetector driver. See definition in ADApp/ADSrc/ADDriver.h in https://github.com/areaDetector/ADCore @@ -100,7 +98,7 @@ def __init__(self, prefix: str, name: str = "") -> None: super().__init__(prefix, name=name) -class Compression(str, Enum): +class Compression(StrictEnum): none = "None" nbit = "N-bit" szip = "szip" diff --git a/src/ophyd_async/epics/adcore/_hdf_writer.py b/src/ophyd_async/epics/adcore/_hdf_writer.py index bfffa67b89..7d9bfd2b11 100644 --- a/src/ophyd_async/epics/adcore/_hdf_writer.py +++ b/src/ophyd_async/epics/adcore/_hdf_writer.py @@ -134,9 +134,9 @@ async def open(self, multiplier: int = 1) -> dict[str, DataKey]: describe = { ds.data_key: DataKey( source=self.hdf.full_file_name.source, - shape=outer_shape + tuple(ds.shape), + shape=list(outer_shape + tuple(ds.shape)), dtype="array" if ds.shape else "number", - dtype_numpy=ds.dtype_numpy, # type: ignore + dtype_numpy=ds.dtype_numpy, external="STREAM:", ) for ds in self._datasets diff --git a/src/ophyd_async/epics/adcore/_utils.py b/src/ophyd_async/epics/adcore/_utils.py index a1a21b6071..bedbd474c2 100644 --- a/src/ophyd_async/epics/adcore/_utils.py +++ b/src/ophyd_async/epics/adcore/_utils.py @@ -1,11 +1,16 @@ from dataclasses import dataclass -from enum import Enum -from ophyd_async.core import DEFAULT_TIMEOUT, SignalRW, T, wait_for_value -from ophyd_async.core._signal import SignalR +from ophyd_async.core import ( + DEFAULT_TIMEOUT, + SignalDatatypeT, + SignalR, + SignalRW, + StrictEnum, + wait_for_value, +) -class ADBaseDataType(str, Enum): +class ADBaseDataType(StrictEnum): Int8 = "Int8" UInt8 = "UInt8" Int16 = "Int16" @@ -73,25 +78,25 @@ def convert_param_dtype_to_np(datatype: str) -> str: return np_datatype -class FileWriteMode(str, Enum): +class FileWriteMode(StrictEnum): single = "Single" capture = "Capture" stream = "Stream" -class ImageMode(str, Enum): +class ImageMode(StrictEnum): single = "Single" multiple = "Multiple" continuous = "Continuous" -class NDAttributeDataType(str, Enum): +class NDAttributeDataType(StrictEnum): INT = "INT" DOUBLE = "DOUBLE" STRING = "STRING" -class NDAttributePvDbrType(str, Enum): +class NDAttributePvDbrType(StrictEnum): DBR_SHORT = "DBR_SHORT" DBR_ENUM = "DBR_ENUM" DBR_INT = "DBR_INT" @@ -122,8 +127,8 @@ class NDAttributeParam: async def stop_busy_record( - signal: SignalRW[T], - value: T, + signal: SignalRW[SignalDatatypeT], + value: SignalDatatypeT, timeout: float = DEFAULT_TIMEOUT, status_timeout: float | None = None, ) -> None: diff --git a/src/ophyd_async/epics/adkinetix/__init__.py b/src/ophyd_async/epics/adkinetix/__init__.py index 7747be3356..d97e40092c 100644 --- a/src/ophyd_async/epics/adkinetix/__init__.py +++ b/src/ophyd_async/epics/adkinetix/__init__.py @@ -1,9 +1,10 @@ from ._kinetix import KinetixDetector from ._kinetix_controller import KinetixController -from ._kinetix_io import KinetixDriverIO +from ._kinetix_io import KinetixDriverIO, KinetixTriggerMode __all__ = [ "KinetixDetector", "KinetixController", "KinetixDriverIO", + "KinetixTriggerMode", ] diff --git a/src/ophyd_async/epics/adkinetix/_kinetix_controller.py b/src/ophyd_async/epics/adkinetix/_kinetix_controller.py index e15cdd8ab0..7bc142d321 100644 --- a/src/ophyd_async/epics/adkinetix/_kinetix_controller.py +++ b/src/ophyd_async/epics/adkinetix/_kinetix_controller.py @@ -1,8 +1,11 @@ import asyncio -from ophyd_async.core import DetectorController, DetectorTrigger -from ophyd_async.core._detector import TriggerInfo -from ophyd_async.core._status import AsyncStatus +from ophyd_async.core import ( + AsyncStatus, + DetectorController, + DetectorTrigger, + TriggerInfo, +) from ophyd_async.epics import adcore from ._kinetix_io import KinetixDriverIO, KinetixTriggerMode diff --git a/src/ophyd_async/epics/adkinetix/_kinetix_io.py b/src/ophyd_async/epics/adkinetix/_kinetix_io.py index 30c4ccd2c3..4b70886648 100644 --- a/src/ophyd_async/epics/adkinetix/_kinetix_io.py +++ b/src/ophyd_async/epics/adkinetix/_kinetix_io.py @@ -1,16 +1,15 @@ -from enum import Enum - +from ophyd_async.core import StrictEnum from ophyd_async.epics import adcore from ophyd_async.epics.signal import epics_signal_rw_rbv -class KinetixTriggerMode(str, Enum): +class KinetixTriggerMode(StrictEnum): internal = "Internal" edge = "Rising Edge" gate = "Exp. Gate" -class KinetixReadoutMode(str, Enum): +class KinetixReadoutMode(StrictEnum): sensitivity = 1 speed = 2 dynamic_range = 3 diff --git a/src/ophyd_async/epics/adpilatus/_pilatus_controller.py b/src/ophyd_async/epics/adpilatus/_pilatus_controller.py index 7a1ab3c268..89a47914f4 100644 --- a/src/ophyd_async/epics/adpilatus/_pilatus_controller.py +++ b/src/ophyd_async/epics/adpilatus/_pilatus_controller.py @@ -2,12 +2,12 @@ from ophyd_async.core import ( DEFAULT_TIMEOUT, + AsyncStatus, DetectorController, DetectorTrigger, + TriggerInfo, wait_for_value, ) -from ophyd_async.core._detector import TriggerInfo -from ophyd_async.core._status import AsyncStatus from ophyd_async.epics import adcore from ._pilatus_io import PilatusDriverIO, PilatusTriggerMode diff --git a/src/ophyd_async/epics/adpilatus/_pilatus_io.py b/src/ophyd_async/epics/adpilatus/_pilatus_io.py index de040b5c4f..51ca65ce9c 100644 --- a/src/ophyd_async/epics/adpilatus/_pilatus_io.py +++ b/src/ophyd_async/epics/adpilatus/_pilatus_io.py @@ -1,10 +1,9 @@ -from enum import Enum - +from ophyd_async.core import StrictEnum from ophyd_async.epics import adcore from ophyd_async.epics.signal import epics_signal_r, epics_signal_rw_rbv -class PilatusTriggerMode(str, Enum): +class PilatusTriggerMode(StrictEnum): internal = "Internal" ext_enable = "Ext. Enable" ext_trigger = "Ext. Trigger" diff --git a/src/ophyd_async/epics/adsimdetector/_sim_controller.py b/src/ophyd_async/epics/adsimdetector/_sim_controller.py index 44253927cb..cf10674f12 100644 --- a/src/ophyd_async/epics/adsimdetector/_sim_controller.py +++ b/src/ophyd_async/epics/adsimdetector/_sim_controller.py @@ -2,11 +2,11 @@ from ophyd_async.core import ( DEFAULT_TIMEOUT, + AsyncStatus, DetectorController, DetectorTrigger, + TriggerInfo, ) -from ophyd_async.core._detector import TriggerInfo -from ophyd_async.core._status import AsyncStatus from ophyd_async.epics import adcore diff --git a/src/ophyd_async/epics/advimba/__init__.py b/src/ophyd_async/epics/advimba/__init__.py index 0a02423c78..bd0f9e9397 100644 --- a/src/ophyd_async/epics/advimba/__init__.py +++ b/src/ophyd_async/epics/advimba/__init__.py @@ -1,9 +1,12 @@ from ._vimba import VimbaDetector from ._vimba_controller import VimbaController -from ._vimba_io import VimbaDriverIO +from ._vimba_io import VimbaDriverIO, VimbaExposeOutMode, VimbaOnOff, VimbaTriggerSource __all__ = [ "VimbaDetector", "VimbaController", "VimbaDriverIO", + "VimbaExposeOutMode", + "VimbaOnOff", + "VimbaTriggerSource", ] diff --git a/src/ophyd_async/epics/advimba/_vimba_controller.py b/src/ophyd_async/epics/advimba/_vimba_controller.py index 6ffb4cad57..69aba6bf39 100644 --- a/src/ophyd_async/epics/advimba/_vimba_controller.py +++ b/src/ophyd_async/epics/advimba/_vimba_controller.py @@ -1,8 +1,11 @@ import asyncio -from ophyd_async.core import DetectorController, DetectorTrigger -from ophyd_async.core._detector import TriggerInfo -from ophyd_async.core._status import AsyncStatus +from ophyd_async.core import ( + AsyncStatus, + DetectorController, + DetectorTrigger, + TriggerInfo, +) from ophyd_async.epics import adcore from ._vimba_io import VimbaDriverIO, VimbaExposeOutMode, VimbaOnOff, VimbaTriggerSource diff --git a/src/ophyd_async/epics/advimba/_vimba_io.py b/src/ophyd_async/epics/advimba/_vimba_io.py index ac14872ef8..0dc7571b7b 100644 --- a/src/ophyd_async/epics/advimba/_vimba_io.py +++ b/src/ophyd_async/epics/advimba/_vimba_io.py @@ -1,10 +1,9 @@ -from enum import Enum - +from ophyd_async.core import StrictEnum from ophyd_async.epics import adcore from ophyd_async.epics.signal import epics_signal_rw_rbv -class VimbaPixelFormat(str, Enum): +class VimbaPixelFormat(StrictEnum): internal = "Mono8" ext_enable = "Mono12" ext_trigger = "Ext. Trigger" @@ -12,7 +11,7 @@ class VimbaPixelFormat(str, Enum): alignment = "Alignment" -class VimbaConvertFormat(str, Enum): +class VimbaConvertFormat(StrictEnum): none = "None" mono8 = "Mono8" mono16 = "Mono16" @@ -20,7 +19,7 @@ class VimbaConvertFormat(str, Enum): rgb16 = "RGB16" -class VimbaTriggerSource(str, Enum): +class VimbaTriggerSource(StrictEnum): freerun = "Freerun" line1 = "Line1" line2 = "Line2" @@ -30,17 +29,17 @@ class VimbaTriggerSource(str, Enum): action1 = "Action1" -class VimbaOverlap(str, Enum): +class VimbaOverlap(StrictEnum): off = "Off" prev_frame = "PreviousFrame" -class VimbaOnOff(str, Enum): +class VimbaOnOff(StrictEnum): on = "On" off = "Off" -class VimbaExposeOutMode(str, Enum): +class VimbaExposeOutMode(StrictEnum): timed = "Timed" # Use ExposureTime PV trigger_width = "TriggerWidth" # Expose for length of high signal diff --git a/src/ophyd_async/epics/demo/_sensor.py b/src/ophyd_async/epics/demo/_sensor.py index 37d590d155..5235fe0aba 100644 --- a/src/ophyd_async/epics/demo/_sensor.py +++ b/src/ophyd_async/epics/demo/_sensor.py @@ -1,10 +1,14 @@ -from enum import Enum - -from ophyd_async.core import ConfigSignal, DeviceVector, HintedSignal, StandardReadable +from ophyd_async.core import ( + ConfigSignal, + DeviceVector, + HintedSignal, + StandardReadable, + StrictEnum, +) from ophyd_async.epics.signal import epics_signal_r, epics_signal_rw -class EnergyMode(str, Enum): +class EnergyMode(StrictEnum): """Energy mode for `Sensor`""" #: Low energy mode diff --git a/src/ophyd_async/epics/eiger/_eiger.py b/src/ophyd_async/epics/eiger/_eiger.py index bc898d3660..d485fb47b0 100644 --- a/src/ophyd_async/epics/eiger/_eiger.py +++ b/src/ophyd_async/epics/eiger/_eiger.py @@ -1,7 +1,6 @@ from pydantic import Field -from ophyd_async.core import AsyncStatus, PathProvider, StandardDetector -from ophyd_async.core._detector import TriggerInfo +from ophyd_async.core import AsyncStatus, PathProvider, StandardDetector, TriggerInfo from ._eiger_controller import EigerController from ._eiger_io import EigerDriverIO diff --git a/src/ophyd_async/epics/eiger/_eiger_controller.py b/src/ophyd_async/epics/eiger/_eiger_controller.py index bf668676df..5cae0db68f 100644 --- a/src/ophyd_async/epics/eiger/_eiger_controller.py +++ b/src/ophyd_async/epics/eiger/_eiger_controller.py @@ -4,9 +4,9 @@ DEFAULT_TIMEOUT, DetectorController, DetectorTrigger, + TriggerInfo, set_and_wait_for_other_value, ) -from ophyd_async.core._detector import TriggerInfo from ._eiger_io import EigerDriverIO, EigerTriggerMode diff --git a/src/ophyd_async/epics/eiger/_eiger_io.py b/src/ophyd_async/epics/eiger/_eiger_io.py index 1df672592d..ed61c0b326 100644 --- a/src/ophyd_async/epics/eiger/_eiger_io.py +++ b/src/ophyd_async/epics/eiger/_eiger_io.py @@ -1,10 +1,8 @@ -from enum import Enum - -from ophyd_async.core import Device +from ophyd_async.core import Device, StrictEnum from ophyd_async.epics.signal import epics_signal_r, epics_signal_rw_rbv, epics_signal_w -class EigerTriggerMode(str, Enum): +class EigerTriggerMode(StrictEnum): internal = "ints" edge = "exts" gate = "exte" diff --git a/src/ophyd_async/epics/eiger/_odin_io.py b/src/ophyd_async/epics/eiger/_odin_io.py index c5a38a669b..321b0d7bb6 100644 --- a/src/ophyd_async/epics/eiger/_odin_io.py +++ b/src/ophyd_async/epics/eiger/_odin_io.py @@ -1,6 +1,5 @@ import asyncio from collections.abc import AsyncGenerator, AsyncIterator -from enum import Enum from bluesky.protocols import StreamAsset from event_model import DataKey @@ -12,6 +11,7 @@ DeviceVector, NameProvider, PathProvider, + StrictEnum, observe_value, set_and_wait_for_value, ) @@ -22,7 +22,7 @@ ) -class Writing(str, Enum): +class Writing(StrictEnum): ON = "ON" OFF = "OFF" @@ -101,10 +101,10 @@ async def _describe(self) -> dict[str, DataKey]: return { "data": DataKey( source=self._drv.file_name.source, - shape=data_shape, + shape=list(data_shape), dtype="array", # TODO: Use correct type based on eiger https://github.com/bluesky/ophyd-async/issues/529 - dtype_numpy=" tuple[str, int | None]: - match = re.match(r"(.*?)(\d*)$", string) - assert match - - name = match.group(1) - number = match.group(2) or None - if number is None: - return name, None - else: - return name, int(number) - - -def _split_subscript(tp: T) -> tuple[Any, tuple[Any]] | tuple[T, None]: - """Split a subscripted type into the its origin and args. - - If `tp` is not a subscripted type, then just return the type and None as args. - - """ - if get_origin(tp) is not None: - return get_origin(tp), get_args(tp) - - return tp, None - - -def _strip_union(field: T | T) -> tuple[T, bool]: - if get_origin(field) in [Union, types.UnionType]: - args = get_args(field) - is_optional = type(None) in args - for arg in args: - if arg is not type(None): - return arg, is_optional - return field, False - - -def _strip_device_vector(field: type[Device]) -> tuple[bool, type[Device]]: - if get_origin(field) is DeviceVector: - return True, get_args(field)[0] - return False, field - - -@dataclass -class _PVIEntry: - """ - A dataclass to represent a single entry in the PVI table. - This could either be a signal or a sub-table. - """ - - sub_entries: dict[str, Union[dict[int, "_PVIEntry"], "_PVIEntry"]] - pvi_pv: str | None = None - device: Device | None = None - common_device_type: type[Device] | None = None - - -def _verify_common_blocks(entry: _PVIEntry, common_device: type[Device]): - if not entry.sub_entries: - return - common_sub_devices = get_type_hints(common_device) - for sub_name, sub_device in common_sub_devices.items(): - if sub_name.startswith("_") or sub_name == "parent": - continue - assert entry.sub_entries - device_t, is_optional = _strip_union(sub_device) - if sub_name not in entry.sub_entries and not is_optional: - raise RuntimeError( - f"sub device `{sub_name}:{type(sub_device)}` " "was not provided by pvi" - ) - if isinstance(entry.sub_entries[sub_name], dict): - for sub_sub_entry in entry.sub_entries[sub_name].values(): # type: ignore - _verify_common_blocks(sub_sub_entry, sub_device) # type: ignore - else: - _verify_common_blocks( - entry.sub_entries[sub_name], # type: ignore - sub_device, # type: ignore - ) - - -_pvi_mapping: dict[frozenset[str], Callable[..., Signal]] = { - frozenset({"r", "w"}): lambda dtype, read_pv, write_pv: epics_signal_rw( - dtype, "pva://" + read_pv, "pva://" + write_pv - ), - frozenset({"rw"}): lambda dtype, read_write_pv: epics_signal_rw( - dtype, "pva://" + read_write_pv, write_pv="pva://" + read_write_pv - ), - frozenset({"r"}): lambda dtype, read_pv: epics_signal_r(dtype, "pva://" + read_pv), - frozenset({"w"}): lambda dtype, write_pv: epics_signal_w( - dtype, "pva://" + write_pv - ), - frozenset({"x"}): lambda _, write_pv: epics_signal_x("pva://" + write_pv), -} - - -def _parse_type( - is_pvi_table: bool, - number_suffix: int | None, - common_device_type: type[Device] | None, -): - if common_device_type: - # pre-defined type - device_cls, _ = _strip_union(common_device_type) - is_device_vector, device_cls = _strip_device_vector(device_cls) - device_cls, device_args = _split_subscript(device_cls) - assert issubclass(device_cls, Device) - - is_signal = issubclass(device_cls, Signal) - signal_dtype = device_args[0] if device_args is not None else None - - elif is_pvi_table: - # is a block, we can make it a DeviceVector if it ends in a number - is_device_vector = number_suffix is not None - is_signal = False - signal_dtype = None - device_cls = Device - else: - # is a signal, signals aren't stored in DeviceVectors unless - # they're defined as such in the common_device_type - is_device_vector = False - is_signal = True - signal_dtype = None - device_cls = Signal - - return is_device_vector, is_signal, signal_dtype, device_cls - - -def _mock_common_blocks(device: Device, stripped_type: type | None = None): - device_t = stripped_type or type(device) - sub_devices = ( - (field, field_type) - for field, field_type in get_type_hints(device_t).items() - if not field.startswith("_") and field != "parent" - ) - - for device_name, device_cls in sub_devices: - device_cls, _ = _strip_union(device_cls) - is_device_vector, device_cls = _strip_device_vector(device_cls) - device_cls, device_args = _split_subscript(device_cls) - assert issubclass(device_cls, Device) - signal_dtype = device_args[0] if device_args is not None else None +def _get_signal_details(entry: dict[str, str]) -> tuple[type[Signal], str, str]: + match entry: + case {"r": read_pv}: + return SignalR, read_pv, read_pv + case {"r": read_pv, "w": write_pv}: + return SignalRW, read_pv, write_pv + case {"rw": read_write_pv}: + return SignalRW, read_write_pv, read_write_pv + case {"x": execute_pv}: + return SignalX, execute_pv, execute_pv + case _: + raise TypeError(f"Can't process entry {entry}") - if is_device_vector: - if issubclass(device_cls, Signal): - sub_device_1 = device_cls(SoftSignalBackend(signal_dtype)) - sub_device_2 = device_cls(SoftSignalBackend(signal_dtype)) - sub_device = DeviceVector({1: sub_device_1, 2: sub_device_2}) - else: - if hasattr(device, device_name): - sub_device = getattr(device, device_name) - else: - sub_device = DeviceVector( - { - 1: device_cls(), - 2: device_cls(), - } - ) - - for sub_device_in_vector in sub_device.values(): - _mock_common_blocks(sub_device_in_vector, stripped_type=device_cls) - - for value in sub_device.values(): - value.parent = sub_device - else: - if issubclass(device_cls, Signal): - sub_device = device_cls(SoftSignalBackend(signal_dtype)) - else: - sub_device = getattr(device, device_name, device_cls()) - _mock_common_blocks(sub_device, stripped_type=device_cls) - - setattr(device, device_name, sub_device) - sub_device.parent = device - - -async def _get_pvi_entries(entry: _PVIEntry, timeout=DEFAULT_TIMEOUT): - if not entry.pvi_pv or not entry.pvi_pv.endswith(":PVI"): - raise RuntimeError("Top level entry must be a pvi table") - - pvi_table_signal_backend: PvaSignalBackend = PvaSignalBackend( - None, entry.pvi_pv, entry.pvi_pv - ) - await pvi_table_signal_backend.connect( - timeout=timeout - ) # create table signal backend - pva_table = (await pvi_table_signal_backend.get_value())["pvi"] - common_device_type_hints = ( - get_type_hints(entry.common_device_type) if entry.common_device_type else {} - ) - - for sub_name, pva_entries in pva_table.items(): - pvs = list(pva_entries.values()) - is_pvi_table = len(pvs) == 1 and pvs[0].endswith(":PVI") - sub_name_split, sub_number_split = _strip_number_from_string(sub_name) - is_device_vector, is_signal, signal_dtype, device_type = _parse_type( - is_pvi_table, - sub_number_split, - common_device_type_hints.get(sub_name_split), - ) - if is_signal: - device = _pvi_mapping[frozenset(pva_entries.keys())](signal_dtype, *pvs) - else: - device = getattr(entry.device, sub_name, device_type()) - - sub_entry = _PVIEntry( - device=device, common_device_type=device_type, sub_entries={} - ) - - if is_device_vector: - # If device vector then we store sub_name -> {sub_number -> sub_entry} - # and aggregate into `DeviceVector` in `_set_device_attributes` - sub_number_split = 1 if sub_number_split is None else sub_number_split - if sub_name_split not in entry.sub_entries: - entry.sub_entries[sub_name_split] = {} - entry.sub_entries[sub_name_split][sub_number_split] = sub_entry # type: ignore - else: - entry.sub_entries[sub_name] = sub_entry +class PviDeviceConnector(DeviceConnector): + def __init__(self, pvi_pv: str = "") -> None: + self.pvi_pv = pvi_pv - if is_pvi_table: - sub_entry.pvi_pv = pvs[0] - await _get_pvi_entries(sub_entry) - - if entry.common_device_type: - _verify_common_blocks(entry, entry.common_device_type) - - -def _set_device_attributes(entry: _PVIEntry): - for sub_name, sub_entry in entry.sub_entries.items(): - if isinstance(sub_entry, dict): - sub_device = DeviceVector() # type: ignore - for key, device_vector_sub_entry in sub_entry.items(): - sub_device[key] = device_vector_sub_entry.device - if device_vector_sub_entry.pvi_pv: - _set_device_attributes(device_vector_sub_entry) - # Set the device vector entry to have the device vector as a parent - device_vector_sub_entry.device.parent = sub_device # type: ignore - else: - sub_device = sub_entry.device - assert sub_device, f"Device of {sub_entry} is None" - if sub_entry.pvi_pv: - _set_device_attributes(sub_entry) - - sub_device.parent = entry.device - setattr(entry.device, sub_name, sub_device) - - -async def fill_pvi_entries( - device: Device, root_pv: str, timeout=DEFAULT_TIMEOUT, mock=False -): - """ - Fills a ``device`` with signals from a the ``root_pvi:PVI`` table. - - If the device names match with parent devices of ``device`` then types are used. - """ - if mock: - # set up mock signals for the common annotations - _mock_common_blocks(device) - else: - # check the pvi table for devices and fill the device with them - root_entry = _PVIEntry( - pvi_pv=root_pv, + def create_children_from_annotations(self, device: Device): + self._filler = DeviceFiller( device=device, - common_device_type=type(device), - sub_entries={}, + signal_backend_factory=PvaSignalBackend, + device_connector_factory=PviDeviceConnector, ) - await _get_pvi_entries(root_entry, timeout=timeout) - _set_device_attributes(root_entry) - - # We call set name now the parent field has been set in all of the - # introspect-initialized devices. This will recursively set the names. - device.set_name(device.name) - - -def create_children_from_annotations( - device: Device, - included_optional_fields: tuple[str, ...] = (), - device_vectors: dict[str, int] | None = None, -): - """For intializing blocks at __init__ of ``device``.""" - for name, device_type in get_type_hints(type(device)).items(): - if name in ("_name", "parent"): - continue - device_type, is_optional = _strip_union(device_type) - if is_optional and name not in included_optional_fields: - continue - is_device_vector, device_type = _strip_device_vector(device_type) - if ( - (is_device_vector and (not device_vectors or name not in device_vectors)) - or ((origin := get_origin(device_type)) and issubclass(origin, Signal)) - or (isclass(device_type) and issubclass(device_type, Signal)) - ): - continue - if is_device_vector: - n_device_vector = DeviceVector( - {i: device_type() for i in range(1, device_vectors[name] + 1)} # type: ignore - ) - setattr(device, name, n_device_vector) - for sub_device in n_device_vector.values(): - create_children_from_annotations( - sub_device, device_vectors=device_vectors - ) + async def connect( + self, device: Device, mock: bool, timeout: float, force_reconnect: bool + ) -> None: + if mock: + # Make 2 entries for each DeviceVector + self._filler.make_soft_device_vector_entries(2) else: - sub_device = device_type() - setattr(device, name, sub_device) - create_children_from_annotations(sub_device, device_vectors=device_vectors) + pvi_structure = await pvget_with_timeout(self.pvi_pv, timeout) + entries: dict[str, dict[str, str]] = pvi_structure["value"].todict() + # Ensure we have device vectors for everything that should be there + self._filler.make_device_vectors(list(entries)) + for name, entry in entries.items(): + if set(entry) == {"d"}: + connector = self._filler.make_child_device(name) + connector.pvi_pv = entry["d"] + else: + signal_type, read_pv, write_pv = _get_signal_details(entry) + backend = self._filler.make_child_signal(name, signal_type) + backend.read_pv = read_pv + backend.write_pv = write_pv + # Check that all the requested children have been created + if unfilled := self._filler.unfilled(): + raise RuntimeError( + f"{device.name}: cannot provision {unfilled} from " + f"{self.pvi_pv}: {entries}" + ) + # Set the name of the device to name all children + device.set_name(device.name) + return await super().connect(device, mock, timeout, force_reconnect) diff --git a/src/ophyd_async/epics/signal/__init__.py b/src/ophyd_async/epics/signal/__init__.py index 8d7628bf01..703880c9ab 100644 --- a/src/ophyd_async/epics/signal/__init__.py +++ b/src/ophyd_async/epics/signal/__init__.py @@ -1,5 +1,5 @@ -from ._common import LimitPair, Limits, get_supported_values -from ._p4p import PvaSignalBackend +from ._common import get_supported_values +from ._p4p import PvaSignalBackend, pvget_with_timeout from ._signal import ( epics_signal_r, epics_signal_rw, @@ -10,9 +10,8 @@ __all__ = [ "get_supported_values", - "LimitPair", - "Limits", "PvaSignalBackend", + "pvget_with_timeout", "epics_signal_r", "epics_signal_rw", "epics_signal_rw_rbv", diff --git a/src/ophyd_async/epics/signal/_aioca.py b/src/ophyd_async/epics/signal/_aioca.py index bdac6d878f..9dc6e4c5ae 100644 --- a/src/ophyd_async/epics/signal/_aioca.py +++ b/src/ophyd_async/epics/signal/_aioca.py @@ -1,11 +1,8 @@ -import inspect import logging import sys from collections.abc import Sequence -from dataclasses import dataclass -from enum import Enum from math import isnan, nan -from typing import Any, get_origin +from typing import Any, Generic, cast import numpy as np from aioca import ( @@ -21,233 +18,198 @@ from aioca.types import AugmentedValue, Dbr, Format from bluesky.protocols import Reading from epicscorelibs.ca import dbr -from event_model import DataKey -from event_model.documents.event_descriptor import Dtype +from event_model import DataKey, Limits, LimitsRange from ophyd_async.core import ( - DEFAULT_TIMEOUT, + Array1D, + Callback, NotConnected, - ReadingValueCallback, - RuntimeSubsetEnum, SignalBackend, - T, - get_dtype, + SignalDatatype, + SignalDatatypeT, + SignalMetadata, + get_enum_cls, get_unique, + make_datakey, wait_for_connection, ) -from ._common import LimitPair, Limits, common_meta, get_supported_values - -dbr_to_dtype: dict[Dbr, Dtype] = { - dbr.DBR_STRING: "string", - dbr.DBR_SHORT: "integer", - dbr.DBR_FLOAT: "number", - dbr.DBR_CHAR: "string", - dbr.DBR_LONG: "integer", - dbr.DBR_DOUBLE: "number", -} - - -def _data_key_from_augmented_value( - value: AugmentedValue, - *, - choices: list[str] | None = None, - dtype: Dtype | None = None, -) -> DataKey: - """Use the return value of get with FORMAT_CTRL to construct a DataKey - describing the signal. See docstring of AugmentedValue for expected - value fields by DBR type. - - Args: - value (AugmentedValue): Description of the the return type of a DB record - choices: Optional list of enum choices to pass as metadata in the datakey - dtype: Optional override dtype when AugmentedValue is ambiguous, e.g. booleans - - Returns: - DataKey: A rich DataKey describing the DB record - """ - source = f"ca://{value.name}" - assert value.ok, f"Error reading {source}: {value}" - - scalar = value.element_count == 1 - dtype = dtype or dbr_to_dtype[value.datatype] # type: ignore - - dtype_numpy = np.dtype(dbr.DbrCodeToType[value.datatype].dtype).descr[0][1] - - d = DataKey( - source=source, - dtype=dtype if scalar else "array", - # Ignore until https://github.com/bluesky/event-model/issues/308 - dtype_numpy=dtype_numpy, # type: ignore - # strictly value.element_count >= len(value) - shape=[] if scalar else [len(value)], - ) - for key in common_meta: - attr = getattr(value, key, nan) - if isinstance(attr, str) or not isnan(attr): - d[key] = attr - - if choices is not None: - d["choices"] = choices # type: ignore - - if limits := _limits_from_augmented_value(value): - d["limits"] = limits # type: ignore - - return d +from ._common import format_datatype, get_supported_values def _limits_from_augmented_value(value: AugmentedValue) -> Limits: - def get_limits(limit: str) -> LimitPair: + def get_limits(limit: str) -> LimitsRange | None: low = getattr(value, f"lower_{limit}_limit", nan) high = getattr(value, f"upper_{limit}_limit", nan) - return LimitPair( - low=None if isnan(low) else low, high=None if isnan(high) else high - ) - - return Limits( - alarm=get_limits("alarm"), - control=get_limits("ctrl"), - display=get_limits("disp"), - warning=get_limits("warning"), - ) - + if not (isnan(low) and isnan(high)): + return LimitsRange( + low=None if isnan(low) else low, + high=None if isnan(high) else high, + ) -@dataclass -class CaConverter: - read_dbr: Dbr | None - write_dbr: Dbr | None + limits = Limits() + if limits_range := get_limits("alarm"): + limits["alarm"] = limits_range + if limits_range := get_limits("ctrl"): + limits["control"] = limits_range + if limits_range := get_limits("disp"): + limits["display"] = limits_range + if limits_range := get_limits("warning"): + limits["warning"] = limits_range + return limits + + +def _metadata_from_augmented_value( + value: AugmentedValue, metadata: SignalMetadata +) -> SignalMetadata: + metadata = metadata.copy() + if hasattr(value, "units"): + metadata["units"] = value.units + if hasattr(value, "precision") and not isnan(value.precision): + metadata["precision"] = value.precision + if limits := _limits_from_augmented_value(value): + metadata["limits"] = limits + return metadata + + +class CaConverter(Generic[SignalDatatypeT]): + def __init__( + self, + datatype: type[SignalDatatypeT], + read_dbr: Dbr, + write_dbr: Dbr | None = None, + metadata: SignalMetadata | None = None, + ): + self.datatype = datatype + self.read_dbr: Dbr = read_dbr + self.write_dbr: Dbr | None = write_dbr + self.metadata = metadata or SignalMetadata() - def write_value(self, value) -> Any: + def write_value(self, value: Any) -> Any: + # The ca library will do the conversion for us return value - def value(self, value: AugmentedValue): + def value(self, value: AugmentedValue) -> SignalDatatypeT: # for channel access ca_xxx classes, this # invokes __pos__ operator to return an instance of # the builtin base class return +value # type: ignore - def reading(self, value: AugmentedValue) -> Reading: - return { - "value": self.value(value), - "timestamp": value.timestamp, - "alarm_severity": -1 if value.severity > 2 else value.severity, - } - - def get_datakey(self, value: AugmentedValue) -> DataKey: - return _data_key_from_augmented_value(value) - -class CaLongStrConverter(CaConverter): - def __init__(self): - return super().__init__(dbr.DBR_CHAR_STR, dbr.DBR_CHAR_STR) - - def write_value(self, value: str): - # Add a null in here as this is what the commandline caput does - # TODO: this should be in the server so check if it can be pushed to asyn - return value + "\0" +class DisconnectedCaConverter(CaConverter): + def __getattribute__(self, __name: str) -> Any: + raise NotImplementedError("No PV has been set as connect() has not been called") -class CaArrayConverter(CaConverter): - def value(self, value: AugmentedValue): +class CaArrayConverter(CaConverter[np.ndarray]): + def value(self, value: AugmentedValue) -> np.ndarray: + # A less expensive conversion return np.array(value, copy=False) -@dataclass -class CaEnumConverter(CaConverter): - """To prevent issues when a signal is restarted and returns with different enum - values or orders, we put treat an Enum signal as a string, and cache the - choices on this class. - """ +class CaSequenceStrConverter(CaConverter[Sequence[str]]): + def value(self, value: AugmentedValue) -> Sequence[str]: + return [str(v) for v in value] # type: ignore - choices: dict[str, str] - def write_value(self, value: Enum | str): - if isinstance(value, Enum): - return value.value - else: - return value +class CaLongStrConverter(CaConverter[str]): + def __init__(self): + super().__init__(str, dbr.DBR_CHAR_STR, dbr.DBR_CHAR_STR) - def value(self, value: AugmentedValue): - return self.choices[value] # type: ignore + def write_value_and_dbr(self, value: Any) -> Any: + # Add a null in here as this is what the commandline caput does + # TODO: this should be in the server so check if it can be pushed to asyn + return value + "\0" - def get_datakey(self, value: AugmentedValue) -> DataKey: - # Sometimes DBR_TYPE returns as String, must pass choices still - return _data_key_from_augmented_value(value, choices=list(self.choices.keys())) +class CaBoolConverter(CaConverter[bool]): + def __init__(self): + super().__init__(bool, dbr.DBR_SHORT) -@dataclass -class CaBoolConverter(CaConverter): def value(self, value: AugmentedValue) -> bool: return bool(value) - def get_datakey(self, value: AugmentedValue) -> DataKey: - return _data_key_from_augmented_value(value, dtype="boolean") +class CaEnumConverter(CaConverter[str]): + def __init__(self, supported_values: dict[str, str]): + self.supported_values = supported_values + super().__init__( + str, dbr.DBR_STRING, metadata=SignalMetadata(choices=list(supported_values)) + ) -class DisconnectedCaConverter(CaConverter): - def __getattribute__(self, __name: str) -> Any: - raise NotImplementedError("No PV has been set as connect() has not been called") + def value(self, value: AugmentedValue) -> str: + return self.supported_values[str(value)] + + +_datatype_converter_from_dbr: dict[ + tuple[Dbr, bool], tuple[type[SignalDatatype], type[CaConverter]] +] = { + (dbr.DBR_STRING, False): (str, CaConverter), + (dbr.DBR_SHORT, False): (int, CaConverter), + (dbr.DBR_FLOAT, False): (float, CaConverter), + (dbr.DBR_ENUM, False): (str, CaConverter), + (dbr.DBR_CHAR, False): (int, CaConverter), + (dbr.DBR_LONG, False): (int, CaConverter), + (dbr.DBR_DOUBLE, False): (float, CaConverter), + (dbr.DBR_STRING, True): (Sequence[str], CaSequenceStrConverter), + (dbr.DBR_SHORT, True): (Array1D[np.int16], CaArrayConverter), + (dbr.DBR_FLOAT, True): (Array1D[np.float32], CaArrayConverter), + (dbr.DBR_ENUM, True): (Sequence[str], CaSequenceStrConverter), + (dbr.DBR_CHAR, True): (Array1D[np.uint8], CaArrayConverter), + (dbr.DBR_LONG, True): (Array1D[np.int32], CaArrayConverter), + (dbr.DBR_DOUBLE, True): (Array1D[np.float64], CaArrayConverter), +} def make_converter( datatype: type | None, values: dict[str, AugmentedValue] ) -> CaConverter: pv = list(values)[0] - pv_dbr = get_unique({k: v.datatype for k, v in values.items()}, "datatypes") + pv_dbr = cast( + Dbr, get_unique({k: v.datatype for k, v in values.items()}, "datatypes") + ) is_array = bool([v for v in values.values() if v.element_count > 1]) - if is_array and datatype is str and pv_dbr == dbr.DBR_CHAR: + # Infer a datatype and converter from the dbr + inferred_datatype, converter_cls = _datatype_converter_from_dbr[(pv_dbr, is_array)] + # Some override cases + if is_array and pv_dbr == dbr.DBR_CHAR and datatype is str: # Override waveform of chars to be treated as string return CaLongStrConverter() - elif is_array and pv_dbr == dbr.DBR_STRING: - # Waveform of strings, check we wanted this - if datatype: - datatype_dtype = get_dtype(datatype) - if not datatype_dtype or not np.can_cast(datatype_dtype, np.str_): - raise TypeError(f"{pv} has type [str] not {datatype.__name__}") - return CaArrayConverter(pv_dbr, None) - elif is_array: - pv_dtype = get_unique({k: v.dtype for k, v in values.items()}, "dtypes") # type: ignore - # This is an array - if datatype: - # Check we wanted an array of this type - dtype = get_dtype(datatype) - if not dtype: - raise TypeError(f"{pv} has type [{pv_dtype}] not {datatype.__name__}") - if dtype != pv_dtype: - raise TypeError(f"{pv} has type [{pv_dtype}] not [{dtype}]") - return CaArrayConverter(pv_dbr, None) # type: ignore - elif pv_dbr == dbr.DBR_ENUM and datatype is bool: - # Database can't do bools, so are often representated as enums, - # CA can do int - pv_choices_len = get_unique( - {k: len(v.enums) for k, v in values.items()}, "number of choices" - ) - if pv_choices_len != 2: - raise TypeError(f"{pv} has {pv_choices_len} choices, can't map to bool") - return CaBoolConverter(dbr.DBR_SHORT, dbr.DBR_SHORT) - elif pv_dbr == dbr.DBR_ENUM: - # This is an Enum + elif not is_array and pv_dbr == dbr.DBR_ENUM: pv_choices = get_unique( {k: tuple(v.enums) for k, v in values.items()}, "choices" ) - supported_values = get_supported_values(pv, datatype, pv_choices) - return CaEnumConverter(dbr.DBR_STRING, None, supported_values) - else: - value = list(values.values())[0] - # Done the dbr check, so enough to check one of the values - if datatype and not isinstance(value, datatype): - # Allow int signals to represent float records when prec is 0 - is_prec_zero_float = ( - isinstance(value, float) - and get_unique({k: v.precision for k, v in values.items()}, "precision") - == 0 + if datatype is bool: + # Database can't do bools, so are often representated as enums of len 2 + if len(pv_choices) != 2: + raise TypeError(f"{pv} has {pv_choices=}, can't map to bool") + return CaBoolConverter() + elif enum_cls := get_enum_cls(datatype): + # If explicitly requested then check + return CaEnumConverter(get_supported_values(pv, enum_cls, pv_choices)) + elif datatype in (None, str): + # Drop to string for safety, but retain choices as metadata + return CaConverter( + str, + dbr.DBR_STRING, + metadata=SignalMetadata(choices=list(pv_choices)), ) - if not (datatype is int and is_prec_zero_float): - raise TypeError( - f"{pv} has type {type(value).__name__.replace('ca_', '')} " - + f"not {datatype.__name__}" - ) - return CaConverter(pv_dbr, None) # type: ignore + elif ( + inferred_datatype is float + and datatype is int + and get_unique({k: v.precision for k, v in values.items()}, "precision") == 0 + ): + # Allow int signals to represent float records when prec is 0 + return CaConverter(int, pv_dbr) + elif datatype in (None, inferred_datatype): + # If datatype matches what we are given then allow it and use inferred converter + return converter_cls(inferred_datatype, pv_dbr) + if pv_dbr == dbr.DBR_ENUM: + inferred_datatype = "str | SubsetEnum | StrictEnum" + raise TypeError( + f"{pv} with inferred datatype {format_datatype(inferred_datatype)}" + f" cannot be coerced to {format_datatype(datatype)}" + ) _tried_pyepics = False @@ -262,42 +224,24 @@ def _use_pyepics_context_if_imported(): _tried_pyepics = True -class CaSignalBackend(SignalBackend[T]): - _ALLOWED_DATATYPES = ( - bool, - int, - float, - str, - Sequence, - Enum, - RuntimeSubsetEnum, - np.ndarray, - ) - - @classmethod - def datatype_allowed(cls, dtype: Any) -> bool: - stripped_origin = get_origin(dtype) or dtype - if dtype is None: - return True - - return inspect.isclass(stripped_origin) and issubclass( - stripped_origin, cls._ALLOWED_DATATYPES - ) - - def __init__(self, datatype: type[T] | None, read_pv: str, write_pv: str): - self.datatype = datatype - if not CaSignalBackend.datatype_allowed(self.datatype): - raise TypeError(f"Given datatype {self.datatype} unsupported in CA.") +class CaSignalBackend(SignalBackend[SignalDatatypeT]): + def __init__( + self, + datatype: type[SignalDatatypeT] | None, + read_pv: str = "", + write_pv: str = "", + ): self.read_pv = read_pv self.write_pv = write_pv + self.converter: CaConverter = DisconnectedCaConverter(float, dbr.DBR_DOUBLE) self.initial_values: dict[str, AugmentedValue] = {} - self.converter: CaConverter = DisconnectedCaConverter(None, None) self.subscription: Subscription | None = None + super().__init__(datatype) - def source(self, name: str): - return f"ca://{self.read_pv}" + def source(self, name: str, read: bool): + return f"ca://{self.read_pv if read else self.write_pv}" - async def _store_initial_value(self, pv, timeout: float = DEFAULT_TIMEOUT): + async def _store_initial_value(self, pv: str, timeout: float): try: self.initial_values[pv] = await caget( pv, format=FORMAT_CTRL, timeout=timeout @@ -306,7 +250,7 @@ async def _store_initial_value(self, pv, timeout: float = DEFAULT_TIMEOUT): logging.debug(f"signal ca://{pv} timed out") raise NotConnected(f"ca://{pv}") from exc - async def connect(self, timeout: float = DEFAULT_TIMEOUT): + async def connect(self, timeout: float): _use_pyepics_context_if_imported() if self.read_pv != self.write_pv: # Different, need to connect both @@ -319,7 +263,19 @@ async def connect(self, timeout: float = DEFAULT_TIMEOUT): await self._store_initial_value(self.read_pv, timeout=timeout) self.converter = make_converter(self.datatype, self.initial_values) - async def put(self, value: T | None, wait=True, timeout=None): + async def _caget(self, pv: str, format: Format) -> AugmentedValue: + return await caget( + pv, datatype=self.converter.read_dbr, format=format, timeout=None + ) + + def _make_reading(self, value: AugmentedValue) -> Reading[SignalDatatypeT]: + return { + "value": self.converter.value(value), + "timestamp": value.timestamp, + "alarm_severity": -1 if value.severity > 2 else value.severity, + } + + async def put(self, value: SignalDatatypeT | None, wait: bool): if value is None: write_value = self.initial_values[self.write_pv] else: @@ -329,50 +285,39 @@ async def put(self, value: T | None, wait=True, timeout=None): write_value, datatype=self.converter.write_dbr, wait=wait, - timeout=timeout, - ) - - async def _caget(self, format: Format) -> AugmentedValue: - return await caget( - self.read_pv, - datatype=self.converter.read_dbr, - format=format, timeout=None, ) async def get_datakey(self, source: str) -> DataKey: - value = await self._caget(FORMAT_CTRL) - return self.converter.get_datakey(value) + value = await self._caget(self.read_pv, FORMAT_CTRL) + metadata = _metadata_from_augmented_value(value, self.converter.metadata) + return make_datakey( + self.converter.datatype, self.converter.value(value), source, metadata + ) - async def get_reading(self) -> Reading: - value = await self._caget(FORMAT_TIME) - return self.converter.reading(value) + async def get_reading(self) -> Reading[SignalDatatypeT]: + value = await self._caget(self.read_pv, FORMAT_TIME) + return self._make_reading(value) - async def get_value(self) -> T: - value = await self._caget(FORMAT_RAW) + async def get_value(self) -> SignalDatatypeT: + value = await self._caget(self.read_pv, FORMAT_RAW) return self.converter.value(value) - async def get_setpoint(self) -> T: - value = await caget( - self.write_pv, - datatype=self.converter.read_dbr, - format=FORMAT_RAW, - timeout=None, - ) + async def get_setpoint(self) -> SignalDatatypeT: + value = await self._caget(self.write_pv, FORMAT_RAW) return self.converter.value(value) - def set_callback(self, callback: ReadingValueCallback[T] | None) -> None: + def set_callback(self, callback: Callback[Reading[SignalDatatypeT]] | None) -> None: if callback: assert ( not self.subscription ), "Cannot set a callback when one is already set" self.subscription = camonitor( self.read_pv, - lambda v: callback(self.converter.reading(v), self.converter.value(v)), + lambda v: callback(self._make_reading(v)), datatype=self.converter.read_dbr, format=FORMAT_TIME, ) - else: - if self.subscription: - self.subscription.close() + elif self.subscription: + self.subscription.close() self.subscription = None diff --git a/src/ophyd_async/epics/signal/_common.py b/src/ophyd_async/epics/signal/_common.py index ae40e93029..d11c85be54 100644 --- a/src/ophyd_async/epics/signal/_common.py +++ b/src/ophyd_async/epics/signal/_common.py @@ -1,57 +1,43 @@ -import inspect -from enum import Enum +from collections.abc import Sequence +from typing import Any, get_args, get_origin -from typing_extensions import TypedDict +import numpy as np -from ophyd_async.core import RuntimeSubsetEnum - -common_meta = { - "units", - "precision", -} - - -class LimitPair(TypedDict): - high: float | None - low: float | None - - -class Limits(TypedDict): - alarm: LimitPair - control: LimitPair - display: LimitPair - warning: LimitPair +from ophyd_async.core import SubsetEnum, get_dtype, get_enum_cls def get_supported_values( pv: str, - datatype: type[str] | None, - pv_choices: tuple[str, ...], + datatype: type, + pv_choices: Sequence[str], ) -> dict[str, str]: - if inspect.isclass(datatype) and issubclass(datatype, RuntimeSubsetEnum): - if not set(datatype.choices).issubset(set(pv_choices)): - raise TypeError( - f"{pv} has choices {pv_choices}, " - f"which is not a superset of {str(datatype)}." - ) - return {x: x or "_" for x in pv_choices} - elif inspect.isclass(datatype) and issubclass(datatype, Enum): - if not issubclass(datatype, str): - raise TypeError( - f"{pv} is type Enum but {datatype} does not inherit from String." - ) - - choices = tuple(v.value for v in datatype) + enum_cls = get_enum_cls(datatype) + if not enum_cls: + raise TypeError(f"{datatype} is not an Enum") + choices = [v.value for v in enum_cls] + error_msg = f"{pv} has choices {pv_choices}, but {datatype} requested {choices} " + if issubclass(enum_cls, SubsetEnum): + if not set(choices).issubset(pv_choices): + raise TypeError(error_msg + "to be a subset of them.") + else: if set(choices) != set(pv_choices): - raise TypeError( - f"{pv} has choices {pv_choices}, " - f"which do not match {datatype}, which has {choices}." - ) - return {x: datatype(x) if x else "_" for x in pv_choices} - elif datatype is None or datatype is str: - return {x: x or "_" for x in pv_choices} - - raise TypeError( - f"{pv} has choices {pv_choices}. " - "Use an Enum or SubsetEnum to represent this." - ) + raise TypeError(error_msg + "to be strictly equal to them.") + + # Take order from the pv choices + supported_values = {x: x for x in pv_choices} + # But override those that we specify via the datatype + for v in enum_cls: + supported_values[v.value] = v + return supported_values + + +def format_datatype(datatype: Any) -> str: + if get_origin(datatype) is np.ndarray and get_args(datatype)[0] == tuple[int]: + dtype = get_dtype(datatype) + return f"Array1D[np.{dtype.name}]" + elif get_origin(datatype) is Sequence: + return f"Sequence[{get_args(datatype)[0].__name__}]" + elif isinstance(datatype, type): + return datatype.__name__ + else: + return str(datatype) diff --git a/src/ophyd_async/epics/signal/_epics_transport.py b/src/ophyd_async/epics/signal/_epics_transport.py deleted file mode 100644 index 4737de704f..0000000000 --- a/src/ophyd_async/epics/signal/_epics_transport.py +++ /dev/null @@ -1,34 +0,0 @@ -"""EPICS Signals over CA or PVA""" - -from __future__ import annotations - -from enum import Enum - - -def _make_unavailable_class(error: Exception) -> type: - class TransportNotAvailable: - def __init__(*args, **kwargs): - raise NotImplementedError("Transport not available") from error - - return TransportNotAvailable - - -try: - from ._aioca import CaSignalBackend -except ImportError as ca_error: - CaSignalBackend = _make_unavailable_class(ca_error) - - -try: - from ._p4p import PvaSignalBackend -except ImportError as pva_error: - PvaSignalBackend = _make_unavailable_class(pva_error) - - -class _EpicsTransport(Enum): - """The sorts of transport EPICS support""" - - #: Use Channel Access (using aioca library) - ca = CaSignalBackend - #: Use PVAccess (using p4p library) - pva = PvaSignalBackend diff --git a/src/ophyd_async/epics/signal/_p4p.py b/src/ophyd_async/epics/signal/_p4p.py index 6fe13d0e2c..3ec4195e53 100644 --- a/src/ophyd_async/epics/signal/_p4p.py +++ b/src/ophyd_async/epics/signal/_p4p.py @@ -1,198 +1,107 @@ +from __future__ import annotations + import asyncio import atexit -import inspect import logging -import time -from collections.abc import Sequence -from dataclasses import dataclass -from enum import Enum +from collections.abc import Mapping, Sequence from math import isnan, nan -from typing import Any, get_origin +from typing import Any, Generic import numpy as np from bluesky.protocols import Reading -from event_model import DataKey -from event_model.documents.event_descriptor import Dtype +from event_model import DataKey, Limits, LimitsRange from p4p import Value from p4p.client.asyncio import Context, Subscription from pydantic import BaseModel from ophyd_async.core import ( - DEFAULT_TIMEOUT, + Array1D, + Callback, NotConnected, - ReadingValueCallback, - RuntimeSubsetEnum, SignalBackend, - T, - get_dtype, + SignalDatatype, + SignalDatatypeT, + SignalMetadata, + StrictEnum, + Table, + get_enum_cls, get_unique, - is_pydantic_model, + make_datakey, wait_for_connection, ) -from ._common import LimitPair, Limits, common_meta, get_supported_values +from ._common import format_datatype, get_supported_values -# https://mdavidsaver.github.io/p4p/values.html -specifier_to_dtype: dict[str, Dtype] = { - "?": "integer", # bool - "b": "integer", # int8 - "B": "integer", # uint8 - "h": "integer", # int16 - "H": "integer", # uint16 - "i": "integer", # int32 - "I": "integer", # uint32 - "l": "integer", # int64 - "L": "integer", # uint64 - "f": "number", # float32 - "d": "number", # float64 - "s": "string", -} -specifier_to_np_dtype: dict[str, str] = { - "?": " DataKey: - """ - Args: - value (Value): Description of the the return type of a DB record - shape: Optional override shape when len(shape) > 1 - choices: Optional list of enum choices to pass as metadata in the datakey - dtype: Optional override dtype when AugmentedValue is ambiguous, e.g. booleans - - Returns: - DataKey: A rich DataKey describing the DB record - """ - shape = shape or [] - type_code = value.type().aspy("value") - - dtype = dtype or specifier_to_dtype[type_code] - - try: - if isinstance(type_code, tuple): - dtype_numpy = "" - if type_code[1] == "enum_t": - if dtype == "boolean": - dtype_numpy = " Limits: +def _limits_from_value(value: Any) -> Limits: def get_limits( substucture_name: str, low_name: str = "limitLow", high_name: str = "limitHigh" - ) -> LimitPair: + ) -> LimitsRange | None: substructure = getattr(value, substucture_name, None) low = getattr(substructure, low_name, nan) high = getattr(substructure, high_name, nan) - return LimitPair( - low=None if isnan(low) else low, high=None if isnan(high) else high - ) - - return Limits( - alarm=get_limits("valueAlarm", "lowAlarmLimit", "highAlarmLimit"), - control=get_limits("control"), - display=get_limits("display"), - warning=get_limits("valueAlarm", "lowWarningLimit", "highWarningLimit"), - ) - + if not (isnan(low) and isnan(high)): + return LimitsRange( + low=None if isnan(low) else low, + high=None if isnan(high) else high, + ) -class PvaConverter: - def write_value(self, value): - return value + limits = Limits() + if limits_range := get_limits("valueAlarm", "lowAlarmLimit", "highAlarmLimit"): + limits["alarm"] = limits_range + if limits_range := get_limits("control"): + limits["control"] = limits_range + if limits_range := get_limits("display"): + limits["display"] = limits_range + if limits_range := get_limits("valueAlarm", "lowWarningLimit", "highWarningLimit"): + limits["warning"] = limits_range + return limits + + +def _metadata_from_value(datatype: type[SignalDatatype], value: Any) -> SignalMetadata: + metadata = SignalMetadata() + value_data: Any = getattr(value, "value", None) + display_data: Any = getattr(value, "display", None) + if hasattr(display_data, "units"): + metadata["units"] = display_data.units + if hasattr(display_data, "precision") and not isnan(display_data.precision): + metadata["precision"] = display_data.precision + if limits := _limits_from_value(value): + metadata["limits"] = limits + # Get choices from display or value + if datatype is str or issubclass(datatype, StrictEnum): + if hasattr(display_data, "choices"): + metadata["choices"] = display_data.choices + elif hasattr(value_data, "choices"): + metadata["choices"] = value_data.choices + return metadata - def value(self, value): - return value["value"] - def reading(self, value) -> Reading: - ts = value["timeStamp"] - sv = value["alarm"]["severity"] - return { - "value": self.value(value), - "timestamp": ts["secondsPastEpoch"] + ts["nanoseconds"] * 1e-9, - "alarm_severity": -1 if sv > 2 else sv, - } +class PvaConverter(Generic[SignalDatatypeT]): + value_fields = ("value",) + reading_fields = ("alarm", "timeStamp") - def get_datakey(self, source: str, value) -> DataKey: - return _data_key_from_value(source, value) + def __init__(self, datatype: type[SignalDatatypeT]): + self.datatype = datatype - def metadata_fields(self) -> list[str]: - """ - Fields to request from PVA for metadata. - """ - return ["alarm", "timeStamp"] + def value(self, value: Any) -> SignalDatatypeT: + # for channel access ca_xxx classes, this + # invokes __pos__ operator to return an instance of + # the builtin base class + return value["value"] - def value_fields(self) -> list[str]: - """ - Fields to request from PVA for the value. - """ - return ["value"] + def write_value(self, value: Any) -> Any: + # The pva library will do the conversion for us + return value -class PvaArrayConverter(PvaConverter): - def get_datakey(self, source: str, value) -> DataKey: - return _data_key_from_value( - source, value, dtype="array", shape=[len(value["value"])] - ) +class DisconnectedPvaConverter(PvaConverter): + def __getattribute__(self, __name: str) -> Any: + raise NotImplementedError("No PV has been set as connect() has not been called") -class PvaNDArrayConverter(PvaConverter): - def metadata_fields(self) -> list[str]: - return super().metadata_fields() + ["dimension"] +class PvaNDArrayConverter(PvaConverter[SignalDatatypeT]): + value_fields = ("value", "dimension") def _get_dimensions(self, value) -> list[int]: dimensions: list[Value] = value["dimension"] @@ -205,243 +114,202 @@ def _get_dimensions(self, value) -> list[int]: # last index changing fastest. return dims[::-1] - def value(self, value): + def value(self, value: Any) -> SignalDatatypeT: dims = self._get_dimensions(value) return value["value"].reshape(dims) - def get_datakey(self, source: str, value) -> DataKey: - dims = self._get_dimensions(value) - return _data_key_from_value(source, value, dtype="array", shape=dims) - - def write_value(self, value): + def write_value(self, value: Any) -> Any: # No clear use-case for writing directly to an NDArray, and some # complexities around flattening to 1-D - e.g. dimension-order. # Don't support this for now. raise TypeError("Writing to NDArray not supported") -@dataclass -class PvaEnumConverter(PvaConverter): - """To prevent issues when a signal is restarted and returns with different enum - values or orders, we put treat an Enum signal as a string, and cache the - choices on this class. - """ - - def __init__(self, choices: dict[str, str]): - self.choices = tuple(choices.values()) +class PvaEnumConverter(PvaConverter[str]): + def __init__( + self, datatype: type[str] = str, supported_values: Mapping[str, str] = {} + ): + self.supported_values = supported_values + super().__init__(datatype) - def write_value(self, value: Enum | str): - if isinstance(value, Enum): - return value.value + def value(self, value: Any) -> str: + str_value = value["value"]["choices"][value["value"]["index"]] + if self.supported_values: + return self.supported_values[str_value] else: - return value - - def value(self, value): - return self.choices[value["value"]["index"]] + return str_value - def get_datakey(self, source: str, value) -> DataKey: - return _data_key_from_value( - source, value, choices=list(self.choices), dtype="string" - ) +class PvaEnumBoolConverter(PvaConverter[bool]): + def __init__(self): + super().__init__(bool) -class PvaEmumBoolConverter(PvaConverter): - def value(self, value): + def value(self, value: Any) -> bool: return bool(value["value"]["index"]) - def get_datakey(self, source: str, value) -> DataKey: - return _data_key_from_value(source, value, dtype="boolean") - - -class PvaTableConverter(PvaConverter): - def value(self, value): - return value["value"].todict() - - def get_datakey(self, source: str, value) -> DataKey: - # This is wrong, but defer until we know how to actually describe a table - return _data_key_from_value(source, value, dtype="object") # type: ignore +class PvaTableConverter(PvaConverter[Table]): + def value(self, value) -> Table: + return self.datatype(**value["value"].todict()) -class PvaPydanticModelConverter(PvaConverter): - def __init__(self, datatype: BaseModel): - self.datatype = datatype - - def value(self, value: Value): - return self.datatype(**value.todict()) # type: ignore - - def write_value(self, value: BaseModel | dict[str, Any]): - if isinstance(value, self.datatype): # type: ignore - return value.model_dump(mode="python") # type: ignore + def write_value(self, value: BaseModel | dict[str, Any]) -> Any: + if isinstance(value, self.datatype): + return value.model_dump(mode="python") return value -class PvaDictConverter(PvaConverter): - def reading(self, value) -> Reading: - ts = time.time() - value = value.todict() - # Alarm severity is vacuously 0 for a table - return {"value": value, "timestamp": ts, "alarm_severity": 0} - - def value(self, value: Value): - return value.todict() - - def get_datakey(self, source: str, value) -> DataKey: - raise NotImplementedError("Describing Dict signals not currently supported") - - def metadata_fields(self) -> list[str]: - """ - Fields to request from PVA for metadata. - """ - return [] - - def value_fields(self) -> list[str]: - """ - Fields to request from PVA for the value. - """ - return [] +# https://mdavidsaver.github.io/p4p/values.html +_datatype_converter_from_typeid: dict[ + tuple[str, str], tuple[type[SignalDatatype], type[PvaConverter]] +] = { + ("epics:nt/NTScalar:1.0", "?"): (bool, PvaConverter), + ("epics:nt/NTScalar:1.0", "b"): (int, PvaConverter), + ("epics:nt/NTScalar:1.0", "B"): (int, PvaConverter), + ("epics:nt/NTScalar:1.0", "h"): (int, PvaConverter), + ("epics:nt/NTScalar:1.0", "H"): (int, PvaConverter), + ("epics:nt/NTScalar:1.0", "i"): (int, PvaConverter), + ("epics:nt/NTScalar:1.0", "I"): (int, PvaConverter), + ("epics:nt/NTScalar:1.0", "l"): (int, PvaConverter), + ("epics:nt/NTScalar:1.0", "L"): (int, PvaConverter), + ("epics:nt/NTScalar:1.0", "f"): (float, PvaConverter), + ("epics:nt/NTScalar:1.0", "d"): (float, PvaConverter), + ("epics:nt/NTScalar:1.0", "s"): (str, PvaConverter), + ("epics:nt/NTEnum:1.0", "S"): (str, PvaEnumConverter), + ("epics:nt/NTScalarArray:1.0", "a?"): (Array1D[np.bool_], PvaConverter), + ("epics:nt/NTScalarArray:1.0", "ab"): (Array1D[np.int8], PvaConverter), + ("epics:nt/NTScalarArray:1.0", "aB"): (Array1D[np.uint8], PvaConverter), + ("epics:nt/NTScalarArray:1.0", "ah"): (Array1D[np.int16], PvaConverter), + ("epics:nt/NTScalarArray:1.0", "aH"): (Array1D[np.uint16], PvaConverter), + ("epics:nt/NTScalarArray:1.0", "ai"): (Array1D[np.int32], PvaConverter), + ("epics:nt/NTScalarArray:1.0", "aI"): (Array1D[np.uint32], PvaConverter), + ("epics:nt/NTScalarArray:1.0", "al"): (Array1D[np.int64], PvaConverter), + ("epics:nt/NTScalarArray:1.0", "aL"): (Array1D[np.uint64], PvaConverter), + ("epics:nt/NTScalarArray:1.0", "af"): (Array1D[np.float32], PvaConverter), + ("epics:nt/NTScalarArray:1.0", "ad"): (Array1D[np.float64], PvaConverter), + ("epics:nt/NTScalarArray:1.0", "as"): (Sequence[str], PvaConverter), + ("epics:nt/NTTable:1.0", "S"): (Table, PvaTableConverter), + ("epics:nt/NTNDArray:1.0", "v"): (np.ndarray, PvaNDArrayConverter), +} -class DisconnectedPvaConverter(PvaConverter): - def __getattribute__(self, __name: str) -> Any: - raise NotImplementedError("No PV has been set as connect() has not been called") +def _get_specifier(value: Value): + typ = value.type("value").aspy() + if isinstance(typ, tuple): + return typ[0] + else: + return str(typ) def make_converter(datatype: type | None, values: dict[str, Any]) -> PvaConverter: pv = list(values)[0] typeid = get_unique({k: v.getID() for k, v in values.items()}, "typeids") - typ = get_unique( - {k: type(v.get("value")) for k, v in values.items()}, "value types" + specifier = get_unique( + {k: _get_specifier(v) for k, v in values.items()}, + "value type specifiers", ) - if "NTScalarArray" in typeid and typ is list: - # Waveform of strings, check we wanted this - if datatype and datatype != Sequence[str]: - raise TypeError(f"{pv} has type [str] not {datatype.__name__}") - return PvaArrayConverter() - elif "NTScalarArray" in typeid or "NTNDArray" in typeid: - pv_dtype = get_unique( - {k: v["value"].dtype for k, v in values.items()}, "dtypes" - ) - # This is an array - if datatype: - # Check we wanted an array of this type - dtype = get_dtype(datatype) - if not dtype: - raise TypeError(f"{pv} has type [{pv_dtype}] not {datatype.__name__}") - if dtype != pv_dtype: - raise TypeError(f"{pv} has type [{pv_dtype}] not [{dtype}]") - if "NTNDArray" in typeid: - return PvaNDArrayConverter() - else: - return PvaArrayConverter() - elif "NTEnum" in typeid and datatype is bool: - # Wanted a bool, but database represents as an enum - pv_choices_len = get_unique( - {k: len(v["value"]["choices"]) for k, v in values.items()}, - "number of choices", - ) - if pv_choices_len != 2: - raise TypeError(f"{pv} has {pv_choices_len} choices, can't map to bool") - return PvaEmumBoolConverter() - elif "NTEnum" in typeid: - # This is an Enum + # Infer a datatype and converter from the typeid and specifier + inferred_datatype, converter_cls = _datatype_converter_from_typeid[ + (typeid, specifier) + ] + # Some override cases + if typeid == "epics:nt/NTEnum:1.0": pv_choices = get_unique( {k: tuple(v["value"]["choices"]) for k, v in values.items()}, "choices" ) - return PvaEnumConverter(get_supported_values(pv, datatype, pv_choices)) - elif "NTScalar" in typeid: - if ( - typ is str - and inspect.isclass(datatype) - and issubclass(datatype, RuntimeSubsetEnum) - ): + if datatype is bool: + # Database can't do bools, so are often representated as enums of len 2 + if len(pv_choices) != 2: + raise TypeError(f"{pv} has {pv_choices=}, can't map to bool") + return PvaEnumBoolConverter() + elif enum_cls := get_enum_cls(datatype): + # We were given an enum class, so make class from that return PvaEnumConverter( - get_supported_values(pv, datatype, datatype.choices) # type: ignore - ) - elif datatype and not issubclass(typ, datatype): - # Allow int signals to represent float records when prec is 0 - is_prec_zero_float = typ is float and ( - get_unique( - {k: v["display"]["precision"] for k, v in values.items()}, - "precision", - ) - == 0 + supported_values=get_supported_values(pv, enum_cls, pv_choices) ) - if not (datatype is int and is_prec_zero_float): - raise TypeError(f"{pv} has type {typ.__name__} not {datatype.__name__}") - return PvaConverter() - elif "NTTable" in typeid: - if is_pydantic_model(datatype): - return PvaPydanticModelConverter(datatype) # type: ignore - return PvaTableConverter() - elif "structure" in typeid: - return PvaDictConverter() - else: - raise TypeError(f"{pv}: Unsupported typeid {typeid}") - - -class PvaSignalBackend(SignalBackend[T]): - _ctxt: Context | None = None - - _ALLOWED_DATATYPES = ( - bool, - int, - float, - str, - Sequence, - np.ndarray, - Enum, - RuntimeSubsetEnum, - BaseModel, - dict, + elif datatype in (None, str): + # Still use the Enum converter, but make choices from what it has + return PvaEnumConverter() + elif ( + inferred_datatype is float + and datatype is int + and get_unique( + {k: v["display"]["precision"] for k, v in values.items()}, "precision" + ) + == 0 + ): + # Allow int signals to represent float records when prec is 0 + return PvaConverter(int) + elif inferred_datatype is str and (enum_cls := get_enum_cls(datatype)): + # Allow strings to be used as enums until QSRV supports this + return PvaConverter(str) + elif inferred_datatype is Table and datatype and issubclass(datatype, Table): + # Use a custom table class + return PvaTableConverter(datatype) + elif datatype in (None, inferred_datatype): + # If datatype matches what we are given then allow it and use inferred converter + return converter_cls(inferred_datatype) + raise TypeError( + f"{pv} with inferred datatype {format_datatype(inferred_datatype)}" + f" from {typeid=} {specifier=}" + f" cannot be coerced to {format_datatype(datatype)}" ) - @classmethod - def datatype_allowed(cls, dtype: Any) -> bool: - stripped_origin = get_origin(dtype) or dtype - if dtype is None: - return True - return inspect.isclass(stripped_origin) and issubclass( - stripped_origin, cls._ALLOWED_DATATYPES - ) - def __init__(self, datatype: type[T] | None, read_pv: str, write_pv: str): - self.datatype = datatype - if not PvaSignalBackend.datatype_allowed(self.datatype): - raise TypeError(f"Given datatype {self.datatype} unsupported in PVA.") +_context: Context | None = None + + +def context() -> Context: + global _context + if _context is None: + _context = Context("pva", nt=False) + + @atexit.register + def _del_ctxt(): + # If we don't do this we get messages like this on close: + # Error in sys.excepthook: + # Original exception was: + global _context + del _context + return _context + + +async def pvget_with_timeout(pv: str, timeout: float) -> Any: + try: + return await asyncio.wait_for(context().get(pv), timeout=timeout) + except asyncio.TimeoutError as exc: + logging.debug(f"signal pva://{pv} timed out", exc_info=True) + raise NotConnected(f"pva://{pv}") from exc + + +def _pva_request_string(fields: Sequence[str]) -> str: + """Converts a list of requested fields into a PVA request string which can be + passed to p4p. + """ + return f"field({','.join(fields)})" + + +class PvaSignalBackend(SignalBackend[SignalDatatypeT]): + def __init__( + self, + datatype: type[SignalDatatypeT] | None, + read_pv: str = "", + write_pv: str = "", + ): self.read_pv = read_pv self.write_pv = write_pv + self.converter: PvaConverter = DisconnectedPvaConverter(float) self.initial_values: dict[str, Any] = {} - self.converter: PvaConverter = DisconnectedPvaConverter() self.subscription: Subscription | None = None + super().__init__(datatype) - def source(self, name: str): - return f"pva://{self.read_pv}" - - @property - def ctxt(self) -> Context: - if PvaSignalBackend._ctxt is None: - PvaSignalBackend._ctxt = Context("pva", nt=False) + def source(self, name: str, read: bool): + return f"pva://{self.read_pv if read else self.write_pv}" - @atexit.register - def _del_ctxt(): - # If we don't do this we get messages like this on close: - # Error in sys.excepthook: - # Original exception was: - PvaSignalBackend._ctxt = None - - return PvaSignalBackend._ctxt - - async def _store_initial_value(self, pv, timeout: float = DEFAULT_TIMEOUT): - try: - self.initial_values[pv] = await asyncio.wait_for( - self.ctxt.get(pv), timeout=timeout - ) - except asyncio.TimeoutError as exc: - logging.debug(f"signal pva://{pv} timed out", exc_info=True) - raise NotConnected(f"pva://{pv}") from exc + async def _store_initial_value(self, pv: str, timeout: float): + self.initial_values[pv] = await pvget_with_timeout(pv, timeout) - async def connect(self, timeout: float = DEFAULT_TIMEOUT): + async def connect(self, timeout: float): if self.read_pv != self.write_pv: # Different, need to connect both await wait_for_connection( @@ -453,66 +321,61 @@ async def connect(self, timeout: float = DEFAULT_TIMEOUT): await self._store_initial_value(self.read_pv, timeout=timeout) self.converter = make_converter(self.datatype, self.initial_values) - async def put(self, value: T | None, wait=True, timeout=None): + def _make_reading(self, value: Any) -> Reading[SignalDatatypeT]: + ts = value["timeStamp"] + sv = value["alarm"]["severity"] + return { + "value": self.converter.value(value), + "timestamp": ts["secondsPastEpoch"] + ts["nanoseconds"] * 1e-9, + "alarm_severity": -1 if sv > 2 else sv, + } + + async def put(self, value: SignalDatatypeT | None, wait: bool): if value is None: write_value = self.initial_values[self.write_pv] else: write_value = self.converter.write_value(value) - coro = self.ctxt.put(self.write_pv, {"value": write_value}, wait=wait) - try: - await asyncio.wait_for(coro, timeout) - except asyncio.TimeoutError as exc: - logging.debug( - f"signal pva://{self.write_pv} timed out \ - put value: {write_value}", - exc_info=True, - ) - raise NotConnected(f"pva://{self.write_pv}") from exc + await context().put(self.write_pv, {"value": write_value}, wait=wait) async def get_datakey(self, source: str) -> DataKey: - value = await self.ctxt.get(self.read_pv) - return self.converter.get_datakey(source, value) - - def _pva_request_string(self, fields: list[str]) -> str: - """ - Converts a list of requested fields into a PVA request string which can be - passed to p4p. - """ - return f"field({','.join(fields)})" + value = await context().get(self.read_pv) + metadata = _metadata_from_value(self.converter.datatype, value) + return make_datakey( + self.converter.datatype, self.converter.value(value), source, metadata + ) async def get_reading(self) -> Reading: - request: str = self._pva_request_string( - self.converter.value_fields() + self.converter.metadata_fields() + request = _pva_request_string( + self.converter.value_fields + self.converter.reading_fields ) - value = await self.ctxt.get(self.read_pv, request=request) - return self.converter.reading(value) + value = await context().get(self.read_pv, request=request) + return self._make_reading(value) - async def get_value(self) -> T: - request: str = self._pva_request_string(self.converter.value_fields()) - value = await self.ctxt.get(self.read_pv, request=request) + async def get_value(self) -> SignalDatatypeT: + request = _pva_request_string(self.converter.value_fields) + value = await context().get(self.read_pv, request=request) return self.converter.value(value) - async def get_setpoint(self) -> T: - value = await self.ctxt.get(self.write_pv, "field(value)") + async def get_setpoint(self) -> SignalDatatypeT: + request = _pva_request_string(self.converter.value_fields) + value = await context().get(self.write_pv, request=request) return self.converter.value(value) - def set_callback(self, callback: ReadingValueCallback[T] | None) -> None: + def set_callback(self, callback: Callback[Reading[SignalDatatypeT]] | None) -> None: if callback: assert ( not self.subscription ), "Cannot set a callback when one is already set" async def async_callback(v): - callback(self.converter.reading(v), self.converter.value(v)) + callback(self._make_reading(v)) - request: str = self._pva_request_string( - self.converter.value_fields() + self.converter.metadata_fields() + request = _pva_request_string( + self.converter.value_fields + self.converter.reading_fields ) - - self.subscription = self.ctxt.monitor( + self.subscription = context().monitor( self.read_pv, async_callback, request=request ) - else: - if self.subscription: - self.subscription.close() + elif self.subscription: + self.subscription.close() self.subscription = None diff --git a/src/ophyd_async/epics/signal/_signal.py b/src/ophyd_async/epics/signal/_signal.py index 6711ac734e..180ed2a6e3 100644 --- a/src/ophyd_async/epics/signal/_signal.py +++ b/src/ophyd_async/epics/signal/_signal.py @@ -2,46 +2,81 @@ from __future__ import annotations +from enum import Enum + from ophyd_async.core import ( SignalBackend, + SignalDatatypeT, SignalR, SignalRW, SignalW, SignalX, - T, get_unique, ) -from ._epics_transport import _EpicsTransport -_default_epics_transport = _EpicsTransport.ca +def _make_unavailable_class(error: Exception) -> type: + class TransportNotAvailable: + def __init__(*args, **kwargs): + raise NotImplementedError("Transport not available") from error + + return TransportNotAvailable + + +class EpicsProtocol(Enum): + CA = "ca" + PVA = "pva" + + +_default_epics_protocol = EpicsProtocol.CA + +try: + from ._p4p import PvaSignalBackend +except ImportError as pva_error: + PvaSignalBackend = _make_unavailable_class(pva_error) +else: + _default_epics_protocol = EpicsProtocol.PVA + +try: + from ._aioca import CaSignalBackend +except ImportError as ca_error: + CaSignalBackend = _make_unavailable_class(ca_error) +else: + _default_epics_protocol = EpicsProtocol.CA -def _transport_pv(pv: str) -> tuple[_EpicsTransport, str]: +def _protocol_pv(pv: str) -> tuple[EpicsProtocol, str]: split = pv.split("://", 1) if len(split) > 1: # We got something like pva://mydevice, so use specified comms mode - transport_str, pv = split - transport = _EpicsTransport[transport_str] + scheme, pv = split + protocol = EpicsProtocol(scheme) else: # No comms mode specified, use the default - transport = _default_epics_transport - return transport, pv + protocol = _default_epics_protocol + return protocol, pv def _epics_signal_backend( - datatype: type[T] | None, read_pv: str, write_pv: str -) -> SignalBackend[T]: + datatype: type[SignalDatatypeT] | None, read_pv: str, write_pv: str +) -> SignalBackend[SignalDatatypeT]: """Create an epics signal backend.""" - r_transport, r_pv = _transport_pv(read_pv) - w_transport, w_pv = _transport_pv(write_pv) - transport = get_unique({read_pv: r_transport, write_pv: w_transport}, "transports") - return transport.value(datatype, r_pv, w_pv) + r_protocol, r_pv = _protocol_pv(read_pv) + w_protocol, w_pv = _protocol_pv(write_pv) + protocol = get_unique({read_pv: r_protocol, write_pv: w_protocol}, "protocols") + match protocol: + case EpicsProtocol.CA: + return CaSignalBackend(datatype, r_pv, w_pv) + case EpicsProtocol.PVA: + return PvaSignalBackend(datatype, r_pv, w_pv) def epics_signal_rw( - datatype: type[T], read_pv: str, write_pv: str | None = None, name: str = "" -) -> SignalRW[T]: + datatype: type[SignalDatatypeT], + read_pv: str, + write_pv: str | None = None, + name: str = "", +) -> SignalRW[SignalDatatypeT]: """Create a `SignalRW` backed by 1 or 2 EPICS PVs Parameters @@ -58,8 +93,11 @@ def epics_signal_rw( def epics_signal_rw_rbv( - datatype: type[T], write_pv: str, read_suffix: str = "_RBV", name: str = "" -) -> SignalRW[T]: + datatype: type[SignalDatatypeT], + write_pv: str, + read_suffix: str = "_RBV", + name: str = "", +) -> SignalRW[SignalDatatypeT]: """Create a `SignalRW` backed by 1 or 2 EPICS PVs, with a suffix on the readback pv Parameters @@ -74,7 +112,9 @@ def epics_signal_rw_rbv( return epics_signal_rw(datatype, f"{write_pv}{read_suffix}", write_pv, name) -def epics_signal_r(datatype: type[T], read_pv: str, name: str = "") -> SignalR[T]: +def epics_signal_r( + datatype: type[SignalDatatypeT], read_pv: str, name: str = "" +) -> SignalR[SignalDatatypeT]: """Create a `SignalR` backed by 1 EPICS PV Parameters @@ -88,7 +128,9 @@ def epics_signal_r(datatype: type[T], read_pv: str, name: str = "") -> SignalR[T return SignalR(backend, name=name) -def epics_signal_w(datatype: type[T], write_pv: str, name: str = "") -> SignalW[T]: +def epics_signal_w( + datatype: type[SignalDatatypeT], write_pv: str, name: str = "" +) -> SignalW[SignalDatatypeT]: """Create a `SignalW` backed by 1 EPICS PVs Parameters @@ -110,5 +152,5 @@ def epics_signal_x(write_pv: str, name: str = "") -> SignalX: write_pv: The PV to write its initial value to on trigger """ - backend: SignalBackend = _epics_signal_backend(None, write_pv, write_pv) + backend = _epics_signal_backend(None, write_pv, write_pv) return SignalX(backend, name=name) diff --git a/src/ophyd_async/fastcs/core.py b/src/ophyd_async/fastcs/core.py new file mode 100644 index 0000000000..bd2e32a033 --- /dev/null +++ b/src/ophyd_async/fastcs/core.py @@ -0,0 +1,9 @@ +from ophyd_async.core import Device, DeviceConnector +from ophyd_async.epics.pvi import PviDeviceConnector + + +def fastcs_connector(device: Device, uri: str) -> DeviceConnector: + # TODO: add Tango support based on uri scheme + connector = PviDeviceConnector(uri + "PVI") + connector.create_children_from_annotations(device) + return connector diff --git a/src/ophyd_async/fastcs/panda/__init__.py b/src/ophyd_async/fastcs/panda/__init__.py index 0dbe7222b0..29b27d557b 100644 --- a/src/ophyd_async/fastcs/panda/__init__.py +++ b/src/ophyd_async/fastcs/panda/__init__.py @@ -1,10 +1,10 @@ from ._block import ( + BitMux, CommonPandaBlocks, DataBlock, - EnableDisableOptions, PcapBlock, PcompBlock, - PcompDirectionOptions, + PcompDirection, PulseBlock, SeqBlock, TimeUnits, @@ -29,10 +29,10 @@ __all__ = [ "CommonPandaBlocks", "DataBlock", - "EnableDisableOptions", + "BitMux", "PcapBlock", "PcompBlock", - "PcompDirectionOptions", + "PcompDirection", "PulseBlock", "SeqBlock", "TimeUnits", diff --git a/src/ophyd_async/fastcs/panda/_block.py b/src/ophyd_async/fastcs/panda/_block.py index 6e28b43f4e..67767ba372 100644 --- a/src/ophyd_async/fastcs/panda/_block.py +++ b/src/ophyd_async/fastcs/panda/_block.py @@ -1,13 +1,16 @@ -from __future__ import annotations - -from enum import Enum - -from ophyd_async.core import Device, DeviceVector, SignalR, SignalRW, SubsetEnum +from ophyd_async.core import ( + Device, + DeviceVector, + SignalR, + SignalRW, + StrictEnum, + SubsetEnum, +) from ._table import DatasetTable, SeqTable -class CaptureMode(str, Enum): +class CaptureMode(StrictEnum): FIRST_N = "FIRST_N" LAST_N = "LAST_N" FOREVER = "FOREVER" @@ -32,26 +35,28 @@ class PulseBlock(Device): width: SignalRW[float] -class PcompDirectionOptions(str, Enum): +class PcompDirection(StrictEnum): positive = "Positive" negative = "Negative" either = "Either" -EnableDisableOptions = SubsetEnum["ZERO", "ONE"] +class BitMux(SubsetEnum): + zero = "ZERO" + one = "ONE" class PcompBlock(Device): active: SignalR[bool] - dir: SignalRW[PcompDirectionOptions] - enable: SignalRW[EnableDisableOptions] + dir: SignalRW[PcompDirection] + enable: SignalRW[BitMux] pulses: SignalRW[int] start: SignalRW[int] step: SignalRW[int] width: SignalRW[int] -class TimeUnits(str, Enum): +class TimeUnits(StrictEnum): min = "min" s = "s" ms = "ms" @@ -60,11 +65,11 @@ class TimeUnits(str, Enum): class SeqBlock(Device): table: SignalRW[SeqTable] - active: SignalRW[bool] + active: SignalR[bool] repeats: SignalRW[int] prescale: SignalRW[float] prescale_units: SignalRW[TimeUnits] - enable: SignalRW[EnableDisableOptions] + enable: SignalRW[BitMux] class PcapBlock(Device): diff --git a/src/ophyd_async/fastcs/panda/_control.py b/src/ophyd_async/fastcs/panda/_control.py index 04827a282b..1fe14c7909 100644 --- a/src/ophyd_async/fastcs/panda/_control.py +++ b/src/ophyd_async/fastcs/panda/_control.py @@ -1,12 +1,10 @@ -import asyncio - from ophyd_async.core import ( + AsyncStatus, DetectorController, DetectorTrigger, + TriggerInfo, wait_for_value, ) -from ophyd_async.core._detector import TriggerInfo -from ophyd_async.core._status import AsyncStatus from ._block import PcapBlock @@ -33,5 +31,5 @@ async def wait_for_idle(self): pass async def disarm(self): - await asyncio.gather(self.pcap.arm.set(False)) + await self.pcap.arm.set(False) await wait_for_value(self.pcap.active, False, timeout=1) diff --git a/src/ophyd_async/fastcs/panda/_hdf_panda.py b/src/ophyd_async/fastcs/panda/_hdf_panda.py index 5045d7b27f..f75403bbb3 100644 --- a/src/ophyd_async/fastcs/panda/_hdf_panda.py +++ b/src/ophyd_async/fastcs/panda/_hdf_panda.py @@ -2,8 +2,8 @@ from collections.abc import Sequence -from ophyd_async.core import DEFAULT_TIMEOUT, PathProvider, SignalR, StandardDetector -from ophyd_async.epics.pvi import create_children_from_annotations, fill_pvi_entries +from ophyd_async.core import PathProvider, SignalR, StandardDetector +from ophyd_async.fastcs.core import fastcs_connector from ._block import CommonPandaBlocks from ._control import PandaPcapController @@ -18,12 +18,10 @@ def __init__( config_sigs: Sequence[SignalR] = (), name: str = "", ): - self._prefix = prefix - - create_children_from_annotations(self) + # This has to be first so we make self.pcap + connector = fastcs_connector(self, prefix) controller = PandaPcapController(pcap=self.pcap) writer = PandaHDFWriter( - prefix=prefix, path_provider=path_provider, name_provider=lambda: name, panda_data_block=self.data, @@ -33,17 +31,5 @@ def __init__( writer=writer, config_sigs=config_sigs, name=name, - ) - - async def connect( - self, - mock: bool = False, - timeout: float = DEFAULT_TIMEOUT, - force_reconnect: bool = False, - ): - # TODO: this doesn't support caching - # https://github.com/bluesky/ophyd-async/issues/472 - await fill_pvi_entries(self, self._prefix + "PVI", timeout=timeout, mock=mock) - await super().connect( - mock=mock, timeout=timeout, force_reconnect=force_reconnect + connector=connector, ) diff --git a/src/ophyd_async/fastcs/panda/_table.py b/src/ophyd_async/fastcs/panda/_table.py index a021d23fa8..cb3c4ce1d9 100644 --- a/src/ophyd_async/fastcs/panda/_table.py +++ b/src/ophyd_async/fastcs/panda/_table.py @@ -1,27 +1,22 @@ from collections.abc import Sequence -from enum import Enum -from typing import Annotated import numpy as np -import numpy.typing as npt -from pydantic import Field, model_validator -from pydantic_numpy.helper.annotation import NpArrayPydanticAnnotation -from typing_extensions import TypedDict +from pydantic import model_validator -from ophyd_async.core import Table +from ophyd_async.core import Array1D, StrictEnum, Table -class PandaHdf5DatasetType(str, Enum): +class PandaHdf5DatasetType(StrictEnum): FLOAT_64 = "float64" UINT_32 = "uint32" -class DatasetTable(TypedDict): - name: npt.NDArray[np.str_] +class DatasetTable(Table): + name: Sequence[str] hdf5_type: Sequence[PandaHdf5DatasetType] -class SeqTrigger(str, Enum): +class SeqTrigger(StrictEnum): IMMEDIATE = "Immediate" BITA_0 = "BITA=0" BITA_1 = "BITA=1" @@ -37,45 +32,27 @@ class SeqTrigger(str, Enum): POSC_LT = "POSC<=POSITION" -PydanticNp1DArrayInt32 = Annotated[ - np.ndarray[tuple[int], np.dtype[np.int32]], - NpArrayPydanticAnnotation.factory( - data_type=np.int32, dimensions=1, strict_data_typing=False - ), - Field(default_factory=lambda: np.array([], np.int32)), -] -PydanticNp1DArrayBool = Annotated[ - np.ndarray[tuple[int], np.dtype[np.bool_]], - NpArrayPydanticAnnotation.factory( - data_type=np.bool_, dimensions=1, strict_data_typing=False - ), - Field(default_factory=lambda: np.array([], dtype=np.bool_)), -] -TriggerStr = Annotated[Sequence[SeqTrigger], Field(default_factory=list)] - - class SeqTable(Table): - repeats: PydanticNp1DArrayInt32 - trigger: TriggerStr - position: PydanticNp1DArrayInt32 - time1: PydanticNp1DArrayInt32 - outa1: PydanticNp1DArrayBool - outb1: PydanticNp1DArrayBool - outc1: PydanticNp1DArrayBool - outd1: PydanticNp1DArrayBool - oute1: PydanticNp1DArrayBool - outf1: PydanticNp1DArrayBool - time2: PydanticNp1DArrayInt32 - outa2: PydanticNp1DArrayBool - outb2: PydanticNp1DArrayBool - outc2: PydanticNp1DArrayBool - outd2: PydanticNp1DArrayBool - oute2: PydanticNp1DArrayBool - outf2: PydanticNp1DArrayBool + repeats: Array1D[np.uint16] + trigger: Sequence[SeqTrigger] + position: Array1D[np.int32] + time1: Array1D[np.uint32] + outa1: Array1D[np.bool_] + outb1: Array1D[np.bool_] + outc1: Array1D[np.bool_] + outd1: Array1D[np.bool_] + oute1: Array1D[np.bool_] + outf1: Array1D[np.bool_] + time2: Array1D[np.uint32] + outa2: Array1D[np.bool_] + outb2: Array1D[np.bool_] + outc2: Array1D[np.bool_] + outd2: Array1D[np.bool_] + oute2: Array1D[np.bool_] + outf2: Array1D[np.bool_] - @classmethod - def row( # type: ignore - cls, + @staticmethod + def row( *, repeats: int = 1, trigger: str = SeqTrigger.IMMEDIATE, @@ -95,7 +72,8 @@ def row( # type: ignore oute2: bool = False, outf2: bool = False, ) -> "SeqTable": - return Table.row(**locals()) + # Let pydantic do the conversions for us + return SeqTable(**{k: [v] for k, v in locals().items()}) # type: ignore @model_validator(mode="after") def validate_max_length(self) -> "SeqTable": @@ -104,6 +82,6 @@ def validate_max_length(self) -> "SeqTable": the pydantic field doesn't work """ - first_length = len(next(iter(self))[1]) - assert 0 <= first_length < 4096, f"Length {first_length} not in range." + first_length = len(self) + assert first_length <= 4096, f"Length {first_length} is too long" return self diff --git a/src/ophyd_async/fastcs/panda/_trigger.py b/src/ophyd_async/fastcs/panda/_trigger.py index 3eee7fe33d..0aa3633760 100644 --- a/src/ophyd_async/fastcs/panda/_trigger.py +++ b/src/ophyd_async/fastcs/panda/_trigger.py @@ -4,7 +4,7 @@ from ophyd_async.core import FlyerController, wait_for_value -from ._block import PcompBlock, PcompDirectionOptions, SeqBlock, TimeUnits +from ._block import BitMux, PcompBlock, PcompDirection, SeqBlock, TimeUnits from ._table import SeqTable @@ -21,7 +21,7 @@ def __init__(self, seq: SeqBlock) -> None: async def prepare(self, value: SeqTableInfo): await asyncio.gather( self.seq.prescale_units.set(TimeUnits.us), - self.seq.enable.set("ZERO"), + self.seq.enable.set(BitMux.zero), ) await asyncio.gather( self.seq.prescale.set(value.prescale_as_us), @@ -30,14 +30,14 @@ async def prepare(self, value: SeqTableInfo): ) async def kickoff(self) -> None: - await self.seq.enable.set("ONE") + await self.seq.enable.set(BitMux.one) await wait_for_value(self.seq.active, True, timeout=1) async def complete(self) -> None: await wait_for_value(self.seq.active, False, timeout=None) async def stop(self): - await self.seq.enable.set("ZERO") + await self.seq.enable.set(BitMux.zero) await wait_for_value(self.seq.active, False, timeout=1) @@ -54,7 +54,7 @@ class PcompInfo(BaseModel): ), ge=0, ) - direction: PcompDirectionOptions = Field( + direction: PcompDirection = Field( description=( "Specifies which direction the motor counts should be " "moving. Pulses won't be sent unless the values are moving in " @@ -68,7 +68,7 @@ def __init__(self, pcomp: PcompBlock) -> None: self.pcomp = pcomp async def prepare(self, value: PcompInfo): - await self.pcomp.enable.set("ZERO") + await self.pcomp.enable.set(BitMux.zero) await asyncio.gather( self.pcomp.start.set(value.start_postion), self.pcomp.width.set(value.pulse_width), @@ -78,12 +78,12 @@ async def prepare(self, value: PcompInfo): ) async def kickoff(self) -> None: - await self.pcomp.enable.set("ONE") + await self.pcomp.enable.set(BitMux.one) await wait_for_value(self.pcomp.active, True, timeout=1) async def complete(self, timeout: float | None = None) -> None: await wait_for_value(self.pcomp.active, False, timeout=timeout) async def stop(self): - await self.pcomp.enable.set("ZERO") + await self.pcomp.enable.set(BitMux.zero) await wait_for_value(self.pcomp.active, False, timeout=1) diff --git a/src/ophyd_async/fastcs/panda/_writer.py b/src/ophyd_async/fastcs/panda/_writer.py index f435b63ce8..e10613ec55 100644 --- a/src/ophyd_async/fastcs/panda/_writer.py +++ b/src/ophyd_async/fastcs/panda/_writer.py @@ -25,13 +25,11 @@ class PandaHDFWriter(DetectorWriter): def __init__( self, - prefix: str, path_provider: PathProvider, name_provider: NameProvider, panda_data_block: DataBlock, ) -> None: self.panda_data_block = panda_data_block - self._prefix = prefix self._path_provider = path_provider self._name_provider = name_provider self._datasets: list[HDFDataset] = [] @@ -89,8 +87,7 @@ async def _describe(self) -> dict[str, DataKey]: shape=list(ds.shape), dtype="array" if ds.shape != [1] else "number", # PandA data should always be written as Float64 - # Ignore type check until https://github.com/bluesky/event-model/issues/308 - dtype_numpy=" None: HDFDataset( dataset_name, "/" + dataset_name, [1], multiplier=1, chunk_shape=(1024,) ) - for dataset_name in capture_table["name"] + for dataset_name in capture_table.name ] # Warn user if dataset table is empty in PandA diff --git a/src/ophyd_async/plan_stubs/_fly.py b/src/ophyd_async/plan_stubs/_fly.py index 043da476f4..d2e757681e 100644 --- a/src/ophyd_async/plan_stubs/_fly.py +++ b/src/ophyd_async/plan_stubs/_fly.py @@ -9,7 +9,7 @@ in_micros, ) from ophyd_async.fastcs.panda import ( - PcompDirectionOptions, + PcompDirection, PcompInfo, SeqTable, SeqTableInfo, @@ -147,7 +147,7 @@ def fly_and_collect_with_static_pcomp( number_of_pulses: int, pulse_width: int, rising_edge_step: int, - direction: PcompDirectionOptions, + direction: PcompDirection, trigger_info: TriggerInfo, ): # Set up scan and prepare trigger diff --git a/src/ophyd_async/plan_stubs/_nd_attributes.py b/src/ophyd_async/plan_stubs/_nd_attributes.py index 95a473033d..5c3ae4d5a8 100644 --- a/src/ophyd_async/plan_stubs/_nd_attributes.py +++ b/src/ophyd_async/plan_stubs/_nd_attributes.py @@ -16,12 +16,12 @@ def setup_ndattributes( device: NDArrayBaseIO, ndattributes: Sequence[NDAttributePv | NDAttributeParam] ): - xml_text = ET.Element("Attributes") + root = ET.Element("Attributes") for ndattribute in ndattributes: if isinstance(ndattribute, NDAttributeParam): ET.SubElement( - xml_text, + root, "Attribute", name=ndattribute.name, type="PARAM", @@ -32,7 +32,7 @@ def setup_ndattributes( ) elif isinstance(ndattribute, NDAttributePv): ET.SubElement( - xml_text, + root, "Attribute", name=ndattribute.name, type="EPICS_PV", @@ -45,7 +45,8 @@ def setup_ndattributes( f"Invalid type for ndattributes: {type(ndattribute)}. " "Expected NDAttributePv or NDAttributeParam." ) - yield from bps.mv(device.nd_attributes_file, xml_text) + xml_text = ET.tostring(root, encoding="unicode") + yield from bps.abs_set(device.nd_attributes_file, xml_text, wait=True) def setup_ndstats_sum(detector: Device): diff --git a/src/ophyd_async/py.typed b/src/ophyd_async/py.typed new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/ophyd_async/sim/demo/_pattern_detector/_pattern_detector_controller.py b/src/ophyd_async/sim/demo/_pattern_detector/_pattern_detector_controller.py index d11cd7e1e0..05b56dfe96 100644 --- a/src/ophyd_async/sim/demo/_pattern_detector/_pattern_detector_controller.py +++ b/src/ophyd_async/sim/demo/_pattern_detector/_pattern_detector_controller.py @@ -1,7 +1,6 @@ import asyncio -from ophyd_async.core import DetectorController, PathProvider -from ophyd_async.core._detector import TriggerInfo +from ophyd_async.core import DetectorController, PathProvider, TriggerInfo from ._pattern_generator import PatternGenerator diff --git a/src/ophyd_async/tango/__init__.py b/src/ophyd_async/tango/__init__.py index 5b45a067c4..4cf7197a91 100644 --- a/src/ophyd_async/tango/__init__.py +++ b/src/ophyd_async/tango/__init__.py @@ -7,14 +7,13 @@ AttributeProxy, CommandProxy, TangoSignalBackend, - __tango_signal_auto, ensure_proper_executor, get_dtype_extended, get_python_type, get_tango_trl, get_trl_descriptor, infer_python_type, - infer_signal_character, + infer_signal_type, make_backend, tango_signal_r, tango_signal_rw, @@ -32,12 +31,11 @@ "get_trl_descriptor", "get_tango_trl", "infer_python_type", - "infer_signal_character", + "infer_signal_type", "make_backend", "AttributeProxy", "CommandProxy", "ensure_proper_executor", - "__tango_signal_auto", "tango_signal_r", "tango_signal_rw", "tango_signal_w", diff --git a/src/ophyd_async/tango/base_devices/_base_device.py b/src/ophyd_async/tango/base_devices/_base_device.py index 9d01539263..d73328724c 100644 --- a/src/ophyd_async/tango/base_devices/_base_device.py +++ b/src/ophyd_async/tango/base_devices/_base_device.py @@ -1,21 +1,12 @@ from __future__ import annotations -from typing import ( - TypeVar, - get_args, - get_origin, - get_type_hints, -) +from typing import TypeVar -from ophyd_async.core import ( - DEFAULT_TIMEOUT, - Device, - Signal, -) +from ophyd_async.core import Device, DeviceConnector, DeviceFiller from ophyd_async.tango.signal import ( TangoSignalBackend, - __tango_signal_auto, - make_backend, + infer_python_type, + infer_signal_type, ) from tango import DeviceProxy as DeviceProxy from tango.asyncio import DeviceProxy as AsyncDeviceProxy @@ -50,64 +41,14 @@ def __init__( device_proxy: DeviceProxy | None = None, name: str = "", ) -> None: - self.trl = trl if trl else "" - self.proxy = device_proxy - tango_create_children_from_annotations(self) - super().__init__(name=name) - - def set_trl(self, trl: str): - """Set the Tango resource locator.""" - if not isinstance(trl, str): - raise ValueError("TRL must be a string.") - self.trl = trl - - async def connect( - self, - mock: bool = False, - timeout: float = DEFAULT_TIMEOUT, - force_reconnect: bool = False, - ): - if self.trl and self.proxy is None: - self.proxy = await AsyncDeviceProxy(self.trl) - elif self.proxy and not self.trl: - self.trl = self.proxy.name() - - # Set the trl of the signal backends - for child in self.children(): - if isinstance(child[1], Signal): - if isinstance(child[1]._backend, TangoSignalBackend): # noqa: SLF001 - resource_name = child[0].lstrip("_") - read_trl = f"{self.trl}/{resource_name}" - child[1]._backend.set_trl(read_trl, read_trl) # noqa: SLF001 - - if self.proxy is not None: - self.register_signals() - await _fill_proxy_entries(self) - - # set_name should be called again to propagate the new signal names - self.set_name(self.name) - - # Set the polling configuration - if self._polling[0]: - for child in self.children(): - child_type = type(child[1]) - if issubclass(child_type, Signal): - if isinstance(child[1]._backend, TangoSignalBackend): # noqa: SLF001 # type: ignore - child[1]._backend.set_polling(*self._polling) # noqa: SLF001 # type: ignore - child[1]._backend.allow_events(False) # noqa: SLF001 # type: ignore - if self._signal_polling: - for signal_name, polling in self._signal_polling.items(): - if hasattr(self, signal_name): - attr = getattr(self, signal_name) - if isinstance(attr._backend, TangoSignalBackend): # noqa: SLF001 - attr._backend.set_polling(*polling) # noqa: SLF001 - attr._backend.allow_events(False) # noqa: SLF001 - - await super().connect(mock=mock, timeout=timeout) - - # Users can override this method to register new signals - def register_signals(self): - pass + connector = TangoDeviceConnector( + trl=trl, + device_proxy=device_proxy, + polling=self._polling, + signal_polling=self._signal_polling, + ) + connector.create_children_from_annotations(self) + super().__init__(name=name, connector=connector) def tango_polling( @@ -150,76 +91,67 @@ def decorator(cls): return decorator -def tango_create_children_from_annotations( - device: TangoDevice, included_optional_fields: tuple[str, ...] = () -): - """Initialize blocks at __init__ of `device`.""" - for name, device_type in get_type_hints(type(device)).items(): - if name in ("_name", "parent"): - continue - - # device_type, is_optional = _strip_union(device_type) - # if is_optional and name not in included_optional_fields: - # continue - # - # is_device_vector, device_type = _strip_device_vector(device_type) - # if is_device_vector: - # n_device_vector = DeviceVector() - # setattr(device, name, n_device_vector) - - # else: - origin = get_origin(device_type) - origin = origin if origin else device_type - - if issubclass(origin, Signal): - type_args = get_args(device_type) - datatype = type_args[0] if type_args else None - backend = make_backend(datatype=datatype, device_proxy=device.proxy) - setattr(device, name, origin(name=name, backend=backend)) - - elif issubclass(origin, Device) or isinstance(origin, Device): - assert callable(origin), f"{origin} is not callable." - setattr(device, name, origin()) - - -async def _fill_proxy_entries(device: TangoDevice): - if device.proxy is None: - raise RuntimeError(f"Device proxy is not connected for {device.name}") - proxy_trl = device.trl - children = [name.lstrip("_") for name, _ in device.children()] - proxy_attributes = list(device.proxy.get_attribute_list()) - proxy_commands = list(device.proxy.get_command_list()) - combined = proxy_attributes + proxy_commands - - for name in combined: - if name not in children: - full_trl = f"{proxy_trl}/{name}" - try: - auto_signal = await __tango_signal_auto( - trl=full_trl, device_proxy=device.proxy +class TangoDeviceConnector(DeviceConnector): + def __init__( + self, + trl: str | None, + device_proxy: DeviceProxy | None, + polling: tuple[bool, float, float | None, float | None], + signal_polling: dict[str, tuple[bool, float, float, float]], + ) -> None: + self.trl = trl + self.proxy = device_proxy + self._polling = polling + self._signal_polling = signal_polling + + def create_children_from_annotations(self, device: Device): + self._filler = DeviceFiller( + device=device, + signal_backend_factory=TangoSignalBackend, + device_connector_factory=lambda: TangoDeviceConnector( + None, None, (False, 0.1, None, None), {} + ), + ) + + async def connect( + self, device: Device, mock: bool, timeout: float, force_reconnect: bool + ) -> None: + if mock: + # Make 2 entries for each DeviceVector + self._filler.make_soft_device_vector_entries(2) + else: + if self.trl and self.proxy is None: + self.proxy = await AsyncDeviceProxy(self.trl) + elif self.proxy and not self.trl: + self.trl = self.proxy.name() + else: + raise TypeError("Neither proxy nor trl supplied") + + children = sorted( + set() + .union(self.proxy.get_attribute_list()) + .union(self.proxy.get_command_list()) + ) + for name in children: + # TODO: strip attribute name + full_trl = f"{self.trl}/{name}" + signal_type = await infer_signal_type(full_trl, self.proxy) + if signal_type: + backend = self._filler.make_child_signal(name, signal_type) + backend.datatype = await infer_python_type(full_trl, self.proxy) + backend.set_trl(full_trl) + if polling := self._signal_polling.get(name, ()): + backend.set_polling(*polling) + backend.allow_events(False) + elif self._polling[0]: + backend.set_polling(*self._polling) + backend.allow_events(False) + # Check that all the requested children have been created + if unfilled := self._filler.unfilled(): + raise RuntimeError( + f"{device.name}: cannot provision {unfilled} from " + f"{self.trl}: {children}" ) - setattr(device, name, auto_signal) - except RuntimeError as e: - if "Commands with different in and out dtypes" in str(e): - print( - f"Skipping {name}. Commands with different in and out dtypes" - f" are not supported." - ) - continue - raise e - - -# def _strip_union(field: T | T) -> tuple[T, bool]: -# if get_origin(field) is Union: -# args = get_args(field) -# is_optional = type(None) in args -# for arg in args: -# if arg is not type(None): -# return arg, is_optional -# return field, False -# -# -# def _strip_device_vector(field: type[Device]) -> tuple[bool, type[Device]]: -# if get_origin(field) is DeviceVector: -# return True, get_args(field)[0] -# return False, field + # Set the name of the device to name all children + device.set_name(device.name) + return await super().connect(device, mock, timeout, force_reconnect) diff --git a/src/ophyd_async/tango/demo/_counter.py b/src/ophyd_async/tango/demo/_counter.py index c8903dfd6d..b23392d234 100644 --- a/src/ophyd_async/tango/demo/_counter.py +++ b/src/ophyd_async/tango/demo/_counter.py @@ -19,7 +19,7 @@ class TangoCounter(TangoReadable): counts: SignalR[int] sample_time: SignalRW[float] start: SignalX - _reset: SignalX + reset_: SignalX def __init__(self, trl: str | None = "", name=""): super().__init__(trl, name=name) @@ -34,4 +34,4 @@ async def trigger(self) -> None: @AsyncStatus.wrap async def reset(self) -> None: - await self._reset.trigger(wait=True, timeout=DEFAULT_TIMEOUT) + await self.reset_.trigger(wait=True, timeout=DEFAULT_TIMEOUT) diff --git a/src/ophyd_async/tango/demo/_mover.py b/src/ophyd_async/tango/demo/_mover.py index ce50356a55..bb15ac1b50 100644 --- a/src/ophyd_async/tango/demo/_mover.py +++ b/src/ophyd_async/tango/demo/_mover.py @@ -29,7 +29,7 @@ class TangoMover(TangoReadable, Movable, Stoppable): position: SignalRW[float] velocity: SignalRW[float] state: SignalR[DevState] - _stop: SignalX + stop_: SignalX def __init__(self, trl: str | None = "", name=""): super().__init__(trl, name=name) @@ -74,4 +74,4 @@ async def set(self, value: float, timeout: CalculatableTimeout = CALCULATE_TIMEO def stop(self, success: bool = True) -> AsyncStatus: self._set_success = success - return self._stop.trigger() + return self.stop_.trigger() diff --git a/src/ophyd_async/tango/signal/__init__.py b/src/ophyd_async/tango/signal/__init__.py index 8923718b6a..4462f6d8b8 100644 --- a/src/ophyd_async/tango/signal/__init__.py +++ b/src/ophyd_async/tango/signal/__init__.py @@ -1,7 +1,6 @@ from ._signal import ( - __tango_signal_auto, infer_python_type, - infer_signal_character, + infer_signal_type, make_backend, tango_signal_r, tango_signal_rw, @@ -29,11 +28,10 @@ "get_trl_descriptor", "get_tango_trl", "infer_python_type", - "infer_signal_character", + "infer_signal_type", "make_backend", "tango_signal_r", "tango_signal_rw", "tango_signal_w", "tango_signal_x", - "__tango_signal_auto", ) diff --git a/src/ophyd_async/tango/signal/_signal.py b/src/ophyd_async/tango/signal/_signal.py index f9274842d5..26c954ab3e 100644 --- a/src/ophyd_async/tango/signal/_signal.py +++ b/src/ophyd_async/tango/signal/_signal.py @@ -2,21 +2,28 @@ from __future__ import annotations +import logging from enum import Enum, IntEnum import numpy.typing as npt -from ophyd_async.core import DEFAULT_TIMEOUT, SignalR, SignalRW, SignalW, SignalX, T -from ophyd_async.tango.signal._tango_transport import ( - TangoSignalBackend, - get_python_type, +from ophyd_async.core import ( + DEFAULT_TIMEOUT, + Signal, + SignalDatatypeT, + SignalR, + SignalRW, + SignalW, + SignalX, ) from tango import AttrDataFormat, AttrWriteType, CmdArgType, DeviceProxy, DevState from tango.asyncio import DeviceProxy as AsyncDeviceProxy +from ._tango_transport import TangoSignalBackend, get_python_type + def make_backend( - datatype: type[T] | None, + datatype: type[SignalDatatypeT] | None, read_trl: str = "", write_trl: str = "", device_proxy: DeviceProxy | None = None, @@ -25,13 +32,13 @@ def make_backend( def tango_signal_rw( - datatype: type[T], + datatype: type[SignalDatatypeT], read_trl: str, write_trl: str = "", device_proxy: DeviceProxy | None = None, timeout: float = DEFAULT_TIMEOUT, name: str = "", -) -> SignalRW[T]: +) -> SignalRW[SignalDatatypeT]: """Create a `SignalRW` backed by 1 or 2 Tango Attribute/Command Parameters @@ -54,12 +61,12 @@ def tango_signal_rw( def tango_signal_r( - datatype: type[T], + datatype: type[SignalDatatypeT], read_trl: str, device_proxy: DeviceProxy | None = None, timeout: float = DEFAULT_TIMEOUT, name: str = "", -) -> SignalR[T]: +) -> SignalR[SignalDatatypeT]: """Create a `SignalR` backed by 1 Tango Attribute/Command Parameters @@ -80,12 +87,12 @@ def tango_signal_r( def tango_signal_w( - datatype: type[T], + datatype: type[SignalDatatypeT], write_trl: str, device_proxy: DeviceProxy | None = None, timeout: float = DEFAULT_TIMEOUT, name: str = "", -) -> SignalW[T]: +) -> SignalW[SignalDatatypeT]: """Create a `SignalW` backed by 1 Tango Attribute/Command Parameters @@ -128,39 +135,10 @@ def tango_signal_x( return SignalX(backend, timeout=timeout, name=name) -async def __tango_signal_auto( - datatype: type[T] | None = None, - *, - trl: str, - device_proxy: DeviceProxy | None, - timeout: float = DEFAULT_TIMEOUT, - name: str = "", -) -> SignalW | SignalX | SignalR | SignalRW | None: - try: - signal_character = await infer_signal_character(trl, device_proxy) - except RuntimeError as e: - if "Commands with different in and out dtypes" in str(e): - return None - else: - raise e - - if datatype is None: - datatype = await infer_python_type(trl, device_proxy) - - backend = make_backend(datatype, trl, trl, device_proxy) - if signal_character == "RW": - return SignalRW(backend=backend, timeout=timeout, name=name) - if signal_character == "R": - return SignalR(backend=backend, timeout=timeout, name=name) - if signal_character == "W": - return SignalW(backend=backend, timeout=timeout, name=name) - if signal_character == "X": - return SignalX(backend=backend, timeout=timeout, name=name) - - async def infer_python_type( trl: str = "", proxy: DeviceProxy | None = None ) -> object | npt.NDArray | type[DevState] | IntEnum: + # TODO: work out if this is still needed device_trl, tr_name = trl.rsplit("/", 1) if proxy is None: dev_proxy = await AsyncDeviceProxy(device_trl) @@ -187,7 +165,9 @@ async def infer_python_type( return npt.NDArray[py_type] if isarray else py_type -async def infer_signal_character(trl, proxy: DeviceProxy | None = None) -> str: +async def infer_signal_type( + trl, proxy: DeviceProxy | None = None +) -> type[Signal] | None: device_trl, tr_name = trl.rsplit("/", 1) if proxy is None: dev_proxy = await AsyncDeviceProxy(device_trl) @@ -204,20 +184,19 @@ async def infer_signal_character(trl, proxy: DeviceProxy | None = None) -> str: if tr_name in dev_proxy.get_attribute_list(): config = await dev_proxy.get_attribute_config(tr_name) if config.writable in [AttrWriteType.READ_WRITE, AttrWriteType.READ_WITH_WRITE]: - return "RW" + return SignalRW elif config.writable == AttrWriteType.READ: - return "R" + return SignalR else: - return "W" + return SignalW if tr_name in dev_proxy.get_command_list(): config = await dev_proxy.get_command_config(tr_name) if config.in_type == CmdArgType.DevVoid: - return "X" + return SignalX elif config.in_type != config.out_type: - raise RuntimeError( - "Commands with different in and out dtypes are not" " supported" - ) + logging.debug("Commands with different in and out dtypes are not supported") + return None else: - return "RW" + return SignalRW raise RuntimeError(f"Unable to infer signal character for {trl}") diff --git a/src/ophyd_async/tango/signal/_tango_transport.py b/src/ophyd_async/tango/signal/_tango_transport.py index 54cea4b610..6de2c85e96 100644 --- a/src/ophyd_async/tango/signal/_tango_transport.py +++ b/src/ophyd_async/tango/signal/_tango_transport.py @@ -2,7 +2,6 @@ import functools import time from abc import abstractmethod -from asyncio import CancelledError from collections.abc import Callable, Coroutine from enum import Enum from typing import Any, TypeVar, cast @@ -11,12 +10,11 @@ from bluesky.protocols import Descriptor, Reading from ophyd_async.core import ( - DEFAULT_TIMEOUT, AsyncStatus, + Callback, NotConnected, - ReadingValueCallback, SignalBackend, - T, + SignalDatatypeT, get_dtype, get_unique, wait_for_connection, @@ -121,7 +119,7 @@ def has_subscription(self) -> bool: """indicates, that this trl already subscribed""" @abstractmethod - def subscribe_callback(self, callback: ReadingValueCallback | None): + def subscribe_callback(self, callback: Callback | None): """subscribe tango CHANGE event to callback""" @abstractmethod @@ -140,7 +138,7 @@ def set_polling( class AttributeProxy(TangoProxy): - _callback: ReadingValueCallback | None = None + _callback: Callback | None = None _eid: int | None = None _poll_task: asyncio.Task | None = None _abs_change: float | None = None @@ -178,6 +176,7 @@ async def get_w_value(self) -> object: async def put( self, value: object | None, wait: bool = True, timeout: float | None = None ) -> AsyncStatus | None: + # TODO: remove the timeout from this as it is handled at the signal level if wait: try: @@ -236,7 +235,7 @@ async def get_reading(self) -> Reading: def has_subscription(self) -> bool: return bool(self._callback) - def subscribe_callback(self, callback: ReadingValueCallback | None): + def subscribe_callback(self, callback: Callback | None): # If the attribute supports events, then we can subscribe to them # If the callback is not a callable, then we raise an error if callback is not None and not callable(callback): @@ -283,21 +282,20 @@ def unsubscribe_callback(self): if self._callback is not None: # Call the callback with the last reading try: - self._callback(self._last_reading, self._last_reading["value"]) + self._callback(self._last_reading) except TypeError: pass self._callback = None def _event_processor(self, event): if not event.err: - value = event.attr_value.value reading = Reading( - value=value, + value=event.attr_value.value, timestamp=event.get_date().totime(), alarm_severity=event.attr_value.quality, ) if self._callback is not None: - self._callback(reading, value) + self._callback(reading) async def poll(self): """ @@ -310,7 +308,7 @@ async def poll(self): flag = 0 # Initial reading if self._callback is not None: - self._callback(last_reading, last_reading["value"]) + self._callback(last_reading) except Exception as e: raise RuntimeError(f"Could not poll the attribute: {e}") from e @@ -325,7 +323,7 @@ async def poll(self): diff = abs(reading["value"] - last_reading["value"]) if self._abs_change is not None and diff >= abs(self._abs_change): if self._callback is not None: - self._callback(reading, reading["value"]) + self._callback(reading) flag = 0 elif ( @@ -333,13 +331,13 @@ async def poll(self): and diff >= self._rel_change * abs(last_reading["value"]) ): if self._callback is not None: - self._callback(reading, reading["value"]) + self._callback(reading) flag = 0 else: flag = (flag + 1) % 4 if flag == 0 and self._callback is not None: - self._callback(reading, reading["value"]) + self._callback(reading) last_reading = reading.copy() if self._callback is None: @@ -358,13 +356,13 @@ async def poll(self): reading["value"], last_reading["value"] ): if self._callback is not None: - self._callback(reading, reading["value"]) + self._callback(reading) else: break else: if reading["value"] != last_reading["value"]: if self._callback is not None: - self._callback(reading, reading["value"]) + self._callback(reading) else: break last_reading = reading.copy() @@ -390,7 +388,7 @@ def set_polling( class CommandProxy(TangoProxy): _last_reading: Reading = Reading(value=None, timestamp=0, alarm_severity=0) - def subscribe_callback(self, callback: ReadingValueCallback | None) -> None: + def subscribe_callback(self, callback: Callback | None) -> None: raise NotImplementedError("Cannot subscribe to commands") def unsubscribe_callback(self) -> None: @@ -584,14 +582,14 @@ def get_trl_descriptor( async def get_tango_trl( - full_trl: str, device_proxy: DeviceProxy | TangoProxy | None + full_trl: str, device_proxy: DeviceProxy | TangoProxy | None, timeout: float ) -> TangoProxy: if isinstance(device_proxy, TangoProxy): return device_proxy device_trl, trl_name = full_trl.rsplit("/", 1) trl_name = trl_name.lower() if device_proxy is None: - device_proxy = await AsyncDeviceProxy(device_trl) + device_proxy = await AsyncDeviceProxy(device_trl, timeout=timeout) # all attributes can be always accessible with low register if isinstance(device_proxy, DeviceProxy): @@ -620,16 +618,15 @@ async def get_tango_trl( raise RuntimeError(f"{trl_name} cannot be found in {device_proxy.name()}") -class TangoSignalBackend(SignalBackend[T]): +class TangoSignalBackend(SignalBackend[SignalDatatypeT]): def __init__( self, - datatype: type[T] | None, + datatype: type[SignalDatatypeT] | None, read_trl: str = "", write_trl: str = "", device_proxy: DeviceProxy | None = None, ): self.device_proxy = device_proxy - self.datatype = datatype self.read_trl = read_trl self.write_trl = write_trl self.proxies: dict[str, TangoProxy | DeviceProxy | None] = { @@ -646,6 +643,7 @@ def __init__( ) self.support_events: bool = True self.status: AsyncStatus | None = None + super().__init__(datatype) @classmethod def datatype_allowed(cls, dtype: Any) -> bool: @@ -659,14 +657,14 @@ def set_trl(self, read_trl: str = "", write_trl: str = ""): write_trl: self.device_proxy, } - def source(self, name: str) -> str: - return self.read_trl + def source(self, name: str, read: bool) -> str: + return self.read_trl if read else self.write_trl - async def _connect_and_store_config(self, trl: str) -> None: + async def _connect_and_store_config(self, trl: str, timeout: float) -> None: if not trl: raise RuntimeError(f"trl not set for {self}") try: - self.proxies[trl] = await get_tango_trl(trl, self.proxies[trl]) + self.proxies[trl] = await get_tango_trl(trl, self.proxies[trl], timeout) if self.proxies[trl] is None: raise NotConnected(f"Not connected to {trl}") # Pyright does not believe that self.proxies[trl] is not None despite @@ -674,27 +672,27 @@ async def _connect_and_store_config(self, trl: str) -> None: await self.proxies[trl].connect() # type: ignore self.trl_configs[trl] = await self.proxies[trl].get_config() # type: ignore self.proxies[trl].support_events = self.support_events # type: ignore - except CancelledError as ce: - raise NotConnected(f"Could not connect to {trl}") from ce + except TimeoutError as ce: + raise NotConnected(f"tango://{trl}") from ce - async def connect(self, timeout: float = DEFAULT_TIMEOUT) -> None: + async def connect(self, timeout: float) -> None: if not self.read_trl: raise RuntimeError(f"trl not set for {self}") if self.read_trl != self.write_trl: # Different, need to connect both await wait_for_connection( - read_trl=self._connect_and_store_config(self.read_trl), - write_trl=self._connect_and_store_config(self.write_trl), + read_trl=self._connect_and_store_config(self.read_trl, timeout), + write_trl=self._connect_and_store_config(self.write_trl, timeout), ) else: # The same, so only need to connect one - await self._connect_and_store_config(self.read_trl) + await self._connect_and_store_config(self.read_trl, timeout) self.proxies[self.read_trl].set_polling(*self._polling) # type: ignore self.descriptor = get_trl_descriptor( self.datatype, self.read_trl, self.trl_configs ) - async def put(self, value: T | None, wait=True, timeout=None) -> None: + async def put(self, value: SignalDatatypeT | None, wait=True, timeout=None) -> None: if self.proxies[self.write_trl] is None: raise NotConnected(f"Not connected to {self.write_trl}") self.status = None @@ -704,28 +702,28 @@ async def put(self, value: T | None, wait=True, timeout=None) -> None: async def get_datakey(self, source: str) -> Descriptor: return self.descriptor - async def get_reading(self) -> Reading: + async def get_reading(self) -> Reading[SignalDatatypeT]: if self.proxies[self.read_trl] is None: raise NotConnected(f"Not connected to {self.read_trl}") return await self.proxies[self.read_trl].get_reading() # type: ignore - async def get_value(self) -> T: + async def get_value(self) -> SignalDatatypeT: if self.proxies[self.read_trl] is None: raise NotConnected(f"Not connected to {self.read_trl}") proxy = self.proxies[self.read_trl] if proxy is None: raise NotConnected(f"Not connected to {self.read_trl}") - return cast(T, await proxy.get()) + return cast(SignalDatatypeT, await proxy.get()) - async def get_setpoint(self) -> T: + async def get_setpoint(self) -> SignalDatatypeT: if self.proxies[self.write_trl] is None: raise NotConnected(f"Not connected to {self.write_trl}") proxy = self.proxies[self.write_trl] if proxy is None: raise NotConnected(f"Not connected to {self.write_trl}") - return cast(T, await proxy.get_w_value()) + return cast(SignalDatatypeT, await proxy.get_w_value()) - def set_callback(self, callback: ReadingValueCallback | None) -> None: + def set_callback(self, callback: Callback | None) -> None: if self.proxies[self.read_trl] is None: raise NotConnected(f"Not connected to {self.read_trl}") if self.support_events is False and self._polling[0] is False: diff --git a/tests/core/test_device.py b/tests/core/test_device.py index 85d2a0fc14..0fff56357b 100644 --- a/tests/core/test_device.py +++ b/tests/core/test_device.py @@ -9,19 +9,20 @@ Device, DeviceCollector, DeviceVector, - MockSignalBackend, NotConnected, - SoftSignalBackend, + Reference, + SignalRW, + soft_signal_rw, wait_for_connection, ) from ophyd_async.epics import motor from ophyd_async.plan_stubs import ensure_connected -from ophyd_async.sim.demo import SimMotor class DummyBaseDevice(Device): def __init__(self) -> None: self.connected = False + super().__init__() async def connect( self, mock=False, timeout=DEFAULT_TIMEOUT, force_reconnect: bool = False @@ -32,11 +33,11 @@ async def connect( class DummyDeviceGroup(Device): def __init__(self, name: str) -> None: self.child1 = DummyBaseDevice() - self.child2 = DummyBaseDevice() + self._child2 = DummyBaseDevice() self.dict_with_children: DeviceVector[DummyBaseDevice] = DeviceVector( {123: DummyBaseDevice()} ) - self.set_name(name) + super().__init__(name) @pytest.fixture @@ -44,23 +45,55 @@ def parent() -> DummyDeviceGroup: return DummyDeviceGroup("parent") +class DeviceWithNamedChild(Device): + def __init__(self, name: str = "") -> None: + super().__init__(name) + self.child = soft_signal_rw(int, name="foo") + + +def test_device_signal_naming(): + device = DeviceWithNamedChild("bar") + assert device.name == "bar" + assert device.child.name == "foo" + + +class DeviceWithRefToSignal(Device): + def __init__(self, signal: SignalRW[int]): + self.signal_ref = Reference(signal) + super().__init__(name="bat") + + def get_source(self) -> str: + return self.signal_ref().source + + +def test_device_with_signal_ref_does_not_rename(): + device = DeviceWithNamedChild() + device.set_name("bar") + assert dict(device.children()) == {"child": device.child} + private_device = DeviceWithRefToSignal(device.child) + assert device.child.source == private_device.get_source() + assert dict(private_device.children()) == {} + assert device.name == "bar" + assert device.child.name == "bar-child" + assert private_device.name == "bat" + + def test_device_children(parent: DummyDeviceGroup): - names = ["child1", "child2", "dict_with_children"] + names = ["child1", "_child2", "dict_with_children"] for idx, (name, child) in enumerate(parent.children()): assert name == names[idx] - assert ( - type(child) is DummyBaseDevice - if name.startswith("child") - else type(child) is DeviceVector + expected_type = ( + DeviceVector if name == "dict_with_children" else DummyBaseDevice ) + assert type(child) is expected_type assert child.parent == parent def test_device_vector_children(): parent = DummyDeviceGroup("root") - device_vector_children = list(parent.dict_with_children.children()) - assert device_vector_children == [("123", parent.dict_with_children[123])] + device_vector_children = list(parent.dict_with_children.items()) + assert device_vector_children == [(123, parent.dict_with_children[123])] async def test_children_of_device_have_set_names_and_get_connected( @@ -68,7 +101,7 @@ async def test_children_of_device_have_set_names_and_get_connected( ): assert parent.name == "parent" assert parent.child1.name == "parent-child1" - assert parent.child2.name == "parent-child2" + assert parent._child2.name == "parent-child2" assert parent.dict_with_children.name == "parent-dict_with_children" assert parent.dict_with_children[123].name == "parent-dict_with_children-123" @@ -84,7 +117,7 @@ async def test_device_with_device_collector(): assert parent.name == "parent" assert parent.child1.name == "parent-child1" - assert parent.child2.name == "parent-child2" + assert parent._child2.name == "parent-child2" assert parent.dict_with_children.name == "parent-dict_with_children" assert parent.dict_with_children[123].name == "parent-dict_with_children-123" assert parent.child1.connected @@ -127,62 +160,6 @@ async def test_device_log_has_correct_name(): assert device.log.extra["ophyd_async_device_name"] == "device" -async def test_device_lazily_connects(RE): - class MockSignalBackendFailingFirst(MockSignalBackend): - succeed_on_connect = False - - async def connect(self, timeout=DEFAULT_TIMEOUT): - if self.succeed_on_connect: - self.succeed_on_connect = False - await super().connect(timeout=timeout) - else: - self.succeed_on_connect = True - raise RuntimeError("connect fail") - - test_motor = motor.Motor("BLxxI-MO-TABLE-01:X") - test_motor.user_setpoint._backend = MockSignalBackendFailingFirst(int) - - with pytest.raises(NotConnected, match="RuntimeError: connect fail"): - await test_motor.connect(mock=True) - - assert ( - test_motor._connect_task - and test_motor._connect_task.done() - and test_motor._connect_task.exception() - ) - - RE(ensure_connected(test_motor, mock=True)) - - assert ( - test_motor._connect_task - and test_motor._connect_task.done() - and not test_motor._connect_task.exception() - ) - - with pytest.raises(NotConnected, match="RuntimeError: connect fail"): - RE(ensure_connected(test_motor, mock=True, force_reconnect=True)) - - assert ( - test_motor._connect_task - and test_motor._connect_task.done() - and test_motor._connect_task.exception() - ) - - -async def test_device_refuses_two_connects_differing_on_mock_attribute(RE): - motor = SimMotor("motor") - assert not motor._connect_task - await motor.connect(mock=False) - assert isinstance(motor.units._backend, SoftSignalBackend) - assert motor._connect_task - with pytest.raises(RuntimeError) as exc: - await motor.connect(mock=True) - assert str(exc.value) == ( - "`connect(mock=True)` called on a `Device` where the previous connect was " - "`mock=False`. Changing mock value between connects is not permitted." - ) - - class MotorBundle(Device): def __init__(self, name: str) -> None: self.X = motor.Motor("BLxxI-MO-TABLE-01:X") @@ -194,6 +171,7 @@ def __init__(self, name: str) -> None: 2: motor.Motor("BLxxI-MO-TABLE-21:Z"), } ) + super().__init__(name) async def test_device_with_children_lazily_connects(RE): @@ -215,23 +193,6 @@ async def test_device_with_children_lazily_connects(RE): ) -async def test_device_with_device_collector_refuses_to_connect_if_mock_switch(): - mock_motor = motor.Motor("NONE_EXISTENT") - with pytest.raises(NotConnected): - await mock_motor.connect(mock=False, timeout=0.01) - assert ( - mock_motor._connect_task is not None - and mock_motor._connect_task.done() - and mock_motor._connect_task.exception() - ) - with pytest.raises(RuntimeError) as exc: - await mock_motor.connect(mock=True, timeout=0.01) - assert str(exc.value) == ( - "`connect(mock=True)` called on a `Device` where the previous connect was " - "`mock=False`. Changing mock value between connects is not permitted." - ) - - async def test_no_reconnect_signals_if_not_forced(): parent = DummyDeviceGroup("parent") diff --git a/tests/core/test_device_collector.py b/tests/core/test_device_collector.py index cdb6e122c5..856016511d 100644 --- a/tests/core/test_device_collector.py +++ b/tests/core/test_device_collector.py @@ -15,14 +15,18 @@ class FailingDevice(Device): - async def connect(self, mock: bool = False, timeout=DEFAULT_TIMEOUT): + async def connect( + self, mock: bool = False, timeout=DEFAULT_TIMEOUT, force_reconnect=False + ): raise AttributeError() class WorkingDevice(Device): connected = False - async def connect(self, mock: bool = True, timeout=DEFAULT_TIMEOUT): + async def connect( + self, mock: bool = True, timeout=DEFAULT_TIMEOUT, force_reconnect=False + ): self.connected = True return await super().connect(mock=True) diff --git a/tests/core/test_device_save_loader.py b/tests/core/test_device_save_loader.py index 03dd73d2c0..e81918d2f6 100644 --- a/tests/core/test_device_save_loader.py +++ b/tests/core/test_device_save_loader.py @@ -1,5 +1,4 @@ from collections.abc import Sequence -from enum import Enum from os import path from typing import Any from unittest.mock import patch @@ -9,13 +8,13 @@ import pytest import yaml from bluesky.run_engine import RunEngine -from pydantic import BaseModel, Field -from pydantic_numpy.typing import NpNDArrayFp16, NpNDArrayInt32 from ophyd_async.core import ( + Array1D, Device, - SignalR, SignalRW, + StrictEnum, + Table, all_at_once, get_signal_values, load_device, @@ -28,43 +27,41 @@ from ophyd_async.epics.signal import epics_signal_r, epics_signal_rw -class DummyChildDevice(Device): - def __init__(self) -> None: - self.sig1: SignalRW = epics_signal_rw(str, "Value1") - self.sig2: SignalR = epics_signal_r(str, "Value2") - - -class EnumTest(str, Enum): +class EnumTest(StrictEnum): VAL1 = "val1" VAL2 = "val2" +class DummyChildDevice(Device): + def __init__(self) -> None: + self.str_sig = epics_signal_rw(str, "StrSignal") + super().__init__() + + class DummyDeviceGroup(Device): def __init__(self, name: str): - self.child1: DummyChildDevice = DummyChildDevice() - self.child2: DummyChildDevice = DummyChildDevice() - self.parent_sig1: SignalRW = epics_signal_rw(str, "ParentValue1") - self.parent_sig2: SignalR = epics_signal_r( + self.child1 = DummyChildDevice() + self.child2 = DummyChildDevice() + self.str_sig = epics_signal_rw(str, "ParentValue1") + self.parent_sig2 = epics_signal_r( int, "ParentValue2" ) # Ensure only RW are found - self.parent_sig3: SignalRW = epics_signal_rw(str, "ParentValue3") - self.position: npt.NDArray[np.int32] + self.table_sig = epics_signal_rw(Table, "TableSignal") + self.array_sig = epics_signal_rw(Array1D[np.uint32], "ArraySignal") + self.enum_sig = epics_signal_rw(EnumTest, "EnumSignal") + super().__init__(name) -class MyEnum(str, Enum): +class MyEnum(StrictEnum): one = "one" two = "two" three = "three" -class SomePvaPydanticModel(BaseModel): - some_int_field: int = Field(default=1) - some_pydantic_numpy_field_float: NpNDArrayFp16 = Field( - default_factory=lambda: np.array([1, 2, 3]) - ) - some_pydantic_numpy_field_int: NpNDArrayInt32 = Field( - default_factory=lambda: np.array([1, 2, 3]) - ) +class SomeTable(Table): + some_float: Array1D[np.float64] + some_int: Array1D[np.int32] + some_enum: Sequence[MyEnum] class DummyDeviceGroupAllTypes(Device): @@ -86,9 +83,8 @@ def __init__(self, name: str): self.pv_array_float64 = epics_signal_rw(npt.NDArray[np.float64], "PV14") self.pv_array_npstr = epics_signal_rw(npt.NDArray[np.str_], "PV15") self.pv_array_str = epics_signal_rw(Sequence[str], "PV16") - self.pv_protocol_device_abstraction = epics_signal_rw( - SomePvaPydanticModel, "pva://PV17" - ) + self.pv_protocol_device_abstraction = epics_signal_rw(Table, "pva://PV17") + super().__init__(name) @pytest.fixture @@ -107,9 +103,14 @@ async def device_all_types() -> DummyDeviceGroupAllTypes: # Dummy function to check different phases save properly def sort_signal_by_phase(values: dict[str, Any]) -> list[dict[str, Any]]: - phase_1 = {"child1.sig1": values["child1.sig1"]} - phase_2 = {"child2.sig1": values["child2.sig1"]} - return [phase_1, phase_2] + phase_1 = {"child1.str_sig": values["child1.str_sig"]} + phase_2 = {"child2.str_sig": values["child2.str_sig"]} + phase_3 = { + key: value + for key, value in values.items() + if key not in phase_1 and key not in phase_2 + } + return [phase_1, phase_2, phase_3] async def test_enum_yaml_formatting(tmp_path): @@ -123,7 +124,9 @@ async def test_enum_yaml_formatting(tmp_path): assert saved_enums == enums -async def test_save_device_all_types(RE: RunEngine, device_all_types, tmp_path): +async def test_save_device_all_types( + RE: RunEngine, device_all_types: DummyDeviceGroupAllTypes, tmp_path +): # Populate fake device with PV's... await device_all_types.pv_int.set(1) await device_all_types.pv_float.set(1.234) @@ -171,7 +174,13 @@ async def test_save_device_all_types(RE: RunEngine, device_all_types, tmp_path): await device_all_types.pv_array_str.set( ["one", "two", "three"], ) - await device_all_types.pv_protocol_device_abstraction.set(SomePvaPydanticModel()) + await device_all_types.pv_protocol_device_abstraction.set( + SomeTable( + some_float=np.arange(3, dtype=np.float64), + some_int=np.arange(3), + some_enum=[MyEnum.one, MyEnum.two, MyEnum.three], + ) + ) # Create save plan from utility functions def save_my_device(): @@ -185,39 +194,40 @@ def save_my_device(): actual_file_path = path.join(tmp_path, "test_file.yaml") with open(actual_file_path) as actual_file: with open("tests/test_data/test_yaml_save.yml") as expected_file: - assert actual_file.read() == expected_file.read() + assert yaml.safe_load(actual_file) == yaml.safe_load(expected_file) -async def test_save_device(RE: RunEngine, device, tmp_path): +async def test_save_device(RE: RunEngine, device: DummyDeviceGroup, tmp_path): # Populate fake device with PV's... - await device.child1.sig1.set("test_string") + await device.child1.str_sig.set("test_string") # Test tables PVs table_pv = {"VAL1": np.array([1, 1, 1, 1, 1]), "VAL2": np.array([1, 1, 1, 1, 1])} - await device.child2.sig1.set(table_pv) - - # Test enum PVs - await device.parent_sig3.set(EnumTest.VAL1) + array_pv = np.array([2, 2, 2, 2, 2]) + await device.array_sig.set(array_pv) + await device.table_sig.set(table_pv) + await device.enum_sig.set(EnumTest.VAL2) # Create save plan from utility functions def save_my_device(): signalRWs = walk_rw_signals(device) assert list(signalRWs.keys()) == [ - "child1.sig1", - "child2.sig1", - "parent_sig1", - "parent_sig3", + "child1.str_sig", + "child2.str_sig", + "str_sig", + "table_sig", + "array_sig", + "enum_sig", ] assert all(isinstance(signal, SignalRW) for signal in list(signalRWs.values())) - values = yield from get_signal_values(signalRWs, ignore=["parent_sig1"]) - - assert values == { - "child1.sig1": "test_string", - "child2.sig1": table_pv, - "parent_sig3": "val1", - "parent_sig1": None, - } + values = yield from get_signal_values(signalRWs, ignore=["str_sig"]) + assert np.array_equal(values["array_sig"], array_pv) + assert values["enum_sig"] == "val2" + assert values["table_sig"] == Table(**table_pv) + assert values["str_sig"] is None + assert values["child1.str_sig"] == "test_string" + assert values["child2.str_sig"] == "" save_to_yaml([values], path.join(tmp_path, "test_file.yaml")) @@ -225,60 +235,74 @@ def save_my_device(): with open(path.join(tmp_path, "test_file.yaml")) as file: yaml_content = yaml.load(file, yaml.Loader)[0] - assert len(yaml_content) == 4 - assert yaml_content["child1.sig1"] == "test_string" - assert np.array_equal( - yaml_content["child2.sig1"]["VAL1"], np.array([1, 1, 1, 1, 1]) - ) - assert np.array_equal( - yaml_content["child2.sig1"]["VAL2"], np.array([1, 1, 1, 1, 1]) - ) - assert yaml_content["parent_sig3"] == "val1" - assert yaml_content["parent_sig1"] is None + assert yaml_content["child1.str_sig"] == "test_string" + assert yaml_content["child2.str_sig"] == "" + assert np.array_equal(yaml_content["table_sig"]["VAL1"], table_pv["VAL1"]) + assert np.array_equal(yaml_content["table_sig"]["VAL2"], table_pv["VAL2"]) + assert np.array_equal(yaml_content["array_sig"], array_pv) + assert yaml_content["enum_sig"] == "val2" + assert yaml_content["str_sig"] is None -async def test_yaml_formatting(RE: RunEngine, device, tmp_path): +async def test_yaml_formatting(RE: RunEngine, device: DummyDeviceGroup, tmp_path): file_path = path.join(tmp_path, "test_file.yaml") - await device.child1.sig1.set("test_string") - table_pv = {"VAL1": np.array([1, 2, 3, 4, 5]), "VAL2": np.array([6, 7, 8, 9, 10])} - await device.child2.sig1.set(table_pv) + await device.child1.str_sig.set("test_string") + table = {"VAL1": np.array([1, 2, 3, 4, 5]), "VAL2": np.array([6, 7, 8, 9, 10])} + await device.array_sig.set(np.array([11, 12, 13, 14, 15])) + await device.table_sig.set(table) + await device.enum_sig.set(EnumTest.VAL2) RE(save_device(device, file_path, sorter=sort_signal_by_phase)) with open(file_path) as file: expected = """\ -- child1.sig1: test_string -- child2.sig1: +- child1.str_sig: test_string +- child2.str_sig: '' +- array_sig: [11, 12, 13, 14, 15] + enum_sig: val2 + str_sig: '' + table_sig: VAL1: [1, 2, 3, 4, 5] VAL2: [6, 7, 8, 9, 10] """ + # assert False, file.read() assert file.read() == expected -async def test_load_from_yaml(RE: RunEngine, device, tmp_path): +async def test_load_from_yaml(RE: RunEngine, device: DummyDeviceGroup, tmp_path): file_path = path.join(tmp_path, "test_file.yaml") array = np.array([1, 1, 1, 1, 1]) - await device.child1.sig1.set("initial_string") - await device.child2.sig1.set(array) - await device.parent_sig1.set(None) + table = {"VAL1": np.array([1, 2, 3, 4, 5]), "VAL2": np.array([6, 7, 8, 9, 10])} + await device.child1.str_sig.set("initial_string") + await device.array_sig.set(array) + await device.str_sig.set(None) + await device.enum_sig.set(EnumTest.VAL2) + await device.table_sig.set(table) RE(save_device(device, file_path, sorter=sort_signal_by_phase)) values = load_from_yaml(file_path) - assert values[0]["child1.sig1"] == "initial_string" - assert np.array_equal(values[1]["child2.sig1"], array) + assert values[0]["child1.str_sig"] == "initial_string" + assert values[1]["child2.str_sig"] == "" + assert values[2]["str_sig"] == "" + assert values[2]["enum_sig"] == "val2" + assert np.array_equal(values[2]["array_sig"], array) + assert np.array_equal(values[2]["table_sig"]["VAL1"], table["VAL1"]) + assert np.array_equal(values[2]["table_sig"]["VAL2"], table["VAL2"]) -async def test_set_signal_values_restores_value(RE: RunEngine, device, tmp_path): +async def test_set_signal_values_restores_value( + RE: RunEngine, device: DummyDeviceGroup, tmp_path +): file_path = path.join(tmp_path, "test_file.yaml") - await device.child1.sig1.set("initial_string") - await device.child2.sig1.set(np.array([1, 1, 1, 1, 1])) + await device.str_sig.set("initial_string") + await device.array_sig.set(np.array([1, 1, 1, 1, 1])) RE(save_device(device, file_path, sorter=sort_signal_by_phase)) - await device.child1.sig1.set("changed_string") - await device.child2.sig1.set(np.array([2, 2, 2, 2, 2])) - string_value = await device.child1.sig1.get_value() - array_value = await device.child2.sig1.get_value() + await device.str_sig.set("changed_string") + await device.array_sig.set(np.array([2, 2, 2, 2, 2])) + string_value = await device.str_sig.get_value() + array_value = await device.array_sig.get_value() assert string_value == "changed_string" assert np.array_equal(array_value, np.array([2, 2, 2, 2, 2])) @@ -287,8 +311,8 @@ async def test_set_signal_values_restores_value(RE: RunEngine, device, tmp_path) RE(set_signal_values(signals_to_set, values)) - string_value = await device.child1.sig1.get_value() - array_value = await device.child2.sig1.get_value() + string_value = await device.str_sig.get_value() + array_value = await device.array_sig.get_value() assert string_value == "initial_string" assert np.array_equal(array_value, np.array([1, 1, 1, 1, 1])) @@ -297,7 +321,10 @@ async def test_set_signal_values_restores_value(RE: RunEngine, device, tmp_path) @patch("ophyd_async.core._device_save_loader.walk_rw_signals") @patch("ophyd_async.core._device_save_loader.set_signal_values") async def test_load_device( - mock_set_signal_values, mock_walk_rw_signals, mock_load_from_yaml, device + mock_set_signal_values, + mock_walk_rw_signals, + mock_load_from_yaml, + device: DummyDeviceGroup, ): RE = RunEngine() RE(load_device(device, "path")) @@ -306,22 +333,22 @@ async def test_load_device( mock_set_signal_values.assert_called_once() -async def test_set_signal_values_skips_ignored_values(device): +async def test_set_signal_values_skips_ignored_values(device: DummyDeviceGroup): RE = RunEngine() array = np.array([1, 1, 1, 1, 1]) - await device.child1.sig1.set("initial_string") - await device.child2.sig1.set(array) - await device.parent_sig1.set(None) + await device.child1.str_sig.set("initial_string") + await device.array_sig.set(array) + await device.str_sig.set(None) signals_of_device = walk_rw_signals(device) - values_to_set = [{"child1.sig1": None, "child2.sig1": np.array([2, 3, 4])}] + values_to_set = [{"child1.str_sig": None, "array_sig": np.array([2, 3, 4])}] RE(set_signal_values(signals_of_device, values_to_set)) - assert np.all(await device.child2.sig1.get_value() == np.array([2, 3, 4])) - assert await device.child1.sig1.get_value() == "initial_string" + assert np.all(await device.array_sig.get_value() == np.array([2, 3, 4])) + assert await device.child1.str_sig.get_value() == "initial_string" def test_all_at_once_sorter(): - assert all_at_once({"child1.sig1": 0}) == [{"child1.sig1": 0}] + assert all_at_once({"child1.str_sig": 0}) == [{"child1.str_sig": 0}] diff --git a/tests/core/test_flyer.py b/tests/core/test_flyer.py index 45ddaa8c86..79bcf8fe7e 100644 --- a/tests/core/test_flyer.py +++ b/tests/core/test_flyer.py @@ -1,6 +1,5 @@ import time from collections.abc import AsyncGenerator, AsyncIterator, Sequence -from enum import Enum from typing import Any from unittest.mock import Mock @@ -19,14 +18,15 @@ FlyerController, StandardDetector, StandardFlyer, + StrictEnum, TriggerInfo, + assert_emitted, observe_value, ) -from ophyd_async.core._signal import assert_emitted from ophyd_async.epics.signal import epics_signal_rw -class TriggerState(str, Enum): +class TriggerState(StrictEnum): null = "null" preparing = "preparing" starting = "starting" diff --git a/tests/core/test_log.py b/tests/core/test_log.py index 41c63be4d8..6bdc0070d1 100644 --- a/tests/core/test_log.py +++ b/tests/core/test_log.py @@ -5,7 +5,7 @@ import pytest -from ophyd_async.core import Device, _log, config_ophyd_async_logging +from ophyd_async.core import Device, _log, config_ophyd_async_logging # noqa: PLC2701 # Allow this importing of _log for now to test the internal interface # But this needs resolving. @@ -60,7 +60,7 @@ def getEffectiveLevel(self): # Full format looks like: -#'[test_device][W 240501 13:28:08.937 test_log:35] here is a warning\n' +# '[test_device][W 240501 13:28:08.937 test_log:35] here is a warning\n' def test_logger_adapter_ophyd_async_device(): log_buffer = io.StringIO() log_stream = logging.StreamHandler(stream=log_buffer) diff --git a/tests/core/test_mock_signal_backend.py b/tests/core/test_mock_signal_backend.py index b0ac8012b3..128dd7eef3 100644 --- a/tests/core/test_mock_signal_backend.py +++ b/tests/core/test_mock_signal_backend.py @@ -1,6 +1,5 @@ import asyncio import re -from itertools import repeat from unittest.mock import ANY, AsyncMock, MagicMock, call import pytest @@ -25,19 +24,18 @@ from ophyd_async.epics.signal import epics_signal_r, epics_signal_rw -@pytest.mark.parametrize("connect_mock_mode", [True, False]) -async def test_mock_signal_backend(connect_mock_mode): - mock_signal = SignalRW(MockSignalBackend(datatype=str)) +async def test_mock_signal_backend(): + mock_signal = SignalRW(SoftSignalBackend(datatype=str)) # If mock is false it will be handled like a normal signal, otherwise it will # initalize a new backend from the one in the line above - await mock_signal.connect(mock=connect_mock_mode) - assert isinstance(mock_signal._backend, MockSignalBackend) - - assert await mock_signal._backend.get_value() == "" - await mock_signal._backend.put("test") - assert await mock_signal._backend.get_value() == "test" - assert mock_signal._backend.put_mock.call_args_list == [ - call("test", wait=True, timeout=None), + await mock_signal.connect(mock=True) + assert isinstance(mock_signal._connector.backend, MockSignalBackend) + + assert await mock_signal._connector.backend.get_value() == "" + await mock_signal._connector.backend.put("test", True) + assert await mock_signal._connector.backend.get_value() == "test" + assert mock_signal._connector.backend.put_mock.call_args_list == [ + call("test", wait=True), ] @@ -65,25 +63,25 @@ async def test_set_mock_value(): mock_signal = SignalRW(SoftSignalBackend(int)) await mock_signal.connect(mock=True) assert await mock_signal.get_value() == 0 - assert mock_signal._backend - assert await mock_signal._backend.get_value() == 0 + assert mock_signal._connector.backend + assert await mock_signal._connector.backend.get_value() == 0 set_mock_value(mock_signal, 10) assert await mock_signal.get_value() == 10 - assert await mock_signal._backend.get_value() == 10 + assert await mock_signal._connector.backend.get_value() == 10 async def test_set_mock_put_proceeds(): mock_signal = SignalW(SoftSignalBackend(str)) await mock_signal.connect(mock=True) - assert isinstance(mock_signal._backend, MockSignalBackend) + assert isinstance(mock_signal._connector.backend, MockSignalBackend) - assert mock_signal._backend.put_proceeds.is_set() is True + assert mock_signal._connector.backend.put_proceeds.is_set() is True set_mock_put_proceeds(mock_signal, False) - assert mock_signal._backend.put_proceeds.is_set() is False + assert mock_signal._connector.backend.put_proceeds.is_set() is False set_mock_put_proceeds(mock_signal, True) - assert mock_signal._backend.put_proceeds.is_set() is True + assert mock_signal._connector.backend.put_proceeds.is_set() is True async def test_set_mock_put_proceeds_timeout(): @@ -92,21 +90,21 @@ async def test_set_mock_put_proceeds_timeout(): set_mock_put_proceeds(mock_signal, False) - with pytest.raises(asyncio.exceptions.TimeoutError): - await mock_signal.set("test", wait=True, timeout=1) + with pytest.raises(asyncio.TimeoutError): + await mock_signal.set("test", wait=True, timeout=0.1) async def test_put_proceeds_timeout(): mock_signal = SignalW(SoftSignalBackend(str)) await mock_signal.connect(mock=True) - assert isinstance(mock_signal._backend, MockSignalBackend) + assert isinstance(mock_signal._connector.backend, MockSignalBackend) - assert mock_signal._backend.put_proceeds.is_set() is True + assert mock_signal._connector.backend.put_proceeds.is_set() is True set_mock_put_proceeds(mock_signal, False) - assert mock_signal._backend.put_proceeds.is_set() is False + assert mock_signal._connector.backend.put_proceeds.is_set() is False set_mock_put_proceeds(mock_signal, True) - assert mock_signal._backend.put_proceeds.is_set() is True + assert mock_signal._connector.backend.put_proceeds.is_set() is True async def test_mock_utils_throw_error_if_backend_isnt_mock_signal_backend(): @@ -145,27 +143,25 @@ async def test_mock_utils_throw_error_if_backend_isnt_mock_signal_backend(): async def test_get_mock_put(): mock_signal = epics_signal_rw(str, "READ_PV", "WRITE_PV", name="mock_name") await mock_signal.connect(mock=True) - await mock_signal.set("test_value", wait=True, timeout=100) + await mock_signal.set("test_value", wait=True) mock = get_mock_put(mock_signal) - mock.assert_called_once_with("test_value", wait=True, timeout=100) + mock.assert_called_once_with("test_value", wait=True) - def err_text(text, wait, timeout): + def err_text(text, wait): return ( - f"Expected: put('{re.escape(str(text))}', wait={re.escape(str(wait))}," - f" timeout={re.escape(str(timeout))})", - "Actual: put('test_value', wait=True, timeout=100)", + f"Expected: put('{re.escape(str(text))}', wait={re.escape(str(wait))})", + "Actual: put('test_value', wait=True)", ) - for text, wait, timeout in [ - ("wrong_name", True, 100), # name wrong - ("test_value", False, 100), # wait wrong - ("test_value", True, 0), # timeout wrong - ("test_value", False, 0), # wait and timeout wrong + for text, wait in [ + ("wrong_name", True), # name wrong + ("test_value", False), # wait wrong + ("wrong_name", False), # name and wait wrong ]: with pytest.raises(AssertionError) as exc: - mock.assert_called_once_with(text, wait=wait, timeout=timeout) - for err_substr in err_text(text, wait, timeout): + mock.assert_called_once_with(text, wait=wait) + for err_substr in err_text(text, wait): assert err_substr in str(exc.value) @@ -186,8 +182,8 @@ async def test_blocks_during_put(mock_signals): signal1, signal2 = mock_signals async with mock_puts_blocked(signal1, signal2): - status1 = signal1.set("second_value", wait=True, timeout=1) - status2 = signal2.set("second_value", wait=True, timeout=1) + status1 = signal1.set("second_value", wait=True, timeout=None) + status2 = signal2.set("second_value", wait=True, timeout=None) assert await signal1.get_value() == "second_value" assert await signal2.get_value() == "second_value" assert not status1.done @@ -197,8 +193,8 @@ async def test_blocks_during_put(mock_signals): assert status1.done assert status2.done - assert await signal1._backend.get_value() == "second_value" - assert await signal2._backend.get_value() == "second_value" + assert await signal1._connector.backend.get_value() == "second_value" + assert await signal2._connector.backend.get_value() == "second_value" async def test_callback_on_mock_put_as_context_manager(mock_signals): @@ -206,12 +202,12 @@ async def test_callback_on_mock_put_as_context_manager(mock_signals): signal2_callbacks = MagicMock() signal1, signal2 = mock_signals with callback_on_mock_put(signal1, signal1_callbacks): - await signal1.set("second_value", wait=True, timeout=1) + await signal1.set("second_value", wait=True) with callback_on_mock_put(signal2, signal2_callbacks): - await signal2.set("second_value", wait=True, timeout=1) + await signal2.set("second_value", wait=True) - signal1_callbacks.assert_called_once_with("second_value", wait=True, timeout=1) - signal2_callbacks.assert_called_once_with("second_value", wait=True, timeout=1) + signal1_callbacks.assert_called_once_with("second_value", wait=True) + signal2_callbacks.assert_called_once_with("second_value", wait=True) async def test_callback_on_mock_put_not_as_context_manager(): @@ -225,7 +221,6 @@ async def test_callback_on_mock_put_not_as_context_manager(): assert calls == [ { "_args": (10.0,), - "timeout": 10.0, "wait": True, } ] @@ -236,12 +231,12 @@ async def test_async_callback_on_mock_put(mock_signals): signal2_callbacks = AsyncMock() signal1, signal2 = mock_signals with callback_on_mock_put(signal1, signal1_callbacks): - await signal1.set("second_value", wait=True, timeout=1) + await signal1.set("second_value", wait=True) with callback_on_mock_put(signal2, signal2_callbacks): - await signal2.set("second_value", wait=True, timeout=1) + await signal2.set("second_value", wait=True) - signal1_callbacks.assert_awaited_once_with("second_value", wait=True, timeout=1) - signal2_callbacks.assert_awaited_once_with("second_value", wait=True, timeout=1) + signal1_callbacks.assert_awaited_once_with("second_value", wait=True) + signal2_callbacks.assert_awaited_once_with("second_value", wait=True) async def test_callback_on_mock_put_fails_if_args_are_not_correct(): @@ -283,7 +278,7 @@ async def test_set_mock_values_exhausted_passes(mock_signals): iterator = set_mock_values( signal2, - repeat(iter(["second_value", "third_value"]), 6), + ["second_value", "third_value"] * 3, require_all_consumed=False, ) calls = 0 @@ -294,7 +289,7 @@ async def test_set_mock_values_exhausted_passes(mock_signals): async def test_set_mock_values_exhausted_fails(mock_signals): - signal1, signal2 = mock_signals + signal1, _ = mock_signals for value_set in ( iterator := set_mock_values( @@ -312,17 +307,17 @@ async def test_set_mock_values_exhausted_fails(mock_signals): async def test_reset_mock_put_calls(mock_signals): - signal1, signal2 = mock_signals + signal1, _ = mock_signals await signal1.set("test_value", wait=True, timeout=1) - get_mock_put(signal1).assert_called_with("test_value", wait=ANY, timeout=ANY) + get_mock_put(signal1).assert_called_with("test_value", wait=ANY) reset_mock_put_calls(signal1) with pytest.raises(AssertionError) as exc: - get_mock_put(signal1).assert_called_with("test_value", wait=ANY, timeout=ANY) + get_mock_put(signal1).assert_called_with("test_value", wait=ANY) # Replacing spaces because they change between runners # (e.g the github actions runner has more) assert str(exc.value).replace(" ", "").replace("\n", "") == ( "expectedcallnotfound." - "Expected:put('test_value',wait=,timeout=)" + "Expected:put('test_value',wait=)" "Actual:notcalled." ) @@ -333,8 +328,8 @@ def __init__(self, name): self.my_signal = soft_signal_rw( datatype=int, initial_value=10, - name=name, ) + super().__init__(name) mocked_device = SomeDevice("mocked_device") await mocked_device.connect(mock=True) @@ -349,8 +344,9 @@ async def test_mock_signal_of_soft_signal_backend_receives_metadata(): class SomeDevice(Device): def __init__(self, name): self.my_signal = soft_signal_rw( - datatype=float, initial_value=1.0, name=name, units="mm", precision=2 + datatype=float, initial_value=1.0, units="mm", precision=2 ) + super().__init__(name) mocked_device = SomeDevice("mocked_device") await mocked_device.connect(mock=True) @@ -358,21 +354,21 @@ def __init__(self, name): await soft_device.connect(mock=False) assert await mocked_device.my_signal.describe() == { - "mocked_device": { + "mocked_device-my_signal": { "dtype": "number", "dtype_numpy": " int: + return len(list(re.finditer(re.escape(substring), string))) async def test_signal_connects_to_previous_backend(caplog): caplog.set_level(logging.DEBUG) - int_mock_backend = MockSignalBackend(int) + int_mock_backend = MockSignalBackend(SoftSignalBackend(int)) original_connect = int_mock_backend.connect times_backend_connect_called = 0 @@ -85,42 +57,17 @@ async def new_connect(timeout=1): int_mock_backend.connect = new_connect signal = Signal(int_mock_backend) await asyncio.gather(signal.connect(), signal.connect()) - response = f"Reusing previous connection to {signal.source}" - assert response in caplog.text + assert num_occurrences(f"Connecting to {signal.source}", caplog.text) == 1 assert times_backend_connect_called == 1 async def test_signal_connects_with_force_reconnect(caplog): caplog.set_level(logging.DEBUG) - signal = Signal(MockSignalBackend(int)) + signal = Signal(MockSignalBackend(SoftSignalBackend(int))) await signal.connect() - assert signal._backend.datatype is int + assert num_occurrences(f"Connecting to {signal.source}", caplog.text) == 1 await signal.connect(force_reconnect=True) - response = f"Connecting to {signal.source}" - assert response in caplog.text - assert "Reusing previous connection to" not in caplog.text - - -@pytest.mark.parametrize( - "first, second", - [(True, False), (True, False)], -) -async def test_rejects_reconnect_when_connects_have_diff_mock_status( - caplog, first, second -): - caplog.set_level(logging.DEBUG) - signal = Signal(MockSignalBackend(int)) - await signal.connect(mock=first) - assert signal._backend.datatype is int - with pytest.raises(RuntimeError) as exc: - await signal.connect(mock=second) - - assert f"`connect(mock={second})` called on a `Signal` where the previous " in str( - exc.value - ) - - response = f"Connecting to {signal.source}" - assert response in caplog.text + assert num_occurrences(f"Connecting to {signal.source}", caplog.text) == 2 async def test_signal_lazily_connects(RE): @@ -135,7 +82,7 @@ async def connect(self, timeout=DEFAULT_TIMEOUT): self.succeed_on_connect = True raise RuntimeError("connect fail") - signal = SignalRW(MockSignalBackendFailingFirst(int)) + signal = SignalRW(MockSignalBackendFailingFirst(SoftSignalBackend(int))) with pytest.raises(RuntimeError, match="connect fail"): await signal.connect(mock=False) @@ -348,25 +295,18 @@ async def test_create_soft_signal(signal_method, signal_class): SIGNAL_NAME = "TEST-PREFIX:SIGNAL" INITIAL_VALUE = "INITIAL" if signal_method == soft_signal_r_and_setter: - signal, unused_backend_set = signal_method(str, INITIAL_VALUE, SIGNAL_NAME) + signal, _ = signal_method(str, INITIAL_VALUE, SIGNAL_NAME) elif signal_method == soft_signal_rw: signal = signal_method(str, INITIAL_VALUE, SIGNAL_NAME) + else: + raise ValueError(signal_method) assert signal.source == f"soft://{SIGNAL_NAME}" assert isinstance(signal, signal_class) - assert isinstance(signal._backend, SoftSignalBackend) await signal.connect() + assert isinstance(signal._connector.backend, SoftSignalBackend) assert (await signal.get_value()) == INITIAL_VALUE -async def test_soft_signal_numpy(): - float_signal = soft_signal_rw(numpy.float64, numpy.float64(1), "float_signal") - int_signal = soft_signal_rw(numpy.int32, numpy.int32(1), "int_signal") - await float_signal.connect() - await int_signal.connect() - assert (await float_signal.describe())["float_signal"]["dtype"] == "number" - assert (await int_signal.describe())["int_signal"]["dtype"] == "integer" - - @pytest.fixture async def mock_signal(): mock_signal = epics_signal_rw(int, "pva://mock_signal", name="mock_signal") @@ -439,13 +379,6 @@ async def test_assert_configuration(mock_readable: DummyReadable): await assert_configuration(mock_readable, dummy_config_reading) -async def test_signal_connect_logs(caplog): - caplog.set_level(logging.DEBUG) - mock_signal_rw = epics_signal_rw(int, "pva://mock_signal", name="mock_signal") - await mock_signal_rw.connect(mock=True) - assert caplog.text.endswith("Connecting to mock+pva://mock_signal\n") - - async def test_signal_get_and_set_logging(caplog): caplog.set_level(logging.DEBUG) mock_signal_rw = epics_signal_rw(int, "pva://mock_signal", name="mock_signal") @@ -477,17 +410,12 @@ def some_function(self): pass err_str = ( - "Given datatype .SomeClass'>" - " unsupported in %s." ) - with pytest.raises(TypeError, match=err_str % ("PVA",)): - epics_signal_rw(SomeClass, "pva://mock_signal", name="mock_signal") - with pytest.raises(TypeError, match=err_str % ("CA",)): - epics_signal_rw(SomeClass, "ca://mock_signal", name="mock_signal") - - # Any dtype allowed in soft signal - signal = soft_signal_rw(SomeClass, SomeClass(), "soft_signal") - assert isinstance((await signal.get_value()), SomeClass) - await signal.set(1) - assert (await signal.get_value()) == 1 + with pytest.raises(TypeError, match=err_str): + await epics_signal_rw(SomeClass, "pva://mock_signal").connect(mock=True) + with pytest.raises(TypeError, match=err_str): + await epics_signal_rw(SomeClass, "ca://mock_signal").connect(mock=True) + with pytest.raises(TypeError, match=err_str): + soft_signal_rw(SomeClass) diff --git a/tests/core/test_soft_signal_backend.py b/tests/core/test_soft_signal_backend.py index 65316d18a8..f23046ec2b 100644 --- a/tests/core/test_soft_signal_backend.py +++ b/tests/core/test_soft_signal_backend.py @@ -1,7 +1,6 @@ import asyncio import time from collections.abc import Callable, Sequence -from enum import Enum from typing import Any import numpy as np @@ -9,10 +8,16 @@ import pytest from bluesky.protocols import Reading -from ophyd_async.core import Signal, SignalBackend, SignalMetadata, SoftSignalBackend, T +from ophyd_async.core import ( + SignalBackend, + SoftSignalBackend, + StrictEnum, + T, + soft_signal_rw, +) -class MyEnum(str, Enum): +class MyEnum(StrictEnum): a = "Aaa" b = "Bbb" c = "Ccc" @@ -31,7 +36,7 @@ def string_d(value): def enum_d(value): - return {"dtype": "string", "shape": [], "choices": ("Aaa", "Bbb", "Ccc")} + return {"dtype": "string", "shape": [], "choices": ["Aaa", "Bbb", "Ccc"]} def waveform_d(value): @@ -41,11 +46,8 @@ def waveform_d(value): class MonitorQueue: def __init__(self, backend: SignalBackend): self.backend = backend - self.updates: asyncio.Queue[tuple[Reading, Any]] = asyncio.Queue() - backend.set_callback(self.add_reading_value) - - def add_reading_value(self, reading: Reading, value): - self.updates.put_nowait((reading, value)) + self.updates: asyncio.Queue[Reading] = asyncio.Queue() + backend.set_callback(self.updates.put_nowait) async def assert_updates(self, expected_value): expected_reading = { @@ -53,12 +55,12 @@ async def assert_updates(self, expected_value): "timestamp": pytest.approx(time.monotonic(), rel=0.1), "alarm_severity": 0, } - reading, value = await self.updates.get() + reading = await self.updates.get() backend_value = await self.backend.get_value() backend_reading = await self.backend.get_reading() - assert value == expected_value == backend_value + assert reading["value"] == expected_value == backend_value assert reading == expected_reading == backend_reading def close(self): @@ -70,19 +72,19 @@ def close(self): [ (int, 0, 43, integer_d, " None: pass - soft_signal = Signal(SoftSignalBackend(myClass)) - await soft_signal.connect() - - with pytest.raises(AssertionError): - await soft_signal._backend.get_datakey("") + with pytest.raises(TypeError): + SoftSignalBackend(myClass) async def test_soft_signal_descriptor_with_metadata(): - soft_signal = Signal( - SoftSignalBackend(int, 0, metadata=SignalMetadata(units="mm", precision=0)) - ) + soft_signal = soft_signal_rw(int, 0, units="mm", precision=0) await soft_signal.connect() - datakey = await soft_signal._backend.get_datakey("") - assert datakey["units"] == "mm" - assert datakey["precision"] == 0 + datakey = await soft_signal.describe() + assert datakey[""]["units"] == "mm" + assert datakey[""]["precision"] == 0 - soft_signal = Signal(SoftSignalBackend(int, metadata=SignalMetadata(units=""))) + soft_signal = soft_signal_rw(int, units="") await soft_signal.connect() - datakey = await soft_signal._backend.get_datakey("") - assert datakey["units"] == "" - assert not hasattr(datakey, "precision") + datakey = await soft_signal.describe() + assert datakey[""]["units"] == "" + assert not hasattr(datakey[""], "precision") async def test_soft_signal_descriptor_with_no_metadata_not_passed(): - soft_signal = Signal(SoftSignalBackend(int)) + soft_signal = soft_signal_rw(int) await soft_signal.connect() - datakey = await soft_signal._backend.get_datakey("") - assert not hasattr(datakey, "units") - assert not hasattr(datakey, "precision") + datakey = await soft_signal.describe() + assert not hasattr(datakey[""], "units") + assert not hasattr(datakey[""], "precision") - soft_signal = Signal(SoftSignalBackend(int, metadata=None)) - await soft_signal.connect() - datakey = await soft_signal._backend.get_datakey("") - assert not hasattr(datakey, "units") - assert not hasattr(datakey, "precision") - soft_signal = Signal(SoftSignalBackend(int, metadata={})) +async def test_soft_signal_coerces_numpy_types(): + soft_signal = soft_signal_rw(float) await soft_signal.connect() - datakey = await soft_signal._backend.get_datakey("") - assert not hasattr(datakey, "units") - assert not hasattr(datakey, "precision") + assert await soft_signal.get_value() == 0.0 + assert type(await soft_signal.get_value()) is float + await soft_signal.set(np.float64(1.1)) + assert await soft_signal.get_value() == 1.1 + assert type(await soft_signal.get_value()) is float + soft_signal._connector.backend.set_value(np.float64(2.2)) + assert await soft_signal.get_value() == 2.2 + assert type(await soft_signal.get_value()) is float diff --git a/tests/core/test_subset_enum.py b/tests/core/test_subset_enum.py index 3512075e7f..ae9e56d2e8 100644 --- a/tests/core/test_subset_enum.py +++ b/tests/core/test_subset_enum.py @@ -1,4 +1,3 @@ -import pytest from epicscorelibs.ca import dbr from p4p import Value as P4PValue from p4p.nt import NTEnum @@ -7,32 +6,27 @@ from ophyd_async.epics.signal import epics_signal_rw # Allow these imports from private modules for tests -from ophyd_async.epics.signal._aioca import make_converter as ca_make_converter -from ophyd_async.epics.signal._p4p import make_converter as pva_make_converter +from ophyd_async.epics.signal._aioca import ( + make_converter as ca_make_converter, # noqa: PLC2701 +) +from ophyd_async.epics.signal._p4p import ( + make_converter as pva_make_converter, # noqa: PLC2701 +) -async def test_runtime_enum_behaviour(): - rt_enum = SubsetEnum["A", "B"] +class AB(SubsetEnum): + a = "A" + b = "B" - with pytest.raises(RuntimeError) as exc: - rt_enum() - assert str(exc.value) == "SubsetEnum cannot be instantiated" - assert issubclass(rt_enum, SubsetEnum) +class AB1(SubsetEnum): + a = "A1" + b = "B1" - # Our metaclass doesn't cache already created runtime enums, - # so we can't do this - assert not issubclass(rt_enum, SubsetEnum["A", "B"]) - assert not issubclass(rt_enum, SubsetEnum["B", "A"]) - assert rt_enum is not SubsetEnum["A", "B"] - assert rt_enum is not SubsetEnum["B", "A"] - assert str(rt_enum) == "SubsetEnum['A', 'B']" - assert str(SubsetEnum) == "SubsetEnum" - - with pytest.raises(TypeError) as exc: - SubsetEnum["A", "B", "A"] - assert str(exc.value) == "Duplicate elements in runtime enum choices." +class AB2(SubsetEnum): + a = "A2" + b = "B2" async def test_ca_runtime_enum_converter(): @@ -51,12 +45,11 @@ def __init__(self): self.enums = ["A", "B", "C"] # More than the runtime enum epics_value = EpicsValue() - rt_enum = SubsetEnum["A", "B"] converter = ca_make_converter( - rt_enum, values={"READ_PV": epics_value, "WRITE_PV": epics_value} + AB, values={"READ_PV": epics_value, "WRITE_PV": epics_value} ) - assert converter.choices == {"A": "A", "B": "B", "C": "C"} - assert set(rt_enum.choices).issubset(set(converter.choices.keys())) + assert converter.supported_values == {"A": "A", "B": "B", "C": "C"} + assert set(AB).issubset(set(converter.supported_values.keys())) async def test_pva_runtime_enum_converter(): @@ -67,16 +60,15 @@ async def test_pva_runtime_enum_converter(): "value.choices": ["A", "B", "C"], }, ) - rt_enum = SubsetEnum["A", "B"] converter = pva_make_converter( - rt_enum, values={"READ_PV": epics_value, "WRITE_PV": epics_value} + AB, values={"READ_PV": epics_value, "WRITE_PV": epics_value} ) - assert {"A", "B"}.issubset(set(converter.choices)) + assert {"A", "B"}.issubset(set(converter.supported_values)) async def test_runtime_enum_signal(): - signal_rw_pva = epics_signal_rw(SubsetEnum["A1", "B1"], "ca://RW_PV", name="signal") - signal_rw_ca = epics_signal_rw(SubsetEnum["A2", "B2"], "ca://RW_PV", name="signal") + signal_rw_pva = epics_signal_rw(AB1, "ca://RW_PV", name="signal") + signal_rw_ca = epics_signal_rw(AB2, "ca://RW_PV", name="signal") await signal_rw_pva.connect(mock=True) await signal_rw_ca.connect(mock=True) assert await signal_rw_pva.get_value() == "A1" diff --git a/tests/core/test_table.py b/tests/core/test_table.py new file mode 100644 index 0000000000..3f7a7a6403 --- /dev/null +++ b/tests/core/test_table.py @@ -0,0 +1,53 @@ +from collections.abc import Sequence + +import numpy as np +import pytest +from pydantic import ValidationError + +from ophyd_async.core import Array1D, Table + + +class MyTable(Table): + bool: Array1D[np.bool_] + uint: Array1D[np.uint32] + str: Sequence[str] + + +@pytest.mark.parametrize( + ["kwargs", "error_msg"], + [ + ( + {"bool": [3, 4], "uint": [3, 4], "str": ["", ""]}, + "bool: Cannot cast [3, 4] to bool without losing precision", + ), + ( + {"bool": np.array([1], dtype=np.uint8), "uint": [-3], "str": [""]}, + "uint: Cannot cast [-3] to uint32 without losing precision", + ), + ( + {"bool": [0], "uint": np.array([1.8], dtype=np.float64), "str": [""]}, + "uint: Cannot cast [1.8] to uint32 without losing precision", + ), + ( + {"bool": [0, 1], "uint": [3, 4], "str": [44, ""]}, + "Input should be a valid string [type=string_type, input_value=44,", + ), + ], +) +def test_table_wrong_types(kwargs, error_msg): + with pytest.raises(ValidationError) as cm: + MyTable(**kwargs) + assert error_msg in str(cm.value) + + +@pytest.mark.parametrize( + "kwargs", + [ + {"bool": np.array([1], dtype=np.uint8), "uint": [3], "str": ["a"]}, + {"bool": [False], "uint": np.array([1], dtype=np.float64), "str": ["b"]}, + ], +) +def test_table_coerces(kwargs): + t = MyTable(**kwargs) + for k, v in t: + assert v == pytest.approx(kwargs[k]) diff --git a/tests/core/test_utils.py b/tests/core/test_utils.py index d29f595901..a40a8ed7a1 100644 --- a/tests/core/test_utils.py +++ b/tests/core/test_utils.py @@ -6,17 +6,18 @@ DEFAULT_TIMEOUT, Device, DeviceCollector, - MockSignalBackend, + SoftSignalBackend, NotConnected, SignalRW, ) +from ophyd_async.core import soft_signal_rw from ophyd_async.epics.signal import epics_signal_rw -class ValueErrorBackend(MockSignalBackend): +class ValueErrorBackend(SoftSignalBackend): def __init__(self, exc_text=""): self.exc_text = exc_text - super().__init__(datatype=int, initial_backend=None) + super().__init__(datatype=int) async def connect(self, timeout: float = DEFAULT_TIMEOUT): raise ValueError(self.exc_text) @@ -24,7 +25,7 @@ async def connect(self, timeout: float = DEFAULT_TIMEOUT): class WorkingDummyChildDevice(Device): def __init__(self, name: str = "working_dummy_child_device") -> None: - self.working_signal = SignalRW(backend=MockSignalBackend(datatype=int)) + self.working_signal = soft_signal_rw(int) super().__init__(name=name) @@ -44,7 +45,7 @@ class ValueErrorDummyChildDevice(Device): def __init__( self, name: str = "value_error_dummy_child_device", exc_text="" ) -> None: - self.value_error_signal = SignalRW(backend=ValueErrorBackend(exc_text=exc_text)) + self.value_error_signal = SignalRW(ValueErrorBackend(exc_text=exc_text)) super().__init__(name=name) diff --git a/tests/epics/adaravis/test_aravis.py b/tests/epics/adaravis/test_aravis.py index 270b661e55..e9c1d52414 100644 --- a/tests/epics/adaravis/test_aravis.py +++ b/tests/epics/adaravis/test_aravis.py @@ -89,7 +89,7 @@ async def test_decribe_describes_writer_dataset( assert await test_adaravis.describe() == { "test_adaravis1": { "source": "mock+ca://ARAVIS1:HDF1:FullFileName_RBV", - "shape": (10, 10), + "shape": [10, 10], "dtype": "array", "dtype_numpy": "|i1", "external": "STREAM:", @@ -135,7 +135,7 @@ async def test_can_decribe_collect( assert (await test_adaravis.describe_collect()) == { "test_adaravis1": { "source": "mock+ca://ARAVIS1:HDF1:FullFileName_RBV", - "shape": (10, 10), + "shape": [10, 10], "dtype": "array", "dtype_numpy": "|i1", "external": "STREAM:", diff --git a/tests/epics/adcore/test_single_trigger.py b/tests/epics/adcore/test_single_trigger.py index 9bb5135cc0..b5b31756c4 100644 --- a/tests/epics/adcore/test_single_trigger.py +++ b/tests/epics/adcore/test_single_trigger.py @@ -8,25 +8,21 @@ @pytest.fixture -async def single_trigger_det_with_stats(): - stats = adcore.NDPluginStatsIO("PREFIX:STATS", name="stats") +async def single_trigger_det(): + stats = adcore.NDPluginStatsIO("PREFIX:STATS:") det = adcore.SingleTriggerDetector( - drv=adcore.ADBaseIO("PREFIX:DRV"), + drv=adcore.ADBaseIO("PREFIX:DRV:"), stats=stats, read_uncached=[stats.unique_id], name="det", ) - - # Set non-default values to check they are set back - # These are using set_mock_value to simulate the backend IOC being setup - # in a particular way, rather than values being set by the Ophyd signals - yield det, stats + yield det async def test_single_trigger_det( - single_trigger_det_with_stats: adcore.SingleTriggerDetector, RE: RunEngine + single_trigger_det: adcore.SingleTriggerDetector, + RE: RunEngine, ): - single_trigger_det, stats = single_trigger_det_with_stats names = [] docs = [] RE.subscribe(lambda name, _: names.append(name)) diff --git a/tests/epics/adcore/test_writers.py b/tests/epics/adcore/test_writers.py index af32f86667..9c402afdda 100644 --- a/tests/epics/adcore/test_writers.py +++ b/tests/epics/adcore/test_writers.py @@ -1,3 +1,4 @@ +import xml.etree.ElementTree as ET from unittest.mock import patch import pytest @@ -8,11 +9,11 @@ PathProvider, StandardDetector, StaticPathProvider, + set_mock_value, ) -from ophyd_async.core._mock_signal_utils import set_mock_value from ophyd_async.epics import adaravis, adcore, adkinetix, adpilatus, advimba -from ophyd_async.epics.signal._signal import epics_signal_r -from ophyd_async.plan_stubs._nd_attributes import setup_ndattributes, setup_ndstats_sum +from ophyd_async.epics.signal import epics_signal_r +from ophyd_async.plan_stubs import setup_ndattributes, setup_ndstats_sum class DummyDatasetDescriber(DatasetDescriber): @@ -107,14 +108,14 @@ async def test_stats_describe_when_plugin_configured( assert descriptor == { "test": { "source": "mock+ca://HDF:FullFileName_RBV", - "shape": (10, 10), + "shape": [10, 10], "dtype": "array", "dtype_numpy": " None: - await fill_pvi_entries( - self, self._prefix + "PVI", timeout=timeout, mock=mock - ) +DeviceT = TypeVar("DeviceT", bound=Device) - await super().connect(mock=mock) - yield TestDevice +def with_pvi_connector( + device_type: type[DeviceT], prefix: str, name: str = "" +) -> DeviceT: + connector = PviDeviceConnector(prefix + ":PVI") + device = device_type(connector=connector, name=name) + connector.create_children_from_annotations(device) + return device -async def test_fill_pvi_entries_mock_mode(pvi_test_device_t): +async def test_fill_pvi_entries_mock_mode(): async with DeviceCollector(mock=True): - test_device = pvi_test_device_t("PREFIX:") + test_device = with_pvi_connector(Block3, "PREFIX:") # device vectors are typed assert isinstance(test_device.device_vector[1], Block2) assert isinstance(test_device.device_vector[2], Block2) # elements of device vectors are typed recursively - assert test_device.device_vector[1].signal_rw._backend.datatype is int + assert test_device.device_vector[1].signal_rw._connector.backend.datatype is int assert isinstance(test_device.device_vector[1].device, Block1) - assert test_device.device_vector[1].device.signal_rw._backend.datatype is int # type: ignore assert ( - test_device.device_vector[1].device.device_vector_signal_rw[1]._backend.datatype # type: ignore + test_device.device_vector[1].device.signal_rw._connector.backend.datatype is int + ) # type: ignore + assert ( + test_device.device_vector[1] + .device.device_vector_signal_rw[1] + ._connector.backend.datatype # type: ignore is float ) @@ -76,9 +77,9 @@ async def test_fill_pvi_entries_mock_mode(pvi_test_device_t): assert isinstance(test_device.device, Block2) # elements of top level blocks are typed recursively - assert test_device.device.signal_rw._backend.datatype is int # type: ignore + assert test_device.device.signal_rw._connector.backend.datatype is int # type: ignore assert isinstance(test_device.device.device, Block1) - assert test_device.device.device.signal_rw._backend.datatype is int # type: ignore + assert test_device.device.device.signal_rw._connector.backend.datatype is int # type: ignore assert test_device.signal_rw.parent == test_device assert test_device.device_vector.parent == test_device @@ -93,48 +94,23 @@ async def test_fill_pvi_entries_mock_mode(pvi_test_device_t): ) # top level signals are typed - assert test_device.signal_rw._backend.datatype is int - + assert test_device.signal_rw._connector.backend.datatype is int -@pytest.fixture -def pvi_test_device_create_children_from_annotations_t(): - """A fixture since pytest discourages init in test case classes""" - class TestDevice(Block3, Device): - def __init__(self, prefix: str, name: str = ""): - self._prefix = prefix - super().__init__(name) - create_children_from_annotations(self) - - async def connect( # type: ignore - self, mock: bool = False, timeout: float = DEFAULT_TIMEOUT - ) -> None: - await fill_pvi_entries( - self, self._prefix + "PVI", timeout=timeout, mock=mock - ) - - await super().connect(mock=mock) - - yield TestDevice - - -async def test_device_create_children_from_annotations( - pvi_test_device_create_children_from_annotations_t, -): - device = pvi_test_device_create_children_from_annotations_t("PREFIX:") +async def test_device_create_children_from_annotations(): + device = with_pvi_connector(Block3, "PREFIX:") block_2_device = device.device block_1_device = device.device.device top_block_1_device = device.signal_device - # The create_children_from_annotations has only made blocks, - # not signals or device vectors + # The create_children_from_annotations has made blocks all the way down assert isinstance(block_2_device, Block2) assert isinstance(block_1_device, Block1) assert isinstance(top_block_1_device, Block1) - assert not hasattr(device, "signal_x") - assert not hasattr(device, "signal_rw") - assert not hasattr(top_block_1_device, "signal_rw") + assert hasattr(device, "signal_x") + assert hasattr(device, "signal_rw") + assert hasattr(top_block_1_device, "signal_rw") await device.connect(mock=True) @@ -144,59 +120,25 @@ async def test_device_create_children_from_annotations( assert device.signal_device is top_block_1_device -@pytest.fixture -def pvi_test_device_with_device_vectors_t(): - """A fixture since pytest discourages init in test case classes""" - - class TestBlock(Device): - device_vector: DeviceVector[Block1] - device: Block1 | None - signal_x: SignalX - signal_rw: SignalRW[int] | None - - class TestDevice(TestBlock): - def __init__(self, prefix: str, name: str = ""): - self._prefix = prefix - create_children_from_annotations( - self, - included_optional_fields=("device", "signal_rw"), - device_vectors={"device_vector": 2}, - ) - super().__init__(name) - - async def connect( # type: ignore - self, mock: bool = False, timeout: float = DEFAULT_TIMEOUT - ) -> None: - await fill_pvi_entries( - self, self._prefix + "PVI", timeout=timeout, mock=mock - ) - - await super().connect(mock=mock) - - yield TestDevice - +async def test_device_create_children_from_annotations_with_device_vectors(): + device = with_pvi_connector(Block4, "PREFIX:", name="test_device") + await device.connect(mock=True) -async def test_device_create_children_from_annotations_with_device_vectors( - pvi_test_device_with_device_vectors_t, -): - device = pvi_test_device_with_device_vectors_t("PREFIX:", name="test_device") + block_1_device = device.device + block_2_device_vector = device.device_vector assert device.device_vector[1].name == "test_device-device_vector-1" assert device.device_vector[2].name == "test_device-device_vector-2" - block_1_device = device.device - block_2_device_vector = device.device_vector # create_children_from_annotiations should have made DeviceVectors # and an optional Block, but no signals assert hasattr(device, "device_vector") - assert not hasattr(device, "signal_rw") + assert hasattr(device, "signal_rw") assert isinstance(block_2_device_vector, DeviceVector) assert isinstance(block_2_device_vector[1], Block1) assert len(device.device_vector) == 2 assert isinstance(block_1_device, Block1) - await device.connect(mock=True) - # The memory addresses have not changed assert device.device is block_1_device assert device.device_vector is block_2_device_vector diff --git a/tests/epics/signal/test_common.py b/tests/epics/signal/test_common.py index 7a16a59a51..124273b2f0 100644 --- a/tests/epics/signal/test_common.py +++ b/tests/epics/signal/test_common.py @@ -2,6 +2,7 @@ import pytest +from ophyd_async.core import StrictEnum from ophyd_async.epics.signal import get_supported_values @@ -19,7 +20,7 @@ class MyEnum(Enum): def test_given_pv_has_choices_not_in_supplied_enum_then_raises(): - class MyEnum(str, Enum): + class MyEnum(StrictEnum): TEST = "test" with pytest.raises(TypeError): @@ -27,7 +28,7 @@ class MyEnum(str, Enum): def test_given_supplied_enum_has_choices_not_in_pv_then_raises(): - class MyEnum(str, Enum): + class MyEnum(StrictEnum): TEST = "test" OTHER = "unexpected_choice" @@ -35,20 +36,8 @@ class MyEnum(str, Enum): get_supported_values("", MyEnum, ("test",)) -@pytest.mark.parametrize( - "datatype", - [None, str], -) -def test_given_no_enum_or_string_then_returns_generated_choices_enum_with_pv_choices( - datatype, -): - supported_vals = get_supported_values("", datatype, ("test",)) - assert len(supported_vals) == 1 - assert "test" in supported_vals - - def test_given_a_supplied_enum_that_matches_the_pv_choices_then_enum_type_is_returned(): - class MyEnum(str, Enum): + class MyEnum(StrictEnum): TEST_1 = "test_1" TEST_2 = "test_2" diff --git a/tests/epics/signal/test_records.db b/tests/epics/signal/test_records.db index e5aa5a776c..18ad9a6678 100644 --- a/tests/epics/signal/test_records.db +++ b/tests/epics/signal/test_records.db @@ -184,7 +184,7 @@ record(waveform, "$(P)longstr") { record(lsi, "$(P)longstr2") { field(SIZV, "80") field(INP, {const:"a string that is just longer than forty characters"}) - field(PINI, "YES") + field(PINI, "YES") } record(waveform, "$(P)table:labels") { @@ -328,27 +328,3 @@ record(waveform, "$(P)ntndarray:data") } }) } - -record(ao, "$(P)width") -{ - info(Q:group, { - "$(P)pvi": { - "pvi.width.rw": { - "+channel": "NAME", - "+type": "plain" - } - } - }) -} - -record(ao, "$(P)height") -{ - info(Q:group, { - "$(P)pvi": { - "pvi.height.rw": { - "+channel": "NAME", - "+type": "plain" - } - } - }) -} diff --git a/tests/epics/signal/test_signals.py b/tests/epics/signal/test_signals.py index 0ce7fd6b45..4c6aadfa69 100644 --- a/tests/epics/signal/test_signals.py +++ b/tests/epics/signal/test_signals.py @@ -1,6 +1,5 @@ import asyncio import random -import re import string import subprocess import sys @@ -16,33 +15,33 @@ import bluesky.plan_stubs as bps import numpy as np -import numpy.typing as npt import pytest -from aioca import CANothing, purge_channel_caches +from aioca import purge_channel_caches from bluesky.protocols import Reading from bluesky.run_engine import RunEngine -from event_model import DataKey +from event_model import DataKey, Limits, LimitsRange from ophyd.signal import EpicsSignal -from typing_extensions import TypedDict from ophyd_async.core import ( + Array1D, NotConnected, SignalBackend, + SignalRW, + StrictEnum, SubsetEnum, T, + Table, load_from_yaml, save_to_yaml, ) from ophyd_async.epics.signal import ( - LimitPair, - Limits, epics_signal_r, epics_signal_rw, epics_signal_rw_rbv, epics_signal_w, epics_signal_x, ) -from ophyd_async.epics.signal._epics_transport import _EpicsTransport # noqa +from ophyd_async.epics.signal._signal import _epics_signal_backend # noqa: PLC2701 RECORDS = str(Path(__file__).parent / "test_records.db") PV_PREFIX = "".join(random.choice(string.ascii_lowercase) for _ in range(12)) @@ -54,16 +53,14 @@ class IOC: protocol: Literal["ca", "pva"] async def make_backend( - self, typ: type | None, suff: str, connect=True + self, typ: type | None, suff: str, timeout=10.0 ) -> SignalBackend: # Calculate the pv - pv = f"{PV_PREFIX}:{self.protocol}:{suff}" + pv = f"{self.protocol}://{PV_PREFIX}:{self.protocol}:{suff}" # Make and connect the backend - cls = _EpicsTransport[self.protocol].value - backend = cls(typ, pv, pv) # type: ignore - if connect: - await asyncio.wait_for(backend.connect(), 10) # type: ignore - return backend # type: ignore + backend = _epics_signal_backend(typ, pv, pv) + await backend.connect(timeout=timeout) + return backend # Use a module level fixture per protocol so it's fast to run tests. This means @@ -89,11 +86,13 @@ def ioc(request: pytest.FixtureRequest): ) start_time = time.monotonic() - while "iocRun: All initialization complete" not in ( - process.stdout.readline().strip() # type: ignore - ): + line = "" + while "iocRun: All initialization complete" not in line: + if line: + print(line) if time.monotonic() - start_time > 10: raise TimeoutError("IOC did not start in time") + line = process.stdout.readline().strip() # type: ignore yield IOC(process, protocol) @@ -131,11 +130,8 @@ def assert_types_are_equal(t_actual, t_expected, actual_value): class MonitorQueue: def __init__(self, backend: SignalBackend): self.backend = backend - self.subscription = backend.set_callback(self.add_reading_value) - self.updates: asyncio.Queue[tuple[Reading, Any]] = asyncio.Queue() - - def add_reading_value(self, reading: Reading, value): - self.updates.put_nowait((reading, value)) + self.updates: asyncio.Queue[Reading] = asyncio.Queue() + self.subscription = backend.set_callback(self.updates.put_nowait) async def assert_updates(self, expected_value, expected_type=None): expected_reading = { @@ -144,14 +140,15 @@ async def assert_updates(self, expected_value, expected_type=None): "alarm_severity": 0, } backend_reading = await asyncio.wait_for(self.backend.get_reading(), timeout=5) - reading, value = await asyncio.wait_for(self.updates.get(), timeout=5) backend_value = await asyncio.wait_for(self.backend.get_value(), timeout=5) + update_reading = await asyncio.wait_for(self.updates.get(), timeout=5) + update_value = update_reading["value"] - assert value == expected_value == backend_value + assert update_value == expected_value == backend_value if expected_type: - assert_types_are_equal(type(value), expected_type, value) + assert_types_are_equal(type(update_value), expected_type, update_value) assert_types_are_equal(type(backend_value), expected_type, backend_value) - assert reading == expected_reading == backend_reading + assert update_reading == expected_reading == backend_reading def close(self): self.backend.set_callback(None) @@ -186,7 +183,7 @@ async def assert_monitor_then_put( datatype if check_type else None, ) # Put to new value and check that - await backend.put(put_value) + await backend.put(put_value, wait=True) await q.assert_updates( pytest.approx(put_value), datatype if check_type else None ) @@ -194,41 +191,31 @@ async def assert_monitor_then_put( q.close() -async def put_error( - ioc: IOC, - suffix: str, - put_value: T, - datatype: type[T] | None = None, -): - backend = await ioc.make_backend(datatype, suffix) - # The below will work without error - await backend.put(put_value) - # Change the name of write_pv to mock disconnection - backend.__setattr__("write_pv", "Disconnect") - await backend.put(put_value, timeout=3) +class MyEnum(StrictEnum): + a = "Aaa" + b = "Bbb" + c = "Ccc" -class MyEnum(str, Enum): +class MySubsetEnum(SubsetEnum): a = "Aaa" b = "Bbb" c = "Ccc" -MySubsetEnum = SubsetEnum["Aaa", "Bbb", "Ccc"] - _metadata: dict[str, dict[str, dict[str, Any]]] = { "ca": { "boolean": {"units": ANY, "limits": ANY}, "integer": {"units": ANY, "limits": ANY}, "number": {"units": ANY, "limits": ANY, "precision": ANY}, - "enum": {"limits": ANY}, - "string": {"limits": ANY}, + "enum": {}, + "string": {}, }, "pva": { - "boolean": {"limits": ANY}, + "boolean": {}, "integer": {"units": ANY, "precision": ANY, "limits": ANY}, "number": {"units": ANY, "precision": ANY, "limits": ANY}, - "enum": {"limits": ANY}, + "enum": {}, "string": {"units": ANY, "precision": ANY, "limits": ANY}, }, } @@ -260,7 +247,7 @@ def get_dtype_numpy(suffix: str) -> str: # type: ignore if "float" in suffix or "double" in suffix: return " str: # type: ignore elif "64" in suffix: int_str += "8" else: - int_str += "4" + int_str += "8" return int_str if "str" in suffix or "enum" in suffix: return "|S40" @@ -311,70 +298,70 @@ def get_dtype_numpy(suffix: str) -> str: # type: ignore (MyEnum, "enum", MyEnum.b, MyEnum.c, {"ca", "pva"}), # numpy arrays of numpy types ( - npt.NDArray[np.int8], + Array1D[np.int8], "int8a", [-128, 127], [-8, 3, 44], {"pva"}, ), ( - npt.NDArray[np.uint8], + Array1D[np.uint8], "uint8a", [0, 255], [218], {"ca", "pva"}, ), ( - npt.NDArray[np.int16], + Array1D[np.int16], "int16a", [-32768, 32767], [-855], {"ca", "pva"}, ), ( - npt.NDArray[np.uint16], + Array1D[np.uint16], "uint16a", [0, 65535], [5666], {"pva"}, ), ( - npt.NDArray[np.int32], + Array1D[np.int32], "int32a", [-2147483648, 2147483647], [-2], {"ca", "pva"}, ), ( - npt.NDArray[np.uint32], + Array1D[np.uint32], "uint32a", [0, 4294967295], [1022233], {"pva"}, ), ( - npt.NDArray[np.int64], + Array1D[np.int64], "int64a", [-2147483649, 2147483648], [-3], {"pva"}, ), ( - npt.NDArray[np.uint64], + Array1D[np.uint64], "uint64a", [0, 4294967297], [995444], {"pva"}, ), ( - npt.NDArray[np.float32], + Array1D[np.float32], "float32a", [0.000002, -123.123], [1.0], {"ca", "pva"}, ), ( - npt.NDArray[np.float64], + Array1D[np.float64], "float64a", [0.1, -12345678.123], [0.2], @@ -385,14 +372,7 @@ def get_dtype_numpy(suffix: str) -> str: # type: ignore "stra", ["five", "six", "seven"], ["nine", "ten"], - {"pva"}, - ), - ( - npt.NDArray[np.str_], - "stra", - ["five", "six", "seven"], - ["nine", "ten"], - {"ca"}, + {"pva", "ca"}, ), # Can't do long strings until https://github.com/epics-base/pva2pva/issues/17 # (str, "longstr", ls1, ls2), @@ -476,21 +456,22 @@ async def test_bool_conversion_of_enum(ioc: IOC, suffix: str, tmp_path: Path) -> async def test_error_raised_on_disconnected_PV(ioc: IOC) -> None: if ioc.protocol == "pva": - err = NotConnected expected = "pva://Disconnect" elif ioc.protocol == "ca": - err = CANothing - expected = "Disconnect: User specified timeout on IO operation expired" - with pytest.raises(err, match=expected): - await put_error( - ioc, - suffix="bool", - put_value=False, - datatype=bool, - ) + expected = "ca://Disconnect" + else: + raise TypeError() + backend = await ioc.make_backend(bool, "bool") + signal = SignalRW(backend) + # The below will work without error + await signal.set(False) + # Change the name of write_pv to mock disconnection + backend.__setattr__("write_pv", "Disconnect") + with pytest.raises(asyncio.TimeoutError, match=expected): + await signal.set(True, timeout=0.1) -class BadEnum(str, Enum): +class BadEnum(StrictEnum): a = "Aaa" b = "B" c = "Ccc" @@ -502,12 +483,12 @@ def test_enum_equality(): possibly more. """ - class GeneratedChoices(str, Enum): + class GeneratedChoices(StrictEnum): a = "Aaa" b = "B" c = "Ccc" - class ExtendedGeneratedChoices(str, Enum): + class ExtendedGeneratedChoices(StrictEnum): a = "Aaa" b = "B" c = "Ccc" @@ -532,75 +513,108 @@ class EnumNoString(Enum): a = "Aaa" +class SubsetEnumWrongChoices(SubsetEnum): + a = "Aaa" + b = "B" + c = "Ccc" + + @pytest.mark.parametrize( - "typ, suff, error", + "typ, suff, errors", [ ( BadEnum, "enum", ( - "has choices ('Aaa', 'Bbb', 'Ccc'), which do not match " - ", which has ('Aaa', 'B', 'Ccc')" + "has choices ('Aaa', 'Bbb', 'Ccc')", + "but ", + "requested ['Aaa', 'B', 'Ccc'] to be strictly equal", ), ), ( - rt_enum := SubsetEnum["Aaa", "B", "Ccc"], + SubsetEnumWrongChoices, "enum", ( - "has choices ('Aaa', 'Bbb', 'Ccc'), " - # SubsetEnum string output isn't deterministic - f"which is not a superset of {str(rt_enum)}." + "has choices ('Aaa', 'Bbb', 'Ccc')", + "but ", + "requested ['Aaa', 'B', 'Ccc'] to be a subset", ), ), - (int, "str", "has type str not int"), - (str, "float", "has type float not str"), - (str, "stra", "has type [str] not str"), - (int, "uint8a", "has type [uint8] not int"), + ( + int, + "str", + ("with inferred datatype str", "cannot be coerced to int"), + ), + ( + str, + "float", + ("with inferred datatype float", "cannot be coerced to str"), + ), + ( + str, + "stra", + ("with inferred datatype Sequence[str]", "cannot be coerced to str"), + ), + ( + int, + "uint8a", + ("with inferred datatype Array1D[np.uint8]", "cannot be coerced to int"), + ), ( float, "enum", + ("with inferred datatype str", "cannot be coerced to float"), + ), + ( + Array1D[np.int32], + "float64a", ( - "has choices ('Aaa', 'Bbb', 'Ccc'). " - "Use an Enum or SubsetEnum to represent this." + "with inferred datatype Array1D[np.float64]", + "cannot be coerced to Array1D[np.int32]", ), ), - (npt.NDArray[np.int32], "float64a", "has type [float64] not [int32]"), ], ) -async def test_backend_wrong_type_errors(ioc: IOC, typ, suff, error): - with pytest.raises( - TypeError, match=re.escape(f"{PV_PREFIX}:{ioc.protocol}:{suff} {error}") - ): +async def test_backend_wrong_type_errors(ioc: IOC, typ, suff, errors): + with pytest.raises(TypeError) as cm: await ioc.make_backend(typ, suff) + for error in errors: + assert error in str(cm.value) async def test_backend_put_enum_string(ioc: IOC) -> None: backend = await ioc.make_backend(MyEnum, "enum2") # Don't do this in production code, but allow on CLI - await backend.put("Ccc") # type: ignore + await backend.put("Ccc", wait=True) # type: ignore assert MyEnum.c == await backend.get_value() async def test_backend_enum_which_doesnt_inherit_string(ioc: IOC) -> None: with pytest.raises(TypeError): backend = await ioc.make_backend(EnumNoString, "enum2") - await backend.put("Aaa") + await backend.put("Aaa", wait=True) async def test_backend_get_setpoint(ioc: IOC) -> None: backend = await ioc.make_backend(MyEnum, "enum2") - await backend.put("Ccc") + await backend.put("Ccc", wait=True) assert await backend.get_setpoint() == MyEnum.c -def approx_table(table): - return {k: pytest.approx(v) for k, v in table.items()} +def approx_table(datatype: type[Table], table: Table): + new_table = datatype(**table.model_dump()) + for k, v in new_table: + if datatype is Table: + setattr(new_table, k, pytest.approx(v)) + else: + object.__setattr__(new_table, k, pytest.approx(v)) + return new_table -class MyTable(TypedDict): - bool: npt.NDArray[np.bool_] - int: npt.NDArray[np.int32] - float: npt.NDArray[np.float64] +class MyTable(Table): + bool: Array1D[np.bool_] + int: Array1D[np.int32] + float: Array1D[np.float64] str: Sequence[str] enum: Sequence[MyEnum] @@ -623,19 +637,6 @@ async def test_pva_table(ioc: IOC) -> None: str=["Hello", "Bat"], enum=[MyEnum.c, MyEnum.b], ) - # TODO: what should this be for a variable length table? - datakey = { - "dtype": "object", - "shape": [], - "source": "test-source", - "dtype_numpy": "", - "limits": { - "alarm": {"high": None, "low": None}, - "control": {"high": None, "low": None}, - "display": {"high": None, "low": None}, - "warning": {"high": None, "low": None}, - }, - } # Make and connect the backend for t, i, p in [(MyTable, initial, put), (None, put, initial)]: backend = await ioc.make_backend(t, "table") @@ -643,50 +644,31 @@ async def test_pva_table(ioc: IOC) -> None: q = MonitorQueue(backend) try: # Check datakey - assert datakey == await backend.get_datakey("test-source") + dk = await backend.get_datakey("test-source") + expected_dk = { + "dtype": "array", + "shape": [len(i)], + "source": "test-source", + "dtype_numpy": [ + # natively bool fields are uint8, so if we don't provide a Table + # subclass to specify bool, that is what we get + ("bool", "|b1" if t else "|u1"), + ("int", " None: - if ioc.protocol == "ca": - # CA can't do structure - return - # Make and connect the backend - backend = await ioc.make_backend(dict[str, Any], "pvi") - - # Make a monitor queue that will monitor for updates - q = MonitorQueue(backend) - - expected = { - "pvi": { - "width": { - "rw": f"{PV_PREFIX}:{ioc.protocol}:width", - }, - "height": { - "rw": f"{PV_PREFIX}:{ioc.protocol}:height", - }, - }, - "record": ANY, - } - - try: - # Check datakey - with pytest.raises(NotImplementedError): - await backend.get_datakey("") - # Check initial value - await q.assert_updates(expected) - await backend.get_value() - - finally: - q.close() - - async def test_pva_ntdarray(ioc: IOC): if ioc.protocol == "ca": # CA can't do ndarray @@ -695,11 +677,11 @@ async def test_pva_ntdarray(ioc: IOC): put = np.array([1, 2, 3, 4, 5, 6], dtype=np.int64).reshape((2, 3)) initial = np.zeros_like(put) - backend = await ioc.make_backend(npt.NDArray[np.int64], "ntndarray") + backend = await ioc.make_backend(np.ndarray, "ntndarray") # Backdoor into the "raw" data underlying the NDArray in QSrv # not supporting direct writes to NDArray at the moment. - raw_data_backend = await ioc.make_backend(npt.NDArray[np.int64], "ntndarray:data") + raw_data_backend = await ioc.make_backend(Array1D[np.int64], "ntndarray:data") # Make a monitor queue that will monitor for updates for i, p in [(initial, put), (put, initial)]: @@ -707,13 +689,12 @@ async def test_pva_ntdarray(ioc: IOC): assert { "source": "test-source", "dtype": "array", - "dtype_numpy": "", + "dtype_numpy": " None: + def __init__(self, pvi: dict[str, Any]) -> None: self.pvi = pvi def get(self, item: str): @@ -40,7 +39,7 @@ def get(self, item: str): class MockCtxt: - def __init__(self, pvi: dict[str, _PVIEntry]) -> None: + def __init__(self, pvi: dict[str, Any]) -> None: self.pvi = copy.copy(pvi) def get(self, pv: str, timeout: float = 0.0): @@ -55,16 +54,8 @@ class CommonPandaBlocksNoData(Device): seq: DeviceVector[SeqBlock] class Panda(CommonPandaBlocksNoData): - def __init__(self, prefix: str, name: str = ""): - self._prefix = prefix - create_children_from_annotations(self) - super().__init__(name) - - async def connect(self, mock: bool = False, timeout: float = DEFAULT_TIMEOUT): - await fill_pvi_entries( - self, self._prefix + "PVI", timeout=timeout, mock=mock - ) - await super().connect(mock=mock, timeout=timeout) + def __init__(self, uri: str, name: str = ""): + super().__init__(name=name, connector=fastcs_connector(self, uri)) yield Panda @@ -72,7 +63,7 @@ async def connect(self, mock: bool = False, timeout: float = DEFAULT_TIMEOUT): @pytest.fixture async def mock_panda(panda_t): async with DeviceCollector(mock=True): - mock_panda = panda_t("PANDAQSRV:", "mock_panda") + mock_panda = panda_t("PANDAQSRV:") assert mock_panda.name == "mock_panda" yield mock_panda @@ -126,13 +117,13 @@ async def test_panda_children_connected(mock_panda): async def test_panda_with_missing_blocks(panda_pva, panda_t): - panda = panda_t("PANDAQSRVI:") - with pytest.raises(RuntimeError) as exc: + panda = panda_t("PANDAQSRVI:", name="mypanda") + with pytest.raises( + RuntimeError, + match="mypanda: cannot provision {'pcap'} from PANDAQSRVI:PVI: {'pulse1': " + + "{'d': 'PANDAQSRVI:PULSE1:PVI'}, 'seq1': {'d': 'PANDAQSRVI:SEQ1:PVI'}}", + ): await panda.connect() - assert ( - str(exc.value) - == "sub device `pcap:` was not provided by pvi" - ) async def test_panda_with_extra_blocks_and_signals(panda_pva, panda_t): @@ -158,16 +149,16 @@ async def test_panda_gets_types_from_common_class(panda_pva, panda_t): assert isinstance(panda.pulse[1], PulseBlock) # others are just Devices - assert isinstance(panda.extra, Device) + assert isinstance(panda.extra, DeviceVector) # predefined signals get set up with the correct datatype - assert panda.pcap.active._backend.datatype is bool + assert panda.pcap.active._connector.backend.datatype is bool # works with custom datatypes - assert panda.seq[1].table._backend.datatype is SeqTable + assert panda.seq[1].table._connector.backend.datatype is SeqTable # others are given the None datatype - assert panda.pcap.newsignal._backend.datatype is None + assert panda.pcap.newsignal._connector.backend.datatype is None async def test_panda_block_missing_signals(panda_pva, panda_t): diff --git a/tests/fastcs/panda/test_panda_control.py b/tests/fastcs/panda/test_panda_control.py index 0920cd394f..e568dc3080 100644 --- a/tests/fastcs/panda/test_panda_control.py +++ b/tests/fastcs/panda/test_panda_control.py @@ -4,25 +4,17 @@ import pytest -from ophyd_async.core import DEFAULT_TIMEOUT, DetectorTrigger, Device, DeviceCollector -from ophyd_async.core._detector import TriggerInfo -from ophyd_async.epics.pvi import fill_pvi_entries +from ophyd_async.core import DetectorTrigger, Device, DeviceCollector, TriggerInfo from ophyd_async.epics.signal import epics_signal_rw +from ophyd_async.fastcs.core import fastcs_connector from ophyd_async.fastcs.panda import CommonPandaBlocks, PandaPcapController @pytest.fixture async def mock_panda(): class Panda(CommonPandaBlocks): - def __init__(self, prefix: str, name: str = ""): - self._prefix = prefix - super().__init__(name) - - async def connect(self, mock: bool = False, timeout: float = DEFAULT_TIMEOUT): - await fill_pvi_entries( - self, self._prefix + "PVI", timeout=timeout, mock=mock - ) - await super().connect(mock=mock, timeout=timeout) + def __init__(self, uri: str, name: str = ""): + super().__init__(name=name, connector=fastcs_connector(self, uri)) async with DeviceCollector(mock=True): mock_panda = Panda("PANDACONTROLLER:", name="mock_panda") diff --git a/tests/fastcs/panda/test_panda_utils.py b/tests/fastcs/panda/test_panda_utils.py index c099614981..c60d9210af 100644 --- a/tests/fastcs/panda/test_panda_utils.py +++ b/tests/fastcs/panda/test_panda_utils.py @@ -2,9 +2,9 @@ import yaml from bluesky import RunEngine -from ophyd_async.core import DEFAULT_TIMEOUT, DeviceCollector, load_device, save_device -from ophyd_async.epics.pvi import fill_pvi_entries +from ophyd_async.core import DeviceCollector, load_device, save_device from ophyd_async.epics.signal import epics_signal_rw +from ophyd_async.fastcs.core import fastcs_connector from ophyd_async.fastcs.panda import ( CommonPandaBlocks, DataBlock, @@ -18,15 +18,8 @@ async def get_mock_panda(): class Panda(CommonPandaBlocks): data: DataBlock - def __init__(self, prefix: str, name: str = ""): - self._prefix = prefix - super().__init__(name) - - async def connect(self, mock: bool = False, timeout: float = DEFAULT_TIMEOUT): - await fill_pvi_entries( - self, self._prefix + "PVI", timeout=timeout, mock=mock - ) - await super().connect(mock=mock, timeout=timeout) + def __init__(self, uri: str, name: str = ""): + super().__init__(name=name, connector=fastcs_connector(self, uri)) async with DeviceCollector(mock=True): mock_panda = Panda("PANDA") @@ -94,7 +87,6 @@ def check_equal_with_seq_tables(actual, expected): "pulse.1.width": 0.0, "pulse.2.delay": 0.0, "pulse.2.width": 0.0, - "seq.1.active": False, "seq.1.table": { "outa1": [False], "outa2": [False], @@ -136,7 +128,6 @@ def check_equal_with_seq_tables(actual, expected): "time2": [], "trigger": [], }, - "seq.2.active": False, "seq.2.repeats": 0, "seq.2.prescale": 0.0, "seq.2.enable": "ZERO", diff --git a/tests/fastcs/panda/test_table.py b/tests/fastcs/panda/test_seq_table.py similarity index 73% rename from tests/fastcs/panda/test_table.py rename to tests/fastcs/panda/test_seq_table.py index ed963c91f5..1130cac330 100644 --- a/tests/fastcs/panda/test_table.py +++ b/tests/fastcs/panda/test_seq_table.py @@ -4,8 +4,7 @@ import pytest from pydantic import ValidationError -from ophyd_async.fastcs.panda import SeqTable -from ophyd_async.fastcs.panda._table import SeqTrigger +from ophyd_async.fastcs.panda import SeqTable, SeqTrigger def test_seq_table_converts_lists(): @@ -20,7 +19,7 @@ def test_seq_table_converts_lists(): def test_seq_table_validation_errors(): - with pytest.raises(ValidationError, match="81 validation errors for SeqTable"): + with pytest.raises(ValidationError, match="17 validation errors for SeqTable"): SeqTable( repeats=0, trigger=SeqTrigger.IMMEDIATE, @@ -42,53 +41,56 @@ def test_seq_table_validation_errors(): ) large_seq_table = SeqTable( - repeats=np.zeros(4095, dtype=np.int32), - trigger=["Immediate"] * 4095, - position=np.zeros(4095, dtype=np.int32), - time1=np.zeros(4095, dtype=np.int32), - outa1=np.zeros(4095, dtype=np.bool_), - outb1=np.zeros(4095, dtype=np.bool_), - outc1=np.zeros(4095, dtype=np.bool_), - outd1=np.zeros(4095, dtype=np.bool_), - oute1=np.zeros(4095, dtype=np.bool_), - outf1=np.zeros(4095, dtype=np.bool_), - time2=np.zeros(4095, dtype=np.int32), - outa2=np.zeros(4095, dtype=np.bool_), - outb2=np.zeros(4095, dtype=np.bool_), - outc2=np.zeros(4095, dtype=np.bool_), - outd2=np.zeros(4095, dtype=np.bool_), - oute2=np.zeros(4095, dtype=np.bool_), - outf2=np.zeros(4095, dtype=np.bool_), + repeats=np.zeros(4096, dtype=np.uint16), + trigger=[SeqTrigger.IMMEDIATE] * 4096, + position=np.zeros(4096, dtype=np.int32), + time1=np.zeros(4096, dtype=np.uint32), + outa1=np.zeros(4096, dtype=np.bool_), + outb1=np.zeros(4096, dtype=np.bool_), + outc1=np.zeros(4096, dtype=np.bool_), + outd1=np.zeros(4096, dtype=np.bool_), + oute1=np.zeros(4096, dtype=np.bool_), + outf1=np.zeros(4096, dtype=np.bool_), + time2=np.zeros(4096, dtype=np.uint32), + outa2=np.zeros(4096, dtype=np.bool_), + outb2=np.zeros(4096, dtype=np.bool_), + outc2=np.zeros(4096, dtype=np.bool_), + outd2=np.zeros(4096, dtype=np.bool_), + oute2=np.zeros(4096, dtype=np.bool_), + outf2=np.zeros(4096, dtype=np.bool_), ) with pytest.raises( ValidationError, match=( "1 validation error for SeqTable\n " - "Assertion failed, Length 4096 not in range." + "Assertion failed, Length 4097 is too long." ), ): large_seq_table + SeqTable.row() + # TODO: validation of numpy types disabled until bool -> uint8 CA issue resolved + # with pytest.raises( + # ValidationError, + # match="1 validation error for SeqTable\n Assertion failed, repeats: " + # + "expected dtype uint16, got int32", + # ): + # row_one = SeqTable.row() + # wrong_types = { + # field_name: field_value.astype(np.int32) + # if isinstance(field_value, np.ndarray) + # else field_value + # for field_name, field_value in row_one + # } + # SeqTable(**wrong_types) with pytest.raises( ValidationError, - match="12 validation errors for SeqTable", - ): - row_one = SeqTable.row() - wrong_types = { - field_name: field_value.astype(np.unicode_) - for field_name, field_value in row_one - if isinstance(field_value, np.ndarray) - } - SeqTable(**wrong_types) - with pytest.raises( - TypeError, - match="Row column should be numpy arrays or sequence of string `Enum`", + match="trigger.0\n Input should be 'Immediate', 'BITA=0'", ): SeqTable.row(trigger="A") def test_seq_table_pva_conversion(): pva_dict = { - "repeats": np.array([1, 2, 3, 4], dtype=np.int32), + "repeats": np.array([1, 2, 3, 4], dtype=np.uint16), "trigger": [ SeqTrigger.IMMEDIATE, SeqTrigger.IMMEDIATE, @@ -96,14 +98,14 @@ def test_seq_table_pva_conversion(): SeqTrigger.IMMEDIATE, ], "position": np.array([1, 2, 3, 4], dtype=np.int32), - "time1": np.array([1, 0, 1, 0], dtype=np.int32), + "time1": np.array([1, 0, 1, 0], dtype=np.uint32), "outa1": np.array([1, 0, 1, 0], dtype=np.bool_), "outb1": np.array([1, 0, 1, 0], dtype=np.bool_), "outc1": np.array([1, 0, 1, 0], dtype=np.bool_), "outd1": np.array([1, 0, 1, 0], dtype=np.bool_), "oute1": np.array([1, 0, 1, 0], dtype=np.bool_), "outf1": np.array([1, 0, 1, 0], dtype=np.bool_), - "time2": np.array([1, 2, 3, 4], dtype=np.int32), + "time2": np.array([1, 2, 3, 4], dtype=np.uint32), "outa2": np.array([1, 0, 1, 0], dtype=np.bool_), "outb2": np.array([1, 0, 1, 0], dtype=np.bool_), "outc2": np.array([1, 0, 1, 0], dtype=np.bool_), @@ -221,50 +223,20 @@ def _assert_col_equal(column1, column2): ): _assert_col_equal(column1, column2) - assert np.array_equal( - seq_table_from_pva_dict.numpy_columns(), - [ - np.array([1, 2, 3, 4], dtype=np.int32), - np.array( - [ - "Immediate", - "Immediate", - "BITC=0", - "Immediate", - ], - dtype=" PandaHDFWriter: dp = StaticPathProvider(fp, tmp_path / mock_panda.name, create_dir_depth=-1) async with DeviceCollector(mock=True): writer = PandaHDFWriter( - prefix="TEST-PANDA", path_provider=dp, name_provider=lambda: mock_panda.name, panda_data_block=mock_panda.data, @@ -138,11 +117,11 @@ async def test_open_returns_correct_descriptors( description = await mock_writer.open() # to make capturing status not time out # Check if empty datasets table leads to warning log message - if len(table["name"]) == 0: + if len(table.name) == 0: assert "DATASETS table is empty!" in caplog.text for key, entry, expected_key in zip( - description.keys(), description.values(), table["name"], strict=False + description.keys(), description.values(), table.name, strict=False ): assert key == expected_key assert entry == { @@ -222,9 +201,9 @@ def assert_resource_document(name, resource_doc): assert type(mock_writer._file) is HDFFile assert mock_writer._file._last_emitted == 1 - for i in range(len(table["name"])): + for i in range(len(table.name)): resource_doc = mock_writer._file._bundles[i].stream_resource_doc - name = table["name"][i] + name = table.name[i] assert_resource_document(name=name, resource_doc=resource_doc) @@ -238,7 +217,6 @@ async def test_oserror_when_hdf_dir_does_not_exist(tmp_path, mock_panda): ) async with DeviceCollector(mock=True): writer = PandaHDFWriter( - prefix="TEST-PANDA", path_provider=dp, name_provider=lambda: "test-panda", panda_data_block=mock_panda.data, diff --git a/tests/plan_stubs/test_ensure_connected.py b/tests/plan_stubs/test_ensure_connected.py index c876932f58..5737deef72 100644 --- a/tests/plan_stubs/test_ensure_connected.py +++ b/tests/plan_stubs/test_ensure_connected.py @@ -1,6 +1,6 @@ import pytest -from ophyd_async.core import Device, MockSignalBackend, NotConnected, SignalRW +from ophyd_async.core import Device, NotConnected, soft_signal_rw from ophyd_async.epics.signal import epics_signal_rw from ophyd_async.plan_stubs import ensure_connected @@ -24,7 +24,7 @@ def connect(): assert isinstance(device1.signal._connect_task.exception(), NotConnected) - device1.signal = SignalRW(MockSignalBackend(str)) + device1.signal = soft_signal_rw(str) RE(connect()) assert device1.signal._connect_task.exception() is None diff --git a/tests/plan_stubs/test_fly.py b/tests/plan_stubs/test_fly.py index 4181da1c3a..1e6c6afc42 100644 --- a/tests/plan_stubs/test_fly.py +++ b/tests/plan_stubs/test_fly.py @@ -24,8 +24,8 @@ observe_value, set_mock_value, ) -from ophyd_async.epics.pvi import fill_pvi_entries from ophyd_async.epics.signal import epics_signal_rw +from ophyd_async.fastcs.core import fastcs_connector from ophyd_async.fastcs.panda import ( CommonPandaBlocks, StaticPcompTriggerLogic, @@ -170,15 +170,8 @@ def dummy_arm_2(self=None): @pytest.fixture async def mock_panda(): class Panda(CommonPandaBlocks): - def __init__(self, prefix: str, name: str = ""): - self._prefix = prefix - super().__init__(name) - - async def connect(self, mock: bool = False, timeout: float = DEFAULT_TIMEOUT): - await fill_pvi_entries( - self, self._prefix + "PVI", timeout=timeout, mock=mock - ) - await super().connect(mock, timeout) + def __init__(self, uri: str, name: str = ""): + super().__init__(name=name, connector=fastcs_connector(self, uri)) async with DeviceCollector(mock=True): mock_panda = Panda("PANDAQSRV:", "mock_panda") diff --git a/tests/tango/test_base_device.py b/tests/tango/test_base_device.py index 3c23461f99..e1ed92eb4f 100644 --- a/tests/tango/test_base_device.py +++ b/tests/tango/test_base_device.py @@ -5,12 +5,11 @@ import bluesky.plan_stubs as bps import bluesky.plans as bp import numpy as np -import numpy.typing as npt import pytest from bluesky import RunEngine import tango -from ophyd_async.core import DeviceCollector, HintedSignal, SignalRW, T +from ophyd_async.core import Array1D, DeviceCollector, HintedSignal, SignalRW, T from ophyd_async.tango import TangoReadable, get_python_type from ophyd_async.tango.demo import ( DemoCounter, @@ -176,7 +175,7 @@ def raise_exception_cmd(self): class TestTangoReadable(TangoReadable): __test__ = False justvalue: SignalRW[int] - array: SignalRW[npt.NDArray[float]] + array: SignalRW[Array1D[np.float64]] limitedvalue: SignalRW[float] def __init__( @@ -325,16 +324,9 @@ async def test_connect(tango_test_device): @pytest.mark.asyncio async def test_set_trl(tango_test_device): values, description = await describe_class(tango_test_device) - - # async with DeviceCollector(): - # test_device = TestTangoReadable(trl=tango_test_device) test_device = TestTangoReadable(name="test_device") - with pytest.raises(ValueError) as excinfo: - test_device.set_trl(0) - assert "TRL must be a string." in str(excinfo.value) - - test_device.set_trl(tango_test_device) + test_device._connector.trl = tango_test_device await test_device.connect() assert test_device.name == "test_device" @@ -350,12 +342,12 @@ async def test_connect_proxy(tango_test_device, proxy: bool | None): test_device = TestTangoReadable(trl=tango_test_device) test_device.proxy = None await test_device.connect() - assert isinstance(test_device.proxy, tango._tango.DeviceProxy) + assert isinstance(test_device._connector.proxy, tango._tango.DeviceProxy) elif proxy: proxy = await AsyncDeviceProxy(tango_test_device) test_device = TestTangoReadable(device_proxy=proxy) await test_device.connect() - assert isinstance(test_device.proxy, tango._tango.DeviceProxy) + assert isinstance(test_device._connector.proxy, tango._tango.DeviceProxy) else: proxy = None test_device = TestTangoReadable(device_proxy=proxy) diff --git a/tests/tango/test_tango_signals.py b/tests/tango/test_tango_signals.py index 40a5b591aa..b128709be0 100644 --- a/tests/tango/test_tango_signals.py +++ b/tests/tango/test_tango_signals.py @@ -3,7 +3,6 @@ import time from enum import Enum, IntEnum from random import choice -from typing import Any import numpy as np import numpy.typing as npt @@ -14,7 +13,6 @@ from ophyd_async.core import SignalBackend, SignalR, SignalRW, SignalW, SignalX, T from ophyd_async.tango import ( TangoSignalBackend, - __tango_signal_auto, tango_signal_r, tango_signal_rw, tango_signal_w, @@ -27,6 +25,11 @@ from tango.test_context import MultiDeviceTestContext from tango.test_utils import assert_close + +def __tango_signal_auto(*args, **kwargs): + raise RuntimeError("Fix this later") + + # -------------------------------------------------------------------- """ Since TangoTest does not support EchoMode, we create our own Device. @@ -248,7 +251,7 @@ async def make_backend( backend = TangoSignalBackend(typ, pv, pv) backend.allow_events(allow_events) if connect: - await asyncio.wait_for(backend.connect(), 10) + await backend.connect(1) return backend @@ -261,35 +264,25 @@ async def prepare_device(echo_device: str, pv: str, put_value: T) -> None: # -------------------------------------------------------------------- class MonitorQueue: def __init__(self, backend: SignalBackend): - self.updates: asyncio.Queue[tuple[Reading, Any]] = asyncio.Queue() + self.updates: asyncio.Queue[Reading] = asyncio.Queue() self.backend = backend - self.subscription = backend.set_callback(self.add_reading_value) - - def add_reading_value(self, reading: Reading, value): - self.updates.put_nowait((reading, value)) + self.subscription = backend.set_callback(self.updates.put_nowait) async def assert_updates(self, expected_value): expected_reading = { "timestamp": pytest.approx(time.time(), rel=0.1), "alarm_severity": 0, } - update_reading, update_value = await self.updates.get() - get_reading = await self.backend.get_reading() - # If update_value is a numpy.ndarray, convert it to a list - if isinstance(update_value, np.ndarray): - update_value = update_value.tolist() - assert_close(update_value, expected_value) - assert_close(await self.backend.get_value(), expected_value) - - update_reading = dict(update_reading) + update_reading = dict(await asyncio.wait_for(self.updates.get(), timeout=5)) update_value = update_reading.pop("value") - - get_reading = dict(get_reading) - get_value = get_reading.pop("value") - - assert update_reading == expected_reading == get_reading assert_close(update_value, expected_value) - assert_close(get_value, expected_value) + backend_reading = dict( + await asyncio.wait_for(self.backend.get_reading(), timeout=5) + ) + backend_reading.pop("value") + backend_value = await asyncio.wait_for(self.backend.get_value(), timeout=5) + assert_close(backend_value, expected_value) + assert update_reading == expected_reading == backend_reading def close(self): self.backend.set_callback(None) @@ -348,7 +341,7 @@ async def test_backend_get_put_monitor_attr( get_test_descriptor(py_type, initial_value, False), py_type, ), - timeout=10, # Timeout in seconds + timeout=100, # Timeout in seconds ) except asyncio.TimeoutError: pytest.fail("Test timed out") @@ -606,6 +599,7 @@ async def test_tango_signal_x(tango_test_device: str, use_proxy: bool): # -------------------------------------------------------------------- @pytest.mark.asyncio +@pytest.mark.skip("Not sure if we need tango_signal_auto") @pytest.mark.parametrize( "pv, tango_type, d_format, py_type, initial_value, put_value, use_proxy", [ @@ -674,6 +668,7 @@ async def _test_signal(dtype, proxy): # -------------------------------------------------------------------- @pytest.mark.asyncio +@pytest.mark.skip("Not sure if we need tango_signal_auto") @pytest.mark.parametrize( "pv, tango_type, d_format, py_type, initial_value, put_value, use_dtype, use_proxy", [ @@ -748,6 +743,7 @@ async def _test_signal(dtype, proxy): # -------------------------------------------------------------------- @pytest.mark.asyncio +@pytest.mark.skip("Not sure if we need tango_signal_auto") @pytest.mark.parametrize("use_proxy", [True, False]) async def test_tango_signal_auto_cmds_void(tango_test_device: str, use_proxy: bool): proxy = await DeviceProxy(tango_test_device) if use_proxy else None @@ -764,6 +760,7 @@ async def test_tango_signal_auto_cmds_void(tango_test_device: str, use_proxy: bo # -------------------------------------------------------------------- @pytest.mark.asyncio +@pytest.mark.skip("Not sure if we need tango_signal_auto") async def test_tango_signal_auto_badtrl(tango_test_device: str): proxy = await DeviceProxy(tango_test_device) with pytest.raises(RuntimeError) as exc_info: diff --git a/tests/tango/test_tango_transport.py b/tests/tango/test_tango_transport.py index 4e951ba80e..036946fbe0 100644 --- a/tests/tango/test_tango_transport.py +++ b/tests/tango/test_tango_transport.py @@ -207,9 +207,9 @@ async def test_get_tango_trl( proxy = await DeviceProxy(tango_test_device) if proxy_needed else None if should_raise: with pytest.raises(RuntimeError): - await get_tango_trl(trl, proxy) + await get_tango_trl(trl, proxy, 1) else: - result = await get_tango_trl(trl, proxy) + result = await get_tango_trl(trl, proxy, 1) assert isinstance(result, expected_type) @@ -328,9 +328,9 @@ async def test_attribute_subscribe_callback(echo_device): attr_proxy = backend.proxies[source] val = None - def callback(reading, value): + def callback(reading): nonlocal val - val = value + val = reading["value"] attr_proxy.subscribe_callback(callback) assert attr_proxy.has_subscription() @@ -356,7 +356,7 @@ async def test_attribute_unsubscribe_callback(echo_device): backend = await make_backend(float, source) attr_proxy = backend.proxies[source] - def callback(reading, value): + def callback(reading): pass attr_proxy.subscribe_callback(callback) @@ -385,9 +385,9 @@ async def test_attribute_poll(tango_test_device): attr_proxy = AttributeProxy(device_proxy, "floatvalue") attr_proxy.support_events = False - def callback(reading, value): + def callback(reading): nonlocal val - val = value + val = reading["value"] def bad_callback(): pass @@ -445,9 +445,9 @@ async def test_attribute_poll_stringsandarrays(tango_test_device, attr): attr_proxy = AttributeProxy(device_proxy, attr) attr_proxy.support_events = False - def callback(reading, value): + def callback(reading): nonlocal val - val = value + val = reading["value"] val = None attr_proxy.set_polling(True, 0.1) @@ -592,7 +592,7 @@ async def test_tango_transport_source(echo_device): await prepare_device(echo_device, "float_scalar_attr", 1.0) source = echo_device + "/" + "float_scalar_attr" transport = await make_backend(float, source) - transport_source = transport.source("") + transport_source = transport.source("", True) assert transport_source == source @@ -620,10 +620,10 @@ async def test_tango_transport_connect(echo_device): source = echo_device + "/" + "float_scalar_attr" backend = await make_backend(float, source, connect=False) assert backend is not None - await backend.connect() + await backend.connect(1) backend.read_trl = "" with pytest.raises(RuntimeError) as exc_info: - await backend.connect() + await backend.connect(1) assert "trl not set" in str(exc_info.value) @@ -633,11 +633,11 @@ async def test_tango_transport_connect_and_store_config(echo_device): await prepare_device(echo_device, "float_scalar_attr", 1.0) source = echo_device + "/" + "float_scalar_attr" transport = await make_backend(float, source, connect=False) - await transport._connect_and_store_config(source) + await transport._connect_and_store_config(source, 1) assert transport.trl_configs[source] is not None with pytest.raises(RuntimeError) as exc_info: - await transport._connect_and_store_config("") + await transport._connect_and_store_config("", 1) assert "trl not set" in str(exc_info.value) @@ -652,8 +652,8 @@ async def test_tango_transport_put(echo_device): await transport.put(1.0) assert "Not connected" in str(exc_info.value) - await transport.connect() - source = transport.source("") + await transport.connect(1) + source = transport.source("", True) await transport.put(2.0) val = await transport.proxies[source].get_w_value() assert val == 2.0 @@ -665,7 +665,7 @@ async def test_tango_transport_get_datakey(echo_device): await prepare_device(echo_device, "float_scalar_attr", 1.0) source = echo_device + "/" + "float_scalar_attr" transport = await make_backend(float, source, connect=False) - await transport.connect() + await transport.connect(1) datakey = await transport.get_datakey(source) assert datakey["source"] == source assert datakey["dtype"] == "number" @@ -683,7 +683,7 @@ async def test_tango_transport_get_reading(echo_device): await transport.put(1.0) assert "Not connected" in str(exc_info.value) - await transport.connect() + await transport.connect(1) reading = await transport.get_reading() assert reading["value"] == 1.0 @@ -699,7 +699,7 @@ async def test_tango_transport_get_value(echo_device): await transport.put(1.0) assert "Not connected" in str(exc_info.value) - await transport.connect() + await transport.connect(1) value = await transport.get_value() assert value == 1.0 @@ -715,7 +715,7 @@ async def test_tango_transport_get_setpoint(echo_device): await transport.put(1.0) assert "Not connected" in str(exc_info.value) - await transport.connect() + await transport.connect(1) new_setpoint = 2.0 await transport.put(new_setpoint) setpoint = await transport.get_setpoint() @@ -733,12 +733,12 @@ async def test_set_callback(echo_device): await transport.put(1.0) assert "Not connected" in str(exc_info.value) - await transport.connect() + await transport.connect(1) val = None - def callback(reading, value): + def callback(reading): nonlocal val - val = value + val = reading["value"] # Correct usage transport.set_callback(callback) @@ -800,7 +800,7 @@ async def test_tango_transport_read_and_write_trl(tango_test_device): # Test with existing proxy transport = TangoSignalBackend(float, read_trl, write_trl, device_proxy) - await transport.connect() + await transport.connect(1) reading = await transport.get_reading() initial_value = reading["value"] new_value = initial_value + 1.0 @@ -810,7 +810,7 @@ async def test_tango_transport_read_and_write_trl(tango_test_device): # Without pre-existing proxy transport = TangoSignalBackend(float, read_trl, write_trl, None) - await transport.connect() + await transport.connect(1) reading = await transport.get_reading() initial_value = reading["value"] new_value = initial_value + 1.0 @@ -828,7 +828,7 @@ async def test_tango_transport_read_only_trl(tango_test_device): # Test with existing proxy transport = TangoSignalBackend(int, read_trl, read_trl, device_proxy) - await transport.connect() + await transport.connect(1) with pytest.raises(RuntimeError) as exc_info: await transport.put(1) assert "is not writable" in str(exc_info.value) @@ -844,11 +844,11 @@ async def test_tango_transport_nonexistent_trl(tango_test_device): # Test with existing proxy transport = TangoSignalBackend(int, nonexistent_trl, nonexistent_trl, device_proxy) with pytest.raises(RuntimeError) as exc_info: - await transport.connect() + await transport.connect(1) assert "cannot be found" in str(exc_info.value) # Without pre-existing proxy transport = TangoSignalBackend(int, nonexistent_trl, nonexistent_trl, None) with pytest.raises(RuntimeError) as exc_info: - await transport.connect() + await transport.connect(1) assert "cannot be found" in str(exc_info.value) diff --git a/tests/test_data/test_yaml_save.yml b/tests/test_data/test_yaml_save.yml index 17dc9e61a9..d00d38dff9 100644 --- a/tests/test_data/test_yaml_save.yml +++ b/tests/test_data/test_yaml_save.yml @@ -1,16 +1,34 @@ -- pv_array_float32: [-3.4028234663852886e+38, 3.4028234663852886e+38, 1.1754943508222875e-38, - 1.401298464324817e-45, 0.0, 1.2339999675750732, 234000.0, 3.4499998946557753e-06] - pv_array_float64: [-1.7976931348623157e+308, 1.7976931348623157e+308, 2.2250738585072014e-308, - 5.0e-324, 0.0, 1.234, 234000.0, 3.45e-06] +- pv_array_float32: + [ + -3.4028234663852886e+38, + 3.4028234663852886e+38, + 1.1754943508222875e-38, + 1.401298464324817e-45, + 0.0, + 1.2339999675750732, + 234000.0, + 3.4499998946557753e-06, + ] + pv_array_float64: + [ + -1.7976931348623157e+308, + 1.7976931348623157e+308, + 2.2250738585072014e-308, + 5.0e-324, + 0.0, + 1.234, + 234000.0, + 3.45e-06, + ] pv_array_int16: [-32768, 32767, 0, 1, 2, 3, 4] pv_array_int32: [-2147483648, 2147483647, 0, 1, 2, 3, 4] pv_array_int64: [-9223372036854775808, 9223372036854775807, 0, 1, 2, 3, 4] pv_array_int8: [-128, 127, 0, 1, 2, 3, 4] pv_array_npstr: [one, two, three] pv_array_str: - - one - - two - - three + - one + - two + - three pv_array_uint16: [0, 65535, 0, 1, 2, 3, 4] pv_array_uint32: [0, 4294967295, 0, 1, 2, 3, 4] pv_array_uint64: [0, 18446744073709551615, 0, 1, 2, 3, 4] @@ -20,7 +38,10 @@ pv_float: 1.234 pv_int: 1 pv_protocol_device_abstraction: - some_int_field: 1 - some_pydantic_numpy_field_float: [1, 2, 3] - some_pydantic_numpy_field_int: [1, 2, 3] + some_enum: + - one + - two + - three + some_float: [0.0, 1.0, 2.0] + some_int: [0, 1, 2] pv_str: test_string From 23a0e8311b2e427a3ddf17604641549c1b31d6f3 Mon Sep 17 00:00:00 2001 From: "Tom C (DLS)" <101418278+coretl@users.noreply.github.com> Date: Wed, 30 Oct 2024 13:36:04 +0000 Subject: [PATCH 17/30] Update copier template to 2.4.0 (#628) * Update copier template to 2.4.0 * Choose a machine accessible URL for license --- .copier-answers.yml | 2 +- .github/CONTRIBUTING.md | 2 +- Dockerfile | 2 +- README.md | 4 ++-- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.copier-answers.yml b/.copier-answers.yml index 3461332e26..f382bb7996 100644 --- a/.copier-answers.yml +++ b/.copier-answers.yml @@ -1,5 +1,5 @@ # Changes here will be overwritten by Copier -_commit: 2.3.0 +_commit: 2.4.0 _src_path: gh:DiamondLightSource/python-copier-template author_email: tom.cobb@diamond.ac.uk author_name: Tom Cobb diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md index 27f6450d97..06311934d0 100644 --- a/.github/CONTRIBUTING.md +++ b/.github/CONTRIBUTING.md @@ -24,4 +24,4 @@ It is recommended that developers use a [vscode devcontainer](https://code.visua This project was created using the [Diamond Light Source Copier Template](https://github.com/DiamondLightSource/python-copier-template) for Python projects. -For more information on common tasks like setting up a developer environment, running the tests, and setting a pre-commit hook, see the template's [How-to guides](https://diamondlightsource.github.io/python-copier-template/2.3.0/how-to.html). +For more information on common tasks like setting up a developer environment, running the tests, and setting a pre-commit hook, see the template's [How-to guides](https://diamondlightsource.github.io/python-copier-template/2.4.0/how-to.html). diff --git a/Dockerfile b/Dockerfile index 4e70aa5621..d65b65cbbd 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,7 +1,7 @@ # The devcontainer should use the developer target and run as root with podman # or docker with user namespaces. ARG PYTHON_VERSION=3.11 -FROM python:${PYTHON_VERSION} as developer +FROM python:${PYTHON_VERSION} AS developer # Allow Qt 6 (pyside6) UI to work in the container - also see apt-get below ENV MPLBACKEND=QtAgg diff --git a/README.md b/README.md index 9f4bbe81c9..7881d80e02 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ [![CI](https://github.com/bluesky/ophyd-async/actions/workflows/ci.yml/badge.svg)](https://github.com/bluesky/ophyd-async/actions/workflows/ci.yml) [![Coverage](https://codecov.io/gh/bluesky/ophyd-async/branch/main/graph/badge.svg)](https://codecov.io/gh/bluesky/ophyd-async) [![PyPI](https://img.shields.io/pypi/v/ophyd-async.svg)](https://pypi.org/project/ophyd-async) -[![License](https://img.shields.io/badge/License-BSD_3--Clause-blue.svg)](https://opensource.org/licenses/BSD-3-Clause) +[![License](https://img.shields.io/badge/License-BSD_3--Clause-blue.svg)](https://choosealicense.com/licenses/bsd-3-clause) # ophyd-async @@ -13,7 +13,7 @@ Asynchronous Bluesky hardware abstraction code, compatible with control systems | Documentation | | | Releases | | -Ophyd-async is a Python library for asynchronously interfacing with hardware, intended to +Ophyd-async is a Python library for asynchronously interfacing with hardware, intended to be used as an abstraction layer that enables experiment orchestration and data acquisition code to operate above the specifics of particular devices and control systems. From 394dc4e375fc696838547bc31eae346eaa6b1eb6 Mon Sep 17 00:00:00 2001 From: "Tom C (DLS)" <101418278+coretl@users.noreply.github.com> Date: Wed, 30 Oct 2024 14:04:58 +0000 Subject: [PATCH 18/30] Logo (#629) --- README.md | 2 +- docs/conf.py | 7 +- .../bluesky_ophyd_epics_devices_logo.svg | 389 ------------------ docs/images/bluesky_ophyd_logo.svg | 323 --------------- docs/images/ophyd-async-logo.svg | 358 ++++++++++++++++ .../{ophyd_favicon.svg => ophyd-favicon.svg} | 0 6 files changed, 361 insertions(+), 718 deletions(-) delete mode 100644 docs/images/bluesky_ophyd_epics_devices_logo.svg delete mode 100644 docs/images/bluesky_ophyd_logo.svg create mode 100644 docs/images/ophyd-async-logo.svg rename docs/images/{ophyd_favicon.svg => ophyd-favicon.svg} (100%) diff --git a/README.md b/README.md index 7881d80e02..8d1c2ac289 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ [![PyPI](https://img.shields.io/pypi/v/ophyd-async.svg)](https://pypi.org/project/ophyd-async) [![License](https://img.shields.io/badge/License-BSD_3--Clause-blue.svg)](https://choosealicense.com/licenses/bsd-3-clause) -# ophyd-async +# ![ophyd-async](https://raw.githubusercontent.com/bluesky/ophyd-async/main/docs/images/ophyd-async-logo.svg) Asynchronous Bluesky hardware abstraction code, compatible with control systems like EPICS and Tango. diff --git a/docs/conf.py b/docs/conf.py index 77565b386d..7a7d1db7bc 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -188,9 +188,6 @@ # will fix the switcher at the end of the docs workflow, but never gets a chance # to complete as the docs build warns and fails. html_theme_options = { - "logo": { - "text": project, - }, "use_edit_page_button": True, "github_url": f"https://github.com/{github_user}/{github_repo}", "icon_links": [ @@ -230,8 +227,8 @@ html_show_copyright = False # Logo -html_logo = "images/bluesky_ophyd_logo.svg" -html_favicon = "images/ophyd_favicon.svg" +html_logo = "images/ophyd-async-logo.svg" +html_favicon = "images/ophyd-favicon.svg" # If False and a module has the __all__ attribute set, autosummary documents # every member listed in __all__ and no others. Default is True diff --git a/docs/images/bluesky_ophyd_epics_devices_logo.svg b/docs/images/bluesky_ophyd_epics_devices_logo.svg deleted file mode 100644 index 5eefa60947..0000000000 --- a/docs/images/bluesky_ophyd_epics_devices_logo.svg +++ /dev/null @@ -1,389 +0,0 @@ - - - - - - image/svg+xml - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/docs/images/bluesky_ophyd_logo.svg b/docs/images/bluesky_ophyd_logo.svg deleted file mode 100644 index 0ea32810e2..0000000000 --- a/docs/images/bluesky_ophyd_logo.svg +++ /dev/null @@ -1,323 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/docs/images/ophyd-async-logo.svg b/docs/images/ophyd-async-logo.svg new file mode 100644 index 0000000000..5b20bada91 --- /dev/null +++ b/docs/images/ophyd-async-logo.svg @@ -0,0 +1,358 @@ + + + + + + image/svg+xml + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/images/ophyd_favicon.svg b/docs/images/ophyd-favicon.svg similarity index 100% rename from docs/images/ophyd_favicon.svg rename to docs/images/ophyd-favicon.svg From 4c6d0d9161391e01f0516c70f202511df284e466 Mon Sep 17 00:00:00 2001 From: "Tom C (DLS)" <101418278+coretl@users.noreply.github.com> Date: Thu, 31 Oct 2024 11:01:29 +0000 Subject: [PATCH 19/30] Allow CA/PVA mismatching enums to be bools (#632) --- src/ophyd_async/epics/signal/_aioca.py | 15 +++++++++------ src/ophyd_async/epics/signal/_p4p.py | 18 +++++++++++------- tests/epics/demo/test_demo.py | 16 +++++++++++++++- tests/epics/signal/test_signals.py | 6 ++++++ 4 files changed, 41 insertions(+), 14 deletions(-) diff --git a/src/ophyd_async/epics/signal/_aioca.py b/src/ophyd_async/epics/signal/_aioca.py index 9dc6e4c5ae..b7add417b4 100644 --- a/src/ophyd_async/epics/signal/_aioca.py +++ b/src/ophyd_async/epics/signal/_aioca.py @@ -175,16 +175,19 @@ def make_converter( if is_array and pv_dbr == dbr.DBR_CHAR and datatype is str: # Override waveform of chars to be treated as string return CaLongStrConverter() + elif not is_array and datatype is bool and pv_dbr == dbr.DBR_ENUM: + # Database can't do bools, so are often representated as enums of len 2 + pv_num_choices = get_unique( + {k: len(v.enums) for k, v in values.items()}, "number of choices" + ) + if pv_num_choices != 2: + raise TypeError(f"{pv} has {pv_num_choices} choices, can't map to bool") + return CaBoolConverter() elif not is_array and pv_dbr == dbr.DBR_ENUM: pv_choices = get_unique( {k: tuple(v.enums) for k, v in values.items()}, "choices" ) - if datatype is bool: - # Database can't do bools, so are often representated as enums of len 2 - if len(pv_choices) != 2: - raise TypeError(f"{pv} has {pv_choices=}, can't map to bool") - return CaBoolConverter() - elif enum_cls := get_enum_cls(datatype): + if enum_cls := get_enum_cls(datatype): # If explicitly requested then check return CaEnumConverter(get_supported_values(pv, enum_cls, pv_choices)) elif datatype in (None, str): diff --git a/src/ophyd_async/epics/signal/_p4p.py b/src/ophyd_async/epics/signal/_p4p.py index 3ec4195e53..737b60fc9c 100644 --- a/src/ophyd_async/epics/signal/_p4p.py +++ b/src/ophyd_async/epics/signal/_p4p.py @@ -212,16 +212,20 @@ def make_converter(datatype: type | None, values: dict[str, Any]) -> PvaConverte (typeid, specifier) ] # Some override cases - if typeid == "epics:nt/NTEnum:1.0": + if datatype is bool and typeid == "epics:nt/NTEnum:1.0": + # Database can't do bools, so are often representated as enums of len 2 + pv_num_choices = get_unique( + {k: len(v["value"]["choices"]) for k, v in values.items()}, + "number of choices", + ) + if pv_num_choices != 2: + raise TypeError(f"{pv} has {pv_num_choices} choices, can't map to bool") + return PvaEnumBoolConverter() + elif typeid == "epics:nt/NTEnum:1.0": pv_choices = get_unique( {k: tuple(v["value"]["choices"]) for k, v in values.items()}, "choices" ) - if datatype is bool: - # Database can't do bools, so are often representated as enums of len 2 - if len(pv_choices) != 2: - raise TypeError(f"{pv} has {pv_choices=}, can't map to bool") - return PvaEnumBoolConverter() - elif enum_cls := get_enum_cls(datatype): + if enum_cls := get_enum_cls(datatype): # We were given an enum class, so make class from that return PvaEnumConverter( supported_values=get_supported_values(pv, enum_cls, pv_choices) diff --git a/tests/epics/demo/test_demo.py b/tests/epics/demo/test_demo.py index 3cdd816749..f3259a7ae6 100644 --- a/tests/epics/demo/test_demo.py +++ b/tests/epics/demo/test_demo.py @@ -293,9 +293,23 @@ async def test_assembly_renaming() -> None: async def test_dynamic_sensor_group_disconnected(): - with pytest.raises(NotConnected): + with pytest.raises(NotConnected) as e: async with DeviceCollector(timeout=0.1): mock_sensor_group_dynamic = demo.SensorGroup("MOCK:SENSOR:") + expected = """ +mock_sensor_group_dynamic: NotConnected: + sensors: NotConnected: + 1: NotConnected: + value: NotConnected: ca://MOCK:SENSOR:1:Value + mode: NotConnected: ca://MOCK:SENSOR:1:Mode + 2: NotConnected: + value: NotConnected: ca://MOCK:SENSOR:2:Value + mode: NotConnected: ca://MOCK:SENSOR:2:Mode + 3: NotConnected: + value: NotConnected: ca://MOCK:SENSOR:3:Value + mode: NotConnected: ca://MOCK:SENSOR:3:Mode +""" + assert str(e.value) == expected assert mock_sensor_group_dynamic.name == "mock_sensor_group_dynamic" diff --git a/tests/epics/signal/test_signals.py b/tests/epics/signal/test_signals.py index 4c6aadfa69..2a7a867da3 100644 --- a/tests/epics/signal/test_signals.py +++ b/tests/epics/signal/test_signals.py @@ -902,6 +902,12 @@ async def test_signals_created_for_not_prec_0_float_cannot_use_int(ioc: IOC): await sig.connect() +async def test_bool_works_for_mismatching_enums(ioc: IOC): + pv_name = f"{ioc.protocol}://{PV_PREFIX}:{ioc.protocol}:bool" + sig = epics_signal_rw(bool, pv_name, pv_name + "_unnamed") + await sig.connect() + + async def test_can_read_using_ophyd_async_then_ophyd(ioc: IOC): oa_read = f"{ioc.protocol}://{PV_PREFIX}:{ioc.protocol}:float_prec_1" ophyd_read = f"{PV_PREFIX}:{ioc.protocol}:float_prec_0" From e9762d29d59df928281fb7d1ca6855be816a556c Mon Sep 17 00:00:00 2001 From: James Souter <107045742+jsouter@users.noreply.github.com> Date: Thu, 31 Oct 2024 14:30:57 +0000 Subject: [PATCH 20/30] Allow shared parent mock to be passed to Device.connect (#599) * Allow unittest.Mock to be passed to Device.connect * use bool | Mock signature for connect everywhere * store Mocks for Devices and MockSignalBackends for Signals in dictionaries * assert mock calls explicitly in epics/demo/test_demo.py --- src/ophyd_async/core/__init__.py | 2 ++ src/ophyd_async/core/_device.py | 29 +++++++++++----- src/ophyd_async/core/_mock_signal_backend.py | 18 ++++++---- src/ophyd_async/core/_mock_signal_utils.py | 21 +++++++----- src/ophyd_async/core/_protocol.py | 4 ++- src/ophyd_async/core/_signal.py | 8 +++-- src/ophyd_async/epics/pvi/_pvi.py | 4 ++- .../plan_stubs/_ensure_connected.py | 4 ++- .../tango/base_devices/_base_device.py | 3 +- tests/core/test_device.py | 11 ++++--- tests/core/test_mock_signal_backend.py | 7 ++-- tests/core/test_signal.py | 8 ++--- tests/epics/demo/test_demo.py | 33 +++++++++++++++---- 13 files changed, 106 insertions(+), 46 deletions(-) diff --git a/src/ophyd_async/core/__init__.py b/src/ophyd_async/core/__init__.py index 0a6c0d2f94..97a2c0c6d4 100644 --- a/src/ophyd_async/core/__init__.py +++ b/src/ophyd_async/core/__init__.py @@ -23,6 +23,7 @@ from ._mock_signal_backend import MockSignalBackend from ._mock_signal_utils import ( callback_on_mock_put, + get_mock, get_mock_put, mock_puts_blocked, reset_mock_put_calls, @@ -117,6 +118,7 @@ "config_ophyd_async_logging", "MockSignalBackend", "callback_on_mock_put", + "get_mock", "get_mock_put", "mock_puts_blocked", "reset_mock_put_calls", diff --git a/src/ophyd_async/core/_device.py b/src/ophyd_async/core/_device.py index c11aa478f8..5274b89018 100644 --- a/src/ophyd_async/core/_device.py +++ b/src/ophyd_async/core/_device.py @@ -5,6 +5,7 @@ from collections.abc import Coroutine, Iterator, Mapping, MutableMapping from logging import LoggerAdapter, getLogger from typing import Any, TypeVar +from unittest.mock import Mock from bluesky.protocols import HasName from bluesky.run_engine import call_in_bluesky_event_loop, in_bluesky_event_loop @@ -12,6 +13,8 @@ from ._protocol import Connectable from ._utils import DEFAULT_TIMEOUT, NotConnected, wait_for_connection +_device_mocks: dict[Device, Mock] = {} + class DeviceConnector: """Defines how a `Device` should be connected and type hints processed.""" @@ -37,7 +40,7 @@ def create_children_from_annotations(self, device: Device): async def connect( self, device: Device, - mock: bool, + mock: bool | Mock, timeout: float, force_reconnect: bool, ): @@ -47,12 +50,12 @@ async def connect( done in a different mock more. It should connect the Device and all its children. """ - coros = { - name: child_device.connect( - mock=mock, timeout=timeout, force_reconnect=force_reconnect + coros = {} + for name, child_device in device.children(): + child_mock = getattr(mock, name) if mock else mock # Mock() or False + coros[name] = child_device.connect( + mock=child_mock, timeout=timeout, force_reconnect=force_reconnect ) - for name, child_device in device.children() - } await wait_for_connection(**coros) @@ -114,7 +117,7 @@ def __setattr__(self, name: str, value: Any) -> None: async def connect( self, - mock: bool = False, + mock: bool | Mock = False, timeout: float = DEFAULT_TIMEOUT, force_reconnect: bool = False, ) -> None: @@ -129,13 +132,18 @@ async def connect( timeout: Time to wait before failing with a TimeoutError. """ + uses_mock = bool(mock) can_use_previous_connect = ( - mock is self._connect_mock_arg + uses_mock is self._connect_mock_arg and self._connect_task and not (self._connect_task.done() and self._connect_task.exception()) ) + if mock is True: + mock = Mock() # create a new Mock if one not provided if force_reconnect or not can_use_previous_connect: - self._connect_mock_arg = mock + self._connect_mock_arg = uses_mock + if self._connect_mock_arg: + _device_mocks[self] = mock coro = self._connector.connect( device=self, mock=mock, timeout=timeout, force_reconnect=force_reconnect ) @@ -198,6 +206,9 @@ def children(self) -> Iterator[tuple[str, Device]]: for key, child in self._children.items(): yield str(key), child + def __hash__(self): # to allow DeviceVector to be used as dict keys and in sets + return hash(id(self)) + class DeviceCollector: """Collector of top level Device instances to be used as a context manager diff --git a/src/ophyd_async/core/_mock_signal_backend.py b/src/ophyd_async/core/_mock_signal_backend.py index 49c835ace5..878313e051 100644 --- a/src/ophyd_async/core/_mock_signal_backend.py +++ b/src/ophyd_async/core/_mock_signal_backend.py @@ -1,7 +1,7 @@ import asyncio from collections.abc import Callable from functools import cached_property -from unittest.mock import AsyncMock +from unittest.mock import AsyncMock, Mock from bluesky.protocols import Descriptor, Reading @@ -13,7 +13,11 @@ class MockSignalBackend(SignalBackend[SignalDatatypeT]): """Signal backend for testing, created by ``Device.connect(mock=True)``.""" - def __init__(self, initial_backend: SignalBackend[SignalDatatypeT]) -> None: + def __init__( + self, + initial_backend: SignalBackend[SignalDatatypeT], + mock: Mock, + ) -> None: if isinstance(initial_backend, MockSignalBackend): raise ValueError("Cannot make a MockSignalBackend for a MockSignalBackend") @@ -27,6 +31,12 @@ def __init__(self, initial_backend: SignalBackend[SignalDatatypeT]) -> None: self.soft_backend = SoftSignalBackend( datatype=self.initial_backend.datatype ) + + # use existing Mock if provided + self.mock = mock + self.put_mock = AsyncMock(name="put", spec=Callable) + self.mock.attach_mock(self.put_mock, "put") + super().__init__(datatype=self.initial_backend.datatype) def set_value(self, value: SignalDatatypeT): @@ -38,10 +48,6 @@ def source(self, name: str, read: bool) -> str: async def connect(self, timeout: float) -> None: pass - @cached_property - def put_mock(self) -> AsyncMock: - return AsyncMock(name="put", spec=Callable) - @cached_property def put_proceeds(self) -> asyncio.Event: put_proceeds = asyncio.Event() diff --git a/src/ophyd_async/core/_mock_signal_utils.py b/src/ophyd_async/core/_mock_signal_utils.py index 70f29a2a4f..30d48dbfe0 100644 --- a/src/ophyd_async/core/_mock_signal_utils.py +++ b/src/ophyd_async/core/_mock_signal_utils.py @@ -1,19 +1,18 @@ from collections.abc import Awaitable, Callable, Iterable from contextlib import asynccontextmanager, contextmanager -from unittest.mock import AsyncMock +from unittest.mock import AsyncMock, Mock +from ._device import Device, _device_mocks from ._mock_signal_backend import MockSignalBackend -from ._signal import Signal, SignalR +from ._signal import Signal, SignalR, _mock_signal_backends from ._soft_signal_backend import SignalDatatypeT def _get_mock_signal_backend(signal: Signal) -> MockSignalBackend: - backend = signal._connector.backend # noqa:SLF001 - assert isinstance(backend, MockSignalBackend), ( - "Expected to receive a `MockSignalBackend`, instead " - f" received {type(backend)}. " - ) - return backend + assert ( + signal in _mock_signal_backends + ), f"Signal {signal} not connected in mock mode" + return _mock_signal_backends[signal] def set_mock_value(signal: Signal[SignalDatatypeT], value: SignalDatatypeT): @@ -46,6 +45,12 @@ def get_mock_put(signal: Signal) -> AsyncMock: return _get_mock_signal_backend(signal).put_mock +def get_mock(device: Device | Signal) -> Mock: + if isinstance(device, Signal): + return _get_mock_signal_backend(device).mock + return _device_mocks[device] + + def reset_mock_put_calls(signal: Signal): backend = _get_mock_signal_backend(signal) backend.put_mock.reset_mock() diff --git a/src/ophyd_async/core/_protocol.py b/src/ophyd_async/core/_protocol.py index 74b7bf0c23..d703d00112 100644 --- a/src/ophyd_async/core/_protocol.py +++ b/src/ophyd_async/core/_protocol.py @@ -16,6 +16,8 @@ from ._utils import DEFAULT_TIMEOUT if TYPE_CHECKING: + from unittest.mock import Mock + from ._status import AsyncStatus @@ -24,7 +26,7 @@ class Connectable(Protocol): @abstractmethod async def connect( self, - mock: bool = False, + mock: bool | Mock = False, timeout: float = DEFAULT_TIMEOUT, force_reconnect: bool = False, ): diff --git a/src/ophyd_async/core/_signal.py b/src/ophyd_async/core/_signal.py index 2895cc2ee8..371cb5a0de 100644 --- a/src/ophyd_async/core/_signal.py +++ b/src/ophyd_async/core/_signal.py @@ -4,6 +4,7 @@ import functools from collections.abc import AsyncGenerator, Awaitable, Callable, Mapping from typing import Any, Generic, cast +from unittest.mock import Mock from bluesky.protocols import ( Locatable, @@ -30,6 +31,8 @@ from ._status import AsyncStatus, completed_status from ._utils import CALCULATE_TIMEOUT, DEFAULT_TIMEOUT, CalculatableTimeout, Callback, T +_mock_signal_backends: dict[Device, MockSignalBackend] = {} + async def _wait_for(coro: Awaitable[T], timeout: float | None, source: str) -> T: try: @@ -53,12 +56,13 @@ def __init__(self, backend: SignalBackend): async def connect( self, device: Device, - mock: bool, + mock: bool | Mock, timeout: float, force_reconnect: bool, ): if mock: - self.backend = MockSignalBackend(self._init_backend) + self.backend = MockSignalBackend(self._init_backend, mock) + _mock_signal_backends[device] = self.backend else: self.backend = self._init_backend device.log.debug(f"Connecting to {self.backend.source(device.name, read=True)}") diff --git a/src/ophyd_async/epics/pvi/_pvi.py b/src/ophyd_async/epics/pvi/_pvi.py index 22f160ea91..5bd7a38ef8 100644 --- a/src/ophyd_async/epics/pvi/_pvi.py +++ b/src/ophyd_async/epics/pvi/_pvi.py @@ -1,5 +1,7 @@ from __future__ import annotations +from unittest.mock import Mock + from ophyd_async.core import ( Device, DeviceConnector, @@ -41,7 +43,7 @@ def create_children_from_annotations(self, device: Device): ) async def connect( - self, device: Device, mock: bool, timeout: float, force_reconnect: bool + self, device: Device, mock: bool | Mock, timeout: float, force_reconnect: bool ) -> None: if mock: # Make 2 entries for each DeviceVector diff --git a/src/ophyd_async/plan_stubs/_ensure_connected.py b/src/ophyd_async/plan_stubs/_ensure_connected.py index 3a64619a5c..d4835b710c 100644 --- a/src/ophyd_async/plan_stubs/_ensure_connected.py +++ b/src/ophyd_async/plan_stubs/_ensure_connected.py @@ -1,3 +1,5 @@ +from unittest.mock import Mock + import bluesky.plan_stubs as bps from ophyd_async.core import DEFAULT_TIMEOUT, Device, wait_for_connection @@ -5,7 +7,7 @@ def ensure_connected( *devices: Device, - mock: bool = False, + mock: bool | Mock = False, timeout: float = DEFAULT_TIMEOUT, force_reconnect=False, ): diff --git a/src/ophyd_async/tango/base_devices/_base_device.py b/src/ophyd_async/tango/base_devices/_base_device.py index d73328724c..f93b5a3eda 100644 --- a/src/ophyd_async/tango/base_devices/_base_device.py +++ b/src/ophyd_async/tango/base_devices/_base_device.py @@ -1,6 +1,7 @@ from __future__ import annotations from typing import TypeVar +from unittest.mock import Mock from ophyd_async.core import Device, DeviceConnector, DeviceFiller from ophyd_async.tango.signal import ( @@ -114,7 +115,7 @@ def create_children_from_annotations(self, device: Device): ) async def connect( - self, device: Device, mock: bool, timeout: float, force_reconnect: bool + self, device: Device, mock: bool | Mock, timeout: float, force_reconnect: bool ) -> None: if mock: # Make 2 entries for each DeviceVector diff --git a/tests/core/test_device.py b/tests/core/test_device.py index 0fff56357b..2fc127f17b 100644 --- a/tests/core/test_device.py +++ b/tests/core/test_device.py @@ -193,21 +193,24 @@ async def test_device_with_children_lazily_connects(RE): ) -async def test_no_reconnect_signals_if_not_forced(): +@pytest.mark.parametrize("use_Mock", [False, True]) +async def test_no_reconnect_signals_if_not_forced(use_Mock): parent = DummyDeviceGroup("parent") + connect_mock_arg = Mock() if use_Mock else True + async def inner_connect(mock, timeout, force_reconnect): parent.child1.connected = True parent.child1.connect = Mock(side_effect=inner_connect) - await parent.connect(mock=True, timeout=0.01) + await parent.connect(mock=connect_mock_arg, timeout=0.01) assert parent.child1.connected assert parent.child1.connect.call_count == 1 - await parent.connect(mock=True, timeout=0.01) + await parent.connect(mock=connect_mock_arg, timeout=0.01) assert parent.child1.connected assert parent.child1.connect.call_count == 1 for count in range(2, 10): - await parent.connect(mock=True, timeout=0.01, force_reconnect=True) + await parent.connect(mock=connect_mock_arg, timeout=0.01, force_reconnect=True) assert parent.child1.connected assert parent.child1.connect.call_count == count diff --git a/tests/core/test_mock_signal_backend.py b/tests/core/test_mock_signal_backend.py index 128dd7eef3..b8fead82f4 100644 --- a/tests/core/test_mock_signal_backend.py +++ b/tests/core/test_mock_signal_backend.py @@ -134,9 +134,10 @@ async def test_mock_utils_throw_error_if_backend_isnt_mock_signal_backend(): exc_msgs.append(str(exc.value)) for msg in exc_msgs: - assert msg == ( - "Expected to receive a `MockSignalBackend`, instead " - f" received {SoftSignalBackend}. " + assert re.match( + r"Signal " + r"not connected in mock mode", + msg, ) diff --git a/tests/core/test_signal.py b/tests/core/test_signal.py index ec8b777743..a9964f5132 100644 --- a/tests/core/test_signal.py +++ b/tests/core/test_signal.py @@ -3,7 +3,7 @@ import re import time from asyncio import Event -from unittest.mock import ANY +from unittest.mock import ANY, Mock import pytest from bluesky.protocols import Reading @@ -44,7 +44,7 @@ def num_occurrences(substring: str, string: str) -> int: async def test_signal_connects_to_previous_backend(caplog): caplog.set_level(logging.DEBUG) - int_mock_backend = MockSignalBackend(SoftSignalBackend(int)) + int_mock_backend = MockSignalBackend(SoftSignalBackend(int), Mock()) original_connect = int_mock_backend.connect times_backend_connect_called = 0 @@ -63,7 +63,7 @@ async def new_connect(timeout=1): async def test_signal_connects_with_force_reconnect(caplog): caplog.set_level(logging.DEBUG) - signal = Signal(MockSignalBackend(SoftSignalBackend(int))) + signal = Signal(MockSignalBackend(SoftSignalBackend(int), Mock())) await signal.connect() assert num_occurrences(f"Connecting to {signal.source}", caplog.text) == 1 await signal.connect(force_reconnect=True) @@ -82,7 +82,7 @@ async def connect(self, timeout=DEFAULT_TIMEOUT): self.succeed_on_connect = True raise RuntimeError("connect fail") - signal = SignalRW(MockSignalBackendFailingFirst(SoftSignalBackend(int))) + signal = SignalRW(MockSignalBackendFailingFirst(SoftSignalBackend(int), Mock())) with pytest.raises(RuntimeError, match="connect fail"): await signal.connect(mock=False) diff --git a/tests/epics/demo/test_demo.py b/tests/epics/demo/test_demo.py index f3259a7ae6..7b14154c41 100644 --- a/tests/epics/demo/test_demo.py +++ b/tests/epics/demo/test_demo.py @@ -15,6 +15,7 @@ assert_reading, assert_value, callback_on_mock_put, + get_mock, get_mock_put, set_mock_value, ) @@ -190,13 +191,33 @@ async def test_retrieve_mock_and_assert(mock_mover: demo.Mover): await mock_mover.velocity.set(100) await mock_mover.setpoint.set(67) + assert parent_mock.mock_calls == [ + call.velocity(100, wait=True), + call.setpoint(67, wait=True), + ] - parent_mock.assert_has_calls( - [ - call.velocity(100, wait=True), - call.setpoint(67, wait=True), - ] - ) + +async def test_mocks_in_device_share_parent(): + mock = Mock() + async with DeviceCollector(mock=mock): + mock_mover = demo.Mover("BLxxI-MO-TABLE-01:Y:") + + assert get_mock(mock_mover) is mock + assert get_mock(mock_mover.setpoint) is mock.setpoint + assert get_mock_put(mock_mover.setpoint) is mock.setpoint.put + await mock_mover.setpoint.set(10) + get_mock_put(mock_mover.setpoint).assert_called_once_with(10, wait=ANY) + + await mock_mover.velocity.set(100) + await mock_mover.setpoint.set(67) + + mock.reset_mock() + await mock_mover.velocity.set(100) + await mock_mover.setpoint.set(67) + assert mock.mock_calls == [ + call.velocity.put(100, wait=True), + call.setpoint.put(67, wait=True), + ] async def test_read_mover(mock_mover: demo.Mover): From 94b6d676b34bc017c659c886c0006e2d24651cfb Mon Sep 17 00:00:00 2001 From: James Souter <107045742+jsouter@users.noreply.github.com> Date: Fri, 1 Nov 2024 07:41:07 +0000 Subject: [PATCH 21/30] Add device name to AsyncStatusBase repr (#607) --- src/ophyd_async/core/_status.py | 28 ++++++++++++++++++----- tests/core/test_status.py | 21 ++++++++++++----- tests/core/test_watchable_async_status.py | 19 +++++++++++++++ 3 files changed, 56 insertions(+), 12 deletions(-) diff --git a/src/ophyd_async/core/_status.py b/src/ophyd_async/core/_status.py index 93b9888404..e4ff8803ff 100644 --- a/src/ophyd_async/core/_status.py +++ b/src/ophyd_async/core/_status.py @@ -13,6 +13,7 @@ from bluesky.protocols import Status +from ._device import Device from ._protocol import Watcher from ._utils import Callback, P, T, WatcherUpdate @@ -23,13 +24,14 @@ class AsyncStatusBase(Status): """Convert asyncio awaitable to bluesky Status interface""" - def __init__(self, awaitable: Coroutine | asyncio.Task): + def __init__(self, awaitable: Coroutine | asyncio.Task, name: str | None = None): if isinstance(awaitable, asyncio.Task): self.task = awaitable else: self.task = asyncio.create_task(awaitable) self.task.add_done_callback(self._run_callbacks) self._callbacks: list[Callback[Status]] = [] + self._name = name def __await__(self): return self.task.__await__() @@ -76,7 +78,11 @@ def __repr__(self) -> str: status = "done" else: status = "pending" - return f"<{type(self).__name__}, task: {self.task.get_coro()}, {status}>" + device_str = f"device: {self._name}, " if self._name else "" + return ( + f"<{type(self).__name__}, {device_str}" + f"task: {self.task.get_coro()}, {status}>" + ) __str__ = __repr__ @@ -90,7 +96,11 @@ def wrap(cls: type[AS], f: Callable[P, Coroutine]) -> Callable[P, AS]: @functools.wraps(f) def wrap_f(*args: P.args, **kwargs: P.kwargs) -> AS: - return cls(f(*args, **kwargs)) + if args and isinstance(args[0], Device): + name = args[0].name + else: + name = None + return cls(f(*args, **kwargs), name=name) # type is actually functools._Wrapped[P, Awaitable, P, AS] # but functools._Wrapped is not necessarily available @@ -100,11 +110,13 @@ def wrap_f(*args: P.args, **kwargs: P.kwargs) -> AS: class WatchableAsyncStatus(AsyncStatusBase, Generic[T]): """Convert AsyncIterator of WatcherUpdates to bluesky Status interface.""" - def __init__(self, iterator: AsyncIterator[WatcherUpdate[T]]): + def __init__( + self, iterator: AsyncIterator[WatcherUpdate[T]], name: str | None = None + ): self._watchers: list[Watcher] = [] self._start = time.monotonic() self._last_update: WatcherUpdate[T] | None = None - super().__init__(self._notify_watchers_from(iterator)) + super().__init__(self._notify_watchers_from(iterator), name) async def _notify_watchers_from(self, iterator: AsyncIterator[WatcherUpdate[T]]): async for update in iterator: @@ -136,7 +148,11 @@ def wrap( @functools.wraps(f) def wrap_f(*args: P.args, **kwargs: P.kwargs) -> WAS: - return cls(f(*args, **kwargs)) + if args and isinstance(args[0], Device): + name = args[0].name + else: + name = None + return cls(f(*args, **kwargs), name=name) return cast(Callable[P, WAS], wrap_f) diff --git a/tests/core/test_status.py b/tests/core/test_status.py index 263a39dca6..109dfd4b97 100644 --- a/tests/core/test_status.py +++ b/tests/core/test_status.py @@ -118,16 +118,14 @@ class FailingMovable(Movable, Device): def _fail(self): raise ValueError("This doesn't work") - async def _set(self, value): + @AsyncStatus.wrap + async def set(self, value): if value: - self._fail() - - def set(self, value) -> AsyncStatus: - return AsyncStatus(self._set(value)) + return self._fail() async def test_status_propogates_traceback_under_RE(RE) -> None: - expected_call_stack = ["_set", "_fail"] + expected_call_stack = ["set", "_fail"] d = FailingMovable() with pytest.raises(FailedStatus) as ctx: RE(bps.mv(d, 3)) @@ -203,3 +201,14 @@ async def test_completed_status(): with pytest.raises(ValueError): await completed_status(ValueError()) await completed_status() + + +async def test_device_name_in_failure_message_AsyncStatus_wrap(RE): + device_name = "MyFailingMovable" + d = FailingMovable(name=device_name) + with pytest.raises(FailedStatus) as ctx: + RE(bps.mv(d, 3)) + # FailingMovable.set is decorated with @AsyncStatus.wrap + # undecorated methods will not print the device name + status: AsyncStatus = ctx.value.args[0] + assert f"device: {device_name}" in repr(status) diff --git a/tests/core/test_watchable_async_status.py b/tests/core/test_watchable_async_status.py index d3f954d239..618a848501 100644 --- a/tests/core/test_watchable_async_status.py +++ b/tests/core/test_watchable_async_status.py @@ -5,9 +5,11 @@ import bluesky.plan_stubs as bps import pytest from bluesky.protocols import Movable +from bluesky.utils import FailedStatus from ophyd_async.core import ( AsyncStatus, + Device, StandardReadable, WatchableAsyncStatus, WatcherUpdate, @@ -197,3 +199,20 @@ async def test_watchableasyncstatus_times_out(RE): await asyncio.sleep(0.01) assert not st.success assert isinstance(st.exception(), asyncio.TimeoutError) + + +async def test_device_name_in_failure_message_WatchableAsyncStatus_wrap(RE): + class FailingWatchableMovable(Movable, Device): + @WatchableAsyncStatus.wrap + async def set(self, value) -> AsyncIterator: + yield WatcherUpdate(0, 0, value) + raise ValueError("This doesn't work") + + device_name = "MyFailingMovable" + d = FailingWatchableMovable(name=device_name) + with pytest.raises(FailedStatus) as ctx: + RE(bps.mv(d, 3)) + # FailingMovable.set is decorated with @AsyncStatus.wrap + # undecorated methods will not print the device name + status: AsyncStatus = ctx.value.args[0] + assert f"device: {device_name}" in repr(status) From 5ae409b9374757c309cea0735f4e953f3ff0bda6 Mon Sep 17 00:00:00 2001 From: "Tom C (DLS)" <101418278+coretl@users.noreply.github.com> Date: Fri, 1 Nov 2024 11:29:39 +0000 Subject: [PATCH 22/30] Temporary fix for PyPI publishing (#634) https://github.com/DiamondLightSource/python-copier-template/issues/210 --- .github/workflows/_pypi.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/_pypi.yml b/.github/workflows/_pypi.yml index 0c5258dbee..574e299b05 100644 --- a/.github/workflows/_pypi.yml +++ b/.github/workflows/_pypi.yml @@ -15,3 +15,5 @@ jobs: - name: Publish to PyPI using trusted publishing uses: pypa/gh-action-pypi-publish@release/v1 + with: + attestations: false From 74c7d48223e781953b614b8bac7caefc94c45907 Mon Sep 17 00:00:00 2001 From: Tom Willemsen Date: Fri, 1 Nov 2024 15:04:02 +0000 Subject: [PATCH 23/30] Windows: fix unit tests & enable CI (#633) * Fix race condition in mock_signal_backend * Fix numpy datatypes in test_mock_signal_backend * Fix adaravis URI paths * Fix path in test_kinetix * Fix paths in simdetector tests * Loosen timings against race condition in test_sim_detector * Correct paths in test_advimba * Fix paths in test_eiger * Fix broken tests in test_signals * Fix tange test_base_device tests on Windows Missing process=True argument was causing all subsequent tests to fail on windows * Fix race conditions in tango tests * Use tango FQTRL to placate windows unit tests * Correct paths in fastcs tests * Correct path logic for windows compatibility * Slacken timings for race condition in test_sim_motor * Slacken race condition timings in test_motor * Enable windows CI builds * Try allowing asyncio time to clean up tasks * Object IDs on windows are uppercase * Try wait_for_wakeups --- .github/workflows/ci.yml | 2 +- tests/core/test_flyer.py | 4 ++++ tests/core/test_mock_signal_backend.py | 4 ++-- tests/core/test_soft_signal_backend.py | 9 ++++++++- tests/epics/adaravis/test_aravis.py | 4 +++- tests/epics/adkinetix/test_kinetix.py | 4 +++- tests/epics/adsimdetector/test_sim.py | 10 ++++++--- tests/epics/advimba/test_vimba.py | 4 +++- tests/epics/eiger/test_odin_io.py | 7 +++---- tests/epics/signal/test_signals.py | 10 +++++++-- tests/epics/test_motor.py | 28 +++++++++++++++++--------- tests/fastcs/panda/test_hdf_panda.py | 2 +- tests/fastcs/panda/test_writer.py | 5 +++-- tests/sim/demo/test_sim_motor.py | 4 ++-- tests/sim/test_sim_detector.py | 3 +++ tests/tango/test_base_device.py | 4 ++-- tests/tango/test_tango_transport.py | 8 +++++--- 17 files changed, 76 insertions(+), 36 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 265fea6424..d11c9c6dca 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -20,7 +20,7 @@ jobs: if: needs.check.outputs.branch-pr == '' strategy: matrix: - runs-on: ["ubuntu-latest"] # can add windows-latest, macos-latest + runs-on: ["ubuntu-latest", "windows-latest"] # can add macos-latest python-version: ["3.10","3.11"] # 3.12 should be added when p4p is updated include: # Include one that runs in the dev environment diff --git a/tests/core/test_flyer.py b/tests/core/test_flyer.py index 79bcf8fe7e..ef0d0eb39d 100644 --- a/tests/core/test_flyer.py +++ b/tests/core/test_flyer.py @@ -1,3 +1,4 @@ +import asyncio import time from collections.abc import AsyncGenerator, AsyncIterator, Sequence from typing import Any @@ -356,6 +357,9 @@ def flying_plan(): with pytest.raises(Exception, match=match_msg): RE(flying_plan()) + # Try explicitly letting event loop clean up tasks...? + await asyncio.sleep(1.0) + @pytest.mark.parametrize( ["kwargs", "error_msg"], diff --git a/tests/core/test_mock_signal_backend.py b/tests/core/test_mock_signal_backend.py index b8fead82f4..a40faea4e5 100644 --- a/tests/core/test_mock_signal_backend.py +++ b/tests/core/test_mock_signal_backend.py @@ -135,7 +135,7 @@ async def test_mock_utils_throw_error_if_backend_isnt_mock_signal_backend(): for msg in exc_msgs: assert re.match( - r"Signal " + r"Signal " r"not connected in mock mode", msg, ) @@ -190,7 +190,7 @@ async def test_blocks_during_put(mock_signals): assert not status1.done assert not status2.done - await asyncio.sleep(1e-4) + await asyncio.sleep(0.1) assert status1.done assert status2.done diff --git a/tests/core/test_soft_signal_backend.py b/tests/core/test_soft_signal_backend.py index f23046ec2b..fc60a2bbfa 100644 --- a/tests/core/test_soft_signal_backend.py +++ b/tests/core/test_soft_signal_backend.py @@ -1,4 +1,5 @@ import asyncio +import os import time from collections.abc import Callable, Sequence from typing import Any @@ -67,10 +68,16 @@ def close(self): self.backend.set_callback(None) +# Can be removed once numpy >=2 is pinned. +default_int_type = ( + " OdinDriverAndWriter: async def test_when_open_called_then_file_correctly_set( - odin_driver_and_writer: OdinDriverAndWriter, + odin_driver_and_writer: OdinDriverAndWriter, tmp_path: Path ): driver, writer = odin_driver_and_writer path_info = writer._path_provider.return_value - expected_path = "/tmp" expected_filename = "filename.h5" - path_info.directory_path = Path(expected_path) + path_info.directory_path = tmp_path path_info.filename = expected_filename await writer.open() - get_mock_put(driver.file_path).assert_called_once_with(expected_path, wait=ANY) + get_mock_put(driver.file_path).assert_called_once_with(str(tmp_path), wait=ANY) get_mock_put(driver.file_name).assert_called_once_with(expected_filename, wait=ANY) diff --git a/tests/epics/signal/test_signals.py b/tests/epics/signal/test_signals.py index 2a7a867da3..250bec3567 100644 --- a/tests/epics/signal/test_signals.py +++ b/tests/epics/signal/test_signals.py @@ -1,4 +1,5 @@ import asyncio +import os import random import string import subprocess @@ -99,7 +100,7 @@ def ioc(request: pytest.FixtureRequest): # close backend caches before the event loop purge_channel_caches() try: - print(process.communicate("exit")[0]) + print(process.communicate("exit()")[0]) except ValueError: # Someone else already called communicate pass @@ -260,7 +261,11 @@ def get_dtype_numpy(suffix: str) -> str: # type: ignore elif "64" in suffix: int_str += "8" else: - int_str += "8" + int_str += ( + "4" + if os.name == "nt" and np.version.version.startswith("1.") + else "8" + ) return int_str if "str" in suffix or "enum" in suffix: return "|S40" @@ -908,6 +913,7 @@ async def test_bool_works_for_mismatching_enums(ioc: IOC): await sig.connect() +@pytest.mark.skipif(os.name == "nt", reason="Hangs on windows for unknown reasons") async def test_can_read_using_ophyd_async_then_ophyd(ioc: IOC): oa_read = f"{ioc.protocol}://{PV_PREFIX}:{ioc.protocol}:float_prec_1" ophyd_read = f"{PV_PREFIX}:{ioc.protocol}:float_prec_0" diff --git a/tests/epics/test_motor.py b/tests/epics/test_motor.py index cd8f11ba8d..1760f245ba 100644 --- a/tests/epics/test_motor.py +++ b/tests/epics/test_motor.py @@ -18,9 +18,16 @@ ) from ophyd_async.epics import motor -# Long enough for multiple asyncio event loop cycles to run so -# all the tasks have a chance to run -A_BIT = 0.001 + +async def wait_for_wakeups(max_yields=10): + loop = asyncio.get_event_loop() + # If anything has called loop.call_soon or is scheduled a wakeup + # then let it run + for _ in range(max_yields): + await asyncio.sleep(0) + if not loop._ready: + return + raise RuntimeError(f"Tasks still scheduling wakeups after {max_yields} yields") @pytest.fixture @@ -37,7 +44,7 @@ async def sim_motor(): async def wait_for_eq(item, attribute, comparison, timeout): timeout_time = time.monotonic() + timeout while getattr(item, attribute) != comparison: - await asyncio.sleep(A_BIT) + await wait_for_wakeups() if time.monotonic() > timeout_time: raise TimeoutError @@ -49,6 +56,7 @@ async def test_motor_moving_well(sim_motor: motor.Motor) -> None: s.watch(watcher) done = Mock() s.add_callback(done) + await wait_for_wakeups() await wait_for_eq(watcher, "call_count", 1, 1) assert watcher.call_args == call( name="sim_motor", @@ -78,7 +86,7 @@ async def test_motor_moving_well(sim_motor: motor.Motor) -> None: set_mock_value(sim_motor.motor_done_move, True) set_mock_value(sim_motor.user_readback, 0.55) set_mock_put_proceeds(sim_motor.user_setpoint, True) - await asyncio.sleep(A_BIT) + await wait_for_wakeups() await wait_for_eq(s, "done", True, 1) done.assert_called_once_with(s) @@ -90,7 +98,7 @@ async def test_motor_moving_well_2(sim_motor: motor.Motor) -> None: s.watch(watcher) done = Mock() s.add_callback(done) - await asyncio.sleep(A_BIT) + await wait_for_wakeups() assert watcher.call_count == 1 assert watcher.call_args == call( name="sim_motor", @@ -99,7 +107,7 @@ async def test_motor_moving_well_2(sim_motor: motor.Motor) -> None: target=0.55, unit="mm", precision=3, - time_elapsed=pytest.approx(0.0, abs=0.05), + time_elapsed=pytest.approx(0.0, abs=0.2), ) watcher.reset_mock() assert 0.55 == await sim_motor.user_setpoint.get_value() @@ -115,10 +123,10 @@ async def test_motor_moving_well_2(sim_motor: motor.Motor) -> None: target=0.55, unit="mm", precision=3, - time_elapsed=pytest.approx(0.1, abs=0.05), + time_elapsed=pytest.approx(0.1, abs=0.2), ) set_mock_put_proceeds(sim_motor.user_setpoint, True) - await asyncio.sleep(A_BIT) + await wait_for_wakeups() assert s.done done.assert_called_once_with(s) @@ -157,7 +165,7 @@ async def test_motor_moving_stopped(sim_motor: motor.Motor): assert not s.done await sim_motor.stop() set_mock_put_proceeds(sim_motor.user_setpoint, True) - await asyncio.sleep(A_BIT) + await wait_for_wakeups() assert s.done assert s.success is False diff --git a/tests/fastcs/panda/test_hdf_panda.py b/tests/fastcs/panda/test_hdf_panda.py index d6fcc4ee9f..4e609c63e5 100644 --- a/tests/fastcs/panda/test_hdf_panda.py +++ b/tests/fastcs/panda/test_hdf_panda.py @@ -149,7 +149,7 @@ def flying_plan(): "uid": ANY, "data_key": data_key_name, "mimetype": "application/x-hdf5", - "uri": "file://localhost" + str(tmp_path / "test-panda.h5"), + "uri": "file://localhost/" + str(tmp_path / "test-panda.h5").lstrip("/"), "parameters": { "dataset": f"/{dataset_name}", "swmr": False, diff --git a/tests/fastcs/panda/test_writer.py b/tests/fastcs/panda/test_writer.py index 386dc16fdc..8dc7586dfd 100644 --- a/tests/fastcs/panda/test_writer.py +++ b/tests/fastcs/panda/test_writer.py @@ -187,7 +187,8 @@ def assert_resource_document(name, resource_doc): "uid": ANY, "data_key": name, "mimetype": "application/x-hdf5", - "uri": "file://localhost" + str(tmp_path / "mock_panda" / "data.h5"), + "uri": "file://localhost/" + + str(tmp_path / "mock_panda" / "data.h5").lstrip("/"), "parameters": { "dataset": f"/{name}", "swmr": False, @@ -195,7 +196,7 @@ def assert_resource_document(name, resource_doc): "chunk_shape": (1024,), }, } - assert "mock_panda/data.h5" in resource_doc["uri"] + assert os.path.join("mock_panda", "data.h5") in resource_doc["uri"] [item async for item in mock_writer.collect_stream_docs(1)] assert type(mock_writer._file) is HDFFile diff --git a/tests/sim/demo/test_sim_motor.py b/tests/sim/demo/test_sim_motor.py index 7d4ed46f19..9a8ac75a6a 100644 --- a/tests/sim/demo/test_sim_motor.py +++ b/tests/sim/demo/test_sim_motor.py @@ -42,9 +42,9 @@ async def test_stop(): async with DeviceCollector(): m1 = SimMotor("M1", instant=False) - # this move should take 10 seconds but we will stop it after 0.2 + # this move should take 10 seconds but we will stop it after 0.5 move_status = m1.set(10) - await asyncio.sleep(0.2) + await asyncio.sleep(0.5) await m1.stop(success=False) new_pos = await m1.user_readback.get_value() assert new_pos < 10 diff --git a/tests/sim/test_sim_detector.py b/tests/sim/test_sim_detector.py index 7631a5b558..062e31004c 100644 --- a/tests/sim/test_sim_detector.py +++ b/tests/sim/test_sim_detector.py @@ -1,3 +1,4 @@ +import os from collections import defaultdict import bluesky.plans as bp @@ -44,6 +45,8 @@ def plan(): docs, start=1, descriptor=1, stream_resource=2, stream_datum=2, event=1, stop=1 ) path = docs["stream_resource"][0]["uri"].split("://localhost")[-1] + if os.name == "nt": + path = path.lstrip("/") h5file = h5py.File(path) assert list(h5file["/entry"]) == ["data", "sum"] assert list(h5file["/entry/sum"]) == [44540.0] diff --git a/tests/tango/test_base_device.py b/tests/tango/test_base_device.py index e1ed92eb4f..d825b37138 100644 --- a/tests/tango/test_base_device.py +++ b/tests/tango/test_base_device.py @@ -290,7 +290,7 @@ def demo_test_context(): "devices": [{"name": "demo/counter/1"}, {"name": "demo/counter/2"}], }, ) - yield MultiDeviceTestContext(content) + yield MultiDeviceTestContext(content, process=True) # -------------------------------------------------------------------- @@ -384,7 +384,7 @@ async def test_tango_demo(demo_test_context): RE(bp.count(list(detector.counters.values()))) set_status = detector.set(1.0) - await asyncio.sleep(0.1) + await asyncio.sleep(1.0) stop_status = detector.stop() await set_status await stop_status diff --git a/tests/tango/test_tango_transport.py b/tests/tango/test_tango_transport.py index 036946fbe0..f1164455c6 100644 --- a/tests/tango/test_tango_transport.py +++ b/tests/tango/test_tango_transport.py @@ -244,6 +244,7 @@ async def test_attribute_proxy_put(tango_test_device, attr, wait): else: if not wait: raise AssertionError("If wait is False, put should return a status object") + await asyncio.sleep(1.0) updated_value = await attr_proxy.get() if isinstance(new_value, np.ndarray): assert np.all(updated_value == new_value) @@ -284,6 +285,7 @@ async def test_attribute_proxy_get_w_value(tango_test_device, attr, new_value): device_proxy = await DeviceProxy(tango_test_device) attr_proxy = AttributeProxy(device_proxy, attr) await attr_proxy.put(new_value) + await asyncio.sleep(1.0) attr_proxy_value = await attr_proxy.get() if isinstance(new_value, np.ndarray): assert np.all(attr_proxy_value == new_value) @@ -794,9 +796,9 @@ async def test_tango_transport_allow_events(echo_device, allow): @pytest.mark.asyncio async def test_tango_transport_read_and_write_trl(tango_test_device): device_proxy = await DeviceProxy(tango_test_device) - trl = device_proxy.dev_name() - read_trl = trl + "/" + "readback" - write_trl = trl + "/" + "setpoint" + # Must use a FQTRL, at least on windows. + read_trl = tango_test_device + "/" + "readback" + write_trl = tango_test_device + "/" + "setpoint" # Test with existing proxy transport = TangoSignalBackend(float, read_trl, write_trl, device_proxy) From 0a126358918338d4378c3ed1b4df26b6514652d1 Mon Sep 17 00:00:00 2001 From: "Tom C (DLS)" <101418278+coretl@users.noreply.github.com> Date: Mon, 4 Nov 2024 11:02:36 +0000 Subject: [PATCH 24/30] Declarative EPICS and StandardReadable Devices (#598) Add optional Declarative Device support Includes: - An ADR for optional Declarative Devices - Support for `StandardReadable` Declarative Devices via `StandardReadableFormat` annotations - Support for EPICS Declarative Devices via an `EpicsDevice` baseclass and `PvSuffic` annotations - Updates to the EPICS and Tango demo devices to use them Structure read from `.value` now includes `DeviceVector` support. Requires at least PandABlocks-ioc 0.11.2 `ophyd_async.epics.signal` moves to `ophyd_async.epics.core` with a backwards compat module that emits deprecation warning. ```python from ophyd_async.epics.signal import epics_signal_rw from ophyd_async.epics.core import epics_signal_rw ``` `StandardReadable` wrappers change to enum members of `StandardReadableFormat` (normally imported as `Format`) ```python from ophyd_async.core import ConfigSignal, HintedSignal class MyDevice(StandardReadable): def __init__(self): self.add_readables([sig1], ConfigSignal) self.add_readables([sig2], HintedSignal) self.add_readables([sig3], HintedSignal.uncached) from ophyd_async.core import StandardReadableFormat as Format class MyDevice(StandardReadable): def __init__(self): self.add_readables([sig1], Format.CONFIG_SIGNAL) self.add_readables([sig2], Format.HINTED_SIGNAL) self.add_readables([sig3], Format.HINTED_UNCACHED_SIGNAL ``` ```python from ophyd_async.core import ConfigSignal, HintedSignal from ophyd_async.epics.signal import epics_signal_r, epics_signal_rw class Sensor(StandardReadable): def __init__(self, prefix: str, name="") -> None: with self.add_children_as_readables(HintedSignal): self.value = epics_signal_r(float, prefix + "Value") with self.add_children_as_readables(ConfigSignal): self.mode = epics_signal_rw(EnergyMode, prefix + "Mode") super().__init__(name=name) from typing import Annotated as A from ophyd_async.core import StandardReadableFormat as Format from ophyd_async.epics.core import EpicsDevice, PvSuffix, epics_signal_r, epics_signal_rw class Sensor(StandardReadable, EpicsDevice): value: A[SignalR[float], PvSuffix("Value"), Format.HINTED_SIGNAL] mode: A[SignalRW[EnergyMode], PvSuffix("Mode"), Format.CONFIG_SIGNAL] ``` --- docs/examples/foo_detector.py | 2 +- .../0009-procedural-vs-declarative-devices.md | 140 ++++++++ pyproject.toml | 3 +- src/ophyd_async/core/__init__.py | 8 +- src/ophyd_async/core/_device.py | 1 + src/ophyd_async/core/_device_filler.py | 338 +++++++++++------- src/ophyd_async/core/_readable.py | 257 +++++++------ src/ophyd_async/core/_utils.py | 10 +- src/ophyd_async/epics/adaravis/_aravis_io.py | 2 +- src/ophyd_async/epics/adcore/_core_io.py | 2 +- .../epics/adcore/_single_trigger.py | 13 +- .../epics/adkinetix/_kinetix_io.py | 2 +- .../epics/adpilatus/_pilatus_io.py | 2 +- src/ophyd_async/epics/advimba/_vimba_io.py | 2 +- src/ophyd_async/epics/core/__init__.py | 26 ++ .../epics/{signal => core}/_aioca.py | 9 +- .../epics/core/_epics_connector.py | 53 +++ src/ophyd_async/epics/core/_epics_device.py | 13 + .../epics/{signal => core}/_p4p.py | 9 +- src/ophyd_async/epics/core/_pvi_connector.py | 92 +++++ .../epics/{signal => core}/_signal.py | 47 ++- .../{signal/_common.py => core/_util.py} | 20 +- src/ophyd_async/epics/demo/_mover.py | 9 +- src/ophyd_async/epics/demo/_sensor.py | 21 +- src/ophyd_async/epics/eiger/_eiger_io.py | 2 +- src/ophyd_async/epics/eiger/_odin_io.py | 2 +- src/ophyd_async/epics/motor.py | 9 +- src/ophyd_async/epics/pvi/__init__.py | 3 - src/ophyd_async/epics/pvi/_pvi.py | 73 ---- src/ophyd_async/epics/signal.py | 11 + src/ophyd_async/epics/signal/__init__.py | 20 -- src/ophyd_async/fastcs/core.py | 4 +- src/ophyd_async/sim/demo/_sim_motor.py | 7 +- .../tango/base_devices/_base_device.py | 31 +- src/ophyd_async/tango/demo/_counter.py | 22 +- src/ophyd_async/tango/demo/_mover.py | 7 +- system_tests/epics/eiger/test_eiger_system.py | 2 +- tests/core/test_device_save_loader.py | 2 +- tests/core/test_flyer.py | 2 +- tests/core/test_mock_signal_backend.py | 2 +- tests/core/test_readable.py | 139 ++++--- tests/core/test_signal.py | 8 +- tests/core/test_subset_enum.py | 6 +- tests/core/test_utils.py | 2 +- tests/epics/adcore/test_writers.py | 2 +- tests/epics/demo/test_demo.py | 7 +- tests/epics/pvi/test_pvi.py | 37 +- tests/epics/signal/test_common.py | 2 +- tests/epics/signal/test_signals.py | 9 +- tests/fastcs/panda/db/panda.db | 16 +- tests/fastcs/panda/test_panda_connect.py | 8 +- tests/fastcs/panda/test_panda_control.py | 2 +- tests/fastcs/panda/test_panda_utils.py | 2 +- tests/plan_stubs/test_ensure_connected.py | 2 +- tests/plan_stubs/test_fly.py | 2 +- tests/tango/test_base_device.py | 22 +- 56 files changed, 968 insertions(+), 578 deletions(-) create mode 100644 docs/explanations/decisions/0009-procedural-vs-declarative-devices.md create mode 100644 src/ophyd_async/epics/core/__init__.py rename src/ophyd_async/epics/{signal => core}/_aioca.py (98%) create mode 100644 src/ophyd_async/epics/core/_epics_connector.py create mode 100644 src/ophyd_async/epics/core/_epics_device.py rename src/ophyd_async/epics/{signal => core}/_p4p.py (98%) create mode 100644 src/ophyd_async/epics/core/_pvi_connector.py rename src/ophyd_async/epics/{signal => core}/_signal.py (78%) rename src/ophyd_async/epics/{signal/_common.py => core/_util.py} (76%) delete mode 100644 src/ophyd_async/epics/pvi/__init__.py delete mode 100644 src/ophyd_async/epics/pvi/_pvi.py create mode 100644 src/ophyd_async/epics/signal.py delete mode 100644 src/ophyd_async/epics/signal/__init__.py diff --git a/docs/examples/foo_detector.py b/docs/examples/foo_detector.py index c1849e11cf..c8f906bd9d 100644 --- a/docs/examples/foo_detector.py +++ b/docs/examples/foo_detector.py @@ -10,7 +10,7 @@ StandardDetector, ) from ophyd_async.epics import adcore -from ophyd_async.epics.signal import epics_signal_rw_rbv +from ophyd_async.epics.core import epics_signal_rw_rbv class FooDriver(adcore.ADBaseIO): diff --git a/docs/explanations/decisions/0009-procedural-vs-declarative-devices.md b/docs/explanations/decisions/0009-procedural-vs-declarative-devices.md new file mode 100644 index 0000000000..9572b7b583 --- /dev/null +++ b/docs/explanations/decisions/0009-procedural-vs-declarative-devices.md @@ -0,0 +1,140 @@ +# 9. Procedural vs Declarative Devices + +Date: 01/10/24 + +## Status + +Accepted + +## Context + +In [](./0006-procedural-device-definitions.rst) we decided we preferred the procedural approach to devices, because of the issue of applying structure like `DeviceVector`. Since then we have `FastCS` and `Tango` support which use a declarative approach. We need to decide whether we are happy with this situation, or whether we should go all in one way or the other. A suitable test Device would be: + +```python +class EpicsProceduralDevice(StandardReadable): + def __init__(self, prefix: str, num_values: int, name="") -> None: + with self.add_children_as_readables(): + self.value = DeviceVector( + { + i: epics_signal_r(float, f"{prefix}Value{i}") + for i in range(1, num_values + 1) + } + ) + with self.add_children_as_readables(ConfigSignal): + self.mode = epics_signal_rw(EnergyMode, prefix + "Mode") + super().__init__(name=name) +``` + +and a Tango/FastCS procedural equivalent would be (if we add support to StandardReadable for Format.HINTED_SIGNAL and Format.CONFIG_SIGNAL annotations): +```python +class TangoDeclarativeDevice(StandardReadable, TangoDevice): + value: Annotated[DeviceVector[SignalR[float]], Format.HINTED_SIGNAL] + mode: Annotated[SignalRW[EnergyMode], Format.CONFIG_SIGNAL] +``` + +But we could specify the Tango one procedurally (with some slight ugliness around the DeviceVector): +```python +class TangoProceduralDevice(StandardReadable): + def __init__(self, prefix: str, name="") -> None: + with self.add_children_as_readables(): + self.value = DeviceVector({0: tango_signal_r(float)}) + with self.add_children_as_readables(ConfigSignal): + self.mode = tango_signal_rw(EnergyMode) + super().__init__(name=name, connector=TangoConnector(prefix)) +``` + +or the EPICS one could be declarative: +```python +class EpicsDeclarativeDevice(StandardReadable, EpicsDevice): + value: Annotated[ + DeviceVector[SignalR[float]], Format.HINTED_SIGNAL, EpicsSuffix("Value%d", "num_values") + ] + mode: Annotated[SignalRW[EnergyMode], Format.CONFIG_SIGNAL, EpicsSuffix("Mode")] +``` + +Which do we prefer? + +## Decision + +We decided that the declarative approach is to be preferred until we need to write formatted strings. At that point we should drop to an `__init__` method and a for loop. This is not a step towards only supporting the declarative approach and there are no plans to drop the procedural approach. + +The two approaches now look like: + +```python +class Sensor(StandardReadable, EpicsDevice): + """A demo sensor that produces a scalar value based on X and Y Movers""" + + value: A[SignalR[float], PvSuffix("Value"), Format.HINTED_SIGNAL] + mode: A[SignalRW[EnergyMode], PvSuffix("Mode"), Format.CONFIG_SIGNAL] + + +class SensorGroup(StandardReadable): + def __init__(self, prefix: str, name: str = "", sensor_count: int = 3) -> None: + with self.add_children_as_readables(): + self.sensors = DeviceVector( + {i: Sensor(f"{prefix}{i}:") for i in range(1, sensor_count + 1)} + ) + super().__init__(name) +``` + +## Consequences + +We need to: +- Add support for reading annotations and `PvSuffix` in an `ophyd_async.epics.core.EpicsDevice` baseclass +- Do the `Format.HINTED_SIGNAL` and `Format.CONFIG_SIGNAL` flags in annotations for `StandardReadable` +- Ensure we can always drop to `__init__` + + +## pvi structure changes +Structure read from `.value` now includes `DeviceVector` support. Requires at least PandABlocks-ioc 0.11.2 + +## Epics `signal` module moves +`ophyd_async.epics.signal` moves to `ophyd_async.epics.core` with a backwards compat module that emits deprecation warning. +```python +# old +from ophyd_async.epics.signal import epics_signal_rw +# new +from ophyd_async.epics.core import epics_signal_rw +``` + +## `StandardReadable` wrappers change to `StandardReadableFormat` +`StandardReadable` wrappers change to enum members of `StandardReadableFormat` (normally imported as `Format`) +```python +# old +from ophyd_async.core import ConfigSignal, HintedSignal +class MyDevice(StandardReadable): + def __init__(self): + self.add_readables([sig1], ConfigSignal) + self.add_readables([sig2], HintedSignal) + self.add_readables([sig3], HintedSignal.uncached) +# new +from ophyd_async.core import StandardReadableFormat as Format +class MyDevice(StandardReadable): + def __init__(self): + self.add_readables([sig1], Format.CONFIG_SIGNAL) + self.add_readables([sig2], Format.HINTED_SIGNAL) + self.add_readables([sig3], Format.HINTED_UNCACHED_SIGNAL +``` + +## Declarative Devices are now available +```python +# old +from ophyd_async.core import ConfigSignal, HintedSignal +from ophyd_async.epics.signal import epics_signal_r, epics_signal_rw + +class Sensor(StandardReadable): + def __init__(self, prefix: str, name="") -> None: + with self.add_children_as_readables(HintedSignal): + self.value = epics_signal_r(float, prefix + "Value") + with self.add_children_as_readables(ConfigSignal): + self.mode = epics_signal_rw(EnergyMode, prefix + "Mode") + super().__init__(name=name) +# new +from typing import Annotated as A +from ophyd_async.core import StandardReadableFormat as Format +from ophyd_async.epics.core import EpicsDevice, PvSuffix, epics_signal_r, epics_signal_rw + +class Sensor(StandardReadable, EpicsDevice): + value: A[SignalR[float], PvSuffix("Value"), Format.HINTED_SIGNAL] + mode: A[SignalRW[EnergyMode], PvSuffix("Mode"), Format.CONFIG_SIGNAL] +``` diff --git a/pyproject.toml b/pyproject.toml index 8c24913e0b..41975b0ae7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -94,7 +94,8 @@ reportMissingImports = false # Ignore missing stubs in imported modules # Run pytest with all our checkers, and don't spam us with massive tracebacks on error addopts = """ --tb=native -vv --strict-markers --doctest-modules - --doctest-glob="*.rst" --doctest-glob="*.md" --ignore=docs/examples + --doctest-glob="*.rst" --doctest-glob="*.md" + --ignore=docs/examples --ignore=src/ophyd_async/epics/signal.py """ # https://iscinumpy.gitlab.io/post/bound-version-constraints/#watch-for-warnings filterwarnings = "error" diff --git a/src/ophyd_async/core/__init__.py b/src/ophyd_async/core/__init__.py index 97a2c0c6d4..f19f4cb59e 100644 --- a/src/ophyd_async/core/__init__.py +++ b/src/ophyd_async/core/__init__.py @@ -45,7 +45,12 @@ UUIDFilenameProvider, YMDPathProvider, ) -from ._readable import ConfigSignal, HintedSignal, StandardReadable +from ._readable import ( + ConfigSignal, + HintedSignal, + StandardReadable, + StandardReadableFormat, +) from ._signal import ( Signal, SignalR, @@ -142,6 +147,7 @@ "ConfigSignal", "HintedSignal", "StandardReadable", + "StandardReadableFormat", "Signal", "SignalR", "SignalRW", diff --git a/src/ophyd_async/core/_device.py b/src/ophyd_async/core/_device.py index 5274b89018..1fe7855f3d 100644 --- a/src/ophyd_async/core/_device.py +++ b/src/ophyd_async/core/_device.py @@ -75,6 +75,7 @@ def __init__( self, name: str = "", connector: DeviceConnector | None = None ) -> None: self._connector = connector or DeviceConnector() + self._connector.create_children_from_annotations(self) self.set_name(name) @property diff --git a/src/ophyd_async/core/_device_filler.py b/src/ophyd_async/core/_device_filler.py index 978380161c..dfca58981f 100644 --- a/src/ophyd_async/core/_device_filler.py +++ b/src/ophyd_async/core/_device_filler.py @@ -1,14 +1,18 @@ from __future__ import annotations -import re -from collections.abc import Callable +from abc import abstractmethod +from collections.abc import Callable, Iterator, Sequence from typing import ( + Any, Generic, + NewType, NoReturn, + Protocol, TypeVar, + cast, get_args, - get_origin, get_type_hints, + runtime_checkable, ) from ._device import Device, DeviceConnector, DeviceVector @@ -16,21 +20,29 @@ from ._signal_backend import SignalBackend, SignalDatatype from ._utils import get_origin_class +SignalBackendT = TypeVar("SignalBackendT", bound=SignalBackend) +DeviceConnectorT = TypeVar("DeviceConnectorT", bound=DeviceConnector) +# Unique name possibly with trailing understore, the attribute name on the Device +UniqueName = NewType("UniqueName", str) +# Logical name without trailing underscore, the name in the control system +LogicalName = NewType("LogicalName", str) -def _strip_number_from_string(string: str) -> tuple[str, int | None]: - match = re.match(r"(.*?)(\d*)$", string) - assert match - name = match.group(1) - number = match.group(2) or None - if number is None: - return name, None - else: - return name, int(number) +def _get_datatype(annotation: Any) -> type | None: + """Return int from SignalRW[int].""" + args = get_args(annotation) + if len(args) == 1 and get_origin_class(args[0]): + return args[0] -SignalBackendT = TypeVar("SignalBackendT", bound=SignalBackend) -DeviceConnectorT = TypeVar("DeviceConnectorT", bound=DeviceConnector) +def _logical(name: UniqueName) -> LogicalName: + return LogicalName(name.rstrip("_")) + + +@runtime_checkable +class DeviceAnnotation(Protocol): + @abstractmethod + def __call__(self, parent: Device, child: Device): ... class DeviceFiller(Generic[SignalBackendT, DeviceConnectorT]): @@ -43,107 +55,180 @@ def __init__( self._device = device self._signal_backend_factory = signal_backend_factory self._device_connector_factory = device_connector_factory - self._vectors: dict[str, DeviceVector] = {} - self._vector_device_type: dict[str, type[Device] | None] = {} - self._signal_backends: dict[str, tuple[SignalBackendT, type[Signal]]] = {} - self._device_connectors: dict[str, DeviceConnectorT] = {} + # Annotations stored ready for the creation phase + self._uncreated_signals: dict[UniqueName, type[Signal]] = {} + self._uncreated_devices: dict[UniqueName, type[Device]] = {} + self._extras: dict[UniqueName, Sequence[Any]] = {} + self._signal_datatype: dict[LogicalName, type | None] = {} + self._vector_device_type: dict[LogicalName, type[Device] | None] = {} + # Backends and Connectors stored ready for the connection phase + self._unfilled_backends: dict[ + LogicalName, tuple[SignalBackendT, type[Signal]] + ] = {} + self._unfilled_connectors: dict[LogicalName, DeviceConnectorT] = {} + # Once they are filled they go here in case we reconnect + self._filled_backends: dict[ + LogicalName, tuple[SignalBackendT, type[Signal]] + ] = {} + self._filled_connectors: dict[LogicalName, DeviceConnectorT] = {} + self._scan_for_annotations() + + def _raise(self, name: str, error: str) -> NoReturn: + raise TypeError(f"{type(self._device).__name__}.{name}: {error}") + + def _store_signal_datatype(self, name: UniqueName, annotation: Any): + origin = get_origin_class(annotation) + datatype = _get_datatype(annotation) + if origin == SignalX: + # SignalX doesn't need datatype + self._signal_datatype[_logical(name)] = None + elif origin and issubclass(origin, Signal) and datatype: + # All other Signals need one + self._signal_datatype[_logical(name)] = datatype + else: + # Not recognized + self._raise( + name, + f"Expected SignalX or SignalR/W/RW[type], got {annotation}", + ) + + def _scan_for_annotations(self): # Get type hints on the class, not the instance # https://github.com/python/cpython/issues/124840 - self._annotations = get_type_hints(type(device)) - for name, annotation in self._annotations.items(): - # names have a trailing underscore if the clash with a bluesky verb, - # so strip this off to get what the CS will provide - stripped_name = name.rstrip("_") + cls = type(self._device) + # Get hints without Annotated for determining types + hints = get_type_hints(cls) + # Get hints with Annotated for wrapping signals and backends + extra_hints = get_type_hints(cls, include_extras=True) + for attr_name, annotation in hints.items(): + name = UniqueName(attr_name) origin = get_origin_class(annotation) - if name == "parent" or name.startswith("_") or not origin: - # Ignore - pass - elif issubclass(origin, Signal): - # SignalX doesn't need datatype, all others need one - datatype = self.get_datatype(name) - if origin != SignalX and datatype is None: - self._raise( - name, - f"Expected SignalX or SignalR/W/RW[type], got {annotation}", - ) - self._signal_backends[stripped_name] = ( - self.make_child_signal(name, origin), - origin, - ) + if ( + name == "parent" + or name.startswith("_") + or not origin + or not issubclass(origin, Device) + ): + # Ignore any child that is not a public Device + continue + self._extras[name] = getattr(extra_hints[attr_name], "__metadata__", ()) + if issubclass(origin, Signal): + self._store_signal_datatype(name, annotation) + self._uncreated_signals[name] = origin elif origin == DeviceVector: - # DeviceVector needs a type of device - args = get_args(annotation) or [None] - child_origin = get_origin(args[0]) or args[0] + child_type = _get_datatype(annotation) + child_origin = get_origin_class(child_type) if child_origin is None or not issubclass(child_origin, Device): self._raise( name, f"Expected DeviceVector[SomeDevice], got {annotation}", ) - self.make_device_vector(name, child_origin) - elif issubclass(origin, Device): - self._device_connectors[stripped_name] = self.make_child_device( - name, origin - ) + if issubclass(child_origin, Signal): + self._store_signal_datatype(name, child_type) + self._vector_device_type[_logical(name)] = child_origin + setattr(self._device, name, DeviceVector({})) + else: + self._uncreated_devices[name] = origin - def unfilled(self) -> set[str]: - return set(self._device_connectors).union(self._signal_backends) + def check_created(self): + uncreated = sorted(set(self._uncreated_signals).union(self._uncreated_devices)) + if uncreated: + raise RuntimeError( + f"{self._device.name}: {uncreated} have not been created yet" + ) - def _raise(self, name: str, error: str) -> NoReturn: - raise TypeError(f"{type(self._device).__name__}.{name}: {error}") + def create_signals_from_annotations( + self, + filled=True, + ) -> Iterator[tuple[SignalBackendT, list[Any]]]: + for name in list(self._uncreated_signals): + child_type = self._uncreated_signals.pop(name) + backend = self._signal_backend_factory( + self._signal_datatype[_logical(name)] + ) + extras = list(self._extras[name]) + yield backend, extras + signal = child_type(backend) + for anno in extras: + assert isinstance(anno, DeviceAnnotation), anno + anno(self._device, signal) + setattr(self._device, name, signal) + dest = self._filled_backends if filled else self._unfilled_backends + dest[_logical(name)] = (backend, child_type) - def make_device_vector(self, name: str, device_type: type[Device] | None): - self._vectors[name] = DeviceVector({}) - self._vector_device_type[name] = device_type - setattr(self._device, name, self._vectors[name]) - - def make_device_vectors(self, names: list[str]): - basenames: dict[str, set[int]] = {} - for name in names: - basename, number = _strip_number_from_string(name) - if number is not None: - basenames.setdefault(basename, set()).add(number) - for basename, numbers in basenames.items(): - # If contiguous numbers starting at 1 then it's a device vector - length = len(numbers) - if length > 1 and numbers == set(range(1, length + 1)): - # DeviceVector needs a type of device - self.make_device_vector(basename, None) - - def get_datatype(self, name: str) -> type[SignalDatatype] | None: - # Get dtype from SignalRW[dtype] or DeviceVector[SignalRW[dtype]] - basename, _ = _strip_number_from_string(name) - if basename in self._vectors: - # We decided to put it in a device vector, so get datatype from that - annotation = self._annotations.get(basename, None) - if annotation: - annotation = get_args(annotation)[0] - else: - # It's not a device vector, so get it from the full name - annotation = self._annotations.get(name, None) - args = get_args(annotation) - if args and get_origin_class(args[0]): - return args[0] - - def make_child_signal(self, name: str, signal_type: type[Signal]) -> SignalBackendT: - if name in self._signal_backends: + def create_devices_from_annotations( + self, + filled=True, + ) -> Iterator[tuple[DeviceConnectorT, list[Any]]]: + for name in list(self._uncreated_devices): + child_type = self._uncreated_devices.pop(name) + connector = self._device_connector_factory() + extras = list(self._extras[name]) + yield connector, extras + device = child_type(connector=connector) + for anno in extras: + assert isinstance(anno, DeviceAnnotation), anno + anno(self._device, device) + setattr(self._device, name, device) + dest = self._filled_connectors if filled else self._unfilled_connectors + dest[_logical(name)] = connector + + def create_device_vector_entries_to_mock(self, num: int): + for name, cls in self._vector_device_type.items(): + assert cls, "Shouldn't happen" + for i in range(1, num + 1): + if issubclass(cls, Signal): + self.fill_child_signal(name, cls, i) + elif issubclass(cls, Device): + self.fill_child_device(name, cls, i) + else: + self._raise(name, f"Can't make {cls}") + + def check_filled(self, source: str): + unfilled = sorted(set(self._unfilled_connectors).union(self._unfilled_backends)) + if unfilled: + raise RuntimeError( + f"{self._device.name}: cannot provision {unfilled} from {source}" + ) + + def _ensure_device_vector(self, name: LogicalName) -> DeviceVector: + if not hasattr(self._device, name): + # We have no type hints, so use whatever we are told + self._vector_device_type[name] = None + setattr(self._device, name, DeviceVector({})) + vector = getattr(self._device, name) + if not isinstance(vector, DeviceVector): + self._raise(name, f"Expected DeviceVector, got {vector}") + return vector + + def fill_child_signal( + self, + name: str, + signal_type: type[Signal], + vector_index: int | None = None, + ) -> SignalBackendT: + name = cast(LogicalName, name) + if name in self._unfilled_backends: # We made it above - backend, expected_signal_type = self._signal_backends.pop(name) + backend, expected_signal_type = self._unfilled_backends.pop(name) + self._filled_backends[name] = backend, expected_signal_type + elif name in self._filled_backends: + # We made it and filled it so return for validation + backend, expected_signal_type = self._filled_backends[name] + elif vector_index: + # We need to add a new entry to a DeviceVector + vector = self._ensure_device_vector(name) + backend = self._signal_backend_factory(self._signal_datatype.get(name)) + expected_signal_type = self._vector_device_type[name] or signal_type + vector[vector_index] = signal_type(backend) + elif child := getattr(self._device, name, None): + # There is an existing child, so raise + self._raise(name, f"Cannot make child as it would shadow {child}") else: - # We need to make a new one - basename, number = _strip_number_from_string(name) - child = getattr(self._device, name, None) - backend = self._signal_backend_factory(self.get_datatype(name)) - signal = signal_type(backend) - if basename in self._vectors and isinstance(number, int): - # We need to add a new entry to an existing DeviceVector - expected_signal_type = self._vector_device_type[basename] or signal_type - self._vectors[basename][number] = signal - elif child is None: - # We need to add a new child to the top level Device - expected_signal_type = signal_type - setattr(self._device, name, signal) - else: - self._raise(name, f"Cannot make child as it would shadow {child}") + # We need to add a new child to the top level Device + backend = self._signal_backend_factory(None) + expected_signal_type = signal_type + setattr(self._device, name, signal_type(backend)) if signal_type is not expected_signal_type: self._raise( name, @@ -151,41 +236,34 @@ def make_child_signal(self, name: str, signal_type: type[Signal]) -> SignalBacke ) return backend - def make_child_device( - self, name: str, device_type: type[Device] = Device + def fill_child_device( + self, + name: str, + device_type: type[Device] = Device, + vector_index: int | None = None, ) -> DeviceConnectorT: - basename, number = _strip_number_from_string(name) - child = getattr(self._device, name, None) - if connector := self._device_connectors.pop(name, None): + name = cast(LogicalName, name) + if name in self._unfilled_connectors: # We made it above - return connector - elif basename in self._vectors and isinstance(number, int): - # We need to add a new entry to an existing DeviceVector - vector_device_type = self._vector_device_type[basename] or device_type + connector = self._unfilled_connectors.pop(name) + self._filled_connectors[name] = connector + elif name in self._filled_backends: + # We made it and filled it so return for validation + connector = self._filled_connectors[name] + elif vector_index: + # We need to add a new entry to a DeviceVector + vector = self._ensure_device_vector(name) + vector_device_type = self._vector_device_type[name] or device_type assert issubclass( vector_device_type, Device ), f"{vector_device_type} is not a Device" connector = self._device_connector_factory() - device = vector_device_type(connector=connector) - self._vectors[basename][number] = device - elif child is None: + vector[vector_index] = vector_device_type(connector=connector) + elif child := getattr(self._device, name, None): + # There is an existing child, so raise + self._raise(name, f"Cannot make child as it would shadow {child}") + else: # We need to add a new child to the top level Device connector = self._device_connector_factory() - device = device_type(connector=connector) - setattr(self._device, name, device) - else: - self._raise(name, f"Cannot make child as it would shadow {child}") - connector.create_children_from_annotations(device) + setattr(self._device, name, device_type(connector=connector)) return connector - - def make_soft_device_vector_entries(self, num: int): - for basename, cls in self._vector_device_type.items(): - assert cls, "Shouldn't happen" - for i in range(num): - name = f"{basename}{i + 1}" - if issubclass(cls, Signal): - self.make_child_signal(name, cls) - elif issubclass(cls, Device): - self.make_child_device(name, cls) - else: - self._raise(name, f"Can't make {cls}") diff --git a/src/ophyd_async/core/_readable.py b/src/ophyd_async/core/_readable.py index f27c061462..e309e73c12 100644 --- a/src/ophyd_async/core/_readable.py +++ b/src/ophyd_async/core/_readable.py @@ -1,6 +1,8 @@ import warnings -from collections.abc import Callable, Generator, Sequence +from collections.abc import Awaitable, Callable, Generator, Sequence from contextlib import contextmanager +from enum import Enum +from typing import Any, cast from bluesky.protocols import HasHints, Hints, Reading from event_model import DataKey @@ -11,11 +13,61 @@ from ._status import AsyncStatus from ._utils import merge_gathered_dicts -ReadableChild = AsyncReadable | AsyncConfigurable | AsyncStageable | HasHints -ReadableChildWrapper = ( - Callable[[ReadableChild], ReadableChild] - | type["ConfigSignal"] - | type["HintedSignal"] + +class StandardReadableFormat(Enum): + """Declare how a `Device` should contribute to the `StandardReadable` verbs.""" + + #: Detect which verbs the child supports and contribute to: + #: + #: - ``read()``, ``describe()`` if it is `bluesky.protocols.Readable` + #: - ``read_configuration()``, ``describe_configuration()`` if it is + #: `bluesky.protocols.Configurable` + #: - ``stage()``, ``unstage()`` if it is `bluesky.protocols.Stageable` + #: - ``hints`` if it `bluesky.protocols.HasHints` + CHILD = "CHILD" + #: Contribute the `Signal` value to ``read_configuration()`` and + #: ``describe_configuration()`` + CONFIG_SIGNAL = "CONFIG_SIGNAL" + #: Contribute the monitored `Signal` value to ``read()`` and ``describe()``` and + #: put the signal name in ``hints`` + HINTED_SIGNAL = "HINTED_SIGNAL" + #: Contribute the uncached `Signal` value to ``read()`` and ``describe()``` + UNCACHED_SIGNAL = "UNCACHED_SIGNAL" + #: Contribute the uncached `Signal` value to ``read()`` and ``describe()``` and + #: put the signal name in ``hints`` + HINTED_UNCACHED_SIGNAL = "HINTED_UNCACHED_SIGNAL" + + def __call__(self, parent: Device, child: Device): + if not isinstance(parent, StandardReadable): + raise TypeError(f"Expected parent to be StandardReadable, got {parent}") + parent.add_readables([child], self) + + +# Back compat +class _WarningMatcher: + def __init__(self, name: str, target: StandardReadableFormat): + self._name = name + self._target = target + + def __eq__(self, value: object) -> bool: + warnings.warn( + DeprecationWarning( + f"Use `StandardReadableFormat.{self._target.name}` " + f"instead of `{self._name}`" + ), + stacklevel=2, + ) + return value == self._target + + +def _compat_format(name: str, target: StandardReadableFormat) -> StandardReadableFormat: + return cast(StandardReadableFormat, _WarningMatcher(name, target)) + + +ConfigSignal = _compat_format("ConfigSignal", StandardReadableFormat.CONFIG_SIGNAL) +HintedSignal: Any = _compat_format("HintedSignal", StandardReadableFormat.HINTED_SIGNAL) +HintedSignal.uncached = _compat_format( + "HintedSignal.uncached", StandardReadableFormat.HINTED_UNCACHED_SIGNAL ) @@ -31,38 +83,13 @@ class StandardReadable( # These must be immutable types to avoid accidental sharing between # different instances of the class - _readables: tuple[AsyncReadable, ...] = () - _configurables: tuple[AsyncConfigurable, ...] = () + _describe_config_funcs: tuple[Callable[[], Awaitable[dict[str, DataKey]]], ...] = () + _read_config_funcs: tuple[Callable[[], Awaitable[dict[str, Reading]]], ...] = () + _describe_funcs: tuple[Callable[[], Awaitable[dict[str, DataKey]]], ...] = () + _read_funcs: tuple[Callable[[], Awaitable[dict[str, Reading]]], ...] = () _stageables: tuple[AsyncStageable, ...] = () _has_hints: tuple[HasHints, ...] = () - def set_readable_signals( - self, - read: Sequence[SignalR] = (), - config: Sequence[SignalR] = (), - read_uncached: Sequence[SignalR] = (), - ): - """ - Parameters - ---------- - read: - Signals to make up :meth:`~StandardReadable.read` - conf: - Signals to make up :meth:`~StandardReadable.read_configuration` - read_uncached: - Signals to make up :meth:`~StandardReadable.read` that won't be cached - """ - warnings.warn( - DeprecationWarning( - "Migrate to `add_children_as_readables` context manager or " - "`add_readables` method" - ), - stacklevel=2, - ) - self.add_readables(read, wrapper=HintedSignal) - self.add_readables(config, wrapper=ConfigSignal) - self.add_readables(read_uncached, wrapper=HintedSignal.uncached) - @AsyncStatus.wrap async def stage(self) -> None: for sig in self._stageables: @@ -75,19 +102,17 @@ async def unstage(self) -> None: async def describe_configuration(self) -> dict[str, DataKey]: return await merge_gathered_dicts( - [sig.describe_configuration() for sig in self._configurables] + [func() for func in self._describe_config_funcs] ) async def read_configuration(self) -> dict[str, Reading]: - return await merge_gathered_dicts( - [sig.read_configuration() for sig in self._configurables] - ) + return await merge_gathered_dicts([func() for func in self._read_config_funcs]) async def describe(self) -> dict[str, DataKey]: - return await merge_gathered_dicts([sig.describe() for sig in self._readables]) + return await merge_gathered_dicts([func() for func in self._describe_funcs]) async def read(self) -> dict[str, Reading]: - return await merge_gathered_dicts([sig.read() for sig in self._readables]) + return await merge_gathered_dicts([func() for func in self._read_funcs]) @property def hints(self) -> Hints: @@ -127,27 +152,13 @@ def hints(self) -> Hints: @contextmanager def add_children_as_readables( self, - wrapper: ReadableChildWrapper | None = None, + format: StandardReadableFormat = StandardReadableFormat.CHILD, ) -> Generator[None, None, None]: - """Context manager to wrap adding Devices + """Context manager that calls `add_readables` on child Devices added within. - Add Devices to this class instance inside the Context Manager to automatically - add them to the correct fields, based on the Device's interfaces. - - The provided wrapper class will be applied to all Devices and can be used to - specify their behaviour. - - Parameters - ---------- - wrapper: - Wrapper class to apply to all Devices created inside the context manager. - - See Also - -------- - :func:`~StandardReadable.add_readables` - :class:`ConfigSignal` - :class:`HintedSignal` - :meth:`HintedSignal.uncached` + Scans ``self.children()`` on entry and exit to context manager, and calls + `add_readables` on any that are added with the provided + `StandardReadableFormat`. """ dict_copy = dict(self.children()) @@ -167,95 +178,83 @@ def add_children_as_readables( flattened_values.append(value) new_devices = list(filter(lambda x: isinstance(x, Device), flattened_values)) - self.add_readables(new_devices, wrapper) + self.add_readables(new_devices, format) def add_readables( self, - devices: Sequence[ReadableChild], - wrapper: ReadableChildWrapper | None = None, + devices: Sequence[Device], + format: StandardReadableFormat = StandardReadableFormat.CHILD, ) -> None: - """Add the given devices to the lists of known Devices + """Add devices to contribute to various bluesky verbs. - Add the provided Devices to the relevant fields, based on the Signal's - interfaces. + Use output from the given devices to contribute to the verbs of the following + interfaces: - The provided wrapper class will be applied to all Devices and can be used to - specify their behaviour. + - `bluesky.protocols.Readable` + - `bluesky.protocols.Configurable` + - `bluesky.protocols.Stageable` + - `bluesky.protocols.HasHints` Parameters ---------- devices: The devices to be added - wrapper: - Wrapper class to apply to all Devices created inside the context manager. - - See Also - -------- - :func:`~StandardReadable.add_children_as_readables` - :class:`ConfigSignal` - :class:`HintedSignal` - :meth:`HintedSignal.uncached` + format: + Determines which of the devices functions are added to which verb as per the + `StandardReadableFormat` documentation """ - for readable in devices: - obj = readable - if wrapper: - obj = wrapper(readable) - - if isinstance(obj, AsyncReadable): - self._readables += (obj,) - - if isinstance(obj, AsyncConfigurable): - self._configurables += (obj,) - - if isinstance(obj, AsyncStageable): - self._stageables += (obj,) - - if isinstance(obj, HasHints): - self._has_hints += (obj,) - - -class ConfigSignal(AsyncConfigurable): - def __init__(self, signal: ReadableChild) -> None: - assert isinstance(signal, SignalR), f"Expected signal, got {signal}" + for device in devices: + match format: + case StandardReadableFormat.CHILD: + if isinstance(device, AsyncConfigurable): + self._describe_config_funcs += (device.describe_configuration,) + self._read_config_funcs += (device.read_configuration,) + if isinstance(device, AsyncReadable): + self._describe_funcs += (device.describe,) + self._read_funcs += (device.read,) + if isinstance(device, AsyncStageable): + self._stageables += (device,) + if isinstance(device, HasHints): + self._has_hints += (device,) + case StandardReadableFormat.CONFIG_SIGNAL: + assert isinstance(device, SignalR), f"{device} is not a SignalR" + self._describe_config_funcs += (device.describe,) + self._read_config_funcs += (device.read,) + case StandardReadableFormat.HINTED_SIGNAL: + assert isinstance(device, SignalR), f"{device} is not a SignalR" + self._describe_funcs += (device.describe,) + self._read_funcs += (device.read,) + self._stageables += (device,) + self._has_hints += (_HintsFromName(device),) + case StandardReadableFormat.UNCACHED_SIGNAL: + assert isinstance(device, SignalR), f"{device} is not a SignalR" + self._describe_funcs += (device.describe,) + self._read_funcs += (_UncachedRead(device),) + case StandardReadableFormat.HINTED_UNCACHED_SIGNAL: + assert isinstance(device, SignalR), f"{device} is not a SignalR" + self._describe_funcs += (device.describe,) + self._read_funcs += (_UncachedRead(device),) + self._has_hints += (_HintsFromName(device),) + + +class _UncachedRead: + def __init__(self, signal: SignalR) -> None: self.signal = signal - async def read_configuration(self) -> dict[str, Reading]: - return await self.signal.read() - - async def describe_configuration(self) -> dict[str, DataKey]: - return await self.signal.describe() - - @property - def name(self) -> str: - return self.signal.name + async def __call__(self) -> dict[str, Reading]: + return await self.signal.read(cached=False) -class HintedSignal(HasHints, AsyncReadable): - def __init__(self, signal: ReadableChild, allow_cache: bool = True) -> None: - assert isinstance(signal, SignalR), f"Expected signal, got {signal}" - self.signal = signal - self.cached = None if allow_cache else allow_cache - if allow_cache: - self.stage = signal.stage - self.unstage = signal.unstage - - async def read(self) -> dict[str, Reading]: - return await self.signal.read(cached=self.cached) - - async def describe(self) -> dict[str, DataKey]: - return await self.signal.describe() +class _HintsFromName(HasHints): + def __init__(self, device: Device) -> None: + self.device = device @property def name(self) -> str: - return self.signal.name + return self.device.name @property def hints(self) -> Hints: - if self.signal.name == "": - return {"fields": []} - return {"fields": [self.signal.name]} - - @classmethod - def uncached(cls, signal: ReadableChild) -> "HintedSignal": - return cls(signal, allow_cache=False) + fields = [self.name] if self.name else [] + return {"fields": fields} diff --git a/src/ophyd_async/core/_utils.py b/src/ophyd_async/core/_utils.py index a8800191d2..db4afae04a 100644 --- a/src/ophyd_async/core/_utils.py +++ b/src/ophyd_async/core/_utils.py @@ -5,7 +5,15 @@ from collections.abc import Awaitable, Callable, Iterable, Sequence from dataclasses import dataclass from enum import Enum, EnumMeta -from typing import Any, Generic, Literal, ParamSpec, TypeVar, get_args, get_origin +from typing import ( + Any, + Generic, + Literal, + ParamSpec, + TypeVar, + get_args, + get_origin, +) import numpy as np diff --git a/src/ophyd_async/epics/adaravis/_aravis_io.py b/src/ophyd_async/epics/adaravis/_aravis_io.py index 9707beac2d..e16beb41f4 100644 --- a/src/ophyd_async/epics/adaravis/_aravis_io.py +++ b/src/ophyd_async/epics/adaravis/_aravis_io.py @@ -1,6 +1,6 @@ from ophyd_async.core import StrictEnum, SubsetEnum from ophyd_async.epics import adcore -from ophyd_async.epics.signal import epics_signal_rw_rbv +from ophyd_async.epics.core import epics_signal_rw_rbv class AravisTriggerMode(StrictEnum): diff --git a/src/ophyd_async/epics/adcore/_core_io.py b/src/ophyd_async/epics/adcore/_core_io.py index e044b0e5d1..97332d9a15 100644 --- a/src/ophyd_async/epics/adcore/_core_io.py +++ b/src/ophyd_async/epics/adcore/_core_io.py @@ -1,5 +1,5 @@ from ophyd_async.core import Device, StrictEnum -from ophyd_async.epics.signal import ( +from ophyd_async.epics.core import ( epics_signal_r, epics_signal_rw, epics_signal_rw_rbv, diff --git a/src/ophyd_async/epics/adcore/_single_trigger.py b/src/ophyd_async/epics/adcore/_single_trigger.py index c39e4e46ab..9fd81b413d 100644 --- a/src/ophyd_async/epics/adcore/_single_trigger.py +++ b/src/ophyd_async/epics/adcore/_single_trigger.py @@ -3,13 +3,8 @@ from bluesky.protocols import Triggerable -from ophyd_async.core import ( - AsyncStatus, - ConfigSignal, - HintedSignal, - SignalR, - StandardReadable, -) +from ophyd_async.core import AsyncStatus, SignalR, StandardReadable +from ophyd_async.core import StandardReadableFormat as Format from ._core_io import ADBaseIO, NDPluginBaseIO from ._utils import ImageMode @@ -28,10 +23,10 @@ def __init__( self.add_readables( [self.drv.array_counter, *read_uncached], - wrapper=HintedSignal.uncached, + Format.HINTED_UNCACHED_SIGNAL, ) - self.add_readables([self.drv.acquire_time], wrapper=ConfigSignal) + self.add_readables([self.drv.acquire_time], Format.CONFIG_SIGNAL) super().__init__(name=name) diff --git a/src/ophyd_async/epics/adkinetix/_kinetix_io.py b/src/ophyd_async/epics/adkinetix/_kinetix_io.py index 4b70886648..bbe53eb410 100644 --- a/src/ophyd_async/epics/adkinetix/_kinetix_io.py +++ b/src/ophyd_async/epics/adkinetix/_kinetix_io.py @@ -1,6 +1,6 @@ from ophyd_async.core import StrictEnum from ophyd_async.epics import adcore -from ophyd_async.epics.signal import epics_signal_rw_rbv +from ophyd_async.epics.core import epics_signal_rw_rbv class KinetixTriggerMode(StrictEnum): diff --git a/src/ophyd_async/epics/adpilatus/_pilatus_io.py b/src/ophyd_async/epics/adpilatus/_pilatus_io.py index 51ca65ce9c..093398ec61 100644 --- a/src/ophyd_async/epics/adpilatus/_pilatus_io.py +++ b/src/ophyd_async/epics/adpilatus/_pilatus_io.py @@ -1,6 +1,6 @@ from ophyd_async.core import StrictEnum from ophyd_async.epics import adcore -from ophyd_async.epics.signal import epics_signal_r, epics_signal_rw_rbv +from ophyd_async.epics.core import epics_signal_r, epics_signal_rw_rbv class PilatusTriggerMode(StrictEnum): diff --git a/src/ophyd_async/epics/advimba/_vimba_io.py b/src/ophyd_async/epics/advimba/_vimba_io.py index 0dc7571b7b..c95a873831 100644 --- a/src/ophyd_async/epics/advimba/_vimba_io.py +++ b/src/ophyd_async/epics/advimba/_vimba_io.py @@ -1,6 +1,6 @@ from ophyd_async.core import StrictEnum from ophyd_async.epics import adcore -from ophyd_async.epics.signal import epics_signal_rw_rbv +from ophyd_async.epics.core import epics_signal_rw_rbv class VimbaPixelFormat(StrictEnum): diff --git a/src/ophyd_async/epics/core/__init__.py b/src/ophyd_async/epics/core/__init__.py new file mode 100644 index 0000000000..1b6904c6f0 --- /dev/null +++ b/src/ophyd_async/epics/core/__init__.py @@ -0,0 +1,26 @@ +from ._epics_connector import EpicsDeviceConnector, PvSuffix +from ._epics_device import EpicsDevice +from ._pvi_connector import PviDeviceConnector +from ._signal import ( + CaSignalBackend, + PvaSignalBackend, + epics_signal_r, + epics_signal_rw, + epics_signal_rw_rbv, + epics_signal_w, + epics_signal_x, +) + +__all__ = [ + "PviDeviceConnector", + "EpicsDeviceConnector", + "PvSuffix", + "EpicsDevice", + "CaSignalBackend", + "PvaSignalBackend", + "epics_signal_r", + "epics_signal_rw", + "epics_signal_rw_rbv", + "epics_signal_w", + "epics_signal_x", +] diff --git a/src/ophyd_async/epics/signal/_aioca.py b/src/ophyd_async/epics/core/_aioca.py similarity index 98% rename from src/ophyd_async/epics/signal/_aioca.py rename to src/ophyd_async/epics/core/_aioca.py index b7add417b4..a11821be37 100644 --- a/src/ophyd_async/epics/signal/_aioca.py +++ b/src/ophyd_async/epics/core/_aioca.py @@ -24,7 +24,6 @@ Array1D, Callback, NotConnected, - SignalBackend, SignalDatatype, SignalDatatypeT, SignalMetadata, @@ -34,7 +33,7 @@ wait_for_connection, ) -from ._common import format_datatype, get_supported_values +from ._util import EpicsSignalBackend, format_datatype, get_supported_values def _limits_from_augmented_value(value: AugmentedValue) -> Limits: @@ -227,19 +226,17 @@ def _use_pyepics_context_if_imported(): _tried_pyepics = True -class CaSignalBackend(SignalBackend[SignalDatatypeT]): +class CaSignalBackend(EpicsSignalBackend[SignalDatatypeT]): def __init__( self, datatype: type[SignalDatatypeT] | None, read_pv: str = "", write_pv: str = "", ): - self.read_pv = read_pv - self.write_pv = write_pv self.converter: CaConverter = DisconnectedCaConverter(float, dbr.DBR_DOUBLE) self.initial_values: dict[str, AugmentedValue] = {} self.subscription: Subscription | None = None - super().__init__(datatype) + super().__init__(datatype, read_pv, write_pv) def source(self, name: str, read: bool): return f"ca://{self.read_pv if read else self.write_pv}" diff --git a/src/ophyd_async/epics/core/_epics_connector.py b/src/ophyd_async/epics/core/_epics_connector.py new file mode 100644 index 0000000000..16cb103331 --- /dev/null +++ b/src/ophyd_async/epics/core/_epics_connector.py @@ -0,0 +1,53 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +from ophyd_async.core import Device, DeviceConnector, DeviceFiller + +from ._signal import EpicsSignalBackend, get_signal_backend_type, split_protocol_from_pv + + +@dataclass +class PvSuffix: + read_suffix: str + write_suffix: str | None = None + + @classmethod + def rbv(cls, write_suffix: str, rbv_suffix: str = "_RBV") -> PvSuffix: + return cls(write_suffix + rbv_suffix, write_suffix) + + +def fill_backend_with_prefix( + prefix: str, backend: EpicsSignalBackend, annotations: list[Any] +): + unhandled = [] + while annotations: + annotation = annotations.pop(0) + if isinstance(annotation, PvSuffix): + backend.read_pv = prefix + annotation.read_suffix + backend.write_pv = prefix + ( + annotation.write_suffix or annotation.read_suffix + ) + else: + unhandled.append(annotation) + annotations.extend(unhandled) + # These leftover annotations will now be handled by the iterator + + +class EpicsDeviceConnector(DeviceConnector): + def __init__(self, prefix: str) -> None: + self.prefix = prefix + + def create_children_from_annotations(self, device: Device): + if not hasattr(self, "filler"): + protocol, prefix = split_protocol_from_pv(self.prefix) + self.filler = DeviceFiller( + device, + signal_backend_factory=get_signal_backend_type(protocol), + device_connector_factory=DeviceConnector, + ) + for backend, annotations in self.filler.create_signals_from_annotations(): + fill_backend_with_prefix(prefix, backend, annotations) + + list(self.filler.create_devices_from_annotations()) diff --git a/src/ophyd_async/epics/core/_epics_device.py b/src/ophyd_async/epics/core/_epics_device.py new file mode 100644 index 0000000000..d72c88f8b8 --- /dev/null +++ b/src/ophyd_async/epics/core/_epics_device.py @@ -0,0 +1,13 @@ +from ophyd_async.core import Device + +from ._epics_connector import EpicsDeviceConnector +from ._pvi_connector import PviDeviceConnector + + +class EpicsDevice(Device): + def __init__(self, prefix: str, with_pvi: bool = False, name: str = ""): + if with_pvi: + connector = PviDeviceConnector(prefix) + else: + connector = EpicsDeviceConnector(prefix) + super().__init__(name=name, connector=connector) diff --git a/src/ophyd_async/epics/signal/_p4p.py b/src/ophyd_async/epics/core/_p4p.py similarity index 98% rename from src/ophyd_async/epics/signal/_p4p.py rename to src/ophyd_async/epics/core/_p4p.py index 737b60fc9c..423839f5ee 100644 --- a/src/ophyd_async/epics/signal/_p4p.py +++ b/src/ophyd_async/epics/core/_p4p.py @@ -18,7 +18,6 @@ Array1D, Callback, NotConnected, - SignalBackend, SignalDatatype, SignalDatatypeT, SignalMetadata, @@ -30,7 +29,7 @@ wait_for_connection, ) -from ._common import format_datatype, get_supported_values +from ._util import EpicsSignalBackend, format_datatype, get_supported_values def _limits_from_value(value: Any) -> Limits: @@ -293,19 +292,17 @@ def _pva_request_string(fields: Sequence[str]) -> str: return f"field({','.join(fields)})" -class PvaSignalBackend(SignalBackend[SignalDatatypeT]): +class PvaSignalBackend(EpicsSignalBackend[SignalDatatypeT]): def __init__( self, datatype: type[SignalDatatypeT] | None, read_pv: str = "", write_pv: str = "", ): - self.read_pv = read_pv - self.write_pv = write_pv self.converter: PvaConverter = DisconnectedPvaConverter(float) self.initial_values: dict[str, Any] = {} self.subscription: Subscription | None = None - super().__init__(datatype) + super().__init__(datatype, read_pv, write_pv) def source(self, name: str, read: bool): return f"pva://{self.read_pv if read else self.write_pv}" diff --git a/src/ophyd_async/epics/core/_pvi_connector.py b/src/ophyd_async/epics/core/_pvi_connector.py new file mode 100644 index 0000000000..812e4ec473 --- /dev/null +++ b/src/ophyd_async/epics/core/_pvi_connector.py @@ -0,0 +1,92 @@ +from __future__ import annotations + +from unittest.mock import Mock + +from ophyd_async.core import ( + Device, + DeviceConnector, + DeviceFiller, + Signal, + SignalR, + SignalRW, + SignalX, +) + +from ._epics_connector import fill_backend_with_prefix +from ._signal import PvaSignalBackend, pvget_with_timeout + +Entry = dict[str, str] + + +def _get_signal_details(entry: Entry) -> tuple[type[Signal], str, str]: + match entry: + case {"r": read_pv}: + return SignalR, read_pv, read_pv + case {"r": read_pv, "w": write_pv}: + return SignalRW, read_pv, write_pv + case {"rw": read_write_pv}: + return SignalRW, read_write_pv, read_write_pv + case {"x": execute_pv}: + return SignalX, execute_pv, execute_pv + case _: + raise TypeError(f"Can't process entry {entry}") + + +class PviDeviceConnector(DeviceConnector): + def __init__(self, prefix: str = "") -> None: + # TODO: what happens if we get a leading "pva://" here? + self.prefix = prefix + self.pvi_pv = prefix + "PVI" + + def create_children_from_annotations(self, device: Device): + if not hasattr(self, "filler"): + self.filler = DeviceFiller( + device=device, + signal_backend_factory=PvaSignalBackend, + device_connector_factory=PviDeviceConnector, + ) + # Devices will be created with unfilled PviDeviceConnectors + list(self.filler.create_devices_from_annotations(filled=False)) + # Signals can be filled in with EpicsSignalSuffix and checked at runtime + for backend, annotations in self.filler.create_signals_from_annotations( + filled=False + ): + fill_backend_with_prefix(self.prefix, backend, annotations) + self.filler.check_created() + + def _fill_child(self, name: str, entry: Entry, vector_index: int | None = None): + if set(entry) == {"d"}: + connector = self.filler.fill_child_device(name, vector_index=vector_index) + connector.pvi_pv = entry["d"] + else: + signal_type, read_pv, write_pv = _get_signal_details(entry) + backend = self.filler.fill_child_signal(name, signal_type, vector_index) + backend.read_pv = read_pv + backend.write_pv = write_pv + + async def connect( + self, device: Device, mock: bool | Mock, timeout: float, force_reconnect: bool + ) -> None: + if mock: + # Make 2 entries for each DeviceVector + self.filler.create_device_vector_entries_to_mock(2) + else: + pvi_structure = await pvget_with_timeout(self.pvi_pv, timeout) + entries: dict[str, Entry | list[Entry | None]] = pvi_structure[ + "value" + ].todict() + # Fill based on what PVI gives us + for name, entry in entries.items(): + if isinstance(entry, dict): + # This is a child + self._fill_child(name, entry) + else: + # This is a DeviceVector of children + for i, e in enumerate(entry): + if e: + self._fill_child(name, e, i) + # Check that all the requested children have been filled + self.filler.check_filled(f"{self.pvi_pv}: {entries}") + # Set the name of the device to name all children + device.set_name(device.name) + return await super().connect(device, mock, timeout, force_reconnect) diff --git a/src/ophyd_async/epics/signal/_signal.py b/src/ophyd_async/epics/core/_signal.py similarity index 78% rename from src/ophyd_async/epics/signal/_signal.py rename to src/ophyd_async/epics/core/_signal.py index 180ed2a6e3..285a846d23 100644 --- a/src/ophyd_async/epics/signal/_signal.py +++ b/src/ophyd_async/epics/core/_signal.py @@ -14,13 +14,7 @@ get_unique, ) - -def _make_unavailable_class(error: Exception) -> type: - class TransportNotAvailable: - def __init__(*args, **kwargs): - raise NotImplementedError("Transport not available") from error - - return TransportNotAvailable +from ._util import EpicsSignalBackend class EpicsProtocol(Enum): @@ -30,10 +24,26 @@ class EpicsProtocol(Enum): _default_epics_protocol = EpicsProtocol.CA + +def _make_unavailable_function(error: Exception): + def transport_not_available(*args, **kwargs): + raise NotImplementedError("Transport not available") from error + + return transport_not_available + + +def _make_unavailable_class(error: Exception) -> type[EpicsSignalBackend]: + class TransportNotAvailable(EpicsSignalBackend): + __init__ = _make_unavailable_function(error) + + return TransportNotAvailable + + try: - from ._p4p import PvaSignalBackend + from ._p4p import PvaSignalBackend, pvget_with_timeout except ImportError as pva_error: PvaSignalBackend = _make_unavailable_class(pva_error) + pvget_with_timeout = _make_unavailable_function(pva_error) else: _default_epics_protocol = EpicsProtocol.PVA @@ -45,7 +55,7 @@ class EpicsProtocol(Enum): _default_epics_protocol = EpicsProtocol.CA -def _protocol_pv(pv: str) -> tuple[EpicsProtocol, str]: +def split_protocol_from_pv(pv: str) -> tuple[EpicsProtocol, str]: split = pv.split("://", 1) if len(split) > 1: # We got something like pva://mydevice, so use specified comms mode @@ -57,18 +67,23 @@ def _protocol_pv(pv: str) -> tuple[EpicsProtocol, str]: return protocol, pv +def get_signal_backend_type(protocol: EpicsProtocol) -> type[EpicsSignalBackend]: + match protocol: + case EpicsProtocol.CA: + return CaSignalBackend + case EpicsProtocol.PVA: + return PvaSignalBackend + + def _epics_signal_backend( datatype: type[SignalDatatypeT] | None, read_pv: str, write_pv: str ) -> SignalBackend[SignalDatatypeT]: """Create an epics signal backend.""" - r_protocol, r_pv = _protocol_pv(read_pv) - w_protocol, w_pv = _protocol_pv(write_pv) + r_protocol, r_pv = split_protocol_from_pv(read_pv) + w_protocol, w_pv = split_protocol_from_pv(write_pv) protocol = get_unique({read_pv: r_protocol, write_pv: w_protocol}, "protocols") - match protocol: - case EpicsProtocol.CA: - return CaSignalBackend(datatype, r_pv, w_pv) - case EpicsProtocol.PVA: - return PvaSignalBackend(datatype, r_pv, w_pv) + signal_backend_type = get_signal_backend_type(protocol) + return signal_backend_type(datatype, r_pv, w_pv) def epics_signal_rw( diff --git a/src/ophyd_async/epics/signal/_common.py b/src/ophyd_async/epics/core/_util.py similarity index 76% rename from src/ophyd_async/epics/signal/_common.py rename to src/ophyd_async/epics/core/_util.py index d11c85be54..56cd058515 100644 --- a/src/ophyd_async/epics/signal/_common.py +++ b/src/ophyd_async/epics/core/_util.py @@ -3,7 +3,13 @@ import numpy as np -from ophyd_async.core import SubsetEnum, get_dtype, get_enum_cls +from ophyd_async.core import ( + SignalBackend, + SignalDatatypeT, + SubsetEnum, + get_dtype, + get_enum_cls, +) def get_supported_values( @@ -41,3 +47,15 @@ def format_datatype(datatype: Any) -> str: return datatype.__name__ else: return str(datatype) + + +class EpicsSignalBackend(SignalBackend[SignalDatatypeT]): + def __init__( + self, + datatype: type[SignalDatatypeT] | None, + read_pv: str = "", + write_pv: str = "", + ): + self.read_pv = read_pv + self.write_pv = write_pv + super().__init__(datatype) diff --git a/src/ophyd_async/epics/demo/_mover.py b/src/ophyd_async/epics/demo/_mover.py index 72266de846..4c1e35fa8d 100644 --- a/src/ophyd_async/epics/demo/_mover.py +++ b/src/ophyd_async/epics/demo/_mover.py @@ -8,15 +8,14 @@ DEFAULT_TIMEOUT, AsyncStatus, CalculatableTimeout, - ConfigSignal, Device, - HintedSignal, StandardReadable, WatchableAsyncStatus, WatcherUpdate, observe_value, ) -from ophyd_async.epics.signal import epics_signal_r, epics_signal_rw, epics_signal_x +from ophyd_async.core import StandardReadableFormat as Format +from ophyd_async.epics.core import epics_signal_r, epics_signal_rw, epics_signal_x class Mover(StandardReadable, Movable, Stoppable): @@ -24,9 +23,9 @@ class Mover(StandardReadable, Movable, Stoppable): def __init__(self, prefix: str, name="") -> None: # Define some signals - with self.add_children_as_readables(HintedSignal): + with self.add_children_as_readables(Format.HINTED_SIGNAL): self.readback = epics_signal_r(float, prefix + "Readback") - with self.add_children_as_readables(ConfigSignal): + with self.add_children_as_readables(Format.CONFIG_SIGNAL): self.velocity = epics_signal_rw(float, prefix + "Velocity") self.units = epics_signal_r(str, prefix + "Readback.EGU") self.setpoint = epics_signal_rw(float, prefix + "Setpoint") diff --git a/src/ophyd_async/epics/demo/_sensor.py b/src/ophyd_async/epics/demo/_sensor.py index 5235fe0aba..1004a04dae 100644 --- a/src/ophyd_async/epics/demo/_sensor.py +++ b/src/ophyd_async/epics/demo/_sensor.py @@ -1,11 +1,14 @@ +from typing import Annotated as A + from ophyd_async.core import ( - ConfigSignal, DeviceVector, - HintedSignal, + SignalR, + SignalRW, StandardReadable, StrictEnum, ) -from ophyd_async.epics.signal import epics_signal_r, epics_signal_rw +from ophyd_async.core import StandardReadableFormat as Format +from ophyd_async.epics.core import EpicsDevice, PvSuffix class EnergyMode(StrictEnum): @@ -17,17 +20,11 @@ class EnergyMode(StrictEnum): high = "High Energy" -class Sensor(StandardReadable): +class Sensor(StandardReadable, EpicsDevice): """A demo sensor that produces a scalar value based on X and Y Movers""" - def __init__(self, prefix: str, name="") -> None: - # Define some signals - with self.add_children_as_readables(HintedSignal): - self.value = epics_signal_r(float, prefix + "Value") - with self.add_children_as_readables(ConfigSignal): - self.mode = epics_signal_rw(EnergyMode, prefix + "Mode") - - super().__init__(name=name) + value: A[SignalR[float], PvSuffix("Value"), Format.HINTED_SIGNAL] + mode: A[SignalRW[EnergyMode], PvSuffix("Mode"), Format.CONFIG_SIGNAL] class SensorGroup(StandardReadable): diff --git a/src/ophyd_async/epics/eiger/_eiger_io.py b/src/ophyd_async/epics/eiger/_eiger_io.py index ed61c0b326..ef4451aa7d 100644 --- a/src/ophyd_async/epics/eiger/_eiger_io.py +++ b/src/ophyd_async/epics/eiger/_eiger_io.py @@ -1,5 +1,5 @@ from ophyd_async.core import Device, StrictEnum -from ophyd_async.epics.signal import epics_signal_r, epics_signal_rw_rbv, epics_signal_w +from ophyd_async.epics.core import epics_signal_r, epics_signal_rw_rbv, epics_signal_w class EigerTriggerMode(StrictEnum): diff --git a/src/ophyd_async/epics/eiger/_odin_io.py b/src/ophyd_async/epics/eiger/_odin_io.py index 321b0d7bb6..19e0a0fc4e 100644 --- a/src/ophyd_async/epics/eiger/_odin_io.py +++ b/src/ophyd_async/epics/eiger/_odin_io.py @@ -15,7 +15,7 @@ observe_value, set_and_wait_for_value, ) -from ophyd_async.epics.signal import ( +from ophyd_async.epics.core import ( epics_signal_r, epics_signal_rw, epics_signal_rw_rbv, diff --git a/src/ophyd_async/epics/motor.py b/src/ophyd_async/epics/motor.py index a890249561..f03a29bcb0 100644 --- a/src/ophyd_async/epics/motor.py +++ b/src/ophyd_async/epics/motor.py @@ -14,14 +14,13 @@ DEFAULT_TIMEOUT, AsyncStatus, CalculatableTimeout, - ConfigSignal, - HintedSignal, StandardReadable, WatchableAsyncStatus, WatcherUpdate, observe_value, ) -from ophyd_async.epics.signal import epics_signal_r, epics_signal_rw, epics_signal_x +from ophyd_async.core import StandardReadableFormat as Format +from ophyd_async.epics.core import epics_signal_r, epics_signal_rw, epics_signal_x class MotorLimitsException(Exception): @@ -61,11 +60,11 @@ class Motor(StandardReadable, Locatable, Stoppable, Flyable, Preparable): def __init__(self, prefix: str, name="") -> None: # Define some signals - with self.add_children_as_readables(ConfigSignal): + with self.add_children_as_readables(Format.CONFIG_SIGNAL): self.motor_egu = epics_signal_r(str, prefix + ".EGU") self.velocity = epics_signal_rw(float, prefix + ".VELO") - with self.add_children_as_readables(HintedSignal): + with self.add_children_as_readables(Format.HINTED_SIGNAL): self.user_readback = epics_signal_r(float, prefix + ".RBV") self.user_setpoint = epics_signal_rw(float, prefix + ".VAL") diff --git a/src/ophyd_async/epics/pvi/__init__.py b/src/ophyd_async/epics/pvi/__init__.py deleted file mode 100644 index 1352036600..0000000000 --- a/src/ophyd_async/epics/pvi/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from ._pvi import PviDeviceConnector - -__all__ = ["PviDeviceConnector"] diff --git a/src/ophyd_async/epics/pvi/_pvi.py b/src/ophyd_async/epics/pvi/_pvi.py deleted file mode 100644 index 5bd7a38ef8..0000000000 --- a/src/ophyd_async/epics/pvi/_pvi.py +++ /dev/null @@ -1,73 +0,0 @@ -from __future__ import annotations - -from unittest.mock import Mock - -from ophyd_async.core import ( - Device, - DeviceConnector, - DeviceFiller, - Signal, - SignalR, - SignalRW, - SignalX, -) -from ophyd_async.epics.signal import ( - PvaSignalBackend, - pvget_with_timeout, -) - - -def _get_signal_details(entry: dict[str, str]) -> tuple[type[Signal], str, str]: - match entry: - case {"r": read_pv}: - return SignalR, read_pv, read_pv - case {"r": read_pv, "w": write_pv}: - return SignalRW, read_pv, write_pv - case {"rw": read_write_pv}: - return SignalRW, read_write_pv, read_write_pv - case {"x": execute_pv}: - return SignalX, execute_pv, execute_pv - case _: - raise TypeError(f"Can't process entry {entry}") - - -class PviDeviceConnector(DeviceConnector): - def __init__(self, pvi_pv: str = "") -> None: - self.pvi_pv = pvi_pv - - def create_children_from_annotations(self, device: Device): - self._filler = DeviceFiller( - device=device, - signal_backend_factory=PvaSignalBackend, - device_connector_factory=PviDeviceConnector, - ) - - async def connect( - self, device: Device, mock: bool | Mock, timeout: float, force_reconnect: bool - ) -> None: - if mock: - # Make 2 entries for each DeviceVector - self._filler.make_soft_device_vector_entries(2) - else: - pvi_structure = await pvget_with_timeout(self.pvi_pv, timeout) - entries: dict[str, dict[str, str]] = pvi_structure["value"].todict() - # Ensure we have device vectors for everything that should be there - self._filler.make_device_vectors(list(entries)) - for name, entry in entries.items(): - if set(entry) == {"d"}: - connector = self._filler.make_child_device(name) - connector.pvi_pv = entry["d"] - else: - signal_type, read_pv, write_pv = _get_signal_details(entry) - backend = self._filler.make_child_signal(name, signal_type) - backend.read_pv = read_pv - backend.write_pv = write_pv - # Check that all the requested children have been created - if unfilled := self._filler.unfilled(): - raise RuntimeError( - f"{device.name}: cannot provision {unfilled} from " - f"{self.pvi_pv}: {entries}" - ) - # Set the name of the device to name all children - device.set_name(device.name) - return await super().connect(device, mock, timeout, force_reconnect) diff --git a/src/ophyd_async/epics/signal.py b/src/ophyd_async/epics/signal.py new file mode 100644 index 0000000000..4cda92c14a --- /dev/null +++ b/src/ophyd_async/epics/signal.py @@ -0,0 +1,11 @@ +# back compat +import warnings + +from .core import * # noqa: F403 + +warnings.warn( + DeprecationWarning( + "Use `ophyd_async.epics.core` instead of `ophyd_async.epics.signal` and `pvi`" + ), + stacklevel=2, +) diff --git a/src/ophyd_async/epics/signal/__init__.py b/src/ophyd_async/epics/signal/__init__.py deleted file mode 100644 index 703880c9ab..0000000000 --- a/src/ophyd_async/epics/signal/__init__.py +++ /dev/null @@ -1,20 +0,0 @@ -from ._common import get_supported_values -from ._p4p import PvaSignalBackend, pvget_with_timeout -from ._signal import ( - epics_signal_r, - epics_signal_rw, - epics_signal_rw_rbv, - epics_signal_w, - epics_signal_x, -) - -__all__ = [ - "get_supported_values", - "PvaSignalBackend", - "pvget_with_timeout", - "epics_signal_r", - "epics_signal_rw", - "epics_signal_rw_rbv", - "epics_signal_w", - "epics_signal_x", -] diff --git a/src/ophyd_async/fastcs/core.py b/src/ophyd_async/fastcs/core.py index bd2e32a033..d7561c7578 100644 --- a/src/ophyd_async/fastcs/core.py +++ b/src/ophyd_async/fastcs/core.py @@ -1,9 +1,9 @@ from ophyd_async.core import Device, DeviceConnector -from ophyd_async.epics.pvi import PviDeviceConnector +from ophyd_async.epics.core import PviDeviceConnector def fastcs_connector(device: Device, uri: str) -> DeviceConnector: # TODO: add Tango support based on uri scheme - connector = PviDeviceConnector(uri + "PVI") + connector = PviDeviceConnector(uri) connector.create_children_from_annotations(device) return connector diff --git a/src/ophyd_async/sim/demo/_sim_motor.py b/src/ophyd_async/sim/demo/_sim_motor.py index eaca21e45d..f1c7f043df 100644 --- a/src/ophyd_async/sim/demo/_sim_motor.py +++ b/src/ophyd_async/sim/demo/_sim_motor.py @@ -6,8 +6,6 @@ from ophyd_async.core import ( AsyncStatus, - ConfigSignal, - HintedSignal, StandardReadable, WatchableAsyncStatus, WatcherUpdate, @@ -15,6 +13,7 @@ soft_signal_r_and_setter, soft_signal_rw, ) +from ophyd_async.core import StandardReadableFormat as Format class SimMotor(StandardReadable, Movable, Stoppable): @@ -28,11 +27,11 @@ def __init__(self, name="", instant=True) -> None: - instant: bool: whether to move instantly, or with a delay """ # Define some signals - with self.add_children_as_readables(HintedSignal): + with self.add_children_as_readables(Format.HINTED_SIGNAL): self.user_readback, self._user_readback_set = soft_signal_r_and_setter( float, 0 ) - with self.add_children_as_readables(ConfigSignal): + with self.add_children_as_readables(Format.CONFIG_SIGNAL): self.velocity = soft_signal_rw(float, 0 if instant else 1.0) self.units = soft_signal_rw(str, "mm") self.user_setpoint = soft_signal_rw(float, 0) diff --git a/src/ophyd_async/tango/base_devices/_base_device.py b/src/ophyd_async/tango/base_devices/_base_device.py index f93b5a3eda..2227d5ddbc 100644 --- a/src/ophyd_async/tango/base_devices/_base_device.py +++ b/src/ophyd_async/tango/base_devices/_base_device.py @@ -48,7 +48,6 @@ def __init__( polling=self._polling, signal_polling=self._signal_polling, ) - connector.create_children_from_annotations(self) super().__init__(name=name, connector=connector) @@ -106,20 +105,24 @@ def __init__( self._signal_polling = signal_polling def create_children_from_annotations(self, device: Device): - self._filler = DeviceFiller( - device=device, - signal_backend_factory=TangoSignalBackend, - device_connector_factory=lambda: TangoDeviceConnector( - None, None, (False, 0.1, None, None), {} - ), - ) + if not hasattr(self, "filler"): + self.filler = DeviceFiller( + device=device, + signal_backend_factory=TangoSignalBackend, + device_connector_factory=lambda: TangoDeviceConnector( + None, None, (False, 0.1, None, None), {} + ), + ) + list(self.filler.create_devices_from_annotations(filled=False)) + list(self.filler.create_signals_from_annotations(filled=False)) + self.filler.check_created() async def connect( self, device: Device, mock: bool | Mock, timeout: float, force_reconnect: bool ) -> None: if mock: # Make 2 entries for each DeviceVector - self._filler.make_soft_device_vector_entries(2) + self.filler.create_device_vector_entries_to_mock(2) else: if self.trl and self.proxy is None: self.proxy = await AsyncDeviceProxy(self.trl) @@ -138,7 +141,7 @@ async def connect( full_trl = f"{self.trl}/{name}" signal_type = await infer_signal_type(full_trl, self.proxy) if signal_type: - backend = self._filler.make_child_signal(name, signal_type) + backend = self.filler.fill_child_signal(name, signal_type) backend.datatype = await infer_python_type(full_trl, self.proxy) backend.set_trl(full_trl) if polling := self._signal_polling.get(name, ()): @@ -147,12 +150,8 @@ async def connect( elif self._polling[0]: backend.set_polling(*self._polling) backend.allow_events(False) - # Check that all the requested children have been created - if unfilled := self._filler.unfilled(): - raise RuntimeError( - f"{device.name}: cannot provision {unfilled} from " - f"{self.trl}: {children}" - ) + # Check that all the requested children have been filled + self.filler.check_filled(f"{self.trl}: {children}") # Set the name of the device to name all children device.set_name(device.name) return await super().connect(device, mock, timeout, force_reconnect) diff --git a/src/ophyd_async/tango/demo/_counter.py b/src/ophyd_async/tango/demo/_counter.py index b23392d234..402c9ddd65 100644 --- a/src/ophyd_async/tango/demo/_counter.py +++ b/src/ophyd_async/tango/demo/_counter.py @@ -1,12 +1,7 @@ -from ophyd_async.core import ( - DEFAULT_TIMEOUT, - AsyncStatus, - ConfigSignal, - HintedSignal, - SignalR, - SignalRW, - SignalX, -) +from typing import Annotated as A + +from ophyd_async.core import DEFAULT_TIMEOUT, AsyncStatus, SignalR, SignalRW, SignalX +from ophyd_async.core import StandardReadableFormat as Format from ophyd_async.tango import TangoReadable, tango_polling @@ -16,16 +11,11 @@ class TangoCounter(TangoReadable): # Enter the name and type of the signals you want to use # If type is None or Signal, the type will be inferred from the Tango device - counts: SignalR[int] - sample_time: SignalRW[float] + counts: A[SignalR[int], Format.HINTED_SIGNAL] + sample_time: A[SignalRW[float], Format.CONFIG_SIGNAL] start: SignalX reset_: SignalX - def __init__(self, trl: str | None = "", name=""): - super().__init__(trl, name=name) - self.add_readables([self.counts], HintedSignal) - self.add_readables([self.sample_time], ConfigSignal) - @AsyncStatus.wrap async def trigger(self) -> None: sample_time = await self.sample_time.get_value() diff --git a/src/ophyd_async/tango/demo/_mover.py b/src/ophyd_async/tango/demo/_mover.py index bb15ac1b50..c249afef2b 100644 --- a/src/ophyd_async/tango/demo/_mover.py +++ b/src/ophyd_async/tango/demo/_mover.py @@ -7,8 +7,6 @@ DEFAULT_TIMEOUT, AsyncStatus, CalculatableTimeout, - ConfigSignal, - HintedSignal, SignalR, SignalRW, SignalX, @@ -17,6 +15,7 @@ observe_value, wait_for_value, ) +from ophyd_async.core import StandardReadableFormat as Format from ophyd_async.tango import TangoReadable, tango_polling from tango import DevState @@ -33,8 +32,8 @@ class TangoMover(TangoReadable, Movable, Stoppable): def __init__(self, trl: str | None = "", name=""): super().__init__(trl, name=name) - self.add_readables([self.position], HintedSignal) - self.add_readables([self.velocity], ConfigSignal) + self.add_readables([self.position], Format.HINTED_SIGNAL) + self.add_readables([self.velocity], Format.CONFIG_SIGNAL) self._set_success = True @WatchableAsyncStatus.wrap diff --git a/system_tests/epics/eiger/test_eiger_system.py b/system_tests/epics/eiger/test_eiger_system.py index 320c74aecc..4cfc81dae3 100644 --- a/system_tests/epics/eiger/test_eiger_system.py +++ b/system_tests/epics/eiger/test_eiger_system.py @@ -12,8 +12,8 @@ DeviceCollector, StaticPathProvider, ) +from ophyd_async.epics.core import epics_signal_rw from ophyd_async.epics.eiger import EigerDetector, EigerTriggerInfo -from ophyd_async.epics.signal import epics_signal_rw SAVE_PATH = "/tmp" diff --git a/tests/core/test_device_save_loader.py b/tests/core/test_device_save_loader.py index e81918d2f6..9800b706e4 100644 --- a/tests/core/test_device_save_loader.py +++ b/tests/core/test_device_save_loader.py @@ -24,7 +24,7 @@ set_signal_values, walk_rw_signals, ) -from ophyd_async.epics.signal import epics_signal_r, epics_signal_rw +from ophyd_async.epics.core import epics_signal_r, epics_signal_rw class EnumTest(StrictEnum): diff --git a/tests/core/test_flyer.py b/tests/core/test_flyer.py index ef0d0eb39d..6c894974b2 100644 --- a/tests/core/test_flyer.py +++ b/tests/core/test_flyer.py @@ -24,7 +24,7 @@ assert_emitted, observe_value, ) -from ophyd_async.epics.signal import epics_signal_rw +from ophyd_async.epics.core import epics_signal_rw class TriggerState(StrictEnum): diff --git a/tests/core/test_mock_signal_backend.py b/tests/core/test_mock_signal_backend.py index a40faea4e5..c5824b5435 100644 --- a/tests/core/test_mock_signal_backend.py +++ b/tests/core/test_mock_signal_backend.py @@ -21,7 +21,7 @@ soft_signal_r_and_setter, soft_signal_rw, ) -from ophyd_async.epics.signal import epics_signal_r, epics_signal_rw +from ophyd_async.epics.core import epics_signal_r, epics_signal_rw async def test_mock_signal_backend(): diff --git a/tests/core/test_readable.py b/tests/core/test_readable.py index ebe20e009d..8b39fb404d 100644 --- a/tests/core/test_readable.py +++ b/tests/core/test_readable.py @@ -1,5 +1,3 @@ -from inspect import ismethod -from typing import get_type_hints from unittest.mock import MagicMock import pytest @@ -17,7 +15,16 @@ SignalR, StandardReadable, soft_signal_r_and_setter, + soft_signal_rw, ) +from ophyd_async.core import StandardReadableFormat as Format + + +@pytest.mark.parametrize("wrapper", [HintedSignal, HintedSignal.uncached, ConfigSignal]) +def test_standard_readable_wrappers_raise_deprecation_warning(wrapper): + sr = StandardReadable() + with pytest.deprecated_call(): + sr.add_readables([soft_signal_rw(int)], wrapper) def test_standard_readable_hints(): @@ -144,87 +151,103 @@ def test_standard_readable_add_children_cm_filters_non_devices(): assert set(mock.call_args.args[0]) == {sr.a, sr.b} -@pytest.mark.parametrize( - "readable, expected_attr", - [ - (SignalR, "_readables"), - (AsyncReadable, "_readables"), - (AsyncConfigurable, "_configurables"), - (AsyncStageable, "_stageables"), - (HasHints, "_has_hints"), - ], -) -def test_standard_readable_add_readables_adds_to_expected_attrs( - readable, expected_attr -): - sr = StandardReadable() - - r1 = MagicMock(spec=readable) - readables = [r1] +def assert_sr_has_attrs(sr: StandardReadable, expected_attrs: dict[str, tuple]): + attrs_to_check = ( + "_describe_config_funcs", + "_read_config_funcs", + "_describe_funcs", + "_read_funcs", + "_stageables", + "_has_hints", + ) + actual = {attr: getattr(sr, attr) for attr in attrs_to_check} + expected = {attr: expected_attrs.get(attr, ()) for attr in attrs_to_check} + assert actual == expected - sr.add_readables(readables) - assert getattr(sr, expected_attr) == (r1,) +signal_r = MagicMock(spec=SignalR) +async_readable = MagicMock(spec=AsyncReadable) +async_configurable = MagicMock(spec=AsyncConfigurable) +async_stageable = MagicMock(spec=AsyncStageable) +has_hints = MagicMock(spec=HasHints) @pytest.mark.parametrize( - "wrapper, expected_attrs", + "readable, expected_attrs", [ - (HintedSignal, ["_readables", "_has_hints", "_stageables"]), - (HintedSignal.uncached, ["_readables", "_has_hints"]), - (ConfigSignal, ["_configurables"]), + ( + signal_r, + { + "_read_funcs": (signal_r.read,), + "_describe_funcs": (signal_r.describe,), + "_stageables": (signal_r,), + }, + ), + ( + async_readable, + { + "_read_funcs": (async_readable.read,), + "_describe_funcs": (async_readable.describe,), + }, + ), + ( + async_configurable, + { + "_read_config_funcs": (async_configurable.read_configuration,), + "_describe_config_funcs": (async_configurable.describe_configuration,), + }, + ), + (async_stageable, {"_stageables": (async_stageable,)}), + (has_hints, {"_has_hints": (has_hints,)}), ], ) -def test_standard_readable_add_readables_adds_wrapped_to_expected_attr( - wrapper, expected_attrs: list[str] +def test_standard_readable_add_readables_adds_to_expected_attrs( + readable, expected_attrs: dict[str, tuple] ): sr = StandardReadable() + sr.add_readables([readable]) + assert_sr_has_attrs(sr, expected_attrs) - r1 = MagicMock(spec=SignalR) - readables = [r1] - sr.add_readables(readables, wrapper=wrapper) - - for expected_attr in expected_attrs: - saved = getattr(sr, expected_attr) - assert len(saved) == 1 - if ismethod(wrapper): - # Convert a classmethod into its Class type. Relies on type hinting! - wrapper = get_type_hints(wrapper)["return"] - assert isinstance(saved[0], wrapper) +def test_standard_readable_config_signal(): + signal_r = MagicMock(spec=SignalR) + sr = StandardReadable() + sr.add_readables([signal_r], Format.CONFIG_SIGNAL) + assert sr._describe_config_funcs == (signal_r.describe,) + assert sr._read_config_funcs == (signal_r.read,) -def test_standard_readable_set_readable_signals__raises_deprecated(): +def test_standard_readable_hinted_signal(): + signal_r = MagicMock(spec=SignalR) sr = StandardReadable() - - with pytest.deprecated_call(): - sr.set_readable_signals(()) + sr.add_readables([signal_r], Format.HINTED_SIGNAL) + assert sr._describe_funcs == (signal_r.describe,) + assert sr._read_funcs == (signal_r.read,) + assert sr._stageables == (signal_r,) + assert sr._has_hints[0].device == signal_r -@pytest.mark.filterwarnings("ignore:Migrate to ") -def test_standard_readable_set_readable_signals(): +def test_standard_readable_uncached_signal(): + signal_r = MagicMock(spec=SignalR) sr = StandardReadable() + sr.add_readables([signal_r], Format.UNCACHED_SIGNAL) + assert sr._describe_funcs == (signal_r.describe,) + assert sr._read_funcs[0].signal == signal_r - readable = MagicMock(spec=SignalR) - configurable = MagicMock(spec=SignalR) - readable_uncached = MagicMock(spec=SignalR) - sr.set_readable_signals( - read=(readable,), config=(configurable,), read_uncached=(readable_uncached,) - ) - - assert len(sr._readables) == 2 - assert all(isinstance(x, HintedSignal) for x in sr._readables) - assert len(sr._configurables) == 1 - assert all(isinstance(x, ConfigSignal) for x in sr._configurables) - assert len(sr._stageables) == 1 - assert all(isinstance(x, HintedSignal) for x in sr._stageables) +def test_standard_readable_hinted_uncached_signal(): + signal_r = MagicMock(spec=SignalR) + sr = StandardReadable() + sr.add_readables([signal_r], Format.HINTED_UNCACHED_SIGNAL) + assert sr._describe_funcs == (signal_r.describe,) + assert sr._read_funcs[0].signal == signal_r + assert sr._has_hints[0].device == signal_r def test_standard_readable_add_children_multi_nested(): inner = StandardReadable() outer = StandardReadable() - with inner.add_children_as_readables(HintedSignal): + with inner.add_children_as_readables(Format.HINTED_SIGNAL): inner.a, _ = soft_signal_r_and_setter(float, initial_value=5.0) inner.b, _ = soft_signal_r_and_setter(float, initial_value=6.0) with outer.add_children_as_readables(): diff --git a/tests/core/test_signal.py b/tests/core/test_signal.py index a9964f5132..09542fb513 100644 --- a/tests/core/test_signal.py +++ b/tests/core/test_signal.py @@ -13,7 +13,6 @@ AsyncStatus, ConfigSignal, DeviceCollector, - HintedSignal, MockSignalBackend, NotConnected, Signal, @@ -34,7 +33,8 @@ soft_signal_rw, wait_for_value, ) -from ophyd_async.epics.signal import epics_signal_r, epics_signal_rw +from ophyd_async.core import StandardReadableFormat as Format +from ophyd_async.epics.core import epics_signal_r, epics_signal_rw from ophyd_async.plan_stubs import ensure_connected @@ -341,9 +341,9 @@ class DummyReadable(StandardReadable): def __init__(self, prefix: str, name="") -> None: # Define some signals - with self.add_children_as_readables(HintedSignal): + with self.add_children_as_readables(Format.HINTED_SIGNAL): self.value = epics_signal_r(float, prefix + "Value") - with self.add_children_as_readables(ConfigSignal): + with self.add_children_as_readables(Format.CONFIG_SIGNAL): self.mode = epics_signal_rw(str, prefix + "Mode") self.mode2 = epics_signal_rw(str, prefix + "Mode2") # Set name and signals for read() and read_configuration() diff --git a/tests/core/test_subset_enum.py b/tests/core/test_subset_enum.py index ae9e56d2e8..4380343e2a 100644 --- a/tests/core/test_subset_enum.py +++ b/tests/core/test_subset_enum.py @@ -3,13 +3,13 @@ from p4p.nt import NTEnum from ophyd_async.core import SubsetEnum -from ophyd_async.epics.signal import epics_signal_rw +from ophyd_async.epics.core import epics_signal_rw # Allow these imports from private modules for tests -from ophyd_async.epics.signal._aioca import ( +from ophyd_async.epics.core._aioca import ( make_converter as ca_make_converter, # noqa: PLC2701 ) -from ophyd_async.epics.signal._p4p import ( +from ophyd_async.epics.core._p4p import ( make_converter as pva_make_converter, # noqa: PLC2701 ) diff --git a/tests/core/test_utils.py b/tests/core/test_utils.py index a40a8ed7a1..80bbc2fce4 100644 --- a/tests/core/test_utils.py +++ b/tests/core/test_utils.py @@ -11,7 +11,7 @@ SignalRW, ) from ophyd_async.core import soft_signal_rw -from ophyd_async.epics.signal import epics_signal_rw +from ophyd_async.epics.core import epics_signal_rw class ValueErrorBackend(SoftSignalBackend): diff --git a/tests/epics/adcore/test_writers.py b/tests/epics/adcore/test_writers.py index 9c402afdda..c06044c616 100644 --- a/tests/epics/adcore/test_writers.py +++ b/tests/epics/adcore/test_writers.py @@ -12,7 +12,7 @@ set_mock_value, ) from ophyd_async.epics import adaravis, adcore, adkinetix, adpilatus, advimba -from ophyd_async.epics.signal import epics_signal_r +from ophyd_async.epics.core import epics_signal_r from ophyd_async.plan_stubs import setup_ndattributes, setup_ndstats_sum diff --git a/tests/epics/demo/test_demo.py b/tests/epics/demo/test_demo.py index 7b14154c41..838b9f5811 100644 --- a/tests/epics/demo/test_demo.py +++ b/tests/epics/demo/test_demo.py @@ -264,9 +264,12 @@ async def test_sensor_disconnected(caplog): logs = caplog.get_records("call") logs = [log for log in logs if "_signal" not in log.pathname] assert len(logs) == 2 + messages = {log.message for log in logs} - assert logs[0].message == ("signal ca://PRE:Value timed out") - assert logs[1].message == ("signal ca://PRE:Mode timed out") + assert messages == { + "signal ca://PRE:Value timed out", + "signal ca://PRE:Mode timed out", + } assert s.name == "sensor" diff --git a/tests/epics/pvi/test_pvi.py b/tests/epics/pvi/test_pvi.py index 81efdc03eb..6391f98121 100644 --- a/tests/epics/pvi/test_pvi.py +++ b/tests/epics/pvi/test_pvi.py @@ -1,21 +1,31 @@ +from typing import Annotated as A from typing import TypeVar +import pytest +from bluesky.protocols import HasHints, Hints + from ophyd_async.core import ( Device, DeviceCollector, DeviceVector, SignalRW, SignalX, + StandardReadable, ) -from ophyd_async.epics.pvi import PviDeviceConnector +from ophyd_async.core import StandardReadableFormat as Format +from ophyd_async.epics.core import PviDeviceConnector -class Block1(Device): +class Block1(Device, HasHints): device_vector_signal_x: DeviceVector[SignalX] device_vector_signal_rw: DeviceVector[SignalRW[float]] signal_x: SignalX signal_rw: SignalRW[int] + @property + def hints(self) -> Hints: + return {} + class Block2(Device): device_vector: DeviceVector[Block1] @@ -32,9 +42,9 @@ class Block3(Device): signal_rw: SignalRW[int] -class Block4(Device): +class Block4(StandardReadable): device_vector: DeviceVector[Block1] - device: Block1 + device: A[Block1, Format.CHILD] signal_x: SignalX signal_rw: SignalRW[int] @@ -125,6 +135,7 @@ async def test_device_create_children_from_annotations_with_device_vectors(): await device.connect(mock=True) block_1_device = device.device + assert block_1_device in device._has_hints block_2_device_vector = device.device_vector assert device.device_vector[1].name == "test_device-device_vector-1" @@ -142,3 +153,21 @@ async def test_device_create_children_from_annotations_with_device_vectors(): # The memory addresses have not changed assert device.device is block_1_device assert device.device_vector is block_2_device_vector + + +class NoSignalType(Device): + a: SignalRW + + +class NoSignalTypeInVector(Device): + a: DeviceVector[SignalRW] + + +@pytest.mark.parametrize("cls", [NoSignalType, NoSignalTypeInVector]) +async def test_no_type_annotation_blocks(cls): + with pytest.raises(TypeError) as cm: + with_pvi_connector(cls, "PREFIX:") + assert str(cm.value) == ( + f"{cls.__name__}.a: Expected SignalX or SignalR/W/RW[type], " + "got " + ) diff --git a/tests/epics/signal/test_common.py b/tests/epics/signal/test_common.py index 124273b2f0..db197bffc4 100644 --- a/tests/epics/signal/test_common.py +++ b/tests/epics/signal/test_common.py @@ -3,7 +3,7 @@ import pytest from ophyd_async.core import StrictEnum -from ophyd_async.epics.signal import get_supported_values +from ophyd_async.epics.core._util import get_supported_values # noqa: PLC2701 def test_given_a_non_enum_passed_to_get_supported_enum_then_raises(): diff --git a/tests/epics/signal/test_signals.py b/tests/epics/signal/test_signals.py index 250bec3567..3f79442c9f 100644 --- a/tests/epics/signal/test_signals.py +++ b/tests/epics/signal/test_signals.py @@ -35,14 +35,14 @@ load_from_yaml, save_to_yaml, ) -from ophyd_async.epics.signal import ( +from ophyd_async.epics.core import ( epics_signal_r, epics_signal_rw, epics_signal_rw_rbv, epics_signal_w, epics_signal_x, ) -from ophyd_async.epics.signal._signal import _epics_signal_backend # noqa: PLC2701 +from ophyd_async.epics.core._signal import _epics_signal_backend # noqa: PLC2701 RECORDS = str(Path(__file__).parent / "test_records.db") PV_PREFIX = "".join(random.choice(string.ascii_lowercase) for _ in range(12)) @@ -930,3 +930,8 @@ def my_plan(): yield from bps.rd(ophyd_signal) RE(my_plan()) + + +def test_signal_module_emits_deprecation_warning(): + with pytest.deprecated_call(): + import ophyd_async.epics.signal # noqa: F401 diff --git a/tests/fastcs/panda/db/panda.db b/tests/fastcs/panda/db/panda.db index f0fab40b8d..cbc7205fb0 100644 --- a/tests/fastcs/panda/db/panda.db +++ b/tests/fastcs/panda/db/panda.db @@ -122,7 +122,7 @@ record(stringin, "$(IOC_NAME=PANDAQSRV):PULSE1:_PVI") field(VAL, "$(IOC_NAME=PANDAQSRV):PULSE1:PVI") info(Q:group, { "$(IOC_NAME=PANDAQSRV):PVI": { - "value.pulse1.d": { + "value.pulse[1].d": { "+channel": "VAL", "+type": "plain" } @@ -441,7 +441,7 @@ record(stringin, "$(IOC_NAME=PANDAQSRV):SEQ1:_PVI") field(VAL, "$(IOC_NAME=PANDAQSRV):SEQ1:PVI") info(Q:group, { "$(IOC_NAME=PANDAQSRV):PVI": { - "value.seq1.d": { + "value.seq[1].d": { "+channel": "VAL", "+type": "plain", "+putorder":18 @@ -560,11 +560,11 @@ $(EXCLUDE_PCAP=) }) $(EXCLUDE_PCAP=)} -$(INCLUDE_EXTRA_BLOCK=#)record(ao, "$(IOC_NAME=PANDAQSRV):EXTRA1:ARM") +$(INCLUDE_EXTRA_BLOCK=#)record(ao, "$(IOC_NAME=PANDAQSRV):EXTRA1:SIG1") $(INCLUDE_EXTRA_BLOCK=#){ $(INCLUDE_EXTRA_BLOCK=#) info(Q:group, { $(INCLUDE_EXTRA_BLOCK=#) "$(IOC_NAME=PANDAQSRV):EXTRA1:PVI": { -$(INCLUDE_EXTRA_BLOCK=#) "value.arm.x": { +$(INCLUDE_EXTRA_BLOCK=#) "value.sig[1].x": { $(INCLUDE_EXTRA_BLOCK=#) "+channel": "NAME", $(INCLUDE_EXTRA_BLOCK=#) "+type": "plain" $(INCLUDE_EXTRA_BLOCK=#) } @@ -578,7 +578,7 @@ $(INCLUDE_EXTRA_BLOCK=#){ $(INCLUDE_EXTRA_BLOCK=#) field(VAL, "$(IOC_NAME=PANDAQSRV):EXTRA1:PVI") $(INCLUDE_EXTRA_BLOCK=#) info(Q:group, { $(INCLUDE_EXTRA_BLOCK=#) "$(IOC_NAME=PANDAQSRV):PVI": { -$(INCLUDE_EXTRA_BLOCK=#) "value.extra1.d": { +$(INCLUDE_EXTRA_BLOCK=#) "value.extra[1].d": { $(INCLUDE_EXTRA_BLOCK=#) "+channel": "VAL", $(INCLUDE_EXTRA_BLOCK=#) "+type": "plain" $(INCLUDE_EXTRA_BLOCK=#) } @@ -586,11 +586,11 @@ $(INCLUDE_EXTRA_BLOCK=#) } $(INCLUDE_EXTRA_BLOCK=#) }) $(INCLUDE_EXTRA_BLOCK=#)} -$(INCLUDE_EXTRA_BLOCK=#)record(ao, "$(IOC_NAME=PANDAQSRV):EXTRA2:ARM") +$(INCLUDE_EXTRA_BLOCK=#)record(ao, "$(IOC_NAME=PANDAQSRV):EXTRA2:SIG1") $(INCLUDE_EXTRA_BLOCK=#){ $(INCLUDE_EXTRA_BLOCK=#) info(Q:group, { $(INCLUDE_EXTRA_BLOCK=#) "$(IOC_NAME=PANDAQSRV):EXTRA2:PVI": { -$(INCLUDE_EXTRA_BLOCK=#) "value.arm.x": { +$(INCLUDE_EXTRA_BLOCK=#) "value.sig[1].x": { $(INCLUDE_EXTRA_BLOCK=#) "+channel": "NAME", $(INCLUDE_EXTRA_BLOCK=#) "+type": "plain" $(INCLUDE_EXTRA_BLOCK=#) } @@ -604,7 +604,7 @@ $(INCLUDE_EXTRA_BLOCK=#){ $(INCLUDE_EXTRA_BLOCK=#) field(VAL, "$(IOC_NAME=PANDAQSRV):EXTRA2:PVI") $(INCLUDE_EXTRA_BLOCK=#) info(Q:group, { $(INCLUDE_EXTRA_BLOCK=#) "$(IOC_NAME=PANDAQSRV):PVI": { -$(INCLUDE_EXTRA_BLOCK=#) "value.extra2.d": { +$(INCLUDE_EXTRA_BLOCK=#) "value.extra[2].d": { $(INCLUDE_EXTRA_BLOCK=#) "+channel": "VAL", $(INCLUDE_EXTRA_BLOCK=#) "+type": "plain" $(INCLUDE_EXTRA_BLOCK=#) } diff --git a/tests/fastcs/panda/test_panda_connect.py b/tests/fastcs/panda/test_panda_connect.py index 7869a8dc9c..a5a88d7d87 100644 --- a/tests/fastcs/panda/test_panda_connect.py +++ b/tests/fastcs/panda/test_panda_connect.py @@ -1,6 +1,7 @@ """Used to test setting up signals for a PandA""" import copy +import re from typing import Any import numpy as np @@ -120,8 +121,11 @@ async def test_panda_with_missing_blocks(panda_pva, panda_t): panda = panda_t("PANDAQSRVI:", name="mypanda") with pytest.raises( RuntimeError, - match="mypanda: cannot provision {'pcap'} from PANDAQSRVI:PVI: {'pulse1': " - + "{'d': 'PANDAQSRVI:PULSE1:PVI'}, 'seq1': {'d': 'PANDAQSRVI:SEQ1:PVI'}}", + match=re.escape( + "mypanda: cannot provision ['pcap'] from PANDAQSRVI:PVI: " + "{'pulse': [None, {'d': 'PANDAQSRVI:PULSE1:PVI'}]," + " 'seq': [None, {'d': 'PANDAQSRVI:SEQ1:PVI'}]}" + ), ): await panda.connect() diff --git a/tests/fastcs/panda/test_panda_control.py b/tests/fastcs/panda/test_panda_control.py index e568dc3080..4f9bb8c546 100644 --- a/tests/fastcs/panda/test_panda_control.py +++ b/tests/fastcs/panda/test_panda_control.py @@ -5,7 +5,7 @@ import pytest from ophyd_async.core import DetectorTrigger, Device, DeviceCollector, TriggerInfo -from ophyd_async.epics.signal import epics_signal_rw +from ophyd_async.epics.core import epics_signal_rw from ophyd_async.fastcs.core import fastcs_connector from ophyd_async.fastcs.panda import CommonPandaBlocks, PandaPcapController diff --git a/tests/fastcs/panda/test_panda_utils.py b/tests/fastcs/panda/test_panda_utils.py index c60d9210af..fc662fde5c 100644 --- a/tests/fastcs/panda/test_panda_utils.py +++ b/tests/fastcs/panda/test_panda_utils.py @@ -3,7 +3,7 @@ from bluesky import RunEngine from ophyd_async.core import DeviceCollector, load_device, save_device -from ophyd_async.epics.signal import epics_signal_rw +from ophyd_async.epics.core import epics_signal_rw from ophyd_async.fastcs.core import fastcs_connector from ophyd_async.fastcs.panda import ( CommonPandaBlocks, diff --git a/tests/plan_stubs/test_ensure_connected.py b/tests/plan_stubs/test_ensure_connected.py index 5737deef72..62cc7d034b 100644 --- a/tests/plan_stubs/test_ensure_connected.py +++ b/tests/plan_stubs/test_ensure_connected.py @@ -1,7 +1,7 @@ import pytest from ophyd_async.core import Device, NotConnected, soft_signal_rw -from ophyd_async.epics.signal import epics_signal_rw +from ophyd_async.epics.core import epics_signal_rw from ophyd_async.plan_stubs import ensure_connected diff --git a/tests/plan_stubs/test_fly.py b/tests/plan_stubs/test_fly.py index 1e6c6afc42..7ad50a03d9 100644 --- a/tests/plan_stubs/test_fly.py +++ b/tests/plan_stubs/test_fly.py @@ -24,7 +24,7 @@ observe_value, set_mock_value, ) -from ophyd_async.epics.signal import epics_signal_rw +from ophyd_async.epics.core import epics_signal_rw from ophyd_async.fastcs.core import fastcs_connector from ophyd_async.fastcs.panda import ( CommonPandaBlocks, diff --git a/tests/tango/test_base_device.py b/tests/tango/test_base_device.py index d825b37138..3328ba51fe 100644 --- a/tests/tango/test_base_device.py +++ b/tests/tango/test_base_device.py @@ -1,6 +1,7 @@ import asyncio import time from enum import Enum, IntEnum +from typing import Annotated as A import bluesky.plan_stubs as bps import bluesky.plans as bp @@ -9,7 +10,8 @@ from bluesky import RunEngine import tango -from ophyd_async.core import Array1D, DeviceCollector, HintedSignal, SignalRW, T +from ophyd_async.core import Array1D, DeviceCollector, SignalRW, T +from ophyd_async.core import StandardReadableFormat as Format from ophyd_async.tango import TangoReadable, get_python_type from ophyd_async.tango.demo import ( DemoCounter, @@ -23,7 +25,6 @@ CmdArgType, DevState, ) -from tango import DeviceProxy as SyncDeviceProxy from tango.asyncio import DeviceProxy as AsyncDeviceProxy from tango.asyncio_executor import set_global_executor from tango.server import Device, attribute, command @@ -174,20 +175,9 @@ def raise_exception_cmd(self): # -------------------------------------------------------------------- class TestTangoReadable(TangoReadable): __test__ = False - justvalue: SignalRW[int] - array: SignalRW[Array1D[np.float64]] - limitedvalue: SignalRW[float] - - def __init__( - self, - trl: str | None = None, - device_proxy: SyncDeviceProxy | None = None, - name: str = "", - ) -> None: - super().__init__(trl, device_proxy, name=name) - self.add_readables( - [self.justvalue, self.array, self.limitedvalue], HintedSignal.uncached - ) + justvalue: A[SignalRW[int], Format.HINTED_UNCACHED_SIGNAL] + array: A[SignalRW[Array1D[np.float64]], Format.HINTED_UNCACHED_SIGNAL] + limitedvalue: A[SignalRW[float], Format.HINTED_UNCACHED_SIGNAL] # -------------------------------------------------------------------- From 3de9bd0125f48a1be019f699641ebdf08c32a474 Mon Sep 17 00:00:00 2001 From: "Tom C (DLS)" <101418278+coretl@users.noreply.github.com> Date: Fri, 8 Nov 2024 14:38:20 +0000 Subject: [PATCH 25/30] Update to copier 2.5.0 (#637) --- .copier-answers.yml | 2 +- .github/CONTRIBUTING.md | 2 +- .github/workflows/_pypi.yml | 2 +- .github/workflows/_release.yml | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.copier-answers.yml b/.copier-answers.yml index f382bb7996..14ad920025 100644 --- a/.copier-answers.yml +++ b/.copier-answers.yml @@ -1,5 +1,5 @@ # Changes here will be overwritten by Copier -_commit: 2.4.0 +_commit: 2.5.0 _src_path: gh:DiamondLightSource/python-copier-template author_email: tom.cobb@diamond.ac.uk author_name: Tom Cobb diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md index 06311934d0..366d18dfba 100644 --- a/.github/CONTRIBUTING.md +++ b/.github/CONTRIBUTING.md @@ -24,4 +24,4 @@ It is recommended that developers use a [vscode devcontainer](https://code.visua This project was created using the [Diamond Light Source Copier Template](https://github.com/DiamondLightSource/python-copier-template) for Python projects. -For more information on common tasks like setting up a developer environment, running the tests, and setting a pre-commit hook, see the template's [How-to guides](https://diamondlightsource.github.io/python-copier-template/2.4.0/how-to.html). +For more information on common tasks like setting up a developer environment, running the tests, and setting a pre-commit hook, see the template's [How-to guides](https://diamondlightsource.github.io/python-copier-template/2.5.0/how-to.html). diff --git a/.github/workflows/_pypi.yml b/.github/workflows/_pypi.yml index 574e299b05..8032bbaac4 100644 --- a/.github/workflows/_pypi.yml +++ b/.github/workflows/_pypi.yml @@ -16,4 +16,4 @@ jobs: - name: Publish to PyPI using trusted publishing uses: pypa/gh-action-pypi-publish@release/v1 with: - attestations: false + attestations: false diff --git a/.github/workflows/_release.yml b/.github/workflows/_release.yml index 10d8ed87d1..81b626438e 100644 --- a/.github/workflows/_release.yml +++ b/.github/workflows/_release.yml @@ -23,7 +23,7 @@ jobs: - name: Create GitHub Release # We pin to the SHA, not the tag, for security reasons. # https://docs.github.com/en/actions/learn-github-actions/security-hardening-for-github-actions#using-third-party-actions - uses: softprops/action-gh-release@c062e08bd532815e2082a85e87e3ef29c3e6d191 # v2.0.8 + uses: softprops/action-gh-release@e7a8f85e1c67a31e6ed99a94b41bd0b71bbee6b8 # v2.0.9 with: prerelease: ${{ contains(github.ref_name, 'a') || contains(github.ref_name, 'b') || contains(github.ref_name, 'rc') }} files: "*" From 4a95c3eb4f956b5d1cc432ef2beea2f0314fb137 Mon Sep 17 00:00:00 2001 From: "Tom C (DLS)" <101418278+coretl@users.noreply.github.com> Date: Mon, 11 Nov 2024 21:45:26 +0000 Subject: [PATCH 26/30] Speed up device creation and connection in mock mode (#641) --- src/ophyd_async/core/__init__.py | 2 + src/ophyd_async/core/_device.py | 119 ++++++++++-------- src/ophyd_async/core/_mock_signal_backend.py | 17 +-- src/ophyd_async/core/_mock_signal_utils.py | 25 ++-- src/ophyd_async/core/_signal.py | 48 ++++--- src/ophyd_async/core/_soft_signal_backend.py | 2 + src/ophyd_async/core/_utils.py | 65 ++++++++-- .../epics/adcore/_single_trigger.py | 3 +- src/ophyd_async/epics/core/_pvi_connector.py | 49 ++++---- .../plan_stubs/_ensure_connected.py | 6 +- .../tango/base_devices/_base_device.py | 73 +++++------ tests/core/test_device.py | 36 +++--- tests/core/test_signal.py | 77 +++--------- tests/epics/demo/test_demo.py | 8 +- tests/plan_stubs/test_ensure_connected.py | 4 +- 15 files changed, 289 insertions(+), 245 deletions(-) diff --git a/src/ophyd_async/core/__init__.py b/src/ophyd_async/core/__init__.py index f19f4cb59e..c208126d9d 100644 --- a/src/ophyd_async/core/__init__.py +++ b/src/ophyd_async/core/__init__.py @@ -84,6 +84,7 @@ DEFAULT_TIMEOUT, CalculatableTimeout, Callback, + LazyMock, NotConnected, Reference, StrictEnum, @@ -178,6 +179,7 @@ "DEFAULT_TIMEOUT", "CalculatableTimeout", "Callback", + "LazyMock", "CALCULATE_TIMEOUT", "NotConnected", "Reference", diff --git a/src/ophyd_async/core/_device.py b/src/ophyd_async/core/_device.py index 1fe7855f3d..eb43abff58 100644 --- a/src/ophyd_async/core/_device.py +++ b/src/ophyd_async/core/_device.py @@ -3,17 +3,15 @@ import asyncio import sys from collections.abc import Coroutine, Iterator, Mapping, MutableMapping +from functools import cached_property from logging import LoggerAdapter, getLogger from typing import Any, TypeVar -from unittest.mock import Mock from bluesky.protocols import HasName from bluesky.run_engine import call_in_bluesky_event_loop, in_bluesky_event_loop from ._protocol import Connectable -from ._utils import DEFAULT_TIMEOUT, NotConnected, wait_for_connection - -_device_mocks: dict[Device, Mock] = {} +from ._utils import DEFAULT_TIMEOUT, LazyMock, NotConnected, wait_for_connection class DeviceConnector: @@ -37,25 +35,23 @@ def create_children_from_annotations(self, device: Device): during ``__init__``. """ - async def connect( - self, - device: Device, - mock: bool | Mock, - timeout: float, - force_reconnect: bool, - ): + async def connect_mock(self, device: Device, mock: LazyMock): + # Connect serially, no errors to gather up as in mock mode + for name, child_device in device.children(): + await child_device.connect(mock=mock.child(name)) + + async def connect_real(self, device: Device, timeout: float, force_reconnect: bool): """Used during ``Device.connect``. This is called when a previous connect has not been done, or has been done in a different mock more. It should connect the Device and all its children. """ - coros = {} - for name, child_device in device.children(): - child_mock = getattr(mock, name) if mock else mock # Mock() or False - coros[name] = child_device.connect( - mock=child_mock, timeout=timeout, force_reconnect=force_reconnect - ) + # Connect in parallel, gathering up NotConnected errors + coros = { + name: child_device.connect(timeout=timeout, force_reconnect=force_reconnect) + for name, child_device in device.children() + } await wait_for_connection(**coros) @@ -67,9 +63,8 @@ class Device(HasName, Connectable): parent: Device | None = None # None if connect hasn't started, a Task if it has _connect_task: asyncio.Task | None = None - # If not None, then this is the mock arg of the previous connect - # to let us know if we can reuse an existing connection - _connect_mock_arg: bool | None = None + # The mock if we have connected in mock mode + _mock: LazyMock | None = None def __init__( self, name: str = "", connector: DeviceConnector | None = None @@ -83,10 +78,18 @@ def name(self) -> str: """Return the name of the Device""" return self._name + @cached_property + def _child_devices(self) -> dict[str, Device]: + return {} + def children(self) -> Iterator[tuple[str, Device]]: - for attr_name, attr in self.__dict__.items(): - if attr_name != "parent" and isinstance(attr, Device): - yield attr_name, attr + yield from self._child_devices.items() + + @cached_property + def log(self) -> LoggerAdapter: + return LoggerAdapter( + getLogger("ophyd_async.devices"), {"ophyd_async_device_name": self.name} + ) def set_name(self, name: str): """Set ``self.name=name`` and each ``self.child.name=name+"-child"``. @@ -97,28 +100,33 @@ def set_name(self, name: str): New name to set """ self._name = name - # Ensure self.log is recreated after a name change - self.log = LoggerAdapter( - getLogger("ophyd_async.devices"), {"ophyd_async_device_name": self.name} - ) + # Ensure logger is recreated after a name change + if "log" in self.__dict__: + del self.log for child_name, child in self.children(): child_name = f"{self.name}-{child_name.strip('_')}" if self.name else "" child.set_name(child_name) def __setattr__(self, name: str, value: Any) -> None: + # Bear in mind that this function is called *a lot*, so + # we need to make sure nothing expensive happens in it... if name == "parent": if self.parent not in (value, None): raise TypeError( f"Cannot set the parent of {self} to be {value}: " f"it is already a child of {self.parent}" ) - elif isinstance(value, Device): + # ...hence not doing an isinstance check for attributes we + # know not to be Devices + elif name not in _not_device_attrs and isinstance(value, Device): value.parent = self - return super().__setattr__(name, value) + self._child_devices[name] = value + # ...and avoiding the super call as we know it resolves to `object` + return object.__setattr__(self, name, value) async def connect( self, - mock: bool | Mock = False, + mock: bool | LazyMock = False, timeout: float = DEFAULT_TIMEOUT, force_reconnect: bool = False, ) -> None: @@ -133,26 +141,39 @@ async def connect( timeout: Time to wait before failing with a TimeoutError. """ - uses_mock = bool(mock) - can_use_previous_connect = ( - uses_mock is self._connect_mock_arg - and self._connect_task - and not (self._connect_task.done() and self._connect_task.exception()) - ) - if mock is True: - mock = Mock() # create a new Mock if one not provided - if force_reconnect or not can_use_previous_connect: - self._connect_mock_arg = uses_mock - if self._connect_mock_arg: - _device_mocks[self] = mock - coro = self._connector.connect( - device=self, mock=mock, timeout=timeout, force_reconnect=force_reconnect + if mock: + # Always connect in mock mode serially + if isinstance(mock, LazyMock): + # Use the provided mock + self._mock = mock + elif not self._mock: + # Make one + self._mock = LazyMock() + await self._connector.connect_mock(self, self._mock) + else: + # Try to cache the connect in real mode + can_use_previous_connect = ( + self._mock is None + and self._connect_task + and not (self._connect_task.done() and self._connect_task.exception()) ) - self._connect_task = asyncio.create_task(coro) - - assert self._connect_task, "Connect task not created, this shouldn't happen" - # Wait for it to complete - await self._connect_task + if force_reconnect or not can_use_previous_connect: + self._mock = None + coro = self._connector.connect_real(self, timeout, force_reconnect) + self._connect_task = asyncio.create_task(coro) + assert self._connect_task, "Connect task not created, this shouldn't happen" + # Wait for it to complete + await self._connect_task + + +_not_device_attrs = { + "_name", + "_children", + "_connector", + "_timeout", + "_mock", + "_connect_task", +} DeviceT = TypeVar("DeviceT", bound=Device) diff --git a/src/ophyd_async/core/_mock_signal_backend.py b/src/ophyd_async/core/_mock_signal_backend.py index 878313e051..43fb2ae7df 100644 --- a/src/ophyd_async/core/_mock_signal_backend.py +++ b/src/ophyd_async/core/_mock_signal_backend.py @@ -1,13 +1,13 @@ import asyncio from collections.abc import Callable from functools import cached_property -from unittest.mock import AsyncMock, Mock +from unittest.mock import AsyncMock from bluesky.protocols import Descriptor, Reading from ._signal_backend import SignalBackend, SignalDatatypeT from ._soft_signal_backend import SoftSignalBackend -from ._utils import Callback +from ._utils import Callback, LazyMock class MockSignalBackend(SignalBackend[SignalDatatypeT]): @@ -16,7 +16,7 @@ class MockSignalBackend(SignalBackend[SignalDatatypeT]): def __init__( self, initial_backend: SignalBackend[SignalDatatypeT], - mock: Mock, + mock: LazyMock, ) -> None: if isinstance(initial_backend, MockSignalBackend): raise ValueError("Cannot make a MockSignalBackend for a MockSignalBackend") @@ -34,11 +34,14 @@ def __init__( # use existing Mock if provided self.mock = mock - self.put_mock = AsyncMock(name="put", spec=Callable) - self.mock.attach_mock(self.put_mock, "put") - super().__init__(datatype=self.initial_backend.datatype) + @cached_property + def put_mock(self) -> AsyncMock: + put_mock = AsyncMock(name="put", spec=Callable) + self.mock().attach_mock(put_mock, "put") + return put_mock + def set_value(self, value: SignalDatatypeT): self.soft_backend.set_value(value) @@ -46,7 +49,7 @@ def source(self, name: str, read: bool) -> str: return f"mock+{self.initial_backend.source(name, read)}" async def connect(self, timeout: float) -> None: - pass + raise RuntimeError("It is not possible to connect a MockSignalBackend") @cached_property def put_proceeds(self) -> asyncio.Event: diff --git a/src/ophyd_async/core/_mock_signal_utils.py b/src/ophyd_async/core/_mock_signal_utils.py index 30d48dbfe0..08976a0468 100644 --- a/src/ophyd_async/core/_mock_signal_utils.py +++ b/src/ophyd_async/core/_mock_signal_utils.py @@ -2,17 +2,26 @@ from contextlib import asynccontextmanager, contextmanager from unittest.mock import AsyncMock, Mock -from ._device import Device, _device_mocks +from ._device import Device from ._mock_signal_backend import MockSignalBackend -from ._signal import Signal, SignalR, _mock_signal_backends +from ._signal import Signal, SignalConnector, SignalR from ._soft_signal_backend import SignalDatatypeT +from ._utils import LazyMock + + +def get_mock(device: Device | Signal) -> Mock: + mock = device._mock # noqa: SLF001 + assert isinstance(mock, LazyMock), f"Device {device} not connected in mock mode" + return mock() def _get_mock_signal_backend(signal: Signal) -> MockSignalBackend: - assert ( - signal in _mock_signal_backends + connector = signal._connector # noqa: SLF001 + assert isinstance(connector, SignalConnector), f"Expected Signal, got {signal}" + assert isinstance( + connector.backend, MockSignalBackend ), f"Signal {signal} not connected in mock mode" - return _mock_signal_backends[signal] + return connector.backend def set_mock_value(signal: Signal[SignalDatatypeT], value: SignalDatatypeT): @@ -45,12 +54,6 @@ def get_mock_put(signal: Signal) -> AsyncMock: return _get_mock_signal_backend(signal).put_mock -def get_mock(device: Device | Signal) -> Mock: - if isinstance(device, Signal): - return _get_mock_signal_backend(device).mock - return _device_mocks[device] - - def reset_mock_put_calls(signal: Signal): backend = _get_mock_signal_backend(signal) backend.put_mock.reset_mock() diff --git a/src/ophyd_async/core/_signal.py b/src/ophyd_async/core/_signal.py index 371cb5a0de..cd6152ef08 100644 --- a/src/ophyd_async/core/_signal.py +++ b/src/ophyd_async/core/_signal.py @@ -4,7 +4,6 @@ import functools from collections.abc import AsyncGenerator, Awaitable, Callable, Mapping from typing import Any, Generic, cast -from unittest.mock import Mock from bluesky.protocols import ( Locatable, @@ -28,10 +27,15 @@ SignalDatatypeT, ) from ._soft_signal_backend import SoftSignalBackend -from ._status import AsyncStatus, completed_status -from ._utils import CALCULATE_TIMEOUT, DEFAULT_TIMEOUT, CalculatableTimeout, Callback, T - -_mock_signal_backends: dict[Device, MockSignalBackend] = {} +from ._status import AsyncStatus +from ._utils import ( + CALCULATE_TIMEOUT, + DEFAULT_TIMEOUT, + CalculatableTimeout, + Callback, + LazyMock, + T, +) async def _wait_for(coro: Awaitable[T], timeout: float | None, source: str) -> T: @@ -53,26 +57,28 @@ class SignalConnector(DeviceConnector): def __init__(self, backend: SignalBackend): self.backend = self._init_backend = backend - async def connect( - self, - device: Device, - mock: bool | Mock, - timeout: float, - force_reconnect: bool, - ): - if mock: - self.backend = MockSignalBackend(self._init_backend, mock) - _mock_signal_backends[device] = self.backend - else: - self.backend = self._init_backend + async def connect_mock(self, device: Device, mock: LazyMock): + self.backend = MockSignalBackend(self._init_backend, mock) + + async def connect_real(self, device: Device, timeout: float, force_reconnect: bool): + self.backend = self._init_backend device.log.debug(f"Connecting to {self.backend.source(device.name, read=True)}") await self.backend.connect(timeout) +class _ChildrenNotAllowed(dict[str, Device]): + def __setitem__(self, key: str, value: Device) -> None: + raise AttributeError( + f"Cannot add Device or Signal child {key}={value} of Signal, " + "make a subclass of Device instead" + ) + + class Signal(Device, Generic[SignalDatatypeT]): """A Device with the concept of a value, with R, RW, W and X flavours""" _connector: SignalConnector + _child_devices = _ChildrenNotAllowed() # type: ignore def __init__( self, @@ -88,14 +94,6 @@ def source(self) -> str: """Like ca://PV_PREFIX:SIGNAL, or "" if not set""" return self._connector.backend.source(self.name, read=True) - def __setattr__(self, name: str, value: Any) -> None: - if name != "parent" and isinstance(value, Device): - raise AttributeError( - f"Cannot add Device or Signal {value} as a child of Signal {self}, " - "make a subclass of Device instead" - ) - return super().__setattr__(name, value) - class _SignalCache(Generic[SignalDatatypeT]): def __init__(self, backend: SignalBackend[SignalDatatypeT], signal: Signal): diff --git a/src/ophyd_async/core/_soft_signal_backend.py b/src/ophyd_async/core/_soft_signal_backend.py index d0e48c7212..ba21c3ba9b 100644 --- a/src/ophyd_async/core/_soft_signal_backend.py +++ b/src/ophyd_async/core/_soft_signal_backend.py @@ -4,6 +4,7 @@ from abc import abstractmethod from collections.abc import Sequence from dataclasses import dataclass +from functools import lru_cache from typing import Any, Generic, get_origin import numpy as np @@ -90,6 +91,7 @@ def write_value(self, value: Any) -> TableT: raise TypeError(f"Cannot convert {value} to {self.datatype}") +@lru_cache def make_converter(datatype: type[SignalDatatype]) -> SoftConverter: enum_cls = get_enum_cls(datatype) if datatype == Sequence[str]: diff --git a/src/ophyd_async/core/_utils.py b/src/ophyd_async/core/_utils.py index db4afae04a..ca20d90a3c 100644 --- a/src/ophyd_async/core/_utils.py +++ b/src/ophyd_async/core/_utils.py @@ -14,6 +14,7 @@ get_args, get_origin, ) +from unittest.mock import Mock import numpy as np @@ -120,20 +121,29 @@ async def wait_for_connection(**coros: Awaitable[None]): Expected kwargs should be a mapping of names to coroutine tasks to execute. """ - results = await asyncio.gather(*coros.values(), return_exceptions=True) - exceptions = {} + exceptions: dict[str, Exception] = {} + if len(coros) == 1: + # Single device optimization + name, coro = coros.popitem() + try: + await coro + except Exception as e: + exceptions[name] = e + else: + # Use gather to connect in parallel + results = await asyncio.gather(*coros.values(), return_exceptions=True) + for name, result in zip(coros, results, strict=False): + if isinstance(result, Exception): + exceptions[name] = result - for name, result in zip(coros, results, strict=False): - if isinstance(result, Exception): - exceptions[name] = result - if not isinstance(result, NotConnected): + if exceptions: + for name, exception in exceptions.items(): + if not isinstance(exception, NotConnected): logging.exception( f"device `{name}` raised unexpected exception " - f"{type(result).__name__}", - exc_info=result, + f"{type(exception).__name__}", + exc_info=exception, ) - - if exceptions: raise NotConnected(exceptions) @@ -252,3 +262,38 @@ def __init__(self, obj: T): def __call__(self) -> T: return self._obj + + +class LazyMock: + """A lazily created Mock to be used when connecting in mock mode. + + Creating Mocks is reasonably expensive when each Device (and Signal) + requires its own, and the tree is only used when ``Signal.set()`` is + called. This class allows a tree of lazily connected Mocks to be + constructed so that when the leaf is created, so are its parents. + Any calls to the child are then accessible from the parent mock. + + >>> parent = LazyMock() + >>> child = parent.child("child") + >>> child_mock = child() + >>> child_mock() # doctest: +ELLIPSIS + + >>> parent_mock = parent() + >>> parent_mock.mock_calls + [call.child()] + """ + + def __init__(self, name: str = "", parent: LazyMock | None = None) -> None: + self.parent = parent + self.name = name + self._mock: Mock | None = None + + def child(self, name: str) -> LazyMock: + return LazyMock(name, self) + + def __call__(self) -> Mock: + if self._mock is None: + self._mock = Mock(spec=object) + if self.parent is not None: + self.parent().attach_mock(self._mock, self.name) + return self._mock diff --git a/src/ophyd_async/epics/adcore/_single_trigger.py b/src/ophyd_async/epics/adcore/_single_trigger.py index 9fd81b413d..165204d371 100644 --- a/src/ophyd_async/epics/adcore/_single_trigger.py +++ b/src/ophyd_async/epics/adcore/_single_trigger.py @@ -19,7 +19,8 @@ def __init__( **plugins: NDPluginBaseIO, ) -> None: self.drv = drv - self.__dict__.update(plugins) + for k, v in plugins.items(): + setattr(self, k, v) self.add_readables( [self.drv.array_counter, *read_uncached], diff --git a/src/ophyd_async/epics/core/_pvi_connector.py b/src/ophyd_async/epics/core/_pvi_connector.py index 812e4ec473..1c5c0eceb6 100644 --- a/src/ophyd_async/epics/core/_pvi_connector.py +++ b/src/ophyd_async/epics/core/_pvi_connector.py @@ -1,7 +1,5 @@ from __future__ import annotations -from unittest.mock import Mock - from ophyd_async.core import ( Device, DeviceConnector, @@ -11,6 +9,7 @@ SignalRW, SignalX, ) +from ophyd_async.core._utils import LazyMock from ._epics_connector import fill_backend_with_prefix from ._signal import PvaSignalBackend, pvget_with_timeout @@ -64,29 +63,29 @@ def _fill_child(self, name: str, entry: Entry, vector_index: int | None = None): backend.read_pv = read_pv backend.write_pv = write_pv - async def connect( - self, device: Device, mock: bool | Mock, timeout: float, force_reconnect: bool + async def connect_mock(self, device: Device, mock: LazyMock): + self.filler.create_device_vector_entries_to_mock(2) + # Set the name of the device to name all children + device.set_name(device.name) + return await super().connect_mock(device, mock) + + async def connect_real( + self, device: Device, timeout: float, force_reconnect: bool ) -> None: - if mock: - # Make 2 entries for each DeviceVector - self.filler.create_device_vector_entries_to_mock(2) - else: - pvi_structure = await pvget_with_timeout(self.pvi_pv, timeout) - entries: dict[str, Entry | list[Entry | None]] = pvi_structure[ - "value" - ].todict() - # Fill based on what PVI gives us - for name, entry in entries.items(): - if isinstance(entry, dict): - # This is a child - self._fill_child(name, entry) - else: - # This is a DeviceVector of children - for i, e in enumerate(entry): - if e: - self._fill_child(name, e, i) - # Check that all the requested children have been filled - self.filler.check_filled(f"{self.pvi_pv}: {entries}") + pvi_structure = await pvget_with_timeout(self.pvi_pv, timeout) + entries: dict[str, Entry | list[Entry | None]] = pvi_structure["value"].todict() + # Fill based on what PVI gives us + for name, entry in entries.items(): + if isinstance(entry, dict): + # This is a child + self._fill_child(name, entry) + else: + # This is a DeviceVector of children + for i, e in enumerate(entry): + if e: + self._fill_child(name, e, i) + # Check that all the requested children have been filled + self.filler.check_filled(f"{self.pvi_pv}: {entries}") # Set the name of the device to name all children device.set_name(device.name) - return await super().connect(device, mock, timeout, force_reconnect) + return await super().connect_real(device, timeout, force_reconnect) diff --git a/src/ophyd_async/plan_stubs/_ensure_connected.py b/src/ophyd_async/plan_stubs/_ensure_connected.py index d4835b710c..0ad5cff518 100644 --- a/src/ophyd_async/plan_stubs/_ensure_connected.py +++ b/src/ophyd_async/plan_stubs/_ensure_connected.py @@ -1,13 +1,11 @@ -from unittest.mock import Mock - import bluesky.plan_stubs as bps -from ophyd_async.core import DEFAULT_TIMEOUT, Device, wait_for_connection +from ophyd_async.core import DEFAULT_TIMEOUT, Device, LazyMock, wait_for_connection def ensure_connected( *devices: Device, - mock: bool | Mock = False, + mock: bool | LazyMock = False, timeout: float = DEFAULT_TIMEOUT, force_reconnect=False, ): diff --git a/src/ophyd_async/tango/base_devices/_base_device.py b/src/ophyd_async/tango/base_devices/_base_device.py index 2227d5ddbc..1f98c4bd4f 100644 --- a/src/ophyd_async/tango/base_devices/_base_device.py +++ b/src/ophyd_async/tango/base_devices/_base_device.py @@ -1,9 +1,9 @@ from __future__ import annotations from typing import TypeVar -from unittest.mock import Mock from ophyd_async.core import Device, DeviceConnector, DeviceFiller +from ophyd_async.core._utils import LazyMock from ophyd_async.tango.signal import ( TangoSignalBackend, infer_python_type, @@ -117,41 +117,42 @@ def create_children_from_annotations(self, device: Device): list(self.filler.create_signals_from_annotations(filled=False)) self.filler.check_created() - async def connect( - self, device: Device, mock: bool | Mock, timeout: float, force_reconnect: bool - ) -> None: - if mock: - # Make 2 entries for each DeviceVector - self.filler.create_device_vector_entries_to_mock(2) + async def connect_mock(self, device: Device, mock: LazyMock): + # Make 2 entries for each DeviceVector + self.filler.create_device_vector_entries_to_mock(2) + # Set the name of the device to name all children + device.set_name(device.name) + return await super().connect_mock(device, mock) + + async def connect_real(self, device: Device, timeout: float, force_reconnect: bool): + if self.trl and self.proxy is None: + self.proxy = await AsyncDeviceProxy(self.trl) + elif self.proxy and not self.trl: + self.trl = self.proxy.name() else: - if self.trl and self.proxy is None: - self.proxy = await AsyncDeviceProxy(self.trl) - elif self.proxy and not self.trl: - self.trl = self.proxy.name() - else: - raise TypeError("Neither proxy nor trl supplied") - - children = sorted( - set() - .union(self.proxy.get_attribute_list()) - .union(self.proxy.get_command_list()) - ) - for name in children: - # TODO: strip attribute name - full_trl = f"{self.trl}/{name}" - signal_type = await infer_signal_type(full_trl, self.proxy) - if signal_type: - backend = self.filler.fill_child_signal(name, signal_type) - backend.datatype = await infer_python_type(full_trl, self.proxy) - backend.set_trl(full_trl) - if polling := self._signal_polling.get(name, ()): - backend.set_polling(*polling) - backend.allow_events(False) - elif self._polling[0]: - backend.set_polling(*self._polling) - backend.allow_events(False) - # Check that all the requested children have been filled - self.filler.check_filled(f"{self.trl}: {children}") + raise TypeError("Neither proxy nor trl supplied") + + children = sorted( + set() + .union(self.proxy.get_attribute_list()) + .union(self.proxy.get_command_list()) + ) + for name in children: + # TODO: strip attribute name + full_trl = f"{self.trl}/{name}" + signal_type = await infer_signal_type(full_trl, self.proxy) + if signal_type: + backend = self.filler.fill_child_signal(name, signal_type) + backend.datatype = await infer_python_type(full_trl, self.proxy) + backend.set_trl(full_trl) + if polling := self._signal_polling.get(name, ()): + backend.set_polling(*polling) + backend.allow_events(False) + elif self._polling[0]: + backend.set_polling(*self._polling) + backend.allow_events(False) + # Check that all the requested children have been filled + self.filler.check_filled(f"{self.trl}: {children}") # Set the name of the device to name all children device.set_name(device.name) - return await super().connect(device, mock, timeout, force_reconnect) + return await super().connect_real(device, timeout, force_reconnect) diff --git a/tests/core/test_device.py b/tests/core/test_device.py index 2fc127f17b..39a9b70a5d 100644 --- a/tests/core/test_device.py +++ b/tests/core/test_device.py @@ -1,4 +1,5 @@ import asyncio +import time import traceback from unittest.mock import Mock @@ -174,43 +175,50 @@ def __init__(self, name: str) -> None: super().__init__(name) +@pytest.mark.parametrize("parallel", (False, True)) +async def test_many_individual_device_connects_not_slow(parallel): + start = time.time() + bundles = [MotorBundle(f"bundle{i}") for i in range(100)] + if parallel: + for bundle in bundles: + await bundle.connect(mock=True) + else: + coros = {bundle.name: bundle.connect(mock=True) for bundle in bundles} + await wait_for_connection(**coros) + duration = time.time() - start + assert duration < 1 + + async def test_device_with_children_lazily_connects(RE): parentMotor = MotorBundle("parentMotor") for device in [parentMotor, parentMotor.X, parentMotor.Y] + list( parentMotor.V.values() ): - assert device._connect_task is None + assert device._mock is None RE(ensure_connected(parentMotor, mock=True)) for device in [parentMotor, parentMotor.X, parentMotor.Y] + list( parentMotor.V.values() ): - assert ( - device._connect_task is not None - and device._connect_task.done() - and not device._connect_task.exception() - ) + assert device._mock is not None -@pytest.mark.parametrize("use_Mock", [False, True]) -async def test_no_reconnect_signals_if_not_forced(use_Mock): +async def test_no_reconnect_signals_if_not_forced(): parent = DummyDeviceGroup("parent") - connect_mock_arg = Mock() if use_Mock else True - - async def inner_connect(mock, timeout, force_reconnect): + async def inner_connect(mock=False, timeout=None, force_reconnect=False): parent.child1.connected = True parent.child1.connect = Mock(side_effect=inner_connect) - await parent.connect(mock=connect_mock_arg, timeout=0.01) + await parent.connect(mock=False, timeout=0.01) assert parent.child1.connected assert parent.child1.connect.call_count == 1 - await parent.connect(mock=connect_mock_arg, timeout=0.01) + await parent.connect(mock=False, timeout=0.01) assert parent.child1.connected assert parent.child1.connect.call_count == 1 for count in range(2, 10): - await parent.connect(mock=connect_mock_arg, timeout=0.01, force_reconnect=True) + await parent.connect(mock=False, timeout=0.01, force_reconnect=True) assert parent.child1.connected assert parent.child1.connect.call_count == count diff --git a/tests/core/test_signal.py b/tests/core/test_signal.py index 09542fb513..77f8e001a7 100644 --- a/tests/core/test_signal.py +++ b/tests/core/test_signal.py @@ -13,9 +13,6 @@ AsyncStatus, ConfigSignal, DeviceCollector, - MockSignalBackend, - NotConnected, - Signal, SignalR, SignalRW, SoftSignalBackend, @@ -35,80 +32,44 @@ ) from ophyd_async.core import StandardReadableFormat as Format from ophyd_async.epics.core import epics_signal_r, epics_signal_rw -from ophyd_async.plan_stubs import ensure_connected def num_occurrences(substring: str, string: str) -> int: return len(list(re.finditer(re.escape(substring), string))) +def test_cannot_add_child_to_signal(): + signal = soft_signal_rw(str) + with pytest.raises( + AttributeError, + match="Cannot add Device or Signal child foo=<.*> of Signal, " + "make a subclass of Device instead", + ): + signal.foo = signal + + async def test_signal_connects_to_previous_backend(caplog): caplog.set_level(logging.DEBUG) - int_mock_backend = MockSignalBackend(SoftSignalBackend(int), Mock()) - original_connect = int_mock_backend.connect - times_backend_connect_called = 0 - - async def new_connect(timeout=1): - nonlocal times_backend_connect_called - times_backend_connect_called += 1 - await asyncio.sleep(0.1) - await original_connect(timeout=timeout) - - int_mock_backend.connect = new_connect - signal = Signal(int_mock_backend) - await asyncio.gather(signal.connect(), signal.connect()) + signal = soft_signal_rw(int) + mock_connect = Mock(side_effect=signal._connector.backend.connect) + signal._connector.backend.connect = mock_connect + await signal.connect() + assert mock_connect.call_count == 1 + assert num_occurrences(f"Connecting to {signal.source}", caplog.text) == 1 + await asyncio.gather(signal.connect(), signal.connect(), signal.connect()) + assert mock_connect.call_count == 1 assert num_occurrences(f"Connecting to {signal.source}", caplog.text) == 1 - assert times_backend_connect_called == 1 async def test_signal_connects_with_force_reconnect(caplog): caplog.set_level(logging.DEBUG) - signal = Signal(MockSignalBackend(SoftSignalBackend(int), Mock())) + signal = soft_signal_rw(int) await signal.connect() assert num_occurrences(f"Connecting to {signal.source}", caplog.text) == 1 await signal.connect(force_reconnect=True) assert num_occurrences(f"Connecting to {signal.source}", caplog.text) == 2 -async def test_signal_lazily_connects(RE): - class MockSignalBackendFailingFirst(MockSignalBackend): - succeed_on_connect = False - - async def connect(self, timeout=DEFAULT_TIMEOUT): - if self.succeed_on_connect: - self.succeed_on_connect = False - await super().connect(timeout=timeout) - else: - self.succeed_on_connect = True - raise RuntimeError("connect fail") - - signal = SignalRW(MockSignalBackendFailingFirst(SoftSignalBackend(int), Mock())) - - with pytest.raises(RuntimeError, match="connect fail"): - await signal.connect(mock=False) - - assert ( - signal._connect_task - and signal._connect_task.done() - and signal._connect_task.exception() - ) - - RE(ensure_connected(signal, mock=False)) - assert ( - signal._connect_task - and signal._connect_task.done() - and not signal._connect_task.exception() - ) - - with pytest.raises(NotConnected, match="RuntimeError: connect fail"): - RE(ensure_connected(signal, mock=False, force_reconnect=True)) - assert ( - signal._connect_task - and signal._connect_task.done() - and signal._connect_task.exception() - ) - - async def time_taken_by(coro) -> float: start = time.monotonic() await coro diff --git a/tests/epics/demo/test_demo.py b/tests/epics/demo/test_demo.py index 838b9f5811..e29e13bf29 100644 --- a/tests/epics/demo/test_demo.py +++ b/tests/epics/demo/test_demo.py @@ -10,6 +10,7 @@ from ophyd_async.core import ( DeviceCollector, + LazyMock, NotConnected, assert_emitted, assert_reading, @@ -198,9 +199,10 @@ async def test_retrieve_mock_and_assert(mock_mover: demo.Mover): async def test_mocks_in_device_share_parent(): - mock = Mock() - async with DeviceCollector(mock=mock): - mock_mover = demo.Mover("BLxxI-MO-TABLE-01:Y:") + lm = LazyMock() + mock_mover = demo.Mover("BLxxI-MO-TABLE-01:Y:") + await mock_mover.connect(mock=lm) + mock = lm() assert get_mock(mock_mover) is mock assert get_mock(mock_mover.setpoint) is mock.setpoint diff --git a/tests/plan_stubs/test_ensure_connected.py b/tests/plan_stubs/test_ensure_connected.py index 62cc7d034b..4665ddcee9 100644 --- a/tests/plan_stubs/test_ensure_connected.py +++ b/tests/plan_stubs/test_ensure_connected.py @@ -31,8 +31,8 @@ def connect(): device2 = MyDevice("PREFIX2", name="device2") def connect_with_mocking(): - assert device2.signal._connect_task is None + assert device2.signal._mock is None yield from ensure_connected(device2, mock=True, timeout=0.1) - assert device2.signal._connect_task.done() + assert device2.signal._mock is not None RE(connect_with_mocking()) From 040f7a481a9aed1b6fd44d8034550f3581e99872 Mon Sep 17 00:00:00 2001 From: "Tom C (DLS)" <101418278+coretl@users.noreply.github.com> Date: Tue, 12 Nov 2024 14:15:34 +0000 Subject: [PATCH 27/30] Set parent of children of DeviceVector passed at init (#644) Fixes #643 --- src/ophyd_async/core/_device.py | 3 ++- tests/core/test_device.py | 5 +++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/src/ophyd_async/core/_device.py b/src/ophyd_async/core/_device.py index eb43abff58..9aef7d136f 100644 --- a/src/ophyd_async/core/_device.py +++ b/src/ophyd_async/core/_device.py @@ -193,7 +193,8 @@ def __init__( children: Mapping[int, DeviceT], name: str = "", ) -> None: - self._children = dict(children) + self._children: dict[int, DeviceT] = {} + self.update(children) super().__init__(name=name) def __setattr__(self, name: str, child: Any) -> None: diff --git a/tests/core/test_device.py b/tests/core/test_device.py index 39a9b70a5d..833ca5541a 100644 --- a/tests/core/test_device.py +++ b/tests/core/test_device.py @@ -117,10 +117,15 @@ async def test_device_with_device_collector(): parent = DummyDeviceGroup("parent") assert parent.name == "parent" + assert parent.parent is None assert parent.child1.name == "parent-child1" + assert parent.child1.parent == parent assert parent._child2.name == "parent-child2" + assert parent._child2.parent == parent assert parent.dict_with_children.name == "parent-dict_with_children" + assert parent.dict_with_children.parent == parent assert parent.dict_with_children[123].name == "parent-dict_with_children-123" + assert parent.dict_with_children[123].parent == parent.dict_with_children assert parent.child1.connected assert parent.dict_with_children[123].connected From ad910fb0abc028d3d3f8315a16f43578921795e9 Mon Sep 17 00:00:00 2001 From: "Tom C (DLS)" <101418278+coretl@users.noreply.github.com> Date: Wed, 13 Nov 2024 14:19:00 +0000 Subject: [PATCH 28/30] Fix some small issues discovered in testing (#646) - Raise `NotConnected` errors in mock mode too when failing on creation of mock signal because of invalid datatype - Support bare `np.ndarray` in SoftSignalBackend - Support tagged and variant unions for NTNDArray - Check that Devices have unique names in `ensure_connected` --- src/ophyd_async/core/_device.py | 8 +++- src/ophyd_async/core/_soft_signal_backend.py | 8 +++- src/ophyd_async/core/_utils.py | 36 ++++++++-------- src/ophyd_async/epics/core/_p4p.py | 1 + .../plan_stubs/_ensure_connected.py | 29 ++++++++----- tests/core/test_device_save_loader.py | 25 +++++------ tests/core/test_signal.py | 42 +++++++++++++------ tests/core/test_soft_signal_backend.py | 27 ++++++------ tests/core/test_utils.py | 24 +++++++++++ tests/plan_stubs/test_ensure_connected.py | 14 +++++++ tests/test_data/test_yaml_save.yml | 1 - 11 files changed, 138 insertions(+), 77 deletions(-) diff --git a/src/ophyd_async/core/_device.py b/src/ophyd_async/core/_device.py index 9aef7d136f..11501dd228 100644 --- a/src/ophyd_async/core/_device.py +++ b/src/ophyd_async/core/_device.py @@ -37,8 +37,14 @@ def create_children_from_annotations(self, device: Device): async def connect_mock(self, device: Device, mock: LazyMock): # Connect serially, no errors to gather up as in mock mode + exceptions: dict[str, Exception] = {} for name, child_device in device.children(): - await child_device.connect(mock=mock.child(name)) + try: + await child_device.connect(mock=mock.child(name)) + except Exception as e: + exceptions[name] = e + if exceptions: + raise NotConnected.with_other_exceptions_logged(exceptions) async def connect_real(self, device: Device, timeout: float, force_reconnect: bool): """Used during ``Device.connect``. diff --git a/src/ophyd_async/core/_soft_signal_backend.py b/src/ophyd_async/core/_soft_signal_backend.py index ba21c3ba9b..cbafe9384d 100644 --- a/src/ophyd_async/core/_soft_signal_backend.py +++ b/src/ophyd_async/core/_soft_signal_backend.py @@ -5,7 +5,7 @@ from collections.abc import Sequence from dataclasses import dataclass from functools import lru_cache -from typing import Any, Generic, get_origin +from typing import Any, Generic, get_args, get_origin import numpy as np from bluesky.protocols import Reading @@ -58,7 +58,7 @@ def write_value(self, value: Any) -> Sequence[EnumT]: @dataclass class NDArraySoftConverter(SoftConverter[Array1D]): - datatype: np.dtype + datatype: np.dtype | None = None def write_value(self, value: Any) -> Array1D: return np.array(() if value is None else value, dtype=self.datatype) @@ -98,7 +98,11 @@ def make_converter(datatype: type[SignalDatatype]) -> SoftConverter: return SequenceStrSoftConverter() elif get_origin(datatype) == Sequence and enum_cls: return SequenceEnumSoftConverter(enum_cls) + elif datatype is np.ndarray: + return NDArraySoftConverter() elif get_origin(datatype) == np.ndarray: + if datatype not in get_args(SignalDatatype): + raise TypeError(f"Expected Array1D[dtype], got {datatype}") return NDArraySoftConverter(get_dtype(datatype)) elif enum_cls: return EnumSoftConverter(enum_cls) diff --git a/src/ophyd_async/core/_utils.py b/src/ophyd_async/core/_utils.py index ca20d90a3c..2aa4b1c717 100644 --- a/src/ophyd_async/core/_utils.py +++ b/src/ophyd_async/core/_utils.py @@ -2,18 +2,10 @@ import asyncio import logging -from collections.abc import Awaitable, Callable, Iterable, Sequence +from collections.abc import Awaitable, Callable, Iterable, Mapping, Sequence from dataclasses import dataclass from enum import Enum, EnumMeta -from typing import ( - Any, - Generic, - Literal, - ParamSpec, - TypeVar, - get_args, - get_origin, -) +from typing import Any, Generic, Literal, ParamSpec, TypeVar, get_args, get_origin from unittest.mock import Mock import numpy as np @@ -22,7 +14,7 @@ P = ParamSpec("P") Callback = Callable[[T], None] DEFAULT_TIMEOUT = 10.0 -ErrorText = str | dict[str, Exception] +ErrorText = str | Mapping[str, Exception] class StrictEnum(str, Enum): @@ -100,6 +92,19 @@ def format_error_string(self, indent="") -> str: def __str__(self) -> str: return self.format_error_string(indent="") + @classmethod + def with_other_exceptions_logged( + cls, exceptions: Mapping[str, Exception] + ) -> NotConnected: + for name, exception in exceptions.items(): + if not isinstance(exception, NotConnected): + logging.exception( + f"device `{name}` raised unexpected exception " + f"{type(exception).__name__}", + exc_info=exception, + ) + return NotConnected(exceptions) + @dataclass(frozen=True) class WatcherUpdate(Generic[T]): @@ -137,14 +142,7 @@ async def wait_for_connection(**coros: Awaitable[None]): exceptions[name] = result if exceptions: - for name, exception in exceptions.items(): - if not isinstance(exception, NotConnected): - logging.exception( - f"device `{name}` raised unexpected exception " - f"{type(exception).__name__}", - exc_info=exception, - ) - raise NotConnected(exceptions) + raise NotConnected.with_other_exceptions_logged(exceptions) def get_dtype(datatype: type) -> np.dtype: diff --git a/src/ophyd_async/epics/core/_p4p.py b/src/ophyd_async/epics/core/_p4p.py index 423839f5ee..c79b70debe 100644 --- a/src/ophyd_async/epics/core/_p4p.py +++ b/src/ophyd_async/epics/core/_p4p.py @@ -188,6 +188,7 @@ def write_value(self, value: BaseModel | dict[str, Any]) -> Any: ("epics:nt/NTScalarArray:1.0", "as"): (Sequence[str], PvaConverter), ("epics:nt/NTTable:1.0", "S"): (Table, PvaTableConverter), ("epics:nt/NTNDArray:1.0", "v"): (np.ndarray, PvaNDArrayConverter), + ("epics:nt/NTNDArray:1.0", "U"): (np.ndarray, PvaNDArrayConverter), } diff --git a/src/ophyd_async/plan_stubs/_ensure_connected.py b/src/ophyd_async/plan_stubs/_ensure_connected.py index 0ad5cff518..2d9a8cc85a 100644 --- a/src/ophyd_async/plan_stubs/_ensure_connected.py +++ b/src/ophyd_async/plan_stubs/_ensure_connected.py @@ -1,3 +1,5 @@ +from collections.abc import Awaitable + import bluesky.plan_stubs as bps from ophyd_async.core import DEFAULT_TIMEOUT, Device, LazyMock, wait_for_connection @@ -9,18 +11,23 @@ def ensure_connected( timeout: float = DEFAULT_TIMEOUT, force_reconnect=False, ): - (connect_task,) = yield from bps.wait_for( - [ - lambda: wait_for_connection( - **{ - device.name: device.connect( - mock=mock, timeout=timeout, force_reconnect=force_reconnect - ) - for device in devices - } + device_names = [device.name for device in devices] + non_unique = { + device: device.name for device in devices if device_names.count(device.name) > 1 + } + if non_unique: + raise ValueError(f"Devices do not have unique names {non_unique}") + + def connect_devices() -> Awaitable[None]: + coros = { + device.name: device.connect( + mock=mock, timeout=timeout, force_reconnect=force_reconnect ) - ] - ) + for device in devices + } + return wait_for_connection(**coros) + + (connect_task,) = yield from bps.wait_for([connect_devices]) if connect_task and connect_task.exception() is not None: raise connect_task.exception() diff --git a/tests/core/test_device_save_loader.py b/tests/core/test_device_save_loader.py index 9800b706e4..8106311ada 100644 --- a/tests/core/test_device_save_loader.py +++ b/tests/core/test_device_save_loader.py @@ -4,7 +4,6 @@ from unittest.mock import patch import numpy as np -import numpy.typing as npt import pytest import yaml from bluesky.run_engine import RunEngine @@ -71,17 +70,16 @@ def __init__(self, name: str): self.pv_str: SignalRW = epics_signal_rw(str, "PV2") self.pv_enum_str: SignalRW = epics_signal_rw(MyEnum, "PV3") self.pv_enum: SignalRW = epics_signal_rw(MyEnum, "PV4") - self.pv_array_int8 = epics_signal_rw(npt.NDArray[np.int8], "PV5") - self.pv_array_uint8 = epics_signal_rw(npt.NDArray[np.uint8], "PV6") - self.pv_array_int16 = epics_signal_rw(npt.NDArray[np.int16], "PV7") - self.pv_array_uint16 = epics_signal_rw(npt.NDArray[np.uint16], "PV8") - self.pv_array_int32 = epics_signal_rw(npt.NDArray[np.int32], "PV9") - self.pv_array_uint32 = epics_signal_rw(npt.NDArray[np.uint32], "PV10") - self.pv_array_int64 = epics_signal_rw(npt.NDArray[np.int64], "PV11") - self.pv_array_uint64 = epics_signal_rw(npt.NDArray[np.uint64], "PV12") - self.pv_array_float32 = epics_signal_rw(npt.NDArray[np.float32], "PV13") - self.pv_array_float64 = epics_signal_rw(npt.NDArray[np.float64], "PV14") - self.pv_array_npstr = epics_signal_rw(npt.NDArray[np.str_], "PV15") + self.pv_array_int8 = epics_signal_rw(Array1D[np.int8], "PV5") + self.pv_array_uint8 = epics_signal_rw(Array1D[np.uint8], "PV6") + self.pv_array_int16 = epics_signal_rw(Array1D[np.int16], "PV7") + self.pv_array_uint16 = epics_signal_rw(Array1D[np.uint16], "PV8") + self.pv_array_int32 = epics_signal_rw(Array1D[np.int32], "PV9") + self.pv_array_uint32 = epics_signal_rw(Array1D[np.uint32], "PV10") + self.pv_array_int64 = epics_signal_rw(Array1D[np.int64], "PV11") + self.pv_array_uint64 = epics_signal_rw(Array1D[np.uint64], "PV12") + self.pv_array_float32 = epics_signal_rw(Array1D[np.float32], "PV13") + self.pv_array_float64 = epics_signal_rw(Array1D[np.float64], "PV14") self.pv_array_str = epics_signal_rw(Sequence[str], "PV16") self.pv_protocol_device_abstraction = epics_signal_rw(Table, "pva://PV17") super().__init__(name) @@ -168,9 +166,6 @@ async def test_save_device_all_types( ) await pv.set(data) - await device_all_types.pv_array_npstr.set( - np.array(["one", "two", "three"], dtype=np.str_), - ) await device_all_types.pv_array_str.set( ["one", "two", "three"], ) diff --git a/tests/core/test_signal.py b/tests/core/test_signal.py index 77f8e001a7..79923ff61e 100644 --- a/tests/core/test_signal.py +++ b/tests/core/test_signal.py @@ -5,6 +5,8 @@ from asyncio import Event from unittest.mock import ANY, Mock +import numpy as np +import numpy.typing as npt import pytest from bluesky.protocols import Reading @@ -362,21 +364,35 @@ async def test_subscription_logs(caplog): assert "Closing subscription on source" in caplog.text -async def test_signal_unknown_datatype(): - class SomeClass: - def __init__(self): - self.some_attribute = "some_attribute" +class SomeClass: + def __init__(self): + self.some_attribute = "some_attribute" - def some_function(self): - pass + def some_function(self): + pass - err_str = ( - "Can't make converter for .SomeClass'>" - ) + +@pytest.mark.parametrize( + "datatype,err", + [ + (SomeClass, "Can't make converter for %s"), + (object, "Can't make converter for %s"), + (dict, "Can't make converter for %s"), + (npt.NDArray[np.float64], "Expected Array1D[dtype], got %s"), + ], +) +async def test_signal_unknown_datatype(datatype, err): + err_str = re.escape(err % datatype) with pytest.raises(TypeError, match=err_str): - await epics_signal_rw(SomeClass, "pva://mock_signal").connect(mock=True) + await epics_signal_rw(datatype, "pva://mock_signal").connect(mock=True) with pytest.raises(TypeError, match=err_str): - await epics_signal_rw(SomeClass, "ca://mock_signal").connect(mock=True) + await epics_signal_rw(datatype, "ca://mock_signal").connect(mock=True) with pytest.raises(TypeError, match=err_str): - soft_signal_rw(SomeClass) + soft_signal_rw(datatype) + + +async def test_soft_signal_ndarray_can_change_dtype(): + sig = soft_signal_rw(np.ndarray) + for dtype in (np.int32, np.float64): + await sig.set(np.arange(4, dtype=dtype)) + assert (await sig.get_value()).dtype == dtype diff --git a/tests/core/test_soft_signal_backend.py b/tests/core/test_soft_signal_backend.py index fc60a2bbfa..89eb87cc36 100644 --- a/tests/core/test_soft_signal_backend.py +++ b/tests/core/test_soft_signal_backend.py @@ -5,11 +5,11 @@ from typing import Any import numpy as np -import numpy.typing as npt import pytest from bluesky.protocols import Reading from ophyd_async.core import ( + Array1D, SignalBackend, SoftSignalBackend, StrictEnum, @@ -81,20 +81,17 @@ def close(self): (float, 0.0, 43.5, number_d, " None: + self.sig = epics_signal_rw(object, "", "") + super().__init__(name) + + +async def test_error_handling_device_collector_mock(): + with pytest.raises(NotConnected) as e: + async with DeviceCollector(mock=True): + device = BadDatatypeDevice() + device2 = BadDatatypeDevice() + expected_output = NotConnected( + { + "device": NotConnected( + {"sig": TypeError(f"Can't make converter for {object}")} + ), + "device2": NotConnected( + {"sig": TypeError(f"Can't make converter for {object}")} + ), + } + ) + assert str(expected_output) == str(e.value) + + async def test_error_handling_device_collector(caplog): caplog.set_level(10) with pytest.raises(NotConnected) as e: diff --git a/tests/plan_stubs/test_ensure_connected.py b/tests/plan_stubs/test_ensure_connected.py index 4665ddcee9..d96184551e 100644 --- a/tests/plan_stubs/test_ensure_connected.py +++ b/tests/plan_stubs/test_ensure_connected.py @@ -1,3 +1,5 @@ +import re + import pytest from ophyd_async.core import Device, NotConnected, soft_signal_rw @@ -36,3 +38,15 @@ def connect_with_mocking(): assert device2.signal._mock is not None RE(connect_with_mocking()) + + +def test_ensure_connected_fails_for_non_unique_device_names(RE): + d1 = Device("dupe") + d2 = Device("dupe") + d3 = Device("ok") + non_unique = {d1: "dupe", d2: "dupe"} + with pytest.raises( + ValueError, + match=re.escape(f"Devices do not have unique names {non_unique}"), + ): + RE(ensure_connected(d1, d2, d3)) diff --git a/tests/test_data/test_yaml_save.yml b/tests/test_data/test_yaml_save.yml index d00d38dff9..6a0613d1a8 100644 --- a/tests/test_data/test_yaml_save.yml +++ b/tests/test_data/test_yaml_save.yml @@ -24,7 +24,6 @@ pv_array_int32: [-2147483648, 2147483647, 0, 1, 2, 3, 4] pv_array_int64: [-9223372036854775808, 9223372036854775807, 0, 1, 2, 3, 4] pv_array_int8: [-128, 127, 0, 1, 2, 3, 4] - pv_array_npstr: [one, two, three] pv_array_str: - one - two From bc153ead462d2f222f0d8aeb859aff214f22b277 Mon Sep 17 00:00:00 2001 From: "Tom C (DLS)" <101418278+coretl@users.noreply.github.com> Date: Thu, 14 Nov 2024 13:50:27 +0000 Subject: [PATCH 29/30] Yield in each loop of observe_value (#648) This helps in the very specific case of an observe_value directly or indirectly modifying the signal that is being updated. This creates a busy loop which will not be interrupted by wrapping in asyncio.wait_for. To demonstrate, added test_observe_value_times_out_with_no_external_task --- src/ophyd_async/core/_signal.py | 5 +- tests/core/test_observe.py | 90 +++++++++++++++++++++++++++++++++ 2 files changed, 94 insertions(+), 1 deletion(-) create mode 100644 tests/core/test_observe.py diff --git a/src/ophyd_async/core/_signal.py b/src/ophyd_async/core/_signal.py index cd6152ef08..d7cbb179de 100644 --- a/src/ophyd_async/core/_signal.py +++ b/src/ophyd_async/core/_signal.py @@ -502,7 +502,10 @@ def queue_value(value: SignalDatatypeT, signal=signal): try: while True: - item = await get_value() + # yield here in case something else is filling the queue + # like in test_observe_value_times_out_with_no_external_task() + await asyncio.sleep(0) + item = await asyncio.wait_for(q.get(), timeout) if done_status and item is done_status: if exc := done_status.exception(): raise exc diff --git a/tests/core/test_observe.py b/tests/core/test_observe.py new file mode 100644 index 0000000000..e68b08465e --- /dev/null +++ b/tests/core/test_observe.py @@ -0,0 +1,90 @@ +import asyncio +import time + +import pytest + +from ophyd_async.core import AsyncStatus, observe_value, soft_signal_r_and_setter + + +async def test_observe_value_working_correctly(): + sig, setter = soft_signal_r_and_setter(float) + + async def tick(): + for i in range(2): + await asyncio.sleep(0.01) + setter(i + 1) + + recv = [] + status = AsyncStatus(tick()) + async for val in observe_value(sig, done_status=status): + recv.append(val) + assert recv == [0, 1, 2] + await status + + +async def test_observe_value_times_out(): + sig, setter = soft_signal_r_and_setter(float) + + async def tick(): + for i in range(5): + await asyncio.sleep(0.1) + setter(i + 1) + + recv = [] + + async def watch(): + async for val in observe_value(sig): + recv.append(val) + + t = asyncio.create_task(tick()) + start = time.time() + try: + with pytest.raises(asyncio.TimeoutError): + await asyncio.wait_for(watch(), timeout=0.2) + assert recv == [0, 1] + assert time.time() - start == pytest.approx(0.2, abs=0.05) + finally: + t.cancel() + + +async def test_observe_value_times_out_with_busy_sleep(): + sig, setter = soft_signal_r_and_setter(float) + + async def tick(): + for i in range(5): + await asyncio.sleep(0.1) + setter(i + 1) + + recv = [] + + async def watch(): + async for val in observe_value(sig): + time.sleep(0.15) + recv.append(val) + + t = asyncio.create_task(tick()) + start = time.time() + try: + with pytest.raises(asyncio.TimeoutError): + await asyncio.wait_for(watch(), timeout=0.2) + assert recv == [0, 1] + assert time.time() - start == pytest.approx(0.3, abs=0.05) + finally: + t.cancel() + + +async def test_observe_value_times_out_with_no_external_task(): + sig, setter = soft_signal_r_and_setter(float) + + recv = [] + + async def watch(): + async for val in observe_value(sig): + recv.append(val) + setter(val + 1) + + start = time.time() + with pytest.raises(asyncio.TimeoutError): + await asyncio.wait_for(watch(), timeout=0.1) + assert recv + assert time.time() - start == pytest.approx(0.1, abs=0.05) From 3a128381ea4db11b43e12c426a622d52ca8de5e3 Mon Sep 17 00:00:00 2001 From: "Tom C (DLS)" <101418278+coretl@users.noreply.github.com> Date: Thu, 14 Nov 2024 13:55:00 +0000 Subject: [PATCH 30/30] Add introspection of the errors that make up NotConnected (#649) --- src/ophyd_async/core/_utils.py | 7 +++++++ tests/core/test_utils.py | 8 ++++++++ 2 files changed, 15 insertions(+) diff --git a/src/ophyd_async/core/_utils.py b/src/ophyd_async/core/_utils.py index 2aa4b1c717..edb8b2c4b2 100644 --- a/src/ophyd_async/core/_utils.py +++ b/src/ophyd_async/core/_utils.py @@ -62,6 +62,13 @@ def __init__(self, errors: ErrorText): self._errors = errors + @property + def sub_errors(self) -> Mapping[str, Exception]: + if isinstance(self._errors, dict): + return self._errors.copy() + else: + return {} + def _format_sub_errors(self, name: str, error: Exception, indent="") -> str: if isinstance(error, NotConnected): error_txt = ":" + error.format_error_string(indent + self._indent_width) diff --git a/tests/core/test_utils.py b/tests/core/test_utils.py index 7af2e507a7..2dea901a27 100644 --- a/tests/core/test_utils.py +++ b/tests/core/test_utils.py @@ -197,6 +197,14 @@ async def test_error_handling_device_collector_mock(): assert str(expected_output) == str(e.value) +def test_introspecting_sub_errors(): + sub_error1 = NotConnected("bad") + assert sub_error1.sub_errors == {} + sub_error2 = ValueError("very bad") + error = NotConnected({"child1": sub_error1, "child2": sub_error2}) + assert error.sub_errors == {"child1": sub_error1, "child2": sub_error2} + + async def test_error_handling_device_collector(caplog): caplog.set_level(10) with pytest.raises(NotConnected) as e: