Skip to content

Commit

Permalink
Refactor Signal type parsing in pvi logic
Browse files Browse the repository at this point in the history
  • Loading branch information
GDYendell committed Apr 17, 2024
1 parent 2fc44f2 commit ddb677f
Showing 1 changed file with 50 additions and 54 deletions.
104 changes: 50 additions & 54 deletions src/ophyd_async/epics/pvi/pvi.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import re
from dataclasses import dataclass
from inspect import isclass
from typing import (
Any,
Callable,
Dict,
FrozenSet,
Expand Down Expand Up @@ -44,6 +44,18 @@ def _strip_number_from_string(string: str) -> Tuple[str, Optional[int]]:
return name, number


def _split_subscript(tp: T) -> Union[Tuple[Any, Tuple[Any]], Tuple[T, None]]:
"""Split a subscripted type into the its origin and args.
If `tp` is not a subscripted type, then just return the type and None as args.
"""
if get_origin(tp) is not None:
return get_origin(tp), get_args(tp)

return tp, None


def _strip_union(field: Union[Union[T], T]) -> T:
if get_origin(field) is Union:
args = get_args(field)
Expand Down Expand Up @@ -115,86 +127,70 @@ def _parse_type(
):
if common_device_type:
# pre-defined type
device_type = _strip_union(common_device_type)
is_device_vector, device_type = _strip_device_vector(device_type)

if ((origin := get_origin(device_type)) and issubclass(origin, Signal)) or (
isclass(device_type) and issubclass(device_type, Signal)
):
# if device_type is of the form `Signal` or `Signal[type]`
is_signal = True
signal_dtype = get_args(device_type)[0]
else:
is_signal = False
signal_dtype = None
device_cls = _strip_union(common_device_type)
is_device_vector, device_cls = _strip_device_vector(device_cls)
device_cls, device_args = _split_subscript(device_cls)
assert issubclass(device_cls, Device)

is_signal = issubclass(device_cls, Signal)
signal_dtype = device_args[0] if device_args is not None else None

elif is_pvi_table:
# is a block, we can make it a DeviceVector if it ends in a number
is_device_vector = number_suffix is not None
is_signal = False
signal_dtype = None
device_type = Device
device_cls = Device
else:
# is a signal, signals aren't stored in DeviceVectors unless
# they're defined as such in the common_device_type
is_device_vector = False
is_signal = True
signal_dtype = None
device_type = Signal
device_cls = Signal

return is_device_vector, is_signal, signal_dtype, device_type
return is_device_vector, is_signal, signal_dtype, device_cls


def _sim_common_blocks(device: Device, stripped_type: Optional[Type] = None):
device_t = stripped_type or type(device)
for sub_name, sub_device_t in get_type_hints(device_t).items():
if sub_name in ("_name", "parent"):
continue
sub_devices = (
(field, field_type)
for field, field_type in get_type_hints(device_t).items()
if field not in ("_name", "parent")
)

for device_name, device_cls in sub_devices:
device_cls = _strip_union(device_cls)
is_device_vector, device_cls = _strip_device_vector(device_cls)
device_cls, device_args = _split_subscript(device_cls)
assert issubclass(device_cls, Device)

# we'll take the first type in the union which isn't NoneType
sub_device_t = _strip_union(sub_device_t)
is_device_vector, sub_device_t = _strip_device_vector(sub_device_t)
is_signal = (
(origin := get_origin(sub_device_t)) and issubclass(origin, Signal)
) or (issubclass(sub_device_t, Signal))
is_signal = issubclass(device_cls, Signal)
signal_dtype = device_args[0] if device_args is not None else None

# TODO: worth coming back to all this code once 3.9 is gone and we can use
# match statments: https://github.com/bluesky/ophyd-async/issues/180
if is_device_vector:
if is_signal:
signal_type = args[0] if (args := get_args(sub_device_t)) else None
sub_device_1 = sub_device_t(SimSignalBackend(signal_type, sub_name))
sub_device_2 = sub_device_t(SimSignalBackend(signal_type, sub_name))
sub_device = DeviceVector(
{
1: sub_device_1,
2: sub_device_2,
}
)
sub_device_1 = device_cls(SimSignalBackend(signal_dtype, device_name))
sub_device_2 = device_cls(SimSignalBackend(signal_dtype, device_name))
sub_device = DeviceVector({1: sub_device_1, 2: sub_device_2})
else:
sub_device = DeviceVector(
{
1: sub_device_t(),
2: sub_device_t(),
}
)
sub_device = DeviceVector({1: device_cls(), 2: device_cls()})

for sub_device_in_vector in sub_device.values():
_sim_common_blocks(sub_device_in_vector, stripped_type=device_cls)

for value in sub_device.values():
value.parent = sub_device

elif is_signal:
signal_type = args[0] if (args := get_args(sub_device_t)) else None
sub_device = sub_device_t(SimSignalBackend(signal_type, sub_name))
else:
sub_device = sub_device_t()

if not is_signal:
if is_device_vector:
for sub_device_in_vector in sub_device.values():
_sim_common_blocks(sub_device_in_vector, stripped_type=sub_device_t)
if is_signal:
sub_device = device_cls(SimSignalBackend(signal_dtype, device_name))
else:
_sim_common_blocks(sub_device, stripped_type=sub_device_t)
sub_device = device_cls()

_sim_common_blocks(sub_device, stripped_type=device_cls)

setattr(device, sub_name, sub_device)
setattr(device, device_name, sub_device)
sub_device.parent = device


Expand Down

0 comments on commit ddb677f

Please sign in to comment.