Skip to content

Commit

Permalink
Made conversion on enum fields to List[str] optional
Browse files Browse the repository at this point in the history
  • Loading branch information
evalott100 committed Oct 10, 2023
1 parent 3e086c9 commit ad05991
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 17 deletions.
22 changes: 16 additions & 6 deletions pandablocks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@


def words_to_table(
words: Iterable[str], table_field_info: TableFieldInfo
words: Iterable[str],
table_field_info: TableFieldInfo,
convert_enum_indices: bool = False,
) -> Dict[str, UnpackedArray]:
"""Unpacks the given `packed` data based on the fields provided.
Returns the unpacked data in {column_name: column_data} column-indexed format
Expand All @@ -23,6 +25,9 @@ def words_to_table(
expected to be the string representation of a uint32.
table_fields_info: The info for tables, containing the number of words per row,
and the bit information for fields.
convert_enum_indices: If True, converts enum indices to labels, the packed
value will be a list of strings. If False the packed value will be a
numpy array of the indices the labels correspond to.
Returns:
unpacked: A dict containing record information, where keys are field names
and values are numpy arrays or a sequence of strings of record values
Expand Down Expand Up @@ -56,18 +61,23 @@ def words_to_table(
if field_info.subtype == "int":
# First convert from 2's complement to offset, then add in offset.
packing_value = (value ^ (1 << (bit_length - 1))) + (-1 << (bit_length - 1))
elif field_info.labels:
elif convert_enum_indices and field_info.labels:
packing_value = [field_info.labels[x] for x in value]
else:
packing_value = value
if bit_length <= 8:
packing_value = value.astype(np.uint8)
elif bit_length <= 16:
packing_value = value.astype(np.uint16)
else:
packing_value = value.astype(np.uint32)

unpacked.update({field_name: packing_value})

return unpacked


def table_to_words(
table: Dict[str, Iterable], table_field_info: TableFieldInfo
table: Dict[str, Union[np.ndarray, List]], table_field_info: TableFieldInfo
) -> List[str]:
"""Convert records based on the field definitions into the format PandA expects
for table writes.
Expand All @@ -88,8 +98,8 @@ def table_to_words(

for column_name, column in table.items():
field_details = table_field_info.fields[column_name]
if field_details.labels:
# Must convert the list of ints into strings
if field_details.labels and len(column) and isinstance(column[0], str):
# Must convert the list of strings to list of ints
column = [field_details.labels.index(x) for x in column]

# PandA always handles tables in uint32 format
Expand Down
28 changes: 17 additions & 11 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Iterable, List, OrderedDict
from typing import Dict, List, OrderedDict, Union

import numpy as np
import pytest
Expand Down Expand Up @@ -157,7 +157,7 @@ def table_field_info(table_fields) -> TableFieldInfo:


@pytest.fixture
def table_1() -> OrderedDict[str, Iterable]:
def table_1() -> OrderedDict[str, Union[List, np.ndarray]]:
return OrderedDict(
{
"REPEATS": [5, 0, 50000],
Expand All @@ -182,7 +182,7 @@ def table_1() -> OrderedDict[str, Iterable]:


@pytest.fixture
def table_1_np_arrays() -> OrderedDict[str, Iterable]:
def table_1_np_arrays() -> OrderedDict[str, Union[List, np.ndarray]]:
# Intentionally not in panda order. Whatever types the np arrays are,
# the outputs from words_to_table will be uint32 or int32.
return OrderedDict(
Expand All @@ -209,7 +209,7 @@ def table_1_np_arrays() -> OrderedDict[str, Iterable]:


@pytest.fixture
def table_1_not_in_panda_order() -> OrderedDict[str, Iterable]:
def table_1_not_in_panda_order() -> OrderedDict[str, Union[List, np.ndarray]]:
return OrderedDict(
{
"REPEATS": [5, 0, 50000],
Expand Down Expand Up @@ -252,8 +252,8 @@ def table_data_1() -> List[str]:


@pytest.fixture
def table_2() -> Dict[str, Iterable]:
table: Dict[str, Iterable] = dict(
def table_2() -> Dict[str, Union[List, np.ndarray]]:
table: Dict[str, Union[List, np.ndarray]] = dict(
REPEATS=[1, 0],
TRIGGER=["Immediate", "Immediate"],
POSITION=[-20, 2**31 - 1],
Expand Down Expand Up @@ -284,7 +284,7 @@ def table_data_2() -> List[str]:


def test_table_packing_pack_length_mismatched(
table_1: OrderedDict[str, Iterable],
table_1: OrderedDict[str, Union[List, np.ndarray]],
table_field_info: TableFieldInfo,
):
assert table_field_info.row_words
Expand Down Expand Up @@ -312,12 +312,16 @@ def test_table_to_words_and_words_to_table(
table_field_info: TableFieldInfo,
request,
):
table: Dict[str, Iterable] = request.getfixturevalue(table_fixture_name)
table: Dict[str, Union[List, np.ndarray]] = request.getfixturevalue(
table_fixture_name
)
table_data: List[str] = request.getfixturevalue(table_data_fixture_name)

output_data = table_to_words(table, table_field_info)
assert output_data == table_data
output_table = words_to_table(output_data, table_field_info)
output_table = words_to_table(
output_data, table_field_info, convert_enum_indices=True
)

# Test the correct keys are outputted
assert output_table.keys() == table.keys()
Expand All @@ -337,7 +341,9 @@ def test_table_packing_unpack(
table_data_1: List[str],
):
assert table_field_info.row_words
output_table = words_to_table(table_data_1, table_field_info)
output_table = words_to_table(
table_data_1, table_field_info, convert_enum_indices=True
)

actual: UnpackedArray
for field_name, actual in output_table.items():
Expand All @@ -346,7 +352,7 @@ def test_table_packing_unpack(


def test_table_packing_pack(
table_1: Dict[str, Iterable],
table_1: Dict[str, Union[List, np.ndarray]],
table_field_info: TableFieldInfo,
table_data_1: List[str],
):
Expand Down

0 comments on commit ad05991

Please sign in to comment.