Skip to content

Commit

Permalink
Add StandardReadableFormat annotation support
Browse files Browse the repository at this point in the history
  • Loading branch information
coretl committed Oct 29, 2024
1 parent 32cf050 commit 8513acc
Show file tree
Hide file tree
Showing 6 changed files with 166 additions and 172 deletions.
8 changes: 7 additions & 1 deletion src/ophyd_async/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,12 @@
UUIDFilenameProvider,
YMDPathProvider,
)
from ._readable import ConfigSignal, HintedSignal, StandardReadable
from ._readable import (
ConfigSignal,
HintedSignal,
StandardReadable,
StandardReadableFormat,
)
from ._signal import (
Signal,
SignalR,
Expand Down Expand Up @@ -139,6 +144,7 @@
"ConfigSignal",
"HintedSignal",
"StandardReadable",
"StandardReadableFormat",
"Signal",
"SignalR",
"SignalRW",
Expand Down
173 changes: 75 additions & 98 deletions src/ophyd_async/core/_readable.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import warnings
from collections.abc import Callable, Generator, Sequence
from collections.abc import Awaitable, Callable, Generator, Sequence
from contextlib import contextmanager
from enum import Enum

from bluesky.protocols import HasHints, Hints, Reading
from event_model import DataKey
Expand All @@ -11,12 +11,24 @@
from ._status import AsyncStatus
from ._utils import merge_gathered_dicts

ReadableChild = AsyncReadable | AsyncConfigurable | AsyncStageable | HasHints
ReadableChildWrapper = (
Callable[[ReadableChild], ReadableChild]
| type["ConfigSignal"]
| type["HintedSignal"]
)

class StandardReadableFormat(Enum):
CHILD = "CHILD"
CONFIG_SIGNAL = "CONFIG_SIGNAL"
HINTED_SIGNAL = "HINTED_SIGNAL"
UNCACHED_SIGNAL = "UNCACHED_SIGNAL"
HINTED_UNCACHED_SIGNAL = "HINTED_UNCACHED_SIGNAL"

def __call__(self, parent: Device, child: Device):
if not isinstance(parent, StandardReadable):
raise TypeError(f"Expected parent to be StandardReadable, got {parent}")
parent.add_readables([child], self)


# Back compat
ConfigSignal = StandardReadableFormat.CONFIG_SIGNAL
HintedSignal = StandardReadableFormat.HINTED_SIGNAL
HintedSignal.uncached = StandardReadableFormat.HINTED_UNCACHED_SIGNAL # type: ignore


class StandardReadable(
Expand All @@ -31,38 +43,13 @@ class StandardReadable(

# These must be immutable types to avoid accidental sharing between
# different instances of the class
_readables: tuple[AsyncReadable, ...] = ()
_configurables: tuple[AsyncConfigurable, ...] = ()
_describe_config_funcs: tuple[Callable[[], Awaitable[dict[str, DataKey]]], ...] = ()
_read_config_funcs: tuple[Callable[[], Awaitable[dict[str, Reading]]], ...] = ()
_describe_funcs: tuple[Callable[[], Awaitable[dict[str, DataKey]]], ...] = ()
_read_funcs: tuple[Callable[[], Awaitable[dict[str, Reading]]], ...] = ()
_stageables: tuple[AsyncStageable, ...] = ()
_has_hints: tuple[HasHints, ...] = ()

def set_readable_signals(
self,
read: Sequence[SignalR] = (),
config: Sequence[SignalR] = (),
read_uncached: Sequence[SignalR] = (),
):
"""
Parameters
----------
read:
Signals to make up :meth:`~StandardReadable.read`
conf:
Signals to make up :meth:`~StandardReadable.read_configuration`
read_uncached:
Signals to make up :meth:`~StandardReadable.read` that won't be cached
"""
warnings.warn(
DeprecationWarning(
"Migrate to `add_children_as_readables` context manager or "
"`add_readables` method"
),
stacklevel=2,
)
self.add_readables(read, wrapper=HintedSignal)
self.add_readables(config, wrapper=ConfigSignal)
self.add_readables(read_uncached, wrapper=HintedSignal.uncached)

@AsyncStatus.wrap
async def stage(self) -> None:
for sig in self._stageables:
Expand All @@ -75,19 +62,17 @@ async def unstage(self) -> None:

async def describe_configuration(self) -> dict[str, DataKey]:
return await merge_gathered_dicts(
[sig.describe_configuration() for sig in self._configurables]
[func() for func in self._describe_config_funcs]
)

async def read_configuration(self) -> dict[str, Reading]:
return await merge_gathered_dicts(
[sig.read_configuration() for sig in self._configurables]
)
return await merge_gathered_dicts([func() for func in self._read_config_funcs])

async def describe(self) -> dict[str, DataKey]:
return await merge_gathered_dicts([sig.describe() for sig in self._readables])
return await merge_gathered_dicts([func() for func in self._describe_funcs])

async def read(self) -> dict[str, Reading]:
return await merge_gathered_dicts([sig.read() for sig in self._readables])
return await merge_gathered_dicts([func() for func in self._read_funcs])

@property
def hints(self) -> Hints:
Expand Down Expand Up @@ -127,7 +112,7 @@ def hints(self) -> Hints:
@contextmanager
def add_children_as_readables(
self,
wrapper: ReadableChildWrapper | None = None,
wrapper: StandardReadableFormat = StandardReadableFormat.CHILD,
) -> Generator[None, None, None]:
"""Context manager to wrap adding Devices
Expand Down Expand Up @@ -171,8 +156,8 @@ def add_children_as_readables(

def add_readables(
self,
devices: Sequence[ReadableChild],
wrapper: ReadableChildWrapper | None = None,
devices: Sequence[Device],
wrapper: StandardReadableFormat = StandardReadableFormat.CHILD,
) -> None:
"""Add the given devices to the lists of known Devices
Expand All @@ -197,65 +182,57 @@ def add_readables(
:meth:`HintedSignal.uncached`
"""

for readable in devices:
obj = readable
if wrapper:
obj = wrapper(readable)

if isinstance(obj, AsyncReadable):
self._readables += (obj,)

if isinstance(obj, AsyncConfigurable):
self._configurables += (obj,)

if isinstance(obj, AsyncStageable):
self._stageables += (obj,)

if isinstance(obj, HasHints):
self._has_hints += (obj,)


class ConfigSignal(AsyncConfigurable):
def __init__(self, signal: ReadableChild) -> None:
assert isinstance(signal, SignalR), f"Expected signal, got {signal}"
for device in devices:
match wrapper:
case StandardReadableFormat.CHILD:
if isinstance(device, AsyncConfigurable):
self._describe_config_funcs += (device.describe_configuration,)
self._read_config_funcs += (device.read_configuration,)
if isinstance(device, AsyncReadable):
self._describe_funcs += (device.describe,)
self._read_funcs += (device.read,)
if isinstance(device, AsyncStageable):
self._stageables += (device,)
if isinstance(device, HasHints):
self._has_hints += (device,)
case StandardReadableFormat.CONFIG_SIGNAL:
assert isinstance(device, SignalR), f"{device} is not a SignalR"
self._describe_config_funcs += (device.describe,)
self._read_config_funcs += (device.read,)
case StandardReadableFormat.HINTED_SIGNAL:
assert isinstance(device, SignalR), f"{device} is not a SignalR"
self._describe_funcs += (device.describe,)
self._read_funcs += (device.read,)
self._stageables += (device,)
self._has_hints += (_HintsFromName(device),)
case StandardReadableFormat.UNCACHED_SIGNAL:
assert isinstance(device, SignalR), f"{device} is not a SignalR"
self._describe_funcs += (device.describe,)
self._read_funcs += (_UncachedRead(device),)
case StandardReadableFormat.HINTED_UNCACHED_SIGNAL:
assert isinstance(device, SignalR), f"{device} is not a SignalR"
self._describe_funcs += (device.describe,)
self._read_funcs += (_UncachedRead(device),)
self._has_hints += (_HintsFromName(device),)


class _UncachedRead:
def __init__(self, signal: SignalR) -> None:
self.signal = signal

async def read_configuration(self) -> dict[str, Reading]:
return await self.signal.read()

async def describe_configuration(self) -> dict[str, DataKey]:
return await self.signal.describe()
async def __call__(self) -> dict[str, Reading]:
return await self.signal.read(cached=False)

@property
def name(self) -> str:
return self.signal.name


class HintedSignal(HasHints, AsyncReadable):
def __init__(self, signal: ReadableChild, allow_cache: bool = True) -> None:
assert isinstance(signal, SignalR), f"Expected signal, got {signal}"
self.signal = signal
self.cached = None if allow_cache else allow_cache
if allow_cache:
self.stage = signal.stage
self.unstage = signal.unstage

async def read(self) -> dict[str, Reading]:
return await self.signal.read(cached=self.cached)

async def describe(self) -> dict[str, DataKey]:
return await self.signal.describe()
class _HintsFromName(HasHints):
def __init__(self, device: Device) -> None:
self.device = device

@property
def name(self) -> str:
return self.signal.name
return self.device.name

@property
def hints(self) -> Hints:
if self.signal.name == "":
return {"fields": []}
return {"fields": [self.signal.name]}

@classmethod
def uncached(cls, signal: ReadableChild) -> "HintedSignal":
return cls(signal, allow_cache=False)
fields = [self.name] if self.name else []
return {"fields": fields}
1 change: 0 additions & 1 deletion src/ophyd_async/core/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from dataclasses import dataclass
from enum import Enum, EnumMeta
from typing import (
Annotated,
Any,
Generic,
Literal,
Expand Down
14 changes: 9 additions & 5 deletions src/ophyd_async/epics/core/_signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,16 @@ class EpicsProtocol(Enum):
_default_epics_protocol = EpicsProtocol.CA


def _make_unavailable_function(error: Exception):
def transport_not_available(*args, **kwargs):
raise NotImplementedError("Transport not available") from error

return transport_not_available


def _make_unavailable_class(error: Exception) -> type[EpicsSignalBackend]:
class TransportNotAvailable(EpicsSignalBackend):
def __init__(*args, **kwargs):
raise NotImplementedError("Transport not available") from error
__init__ = _make_unavailable_function(error)

return TransportNotAvailable

Expand All @@ -37,9 +43,7 @@ def __init__(*args, **kwargs):
from ._p4p import PvaSignalBackend, pvget_with_timeout
except ImportError as pva_error:
PvaSignalBackend = _make_unavailable_class(pva_error)

async def pvget_with_timeout(pv: str, timeout: float):
raise NotImplementedError("Transport not available") from pva_error
pvget_with_timeout = _make_unavailable_function(pva_error)
else:
_default_epics_protocol = EpicsProtocol.PVA

Expand Down
12 changes: 3 additions & 9 deletions src/ophyd_async/epics/demo/_sensor.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from typing import Annotated as A

from ophyd_async.core import (
ConfigSignal,
DeviceVector,
HintedSignal,
SignalR,
SignalRW,
StandardReadable,
StrictEnum,
)
from ophyd_async.core import StandardReadableFormat as Format
from ophyd_async.epics.core import EpicsDevice, PvSuffix


Expand All @@ -24,13 +23,8 @@ class EnergyMode(StrictEnum):
class Sensor(StandardReadable, EpicsDevice):
"""A demo sensor that produces a scalar value based on X and Y Movers"""

value: A[SignalR[float], PvSuffix("Value")]
mode: A[SignalRW[EnergyMode], PvSuffix("Mode")]

def __init__(self, prefix: str, name="") -> None:
super().__init__(prefix=prefix, name=name)
self.add_readables([self.value], HintedSignal)
self.add_readables([self.mode], ConfigSignal)
value: A[SignalR[float], PvSuffix("Value"), Format.HINTED_SIGNAL]
mode: A[SignalRW[EnergyMode], PvSuffix("Mode"), Format.CONFIG_SIGNAL]


class SensorGroup(StandardReadable):
Expand Down
Loading

0 comments on commit 8513acc

Please sign in to comment.