Skip to content

Commit

Permalink
added utilities for converting SeqTable to numpy array
Browse files Browse the repository at this point in the history
  • Loading branch information
evalott100 committed Sep 16, 2024
1 parent 83144c1 commit 46c9d2a
Show file tree
Hide file tree
Showing 3 changed files with 204 additions and 49 deletions.
49 changes: 34 additions & 15 deletions src/ophyd_async/core/_table.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from enum import Enum
from typing import get_args
from typing import get_args, get_origin

import numpy as np
from pydantic import BaseModel, ConfigDict, model_validator
Expand Down Expand Up @@ -63,42 +63,61 @@ def __add__(self, right: "Table") -> "Table":
}
)

@property
def numpy_dtype(self) -> np.dtype:
dtype = []
for field_value in self.model_fields.values():
if isinstance(field_value, np.ndarray):
dtype.append(field_value.dtype)
for field_name, field_value in self.model_fields.items():
if np.ndarray in (
get_origin(field_value.annotation),
field_value.annotation,
):
dtype.append((field_name, getattr(self, field_name).dtype))
else:
enum_type = get_args(field_value.annotation)[0]
assert issubclass(enum_type, Enum)
enum_values = [element.value for element in enum_type]
max_length_in_enum = max(len(value) for value in enum_values)
dtype.append(np.dtype(f"<U{max_length_in_enum}"))
dtype.append((field_name, np.dtype(f"<U{max_length_in_enum}")))

def numpy_table(self):
return np.array(
self.numpy_columns(),
dtype=self.numpy_dtype(),
).transpose()
return np.dtype(dtype)

@property
def numpy_table(self):
# It would be nice to be able to use np.transpose for this,
# but it defaults to the largest dtype for everything.
dtype = self.numpy_dtype
transposed_list = [
np.array(tuple(row), dtype=dtype) for row in zip(*self.numpy_columns)
]
transposed = np.array(transposed_list, dtype=dtype)
return transposed

@property
def numpy_columns(self) -> list[np.ndarray]:
"""Columns in the table can be lists of string enums or numpy arrays.
This method returns the columns, converting the string enums to numpy arrays.
"""

columns = []
for field_value in self.model_fields.values():
if isinstance(field_value, np.ndarray):
columns.append(field_value)
for field_name, field_value in self.model_fields.items():
if np.ndarray in (
get_origin(field_value.annotation),
field_value.annotation,
):
columns.append(getattr(self, field_name))
else:
enum_type = get_args(field_value.field_info.annotation)[0]
enum_type = get_args(field_value.annotation)[0]
assert issubclass(enum_type, Enum)
enum_values = [element.value for element in enum_type]
max_length_in_enum = max(len(value) for value in enum_values)
dtype = np.dtype(f"<U{max_length_in_enum}")

columns.append(np.array(enum_values, dtype=dtype))
columns.append(
np.array(
[enum.value for enum in getattr(self, field_name)], dtype=dtype
)
)

return columns

Expand Down
4 changes: 0 additions & 4 deletions src/ophyd_async/fastcs/panda/_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,6 @@ def row(
) -> "SeqTable":
sig = inspect.signature(cls.row)
kwargs = {k: v for k, v in locals().items() if k in sig.parameters}
if not isinstance(kwargs["trigger"], SeqTrigger):
if kwargs["trigger"] not in SeqTrigger.__members__.values():
raise ValueError(f"'{kwargs['trigger']}' is not a valid SeqTrigger.")
kwargs["trigger"] = SeqTrigger(kwargs["trigger"])
return Table.row(cls, **kwargs)

