Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MetricsRecord #2802

Merged
merged 18 commits into from
Jan 18, 2024
Merged
96 changes: 96 additions & 0 deletions src/py/flwr/common/metricsrecord.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""MetricsRecord."""


from dataclasses import dataclass, field
from typing import Dict, Optional, get_args

from .typing import MetricsRecordValues, MetricsScalar


@dataclass
class MetricsRecord:
"""Metrics record."""

keep_input: bool
data: Dict[str, MetricsRecordValues] = field(default_factory=dict)

def __init__(
self,
metrics_dict: Optional[Dict[str, MetricsRecordValues]] = None,
keep_input: bool = True,
):
"""Construct a MetricsRecord object.

Parameters
----------
metrics_dict : Optional[Dict[str, MetricsRecordValues]]
A dictionary that stores basic types (i.e. `int`, `float` as defined
in `MetricsScalar`) and list of such types (see `MetricsScalarList`).
keep_input : bool (default: True)
A boolean indicating whether metrics should be deleted from the input
dictionary immediately after adding them to the record. When set
to True, the data is duplicated in memory. If memory is a concern, set
it to False.
"""
self.keep_input = keep_input
self.data = {}
if metrics_dict:
self.set_metrics(metrics_dict)

def set_metrics(self, metrics_dict: Dict[str, MetricsRecordValues]) -> None:
"""Add metrics to the record.

Parameters
----------
metrics_dict : Dict[str, MetricsRecordValues]
A dictionary that stores basic types (i.e. `int`, `float` as defined
in `MetricsScalar`) and list of such types (see `MetricsScalarList`).
"""
if any(not isinstance(k, str) for k in metrics_dict.keys()):
raise TypeError(f"Not all keys are of valid type. Expected {str}.")

def is_valid(value: MetricsScalar) -> None:
"""Check if value is of expected type."""
if not isinstance(value, get_args(MetricsScalar)):
raise TypeError(
"Not all values are of valid type."
f" Expected {MetricsRecordValues} but you passed {type(value)}."
)

# Check types of values
# Split between those values that are list and those that aren't
# then process in the same way
for value in metrics_dict.values():
if isinstance(value, list):
# If your lists are large (e.g. 1M+ elements) this will be slow
# 1s to check 10M element list on a M2 Pro
# In such settings, you'd be better of treating such metric as
# an array and pass it to a ParametersRecord.
for list_value in value:
is_valid(list_value)
else:
is_valid(value)

# Add metrics to record
if self.keep_input:
# Copy
self.data = metrics_dict.copy()
else:
# Add entries to dataclass without duplicating memory
for key in list(metrics_dict.keys()):
self.data[key] = metrics_dict[key]
del metrics_dict[key]
6 changes: 1 addition & 5 deletions src/py/flwr/common/recordset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,10 @@
from dataclasses import dataclass, field
from typing import Dict

from .metricsrecord import MetricsRecord
from .parametersrecord import ParametersRecord


@dataclass
class MetricsRecord:
"""Metrics record."""


@dataclass
class ConfigsRecord:
"""Configs record."""
Expand Down
114 changes: 112 additions & 2 deletions src/py/flwr/common/recordset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,19 @@
"""RecordSet tests."""


from typing import Callable, List, OrderedDict, Type, Union
from typing import Callable, Dict, List, OrderedDict, Type, Union

import numpy as np
import pytest

from .metricsrecord import MetricsRecord
from .parameter import ndarrays_to_parameters, parameters_to_ndarrays
from .parametersrecord import Array, ParametersRecord
from .recordset_utils import (
parameters_to_parametersrecord,
parametersrecord_to_parameters,
)
from .typing import NDArray, NDArrays, Parameters
from .typing import MetricsRecordValues, NDArray, NDArrays, Parameters


def get_ndarrays() -> NDArrays:
Expand Down Expand Up @@ -145,3 +146,112 @@ def test_set_parameters_with_incorrect_types(

with pytest.raises(TypeError):
p_record.set_parameters(array_dict) # type: ignore


@pytest.mark.parametrize(
"key_type, value_fn",
[
(str, lambda x: int(x.flatten()[0])), # str: int
(str, lambda x: float(x.flatten()[0])), # str: float
(str, lambda x: x.flatten().astype("int").tolist()), # str: List[int]
(str, lambda x: x.flatten().astype("float").tolist()), # str: List[float]
],
)
def test_set_metrics_to_metricsrecord_with_correct_types(
key_type: Type[str],
value_fn: Callable[[NDArray], MetricsRecordValues],
) -> None:
"""Test adding metrics of various types to a MetricsRecord."""
m_record = MetricsRecord()

labels = [1, 2.0]
arrays = get_ndarrays()

my_metrics = OrderedDict(
{key_type(label): value_fn(arr) for label, arr in zip(labels, arrays)}
)

# Add metric
m_record.set_metrics(my_metrics)

# Check metrics are actually added
assert my_metrics == m_record.data


@pytest.mark.parametrize(
"key_type, value_fn",
[
(str, lambda x: str(x.flatten()[0])), # str: str (supported: unsupported)
(
str,
lambda x: x.flatten().astype("str").tolist(),
), # str: List[str] (supported: unsupported)
(str, lambda x: x), # str: NDArray (supported: unsupported)
(
str,
lambda x: {str(v): v for v in x.flatten()},
), # str: dict[str: float] (supported: unsupported)
(
str,
lambda x: [{str(v): v for v in x.flatten()}],
), # str: List[dict[str: float]] (supported: unsupported)
(
int,
lambda x: x.flatten().tolist(),
), # int: List[str] (unsupported: supported)
(
float,
lambda x: x.flatten().tolist(),
), # float: List[int] (unsupported: supported)
],
)
def test_set_metrics_to_metricsrecord_with_incorrect_types(
key_type: Type[Union[str, int, float]],
value_fn: Callable[[NDArray], Union[NDArray, Dict[str, NDArray], List[float]]],
) -> None:
"""Test adding metrics of various unsupported types to a MetricsRecord."""
m_record = MetricsRecord()

labels = [1, 2.0]
arrays = get_ndarrays()

my_metrics = OrderedDict(
{key_type(label): value_fn(arr) for label, arr in zip(labels, arrays)}
)

with pytest.raises(TypeError):
m_record.set_metrics(my_metrics) # type: ignore


@pytest.mark.parametrize(
"keep_input",
[
(True),
(False),
],
)
def test_set_metrics_to_metricsrecord_with_and_without_keeping_input(
keep_input: bool,
) -> None:
"""Test keep_input functionality for MetricsRecord."""
m_record = MetricsRecord(keep_input=keep_input)

# constructing a valid input
labels = [1, 2.0]
arrays = get_ndarrays()
my_metrics = OrderedDict(
{str(label): arr.flatten().tolist() for label, arr in zip(labels, arrays)}
)

my_metrics_copy = my_metrics.copy()

# Add metric
m_record.set_metrics(my_metrics)

# Check metrics are actually added
# Check that input dict has been emptied when enabled such behaviour
if keep_input:
assert my_metrics == m_record.data
else:
assert my_metrics_copy == m_record.data
assert len(my_metrics) == 0
5 changes: 5 additions & 0 deletions src/py/flwr/common/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@
List[str],
]

# Value types for common.MetricsRecord
MetricsScalar = Union[int, float]
MetricsScalarList = Union[List[int], List[float]]
MetricsRecordValues = Union[MetricsScalar, MetricsScalarList]

Metrics = Dict[str, Scalar]
MetricsAggregationFn = Callable[[List[Tuple[int, Metrics]]], Metrics]

Expand Down