Skip to content

Commit

Permalink
Make Signal, Backend generic on read/set bounds
Browse files Browse the repository at this point in the history
  • Loading branch information
DiamondJoseph committed Apr 22, 2024
1 parent 610d2d1 commit 5e4e2a7
Show file tree
Hide file tree
Showing 13 changed files with 269 additions and 152 deletions.
10 changes: 5 additions & 5 deletions src/ophyd_async/core/device_save_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def represent_data(self, data: Any) -> Any:


def get_signal_values(
signals: Dict[str, SignalRW[Any]], ignore: Optional[List[str]] = None
signals: Dict[str, SignalRW[Any, Any]], ignore: Optional[List[str]] = None
) -> Generator[Msg, Sequence[Location[Any]], Dict[str, Any]]:
"""Get signal values in bulk.
Expand Down Expand Up @@ -91,7 +91,7 @@ def get_signal_values(

def walk_rw_signals(
device: Device, path_prefix: Optional[str] = ""
) -> Dict[str, SignalRW[Any]]:
) -> Dict[str, SignalRW[Any, Any]]:
"""Retrieve all SignalRWs from a device.
Stores retrieved signals with their dotted attribute paths in a dictionary. Used as
Expand Down Expand Up @@ -121,7 +121,7 @@ def walk_rw_signals(
if not path_prefix:
path_prefix = ""

signals: Dict[str, SignalRW[Any]] = {}
signals: Dict[str, SignalRW[Any, Any]] = {}
for attr_name, attr in device.children():
dot_path = f"{path_prefix}{attr_name}"
if type(attr) is SignalRW:
Expand Down Expand Up @@ -179,7 +179,7 @@ def load_from_yaml(save_path: str) -> Sequence[Dict[str, Any]]:


def set_signal_values(
signals: Dict[str, SignalRW[Any]], values: Sequence[Dict[str, Any]]
signals: Dict[str, SignalRW[Any, Any]], values: Sequence[Dict[str, Any]]
) -> Generator[Msg, None, None]:
"""Maps signals from a yaml file into device signals.
Expand All @@ -188,7 +188,7 @@ def set_signal_values(
Parameters
----------
signals : Dict[str, SignalRW[Any]]
signals : Dict[str, SignalRW[Any, Any]]
Dictionary of named signals to be updated if value found in values argument.
Can be the output of :func:`walk_rw_signals()` for a device.
Expand Down
75 changes: 40 additions & 35 deletions src/ophyd_async/core/signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from .device import Device
from .signal_backend import SignalBackend
from .sim_signal_backend import SimSignalBackend
from .utils import DEFAULT_TIMEOUT, Callback, ReadingValueCallback, T
from .utils import DEFAULT_TIMEOUT, Callback, R, ReadingValueCallback, S

_sim_backends: Dict[Signal, SimSignalBackend] = {}

Expand All @@ -42,12 +42,12 @@ def _fail(self, other, *args, **kwargs):
return NotImplemented


class Signal(Device, Generic[T]):
class Signal(Device, Generic[R, S]):
"""A Device with the concept of a value, with R, RW, W and X flavours"""

def __init__(
self,
backend: SignalBackend[T],
backend: SignalBackend[R, S],
timeout: Optional[float] = DEFAULT_TIMEOUT,
name: str = "",
) -> None:
Expand Down Expand Up @@ -76,14 +76,14 @@ def __hash__(self):
return hash(id(self))


class _SignalCache(Generic[T]):
def __init__(self, backend: SignalBackend[T], signal: Signal):
class _SignalCache(Generic[R, S]):
def __init__(self, backend: SignalBackend[R, S], signal: Signal):
self._signal = signal
self._staged = False
self._listeners: Dict[Callback, bool] = {}
self._valid = asyncio.Event()
self._reading: Optional[Reading] = None
self._value: Optional[T] = None
self._value: Optional[R] = None

self.backend = backend
backend.set_callback(self._callback)
Expand All @@ -96,12 +96,12 @@ async def get_reading(self) -> Reading:
assert self._reading is not None, "Monitor not working"
return self._reading

async def get_value(self) -> T:
async def get_value(self) -> R:
await self._valid.wait()
assert self._value is not None, "Monitor not working"
return self._value

def _callback(self, reading: Reading, value: T):
def _callback(self, reading: Reading, value: R):
self._reading = reading
self._value = value
self._valid.set()
Expand All @@ -128,7 +128,7 @@ def set_staged(self, staged: bool):
return self._staged or bool(self._listeners)


class SignalR(Signal[T], AsyncReadable, Stageable, Subscribable):
class SignalR(Signal[R, R], AsyncReadable, Stageable, Subscribable):
"""Signal that can be read from and monitored"""

_cache: Optional[_SignalCache] = None
Expand Down Expand Up @@ -166,11 +166,11 @@ async def describe(self) -> Dict[str, Descriptor]:
return {self.name: await self._backend.get_descriptor(self.source)}

@_add_timeout
async def get_value(self, cached: Optional[bool] = None) -> T:
async def get_value(self, cached: Optional[bool] = None) -> R:
"""The current value"""
return await self._backend_or_cache(cached).get_value()

def subscribe_value(self, function: Callback[T]):
def subscribe_value(self, function: Callback[R]):
"""Subscribe to updates in value of a device"""
self._get_cache().subscribe(function, want_value=True)

Expand All @@ -196,18 +196,18 @@ async def unstage(self) -> None:
USE_DEFAULT_TIMEOUT = "USE_DEFAULT_TIMEOUT"


class SignalW(Signal[T], Movable):
class SignalW(Signal[S, S], Movable):
"""Signal that can be set"""

def set(self, value: T, wait=True, timeout=USE_DEFAULT_TIMEOUT) -> AsyncStatus:
def set(self, value: S, wait=True, timeout=USE_DEFAULT_TIMEOUT) -> AsyncStatus:
"""Set the value and return a status saying when it's done"""
if timeout is USE_DEFAULT_TIMEOUT:
timeout = self._timeout
coro = self._backend.put(value, wait=wait, timeout=timeout)
return AsyncStatus(coro)


class SignalRW(SignalR[T], SignalW[T], Locatable):
class SignalRW(SignalR[R], SignalW[S], Locatable):
"""Signal that can be both read and set"""

async def locate(self) -> Location:
Expand All @@ -229,12 +229,12 @@ def trigger(self, wait=True, timeout=USE_DEFAULT_TIMEOUT) -> AsyncStatus:
return AsyncStatus(coro)


def set_sim_value(signal: Signal[T], value: T):
def set_sim_value(signal: Signal[R, S], value: S):
"""Set the value of a signal that is in sim mode."""
_sim_backends[signal]._set_value(value)


def set_sim_put_proceeds(signal: Signal[T], proceeds: bool):
def set_sim_put_proceeds(signal: Signal[R, S], proceeds: bool):
"""Allow or block a put with wait=True from proceeding"""
event = _sim_backends[signal].put_proceeds
if proceeds:
Expand All @@ -243,26 +243,29 @@ def set_sim_put_proceeds(signal: Signal[T], proceeds: bool):
event.clear()


def set_sim_callback(signal: Signal[T], callback: ReadingValueCallback[T]) -> None:
def set_sim_callback(signal: Signal[R, S], callback: ReadingValueCallback[R]) -> None:
"""Monitor the value of a signal that is in sim mode"""
return _sim_backends[signal].set_callback(callback)


def soft_signal_rw(
datatype: Optional[Type[T]] = None,
initial_value: Optional[T] = None,
datatype: Optional[Type[S]] = None,
initial_value: Optional[R] = None,
name: str = "",
) -> SignalRW[T]:
read_datatype: Optional[Type[R]] = None,
) -> SignalRW[R, S]:
"""Creates a read-writable Signal with a SimSignalBackend"""
signal = SignalRW(SimSignalBackend(datatype, initial_value), name=name)
return signal
return SignalRW(
SimSignalBackend(datatype, initial_value, read_datatype=read_datatype),
name=name,
)


