diff --git a/src/py/flwr/common/metricsrecord.py b/src/py/flwr/common/metricsrecord.py new file mode 100644 index 000000000000..68eca732efa2 --- /dev/null +++ b/src/py/flwr/common/metricsrecord.py @@ -0,0 +1,96 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""MetricsRecord.""" + + +from dataclasses import dataclass, field +from typing import Dict, Optional, get_args + +from .typing import MetricsRecordValues, MetricsScalar + + +@dataclass +class MetricsRecord: + """Metrics record.""" + + keep_input: bool + data: Dict[str, MetricsRecordValues] = field(default_factory=dict) + + def __init__( + self, + metrics_dict: Optional[Dict[str, MetricsRecordValues]] = None, + keep_input: bool = True, + ): + """Construct a MetricsRecord object. + + Parameters + ---------- + metrics_dict : Optional[Dict[str, MetricsRecordValues]] + A dictionary that stores basic types (i.e. `int`, `float` as defined + in `MetricsScalar`) and list of such types (see `MetricsScalarList`). + keep_input : bool (default: True) + A boolean indicating whether metrics should be deleted from the input + dictionary immediately after adding them to the record. When set + to True, the data is duplicated in memory. If memory is a concern, set + it to False. + """ + self.keep_input = keep_input + self.data = {} + if metrics_dict: + self.set_metrics(metrics_dict) + + def set_metrics(self, metrics_dict: Dict[str, MetricsRecordValues]) -> None: + """Add metrics to the record. + + Parameters + ---------- + metrics_dict : Dict[str, MetricsRecordValues] + A dictionary that stores basic types (i.e. `int`, `float` as defined + in `MetricsScalar`) and list of such types (see `MetricsScalarList`). + """ + if any(not isinstance(k, str) for k in metrics_dict.keys()): + raise TypeError(f"Not all keys are of valid type. Expected {str}.") + + def is_valid(value: MetricsScalar) -> None: + """Check if value is of expected type.""" + if not isinstance(value, get_args(MetricsScalar)): + raise TypeError( + "Not all values are of valid type." + f" Expected {MetricsRecordValues} but you passed {type(value)}." + ) + + # Check types of values + # Split between those values that are list and those that aren't + # then process in the same way + for value in metrics_dict.values(): + if isinstance(value, list): + # If your lists are large (e.g. 1M+ elements) this will be slow + # 1s to check 10M element list on a M2 Pro + # In such settings, you'd be better of treating such metric as + # an array and pass it to a ParametersRecord. + for list_value in value: + is_valid(list_value) + else: + is_valid(value) + + # Add metrics to record + if self.keep_input: + # Copy + self.data = metrics_dict.copy() + else: + # Add entries to dataclass without duplicating memory + for key in list(metrics_dict.keys()): + self.data[key] = metrics_dict[key] + del metrics_dict[key] diff --git a/src/py/flwr/common/recordset.py b/src/py/flwr/common/recordset.py index dc723a2cea86..a5af909911fe 100644 --- a/src/py/flwr/common/recordset.py +++ b/src/py/flwr/common/recordset.py @@ -17,14 +17,10 @@ from dataclasses import dataclass, field from typing import Dict +from .metricsrecord import MetricsRecord from .parametersrecord import ParametersRecord -@dataclass -class MetricsRecord: - """Metrics record.""" - - @dataclass class ConfigsRecord: """Configs record.""" diff --git a/src/py/flwr/common/recordset_test.py b/src/py/flwr/common/recordset_test.py index 90c06dcdb109..26b2b4976e13 100644 --- a/src/py/flwr/common/recordset_test.py +++ b/src/py/flwr/common/recordset_test.py @@ -15,18 +15,19 @@ """RecordSet tests.""" -from typing import Callable, List, OrderedDict, Type, Union +from typing import Callable, Dict, List, OrderedDict, Type, Union import numpy as np import pytest +from .metricsrecord import MetricsRecord from .parameter import ndarrays_to_parameters, parameters_to_ndarrays from .parametersrecord import Array, ParametersRecord from .recordset_utils import ( parameters_to_parametersrecord, parametersrecord_to_parameters, ) -from .typing import NDArray, NDArrays, Parameters +from .typing import MetricsRecordValues, NDArray, NDArrays, Parameters def get_ndarrays() -> NDArrays: @@ -145,3 +146,112 @@ def test_set_parameters_with_incorrect_types( with pytest.raises(TypeError): p_record.set_parameters(array_dict) # type: ignore + + +@pytest.mark.parametrize( + "key_type, value_fn", + [ + (str, lambda x: int(x.flatten()[0])), # str: int + (str, lambda x: float(x.flatten()[0])), # str: float + (str, lambda x: x.flatten().astype("int").tolist()), # str: List[int] + (str, lambda x: x.flatten().astype("float").tolist()), # str: List[float] + ], +) +def test_set_metrics_to_metricsrecord_with_correct_types( + key_type: Type[str], + value_fn: Callable[[NDArray], MetricsRecordValues], +) -> None: + """Test adding metrics of various types to a MetricsRecord.""" + m_record = MetricsRecord() + + labels = [1, 2.0] + arrays = get_ndarrays() + + my_metrics = OrderedDict( + {key_type(label): value_fn(arr) for label, arr in zip(labels, arrays)} + ) + + # Add metric + m_record.set_metrics(my_metrics) + + # Check metrics are actually added + assert my_metrics == m_record.data + + +@pytest.mark.parametrize( + "key_type, value_fn", + [ + (str, lambda x: str(x.flatten()[0])), # str: str (supported: unsupported) + ( + str, + lambda x: x.flatten().astype("str").tolist(), + ), # str: List[str] (supported: unsupported) + (str, lambda x: x), # str: NDArray (supported: unsupported) + ( + str, + lambda x: {str(v): v for v in x.flatten()}, + ), # str: dict[str: float] (supported: unsupported) + ( + str, + lambda x: [{str(v): v for v in x.flatten()}], + ), # str: List[dict[str: float]] (supported: unsupported) + ( + int, + lambda x: x.flatten().tolist(), + ), # int: List[str] (unsupported: supported) + ( + float, + lambda x: x.flatten().tolist(), + ), # float: List[int] (unsupported: supported) + ], +) +def test_set_metrics_to_metricsrecord_with_incorrect_types( + key_type: Type[Union[str, int, float]], + value_fn: Callable[[NDArray], Union[NDArray, Dict[str, NDArray], List[float]]], +) -> None: + """Test adding metrics of various unsupported types to a MetricsRecord.""" + m_record = MetricsRecord() + + labels = [1, 2.0] + arrays = get_ndarrays() + + my_metrics = OrderedDict( + {key_type(label): value_fn(arr) for label, arr in zip(labels, arrays)} + ) + + with pytest.raises(TypeError): + m_record.set_metrics(my_metrics) # type: ignore + + +@pytest.mark.parametrize( + "keep_input", + [ + (True), + (False), + ], +) +def test_set_metrics_to_metricsrecord_with_and_without_keeping_input( + keep_input: bool, +) -> None: + """Test keep_input functionality for MetricsRecord.""" + m_record = MetricsRecord(keep_input=keep_input) + + # constructing a valid input + labels = [1, 2.0] + arrays = get_ndarrays() + my_metrics = OrderedDict( + {str(label): arr.flatten().tolist() for label, arr in zip(labels, arrays)} + ) + + my_metrics_copy = my_metrics.copy() + + # Add metric + m_record.set_metrics(my_metrics) + + # Check metrics are actually added + # Check that input dict has been emptied when enabled such behaviour + if keep_input: + assert my_metrics == m_record.data + else: + assert my_metrics_copy == m_record.data + assert len(my_metrics) == 0 diff --git a/src/py/flwr/common/typing.py b/src/py/flwr/common/typing.py index 6c0266f5eec8..a8196126ecfc 100644 --- a/src/py/flwr/common/typing.py +++ b/src/py/flwr/common/typing.py @@ -45,6 +45,11 @@ List[str], ] +# Value types for common.MetricsRecord +MetricsScalar = Union[int, float] +MetricsScalarList = Union[List[int], List[float]] +MetricsRecordValues = Union[MetricsScalar, MetricsScalarList] + Metrics = Dict[str, Scalar] MetricsAggregationFn = Callable[[List[Tuple[int, Metrics]]], Metrics]