diff --git a/docs/examples/tango_demo.py b/docs/examples/tango_demo.py new file mode 100644 index 0000000000..503dccd1ec --- /dev/null +++ b/docs/examples/tango_demo.py @@ -0,0 +1,54 @@ +import asyncio + +import bluesky.plan_stubs as bps +import bluesky.plans as bp +from bluesky import RunEngine + +from ophyd_async.tango.demo import ( + DemoCounter, + DemoMover, + TangoDetector, +) +from tango.test_context import MultiDeviceTestContext + +content = ( + { + "class": DemoMover, + "devices": [{"name": "demo/motor/1"}], + }, + { + "class": DemoCounter, + "devices": [{"name": "demo/counter/1"}, {"name": "demo/counter/2"}], + }, +) + +tango_context = MultiDeviceTestContext(content) + + +async def main(): + with tango_context: + detector = TangoDetector( + trl="", + name="detector", + counters_kwargs={"prefix": "demo/counter/", "count": 2}, + mover_kwargs={"trl": "demo/motor/1"}, + ) + await detector.connect() + + RE = RunEngine() + + RE(bps.read(detector)) + RE(bps.mv(detector, 0)) + RE(bp.count(list(detector.counters.values()))) + + set_status = detector.set(1.0) + await asyncio.sleep(0.1) + stop_status = detector.stop() + await set_status + await stop_status + assert all([set_status.done, stop_status.done]) + assert all([set_status.success, stop_status.success]) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/pyproject.toml b/pyproject.toml index 23afbf19bc..7dbf22271b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,10 +33,12 @@ requires-python = ">=3.10" ca = ["aioca>=1.6"] pva = ["p4p"] sim = ["h5py"] +tango = ["pytango>=10.0.0"] dev = [ "ophyd_async[pva]", "ophyd_async[sim]", "ophyd_async[ca]", + "ophyd_async[tango]", "black", "flake8", "flake8-isort", @@ -59,6 +61,7 @@ dev = [ "pytest-asyncio", "pytest-cov", "pytest-faulthandler", + "pytest-forked", "pytest-rerunfailures", "pytest-timeout", "ruff", diff --git a/src/ophyd_async/tango/__init__.py b/src/ophyd_async/tango/__init__.py index e69de29bb2..5b45a067c4 100644 --- a/src/ophyd_async/tango/__init__.py +++ b/src/ophyd_async/tango/__init__.py @@ -0,0 +1,45 @@ +from .base_devices import ( + TangoDevice, + TangoReadable, + tango_polling, +) +from .signal import ( + AttributeProxy, + CommandProxy, + TangoSignalBackend, + __tango_signal_auto, + ensure_proper_executor, + get_dtype_extended, + get_python_type, + get_tango_trl, + get_trl_descriptor, + infer_python_type, + infer_signal_character, + make_backend, + tango_signal_r, + tango_signal_rw, + tango_signal_w, + tango_signal_x, +) + +__all__ = [ + "TangoDevice", + "TangoReadable", + "tango_polling", + "TangoSignalBackend", + "get_python_type", + "get_dtype_extended", + "get_trl_descriptor", + "get_tango_trl", + "infer_python_type", + "infer_signal_character", + "make_backend", + "AttributeProxy", + "CommandProxy", + "ensure_proper_executor", + "__tango_signal_auto", + "tango_signal_r", + "tango_signal_rw", + "tango_signal_w", + "tango_signal_x", +] diff --git a/src/ophyd_async/tango/base_devices/__init__.py b/src/ophyd_async/tango/base_devices/__init__.py new file mode 100644 index 0000000000..ecba4e1b23 --- /dev/null +++ b/src/ophyd_async/tango/base_devices/__init__.py @@ -0,0 +1,4 @@ +from ._base_device import TangoDevice, tango_polling +from ._tango_readable import TangoReadable + +__all__ = ["TangoDevice", "TangoReadable", "tango_polling"] diff --git a/src/ophyd_async/tango/base_devices/_base_device.py b/src/ophyd_async/tango/base_devices/_base_device.py new file mode 100644 index 0000000000..9d01539263 --- /dev/null +++ b/src/ophyd_async/tango/base_devices/_base_device.py @@ -0,0 +1,225 @@ +from __future__ import annotations + +from typing import ( + TypeVar, + get_args, + get_origin, + get_type_hints, +) + +from ophyd_async.core import ( + DEFAULT_TIMEOUT, + Device, + Signal, +) +from ophyd_async.tango.signal import ( + TangoSignalBackend, + __tango_signal_auto, + make_backend, +) +from tango import DeviceProxy as DeviceProxy +from tango.asyncio import DeviceProxy as AsyncDeviceProxy + +T = TypeVar("T") + + +class TangoDevice(Device): + """ + General class for TangoDevices. Extends Device to provide attributes for Tango + devices. + + Parameters + ---------- + trl: str + Tango resource locator, typically of the device server. + device_proxy: Optional[Union[AsyncDeviceProxy, SyncDeviceProxy]] + Asynchronous or synchronous DeviceProxy object for the device. If not provided, + an asynchronous DeviceProxy object will be created using the trl and awaited + when the device is connected. + """ + + trl: str = "" + proxy: DeviceProxy | None = None + _polling: tuple[bool, float, float | None, float | None] = (False, 0.1, None, 0.1) + _signal_polling: dict[str, tuple[bool, float, float, float]] = {} + _poll_only_annotated_signals: bool = True + + def __init__( + self, + trl: str | None = None, + device_proxy: DeviceProxy | None = None, + name: str = "", + ) -> None: + self.trl = trl if trl else "" + self.proxy = device_proxy + tango_create_children_from_annotations(self) + super().__init__(name=name) + + def set_trl(self, trl: str): + """Set the Tango resource locator.""" + if not isinstance(trl, str): + raise ValueError("TRL must be a string.") + self.trl = trl + + async def connect( + self, + mock: bool = False, + timeout: float = DEFAULT_TIMEOUT, + force_reconnect: bool = False, + ): + if self.trl and self.proxy is None: + self.proxy = await AsyncDeviceProxy(self.trl) + elif self.proxy and not self.trl: + self.trl = self.proxy.name() + + # Set the trl of the signal backends + for child in self.children(): + if isinstance(child[1], Signal): + if isinstance(child[1]._backend, TangoSignalBackend): # noqa: SLF001 + resource_name = child[0].lstrip("_") + read_trl = f"{self.trl}/{resource_name}" + child[1]._backend.set_trl(read_trl, read_trl) # noqa: SLF001 + + if self.proxy is not None: + self.register_signals() + await _fill_proxy_entries(self) + + # set_name should be called again to propagate the new signal names + self.set_name(self.name) + + # Set the polling configuration + if self._polling[0]: + for child in self.children(): + child_type = type(child[1]) + if issubclass(child_type, Signal): + if isinstance(child[1]._backend, TangoSignalBackend): # noqa: SLF001 # type: ignore + child[1]._backend.set_polling(*self._polling) # noqa: SLF001 # type: ignore + child[1]._backend.allow_events(False) # noqa: SLF001 # type: ignore + if self._signal_polling: + for signal_name, polling in self._signal_polling.items(): + if hasattr(self, signal_name): + attr = getattr(self, signal_name) + if isinstance(attr._backend, TangoSignalBackend): # noqa: SLF001 + attr._backend.set_polling(*polling) # noqa: SLF001 + attr._backend.allow_events(False) # noqa: SLF001 + + await super().connect(mock=mock, timeout=timeout) + + # Users can override this method to register new signals + def register_signals(self): + pass + + +def tango_polling( + polling: tuple[float, float, float] + | dict[str, tuple[float, float, float]] + | None = None, + signal_polling: dict[str, tuple[float, float, float]] | None = None, +): + """ + Class decorator to configure polling for Tango devices. + + This decorator allows for the configuration of both device-level and signal-level + polling for Tango devices. Polling is useful for device servers that do not support + event-driven updates. + + Parameters + ---------- + polling : Optional[Union[Tuple[float, float, float], + Dict[str, Tuple[float, float, float]]]], optional + Device-level polling configuration as a tuple of three floats representing the + polling interval, polling timeout, and polling delay. Alternatively, + a dictionary can be provided to specify signal-level polling configurations + directly. + signal_polling : Optional[Dict[str, Tuple[float, float, float]]], optional + Signal-level polling configuration as a dictionary where keys are signal names + and values are tuples of three floats representing the polling interval, polling + timeout, and polling delay. + """ + if isinstance(polling, dict): + signal_polling = polling + polling = None + + def decorator(cls): + if polling is not None: + cls._polling = (True, *polling) + if signal_polling is not None: + cls._signal_polling = {k: (True, *v) for k, v in signal_polling.items()} + return cls + + return decorator + + +def tango_create_children_from_annotations( + device: TangoDevice, included_optional_fields: tuple[str, ...] = () +): + """Initialize blocks at __init__ of `device`.""" + for name, device_type in get_type_hints(type(device)).items(): + if name in ("_name", "parent"): + continue + + # device_type, is_optional = _strip_union(device_type) + # if is_optional and name not in included_optional_fields: + # continue + # + # is_device_vector, device_type = _strip_device_vector(device_type) + # if is_device_vector: + # n_device_vector = DeviceVector() + # setattr(device, name, n_device_vector) + + # else: + origin = get_origin(device_type) + origin = origin if origin else device_type + + if issubclass(origin, Signal): + type_args = get_args(device_type) + datatype = type_args[0] if type_args else None + backend = make_backend(datatype=datatype, device_proxy=device.proxy) + setattr(device, name, origin(name=name, backend=backend)) + + elif issubclass(origin, Device) or isinstance(origin, Device): + assert callable(origin), f"{origin} is not callable." + setattr(device, name, origin()) + + +async def _fill_proxy_entries(device: TangoDevice): + if device.proxy is None: + raise RuntimeError(f"Device proxy is not connected for {device.name}") + proxy_trl = device.trl + children = [name.lstrip("_") for name, _ in device.children()] + proxy_attributes = list(device.proxy.get_attribute_list()) + proxy_commands = list(device.proxy.get_command_list()) + combined = proxy_attributes + proxy_commands + + for name in combined: + if name not in children: + full_trl = f"{proxy_trl}/{name}" + try: + auto_signal = await __tango_signal_auto( + trl=full_trl, device_proxy=device.proxy + ) + setattr(device, name, auto_signal) + except RuntimeError as e: + if "Commands with different in and out dtypes" in str(e): + print( + f"Skipping {name}. Commands with different in and out dtypes" + f" are not supported." + ) + continue + raise e + + +# def _strip_union(field: T | T) -> tuple[T, bool]: +# if get_origin(field) is Union: +# args = get_args(field) +# is_optional = type(None) in args +# for arg in args: +# if arg is not type(None): +# return arg, is_optional +# return field, False +# +# +# def _strip_device_vector(field: type[Device]) -> tuple[bool, type[Device]]: +# if get_origin(field) is DeviceVector: +# return True, get_args(field)[0] +# return False, field diff --git a/src/ophyd_async/tango/base_devices/_tango_readable.py b/src/ophyd_async/tango/base_devices/_tango_readable.py new file mode 100644 index 0000000000..4a8fe3d1a4 --- /dev/null +++ b/src/ophyd_async/tango/base_devices/_tango_readable.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +from ophyd_async.core import ( + StandardReadable, +) +from ophyd_async.tango.base_devices._base_device import TangoDevice +from tango import DeviceProxy + + +class TangoReadable(TangoDevice, StandardReadable): + """ + General class for readable TangoDevices. Extends StandardReadable to provide + attributes for Tango devices. + + Usage: to proper signals mount should be awaited: + new_device = await TangoDevice() + + Attributes + ---------- + trl : str + Tango resource locator, typically of the device server. + proxy : AsyncDeviceProxy + AsyncDeviceProxy object for the device. This is created when the + device is connected. + """ + + def __init__( + self, + trl: str | None = None, + device_proxy: DeviceProxy | None = None, + name: str = "", + ) -> None: + TangoDevice.__init__(self, trl, device_proxy=device_proxy, name=name) diff --git a/src/ophyd_async/tango/demo/__init__.py b/src/ophyd_async/tango/demo/__init__.py new file mode 100644 index 0000000000..78fffae2aa --- /dev/null +++ b/src/ophyd_async/tango/demo/__init__.py @@ -0,0 +1,12 @@ +from ._counter import TangoCounter +from ._detector import TangoDetector +from ._mover import TangoMover +from ._tango import DemoCounter, DemoMover + +__all__ = [ + "DemoCounter", + "DemoMover", + "TangoCounter", + "TangoMover", + "TangoDetector", +] diff --git a/src/ophyd_async/tango/demo/_counter.py b/src/ophyd_async/tango/demo/_counter.py new file mode 100644 index 0000000000..c8903dfd6d --- /dev/null +++ b/src/ophyd_async/tango/demo/_counter.py @@ -0,0 +1,37 @@ +from ophyd_async.core import ( + DEFAULT_TIMEOUT, + AsyncStatus, + ConfigSignal, + HintedSignal, + SignalR, + SignalRW, + SignalX, +) +from ophyd_async.tango import TangoReadable, tango_polling + + +# Enable device level polling, useful for servers that do not support events +# Polling for individual signal can be enabled with a dict +@tango_polling({"counts": (1.0, 0.1, 0.1), "sample_time": (0.1, 0.1, 0.1)}) +class TangoCounter(TangoReadable): + # Enter the name and type of the signals you want to use + # If type is None or Signal, the type will be inferred from the Tango device + counts: SignalR[int] + sample_time: SignalRW[float] + start: SignalX + _reset: SignalX + + def __init__(self, trl: str | None = "", name=""): + super().__init__(trl, name=name) + self.add_readables([self.counts], HintedSignal) + self.add_readables([self.sample_time], ConfigSignal) + + @AsyncStatus.wrap + async def trigger(self) -> None: + sample_time = await self.sample_time.get_value() + timeout = sample_time + DEFAULT_TIMEOUT + await self.start.trigger(wait=True, timeout=timeout) + + @AsyncStatus.wrap + async def reset(self) -> None: + await self._reset.trigger(wait=True, timeout=DEFAULT_TIMEOUT) diff --git a/src/ophyd_async/tango/demo/_detector.py b/src/ophyd_async/tango/demo/_detector.py new file mode 100644 index 0000000000..61025c18c7 --- /dev/null +++ b/src/ophyd_async/tango/demo/_detector.py @@ -0,0 +1,42 @@ +import asyncio + +from ophyd_async.core import ( + AsyncStatus, + DeviceVector, + StandardReadable, +) + +from ._counter import TangoCounter +from ._mover import TangoMover + + +class TangoDetector(StandardReadable): + def __init__(self, mover_trl: str, counter_trls: list[str], name=""): + # A detector device may be composed of tango sub-devices + self.mover = TangoMover(mover_trl) + self.counters = DeviceVector( + {i + 1: TangoCounter(c_trl) for i, c_trl in enumerate(counter_trls)} + ) + + # Define the readables for TangoDetector + # DeviceVectors are incompatible with AsyncReadable. Ignore until fixed. + self.add_readables([self.counters, self.mover]) # type: ignore + + super().__init__(name=name) + + def set(self, value): + return self.mover.set(value) + + def stop(self, success: bool = True) -> AsyncStatus: + return self.mover.stop(success) + + @AsyncStatus.wrap + async def trigger(self): + statuses = [] + for counter in self.counters.values(): + statuses.append(counter.reset()) + await asyncio.gather(*statuses) + statuses.clear() + for counter in self.counters.values(): + statuses.append(counter.trigger()) + await asyncio.gather(*statuses) diff --git a/src/ophyd_async/tango/demo/_mover.py b/src/ophyd_async/tango/demo/_mover.py new file mode 100644 index 0000000000..ce50356a55 --- /dev/null +++ b/src/ophyd_async/tango/demo/_mover.py @@ -0,0 +1,77 @@ +import asyncio + +from bluesky.protocols import Movable, Stoppable + +from ophyd_async.core import ( + CALCULATE_TIMEOUT, + DEFAULT_TIMEOUT, + AsyncStatus, + CalculatableTimeout, + ConfigSignal, + HintedSignal, + SignalR, + SignalRW, + SignalX, + WatchableAsyncStatus, + WatcherUpdate, + observe_value, + wait_for_value, +) +from ophyd_async.tango import TangoReadable, tango_polling +from tango import DevState + + +# Enable device level polling, useful for servers that do not support events +@tango_polling((0.1, 0.1, 0.1)) +class TangoMover(TangoReadable, Movable, Stoppable): + # Enter the name and type of the signals you want to use + # If type is None or Signal, the type will be inferred from the Tango device + position: SignalRW[float] + velocity: SignalRW[float] + state: SignalR[DevState] + _stop: SignalX + + def __init__(self, trl: str | None = "", name=""): + super().__init__(trl, name=name) + self.add_readables([self.position], HintedSignal) + self.add_readables([self.velocity], ConfigSignal) + self._set_success = True + + @WatchableAsyncStatus.wrap + async def set(self, value: float, timeout: CalculatableTimeout = CALCULATE_TIMEOUT): + self._set_success = True + (old_position, velocity) = await asyncio.gather( + self.position.get_value(), self.velocity.get_value() + ) + if timeout is CALCULATE_TIMEOUT: + assert velocity > 0, "Motor has zero velocity" + timeout = abs(value - old_position) / velocity + DEFAULT_TIMEOUT + + if not (isinstance(timeout, float) or timeout is None): + raise ValueError("Timeout must be a float or None") + # For this server, set returns immediately so this status should not be awaited + await self.position.set(value, wait=False, timeout=timeout) + + move_status = AsyncStatus( + wait_for_value(self.state, DevState.ON, timeout=timeout) + ) + + try: + async for current_position in observe_value( + self.position, done_status=move_status + ): + yield WatcherUpdate( + current=current_position, + initial=old_position, + target=value, + name=self.name, + ) + except RuntimeError as exc: + self._set_success = False + raise RuntimeError("Motor was stopped") from exc + if not self._set_success: + raise RuntimeError("Motor was stopped") + + def stop(self, success: bool = True) -> AsyncStatus: + self._set_success = success + return self._stop.trigger() diff --git a/src/ophyd_async/tango/demo/_tango/__init__.py b/src/ophyd_async/tango/demo/_tango/__init__.py new file mode 100644 index 0000000000..fd7965a5f1 --- /dev/null +++ b/src/ophyd_async/tango/demo/_tango/__init__.py @@ -0,0 +1,3 @@ +from ._servers import DemoCounter, DemoMover + +__all__ = ["DemoCounter", "DemoMover"] diff --git a/src/ophyd_async/tango/demo/_tango/_servers.py b/src/ophyd_async/tango/demo/_tango/_servers.py new file mode 100644 index 0000000000..d3332f881a --- /dev/null +++ b/src/ophyd_async/tango/demo/_tango/_servers.py @@ -0,0 +1,108 @@ +import asyncio +import time + +import numpy as np + +from tango import AttrWriteType, DevState, GreenMode +from tango.server import Device, attribute, command + + +class DemoMover(Device): + green_mode = GreenMode.Asyncio + _position = 0.0 + _setpoint = 0.0 + _velocity = 0.5 + _acceleration = 0.5 + _precision = 0.1 + _stop = False + DEVICE_CLASS_INITIAL_STATE = DevState.ON + + @attribute(dtype=float, access=AttrWriteType.READ_WRITE) + async def position(self): + return self._position + + async def write_position(self, new_position): + self._setpoint = new_position + await self.move() + + @attribute(dtype=float, access=AttrWriteType.READ_WRITE) + async def velocity(self): + return self._velocity + + async def write_velocity(self, value: float): + self._velocity = value + + @attribute(dtype=DevState, access=AttrWriteType.READ) + async def state(self): + return self.get_state() + + @command + async def stop(self): + self._stop = True + + @command + async def move(self): + self.set_state(DevState.MOVING) + await self._move(self._setpoint) + self.set_state(DevState.ON) + + async def _move(self, new_position): + self._setpoint = new_position + self._stop = False + step = 0.1 + while True: + if self._stop: + self._stop = False + break + if self._position < new_position: + self._position = self._position + self._velocity * step + else: + self._position = self._position - self._velocity * step + if abs(self._position - new_position) < self._precision: + self._position = new_position + break + await asyncio.sleep(step) + + +class DemoCounter(Device): + green_mode = GreenMode.Asyncio + _counts = 0 + _sample_time = 1.0 + + @attribute(dtype=int, access=AttrWriteType.READ) + async def counts(self): + return self._counts + + @attribute(dtype=float, access=AttrWriteType.READ_WRITE) + async def sample_time(self): + return self._sample_time + + async def write_sample_time(self, value: float): + self._sample_time = value + + @attribute(dtype=DevState, access=AttrWriteType.READ) + async def state(self): + return self.get_state() + + @command + async def reset(self): + self._counts = 0 + return self._counts + + @command + async def start(self): + self._counts = 0 + if self._sample_time <= 0.0: + return + self.set_state(DevState.MOVING) + await self._trigger() + self.set_state(DevState.ON) + + async def _trigger(self): + st = time.time() + while True: + ct = time.time() + if ct - st > self._sample_time: + break + self._counts += int(np.random.normal(1000, 100)) + await asyncio.sleep(0.1) diff --git a/src/ophyd_async/tango/signal/__init__.py b/src/ophyd_async/tango/signal/__init__.py new file mode 100644 index 0000000000..8923718b6a --- /dev/null +++ b/src/ophyd_async/tango/signal/__init__.py @@ -0,0 +1,39 @@ +from ._signal import ( + __tango_signal_auto, + infer_python_type, + infer_signal_character, + make_backend, + tango_signal_r, + tango_signal_rw, + tango_signal_w, + tango_signal_x, +) +from ._tango_transport import ( + AttributeProxy, + CommandProxy, + TangoSignalBackend, + ensure_proper_executor, + get_dtype_extended, + get_python_type, + get_tango_trl, + get_trl_descriptor, +) + +__all__ = ( + "AttributeProxy", + "CommandProxy", + "ensure_proper_executor", + "TangoSignalBackend", + "get_python_type", + "get_dtype_extended", + "get_trl_descriptor", + "get_tango_trl", + "infer_python_type", + "infer_signal_character", + "make_backend", + "tango_signal_r", + "tango_signal_rw", + "tango_signal_w", + "tango_signal_x", + "__tango_signal_auto", +) diff --git a/src/ophyd_async/tango/signal/_signal.py b/src/ophyd_async/tango/signal/_signal.py new file mode 100644 index 0000000000..f9274842d5 --- /dev/null +++ b/src/ophyd_async/tango/signal/_signal.py @@ -0,0 +1,223 @@ +"""Tango Signals over Pytango""" + +from __future__ import annotations + +from enum import Enum, IntEnum + +import numpy.typing as npt + +from ophyd_async.core import DEFAULT_TIMEOUT, SignalR, SignalRW, SignalW, SignalX, T +from ophyd_async.tango.signal._tango_transport import ( + TangoSignalBackend, + get_python_type, +) +from tango import AttrDataFormat, AttrWriteType, CmdArgType, DeviceProxy, DevState +from tango.asyncio import DeviceProxy as AsyncDeviceProxy + + +def make_backend( + datatype: type[T] | None, + read_trl: str = "", + write_trl: str = "", + device_proxy: DeviceProxy | None = None, +) -> TangoSignalBackend: + return TangoSignalBackend(datatype, read_trl, write_trl, device_proxy) + + +def tango_signal_rw( + datatype: type[T], + read_trl: str, + write_trl: str = "", + device_proxy: DeviceProxy | None = None, + timeout: float = DEFAULT_TIMEOUT, + name: str = "", +) -> SignalRW[T]: + """Create a `SignalRW` backed by 1 or 2 Tango Attribute/Command + + Parameters + ---------- + datatype: + Check that the Attribute/Command is of this type + read_trl: + The Attribute/Command to read and monitor + write_trl: + If given, use this Attribute/Command to write to, otherwise use read_trl + device_proxy: + If given, this DeviceProxy will be used + timeout: + The timeout for the read and write operations + name: + The name of the Signal + """ + backend = make_backend(datatype, read_trl, write_trl or read_trl, device_proxy) + return SignalRW(backend, timeout=timeout, name=name) + + +def tango_signal_r( + datatype: type[T], + read_trl: str, + device_proxy: DeviceProxy | None = None, + timeout: float = DEFAULT_TIMEOUT, + name: str = "", +) -> SignalR[T]: + """Create a `SignalR` backed by 1 Tango Attribute/Command + + Parameters + ---------- + datatype: + Check that the Attribute/Command is of this type + read_trl: + The Attribute/Command to read and monitor + device_proxy: + If given, this DeviceProxy will be used + timeout: + The timeout for the read operation + name: + The name of the Signal + """ + backend = make_backend(datatype, read_trl, read_trl, device_proxy) + return SignalR(backend, timeout=timeout, name=name) + + +def tango_signal_w( + datatype: type[T], + write_trl: str, + device_proxy: DeviceProxy | None = None, + timeout: float = DEFAULT_TIMEOUT, + name: str = "", +) -> SignalW[T]: + """Create a `SignalW` backed by 1 Tango Attribute/Command + + Parameters + ---------- + datatype: + Check that the Attribute/Command is of this type + write_trl: + The Attribute/Command to write to + device_proxy: + If given, this DeviceProxy will be used + timeout: + The timeout for the write operation + name: + The name of the Signal + """ + backend = make_backend(datatype, write_trl, write_trl, device_proxy) + return SignalW(backend, timeout=timeout, name=name) + + +def tango_signal_x( + write_trl: str, + device_proxy: DeviceProxy | None = None, + timeout: float = DEFAULT_TIMEOUT, + name: str = "", +) -> SignalX: + """Create a `SignalX` backed by 1 Tango Attribute/Command + + Parameters + ---------- + write_trl: + The Attribute/Command to write its initial value to on execute + device_proxy: + If given, this DeviceProxy will be used + timeout: + The timeout for the command operation + name: + The name of the Signal + """ + backend = make_backend(None, write_trl, write_trl, device_proxy) + return SignalX(backend, timeout=timeout, name=name) + + +async def __tango_signal_auto( + datatype: type[T] | None = None, + *, + trl: str, + device_proxy: DeviceProxy | None, + timeout: float = DEFAULT_TIMEOUT, + name: str = "", +) -> SignalW | SignalX | SignalR | SignalRW | None: + try: + signal_character = await infer_signal_character(trl, device_proxy) + except RuntimeError as e: + if "Commands with different in and out dtypes" in str(e): + return None + else: + raise e + + if datatype is None: + datatype = await infer_python_type(trl, device_proxy) + + backend = make_backend(datatype, trl, trl, device_proxy) + if signal_character == "RW": + return SignalRW(backend=backend, timeout=timeout, name=name) + if signal_character == "R": + return SignalR(backend=backend, timeout=timeout, name=name) + if signal_character == "W": + return SignalW(backend=backend, timeout=timeout, name=name) + if signal_character == "X": + return SignalX(backend=backend, timeout=timeout, name=name) + + +async def infer_python_type( + trl: str = "", proxy: DeviceProxy | None = None +) -> object | npt.NDArray | type[DevState] | IntEnum: + device_trl, tr_name = trl.rsplit("/", 1) + if proxy is None: + dev_proxy = await AsyncDeviceProxy(device_trl) + else: + dev_proxy = proxy + + if tr_name in dev_proxy.get_command_list(): + config = await dev_proxy.get_command_config(tr_name) + isarray, py_type, _ = get_python_type(config.in_type) + elif tr_name in dev_proxy.get_attribute_list(): + config = await dev_proxy.get_attribute_config(tr_name) + isarray, py_type, _ = get_python_type(config.data_type) + if py_type is Enum: + enum_dict = {label: i for i, label in enumerate(config.enum_labels)} + py_type = IntEnum("TangoEnum", enum_dict) + if config.data_format in [AttrDataFormat.SPECTRUM, AttrDataFormat.IMAGE]: + isarray = True + else: + raise RuntimeError(f"Cannot find {tr_name} in {device_trl}") + + if py_type is CmdArgType.DevState: + py_type = DevState + + return npt.NDArray[py_type] if isarray else py_type + + +async def infer_signal_character(trl, proxy: DeviceProxy | None = None) -> str: + device_trl, tr_name = trl.rsplit("/", 1) + if proxy is None: + dev_proxy = await AsyncDeviceProxy(device_trl) + else: + dev_proxy = proxy + + if tr_name in dev_proxy.get_pipe_list(): + raise NotImplementedError("Pipes are not supported") + + if tr_name not in dev_proxy.get_attribute_list(): + if tr_name not in dev_proxy.get_command_list(): + raise RuntimeError(f"Cannot find {tr_name} in {device_trl}") + + if tr_name in dev_proxy.get_attribute_list(): + config = await dev_proxy.get_attribute_config(tr_name) + if config.writable in [AttrWriteType.READ_WRITE, AttrWriteType.READ_WITH_WRITE]: + return "RW" + elif config.writable == AttrWriteType.READ: + return "R" + else: + return "W" + + if tr_name in dev_proxy.get_command_list(): + config = await dev_proxy.get_command_config(tr_name) + if config.in_type == CmdArgType.DevVoid: + return "X" + elif config.in_type != config.out_type: + raise RuntimeError( + "Commands with different in and out dtypes are not" " supported" + ) + else: + return "RW" + raise RuntimeError(f"Unable to infer signal character for {trl}") diff --git a/src/ophyd_async/tango/signal/_tango_transport.py b/src/ophyd_async/tango/signal/_tango_transport.py new file mode 100644 index 0000000000..54cea4b610 --- /dev/null +++ b/src/ophyd_async/tango/signal/_tango_transport.py @@ -0,0 +1,764 @@ +import asyncio +import functools +import time +from abc import abstractmethod +from asyncio import CancelledError +from collections.abc import Callable, Coroutine +from enum import Enum +from typing import Any, TypeVar, cast + +import numpy as np +from bluesky.protocols import Descriptor, Reading + +from ophyd_async.core import ( + DEFAULT_TIMEOUT, + AsyncStatus, + NotConnected, + ReadingValueCallback, + SignalBackend, + T, + get_dtype, + get_unique, + wait_for_connection, +) +from tango import ( + AttrDataFormat, + AttributeInfoEx, + CmdArgType, + CommandInfo, + DevFailed, # type: ignore + DeviceProxy, + DevState, + EventType, +) +from tango.asyncio import DeviceProxy as AsyncDeviceProxy +from tango.asyncio_executor import ( + AsyncioExecutor, + get_global_executor, + set_global_executor, +) +from tango.utils import is_array, is_binary, is_bool, is_float, is_int, is_str + +# time constant to wait for timeout +A_BIT = 1e-5 + +R = TypeVar("R") + + +def ensure_proper_executor( + func: Callable[..., Coroutine[Any, Any, R]], +) -> Callable[..., Coroutine[Any, Any, R]]: + @functools.wraps(func) + async def wrapper(self: Any, *args: Any, **kwargs: Any) -> R: + current_executor: AsyncioExecutor = get_global_executor() # type: ignore + if not current_executor.in_executor_context(): # type: ignore + set_global_executor(AsyncioExecutor()) + return await func(self, *args, **kwargs) + + return cast(Callable[..., Coroutine[Any, Any, R]], wrapper) + + +def get_python_type(tango_type: CmdArgType) -> tuple[bool, object, str]: + array = is_array(tango_type) + if is_int(tango_type, True): + return array, int, "integer" + if is_float(tango_type, True): + return array, float, "number" + if is_bool(tango_type, True): + return array, bool, "integer" + if is_str(tango_type, True): + return array, str, "string" + if is_binary(tango_type, True): + return array, list[str], "string" + if tango_type == CmdArgType.DevEnum: + return array, Enum, "string" + if tango_type == CmdArgType.DevState: + return array, CmdArgType.DevState, "string" + if tango_type == CmdArgType.DevUChar: + return array, int, "integer" + if tango_type == CmdArgType.DevVoid: + return array, None, "string" + raise TypeError("Unknown TangoType") + + +class TangoProxy: + support_events: bool = True + _proxy: DeviceProxy + _name: str + + def __init__(self, device_proxy: DeviceProxy, name: str): + self._proxy = device_proxy + self._name = name + + async def connect(self) -> None: + """perform actions after proxy is connected, e.g. checks if signal + can be subscribed""" + + @abstractmethod + async def get(self) -> object: + """Get value from TRL""" + + @abstractmethod + async def get_w_value(self) -> object: + """Get last written value from TRL""" + + @abstractmethod + async def put( + self, value: object | None, wait: bool = True, timeout: float | None = None + ) -> AsyncStatus | None: + """Put value to TRL""" + + @abstractmethod + async def get_config(self) -> AttributeInfoEx | CommandInfo: + """Get TRL config async""" + + @abstractmethod + async def get_reading(self) -> Reading: + """Get reading from TRL""" + + @abstractmethod + def has_subscription(self) -> bool: + """indicates, that this trl already subscribed""" + + @abstractmethod + def subscribe_callback(self, callback: ReadingValueCallback | None): + """subscribe tango CHANGE event to callback""" + + @abstractmethod + def unsubscribe_callback(self): + """delete CHANGE event subscription""" + + @abstractmethod + def set_polling( + self, + allow_polling: bool = True, + polling_period: float = 0.1, + abs_change=None, + rel_change=None, + ): + """Set polling parameters""" + + +class AttributeProxy(TangoProxy): + _callback: ReadingValueCallback | None = None + _eid: int | None = None + _poll_task: asyncio.Task | None = None + _abs_change: float | None = None + _rel_change: float | None = 0.1 + _polling_period: float = 0.1 + _allow_polling: bool = False + exception: BaseException | None = None + _last_reading: Reading = Reading(value=None, timestamp=0, alarm_severity=0) + + async def connect(self) -> None: + try: + # I have to typehint proxy as tango.DeviceProxy because + # tango.asyncio.DeviceProxy cannot be used as a typehint. + # This means pyright will not be able to see that + # subscribe_event is awaitable. + eid = await self._proxy.subscribe_event( # type: ignore + self._name, EventType.CHANGE_EVENT, self._event_processor + ) + await self._proxy.unsubscribe_event(eid) + self.support_events = True + except Exception: + pass + + @ensure_proper_executor + async def get(self) -> Coroutine[Any, Any, object]: + attr = await self._proxy.read_attribute(self._name) + return attr.value + + @ensure_proper_executor + async def get_w_value(self) -> object: + attr = await self._proxy.read_attribute(self._name) + return attr.w_value + + @ensure_proper_executor + async def put( + self, value: object | None, wait: bool = True, timeout: float | None = None + ) -> AsyncStatus | None: + if wait: + try: + + async def _write(): + return await self._proxy.write_attribute(self._name, value) + + task = asyncio.create_task(_write()) + await asyncio.wait_for(task, timeout) + except asyncio.TimeoutError as te: + raise TimeoutError(f"{self._name} attr put failed: Timeout") from te + except DevFailed as de: + raise RuntimeError( + f"{self._name} device" f" failure: {de.args[0].desc}" + ) from de + + else: + rid = await self._proxy.write_attribute_asynch(self._name, value) + + async def wait_for_reply(rd: int, to: float | None): + start_time = time.time() + while True: + try: + # I have to typehint proxy as tango.DeviceProxy because + # tango.asyncio.DeviceProxy cannot be used as a typehint. + # This means pyright will not be able to see that + # write_attribute_reply is awaitable. + await self._proxy.write_attribute_reply(rd) # type: ignore + break + except DevFailed as exc: + if exc.args[0].reason == "API_AsynReplyNotArrived": + await asyncio.sleep(A_BIT) + if to and (time.time() - start_time > to): + raise TimeoutError( + f"{self._name} attr put failed:" f" Timeout" + ) from exc + else: + raise RuntimeError( + f"{self._name} device failure:" f" {exc.args[0].desc}" + ) from exc + + return AsyncStatus(wait_for_reply(rid, timeout)) + + @ensure_proper_executor + async def get_config(self) -> AttributeInfoEx: + return await self._proxy.get_attribute_config(self._name) + + @ensure_proper_executor + async def get_reading(self) -> Reading: + attr = await self._proxy.read_attribute(self._name) + reading = Reading( + value=attr.value, timestamp=attr.time.totime(), alarm_severity=attr.quality + ) + self._last_reading = reading + return reading + + def has_subscription(self) -> bool: + return bool(self._callback) + + def subscribe_callback(self, callback: ReadingValueCallback | None): + # If the attribute supports events, then we can subscribe to them + # If the callback is not a callable, then we raise an error + if callback is not None and not callable(callback): + raise RuntimeError("Callback must be a callable") + + self._callback = callback + if self.support_events: + """add user callback to CHANGE event subscription""" + if not self._eid: + self._eid = self._proxy.subscribe_event( + self._name, + EventType.CHANGE_EVENT, + self._event_processor, + green_mode=False, + ) + elif self._allow_polling: + """start polling if no events supported""" + if self._callback is not None: + + async def _poll(): + while True: + try: + await self.poll() + except RuntimeError as exc: + self.exception = exc + await asyncio.sleep(1) + + self._poll_task = asyncio.create_task(_poll()) + else: + self.unsubscribe_callback() + raise RuntimeError( + f"Cannot set event for {self._name}. " + "Cannot set a callback on an attribute that does not support events and" + " for which polling is disabled." + ) + + def unsubscribe_callback(self): + if self._eid: + self._proxy.unsubscribe_event(self._eid, green_mode=False) + self._eid = None + if self._poll_task: + self._poll_task.cancel() + self._poll_task = None + if self._callback is not None: + # Call the callback with the last reading + try: + self._callback(self._last_reading, self._last_reading["value"]) + except TypeError: + pass + self._callback = None + + def _event_processor(self, event): + if not event.err: + value = event.attr_value.value + reading = Reading( + value=value, + timestamp=event.get_date().totime(), + alarm_severity=event.attr_value.quality, + ) + if self._callback is not None: + self._callback(reading, value) + + async def poll(self): + """ + Poll the attribute and call the callback if the value has changed by more + than the absolute or relative change. This function is used when an attribute + that does not support events is cached or a callback is passed to it. + """ + try: + last_reading = await self.get_reading() + flag = 0 + # Initial reading + if self._callback is not None: + self._callback(last_reading, last_reading["value"]) + except Exception as e: + raise RuntimeError(f"Could not poll the attribute: {e}") from e + + try: + # If the value is a number, we can check for changes + if isinstance(last_reading["value"], int | float): + while True: + await asyncio.sleep(self._polling_period) + reading = await self.get_reading() + if reading is None or reading["value"] is None: + continue + diff = abs(reading["value"] - last_reading["value"]) + if self._abs_change is not None and diff >= abs(self._abs_change): + if self._callback is not None: + self._callback(reading, reading["value"]) + flag = 0 + + elif ( + self._rel_change is not None + and diff >= self._rel_change * abs(last_reading["value"]) + ): + if self._callback is not None: + self._callback(reading, reading["value"]) + flag = 0 + + else: + flag = (flag + 1) % 4 + if flag == 0 and self._callback is not None: + self._callback(reading, reading["value"]) + + last_reading = reading.copy() + if self._callback is None: + break + # If the value is not a number, we can only poll + else: + while True: + await asyncio.sleep(self._polling_period) + flag = (flag + 1) % 4 + if flag == 0: + reading = await self.get_reading() + if reading is None or reading["value"] is None: + continue + if isinstance(reading["value"], np.ndarray): + if not np.array_equal( + reading["value"], last_reading["value"] + ): + if self._callback is not None: + self._callback(reading, reading["value"]) + else: + break + else: + if reading["value"] != last_reading["value"]: + if self._callback is not None: + self._callback(reading, reading["value"]) + else: + break + last_reading = reading.copy() + except Exception as e: + raise RuntimeError(f"Could not poll the attribute: {e}") from e + + def set_polling( + self, + allow_polling: bool = False, + polling_period: float = 0.5, + abs_change: float | None = None, + rel_change: float | None = 0.1, + ): + """ + Set the polling parameters. + """ + self._allow_polling = allow_polling + self._polling_period = polling_period + self._abs_change = abs_change + self._rel_change = rel_change + + +class CommandProxy(TangoProxy): + _last_reading: Reading = Reading(value=None, timestamp=0, alarm_severity=0) + + def subscribe_callback(self, callback: ReadingValueCallback | None) -> None: + raise NotImplementedError("Cannot subscribe to commands") + + def unsubscribe_callback(self) -> None: + raise NotImplementedError("Cannot unsubscribe from commands") + + async def get(self) -> object: + return self._last_reading["value"] + + async def get_w_value(self) -> object: + return self._last_reading["value"] + + async def connect(self) -> None: + pass + + @ensure_proper_executor + async def put( + self, value: object | None, wait: bool = True, timeout: float | None = None + ) -> AsyncStatus | None: + if wait: + try: + + async def _put(): + return await self._proxy.command_inout(self._name, value) + + task = asyncio.create_task(_put()) + val = await asyncio.wait_for(task, timeout) + self._last_reading = Reading( + value=val, timestamp=time.time(), alarm_severity=0 + ) + except asyncio.TimeoutError as te: + raise TimeoutError(f"{self._name} command failed: Timeout") from te + except DevFailed as de: + raise RuntimeError( + f"{self._name} device" f" failure: {de.args[0].desc}" + ) from de + + else: + rid = self._proxy.command_inout_asynch(self._name, value) + + async def wait_for_reply(rd: int, to: float | None): + start_time = time.time() + while True: + try: + reply_value = self._proxy.command_inout_reply(rd) + self._last_reading = Reading( + value=reply_value, timestamp=time.time(), alarm_severity=0 + ) + break + except DevFailed as de_exc: + if de_exc.args[0].reason == "API_AsynReplyNotArrived": + await asyncio.sleep(A_BIT) + if to and time.time() - start_time > to: + raise TimeoutError( + "Timeout while waiting for command reply" + ) from de_exc + else: + raise RuntimeError( + f"{self._name} device failure:" + f" {de_exc.args[0].desc}" + ) from de_exc + + return AsyncStatus(wait_for_reply(rid, timeout)) + + @ensure_proper_executor + async def get_config(self) -> CommandInfo: + return await self._proxy.get_command_config(self._name) + + async def get_reading(self) -> Reading: + reading = Reading( + value=self._last_reading["value"], + timestamp=self._last_reading["timestamp"], + alarm_severity=self._last_reading.get("alarm_severity", 0), + ) + return reading + + def set_polling( + self, + allow_polling: bool = False, + polling_period: float = 0.5, + abs_change: float | None = None, + rel_change: float | None = 0.1, + ): + pass + + +def get_dtype_extended(datatype) -> object | None: + # DevState tango type does not have numpy equivalents + dtype = get_dtype(datatype) + if dtype == np.object_: + if datatype.__args__[1].__args__[0] == DevState: + dtype = CmdArgType.DevState + return dtype + + +def get_trl_descriptor( + datatype: type | None, + tango_resource: str, + tr_configs: dict[str, AttributeInfoEx | CommandInfo], +) -> Descriptor: + tr_dtype = {} + for tr_name, config in tr_configs.items(): + if isinstance(config, AttributeInfoEx): + _, dtype, descr = get_python_type(config.data_type) + tr_dtype[tr_name] = config.data_format, dtype, descr + elif isinstance(config, CommandInfo): + if ( + config.in_type != CmdArgType.DevVoid + and config.out_type != CmdArgType.DevVoid + and config.in_type != config.out_type + ): + raise RuntimeError( + "Commands with different in and out dtypes are not supported" + ) + array, dtype, descr = get_python_type( + config.in_type + if config.in_type != CmdArgType.DevVoid + else config.out_type + ) + tr_dtype[tr_name] = ( + AttrDataFormat.SPECTRUM if array else AttrDataFormat.SCALAR, + dtype, + descr, + ) + else: + raise RuntimeError(f"Unknown config type: {type(config)}") + tr_format, tr_dtype, tr_dtype_desc = get_unique(tr_dtype, "typeids") + + # tango commands are limited in functionality: + # they do not have info about shape and Enum labels + trl_config = list(tr_configs.values())[0] + max_x: int = ( + trl_config.max_dim_x + if hasattr(trl_config, "max_dim_x") + else np.iinfo(np.int32).max + ) + max_y: int = ( + trl_config.max_dim_y + if hasattr(trl_config, "max_dim_y") + else np.iinfo(np.int32).max + ) + # is_attr = hasattr(trl_config, "enum_labels") + # trl_choices = list(trl_config.enum_labels) if is_attr else [] + + if tr_format in [AttrDataFormat.SPECTRUM, AttrDataFormat.IMAGE]: + # This is an array + if datatype: + # Check we wanted an array of this type + dtype = get_dtype_extended(datatype) + if not dtype: + raise TypeError( + f"{tango_resource} has type [{tr_dtype}] not {datatype.__name__}" + ) + if dtype != tr_dtype: + raise TypeError(f"{tango_resource} has type [{tr_dtype}] not [{dtype}]") + + if tr_format == AttrDataFormat.SPECTRUM: + return Descriptor(source=tango_resource, dtype="array", shape=[max_x]) + elif tr_format == AttrDataFormat.IMAGE: + return Descriptor( + source=tango_resource, dtype="array", shape=[max_y, max_x] + ) + + else: + if tr_dtype in (Enum, CmdArgType.DevState): + # if tr_dtype == CmdArgType.DevState: + # trl_choices = list(DevState.names.keys()) + + if datatype: + if not issubclass(datatype, Enum | DevState): + raise TypeError( + f"{tango_resource} has type Enum not {datatype.__name__}" + ) + # if tr_dtype == Enum and is_attr: + # if isinstance(datatype, DevState): + # choices = tuple(v.name for v in datatype) + # if set(choices) != set(trl_choices): + # raise TypeError( + # f"{tango_resource} has choices {trl_choices} " + # f"not {choices}" + # ) + return Descriptor(source=tango_resource, dtype="string", shape=[]) + else: + if datatype and not issubclass(tr_dtype, datatype): + raise TypeError( + f"{tango_resource} has type {tr_dtype.__name__} " + f"not {datatype.__name__}" + ) + return Descriptor(source=tango_resource, dtype=tr_dtype_desc, shape=[]) + + raise RuntimeError(f"Error getting descriptor for {tango_resource}") + + +async def get_tango_trl( + full_trl: str, device_proxy: DeviceProxy | TangoProxy | None +) -> TangoProxy: + if isinstance(device_proxy, TangoProxy): + return device_proxy + device_trl, trl_name = full_trl.rsplit("/", 1) + trl_name = trl_name.lower() + if device_proxy is None: + device_proxy = await AsyncDeviceProxy(device_trl) + + # all attributes can be always accessible with low register + if isinstance(device_proxy, DeviceProxy): + all_attrs = [ + attr_name.lower() for attr_name in device_proxy.get_attribute_list() + ] + else: + raise TypeError( + f"device_proxy must be an instance of DeviceProxy for {full_trl}" + ) + if trl_name in all_attrs: + return AttributeProxy(device_proxy, trl_name) + + # all commands can be always accessible with low register + all_cmds = [cmd_name.lower() for cmd_name in device_proxy.get_command_list()] + if trl_name in all_cmds: + return CommandProxy(device_proxy, trl_name) + + # If version is below tango 9, then pipes are not supported + if device_proxy.info().server_version >= 9: + # all pipes can be always accessible with low register + all_pipes = [pipe_name.lower() for pipe_name in device_proxy.get_pipe_list()] + if trl_name in all_pipes: + raise NotImplementedError("Pipes are not supported") + + raise RuntimeError(f"{trl_name} cannot be found in {device_proxy.name()}") + + +class TangoSignalBackend(SignalBackend[T]): + def __init__( + self, + datatype: type[T] | None, + read_trl: str = "", + write_trl: str = "", + device_proxy: DeviceProxy | None = None, + ): + self.device_proxy = device_proxy + self.datatype = datatype + self.read_trl = read_trl + self.write_trl = write_trl + self.proxies: dict[str, TangoProxy | DeviceProxy | None] = { + read_trl: self.device_proxy, + write_trl: self.device_proxy, + } + self.trl_configs: dict[str, AttributeInfoEx] = {} + self.descriptor: Descriptor = {} # type: ignore + self._polling: tuple[bool, float, float | None, float | None] = ( + False, + 0.1, + None, + 0.1, + ) + self.support_events: bool = True + self.status: AsyncStatus | None = None + + @classmethod + def datatype_allowed(cls, dtype: Any) -> bool: + return dtype in (int, float, str, bool, np.ndarray, Enum, DevState) + + def set_trl(self, read_trl: str = "", write_trl: str = ""): + self.read_trl = read_trl + self.write_trl = write_trl if write_trl else read_trl + self.proxies = { + read_trl: self.device_proxy, + write_trl: self.device_proxy, + } + + def source(self, name: str) -> str: + return self.read_trl + + async def _connect_and_store_config(self, trl: str) -> None: + if not trl: + raise RuntimeError(f"trl not set for {self}") + try: + self.proxies[trl] = await get_tango_trl(trl, self.proxies[trl]) + if self.proxies[trl] is None: + raise NotConnected(f"Not connected to {trl}") + # Pyright does not believe that self.proxies[trl] is not None despite + # the check above + await self.proxies[trl].connect() # type: ignore + self.trl_configs[trl] = await self.proxies[trl].get_config() # type: ignore + self.proxies[trl].support_events = self.support_events # type: ignore + except CancelledError as ce: + raise NotConnected(f"Could not connect to {trl}") from ce + + async def connect(self, timeout: float = DEFAULT_TIMEOUT) -> None: + if not self.read_trl: + raise RuntimeError(f"trl not set for {self}") + if self.read_trl != self.write_trl: + # Different, need to connect both + await wait_for_connection( + read_trl=self._connect_and_store_config(self.read_trl), + write_trl=self._connect_and_store_config(self.write_trl), + ) + else: + # The same, so only need to connect one + await self._connect_and_store_config(self.read_trl) + self.proxies[self.read_trl].set_polling(*self._polling) # type: ignore + self.descriptor = get_trl_descriptor( + self.datatype, self.read_trl, self.trl_configs + ) + + async def put(self, value: T | None, wait=True, timeout=None) -> None: + if self.proxies[self.write_trl] is None: + raise NotConnected(f"Not connected to {self.write_trl}") + self.status = None + put_status = await self.proxies[self.write_trl].put(value, wait, timeout) # type: ignore + self.status = put_status + + async def get_datakey(self, source: str) -> Descriptor: + return self.descriptor + + async def get_reading(self) -> Reading: + if self.proxies[self.read_trl] is None: + raise NotConnected(f"Not connected to {self.read_trl}") + return await self.proxies[self.read_trl].get_reading() # type: ignore + + async def get_value(self) -> T: + if self.proxies[self.read_trl] is None: + raise NotConnected(f"Not connected to {self.read_trl}") + proxy = self.proxies[self.read_trl] + if proxy is None: + raise NotConnected(f"Not connected to {self.read_trl}") + return cast(T, await proxy.get()) + + async def get_setpoint(self) -> T: + if self.proxies[self.write_trl] is None: + raise NotConnected(f"Not connected to {self.write_trl}") + proxy = self.proxies[self.write_trl] + if proxy is None: + raise NotConnected(f"Not connected to {self.write_trl}") + return cast(T, await proxy.get_w_value()) + + def set_callback(self, callback: ReadingValueCallback | None) -> None: + if self.proxies[self.read_trl] is None: + raise NotConnected(f"Not connected to {self.read_trl}") + if self.support_events is False and self._polling[0] is False: + raise RuntimeError( + f"Cannot set event for {self.read_trl}. " + "Cannot set a callback on an attribute that does not support events and" + " for which polling is disabled." + ) + + if callback: + try: + assert not self.proxies[self.read_trl].has_subscription() # type: ignore + self.proxies[self.read_trl].subscribe_callback(callback) # type: ignore + except AssertionError as ae: + raise RuntimeError( + "Cannot set a callback when one" " is already set" + ) from ae + except RuntimeError as exc: + raise RuntimeError( + f"Cannot set callback" f" for {self.read_trl}. {exc}" + ) from exc + + else: + self.proxies[self.read_trl].unsubscribe_callback() # type: ignore + + def set_polling( + self, + allow_polling: bool = True, + polling_period: float = 0.1, + abs_change: float | None = None, + rel_change: float | None = 0.1, + ): + self._polling = (allow_polling, polling_period, abs_change, rel_change) + + def allow_events(self, allow: bool = True): + self.support_events = allow diff --git a/tests/conftest.py b/tests/conftest.py index fa0a8fb800..8e0cf546f8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -235,3 +235,11 @@ def one_shot_trigger_info() -> TriggerInfo: deadtime=None, livetime=None, ) + + +def pytest_collection_modifyitems(config, items): + tango_dir = "tests/tango" + + for item in items: + if tango_dir in str(item.fspath): + item.add_marker(pytest.mark.forked) diff --git a/tests/tango/test_base_device.py b/tests/tango/test_base_device.py new file mode 100644 index 0000000000..3c23461f99 --- /dev/null +++ b/tests/tango/test_base_device.py @@ -0,0 +1,400 @@ +import asyncio +import time +from enum import Enum, IntEnum + +import bluesky.plan_stubs as bps +import bluesky.plans as bp +import numpy as np +import numpy.typing as npt +import pytest +from bluesky import RunEngine + +import tango +from ophyd_async.core import DeviceCollector, HintedSignal, SignalRW, T +from ophyd_async.tango import TangoReadable, get_python_type +from ophyd_async.tango.demo import ( + DemoCounter, + DemoMover, + TangoDetector, +) +from tango import ( + AttrDataFormat, + AttrQuality, + AttrWriteType, + CmdArgType, + DevState, +) +from tango import DeviceProxy as SyncDeviceProxy +from tango.asyncio import DeviceProxy as AsyncDeviceProxy +from tango.asyncio_executor import set_global_executor +from tango.server import Device, attribute, command +from tango.test_context import MultiDeviceTestContext +from tango.test_utils import assert_close + + +class TestEnum(IntEnum): + __test__ = False + A = 0 + B = 1 + + +# -------------------------------------------------------------------- +# fixtures to run Echo device +# -------------------------------------------------------------------- + +TESTED_FEATURES = ["array", "limitedvalue", "justvalue"] + + +# -------------------------------------------------------------------- +class TestDevice(Device): + __test__ = False + + _array = [[1, 2, 3], [4, 5, 6]] + + _justvalue = 5 + _writeonly = 6 + _readonly = 7 + _slow_attribute = 1.0 + + _floatvalue = 1.0 + + _readback = 1.0 + _setpoint = 1.0 + + _label = "Test Device" + + _limitedvalue = 3 + + @attribute(dtype=float, access=AttrWriteType.READ) + def readback(self): + return self._readback + + @attribute(dtype=float, access=AttrWriteType.WRITE) + def setpoint(self): + return self._setpoint + + def write_setpoint(self, value: float): + self._setpoint = value + self._readback = value + + @attribute(dtype=str, access=AttrWriteType.READ_WRITE) + def label(self): + return self._label + + def write_label(self, value: str): + self._label = value + + @attribute(dtype=float, access=AttrWriteType.READ_WRITE) + def floatvalue(self): + return self._floatvalue + + def write_floatvalue(self, value: float): + self._floatvalue = value + + @attribute(dtype=int, access=AttrWriteType.READ_WRITE, polling_period=100) + def justvalue(self): + return self._justvalue + + def write_justvalue(self, value: int): + self._justvalue = value + + @attribute(dtype=int, access=AttrWriteType.WRITE, polling_period=100) + def writeonly(self): + return self._writeonly + + def write_writeonly(self, value: int): + self._writeonly = value + + @attribute(dtype=int, access=AttrWriteType.READ, polling_period=100) + def readonly(self): + return self._readonly + + @attribute( + dtype=float, + access=AttrWriteType.READ_WRITE, + dformat=AttrDataFormat.IMAGE, + max_dim_x=3, + max_dim_y=2, + ) + def array(self) -> list[list[float]]: + return self._array + + def write_array(self, array: list[list[float]]): + self._array = array + + @attribute( + dtype=float, + access=AttrWriteType.READ_WRITE, + min_value=0, + min_alarm=1, + min_warning=2, + max_warning=4, + max_alarm=5, + max_value=6, + ) + def limitedvalue(self) -> float: + return self._limitedvalue + + def write_limitedvalue(self, value: float): + self._limitedvalue = value + + @attribute(dtype=float, access=AttrWriteType.WRITE) + def slow_attribute(self) -> float: + return self._slow_attribute + + def write_slow_attribute(self, value: float): + time.sleep(0.2) + self._slow_attribute = value + + @attribute(dtype=float, access=AttrWriteType.READ_WRITE) + def raise_exception_attr(self) -> float: + raise + + def write_raise_exception_attr(self, value: float): + raise + + @command + def clear(self) -> str: + # self.info_stream("Received clear command") + return "Received clear command" + + @command + def slow_command(self) -> str: + time.sleep(0.2) + return "Completed slow command" + + @command + def echo(self, value: str) -> str: + return value + + @command + def raise_exception_cmd(self): + raise + + +# -------------------------------------------------------------------- +class TestTangoReadable(TangoReadable): + __test__ = False + justvalue: SignalRW[int] + array: SignalRW[npt.NDArray[float]] + limitedvalue: SignalRW[float] + + def __init__( + self, + trl: str | None = None, + device_proxy: SyncDeviceProxy | None = None, + name: str = "", + ) -> None: + super().__init__(trl, device_proxy, name=name) + self.add_readables( + [self.justvalue, self.array, self.limitedvalue], HintedSignal.uncached + ) + + +# -------------------------------------------------------------------- +async def describe_class(fqtrl): + description = {} + values = {} + dev = await AsyncDeviceProxy(fqtrl) + + for name in TESTED_FEATURES: + if name in dev.get_attribute_list(): + attr_conf = await dev.get_attribute_config(name) + attr_value = await dev.read_attribute(name) + value = attr_value.value + _, _, descr = get_python_type(attr_conf.data_type) + max_x = attr_conf.max_dim_x + max_y = attr_conf.max_dim_y + if attr_conf.data_format == AttrDataFormat.SCALAR: + is_array = False + shape = [] + elif attr_conf.data_format == AttrDataFormat.SPECTRUM: + is_array = True + shape = [max_x] + else: + is_array = True + shape = [max_y, max_x] + + elif name in dev.get_command_list(): + cmd_conf = await dev.get_command_config(name) + _, _, descr = get_python_type( + cmd_conf.in_type + if cmd_conf.in_type != CmdArgType.DevVoid + else cmd_conf.out_type + ) + is_array = False + shape = [] + value = getattr(dev, name)() + + else: + raise RuntimeError( + f"Cannot find {name} in attributes/commands (pipes are not supported!)" + ) + + description[f"test_device-{name}"] = { + "source": f"{fqtrl}/{name}", # type: ignore + "dtype": "array" if is_array else descr, + "shape": shape, + } + + values[f"test_device-{name}"] = { + "value": value, + "timestamp": pytest.approx(time.time()), + "alarm_severity": AttrQuality.ATTR_VALID, + } + + return values, description + + +# -------------------------------------------------------------------- +def get_test_descriptor(python_type: type[T], value: T, is_cmd: bool) -> dict: + if python_type in [bool, int]: + return {"dtype": "integer", "shape": []} + if python_type in [float]: + return {"dtype": "number", "shape": []} + if python_type in [str]: + return {"dtype": "string", "shape": []} + if issubclass(python_type, DevState): + return {"dtype": "string", "shape": [], "choices": list(DevState.names.keys())} + if issubclass(python_type, Enum): + return { + "dtype": "string", + "shape": [], + "choices": [] if is_cmd else [member.name for member in python_type], + } + + return { + "dtype": "array", + "shape": [np.Inf] if is_cmd else list(np.array(value).shape), + } + + +# -------------------------------------------------------------------- +@pytest.fixture(scope="module") +def tango_test_device(): + with MultiDeviceTestContext( + [{"class": TestDevice, "devices": [{"name": "test/device/1"}]}], process=True + ) as context: + yield context.get_device_access("test/device/1") + + +# -------------------------------------------------------------------- +@pytest.fixture(scope="module") +def demo_test_context(): + content = ( + { + "class": DemoMover, + "devices": [{"name": "demo/motor/1"}], + }, + { + "class": DemoCounter, + "devices": [{"name": "demo/counter/1"}, {"name": "demo/counter/2"}], + }, + ) + yield MultiDeviceTestContext(content) + + +# -------------------------------------------------------------------- +@pytest.fixture(autouse=True) +def reset_tango_asyncio(): + set_global_executor(None) + + +# -------------------------------------------------------------------- +def compare_values(expected, received): + assert set(expected.keys()) == set(received.keys()) + for k, v in expected.items(): + for _k, _v in v.items(): + assert_close(_v, received[k][_k]) + + +# -------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_connect(tango_test_device): + values, description = await describe_class(tango_test_device) + + async with DeviceCollector(): + test_device = TestTangoReadable(tango_test_device) + + assert test_device.name == "test_device" + assert description == await test_device.describe() + compare_values(values, await test_device.read()) + + +# -------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_set_trl(tango_test_device): + values, description = await describe_class(tango_test_device) + + # async with DeviceCollector(): + # test_device = TestTangoReadable(trl=tango_test_device) + test_device = TestTangoReadable(name="test_device") + + with pytest.raises(ValueError) as excinfo: + test_device.set_trl(0) + assert "TRL must be a string." in str(excinfo.value) + + test_device.set_trl(tango_test_device) + await test_device.connect() + + assert test_device.name == "test_device" + assert description == await test_device.describe() + compare_values(values, await test_device.read()) + + +# -------------------------------------------------------------------- +@pytest.mark.asyncio +@pytest.mark.parametrize("proxy", [True, False, None]) +async def test_connect_proxy(tango_test_device, proxy: bool | None): + if proxy is None: + test_device = TestTangoReadable(trl=tango_test_device) + test_device.proxy = None + await test_device.connect() + assert isinstance(test_device.proxy, tango._tango.DeviceProxy) + elif proxy: + proxy = await AsyncDeviceProxy(tango_test_device) + test_device = TestTangoReadable(device_proxy=proxy) + await test_device.connect() + assert isinstance(test_device.proxy, tango._tango.DeviceProxy) + else: + proxy = None + test_device = TestTangoReadable(device_proxy=proxy) + assert test_device + + +# -------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_with_bluesky(tango_test_device): + # now let's do some bluesky stuff + RE = RunEngine() + with DeviceCollector(): + device = TestTangoReadable(tango_test_device) + RE(bp.count([device])) + + +# -------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_tango_demo(demo_test_context): + with demo_test_context: + detector = TangoDetector( + name="detector", + mover_trl="demo/motor/1", + counter_trls=["demo/counter/1", "demo/counter/2"], + ) + await detector.connect() + await detector.trigger() + await detector.mover.velocity.set(0.5) + + RE = RunEngine() + + RE(bps.read(detector)) + RE(bps.mv(detector, 0)) + RE(bp.count(list(detector.counters.values()))) + + set_status = detector.set(1.0) + await asyncio.sleep(0.1) + stop_status = detector.stop() + await set_status + await stop_status + assert all([set_status.done, stop_status.done]) + assert all([set_status.success, stop_status.success]) diff --git a/tests/tango/test_tango_signals.py b/tests/tango/test_tango_signals.py new file mode 100644 index 0000000000..40a5b591aa --- /dev/null +++ b/tests/tango/test_tango_signals.py @@ -0,0 +1,775 @@ +import asyncio +import textwrap +import time +from enum import Enum, IntEnum +from random import choice +from typing import Any + +import numpy as np +import numpy.typing as npt +import pytest +from bluesky.protocols import Reading +from test_base_device import TestDevice + +from ophyd_async.core import SignalBackend, SignalR, SignalRW, SignalW, SignalX, T +from ophyd_async.tango import ( + TangoSignalBackend, + __tango_signal_auto, + tango_signal_r, + tango_signal_rw, + tango_signal_w, + tango_signal_x, +) +from tango import AttrDataFormat, AttrWriteType, DevState +from tango.asyncio import DeviceProxy +from tango.asyncio_executor import set_global_executor +from tango.server import Device, attribute, command +from tango.test_context import MultiDeviceTestContext +from tango.test_utils import assert_close + +# -------------------------------------------------------------------- +""" +Since TangoTest does not support EchoMode, we create our own Device. + +""" + + +class TestEnum(IntEnum): + __test__ = False + A = 0 + B = 1 + + +BASE_TYPES_SET = ( + # type_name, tango_name, py_type, sample_values + ("boolean", "DevBoolean", bool, (True, False)), + ("short", "DevShort", int, (1, 2, 3, 4, 5)), + ("ushort", "DevUShort", int, (1, 2, 3, 4, 5)), + ("long", "DevLong", int, (1, 2, 3, 4, 5)), + ("ulong", "DevULong", int, (1, 2, 3, 4, 5)), + ("long64", "DevLong64", int, (1, 2, 3, 4, 5)), + ("char", "DevUChar", int, (1, 2, 3, 4, 5)), + ("float", "DevFloat", float, (1.1, 2.2, 3.3, 4.4, 5.5)), + ("double", "DevDouble", float, (1.1, 2.2, 3.3, 4.4, 5.5)), + ("string", "DevString", str, ("aaa", "bbb", "ccc")), + ("state", "DevState", DevState, (DevState.ON, DevState.MOVING, DevState.ALARM)), + ("enum", "DevEnum", TestEnum, (TestEnum.A, TestEnum.B)), +) + +ATTRIBUTES_SET = [] +COMMANDS_SET = [] + +for type_name, tango_type_name, py_type, values in BASE_TYPES_SET: + ATTRIBUTES_SET.extend( + [ + ( + f"{type_name}_scalar_attr", + tango_type_name, + AttrDataFormat.SCALAR, + py_type, + choice(values), + choice(values), + ), + ( + f"{type_name}_spectrum_attr", + tango_type_name, + AttrDataFormat.SPECTRUM, + npt.NDArray[py_type], + [choice(values), choice(values), choice(values)], + [choice(values), choice(values), choice(values)], + ), + ( + f"{type_name}_image_attr", + tango_type_name, + AttrDataFormat.IMAGE, + npt.NDArray[py_type], + [ + [choice(values), choice(values), choice(values)], + [choice(values), choice(values), choice(values)], + ], + [ + [choice(values), choice(values), choice(values)], + [choice(values), choice(values), choice(values)], + ], + ), + ] + ) + + if tango_type_name == "DevUChar": + continue + else: + COMMANDS_SET.append( + ( + f"{type_name}_scalar_cmd", + tango_type_name, + AttrDataFormat.SCALAR, + py_type, + choice(values), + choice(values), + ) + ) + if tango_type_name in ["DevState", "DevEnum"]: + continue + else: + COMMANDS_SET.append( + ( + f"{type_name}_spectrum_cmd", + tango_type_name, + AttrDataFormat.SPECTRUM, + npt.NDArray[py_type], + [choice(values), choice(values), choice(values)], + [choice(values), choice(values), choice(values)], + ) + ) + + +# -------------------------------------------------------------------- +# TestDevice +# -------------------------------------------------------------------- +# -------------------------------------------------------------------- +@pytest.fixture(scope="module") +def tango_test_device(): + with MultiDeviceTestContext( + [{"class": TestDevice, "devices": [{"name": "test/device/1"}]}], process=True + ) as context: + yield context.get_device_access("test/device/1") + + +# -------------------------------------------------------------------- +# Echo device +# -------------------------------------------------------------------- +class EchoDevice(Device): + attr_values = {} + + def initialize_dynamic_attributes(self): + for name, typ, form, _, _, _ in ATTRIBUTES_SET: + attr = attribute( + name=name, + dtype=typ, + dformat=form, + access=AttrWriteType.READ_WRITE, + fget=self.read, + fset=self.write, + max_dim_x=3, + max_dim_y=2, + enum_labels=[member.name for member in TestEnum], + ) + self.add_attribute(attr) + self.set_change_event(name, True, False) + + for name, typ, form, _, _, _ in COMMANDS_SET: + cmd = command( + f=getattr(self, name), + dtype_in=typ, + dformat_in=form, + dtype_out=typ, + dformat_out=form, + ) + self.add_command(cmd) + + def read(self, attr): + attr.set_value(self.attr_values[attr.get_name()]) + + def write(self, attr): + new_value = attr.get_write_value() + self.attr_values[attr.get_name()] = new_value + self.push_change_event(attr.get_name(), new_value) + + echo_command_code = textwrap.dedent( + """\ + def echo_command(self, arg): + return arg + """ + ) + + for name, _, _, _, _, _ in COMMANDS_SET: + exec(echo_command_code.replace("echo_command", name)) + + +# -------------------------------------------------------------------- +def assert_enum(initial_value, readout_value): + if type(readout_value) in [list, tuple]: + for _initial_value, _readout_value in zip( + initial_value, readout_value, strict=False + ): + assert_enum(_initial_value, _readout_value) + else: + assert initial_value == readout_value + + +# -------------------------------------------------------------------- +# fixtures to run Echo device +# -------------------------------------------------------------------- +@pytest.fixture(scope="module") +def echo_device(): + with MultiDeviceTestContext( + [{"class": EchoDevice, "devices": [{"name": "test/device/1"}]}], process=True + ) as context: + yield context.get_device_access("test/device/1") + + +# -------------------------------------------------------------------- +@pytest.fixture(autouse=True) +def reset_tango_asyncio(): + set_global_executor(None) + + +# -------------------------------------------------------------------- +# helpers to run tests +# -------------------------------------------------------------------- +def get_test_descriptor(python_type: type[T], value: T, is_cmd: bool) -> dict: + if python_type in [bool, int]: + return {"dtype": "integer", "shape": []} + if python_type in [float]: + return {"dtype": "number", "shape": []} + if python_type in [str]: + return {"dtype": "string", "shape": []} + if issubclass(python_type, DevState): + return {"dtype": "string", "shape": []} + if issubclass(python_type, Enum): + return { + "dtype": "string", + "shape": [], + } + + return { + "dtype": "array", + "shape": [np.iinfo(np.int32).max] if is_cmd else list(np.array(value).shape), + } + + +# -------------------------------------------------------------------- +async def make_backend( + typ: type | None, + pv: str, + connect: bool = True, + allow_events: bool | None = True, +) -> TangoSignalBackend: + backend = TangoSignalBackend(typ, pv, pv) + backend.allow_events(allow_events) + if connect: + await asyncio.wait_for(backend.connect(), 10) + return backend + + +# -------------------------------------------------------------------- +async def prepare_device(echo_device: str, pv: str, put_value: T) -> None: + proxy = await DeviceProxy(echo_device) + setattr(proxy, pv, put_value) + + +# -------------------------------------------------------------------- +class MonitorQueue: + def __init__(self, backend: SignalBackend): + self.updates: asyncio.Queue[tuple[Reading, Any]] = asyncio.Queue() + self.backend = backend + self.subscription = backend.set_callback(self.add_reading_value) + + def add_reading_value(self, reading: Reading, value): + self.updates.put_nowait((reading, value)) + + async def assert_updates(self, expected_value): + expected_reading = { + "timestamp": pytest.approx(time.time(), rel=0.1), + "alarm_severity": 0, + } + update_reading, update_value = await self.updates.get() + get_reading = await self.backend.get_reading() + # If update_value is a numpy.ndarray, convert it to a list + if isinstance(update_value, np.ndarray): + update_value = update_value.tolist() + assert_close(update_value, expected_value) + assert_close(await self.backend.get_value(), expected_value) + + update_reading = dict(update_reading) + update_value = update_reading.pop("value") + + get_reading = dict(get_reading) + get_value = get_reading.pop("value") + + assert update_reading == expected_reading == get_reading + assert_close(update_value, expected_value) + assert_close(get_value, expected_value) + + def close(self): + self.backend.set_callback(None) + + +# -------------------------------------------------------------------- +async def assert_monitor_then_put( + echo_device: str, + pv: str, + initial_value: T, + put_value: T, + descriptor: dict, + datatype: type[T] | None = None, +): + await prepare_device(echo_device, pv, initial_value) + source = echo_device + "/" + pv + backend = await make_backend(datatype, source, allow_events=True) + # Make a monitor queue that will monitor for updates + q = MonitorQueue(backend) + try: + assert dict(source=source, **descriptor) == await backend.get_datakey("") + # Check initial value + await q.assert_updates(initial_value) + # Put to new value and check that + await backend.put(put_value, wait=True) + assert_close(put_value, await backend.get_setpoint()) + await q.assert_updates(put_value) + finally: + q.close() + + +# -------------------------------------------------------------------- +@pytest.mark.asyncio +@pytest.mark.parametrize( + "pv, tango_type, d_format, py_type, initial_value, put_value", + ATTRIBUTES_SET, + ids=[x[0] for x in ATTRIBUTES_SET], +) +async def test_backend_get_put_monitor_attr( + echo_device: str, + pv: str, + tango_type: str, + d_format: AttrDataFormat, + py_type: type[T], + initial_value: T, + put_value: T, +): + try: + # Set a timeout for the operation to prevent it from running indefinitely + await asyncio.wait_for( + assert_monitor_then_put( + echo_device, + pv, + initial_value, + put_value, + get_test_descriptor(py_type, initial_value, False), + py_type, + ), + timeout=10, # Timeout in seconds + ) + except asyncio.TimeoutError: + pytest.fail("Test timed out") + except Exception as e: + pytest.fail(f"Test failed with exception: {e}") + + +# -------------------------------------------------------------------- +async def assert_put_read( + echo_device: str, + pv: str, + put_value: T, + descriptor: dict, + datatype: type[T] | None = None, +): + source = echo_device + "/" + pv + backend = await make_backend(datatype, source) + # Make a monitor queue that will monitor for updates + assert dict(source=source, **descriptor) == await backend.get_datakey("") + # Put to new value and check that + await backend.put(put_value, wait=True) + + expected_reading = { + "timestamp": pytest.approx(time.time(), rel=0.1), + "alarm_severity": 0, + } + + assert_close(await backend.get_value(), put_value) + + get_reading = dict(await backend.get_reading()) + get_reading.pop("value") + assert expected_reading == get_reading + + +# -------------------------------------------------------------------- +@pytest.mark.asyncio +@pytest.mark.parametrize( + "pv, tango_type, d_format, py_type, initial_value, put_value", + COMMANDS_SET, + ids=[x[0] for x in COMMANDS_SET], +) +async def test_backend_get_put_monitor_cmd( + echo_device: str, + pv: str, + tango_type: str, + d_format: AttrDataFormat, + py_type: type[T], + initial_value: T, + put_value: T, +): + # With the given datatype, check we have the correct initial value and putting works + descriptor = get_test_descriptor(py_type, initial_value, True) + await assert_put_read(echo_device, pv, put_value, descriptor, py_type) + # # With guessed datatype, check we can set it back to the initial value + await assert_put_read(echo_device, pv, put_value, descriptor) + tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()] + await asyncio.gather(*tasks) + del echo_device + + +# -------------------------------------------------------------------- +@pytest.mark.asyncio +@pytest.mark.parametrize( + "pv, tango_type, d_format, py_type, initial_value, put_value, use_proxy", + [ + ( + pv, + tango_type, + d_format, + py_type, + initial_value, + put_value, + use_proxy, + ) + for ( + pv, + tango_type, + d_format, + py_type, + initial_value, + put_value, + ) in ATTRIBUTES_SET + for use_proxy in [True, False] + ], + ids=[f"{x[0]}_{use_proxy}" for x in ATTRIBUTES_SET for use_proxy in [True, False]], +) +async def test_tango_signal_r( + echo_device: str, + pv: str, + tango_type: str, + d_format: AttrDataFormat, + py_type: type[T], + initial_value: T, + put_value: T, + use_proxy: bool, +): + await prepare_device(echo_device, pv, initial_value) + source = echo_device + "/" + pv + proxy = await DeviceProxy(echo_device) if use_proxy else None + + timeout = 0.2 + signal = tango_signal_r( + datatype=py_type, + read_trl=source, + device_proxy=proxy, + timeout=timeout, + name="test_signal", + ) + await signal.connect() + reading = await signal.read() + assert_close(reading["test_signal"]["value"], initial_value) + + +# -------------------------------------------------------------------- +@pytest.mark.asyncio +@pytest.mark.parametrize( + "pv, tango_type, d_format, py_type, initial_value, put_value, use_proxy", + [ + ( + pv, + tango_type, + d_format, + py_type, + initial_value, + put_value, + use_proxy, + ) + for ( + pv, + tango_type, + d_format, + py_type, + initial_value, + put_value, + ) in ATTRIBUTES_SET + for use_proxy in [True, False] + ], + ids=[f"{x[0]}_{use_proxy}" for x in ATTRIBUTES_SET for use_proxy in [True, False]], +) +async def test_tango_signal_w( + echo_device: str, + pv: str, + tango_type: str, + d_format: AttrDataFormat, + py_type: type[T], + initial_value: T, + put_value: T, + use_proxy: bool, +): + await prepare_device(echo_device, pv, initial_value) + source = echo_device + "/" + pv + proxy = await DeviceProxy(echo_device) if use_proxy else None + + timeout = 0.2 + signal = tango_signal_w( + datatype=py_type, + write_trl=source, + device_proxy=proxy, + timeout=timeout, + name="test_signal", + ) + await signal.connect() + status = signal.set(put_value, wait=True, timeout=timeout) + await status + assert status.done is True and status.success is True + + status = signal.set(put_value, wait=False, timeout=timeout) + await status + assert status.done is True and status.success is True + + status = signal.set(put_value, wait=True) + await status + assert status.done is True and status.success is True + + status = signal.set(put_value, wait=False) + await status + assert status.done is True and status.success is True + + +# -------------------------------------------------------------------- +@pytest.mark.asyncio +@pytest.mark.parametrize( + "pv, tango_type, d_format, py_type, initial_value, put_value, use_proxy", + [ + ( + pv, + tango_type, + d_format, + py_type, + initial_value, + put_value, + use_proxy, + ) + for ( + pv, + tango_type, + d_format, + py_type, + initial_value, + put_value, + ) in ATTRIBUTES_SET + for use_proxy in [True, False] + ], + ids=[f"{x[0]}_{use_proxy}" for x in ATTRIBUTES_SET for use_proxy in [True, False]], +) +async def test_tango_signal_rw( + echo_device: str, + pv: str, + tango_type: str, + d_format: AttrDataFormat, + py_type: type[T], + initial_value: T, + put_value: T, + use_proxy: bool, +): + await prepare_device(echo_device, pv, initial_value) + source = echo_device + "/" + pv + proxy = await DeviceProxy(echo_device) if use_proxy else None + + timeout = 0.2 + signal = tango_signal_rw( + datatype=py_type, + read_trl=source, + write_trl=source, + device_proxy=proxy, + timeout=timeout, + name="test_signal", + ) + await signal.connect() + reading = await signal.read() + assert_close(reading["test_signal"]["value"], initial_value) + await signal.set(put_value) + location = await signal.locate() + assert_close(location["setpoint"], put_value) + assert_close(location["readback"], put_value) + + +# -------------------------------------------------------------------- +@pytest.mark.asyncio +@pytest.mark.parametrize("use_proxy", [True, False]) +async def test_tango_signal_x(tango_test_device: str, use_proxy: bool): + proxy = await DeviceProxy(tango_test_device) if use_proxy else None + timeout = 0.2 + signal = tango_signal_x( + write_trl=tango_test_device + "/" + "clear", + device_proxy=proxy, + timeout=timeout, + name="test_signal", + ) + await signal.connect() + status = signal.trigger() + await status + assert status.done is True and status.success is True + + +# -------------------------------------------------------------------- +@pytest.mark.asyncio +@pytest.mark.parametrize( + "pv, tango_type, d_format, py_type, initial_value, put_value, use_proxy", + [ + ( + pv, + tango_type, + d_format, + py_type, + initial_value, + put_value, + use_proxy, + ) + for ( + pv, + tango_type, + d_format, + py_type, + initial_value, + put_value, + ) in ATTRIBUTES_SET + for use_proxy in [True, False] + ], + ids=[f"{x[0]}_{use_proxy}" for x in ATTRIBUTES_SET for use_proxy in [True, False]], +) +async def test_tango_signal_auto_attrs( + echo_device: str, + pv: str, + tango_type: str, + d_format: AttrDataFormat, + py_type: type[T], + initial_value: T, + put_value: T, + use_proxy: bool, +): + await prepare_device(echo_device, pv, initial_value) + source = echo_device + "/" + pv + proxy = await DeviceProxy(echo_device) if use_proxy else None + timeout = 0.2 + + async def _test_signal(dtype, proxy): + signal = await __tango_signal_auto( + datatype=dtype, + trl=source, + device_proxy=proxy, + timeout=timeout, + name="test_signal", + ) + assert type(signal) is SignalRW + await signal.connect() + reading = await signal.read() + value = reading["test_signal"]["value"] + if isinstance(value, np.ndarray): + value = value.tolist() + assert_close(value, initial_value) + + await signal.set(put_value, wait=True, timeout=timeout) + reading = await signal.read() + value = reading["test_signal"]["value"] + if isinstance(value, np.ndarray): + value = value.tolist() + assert_close(value, put_value) + + dtype = py_type + await _test_signal(dtype, proxy) + + +# -------------------------------------------------------------------- +@pytest.mark.asyncio +@pytest.mark.parametrize( + "pv, tango_type, d_format, py_type, initial_value, put_value, use_dtype, use_proxy", + [ + ( + pv, + tango_type, + d_format, + py_type, + initial_value, + put_value, + use_dtype, + use_proxy, + ) + for ( + pv, + tango_type, + d_format, + py_type, + initial_value, + put_value, + ) in COMMANDS_SET + for use_dtype in [True, False] + for use_proxy in [True, False] + ], + ids=[ + f"{x[0]}_{use_dtype}_{use_proxy}" + for x in COMMANDS_SET + for use_dtype in [True, False] + for use_proxy in [True, False] + ], +) +async def test_tango_signal_auto_cmds( + echo_device: str, + pv: str, + tango_type: str, + d_format: AttrDataFormat, + py_type: type[T], + initial_value: T, + put_value: T, + use_dtype: bool, + use_proxy: bool, +): + source = echo_device + "/" + pv + timeout = 0.2 + + async def _test_signal(dtype, proxy): + signal = await __tango_signal_auto( + datatype=dtype, + trl=source, + device_proxy=proxy, + name="test_signal", + timeout=timeout, + ) + # Ophyd SignalX does not support types + assert type(signal) in [SignalR, SignalRW, SignalW] + await signal.connect() + assert signal + reading = await signal.read() + assert reading["test_signal"]["value"] is None + + await signal.set(put_value, wait=True, timeout=0.1) + reading = await signal.read() + value = reading["test_signal"]["value"] + if isinstance(value, np.ndarray): + value = value.tolist() + assert_close(value, put_value) + + proxy = await DeviceProxy(echo_device) if use_proxy else None + dtype = py_type if use_dtype else None + await _test_signal(dtype, proxy) + + +# -------------------------------------------------------------------- +@pytest.mark.asyncio +@pytest.mark.parametrize("use_proxy", [True, False]) +async def test_tango_signal_auto_cmds_void(tango_test_device: str, use_proxy: bool): + proxy = await DeviceProxy(tango_test_device) if use_proxy else None + signal = await __tango_signal_auto( + datatype=None, + trl=tango_test_device + "/" + "clear", + device_proxy=proxy, + ) + assert type(signal) is SignalX + await signal.connect() + assert signal + await signal.trigger(wait=True) + + +# -------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_tango_signal_auto_badtrl(tango_test_device: str): + proxy = await DeviceProxy(tango_test_device) + with pytest.raises(RuntimeError) as exc_info: + await __tango_signal_auto( + datatype=None, + trl=tango_test_device + "/" + "badtrl", + device_proxy=proxy, + ) + assert f"Cannot find badtrl in {tango_test_device}" in str(exc_info.value) diff --git a/tests/tango/test_tango_transport.py b/tests/tango/test_tango_transport.py new file mode 100644 index 0000000000..4e951ba80e --- /dev/null +++ b/tests/tango/test_tango_transport.py @@ -0,0 +1,854 @@ +import asyncio +from enum import Enum + +import numpy as np +import numpy.typing as npt +import pytest +from test_base_device import TestDevice +from test_tango_signals import ( + EchoDevice, + make_backend, + prepare_device, +) + +from ophyd_async.core import ( + NotConnected, +) +from ophyd_async.tango import ( + AttributeProxy, + CommandProxy, + TangoSignalBackend, + ensure_proper_executor, + get_dtype_extended, + get_python_type, + get_tango_trl, + get_trl_descriptor, +) +from tango import ( + CmdArgType, + DevState, +) +from tango.asyncio import DeviceProxy +from tango.asyncio_executor import ( + AsyncioExecutor, + get_global_executor, + set_global_executor, +) +from tango.test_context import MultiDeviceTestContext + + +# -------------------------------------------------------------------- +@pytest.fixture(scope="module") +def tango_test_device(): + with MultiDeviceTestContext( + [{"class": TestDevice, "devices": [{"name": "test/device/1"}]}], process=True + ) as context: + yield context.get_device_access("test/device/1") + + +# -------------------------------------------------------------------- +@pytest.fixture(scope="module") +def echo_device(): + with MultiDeviceTestContext( + [{"class": EchoDevice, "devices": [{"name": "test/device/1"}]}], process=True + ) as context: + yield context.get_device_access("test/device/1") + + +# -------------------------------------------------------------------- +@pytest.fixture(autouse=True) +def reset_tango_asyncio(): + set_global_executor(None) + + +# -------------------------------------------------------------------- +class HelperClass: + @ensure_proper_executor + async def mock_func(self): + return "executed" + + +# Test function +@pytest.mark.asyncio +async def test_ensure_proper_executor(): + # Instantiate the helper class and call the decorated method + helper_instance = HelperClass() + result = await helper_instance.mock_func() + + # Assertions + assert result == "executed" + assert isinstance(get_global_executor(), AsyncioExecutor) + + +# -------------------------------------------------------------------- +@pytest.mark.parametrize( + "tango_type, expected", + [ + (CmdArgType.DevVoid, (False, None, "string")), + (CmdArgType.DevBoolean, (False, bool, "integer")), + (CmdArgType.DevShort, (False, int, "integer")), + (CmdArgType.DevLong, (False, int, "integer")), + (CmdArgType.DevFloat, (False, float, "number")), + (CmdArgType.DevDouble, (False, float, "number")), + (CmdArgType.DevUShort, (False, int, "integer")), + (CmdArgType.DevULong, (False, int, "integer")), + (CmdArgType.DevString, (False, str, "string")), + (CmdArgType.DevVarCharArray, (True, list[str], "string")), + (CmdArgType.DevVarShortArray, (True, int, "integer")), + (CmdArgType.DevVarLongArray, (True, int, "integer")), + (CmdArgType.DevVarFloatArray, (True, float, "number")), + (CmdArgType.DevVarDoubleArray, (True, float, "number")), + (CmdArgType.DevVarUShortArray, (True, int, "integer")), + (CmdArgType.DevVarULongArray, (True, int, "integer")), + (CmdArgType.DevVarStringArray, (True, str, "string")), + # (CmdArgType.DevVarLongStringArray, (True, str, "string")), + # (CmdArgType.DevVarDoubleStringArray, (True, str, "string")), + (CmdArgType.DevState, (False, CmdArgType.DevState, "string")), + (CmdArgType.ConstDevString, (False, str, "string")), + (CmdArgType.DevVarBooleanArray, (True, bool, "integer")), + (CmdArgType.DevUChar, (False, int, "integer")), + (CmdArgType.DevLong64, (False, int, "integer")), + (CmdArgType.DevULong64, (False, int, "integer")), + (CmdArgType.DevVarLong64Array, (True, int, "integer")), + (CmdArgType.DevVarULong64Array, (True, int, "integer")), + (CmdArgType.DevEncoded, (False, list[str], "string")), + (CmdArgType.DevEnum, (False, Enum, "string")), + # (CmdArgType.DevPipeBlob, (False, list[str], "string")), + (float, (False, float, "number")), + ], +) +def test_get_python_type(tango_type, expected): + if tango_type is not float: + assert get_python_type(tango_type) == expected + else: + # get_python_type should raise a TypeError + with pytest.raises(TypeError) as exc_info: + get_python_type(tango_type) + assert str(exc_info.value) == "Unknown TangoType" + + +# -------------------------------------------------------------------- +@pytest.mark.parametrize( + "datatype, expected", + [ + (npt.NDArray[np.float64], np.dtype("float64")), + (npt.NDArray[np.int8], np.dtype("int8")), + (npt.NDArray[np.uint8], np.dtype("uint8")), + (npt.NDArray[np.int32], np.dtype("int32")), + (npt.NDArray[np.int64], np.dtype("int64")), + (npt.NDArray[np.uint16], np.dtype("uint16")), + (npt.NDArray[np.uint32], np.dtype("uint32")), + (npt.NDArray[np.uint64], np.dtype("uint64")), + (npt.NDArray[np.bool_], np.dtype("bool")), + (npt.NDArray[DevState], CmdArgType.DevState), + (npt.NDArray[np.str_], np.dtype("str")), + (npt.NDArray[np.float32], np.dtype("float32")), + (npt.NDArray[np.complex64], np.dtype("complex64")), + (npt.NDArray[np.complex128], np.dtype("complex128")), + ], +) +def test_get_dtype_extended(datatype, expected): + assert get_dtype_extended(datatype) == expected + + +# -------------------------------------------------------------------- + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "datatype, tango_resource, expected_descriptor", + [ + ( + int, + "test/device/1/justvalue", + {"source": "test/device/1/justvalue", "dtype": "integer", "shape": []}, + ), + ( + float, + "test/device/1/limitedvalue", + {"source": "test/device/1/limitedvalue", "dtype": "number", "shape": []}, + ), + ( + npt.NDArray[float], + "test/device/1/array", + {"source": "test/device/1/array", "dtype": "array", "shape": [2, 3]}, + ), + # Add more test cases as needed + ], +) +async def test_get_trl_descriptor( + tango_test_device, datatype, tango_resource, expected_descriptor +): + proxy = await DeviceProxy(tango_test_device) + tr_configs = { + tango_resource.split("/")[-1]: await proxy.get_attribute_config( + tango_resource.split("/")[-1] + ) + } + descriptor = get_trl_descriptor(datatype, tango_resource, tr_configs) + assert descriptor == expected_descriptor + + +# -------------------------------------------------------------------- +@pytest.mark.asyncio +@pytest.mark.parametrize( + "trl, proxy_needed, expected_type, should_raise", + [ + ("test/device/1/justvalue", True, AttributeProxy, False), + ("test/device/1/justvalue", False, AttributeProxy, False), + ("test/device/1/clear", True, CommandProxy, False), + ("test/device/1/clear", False, CommandProxy, False), + ("test/device/1/nonexistent", True, None, True), + ], +) +async def test_get_tango_trl( + tango_test_device, trl, proxy_needed, expected_type, should_raise +): + proxy = await DeviceProxy(tango_test_device) if proxy_needed else None + if should_raise: + with pytest.raises(RuntimeError): + await get_tango_trl(trl, proxy) + else: + result = await get_tango_trl(trl, proxy) + assert isinstance(result, expected_type) + + +# -------------------------------------------------------------------- + + +@pytest.mark.asyncio +@pytest.mark.parametrize("attr", ["justvalue", "array"]) +async def test_attribute_proxy_get(tango_test_device, attr): + device_proxy = await DeviceProxy(tango_test_device) + attr_proxy = AttributeProxy(device_proxy, attr) + val = None + val = await attr_proxy.get() + assert val is not None + + +# -------------------------------------------------------------------- +@pytest.mark.asyncio +@pytest.mark.parametrize( + "attr, wait", + [("justvalue", True), ("justvalue", False), ("array", True), ("array", False)], +) +async def test_attribute_proxy_put(tango_test_device, attr, wait): + device_proxy = await DeviceProxy(tango_test_device) + attr_proxy = AttributeProxy(device_proxy, attr) + + old_value = await attr_proxy.get() + new_value = old_value + 1 + status = await attr_proxy.put(new_value, wait=wait, timeout=0.1) + if status: + await status + else: + if not wait: + raise AssertionError("If wait is False, put should return a status object") + updated_value = await attr_proxy.get() + if isinstance(new_value, np.ndarray): + assert np.all(updated_value == new_value) + else: + assert updated_value == new_value + + +# -------------------------------------------------------------------- +@pytest.mark.asyncio +@pytest.mark.parametrize("wait", [True, False]) +async def test_attribute_proxy_put_force_timeout(tango_test_device, wait): + device_proxy = await DeviceProxy(tango_test_device) + attr_proxy = AttributeProxy(device_proxy, "slow_attribute") + with pytest.raises(TimeoutError) as exc_info: + status = await attr_proxy.put(3.0, wait=wait, timeout=0.1) + await status + assert "attr put failed" in str(exc_info.value) + + +# -------------------------------------------------------------------- +@pytest.mark.asyncio +@pytest.mark.parametrize("wait", [True, False]) +async def test_attribute_proxy_put_exceptions(tango_test_device, wait): + device_proxy = await DeviceProxy(tango_test_device) + attr_proxy = AttributeProxy(device_proxy, "raise_exception_attr") + with pytest.raises(RuntimeError) as exc_info: + status = await attr_proxy.put(3.0, wait=wait) + await status + assert "device failure" in str(exc_info.value) + + +# -------------------------------------------------------------------- +@pytest.mark.asyncio +@pytest.mark.parametrize( + "attr, new_value", [("justvalue", 10), ("array", np.array([[2, 3, 4], [5, 6, 7]]))] +) +async def test_attribute_proxy_get_w_value(tango_test_device, attr, new_value): + device_proxy = await DeviceProxy(tango_test_device) + attr_proxy = AttributeProxy(device_proxy, attr) + await attr_proxy.put(new_value) + attr_proxy_value = await attr_proxy.get() + if isinstance(new_value, np.ndarray): + assert np.all(attr_proxy_value == new_value) + else: + assert attr_proxy_value == new_value + + +# -------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_attribute_get_config(tango_test_device): + device_proxy = await DeviceProxy(tango_test_device) + attr_proxy = AttributeProxy(device_proxy, "justvalue") + config = await attr_proxy.get_config() + assert config.writable is not None + + +# -------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_attribute_get_reading(tango_test_device): + device_proxy = await DeviceProxy(tango_test_device) + attr_proxy = AttributeProxy(device_proxy, "justvalue") + reading = await attr_proxy.get_reading() + assert reading["value"] is not None + + +# -------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_attribute_has_subscription(tango_test_device): + device_proxy = await DeviceProxy(tango_test_device) + attr_proxy = AttributeProxy(device_proxy, "justvalue") + expected = bool(attr_proxy._callback) + has_subscription = attr_proxy.has_subscription() + assert has_subscription is expected + + +# -------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_attribute_subscribe_callback(echo_device): + await prepare_device(echo_device, "float_scalar_attr", 1.0) + source = echo_device + "/" + "float_scalar_attr" + backend = await make_backend(float, source) + attr_proxy = backend.proxies[source] + val = None + + def callback(reading, value): + nonlocal val + val = value + + attr_proxy.subscribe_callback(callback) + assert attr_proxy.has_subscription() + old_value = await attr_proxy.get() + new_value = old_value + 1 + await attr_proxy.put(new_value) + await asyncio.sleep(0.2) + attr_proxy.unsubscribe_callback() + assert val == new_value + + attr_proxy.set_polling(False) + attr_proxy.support_events = False + with pytest.raises(RuntimeError) as exc_info: + attr_proxy.subscribe_callback(callback) + assert "Cannot set a callback" in str(exc_info.value) + + +# -------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_attribute_unsubscribe_callback(echo_device): + await prepare_device(echo_device, "float_scalar_attr", 1.0) + source = echo_device + "/" + "float_scalar_attr" + backend = await make_backend(float, source) + attr_proxy = backend.proxies[source] + + def callback(reading, value): + pass + + attr_proxy.subscribe_callback(callback) + assert attr_proxy.has_subscription() + attr_proxy.unsubscribe_callback() + assert not attr_proxy.has_subscription() + + +# -------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_attribute_set_polling(tango_test_device): + device_proxy = await DeviceProxy(tango_test_device) + attr_proxy = AttributeProxy(device_proxy, "justvalue") + attr_proxy.set_polling(True, 0.1, 1, 0.1) + assert attr_proxy._allow_polling + assert attr_proxy._polling_period == 0.1 + assert attr_proxy._abs_change == 1 + assert attr_proxy._rel_change == 0.1 + attr_proxy.set_polling(False) + + +# -------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_attribute_poll(tango_test_device): + device_proxy = await DeviceProxy(tango_test_device) + attr_proxy = AttributeProxy(device_proxy, "floatvalue") + attr_proxy.support_events = False + + def callback(reading, value): + nonlocal val + val = value + + def bad_callback(): + pass + + # Test polling with absolute change + val = None + + attr_proxy.set_polling(True, 0.1, 1, 1.0) + attr_proxy.subscribe_callback(callback) + current_value = await attr_proxy.get() + new_value = current_value + 2 + await attr_proxy.put(new_value) + polling_period = attr_proxy._polling_period + await asyncio.sleep(polling_period) + assert val is not None + attr_proxy.unsubscribe_callback() + + # Test polling with relative change + val = None + attr_proxy.set_polling(True, 0.1, 100, 0.1) + attr_proxy.subscribe_callback(callback) + current_value = await attr_proxy.get() + new_value = current_value * 2 + await attr_proxy.put(new_value) + polling_period = attr_proxy._polling_period + await asyncio.sleep(polling_period) + assert val is not None + attr_proxy.unsubscribe_callback() + + # Test polling with small changes. This should not update last_reading + attr_proxy.set_polling(True, 0.1, 100, 1.0) + attr_proxy.subscribe_callback(callback) + await asyncio.sleep(0.2) + current_value = await attr_proxy.get() + new_value = current_value + 1 + val = None + await attr_proxy.put(new_value) + polling_period = attr_proxy._polling_period + await asyncio.sleep(polling_period * 2) + assert val is None + attr_proxy.unsubscribe_callback() + + # Test polling with bad callback + attr_proxy.subscribe_callback(bad_callback) + await asyncio.sleep(0.2) + assert "Could not poll the attribute" in str(attr_proxy.exception) + attr_proxy.unsubscribe_callback() + + +# -------------------------------------------------------------------- +@pytest.mark.asyncio +@pytest.mark.parametrize("attr", ["array", "label"]) +async def test_attribute_poll_stringsandarrays(tango_test_device, attr): + device_proxy = await DeviceProxy(tango_test_device) + attr_proxy = AttributeProxy(device_proxy, attr) + attr_proxy.support_events = False + + def callback(reading, value): + nonlocal val + val = value + + val = None + attr_proxy.set_polling(True, 0.1) + attr_proxy.subscribe_callback(callback) + await asyncio.sleep(0.2) + assert val is not None + if isinstance(val, np.ndarray): + await attr_proxy.put(np.array([[2, 3, 4], [5, 6, 7]])) + await asyncio.sleep(0.5) + assert np.all(val == np.array([[2, 3, 4], [5, 6, 7]])) + if isinstance(val, str): + await attr_proxy.put("new label") + await asyncio.sleep(0.5) + assert val == "new label" + + +# -------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_attribute_poll_exceptions(tango_test_device): + device_proxy = await DeviceProxy(tango_test_device) + # Try to poll a non-existent attribute + attr_proxy = AttributeProxy(device_proxy, "nonexistent") + attr_proxy.support_events = False + attr_proxy.set_polling(True, 0.1) + + def callback(reading, value): + pass + + attr_proxy.subscribe_callback(callback) + await asyncio.sleep(0.2) + assert "Could not poll the attribute" in str(attr_proxy.exception) + + +# -------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_command_proxy_put_wait(tango_test_device): + device_proxy = await DeviceProxy(tango_test_device) + cmd_proxy = CommandProxy(device_proxy, "clear") + + cmd_proxy._last_reading = None + await cmd_proxy.put(None, wait=True) + assert cmd_proxy._last_reading["value"] == "Received clear command" + + # Force timeout + cmd_proxy = CommandProxy(device_proxy, "slow_command") + cmd_proxy._last_reading = None + with pytest.raises(TimeoutError) as exc_info: + await cmd_proxy.put(None, wait=True, timeout=0.1) + assert "command failed" in str(exc_info.value) + + +# -------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_command_proxy_put_nowait(tango_test_device): + device_proxy = await DeviceProxy(tango_test_device) + cmd_proxy = CommandProxy(device_proxy, "slow_command") + + # Reply before timeout + cmd_proxy._last_reading = None + status = await cmd_proxy.put(None, wait=False, timeout=0.5) + assert cmd_proxy._last_reading is None + await status + assert cmd_proxy._last_reading["value"] == "Completed slow command" + + # Timeout + cmd_proxy._last_reading = None + status = await cmd_proxy.put(None, wait=False, timeout=0.1) + with pytest.raises(TimeoutError) as exc_info: + await status + assert str(exc_info.value) == "Timeout while waiting for command reply" + + # No timeout + cmd_proxy._last_reading = None + status = await cmd_proxy.put(None, wait=False) + assert cmd_proxy._last_reading is None + await status + assert cmd_proxy._last_reading["value"] == "Completed slow command" + + +# -------------------------------------------------------------------- +@pytest.mark.asyncio +@pytest.mark.parametrize("wait", [True, False]) +async def test_command_proxy_put_exceptions(tango_test_device, wait): + device_proxy = await DeviceProxy(tango_test_device) + cmd_proxy = CommandProxy(device_proxy, "raise_exception_cmd") + with pytest.raises(RuntimeError) as exc_info: + await cmd_proxy.put(None, wait=True) + assert "device failure" in str(exc_info.value) + + +# -------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_command_get(tango_test_device): + device_proxy = await DeviceProxy(tango_test_device) + cmd_proxy = CommandProxy(device_proxy, "clear") + await cmd_proxy.put(None, wait=True, timeout=1.0) + reading = cmd_proxy._last_reading + assert reading["value"] is not None + + +# -------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_command_get_config(tango_test_device): + device_proxy = await DeviceProxy(tango_test_device) + cmd_proxy = CommandProxy(device_proxy, "clear") + config = await cmd_proxy.get_config() + assert config.out_type is not None + + +# -------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_command_get_reading(tango_test_device): + device_proxy = await DeviceProxy(tango_test_device) + cmd_proxy = CommandProxy(device_proxy, "clear") + await cmd_proxy.put(None, wait=True, timeout=1.0) + reading = await cmd_proxy.get_reading() + assert reading["value"] is not None + + +# -------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_command_set_polling(tango_test_device): + device_proxy = await DeviceProxy(tango_test_device) + cmd_proxy = CommandProxy(device_proxy, "clear") + cmd_proxy.set_polling(True, 0.1) + # Set polling in the command proxy currently does nothing + assert True + + +# -------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_tango_transport_init(echo_device): + await prepare_device(echo_device, "float_scalar_attr", 1.0) + source = echo_device + "/" + "float_scalar_attr" + transport = await make_backend(float, source, connect=False) + assert transport is not None + + +# -------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_tango_transport_source(echo_device): + await prepare_device(echo_device, "float_scalar_attr", 1.0) + source = echo_device + "/" + "float_scalar_attr" + transport = await make_backend(float, source) + transport_source = transport.source("") + assert transport_source == source + + +# -------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_tango_transport_datatype_allowed(echo_device): + await prepare_device(echo_device, "float_scalar_attr", 1.0) + source = echo_device + "/" + "float_scalar_attr" + backend = await make_backend(float, source) + + assert backend.datatype_allowed(int) + assert backend.datatype_allowed(float) + assert backend.datatype_allowed(str) + assert backend.datatype_allowed(bool) + assert backend.datatype_allowed(np.ndarray) + assert backend.datatype_allowed(Enum) + assert backend.datatype_allowed(DevState) + assert not backend.datatype_allowed(list) + + +# -------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_tango_transport_connect(echo_device): + await prepare_device(echo_device, "float_scalar_attr", 1.0) + source = echo_device + "/" + "float_scalar_attr" + backend = await make_backend(float, source, connect=False) + assert backend is not None + await backend.connect() + backend.read_trl = "" + with pytest.raises(RuntimeError) as exc_info: + await backend.connect() + assert "trl not set" in str(exc_info.value) + + +# -------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_tango_transport_connect_and_store_config(echo_device): + await prepare_device(echo_device, "float_scalar_attr", 1.0) + source = echo_device + "/" + "float_scalar_attr" + transport = await make_backend(float, source, connect=False) + await transport._connect_and_store_config(source) + assert transport.trl_configs[source] is not None + + with pytest.raises(RuntimeError) as exc_info: + await transport._connect_and_store_config("") + assert "trl not set" in str(exc_info.value) + + +# -------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_tango_transport_put(echo_device): + await prepare_device(echo_device, "float_scalar_attr", 1.0) + source = echo_device + "/" + "float_scalar_attr" + transport = await make_backend(float, source, connect=False) + + with pytest.raises(NotConnected) as exc_info: + await transport.put(1.0) + assert "Not connected" in str(exc_info.value) + + await transport.connect() + source = transport.source("") + await transport.put(2.0) + val = await transport.proxies[source].get_w_value() + assert val == 2.0 + + +# -------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_tango_transport_get_datakey(echo_device): + await prepare_device(echo_device, "float_scalar_attr", 1.0) + source = echo_device + "/" + "float_scalar_attr" + transport = await make_backend(float, source, connect=False) + await transport.connect() + datakey = await transport.get_datakey(source) + assert datakey["source"] == source + assert datakey["dtype"] == "number" + assert datakey["shape"] == [] + + +# -------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_tango_transport_get_reading(echo_device): + await prepare_device(echo_device, "float_scalar_attr", 1.0) + source = echo_device + "/" + "float_scalar_attr" + transport = await make_backend(float, source, connect=False) + + with pytest.raises(NotConnected) as exc_info: + await transport.put(1.0) + assert "Not connected" in str(exc_info.value) + + await transport.connect() + reading = await transport.get_reading() + assert reading["value"] == 1.0 + + +# -------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_tango_transport_get_value(echo_device): + await prepare_device(echo_device, "float_scalar_attr", 1.0) + source = echo_device + "/" + "float_scalar_attr" + transport = await make_backend(float, source, connect=False) + + with pytest.raises(NotConnected) as exc_info: + await transport.put(1.0) + assert "Not connected" in str(exc_info.value) + + await transport.connect() + value = await transport.get_value() + assert value == 1.0 + + +# -------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_tango_transport_get_setpoint(echo_device): + await prepare_device(echo_device, "float_scalar_attr", 1.0) + source = echo_device + "/" + "float_scalar_attr" + transport = await make_backend(float, source, connect=False) + + with pytest.raises(NotConnected) as exc_info: + await transport.put(1.0) + assert "Not connected" in str(exc_info.value) + + await transport.connect() + new_setpoint = 2.0 + await transport.put(new_setpoint) + setpoint = await transport.get_setpoint() + assert setpoint == new_setpoint + + +# -------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_set_callback(echo_device): + await prepare_device(echo_device, "float_scalar_attr", 1.0) + source = echo_device + "/" + "float_scalar_attr" + transport = await make_backend(float, source, connect=False) + + with pytest.raises(NotConnected) as exc_info: + await transport.put(1.0) + assert "Not connected" in str(exc_info.value) + + await transport.connect() + val = None + + def callback(reading, value): + nonlocal val + val = value + + # Correct usage + transport.set_callback(callback) + current_value = await transport.get_value() + new_value = current_value + 2 + await transport.put(new_value) + await asyncio.sleep(0.1) + assert val == new_value + + # Try to add second callback + with pytest.raises(RuntimeError) as exc_info: + transport.set_callback(callback) + assert "Cannot set a callback when one is already set" + + transport.set_callback(None) + + # Try to add a callback to a non-callable proxy + transport.allow_events(False) + transport.set_polling(False) + with pytest.raises(RuntimeError) as exc_info: + transport.set_callback(callback) + assert "Cannot set event" in str(exc_info.value) + + # Try to add a non-callable callback + transport.allow_events(True) + with pytest.raises(RuntimeError) as exc_info: + transport.set_callback(1) + assert "Callback must be a callable" in str(exc_info.value) + + +# -------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_tango_transport_set_polling(echo_device): + await prepare_device(echo_device, "float_scalar_attr", 1.0) + source = echo_device + "/" + "float_scalar_attr" + transport = await make_backend(float, source, connect=False) + transport.set_polling(True, 0.1, 1, 0.1) + assert transport._polling == (True, 0.1, 1, 0.1) + + +# -------------------------------------------------------------------- +@pytest.mark.asyncio +@pytest.mark.parametrize("allow", [True, False]) +async def test_tango_transport_allow_events(echo_device, allow): + await prepare_device(echo_device, "float_scalar_attr", 1.0) + source = echo_device + "/" + "float_scalar_attr" + transport = await make_backend(float, source, connect=False) + transport.allow_events(allow) + assert transport.support_events == allow + + +# -------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_tango_transport_read_and_write_trl(tango_test_device): + device_proxy = await DeviceProxy(tango_test_device) + trl = device_proxy.dev_name() + read_trl = trl + "/" + "readback" + write_trl = trl + "/" + "setpoint" + + # Test with existing proxy + transport = TangoSignalBackend(float, read_trl, write_trl, device_proxy) + await transport.connect() + reading = await transport.get_reading() + initial_value = reading["value"] + new_value = initial_value + 1.0 + await transport.put(new_value) + updated_value = await transport.get_value() + assert updated_value == new_value + + # Without pre-existing proxy + transport = TangoSignalBackend(float, read_trl, write_trl, None) + await transport.connect() + reading = await transport.get_reading() + initial_value = reading["value"] + new_value = initial_value + 1.0 + await transport.put(new_value) + updated_value = await transport.get_value() + assert updated_value == new_value + + +# -------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_tango_transport_read_only_trl(tango_test_device): + device_proxy = await DeviceProxy(tango_test_device) + trl = device_proxy.dev_name() + read_trl = trl + "/" + "readonly" + + # Test with existing proxy + transport = TangoSignalBackend(int, read_trl, read_trl, device_proxy) + await transport.connect() + with pytest.raises(RuntimeError) as exc_info: + await transport.put(1) + assert "is not writable" in str(exc_info.value) + + +# -------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_tango_transport_nonexistent_trl(tango_test_device): + device_proxy = await DeviceProxy(tango_test_device) + trl = device_proxy.dev_name() + nonexistent_trl = trl + "/" + "nonexistent" + + # Test with existing proxy + transport = TangoSignalBackend(int, nonexistent_trl, nonexistent_trl, device_proxy) + with pytest.raises(RuntimeError) as exc_info: + await transport.connect() + assert "cannot be found" in str(exc_info.value) + + # Without pre-existing proxy + transport = TangoSignalBackend(int, nonexistent_trl, nonexistent_trl, None) + with pytest.raises(RuntimeError) as exc_info: + await transport.connect() + assert "cannot be found" in str(exc_info.value)