From 033e4c2fcac846fb9a066d49c4b4bc9cd25b5f34 Mon Sep 17 00:00:00 2001 From: Adam Narozniak <51029327+adam-narozniak@users.noreply.github.com> Date: Mon, 22 Jan 2024 11:19:32 +0100 Subject: [PATCH 1/7] Migrate quickstart HuggingFace to flwr-datasets (#2829) Co-authored-by: jafermarq --- examples/quickstart-huggingface/README.md | 6 +- examples/quickstart-huggingface/client.py | 60 +++++++++---------- .../quickstart-huggingface/pyproject.toml | 1 + .../quickstart-huggingface/requirements.txt | 1 + examples/quickstart-huggingface/run.sh | 2 +- 5 files changed, 36 insertions(+), 34 deletions(-) diff --git a/examples/quickstart-huggingface/README.md b/examples/quickstart-huggingface/README.md index c1e3cc4edc06..fd868aa1fcce 100644 --- a/examples/quickstart-huggingface/README.md +++ b/examples/quickstart-huggingface/README.md @@ -1,6 +1,6 @@ # Federated HuggingFace Transformers using Flower and PyTorch -This introductory example to using [HuggingFace](https://huggingface.co) Transformers with Flower with PyTorch. This example has been extended from the [quickstart-pytorch](https://flower.dev/docs/examples/quickstart-pytorch.html) example. The training script closely follows the [HuggingFace course](https://huggingface.co/course/chapter3?fw=pt), so you are encouraged to check that out for detailed explaination for the transformer pipeline. +This introductory example to using [HuggingFace](https://huggingface.co) Transformers with Flower with PyTorch. This example has been extended from the [quickstart-pytorch](https://flower.dev/docs/examples/quickstart-pytorch.html) example. The training script closely follows the [HuggingFace course](https://huggingface.co/course/chapter3?fw=pt), so you are encouraged to check that out for a detailed explanation of the transformer pipeline. Like `quickstart-pytorch`, running this example in itself is also meant to be quite easy. @@ -62,13 +62,13 @@ Now you are ready to start the Flower clients which will participate in the lear Start client 1 in the first terminal: ```shell -python3 client.py +python3 client.py --node-id 0 ``` Start client 2 in the second terminal: ```shell -python3 client.py +python3 client.py --node-id 1 ``` You will see that PyTorch is starting a federated training. diff --git a/examples/quickstart-huggingface/client.py b/examples/quickstart-huggingface/client.py index 8717d710ad9c..5fa10b9ca0f2 100644 --- a/examples/quickstart-huggingface/client.py +++ b/examples/quickstart-huggingface/client.py @@ -1,58 +1,48 @@ -from collections import OrderedDict +import argparse import warnings +from collections import OrderedDict import flwr as fl import torch -import numpy as np - -import random -from torch.utils.data import DataLoader - -from datasets import load_dataset from evaluate import load as load_metric - -from transformers import AutoTokenizer, DataCollatorWithPadding +from torch.optim import AdamW +from torch.utils.data import DataLoader from transformers import AutoModelForSequenceClassification -from transformers import AdamW +from transformers import AutoTokenizer, DataCollatorWithPadding + +from flwr_datasets import FederatedDataset warnings.filterwarnings("ignore", category=UserWarning) DEVICE = torch.device("cpu") CHECKPOINT = "distilbert-base-uncased" # transformer model checkpoint -def load_data(): +def load_data(node_id): """Load IMDB data (training and eval)""" - raw_datasets = load_dataset("imdb") - raw_datasets = raw_datasets.shuffle(seed=42) - - # remove unnecessary data split - del raw_datasets["unsupervised"] + fds = FederatedDataset(dataset="imdb", partitioners={"train": 1_000}) + partition = fds.load_partition(node_id) + # Divide data: 80% train, 20% test + partition_train_test = partition.train_test_split(test_size=0.2) tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT) def tokenize_function(examples): return tokenizer(examples["text"], truncation=True) - # random 100 samples - population = random.sample(range(len(raw_datasets["train"])), 100) - - tokenized_datasets = raw_datasets.map(tokenize_function, batched=True) - tokenized_datasets["train"] = tokenized_datasets["train"].select(population) - tokenized_datasets["test"] = tokenized_datasets["test"].select(population) - - tokenized_datasets = tokenized_datasets.remove_columns("text") - tokenized_datasets = tokenized_datasets.rename_column("label", "labels") + partition_train_test = partition_train_test.map(tokenize_function, batched=True) + partition_train_test = partition_train_test.remove_columns("text") + partition_train_test = partition_train_test.rename_column("label", "labels") data_collator = DataCollatorWithPadding(tokenizer=tokenizer) trainloader = DataLoader( - tokenized_datasets["train"], + partition_train_test["train"], shuffle=True, batch_size=32, collate_fn=data_collator, ) testloader = DataLoader( - tokenized_datasets["test"], batch_size=32, collate_fn=data_collator + partition_train_test["test"], batch_size=32, collate_fn=data_collator ) return trainloader, testloader @@ -88,12 +78,12 @@ def test(net, testloader): return loss, accuracy -def main(): +def main(node_id): net = AutoModelForSequenceClassification.from_pretrained( CHECKPOINT, num_labels=2 ).to(DEVICE) - trainloader, testloader = load_data() + trainloader, testloader = load_data(node_id) # Flower client class IMDBClient(fl.client.NumPyClient): @@ -122,4 +112,14 @@ def evaluate(self, parameters, config): if __name__ == "__main__": - main() + parser = argparse.ArgumentParser(description="Flower") + parser.add_argument( + "--node-id", + choices=list(range(1_000)), + required=True, + type=int, + help="Partition of the dataset divided into 1,000 iid partitions created " + "artificially.", + ) + node_id = parser.parse_args().node_id + main(node_id) diff --git a/examples/quickstart-huggingface/pyproject.toml b/examples/quickstart-huggingface/pyproject.toml index eb9687c5152c..50ba0b37f8d2 100644 --- a/examples/quickstart-huggingface/pyproject.toml +++ b/examples/quickstart-huggingface/pyproject.toml @@ -14,6 +14,7 @@ authors = [ [tool.poetry.dependencies] python = ">=3.8,<3.11" flwr = ">=1.0,<2.0" +flwr-datasets = ">=0.0.2,<1.0.0" torch = ">=1.13.1,<2.0" transformers = ">=4.30.0,<5.0" evaluate = ">=0.4.0,<1.0" diff --git a/examples/quickstart-huggingface/requirements.txt b/examples/quickstart-huggingface/requirements.txt index aeb2d13fc4a4..3cd5735625ba 100644 --- a/examples/quickstart-huggingface/requirements.txt +++ b/examples/quickstart-huggingface/requirements.txt @@ -1,4 +1,5 @@ flwr>=1.0, <2.0 +flwr-datasets>=0.0.2, <1.0.0 torch>=1.13.1, <2.0 transformers>=4.30.0, <5.0 evaluate>=0.4.0, <1.0 diff --git a/examples/quickstart-huggingface/run.sh b/examples/quickstart-huggingface/run.sh index c64f362086aa..e722a24a21a9 100755 --- a/examples/quickstart-huggingface/run.sh +++ b/examples/quickstart-huggingface/run.sh @@ -6,7 +6,7 @@ sleep 3 # Sleep for 3s to give the server enough time to start for i in `seq 0 1`; do echo "Starting client $i" - python client.py & + python client.py --node-id ${i}& done # This will allow you to use CTRL+C to stop all background processes From 255925938bd7c559af3f7f1ad7b363d34aab28d3 Mon Sep 17 00:00:00 2001 From: Javier Date: Mon, 22 Jan 2024 17:02:31 +0000 Subject: [PATCH 2/7] Improve records type checking (#2838) --- src/py/flwr/common/configsrecord.py | 13 +++++++++++-- src/py/flwr/common/metricsrecord.py | 13 +++++++++++-- src/py/flwr/common/recordset_test.py | 10 ++++++++++ 3 files changed, 32 insertions(+), 4 deletions(-) diff --git a/src/py/flwr/common/configsrecord.py b/src/py/flwr/common/configsrecord.py index 332269503ac0..b0480841e06c 100644 --- a/src/py/flwr/common/configsrecord.py +++ b/src/py/flwr/common/configsrecord.py @@ -87,8 +87,17 @@ def is_valid(value: ConfigsScalar) -> None: # 1s to check 10M element list on a M2 Pro # In such settings, you'd be better of treating such config as # an array and pass it to a ParametersRecord. - for list_value in value: - is_valid(list_value) + # Empty lists are valid + if len(value) > 0: + is_valid(value[0]) + # all elements in the list must be of the same valid type + # this is needed for protobuf + value_type = type(value[0]) + if not all(isinstance(v, value_type) for v in value): + raise TypeError( + "All values in a list must be of the same valid type. " + f"One of {ConfigsScalar}." + ) else: is_valid(value) diff --git a/src/py/flwr/common/metricsrecord.py b/src/py/flwr/common/metricsrecord.py index ecb8eff830ab..e70b0cb31d55 100644 --- a/src/py/flwr/common/metricsrecord.py +++ b/src/py/flwr/common/metricsrecord.py @@ -87,8 +87,17 @@ def is_valid(value: MetricsScalar) -> None: # 1s to check 10M element list on a M2 Pro # In such settings, you'd be better of treating such metric as # an array and pass it to a ParametersRecord. - for list_value in value: - is_valid(list_value) + # Empty lists are valid + if len(value) > 0: + is_valid(value[0]) + # all elements in the list must be of the same valid type + # this is needed for protobuf + value_type = type(value[0]) + if not all(isinstance(v, value_type) for v in value): + raise TypeError( + "All values in a list must be of the same valid type. " + f"One of {MetricsScalar}." + ) else: is_valid(value) diff --git a/src/py/flwr/common/recordset_test.py b/src/py/flwr/common/recordset_test.py index 0e4c351647da..83e1e4595f1d 100644 --- a/src/py/flwr/common/recordset_test.py +++ b/src/py/flwr/common/recordset_test.py @@ -162,6 +162,7 @@ def test_set_parameters_with_incorrect_types( (str, lambda x: float(x.flatten()[0])), # str: float (str, lambda x: x.flatten().astype("int").tolist()), # str: List[int] (str, lambda x: x.flatten().astype("float").tolist()), # str: List[float] + (str, lambda x: []), # str: empty list ], ) def test_set_metrics_to_metricsrecord_with_correct_types( @@ -203,6 +204,10 @@ def test_set_metrics_to_metricsrecord_with_correct_types( str, lambda x: [{str(v): v for v in x.flatten()}], ), # str: List[dict[str: float]] (supported: unsupported) + ( + str, + lambda x: [1, 2.0, 3.0, 4], + ), # str: List[mixing valid types] (supported: unsupported) ( int, lambda x: x.flatten().tolist(), @@ -278,6 +283,7 @@ def test_set_metrics_to_metricsrecord_with_and_without_keeping_input( (str, lambda x: x.flatten().astype("float").tolist()), # str: List[float] (str, lambda x: x.flatten().astype("bool").tolist()), # str: List[bool] (str, lambda x: [x.flatten().tobytes()]), # str: List[bytes] + (str, lambda x: []), # str: empyt list ], ) def test_set_configs_to_configsrecord_with_correct_types( @@ -310,6 +316,10 @@ def test_set_configs_to_configsrecord_with_correct_types( str, lambda x: [{str(v): v for v in x.flatten()}], ), # str: List[dict[str: float]] (supported: unsupported) + ( + str, + lambda x: [1, 2.0, 3.0, 4], + ), # str: List[mixing valid types] (supported: unsupported) ( int, lambda x: x.flatten().tolist(), From d7be8fb64507a83968479782976b9458435f909a Mon Sep 17 00:00:00 2001 From: Heng Pan <134433891+panh99@users.noreply.github.com> Date: Mon, 22 Jan 2024 17:50:14 +0000 Subject: [PATCH 3/7] Add `*Record` ProtoBuf messages and corresponding serde functions. (#2831) --- src/proto/flwr/proto/recordset.proto | 70 +++++++ src/proto/flwr/proto/task.proto | 7 +- src/py/flwr/common/serde.py | 180 +++++++++++++++-- src/py/flwr/common/serde_test.py | 95 ++++++++- src/py/flwr/proto/recordset_pb2.py | 54 +++++ src/py/flwr/proto/recordset_pb2.pyi | 240 +++++++++++++++++++++++ src/py/flwr/proto/recordset_pb2_grpc.py | 4 + src/py/flwr/proto/recordset_pb2_grpc.pyi | 4 + src/py/flwr/proto/task_pb2.py | 37 ++-- src/py/flwr/proto/task_pb2.pyi | 76 ++----- src/py/flwr_tool/init_py_check.py | 2 +- src/py/flwr_tool/protoc.py | 2 +- src/py/flwr_tool/protoc_test.py | 2 +- 13 files changed, 654 insertions(+), 119 deletions(-) create mode 100644 src/proto/flwr/proto/recordset.proto create mode 100644 src/py/flwr/proto/recordset_pb2.py create mode 100644 src/py/flwr/proto/recordset_pb2.pyi create mode 100644 src/py/flwr/proto/recordset_pb2_grpc.py create mode 100644 src/py/flwr/proto/recordset_pb2_grpc.pyi diff --git a/src/proto/flwr/proto/recordset.proto b/src/proto/flwr/proto/recordset.proto new file mode 100644 index 000000000000..8e2e5d60b6db --- /dev/null +++ b/src/proto/flwr/proto/recordset.proto @@ -0,0 +1,70 @@ +// Copyright 2024 Flower Labs GmbH. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================== + +syntax = "proto3"; + +package flwr.proto; + +message DoubleList { repeated double vals = 1; } +message Sint64List { repeated sint64 vals = 1; } +message BoolList { repeated bool vals = 1; } +message StringList { repeated string vals = 1; } +message BytesList { repeated bytes vals = 1; } + +message Array { + string dtype = 1; + repeated int32 shape = 2; + string stype = 3; + bytes data = 4; +} + +message MetricsRecordValue { + oneof value { + // Single element + double double = 1; + sint64 sint64 = 2; + + // List types + DoubleList double_list = 21; + Sint64List sint64_list = 22; + } +} + +message ConfigsRecordValue { + oneof value { + // Single element + double double = 1; + sint64 sint64 = 2; + bool bool = 3; + string string = 4; + bytes bytes = 5; + + // List types + DoubleList double_list = 21; + Sint64List sint64_list = 22; + BoolList bool_list = 23; + StringList string_list = 24; + BytesList bytes_list = 25; + } +} + +message ParametersRecord { + repeated string data_keys = 1; + repeated Array data_values = 2; +} + +message MetricsRecord { map data = 1; } + +message ConfigsRecord { map data = 1; } diff --git a/src/proto/flwr/proto/task.proto b/src/proto/flwr/proto/task.proto index 9faabe1eebd1..20dd5a3aa6c8 100644 --- a/src/proto/flwr/proto/task.proto +++ b/src/proto/flwr/proto/task.proto @@ -18,6 +18,7 @@ syntax = "proto3"; package flwr.proto; import "flwr/proto/node.proto"; +import "flwr/proto/recordset.proto"; import "flwr/proto/transport.proto"; message Task { @@ -49,12 +50,6 @@ message TaskRes { } message Value { - message DoubleList { repeated double vals = 1; } - message Sint64List { repeated sint64 vals = 1; } - message BoolList { repeated bool vals = 1; } - message StringList { repeated string vals = 1; } - message BytesList { repeated bytes vals = 1; } - oneof value { // Single element double double = 1; diff --git a/src/py/flwr/common/serde.py b/src/py/flwr/common/serde.py index a059d95af833..2094c76a9856 100644 --- a/src/py/flwr/common/serde.py +++ b/src/py/flwr/common/serde.py @@ -15,10 +15,22 @@ """ProtoBuf serialization and deserialization.""" -from typing import Any, Dict, List, MutableMapping, cast - -from flwr.proto.task_pb2 import Value # pylint: disable=E0611 -from flwr.proto.transport_pb2 import ( # pylint: disable=E0611 +from typing import Any, Dict, List, MutableMapping, OrderedDict, Type, TypeVar, cast + +from google.protobuf.message import Message + +# pylint: disable=E0611 +from flwr.proto.recordset_pb2 import Array as ProtoArray +from flwr.proto.recordset_pb2 import BoolList, BytesList +from flwr.proto.recordset_pb2 import ConfigsRecord as ProtoConfigsRecord +from flwr.proto.recordset_pb2 import ConfigsRecordValue as ProtoConfigsRecordValue +from flwr.proto.recordset_pb2 import DoubleList +from flwr.proto.recordset_pb2 import MetricsRecord as ProtoMetricsRecord +from flwr.proto.recordset_pb2 import MetricsRecordValue as ProtoMetricsRecordValue +from flwr.proto.recordset_pb2 import ParametersRecord as ProtoParametersRecord +from flwr.proto.recordset_pb2 import Sint64List, StringList +from flwr.proto.task_pb2 import Value +from flwr.proto.transport_pb2 import ( ClientMessage, Code, Parameters, @@ -28,7 +40,11 @@ Status, ) +# pylint: enable=E0611 from . import typing +from .configsrecord import ConfigsRecord +from .metricsrecord import MetricsRecord +from .parametersrecord import Array, ParametersRecord # === ServerMessage message === @@ -493,7 +509,7 @@ def scalar_from_proto(scalar_msg: Scalar) -> typing.Scalar: # === Value messages === -_python_type_to_field_name = { +_type_to_field = { float: "double", int: "sint64", bool: "bool", @@ -502,22 +518,20 @@ def scalar_from_proto(scalar_msg: Scalar) -> typing.Scalar: } -_python_list_type_to_message_and_field_name = { - float: (Value.DoubleList, "double_list"), - int: (Value.Sint64List, "sint64_list"), - bool: (Value.BoolList, "bool_list"), - str: (Value.StringList, "string_list"), - bytes: (Value.BytesList, "bytes_list"), +_list_type_to_class_and_field = { + float: (DoubleList, "double_list"), + int: (Sint64List, "sint64_list"), + bool: (BoolList, "bool_list"), + str: (StringList, "string_list"), + bytes: (BytesList, "bytes_list"), } def _check_value(value: typing.Value) -> None: - if isinstance(value, tuple(_python_type_to_field_name.keys())): + if isinstance(value, tuple(_type_to_field.keys())): return if isinstance(value, list): - if len(value) > 0 and isinstance( - value[0], tuple(_python_type_to_field_name.keys()) - ): + if len(value) > 0 and isinstance(value[0], tuple(_type_to_field.keys())): data_type = type(value[0]) for element in value: if isinstance(element, data_type): @@ -539,12 +553,12 @@ def value_to_proto(value: typing.Value) -> Value: arg = {} if isinstance(value, list): - msg_class, field_name = _python_list_type_to_message_and_field_name[ + msg_class, field_name = _list_type_to_class_and_field[ type(value[0]) if len(value) > 0 else int ] arg[field_name] = msg_class(vals=value) else: - arg[_python_type_to_field_name[type(value)]] = value + arg[_type_to_field[type(value)]] = value return Value(**arg) @@ -573,3 +587,135 @@ def named_values_from_proto( ) -> Dict[str, typing.Value]: """Deserialize named values from ProtoBuf.""" return {name: value_from_proto(value) for name, value in named_values_proto.items()} + + +# === Record messages === + + +T = TypeVar("T") + + +def _record_value_to_proto( + value: Any, allowed_types: List[type], proto_class: Type[T] +) -> T: + """Serialize `*RecordValue` to ProtoBuf.""" + arg = {} + for t in allowed_types: + # Single element + # Note: `isinstance(False, int) == True`. + if type(value) == t: # pylint: disable=C0123 + arg[_type_to_field[t]] = value + return proto_class(**arg) + # List + if isinstance(value, list) and all(isinstance(item, t) for item in value): + list_class, field_name = _list_type_to_class_and_field[t] + arg[field_name] = list_class(vals=value) + return proto_class(**arg) + # Invalid types + raise TypeError( + f"The type of the following value is not allowed " + f"in '{proto_class.__name__}':\n{value}" + ) + + +def _record_value_from_proto(value_proto: Message) -> Any: + """Deserialize `*RecordValue` from ProtoBuf.""" + value_field = cast(str, value_proto.WhichOneof("value")) + if value_field.endswith("list"): + value = list(getattr(value_proto, value_field).vals) + else: + value = getattr(value_proto, value_field) + return value + + +def _record_value_dict_to_proto( + value_dict: Dict[str, Any], allowed_types: List[type], value_proto_class: Type[T] +) -> Dict[str, T]: + """Serialize the record value dict to ProtoBuf.""" + + def proto(_v: Any) -> T: + return _record_value_to_proto(_v, allowed_types, value_proto_class) + + return {k: proto(v) for k, v in value_dict.items()} + + +def _record_value_dict_from_proto( + value_dict_proto: MutableMapping[str, Any] +) -> Dict[str, Any]: + """Deserialize the record value dict from ProtoBuf.""" + return {k: _record_value_from_proto(v) for k, v in value_dict_proto.items()} + + +def array_to_proto(array: Array) -> ProtoArray: + """Serialize Array to ProtoBuf.""" + return ProtoArray(**vars(array)) + + +def array_from_proto(array_proto: ProtoArray) -> Array: + """Deserialize Array from ProtoBuf.""" + return Array( + dtype=array_proto.dtype, + shape=list(array_proto.shape), + stype=array_proto.stype, + data=array_proto.data, + ) + + +def parameters_record_to_proto(record: ParametersRecord) -> ProtoParametersRecord: + """Serialize ParametersRecord to ProtoBuf.""" + return ProtoParametersRecord( + data_keys=record.data.keys(), + data_values=map(array_to_proto, record.data.values()), + ) + + +def parameters_record_from_proto( + record_proto: ProtoParametersRecord, +) -> ParametersRecord: + """Deserialize ParametersRecord from ProtoBuf.""" + return ParametersRecord( + array_dict=OrderedDict( + zip(record_proto.data_keys, map(array_from_proto, record_proto.data_values)) + ), + keep_input=False, + ) + + +def metrics_record_to_proto(record: MetricsRecord) -> ProtoMetricsRecord: + """Serialize MetricsRecord to ProtoBuf.""" + return ProtoMetricsRecord( + data=_record_value_dict_to_proto( + record.data, [float, int], ProtoMetricsRecordValue + ) + ) + + +def metrics_record_from_proto(record_proto: ProtoMetricsRecord) -> MetricsRecord: + """Deserialize MetricsRecord from ProtoBuf.""" + return MetricsRecord( + metrics_dict=cast( + Dict[str, typing.MetricsRecordValues], + _record_value_dict_from_proto(record_proto.data), + ), + keep_input=False, + ) + + +def configs_record_to_proto(record: ConfigsRecord) -> ProtoConfigsRecord: + """Serialize ConfigsRecord to ProtoBuf.""" + return ProtoConfigsRecord( + data=_record_value_dict_to_proto( + record.data, [int, float, bool, str, bytes], ProtoConfigsRecordValue + ) + ) + + +def configs_record_from_proto(record_proto: ProtoConfigsRecord) -> ConfigsRecord: + """Deserialize ConfigsRecord from ProtoBuf.""" + return ConfigsRecord( + configs_dict=cast( + Dict[str, typing.ConfigsRecordValues], + _record_value_dict_from_proto(record_proto.data), + ), + keep_input=False, + ) diff --git a/src/py/flwr/common/serde_test.py b/src/py/flwr/common/serde_test.py index 2c61c28eb0ee..c584597d89f6 100644 --- a/src/py/flwr/common/serde_test.py +++ b/src/py/flwr/common/serde_test.py @@ -15,14 +15,31 @@ """(De-)serialization tests.""" -from typing import Dict, Union, cast - -from flwr.common import typing -from flwr.proto import transport_pb2 as pb2 # pylint: disable=E0611 - +from typing import Dict, OrderedDict, Union, cast + +# pylint: disable=E0611 +from flwr.proto import transport_pb2 as pb2 +from flwr.proto.recordset_pb2 import Array as ProtoArray +from flwr.proto.recordset_pb2 import ConfigsRecord as ProtoConfigsRecord +from flwr.proto.recordset_pb2 import MetricsRecord as ProtoMetricsRecord +from flwr.proto.recordset_pb2 import ParametersRecord as ProtoParametersRecord + +# pylint: enable=E0611 +from . import typing +from .configsrecord import ConfigsRecord +from .metricsrecord import MetricsRecord +from .parametersrecord import Array, ParametersRecord from .serde import ( + array_from_proto, + array_to_proto, + configs_record_from_proto, + configs_record_to_proto, + metrics_record_from_proto, + metrics_record_to_proto, named_values_from_proto, named_values_to_proto, + parameters_record_from_proto, + parameters_record_to_proto, scalar_from_proto, scalar_to_proto, status_from_proto, @@ -157,3 +174,71 @@ def test_named_values_serialization_deserialization() -> None: assert elm1 == elm2 else: assert expected == actual + + +def test_array_serialization_deserialization() -> None: + """Test serialization and deserialization of Array.""" + # Prepare + original = Array(dtype="float", shape=[2, 2], stype="dense", data=b"1234") + + # Execute + proto = array_to_proto(original) + deserialized = array_from_proto(proto) + + # Assert + assert isinstance(proto, ProtoArray) + assert original == deserialized + + +def test_parameters_record_serialization_deserialization() -> None: + """Test serialization and deserialization of ParametersRecord.""" + # Prepare + original = ParametersRecord( + array_dict=OrderedDict( + [ + ("k1", Array(dtype="float", shape=[2, 2], stype="dense", data=b"1234")), + ("k2", Array(dtype="int", shape=[3], stype="sparse", data=b"567")), + ] + ), + keep_input=False, + ) + + # Execute + proto = parameters_record_to_proto(original) + deserialized = parameters_record_from_proto(proto) + + # Assert + assert isinstance(proto, ProtoParametersRecord) + assert original.data == deserialized.data + + +def test_metrics_record_serialization_deserialization() -> None: + """Test serialization and deserialization of MetricsRecord.""" + # Prepare + original = MetricsRecord( + metrics_dict={"accuracy": 0.95, "loss": 0.1}, keep_input=False + ) + + # Execute + proto = metrics_record_to_proto(original) + deserialized = metrics_record_from_proto(proto) + + # Assert + assert isinstance(proto, ProtoMetricsRecord) + assert original.data == deserialized.data + + +def test_configs_record_serialization_deserialization() -> None: + """Test serialization and deserialization of ConfigsRecord.""" + # Prepare + original = ConfigsRecord( + configs_dict={"learning_rate": 0.01, "batch_size": 32}, keep_input=False + ) + + # Execute + proto = configs_record_to_proto(original) + deserialized = configs_record_from_proto(proto) + + # Assert + assert isinstance(proto, ProtoConfigsRecord) + assert original.data == deserialized.data diff --git a/src/py/flwr/proto/recordset_pb2.py b/src/py/flwr/proto/recordset_pb2.py new file mode 100644 index 000000000000..4134511f1f53 --- /dev/null +++ b/src/py/flwr/proto/recordset_pb2.py @@ -0,0 +1,54 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: flwr/proto/recordset.proto +# Protobuf Python Version: 4.25.0 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1a\x66lwr/proto/recordset.proto\x12\nflwr.proto\"\x1a\n\nDoubleList\x12\x0c\n\x04vals\x18\x01 \x03(\x01\"\x1a\n\nSint64List\x12\x0c\n\x04vals\x18\x01 \x03(\x12\"\x18\n\x08\x42oolList\x12\x0c\n\x04vals\x18\x01 \x03(\x08\"\x1a\n\nStringList\x12\x0c\n\x04vals\x18\x01 \x03(\t\"\x19\n\tBytesList\x12\x0c\n\x04vals\x18\x01 \x03(\x0c\"B\n\x05\x41rray\x12\r\n\x05\x64type\x18\x01 \x01(\t\x12\r\n\x05shape\x18\x02 \x03(\x05\x12\r\n\x05stype\x18\x03 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\x0c\"\x9f\x01\n\x12MetricsRecordValue\x12\x10\n\x06\x64ouble\x18\x01 \x01(\x01H\x00\x12\x10\n\x06sint64\x18\x02 \x01(\x12H\x00\x12-\n\x0b\x64ouble_list\x18\x15 \x01(\x0b\x32\x16.flwr.proto.DoubleListH\x00\x12-\n\x0bsint64_list\x18\x16 \x01(\x0b\x32\x16.flwr.proto.Sint64ListH\x00\x42\x07\n\x05value\"\xd9\x02\n\x12\x43onfigsRecordValue\x12\x10\n\x06\x64ouble\x18\x01 \x01(\x01H\x00\x12\x10\n\x06sint64\x18\x02 \x01(\x12H\x00\x12\x0e\n\x04\x62ool\x18\x03 \x01(\x08H\x00\x12\x10\n\x06string\x18\x04 \x01(\tH\x00\x12\x0f\n\x05\x62ytes\x18\x05 \x01(\x0cH\x00\x12-\n\x0b\x64ouble_list\x18\x15 \x01(\x0b\x32\x16.flwr.proto.DoubleListH\x00\x12-\n\x0bsint64_list\x18\x16 \x01(\x0b\x32\x16.flwr.proto.Sint64ListH\x00\x12)\n\tbool_list\x18\x17 \x01(\x0b\x32\x14.flwr.proto.BoolListH\x00\x12-\n\x0bstring_list\x18\x18 \x01(\x0b\x32\x16.flwr.proto.StringListH\x00\x12+\n\nbytes_list\x18\x19 \x01(\x0b\x32\x15.flwr.proto.BytesListH\x00\x42\x07\n\x05value\"M\n\x10ParametersRecord\x12\x11\n\tdata_keys\x18\x01 \x03(\t\x12&\n\x0b\x64\x61ta_values\x18\x02 \x03(\x0b\x32\x11.flwr.proto.Array\"\x8f\x01\n\rMetricsRecord\x12\x31\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32#.flwr.proto.MetricsRecord.DataEntry\x1aK\n\tDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12-\n\x05value\x18\x02 \x01(\x0b\x32\x1e.flwr.proto.MetricsRecordValue:\x02\x38\x01\"\x8f\x01\n\rConfigsRecord\x12\x31\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32#.flwr.proto.ConfigsRecord.DataEntry\x1aK\n\tDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12-\n\x05value\x18\x02 \x01(\x0b\x32\x1e.flwr.proto.ConfigsRecordValue:\x02\x38\x01\x62\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.recordset_pb2', _globals) +if _descriptor._USE_C_DESCRIPTORS == False: + DESCRIPTOR._options = None + _globals['_METRICSRECORD_DATAENTRY']._options = None + _globals['_METRICSRECORD_DATAENTRY']._serialized_options = b'8\001' + _globals['_CONFIGSRECORD_DATAENTRY']._options = None + _globals['_CONFIGSRECORD_DATAENTRY']._serialized_options = b'8\001' + _globals['_DOUBLELIST']._serialized_start=42 + _globals['_DOUBLELIST']._serialized_end=68 + _globals['_SINT64LIST']._serialized_start=70 + _globals['_SINT64LIST']._serialized_end=96 + _globals['_BOOLLIST']._serialized_start=98 + _globals['_BOOLLIST']._serialized_end=122 + _globals['_STRINGLIST']._serialized_start=124 + _globals['_STRINGLIST']._serialized_end=150 + _globals['_BYTESLIST']._serialized_start=152 + _globals['_BYTESLIST']._serialized_end=177 + _globals['_ARRAY']._serialized_start=179 + _globals['_ARRAY']._serialized_end=245 + _globals['_METRICSRECORDVALUE']._serialized_start=248 + _globals['_METRICSRECORDVALUE']._serialized_end=407 + _globals['_CONFIGSRECORDVALUE']._serialized_start=410 + _globals['_CONFIGSRECORDVALUE']._serialized_end=755 + _globals['_PARAMETERSRECORD']._serialized_start=757 + _globals['_PARAMETERSRECORD']._serialized_end=834 + _globals['_METRICSRECORD']._serialized_start=837 + _globals['_METRICSRECORD']._serialized_end=980 + _globals['_METRICSRECORD_DATAENTRY']._serialized_start=905 + _globals['_METRICSRECORD_DATAENTRY']._serialized_end=980 + _globals['_CONFIGSRECORD']._serialized_start=983 + _globals['_CONFIGSRECORD']._serialized_end=1126 + _globals['_CONFIGSRECORD_DATAENTRY']._serialized_start=1051 + _globals['_CONFIGSRECORD_DATAENTRY']._serialized_end=1126 +# @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/recordset_pb2.pyi b/src/py/flwr/proto/recordset_pb2.pyi new file mode 100644 index 000000000000..1e9556de9ce6 --- /dev/null +++ b/src/py/flwr/proto/recordset_pb2.pyi @@ -0,0 +1,240 @@ +""" +@generated by mypy-protobuf. Do not edit manually! +isort:skip_file +""" +import builtins +import google.protobuf.descriptor +import google.protobuf.internal.containers +import google.protobuf.message +import typing +import typing_extensions + +DESCRIPTOR: google.protobuf.descriptor.FileDescriptor + +class DoubleList(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + VALS_FIELD_NUMBER: builtins.int + @property + def vals(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.float]: ... + def __init__(self, + *, + vals: typing.Optional[typing.Iterable[builtins.float]] = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["vals",b"vals"]) -> None: ... +global___DoubleList = DoubleList + +class Sint64List(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + VALS_FIELD_NUMBER: builtins.int + @property + def vals(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ... + def __init__(self, + *, + vals: typing.Optional[typing.Iterable[builtins.int]] = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["vals",b"vals"]) -> None: ... +global___Sint64List = Sint64List + +class BoolList(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + VALS_FIELD_NUMBER: builtins.int + @property + def vals(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.bool]: ... + def __init__(self, + *, + vals: typing.Optional[typing.Iterable[builtins.bool]] = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["vals",b"vals"]) -> None: ... +global___BoolList = BoolList + +class StringList(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + VALS_FIELD_NUMBER: builtins.int + @property + def vals(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[typing.Text]: ... + def __init__(self, + *, + vals: typing.Optional[typing.Iterable[typing.Text]] = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["vals",b"vals"]) -> None: ... +global___StringList = StringList + +class BytesList(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + VALS_FIELD_NUMBER: builtins.int + @property + def vals(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.bytes]: ... + def __init__(self, + *, + vals: typing.Optional[typing.Iterable[builtins.bytes]] = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["vals",b"vals"]) -> None: ... +global___BytesList = BytesList + +class Array(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + DTYPE_FIELD_NUMBER: builtins.int + SHAPE_FIELD_NUMBER: builtins.int + STYPE_FIELD_NUMBER: builtins.int + DATA_FIELD_NUMBER: builtins.int + dtype: typing.Text + @property + def shape(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ... + stype: typing.Text + data: builtins.bytes + def __init__(self, + *, + dtype: typing.Text = ..., + shape: typing.Optional[typing.Iterable[builtins.int]] = ..., + stype: typing.Text = ..., + data: builtins.bytes = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["data",b"data","dtype",b"dtype","shape",b"shape","stype",b"stype"]) -> None: ... +global___Array = Array + +class MetricsRecordValue(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + DOUBLE_FIELD_NUMBER: builtins.int + SINT64_FIELD_NUMBER: builtins.int + DOUBLE_LIST_FIELD_NUMBER: builtins.int + SINT64_LIST_FIELD_NUMBER: builtins.int + double: builtins.float + """Single element""" + + sint64: builtins.int + @property + def double_list(self) -> global___DoubleList: + """List types""" + pass + @property + def sint64_list(self) -> global___Sint64List: ... + def __init__(self, + *, + double: builtins.float = ..., + sint64: builtins.int = ..., + double_list: typing.Optional[global___DoubleList] = ..., + sint64_list: typing.Optional[global___Sint64List] = ..., + ) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["double",b"double","double_list",b"double_list","sint64",b"sint64","sint64_list",b"sint64_list","value",b"value"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["double",b"double","double_list",b"double_list","sint64",b"sint64","sint64_list",b"sint64_list","value",b"value"]) -> None: ... + def WhichOneof(self, oneof_group: typing_extensions.Literal["value",b"value"]) -> typing.Optional[typing_extensions.Literal["double","sint64","double_list","sint64_list"]]: ... +global___MetricsRecordValue = MetricsRecordValue + +class ConfigsRecordValue(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + DOUBLE_FIELD_NUMBER: builtins.int + SINT64_FIELD_NUMBER: builtins.int + BOOL_FIELD_NUMBER: builtins.int + STRING_FIELD_NUMBER: builtins.int + BYTES_FIELD_NUMBER: builtins.int + DOUBLE_LIST_FIELD_NUMBER: builtins.int + SINT64_LIST_FIELD_NUMBER: builtins.int + BOOL_LIST_FIELD_NUMBER: builtins.int + STRING_LIST_FIELD_NUMBER: builtins.int + BYTES_LIST_FIELD_NUMBER: builtins.int + double: builtins.float + """Single element""" + + sint64: builtins.int + bool: builtins.bool + string: typing.Text + bytes: builtins.bytes + @property + def double_list(self) -> global___DoubleList: + """List types""" + pass + @property + def sint64_list(self) -> global___Sint64List: ... + @property + def bool_list(self) -> global___BoolList: ... + @property + def string_list(self) -> global___StringList: ... + @property + def bytes_list(self) -> global___BytesList: ... + def __init__(self, + *, + double: builtins.float = ..., + sint64: builtins.int = ..., + bool: builtins.bool = ..., + string: typing.Text = ..., + bytes: builtins.bytes = ..., + double_list: typing.Optional[global___DoubleList] = ..., + sint64_list: typing.Optional[global___Sint64List] = ..., + bool_list: typing.Optional[global___BoolList] = ..., + string_list: typing.Optional[global___StringList] = ..., + bytes_list: typing.Optional[global___BytesList] = ..., + ) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["bool",b"bool","bool_list",b"bool_list","bytes",b"bytes","bytes_list",b"bytes_list","double",b"double","double_list",b"double_list","sint64",b"sint64","sint64_list",b"sint64_list","string",b"string","string_list",b"string_list","value",b"value"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["bool",b"bool","bool_list",b"bool_list","bytes",b"bytes","bytes_list",b"bytes_list","double",b"double","double_list",b"double_list","sint64",b"sint64","sint64_list",b"sint64_list","string",b"string","string_list",b"string_list","value",b"value"]) -> None: ... + def WhichOneof(self, oneof_group: typing_extensions.Literal["value",b"value"]) -> typing.Optional[typing_extensions.Literal["double","sint64","bool","string","bytes","double_list","sint64_list","bool_list","string_list","bytes_list"]]: ... +global___ConfigsRecordValue = ConfigsRecordValue + +class ParametersRecord(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + DATA_KEYS_FIELD_NUMBER: builtins.int + DATA_VALUES_FIELD_NUMBER: builtins.int + @property + def data_keys(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[typing.Text]: ... + @property + def data_values(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Array]: ... + def __init__(self, + *, + data_keys: typing.Optional[typing.Iterable[typing.Text]] = ..., + data_values: typing.Optional[typing.Iterable[global___Array]] = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["data_keys",b"data_keys","data_values",b"data_values"]) -> None: ... +global___ParametersRecord = ParametersRecord + +class MetricsRecord(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + class DataEntry(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + KEY_FIELD_NUMBER: builtins.int + VALUE_FIELD_NUMBER: builtins.int + key: typing.Text + @property + def value(self) -> global___MetricsRecordValue: ... + def __init__(self, + *, + key: typing.Text = ..., + value: typing.Optional[global___MetricsRecordValue] = ..., + ) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ... + + DATA_FIELD_NUMBER: builtins.int + @property + def data(self) -> google.protobuf.internal.containers.MessageMap[typing.Text, global___MetricsRecordValue]: ... + def __init__(self, + *, + data: typing.Optional[typing.Mapping[typing.Text, global___MetricsRecordValue]] = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["data",b"data"]) -> None: ... +global___MetricsRecord = MetricsRecord + +class ConfigsRecord(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + class DataEntry(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + KEY_FIELD_NUMBER: builtins.int + VALUE_FIELD_NUMBER: builtins.int + key: typing.Text + @property + def value(self) -> global___ConfigsRecordValue: ... + def __init__(self, + *, + key: typing.Text = ..., + value: typing.Optional[global___ConfigsRecordValue] = ..., + ) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ... + + DATA_FIELD_NUMBER: builtins.int + @property + def data(self) -> google.protobuf.internal.containers.MessageMap[typing.Text, global___ConfigsRecordValue]: ... + def __init__(self, + *, + data: typing.Optional[typing.Mapping[typing.Text, global___ConfigsRecordValue]] = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["data",b"data"]) -> None: ... +global___ConfigsRecord = ConfigsRecord diff --git a/src/py/flwr/proto/recordset_pb2_grpc.py b/src/py/flwr/proto/recordset_pb2_grpc.py new file mode 100644 index 000000000000..2daafffebfc8 --- /dev/null +++ b/src/py/flwr/proto/recordset_pb2_grpc.py @@ -0,0 +1,4 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc + diff --git a/src/py/flwr/proto/recordset_pb2_grpc.pyi b/src/py/flwr/proto/recordset_pb2_grpc.pyi new file mode 100644 index 000000000000..f3a5a087ef5d --- /dev/null +++ b/src/py/flwr/proto/recordset_pb2_grpc.pyi @@ -0,0 +1,4 @@ +""" +@generated by mypy-protobuf. Do not edit manually! +isort:skip_file +""" diff --git a/src/py/flwr/proto/task_pb2.py b/src/py/flwr/proto/task_pb2.py index 83ae15c0aba2..963b07db94f8 100644 --- a/src/py/flwr/proto/task_pb2.py +++ b/src/py/flwr/proto/task_pb2.py @@ -13,10 +13,11 @@ from flwr.proto import node_pb2 as flwr_dot_proto_dot_node__pb2 +from flwr.proto import recordset_pb2 as flwr_dot_proto_dot_recordset__pb2 from flwr.proto import transport_pb2 as flwr_dot_proto_dot_transport__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/task.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x1a\x66lwr/proto/transport.proto\"\xd1\x02\n\x04Task\x12\"\n\x08producer\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\"\n\x08\x63onsumer\x18\x02 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x12\n\ncreated_at\x18\x03 \x01(\t\x12\x14\n\x0c\x64\x65livered_at\x18\x04 \x01(\t\x12\x0b\n\x03ttl\x18\x05 \x01(\t\x12\x10\n\x08\x61ncestry\x18\x06 \x03(\t\x12\x11\n\ttask_type\x18\x07 \x01(\t\x12)\n\x02sa\x18\x08 \x01(\x0b\x32\x1d.flwr.proto.SecureAggregation\x12<\n\x15legacy_server_message\x18\x65 \x01(\x0b\x32\x19.flwr.proto.ServerMessageB\x02\x18\x01\x12<\n\x15legacy_client_message\x18\x66 \x01(\x0b\x32\x19.flwr.proto.ClientMessageB\x02\x18\x01\"\\\n\x07TaskIns\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Task\"\\\n\x07TaskRes\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Task\"\xf3\x03\n\x05Value\x12\x10\n\x06\x64ouble\x18\x01 \x01(\x01H\x00\x12\x10\n\x06sint64\x18\x02 \x01(\x12H\x00\x12\x0e\n\x04\x62ool\x18\x03 \x01(\x08H\x00\x12\x10\n\x06string\x18\x04 \x01(\tH\x00\x12\x0f\n\x05\x62ytes\x18\x05 \x01(\x0cH\x00\x12\x33\n\x0b\x64ouble_list\x18\x15 \x01(\x0b\x32\x1c.flwr.proto.Value.DoubleListH\x00\x12\x33\n\x0bsint64_list\x18\x16 \x01(\x0b\x32\x1c.flwr.proto.Value.Sint64ListH\x00\x12/\n\tbool_list\x18\x17 \x01(\x0b\x32\x1a.flwr.proto.Value.BoolListH\x00\x12\x33\n\x0bstring_list\x18\x18 \x01(\x0b\x32\x1c.flwr.proto.Value.StringListH\x00\x12\x31\n\nbytes_list\x18\x19 \x01(\x0b\x32\x1b.flwr.proto.Value.BytesListH\x00\x1a\x1a\n\nDoubleList\x12\x0c\n\x04vals\x18\x01 \x03(\x01\x1a\x1a\n\nSint64List\x12\x0c\n\x04vals\x18\x01 \x03(\x12\x1a\x18\n\x08\x42oolList\x12\x0c\n\x04vals\x18\x01 \x03(\x08\x1a\x1a\n\nStringList\x12\x0c\n\x04vals\x18\x01 \x03(\t\x1a\x19\n\tBytesList\x12\x0c\n\x04vals\x18\x01 \x03(\x0c\x42\x07\n\x05value\"\xa0\x01\n\x11SecureAggregation\x12\x44\n\x0cnamed_values\x18\x01 \x03(\x0b\x32..flwr.proto.SecureAggregation.NamedValuesEntry\x1a\x45\n\x10NamedValuesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12 \n\x05value\x18\x02 \x01(\x0b\x32\x11.flwr.proto.Value:\x02\x38\x01\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/task.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\x1a\x66lwr/proto/transport.proto\"\xd1\x02\n\x04Task\x12\"\n\x08producer\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\"\n\x08\x63onsumer\x18\x02 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x12\n\ncreated_at\x18\x03 \x01(\t\x12\x14\n\x0c\x64\x65livered_at\x18\x04 \x01(\t\x12\x0b\n\x03ttl\x18\x05 \x01(\t\x12\x10\n\x08\x61ncestry\x18\x06 \x03(\t\x12\x11\n\ttask_type\x18\x07 \x01(\t\x12)\n\x02sa\x18\x08 \x01(\x0b\x32\x1d.flwr.proto.SecureAggregation\x12<\n\x15legacy_server_message\x18\x65 \x01(\x0b\x32\x19.flwr.proto.ServerMessageB\x02\x18\x01\x12<\n\x15legacy_client_message\x18\x66 \x01(\x0b\x32\x19.flwr.proto.ClientMessageB\x02\x18\x01\"\\\n\x07TaskIns\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Task\"\\\n\x07TaskRes\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Task\"\xcc\x02\n\x05Value\x12\x10\n\x06\x64ouble\x18\x01 \x01(\x01H\x00\x12\x10\n\x06sint64\x18\x02 \x01(\x12H\x00\x12\x0e\n\x04\x62ool\x18\x03 \x01(\x08H\x00\x12\x10\n\x06string\x18\x04 \x01(\tH\x00\x12\x0f\n\x05\x62ytes\x18\x05 \x01(\x0cH\x00\x12-\n\x0b\x64ouble_list\x18\x15 \x01(\x0b\x32\x16.flwr.proto.DoubleListH\x00\x12-\n\x0bsint64_list\x18\x16 \x01(\x0b\x32\x16.flwr.proto.Sint64ListH\x00\x12)\n\tbool_list\x18\x17 \x01(\x0b\x32\x14.flwr.proto.BoolListH\x00\x12-\n\x0bstring_list\x18\x18 \x01(\x0b\x32\x16.flwr.proto.StringListH\x00\x12+\n\nbytes_list\x18\x19 \x01(\x0b\x32\x15.flwr.proto.BytesListH\x00\x42\x07\n\x05value\"\xa0\x01\n\x11SecureAggregation\x12\x44\n\x0cnamed_values\x18\x01 \x03(\x0b\x32..flwr.proto.SecureAggregation.NamedValuesEntry\x1a\x45\n\x10NamedValuesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12 \n\x05value\x18\x02 \x01(\x0b\x32\x11.flwr.proto.Value:\x02\x38\x01\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -29,26 +30,16 @@ _globals['_TASK'].fields_by_name['legacy_client_message']._serialized_options = b'\030\001' _globals['_SECUREAGGREGATION_NAMEDVALUESENTRY']._options = None _globals['_SECUREAGGREGATION_NAMEDVALUESENTRY']._serialized_options = b'8\001' - _globals['_TASK']._serialized_start=89 - _globals['_TASK']._serialized_end=426 - _globals['_TASKINS']._serialized_start=428 - _globals['_TASKINS']._serialized_end=520 - _globals['_TASKRES']._serialized_start=522 - _globals['_TASKRES']._serialized_end=614 - _globals['_VALUE']._serialized_start=617 - _globals['_VALUE']._serialized_end=1116 - _globals['_VALUE_DOUBLELIST']._serialized_start=972 - _globals['_VALUE_DOUBLELIST']._serialized_end=998 - _globals['_VALUE_SINT64LIST']._serialized_start=1000 - _globals['_VALUE_SINT64LIST']._serialized_end=1026 - _globals['_VALUE_BOOLLIST']._serialized_start=1028 - _globals['_VALUE_BOOLLIST']._serialized_end=1052 - _globals['_VALUE_STRINGLIST']._serialized_start=1054 - _globals['_VALUE_STRINGLIST']._serialized_end=1080 - _globals['_VALUE_BYTESLIST']._serialized_start=1082 - _globals['_VALUE_BYTESLIST']._serialized_end=1107 - _globals['_SECUREAGGREGATION']._serialized_start=1119 - _globals['_SECUREAGGREGATION']._serialized_end=1279 - _globals['_SECUREAGGREGATION_NAMEDVALUESENTRY']._serialized_start=1210 - _globals['_SECUREAGGREGATION_NAMEDVALUESENTRY']._serialized_end=1279 + _globals['_TASK']._serialized_start=117 + _globals['_TASK']._serialized_end=454 + _globals['_TASKINS']._serialized_start=456 + _globals['_TASKINS']._serialized_end=548 + _globals['_TASKRES']._serialized_start=550 + _globals['_TASKRES']._serialized_end=642 + _globals['_VALUE']._serialized_start=645 + _globals['_VALUE']._serialized_end=977 + _globals['_SECUREAGGREGATION']._serialized_start=980 + _globals['_SECUREAGGREGATION']._serialized_end=1140 + _globals['_SECUREAGGREGATION_NAMEDVALUESENTRY']._serialized_start=1071 + _globals['_SECUREAGGREGATION_NAMEDVALUESENTRY']._serialized_end=1140 # @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/task_pb2.pyi b/src/py/flwr/proto/task_pb2.pyi index 1039fd3d56ae..ebe69d05c974 100644 --- a/src/py/flwr/proto/task_pb2.pyi +++ b/src/py/flwr/proto/task_pb2.pyi @@ -4,6 +4,7 @@ isort:skip_file """ import builtins import flwr.proto.node_pb2 +import flwr.proto.recordset_pb2 import flwr.proto.transport_pb2 import google.protobuf.descriptor import google.protobuf.internal.containers @@ -104,61 +105,6 @@ global___TaskRes = TaskRes class Value(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor - class DoubleList(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - VALS_FIELD_NUMBER: builtins.int - @property - def vals(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.float]: ... - def __init__(self, - *, - vals: typing.Optional[typing.Iterable[builtins.float]] = ..., - ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["vals",b"vals"]) -> None: ... - - class Sint64List(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - VALS_FIELD_NUMBER: builtins.int - @property - def vals(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ... - def __init__(self, - *, - vals: typing.Optional[typing.Iterable[builtins.int]] = ..., - ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["vals",b"vals"]) -> None: ... - - class BoolList(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - VALS_FIELD_NUMBER: builtins.int - @property - def vals(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.bool]: ... - def __init__(self, - *, - vals: typing.Optional[typing.Iterable[builtins.bool]] = ..., - ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["vals",b"vals"]) -> None: ... - - class StringList(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - VALS_FIELD_NUMBER: builtins.int - @property - def vals(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[typing.Text]: ... - def __init__(self, - *, - vals: typing.Optional[typing.Iterable[typing.Text]] = ..., - ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["vals",b"vals"]) -> None: ... - - class BytesList(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - VALS_FIELD_NUMBER: builtins.int - @property - def vals(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.bytes]: ... - def __init__(self, - *, - vals: typing.Optional[typing.Iterable[builtins.bytes]] = ..., - ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["vals",b"vals"]) -> None: ... - DOUBLE_FIELD_NUMBER: builtins.int SINT64_FIELD_NUMBER: builtins.int BOOL_FIELD_NUMBER: builtins.int @@ -177,17 +123,17 @@ class Value(google.protobuf.message.Message): string: typing.Text bytes: builtins.bytes @property - def double_list(self) -> global___Value.DoubleList: + def double_list(self) -> flwr.proto.recordset_pb2.DoubleList: """List types""" pass @property - def sint64_list(self) -> global___Value.Sint64List: ... + def sint64_list(self) -> flwr.proto.recordset_pb2.Sint64List: ... @property - def bool_list(self) -> global___Value.BoolList: ... + def bool_list(self) -> flwr.proto.recordset_pb2.BoolList: ... @property - def string_list(self) -> global___Value.StringList: ... + def string_list(self) -> flwr.proto.recordset_pb2.StringList: ... @property - def bytes_list(self) -> global___Value.BytesList: ... + def bytes_list(self) -> flwr.proto.recordset_pb2.BytesList: ... def __init__(self, *, double: builtins.float = ..., @@ -195,11 +141,11 @@ class Value(google.protobuf.message.Message): bool: builtins.bool = ..., string: typing.Text = ..., bytes: builtins.bytes = ..., - double_list: typing.Optional[global___Value.DoubleList] = ..., - sint64_list: typing.Optional[global___Value.Sint64List] = ..., - bool_list: typing.Optional[global___Value.BoolList] = ..., - string_list: typing.Optional[global___Value.StringList] = ..., - bytes_list: typing.Optional[global___Value.BytesList] = ..., + double_list: typing.Optional[flwr.proto.recordset_pb2.DoubleList] = ..., + sint64_list: typing.Optional[flwr.proto.recordset_pb2.Sint64List] = ..., + bool_list: typing.Optional[flwr.proto.recordset_pb2.BoolList] = ..., + string_list: typing.Optional[flwr.proto.recordset_pb2.StringList] = ..., + bytes_list: typing.Optional[flwr.proto.recordset_pb2.BytesList] = ..., ) -> None: ... def HasField(self, field_name: typing_extensions.Literal["bool",b"bool","bool_list",b"bool_list","bytes",b"bytes","bytes_list",b"bytes_list","double",b"double","double_list",b"double_list","sint64",b"sint64","sint64_list",b"sint64_list","string",b"string","string_list",b"string_list","value",b"value"]) -> builtins.bool: ... def ClearField(self, field_name: typing_extensions.Literal["bool",b"bool","bool_list",b"bool_list","bytes",b"bytes","bytes_list",b"bytes_list","double",b"double","double_list",b"double_list","sint64",b"sint64","sint64_list",b"sint64_list","string",b"string","string_list",b"string_list","value",b"value"]) -> None: ... diff --git a/src/py/flwr_tool/init_py_check.py b/src/py/flwr_tool/init_py_check.py index 8cdc2e0ab5be..67425139f991 100755 --- a/src/py/flwr_tool/init_py_check.py +++ b/src/py/flwr_tool/init_py_check.py @@ -36,7 +36,7 @@ def check_missing_init_files(absolute_path: str) -> None: if __name__ == "__main__": if len(sys.argv) == 0: - raise Exception( + raise Exception( # pylint: disable=W0719 "Please provide at least one directory path relative to your current working directory." ) for i, _ in enumerate(sys.argv): diff --git a/src/py/flwr_tool/protoc.py b/src/py/flwr_tool/protoc.py index 5d3ce942c1e0..b0b078c2eae4 100644 --- a/src/py/flwr_tool/protoc.py +++ b/src/py/flwr_tool/protoc.py @@ -51,7 +51,7 @@ def compile_all() -> None: exit_code = protoc.main(command) if exit_code != 0: - raise Exception(f"Error: {command} failed") + raise Exception(f"Error: {command} failed") # pylint: disable=W0719 if __name__ == "__main__": diff --git a/src/py/flwr_tool/protoc_test.py b/src/py/flwr_tool/protoc_test.py index 57ca3ff423c2..607d808c8497 100644 --- a/src/py/flwr_tool/protoc_test.py +++ b/src/py/flwr_tool/protoc_test.py @@ -28,4 +28,4 @@ def test_directories() -> None: def test_proto_file_count() -> None: """Test if the correct number of proto files were captured by the glob.""" - assert len(PROTO_FILES) == 5 + assert len(PROTO_FILES) == 6 From dfa30a30681042e1ad03264fa4aa5c9a1af1960a Mon Sep 17 00:00:00 2001 From: Gustavo Bertoli Date: Tue, 23 Jan 2024 15:25:47 +0100 Subject: [PATCH 4/7] Example for Custom Metrics calculation during Federated Learning (#1958) Co-authored-by: Yan Gao --- examples/custom-metrics/README.md | 106 +++++++++++++++++++++++ examples/custom-metrics/client.py | 71 +++++++++++++++ examples/custom-metrics/pyproject.toml | 19 ++++ examples/custom-metrics/requirements.txt | 4 + examples/custom-metrics/run.sh | 15 ++++ examples/custom-metrics/server.py | 58 +++++++++++++ 6 files changed, 273 insertions(+) create mode 100644 examples/custom-metrics/README.md create mode 100644 examples/custom-metrics/client.py create mode 100644 examples/custom-metrics/pyproject.toml create mode 100644 examples/custom-metrics/requirements.txt create mode 100755 examples/custom-metrics/run.sh create mode 100644 examples/custom-metrics/server.py diff --git a/examples/custom-metrics/README.md b/examples/custom-metrics/README.md new file mode 100644 index 000000000000..debcd7919839 --- /dev/null +++ b/examples/custom-metrics/README.md @@ -0,0 +1,106 @@ +# Flower Example using Custom Metrics + +This simple example demonstrates how to calculate custom metrics over multiple clients beyond the traditional ones available in the ML frameworks. In this case, it demonstrates the use of ready-available `scikit-learn` metrics: accuracy, recall, precision, and f1-score. + +Once both the test values (`y_test`) and the predictions (`y_pred`) are available on the client side (`client.py`), other metrics or custom ones are possible to be calculated. + +The main takeaways of this implementation are: + +- the use of the `output_dict` on the client side - inside `evaluate` method on `client.py` +- the use of the `evaluate_metrics_aggregation_fn` - to aggregate the metrics on the server side, part of the `strategy` on `server.py` + +This example is based on the `quickstart-tensorflow` with CIFAR-10, source [here](https://flower.dev/docs/quickstart-tensorflow.html), with the addition of [Flower Datasets](https://flower.dev/docs/datasets/index.html) to retrieve the CIFAR-10. + +Using the CIFAR-10 dataset for classification, this is a multi-class classification problem, thus some changes on how to calculate the metrics using `average='micro'` and `np.argmax` is required. For binary classification, this is not required. Also, for unsupervised learning tasks, such as using a deep autoencoder, a custom metric based on reconstruction error could be implemented on client side. + +## Project Setup + +Start by cloning the example project. We prepared a single-line command that you can copy into your shell which will checkout the example for you: + +```shell +git clone --depth=1 https://github.com/adap/flower.git && mv flower/examples/custom-metrics . && rm -rf flower && cd custom-metrics +``` + +This will create a new directory called `custom-metrics` containing the following files: + +```shell +-- pyproject.toml +-- requirements.txt +-- client.py +-- server.py +-- run.sh +-- README.md +``` + +### Installing Dependencies + +Project dependencies (such as `scikit-learn`, `tensorflow` and `flwr`) are defined in `pyproject.toml` and `requirements.txt`. We recommend [Poetry](https://python-poetry.org/docs/) to install those dependencies and manage your virtual environment ([Poetry installation](https://python-poetry.org/docs/#installation)) or [pip](https://pip.pypa.io/en/latest/development/), but feel free to use a different way of installing dependencies and managing virtual environments if you have other preferences. + +#### Poetry + +```shell +poetry install +poetry shell +``` + +Poetry will install all your dependencies in a newly created virtual environment. To verify that everything works correctly you can run the following command: + +```shell +poetry run python3 -c "import flwr" +``` + +If you don't see any errors you're good to go! + +#### pip + +Write the command below in your terminal to install the dependencies according to the configuration file requirements.txt. + +```shell +python -m venv venv +source venv/bin/activate +pip install -r requirements.txt +``` + +## Run Federated Learning with Custom Metrics + +Afterwards you are ready to start the Flower server as well as the clients. You can simply start the server in a terminal as follows: + +```shell +python server.py +``` + +Now you are ready to start the Flower clients which will participate in the learning. To do so simply open two more terminals and run the following command in each: + +```shell +python client.py +``` + +Alternatively you can run all of it in one shell as follows: + +```shell +python server.py & +# Wait for a few seconds to give the server enough time to start, then: +python client.py & +python client.py +``` + +or + +```shell +chmod +x run.sh +./run.sh +``` + +You will see that Keras is starting a federated training. Have a look to the [Flower Quickstarter documentation](https://flower.dev/docs/quickstart-tensorflow.html) for a detailed explanation. You can add `steps_per_epoch=3` to `model.fit()` if you just want to evaluate that everything works without having to wait for the client-side training to finish (this will save you a lot of time during development). + +Running `run.sh` will result in the following output (after 3 rounds): + +```shell +INFO flwr 2024-01-17 17:45:23,794 | app.py:228 | app_fit: metrics_distributed { + 'accuracy': [(1, 0.10000000149011612), (2, 0.10000000149011612), (3, 0.3393000066280365)], + 'acc': [(1, 0.1), (2, 0.1), (3, 0.3393)], + 'rec': [(1, 0.1), (2, 0.1), (3, 0.3393)], + 'prec': [(1, 0.1), (2, 0.1), (3, 0.3393)], + 'f1': [(1, 0.10000000000000002), (2, 0.10000000000000002), (3, 0.3393)] +} +``` diff --git a/examples/custom-metrics/client.py b/examples/custom-metrics/client.py new file mode 100644 index 000000000000..b2206118ed44 --- /dev/null +++ b/examples/custom-metrics/client.py @@ -0,0 +1,71 @@ +import os + +import flwr as fl +import numpy as np +import tensorflow as tf +from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score +from flwr_datasets import FederatedDataset + + +# Make TensorFlow log less verbose +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" + + +# Load model (MobileNetV2) +model = tf.keras.applications.MobileNetV2((32, 32, 3), classes=10, weights=None) +model.compile("adam", "sparse_categorical_crossentropy", metrics=["accuracy"]) + +# Load data with Flower Datasets (CIFAR-10) +fds = FederatedDataset(dataset="cifar10", partitioners={"train": 10}) +train = fds.load_full("train") +test = fds.load_full("test") + +# Using Numpy format +train_np = train.with_format("numpy") +test_np = test.with_format("numpy") +x_train, y_train = train_np["img"], train_np["label"] +x_test, y_test = test_np["img"], test_np["label"] + + +# Method for extra learning metrics calculation +def eval_learning(y_test, y_pred): + acc = accuracy_score(y_test, y_pred) + rec = recall_score( + y_test, y_pred, average="micro" + ) # average argument required for multi-class + prec = precision_score(y_test, y_pred, average="micro") + f1 = f1_score(y_test, y_pred, average="micro") + return acc, rec, prec, f1 + + +# Define Flower client +class FlowerClient(fl.client.NumPyClient): + def get_parameters(self, config): + return model.get_weights() + + def fit(self, parameters, config): + model.set_weights(parameters) + model.fit(x_train, y_train, epochs=1, batch_size=32) + return model.get_weights(), len(x_train), {} + + def evaluate(self, parameters, config): + model.set_weights(parameters) + loss, accuracy = model.evaluate(x_test, y_test) + y_pred = model.predict(x_test) + y_pred = np.argmax(y_pred, axis=1).reshape( + -1, 1 + ) # MobileNetV2 outputs 10 possible classes, argmax returns just the most probable + + acc, rec, prec, f1 = eval_learning(y_test, y_pred) + output_dict = { + "accuracy": accuracy, # accuracy from tensorflow model.evaluate + "acc": acc, + "rec": rec, + "prec": prec, + "f1": f1, + } + return loss, len(x_test), output_dict + + +# Start Flower client +fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=FlowerClient()) diff --git a/examples/custom-metrics/pyproject.toml b/examples/custom-metrics/pyproject.toml new file mode 100644 index 000000000000..8a2da6562018 --- /dev/null +++ b/examples/custom-metrics/pyproject.toml @@ -0,0 +1,19 @@ +[build-system] +requires = ["poetry-core>=1.4.0"] +build-backend = "poetry.core.masonry.api" + +[tool.poetry] +name = "custom-metrics" +version = "0.1.0" +description = "Federated Learning with Flower and Custom Metrics" +authors = [ + "The Flower Authors ", + "Gustavo Bertoli " +] + +[tool.poetry.dependencies] +python = ">=3.8,<3.11" +flwr = ">=1.0,<2.0" +flwr-datasets = { version = "*", extras = ["vision"] } +scikit-learn = "^1.2.2" +tensorflow = "==2.12.0" \ No newline at end of file diff --git a/examples/custom-metrics/requirements.txt b/examples/custom-metrics/requirements.txt new file mode 100644 index 000000000000..69d867c5f287 --- /dev/null +++ b/examples/custom-metrics/requirements.txt @@ -0,0 +1,4 @@ +flwr>=1.0,<2.0 +flwr-datasets[vision] +scikit-learn>=1.2.2 +tensorflow==2.12.0 diff --git a/examples/custom-metrics/run.sh b/examples/custom-metrics/run.sh new file mode 100755 index 000000000000..c64f362086aa --- /dev/null +++ b/examples/custom-metrics/run.sh @@ -0,0 +1,15 @@ +#!/bin/bash + +echo "Starting server" +python server.py & +sleep 3 # Sleep for 3s to give the server enough time to start + +for i in `seq 0 1`; do + echo "Starting client $i" + python client.py & +done + +# This will allow you to use CTRL+C to stop all background processes +trap "trap - SIGTERM && kill -- -$$" SIGINT SIGTERM +# Wait for all background processes to complete +wait diff --git a/examples/custom-metrics/server.py b/examples/custom-metrics/server.py new file mode 100644 index 000000000000..f8420bf51f16 --- /dev/null +++ b/examples/custom-metrics/server.py @@ -0,0 +1,58 @@ +import flwr as fl +import numpy as np + + +# Define metrics aggregation function +def average_metrics(metrics): + """Aggregate metrics from multiple clients by calculating mean averages. + + Parameters: + - metrics (list): A list containing tuples, where each tuple represents metrics for a client. + Each tuple is structured as (num_examples, metric), where: + - num_examples (int): The number of examples used to compute the metrics. + - metric (dict): A dictionary containing custom metrics provided as `output_dict` + in the `evaluate` method from `client.py`. + + Returns: + A dictionary with the aggregated metrics, calculating mean averages. The keys of the + dictionary represent different metrics, including: + - 'accuracy': Mean accuracy calculated by TensorFlow. + - 'acc': Mean accuracy from scikit-learn. + - 'rec': Mean recall from scikit-learn. + - 'prec': Mean precision from scikit-learn. + - 'f1': Mean F1 score from scikit-learn. + + Note: If a weighted average is required, the `num_examples` parameter can be leveraged. + + Example: + Example `metrics` list for two clients after the last round: + [(10000, {'prec': 0.108, 'acc': 0.108, 'f1': 0.108, 'accuracy': 0.1080000028014183, 'rec': 0.108}), + (10000, {'f1': 0.108, 'rec': 0.108, 'accuracy': 0.1080000028014183, 'prec': 0.108, 'acc': 0.108})] + """ + + # Here num_examples are not taken into account by using _ + accuracies_tf = np.mean([metric["accuracy"] for _, metric in metrics]) + accuracies = np.mean([metric["acc"] for _, metric in metrics]) + recalls = np.mean([metric["rec"] for _, metric in metrics]) + precisions = np.mean([metric["prec"] for _, metric in metrics]) + f1s = np.mean([metric["f1"] for _, metric in metrics]) + + return { + "accuracy": accuracies_tf, + "acc": accuracies, + "rec": recalls, + "prec": precisions, + "f1": f1s, + } + + +# Define strategy and the custom aggregation function for the evaluation metrics +strategy = fl.server.strategy.FedAvg(evaluate_metrics_aggregation_fn=average_metrics) + + +# Start Flower server +fl.server.start_server( + server_address="0.0.0.0:8080", + config=fl.server.ServerConfig(num_rounds=3), + strategy=strategy, +) From a04265abb829bf955495d061ed092d923726dff9 Mon Sep 17 00:00:00 2001 From: Yan Gao Date: Tue, 23 Jan 2024 14:42:26 +0000 Subject: [PATCH 5/7] Update xgb doc for new features (#2821) Co-authored-by: yan-gao-GY --- doc/source/tutorial-quickstart-xgboost.rst | 616 ++++++++++++++++++--- 1 file changed, 542 insertions(+), 74 deletions(-) diff --git a/doc/source/tutorial-quickstart-xgboost.rst b/doc/source/tutorial-quickstart-xgboost.rst index 7eb58da7f2f6..3a7b356c4d2a 100644 --- a/doc/source/tutorial-quickstart-xgboost.rst +++ b/doc/source/tutorial-quickstart-xgboost.rst @@ -595,9 +595,164 @@ Comprehensive Federated XGBoost Now that you have known how federated XGBoost work with Flower, it's time to run some more comprehensive experiments by customising the experimental settings. In the xgboost-comprehensive example (`full code `_), -we provide more options to define various experimental setups, including data partitioning and centralised/distributed evaluation. +we provide more options to define various experimental setups, including aggregation strategies, data partitioning and centralised/distributed evaluation. +We also support `Flower simulation `_ making it easy to simulate large client cohorts in a resource-aware manner. Let's take a look! +Cyclic training +~~~~~~~~~~~~~~~~~~ + +In addition to bagging aggregation, we offer a cyclic training scheme, which performs FL in a client-by-client fashion. +Instead of aggregating multiple clients, there is only one single client participating in the training per round in the cyclic training scenario. +The trained local XGBoost trees will be passed to the next client as an initialised model for next round's boosting. + +To do this, we first customise a :code:`ClientManager` in :code:`server_utils.py`: + +.. code-block:: python + + class CyclicClientManager(SimpleClientManager): + """Provides a cyclic client selection rule.""" + + def sample( + self, + num_clients: int, + min_num_clients: Optional[int] = None, + criterion: Optional[Criterion] = None, + ) -> List[ClientProxy]: + """Sample a number of Flower ClientProxy instances.""" + + # Block until at least num_clients are connected. + if min_num_clients is None: + min_num_clients = num_clients + self.wait_for(min_num_clients) + + # Sample clients which meet the criterion + available_cids = list(self.clients) + if criterion is not None: + available_cids = [ + cid for cid in available_cids if criterion.select(self.clients[cid]) + ] + + if num_clients > len(available_cids): + log( + INFO, + "Sampling failed: number of available clients" + " (%s) is less than number of requested clients (%s).", + len(available_cids), + num_clients, + ) + return [] + + # Return all available clients + return [self.clients[cid] for cid in available_cids] + +The customised :code:`ClientManager` samples all available clients in each FL round based on the order of connection to the server. +Then, we define a new strategy :code:`FedXgbCyclic` in :code:`flwr.server.strategy.fedxgb_cyclic.py`, +in order to sequentially select only one client in given round and pass the received model to next client. + +.. code-block:: python + + class FedXgbCyclic(FedAvg): + """Configurable FedXgbCyclic strategy implementation.""" + + # pylint: disable=too-many-arguments,too-many-instance-attributes, line-too-long + def __init__( + self, + **kwargs: Any, + ): + self.global_model: Optional[bytes] = None + super().__init__(**kwargs) + + def aggregate_fit( + self, + server_round: int, + results: List[Tuple[ClientProxy, FitRes]], + failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], + ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + """Aggregate fit results using bagging.""" + if not results: + return None, {} + # Do not aggregate if there are failures and failures are not accepted + if not self.accept_failures and failures: + return None, {} + + # Fetch the client model from last round as global model + for _, fit_res in results: + update = fit_res.parameters.tensors + for bst in update: + self.global_model = bst + + return ( + Parameters(tensor_type="", tensors=[cast(bytes, self.global_model)]), + {}, + ) + +Unlike the original :code:`FedAvg`, we don't perform aggregation here. +Instead, we just make a copy of the received client model as global model by overriding :code:`aggregate_fit`. + +Also, the customised :code:`configure_fit` and :code:`configure_evaluate` methods ensure the clients to be sequentially selected given FL round: + +.. code-block:: python + + def configure_fit( + self, server_round: int, parameters: Parameters, client_manager: ClientManager + ) -> List[Tuple[ClientProxy, FitIns]]: + """Configure the next round of training.""" + config = {} + if self.on_fit_config_fn is not None: + # Custom fit config function provided + config = self.on_fit_config_fn(server_round) + fit_ins = FitIns(parameters, config) + + # Sample clients + sample_size, min_num_clients = self.num_fit_clients( + client_manager.num_available() + ) + clients = client_manager.sample( + num_clients=sample_size, + min_num_clients=min_num_clients, + ) + + # Sample the clients sequentially given server_round + sampled_idx = (server_round - 1) % len(clients) + sampled_clients = [clients[sampled_idx]] + + # Return client/config pairs + return [(client, fit_ins) for client in sampled_clients] + + def configure_evaluate( + self, server_round: int, parameters: Parameters, client_manager: ClientManager + ) -> List[Tuple[ClientProxy, EvaluateIns]]: + """Configure the next round of evaluation.""" + # Do not configure federated evaluation if fraction eval is 0. + if self.fraction_evaluate == 0.0: + return [] + + # Parameters and config + config = {} + if self.on_evaluate_config_fn is not None: + # Custom evaluation config function provided + config = self.on_evaluate_config_fn(server_round) + evaluate_ins = EvaluateIns(parameters, config) + + # Sample clients + sample_size, min_num_clients = self.num_evaluation_clients( + client_manager.num_available() + ) + clients = client_manager.sample( + num_clients=sample_size, + min_num_clients=min_num_clients, + ) + + # Sample the clients sequentially given server_round + sampled_idx = (server_round - 1) % len(clients) + sampled_clients = [clients[sampled_idx]] + + # Return client/config pairs + return [(client, evaluate_ins) for client in sampled_clients] + + + Customised data partitioning ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -634,7 +789,7 @@ Currently, we provide four supported partitioner type to simulate the uniformity Customised centralised/distributed evaluation ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -To facilitate centralised evaluation, we define a function in :code:`server.py`: +To facilitate centralised evaluation, we define a function in :code:`server_utils.py`: .. code-block:: python @@ -670,51 +825,265 @@ This function returns a evaluation function which instantiates a :code:`Booster` The evaluation is conducted by calling :code:`eval_set()` method, and the tested AUC value is reported. As for distributed evaluation on the clients, it's same as the quick-start example by -overriding the :code:`evaluate()` method insides the :code:`XgbClient` class in :code:`client.py`. +overriding the :code:`evaluate()` method insides the :code:`XgbClient` class in :code:`client_utils.py`. -Arguments parser -~~~~~~~~~~~~~~~~~~~~~~ +Flower simulation +~~~~~~~~~~~~~~~~~~~~ +We also provide an example code (:code:`sim.py`) to use the simulation capabilities of Flower to simulate federated XGBoost training on either a single machine or a cluster of machines. -In :code:`utils.py`, we define the arguments parsers for clients and server, allowing users to specify different experimental settings. -Let's first see the sever side: +.. code-block:: python + + from logging import INFO + import xgboost as xgb + from tqdm import tqdm + + import flwr as fl + from flwr_datasets import FederatedDataset + from flwr.common.logger import log + from flwr.server.strategy import FedXgbBagging, FedXgbCyclic + + from dataset import ( + instantiate_partitioner, + train_test_split, + transform_dataset_to_dmatrix, + separate_xy, + resplit, + ) + from utils import ( + sim_args_parser, + NUM_LOCAL_ROUND, + BST_PARAMS, + ) + from server_utils import ( + eval_config, + fit_config, + evaluate_metrics_aggregation, + get_evaluate_fn, + CyclicClientManager, + ) + from client_utils import XgbClient + +After importing all required packages, we define a :code:`main()` function to perform the simulation process: .. code-block:: python - import argparse + def main(): + # Parse arguments for experimental settings + args = sim_args_parser() + # Load (HIGGS) dataset and conduct partitioning + partitioner = instantiate_partitioner( + partitioner_type=args.partitioner_type, num_partitions=args.pool_size + ) + fds = FederatedDataset( + dataset="jxie/higgs", + partitioners={"train": partitioner}, + resplitter=resplit, + ) - def server_args_parser(): - """Parse arguments to define experimental settings on server side.""" - parser = argparse.ArgumentParser() + # Load centralised test set + if args.centralised_eval or args.centralised_eval_client: + log(INFO, "Loading centralised test set...") + test_data = fds.load_full("test") + test_data.set_format("numpy") + num_test = test_data.shape[0] + test_dmatrix = transform_dataset_to_dmatrix(test_data) + + # Load partitions and reformat data to DMatrix for xgboost + log(INFO, "Loading client local partitions...") + train_data_list = [] + valid_data_list = [] + + # Load and process all client partitions. This upfront cost is amortized soon + # after the simulation begins since clients wont need to preprocess their partition. + for node_id in tqdm(range(args.pool_size), desc="Extracting client partition"): + # Extract partition for client with node_id + partition = fds.load_partition(node_id=node_id, split="train") + partition.set_format("numpy") + + if args.centralised_eval_client: + # Use centralised test set for evaluation + train_data = partition + num_train = train_data.shape[0] + x_test, y_test = separate_xy(test_data) + valid_data_list.append(((x_test, y_test), num_test)) + else: + # Train/test splitting + train_data, valid_data, num_train, num_val = train_test_split( + partition, test_fraction=args.test_fraction, seed=args.seed + ) + x_valid, y_valid = separate_xy(valid_data) + valid_data_list.append(((x_valid, y_valid), num_val)) - parser.add_argument( - "--pool-size", default=2, type=int, help="Number of total clients." - ) - parser.add_argument( - "--num-rounds", default=5, type=int, help="Number of FL rounds." - ) - parser.add_argument( - "--num-clients-per-round", - default=2, - type=int, - help="Number of clients participate in training each round.", - ) - parser.add_argument( - "--num-evaluate-clients", - default=2, - type=int, - help="Number of clients selected for evaluation.", + x_train, y_train = separate_xy(train_data) + train_data_list.append(((x_train, y_train), num_train)) + +We first load the dataset and perform data partitioning, and the pre-processed data is stored in a :code:`list`. +After the simulation begins, the clients won't need to pre-process their partitions again. + +Then, we define the strategies and other hyper-parameters: + +.. code-block:: python + + # Define strategy + if args.train_method == "bagging": + # Bagging training + strategy = FedXgbBagging( + evaluate_function=get_evaluate_fn(test_dmatrix) + if args.centralised_eval + else None, + fraction_fit=(float(args.num_clients_per_round) / args.pool_size), + min_fit_clients=args.num_clients_per_round, + min_available_clients=args.pool_size, + min_evaluate_clients=args.num_evaluate_clients + if not args.centralised_eval + else 0, + fraction_evaluate=1.0 if not args.centralised_eval else 0.0, + on_evaluate_config_fn=eval_config, + on_fit_config_fn=fit_config, + evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation + if not args.centralised_eval + else None, ) - parser.add_argument( - "--centralised-eval", - action="store_true", - help="Conduct centralised evaluation (True), or client evaluation on hold-out data (False).", + else: + # Cyclic training + strategy = FedXgbCyclic( + fraction_fit=1.0, + min_available_clients=args.pool_size, + fraction_evaluate=1.0, + evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation, + on_evaluate_config_fn=eval_config, + on_fit_config_fn=fit_config, ) - args = parser.parse_args() - return args + # Resources to be assigned to each virtual client + # In this example we use CPU by default + client_resources = { + "num_cpus": args.num_cpus_per_client, + "num_gpus": 0.0, + } + + # Hyper-parameters for xgboost training + num_local_round = NUM_LOCAL_ROUND + params = BST_PARAMS + + # Setup learning rate + if args.train_method == "bagging" and args.scaled_lr: + new_lr = params["eta"] / args.pool_size + params.update({"eta": new_lr}) + +After that, we start the simulation by calling :code:`fl.simulation.start_simulation`: + +.. code-block:: python -This allows user to specify the number of total clients / FL rounds / participating clients / clients for evaluation, + # Start simulation + fl.simulation.start_simulation( + client_fn=get_client_fn( + train_data_list, + valid_data_list, + args.train_method, + params, + num_local_round, + ), + num_clients=args.pool_size, + client_resources=client_resources, + config=fl.server.ServerConfig(num_rounds=args.num_rounds), + strategy=strategy, + client_manager=CyclicClientManager() if args.train_method == "cyclic" else None, + ) + +One of key parameters for :code:`start_simulation` is :code:`client_fn` which returns a function to construct a client. +We define it as follows: + +.. code-block:: python + + def get_client_fn( + train_data_list, valid_data_list, train_method, params, num_local_round + ): + """Return a function to construct a client. + + The VirtualClientEngine will execute this function whenever a client is sampled by + the strategy to participate. + """ + + def client_fn(cid: str) -> fl.client.Client: + """Construct a FlowerClient with its own dataset partition.""" + x_train, y_train = train_data_list[int(cid)][0] + x_valid, y_valid = valid_data_list[int(cid)][0] + + # Reformat data to DMatrix + train_dmatrix = xgb.DMatrix(x_train, label=y_train) + valid_dmatrix = xgb.DMatrix(x_valid, label=y_valid) + + # Fetch the number of examples + num_train = train_data_list[int(cid)][1] + num_val = valid_data_list[int(cid)][1] + + # Create and return client + return XgbClient( + train_dmatrix, + valid_dmatrix, + num_train, + num_val, + num_local_round, + params, + train_method, + ) + + return client_fn + + + +Arguments parser +~~~~~~~~~~~~~~~~~~~~~~ + +In :code:`utils.py`, we define the arguments parsers for clients, server and simulation, allowing users to specify different experimental settings. +Let's first see the sever side: + +.. code-block:: python + + import argparse + + + def server_args_parser(): + """Parse arguments to define experimental settings on server side.""" + parser = argparse.ArgumentParser() + + parser.add_argument( + "--train-method", + default="bagging", + type=str, + choices=["bagging", "cyclic"], + help="Training methods selected from bagging aggregation or cyclic training.", + ) + parser.add_argument( + "--pool-size", default=2, type=int, help="Number of total clients." + ) + parser.add_argument( + "--num-rounds", default=5, type=int, help="Number of FL rounds." + ) + parser.add_argument( + "--num-clients-per-round", + default=2, + type=int, + help="Number of clients participate in training each round.", + ) + parser.add_argument( + "--num-evaluate-clients", + default=2, + type=int, + help="Number of clients selected for evaluation.", + ) + parser.add_argument( + "--centralised-eval", + action="store_true", + help="Conduct centralised evaluation (True), or client evaluation on hold-out data (False).", + ) + + args = parser.parse_args() + return args + +This allows user to specify training strategies / the number of total clients / FL rounds / participating clients / clients for evaluation, and evaluation fashion. Note that with :code:`--centralised-eval`, the sever will do centralised evaluation and all functionalities for client evaluation will be disabled. @@ -723,60 +1092,159 @@ Then, the argument parser on client side: .. code-block:: python def client_args_parser(): - """Parse arguments to define experimental settings on client side.""" - parser = argparse.ArgumentParser() + """Parse arguments to define experimental settings on client side.""" + parser = argparse.ArgumentParser() + + parser.add_argument( + "--train-method", + default="bagging", + type=str, + choices=["bagging", "cyclic"], + help="Training methods selected from bagging aggregation or cyclic training.", + ) + parser.add_argument( + "--num-partitions", default=10, type=int, help="Number of partitions." + ) + parser.add_argument( + "--partitioner-type", + default="uniform", + type=str, + choices=["uniform", "linear", "square", "exponential"], + help="Partitioner types.", + ) + parser.add_argument( + "--node-id", + default=0, + type=int, + help="Node ID used for the current client.", + ) + parser.add_argument( + "--seed", default=42, type=int, help="Seed used for train/test splitting." + ) + parser.add_argument( + "--test-fraction", + default=0.2, + type=float, + help="Test fraction for train/test splitting.", + ) + parser.add_argument( + "--centralised-eval", + action="store_true", + help="Conduct evaluation on centralised test set (True), or on hold-out data (False).", + ) + parser.add_argument( + "--scaled-lr", + action="store_true", + help="Perform scaled learning rate based on the number of clients (True).", + ) + + args = parser.parse_args() + return args - parser.add_argument( - "--num-partitions", default=10, type=int, help="Number of partitions." - ) - parser.add_argument( - "--partitioner-type", - default="uniform", - type=str, - choices=["uniform", "linear", "square", "exponential"], - help="Partitioner types.", - ) - parser.add_argument( - "--node-id", - default=0, - type=int, - help="Node ID used for the current client.", - ) - parser.add_argument( - "--seed", default=42, type=int, help="Seed used for train/test splitting." - ) - parser.add_argument( - "--test-fraction", - default=0.2, - type=float, - help="Test fraction for train/test splitting.", - ) - parser.add_argument( - "--centralised-eval", - action="store_true", - help="Conduct centralised evaluation (True), or client evaluation on hold-out data (False).", - ) +This defines various options for client data partitioning. +Besides, clients also have an option to conduct evaluation on centralised test set by setting :code:`--centralised-eval`, +as well as an option to perform scaled learning rate based on the number of clients by setting :code:`--scaled-lr`. - args = parser.parse_args() - return args +We also have an argument parser for simulation: -This defines various options for client data partitioning. -Besides, clients also have a option to conduct evaluation on centralised test set by setting :code:`--centralised-eval`. +.. code-block:: python + + def sim_args_parser(): + """Parse arguments to define experimental settings on server side.""" + parser = argparse.ArgumentParser() + + parser.add_argument( + "--train-method", + default="bagging", + type=str, + choices=["bagging", "cyclic"], + help="Training methods selected from bagging aggregation or cyclic training.", + ) + + # Server side + parser.add_argument( + "--pool-size", default=5, type=int, help="Number of total clients." + ) + parser.add_argument( + "--num-rounds", default=30, type=int, help="Number of FL rounds." + ) + parser.add_argument( + "--num-clients-per-round", + default=5, + type=int, + help="Number of clients participate in training each round.", + ) + parser.add_argument( + "--num-evaluate-clients", + default=5, + type=int, + help="Number of clients selected for evaluation.", + ) + parser.add_argument( + "--centralised-eval", + action="store_true", + help="Conduct centralised evaluation (True), or client evaluation on hold-out data (False).", + ) + parser.add_argument( + "--num-cpus-per-client", + default=2, + type=int, + help="Number of CPUs used for per client.", + ) + + # Client side + parser.add_argument( + "--partitioner-type", + default="uniform", + type=str, + choices=["uniform", "linear", "square", "exponential"], + help="Partitioner types.", + ) + parser.add_argument( + "--seed", default=42, type=int, help="Seed used for train/test splitting." + ) + parser.add_argument( + "--test-fraction", + default=0.2, + type=float, + help="Test fraction for train/test splitting.", + ) + parser.add_argument( + "--centralised-eval-client", + action="store_true", + help="Conduct evaluation on centralised test set (True), or on hold-out data (False).", + ) + parser.add_argument( + "--scaled-lr", + action="store_true", + help="Perform scaled learning rate based on the number of clients (True).", + ) + + args = parser.parse_args() + return args + +This integrates all arguments for both client and server sides. Example commands ~~~~~~~~~~~~~~~~~~~~~ -To run a centralised evaluated experiment on 5 clients with exponential distribution for 50 rounds, +To run a centralised evaluated experiment with bagging strategy on 5 clients with exponential distribution for 50 rounds, we first start the server as below: .. code-block:: shell - $ python3 server.py --pool-size=5 --num-rounds=50 --num-clients-per-round=5 --centralised-eval + $ python3 server.py --train-method=bagging --pool-size=5 --num-rounds=50 --num-clients-per-round=5 --centralised-eval Then, on each client terminal, we start the clients: .. code-block:: shell - $ python3 clients.py --num-partitions=5 --partitioner-type=exponential --node-id=NODE_ID + $ python3 clients.py --train-method=bagging --num-partitions=5 --partitioner-type=exponential --node-id=NODE_ID + +To run the same experiment with Flower simulation: + +.. code-block:: shell + + $ python3 sim.py --train-method=bagging --pool-size=5 --num-rounds=50 --num-clients-per-round=5 --partitioner-type=exponential --centralised-eval The full `code `_ for this comprehensive example can be found in :code:`examples/xgboost-comprehensive`. From 0e0af842ff9fead510dfc122a93f1440eeb3a75a Mon Sep 17 00:00:00 2001 From: Heng Pan <134433891+panh99@users.noreply.github.com> Date: Tue, 23 Jan 2024 17:41:24 +0000 Subject: [PATCH 6/7] Add `RecordSet` to `Task` proto (#2841) --- src/proto/flwr/proto/recordset.proto | 6 +++ src/proto/flwr/proto/task.proto | 3 +- src/py/flwr/common/serde.py | 32 ++++++++++++++ src/py/flwr/common/serde_test.py | 65 ++++++++++++++++++++++++++++ src/py/flwr/proto/recordset_pb2.py | 16 ++++++- src/py/flwr/proto/recordset_pb2.pyi | 65 ++++++++++++++++++++++++++++ src/py/flwr/proto/task_pb2.py | 26 ++++++----- src/py/flwr/proto/task_pb2.pyi | 14 +++--- 8 files changed, 208 insertions(+), 19 deletions(-) diff --git a/src/proto/flwr/proto/recordset.proto b/src/proto/flwr/proto/recordset.proto index 8e2e5d60b6db..d51d0f9ce416 100644 --- a/src/proto/flwr/proto/recordset.proto +++ b/src/proto/flwr/proto/recordset.proto @@ -68,3 +68,9 @@ message ParametersRecord { message MetricsRecord { map data = 1; } message ConfigsRecord { map data = 1; } + +message RecordSet { + map parameters = 1; + map metrics = 2; + map configs = 3; +} diff --git a/src/proto/flwr/proto/task.proto b/src/proto/flwr/proto/task.proto index 20dd5a3aa6c8..2cde16143d8d 100644 --- a/src/proto/flwr/proto/task.proto +++ b/src/proto/flwr/proto/task.proto @@ -29,10 +29,11 @@ message Task { string ttl = 5; repeated string ancestry = 6; string task_type = 7; - SecureAggregation sa = 8; + RecordSet recordset = 8; ServerMessage legacy_server_message = 101 [ deprecated = true ]; ClientMessage legacy_client_message = 102 [ deprecated = true ]; + SecureAggregation sa = 103 [ deprecated = true ]; } message TaskIns { diff --git a/src/py/flwr/common/serde.py b/src/py/flwr/common/serde.py index 2094c76a9856..2600d46edddc 100644 --- a/src/py/flwr/common/serde.py +++ b/src/py/flwr/common/serde.py @@ -28,6 +28,7 @@ from flwr.proto.recordset_pb2 import MetricsRecord as ProtoMetricsRecord from flwr.proto.recordset_pb2 import MetricsRecordValue as ProtoMetricsRecordValue from flwr.proto.recordset_pb2 import ParametersRecord as ProtoParametersRecord +from flwr.proto.recordset_pb2 import RecordSet as ProtoRecordSet from flwr.proto.recordset_pb2 import Sint64List, StringList from flwr.proto.task_pb2 import Value from flwr.proto.transport_pb2 import ( @@ -45,6 +46,7 @@ from .configsrecord import ConfigsRecord from .metricsrecord import MetricsRecord from .parametersrecord import Array, ParametersRecord +from .recordset import RecordSet # === ServerMessage message === @@ -719,3 +721,33 @@ def configs_record_from_proto(record_proto: ProtoConfigsRecord) -> ConfigsRecord ), keep_input=False, ) + + +# === RecordSet message === + + +def recordset_to_proto(recordset: RecordSet) -> ProtoRecordSet: + """Serialize RecordSet to ProtoBuf.""" + return ProtoRecordSet( + parameters={ + k: parameters_record_to_proto(v) for k, v in recordset.parameters.items() + }, + metrics={k: metrics_record_to_proto(v) for k, v in recordset.metrics.items()}, + configs={k: configs_record_to_proto(v) for k, v in recordset.configs.items()}, + ) + + +def recordset_from_proto(recordset_proto: ProtoRecordSet) -> RecordSet: + """Deserialize RecordSet from ProtoBuf.""" + return RecordSet( + parameters={ + k: parameters_record_from_proto(v) + for k, v in recordset_proto.parameters.items() + }, + metrics={ + k: metrics_record_from_proto(v) for k, v in recordset_proto.metrics.items() + }, + configs={ + k: configs_record_from_proto(v) for k, v in recordset_proto.configs.items() + }, + ) diff --git a/src/py/flwr/common/serde_test.py b/src/py/flwr/common/serde_test.py index c584597d89f6..53f40eee5e53 100644 --- a/src/py/flwr/common/serde_test.py +++ b/src/py/flwr/common/serde_test.py @@ -23,12 +23,14 @@ from flwr.proto.recordset_pb2 import ConfigsRecord as ProtoConfigsRecord from flwr.proto.recordset_pb2 import MetricsRecord as ProtoMetricsRecord from flwr.proto.recordset_pb2 import ParametersRecord as ProtoParametersRecord +from flwr.proto.recordset_pb2 import RecordSet as ProtoRecordSet # pylint: enable=E0611 from . import typing from .configsrecord import ConfigsRecord from .metricsrecord import MetricsRecord from .parametersrecord import Array, ParametersRecord +from .recordset import RecordSet from .serde import ( array_from_proto, array_to_proto, @@ -40,6 +42,8 @@ named_values_to_proto, parameters_record_from_proto, parameters_record_to_proto, + recordset_from_proto, + recordset_to_proto, scalar_from_proto, scalar_to_proto, status_from_proto, @@ -242,3 +246,64 @@ def test_configs_record_serialization_deserialization() -> None: # Assert assert isinstance(proto, ProtoConfigsRecord) assert original.data == deserialized.data + + +def test_recordset_serialization_deserialization() -> None: + """Test serialization and deserialization of RecordSet.""" + # Prepare + encoder_params_record = ParametersRecord( + array_dict=OrderedDict( + [ + ( + "k1", + Array(dtype="float", shape=[2, 2], stype="dense", data=b"1234"), + ), + ("k2", Array(dtype="int", shape=[3], stype="sparse", data=b"567")), + ] + ), + keep_input=False, + ) + decoder_params_record = ParametersRecord( + array_dict=OrderedDict( + [ + ( + "k1", + Array( + dtype="float", shape=[32, 32, 4], stype="dense", data=b"0987" + ), + ), + ] + ), + keep_input=False, + ) + + original = RecordSet( + parameters={ + "encoder_parameters": encoder_params_record, + "decoder_parameters": decoder_params_record, + }, + metrics={ + "acc_metrics": MetricsRecord( + metrics_dict={"accuracy": 0.95, "loss": 0.1}, keep_input=False + ) + }, + configs={ + "my_configs": ConfigsRecord( + configs_dict={ + "learning_rate": 0.01, + "batch_size": 32, + "public_key": b"21f8sioj@!#", + "log": "Hello, world!", + }, + keep_input=False, + ) + }, + ) + + # Execute + proto = recordset_to_proto(original) + deserialized = recordset_from_proto(proto) + + # Assert + assert isinstance(proto, ProtoRecordSet) + assert original == deserialized diff --git a/src/py/flwr/proto/recordset_pb2.py b/src/py/flwr/proto/recordset_pb2.py index 4134511f1f53..f7f74d72182b 100644 --- a/src/py/flwr/proto/recordset_pb2.py +++ b/src/py/flwr/proto/recordset_pb2.py @@ -14,7 +14,7 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1a\x66lwr/proto/recordset.proto\x12\nflwr.proto\"\x1a\n\nDoubleList\x12\x0c\n\x04vals\x18\x01 \x03(\x01\"\x1a\n\nSint64List\x12\x0c\n\x04vals\x18\x01 \x03(\x12\"\x18\n\x08\x42oolList\x12\x0c\n\x04vals\x18\x01 \x03(\x08\"\x1a\n\nStringList\x12\x0c\n\x04vals\x18\x01 \x03(\t\"\x19\n\tBytesList\x12\x0c\n\x04vals\x18\x01 \x03(\x0c\"B\n\x05\x41rray\x12\r\n\x05\x64type\x18\x01 \x01(\t\x12\r\n\x05shape\x18\x02 \x03(\x05\x12\r\n\x05stype\x18\x03 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\x0c\"\x9f\x01\n\x12MetricsRecordValue\x12\x10\n\x06\x64ouble\x18\x01 \x01(\x01H\x00\x12\x10\n\x06sint64\x18\x02 \x01(\x12H\x00\x12-\n\x0b\x64ouble_list\x18\x15 \x01(\x0b\x32\x16.flwr.proto.DoubleListH\x00\x12-\n\x0bsint64_list\x18\x16 \x01(\x0b\x32\x16.flwr.proto.Sint64ListH\x00\x42\x07\n\x05value\"\xd9\x02\n\x12\x43onfigsRecordValue\x12\x10\n\x06\x64ouble\x18\x01 \x01(\x01H\x00\x12\x10\n\x06sint64\x18\x02 \x01(\x12H\x00\x12\x0e\n\x04\x62ool\x18\x03 \x01(\x08H\x00\x12\x10\n\x06string\x18\x04 \x01(\tH\x00\x12\x0f\n\x05\x62ytes\x18\x05 \x01(\x0cH\x00\x12-\n\x0b\x64ouble_list\x18\x15 \x01(\x0b\x32\x16.flwr.proto.DoubleListH\x00\x12-\n\x0bsint64_list\x18\x16 \x01(\x0b\x32\x16.flwr.proto.Sint64ListH\x00\x12)\n\tbool_list\x18\x17 \x01(\x0b\x32\x14.flwr.proto.BoolListH\x00\x12-\n\x0bstring_list\x18\x18 \x01(\x0b\x32\x16.flwr.proto.StringListH\x00\x12+\n\nbytes_list\x18\x19 \x01(\x0b\x32\x15.flwr.proto.BytesListH\x00\x42\x07\n\x05value\"M\n\x10ParametersRecord\x12\x11\n\tdata_keys\x18\x01 \x03(\t\x12&\n\x0b\x64\x61ta_values\x18\x02 \x03(\x0b\x32\x11.flwr.proto.Array\"\x8f\x01\n\rMetricsRecord\x12\x31\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32#.flwr.proto.MetricsRecord.DataEntry\x1aK\n\tDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12-\n\x05value\x18\x02 \x01(\x0b\x32\x1e.flwr.proto.MetricsRecordValue:\x02\x38\x01\"\x8f\x01\n\rConfigsRecord\x12\x31\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32#.flwr.proto.ConfigsRecord.DataEntry\x1aK\n\tDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12-\n\x05value\x18\x02 \x01(\x0b\x32\x1e.flwr.proto.ConfigsRecordValue:\x02\x38\x01\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1a\x66lwr/proto/recordset.proto\x12\nflwr.proto\"\x1a\n\nDoubleList\x12\x0c\n\x04vals\x18\x01 \x03(\x01\"\x1a\n\nSint64List\x12\x0c\n\x04vals\x18\x01 \x03(\x12\"\x18\n\x08\x42oolList\x12\x0c\n\x04vals\x18\x01 \x03(\x08\"\x1a\n\nStringList\x12\x0c\n\x04vals\x18\x01 \x03(\t\"\x19\n\tBytesList\x12\x0c\n\x04vals\x18\x01 \x03(\x0c\"B\n\x05\x41rray\x12\r\n\x05\x64type\x18\x01 \x01(\t\x12\r\n\x05shape\x18\x02 \x03(\x05\x12\r\n\x05stype\x18\x03 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\x0c\"\x9f\x01\n\x12MetricsRecordValue\x12\x10\n\x06\x64ouble\x18\x01 \x01(\x01H\x00\x12\x10\n\x06sint64\x18\x02 \x01(\x12H\x00\x12-\n\x0b\x64ouble_list\x18\x15 \x01(\x0b\x32\x16.flwr.proto.DoubleListH\x00\x12-\n\x0bsint64_list\x18\x16 \x01(\x0b\x32\x16.flwr.proto.Sint64ListH\x00\x42\x07\n\x05value\"\xd9\x02\n\x12\x43onfigsRecordValue\x12\x10\n\x06\x64ouble\x18\x01 \x01(\x01H\x00\x12\x10\n\x06sint64\x18\x02 \x01(\x12H\x00\x12\x0e\n\x04\x62ool\x18\x03 \x01(\x08H\x00\x12\x10\n\x06string\x18\x04 \x01(\tH\x00\x12\x0f\n\x05\x62ytes\x18\x05 \x01(\x0cH\x00\x12-\n\x0b\x64ouble_list\x18\x15 \x01(\x0b\x32\x16.flwr.proto.DoubleListH\x00\x12-\n\x0bsint64_list\x18\x16 \x01(\x0b\x32\x16.flwr.proto.Sint64ListH\x00\x12)\n\tbool_list\x18\x17 \x01(\x0b\x32\x14.flwr.proto.BoolListH\x00\x12-\n\x0bstring_list\x18\x18 \x01(\x0b\x32\x16.flwr.proto.StringListH\x00\x12+\n\nbytes_list\x18\x19 \x01(\x0b\x32\x15.flwr.proto.BytesListH\x00\x42\x07\n\x05value\"M\n\x10ParametersRecord\x12\x11\n\tdata_keys\x18\x01 \x03(\t\x12&\n\x0b\x64\x61ta_values\x18\x02 \x03(\x0b\x32\x11.flwr.proto.Array\"\x8f\x01\n\rMetricsRecord\x12\x31\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32#.flwr.proto.MetricsRecord.DataEntry\x1aK\n\tDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12-\n\x05value\x18\x02 \x01(\x0b\x32\x1e.flwr.proto.MetricsRecordValue:\x02\x38\x01\"\x8f\x01\n\rConfigsRecord\x12\x31\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32#.flwr.proto.ConfigsRecord.DataEntry\x1aK\n\tDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12-\n\x05value\x18\x02 \x01(\x0b\x32\x1e.flwr.proto.ConfigsRecordValue:\x02\x38\x01\"\x97\x03\n\tRecordSet\x12\x39\n\nparameters\x18\x01 \x03(\x0b\x32%.flwr.proto.RecordSet.ParametersEntry\x12\x33\n\x07metrics\x18\x02 \x03(\x0b\x32\".flwr.proto.RecordSet.MetricsEntry\x12\x33\n\x07\x63onfigs\x18\x03 \x03(\x0b\x32\".flwr.proto.RecordSet.ConfigsEntry\x1aO\n\x0fParametersEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12+\n\x05value\x18\x02 \x01(\x0b\x32\x1c.flwr.proto.ParametersRecord:\x02\x38\x01\x1aI\n\x0cMetricsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12(\n\x05value\x18\x02 \x01(\x0b\x32\x19.flwr.proto.MetricsRecord:\x02\x38\x01\x1aI\n\x0c\x43onfigsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12(\n\x05value\x18\x02 \x01(\x0b\x32\x19.flwr.proto.ConfigsRecord:\x02\x38\x01\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -25,6 +25,12 @@ _globals['_METRICSRECORD_DATAENTRY']._serialized_options = b'8\001' _globals['_CONFIGSRECORD_DATAENTRY']._options = None _globals['_CONFIGSRECORD_DATAENTRY']._serialized_options = b'8\001' + _globals['_RECORDSET_PARAMETERSENTRY']._options = None + _globals['_RECORDSET_PARAMETERSENTRY']._serialized_options = b'8\001' + _globals['_RECORDSET_METRICSENTRY']._options = None + _globals['_RECORDSET_METRICSENTRY']._serialized_options = b'8\001' + _globals['_RECORDSET_CONFIGSENTRY']._options = None + _globals['_RECORDSET_CONFIGSENTRY']._serialized_options = b'8\001' _globals['_DOUBLELIST']._serialized_start=42 _globals['_DOUBLELIST']._serialized_end=68 _globals['_SINT64LIST']._serialized_start=70 @@ -51,4 +57,12 @@ _globals['_CONFIGSRECORD']._serialized_end=1126 _globals['_CONFIGSRECORD_DATAENTRY']._serialized_start=1051 _globals['_CONFIGSRECORD_DATAENTRY']._serialized_end=1126 + _globals['_RECORDSET']._serialized_start=1129 + _globals['_RECORDSET']._serialized_end=1536 + _globals['_RECORDSET_PARAMETERSENTRY']._serialized_start=1307 + _globals['_RECORDSET_PARAMETERSENTRY']._serialized_end=1386 + _globals['_RECORDSET_METRICSENTRY']._serialized_start=1388 + _globals['_RECORDSET_METRICSENTRY']._serialized_end=1461 + _globals['_RECORDSET_CONFIGSENTRY']._serialized_start=1463 + _globals['_RECORDSET_CONFIGSENTRY']._serialized_end=1536 # @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/recordset_pb2.pyi b/src/py/flwr/proto/recordset_pb2.pyi index 1e9556de9ce6..86244697129c 100644 --- a/src/py/flwr/proto/recordset_pb2.pyi +++ b/src/py/flwr/proto/recordset_pb2.pyi @@ -238,3 +238,68 @@ class ConfigsRecord(google.protobuf.message.Message): ) -> None: ... def ClearField(self, field_name: typing_extensions.Literal["data",b"data"]) -> None: ... global___ConfigsRecord = ConfigsRecord + +class RecordSet(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + class ParametersEntry(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + KEY_FIELD_NUMBER: builtins.int + VALUE_FIELD_NUMBER: builtins.int + key: typing.Text + @property + def value(self) -> global___ParametersRecord: ... + def __init__(self, + *, + key: typing.Text = ..., + value: typing.Optional[global___ParametersRecord] = ..., + ) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ... + + class MetricsEntry(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + KEY_FIELD_NUMBER: builtins.int + VALUE_FIELD_NUMBER: builtins.int + key: typing.Text + @property + def value(self) -> global___MetricsRecord: ... + def __init__(self, + *, + key: typing.Text = ..., + value: typing.Optional[global___MetricsRecord] = ..., + ) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ... + + class ConfigsEntry(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + KEY_FIELD_NUMBER: builtins.int + VALUE_FIELD_NUMBER: builtins.int + key: typing.Text + @property + def value(self) -> global___ConfigsRecord: ... + def __init__(self, + *, + key: typing.Text = ..., + value: typing.Optional[global___ConfigsRecord] = ..., + ) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ... + + PARAMETERS_FIELD_NUMBER: builtins.int + METRICS_FIELD_NUMBER: builtins.int + CONFIGS_FIELD_NUMBER: builtins.int + @property + def parameters(self) -> google.protobuf.internal.containers.MessageMap[typing.Text, global___ParametersRecord]: ... + @property + def metrics(self) -> google.protobuf.internal.containers.MessageMap[typing.Text, global___MetricsRecord]: ... + @property + def configs(self) -> google.protobuf.internal.containers.MessageMap[typing.Text, global___ConfigsRecord]: ... + def __init__(self, + *, + parameters: typing.Optional[typing.Mapping[typing.Text, global___ParametersRecord]] = ..., + metrics: typing.Optional[typing.Mapping[typing.Text, global___MetricsRecord]] = ..., + configs: typing.Optional[typing.Mapping[typing.Text, global___ConfigsRecord]] = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["configs",b"configs","metrics",b"metrics","parameters",b"parameters"]) -> None: ... +global___RecordSet = RecordSet diff --git a/src/py/flwr/proto/task_pb2.py b/src/py/flwr/proto/task_pb2.py index 963b07db94f8..f9b2180b15dd 100644 --- a/src/py/flwr/proto/task_pb2.py +++ b/src/py/flwr/proto/task_pb2.py @@ -17,7 +17,7 @@ from flwr.proto import transport_pb2 as flwr_dot_proto_dot_transport__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/task.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\x1a\x66lwr/proto/transport.proto\"\xd1\x02\n\x04Task\x12\"\n\x08producer\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\"\n\x08\x63onsumer\x18\x02 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x12\n\ncreated_at\x18\x03 \x01(\t\x12\x14\n\x0c\x64\x65livered_at\x18\x04 \x01(\t\x12\x0b\n\x03ttl\x18\x05 \x01(\t\x12\x10\n\x08\x61ncestry\x18\x06 \x03(\t\x12\x11\n\ttask_type\x18\x07 \x01(\t\x12)\n\x02sa\x18\x08 \x01(\x0b\x32\x1d.flwr.proto.SecureAggregation\x12<\n\x15legacy_server_message\x18\x65 \x01(\x0b\x32\x19.flwr.proto.ServerMessageB\x02\x18\x01\x12<\n\x15legacy_client_message\x18\x66 \x01(\x0b\x32\x19.flwr.proto.ClientMessageB\x02\x18\x01\"\\\n\x07TaskIns\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Task\"\\\n\x07TaskRes\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Task\"\xcc\x02\n\x05Value\x12\x10\n\x06\x64ouble\x18\x01 \x01(\x01H\x00\x12\x10\n\x06sint64\x18\x02 \x01(\x12H\x00\x12\x0e\n\x04\x62ool\x18\x03 \x01(\x08H\x00\x12\x10\n\x06string\x18\x04 \x01(\tH\x00\x12\x0f\n\x05\x62ytes\x18\x05 \x01(\x0cH\x00\x12-\n\x0b\x64ouble_list\x18\x15 \x01(\x0b\x32\x16.flwr.proto.DoubleListH\x00\x12-\n\x0bsint64_list\x18\x16 \x01(\x0b\x32\x16.flwr.proto.Sint64ListH\x00\x12)\n\tbool_list\x18\x17 \x01(\x0b\x32\x14.flwr.proto.BoolListH\x00\x12-\n\x0bstring_list\x18\x18 \x01(\x0b\x32\x16.flwr.proto.StringListH\x00\x12+\n\nbytes_list\x18\x19 \x01(\x0b\x32\x15.flwr.proto.BytesListH\x00\x42\x07\n\x05value\"\xa0\x01\n\x11SecureAggregation\x12\x44\n\x0cnamed_values\x18\x01 \x03(\x0b\x32..flwr.proto.SecureAggregation.NamedValuesEntry\x1a\x45\n\x10NamedValuesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12 \n\x05value\x18\x02 \x01(\x0b\x32\x11.flwr.proto.Value:\x02\x38\x01\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/task.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\x1a\x66lwr/proto/transport.proto\"\xff\x02\n\x04Task\x12\"\n\x08producer\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\"\n\x08\x63onsumer\x18\x02 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x12\n\ncreated_at\x18\x03 \x01(\t\x12\x14\n\x0c\x64\x65livered_at\x18\x04 \x01(\t\x12\x0b\n\x03ttl\x18\x05 \x01(\t\x12\x10\n\x08\x61ncestry\x18\x06 \x03(\t\x12\x11\n\ttask_type\x18\x07 \x01(\t\x12(\n\trecordset\x18\x08 \x01(\x0b\x32\x15.flwr.proto.RecordSet\x12<\n\x15legacy_server_message\x18\x65 \x01(\x0b\x32\x19.flwr.proto.ServerMessageB\x02\x18\x01\x12<\n\x15legacy_client_message\x18\x66 \x01(\x0b\x32\x19.flwr.proto.ClientMessageB\x02\x18\x01\x12-\n\x02sa\x18g \x01(\x0b\x32\x1d.flwr.proto.SecureAggregationB\x02\x18\x01\"\\\n\x07TaskIns\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Task\"\\\n\x07TaskRes\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Task\"\xcc\x02\n\x05Value\x12\x10\n\x06\x64ouble\x18\x01 \x01(\x01H\x00\x12\x10\n\x06sint64\x18\x02 \x01(\x12H\x00\x12\x0e\n\x04\x62ool\x18\x03 \x01(\x08H\x00\x12\x10\n\x06string\x18\x04 \x01(\tH\x00\x12\x0f\n\x05\x62ytes\x18\x05 \x01(\x0cH\x00\x12-\n\x0b\x64ouble_list\x18\x15 \x01(\x0b\x32\x16.flwr.proto.DoubleListH\x00\x12-\n\x0bsint64_list\x18\x16 \x01(\x0b\x32\x16.flwr.proto.Sint64ListH\x00\x12)\n\tbool_list\x18\x17 \x01(\x0b\x32\x14.flwr.proto.BoolListH\x00\x12-\n\x0bstring_list\x18\x18 \x01(\x0b\x32\x16.flwr.proto.StringListH\x00\x12+\n\nbytes_list\x18\x19 \x01(\x0b\x32\x15.flwr.proto.BytesListH\x00\x42\x07\n\x05value\"\xa0\x01\n\x11SecureAggregation\x12\x44\n\x0cnamed_values\x18\x01 \x03(\x0b\x32..flwr.proto.SecureAggregation.NamedValuesEntry\x1a\x45\n\x10NamedValuesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12 \n\x05value\x18\x02 \x01(\x0b\x32\x11.flwr.proto.Value:\x02\x38\x01\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -28,18 +28,20 @@ _globals['_TASK'].fields_by_name['legacy_server_message']._serialized_options = b'\030\001' _globals['_TASK'].fields_by_name['legacy_client_message']._options = None _globals['_TASK'].fields_by_name['legacy_client_message']._serialized_options = b'\030\001' + _globals['_TASK'].fields_by_name['sa']._options = None + _globals['_TASK'].fields_by_name['sa']._serialized_options = b'\030\001' _globals['_SECUREAGGREGATION_NAMEDVALUESENTRY']._options = None _globals['_SECUREAGGREGATION_NAMEDVALUESENTRY']._serialized_options = b'8\001' _globals['_TASK']._serialized_start=117 - _globals['_TASK']._serialized_end=454 - _globals['_TASKINS']._serialized_start=456 - _globals['_TASKINS']._serialized_end=548 - _globals['_TASKRES']._serialized_start=550 - _globals['_TASKRES']._serialized_end=642 - _globals['_VALUE']._serialized_start=645 - _globals['_VALUE']._serialized_end=977 - _globals['_SECUREAGGREGATION']._serialized_start=980 - _globals['_SECUREAGGREGATION']._serialized_end=1140 - _globals['_SECUREAGGREGATION_NAMEDVALUESENTRY']._serialized_start=1071 - _globals['_SECUREAGGREGATION_NAMEDVALUESENTRY']._serialized_end=1140 + _globals['_TASK']._serialized_end=500 + _globals['_TASKINS']._serialized_start=502 + _globals['_TASKINS']._serialized_end=594 + _globals['_TASKRES']._serialized_start=596 + _globals['_TASKRES']._serialized_end=688 + _globals['_VALUE']._serialized_start=691 + _globals['_VALUE']._serialized_end=1023 + _globals['_SECUREAGGREGATION']._serialized_start=1026 + _globals['_SECUREAGGREGATION']._serialized_end=1186 + _globals['_SECUREAGGREGATION_NAMEDVALUESENTRY']._serialized_start=1117 + _globals['_SECUREAGGREGATION_NAMEDVALUESENTRY']._serialized_end=1186 # @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/task_pb2.pyi b/src/py/flwr/proto/task_pb2.pyi index ebe69d05c974..39119797c9e4 100644 --- a/src/py/flwr/proto/task_pb2.pyi +++ b/src/py/flwr/proto/task_pb2.pyi @@ -23,9 +23,10 @@ class Task(google.protobuf.message.Message): TTL_FIELD_NUMBER: builtins.int ANCESTRY_FIELD_NUMBER: builtins.int TASK_TYPE_FIELD_NUMBER: builtins.int - SA_FIELD_NUMBER: builtins.int + RECORDSET_FIELD_NUMBER: builtins.int LEGACY_SERVER_MESSAGE_FIELD_NUMBER: builtins.int LEGACY_CLIENT_MESSAGE_FIELD_NUMBER: builtins.int + SA_FIELD_NUMBER: builtins.int @property def producer(self) -> flwr.proto.node_pb2.Node: ... @property @@ -37,11 +38,13 @@ class Task(google.protobuf.message.Message): def ancestry(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[typing.Text]: ... task_type: typing.Text @property - def sa(self) -> global___SecureAggregation: ... + def recordset(self) -> flwr.proto.recordset_pb2.RecordSet: ... @property def legacy_server_message(self) -> flwr.proto.transport_pb2.ServerMessage: ... @property def legacy_client_message(self) -> flwr.proto.transport_pb2.ClientMessage: ... + @property + def sa(self) -> global___SecureAggregation: ... def __init__(self, *, producer: typing.Optional[flwr.proto.node_pb2.Node] = ..., @@ -51,12 +54,13 @@ class Task(google.protobuf.message.Message): ttl: typing.Text = ..., ancestry: typing.Optional[typing.Iterable[typing.Text]] = ..., task_type: typing.Text = ..., - sa: typing.Optional[global___SecureAggregation] = ..., + recordset: typing.Optional[flwr.proto.recordset_pb2.RecordSet] = ..., legacy_server_message: typing.Optional[flwr.proto.transport_pb2.ServerMessage] = ..., legacy_client_message: typing.Optional[flwr.proto.transport_pb2.ClientMessage] = ..., + sa: typing.Optional[global___SecureAggregation] = ..., ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal["consumer",b"consumer","legacy_client_message",b"legacy_client_message","legacy_server_message",b"legacy_server_message","producer",b"producer","sa",b"sa"]) -> builtins.bool: ... - def ClearField(self, field_name: typing_extensions.Literal["ancestry",b"ancestry","consumer",b"consumer","created_at",b"created_at","delivered_at",b"delivered_at","legacy_client_message",b"legacy_client_message","legacy_server_message",b"legacy_server_message","producer",b"producer","sa",b"sa","task_type",b"task_type","ttl",b"ttl"]) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["consumer",b"consumer","legacy_client_message",b"legacy_client_message","legacy_server_message",b"legacy_server_message","producer",b"producer","recordset",b"recordset","sa",b"sa"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["ancestry",b"ancestry","consumer",b"consumer","created_at",b"created_at","delivered_at",b"delivered_at","legacy_client_message",b"legacy_client_message","legacy_server_message",b"legacy_server_message","producer",b"producer","recordset",b"recordset","sa",b"sa","task_type",b"task_type","ttl",b"ttl"]) -> None: ... global___Task = Task class TaskIns(google.protobuf.message.Message): From 214d1c873487b3ad9d466eaa878dcfeda4d27d82 Mon Sep 17 00:00:00 2001 From: Javier Date: Tue, 23 Jan 2024 17:53:05 +0000 Subject: [PATCH 7/7] Convert `RecordSet` to/from legacy `*Ins/*Res` (#2828) --- src/py/flwr/common/recordset_compat.py | 401 ++++++++++++++++++++ src/py/flwr/common/recordset_compat_test.py | 234 ++++++++++++ src/py/flwr/common/recordset_test.py | 3 +- src/py/flwr/common/recordset_utils.py | 87 ----- 4 files changed, 636 insertions(+), 89 deletions(-) create mode 100644 src/py/flwr/common/recordset_compat.py create mode 100644 src/py/flwr/common/recordset_compat_test.py delete mode 100644 src/py/flwr/common/recordset_utils.py diff --git a/src/py/flwr/common/recordset_compat.py b/src/py/flwr/common/recordset_compat.py new file mode 100644 index 000000000000..c45f7fcd9fb8 --- /dev/null +++ b/src/py/flwr/common/recordset_compat.py @@ -0,0 +1,401 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""RecordSet utilities.""" + + +from typing import Dict, Mapping, OrderedDict, Tuple, Union, cast, get_args + +from .configsrecord import ConfigsRecord +from .metricsrecord import MetricsRecord +from .parametersrecord import Array, ParametersRecord +from .recordset import RecordSet +from .typing import ( + Code, + ConfigsRecordValues, + EvaluateIns, + EvaluateRes, + FitIns, + FitRes, + GetParametersIns, + GetParametersRes, + GetPropertiesIns, + GetPropertiesRes, + MetricsRecordValues, + Parameters, + Scalar, + Status, +) + + +def parametersrecord_to_parameters( + record: ParametersRecord, keep_input: bool = False +) -> Parameters: + """Convert ParameterRecord to legacy Parameters. + + Warning: Because `Arrays` in `ParametersRecord` encode more information of the + array-like or tensor-like data (e.g their datatype, shape) than `Parameters` it + might not be possible to reconstruct such data structures from `Parameters` objects + alone. Additional information or metadta must be provided from elsewhere. + + Parameters + ---------- + record : ParametersRecord + The record to be conveted into Parameters. + keep_input : bool (default: False) + A boolean indicating whether entries in the record should be deleted from the + input dictionary immediately after adding them to the record. + """ + parameters = Parameters(tensors=[], tensor_type="") + + for key in list(record.data.keys()): + parameters.tensors.append(record[key].data) + + if not parameters.tensor_type: + # Setting from first array in record. Recall the warning in the docstrings + # of this function. + parameters.tensor_type = record[key].stype + + if not keep_input: + del record.data[key] + + return parameters + + +def parameters_to_parametersrecord( + parameters: Parameters, keep_input: bool = False +) -> ParametersRecord: + """Convert legacy Parameters into a single ParametersRecord. + + Because there is no concept of names in the legacy Parameters, arbitrary keys will + be used when constructing the ParametersRecord. Similarly, the shape and data type + won't be recorded in the Array objects. + + Parameters + ---------- + parameters : Parameters + Parameters object to be represented as a ParametersRecord. + keep_input : bool (default: False) + A boolean indicating whether parameters should be deleted from the input + Parameters object (i.e. a list of serialized NumPy arrays) immediately after + adding them to the record. + """ + tensor_type = parameters.tensor_type + + p_record = ParametersRecord() + + num_arrays = len(parameters.tensors) + for idx in range(num_arrays): + if keep_input: + tensor = parameters.tensors[idx] + else: + tensor = parameters.tensors.pop(0) + p_record.set_parameters( + OrderedDict( + {str(idx): Array(data=tensor, dtype="", stype=tensor_type, shape=[])} + ) + ) + + return p_record + + +def _check_mapping_from_recordscalartype_to_scalar( + record_data: Mapping[str, Union[ConfigsRecordValues, MetricsRecordValues]] +) -> Dict[str, Scalar]: + """Check mapping `common.*RecordValues` into `common.Scalar` is possible.""" + for value in record_data.values(): + if not isinstance(value, get_args(Scalar)): + raise TypeError( + "There is not a 1:1 mapping between `common.Scalar` types and those " + "supported in `common.ConfigsRecordValues` or " + "`common.ConfigsRecordValues`. Consider casting your values to a type " + "supported by the `common.RecordSet` infrastructure. " + f"You used type: {type(value)}" + ) + return cast(Dict[str, Scalar], record_data) + + +def _recordset_to_fit_or_evaluate_ins_components( + recordset: RecordSet, + ins_str: str, + keep_input: bool, +) -> Tuple[Parameters, Dict[str, Scalar]]: + """Derive Fit/Evaluate Ins from a RecordSet.""" + # get Array and construct Parameters + parameters_record = recordset.get_parameters(f"{ins_str}.parameters") + + parameters = parametersrecord_to_parameters( + parameters_record, keep_input=keep_input + ) + + # get config dict + config_record = recordset.get_configs(f"{ins_str}.config") + + config_dict = _check_mapping_from_recordscalartype_to_scalar(config_record.data) + + return parameters, config_dict + + +def _fit_or_evaluate_ins_to_recordset( + ins: Union[FitIns, EvaluateIns], keep_input: bool +) -> RecordSet: + recordset = RecordSet() + + ins_str = "fitins" if isinstance(ins, FitIns) else "evaluateins" + recordset.set_parameters( + name=f"{ins_str}.parameters", + record=parameters_to_parametersrecord(ins.parameters, keep_input=keep_input), + ) + + recordset.set_configs( + name=f"{ins_str}.config", record=ConfigsRecord(ins.config) # type: ignore + ) + + return recordset + + +def _embed_status_into_recordset( + res_str: str, status: Status, recordset: RecordSet +) -> RecordSet: + status_dict: Dict[str, ConfigsRecordValues] = { + "code": int(status.code.value), + "message": status.message, + } + # we add it to a `ConfigsRecord`` because the `status.message`` is a string + # and `str` values aren't supported in `MetricsRecords` + recordset.set_configs(f"{res_str}.status", record=ConfigsRecord(status_dict)) + return recordset + + +def _extract_status_from_recordset(res_str: str, recordset: RecordSet) -> Status: + status = recordset.get_configs(f"{res_str}.status") + code = cast(int, status["code"]) + return Status(code=Code(code), message=str(status["message"])) + + +def recordset_to_fitins(recordset: RecordSet, keep_input: bool) -> FitIns: + """Derive FitIns from a RecordSet object.""" + parameters, config = _recordset_to_fit_or_evaluate_ins_components( + recordset, + ins_str="fitins", + keep_input=keep_input, + ) + + return FitIns(parameters=parameters, config=config) + + +def fitins_to_recordset(fitins: FitIns, keep_input: bool) -> RecordSet: + """Construct a RecordSet from a FitIns object.""" + return _fit_or_evaluate_ins_to_recordset(fitins, keep_input) + + +def recordset_to_fitres(recordset: RecordSet, keep_input: bool) -> FitRes: + """Derive FitRes from a RecordSet object.""" + ins_str = "fitres" + parameters = parametersrecord_to_parameters( + recordset.get_parameters(f"{ins_str}.parameters"), keep_input=keep_input + ) + + num_examples = cast( + int, recordset.get_metrics(f"{ins_str}.num_examples")["num_examples"] + ) + configs_record = recordset.get_configs(f"{ins_str}.metrics") + + metrics = _check_mapping_from_recordscalartype_to_scalar(configs_record.data) + status = _extract_status_from_recordset(ins_str, recordset) + + return FitRes( + status=status, parameters=parameters, num_examples=num_examples, metrics=metrics + ) + + +def fitres_to_recordset(fitres: FitRes, keep_input: bool) -> RecordSet: + """Construct a RecordSet from a FitRes object.""" + recordset = RecordSet() + + res_str = "fitres" + + recordset.set_configs( + name=f"{res_str}.metrics", record=ConfigsRecord(fitres.metrics) # type: ignore + ) + recordset.set_metrics( + name=f"{res_str}.num_examples", + record=MetricsRecord({"num_examples": fitres.num_examples}), + ) + recordset.set_parameters( + name=f"{res_str}.parameters", + record=parameters_to_parametersrecord(fitres.parameters, keep_input), + ) + + # status + recordset = _embed_status_into_recordset(res_str, fitres.status, recordset) + + return recordset + + +def recordset_to_evaluateins(recordset: RecordSet, keep_input: bool) -> EvaluateIns: + """Derive EvaluateIns from a RecordSet object.""" + parameters, config = _recordset_to_fit_or_evaluate_ins_components( + recordset, + ins_str="evaluateins", + keep_input=keep_input, + ) + + return EvaluateIns(parameters=parameters, config=config) + + +def evaluateins_to_recordset(evaluateins: EvaluateIns, keep_input: bool) -> RecordSet: + """Construct a RecordSet from a EvaluateIns object.""" + return _fit_or_evaluate_ins_to_recordset(evaluateins, keep_input) + + +def recordset_to_evaluateres(recordset: RecordSet) -> EvaluateRes: + """Derive EvaluateRes from a RecordSet object.""" + ins_str = "evaluateres" + + loss = cast(int, recordset.get_metrics(f"{ins_str}.loss")["loss"]) + + num_examples = cast( + int, recordset.get_metrics(f"{ins_str}.num_examples")["num_examples"] + ) + configs_record = recordset.get_configs(f"{ins_str}.metrics") + + metrics = _check_mapping_from_recordscalartype_to_scalar(configs_record.data) + status = _extract_status_from_recordset(ins_str, recordset) + + return EvaluateRes( + status=status, loss=loss, num_examples=num_examples, metrics=metrics + ) + + +def evaluateres_to_recordset(evaluateres: EvaluateRes) -> RecordSet: + """Construct a RecordSet from a EvaluateRes object.""" + recordset = RecordSet() + + res_str = "evaluateres" + # loss + recordset.set_metrics( + name=f"{res_str}.loss", + record=MetricsRecord({"loss": evaluateres.loss}), + ) + + # num_examples + recordset.set_metrics( + name=f"{res_str}.num_examples", + record=MetricsRecord({"num_examples": evaluateres.num_examples}), + ) + + # metrics + recordset.set_configs( + name=f"{res_str}.metrics", + record=ConfigsRecord(evaluateres.metrics), # type: ignore + ) + + # status + recordset = _embed_status_into_recordset( + f"{res_str}", evaluateres.status, recordset + ) + + return recordset + + +def recordset_to_getparametersins(recordset: RecordSet) -> GetParametersIns: + """Derive GetParametersIns from a RecordSet object.""" + config_record = recordset.get_configs("getparametersins.config") + + config_dict = _check_mapping_from_recordscalartype_to_scalar(config_record.data) + + return GetParametersIns(config=config_dict) + + +def getparametersins_to_recordset(getparameters_ins: GetParametersIns) -> RecordSet: + """Construct a RecordSet from a GetParametersIns object.""" + recordset = RecordSet() + + recordset.set_configs( + name="getparametersins.config", + record=ConfigsRecord(getparameters_ins.config), # type: ignore + ) + return recordset + + +def getparametersres_to_recordset(getparametersres: GetParametersRes) -> RecordSet: + """Construct a RecordSet from a GetParametersRes object.""" + recordset = RecordSet() + res_str = "getparametersres" + parameters_record = parameters_to_parametersrecord(getparametersres.parameters) + recordset.set_parameters(f"{res_str}.parameters", parameters_record) + + # status + recordset = _embed_status_into_recordset( + res_str, getparametersres.status, recordset + ) + + return recordset + + +def recordset_to_getparametersres(recordset: RecordSet) -> GetParametersRes: + """Derive GetParametersRes from a RecordSet object.""" + res_str = "getparametersres" + parameters = parametersrecord_to_parameters( + recordset.get_parameters(f"{res_str}.parameters") + ) + + status = _extract_status_from_recordset(res_str, recordset) + return GetParametersRes(status=status, parameters=parameters) + + +def recordset_to_getpropertiesins(recordset: RecordSet) -> GetPropertiesIns: + """Derive GetPropertiesIns from a RecordSet object.""" + config_record = recordset.get_configs("getpropertiesins.config") + config_dict = _check_mapping_from_recordscalartype_to_scalar(config_record.data) + + return GetPropertiesIns(config=config_dict) + + +def getpropertiesins_to_recordset(getpropertiesins: GetPropertiesIns) -> RecordSet: + """Construct a RecordSet from a GetPropertiesRes object.""" + recordset = RecordSet() + recordset.set_configs( + name="getpropertiesins.config", + record=ConfigsRecord(getpropertiesins.config), # type: ignore + ) + return recordset + + +def recordset_to_getpropertiesres(recordset: RecordSet) -> GetPropertiesRes: + """Derive GetPropertiesRes from a RecordSet object.""" + res_str = "getpropertiesres" + config_record = recordset.get_configs(f"{res_str}.properties") + properties = _check_mapping_from_recordscalartype_to_scalar(config_record.data) + + status = _extract_status_from_recordset(res_str, recordset=recordset) + + return GetPropertiesRes(status=status, properties=properties) + + +def getpropertiesres_to_recordset(getpropertiesres: GetPropertiesRes) -> RecordSet: + """Construct a RecordSet from a GetPropertiesRes object.""" + recordset = RecordSet() + res_str = "getpropertiesres" + recordset.set_configs( + name=f"{res_str}.properties", + record=ConfigsRecord(getpropertiesres.properties), # type: ignore + ) + # status + recordset = _embed_status_into_recordset( + res_str, getpropertiesres.status, recordset + ) + + return recordset diff --git a/src/py/flwr/common/recordset_compat_test.py b/src/py/flwr/common/recordset_compat_test.py new file mode 100644 index 000000000000..ad91cd3a42fc --- /dev/null +++ b/src/py/flwr/common/recordset_compat_test.py @@ -0,0 +1,234 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""RecordSet from legacy messages tests.""" + +from copy import deepcopy +from typing import Dict + +import numpy as np + +from .parameter import ndarrays_to_parameters +from .recordset_compat import ( + evaluateins_to_recordset, + evaluateres_to_recordset, + fitins_to_recordset, + fitres_to_recordset, + getparametersins_to_recordset, + getparametersres_to_recordset, + getpropertiesins_to_recordset, + getpropertiesres_to_recordset, + recordset_to_evaluateins, + recordset_to_evaluateres, + recordset_to_fitins, + recordset_to_fitres, + recordset_to_getparametersins, + recordset_to_getparametersres, + recordset_to_getpropertiesins, + recordset_to_getpropertiesres, +) +from .typing import ( + Code, + EvaluateIns, + EvaluateRes, + FitIns, + FitRes, + GetParametersIns, + GetParametersRes, + GetPropertiesIns, + GetPropertiesRes, + NDArrays, + Scalar, + Status, +) + + +def get_ndarrays() -> NDArrays: + """Return list of NumPy arrays.""" + arr1 = np.array([[1.0, 2.0], [3.0, 4], [5.0, 6.0]]) + arr2 = np.eye(2, 7, 3) + + return [arr1, arr2] + + +################################################## +# Testing conversion: *Ins --> RecordSet --> *Ins +# Testing conversion: *Res <-- RecordSet <-- *Res +################################################## + + +def _get_valid_fitins() -> FitIns: + arrays = get_ndarrays() + return FitIns(parameters=ndarrays_to_parameters(arrays), config={"a": 1.0, "b": 0}) + + +def _get_valid_fitres() -> FitRes: + """Returnn Valid parameters but potentially invalid config.""" + arrays = get_ndarrays() + metrics: Dict[str, Scalar] = {"a": 1.0, "b": 0} + return FitRes( + parameters=ndarrays_to_parameters(arrays), + num_examples=1, + status=Status(code=Code(0), message=""), + metrics=metrics, + ) + + +def _get_valid_evaluateins() -> EvaluateIns: + fit_ins = _get_valid_fitins() + return EvaluateIns(parameters=fit_ins.parameters, config=fit_ins.config) + + +def _get_valid_evaluateres() -> EvaluateRes: + """Return potentially invalid config.""" + metrics: Dict[str, Scalar] = {"a": 1.0, "b": 0} + return EvaluateRes( + num_examples=1, + loss=0.1, + status=Status(code=Code(0), message=""), + metrics=metrics, + ) + + +def _get_valid_getparametersins() -> GetParametersIns: + config_dict: Dict[str, Scalar] = { + "a": 1.0, + "b": 3, + "c": True, + } # valid since both Ins/Res communicate over ConfigsRecord + + return GetParametersIns(config_dict) + + +def _get_valid_getparametersres() -> GetParametersRes: + arrays = get_ndarrays() + return GetParametersRes( + status=Status(code=Code(0), message=""), + parameters=ndarrays_to_parameters(arrays), + ) + + +def _get_valid_getpropertiesins() -> GetPropertiesIns: + getparamsins = _get_valid_getparametersins() + return GetPropertiesIns(config=getparamsins.config) + + +def _get_valid_getpropertiesres() -> GetPropertiesRes: + config_dict: Dict[str, Scalar] = { + "a": 1.0, + "b": 3, + "c": True, + } # valid since both Ins/Res communicate over ConfigsRecord + + return GetPropertiesRes( + status=Status(code=Code(0), message=""), properties=config_dict + ) + + +def test_fitins_to_recordset_and_back() -> None: + """Test conversion FitIns --> RecordSet --> FitIns.""" + fitins = _get_valid_fitins() + + fitins_copy = deepcopy(fitins) + + recordset = fitins_to_recordset(fitins, keep_input=False) + + fitins_ = recordset_to_fitins(recordset, keep_input=False) + + assert fitins_copy == fitins_ + + +def test_fitres_to_recordset_and_back() -> None: + """Test conversion FitRes --> RecordSet --> FitRes.""" + fitres = _get_valid_fitres() + + fitres_copy = deepcopy(fitres) + + recordset = fitres_to_recordset(fitres, keep_input=False) + fitres_ = recordset_to_fitres(recordset, keep_input=False) + + assert fitres_copy == fitres_ + + +def test_evaluateins_to_recordset_and_back() -> None: + """Test conversion EvaluateIns --> RecordSet --> EvaluateIns.""" + evaluateins = _get_valid_evaluateins() + + evaluateins_copy = deepcopy(evaluateins) + + recordset = evaluateins_to_recordset(evaluateins, keep_input=False) + + evaluateins_ = recordset_to_evaluateins(recordset, keep_input=False) + + assert evaluateins_copy == evaluateins_ + + +def test_evaluateres_to_recordset_and_back() -> None: + """Test conversion EvaluateRes --> RecordSet --> EvaluateRes.""" + evaluateres = _get_valid_evaluateres() + + evaluateres_copy = deepcopy(evaluateres) + + recordset = evaluateres_to_recordset(evaluateres) + evaluateres_ = recordset_to_evaluateres(recordset) + + assert evaluateres_copy == evaluateres_ + + +def test_get_properties_ins_to_recordset_and_back() -> None: + """Test conversion GetPropertiesIns --> RecordSet --> GetPropertiesIns.""" + getproperties_ins = _get_valid_getpropertiesins() + + getproperties_ins_copy = deepcopy(getproperties_ins) + + recordset = getpropertiesins_to_recordset(getproperties_ins) + getproperties_ins_ = recordset_to_getpropertiesins(recordset) + + assert getproperties_ins_copy == getproperties_ins_ + + +def test_get_properties_res_to_recordset_and_back() -> None: + """Test conversion GetPropertiesRes --> RecordSet --> GetPropertiesRes.""" + getproperties_res = _get_valid_getpropertiesres() + + getproperties_res_copy = deepcopy(getproperties_res) + + recordset = getpropertiesres_to_recordset(getproperties_res) + getproperties_res_ = recordset_to_getpropertiesres(recordset) + + assert getproperties_res_copy == getproperties_res_ + + +def test_get_parameters_ins_to_recordset_and_back() -> None: + """Test conversion GetParametersIns --> RecordSet --> GetParametersIns.""" + getparameters_ins = _get_valid_getparametersins() + + getparameters_ins_copy = deepcopy(getparameters_ins) + + recordset = getparametersins_to_recordset(getparameters_ins) + getparameters_ins_ = recordset_to_getparametersins(recordset) + + assert getparameters_ins_copy == getparameters_ins_ + + +def test_get_parameters_res_to_recordset_and_back() -> None: + """Test conversion GetParametersRes --> RecordSet --> GetParametersRes.""" + getparameteres_res = _get_valid_getparametersres() + + getparameters_res_copy = deepcopy(getparameteres_res) + + recordset = getparametersres_to_recordset(getparameteres_res) + getparameteres_res_ = recordset_to_getparametersres(recordset) + + assert getparameters_res_copy == getparameteres_res_ diff --git a/src/py/flwr/common/recordset_test.py b/src/py/flwr/common/recordset_test.py index 83e1e4595f1d..e1825eaeef14 100644 --- a/src/py/flwr/common/recordset_test.py +++ b/src/py/flwr/common/recordset_test.py @@ -14,7 +14,6 @@ # ============================================================================== """RecordSet tests.""" - from typing import Callable, Dict, List, OrderedDict, Type, Union import numpy as np @@ -24,7 +23,7 @@ from .metricsrecord import MetricsRecord from .parameter import ndarrays_to_parameters, parameters_to_ndarrays from .parametersrecord import Array, ParametersRecord -from .recordset_utils import ( +from .recordset_compat import ( parameters_to_parametersrecord, parametersrecord_to_parameters, ) diff --git a/src/py/flwr/common/recordset_utils.py b/src/py/flwr/common/recordset_utils.py deleted file mode 100644 index c1e724fa2758..000000000000 --- a/src/py/flwr/common/recordset_utils.py +++ /dev/null @@ -1,87 +0,0 @@ -# Copyright 2024 Flower Labs GmbH. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""RecordSet utilities.""" - - -from typing import OrderedDict - -from .parametersrecord import Array, ParametersRecord -from .typing import Parameters - - -def parametersrecord_to_parameters( - record: ParametersRecord, keep_input: bool = False -) -> Parameters: - """Convert ParameterRecord to legacy Parameters. - - Warning: Because `Arrays` in `ParametersRecord` encode more information of the - array-like or tensor-like data (e.g their datatype, shape) than `Parameters` it - might not be possible to reconstruct such data structures from `Parameters` objects - alone. Additional information or metadta must be provided from elsewhere. - - Parameters - ---------- - record : ParametersRecord - The record to be conveted into Parameters. - keep_input : bool (default: False) - A boolean indicating whether entries in the record should be deleted from the - input dictionary immediately after adding them to the record. - """ - parameters = Parameters(tensors=[], tensor_type="") - - for key in list(record.data.keys()): - parameters.tensors.append(record.data[key].data) - - if not keep_input: - del record.data[key] - - return parameters - - -def parameters_to_parametersrecord( - parameters: Parameters, keep_input: bool = False -) -> ParametersRecord: - """Convert legacy Parameters into a single ParametersRecord. - - Because there is no concept of names in the legacy Parameters, arbitrary keys will - be used when constructing the ParametersRecord. Similarly, the shape and data type - won't be recorded in the Array objects. - - Parameters - ---------- - parameters : Parameters - Parameters object to be represented as a ParametersRecord. - keep_input : bool (default: False) - A boolean indicating whether parameters should be deleted from the input - Parameters object (i.e. a list of serialized NumPy arrays) immediately after - adding them to the record. - """ - tensor_type = parameters.tensor_type - - p_record = ParametersRecord() - - num_arrays = len(parameters.tensors) - for idx in range(num_arrays): - if keep_input: - tensor = parameters.tensors[idx] - else: - tensor = parameters.tensors.pop(0) - p_record.set_parameters( - OrderedDict( - {str(idx): Array(data=tensor, dtype="", stype=tensor_type, shape=[])} - ) - ) - - return p_record