diff --git a/pandablocks/utils.py b/pandablocks/utils.py index 7970731c9..2453f0623 100644 --- a/pandablocks/utils.py +++ b/pandablocks/utils.py @@ -13,7 +13,9 @@ def words_to_table( - words: Iterable[str], table_field_info: TableFieldInfo + words: Iterable[str], + table_field_info: TableFieldInfo, + convert_enum_indices: bool = False, ) -> Dict[str, UnpackedArray]: """Unpacks the given `packed` data based on the fields provided. Returns the unpacked data in {column_name: column_data} column-indexed format @@ -23,6 +25,9 @@ def words_to_table( expected to be the string representation of a uint32. table_fields_info: The info for tables, containing the number of words per row, and the bit information for fields. + convert_enum_indices: If True, converts enum indices to labels, the packed + value will be a list of strings. If False the packed value will be a + numpy array of the indices the labels correspond to. Returns: unpacked: A dict containing record information, where keys are field names and values are numpy arrays or a sequence of strings of record values @@ -56,10 +61,15 @@ def words_to_table( if field_info.subtype == "int": # First convert from 2's complement to offset, then add in offset. packing_value = (value ^ (1 << (bit_length - 1))) + (-1 << (bit_length - 1)) - elif field_info.labels: + elif convert_enum_indices and field_info.labels: packing_value = [field_info.labels[x] for x in value] else: - packing_value = value + if bit_length <= 8: + packing_value = value.astype(np.uint8) + elif bit_length <= 16: + packing_value = value.astype(np.uint16) + else: + packing_value = value.astype(np.uint32) unpacked.update({field_name: packing_value}) @@ -67,7 +77,7 @@ def words_to_table( def table_to_words( - table: Dict[str, Iterable], table_field_info: TableFieldInfo + table: Dict[str, Union[np.ndarray, List]], table_field_info: TableFieldInfo ) -> List[str]: """Convert records based on the field definitions into the format PandA expects for table writes. @@ -88,8 +98,8 @@ def table_to_words( for column_name, column in table.items(): field_details = table_field_info.fields[column_name] - if field_details.labels: - # Must convert the list of ints into strings + if field_details.labels and len(column) and isinstance(column[0], str): + # Must convert the list of strings to list of ints column = [field_details.labels.index(x) for x in column] # PandA always handles tables in uint32 format diff --git a/tests/test_utils.py b/tests/test_utils.py index 6057dd635..95c14473b 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,4 +1,4 @@ -from typing import Dict, Iterable, List, OrderedDict +from typing import Dict, List, OrderedDict, Union import numpy as np import pytest @@ -157,7 +157,7 @@ def table_field_info(table_fields) -> TableFieldInfo: @pytest.fixture -def table_1() -> OrderedDict[str, Iterable]: +def table_1() -> OrderedDict[str, Union[List, np.ndarray]]: return OrderedDict( { "REPEATS": [5, 0, 50000], @@ -182,7 +182,7 @@ def table_1() -> OrderedDict[str, Iterable]: @pytest.fixture -def table_1_np_arrays() -> OrderedDict[str, Iterable]: +def table_1_np_arrays() -> OrderedDict[str, Union[List, np.ndarray]]: # Intentionally not in panda order. Whatever types the np arrays are, # the outputs from words_to_table will be uint32 or int32. return OrderedDict( @@ -209,7 +209,7 @@ def table_1_np_arrays() -> OrderedDict[str, Iterable]: @pytest.fixture -def table_1_not_in_panda_order() -> OrderedDict[str, Iterable]: +def table_1_not_in_panda_order() -> OrderedDict[str, Union[List, np.ndarray]]: return OrderedDict( { "REPEATS": [5, 0, 50000], @@ -252,8 +252,8 @@ def table_data_1() -> List[str]: @pytest.fixture -def table_2() -> Dict[str, Iterable]: - table: Dict[str, Iterable] = dict( +def table_2() -> Dict[str, Union[List, np.ndarray]]: + table: Dict[str, Union[List, np.ndarray]] = dict( REPEATS=[1, 0], TRIGGER=["Immediate", "Immediate"], POSITION=[-20, 2**31 - 1], @@ -284,7 +284,7 @@ def table_data_2() -> List[str]: def test_table_packing_pack_length_mismatched( - table_1: OrderedDict[str, Iterable], + table_1: OrderedDict[str, Union[List, np.ndarray]], table_field_info: TableFieldInfo, ): assert table_field_info.row_words @@ -312,12 +312,16 @@ def test_table_to_words_and_words_to_table( table_field_info: TableFieldInfo, request, ): - table: Dict[str, Iterable] = request.getfixturevalue(table_fixture_name) + table: Dict[str, Union[List, np.ndarray]] = request.getfixturevalue( + table_fixture_name + ) table_data: List[str] = request.getfixturevalue(table_data_fixture_name) output_data = table_to_words(table, table_field_info) assert output_data == table_data - output_table = words_to_table(output_data, table_field_info) + output_table = words_to_table( + output_data, table_field_info, convert_enum_indices=True + ) # Test the correct keys are outputted assert output_table.keys() == table.keys() @@ -337,7 +341,9 @@ def test_table_packing_unpack( table_data_1: List[str], ): assert table_field_info.row_words - output_table = words_to_table(table_data_1, table_field_info) + output_table = words_to_table( + table_data_1, table_field_info, convert_enum_indices=True + ) actual: UnpackedArray for field_name, actual in output_table.items(): @@ -346,7 +352,7 @@ def test_table_packing_unpack( def test_table_packing_pack( - table_1: Dict[str, Iterable], + table_1: Dict[str, Union[List, np.ndarray]], table_field_info: TableFieldInfo, table_data_1: List[str], ):