diff --git a/council/utils/data_object.py b/council/utils/data_object.py index 228b712d..8497b7d3 100644 --- a/council/utils/data_object.py +++ b/council/utils/data_object.py @@ -1,11 +1,13 @@ from __future__ import annotations import abc -from typing import Any, Dict, Generic, Optional, Type, TypeVar +from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union import yaml from typing_extensions import Self +Label = Optional[Union[str, List[str]]] + class DataObjectMetadata: def __init__(self, name: str, labels: Dict[str, Any], description: Optional[str] = None) -> None: @@ -21,11 +23,26 @@ def get_label_value(self, label: str) -> Optional[Any]: return self.labels[label] return None - def is_matching_labels(self, labels: Dict[str, Any]) -> bool: - for label in labels: - value = self.get_label_value(label) - if value != labels[label]: + def is_matching_labels(self, labels: Dict[str, Label]) -> bool: + """ + Returns true if the test_case_object satisfies any of the following: + - if value_to_check is None, check if the label exists + - exact match of label-value pair + - when a label maps to a list, check if value_to_check is in the list of values for that label + """ + for label, value_to_check in labels.items(): + if not self.has_label(label): return False + + value = self.get_label_value(label) + if value_to_check is not None: + if type(value_to_check) is not type(value): + return False + elif isinstance(value_to_check, str) and isinstance(value, str) and value_to_check != value: + return False + elif isinstance(value_to_check, list) and isinstance(value, list): + if not all(v in value for v in value_to_check): + return False return True def to_dict(self) -> Dict[str, Any]: diff --git a/tests/unit/utils/test_data_object.py b/tests/unit/utils/test_data_object.py new file mode 100644 index 00000000..2aa614e0 --- /dev/null +++ b/tests/unit/utils/test_data_object.py @@ -0,0 +1,24 @@ +import unittest + +from council.utils.data_object import DataObjectMetadata + + +class TestCodeDataObject(unittest.TestCase): + + def test_code_data_object(self): + labels = {"test": "value1", "test2": "value2", "test_array": ["value31", "value32", "value33"]} + metadata = DataObjectMetadata("test", labels=labels) + self.assertEqual(metadata.labels, labels.copy()) + self.assertTrue(metadata.has_label("test")) + + self.assertTrue(metadata.is_matching_labels(labels={"test": "value1"})) + self.assertTrue(metadata.is_matching_labels(labels={"test": "value1", "test2": "value2", "test_array": None})) + self.assertTrue(metadata.is_matching_labels(labels={"test": None})) + self.assertTrue(metadata.is_matching_labels(labels={"test_array": None})) + self.assertTrue(metadata.is_matching_labels(labels={"test_array": ["value31", "value33"]})) + + self.assertFalse(metadata.is_matching_labels(labels={"test": "value2"})) + self.assertFalse(metadata.is_matching_labels(labels={"key_not_exist": None})) + self.assertFalse(metadata.is_matching_labels(labels={"key_not_exist": "value2"})) + self.assertFalse(metadata.is_matching_labels(labels={"test_array": "value32"})) + self.assertFalse(metadata.is_matching_labels(labels={"test_array": ["value31", "value33", "value34"]}))