From fd5e699a154f8619a3309479378d7efec7843d73 Mon Sep 17 00:00:00 2001 From: Tom Cobb Date: Tue, 24 Sep 2024 16:08:27 +0000 Subject: [PATCH 1/6] WIP --- pyproject.toml | 10 +- src/ophyd_async/core/__init__.py | 31 +- src/ophyd_async/core/_device.py | 123 ++-- src/ophyd_async/core/_mock_signal_backend.py | 84 --- src/ophyd_async/core/_mock_signal_utils.py | 11 +- src/ophyd_async/core/_protocol.py | 13 +- src/ophyd_async/core/_signal.py | 254 ++++----- src/ophyd_async/core/_signal_backend.py | 191 ++++--- src/ophyd_async/core/_soft_signal_backend.py | 364 ++++++------ src/ophyd_async/core/_table.py | 15 +- src/ophyd_async/core/_utils.py | 49 +- src/ophyd_async/epics/signal/__init__.py | 4 +- src/ophyd_async/epics/signal/_aioca.py | 533 ++++++++---------- src/ophyd_async/epics/signal/_common.py | 72 +-- .../epics/signal/_epics_transport.py | 34 -- src/ophyd_async/epics/signal/_p4p.py | 7 +- src/ophyd_async/epics/signal/_signal.py | 104 +++- src/ophyd_async/py.typed | 0 tests/core/test_soft_signal_backend.py | 121 ++-- tests/epics/signal/test_common.py | 19 +- tests/epics/signal/test_records.db | 2 +- tests/epics/signal/test_signals.py | 117 ++-- 22 files changed, 1036 insertions(+), 1122 deletions(-) delete mode 100644 src/ophyd_async/core/_mock_signal_backend.py delete mode 100644 src/ophyd_async/epics/signal/_epics_transport.py create mode 100644 src/ophyd_async/py.typed diff --git a/pyproject.toml b/pyproject.toml index 23afbf19bc..3923c08918 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,12 +13,12 @@ classifiers = [ description = "Asynchronous Bluesky hardware abstraction code, compatible with control systems like EPICS and Tango" dependencies = [ "networkx>=2.0", - "numpy<2.0.0", + "numpy", "packaging", "pint", "bluesky>=1.13.0a3", - "event_model", - "p4p", + "event-model @ git+https://github.com/bluesky/event-model@main", + "p4p>=4.2.0a3", "pyyaml", "colorlog", "pydantic>=2.0", @@ -37,10 +37,6 @@ dev = [ "ophyd_async[pva]", "ophyd_async[sim]", "ophyd_async[ca]", - "black", - "flake8", - "flake8-isort", - "Flake8-pyproject", "inflection", "ipython", "ipywidgets", diff --git a/src/ophyd_async/core/__init__.py b/src/ophyd_async/core/__init__.py index be38555d10..8873544b49 100644 --- a/src/ophyd_async/core/__init__.py +++ b/src/ophyd_async/core/__init__.py @@ -19,7 +19,6 @@ from ._flyer import StandardFlyer, TriggerLogic from ._hdf_dataset import HDFDataset, HDFFile from ._log import config_ophyd_async_logging -from ._mock_signal_backend import MockSignalBackend from ._mock_signal_utils import ( callback_on_mock_put, get_mock_put, @@ -62,22 +61,33 @@ wait_for_value, ) from ._signal_backend import ( - RuntimeSubsetEnum, + Array1D, SignalBackend, - SubsetEnum, + SignalConnector, + SignalDatatype, + SignalDatatypeT, + make_datakey, +) +from ._soft_signal_backend import ( + MockSignalBackend, + SignalMetadata, + SoftSignalBackend, + SoftSignalConnector, ) -from ._soft_signal_backend import SignalMetadata, SoftSignalBackend from ._status import AsyncStatus, WatchableAsyncStatus, completed_status from ._table import Table from ._utils import ( CALCULATE_TIMEOUT, DEFAULT_TIMEOUT, CalculatableTimeout, + Callback, NotConnected, - ReadingValueCallback, + StrictEnum, + SubsetEnum, T, WatcherUpdate, get_dtype, + get_enum_cls, get_unique, in_micros, is_pydantic_model, @@ -146,22 +156,29 @@ "soft_signal_r_and_setter", "soft_signal_rw", "wait_for_value", - "RuntimeSubsetEnum", + "Array1D", "SignalBackend", + "SignalConnector", + "make_datakey", + "StrictEnum", "SubsetEnum", + "SignalDatatype", + "SignalDatatypeT", "SignalMetadata", "SoftSignalBackend", + "SoftSignalConnector", "AsyncStatus", "WatchableAsyncStatus", "DEFAULT_TIMEOUT", "CalculatableTimeout", + "Callback", "CALCULATE_TIMEOUT", "NotConnected", - "ReadingValueCallback", "Table", "T", "WatcherUpdate", "get_dtype", + "get_enum_cls", "get_unique", "in_micros", "is_pydantic_model", diff --git a/src/ophyd_async/core/_device.py b/src/ophyd_async/core/_device.py index b4305c2bfa..963888938a 100644 --- a/src/ophyd_async/core/_device.py +++ b/src/ophyd_async/core/_device.py @@ -1,15 +1,12 @@ -"""Base device""" +from __future__ import annotations import asyncio import sys -from collections.abc import Coroutine, Generator, Iterator +from abc import abstractmethod +from collections.abc import Callable, Coroutine, Iterator, Mapping from functools import cached_property from logging import LoggerAdapter, getLogger -from typing import ( - Any, - Optional, - TypeVar, -) +from typing import Any, Generic, TypeVar, cast from bluesky.protocols import HasName from bluesky.run_engine import call_in_bluesky_event_loop @@ -17,7 +14,28 @@ from ._utils import DEFAULT_TIMEOUT, NotConnected, wait_for_connection -class Device(HasName): +class DeviceConnector: + @abstractmethod + async def connect( + self, device: Device, mock: bool, timeout: float, force_reconnect: bool + ) -> None: ... + + +class DeviceChildConnector(DeviceConnector): + async def connect( + self, device: Device, mock: bool, timeout: float, force_reconnect: bool + ) -> Any: + coros = { + name: child_device.connect(mock, timeout, force_reconnect) + for name, child_device in device.children().items() + } + await wait_for_connection(**coros) + + +DeviceConnectorType = TypeVar("DeviceConnectorType", bound=DeviceConnector) + + +class Device(HasName, Generic[DeviceConnectorType]): """Common base class for all Ophyd Async Devices. By default, names and connects all Device children. @@ -25,15 +43,21 @@ class Device(HasName): _name: str = "" #: The parent Device if it exists - parent: Optional["Device"] = None + parent: Device | None = None # None if connect hasn't started, a Task if it has _connect_task: asyncio.Task | None = None + # The value of the mock arg to connect + _connect_mock: bool | None = None + # The connector to use + _connector: DeviceConnectorType = DeviceChildConnector() - # Used to check if the previous connect was mocked, - # if the next mock value differs then we fail - _previous_connect_was_mock = None - - def __init__(self, name: str = "") -> None: + def __init__( + self, + name: str = "", + connector: DeviceConnectorType | None = None, + ) -> None: + if connector is not None: + self._connector = connector self.set_name(name) @property @@ -47,10 +71,12 @@ def log(self): getLogger("ophyd_async.devices"), {"ophyd_async_device_name": self.name} ) - def children(self) -> Iterator[tuple[str, "Device"]]: - for attr_name, attr in self.__dict__.items(): - if attr_name != "parent" and isinstance(attr, Device): - yield attr_name, attr + def children(self) -> dict[str, Device]: + return { + attr_name: attr + for attr_name, attr in self.__dict__.items() + if attr_name != "parent" and isinstance(attr, Device) + } def set_name(self, name: str): """Set ``self.name=name`` and each ``self.child.name=name+"-child"``. @@ -66,7 +92,7 @@ def set_name(self, name: str): del self.log self._name = name - for attr_name, child in self.children(): + for attr_name, child in self.children().items(): child_name = f"{name}-{attr_name.rstrip('_')}" if name else "" child.set_name(child_name) child.parent = self @@ -89,40 +115,27 @@ async def connect( Time to wait before failing with a TimeoutError. """ - if ( - self._previous_connect_was_mock is not None - and self._previous_connect_was_mock != mock - ): - raise RuntimeError( - f"`connect(mock={mock})` called on a `Device` where the previous " - f"connect was `mock={self._previous_connect_was_mock}`. Changing mock " - "value between connects is not permitted." - ) - self._previous_connect_was_mock = mock - # If previous connect with same args has started and not errored, can use it - can_use_previous_connect = self._connect_task and not ( - self._connect_task.done() and self._connect_task.exception() + can_use_previous_connect = ( + mock is self._connect_mock + and self._connect_task + and not (self._connect_task.done() and self._connect_task.exception()) ) if force_reconnect or not can_use_previous_connect: - # Kick off a connection - coros = { - name: child_device.connect( - mock, timeout=timeout, force_reconnect=force_reconnect - ) - for name, child_device in self.children() - } - self._connect_task = asyncio.create_task(wait_for_connection(**coros)) - + # Use the connector to make a new connection + self._connect_mock = mock + self._connect_task = asyncio.create_task( + self._connector.connect(self, mock, timeout, force_reconnect) + ) assert self._connect_task, "Connect task not created, this shouldn't happen" # Wait for it to complete await self._connect_task -VT = TypeVar("VT", bound=Device) +DeviceType = TypeVar("DeviceType", bound=Device) -class DeviceVector(dict[int, VT], Device): +class DeviceVector(Mapping[int, DeviceType], Device): """ Defines device components with indices. @@ -131,10 +144,26 @@ class DeviceVector(dict[int, VT], Device): :class:`~ophyd_async.epics.demo.DynamicSensorGroup` """ - def children(self) -> Generator[tuple[str, Device], None, None]: - for attr_name, attr in self.items(): - if isinstance(attr, Device): - yield str(attr_name), attr + def __init__( + self, + children: dict[int, DeviceType], + name: str = "", + connector: DeviceConnector | None = None, + ) -> None: + self._children = children + super().__init__(name, connector) + + def __getitem__(self, key: int) -> DeviceType: + return self._children[key] + + def __iter__(self) -> Iterator[int]: + yield from self._children + + def __len__(self) -> int: + return len(self._children) + + def children(self) -> dict[str, Device]: + return {str(key): value for key, value in self.items()} class DeviceCollector: diff --git a/src/ophyd_async/core/_mock_signal_backend.py b/src/ophyd_async/core/_mock_signal_backend.py deleted file mode 100644 index 029881cd96..0000000000 --- a/src/ophyd_async/core/_mock_signal_backend.py +++ /dev/null @@ -1,84 +0,0 @@ -import asyncio -from collections.abc import Callable -from functools import cached_property -from unittest.mock import AsyncMock - -from bluesky.protocols import Descriptor, Reading - -from ._signal_backend import SignalBackend -from ._soft_signal_backend import SoftSignalBackend -from ._utils import DEFAULT_TIMEOUT, ReadingValueCallback, T - - -class MockSignalBackend(SignalBackend[T]): - """Signal backend for testing, created by ``Device.connect(mock=True)``.""" - - def __init__( - self, - datatype: type[T] | None = None, - initial_backend: SignalBackend[T] | None = None, - ) -> None: - if isinstance(initial_backend, MockSignalBackend): - raise ValueError("Cannot make a MockSignalBackend for a MockSignalBackends") - - self.initial_backend = initial_backend - - if datatype is None: - assert ( - self.initial_backend - ), "Must supply either initial_backend or datatype" - datatype = self.initial_backend.datatype - - self.datatype = datatype - - if not isinstance(self.initial_backend, SoftSignalBackend): - # If the backend is a hard signal backend, or not provided, - # then we create a soft signal to mimic it - - self.soft_backend = SoftSignalBackend(datatype=datatype) - else: - self.soft_backend = self.initial_backend - - def source(self, name: str) -> str: - if self.initial_backend: - return f"mock+{self.initial_backend.source(name)}" - return f"mock+{name}" - - async def connect(self, timeout: float = DEFAULT_TIMEOUT) -> None: - pass - - @cached_property - def put_mock(self) -> AsyncMock: - return AsyncMock(name="put", spec=Callable) - - @cached_property - def put_proceeds(self) -> asyncio.Event: - put_proceeds = asyncio.Event() - put_proceeds.set() - return put_proceeds - - async def put(self, value: T | None, wait=True, timeout=None): - await self.put_mock(value, wait=wait, timeout=timeout) - await self.soft_backend.put(value, wait=wait, timeout=timeout) - - if wait: - await asyncio.wait_for(self.put_proceeds.wait(), timeout=timeout) - - def set_value(self, value: T): - self.soft_backend.set_value(value) - - async def get_reading(self) -> Reading: - return await self.soft_backend.get_reading() - - async def get_value(self) -> T: - return await self.soft_backend.get_value() - - async def get_setpoint(self) -> T: - """For a soft signal, the setpoint and readback values are the same.""" - return await self.soft_backend.get_setpoint() - - async def get_datakey(self, source: str) -> Descriptor: - return await self.soft_backend.get_datakey(source) - - def set_callback(self, callback: ReadingValueCallback[T] | None) -> None: - self.soft_backend.set_callback(callback) diff --git a/src/ophyd_async/core/_mock_signal_utils.py b/src/ophyd_async/core/_mock_signal_utils.py index 33c0f677ba..9eb1cf63e2 100644 --- a/src/ophyd_async/core/_mock_signal_utils.py +++ b/src/ophyd_async/core/_mock_signal_utils.py @@ -3,13 +3,12 @@ from typing import Any from unittest.mock import AsyncMock -from ._mock_signal_backend import MockSignalBackend from ._signal import Signal -from ._utils import T +from ._soft_signal_backend import MockSignalBackend, SignalDatatypeT def _get_mock_signal_backend(signal: Signal) -> MockSignalBackend: - backend = signal._backend # noqa:SLF001 + backend = signal._connector.backend # noqa:SLF001 assert isinstance(backend, MockSignalBackend), ( "Expected to receive a `MockSignalBackend`, instead " f" received {type(backend)}. " @@ -17,7 +16,7 @@ def _get_mock_signal_backend(signal: Signal) -> MockSignalBackend: return backend -def set_mock_value(signal: Signal[T], value: T): +def set_mock_value(signal: Signal[SignalDatatypeT], value: SignalDatatypeT): """Set the value of a signal that is in mock mode.""" backend = _get_mock_signal_backend(signal) backend.set_value(value) @@ -143,7 +142,9 @@ def _unset_side_effect_cm(put_mock: AsyncMock): def callback_on_mock_put( - signal: Signal[T], callback: Callable[[T], None] | Callable[[T], Awaitable[None]] + signal: Signal[SignalDatatypeT], + callback: Callable[[SignalDatatypeT], None] + | Callable[[SignalDatatypeT], Awaitable[None]], ): """For setting a callback when a backend is put to. diff --git a/src/ophyd_async/core/_protocol.py b/src/ophyd_async/core/_protocol.py index 3978f39cc8..75bdcf16a6 100644 --- a/src/ophyd_async/core/_protocol.py +++ b/src/ophyd_async/core/_protocol.py @@ -10,13 +10,24 @@ runtime_checkable, ) -from bluesky.protocols import HasName, Reading +from bluesky.protocols import HasName, ReadingOptional from event_model import DataKey +from ._utils import T + if TYPE_CHECKING: from ._status import AsyncStatus +class Reading(ReadingOptional, Generic[T]): + """A dictionary containing the value and timestamp of a piece of scan data""" + + #: The current value, as a JSON encodable type or numpy array + value: T + #: Timestamp in seconds since the UNIX epoch + timestamp: float + + @runtime_checkable class AsyncReadable(HasName, Protocol): @abstractmethod diff --git a/src/ophyd_async/core/_signal.py b/src/ophyd_async/core/_signal.py index d4e4d7bbb9..93e2d94a70 100644 --- a/src/ophyd_async/core/_signal.py +++ b/src/ophyd_async/core/_signal.py @@ -3,27 +3,28 @@ import asyncio import functools from collections.abc import AsyncGenerator, Callable, Mapping -from typing import Any, Generic, TypeVar, cast +from typing import Any, Generic, cast from bluesky.protocols import ( Locatable, Location, Movable, - Reading, Status, Subscribable, ) from event_model import DataKey from ._device import Device -from ._mock_signal_backend import MockSignalBackend -from ._protocol import AsyncConfigurable, AsyncReadable, AsyncStageable -from ._signal_backend import SignalBackend -from ._soft_signal_backend import SignalMetadata, SoftSignalBackend +from ._protocol import AsyncConfigurable, AsyncReadable, AsyncStageable, Reading +from ._signal_backend import ( + SignalBackend, + SignalConnector, + SignalDatatypeT, + SignalDatatypeV, +) +from ._soft_signal_backend import SoftSignalConnector from ._status import AsyncStatus -from ._utils import CALCULATE_TIMEOUT, DEFAULT_TIMEOUT, CalculatableTimeout, Callback, T - -S = TypeVar("S") +from ._utils import CALCULATE_TIMEOUT, DEFAULT_TIMEOUT, CalculatableTimeout, Callback def _add_timeout(func): @@ -34,96 +35,43 @@ async def wrapper(self: Signal, *args, **kwargs): return wrapper -def _fail(*args, **kwargs): - raise RuntimeError("Signal has not been supplied a backend yet") - - -class DisconnectedBackend(SignalBackend): - source = connect = put = get_datakey = get_reading = get_value = get_setpoint = ( - set_callback - ) = _fail - - -DISCONNECTED_BACKEND = DisconnectedBackend() - - -class Signal(Device, Generic[T]): +class Signal(Device[SignalConnector[SignalDatatypeT]]): """A Device with the concept of a value, with R, RW, W and X flavours""" def __init__( self, - backend: SignalBackend[T] = DISCONNECTED_BACKEND, + connector: SignalConnector[SignalDatatypeT], timeout: float | None = DEFAULT_TIMEOUT, name: str = "", ) -> None: self._timeout = timeout - self._backend = backend - super().__init__(name) - - async def connect( - self, - mock=False, - timeout=DEFAULT_TIMEOUT, - force_reconnect: bool = False, - backend: SignalBackend[T] | None = None, - ): - if backend: - if ( - self._backend is not DISCONNECTED_BACKEND - and backend is not self._backend - ): - raise ValueError("Backend at connection different from previous one.") - - self._backend = backend - if ( - self._previous_connect_was_mock is not None - and self._previous_connect_was_mock != mock - ): - raise RuntimeError( - f"`connect(mock={mock})` called on a `Signal` where the previous " - f"connect was `mock={self._previous_connect_was_mock}`. Changing mock " - "value between connects is not permitted." - ) - self._previous_connect_was_mock = mock - - if mock and not issubclass(type(self._backend), MockSignalBackend): - # Using a soft backend, look to the initial value - self._backend = MockSignalBackend(initial_backend=self._backend) - - if self._backend is None: - raise RuntimeError("`connect` called on signal without backend") - - can_use_previous_connection: bool = self._connect_task is not None and not ( - self._connect_task.done() and self._connect_task.exception() - ) - - if force_reconnect or not can_use_previous_connection: - self.log.debug(f"Connecting to {self.source}") - self._connect_task = asyncio.create_task( - self._backend.connect(timeout=timeout) - ) - else: - self.log.debug(f"Reusing previous connection to {self.source}") - assert ( - self._connect_task - ), "this assert is for type analysis and will never fail" - await self._connect_task + super().__init__(name, connector) @property def source(self) -> str: """Like ca://PV_PREFIX:SIGNAL, or "" if not set""" - return self._backend.source(self.name) + source = self._connector.source(self.name) + if self._connect_mock: + return f"mock+{source}" + else: + return source + + def __setattr__(self, name: str, value: Any) -> None: + if name != "parent" and isinstance(value, Device): + raise AttributeError( + f"Cannot add Device {value} as a child of Signal {self}, " + "make a subclass of Device instead" + ) + return super().__setattr__(name, value) -class _SignalCache(Generic[T]): - def __init__(self, backend: SignalBackend[T], signal: Signal): +class _SignalCache(Generic[SignalDatatypeT]): + def __init__(self, backend: SignalBackend[SignalDatatypeT], signal: Signal): self._signal = signal self._staged = False self._listeners: dict[Callback, bool] = {} self._valid = asyncio.Event() - self._reading: Reading | None = None - self._value: T | None = None - + self._reading: Reading[SignalDatatypeT] | None = None self.backend = backend signal.log.debug(f"Making subscription on source {signal.source}") backend.set_callback(self._callback) @@ -132,30 +80,29 @@ def close(self): self.backend.set_callback(None) self._signal.log.debug(f"Closing subscription on source {self._signal.source}") - async def get_reading(self) -> Reading: + async def get_reading(self) -> Reading[SignalDatatypeT]: await self._valid.wait() assert self._reading is not None, "Monitor not working" return self._reading - async def get_value(self) -> T: - await self._valid.wait() - assert self._value is not None, "Monitor not working" - return self._value + async def get_value(self) -> SignalDatatypeT: + reading = await self.get_reading() + return reading["value"] - def _callback(self, reading: Reading, value: T): + def _callback(self, reading: Reading[SignalDatatypeT]): self._signal.log.debug( f"Updated subscription: reading of source {self._signal.source} changed" f"from {self._reading} to {reading}" ) self._reading = reading - self._value = value self._valid.set() for function, want_value in self._listeners.items(): self._notify(function, want_value) def _notify(self, function: Callback, want_value: bool): + assert self._reading, "Monitor not working" if want_value: - function(self._value) + function(self._reading["value"]) else: function({self._signal.name: self._reading}) @@ -173,7 +120,7 @@ def set_staged(self, staged: bool): return self._staged or bool(self._listeners) -class SignalR(Signal[T], AsyncReadable, AsyncStageable, Subscribable): +class SignalR(Signal[SignalDatatypeT], AsyncReadable, AsyncStageable, Subscribable): """Signal that can be read from and monitored""" _cache: _SignalCache | None = None @@ -186,11 +133,11 @@ def _backend_or_cache(self, cached: bool | None) -> _SignalCache | SignalBackend assert self._cache, f"{self.source} not being monitored" return self._cache else: - return self._backend + return self._connector.backend def _get_cache(self) -> _SignalCache: if not self._cache: - self._cache = _SignalCache(self._backend, self) + self._cache = _SignalCache(self._connector.backend, self) return self._cache def _del_cache(self, needed: bool): @@ -206,16 +153,16 @@ async def read(self, cached: bool | None = None) -> dict[str, Reading]: @_add_timeout async def describe(self) -> dict[str, DataKey]: """Return a single item dict with the descriptor in it""" - return {self.name: await self._backend.get_datakey(self.source)} + return {self.name: await self._connector.backend.get_datakey(self.source)} @_add_timeout - async def get_value(self, cached: bool | None = None) -> T: + async def get_value(self, cached: bool | None = None) -> SignalDatatypeT: """The current value""" value = await self._backend_or_cache(cached).get_value() self.log.debug(f"get_value() on source {self.source} returned {value}") return value - def subscribe_value(self, function: Callback[T]): + def subscribe_value(self, function: Callback[SignalDatatypeT]): """Subscribe to updates in value of a device""" self._get_cache().subscribe(function, want_value=True) @@ -238,85 +185,84 @@ async def unstage(self) -> None: self._del_cache(self._get_cache().set_staged(False)) -class SignalW(Signal[T], Movable): +class SignalW(Signal[SignalDatatypeT], Movable): """Signal that can be set""" - def set( - self, value: T, wait=True, timeout: CalculatableTimeout = CALCULATE_TIMEOUT - ) -> AsyncStatus: + @AsyncStatus.wrap + async def set( + self, + value: SignalDatatypeT, + wait=True, + timeout: CalculatableTimeout = CALCULATE_TIMEOUT, + ) -> None: """Set the value and return a status saying when it's done""" if timeout is CALCULATE_TIMEOUT: timeout = self._timeout - - async def do_set(): - self.log.debug(f"Putting value {value} to backend at source {self.source}") - await self._backend.put(value, wait=wait, timeout=timeout) - self.log.debug( - f"Successfully put value {value} to backend at source {self.source}" - ) - - return AsyncStatus(do_set()) + self.log.debug(f"Putting value {value} to backend at source {self.source}") + await self._connector.backend.put(value, wait=wait, timeout=timeout) + self.log.debug( + f"Successfully put value {value} to backend at source {self.source}" + ) -class SignalRW(SignalR[T], SignalW[T], Locatable): +class SignalRW(SignalR[SignalDatatypeT], SignalW[SignalDatatypeT], Locatable): """Signal that can be both read and set""" async def locate(self) -> Location: - location: Location = { - "setpoint": await self._backend.get_setpoint(), - "readback": await self.get_value(), - } - return location + """Return the setpoint and readback.""" + setpoint, readback = asyncio.gather( + self._connector.backend.get_setpoint(), self.get_value() + ) + return Location(setpoint=setpoint, readback=readback) class SignalX(Signal): """Signal that puts the default value""" - def trigger( + @AsyncStatus.wrap + async def trigger( self, wait=True, timeout: CalculatableTimeout = CALCULATE_TIMEOUT - ) -> AsyncStatus: + ) -> None: """Trigger the action and return a status saying when it's done""" if timeout is CALCULATE_TIMEOUT: timeout = self._timeout - coro = self._backend.put(None, wait=wait, timeout=timeout) - return AsyncStatus(coro) + self.log.debug(f"Putting default value to backend at source {self.source}") + await self._connector.backend.put(None, wait=wait, timeout=timeout) + self.log.debug( + f"Successfully put default value to backend at source {self.source}" + ) def soft_signal_rw( - datatype: type[T] | None = None, - initial_value: T | None = None, + datatype: type[SignalDatatypeT], + initial_value: SignalDatatypeT | None = None, name: str = "", units: str | None = None, precision: int | None = None, -) -> SignalRW[T]: +) -> SignalRW[SignalDatatypeT]: """Creates a read-writable Signal with a SoftSignalBackend. May pass metadata, which are propagated into describe. """ - metadata = SignalMetadata(units=units, precision=precision) - signal = SignalRW( - SoftSignalBackend(datatype, initial_value, metadata=metadata), - name=name, - ) + connector = SoftSignalConnector(datatype, initial_value, units, precision) + signal = SignalRW(connector=connector, name=name) return signal def soft_signal_r_and_setter( - datatype: type[T] | None = None, - initial_value: T | None = None, + datatype: type[SignalDatatypeT], + initial_value: SignalDatatypeT | None = None, name: str = "", units: str | None = None, precision: int | None = None, -) -> tuple[SignalR[T], Callable[[T], None]]: +) -> tuple[SignalR[SignalDatatypeT], Callable[[SignalDatatypeT], None]]: """Returns a tuple of a read-only Signal and a callable through which the signal can be internally modified within the device. May pass metadata, which are propagated into describe. Use soft_signal_rw if you want a device that is externally modifiable """ - metadata = SignalMetadata(units=units, precision=precision) - backend = SoftSignalBackend(datatype, initial_value, metadata=metadata) - signal = SignalR(backend, name=name) - - return (signal, backend.set_value) + connector = SoftSignalConnector(datatype, initial_value, units, precision) + signal = SignalR(connector=connector, name=name) + return (signal, connector.set_value) def _generate_assert_error_msg(name: str, expected_result, actual_result) -> str: @@ -330,7 +276,7 @@ def _generate_assert_error_msg(name: str, expected_result, actual_result) -> str ) -async def assert_value(signal: SignalR[T], value: Any) -> None: +async def assert_value(signal: SignalR[SignalDatatypeT], value: Any) -> None: """Assert a signal's value and compare it an expected signal. Parameters @@ -440,8 +386,10 @@ def assert_emitted(docs: Mapping[str, list[dict]], **numbers: int): async def observe_value( - signal: SignalR[T], timeout: float | None = None, done_status: Status | None = None -) -> AsyncGenerator[T, None]: + signal: SignalR[SignalDatatypeT], + timeout: float | None = None, + done_status: Status | None = None, +) -> AsyncGenerator[SignalDatatypeT, None]: """Subscribe to the value of a signal so it can be iterated from. Parameters @@ -464,7 +412,7 @@ async def observe_value( do_something_with(value) """ - q: asyncio.Queue[T | Status] = asyncio.Queue() + q: asyncio.Queue[SignalDatatypeT | Status] = asyncio.Queue() if timeout is None: get_value = q.get else: @@ -485,24 +433,26 @@ async def get_value(): else: break else: - yield cast(T, item) + yield cast(SignalDatatypeT, item) finally: signal.clear_sub(q.put_nowait) -class _ValueChecker(Generic[T]): - def __init__(self, matcher: Callable[[T], bool], matcher_name: str): - self._last_value: T | None = None +class _ValueChecker(Generic[SignalDatatypeT]): + def __init__(self, matcher: Callable[[SignalDatatypeT], bool], matcher_name: str): + self._last_value: SignalDatatypeT | None = None self._matcher = matcher self._matcher_name = matcher_name - async def _wait_for_value(self, signal: SignalR[T]): + async def _wait_for_value(self, signal: SignalR[SignalDatatypeT]): async for value in observe_value(signal): self._last_value = value if self._matcher(value): return - async def wait_for_value(self, signal: SignalR[T], timeout: float | None): + async def wait_for_value( + self, signal: SignalR[SignalDatatypeT], timeout: float | None + ): try: await asyncio.wait_for(self._wait_for_value(signal), timeout) except asyncio.TimeoutError as e: @@ -513,8 +463,8 @@ async def wait_for_value(self, signal: SignalR[T], timeout: float | None): async def wait_for_value( - signal: SignalR[T], - match: T | Callable[[T], bool], + signal: SignalR[SignalDatatypeT], + match: SignalDatatypeT | Callable[[SignalDatatypeT], bool], timeout: float | None, ): """Wait for a signal to have a matching value. @@ -548,10 +498,10 @@ async def wait_for_value( async def set_and_wait_for_other_value( - set_signal: SignalW[T], - set_value: T, - read_signal: SignalR[S], - read_value: S, + set_signal: SignalW[SignalDatatypeT], + set_value: SignalDatatypeT, + read_signal: SignalR[SignalDatatypeV], + read_value: SignalDatatypeV, timeout: float = DEFAULT_TIMEOUT, set_timeout: float | None = None, ) -> AsyncStatus: @@ -608,8 +558,8 @@ async def _wait_for_value(): async def set_and_wait_for_value( - signal: SignalRW[T], - value: T, + signal: SignalRW[SignalDatatypeT], + value: SignalDatatypeT, timeout: float = DEFAULT_TIMEOUT, status_timeout: float | None = None, ) -> AsyncStatus: diff --git a/src/ophyd_async/core/_signal_backend.py b/src/ophyd_async/core/_signal_backend.py index 035936f32c..27b233b489 100644 --- a/src/ophyd_async/core/_signal_backend.py +++ b/src/ophyd_async/core/_signal_backend.py @@ -1,41 +1,53 @@ from abc import abstractmethod -from typing import ( - TYPE_CHECKING, - Any, - ClassVar, - Generic, - Literal, -) +from collections.abc import Sequence +from enum import Enum +from typing import Generic, TypedDict, TypeVar, get_origin -from bluesky.protocols import Reading +import numpy as np from event_model import DataKey - -from ._utils import DEFAULT_TIMEOUT, ReadingValueCallback, T +from event_model.documents.event_descriptor import Dtype, Limits + +from ._device import DeviceConnector +from ._protocol import Reading +from ._table import Table +from ._utils import Callback, SubsetEnum, T, get_dtype + +DTypeScalar_co = TypeVar("DTypeScalar_co", covariant=True, bound=np.generic) +Array1D = np.ndarray[tuple[int], np.dtype[DTypeScalar_co]] +Primitive = bool | int | float | str +SignalDatatype = ( + Primitive + | Array1D[np.bool_] + | Array1D[np.int8] + | Array1D[np.uint8] + | Array1D[np.int16] + | Array1D[np.uint16] + | Array1D[np.int32] + | Array1D[np.uint32] + | Array1D[np.int64] + | Array1D[np.uint64] + | Array1D[np.float32] + | Array1D[np.float64] + | SubsetEnum + | Sequence[str] + | Sequence[SubsetEnum] + | Table +) +# TODO: These typevars will not be needed when we drop python 3.11 +# as you can do MyConverter[SignalType: SignalTypeUnion]: +# rather than MyConverter(Generic[SignalType]) +PrimitiveT = TypeVar("PrimitiveT", bound=Primitive) +SignalDatatypeT = TypeVar("SignalDatatypeT", bound=SignalDatatype) +SignalDatatypeV = TypeVar("SignalDatatypeV", bound=SignalDatatype) +EnumT = TypeVar("EnumT", bound=SubsetEnum) +TableT = TypeVar("TableT", bound=Table) -class SignalBackend(Generic[T]): +class SignalBackend(Generic[SignalDatatypeT]): """A read/write/monitor backend for a Signals""" - #: Datatype of the signal value - datatype: type[T] | None = None - - @classmethod - @abstractmethod - def datatype_allowed(cls, dtype: Any) -> bool: - """Check if a given datatype is acceptable for this signal backend.""" - - #: Like ca://PV_PREFIX:SIGNAL - @abstractmethod - def source(self, name: str) -> str: - """Return source of signal. Signals may pass a name to the backend, which can be - used or discarded.""" - - @abstractmethod - async def connect(self, timeout: float = DEFAULT_TIMEOUT): - """Connect to underlying hardware""" - @abstractmethod - async def put(self, value: T | None, wait=True, timeout=None): + async def put(self, value: SignalDatatypeT | None, wait=True, timeout=None): """Put a value to the PV, if wait then wait for completion for up to timeout""" @abstractmethod @@ -43,55 +55,104 @@ async def get_datakey(self, source: str) -> DataKey: """Metadata like source, dtype, shape, precision, units""" @abstractmethod - async def get_reading(self) -> Reading: + async def get_reading(self) -> Reading[SignalDatatypeT]: """The current value, timestamp and severity""" @abstractmethod - async def get_value(self) -> T: + async def get_value(self) -> SignalDatatypeT: """The current value""" @abstractmethod - async def get_setpoint(self) -> T: + async def get_setpoint(self) -> SignalDatatypeT: """The point that a signal was requested to move to.""" @abstractmethod - def set_callback(self, callback: ReadingValueCallback[T] | None) -> None: + def set_callback(self, callback: Callback[T] | None) -> None: """Observe changes to the current value, timestamp and severity""" -class _RuntimeSubsetEnumMeta(type): - def __str__(cls): - if hasattr(cls, "choices"): - return f"SubsetEnum{list(cls.choices)}" # type: ignore - return "SubsetEnum" - - def __getitem__(cls, _choices): - if isinstance(_choices, str): - _choices = (_choices,) - else: - if not isinstance(_choices, tuple) or not all( - isinstance(c, str) for c in _choices - ): - raise TypeError( - "Choices must be a str or a tuple of str, " f"not {type(_choices)}." - ) - if len(set(_choices)) != len(_choices): - raise TypeError("Duplicate elements in runtime enum choices.") +def _fail(*args, **kwargs): + raise RuntimeError("Signal has not been supplied a backend yet") - class _RuntimeSubsetEnum(cls): - choices = _choices - return _RuntimeSubsetEnum +class DisconnectedBackend(SignalBackend): + source = connect = put = get_datakey = get_reading = get_value = get_setpoint = ( + set_callback + ) = _fail -class RuntimeSubsetEnum(metaclass=_RuntimeSubsetEnumMeta): - choices: ClassVar[tuple[str, ...]] +class SignalConnector(DeviceConnector, Generic[SignalDatatypeT]): + backend: SignalBackend[SignalDatatypeT] = DisconnectedBackend() - def __init__(self): - raise RuntimeError("SubsetEnum cannot be instantiated") - - -if TYPE_CHECKING: - SubsetEnum = Literal -else: - SubsetEnum = RuntimeSubsetEnum + @abstractmethod + def source(self, name: str) -> str: ... + + +_primitive_dtype: dict[type[Primitive], Dtype] = { + bool: "boolean", + int: "integer", + float: "number", + str: "string", +} + + +class SignalMetadata(TypedDict, total=False): + limits: Limits + choices: list[str] + precision: int + units: str + + +def _datakey_dtype(datatype: type[SignalDatatypeT]) -> Dtype: + if get_origin(datatype) in (Sequence, np.ndarray) or issubclass(datatype, Table): + return "array" + elif issubclass(datatype, Enum): + return "string" + elif issubclass(datatype, Primitive): + return _primitive_dtype[datatype] + else: + raise TypeError(f"Can't make dtype for {datatype}") + + +def _datakey_dtype_numpy(datatype: type[SignalDatatypeT]) -> np.dtype: + if get_origin(datatype) == np.ndarray: + return get_dtype(datatype) + elif ( + get_origin(datatype) == Sequence + or datatype is str + or issubclass(datatype, Enum) + ): + return np.dtypes.StringDType() + elif issubclass(datatype, Table): + return datatype.numpy_dtype() + elif issubclass(datatype, Primitive): + return np.dtype(datatype) + else: + raise TypeError(f"Can't make dtype_numpy for {datatype}") + + +def _datakey_shape(value: SignalDatatype) -> list[int]: + if type(value) in _primitive_dtype or isinstance(value, Enum): + return [] + elif isinstance(value, np.ndarray): + return list(value.shape) + elif isinstance(value, Sequence): + return [len(value)] + else: + raise TypeError(f"Can't make shape for {value}") + + +def make_datakey( + datatype: type[SignalDatatypeT], + value: SignalDatatypeT, + source: str, + metadata: SignalMetadata, +) -> DataKey: + return DataKey( + dtype=_datakey_dtype(datatype), + shape=_datakey_shape(value), + # Ignore until https://github.com/bluesky/event-model/issues/308 + dtype_numpy=_datakey_dtype_numpy(datatype).str, # type: ignore + source=source, + **metadata, + ) diff --git a/src/ophyd_async/core/_soft_signal_backend.py b/src/ophyd_async/core/_soft_signal_backend.py index eb4aa47d71..aeb977ebd4 100644 --- a/src/ophyd_async/core/_soft_signal_backend.py +++ b/src/ophyd_async/core/_soft_signal_backend.py @@ -1,244 +1,218 @@ from __future__ import annotations -import inspect +import asyncio import time -from collections import abc -from enum import Enum -from typing import Generic, cast, get_origin +from abc import abstractmethod +from collections.abc import Callable, Sequence +from dataclasses import dataclass +from functools import cached_property +from typing import Any, Generic, get_args, get_origin +from unittest.mock import AsyncMock import numpy as np -from bluesky.protocols import Reading from event_model import DataKey -from event_model.documents.event_descriptor import Dtype -from pydantic import BaseModel -from typing_extensions import TypedDict +from ophyd_async.core._device import Device + +from ._protocol import Reading from ._signal_backend import ( - RuntimeSubsetEnum, + Array1D, + EnumT, + Primitive, + PrimitiveT, SignalBackend, + SignalConnector, + SignalDatatype, + SignalDatatypeT, + SignalMetadata, + TableT, + make_datakey, ) -from ._utils import ( - DEFAULT_TIMEOUT, - ReadingValueCallback, - T, - get_dtype, - is_pydantic_model, -) - -primitive_dtypes: dict[type, Dtype] = { - str: "string", - int: "integer", - float: "number", - bool: "boolean", -} +from ._table import Table +from ._utils import Callback, get_dtype, get_enum_cls -class SignalMetadata(TypedDict): - units: str | None - precision: int | None +class SoftConverter(Generic[SignalDatatypeT]): + @abstractmethod + def write_value(self, value: Any) -> SignalDatatypeT: ... -class SoftConverter(Generic[T]): - def value(self, value: T) -> T: - return value +@dataclass +class PrimitiveSoftConverter(SoftConverter[PrimitiveT]): + datatype: type[PrimitiveT] - def write_value(self, value: T) -> T: - return value + def write_value(self, value: Any) -> PrimitiveT: + return self.datatype(value) if value else self.datatype() - def reading(self, value: T, timestamp: float, severity: int) -> Reading: - return Reading( - value=value, - timestamp=timestamp, - alarm_severity=-1 if severity > 2 else severity, - ) - def get_datakey(self, source: str, value, **metadata) -> DataKey: - dk: DataKey = {"source": source, "shape": [], **metadata} # type: ignore - dtype = type(value) - if np.issubdtype(dtype, np.integer): - dtype = int - elif np.issubdtype(dtype, np.floating): - dtype = float - assert ( - dtype in primitive_dtypes - ), f"invalid converter for value of type {type(value)}" - dk["dtype"] = primitive_dtypes[dtype] - # type ignore until https://github.com/bluesky/event-model/issues/308 - try: - dk["dtype_numpy"] = np.dtype(dtype).descr[0][1] # type: ignore - except TypeError: - dk["dtype_numpy"] = "" # type: ignore - return dk - - def make_initial_value(self, datatype: type[T] | None) -> T: - if datatype is None: - return cast(T, None) - - return datatype() - - -class SoftArrayConverter(SoftConverter): - def get_datakey(self, source: str, value, **metadata) -> DataKey: - dtype_numpy = "" - if isinstance(value, list): - if len(value) > 0: - dtype_numpy = np.dtype(type(value[0])).descr[0][1] - else: - dtype_numpy = np.dtype(value.dtype).descr[0][1] +class SequenceStrSoftConverter(SoftConverter[Sequence[str]]): + def write_value(self, value: Any) -> Sequence[str]: + return [str(v) for v in value] if value else [] - return { - "source": source, - "dtype": "array", - "dtype_numpy": dtype_numpy, # type: ignore - "shape": [len(value)], - **metadata, - } - def make_initial_value(self, datatype: type[T] | None) -> T: - if datatype is None: - return cast(T, None) +@dataclass +class SequenceEnumSoftConverter(SoftConverter[Sequence[EnumT]]): + datatype: type[EnumT] - if get_origin(datatype) == abc.Sequence: - return cast(T, []) + def write_value(self, value: Any) -> Sequence[EnumT]: + return [self.datatype(v) for v in value] if value else [] - return cast(T, datatype(shape=0)) # type: ignore +@dataclass +class NDArraySoftConverter(SoftConverter[Array1D]): + datatype: np.dtype -class SoftEnumConverter(SoftConverter): - choices: tuple[str, ...] + def write_value(self, value: Any) -> Array1D: + return np.array(value or (), dtype=self.datatype) - def __init__(self, datatype: RuntimeSubsetEnum | type[Enum]): - if issubclass(datatype, Enum): # type: ignore - self.choices = tuple(v.value for v in datatype) - else: - self.choices = datatype.choices - def write_value(self, value: Enum | str) -> str: - return value # type: ignore +@dataclass +class EnumSoftConverter(SoftConverter[EnumT]): + datatype: type[EnumT] - def get_datakey(self, source: str, value, **metadata) -> DataKey: - return { - "source": source, - "dtype": "string", - # type ignore until https://github.com/bluesky/event-model/issues/308 - "dtype_numpy": "|S40", # type: ignore - "shape": [], - "choices": self.choices, - **metadata, - } - - def make_initial_value(self, datatype: type[T] | None) -> T: - if datatype is None: - return cast(T, None) - - if issubclass(datatype, Enum): - return cast(T, list(datatype.__members__.values())[0]) # type: ignore - return cast(T, self.choices[0]) + def write_value(self, value: Any) -> EnumT: + return ( + self.datatype(value) + if value + else list(self.datatype.__members__.values())[0] + ) -class SoftPydanticModelConverter(SoftConverter): - def __init__(self, datatype: type[BaseModel]): - self.datatype = datatype +@dataclass +class TableSoftConverter(SoftConverter[TableT]): + datatype: type[TableT] - def write_value(self, value): + def write_value(self, value: Any) -> TableT: if isinstance(value, dict): return self.datatype(**value) - return value - - -def make_converter(datatype): - is_array = get_dtype(datatype) is not None - is_sequence = get_origin(datatype) == abc.Sequence - is_enum = inspect.isclass(datatype) and ( - issubclass(datatype, Enum) or issubclass(datatype, RuntimeSubsetEnum) - ) - - if is_array or is_sequence: - return SoftArrayConverter() - if is_enum: - return SoftEnumConverter(datatype) # type: ignore - if is_pydantic_model(datatype): - return SoftPydanticModelConverter(datatype) # type: ignore - - return SoftConverter() - - -class SoftSignalBackend(SignalBackend[T]): + elif isinstance(value, self.datatype): + return value + elif value is None: + return self.datatype() + else: + raise TypeError(f"Cannot convert {value} to {self.datatype}") + + +def make_converter(datatype: type[SignalDatatype]) -> SoftConverter: + enum_cls = get_enum_cls(datatype) + if datatype == Sequence[str]: + return SequenceStrSoftConverter() + elif get_origin(datatype) == Sequence and enum_cls: + return SequenceEnumSoftConverter(enum_cls) + elif get_origin(datatype) == np.ndarray: + return NDArraySoftConverter(get_dtype(datatype)) + elif enum_cls: + return EnumSoftConverter(enum_cls) + elif issubclass(datatype, Table): + return TableSoftConverter(datatype) + elif issubclass(datatype, Primitive): + return PrimitiveSoftConverter(datatype) + raise TypeError(f"Can't make converter for {datatype}") + + +class SoftSignalBackend(SignalBackend[SignalDatatypeT]): """An backend to a soft Signal, for test signals see ``MockSignalBackend``.""" - _value: T - _initial_value: T | None - _timestamp: float - _severity: int - - @classmethod - def datatype_allowed(cls, dtype: type) -> bool: - return True # Any value allowed in a soft signal + _reading: Reading[SignalDatatypeT] + _callback: Callback[Reading[SignalDatatypeT]] | None = None def __init__( self, - datatype: type[T] | None, - initial_value: T | None = None, - metadata: SignalMetadata = None, # type: ignore - ) -> None: - self.datatype = datatype - self._initial_value = initial_value - self._metadata = metadata or {} - self.converter: SoftConverter = make_converter(datatype) - if self._initial_value is None: - self._initial_value = self.converter.make_initial_value(self.datatype) - else: - self._initial_value = self.converter.write_value(self._initial_value) # type: ignore - - self.callback: ReadingValueCallback[T] | None = None - self._severity = 0 - self.set_value(self._initial_value) # type: ignore - - def source(self, name: str) -> str: - return f"soft://{name}" - - async def connect(self, timeout: float = DEFAULT_TIMEOUT) -> None: - """Connection isn't required for soft signals.""" - pass + datatype: type[SignalDatatypeT] | None = None, + initial_value: SignalDatatypeT | None = None, + metadata: SignalMetadata = {}, + ): + # If not specified then default to float + self._datatype = datatype or float + # Create the right converter for the datatype + self._converter = make_converter(self._datatype) + self._initial_value = self._converter.write_value(initial_value) + self._metadata = metadata + self.set_value(self._initial_value) + + def set_value(self, value: SignalDatatypeT): + self._reading = Reading( + value=value, timestamp=time.monotonic(), alarm_severity=0 + ) + if self._callback: + self._callback(self._reading) - async def put(self, value: T | None, wait=True, timeout=None): + async def put(self, value: SignalDatatypeT | None, wait=True, timeout=None) -> None: write_value = ( - self.converter.write_value(value) + self._converter.write_value(value) if value is not None else self._initial_value ) + self.set_value(write_value) - self.set_value(write_value) # type: ignore - - def set_value(self, value: T): - """Method to bypass asynchronous logic.""" - self._value = value - self._timestamp = time.monotonic() - reading: Reading = self.converter.reading( - self._value, self._timestamp, self._severity + async def get_datakey(self, source: str) -> DataKey: + return make_datakey( + self._datatype, self._reading["value"], source, self._metadata ) - if self.callback: - self.callback(reading, self._value) + async def get_reading(self) -> Reading[SignalDatatypeT]: + return self._reading - async def get_datakey(self, source: str) -> DataKey: - return self.converter.get_datakey(source, self._value, **self._metadata) + async def get_value(self) -> SignalDatatypeT: + return self._reading["value"] - async def get_reading(self) -> Reading: - return self.converter.reading(self._value, self._timestamp, self._severity) + async def get_setpoint(self) -> SignalDatatypeT: + # For a soft signal, the setpoint and readback values are the same. + return self._reading["value"] - async def get_value(self) -> T: - return self.converter.value(self._value) + def set_callback(self, callback: Callback[Reading[SignalDatatypeT]] | None) -> None: + if callback: + assert not self._callback, "Cannot set a callback when one is already set" + callback(self._reading) + self._callback = callback - async def get_setpoint(self) -> T: - """For a soft signal, the setpoint and readback values are the same.""" - return await self.get_value() - def set_callback(self, callback: ReadingValueCallback[T] | None) -> None: - if callback: - assert not self.callback, "Cannot set a callback when one is already set" - reading: Reading = self.converter.reading( - self._value, self._timestamp, self._severity - ) - callback(reading, self._value) - self.callback = callback +class MockSignalBackend(SoftSignalBackend[SignalDatatypeT]): + @cached_property + def put_mock(self) -> AsyncMock: + return AsyncMock(name="put", spec=Callable) + + @cached_property + def put_proceeds(self) -> asyncio.Event: + put_proceeds = asyncio.Event() + put_proceeds.set() + return put_proceeds + + async def put(self, value: SignalDatatypeT | None, wait=True, timeout=None): + await self.put_mock(value, wait=wait, timeout=timeout) + await super().put(value, wait, timeout) + + if wait: + await asyncio.wait_for(self.put_proceeds.wait(), timeout=timeout) + + +@dataclass +class SoftSignalConnector(SignalConnector[SignalDatatypeT]): + datatype: type[SignalDatatypeT] + initial_value: SignalDatatypeT | None = None + units: str | None = None + precision: int | None = None + + async def connect( + self, device: Device, mock: bool, timeout: float, force_reconnect: bool + ) -> None: + # Add the extra static metadata to the dictionary + metadata: SignalMetadata = {} + if self.units is not None: + metadata["units"] = self.units + if self.precision is not None: + metadata["precision"] = self.precision + if enum_cls := get_enum_cls(self.datatype): + metadata["choices"] = [v.value for v in enum_cls] + # Create the backend + backend_cls = MockSignalBackend if mock else SoftSignalBackend + self.backend = backend_cls(self.datatype, self.initial_value, metadata) + + def source(self, name: str) -> str: + return f"soft://{name}" + + def set_value(self, value: SignalDatatypeT): + assert isinstance( + self.backend, SoftSignalBackend + ), "Cannot set soft signal value until after connect" + self.backend.set_value(value) diff --git a/src/ophyd_async/core/_table.py b/src/ophyd_async/core/_table.py index f36b60dceb..fc712eb718 100644 --- a/src/ophyd_async/core/_table.py +++ b/src/ophyd_async/core/_table.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import TypeVar, get_args, get_origin +from typing import Any, TypeVar, get_args, get_origin import numpy as np from pydantic import BaseModel, ConfigDict, model_validator @@ -64,14 +64,15 @@ def __add__(self, right: TableSubclass) -> TableSubclass: } ) - def numpy_dtype(self) -> np.dtype: + @classmethod + def numpy_dtype(cls) -> np.dtype: dtype = [] - for field_name, field_value in self.model_fields.items(): + for field_name, field_value in cls.model_fields.items(): if np.ndarray in ( get_origin(field_value.annotation), field_value.annotation, ): - dtype.append((field_name, getattr(self, field_name).dtype)) + dtype.append((field_name, getattr(cls, field_name).dtype)) else: enum_type = get_args(field_value.annotation)[0] assert issubclass(enum_type, Enum) @@ -144,3 +145,9 @@ def validate_arrays(self) -> "Table": ) return self + + def __len__(self) -> int: + return len(next(iter(self))[1]) + + def __getitem__(self) -> Any: + raise NotImplementedError() diff --git a/src/ophyd_async/core/_utils.py b/src/ophyd_async/core/_utils.py index 8c90639e21..13d290feed 100644 --- a/src/ophyd_async/core/_utils.py +++ b/src/ophyd_async/core/_utils.py @@ -2,25 +2,29 @@ import asyncio import logging -from collections.abc import Awaitable, Callable, Iterable +from collections.abc import Awaitable, Callable, Iterable, Sequence from dataclasses import dataclass -from typing import Generic, Literal, ParamSpec, TypeVar, get_origin +from enum import Enum +from typing import Generic, Literal, ParamSpec, TypeVar, get_args, get_origin import numpy as np -from bluesky.protocols import Reading from pydantic import BaseModel T = TypeVar("T") P = ParamSpec("P") Callback = Callable[[T], None] - -#: A function that will be called with the Reading and value when the -#: monitor updates -ReadingValueCallback = Callable[[Reading, T], None] DEFAULT_TIMEOUT = 10.0 ErrorText = str | dict[str, Exception] +class SubsetEnum(str, Enum): + """All members should exist in the Backend, but there may be extras""" + + +class StrictEnum(SubsetEnum): + """All members should exist in the Backend, and there will be no extras""" + + CALCULATE_TIMEOUT = "CALCULATE_TIMEOUT" """Sentinel used to implement ``myfunc(timeout=CalculateTimeout)`` @@ -119,7 +123,7 @@ async def wait_for_connection(**coros: Awaitable[None]): raise NotConnected(exceptions) -def get_dtype(typ: type) -> np.dtype | None: +def get_dtype(datatype: type) -> np.dtype: """Get the runtime dtype from a numpy ndarray type annotation >>> import numpy.typing as npt @@ -127,11 +131,30 @@ def get_dtype(typ: type) -> np.dtype | None: >>> get_dtype(npt.NDArray[np.int8]) dtype('int8') """ - if getattr(typ, "__origin__", None) == np.ndarray: - # datatype = numpy.ndarray[typing.Any, numpy.dtype[numpy.float64]] - # so extract numpy.float64 from it - return np.dtype(typ.__args__[1].__args__[0]) # type: ignore - return None + if not get_origin(datatype) == np.ndarray: + raise TypeError(f"Expected np.ndarray, got {datatype}") + # datatype = numpy.ndarray[typing.Any, numpy.dtype[numpy.float64]] + # so extract numpy.float64 from it + return np.dtype(get_args(get_args(datatype)[1])[0]) + + +def get_enum_cls(datatype: type | None) -> type[SubsetEnum] | None: + """Get the runtime dtype from a numpy ndarray type annotation + + >>> import numpy.typing as npt + >>> import numpy as np + >>> get_dtype(npt.NDArray[np.int8]) + dtype('int8') + """ + if get_origin(datatype) == Sequence: + datatype = get_args(datatype)[0] + if datatype and issubclass(datatype, Enum): + if not issubclass(datatype, SubsetEnum): + raise TypeError( + f"{datatype} should inherit from ophyd_async.core.SubsetEnum " + "or ophyd_async.core.StrictEnum" + ) + return datatype def get_unique(values: dict[str, T], types: str) -> T: diff --git a/src/ophyd_async/epics/signal/__init__.py b/src/ophyd_async/epics/signal/__init__.py index 8d7628bf01..5122e38556 100644 --- a/src/ophyd_async/epics/signal/__init__.py +++ b/src/ophyd_async/epics/signal/__init__.py @@ -1,4 +1,4 @@ -from ._common import LimitPair, Limits, get_supported_values +from ._common import get_supported_values from ._p4p import PvaSignalBackend from ._signal import ( epics_signal_r, @@ -10,8 +10,6 @@ __all__ = [ "get_supported_values", - "LimitPair", - "Limits", "PvaSignalBackend", "epics_signal_r", "epics_signal_rw", diff --git a/src/ophyd_async/epics/signal/_aioca.py b/src/ophyd_async/epics/signal/_aioca.py index bdac6d878f..4e6d0d0c18 100644 --- a/src/ophyd_async/epics/signal/_aioca.py +++ b/src/ophyd_async/epics/signal/_aioca.py @@ -1,253 +1,298 @@ -import inspect -import logging import sys from collections.abc import Sequence from dataclasses import dataclass -from enum import Enum from math import isnan, nan -from typing import Any, get_origin +from typing import Any, Generic, cast, get_origin import numpy as np from aioca import ( FORMAT_CTRL, FORMAT_RAW, FORMAT_TIME, - CANothing, Subscription, caget, camonitor, caput, ) from aioca.types import AugmentedValue, Dbr, Format -from bluesky.protocols import Reading from epicscorelibs.ca import dbr from event_model import DataKey -from event_model.documents.event_descriptor import Dtype +from event_model.documents.event_descriptor import Limits, LimitsRange from ophyd_async.core import ( - DEFAULT_TIMEOUT, - NotConnected, - ReadingValueCallback, - RuntimeSubsetEnum, + Array1D, + Callback, + Device, + MockSignalBackend, SignalBackend, - T, - get_dtype, + SignalConnector, + SignalDatatype, + SignalDatatypeT, + SignalMetadata, + get_enum_cls, get_unique, + make_datakey, wait_for_connection, ) +from ophyd_async.core._protocol import Reading -from ._common import LimitPair, Limits, common_meta, get_supported_values - -dbr_to_dtype: dict[Dbr, Dtype] = { - dbr.DBR_STRING: "string", - dbr.DBR_SHORT: "integer", - dbr.DBR_FLOAT: "number", - dbr.DBR_CHAR: "string", - dbr.DBR_LONG: "integer", - dbr.DBR_DOUBLE: "number", -} - - -def _data_key_from_augmented_value( - value: AugmentedValue, - *, - choices: list[str] | None = None, - dtype: Dtype | None = None, -) -> DataKey: - """Use the return value of get with FORMAT_CTRL to construct a DataKey - describing the signal. See docstring of AugmentedValue for expected - value fields by DBR type. - - Args: - value (AugmentedValue): Description of the the return type of a DB record - choices: Optional list of enum choices to pass as metadata in the datakey - dtype: Optional override dtype when AugmentedValue is ambiguous, e.g. booleans - - Returns: - DataKey: A rich DataKey describing the DB record - """ - source = f"ca://{value.name}" - assert value.ok, f"Error reading {source}: {value}" - - scalar = value.element_count == 1 - dtype = dtype or dbr_to_dtype[value.datatype] # type: ignore - - dtype_numpy = np.dtype(dbr.DbrCodeToType[value.datatype].dtype).descr[0][1] - - d = DataKey( - source=source, - dtype=dtype if scalar else "array", - # Ignore until https://github.com/bluesky/event-model/issues/308 - dtype_numpy=dtype_numpy, # type: ignore - # strictly value.element_count >= len(value) - shape=[] if scalar else [len(value)], - ) - for key in common_meta: - attr = getattr(value, key, nan) - if isinstance(attr, str) or not isnan(attr): - d[key] = attr - - if choices is not None: - d["choices"] = choices # type: ignore - - if limits := _limits_from_augmented_value(value): - d["limits"] = limits # type: ignore - - return d +from ._common import get_supported_values def _limits_from_augmented_value(value: AugmentedValue) -> Limits: - def get_limits(limit: str) -> LimitPair: + def get_limits(limit: str) -> LimitsRange | None: low = getattr(value, f"lower_{limit}_limit", nan) high = getattr(value, f"upper_{limit}_limit", nan) - return LimitPair( - low=None if isnan(low) else low, high=None if isnan(high) else high - ) - - return Limits( - alarm=get_limits("alarm"), - control=get_limits("ctrl"), - display=get_limits("disp"), - warning=get_limits("warning"), - ) - + if not (isnan(low) and isnan(high)): + return LimitsRange( + low=None if isnan(low) else low, + high=None if isnan(high) else high, + ) -@dataclass -class CaConverter: - read_dbr: Dbr | None - write_dbr: Dbr | None + limits = Limits() + if limits_range := get_limits("alarm"): + limits["alarm"] = limits_range + if limits_range := get_limits("ctrl"): + limits["control"] = limits_range + if limits_range := get_limits("disp"): + limits["display"] = limits_range + if limits_range := get_limits("warning"): + limits["warning"] = limits_range + return limits + + +def _metadata_from_augmented_value( + value: AugmentedValue, metadata: SignalMetadata +) -> SignalMetadata: + metadata = metadata.copy() + if hasattr(value, "units"): + metadata["units"] = value.units + if hasattr(value, "precision") and not isnan(value.precision): + metadata["precision"] = value.precision + if limits := _limits_from_augmented_value(value): + metadata["limits"] = limits + return metadata + + +class CaConverter(Generic[SignalDatatypeT]): + def __init__( + self, + datatype: type[SignalDatatypeT], + read_dbr: Dbr, + write_dbr: Dbr | None = None, + metadata: SignalMetadata | None = None, + ): + self.datatype = datatype + self.read_dbr: Dbr = read_dbr + self.write_dbr: Dbr | None = write_dbr + self.metadata = metadata or SignalMetadata() - def write_value(self, value) -> Any: + def write_value(self, value: Any) -> Any: + # The ca library will do the conversion for us return value - def value(self, value: AugmentedValue): + def value(self, value: AugmentedValue) -> SignalDatatypeT: # for channel access ca_xxx classes, this # invokes __pos__ operator to return an instance of # the builtin base class return +value # type: ignore - def reading(self, value: AugmentedValue) -> Reading: - return { - "value": self.value(value), - "timestamp": value.timestamp, - "alarm_severity": -1 if value.severity > 2 else value.severity, - } - - def get_datakey(self, value: AugmentedValue) -> DataKey: - return _data_key_from_augmented_value(value) - -class CaLongStrConverter(CaConverter): +class CaLongStrConverter(CaConverter[str]): def __init__(self): - return super().__init__(dbr.DBR_CHAR_STR, dbr.DBR_CHAR_STR) + super().__init__(str, dbr.DBR_CHAR_STR, dbr.DBR_CHAR_STR) - def write_value(self, value: str): + def write_value_and_dbr(self, value: Any) -> Any: # Add a null in here as this is what the commandline caput does # TODO: this should be in the server so check if it can be pushed to asyn return value + "\0" -class CaArrayConverter(CaConverter): - def value(self, value: AugmentedValue): +class CaArrayConverter(CaConverter[np.ndarray]): + def value(self, value: AugmentedValue) -> np.ndarray: + # A less expensive conversion return np.array(value, copy=False) -@dataclass -class CaEnumConverter(CaConverter): - """To prevent issues when a signal is restarted and returns with different enum - values or orders, we put treat an Enum signal as a string, and cache the - choices on this class. - """ +class CaEnumConverter(CaConverter[str]): + def __init__(self, supported_values: dict[str, str]): + self.supported_values = supported_values + super().__init__( + str, dbr.DBR_STRING, metadata=SignalMetadata(choices=list(supported_values)) + ) - choices: dict[str, str] + def value(self, value: AugmentedValue) -> str: + return self.supported_values[str(value)] - def write_value(self, value: Enum | str): - if isinstance(value, Enum): - return value.value - else: - return value - def value(self, value: AugmentedValue): - return self.choices[value] # type: ignore +class CaSequenceStrConverter(CaConverter[Sequence[str]]): + def __init__(self): + super().__init__(Sequence[str], dbr.DBR_STRING) - def get_datakey(self, value: AugmentedValue) -> DataKey: - # Sometimes DBR_TYPE returns as String, must pass choices still - return _data_key_from_augmented_value(value, choices=list(self.choices.keys())) + def value(self, value: AugmentedValue) -> Sequence[str]: + return [str(v) for v in value] # type: ignore -@dataclass -class CaBoolConverter(CaConverter): +class CaBoolConverter(CaConverter[bool]): + def __init__(self): + super().__init__(bool, dbr.DBR_SHORT) + def value(self, value: AugmentedValue) -> bool: return bool(value) - def get_datakey(self, value: AugmentedValue) -> DataKey: - return _data_key_from_augmented_value(value, dtype="boolean") - -class DisconnectedCaConverter(CaConverter): - def __getattribute__(self, __name: str) -> Any: - raise NotImplementedError("No PV has been set as connect() has not been called") +_datatypes_from_dbr: dict[tuple[Dbr, bool], type[SignalDatatype]] = { + (dbr.DBR_STRING, False): str, + (dbr.DBR_SHORT, False): int, + (dbr.DBR_FLOAT, False): float, + (dbr.DBR_ENUM, False): str, + (dbr.DBR_CHAR, False): int, + (dbr.DBR_LONG, False): int, + (dbr.DBR_DOUBLE, False): float, + (dbr.DBR_STRING, True): Sequence[str], + (dbr.DBR_SHORT, True): Array1D[np.int16], + (dbr.DBR_FLOAT, True): Array1D[np.float32], + (dbr.DBR_ENUM, True): Sequence[str], + (dbr.DBR_CHAR, True): Array1D[np.uint8], + (dbr.DBR_LONG, True): Array1D[np.int32], + (dbr.DBR_DOUBLE, True): Array1D[np.float64], +} def make_converter( datatype: type | None, values: dict[str, AugmentedValue] ) -> CaConverter: pv = list(values)[0] - pv_dbr = get_unique({k: v.datatype for k, v in values.items()}, "datatypes") + pv_dbr = cast( + Dbr, get_unique({k: v.datatype for k, v in values.items()}, "datatypes") + ) is_array = bool([v for v in values.values() if v.element_count > 1]) - if is_array and datatype is str and pv_dbr == dbr.DBR_CHAR: - # Override waveform of chars to be treated as string - return CaLongStrConverter() - elif is_array and pv_dbr == dbr.DBR_STRING: - # Waveform of strings, check we wanted this - if datatype: - datatype_dtype = get_dtype(datatype) - if not datatype_dtype or not np.can_cast(datatype_dtype, np.str_): - raise TypeError(f"{pv} has type [str] not {datatype.__name__}") - return CaArrayConverter(pv_dbr, None) - elif is_array: - pv_dtype = get_unique({k: v.dtype for k, v in values.items()}, "dtypes") # type: ignore - # This is an array - if datatype: - # Check we wanted an array of this type - dtype = get_dtype(datatype) - if not dtype: - raise TypeError(f"{pv} has type [{pv_dtype}] not {datatype.__name__}") - if dtype != pv_dtype: - raise TypeError(f"{pv} has type [{pv_dtype}] not [{dtype}]") - return CaArrayConverter(pv_dbr, None) # type: ignore - elif pv_dbr == dbr.DBR_ENUM and datatype is bool: - # Database can't do bools, so are often representated as enums, - # CA can do int - pv_choices_len = get_unique( - {k: len(v.enums) for k, v in values.items()}, "number of choices" - ) - if pv_choices_len != 2: - raise TypeError(f"{pv} has {pv_choices_len} choices, can't map to bool") - return CaBoolConverter(dbr.DBR_SHORT, dbr.DBR_SHORT) - elif pv_dbr == dbr.DBR_ENUM: - # This is an Enum - pv_choices = get_unique( - {k: tuple(v.enums) for k, v in values.items()}, "choices" - ) - supported_values = get_supported_values(pv, datatype, pv_choices) - return CaEnumConverter(dbr.DBR_STRING, None, supported_values) + # Infer a datatype from the dbr + inferred_datatype = _datatypes_from_dbr[(pv_dbr, is_array)] + # Create the correct converter based on requested datatype + if is_array: + if pv_dbr == dbr.DBR_STRING and datatype in (None, Sequence[str]): + # Otherwise they get string if requested or inferred + return CaSequenceStrConverter() + elif pv_dbr == dbr.DBR_CHAR and datatype is str: + # Override waveform of chars to be treated as string + return CaLongStrConverter() + elif ( + datatype in (None, inferred_datatype) + and get_origin(inferred_datatype) == np.ndarray + ): + # The requested datatype matches the inferred datatype, so use that + # We verify the origin of inferred_datatype above, but pyright doesn't know + # that, so do a cast below + return CaArrayConverter(cast(type[np.ndarray], inferred_datatype), pv_dbr) else: - value = list(values.values())[0] - # Done the dbr check, so enough to check one of the values - if datatype and not isinstance(value, datatype): - # Allow int signals to represent float records when prec is 0 - is_prec_zero_float = ( - isinstance(value, float) - and get_unique({k: v.precision for k, v in values.items()}, "precision") - == 0 + if pv_dbr == dbr.DBR_ENUM: + pv_choices = get_unique( + {k: tuple(v.enums) for k, v in values.items()}, "choices" ) - if not (datatype is int and is_prec_zero_float): - raise TypeError( - f"{pv} has type {type(value).__name__.replace('ca_', '')} " - + f"not {datatype.__name__}" + if datatype is bool: + # Database can't do bools, so are often representated as enums of len 2 + if len(pv_choices) != 2: + raise TypeError(f"{pv} has {pv_choices=}, can't map to bool") + return CaBoolConverter() + elif enum_cls := get_enum_cls(datatype): + # If explicitly requested then check + supported_values = get_supported_values(pv, enum_cls, pv_choices) + return CaEnumConverter(supported_values) + else: + # Drop to string, but retain choices as metadata + return CaConverter( + str, + dbr.DBR_STRING, + metadata=SignalMetadata(choices=list(pv_choices)), ) - return CaConverter(pv_dbr, None) # type: ignore + elif ( + pv_dbr == dbr.DBR_DOUBLE + and get_unique({k: v.precision for k, v in values.items()}, "precision") + == 0 + ): + # Allow int signals to represent float records when prec is 0 + return CaConverter(int, pv_dbr) + elif datatype in (None, inferred_datatype): + # If datatype matches what we are given then allow it + return CaConverter(inferred_datatype, pv_dbr) + raise TypeError( + f"{pv} with inferred datatype {inferred_datatype} cannot be coerced to {datatype}" + ) + + +class CaSignalBackend(SignalBackend[SignalDatatypeT]): + def __init__( + self, + datatype: type[SignalDatatypeT] | None, + read_pv: str, + write_pv: str, + initial_values: dict[str, AugmentedValue], + ): + self._converter = make_converter(datatype, initial_values) + self._read_pv = read_pv + self._write_pv = write_pv + self._initial_values = initial_values + self._subscription: Subscription | None = None + + async def _caget(self, pv: str, format: Format) -> AugmentedValue: + return await caget( + pv, datatype=self._converter.read_dbr, format=format, timeout=None + ) + + def _make_reading(self, value: AugmentedValue) -> Reading[SignalDatatypeT]: + return { + "value": self._converter.value(value), + "timestamp": value.timestamp, + "alarm_severity": -1 if value.severity > 2 else value.severity, + } + + async def put(self, value: SignalDatatypeT | None, wait=True, timeout=None): + if value is None: + write_value = self._initial_values[self._write_pv] + else: + write_value = self._converter.write_value(value) + await caput( + self._write_pv, + write_value, + datatype=self._converter.write_dbr, + wait=wait, + timeout=timeout, + ) + + async def get_datakey(self, source: str) -> DataKey: + value = await self._caget(self._read_pv, FORMAT_CTRL) + metadata = _metadata_from_augmented_value(value, self._converter.metadata) + return make_datakey( + self._converter.datatype, self._converter.value(value), source, metadata + ) + + async def get_reading(self) -> Reading[SignalDatatypeT]: + value = await self._caget(self._read_pv, FORMAT_TIME) + return self._make_reading(value) + + async def get_value(self) -> SignalDatatypeT: + value = await self._caget(self._read_pv, FORMAT_RAW) + return self._converter.value(value) + + async def get_setpoint(self) -> SignalDatatypeT: + value = await self._caget(self._write_pv, FORMAT_RAW) + return self._converter.value(value) + + def set_callback(self, callback: Callback[Reading[SignalDatatypeT]] | None) -> None: + if callback: + assert ( + not self._subscription + ), "Cannot set a callback when one is already set" + self._subscription = camonitor( + self._read_pv, + lambda v: callback(self._make_reading(v)), + datatype=self._converter.read_dbr, + format=FORMAT_TIME, + ) + elif self._subscription: + self._subscription.close() + self._subscription = None _tried_pyepics = False @@ -262,117 +307,39 @@ def _use_pyepics_context_if_imported(): _tried_pyepics = True -class CaSignalBackend(SignalBackend[T]): - _ALLOWED_DATATYPES = ( - bool, - int, - float, - str, - Sequence, - Enum, - RuntimeSubsetEnum, - np.ndarray, - ) - - @classmethod - def datatype_allowed(cls, dtype: Any) -> bool: - stripped_origin = get_origin(dtype) or dtype - if dtype is None: - return True - - return inspect.isclass(stripped_origin) and issubclass( - stripped_origin, cls._ALLOWED_DATATYPES - ) +@dataclass +class CaSignalConnector(SignalConnector[SignalDatatypeT]): + datatype: type[SignalDatatypeT] | None + read_pv: str + write_pv: str + + async def connect( + self, device: Device, mock: bool, timeout: float, force_reconnect: bool + ) -> None: + if mock: + self.backend = MockSignalBackend(self.datatype) + else: + self.backend = await self.connect_epics(timeout) - def __init__(self, datatype: type[T] | None, read_pv: str, write_pv: str): - self.datatype = datatype - if not CaSignalBackend.datatype_allowed(self.datatype): - raise TypeError(f"Given datatype {self.datatype} unsupported in CA.") - self.read_pv = read_pv - self.write_pv = write_pv - self.initial_values: dict[str, AugmentedValue] = {} - self.converter: CaConverter = DisconnectedCaConverter(None, None) - self.subscription: Subscription | None = None - - def source(self, name: str): - return f"ca://{self.read_pv}" + async def connect_epics(self, timeout: float) -> CaSignalBackend: + _use_pyepics_context_if_imported() + initial_values: dict[str, AugmentedValue] = {} - async def _store_initial_value(self, pv, timeout: float = DEFAULT_TIMEOUT): - try: - self.initial_values[pv] = await caget( - pv, format=FORMAT_CTRL, timeout=timeout - ) - except CANothing as exc: - logging.debug(f"signal ca://{pv} timed out") - raise NotConnected(f"ca://{pv}") from exc + async def store_initial_value(pv: str): + initial_values[pv] = await caget(pv, format=FORMAT_CTRL, timeout=timeout) - async def connect(self, timeout: float = DEFAULT_TIMEOUT): - _use_pyepics_context_if_imported() if self.read_pv != self.write_pv: # Different, need to connect both await wait_for_connection( - read_pv=self._store_initial_value(self.read_pv, timeout=timeout), - write_pv=self._store_initial_value(self.write_pv, timeout=timeout), + read_pv=store_initial_value(self.read_pv), + write_pv=store_initial_value(self.write_pv), ) else: # The same, so only need to connect one - await self._store_initial_value(self.read_pv, timeout=timeout) - self.converter = make_converter(self.datatype, self.initial_values) - - async def put(self, value: T | None, wait=True, timeout=None): - if value is None: - write_value = self.initial_values[self.write_pv] - else: - write_value = self.converter.write_value(value) - await caput( - self.write_pv, - write_value, - datatype=self.converter.write_dbr, - wait=wait, - timeout=timeout, - ) - - async def _caget(self, format: Format) -> AugmentedValue: - return await caget( - self.read_pv, - datatype=self.converter.read_dbr, - format=format, - timeout=None, - ) - - async def get_datakey(self, source: str) -> DataKey: - value = await self._caget(FORMAT_CTRL) - return self.converter.get_datakey(value) - - async def get_reading(self) -> Reading: - value = await self._caget(FORMAT_TIME) - return self.converter.reading(value) - - async def get_value(self) -> T: - value = await self._caget(FORMAT_RAW) - return self.converter.value(value) - - async def get_setpoint(self) -> T: - value = await caget( - self.write_pv, - datatype=self.converter.read_dbr, - format=FORMAT_RAW, - timeout=None, + await store_initial_value(self.read_pv) + return CaSignalBackend( + self.datatype, self.read_pv, self.write_pv, initial_values ) - return self.converter.value(value) - def set_callback(self, callback: ReadingValueCallback[T] | None) -> None: - if callback: - assert ( - not self.subscription - ), "Cannot set a callback when one is already set" - self.subscription = camonitor( - self.read_pv, - lambda v: callback(self.converter.reading(v), self.converter.value(v)), - datatype=self.converter.read_dbr, - format=FORMAT_TIME, - ) - else: - if self.subscription: - self.subscription.close() - self.subscription = None + def source(self, name: str) -> str: + return f"ca://{self.read_pv}" diff --git a/src/ophyd_async/epics/signal/_common.py b/src/ophyd_async/epics/signal/_common.py index ae40e93029..a1862612a6 100644 --- a/src/ophyd_async/epics/signal/_common.py +++ b/src/ophyd_async/epics/signal/_common.py @@ -1,57 +1,29 @@ -import inspect -from enum import Enum +from collections.abc import Sequence -from typing_extensions import TypedDict - -from ophyd_async.core import RuntimeSubsetEnum - -common_meta = { - "units", - "precision", -} - - -class LimitPair(TypedDict): - high: float | None - low: float | None - - -class Limits(TypedDict): - alarm: LimitPair - control: LimitPair - display: LimitPair - warning: LimitPair +from ophyd_async.core import StrictEnum, get_enum_cls def get_supported_values( pv: str, - datatype: type[str] | None, - pv_choices: tuple[str, ...], + datatype: type, + pv_choices: Sequence[str], ) -> dict[str, str]: - if inspect.isclass(datatype) and issubclass(datatype, RuntimeSubsetEnum): - if not set(datatype.choices).issubset(set(pv_choices)): - raise TypeError( - f"{pv} has choices {pv_choices}, " - f"which is not a superset of {str(datatype)}." - ) - return {x: x or "_" for x in pv_choices} - elif inspect.isclass(datatype) and issubclass(datatype, Enum): - if not issubclass(datatype, str): - raise TypeError( - f"{pv} is type Enum but {datatype} does not inherit from String." - ) - - choices = tuple(v.value for v in datatype) + enum_cls = get_enum_cls(datatype) + if not enum_cls: + raise TypeError(f"{datatype} is not an Enum") + choices = [v.value for v in enum_cls] + error_msg = f"{pv} has choices {pv_choices}, but {datatype} requested {choices} " + if issubclass(enum_cls, StrictEnum): if set(choices) != set(pv_choices): - raise TypeError( - f"{pv} has choices {pv_choices}, " - f"which do not match {datatype}, which has {choices}." - ) - return {x: datatype(x) if x else "_" for x in pv_choices} - elif datatype is None or datatype is str: - return {x: x or "_" for x in pv_choices} - - raise TypeError( - f"{pv} has choices {pv_choices}. " - "Use an Enum or SubsetEnum to represent this." - ) + raise TypeError(error_msg + "to be a subset of them.") + + else: + if not set(choices).issubset(pv_choices): + raise TypeError(error_msg + "to be strictly equal to them.") + + # Take order from the pv choices + supported_values = {x: x for x in pv_choices} + # But override those that we specify via the datatype + for v in enum_cls: + supported_values[v.value] = v + return supported_values diff --git a/src/ophyd_async/epics/signal/_epics_transport.py b/src/ophyd_async/epics/signal/_epics_transport.py deleted file mode 100644 index 4737de704f..0000000000 --- a/src/ophyd_async/epics/signal/_epics_transport.py +++ /dev/null @@ -1,34 +0,0 @@ -"""EPICS Signals over CA or PVA""" - -from __future__ import annotations - -from enum import Enum - - -def _make_unavailable_class(error: Exception) -> type: - class TransportNotAvailable: - def __init__(*args, **kwargs): - raise NotImplementedError("Transport not available") from error - - return TransportNotAvailable - - -try: - from ._aioca import CaSignalBackend -except ImportError as ca_error: - CaSignalBackend = _make_unavailable_class(ca_error) - - -try: - from ._p4p import PvaSignalBackend -except ImportError as pva_error: - PvaSignalBackend = _make_unavailable_class(pva_error) - - -class _EpicsTransport(Enum): - """The sorts of transport EPICS support""" - - #: Use Channel Access (using aioca library) - ca = CaSignalBackend - #: Use PVAccess (using p4p library) - pva = PvaSignalBackend diff --git a/src/ophyd_async/epics/signal/_p4p.py b/src/ophyd_async/epics/signal/_p4p.py index 6fe13d0e2c..21ae85d376 100644 --- a/src/ophyd_async/epics/signal/_p4p.py +++ b/src/ophyd_async/epics/signal/_p4p.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio import atexit import inspect @@ -20,8 +22,6 @@ from ophyd_async.core import ( DEFAULT_TIMEOUT, NotConnected, - ReadingValueCallback, - RuntimeSubsetEnum, SignalBackend, T, get_dtype, @@ -30,7 +30,7 @@ wait_for_connection, ) -from ._common import LimitPair, Limits, common_meta, get_supported_values +from ._common import get_supported_values # https://mdavidsaver.github.io/p4p/values.html specifier_to_dtype: dict[str, Dtype] = { @@ -390,7 +390,6 @@ class PvaSignalBackend(SignalBackend[T]): Sequence, np.ndarray, Enum, - RuntimeSubsetEnum, BaseModel, dict, ) diff --git a/src/ophyd_async/epics/signal/_signal.py b/src/ophyd_async/epics/signal/_signal.py index 6711ac734e..dd82891f4b 100644 --- a/src/ophyd_async/epics/signal/_signal.py +++ b/src/ophyd_async/epics/signal/_signal.py @@ -2,46 +2,81 @@ from __future__ import annotations +from enum import Enum + from ophyd_async.core import ( - SignalBackend, + SignalConnector, + SignalDatatypeT, SignalR, SignalRW, SignalW, SignalX, - T, get_unique, ) -from ._epics_transport import _EpicsTransport -_default_epics_transport = _EpicsTransport.ca +def _make_unavailable_class(error: Exception) -> type: + class TransportNotAvailable: + def __init__(*args, **kwargs): + raise NotImplementedError("Transport not available") from error + + return TransportNotAvailable + + +class EpicsProtocol(Enum): + ca = "ca" + pva = "pva" + + +_default_epics_protocol = EpicsProtocol.ca + +try: + from ._p4p import PvaSignalConnector +except ImportError as pva_error: + PvaSignalConnector = _make_unavailable_class(pva_error) +else: + _default_epics_protocol = EpicsProtocol.pva + +try: + from ._aioca import CaSignalConnector +except ImportError as ca_error: + CaSignalConnector = _make_unavailable_class(ca_error) +else: + _default_epics_protocol = EpicsProtocol.ca -def _transport_pv(pv: str) -> tuple[_EpicsTransport, str]: +def _protocol_pv(pv: str) -> tuple[EpicsProtocol, str]: split = pv.split("://", 1) if len(split) > 1: # We got something like pva://mydevice, so use specified comms mode - transport_str, pv = split - transport = _EpicsTransport[transport_str] + scheme, pv = split + protocol = EpicsProtocol[scheme] else: # No comms mode specified, use the default - transport = _default_epics_transport - return transport, pv + protocol = _default_epics_protocol + return protocol, pv -def _epics_signal_backend( - datatype: type[T] | None, read_pv: str, write_pv: str -) -> SignalBackend[T]: - """Create an epics signal backend.""" - r_transport, r_pv = _transport_pv(read_pv) - w_transport, w_pv = _transport_pv(write_pv) - transport = get_unique({read_pv: r_transport, write_pv: w_transport}, "transports") - return transport.value(datatype, r_pv, w_pv) +def _epics_signal_connector( + datatype: type[SignalDatatypeT] | None, read_pv: str, write_pv: str +) -> SignalConnector[SignalDatatypeT]: + """Create an epics signal connector.""" + r_protocol, r_pv = _protocol_pv(read_pv) + w_protocol, w_pv = _protocol_pv(write_pv) + protocol = get_unique({read_pv: r_protocol, write_pv: w_protocol}, "protocols") + match protocol: + case EpicsProtocol.ca: + return CaSignalConnector(datatype, r_pv, w_pv) + case EpicsProtocol.pva: + return PvaSignalConnector(datatype, r_pv, w_pv) def epics_signal_rw( - datatype: type[T], read_pv: str, write_pv: str | None = None, name: str = "" -) -> SignalRW[T]: + datatype: type[SignalDatatypeT], + read_pv: str, + write_pv: str | None = None, + name: str = "", +) -> SignalRW[SignalDatatypeT]: """Create a `SignalRW` backed by 1 or 2 EPICS PVs Parameters @@ -53,13 +88,16 @@ def epics_signal_rw( write_pv: If given, use this PV to write to, otherwise use read_pv """ - backend = _epics_signal_backend(datatype, read_pv, write_pv or read_pv) - return SignalRW(backend, name=name) + connector = _epics_signal_connector(datatype, read_pv, write_pv or read_pv) + return SignalRW(connector, name=name) def epics_signal_rw_rbv( - datatype: type[T], write_pv: str, read_suffix: str = "_RBV", name: str = "" -) -> SignalRW[T]: + datatype: type[SignalDatatypeT], + write_pv: str, + read_suffix: str = "_RBV", + name: str = "", +) -> SignalRW[SignalDatatypeT]: """Create a `SignalRW` backed by 1 or 2 EPICS PVs, with a suffix on the readback pv Parameters @@ -74,7 +112,9 @@ def epics_signal_rw_rbv( return epics_signal_rw(datatype, f"{write_pv}{read_suffix}", write_pv, name) -def epics_signal_r(datatype: type[T], read_pv: str, name: str = "") -> SignalR[T]: +def epics_signal_r( + datatype: type[SignalDatatypeT], read_pv: str, name: str = "" +) -> SignalR[SignalDatatypeT]: """Create a `SignalR` backed by 1 EPICS PV Parameters @@ -84,11 +124,13 @@ def epics_signal_r(datatype: type[T], read_pv: str, name: str = "") -> SignalR[T read_pv: The PV to read and monitor """ - backend = _epics_signal_backend(datatype, read_pv, read_pv) - return SignalR(backend, name=name) + connector = _epics_signal_connector(datatype, read_pv, read_pv) + return SignalR(connector, name=name) -def epics_signal_w(datatype: type[T], write_pv: str, name: str = "") -> SignalW[T]: +def epics_signal_w( + datatype: type[SignalDatatypeT], write_pv: str, name: str = "" +) -> SignalW[SignalDatatypeT]: """Create a `SignalW` backed by 1 EPICS PVs Parameters @@ -98,8 +140,8 @@ def epics_signal_w(datatype: type[T], write_pv: str, name: str = "") -> SignalW[ write_pv: The PV to write to """ - backend = _epics_signal_backend(datatype, write_pv, write_pv) - return SignalW(backend, name=name) + connector = _epics_signal_connector(datatype, write_pv, write_pv) + return SignalW(connector, name=name) def epics_signal_x(write_pv: str, name: str = "") -> SignalX: @@ -110,5 +152,5 @@ def epics_signal_x(write_pv: str, name: str = "") -> SignalX: write_pv: The PV to write its initial value to on trigger """ - backend: SignalBackend = _epics_signal_backend(None, write_pv, write_pv) - return SignalX(backend, name=name) + connector = _epics_signal_connector(None, write_pv, write_pv) + return SignalX(connector, name=name) diff --git a/src/ophyd_async/py.typed b/src/ophyd_async/py.typed new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/core/test_soft_signal_backend.py b/tests/core/test_soft_signal_backend.py index 65316d18a8..60bfd5e047 100644 --- a/tests/core/test_soft_signal_backend.py +++ b/tests/core/test_soft_signal_backend.py @@ -1,7 +1,6 @@ import asyncio import time from collections.abc import Callable, Sequence -from enum import Enum from typing import Any import numpy as np @@ -9,10 +8,17 @@ import pytest from bluesky.protocols import Reading -from ophyd_async.core import Signal, SignalBackend, SignalMetadata, SoftSignalBackend, T +from ophyd_async.core import ( + SignalBackend, + SignalR, + SoftSignalConnector, + StrictEnum, + T, + soft_signal_rw, +) -class MyEnum(str, Enum): +class MyEnum(StrictEnum): a = "Aaa" b = "Bbb" c = "Ccc" @@ -31,7 +37,7 @@ def string_d(value): def enum_d(value): - return {"dtype": "string", "shape": [], "choices": ("Aaa", "Bbb", "Ccc")} + return {"dtype": "string", "shape": [], "choices": ["Aaa", "Bbb", "Ccc"]} def waveform_d(value): @@ -41,11 +47,8 @@ def waveform_d(value): class MonitorQueue: def __init__(self, backend: SignalBackend): self.backend = backend - self.updates: asyncio.Queue[tuple[Reading, Any]] = asyncio.Queue() - backend.set_callback(self.add_reading_value) - - def add_reading_value(self, reading: Reading, value): - self.updates.put_nowait((reading, value)) + self.updates: asyncio.Queue[Reading] = asyncio.Queue() + backend.set_callback(self.updates.put_nowait) async def assert_updates(self, expected_value): expected_reading = { @@ -53,12 +56,12 @@ async def assert_updates(self, expected_value): "timestamp": pytest.approx(time.monotonic(), rel=0.1), "alarm_severity": 0, } - reading, value = await self.updates.get() + reading = await self.updates.get() backend_value = await self.backend.get_value() backend_reading = await self.backend.get_reading() - assert value == expected_value == backend_value + assert reading["value"] == expected_value == backend_value assert reading == expected_reading == backend_reading def close(self): @@ -70,19 +73,19 @@ def close(self): [ (int, 0, 43, integer_d, " None: pass - soft_signal = Signal(SoftSignalBackend(myClass)) - await soft_signal.connect() - - with pytest.raises(AssertionError): - await soft_signal._backend.get_datakey("") + soft_signal = SignalR(SoftSignalConnector(myClass)) + with pytest.raises(TypeError): + await soft_signal.connect() async def test_soft_signal_descriptor_with_metadata(): - soft_signal = Signal( - SoftSignalBackend(int, 0, metadata=SignalMetadata(units="mm", precision=0)) - ) + soft_signal = soft_signal_rw(int, 0, units="mm", precision=0) await soft_signal.connect() - datakey = await soft_signal._backend.get_datakey("") - assert datakey["units"] == "mm" - assert datakey["precision"] == 0 + datakey = await soft_signal.describe() + assert datakey[""]["units"] == "mm" + assert datakey[""]["precision"] == 0 - soft_signal = Signal(SoftSignalBackend(int, metadata=SignalMetadata(units=""))) + soft_signal = soft_signal_rw(int, units="") await soft_signal.connect() - datakey = await soft_signal._backend.get_datakey("") - assert datakey["units"] == "" - assert not hasattr(datakey, "precision") + datakey = await soft_signal.describe() + assert datakey[""]["units"] == "" + assert not hasattr(datakey[""], "precision") async def test_soft_signal_descriptor_with_no_metadata_not_passed(): - soft_signal = Signal(SoftSignalBackend(int)) - await soft_signal.connect() - datakey = await soft_signal._backend.get_datakey("") - assert not hasattr(datakey, "units") - assert not hasattr(datakey, "precision") - - soft_signal = Signal(SoftSignalBackend(int, metadata=None)) - await soft_signal.connect() - datakey = await soft_signal._backend.get_datakey("") - assert not hasattr(datakey, "units") - assert not hasattr(datakey, "precision") - - soft_signal = Signal(SoftSignalBackend(int, metadata={})) + soft_signal = soft_signal_rw(int) await soft_signal.connect() - datakey = await soft_signal._backend.get_datakey("") - assert not hasattr(datakey, "units") - assert not hasattr(datakey, "precision") + datakey = await soft_signal.describe() + assert not hasattr(datakey[""], "units") + assert not hasattr(datakey[""], "precision") diff --git a/tests/epics/signal/test_common.py b/tests/epics/signal/test_common.py index 7a16a59a51..124273b2f0 100644 --- a/tests/epics/signal/test_common.py +++ b/tests/epics/signal/test_common.py @@ -2,6 +2,7 @@ import pytest +from ophyd_async.core import StrictEnum from ophyd_async.epics.signal import get_supported_values @@ -19,7 +20,7 @@ class MyEnum(Enum): def test_given_pv_has_choices_not_in_supplied_enum_then_raises(): - class MyEnum(str, Enum): + class MyEnum(StrictEnum): TEST = "test" with pytest.raises(TypeError): @@ -27,7 +28,7 @@ class MyEnum(str, Enum): def test_given_supplied_enum_has_choices_not_in_pv_then_raises(): - class MyEnum(str, Enum): + class MyEnum(StrictEnum): TEST = "test" OTHER = "unexpected_choice" @@ -35,20 +36,8 @@ class MyEnum(str, Enum): get_supported_values("", MyEnum, ("test",)) -@pytest.mark.parametrize( - "datatype", - [None, str], -) -def test_given_no_enum_or_string_then_returns_generated_choices_enum_with_pv_choices( - datatype, -): - supported_vals = get_supported_values("", datatype, ("test",)) - assert len(supported_vals) == 1 - assert "test" in supported_vals - - def test_given_a_supplied_enum_that_matches_the_pv_choices_then_enum_type_is_returned(): - class MyEnum(str, Enum): + class MyEnum(StrictEnum): TEST_1 = "test_1" TEST_2 = "test_2" diff --git a/tests/epics/signal/test_records.db b/tests/epics/signal/test_records.db index e5aa5a776c..f5d7607873 100644 --- a/tests/epics/signal/test_records.db +++ b/tests/epics/signal/test_records.db @@ -184,7 +184,7 @@ record(waveform, "$(P)longstr") { record(lsi, "$(P)longstr2") { field(SIZV, "80") field(INP, {const:"a string that is just longer than forty characters"}) - field(PINI, "YES") + field(PINI, "YES") } record(waveform, "$(P)table:labels") { diff --git a/tests/epics/signal/test_signals.py b/tests/epics/signal/test_signals.py index e2e5c20f7d..9ae5ee84d2 100644 --- a/tests/epics/signal/test_signals.py +++ b/tests/epics/signal/test_signals.py @@ -15,14 +15,15 @@ from unittest.mock import ANY import numpy as np -import numpy.typing as npt import pytest from aioca import CANothing, purge_channel_caches from bluesky.protocols import Reading from event_model import DataKey +from event_model.documents.event_descriptor import Limits, LimitsRange from typing_extensions import TypedDict from ophyd_async.core import ( + Array1D, NotConnected, SignalBackend, SubsetEnum, @@ -30,16 +31,15 @@ load_from_yaml, save_to_yaml, ) +from ophyd_async.core._utils import StrictEnum from ophyd_async.epics.signal import ( - LimitPair, - Limits, epics_signal_r, epics_signal_rw, epics_signal_rw_rbv, epics_signal_w, epics_signal_x, ) -from ophyd_async.epics.signal._epics_transport import _EpicsTransport # noqa +from ophyd_async.epics.signal._signal import _epics_signal_connector RECORDS = str(Path(__file__).parent / "test_records.db") PV_PREFIX = "".join(random.choice(string.ascii_lowercase) for _ in range(12)) @@ -54,13 +54,12 @@ async def make_backend( self, typ: type | None, suff: str, connect=True ) -> SignalBackend: # Calculate the pv - pv = f"{PV_PREFIX}:{self.protocol}:{suff}" + pv = f"{self.protocol}://{PV_PREFIX}:{self.protocol}:{suff}" # Make and connect the backend - cls = _EpicsTransport[self.protocol].value - backend = cls(typ, pv, pv) # type: ignore + connector = _epics_signal_connector(typ, pv, pv) if connect: - await asyncio.wait_for(backend.connect(), 10) # type: ignore - return backend # type: ignore + await connector.connect(None, False, 10, False) # type: ignore + return connector.backend # Use a module level fixture per protocol so it's fast to run tests. This means @@ -128,11 +127,8 @@ def assert_types_are_equal(t_actual, t_expected, actual_value): class MonitorQueue: def __init__(self, backend: SignalBackend): self.backend = backend - self.subscription = backend.set_callback(self.add_reading_value) - self.updates: asyncio.Queue[tuple[Reading, Any]] = asyncio.Queue() - - def add_reading_value(self, reading: Reading, value): - self.updates.put_nowait((reading, value)) + self.updates: asyncio.Queue[Reading] = asyncio.Queue() + self.subscription = backend.set_callback(self.updates.put_nowait) async def assert_updates(self, expected_value, expected_type=None): expected_reading = { @@ -141,7 +137,8 @@ async def assert_updates(self, expected_value, expected_type=None): "alarm_severity": 0, } backend_reading = await asyncio.wait_for(self.backend.get_reading(), timeout=5) - reading, value = await asyncio.wait_for(self.updates.get(), timeout=5) + reading = await asyncio.wait_for(self.updates.get(), timeout=5) + value = reading["value"] backend_value = await asyncio.wait_for(self.backend.get_value(), timeout=5) assert value == expected_value == backend_value @@ -205,21 +202,25 @@ async def put_error( await backend.put(put_value, timeout=3) -class MyEnum(str, Enum): +class MyEnum(StrictEnum): a = "Aaa" b = "Bbb" c = "Ccc" -MySubsetEnum = SubsetEnum["Aaa", "Bbb", "Ccc"] +class MySubsetEnum(SubsetEnum): + a = "Aaa" + b = "Bbb" + c = "Ccc" + _metadata: dict[str, dict[str, dict[str, Any]]] = { "ca": { "boolean": {"units": ANY, "limits": ANY}, "integer": {"units": ANY, "limits": ANY}, "number": {"units": ANY, "limits": ANY, "precision": ANY}, - "enum": {"limits": ANY}, - "string": {"limits": ANY}, + "enum": {}, + "string": {}, }, "pva": { "boolean": {"limits": ANY}, @@ -270,10 +271,10 @@ def get_dtype_numpy(suffix: str) -> str: # type: ignore elif "64" in suffix: int_str += "8" else: - int_str += "4" + int_str += "8" return int_str if "str" in suffix or "enum" in suffix: - return "|S40" + return "|T16" dtype = get_dtype(suffix) dtype_numpy = get_dtype_numpy(suffix) @@ -308,70 +309,70 @@ def get_dtype_numpy(suffix: str) -> str: # type: ignore (MyEnum, "enum", MyEnum.b, MyEnum.c, {"ca", "pva"}), # numpy arrays of numpy types ( - npt.NDArray[np.int8], + Array1D[np.int8], "int8a", [-128, 127], [-8, 3, 44], {"pva"}, ), ( - npt.NDArray[np.uint8], + Array1D[np.uint8], "uint8a", [0, 255], [218], {"ca", "pva"}, ), ( - npt.NDArray[np.int16], + Array1D[np.int16], "int16a", [-32768, 32767], [-855], {"ca", "pva"}, ), ( - npt.NDArray[np.uint16], + Array1D[np.uint16], "uint16a", [0, 65535], [5666], {"pva"}, ), ( - npt.NDArray[np.int32], + Array1D[np.int32], "int32a", [-2147483648, 2147483647], [-2], {"ca", "pva"}, ), ( - npt.NDArray[np.uint32], + Array1D[np.uint32], "uint32a", [0, 4294967295], [1022233], {"pva"}, ), ( - npt.NDArray[np.int64], + Array1D[np.int64], "int64a", [-2147483649, 2147483648], [-3], {"pva"}, ), ( - npt.NDArray[np.uint64], + Array1D[np.uint64], "uint64a", [0, 4294967297], [995444], {"pva"}, ), ( - npt.NDArray[np.float32], + Array1D[np.float32], "float32a", [0.000002, -123.123], [1.0], {"ca", "pva"}, ), ( - npt.NDArray[np.float64], + Array1D[np.float64], "float64a", [0.1, -12345678.123], [0.2], @@ -385,7 +386,7 @@ def get_dtype_numpy(suffix: str) -> str: # type: ignore {"pva"}, ), ( - npt.NDArray[np.str_], + Array1D[np.str_], "stra", ["five", "six", "seven"], ["nine", "ten"], @@ -529,6 +530,12 @@ class EnumNoString(Enum): a = "Aaa" +class SubsetEnumWrongChoices(SubsetEnum): + a = "Aaa" + b = "B" + c = "Ccc" + + @pytest.mark.parametrize( "typ, suff, error", [ @@ -541,12 +548,11 @@ class EnumNoString(Enum): ), ), ( - rt_enum := SubsetEnum["Aaa", "B", "Ccc"], + SubsetEnumWrongChoices, "enum", ( "has choices ('Aaa', 'Bbb', 'Ccc'), " - # SubsetEnum string output isn't deterministic - f"which is not a superset of {str(rt_enum)}." + "which is not a superset of ('Aaa', 'Bbb', 'Ccc')." ), ), (int, "str", "has type str not int"), @@ -561,7 +567,7 @@ class EnumNoString(Enum): "Use an Enum or SubsetEnum to represent this." ), ), - (npt.NDArray[np.int32], "float64a", "has type [float64] not [int32]"), + (Array1D[np.int32], "float64a", "has type [float64] not [int32]"), ], ) async def test_backend_wrong_type_errors(ioc: IOC, typ, suff, error): @@ -595,9 +601,9 @@ def approx_table(table): class MyTable(TypedDict): - bool: npt.NDArray[np.bool_] - int: npt.NDArray[np.int32] - float: npt.NDArray[np.float64] + bool: Array1D[np.bool_] + int: Array1D[np.int32] + float: Array1D[np.float64] str: Sequence[str] enum: Sequence[MyEnum] @@ -692,11 +698,11 @@ async def test_pva_ntdarray(ioc: IOC): put = np.array([1, 2, 3, 4, 5, 6], dtype=np.int64).reshape((2, 3)) initial = np.zeros_like(put) - backend = await ioc.make_backend(npt.NDArray[np.int64], "ntndarray") + backend = await ioc.make_backend(Array1D[np.int64], "ntndarray") # Backdoor into the "raw" data underlying the NDArray in QSrv # not supporting direct writes to NDArray at the moment. - raw_data_backend = await ioc.make_backend(npt.NDArray[np.int64], "ntndarray:data") + raw_data_backend = await ioc.make_backend(Array1D[np.int64], "ntndarray:data") # Make a monitor queue that will monitor for updates for i, p in [(initial, put), (put, initial)]: @@ -719,7 +725,7 @@ async def test_writing_to_ndarray_raises_typeerror(ioc: IOC): # CA can't do ndarray return - backend = await ioc.make_backend(npt.NDArray[np.int64], "ntndarray") + backend = await ioc.make_backend(Array1D[np.int64], "ntndarray") with pytest.raises(TypeError): await backend.put(np.zeros((6,), dtype=np.int64)) @@ -836,13 +842,13 @@ async def test_signal_returns_limits(ioc: IOC): expected_limits = Limits( # LOW, HIGH - warning=LimitPair(low=5.0, high=96.0), + warning=LimitsRange(low=5.0, high=96.0), # DRVL, DRVH - control=LimitPair(low=10.0, high=90.0), + control=LimitsRange(low=10.0, high=90.0), # LOPR, HOPR - display=LimitPair(low=0.0, high=100.0), + display=LimitsRange(low=0.0, high=100.0), # LOLO, HIHI - alarm=LimitPair(low=2.0, high=98.0), + alarm=LimitsRange(low=2.0, high=98.0), ) sig = epics_signal_rw(int, pv_name) @@ -858,13 +864,13 @@ async def test_signal_returns_partial_limits(ioc: IOC): expected_limits = Limits( # LOLO, HIHI - alarm=LimitPair(low=2.0, high=98.0), + alarm=LimitsRange(low=2.0, high=98.0), # DRVL, DRVH - control=LimitPair(low=10.0, high=90.0), + control=LimitsRange(low=10.0, high=90.0), # LOPR, HOPR - display=LimitPair(low=0.0, high=100.0), + display=LimitsRange(low=0.0, high=100.0), # HSV, LSV not set. - warning=LimitPair(low=not_set, high=not_set), + warning=LimitsRange(low=not_set, high=not_set), ) sig = epics_signal_rw(int, pv_name) @@ -880,13 +886,13 @@ async def test_signal_returns_warning_and_partial_limits(ioc: IOC): expected_limits = Limits( # HSV, LSV not set - alarm=LimitPair(low=not_set, high=not_set), + alarm=LimitsRange(low=not_set, high=not_set), # control = display if DRVL, DRVH not set - control=LimitPair(low=0.0, high=100.0), + control=LimitsRange(low=0.0, high=100.0), # LOPR, HOPR - display=LimitPair(low=0.0, high=100.0), + display=LimitsRange(low=0.0, high=100.0), # LOW, HIGH - warning=LimitPair(low=2.0, high=98.0), + warning=LimitsRange(low=2.0, high=98.0), ) sig = epics_signal_rw(int, pv_name) @@ -914,6 +920,7 @@ async def test_signals_created_for_not_prec_0_float_cannot_use_int(ioc: IOC): pv_name = f"{ioc.protocol}://{PV_PREFIX}:{ioc.protocol}:float_prec_1" sig = epics_signal_rw(int, pv_name) with pytest.raises( - TypeError, match=f"{ioc.protocol}:float_prec_1 has type float not int" + TypeError, + match=f"{ioc.protocol}:float_prec_1 with inferred datatype cannot be coerced to ", ): await sig.connect() From 2e9dd617eed3f0b17e4a56959adb89fe23c4b1f1 Mon Sep 17 00:00:00 2001 From: Tom Cobb Date: Wed, 25 Sep 2024 12:54:07 +0000 Subject: [PATCH 2/6] Slightly improve the connect signature --- src/ophyd_async/core/_device.py | 36 ++++++++++++-------- src/ophyd_async/core/_signal.py | 2 +- src/ophyd_async/core/_soft_signal_backend.py | 12 +++---- src/ophyd_async/epics/signal/_aioca.py | 8 ++--- tests/core/test_soft_signal_backend.py | 6 ++-- tests/epics/signal/test_signals.py | 3 +- 6 files changed, 35 insertions(+), 32 deletions(-) diff --git a/src/ophyd_async/core/_device.py b/src/ophyd_async/core/_device.py index 963888938a..6021bbd6df 100644 --- a/src/ophyd_async/core/_device.py +++ b/src/ophyd_async/core/_device.py @@ -15,19 +15,23 @@ class DeviceConnector: + # TODO: we will add some mechanism of invalidating the cache here later @abstractmethod async def connect( - self, device: Device, mock: bool, timeout: float, force_reconnect: bool + self, mock: bool, timeout: float, force_reconnect: bool ) -> None: ... class DeviceChildConnector(DeviceConnector): - async def connect( - self, device: Device, mock: bool, timeout: float, force_reconnect: bool - ) -> Any: + def __init__(self, children: Callable[[], dict[str, Device]]): + self._children = children + + async def connect(self, mock: bool, timeout: float, force_reconnect: bool) -> None: coros = { - name: child_device.connect(mock, timeout, force_reconnect) - for name, child_device in device.children().items() + name: child_device.connect( + mock=mock, timeout=timeout, force_reconnect=force_reconnect + ) + for name, child_device in self._children().items() } await wait_for_connection(**coros) @@ -49,15 +53,18 @@ class Device(HasName, Generic[DeviceConnectorType]): # The value of the mock arg to connect _connect_mock: bool | None = None # The connector to use - _connector: DeviceConnectorType = DeviceChildConnector() + _connector: DeviceConnectorType def __init__( self, name: str = "", connector: DeviceConnectorType | None = None, ) -> None: - if connector is not None: - self._connector = connector + if connector is None: + # TODO: this is ugly, maybe we remove the option to pass None as the + # connector so this goes away? + connector = cast(DeviceConnectorType, DeviceChildConnector(self.children)) + self._connector = connector self.set_name(name) @property @@ -125,7 +132,9 @@ async def connect( # Use the connector to make a new connection self._connect_mock = mock self._connect_task = asyncio.create_task( - self._connector.connect(self, mock, timeout, force_reconnect) + self._connector.connect( + mock=mock, timeout=timeout, force_reconnect=force_reconnect + ) ) assert self._connect_task, "Connect task not created, this shouldn't happen" # Wait for it to complete @@ -148,10 +157,9 @@ def __init__( self, children: dict[int, DeviceType], name: str = "", - connector: DeviceConnector | None = None, ) -> None: self._children = children - super().__init__(name, connector) + super().__init__(connector=DeviceChildConnector(self.children), name=name) def __getitem__(self, key: int) -> DeviceType: return self._children[key] @@ -224,12 +232,12 @@ def _caller_locals(self): ), "No previous frame to the one with self in it, this shouldn't happen" return caller_frame.f_locals - def __enter__(self) -> "DeviceCollector": + def __enter__(self) -> DeviceCollector: # Stash the names that were defined before we were called self._names_on_enter = set(self._caller_locals()) return self - async def __aenter__(self) -> "DeviceCollector": + async def __aenter__(self) -> DeviceCollector: return self.__enter__() async def _on_exit(self) -> None: diff --git a/src/ophyd_async/core/_signal.py b/src/ophyd_async/core/_signal.py index 93e2d94a70..b604e4439a 100644 --- a/src/ophyd_async/core/_signal.py +++ b/src/ophyd_async/core/_signal.py @@ -45,7 +45,7 @@ def __init__( name: str = "", ) -> None: self._timeout = timeout - super().__init__(name, connector) + super().__init__(name=name, connector=connector) @property def source(self) -> str: diff --git a/src/ophyd_async/core/_soft_signal_backend.py b/src/ophyd_async/core/_soft_signal_backend.py index aeb977ebd4..4eca1628bc 100644 --- a/src/ophyd_async/core/_soft_signal_backend.py +++ b/src/ophyd_async/core/_soft_signal_backend.py @@ -6,14 +6,12 @@ from collections.abc import Callable, Sequence from dataclasses import dataclass from functools import cached_property -from typing import Any, Generic, get_args, get_origin +from typing import Any, Generic, get_origin from unittest.mock import AsyncMock import numpy as np from event_model import DataKey -from ophyd_async.core._device import Device - from ._protocol import Reading from ._signal_backend import ( Array1D, @@ -120,14 +118,14 @@ def __init__( self, datatype: type[SignalDatatypeT] | None = None, initial_value: SignalDatatypeT | None = None, - metadata: SignalMetadata = {}, + metadata: SignalMetadata | None = None, ): # If not specified then default to float self._datatype = datatype or float # Create the right converter for the datatype self._converter = make_converter(self._datatype) self._initial_value = self._converter.write_value(initial_value) - self._metadata = metadata + self._metadata = metadata or SignalMetadata() self.set_value(self._initial_value) def set_value(self, value: SignalDatatypeT): @@ -193,9 +191,7 @@ class SoftSignalConnector(SignalConnector[SignalDatatypeT]): units: str | None = None precision: int | None = None - async def connect( - self, device: Device, mock: bool, timeout: float, force_reconnect: bool - ) -> None: + async def connect(self, mock: bool, timeout: float, force_reconnect: bool) -> None: # Add the extra static metadata to the dictionary metadata: SignalMetadata = {} if self.units is not None: diff --git a/src/ophyd_async/epics/signal/_aioca.py b/src/ophyd_async/epics/signal/_aioca.py index 4e6d0d0c18..a35a7379ce 100644 --- a/src/ophyd_async/epics/signal/_aioca.py +++ b/src/ophyd_async/epics/signal/_aioca.py @@ -22,7 +22,6 @@ from ophyd_async.core import ( Array1D, Callback, - Device, MockSignalBackend, SignalBackend, SignalConnector, @@ -217,7 +216,8 @@ def make_converter( # If datatype matches what we are given then allow it return CaConverter(inferred_datatype, pv_dbr) raise TypeError( - f"{pv} with inferred datatype {inferred_datatype} cannot be coerced to {datatype}" + f"{pv} with inferred datatype {inferred_datatype}" + f" cannot be coerced to {datatype}" ) @@ -313,9 +313,7 @@ class CaSignalConnector(SignalConnector[SignalDatatypeT]): read_pv: str write_pv: str - async def connect( - self, device: Device, mock: bool, timeout: float, force_reconnect: bool - ) -> None: + async def connect(self, mock: bool, timeout: float, force_reconnect: bool) -> None: if mock: self.backend = MockSignalBackend(self.datatype) else: diff --git a/tests/core/test_soft_signal_backend.py b/tests/core/test_soft_signal_backend.py index 60bfd5e047..e777fea767 100644 --- a/tests/core/test_soft_signal_backend.py +++ b/tests/core/test_soft_signal_backend.py @@ -100,7 +100,7 @@ async def test_soft_signal_backend_get_put_monitor( ): connector = SoftSignalConnector(datatype=datatype) - await connector.connect(None, False, 1, False) + await connector.connect(False, 1, False) q = MonitorQueue(connector.backend) try: # Check descriptor @@ -122,7 +122,7 @@ async def test_soft_signal_backend_get_put_monitor( async def test_soft_signal_backend_enum_value_equivalence(): connector = SoftSignalConnector(MyEnum) - await connector.connect(None, False, 1, False) + await connector.connect(False, 1, False) soft_backend = connector.backend assert (await soft_backend.get_value()) is MyEnum.a await soft_backend.put(MyEnum.b) @@ -131,7 +131,7 @@ async def test_soft_signal_backend_enum_value_equivalence(): async def test_soft_signal_backend_with_numpy_typing(): connector = SoftSignalConnector(npt.NDArray[np.float64]) - await connector.connect(None, False, 1, False) + await connector.connect(False, 1, False) soft_backend = connector.backend array = await soft_backend.get_value() diff --git a/tests/epics/signal/test_signals.py b/tests/epics/signal/test_signals.py index 9ae5ee84d2..aa9973a20e 100644 --- a/tests/epics/signal/test_signals.py +++ b/tests/epics/signal/test_signals.py @@ -921,6 +921,7 @@ async def test_signals_created_for_not_prec_0_float_cannot_use_int(ioc: IOC): sig = epics_signal_rw(int, pv_name) with pytest.raises( TypeError, - match=f"{ioc.protocol}:float_prec_1 with inferred datatype cannot be coerced to ", + match=f"{ioc.protocol}:float_prec_1 with inferred datatype " + "cannot be coerced to ", ): await sig.connect() From 2f6d171693edde44fdc9624cd57a89c2c5b1d230 Mon Sep 17 00:00:00 2001 From: "Tom C (DLS)" <101418278+coretl@users.noreply.github.com> Date: Thu, 26 Sep 2024 09:23:00 +0100 Subject: [PATCH 3/6] Update src/ophyd_async/core/_device.py Co-authored-by: Eva Lott --- src/ophyd_async/core/_device.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ophyd_async/core/_device.py b/src/ophyd_async/core/_device.py index 963888938a..80942e4457 100644 --- a/src/ophyd_async/core/_device.py +++ b/src/ophyd_async/core/_device.py @@ -47,7 +47,7 @@ class Device(HasName, Generic[DeviceConnectorType]): # None if connect hasn't started, a Task if it has _connect_task: asyncio.Task | None = None # The value of the mock arg to connect - _connect_mock: bool | None = None + _connected_in_mock_mode: bool | None = None # The connector to use _connector: DeviceConnectorType = DeviceChildConnector() From fdb450400ea5397952d6ec8d9612b109cf302ab5 Mon Sep 17 00:00:00 2001 From: "Tom C (DLS)" <101418278+coretl@users.noreply.github.com> Date: Thu, 26 Sep 2024 09:28:13 +0100 Subject: [PATCH 4/6] Update src/ophyd_async/core/_table.py Co-authored-by: Eva Lott --- src/ophyd_async/core/_table.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ophyd_async/core/_table.py b/src/ophyd_async/core/_table.py index fc712eb718..4a6d592028 100644 --- a/src/ophyd_async/core/_table.py +++ b/src/ophyd_async/core/_table.py @@ -149,5 +149,5 @@ def validate_arrays(self) -> "Table": def __len__(self) -> int: return len(next(iter(self))[1]) - def __getitem__(self) -> Any: + def __getitem__(self, items: Tuple[str, ...]) -> Any: raise NotImplementedError() From c70d4490f62f5bbdf218f6cf73ed6f86eb7408e4 Mon Sep 17 00:00:00 2001 From: "Tom C (DLS)" <101418278+coretl@users.noreply.github.com> Date: Thu, 26 Sep 2024 09:28:51 +0100 Subject: [PATCH 5/6] Update src/ophyd_async/epics/signal/_signal.py Co-authored-by: Eva Lott --- src/ophyd_async/epics/signal/_signal.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/ophyd_async/epics/signal/_signal.py b/src/ophyd_async/epics/signal/_signal.py index dd82891f4b..a40e97ab3f 100644 --- a/src/ophyd_async/epics/signal/_signal.py +++ b/src/ophyd_async/epics/signal/_signal.py @@ -24,8 +24,8 @@ def __init__(*args, **kwargs): class EpicsProtocol(Enum): - ca = "ca" - pva = "pva" + CA = "ca" + PVA = "pva" _default_epics_protocol = EpicsProtocol.ca From 4a5c27f20785bed1231126e9ee7b1ac1f9f0c5a3 Mon Sep 17 00:00:00 2001 From: Tom Cobb Date: Thu, 26 Sep 2024 12:44:44 +0000 Subject: [PATCH 6/6] Convert p4p and improve ca to match --- docs/how-to/make-a-simple-device.rst | 8 +- src/ophyd_async/core/_detector.py | 9 +- src/ophyd_async/core/_device.py | 14 +- src/ophyd_async/core/_device_save_loader.py | 2 +- src/ophyd_async/core/_readable.py | 3 +- src/ophyd_async/core/_signal.py | 2 +- src/ophyd_async/core/_signal_backend.py | 35 +- src/ophyd_async/core/_table.py | 2 +- .../epics/adaravis/_aravis_controller.py | 2 +- src/ophyd_async/epics/adaravis/_aravis_io.py | 12 +- src/ophyd_async/epics/adcore/_core_io.py | 10 +- src/ophyd_async/epics/adcore/_utils.py | 25 +- .../epics/adkinetix/_kinetix_io.py | 7 +- .../epics/adpilatus/_pilatus_io.py | 5 +- src/ophyd_async/epics/advimba/_vimba_io.py | 15 +- src/ophyd_async/epics/demo/_sensor.py | 12 +- src/ophyd_async/epics/eiger/_eiger_io.py | 6 +- src/ophyd_async/epics/eiger/_odin_io.py | 4 +- src/ophyd_async/epics/signal/_aioca.py | 172 +++-- src/ophyd_async/epics/signal/_common.py | 14 + src/ophyd_async/epics/signal/_p4p.py | 630 ++++++++---------- src/ophyd_async/epics/signal/_signal.py | 12 +- src/ophyd_async/fastcs/panda/_block.py | 19 +- src/ophyd_async/fastcs/panda/_table.py | 7 +- tests/core/test_device_save_loader.py | 6 +- tests/core/test_flyer.py | 4 +- tests/epics/signal/test_signals.py | 88 ++- 27 files changed, 528 insertions(+), 597 deletions(-) diff --git a/docs/how-to/make-a-simple-device.rst b/docs/how-to/make-a-simple-device.rst index f51fea2120..40edd36426 100644 --- a/docs/how-to/make-a-simple-device.rst +++ b/docs/how-to/make-a-simple-device.rst @@ -1,6 +1,6 @@ .. note:: - Ophyd async is included on a provisional basis until the v1.0 release and + Ophyd async is included on a provisional basis until the v1.0 release and may change API on minor release numbers before then Make a Simple Device @@ -31,7 +31,7 @@ its Python type, which could be: - A primitive (`str`, `int`, `float`) - An array (`numpy.typing.NDArray` ie. ``numpy.typing.NDArray[numpy.uint16]`` or ``Sequence[str]``) - An enum (`enum.Enum`) which **must** also extend `str` - - `str` and ``EnumClass(str, Enum)`` are the only valid ``datatype`` for an enumerated signal. + - `str` and ``EnumClass(StrictEnum)`` are the only valid ``datatype`` for an enumerated signal. The rest of the arguments are PV connection information, in this case the PV suffix. @@ -45,7 +45,7 @@ Finally `super().__init__() ` is called with: without renaming All signals passed into this init method will be monitored between ``stage()`` -and ``unstage()`` and their cached values returned on ``read()`` and +and ``unstage()`` and their cached values returned on ``read()`` and ``read_configuration()`` for perfomance. Movable @@ -64,7 +64,7 @@ informing watchers of the progress. When it gets to the requested value it completes. This co-routine is wrapped in a timeout handler, and passed to an `AsyncStatus` which will start executing it as soon as the Run Engine adds a callback to it. The ``stop()`` method then pokes a PV if the move needs to be -interrupted. +interrupted. Assembly -------- diff --git a/src/ophyd_async/core/_detector.py b/src/ophyd_async/core/_detector.py index bacd43a279..f18d8bc250 100644 --- a/src/ophyd_async/core/_detector.py +++ b/src/ophyd_async/core/_detector.py @@ -4,10 +4,6 @@ import time from abc import ABC, abstractmethod from collections.abc import AsyncGenerator, AsyncIterator, Callable, Sequence -from enum import Enum -from typing import ( - Generic, -) from bluesky.protocols import ( Collectable, @@ -26,10 +22,10 @@ from ._protocol import AsyncConfigurable, AsyncReadable from ._signal import SignalR from ._status import AsyncStatus, WatchableAsyncStatus -from ._utils import DEFAULT_TIMEOUT, T, WatcherUpdate, merge_gathered_dicts +from ._utils import DEFAULT_TIMEOUT, StrictEnum, WatcherUpdate, merge_gathered_dicts -class DetectorTrigger(str, Enum): +class DetectorTrigger(StrictEnum): """Type of mechanism for triggering a detector to take frames""" #: Detector generates internal trigger for given rate @@ -158,7 +154,6 @@ class StandardDetector( Flyable, Collectable, WritesStreamAssets, - Generic[T], ): """ Useful detector base class for step and fly scanning detectors. diff --git a/src/ophyd_async/core/_device.py b/src/ophyd_async/core/_device.py index f12b89051b..45b01e18b1 100644 --- a/src/ophyd_async/core/_device.py +++ b/src/ophyd_async/core/_device.py @@ -63,7 +63,9 @@ def __init__( if connector is None: # TODO: this is ugly, maybe we remove the option to pass None as the # connector so this goes away? - connector = cast(DeviceConnectorType, DeviceChildConnector(self.children)) + connector = cast( + DeviceConnectorType, DeviceChildConnector(lambda: self.children) + ) self._connector = connector self.set_name(name) @@ -78,6 +80,7 @@ def log(self): getLogger("ophyd_async.devices"), {"ophyd_async_device_name": self.name} ) + @property def children(self) -> dict[str, Device]: return { attr_name: attr @@ -99,7 +102,7 @@ def set_name(self, name: str): del self.log self._name = name - for attr_name, child in self.children().items(): + for attr_name, child in self.children.items(): child_name = f"{name}-{attr_name.rstrip('_')}" if name else "" child.set_name(child_name) child.parent = self @@ -124,13 +127,13 @@ async def connect( # If previous connect with same args has started and not errored, can use it can_use_previous_connect = ( - mock is self._connect_mock + mock is self._connected_in_mock_mode and self._connect_task and not (self._connect_task.done() and self._connect_task.exception()) ) if force_reconnect or not can_use_previous_connect: # Use the connector to make a new connection - self._connect_mock = mock + self._connected_in_mock_mode = mock self._connect_task = asyncio.create_task( self._connector.connect( mock=mock, timeout=timeout, force_reconnect=force_reconnect @@ -159,7 +162,7 @@ def __init__( name: str = "", ) -> None: self._children = children - super().__init__(connector=DeviceChildConnector(self.children), name=name) + super().__init__(name=name) def __getitem__(self, key: int) -> DeviceType: return self._children[key] @@ -170,6 +173,7 @@ def __iter__(self) -> Iterator[int]: def __len__(self) -> int: return len(self._children) + @property def children(self) -> dict[str, Device]: return {str(key): value for key, value in self.items()} diff --git a/src/ophyd_async/core/_device_save_loader.py b/src/ophyd_async/core/_device_save_loader.py index 95e936752b..0f5d0176d2 100644 --- a/src/ophyd_async/core/_device_save_loader.py +++ b/src/ophyd_async/core/_device_save_loader.py @@ -111,7 +111,7 @@ def walk_rw_signals( path_prefix = "" signals: dict[str, SignalRW[Any]] = {} - for attr_name, attr in device.children(): + for attr_name, attr in device.children.items(): dot_path = f"{path_prefix}{attr_name}" if type(attr) is SignalRW: signals[dot_path] = attr diff --git a/src/ophyd_async/core/_readable.py b/src/ophyd_async/core/_readable.py index 111a26d3b1..a4fef65e53 100644 --- a/src/ophyd_async/core/_readable.py +++ b/src/ophyd_async/core/_readable.py @@ -161,8 +161,7 @@ def add_children_as_readables( flattened_values = [] for value in new_values: if isinstance(value, DeviceVector): - children = value.children() - flattened_values.extend([x[1] for x in children]) + flattened_values.extend(value.children.values()) else: flattened_values.append(value) diff --git a/src/ophyd_async/core/_signal.py b/src/ophyd_async/core/_signal.py index b604e4439a..d41a9e4b63 100644 --- a/src/ophyd_async/core/_signal.py +++ b/src/ophyd_async/core/_signal.py @@ -51,7 +51,7 @@ def __init__( def source(self) -> str: """Like ca://PV_PREFIX:SIGNAL, or "" if not set""" source = self._connector.source(self.name) - if self._connect_mock: + if self._connected_in_mock_mode: return f"mock+{source}" else: return source diff --git a/src/ophyd_async/core/_signal_backend.py b/src/ophyd_async/core/_signal_backend.py index 27b233b489..43bbf9974b 100644 --- a/src/ophyd_async/core/_signal_backend.py +++ b/src/ophyd_async/core/_signal_backend.py @@ -1,6 +1,5 @@ from abc import abstractmethod from collections.abc import Sequence -from enum import Enum from typing import Generic, TypedDict, TypeVar, get_origin import numpy as np @@ -28,6 +27,7 @@ | Array1D[np.uint64] | Array1D[np.float32] | Array1D[np.float64] + | np.ndarray | SubsetEnum | Sequence[str] | Sequence[SubsetEnum] @@ -76,9 +76,12 @@ def _fail(*args, **kwargs): class DisconnectedBackend(SignalBackend): - source = connect = put = get_datakey = get_reading = get_value = get_setpoint = ( - set_callback - ) = _fail + put = _fail + get_datakey = _fail + get_reading = _fail + get_value = _fail + get_setpoint = _fail + set_callback = _fail class SignalConnector(DeviceConnector, Generic[SignalDatatypeT]): @@ -103,10 +106,14 @@ class SignalMetadata(TypedDict, total=False): units: str -def _datakey_dtype(datatype: type[SignalDatatypeT]) -> Dtype: - if get_origin(datatype) in (Sequence, np.ndarray) or issubclass(datatype, Table): +def _datakey_dtype(datatype: type[SignalDatatype]) -> Dtype: + if ( + datatype is np.ndarray + or get_origin(datatype) in (Sequence, np.ndarray) + or issubclass(datatype, Table) + ): return "array" - elif issubclass(datatype, Enum): + elif issubclass(datatype, SubsetEnum): return "string" elif issubclass(datatype, Primitive): return _primitive_dtype[datatype] @@ -114,13 +121,19 @@ def _datakey_dtype(datatype: type[SignalDatatypeT]) -> Dtype: raise TypeError(f"Can't make dtype for {datatype}") -def _datakey_dtype_numpy(datatype: type[SignalDatatypeT]) -> np.dtype: +def _datakey_dtype_numpy( + datatype: type[SignalDatatypeT], value: SignalDatatypeT +) -> np.dtype: if get_origin(datatype) == np.ndarray: + # If we are told what numpy dtype we will be, use that return get_dtype(datatype) + elif datatype is np.ndarray and isinstance(value, np.ndarray): + # If we are just told an array, get it from the value + return value.dtype elif ( get_origin(datatype) == Sequence or datatype is str - or issubclass(datatype, Enum) + or issubclass(datatype, SubsetEnum) ): return np.dtypes.StringDType() elif issubclass(datatype, Table): @@ -132,7 +145,7 @@ def _datakey_dtype_numpy(datatype: type[SignalDatatypeT]) -> np.dtype: def _datakey_shape(value: SignalDatatype) -> list[int]: - if type(value) in _primitive_dtype or isinstance(value, Enum): + if type(value) in _primitive_dtype or isinstance(value, SubsetEnum): return [] elif isinstance(value, np.ndarray): return list(value.shape) @@ -152,7 +165,7 @@ def make_datakey( dtype=_datakey_dtype(datatype), shape=_datakey_shape(value), # Ignore until https://github.com/bluesky/event-model/issues/308 - dtype_numpy=_datakey_dtype_numpy(datatype).str, # type: ignore + dtype_numpy=_datakey_dtype_numpy(datatype, value).str, # type: ignore source=source, **metadata, ) diff --git a/src/ophyd_async/core/_table.py b/src/ophyd_async/core/_table.py index 4a6d592028..800926ae89 100644 --- a/src/ophyd_async/core/_table.py +++ b/src/ophyd_async/core/_table.py @@ -149,5 +149,5 @@ def validate_arrays(self) -> "Table": def __len__(self) -> int: return len(next(iter(self))[1]) - def __getitem__(self, items: Tuple[str, ...]) -> Any: + def __getitem__(self, items: tuple[str, ...]) -> Any: raise NotImplementedError() diff --git a/src/ophyd_async/epics/adaravis/_aravis_controller.py b/src/ophyd_async/epics/adaravis/_aravis_controller.py index 80b826db2d..e604074b0e 100644 --- a/src/ophyd_async/epics/adaravis/_aravis_controller.py +++ b/src/ophyd_async/epics/adaravis/_aravis_controller.py @@ -69,7 +69,7 @@ def _get_trigger_info( f"use {trigger}" ) if trigger == DetectorTrigger.internal: - return AravisTriggerMode.off, "Freerun" + return AravisTriggerMode.off, AravisTriggerSource.freerun else: return (AravisTriggerMode.on, f"Line{self.gpio_number}") # type: ignore diff --git a/src/ophyd_async/epics/adaravis/_aravis_io.py b/src/ophyd_async/epics/adaravis/_aravis_io.py index 27c2898513..9707beac2d 100644 --- a/src/ophyd_async/epics/adaravis/_aravis_io.py +++ b/src/ophyd_async/epics/adaravis/_aravis_io.py @@ -1,11 +1,9 @@ -from enum import Enum - -from ophyd_async.core import SubsetEnum +from ophyd_async.core import StrictEnum, SubsetEnum from ophyd_async.epics import adcore from ophyd_async.epics.signal import epics_signal_rw_rbv -class AravisTriggerMode(str, Enum): +class AravisTriggerMode(StrictEnum): """GigEVision GenICAM standard: on=externally triggered""" on = "On" @@ -19,7 +17,11 @@ class AravisTriggerMode(str, Enum): To prevent requiring one Enum class per possible configuration, we set as this Enum but read from the underlying signal as a str. """ -AravisTriggerSource = SubsetEnum["Freerun", "Line1"] + + +class AravisTriggerSource(SubsetEnum): + freerun = "Freerun" + line1 = "Line1" class AravisDriverIO(adcore.ADBaseIO): diff --git a/src/ophyd_async/epics/adcore/_core_io.py b/src/ophyd_async/epics/adcore/_core_io.py index 7968579117..e044b0e5d1 100644 --- a/src/ophyd_async/epics/adcore/_core_io.py +++ b/src/ophyd_async/epics/adcore/_core_io.py @@ -1,6 +1,4 @@ -from enum import Enum - -from ophyd_async.core import Device +from ophyd_async.core import Device, StrictEnum from ophyd_async.epics.signal import ( epics_signal_r, epics_signal_rw, @@ -10,7 +8,7 @@ from ._utils import ADBaseDataType, FileWriteMode, ImageMode -class Callback(str, Enum): +class Callback(StrictEnum): Enable = "Enable" Disable = "Disable" @@ -68,7 +66,7 @@ def __init__(self, prefix: str, name: str = "") -> None: super().__init__(prefix, name) -class DetectorState(str, Enum): +class DetectorState(StrictEnum): """ Default set of states of an AreaDetector driver. See definition in ADApp/ADSrc/ADDriver.h in https://github.com/areaDetector/ADCore @@ -100,7 +98,7 @@ def __init__(self, prefix: str, name: str = "") -> None: super().__init__(prefix, name=name) -class Compression(str, Enum): +class Compression(StrictEnum): none = "None" nbit = "N-bit" szip = "szip" diff --git a/src/ophyd_async/epics/adcore/_utils.py b/src/ophyd_async/epics/adcore/_utils.py index a1a21b6071..bedbd474c2 100644 --- a/src/ophyd_async/epics/adcore/_utils.py +++ b/src/ophyd_async/epics/adcore/_utils.py @@ -1,11 +1,16 @@ from dataclasses import dataclass -from enum import Enum -from ophyd_async.core import DEFAULT_TIMEOUT, SignalRW, T, wait_for_value -from ophyd_async.core._signal import SignalR +from ophyd_async.core import ( + DEFAULT_TIMEOUT, + SignalDatatypeT, + SignalR, + SignalRW, + StrictEnum, + wait_for_value, +) -class ADBaseDataType(str, Enum): +class ADBaseDataType(StrictEnum): Int8 = "Int8" UInt8 = "UInt8" Int16 = "Int16" @@ -73,25 +78,25 @@ def convert_param_dtype_to_np(datatype: str) -> str: return np_datatype -class FileWriteMode(str, Enum): +class FileWriteMode(StrictEnum): single = "Single" capture = "Capture" stream = "Stream" -class ImageMode(str, Enum): +class ImageMode(StrictEnum): single = "Single" multiple = "Multiple" continuous = "Continuous" -class NDAttributeDataType(str, Enum): +class NDAttributeDataType(StrictEnum): INT = "INT" DOUBLE = "DOUBLE" STRING = "STRING" -class NDAttributePvDbrType(str, Enum): +class NDAttributePvDbrType(StrictEnum): DBR_SHORT = "DBR_SHORT" DBR_ENUM = "DBR_ENUM" DBR_INT = "DBR_INT" @@ -122,8 +127,8 @@ class NDAttributeParam: async def stop_busy_record( - signal: SignalRW[T], - value: T, + signal: SignalRW[SignalDatatypeT], + value: SignalDatatypeT, timeout: float = DEFAULT_TIMEOUT, status_timeout: float | None = None, ) -> None: diff --git a/src/ophyd_async/epics/adkinetix/_kinetix_io.py b/src/ophyd_async/epics/adkinetix/_kinetix_io.py index 30c4ccd2c3..4b70886648 100644 --- a/src/ophyd_async/epics/adkinetix/_kinetix_io.py +++ b/src/ophyd_async/epics/adkinetix/_kinetix_io.py @@ -1,16 +1,15 @@ -from enum import Enum - +from ophyd_async.core import StrictEnum from ophyd_async.epics import adcore from ophyd_async.epics.signal import epics_signal_rw_rbv -class KinetixTriggerMode(str, Enum): +class KinetixTriggerMode(StrictEnum): internal = "Internal" edge = "Rising Edge" gate = "Exp. Gate" -class KinetixReadoutMode(str, Enum): +class KinetixReadoutMode(StrictEnum): sensitivity = 1 speed = 2 dynamic_range = 3 diff --git a/src/ophyd_async/epics/adpilatus/_pilatus_io.py b/src/ophyd_async/epics/adpilatus/_pilatus_io.py index de040b5c4f..51ca65ce9c 100644 --- a/src/ophyd_async/epics/adpilatus/_pilatus_io.py +++ b/src/ophyd_async/epics/adpilatus/_pilatus_io.py @@ -1,10 +1,9 @@ -from enum import Enum - +from ophyd_async.core import StrictEnum from ophyd_async.epics import adcore from ophyd_async.epics.signal import epics_signal_r, epics_signal_rw_rbv -class PilatusTriggerMode(str, Enum): +class PilatusTriggerMode(StrictEnum): internal = "Internal" ext_enable = "Ext. Enable" ext_trigger = "Ext. Trigger" diff --git a/src/ophyd_async/epics/advimba/_vimba_io.py b/src/ophyd_async/epics/advimba/_vimba_io.py index ac14872ef8..0dc7571b7b 100644 --- a/src/ophyd_async/epics/advimba/_vimba_io.py +++ b/src/ophyd_async/epics/advimba/_vimba_io.py @@ -1,10 +1,9 @@ -from enum import Enum - +from ophyd_async.core import StrictEnum from ophyd_async.epics import adcore from ophyd_async.epics.signal import epics_signal_rw_rbv -class VimbaPixelFormat(str, Enum): +class VimbaPixelFormat(StrictEnum): internal = "Mono8" ext_enable = "Mono12" ext_trigger = "Ext. Trigger" @@ -12,7 +11,7 @@ class VimbaPixelFormat(str, Enum): alignment = "Alignment" -class VimbaConvertFormat(str, Enum): +class VimbaConvertFormat(StrictEnum): none = "None" mono8 = "Mono8" mono16 = "Mono16" @@ -20,7 +19,7 @@ class VimbaConvertFormat(str, Enum): rgb16 = "RGB16" -class VimbaTriggerSource(str, Enum): +class VimbaTriggerSource(StrictEnum): freerun = "Freerun" line1 = "Line1" line2 = "Line2" @@ -30,17 +29,17 @@ class VimbaTriggerSource(str, Enum): action1 = "Action1" -class VimbaOverlap(str, Enum): +class VimbaOverlap(StrictEnum): off = "Off" prev_frame = "PreviousFrame" -class VimbaOnOff(str, Enum): +class VimbaOnOff(StrictEnum): on = "On" off = "Off" -class VimbaExposeOutMode(str, Enum): +class VimbaExposeOutMode(StrictEnum): timed = "Timed" # Use ExposureTime PV trigger_width = "TriggerWidth" # Expose for length of high signal diff --git a/src/ophyd_async/epics/demo/_sensor.py b/src/ophyd_async/epics/demo/_sensor.py index 37d590d155..5235fe0aba 100644 --- a/src/ophyd_async/epics/demo/_sensor.py +++ b/src/ophyd_async/epics/demo/_sensor.py @@ -1,10 +1,14 @@ -from enum import Enum - -from ophyd_async.core import ConfigSignal, DeviceVector, HintedSignal, StandardReadable +from ophyd_async.core import ( + ConfigSignal, + DeviceVector, + HintedSignal, + StandardReadable, + StrictEnum, +) from ophyd_async.epics.signal import epics_signal_r, epics_signal_rw -class EnergyMode(str, Enum): +class EnergyMode(StrictEnum): """Energy mode for `Sensor`""" #: Low energy mode diff --git a/src/ophyd_async/epics/eiger/_eiger_io.py b/src/ophyd_async/epics/eiger/_eiger_io.py index 1df672592d..ed61c0b326 100644 --- a/src/ophyd_async/epics/eiger/_eiger_io.py +++ b/src/ophyd_async/epics/eiger/_eiger_io.py @@ -1,10 +1,8 @@ -from enum import Enum - -from ophyd_async.core import Device +from ophyd_async.core import Device, StrictEnum from ophyd_async.epics.signal import epics_signal_r, epics_signal_rw_rbv, epics_signal_w -class EigerTriggerMode(str, Enum): +class EigerTriggerMode(StrictEnum): internal = "ints" edge = "exts" gate = "exte" diff --git a/src/ophyd_async/epics/eiger/_odin_io.py b/src/ophyd_async/epics/eiger/_odin_io.py index c5a38a669b..34002e9794 100644 --- a/src/ophyd_async/epics/eiger/_odin_io.py +++ b/src/ophyd_async/epics/eiger/_odin_io.py @@ -1,6 +1,5 @@ import asyncio from collections.abc import AsyncGenerator, AsyncIterator -from enum import Enum from bluesky.protocols import StreamAsset from event_model import DataKey @@ -12,6 +11,7 @@ DeviceVector, NameProvider, PathProvider, + StrictEnum, observe_value, set_and_wait_for_value, ) @@ -22,7 +22,7 @@ ) -class Writing(str, Enum): +class Writing(StrictEnum): ON = "ON" OFF = "OFF" diff --git a/src/ophyd_async/epics/signal/_aioca.py b/src/ophyd_async/epics/signal/_aioca.py index a35a7379ce..3efca044bf 100644 --- a/src/ophyd_async/epics/signal/_aioca.py +++ b/src/ophyd_async/epics/signal/_aioca.py @@ -1,14 +1,16 @@ +import logging import sys from collections.abc import Sequence from dataclasses import dataclass from math import isnan, nan -from typing import Any, Generic, cast, get_origin +from typing import Any, Generic, cast import numpy as np from aioca import ( FORMAT_CTRL, FORMAT_RAW, FORMAT_TIME, + CANothing, Subscription, caget, camonitor, @@ -34,8 +36,9 @@ wait_for_connection, ) from ophyd_async.core._protocol import Reading +from ophyd_async.core._utils import NotConnected -from ._common import get_supported_values +from ._common import format_datatype, get_supported_values def _limits_from_augmented_value(value: AugmentedValue) -> Limits: @@ -97,6 +100,17 @@ def value(self, value: AugmentedValue) -> SignalDatatypeT: return +value # type: ignore +class CaArrayConverter(CaConverter[np.ndarray]): + def value(self, value: AugmentedValue) -> np.ndarray: + # A less expensive conversion + return np.array(value, copy=False) + + +class CaSequenceStrConverter(CaConverter[Sequence[str]]): + def value(self, value: AugmentedValue) -> Sequence[str]: + return [str(v) for v in value] # type: ignore + + class CaLongStrConverter(CaConverter[str]): def __init__(self): super().__init__(str, dbr.DBR_CHAR_STR, dbr.DBR_CHAR_STR) @@ -107,10 +121,12 @@ def write_value_and_dbr(self, value: Any) -> Any: return value + "\0" -class CaArrayConverter(CaConverter[np.ndarray]): - def value(self, value: AugmentedValue) -> np.ndarray: - # A less expensive conversion - return np.array(value, copy=False) +class CaBoolConverter(CaConverter[bool]): + def __init__(self): + super().__init__(bool, dbr.DBR_SHORT) + + def value(self, value: AugmentedValue) -> bool: + return bool(value) class CaEnumConverter(CaConverter[str]): @@ -124,37 +140,23 @@ def value(self, value: AugmentedValue) -> str: return self.supported_values[str(value)] -class CaSequenceStrConverter(CaConverter[Sequence[str]]): - def __init__(self): - super().__init__(Sequence[str], dbr.DBR_STRING) - - def value(self, value: AugmentedValue) -> Sequence[str]: - return [str(v) for v in value] # type: ignore - - -class CaBoolConverter(CaConverter[bool]): - def __init__(self): - super().__init__(bool, dbr.DBR_SHORT) - - def value(self, value: AugmentedValue) -> bool: - return bool(value) - - -_datatypes_from_dbr: dict[tuple[Dbr, bool], type[SignalDatatype]] = { - (dbr.DBR_STRING, False): str, - (dbr.DBR_SHORT, False): int, - (dbr.DBR_FLOAT, False): float, - (dbr.DBR_ENUM, False): str, - (dbr.DBR_CHAR, False): int, - (dbr.DBR_LONG, False): int, - (dbr.DBR_DOUBLE, False): float, - (dbr.DBR_STRING, True): Sequence[str], - (dbr.DBR_SHORT, True): Array1D[np.int16], - (dbr.DBR_FLOAT, True): Array1D[np.float32], - (dbr.DBR_ENUM, True): Sequence[str], - (dbr.DBR_CHAR, True): Array1D[np.uint8], - (dbr.DBR_LONG, True): Array1D[np.int32], - (dbr.DBR_DOUBLE, True): Array1D[np.float64], +_datatype_converter_from_dbr: dict[ + tuple[Dbr, bool], tuple[type[SignalDatatype], type[CaConverter]] +] = { + (dbr.DBR_STRING, False): (str, CaConverter), + (dbr.DBR_SHORT, False): (int, CaConverter), + (dbr.DBR_FLOAT, False): (float, CaConverter), + (dbr.DBR_ENUM, False): (str, CaConverter), + (dbr.DBR_CHAR, False): (int, CaConverter), + (dbr.DBR_LONG, False): (int, CaConverter), + (dbr.DBR_DOUBLE, False): (float, CaConverter), + (dbr.DBR_STRING, True): (Sequence[str], CaSequenceStrConverter), + (dbr.DBR_SHORT, True): (Array1D[np.int16], CaArrayConverter), + (dbr.DBR_FLOAT, True): (Array1D[np.float32], CaArrayConverter), + (dbr.DBR_ENUM, True): (Sequence[str], CaSequenceStrConverter), + (dbr.DBR_CHAR, True): (Array1D[np.uint8], CaArrayConverter), + (dbr.DBR_LONG, True): (Array1D[np.int32], CaArrayConverter), + (dbr.DBR_DOUBLE, True): (Array1D[np.float64], CaArrayConverter), } @@ -166,58 +168,46 @@ def make_converter( Dbr, get_unique({k: v.datatype for k, v in values.items()}, "datatypes") ) is_array = bool([v for v in values.values() if v.element_count > 1]) - # Infer a datatype from the dbr - inferred_datatype = _datatypes_from_dbr[(pv_dbr, is_array)] - # Create the correct converter based on requested datatype - if is_array: - if pv_dbr == dbr.DBR_STRING and datatype in (None, Sequence[str]): - # Otherwise they get string if requested or inferred - return CaSequenceStrConverter() - elif pv_dbr == dbr.DBR_CHAR and datatype is str: - # Override waveform of chars to be treated as string - return CaLongStrConverter() - elif ( - datatype in (None, inferred_datatype) - and get_origin(inferred_datatype) == np.ndarray - ): - # The requested datatype matches the inferred datatype, so use that - # We verify the origin of inferred_datatype above, but pyright doesn't know - # that, so do a cast below - return CaArrayConverter(cast(type[np.ndarray], inferred_datatype), pv_dbr) - else: - if pv_dbr == dbr.DBR_ENUM: - pv_choices = get_unique( - {k: tuple(v.enums) for k, v in values.items()}, "choices" + # Infer a datatype and converter from the dbr + inferred_datatype, converter_cls = _datatype_converter_from_dbr[(pv_dbr, is_array)] + # Some override cases + if is_array and pv_dbr == dbr.DBR_CHAR and datatype is str: + # Override waveform of chars to be treated as string + return CaLongStrConverter() + elif not is_array and pv_dbr == dbr.DBR_ENUM: + pv_choices = get_unique( + {k: tuple(v.enums) for k, v in values.items()}, "choices" + ) + if datatype is bool: + # Database can't do bools, so are often representated as enums of len 2 + if len(pv_choices) != 2: + raise TypeError(f"{pv} has {pv_choices=}, can't map to bool") + return CaBoolConverter() + elif enum_cls := get_enum_cls(datatype): + # If explicitly requested then check + return CaEnumConverter(get_supported_values(pv, enum_cls, pv_choices)) + elif datatype in (None, str): + # Drop to string for safety, but retain choices as metadata + return CaConverter( + str, + dbr.DBR_STRING, + metadata=SignalMetadata(choices=list(pv_choices)), ) - if datatype is bool: - # Database can't do bools, so are often representated as enums of len 2 - if len(pv_choices) != 2: - raise TypeError(f"{pv} has {pv_choices=}, can't map to bool") - return CaBoolConverter() - elif enum_cls := get_enum_cls(datatype): - # If explicitly requested then check - supported_values = get_supported_values(pv, enum_cls, pv_choices) - return CaEnumConverter(supported_values) - else: - # Drop to string, but retain choices as metadata - return CaConverter( - str, - dbr.DBR_STRING, - metadata=SignalMetadata(choices=list(pv_choices)), - ) - elif ( - pv_dbr == dbr.DBR_DOUBLE - and get_unique({k: v.precision for k, v in values.items()}, "precision") - == 0 - ): - # Allow int signals to represent float records when prec is 0 - return CaConverter(int, pv_dbr) - elif datatype in (None, inferred_datatype): - # If datatype matches what we are given then allow it - return CaConverter(inferred_datatype, pv_dbr) + elif ( + inferred_datatype is float + and datatype is int + and get_unique({k: v.precision for k, v in values.items()}, "precision") == 0 + ): + # Allow int signals to represent float records when prec is 0 + return CaConverter(int, pv_dbr) + elif datatype in (None, inferred_datatype): + # If datatype matches what we are given then allow it and use inferred converter + return converter_cls(inferred_datatype, pv_dbr) + if pv_dbr == dbr.DBR_ENUM: + inferred_datatype = "str | SubsetEnum | StrictEnum" raise TypeError( - f"{pv} with inferred datatype {inferred_datatype}" - f" cannot be coerced to {datatype}" + f"{pv} with inferred datatype {format_datatype(inferred_datatype)}" + f" cannot be coerced to {format_datatype(datatype)}" ) @@ -324,7 +314,13 @@ async def connect_epics(self, timeout: float) -> CaSignalBackend: initial_values: dict[str, AugmentedValue] = {} async def store_initial_value(pv: str): - initial_values[pv] = await caget(pv, format=FORMAT_CTRL, timeout=timeout) + try: + initial_values[pv] = await caget( + pv, format=FORMAT_CTRL, timeout=timeout + ) + except CANothing as exc: + logging.debug(f"signal ca://{pv} timed out") + raise NotConnected(f"ca://{pv}") from exc if self.read_pv != self.write_pv: # Different, need to connect both diff --git a/src/ophyd_async/epics/signal/_common.py b/src/ophyd_async/epics/signal/_common.py index a1862612a6..633d953056 100644 --- a/src/ophyd_async/epics/signal/_common.py +++ b/src/ophyd_async/epics/signal/_common.py @@ -1,6 +1,10 @@ from collections.abc import Sequence +from typing import Any, get_args, get_origin + +import numpy as np from ophyd_async.core import StrictEnum, get_enum_cls +from ophyd_async.core._utils import get_dtype def get_supported_values( @@ -27,3 +31,13 @@ def get_supported_values( for v in enum_cls: supported_values[v.value] = v return supported_values + + +def format_datatype(datatype: Any) -> str: + if get_origin(datatype) is np.ndarray and get_args(datatype)[0] == tuple[int]: + dtype = get_dtype(datatype) + return f"Array1D[np.{dtype.name}]" + elif isinstance(datatype, type): + return datatype.__name__ + else: + return str(datatype) diff --git a/src/ophyd_async/epics/signal/_p4p.py b/src/ophyd_async/epics/signal/_p4p.py index 21ae85d376..02aea34a9f 100644 --- a/src/ophyd_async/epics/signal/_p4p.py +++ b/src/ophyd_async/epics/signal/_p4p.py @@ -2,197 +2,106 @@ import asyncio import atexit -import inspect import logging import time -from collections.abc import Sequence +from collections.abc import Mapping, Sequence from dataclasses import dataclass -from enum import Enum from math import isnan, nan -from typing import Any, get_origin +from typing import Any, Generic import numpy as np -from bluesky.protocols import Reading from event_model import DataKey -from event_model.documents.event_descriptor import Dtype +from event_model.documents.event_descriptor import Limits, LimitsRange from p4p import Value from p4p.client.asyncio import Context, Subscription from pydantic import BaseModel from ophyd_async.core import ( - DEFAULT_TIMEOUT, NotConnected, SignalBackend, - T, - get_dtype, + SignalDatatypeT, + SignalMetadata, get_unique, - is_pydantic_model, wait_for_connection, ) +from ophyd_async.core._protocol import Reading +from ophyd_async.core._signal_backend import ( + Array1D, + SignalConnector, + SignalDatatype, + make_datakey, +) +from ophyd_async.core._soft_signal_backend import MockSignalBackend +from ophyd_async.core._table import Table +from ophyd_async.core._utils import Callback, SubsetEnum, get_enum_cls -from ._common import get_supported_values - -# https://mdavidsaver.github.io/p4p/values.html -specifier_to_dtype: dict[str, Dtype] = { - "?": "integer", # bool - "b": "integer", # int8 - "B": "integer", # uint8 - "h": "integer", # int16 - "H": "integer", # uint16 - "i": "integer", # int32 - "I": "integer", # uint32 - "l": "integer", # int64 - "L": "integer", # uint64 - "f": "number", # float32 - "d": "number", # float64 - "s": "string", -} - -specifier_to_np_dtype: dict[str, str] = { - "?": " DataKey: - """ - Args: - value (Value): Description of the the return type of a DB record - shape: Optional override shape when len(shape) > 1 - choices: Optional list of enum choices to pass as metadata in the datakey - dtype: Optional override dtype when AugmentedValue is ambiguous, e.g. booleans - - Returns: - DataKey: A rich DataKey describing the DB record - """ - shape = shape or [] - type_code = value.type().aspy("value") - - dtype = dtype or specifier_to_dtype[type_code] - - try: - if isinstance(type_code, tuple): - dtype_numpy = "" - if type_code[1] == "enum_t": - if dtype == "boolean": - dtype_numpy = " Limits: +def _limits_from_value(value: Any) -> Limits: def get_limits( substucture_name: str, low_name: str = "limitLow", high_name: str = "limitHigh" - ) -> LimitPair: + ) -> LimitsRange | None: substructure = getattr(value, substucture_name, None) low = getattr(substructure, low_name, nan) high = getattr(substructure, high_name, nan) - return LimitPair( - low=None if isnan(low) else low, high=None if isnan(high) else high - ) - - return Limits( - alarm=get_limits("valueAlarm", "lowAlarmLimit", "highAlarmLimit"), - control=get_limits("control"), - display=get_limits("display"), - warning=get_limits("valueAlarm", "lowWarningLimit", "highWarningLimit"), - ) - - -class PvaConverter: - def write_value(self, value): - return value - - def value(self, value): - return value["value"] + if not (isnan(low) and isnan(high)): + return LimitsRange( + low=None if isnan(low) else low, + high=None if isnan(high) else high, + ) - def reading(self, value) -> Reading: - ts = value["timeStamp"] - sv = value["alarm"]["severity"] - return { - "value": self.value(value), - "timestamp": ts["secondsPastEpoch"] + ts["nanoseconds"] * 1e-9, - "alarm_severity": -1 if sv > 2 else sv, - } + limits = Limits() + if limits_range := get_limits("valueAlarm", "lowAlarmLimit", "highAlarmLimit"): + limits["alarm"] = limits_range + if limits_range := get_limits("control"): + limits["control"] = limits_range + if limits_range := get_limits("display"): + limits["display"] = limits_range + if limits_range := get_limits("valueAlarm", "lowWarningLimit", "highWarningLimit"): + limits["warning"] = limits_range + return limits + + +def _metadata_from_value(datatype: type[SignalDatatype], value: Any) -> SignalMetadata: + metadata = SignalMetadata() + value_data: Any = getattr(value, "value", None) + display_data: Any = getattr(value, "display", None) + if hasattr(display_data, "units"): + metadata["units"] = display_data.units + if hasattr(display_data, "precision") and not isnan(display_data.precision): + metadata["precision"] = display_data.precision + if limits := _limits_from_value(value): + metadata["limits"] = limits + # Get choices from display or value + if datatype is str or issubclass(datatype, SubsetEnum): + if hasattr(display_data, "choices"): + metadata["choices"] = display_data.choices + elif hasattr(value_data, "choices"): + metadata["choices"] = value_data.choices + return metadata - def get_datakey(self, source: str, value) -> DataKey: - return _data_key_from_value(source, value) - def metadata_fields(self) -> list[str]: - """ - Fields to request from PVA for metadata. - """ - return ["alarm", "timeStamp"] +class PvaConverter(Generic[SignalDatatypeT]): + value_fields = ("value",) + reading_fields = ("alarm", "timeStamp") - def value_fields(self) -> list[str]: - """ - Fields to request from PVA for the value. - """ - return ["value"] + def __init__(self, datatype: type[SignalDatatypeT]): + self.datatype = datatype + def value(self, value: Any) -> SignalDatatypeT: + # for channel access ca_xxx classes, this + # invokes __pos__ operator to return an instance of + # the builtin base class + return value["value"] -class PvaArrayConverter(PvaConverter): - def get_datakey(self, source: str, value) -> DataKey: - return _data_key_from_value( - source, value, dtype="array", shape=[len(value["value"])] - ) + def write_value(self, value: Any) -> Any: + # The pva library will do the conversion for us + return value -class PvaNDArrayConverter(PvaConverter): - def metadata_fields(self) -> list[str]: - return super().metadata_fields() + ["dimension"] +class PvaNDArrayConverter(PvaConverter[SignalDatatypeT]): + value_fields = ("value", "dimension") def _get_dimensions(self, value) -> list[int]: dimensions: list[Value] = value["dimension"] @@ -205,62 +114,44 @@ def _get_dimensions(self, value) -> list[int]: # last index changing fastest. return dims[::-1] - def value(self, value): + def value(self, value: Any) -> SignalDatatypeT: dims = self._get_dimensions(value) return value["value"].reshape(dims) - def get_datakey(self, source: str, value) -> DataKey: - dims = self._get_dimensions(value) - return _data_key_from_value(source, value, dtype="array", shape=dims) - - def write_value(self, value): + def write_value(self, value: Any) -> Any: # No clear use-case for writing directly to an NDArray, and some # complexities around flattening to 1-D - e.g. dimension-order. # Don't support this for now. raise TypeError("Writing to NDArray not supported") -@dataclass -class PvaEnumConverter(PvaConverter): - """To prevent issues when a signal is restarted and returns with different enum - values or orders, we put treat an Enum signal as a string, and cache the - choices on this class. - """ - - def __init__(self, choices: dict[str, str]): - self.choices = tuple(choices.values()) - - def write_value(self, value: Enum | str): - if isinstance(value, Enum): - return value.value - else: - return value +class PvaEnumConverter(PvaConverter[str]): + def __init__( + self, datatype: type[str] = str, supported_values: Mapping[str, str] = {} + ): + self.supported_values = supported_values + super().__init__(datatype) - def value(self, value): - return self.choices[value["value"]["index"]] + def value(self, value: Any) -> str: + str_value = value["value"]["choices"][value["value"]["index"]] + if self.supported_values: + return self.supported_values[str_value] + else: + return str_value - def get_datakey(self, source: str, value) -> DataKey: - return _data_key_from_value( - source, value, choices=list(self.choices), dtype="string" - ) +class PvaEnumBoolConverter(PvaConverter[bool]): + def __init__(self): + super().__init__(bool) -class PvaEmumBoolConverter(PvaConverter): - def value(self, value): + def value(self, value: Any) -> bool: return bool(value["value"]["index"]) - def get_datakey(self, source: str, value) -> DataKey: - return _data_key_from_value(source, value, dtype="boolean") - -class PvaTableConverter(PvaConverter): +class PvaTableConverter(PvaConverter[Table]): def value(self, value): return value["value"].todict() - def get_datakey(self, source: str, value) -> DataKey: - # This is wrong, but defer until we know how to actually describe a table - return _data_key_from_value(source, value, dtype="object") # type: ignore - class PvaPydanticModelConverter(PvaConverter): def __init__(self, datatype: BaseModel): @@ -301,217 +192,238 @@ def value_fields(self) -> list[str]: return [] -class DisconnectedPvaConverter(PvaConverter): - def __getattribute__(self, __name: str) -> Any: - raise NotImplementedError("No PV has been set as connect() has not been called") +# https://mdavidsaver.github.io/p4p/values.html +_datatype_converter_from_typeid: dict[ + tuple[str, str], tuple[type[SignalDatatype], type[PvaConverter]] +] = { + ("epics:nt/NTScalar:1.0", "?"): (bool, PvaConverter), + ("epics:nt/NTScalar:1.0", "b"): (int, PvaConverter), + ("epics:nt/NTScalar:1.0", "B"): (int, PvaConverter), + ("epics:nt/NTScalar:1.0", "h"): (int, PvaConverter), + ("epics:nt/NTScalar:1.0", "H"): (int, PvaConverter), + ("epics:nt/NTScalar:1.0", "i"): (int, PvaConverter), + ("epics:nt/NTScalar:1.0", "I"): (int, PvaConverter), + ("epics:nt/NTScalar:1.0", "l"): (int, PvaConverter), + ("epics:nt/NTScalar:1.0", "L"): (int, PvaConverter), + ("epics:nt/NTScalar:1.0", "f"): (float, PvaConverter), + ("epics:nt/NTScalar:1.0", "d"): (float, PvaConverter), + ("epics:nt/NTScalar:1.0", "s"): (str, PvaConverter), + ("epics:nt/NTEnum:1.0", "S"): (str, PvaEnumConverter), + ("epics:nt/NTScalarArray:1.0", "a?"): (Array1D[np.bool_], PvaConverter), + ("epics:nt/NTScalarArray:1.0", "ab"): (Array1D[np.int8], PvaConverter), + ("epics:nt/NTScalarArray:1.0", "aB"): (Array1D[np.uint8], PvaConverter), + ("epics:nt/NTScalarArray:1.0", "ah"): (Array1D[np.int16], PvaConverter), + ("epics:nt/NTScalarArray:1.0", "aH"): (Array1D[np.uint16], PvaConverter), + ("epics:nt/NTScalarArray:1.0", "ai"): (Array1D[np.int32], PvaConverter), + ("epics:nt/NTScalarArray:1.0", "aI"): (Array1D[np.uint32], PvaConverter), + ("epics:nt/NTScalarArray:1.0", "al"): (Array1D[np.int64], PvaConverter), + ("epics:nt/NTScalarArray:1.0", "aL"): (Array1D[np.uint64], PvaConverter), + ("epics:nt/NTScalarArray:1.0", "af"): (Array1D[np.float32], PvaConverter), + ("epics:nt/NTScalarArray:1.0", "ad"): (Array1D[np.float64], PvaConverter), + ("epics:nt/NTScalarArray:1.0", "as"): (Sequence[str], PvaConverter), + ("epics:nt/NTTable:1.0", "S"): (Table, PvaTableConverter), + ("epics:nt/NTNDArray:1.0", "v"): (np.ndarray, PvaNDArrayConverter), +} + + +def _get_specifier(value: Value): + typ = value.type("value").aspy() + if isinstance(typ, tuple): + return typ[0] + else: + return str(typ) def make_converter(datatype: type | None, values: dict[str, Any]) -> PvaConverter: pv = list(values)[0] typeid = get_unique({k: v.getID() for k, v in values.items()}, "typeids") - typ = get_unique( - {k: type(v.get("value")) for k, v in values.items()}, "value types" + specifier = get_unique( + {k: _get_specifier(v) for k, v in values.items()}, + "value type specifiers", ) - if "NTScalarArray" in typeid and typ is list: - # Waveform of strings, check we wanted this - if datatype and datatype != Sequence[str]: - raise TypeError(f"{pv} has type [str] not {datatype.__name__}") - return PvaArrayConverter() - elif "NTScalarArray" in typeid or "NTNDArray" in typeid: - pv_dtype = get_unique( - {k: v["value"].dtype for k, v in values.items()}, "dtypes" - ) - # This is an array - if datatype: - # Check we wanted an array of this type - dtype = get_dtype(datatype) - if not dtype: - raise TypeError(f"{pv} has type [{pv_dtype}] not {datatype.__name__}") - if dtype != pv_dtype: - raise TypeError(f"{pv} has type [{pv_dtype}] not [{dtype}]") - if "NTNDArray" in typeid: - return PvaNDArrayConverter() - else: - return PvaArrayConverter() - elif "NTEnum" in typeid and datatype is bool: - # Wanted a bool, but database represents as an enum - pv_choices_len = get_unique( - {k: len(v["value"]["choices"]) for k, v in values.items()}, - "number of choices", - ) - if pv_choices_len != 2: - raise TypeError(f"{pv} has {pv_choices_len} choices, can't map to bool") - return PvaEmumBoolConverter() - elif "NTEnum" in typeid: - # This is an Enum + # Infer a datatype and converter from the typeid and specifier + inferred_datatype, converter_cls = _datatype_converter_from_typeid[ + (typeid, specifier) + ] + # Some override cases + if typeid == "epics:nt/NTEnum:1.0": pv_choices = get_unique( {k: tuple(v["value"]["choices"]) for k, v in values.items()}, "choices" ) - return PvaEnumConverter(get_supported_values(pv, datatype, pv_choices)) - elif "NTScalar" in typeid: - if ( - typ is str - and inspect.isclass(datatype) - and issubclass(datatype, RuntimeSubsetEnum) - ): + if datatype is bool: + # Database can't do bools, so are often representated as enums of len 2 + if len(pv_choices) != 2: + raise TypeError(f"{pv} has {pv_choices=}, can't map to bool") + return PvaEnumBoolConverter() + elif enum_cls := get_enum_cls(datatype): + # We were given an enum class, so make class from that return PvaEnumConverter( - get_supported_values(pv, datatype, datatype.choices) # type: ignore - ) - elif datatype and not issubclass(typ, datatype): - # Allow int signals to represent float records when prec is 0 - is_prec_zero_float = typ is float and ( - get_unique( - {k: v["display"]["precision"] for k, v in values.items()}, - "precision", - ) - == 0 + supported_values=get_supported_values(pv, enum_cls, pv_choices) ) - if not (datatype is int and is_prec_zero_float): - raise TypeError(f"{pv} has type {typ.__name__} not {datatype.__name__}") - return PvaConverter() - elif "NTTable" in typeid: - if is_pydantic_model(datatype): - return PvaPydanticModelConverter(datatype) # type: ignore - return PvaTableConverter() - elif "structure" in typeid: - return PvaDictConverter() - else: - raise TypeError(f"{pv}: Unsupported typeid {typeid}") - - -class PvaSignalBackend(SignalBackend[T]): - _ctxt: Context | None = None - - _ALLOWED_DATATYPES = ( - bool, - int, - float, - str, - Sequence, - np.ndarray, - Enum, - BaseModel, - dict, + elif datatype in (None, str): + # Still use the Enum converter, but make choices from what it has + return PvaEnumConverter() + elif ( + inferred_datatype is float + and datatype is int + and get_unique( + {k: v["display"]["precision"] for k, v in values.items()}, "precision" + ) + == 0 + ): + # Allow int signals to represent float records when prec is 0 + return PvaConverter(int) + elif datatype in (None, inferred_datatype): + # If datatype matches what we are given then allow it and use inferred converter + return converter_cls(inferred_datatype) + raise TypeError( + f"{pv} with inferred datatype {format_datatype(inferred_datatype)}" + f" from {typeid=} {specifier=}" + f" cannot be coerced to {format_datatype(datatype)}" ) - @classmethod - def datatype_allowed(cls, dtype: Any) -> bool: - stripped_origin = get_origin(dtype) or dtype - if dtype is None: - return True - return inspect.isclass(stripped_origin) and issubclass( - stripped_origin, cls._ALLOWED_DATATYPES - ) - def __init__(self, datatype: type[T] | None, read_pv: str, write_pv: str): - self.datatype = datatype - if not PvaSignalBackend.datatype_allowed(self.datatype): - raise TypeError(f"Given datatype {self.datatype} unsupported in PVA.") +_context: Context | None = None - self.read_pv = read_pv - self.write_pv = write_pv - self.initial_values: dict[str, Any] = {} - self.converter: PvaConverter = DisconnectedPvaConverter() - self.subscription: Subscription | None = None - def source(self, name: str): - return f"pva://{self.read_pv}" +def context() -> Context: + global _context + if _context is None: + _context = Context("pva", nt=False) - @property - def ctxt(self) -> Context: - if PvaSignalBackend._ctxt is None: - PvaSignalBackend._ctxt = Context("pva", nt=False) + @atexit.register + def _del_ctxt(): + # If we don't do this we get messages like this on close: + # Error in sys.excepthook: + # Original exception was: + global _context + del _context - @atexit.register - def _del_ctxt(): - # If we don't do this we get messages like this on close: - # Error in sys.excepthook: - # Original exception was: - PvaSignalBackend._ctxt = None + return _context - return PvaSignalBackend._ctxt - async def _store_initial_value(self, pv, timeout: float = DEFAULT_TIMEOUT): - try: - self.initial_values[pv] = await asyncio.wait_for( - self.ctxt.get(pv), timeout=timeout - ) - except asyncio.TimeoutError as exc: - logging.debug(f"signal pva://{pv} timed out", exc_info=True) - raise NotConnected(f"pva://{pv}") from exc +class PvaSignalBackend(SignalBackend[SignalDatatypeT]): + def __init__( + self, + datatype: type[SignalDatatypeT] | None, + read_pv: str, + write_pv: str, + initial_values: dict[str, Any], + ): + self._converter = make_converter(datatype, initial_values) + self._read_pv = read_pv + self._write_pv = write_pv + self._initial_values = initial_values + self.subscription: Subscription | None = None - async def connect(self, timeout: float = DEFAULT_TIMEOUT): - if self.read_pv != self.write_pv: - # Different, need to connect both - await wait_for_connection( - read_pv=self._store_initial_value(self.read_pv, timeout=timeout), - write_pv=self._store_initial_value(self.write_pv, timeout=timeout), - ) - else: - # The same, so only need to connect one - await self._store_initial_value(self.read_pv, timeout=timeout) - self.converter = make_converter(self.datatype, self.initial_values) + def _make_reading(self, value: Any) -> Reading[SignalDatatypeT]: + ts = value["timeStamp"] + sv = value["alarm"]["severity"] + return { + "value": self._converter.value(value), + "timestamp": ts["secondsPastEpoch"] + ts["nanoseconds"] * 1e-9, + "alarm_severity": -1 if sv > 2 else sv, + } - async def put(self, value: T | None, wait=True, timeout=None): + async def put(self, value: SignalDatatypeT | None, wait=True, timeout=None): if value is None: - write_value = self.initial_values[self.write_pv] + write_value = self._initial_values[self._write_pv] else: - write_value = self.converter.write_value(value) - coro = self.ctxt.put(self.write_pv, {"value": write_value}, wait=wait) + write_value = self._converter.write_value(value) + coro = context().put(self._write_pv, {"value": write_value}, wait=wait) try: await asyncio.wait_for(coro, timeout) except asyncio.TimeoutError as exc: - logging.debug( - f"signal pva://{self.write_pv} timed out \ - put value: {write_value}", - exc_info=True, - ) - raise NotConnected(f"pva://{self.write_pv}") from exc + raise asyncio.TimeoutError( + f"pva://{self._write_pv}: Put timed out" + ) from exc async def get_datakey(self, source: str) -> DataKey: - value = await self.ctxt.get(self.read_pv) - return self.converter.get_datakey(source, value) + value = await context().get(self._read_pv) + metadata = _metadata_from_value(self._converter.datatype, value) + return make_datakey( + self._converter.datatype, self._converter.value(value), source, metadata + ) - def _pva_request_string(self, fields: list[str]) -> str: - """ - Converts a list of requested fields into a PVA request string which can be + def _pva_request_string(self, fields: Sequence[str]) -> str: + """Converts a list of requested fields into a PVA request string which can be passed to p4p. """ return f"field({','.join(fields)})" async def get_reading(self) -> Reading: - request: str = self._pva_request_string( - self.converter.value_fields() + self.converter.metadata_fields() + request = self._pva_request_string( + self._converter.value_fields + self._converter.reading_fields ) - value = await self.ctxt.get(self.read_pv, request=request) - return self.converter.reading(value) + value = await context().get(self._read_pv, request=request) + return self._make_reading(value) - async def get_value(self) -> T: - request: str = self._pva_request_string(self.converter.value_fields()) - value = await self.ctxt.get(self.read_pv, request=request) - return self.converter.value(value) + async def get_value(self) -> SignalDatatypeT: + request = self._pva_request_string(self._converter.value_fields) + value = await context().get(self._read_pv, request=request) + return self._converter.value(value) - async def get_setpoint(self) -> T: - value = await self.ctxt.get(self.write_pv, "field(value)") - return self.converter.value(value) + async def get_setpoint(self) -> SignalDatatypeT: + request = self._pva_request_string(self._converter.value_fields) + value = await context().get(self._write_pv, request=request) + return self._converter.value(value) - def set_callback(self, callback: ReadingValueCallback[T] | None) -> None: + def set_callback(self, callback: Callback[Reading[SignalDatatypeT]] | None) -> None: if callback: assert ( not self.subscription ), "Cannot set a callback when one is already set" async def async_callback(v): - callback(self.converter.reading(v), self.converter.value(v)) + callback(self._make_reading(v)) - request: str = self._pva_request_string( - self.converter.value_fields() + self.converter.metadata_fields() + request = self._pva_request_string( + self._converter.value_fields + self._converter.reading_fields ) + self.subscription = context().monitor( + self._read_pv, async_callback, request=request + ) + elif self.subscription: + self.subscription.close() + self.subscription = None + - self.subscription = self.ctxt.monitor( - self.read_pv, async_callback, request=request +@dataclass +class PvaSignalConnector(SignalConnector[SignalDatatypeT]): + datatype: type[SignalDatatypeT] | None + read_pv: str + write_pv: str + + async def connect(self, mock: bool, timeout: float, force_reconnect: bool) -> None: + if mock: + self.backend = MockSignalBackend(self.datatype) + else: + self.backend = await self.connect_epics(timeout) + + async def connect_epics(self, timeout: float) -> PvaSignalBackend: + initial_values: dict[str, Any] = {} + + async def store_initial_value(pv: str): + try: + initial_values[pv] = await asyncio.wait_for( + context().get(pv), timeout=timeout + ) + except asyncio.TimeoutError as exc: + logging.debug(f"signal pva://{pv} timed out", exc_info=True) + raise NotConnected(f"pva://{pv}") from exc + + if self.read_pv != self.write_pv: + # Different, need to connect both + await wait_for_connection( + read_pv=store_initial_value(self.read_pv), + write_pv=store_initial_value(self.write_pv), ) else: - if self.subscription: - self.subscription.close() - self.subscription = None + # The same, so only need to connect one + await store_initial_value(self.read_pv) + return PvaSignalBackend( + self.datatype, self.read_pv, self.write_pv, initial_values + ) + + def source(self, name: str) -> str: + return f"pva://{self.read_pv}" diff --git a/src/ophyd_async/epics/signal/_signal.py b/src/ophyd_async/epics/signal/_signal.py index a40e97ab3f..d4dcc32f4c 100644 --- a/src/ophyd_async/epics/signal/_signal.py +++ b/src/ophyd_async/epics/signal/_signal.py @@ -28,21 +28,21 @@ class EpicsProtocol(Enum): PVA = "pva" -_default_epics_protocol = EpicsProtocol.ca +_default_epics_protocol = EpicsProtocol.CA try: from ._p4p import PvaSignalConnector except ImportError as pva_error: PvaSignalConnector = _make_unavailable_class(pva_error) else: - _default_epics_protocol = EpicsProtocol.pva + _default_epics_protocol = EpicsProtocol.PVA try: from ._aioca import CaSignalConnector except ImportError as ca_error: CaSignalConnector = _make_unavailable_class(ca_error) else: - _default_epics_protocol = EpicsProtocol.ca + _default_epics_protocol = EpicsProtocol.CA def _protocol_pv(pv: str) -> tuple[EpicsProtocol, str]: @@ -50,7 +50,7 @@ def _protocol_pv(pv: str) -> tuple[EpicsProtocol, str]: if len(split) > 1: # We got something like pva://mydevice, so use specified comms mode scheme, pv = split - protocol = EpicsProtocol[scheme] + protocol = EpicsProtocol(scheme) else: # No comms mode specified, use the default protocol = _default_epics_protocol @@ -65,9 +65,9 @@ def _epics_signal_connector( w_protocol, w_pv = _protocol_pv(write_pv) protocol = get_unique({read_pv: r_protocol, write_pv: w_protocol}, "protocols") match protocol: - case EpicsProtocol.ca: + case EpicsProtocol.CA: return CaSignalConnector(datatype, r_pv, w_pv) - case EpicsProtocol.pva: + case EpicsProtocol.PVA: return PvaSignalConnector(datatype, r_pv, w_pv) diff --git a/src/ophyd_async/fastcs/panda/_block.py b/src/ophyd_async/fastcs/panda/_block.py index 9deff70015..cf941d576a 100644 --- a/src/ophyd_async/fastcs/panda/_block.py +++ b/src/ophyd_async/fastcs/panda/_block.py @@ -1,8 +1,13 @@ from __future__ import annotations -from enum import Enum - -from ophyd_async.core import Device, DeviceVector, SignalR, SignalRW, SubsetEnum +from ophyd_async.core import ( + Device, + DeviceVector, + SignalR, + SignalRW, + StrictEnum, + SubsetEnum, +) from ._table import DatasetTable, SeqTable @@ -25,13 +30,15 @@ class PulseBlock(Device): width: SignalRW[float] -class PcompDirectionOptions(str, Enum): +class PcompDirectionOptions(StrictEnum): positive = "Positive" negative = "Negative" either = "Either" -EnableDisableOptions = SubsetEnum["ZERO", "ONE"] +class EnableDisableOptions(SubsetEnum): + zero = "ZERO" + one = "ONE" class PcompBlock(Device): @@ -44,7 +51,7 @@ class PcompBlock(Device): width: SignalRW[int] -class TimeUnits(str, Enum): +class TimeUnits(StrictEnum): min = "min" s = "s" ms = "ms" diff --git a/src/ophyd_async/fastcs/panda/_table.py b/src/ophyd_async/fastcs/panda/_table.py index a021d23fa8..1e378f7669 100644 --- a/src/ophyd_async/fastcs/panda/_table.py +++ b/src/ophyd_async/fastcs/panda/_table.py @@ -1,5 +1,4 @@ from collections.abc import Sequence -from enum import Enum from typing import Annotated import numpy as np @@ -8,10 +7,10 @@ from pydantic_numpy.helper.annotation import NpArrayPydanticAnnotation from typing_extensions import TypedDict -from ophyd_async.core import Table +from ophyd_async.core import StrictEnum, Table -class PandaHdf5DatasetType(str, Enum): +class PandaHdf5DatasetType(StrictEnum): FLOAT_64 = "float64" UINT_32 = "uint32" @@ -21,7 +20,7 @@ class DatasetTable(TypedDict): hdf5_type: Sequence[PandaHdf5DatasetType] -class SeqTrigger(str, Enum): +class SeqTrigger(StrictEnum): IMMEDIATE = "Immediate" BITA_0 = "BITA=0" BITA_1 = "BITA=1" diff --git a/tests/core/test_device_save_loader.py b/tests/core/test_device_save_loader.py index 03dd73d2c0..abdf1d72eb 100644 --- a/tests/core/test_device_save_loader.py +++ b/tests/core/test_device_save_loader.py @@ -1,5 +1,4 @@ from collections.abc import Sequence -from enum import Enum from os import path from typing import Any from unittest.mock import patch @@ -16,6 +15,7 @@ Device, SignalR, SignalRW, + StrictEnum, all_at_once, get_signal_values, load_device, @@ -34,7 +34,7 @@ def __init__(self) -> None: self.sig2: SignalR = epics_signal_r(str, "Value2") -class EnumTest(str, Enum): +class EnumTest(StrictEnum): VAL1 = "val1" VAL2 = "val2" @@ -51,7 +51,7 @@ def __init__(self, name: str): self.position: npt.NDArray[np.int32] -class MyEnum(str, Enum): +class MyEnum(StrictEnum): one = "one" two = "two" three = "three" diff --git a/tests/core/test_flyer.py b/tests/core/test_flyer.py index b9a3186134..f2595202d1 100644 --- a/tests/core/test_flyer.py +++ b/tests/core/test_flyer.py @@ -1,6 +1,5 @@ import time from collections.abc import AsyncGenerator, AsyncIterator, Sequence -from enum import Enum from typing import Any from unittest.mock import Mock @@ -18,6 +17,7 @@ DetectorWriter, StandardDetector, StandardFlyer, + StrictEnum, TriggerInfo, TriggerLogic, observe_value, @@ -25,7 +25,7 @@ from ophyd_async.epics.signal import epics_signal_rw -class TriggerState(str, Enum): +class TriggerState(StrictEnum): null = "null" preparing = "preparing" starting = "starting" diff --git a/tests/epics/signal/test_signals.py b/tests/epics/signal/test_signals.py index aa9973a20e..22b88ce6df 100644 --- a/tests/epics/signal/test_signals.py +++ b/tests/epics/signal/test_signals.py @@ -51,14 +51,13 @@ class IOC: protocol: Literal["ca", "pva"] async def make_backend( - self, typ: type | None, suff: str, connect=True + self, typ: type | None, suff: str, timeout=10.0 ) -> SignalBackend: # Calculate the pv pv = f"{self.protocol}://{PV_PREFIX}:{self.protocol}:{suff}" # Make and connect the backend connector = _epics_signal_connector(typ, pv, pv) - if connect: - await connector.connect(None, False, 10, False) # type: ignore + await connector.connect(mock=False, timeout=timeout, force_reconnect=False) return connector.backend @@ -198,8 +197,8 @@ async def put_error( # The below will work without error await backend.put(put_value) # Change the name of write_pv to mock disconnection - backend.__setattr__("write_pv", "Disconnect") - await backend.put(put_value, timeout=3) + backend.__setattr__("_write_pv", "Disconnect") + await backend.put(put_value, timeout=0.1) class MyEnum(StrictEnum): @@ -223,10 +222,10 @@ class MySubsetEnum(SubsetEnum): "string": {}, }, "pva": { - "boolean": {"limits": ANY}, + "boolean": {}, "integer": {"units": ANY, "precision": ANY, "limits": ANY}, "number": {"units": ANY, "precision": ANY, "limits": ANY}, - "enum": {"limits": ANY}, + "enum": {}, "string": {"units": ANY, "precision": ANY, "limits": ANY}, }, } @@ -258,7 +257,7 @@ def get_dtype_numpy(suffix: str) -> str: # type: ignore if "float" in suffix or "double" in suffix: return " str: # type: ignore "stra", ["five", "six", "seven"], ["nine", "ten"], - {"pva"}, - ), - ( - Array1D[np.str_], - "stra", - ["five", "six", "seven"], - ["nine", "ten"], - {"ca"}, + {"pva", "ca"}, ), # Can't do long strings until https://github.com/epics-base/pva2pva/issues/17 # (str, "longstr", ls1, ls2), @@ -474,8 +466,8 @@ async def test_bool_conversion_of_enum(ioc: IOC, suffix: str, tmp_path: Path) -> async def test_error_raised_on_disconnected_PV(ioc: IOC) -> None: if ioc.protocol == "pva": - err = NotConnected - expected = "pva://Disconnect" + err = asyncio.TimeoutError + expected = "pva://Disconnect: Put timed out" elif ioc.protocol == "ca": err = CANothing expected = "Disconnect: User specified timeout on IO operation expired" @@ -488,7 +480,7 @@ async def test_error_raised_on_disconnected_PV(ioc: IOC) -> None: ) -class BadEnum(str, Enum): +class BadEnum(StrictEnum): a = "Aaa" b = "B" c = "Ccc" @@ -500,12 +492,12 @@ def test_enum_equality(): possibly more. """ - class GeneratedChoices(str, Enum): + class GeneratedChoices(StrictEnum): a = "Aaa" b = "B" c = "Ccc" - class ExtendedGeneratedChoices(str, Enum): + class ExtendedGeneratedChoices(StrictEnum): a = "Aaa" b = "B" c = "Ccc" @@ -543,8 +535,8 @@ class SubsetEnumWrongChoices(SubsetEnum): BadEnum, "enum", ( - "has choices ('Aaa', 'Bbb', 'Ccc'), which do not match " - ", which has ('Aaa', 'B', 'Ccc')" + "has choices ('Aaa', 'Bbb', 'Ccc'), but " + "requested ['Aaa', 'B', 'Ccc'] to be a subset of them" ), ), ( @@ -698,7 +690,7 @@ async def test_pva_ntdarray(ioc: IOC): put = np.array([1, 2, 3, 4, 5, 6], dtype=np.int64).reshape((2, 3)) initial = np.zeros_like(put) - backend = await ioc.make_backend(Array1D[np.int64], "ntndarray") + backend = await ioc.make_backend(np.ndarray, "ntndarray") # Backdoor into the "raw" data underlying the NDArray in QSrv # not supporting direct writes to NDArray at the moment. @@ -710,9 +702,8 @@ async def test_pva_ntdarray(ioc: IOC): assert { "source": "test-source", "dtype": "array", - "dtype_numpy": "", + "dtype_numpy": " " - "cannot be coerced to ", + match=f"{ioc.protocol}:float_prec_1 with inferred datatype float" + ".* cannot be coerced to int", ): await sig.connect()