diff --git a/test_uai_openlabel/test_data_types/test_generic_data.py b/test_uai_openlabel/test_data_types/test_generic_data.py index 69047d7..16aa7bc 100644 --- a/test_uai_openlabel/test_data_types/test_generic_data.py +++ b/test_uai_openlabel/test_data_types/test_generic_data.py @@ -18,7 +18,14 @@ import pytest -from uai_openlabel import Attributes, NumberData, ObjectData, TextData +from uai_openlabel import ( + Attributes, + NumberData, + ObjectData, + ObjectUid, + TextData, + VectorData, +) def test_attributes_iter_yields_all() -> None: @@ -64,8 +71,7 @@ def test_values_are_converted_from_single_element_sequence( assert isinstance(data.val, conversion_source) # Test data that needs conversion - class NotExactlyConversionTarget(conversion_target): - ... + class NotExactlyConversionTarget(conversion_target): ... vals_needing_conversion = [NotExactlyConversionTarget(1.1)] @@ -73,3 +79,23 @@ class NotExactlyConversionTarget(conversion_target): data = class_name(val=vals_needing_conversion) assert len(caplog.messages) == 1 and class_name.__name__ in caplog.messages[0] assert isinstance(data.val, conversion_target) + + +def test_object_uid_in_vector_data(caplog: pytest.LogCaptureFixture) -> None: + caplog.set_level(logging.INFO) + object_uid = [ObjectUid("075c92c1-7375-49b7-9ebe-76c0f1eac398")] + + data = VectorData(val=object_uid, name="example", **{}) + + assert isinstance(data, VectorData) + assert len(caplog.messages) == 0 + + +def test_object_uid_in_text_data(caplog: pytest.LogCaptureFixture) -> None: + caplog.set_level(logging.INFO) + object_uid = ObjectUid("075c92c1-7375-49b7-9ebe-76c0f1eac398") + + data = TextData(val=object_uid, name="example", **{}) + + assert isinstance(data, TextData) + assert len(caplog.messages) == 0 diff --git a/uai_openlabel/data_types/generic_data.py b/uai_openlabel/data_types/generic_data.py index 17ee561..d8eeb95 100644 --- a/uai_openlabel/data_types/generic_data.py +++ b/uai_openlabel/data_types/generic_data.py @@ -23,7 +23,12 @@ from uai_openlabel.serializer import JsonSnakeCaseSerializableMixin # noinspection PyProtectedMember -from uai_openlabel.types_and_constants import AttributeName, CoordinateSystemUid, Number +from uai_openlabel.types_and_constants import ( + AttributeName, + CoordinateSystemUid, + Number, + ObjectUid, +) # noinspection PyProtectedMember from uai_openlabel.utils import convert_values, no_default, unpack_sequence_of_length_1 @@ -40,9 +45,7 @@ class BooleanType(Enum): @dataclass class BooleanData(JsonSnakeCaseSerializableMixin): - val: bool = field( - default_factory=lambda: no_default(field="BooleanData.val"), metadata=required - ) + val: bool = field(default_factory=lambda: no_default(field="BooleanData.val"), metadata=required) attributes: Optional["Attributes"] = field(default=None) coordinate_system: Optional[CoordinateSystemUid] = field(default=None) @@ -91,9 +94,7 @@ class NumberType(Enum): @dataclass class NumberData(JsonSnakeCaseSerializableMixin): - val: Number = field( - default_factory=lambda: no_default(field="NumberData.val"), metadata=required - ) + val: Number = field(default_factory=lambda: no_default(field="NumberData.val"), metadata=required) attributes: Optional["Attributes"] = field(default=None) coordinate_system: Optional[CoordinateSystemUid] = field(default=None) @@ -140,9 +141,7 @@ class TextType(Enum): @dataclass class TextData(JsonSnakeCaseSerializableMixin): - val: str = field( - default_factory=lambda: no_default(field="TextData.val"), metadata=required - ) + val: str = field(default_factory=lambda: no_default(field="TextData.val"), metadata=required) attributes: Optional["Attributes"] = field(default=None) coordinate_system: Optional[CoordinateSystemUid] = field(default=None) @@ -156,7 +155,7 @@ def __post_init__(self) -> None: converted_val = convert_values( values=[self.val], conversion_target=str, - dont_convert=[], + dont_convert=[ObjectUid], field_name_for_logging=field_name_for_logging, )[0] self.val = converted_val @@ -205,7 +204,7 @@ def __post_init__(self) -> None: converted_val = convert_values( values=self.val, conversion_target=float, - dont_convert=[int, str], + dont_convert=[int, str, ObjectUid], field_name_for_logging=field_name_for_logging, ) self.val = tuple(converted_val) @@ -262,9 +261,7 @@ def static_attributes_example(cls: builtins.type[A]) -> A: ) @classmethod - def dynamic_attributes_example( - cls: builtins.type[A], toggle_value: bool = False - ) -> A: + def dynamic_attributes_example(cls: builtins.type[A], toggle_value: bool = False) -> A: """Contains attributes meant to be dynamic. Use the toggle_value switch to vary them.""" return cls( boolean=[BooleanData.dynamic_example(toggle_value)],