From 32b1155a017fac4c5f2fbd73c12e0d3a209b347a Mon Sep 17 00:00:00 2001 From: jafermarq Date: Wed, 17 Jan 2024 21:10:39 +0000 Subject: [PATCH] better tests; definitions in typing --- src/py/flwr/common/metricsrecord.py | 45 ++++++++++------ src/py/flwr/common/recordset_test.py | 76 ++++++++++++++++++++++++---- src/py/flwr/common/typing.py | 11 ++-- 3 files changed, 98 insertions(+), 34 deletions(-) diff --git a/src/py/flwr/common/metricsrecord.py b/src/py/flwr/common/metricsrecord.py index aa683ff7a2ca..3b9d1dc35b2e 100644 --- a/src/py/flwr/common/metricsrecord.py +++ b/src/py/flwr/common/metricsrecord.py @@ -15,32 +15,50 @@ """MetricsRecord.""" from dataclasses import dataclass, field -from typing import Dict, Union, get_args +from typing import Dict, Optional, Union, get_args -from .typing import Scalar, ScalarList +from .typing import MetricsScalar, MetricsScalarList + +MetricsRecordValues = Union[MetricsScalar, MetricsScalarList] @dataclass class MetricsRecord: """Parameters record.""" - data: Dict[str, Union[Scalar, ScalarList]] = field(default_factory=dict) + data: Dict[str, MetricsRecordValues] = field(default_factory=dict) + + def __init__(self, metrics_dict: Optional[Dict[str, MetricsRecordValues]] = None): + """Construct a MetricsRecord object. - def add_metrics(self, metrics_dict: Dict[str, Union[Scalar, ScalarList]]) -> None: + Parameters + ---------- + array_dict : Optional[Dict[str, MetricsRecordValues]] + A dictionary that stores basic types (i.e. `str`, `int`, `float` as defined + in `MetricsScalar`) and list of such types (see `MetricsScalarList`). + """ + self.data = {} + if metrics_dict: + self.set_metrics(metrics_dict) + + def set_metrics(self, metrics_dict: Dict[str, MetricsRecordValues]) -> None: """Add metrics to record. - This not implemented as a constructor so we can cleanly create and empyt - MetricsRecord object. + Parameters + ---------- + array_dict : Optional[Dict[str, MetricsRecordValues]] + A dictionary that stores basic types (i.e. `str`, `int`, `float` as defined + in `MetricsScalar`) and list of such types (see `MetricsScalarList`). """ if any(not isinstance(k, str) for k in metrics_dict.keys()): - raise TypeError(f"Not all keys are of valide type. Expected {str}") + raise TypeError(f"Not all keys are of valid type. Expected {str}") - def is_valid(value: Scalar) -> None: + def is_valid(value: MetricsScalar) -> None: """Check if value is of expected type.""" - if not isinstance(value, get_args(Scalar)): + if not isinstance(value, get_args(MetricsScalar)): raise TypeError( - "Not all values are of valide type." - f" Expected {Union[Scalar, ScalarList]}" + "Not all values are of valid type." + f" Expected {MetricsRecordValues}" ) # Check types of values @@ -56,8 +74,3 @@ def is_valid(value: Scalar) -> None: is_valid(list_value) else: is_valid(value) - - # Add entries to dataclass without duplicating memory - for key in list(metrics_dict.keys()): - self.data[key] = metrics_dict[key] - del metrics_dict[key] diff --git a/src/py/flwr/common/recordset_test.py b/src/py/flwr/common/recordset_test.py index 17a501ed50cd..a8ae6f0778a6 100644 --- a/src/py/flwr/common/recordset_test.py +++ b/src/py/flwr/common/recordset_test.py @@ -15,7 +15,7 @@ """RecordSet tests.""" -from typing import Callable, List, OrderedDict, Type, Union +from typing import Callable, Dict, List, OrderedDict, Type, Union import numpy as np import pytest @@ -27,7 +27,7 @@ parameters_to_parametersrecord, parametersrecord_to_parameters, ) -from .typing import NDArray, NDArrays, Parameters, Scalar, ScalarList +from .typing import NDArray, NDArrays, Parameters def get_ndarrays() -> NDArrays: @@ -148,17 +148,71 @@ def test_set_parameters_with_incorrect_types( p_record.set_parameters(array_dict) # type: ignore -def test_add_metrics_to_metricsrecord() -> None: +@pytest.mark.parametrize( + "key_type, value_fn", + [ + (str, lambda x: str(x.flatten()[0])), # str: str + (str, lambda x: int(x.flatten()[0])), # str: int + (str, lambda x: float(x.flatten()[0])), # str: float + (str, lambda x: x.flatten().astype("str").tolist()), # str: List[str] + (str, lambda x: x.flatten().astype("int").tolist()), # str: List[int] + (str, lambda x: x.flatten().astype("float").tolist()), # str: List[float] + ], +) +def test_set_metrics_to_metricsrecord_with_correct_types( + key_type: Type[str], + value_fn: Callable[ + [NDArray], Union[str, int, float, List[str], List[int], List[float]] + ], +) -> None: """Test adding metrics of various types to a MetricsRecord.""" m_record = MetricsRecord() - my_metrics: OrderedDict[str, Union[Scalar, ScalarList]] = OrderedDict( - { - "loss": 0.12445, - "converged": True, - "my_int": 2, - "embeddings": np.random.randn(10).tolist(), - } + labels = [1, 2.0] + arrays = get_ndarrays() + + my_metrics = OrderedDict( + {key_type(label): value_fn(arr) for label, arr in zip(labels, arrays)} + ) + + m_record.set_metrics(my_metrics) + + +@pytest.mark.parametrize( + "key_type, value_fn", + [ + (str, lambda x: x), # str: NDArray (supported: unsupported) + ( + str, + lambda x: {str(v): v for v in x.flatten()}, + ), # str: dict[str: float] (supported: unsupported) + ( + str, + lambda x: [{str(v): v for v in x.flatten()}], + ), # str: List[dict[str: float]] (supported: unsupported) + ( + int, + lambda x: x.flatten().tolist(), + ), # int: List[str] (unsupported: supported) + ( + float, + lambda x: x.flatten().tolist(), + ), # float: List[int] (unsupported: supported) + ], +) +def test_set_metrics_to_metricsrecord_with_incorrect_types( + key_type: Type[Union[str, int, float]], + value_fn: Callable[[NDArray], Union[NDArray, Dict[str, NDArray], List[float]]], +) -> None: + """Test adding metrics of various unsupported types to a MetricsRecord.""" + m_record = MetricsRecord() + + labels = [1, 2.0] + arrays = get_ndarrays() + + my_metrics = OrderedDict( + {key_type(label): value_fn(arr) for label, arr in zip(labels, arrays)} ) - m_record.add_metrics(my_metrics) + with pytest.raises(TypeError): + m_record.set_metrics(my_metrics) # type: ignore diff --git a/src/py/flwr/common/typing.py b/src/py/flwr/common/typing.py index ffa7be88e40c..6ec7979835fe 100644 --- a/src/py/flwr/common/typing.py +++ b/src/py/flwr/common/typing.py @@ -32,13 +32,6 @@ # not conform to other definitions of what a scalar is. Source: # https://developers.google.com/protocol-buffers/docs/overview#scalar Scalar = Union[bool, bytes, float, int, str] -ScalarList = Union[ - List[bool], - List[bytes], - List[float], - List[int], - List[str], -] Value = Union[ bool, bytes, @@ -52,6 +45,10 @@ List[str], ] +# Value types for common.MetricsRecord +MetricsScalar = Union[str, int, float] +MetricsScalarList = Union[List[str], List[int], List[float]] + Metrics = Dict[str, Scalar] MetricsAggregationFn = Callable[[List[Tuple[int, Metrics]]], Metrics]