def soft_signal_r_and_backend(
datatype: Optional[Type[T]] = None,
initial_value: Optional[T] = None,
datatype: Optional[Type[R]] = None,
initial_value: Optional[R] = None,
name: str = "",
) -> Tuple[SignalR[T], SimSignalBackend]:
) -> Tuple[SignalR[R], SimSignalBackend]:
"""Returns a tuple of a read-only Signal and its SimSignalBackend through
which the signal can be internally modified within the device. Use
soft_signal_rw if you want a device that is externally modifiable
Expand All @@ -272,7 +275,7 @@ def soft_signal_r_and_backend(
return (signal, backend)


async def observe_value(signal: SignalR[T], timeout=None) -> AsyncGenerator[T, None]:
async def observe_value(signal: SignalR[R], timeout=None) -> AsyncGenerator[R, None]:
"""Subscribe to the value of a signal so it can be iterated from.
Parameters
Expand All @@ -288,7 +291,7 @@ async def observe_value(signal: SignalR[T], timeout=None) -> AsyncGenerator[T, N
async for value in observe_value(sig):
do_something_with(value)
"""
q: asyncio.Queue[T] = asyncio.Queue()
q: asyncio.Queue[R] = asyncio.Queue()
if timeout is None:
get_value = q.get
else:
Expand All @@ -304,19 +307,19 @@ async def get_value():
signal.clear_sub(q.put_nowait)


class _ValueChecker(Generic[T]):
def __init__(self, matcher: Callable[[T], bool], matcher_name: str):
self._last_value: Optional[T] = None
class _ValueChecker(Generic[R]):
def __init__(self, matcher: Callable[[R], bool], matcher_name: str):
self._last_value: Optional[R] = None
self._matcher = matcher
self._matcher_name = matcher_name

async def _wait_for_value(self, signal: SignalR[T]):
async def _wait_for_value(self, signal: SignalR[R]):
async for value in observe_value(signal):
self._last_value = value
if self._matcher(value):
return

async def wait_for_value(self, signal: SignalR[T], timeout: Optional[float]):
async def wait_for_value(self, signal: SignalR[R], timeout: Optional[float]):
try:
await asyncio.wait_for(self._wait_for_value(signal), timeout)
except asyncio.TimeoutError as e:
Expand All @@ -327,7 +330,9 @@ async def wait_for_value(self, signal: SignalR[T], timeout: Optional[float]):


async def wait_for_value(
signal: SignalR[T], match: Union[T, Callable[[T], bool]], timeout: Optional[float]
signal: SignalR[R],
match: Union[R, Callable[[R], bool]],
timeout: Optional[float],
):
"""Wait for a signal to have a matching value.
Expand Down Expand Up @@ -360,8 +365,8 @@ async def wait_for_value(


async def set_and_wait_for_value(
signal: SignalRW[T],
value: T,
signal: SignalRW[R, S],
value: S,
timeout: float = DEFAULT_TIMEOUT,
status_timeout: Optional[float] = None,
) -> AsyncStatus:
Expand Down
18 changes: 10 additions & 8 deletions src/ophyd_async/core/signal_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,16 @@

from bluesky.protocols import Descriptor, Reading

from .utils import DEFAULT_TIMEOUT, ReadingValueCallback, T
from .utils import DEFAULT_TIMEOUT, R, ReadingValueCallback, S


class SignalBackend(Generic[T]):
class SignalBackend(Generic[R, S]):
"""A read/write/monitor backend for a Signals"""

#: Datatype of the signal value
datatype: Optional[Type[T]] = None
#: Datatype of the read signal value
read_datatype: Optional[Type[R]] = None
#: Datatype of the write signal value
set_datatype: Optional[Type[S]] = None

#: Like ca://PV_PREFIX:SIGNAL
@abstractmethod
Expand All @@ -23,7 +25,7 @@ async def connect(self, timeout: float = DEFAULT_TIMEOUT):
"""Connect to underlying hardware"""

@abstractmethod
async def put(self, value: Optional[T], wait=True, timeout=None):
async def put(self, value: Optional[S], wait=True, timeout=None):
"""Put a value to the PV, if wait then wait for completion for up to timeout"""

@abstractmethod
Expand All @@ -35,13 +37,13 @@ async def get_reading(self) -> Reading:
"""The current value, timestamp and severity"""

@abstractmethod
async def get_value(self) -> T:
async def get_value(self) -> R:
"""The current value"""

@abstractmethod
async def get_setpoint(self) -> T:
async def get_setpoint(self) -> R:
"""The point that a signal was requested to move to."""

@abstractmethod
def set_callback(self, callback: Optional[ReadingValueCallback[T]]) -> None:
def set_callback(self, callback: Optional[ReadingValueCallback[R]]) -> None:
"""Observe changes to the current value, timestamp and severity"""
Loading

0 comments on commit 5e4e2a7

Please sign in to comment.