diff --git a/src/ophyd_async/core/_table.py b/src/ophyd_async/core/_table.py index ff54a01022..f36b60dceb 100644 --- a/src/ophyd_async/core/_table.py +++ b/src/ophyd_async/core/_table.py @@ -1,4 +1,5 @@ -from typing import TypeVar +from enum import Enum +from typing import TypeVar, get_args, get_origin import numpy as np from pydantic import BaseModel, ConfigDict, model_validator @@ -6,6 +7,13 @@ TableSubclass = TypeVar("TableSubclass", bound="Table") +def _concat(value1, value2): + if isinstance(value1, np.ndarray): + return np.concatenate((value1, value2)) + else: + return value1 + value2 + + class Table(BaseModel): """An abstraction of a Table of str to numpy array.""" @@ -13,34 +21,105 @@ class Table(BaseModel): @staticmethod def row(cls: type[TableSubclass], **kwargs) -> TableSubclass: # type: ignore - arrayified_kwargs = { - field_name: np.concatenate( - ( - (default_arr := field_value.default_factory()), # type: ignore - np.array([kwargs[field_name]], dtype=default_arr.dtype), + arrayified_kwargs = {} + for field_name, field_value in cls.model_fields.items(): + value = kwargs.pop(field_name) + if field_value.default_factory is None: + raise ValueError( + "`Table` models should have default factories for their " + "mutable empty columns." + ) + default_array = field_value.default_factory() + if isinstance(default_array, np.ndarray): + arrayified_kwargs[field_name] = np.array( + [value], dtype=default_array.dtype + ) + elif issubclass(type(value), Enum) and isinstance(value, str): + arrayified_kwargs[field_name] = [value] + else: + raise TypeError( + "Row column should be numpy arrays or sequence of string `Enum`." ) + if kwargs: + raise TypeError( + f"Unexpected keyword arguments {kwargs.keys()} for {cls.__name__}." ) - for field_name, field_value in cls.model_fields.items() - } return cls(**arrayified_kwargs) def __add__(self, right: TableSubclass) -> TableSubclass: """Concatenate the arrays in field values.""" - assert type(right) is type(self), ( - f"{right} is not a `Table`, or is not the same " - f"type of `Table` as {self}." - ) + if type(right) is not type(self): + raise RuntimeError( + f"{right} is not a `Table`, or is not the same " + f"type of `Table` as {self}." + ) return type(right)( **{ - field_name: np.concatenate( - (getattr(self, field_name), getattr(right, field_name)) + field_name: _concat( + getattr(self, field_name), getattr(right, field_name) ) for field_name in self.model_fields } ) + def numpy_dtype(self) -> np.dtype: + dtype = [] + for field_name, field_value in self.model_fields.items(): + if np.ndarray in ( + get_origin(field_value.annotation), + field_value.annotation, + ): + dtype.append((field_name, getattr(self, field_name).dtype)) + else: + enum_type = get_args(field_value.annotation)[0] + assert issubclass(enum_type, Enum) + enum_values = [element.value for element in enum_type] + max_length_in_enum = max(len(value) for value in enum_values) + dtype.append((field_name, np.dtype(f" list[np.ndarray]: + """Columns in the table can be lists of string enums or numpy arrays. + + This method returns the columns, converting the string enums to numpy arrays. + """ + + columns = [] + for field_name, field_value in self.model_fields.items(): + if np.ndarray in ( + get_origin(field_value.annotation), + field_value.annotation, + ): + columns.append(getattr(self, field_name)) + else: + enum_type = get_args(field_value.annotation)[0] + assert issubclass(enum_type, Enum) + enum_values = [element.value for element in enum_type] + max_length_in_enum = max(len(value) for value in enum_values) + dtype = np.dtype(f" "Table": first_length = len(next(iter(self))[1]) @@ -49,11 +128,15 @@ def validate_arrays(self) -> "Table": ), "Rows should all be of equal size." if not all( - np.issubdtype( - self.model_fields[field_name].default_factory().dtype, # type: ignore - field_value.dtype, + # Checks if the values are numpy subtypes if the array is a numpy array, + # or if the value is a string enum. + np.issubdtype(getattr(self, field_name).dtype, default_array.dtype) + if isinstance( + default_array := self.model_fields[field_name].default_factory(), # type: ignore + np.ndarray, ) - for field_name, field_value in self + else issubclass(get_args(field_value.annotation)[0], Enum) + for field_name, field_value in self.model_fields.items() ): raise ValueError( f"Cannot construct a `{type(self).__name__}`, " diff --git a/src/ophyd_async/fastcs/panda/_table.py b/src/ophyd_async/fastcs/panda/_table.py index 91fce9a298..a021d23fa8 100644 --- a/src/ophyd_async/fastcs/panda/_table.py +++ b/src/ophyd_async/fastcs/panda/_table.py @@ -4,7 +4,7 @@ import numpy as np import numpy.typing as npt -from pydantic import Field, field_validator, model_validator +from pydantic import Field, model_validator from pydantic_numpy.helper.annotation import NpArrayPydanticAnnotation from typing_extensions import TypedDict @@ -51,13 +51,7 @@ class SeqTrigger(str, Enum): ), Field(default_factory=lambda: np.array([], dtype=np.bool_)), ] -TriggerStr = Annotated[ - np.ndarray[tuple[int], np.dtype[np.unicode_]], - NpArrayPydanticAnnotation.factory( - data_type=np.unicode_, dimensions=1, strict_data_typing=False - ), - Field(default_factory=lambda: np.array([], dtype=np.dtype(" "SeqTable": - if isinstance(trigger, SeqTrigger): - trigger = trigger.value - return super().row(**locals()) - - @field_validator("trigger", mode="before") - @classmethod - def trigger_to_np_array(cls, trigger_column): - """ - The user can provide a list of SeqTrigger enum elements instead of a numpy str. - """ - if isinstance(trigger_column, Sequence) and all( - isinstance(trigger, SeqTrigger) for trigger in trigger_column - ): - trigger_column = np.array( - [trigger.value for trigger in trigger_column], dtype=np.dtype(" "SeqTable": diff --git a/tests/fastcs/panda/test_table.py b/tests/fastcs/panda/test_table.py index b34103250b..ed963c91f5 100644 --- a/tests/fastcs/panda/test_table.py +++ b/tests/fastcs/panda/test_table.py @@ -12,15 +12,18 @@ def test_seq_table_converts_lists(): seq_table_dict_with_lists = {field_name: [] for field_name, _ in SeqTable()} # Validation passes seq_table = SeqTable(**seq_table_dict_with_lists) - assert isinstance(seq_table.trigger, np.ndarray) - assert seq_table.trigger.dtype == np.dtype("U32") + for field_name, field_value in seq_table: + if field_name == "trigger": + assert field_value == [] + else: + assert np.array_equal(field_value, np.array([], dtype=field_value.dtype)) def test_seq_table_validation_errors(): with pytest.raises(ValidationError, match="81 validation errors for SeqTable"): SeqTable( repeats=0, - trigger="Immediate", + trigger=SeqTrigger.IMMEDIATE, position=0, time1=0, outa1=False, @@ -40,7 +43,7 @@ def test_seq_table_validation_errors(): large_seq_table = SeqTable( repeats=np.zeros(4095, dtype=np.int32), - trigger=np.array(["Immediate"] * 4095, dtype="U32"), + trigger=["Immediate"] * 4095, position=np.zeros(4095, dtype=np.int32), time1=np.zeros(4095, dtype=np.int32), outa1=np.zeros(4095, dtype=np.bool_), @@ -73,16 +76,25 @@ def test_seq_table_validation_errors(): wrong_types = { field_name: field_value.astype(np.unicode_) for field_name, field_value in row_one + if isinstance(field_value, np.ndarray) } SeqTable(**wrong_types) + with pytest.raises( + TypeError, + match="Row column should be numpy arrays or sequence of string `Enum`", + ): + SeqTable.row(trigger="A") def test_seq_table_pva_conversion(): pva_dict = { "repeats": np.array([1, 2, 3, 4], dtype=np.int32), - "trigger": np.array( - ["Immediate", "Immediate", "BITC=0", "Immediate"], dtype=np.dtype("U32") - ), + "trigger": [ + SeqTrigger.IMMEDIATE, + SeqTrigger.IMMEDIATE, + SeqTrigger.BITC_0, + SeqTrigger.IMMEDIATE, + ], "position": np.array([1, 2, 3, 4], dtype=np.int32), "time1": np.array([1, 0, 1, 0], dtype=np.int32), "outa1": np.array([1, 0, 1, 0], dtype=np.bool_), @@ -102,7 +114,7 @@ def test_seq_table_pva_conversion(): row_wise_dicts = [ { "repeats": 1, - "trigger": "Immediate", + "trigger": SeqTrigger.IMMEDIATE, "position": 1, "time1": 1, "outa1": 1, @@ -121,7 +133,7 @@ def test_seq_table_pva_conversion(): }, { "repeats": 2, - "trigger": "Immediate", + "trigger": SeqTrigger.IMMEDIATE, "position": 2, "time1": 0, "outa1": 0, @@ -140,7 +152,7 @@ def test_seq_table_pva_conversion(): }, { "repeats": 3, - "trigger": "BITC=0", + "trigger": SeqTrigger.BITC_0, "position": 3, "time1": 1, "outa1": 1, @@ -159,7 +171,7 @@ def test_seq_table_pva_conversion(): }, { "repeats": 4, - "trigger": "Immediate", + "trigger": SeqTrigger.IMMEDIATE, "position": 4, "time1": 0, "outa1": 0, @@ -178,12 +190,20 @@ def test_seq_table_pva_conversion(): }, ] + def _assert_col_equal(column1, column2): + if isinstance(column1, np.ndarray): + assert np.array_equal(column1, column2) + assert column1.dtype == column2.dtype + else: + assert column1 == column2 + assert all(isinstance(x, SeqTrigger) for x in column1) + assert all(isinstance(x, SeqTrigger) for x in column2) + seq_table_from_pva_dict = SeqTable(**pva_dict) for (_, column1), column2 in zip( seq_table_from_pva_dict, pva_dict.values(), strict=False ): - assert np.array_equal(column1, column2) - assert column1.dtype == column2.dtype + _assert_col_equal(column1, column2) seq_table_from_rows = reduce( lambda x, y: x + y, @@ -192,41 +212,174 @@ def test_seq_table_pva_conversion(): for (_, column1), column2 in zip( seq_table_from_rows, pva_dict.values(), strict=False ): - assert np.array_equal(column1, column2) - assert column1.dtype == column2.dtype + _assert_col_equal(column1, column2) # Idempotency applied_twice_to_pva_dict = SeqTable(**pva_dict).model_dump(mode="python") for column1, column2 in zip( applied_twice_to_pva_dict.values(), pva_dict.values(), strict=False ): - assert np.array_equal(column1, column2) - assert column1.dtype == column2.dtype + _assert_col_equal(column1, column2) + + assert np.array_equal( + seq_table_from_pva_dict.numpy_columns(), + [ + np.array([1, 2, 3, 4], dtype=np.int32), + np.array( + [ + "Immediate", + "Immediate", + "BITC=0", + "Immediate", + ], + dtype="