Skip to content

Commit

Permalink
better tests; definitions in typing
Browse files Browse the repository at this point in the history
  • Loading branch information
jafermarq committed Jan 17, 2024
1 parent 092a74b commit 32b1155
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 34 deletions.
45 changes: 29 additions & 16 deletions src/py/flwr/common/metricsrecord.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,32 +15,50 @@
"""MetricsRecord."""

from dataclasses import dataclass, field
from typing import Dict, Union, get_args
from typing import Dict, Optional, Union, get_args

from .typing import Scalar, ScalarList
from .typing import MetricsScalar, MetricsScalarList

MetricsRecordValues = Union[MetricsScalar, MetricsScalarList]


@dataclass
class MetricsRecord:
"""Parameters record."""

data: Dict[str, Union[Scalar, ScalarList]] = field(default_factory=dict)
data: Dict[str, MetricsRecordValues] = field(default_factory=dict)

def __init__(self, metrics_dict: Optional[Dict[str, MetricsRecordValues]] = None):
"""Construct a MetricsRecord object.
def add_metrics(self, metrics_dict: Dict[str, Union[Scalar, ScalarList]]) -> None:
Parameters
----------
array_dict : Optional[Dict[str, MetricsRecordValues]]
A dictionary that stores basic types (i.e. `str`, `int`, `float` as defined
in `MetricsScalar`) and list of such types (see `MetricsScalarList`).
"""
self.data = {}
if metrics_dict:
self.set_metrics(metrics_dict)

def set_metrics(self, metrics_dict: Dict[str, MetricsRecordValues]) -> None:
"""Add metrics to record.
This not implemented as a constructor so we can cleanly create and empyt
MetricsRecord object.
Parameters
----------
array_dict : Optional[Dict[str, MetricsRecordValues]]
A dictionary that stores basic types (i.e. `str`, `int`, `float` as defined
in `MetricsScalar`) and list of such types (see `MetricsScalarList`).
"""
if any(not isinstance(k, str) for k in metrics_dict.keys()):
raise TypeError(f"Not all keys are of valide type. Expected {str}")
raise TypeError(f"Not all keys are of valid type. Expected {str}")

def is_valid(value: Scalar) -> None:
def is_valid(value: MetricsScalar) -> None:
"""Check if value is of expected type."""
if not isinstance(value, get_args(Scalar)):
if not isinstance(value, get_args(MetricsScalar)):
raise TypeError(
"Not all values are of valide type."
f" Expected {Union[Scalar, ScalarList]}"
"Not all values are of valid type."
f" Expected {MetricsRecordValues}"
)

# Check types of values
Expand All @@ -56,8 +74,3 @@ def is_valid(value: Scalar) -> None:
is_valid(list_value)
else:
is_valid(value)

# Add entries to dataclass without duplicating memory
for key in list(metrics_dict.keys()):
self.data[key] = metrics_dict[key]
del metrics_dict[key]
76 changes: 65 additions & 11 deletions src/py/flwr/common/recordset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""RecordSet tests."""


from typing import Callable, List, OrderedDict, Type, Union
from typing import Callable, Dict, List, OrderedDict, Type, Union

import numpy as np
import pytest
Expand All @@ -27,7 +27,7 @@
parameters_to_parametersrecord,
parametersrecord_to_parameters,
)
from .typing import NDArray, NDArrays, Parameters, Scalar, ScalarList
from .typing import NDArray, NDArrays, Parameters


def get_ndarrays() -> NDArrays:
Expand Down Expand Up @@ -148,17 +148,71 @@ def test_set_parameters_with_incorrect_types(
p_record.set_parameters(array_dict) # type: ignore


def test_add_metrics_to_metricsrecord() -> None:
@pytest.mark.parametrize(
"key_type, value_fn",
[
(str, lambda x: str(x.flatten()[0])), # str: str
(str, lambda x: int(x.flatten()[0])), # str: int
(str, lambda x: float(x.flatten()[0])), # str: float
(str, lambda x: x.flatten().astype("str").tolist()), # str: List[str]
(str, lambda x: x.flatten().astype("int").tolist()), # str: List[int]
(str, lambda x: x.flatten().astype("float").tolist()), # str: List[float]
],
)
def test_set_metrics_to_metricsrecord_with_correct_types(
key_type: Type[str],
value_fn: Callable[
[NDArray], Union[str, int, float, List[str], List[int], List[float]]
],
) -> None:
"""Test adding metrics of various types to a MetricsRecord."""
m_record = MetricsRecord()

my_metrics: OrderedDict[str, Union[Scalar, ScalarList]] = OrderedDict(
{
"loss": 0.12445,
"converged": True,
"my_int": 2,
"embeddings": np.random.randn(10).tolist(),
}
labels = [1, 2.0]
arrays = get_ndarrays()

my_metrics = OrderedDict(
{key_type(label): value_fn(arr) for label, arr in zip(labels, arrays)}
)

m_record.set_metrics(my_metrics)


@pytest.mark.parametrize(
"key_type, value_fn",
[
(str, lambda x: x), # str: NDArray (supported: unsupported)
(
str,
lambda x: {str(v): v for v in x.flatten()},
), # str: dict[str: float] (supported: unsupported)
(
str,
lambda x: [{str(v): v for v in x.flatten()}],
), # str: List[dict[str: float]] (supported: unsupported)
(
int,
lambda x: x.flatten().tolist(),
), # int: List[str] (unsupported: supported)
(
float,
lambda x: x.flatten().tolist(),
), # float: List[int] (unsupported: supported)
],
)
def test_set_metrics_to_metricsrecord_with_incorrect_types(
key_type: Type[Union[str, int, float]],
value_fn: Callable[[NDArray], Union[NDArray, Dict[str, NDArray], List[float]]],
) -> None:
"""Test adding metrics of various unsupported types to a MetricsRecord."""
m_record = MetricsRecord()

labels = [1, 2.0]
arrays = get_ndarrays()

my_metrics = OrderedDict(
{key_type(label): value_fn(arr) for label, arr in zip(labels, arrays)}
)

m_record.add_metrics(my_metrics)
with pytest.raises(TypeError):
m_record.set_metrics(my_metrics) # type: ignore
11 changes: 4 additions & 7 deletions src/py/flwr/common/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,6 @@
# not conform to other definitions of what a scalar is. Source:
# https://developers.google.com/protocol-buffers/docs/overview#scalar
Scalar = Union[bool, bytes, float, int, str]
ScalarList = Union[
List[bool],
List[bytes],
List[float],
List[int],
List[str],
]
Value = Union[
bool,
bytes,
Expand All @@ -52,6 +45,10 @@
List[str],
]

# Value types for common.MetricsRecord
MetricsScalar = Union[str, int, float]
MetricsScalarList = Union[List[str], List[int], List[float]]

Metrics = Dict[str, Scalar]
MetricsAggregationFn = Callable[[List[Tuple[int, Metrics]]], Metrics]

Expand Down

0 comments on commit 32b1155

Please sign in to comment.