From b43ccf6188c6cf2d47ea6cb37ae8f172b32dfc7d Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Wed, 3 Apr 2024 17:25:05 +0100 Subject: [PATCH 1/4] hide members and make recordset picklable --- src/py/flwr/common/record/recordset.py | 93 +++++++++++++++++++------- 1 file changed, 67 insertions(+), 26 deletions(-) diff --git a/src/py/flwr/common/record/recordset.py b/src/py/flwr/common/record/recordset.py index d8ef44ab15c2..ba60889ca7ef 100644 --- a/src/py/flwr/common/record/recordset.py +++ b/src/py/flwr/common/record/recordset.py @@ -16,7 +16,7 @@ from dataclasses import dataclass -from typing import Callable, Dict, Optional, Type, TypeVar +from typing import Dict, Optional, TypeVar, cast from .configsrecord import ConfigsRecord from .metricsrecord import MetricsRecord @@ -26,13 +26,12 @@ T = TypeVar("T") -@dataclass -class RecordSet: - """RecordSet stores groups of parameters, metrics and configs.""" +class RecordSetData: + """Inner data container for theRecordSet class.""" - _parameters_records: TypedDict[str, ParametersRecord] - _metrics_records: TypedDict[str, MetricsRecord] - _configs_records: TypedDict[str, ConfigsRecord] + parameters_records: TypedDict[str, ParametersRecord] + metrics_records: TypedDict[str, MetricsRecord] + configs_records: TypedDict[str, ConfigsRecord] def __init__( self, @@ -40,40 +39,82 @@ def __init__( metrics_records: Optional[Dict[str, MetricsRecord]] = None, configs_records: Optional[Dict[str, ConfigsRecord]] = None, ) -> None: - def _get_check_fn(__t: Type[T]) -> Callable[[T], None]: - def _check_fn(__v: T) -> None: - if not isinstance(__v, __t): - raise TypeError(f"Expected `{__t}`, but `{type(__v)}` was passed.") - - return _check_fn - - self._parameters_records = TypedDict[str, ParametersRecord]( - _get_check_fn(str), _get_check_fn(ParametersRecord) + self.parameters_records = TypedDict[str, ParametersRecord]( + self._check_fn_str, self._check_fn_params ) - self._metrics_records = TypedDict[str, MetricsRecord]( - _get_check_fn(str), _get_check_fn(MetricsRecord) + self.metrics_records = TypedDict[str, MetricsRecord]( + self._check_fn_str, self._check_fn_metrics ) - self._configs_records = TypedDict[str, ConfigsRecord]( - _get_check_fn(str), _get_check_fn(ConfigsRecord) + self.configs_records = TypedDict[str, ConfigsRecord]( + self._check_fn_str, self._check_fn_configs ) if parameters_records is not None: - self._parameters_records.update(parameters_records) + self.parameters_records.update(parameters_records) if metrics_records is not None: - self._metrics_records.update(metrics_records) + self.metrics_records.update(metrics_records) if configs_records is not None: - self._configs_records.update(configs_records) + self.configs_records.update(configs_records) + + def _check_fn_str(self, key: str) -> None: + if not isinstance(key, str): + raise TypeError( + f"Expected `{str.__name__}`, but " + f"received `{type(key).__name__}` for the key." + ) + + def _check_fn_params(self, record: ParametersRecord) -> None: + if not isinstance(record, ParametersRecord): + raise TypeError( + f"Expected `{ParametersRecord.__name__}`, but " + f"received `{type(record).__name__}` for the value." + ) + + def _check_fn_metrics(self, record: MetricsRecord) -> None: + if not isinstance(record, MetricsRecord): + raise TypeError( + f"Expected `{MetricsRecord.__name__}`, but " + f"received `{type(record).__name__}` for the value." + ) + + def _check_fn_configs(self, record: ConfigsRecord) -> None: + if not isinstance(record, ConfigsRecord): + raise TypeError( + f"Expected `{ConfigsRecord.__name__}`, but " + f"received `{type(record).__name__}` for the value." + ) + + +@dataclass +class RecordSet: + """RecordSet stores groups of parameters, metrics and configs.""" + + def __init__( + self, + parameters_records: Optional[Dict[str, ParametersRecord]] = None, + metrics_records: Optional[Dict[str, MetricsRecord]] = None, + configs_records: Optional[Dict[str, ConfigsRecord]] = None, + ) -> None: + data = RecordSetData( + parameters_records=parameters_records, + metrics_records=metrics_records, + configs_records=configs_records, + ) + setattr(self, "_data", data) # noqa @property def parameters_records(self) -> TypedDict[str, ParametersRecord]: """Dictionary holding ParametersRecord instances.""" - return self._parameters_records + data = cast(RecordSetData, getattr(self, "_data")) # noqa + return data.parameters_records @property def metrics_records(self) -> TypedDict[str, MetricsRecord]: """Dictionary holding MetricsRecord instances.""" - return self._metrics_records + data = cast(RecordSetData, getattr(self, "_data")) # noqa + return data.metrics_records @property def configs_records(self) -> TypedDict[str, ConfigsRecord]: """Dictionary holding ConfigsRecord instances.""" - return self._configs_records + data = cast(RecordSetData, getattr(self, "_data")) # noqa + return data.configs_records From 67c83b7466a24da94c059adf1c1afdc188d661d6 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Wed, 3 Apr 2024 17:27:10 +0100 Subject: [PATCH 2/4] rm T --- src/py/flwr/common/record/recordset.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/py/flwr/common/record/recordset.py b/src/py/flwr/common/record/recordset.py index ba60889ca7ef..82fa8ef63f10 100644 --- a/src/py/flwr/common/record/recordset.py +++ b/src/py/flwr/common/record/recordset.py @@ -16,15 +16,13 @@ from dataclasses import dataclass -from typing import Dict, Optional, TypeVar, cast +from typing import Dict, Optional, cast from .configsrecord import ConfigsRecord from .metricsrecord import MetricsRecord from .parametersrecord import ParametersRecord from .typeddict import TypedDict -T = TypeVar("T") - class RecordSetData: """Inner data container for theRecordSet class.""" From f1de6876b16c044e665afec4ba4afbefb70b0835 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Thu, 4 Apr 2024 10:57:13 +0100 Subject: [PATCH 3/4] add unit test if record is picklable --- src/py/flwr/common/record/recordset_test.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/src/py/flwr/common/record/recordset_test.py b/src/py/flwr/common/record/recordset_test.py index 0e0b149881be..94d087795841 100644 --- a/src/py/flwr/common/record/recordset_test.py +++ b/src/py/flwr/common/record/recordset_test.py @@ -14,6 +14,7 @@ # ============================================================================== """RecordSet tests.""" +import pickle from copy import deepcopy from typing import Callable, Dict, List, OrderedDict, Type, Union @@ -33,7 +34,7 @@ Parameters, ) -from . import Array, ConfigsRecord, MetricsRecord, ParametersRecord +from . import Array, ConfigsRecord, MetricsRecord, ParametersRecord, RecordSet def get_ndarrays() -> NDArrays: @@ -398,3 +399,18 @@ def test_count_bytes_configsrecord() -> None: record_bytest_count = c_record.count_bytes() assert bytes_in_dict == record_bytest_count + + +def test_record_is_picklable() -> None: + """Test if RecordSet and *Record are picklable.""" + # Prepare + p_record = ParametersRecord() + m_record = MetricsRecord({"aa": 123}) + c_record = ConfigsRecord({"cc": bytes(9)}) + rs = RecordSet() + rs.parameters_records["params"] = p_record + rs.metrics_records["metrics"] = m_record + rs.configs_records["configs"] = c_record + + # Execute + pickle.dumps((p_record, m_record, c_record, rs)) From dd2f7f4c6bc901728f524252fde1cd996dcf7fe0 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Thu, 18 Apr 2024 16:26:06 +0100 Subject: [PATCH 4/4] Update src/py/flwr/common/record/recordset.py Co-authored-by: Daniel J. Beutel --- src/py/flwr/common/record/recordset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/py/flwr/common/record/recordset.py b/src/py/flwr/common/record/recordset.py index 82fa8ef63f10..212cbbc8e6e8 100644 --- a/src/py/flwr/common/record/recordset.py +++ b/src/py/flwr/common/record/recordset.py @@ -25,7 +25,7 @@ class RecordSetData: - """Inner data container for theRecordSet class.""" + """Inner data container for the RecordSet class.""" parameters_records: TypedDict[str, ParametersRecord] metrics_records: TypedDict[str, MetricsRecord]