@model_validator(mode="after")
Expand Down
200 changes: 170 additions & 30 deletions tests/fastcs/panda/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def test_seq_table_validation_errors():
with pytest.raises(ValidationError, match="81 validation errors for SeqTable"):
SeqTable(
repeats=0,
trigger="Immediate",
trigger=SeqTrigger.IMMEDIATE,
position=0,
time1=0,
outa1=False,
Expand All @@ -43,7 +43,7 @@ def test_seq_table_validation_errors():

large_seq_table = SeqTable(
repeats=np.zeros(4095, dtype=np.int32),
trigger=np.array(["Immediate"] * 4095, dtype="U32"),
trigger=["Immediate"] * 4095,
position=np.zeros(4095, dtype=np.int32),
time1=np.zeros(4095, dtype=np.int32),
outa1=np.zeros(4095, dtype=np.bool_),
Expand Down Expand Up @@ -79,7 +79,10 @@ def test_seq_table_validation_errors():
if isinstance(field_value, np.ndarray)
}
SeqTable(**wrong_types)
with pytest.raises(ValueError, match="'A' is not a valid SeqTrigger."):
with pytest.raises(
TypeError,
match="Row column should be numpy arrays or sequence of string `Enum`",
):
SeqTable.row(trigger="A")


Expand Down Expand Up @@ -111,7 +114,7 @@ def test_seq_table_pva_conversion():
row_wise_dicts = [
{
"repeats": 1,
"trigger": "Immediate",
"trigger": SeqTrigger.IMMEDIATE,
"position": 1,
"time1": 1,
"outa1": 1,
Expand All @@ -130,7 +133,7 @@ def test_seq_table_pva_conversion():
},
{
"repeats": 2,
"trigger": "Immediate",
"trigger": SeqTrigger.IMMEDIATE,
"position": 2,
"time1": 0,
"outa1": 0,
Expand All @@ -149,7 +152,7 @@ def test_seq_table_pva_conversion():
},
{
"repeats": 3,
"trigger": "BITC=0",
"trigger": SeqTrigger.BITC_0,
"position": 3,
"time1": 1,
"outa1": 1,
Expand All @@ -168,7 +171,7 @@ def test_seq_table_pva_conversion():
},
{
"repeats": 4,
"trigger": "Immediate",
"trigger": SeqTrigger.IMMEDIATE,
"position": 4,
"time1": 0,
"outa1": 0,
Expand Down Expand Up @@ -212,28 +215,165 @@ def _assert_col_equal(column1, column2):
for column1, column2 in zip(applied_twice_to_pva_dict.values(), pva_dict.values()):
_assert_col_equal(column1, column2)

assert np.array_equal(
seq_table_from_pva_dict.numpy_columns,
[
np.array([1, 2, 3, 4], dtype=np.int32),
np.array(
[
"Immediate",
"Immediate",
"BITC=0",
"Immediate",
],
dtype="<U14",
),
np.array([1, 2, 3, 4], dtype=np.int32),
np.array([1, 0, 1, 0], dtype=np.int32),
np.array([True, False, True, False], dtype=np.bool_),
np.array([True, False, True, False], dtype=np.bool_),
np.array([True, False, True, False], dtype=np.bool_),
np.array([True, False, True, False], dtype=np.bool_),
np.array([True, False, True, False], dtype=np.bool_),
np.array([True, False, True, False], dtype=np.bool_),
np.array([1, 2, 3, 4], dtype=np.int32),
np.array([True, False, True, False], dtype=np.bool_),
np.array([True, False, True, False], dtype=np.bool_),
np.array([True, False, True, False], dtype=np.bool_),
np.array([True, False, True, False], dtype=np.bool_),
np.array([True, False, True, False], dtype=np.bool_),
np.array([True, False, True, False], dtype=np.bool_),
],
)
dtype = seq_table_from_pva_dict.numpy_dtype
assert dtype == np.dtype(
[
("repeats", np.int32),
("trigger", "<U14"),
("position", np.int32),
("time1", np.int32),
("outa1", np.bool_),
("outb1", np.bool_),
("outc1", np.bool_),
("outd1", np.bool_),
("oute1", np.bool_),
("outf1", np.bool_),
("time2", np.int32),
("outa2", np.bool_),
("outb2", np.bool_),
("outc2", np.bool_),
("outd2", np.bool_),
("oute2", np.bool_),
("outf2", np.bool_),
]
)

assert np.array_equal(
seq_table_from_pva_dict.numpy_table,
np.array(
[
(
1,
"Immediate",
1,
1,
True,
True,
True,
True,
True,
True,
1,
True,
True,
True,
True,
True,
True,
),
(
2,
"Immediate",
2,
0,
False,
False,
False,
False,
False,
False,
2,
False,
False,
False,
False,
False,
False,
),
(
3,
"BITC=0",
3,
1,
True,
True,
True,
True,
True,
True,
3,
True,
True,
True,
True,
True,
True,
),
(
4,
"Immediate",
4,
0,
False,
False,
False,
False,
False,
False,
4,
False,
False,
False,
False,
False,
False,
),
],
dtype=dtype,
),
)


def test_seq_table_takes_trigger_enum_row():
for trigger in (SeqTrigger.BITA_0, "BITA=0"):
table = SeqTable.row(trigger=trigger)
assert table.trigger[0] == SeqTrigger.BITA_0
table = SeqTable(
repeats=np.array([1], dtype=np.int32),
trigger=[trigger],
position=np.array([1], dtype=np.int32),
time1=np.array([1], dtype=np.int32),
outa1=np.array([1], dtype=np.bool_),
outb1=np.array([1], dtype=np.bool_),
outc1=np.array([1], dtype=np.bool_),
outd1=np.array([1], dtype=np.bool_),
oute1=np.array([1], dtype=np.bool_),
outf1=np.array([1], dtype=np.bool_),
time2=np.array([1], dtype=np.int32),
outa2=np.array([1], dtype=np.bool_),
outb2=np.array([1], dtype=np.bool_),
outc2=np.array([1], dtype=np.bool_),
outd2=np.array([1], dtype=np.bool_),
oute2=np.array([1], dtype=np.bool_),
outf2=np.array([1], dtype=np.bool_),
)
assert table.trigger[0] == SeqTrigger.BITA_0
table = SeqTable.row(trigger=SeqTrigger.BITA_0)
assert table.trigger[0] == SeqTrigger.BITA_0
table = SeqTable(
repeats=np.array([1], dtype=np.int32),
trigger=[SeqTrigger.BITA_0],
position=np.array([1], dtype=np.int32),
time1=np.array([1], dtype=np.int32),
outa1=np.array([1], dtype=np.bool_),
outb1=np.array([1], dtype=np.bool_),
outc1=np.array([1], dtype=np.bool_),
outd1=np.array([1], dtype=np.bool_),
oute1=np.array([1], dtype=np.bool_),
outf1=np.array([1], dtype=np.bool_),
time2=np.array([1], dtype=np.int32),
outa2=np.array([1], dtype=np.bool_),
outb2=np.array([1], dtype=np.bool_),
outc2=np.array([1], dtype=np.bool_),
outd2=np.array([1], dtype=np.bool_),
oute2=np.array([1], dtype=np.bool_),
outf2=np.array([1], dtype=np.bool_),
)
assert table.trigger[0] == SeqTrigger.BITA_0

0 comments on commit 46c9d2a

Please sign in to comment.