From a0675e52b2627fa1f4032b83c3d8d21cbea00b59 Mon Sep 17 00:00:00 2001 From: AlexWells Date: Thu, 18 Apr 2024 14:40:02 +0100 Subject: [PATCH] Reduce boilerplate in StandardReadable This reduces the amount of duplication and repetition when adding signals to a StandardReadable. As part of this, classes defining the types of signal have been created, which control the behaviour of the Signal being registered Signals must be registered either using the "add_children_as_readables" contextmanager, or the "add_readables" function. --- src/ophyd_async/core/__init__.py | 8 +- src/ophyd_async/core/detector.py | 5 +- src/ophyd_async/core/standard_readable.py | 168 ++++++++++++++---- .../epics/areadetector/single_trigger_det.py | 19 +- src/ophyd_async/epics/demo/__init__.py | 36 ++-- src/ophyd_async/epics/motion/motor.py | 17 +- src/ophyd_async/protocols/__init__.py | 4 +- src/ophyd_async/protocols/protocols.py | 23 +++ src/ophyd_async/sim/sim_pattern_generator.py | 4 +- tests/sim/test_sim_detector.py | 2 +- 10 files changed, 209 insertions(+), 77 deletions(-) diff --git a/src/ophyd_async/core/__init__.py b/src/ophyd_async/core/__init__.py index 103638019d..eaf0ad6165 100644 --- a/src/ophyd_async/core/__init__.py +++ b/src/ophyd_async/core/__init__.py @@ -39,7 +39,11 @@ ) from .signal_backend import SignalBackend from .sim_signal_backend import SimSignalBackend -from .standard_readable import StandardReadable +from .standard_readable import ( + ConfigSignal, + HintedSignal, + StandardReadable, +) from .utils import ( DEFAULT_TIMEOUT, Callback, @@ -80,6 +84,8 @@ "ShapeProvider", "StaticDirectoryProvider", "StandardReadable", + "ConfigSignal", + "HintedSignal", "TriggerInfo", "TriggerLogic", "HardwareTriggeredFlyable", diff --git a/src/ophyd_async/core/detector.py b/src/ophyd_async/core/detector.py index b2b2ec890f..416cd88ef7 100644 --- a/src/ophyd_async/core/detector.py +++ b/src/ophyd_async/core/detector.py @@ -33,7 +33,6 @@ from .async_status import AsyncStatus from .device import Device -from .signal import SignalR from .utils import DEFAULT_TIMEOUT, merge_gathered_dicts T = TypeVar("T") @@ -161,7 +160,7 @@ def __init__( self, controller: DetectorControl, writer: DetectorWriter, - config_sigs: Sequence[SignalR] = (), + config_sigs: Sequence[AsyncReadable] = (), name: str = "", writer_timeout: float = DEFAULT_TIMEOUT, ) -> None: @@ -214,7 +213,7 @@ async def stage(self) -> None: async def _check_config_sigs(self): """Checks configuration signals are named and connected.""" for signal in self._config_sigs: - if signal._name == "": + if signal.name == "": raise Exception( "config signal must be named before it is passed to the detector" ) diff --git a/src/ophyd_async/core/standard_readable.py b/src/ophyd_async/core/standard_readable.py index 3c37d5a332..89537b445e 100644 --- a/src/ophyd_async/core/standard_readable.py +++ b/src/ophyd_async/core/standard_readable.py @@ -1,8 +1,9 @@ -from typing import Dict, Sequence, Tuple +from contextlib import contextmanager +from typing import Dict, Generator, List, Optional, Sequence, Type, Union -from bluesky.protocols import Descriptor, Reading, Stageable +from bluesky.protocols import Descriptor, HasHints, Hints, Reading, Stageable -from ophyd_async.protocols import AsyncConfigurable, AsyncReadable +from ophyd_async.protocols import AsyncConfigurable, AsyncReadable, AsyncStageable from .async_status import AsyncStatus from .device import Device @@ -10,7 +11,9 @@ from .utils import merge_gathered_dicts -class StandardReadable(Device, AsyncReadable, AsyncConfigurable, Stageable): +class StandardReadable( + Device, AsyncReadable, AsyncConfigurable, AsyncStageable, HasHints +): """Device that owns its children and provides useful default behavior. - When its name is set it renames child Devices @@ -18,57 +21,148 @@ class StandardReadable(Device, AsyncReadable, AsyncConfigurable, Stageable): - These signals will be subscribed for read() between stage() and unstage() """ - _read_signals: Tuple[SignalR, ...] = () - _configuration_signals: Tuple[SignalR, ...] = () - _read_uncached_signals: Tuple[SignalR, ...] = () + _readables: List[AsyncReadable] = [] + _configurables: List[AsyncConfigurable] = [] + _stageables: List[AsyncStageable] = [] - def set_readable_signals( - self, - read: Sequence[SignalR] = (), - config: Sequence[SignalR] = (), - read_uncached: Sequence[SignalR] = (), - ): - """ - Parameters - ---------- - read: - Signals to make up :meth:`~StandardReadable.read` - conf: - Signals to make up :meth:`~StandardReadable.read_configuration` - read_uncached: - Signals to make up :meth:`~StandardReadable.read` that won't be cached - """ - self._read_signals = tuple(read) - self._configuration_signals = tuple(config) - self._read_uncached_signals = tuple(read_uncached) + _hints: Hints = {} @AsyncStatus.wrap async def stage(self) -> None: - for sig in self._read_signals + self._configuration_signals: + for sig in self._stageables: await sig.stage().task @AsyncStatus.wrap async def unstage(self) -> None: - for sig in self._read_signals + self._configuration_signals: + for sig in self._stageables: await sig.unstage().task async def describe_configuration(self) -> Dict[str, Descriptor]: return await merge_gathered_dicts( - [sig.describe() for sig in self._configuration_signals] + [sig.describe_configuration() for sig in self._configurables] ) async def read_configuration(self) -> Dict[str, Reading]: return await merge_gathered_dicts( - [sig.read() for sig in self._configuration_signals] + [sig.read_configuration() for sig in self._configurables] ) async def describe(self) -> Dict[str, Descriptor]: - return await merge_gathered_dicts( - [sig.describe() for sig in self._read_signals + self._read_uncached_signals] - ) + return await merge_gathered_dicts([sig.describe() for sig in self._readables]) async def read(self) -> Dict[str, Reading]: - return await merge_gathered_dicts( - [sig.read() for sig in self._read_signals] - + [sig.read(cached=False) for sig in self._read_uncached_signals] - ) + return await merge_gathered_dicts([sig.read() for sig in self._readables]) + + @property + def hints(self) -> Hints: + return self._hints + + @contextmanager + def add_children_as_readables( + self, + wrapper: Optional[Type[Union["ConfigSignal", "HintedSignal"]]] = None, + ) -> Generator[None, None, None]: + dict_copy = self.__dict__.copy() + + yield + + # Set symmetric difference operator gives all newly added items + new_attributes = dict_copy.items() ^ self.__dict__.items() + new_signals: List[SignalR] = [x[1] for x in new_attributes] + + self._wrap_signals(wrapper, new_signals) + + def add_readables( + self, + wrapper: Type[Union["ConfigSignal", "HintedSignal"]], + *signals: SignalR, + ) -> None: + + self._wrap_signals(wrapper, signals) + + def _wrap_signals( + self, + wrapper: Optional[Type[Union["ConfigSignal", "HintedSignal"]]], + signals: Sequence[SignalR], + ): + + for signal in signals: + obj: Union[SignalR, "ConfigSignal", "HintedSignal"] = signal + if wrapper: + obj = wrapper(signal) + + if isinstance(obj, AsyncReadable): + self._readables.append(obj) + + if isinstance(obj, AsyncConfigurable): + self._configurables.append(obj) + + if isinstance(obj, AsyncStageable): + self._stageables.append(obj) + + if isinstance(obj, HasHints): + new_hint = obj.hints + + # Merge the existing and new hints, based on the type of the value. + # This avoids default dict merge behaviour that overrided the values; + # we want to combine them when they are Sequences, and ensure they are + # identical when string values. + for key, value in new_hint.items(): + if isinstance(value, Sequence): + if key in self._hints: + self._hints[key] = ( # type: ignore[literal-required] + self._hints[key] # type: ignore[literal-required] + + value + ) + else: + self._hints[key] = value # type: ignore[literal-required] + elif isinstance(value, str): + if key in self._hints: + assert ( + self._hints[key] # type: ignore[literal-required] + == value + ), "Hints value may not be overridden" + else: + self._hints[key] = value # type: ignore[literal-required] + else: + raise AssertionError("Unknown type in Hints dictionary") + + +class ConfigSignal(AsyncConfigurable): + + def __init__(self, signal: SignalR) -> None: + self.signal = signal + + async def read_configuration(self) -> Dict[str, Reading]: + return await self.signal.read() + + async def describe_configuration(self) -> Dict[str, Descriptor]: + return await self.signal.describe() + + +class HintedSignal(HasHints, AsyncReadable): + + def __init__(self, signal: SignalR, cached: Optional[bool] = None) -> None: + self.signal = signal + self.cached = cached + if cached: + self.stage = signal.stage + self.unstage = signal.unstage + + async def read(self) -> Dict[str, Reading]: + return await self.signal.read(cached=self.cached) + + async def describe(self) -> Dict[str, Descriptor]: + return await self.signal.describe() + + @property + def name(self) -> str: + return self.signal.name + + @property + def hints(self) -> Hints: + return {"fields": [self.signal.name]} + + @classmethod + def uncached(cls, signal: SignalR): + return cls(signal, cached=False) diff --git a/src/ophyd_async/epics/areadetector/single_trigger_det.py b/src/ophyd_async/epics/areadetector/single_trigger_det.py index ae76c0475a..4503aa0508 100644 --- a/src/ophyd_async/epics/areadetector/single_trigger_det.py +++ b/src/ophyd_async/epics/areadetector/single_trigger_det.py @@ -3,7 +3,13 @@ from bluesky.protocols import Triggerable -from ophyd_async.core import AsyncStatus, SignalR, StandardReadable +from ophyd_async.core import ( + AsyncStatus, + ConfigSignal, + HintedSignal, + SignalR, + StandardReadable, +) from .drivers.ad_base import ADBase from .utils import ImageMode @@ -20,12 +26,13 @@ def __init__( ) -> None: self.drv = drv self.__dict__.update(plugins) - self.set_readable_signals( - # Can't subscribe to read signals as race between monitor coming back and - # caput callback on acquire - read_uncached=[self.drv.array_counter] + list(read_uncached), - config=[self.drv.acquire_time], + + self.add_readables( + HintedSignal.uncached, self.drv.array_counter, *read_uncached ) + + self.add_readables(ConfigSignal, self.drv.acquire_time) + super().__init__(name=name) @AsyncStatus.wrap diff --git a/src/ophyd_async/epics/demo/__init__.py b/src/ophyd_async/epics/demo/__init__.py index 73833c1731..ff5c83867f 100644 --- a/src/ophyd_async/epics/demo/__init__.py +++ b/src/ophyd_async/epics/demo/__init__.py @@ -14,7 +14,14 @@ import numpy as np from bluesky.protocols import Movable, Stoppable -from ophyd_async.core import AsyncStatus, Device, StandardReadable, observe_value +from ophyd_async.core import ( + AsyncStatus, + ConfigSignal, + Device, + HintedSignal, + StandardReadable, + observe_value, +) from ..signal.signal import epics_signal_r, epics_signal_rw, epics_signal_x @@ -33,13 +40,11 @@ class Sensor(StandardReadable): def __init__(self, prefix: str, name="") -> None: # Define some signals - self.value = epics_signal_r(float, prefix + "Value") - self.mode = epics_signal_rw(EnergyMode, prefix + "Mode") - # Set name and signals for read() and read_configuration() - self.set_readable_signals( - read=[self.value], - config=[self.mode], - ) + with self.add_children_as_readables(HintedSignal): + self.value = epics_signal_r(float, prefix + "Value") + with self.add_children_as_readables(ConfigSignal): + self.mode = epics_signal_rw(EnergyMode, prefix + "Mode") + super().__init__(name=name) @@ -49,19 +54,18 @@ class Mover(StandardReadable, Movable, Stoppable): def __init__(self, prefix: str, name="") -> None: # Define some signals self.setpoint = epics_signal_rw(float, prefix + "Setpoint") - self.readback = epics_signal_r(float, prefix + "Readback") - self.velocity = epics_signal_rw(float, prefix + "Velocity") - self.units = epics_signal_r(str, prefix + "Readback.EGU") self.precision = epics_signal_r(int, prefix + "Readback.PREC") # Signals that collide with standard methods should have a trailing underscore self.stop_ = epics_signal_x(prefix + "Stop.PROC") # Whether set() should complete successfully or not self._set_success = True - # Set name and signals for read() and read_configuration() - self.set_readable_signals( - read=[self.readback], - config=[self.velocity, self.units], - ) + + with self.add_children_as_readables(HintedSignal): + self.readback = epics_signal_r(float, prefix + "Readback") + with self.add_children_as_readables(ConfigSignal): + self.velocity = epics_signal_rw(float, prefix + "Velocity") + self.units = epics_signal_r(str, prefix + "Readback.EGU") + super().__init__(name=name) def set_name(self, name: str): diff --git a/src/ophyd_async/epics/motion/motor.py b/src/ophyd_async/epics/motion/motor.py index ef66a13520..d5d7cce091 100644 --- a/src/ophyd_async/epics/motion/motor.py +++ b/src/ophyd_async/epics/motion/motor.py @@ -4,7 +4,7 @@ from bluesky.protocols import Movable, Stoppable -from ophyd_async.core import AsyncStatus, StandardReadable +from ophyd_async.core import AsyncStatus, ConfigSignal, HintedSignal, StandardReadable from ..signal.signal import epics_signal_r, epics_signal_rw, epics_signal_x @@ -15,25 +15,24 @@ class Motor(StandardReadable, Movable, Stoppable): def __init__(self, prefix: str, name="") -> None: # Define some signals self.user_setpoint = epics_signal_rw(float, prefix + ".VAL") - self.user_readback = epics_signal_r(float, prefix + ".RBV") - self.velocity = epics_signal_rw(float, prefix + ".VELO") self.max_velocity = epics_signal_r(float, prefix + ".VMAX") self.acceleration_time = epics_signal_rw(float, prefix + ".ACCL") - self.motor_egu = epics_signal_r(str, prefix + ".EGU") self.precision = epics_signal_r(int, prefix + ".PREC") self.deadband = epics_signal_r(float, prefix + ".RDBD") self.motor_done_move = epics_signal_r(float, prefix + ".DMOV") self.low_limit_travel = epics_signal_rw(int, prefix + ".LLM") self.high_limit_travel = epics_signal_rw(int, prefix + ".HLM") + with self.add_children_as_readables(ConfigSignal): + self.motor_egu = epics_signal_r(str, prefix + ".EGU") + self.velocity = epics_signal_rw(float, prefix + ".VELO") + + with self.add_children_as_readables(HintedSignal): + self.user_readback = epics_signal_r(float, prefix + ".RBV") + self.motor_stop = epics_signal_x(prefix + ".STOP") # Whether set() should complete successfully or not self._set_success = True - # Set name and signals for read() and read_configuration() - self.set_readable_signals( - read=[self.user_readback], - config=[self.velocity, self.motor_egu], - ) super().__init__(name=name) def set_name(self, name: str): diff --git a/src/ophyd_async/protocols/__init__.py b/src/ophyd_async/protocols/__init__.py index 63fdf7025e..81c0f8827b 100644 --- a/src/ophyd_async/protocols/__init__.py +++ b/src/ophyd_async/protocols/__init__.py @@ -1,3 +1,3 @@ -from .protocols import AsyncConfigurable, AsyncPausable, AsyncReadable +from .protocols import AsyncConfigurable, AsyncPausable, AsyncReadable, AsyncStageable -__all__ = ["AsyncReadable", "AsyncConfigurable", "AsyncPausable"] +__all__ = ["AsyncReadable", "AsyncConfigurable", "AsyncPausable", "AsyncStageable"] diff --git a/src/ophyd_async/protocols/protocols.py b/src/ophyd_async/protocols/protocols.py index 51169d5418..42438d14f1 100644 --- a/src/ophyd_async/protocols/protocols.py +++ b/src/ophyd_async/protocols/protocols.py @@ -3,6 +3,8 @@ from bluesky.protocols import Descriptor, HasName, Reading +from ophyd_async.core.async_status import AsyncStatus + @runtime_checkable class AsyncReadable(HasName, Protocol): @@ -71,3 +73,24 @@ async def pause(self) -> None: async def resume(self) -> None: """Perform device-specific work when the RunEngine resumes after a pause.""" ... + + +@runtime_checkable +class AsyncStageable(Protocol): + @abstractmethod + def stage(self) -> AsyncStatus: + """An optional hook for "setting up" the device for acquisition. + + It should return a ``Status`` that is marked done when the device is + done staging. + """ + ... + + @abstractmethod + def unstage(self) -> AsyncStatus: + """A hook for "cleaning up" the device after acquisition. + + It should return a ``Status`` that is marked done when the device is finished + unstaging. + """ + ... diff --git a/src/ophyd_async/sim/sim_pattern_generator.py b/src/ophyd_async/sim/sim_pattern_generator.py index 3c2c54d8af..3494221c70 100644 --- a/src/ophyd_async/sim/sim_pattern_generator.py +++ b/src/ophyd_async/sim/sim_pattern_generator.py @@ -3,7 +3,7 @@ from ophyd_async.core import DirectoryProvider, StaticDirectoryProvider from ophyd_async.core.detector import StandardDetector -from ophyd_async.core.signal import SignalR +from ophyd_async.protocols.protocols import AsyncReadable from ophyd_async.sim.pattern_generator import PatternGenerator from .sim_pattern_detector_control import SimPatternDetectorControl @@ -14,7 +14,7 @@ class SimPatternDetector(StandardDetector): def __init__( self, path: Path, - config_sigs: Sequence[SignalR] = [], + config_sigs: Sequence[AsyncReadable] = [], name: str = "sim_pattern_detector", writer_timeout: float = 1, ) -> None: diff --git a/tests/sim/test_sim_detector.py b/tests/sim/test_sim_detector.py index 808dffee90..338cbf342a 100644 --- a/tests/sim/test_sim_detector.py +++ b/tests/sim/test_sim_detector.py @@ -31,7 +31,7 @@ async def test_writes_pattern_to_file( sim_pattern_detector: SimPatternDetector, sim_motor: motor.Motor, tmp_path ): sim_pattern_detector = SimPatternDetector( - config_sigs=[*sim_motor._read_signals], path=tmp_path + config_sigs=[*sim_motor._readables], path=tmp_path ) images_number = 2