Skip to content

Commit

Permalink
*res.metrics stored as ConfigsRecord
Browse files Browse the repository at this point in the history
  • Loading branch information
jafermarq committed Jan 23, 2024
1 parent e884586 commit 21f07a0
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 55 deletions.
16 changes: 8 additions & 8 deletions src/py/flwr/common/recordset_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,9 +210,9 @@ def recordset_to_fitres(recordset: RecordSet, keep_input: bool) -> FitRes:
num_examples = cast(
int, recordset.get_metrics(f"{ins_str}.num_examples")["num_examples"]
)
metrics_record = recordset.get_metrics(f"{ins_str}.metrics")
configs_record = recordset.get_configs(f"{ins_str}.metrics")

metrics = _check_mapping_from_recordscalartype_to_scalar(metrics_record.data)
metrics = _check_mapping_from_recordscalartype_to_scalar(configs_record.data)
status = _extract_status_from_recordset(ins_str, recordset)

return FitRes(
Expand All @@ -226,8 +226,8 @@ def fitres_to_recordset(fitres: FitRes, keep_input: bool) -> RecordSet:

res_str = "fitres"

recordset.set_metrics(
name=f"{res_str}.metrics", record=MetricsRecord(fitres.metrics) # type: ignore
recordset.set_configs(
name=f"{res_str}.metrics", record=ConfigsRecord(fitres.metrics) # type: ignore
)
recordset.set_metrics(
name=f"{res_str}.num_examples",
Expand Down Expand Up @@ -269,9 +269,9 @@ def recordset_to_evaluateres(recordset: RecordSet) -> EvaluateRes:
num_examples = cast(
int, recordset.get_metrics(f"{ins_str}.num_examples")["num_examples"]
)
metrics_record = recordset.get_metrics(f"{ins_str}.metrics")
configs_record = recordset.get_configs(f"{ins_str}.metrics")

metrics = _check_mapping_from_recordscalartype_to_scalar(metrics_record.data)
metrics = _check_mapping_from_recordscalartype_to_scalar(configs_record.data)
status = _extract_status_from_recordset(ins_str, recordset)

return EvaluateRes(
Expand All @@ -297,9 +297,9 @@ def evaluateres_to_recordset(evaluateres: EvaluateRes) -> RecordSet:
)

# metrics
recordset.set_metrics(
recordset.set_configs(
name=f"{res_str}.metrics",
record=MetricsRecord(evaluateres.metrics), # type: ignore
record=ConfigsRecord(evaluateres.metrics), # type: ignore
)

# status
Expand Down
62 changes: 15 additions & 47 deletions src/py/flwr/common/recordset_compat_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,10 @@
# ==============================================================================
"""RecordSet from legacy messages tests."""

from contextlib import nullcontext
from copy import deepcopy
from typing import Any, Dict
from typing import Dict

import numpy as np
import pytest

from .parameter import ndarrays_to_parameters
from .recordset_compat import (
Expand Down Expand Up @@ -75,9 +73,10 @@ def _get_valid_fitins() -> FitIns:
return FitIns(parameters=ndarrays_to_parameters(arrays), config={"a": 1.0, "b": 0})


def _get_valid_fitres_with_config(metrics: Dict[str, Scalar]) -> FitRes:
def _get_valid_fitres() -> FitRes:
"""Returnn Valid parameters but potentially invalid config."""
arrays = get_ndarrays()
metrics: Dict[str, Scalar] = {"a": 1.0, "b": 0}
return FitRes(
parameters=ndarrays_to_parameters(arrays),
num_examples=1,
Expand All @@ -91,8 +90,9 @@ def _get_valid_evaluateins() -> EvaluateIns:
return EvaluateIns(parameters=fit_ins.parameters, config=fit_ins.config)


def _get_valid_evaluateres_with_config(metrics: Dict[str, Scalar]) -> EvaluateRes:
def _get_valid_evaluateres() -> EvaluateRes:
"""Return potentially invalid config."""
metrics: Dict[str, Scalar] = {"a": 1.0, "b": 0}
return EvaluateRes(
num_examples=1,
loss=0.1,
Expand Down Expand Up @@ -149,31 +149,16 @@ def test_fitins_to_recordset_and_back() -> None:
assert fitins_copy == fitins_


@pytest.mark.parametrize(
"context, metrics",
[
(nullcontext(), {"a": 1.0, "b": 0}),
(
pytest.raises(TypeError),
{"a": 1.0, "b": 3, "c": True},
), # fails due to unsupported type for metricsrecord value
],
)
def test_fitres_to_recordset_and_back(context: Any, metrics: Dict[str, Scalar]) -> None:
def test_fitres_to_recordset_and_back() -> None:
"""Test conversion FitRes --> RecordSet --> FitRes."""
fitres = _get_valid_fitres_with_config(metrics)
fitres = _get_valid_fitres()

fitres_copy = deepcopy(fitres)

with context:
recordset = fitres_to_recordset(fitres, keep_input=False)
fitres_ = recordset_to_fitres(recordset, keep_input=False)
recordset = fitres_to_recordset(fitres, keep_input=False)
fitres_ = recordset_to_fitres(recordset, keep_input=False)

# only check if we didn't test for an invalid setting. Only in valid settings
# makes sense to evaluate the below, since both functions above have succesfully
# being executed.
if isinstance(context, nullcontext):
assert fitres_copy == fitres_
assert fitres_copy == fitres_


def test_evaluateins_to_recordset_and_back() -> None:
Expand All @@ -189,33 +174,16 @@ def test_evaluateins_to_recordset_and_back() -> None:
assert evaluateins_copy == evaluateins_


@pytest.mark.parametrize(
"context, metrics",
[
(nullcontext(), {"a": 1.0, "b": 0}),
(
pytest.raises(TypeError),
{"a": 1.0, "b": 3, "c": True},
), # fails due to unsupported type for metricsrecord value
],
)
def test_evaluateres_to_recordset_and_back(
context: Any, metrics: Dict[str, Scalar]
) -> None:
def test_evaluateres_to_recordset_and_back() -> None:
"""Test conversion EvaluateRes --> RecordSet --> EvaluateRes."""
evaluateres = _get_valid_evaluateres_with_config(metrics)
evaluateres = _get_valid_evaluateres()

evaluateres_copy = deepcopy(evaluateres)

with context:
recordset = evaluateres_to_recordset(evaluateres)
evaluateres_ = recordset_to_evaluateres(recordset)
recordset = evaluateres_to_recordset(evaluateres)
evaluateres_ = recordset_to_evaluateres(recordset)

# only check if we didn't test for an invalid setting. Only in valid settings
# makes sense to evaluate the below, since both functions above have succesfully
# being executed.
if isinstance(context, nullcontext):
assert evaluateres_copy == evaluateres_
assert evaluateres_copy == evaluateres_


def test_get_properties_ins_to_recordset_and_back() -> None:
Expand Down

0 comments on commit 21f07a0

Please sign in to comment.