Skip to content

Commit

Permalink
Merge branch 'main' into recordset-to-legacy-ins-res
Browse files Browse the repository at this point in the history
  • Loading branch information
danieljanes authored Jan 23, 2024
2 parents 21f07a0 + 0e0af84 commit 6d8e788
Show file tree
Hide file tree
Showing 8 changed files with 208 additions and 19 deletions.
6 changes: 6 additions & 0 deletions src/proto/flwr/proto/recordset.proto
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,9 @@ message ParametersRecord {
message MetricsRecord { map<string, MetricsRecordValue> data = 1; }

message ConfigsRecord { map<string, ConfigsRecordValue> data = 1; }

message RecordSet {
map<string, ParametersRecord> parameters = 1;
map<string, MetricsRecord> metrics = 2;
map<string, ConfigsRecord> configs = 3;
}
3 changes: 2 additions & 1 deletion src/proto/flwr/proto/task.proto
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,11 @@ message Task {
string ttl = 5;
repeated string ancestry = 6;
string task_type = 7;
SecureAggregation sa = 8;
RecordSet recordset = 8;

ServerMessage legacy_server_message = 101 [ deprecated = true ];
ClientMessage legacy_client_message = 102 [ deprecated = true ];
SecureAggregation sa = 103 [ deprecated = true ];
}

message TaskIns {
Expand Down
32 changes: 32 additions & 0 deletions src/py/flwr/common/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from flwr.proto.recordset_pb2 import MetricsRecord as ProtoMetricsRecord
from flwr.proto.recordset_pb2 import MetricsRecordValue as ProtoMetricsRecordValue
from flwr.proto.recordset_pb2 import ParametersRecord as ProtoParametersRecord
from flwr.proto.recordset_pb2 import RecordSet as ProtoRecordSet
from flwr.proto.recordset_pb2 import Sint64List, StringList
from flwr.proto.task_pb2 import Value
from flwr.proto.transport_pb2 import (
Expand All @@ -45,6 +46,7 @@
from .configsrecord import ConfigsRecord
from .metricsrecord import MetricsRecord
from .parametersrecord import Array, ParametersRecord
from .recordset import RecordSet

# === ServerMessage message ===

Expand Down Expand Up @@ -719,3 +721,33 @@ def configs_record_from_proto(record_proto: ProtoConfigsRecord) -> ConfigsRecord
),
keep_input=False,
)


# === RecordSet message ===


def recordset_to_proto(recordset: RecordSet) -> ProtoRecordSet:
"""Serialize RecordSet to ProtoBuf."""
return ProtoRecordSet(
parameters={
k: parameters_record_to_proto(v) for k, v in recordset.parameters.items()
},
metrics={k: metrics_record_to_proto(v) for k, v in recordset.metrics.items()},
configs={k: configs_record_to_proto(v) for k, v in recordset.configs.items()},
)


def recordset_from_proto(recordset_proto: ProtoRecordSet) -> RecordSet:
"""Deserialize RecordSet from ProtoBuf."""
return RecordSet(
parameters={
k: parameters_record_from_proto(v)
for k, v in recordset_proto.parameters.items()
},
metrics={
k: metrics_record_from_proto(v) for k, v in recordset_proto.metrics.items()
},
configs={
k: configs_record_from_proto(v) for k, v in recordset_proto.configs.items()
},
)
65 changes: 65 additions & 0 deletions src/py/flwr/common/serde_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,14 @@
from flwr.proto.recordset_pb2 import ConfigsRecord as ProtoConfigsRecord
from flwr.proto.recordset_pb2 import MetricsRecord as ProtoMetricsRecord
from flwr.proto.recordset_pb2 import ParametersRecord as ProtoParametersRecord
from flwr.proto.recordset_pb2 import RecordSet as ProtoRecordSet

# pylint: enable=E0611
from . import typing
from .configsrecord import ConfigsRecord
from .metricsrecord import MetricsRecord
from .parametersrecord import Array, ParametersRecord
from .recordset import RecordSet
from .serde import (
array_from_proto,
array_to_proto,
Expand All @@ -40,6 +42,8 @@
named_values_to_proto,
parameters_record_from_proto,
parameters_record_to_proto,
recordset_from_proto,
recordset_to_proto,
scalar_from_proto,
scalar_to_proto,
status_from_proto,
Expand Down Expand Up @@ -242,3 +246,64 @@ def test_configs_record_serialization_deserialization() -> None:
# Assert
assert isinstance(proto, ProtoConfigsRecord)
assert original.data == deserialized.data


def test_recordset_serialization_deserialization() -> None:
"""Test serialization and deserialization of RecordSet."""
# Prepare
encoder_params_record = ParametersRecord(
array_dict=OrderedDict(
[
(
"k1",
Array(dtype="float", shape=[2, 2], stype="dense", data=b"1234"),
),
("k2", Array(dtype="int", shape=[3], stype="sparse", data=b"567")),
]
),
keep_input=False,
)
decoder_params_record = ParametersRecord(
array_dict=OrderedDict(
[
(
"k1",
Array(
dtype="float", shape=[32, 32, 4], stype="dense", data=b"0987"
),
),
]
),
keep_input=False,
)

original = RecordSet(
parameters={
"encoder_parameters": encoder_params_record,
"decoder_parameters": decoder_params_record,
},
metrics={
"acc_metrics": MetricsRecord(
metrics_dict={"accuracy": 0.95, "loss": 0.1}, keep_input=False
)
},
configs={
"my_configs": ConfigsRecord(
configs_dict={
"learning_rate": 0.01,
"batch_size": 32,
"public_key": b"21f8sioj@!#",
"log": "Hello, world!",
},
keep_input=False,
)
},
)

# Execute
proto = recordset_to_proto(original)
deserialized = recordset_from_proto(proto)

# Assert
assert isinstance(proto, ProtoRecordSet)
assert original == deserialized
16 changes: 15 additions & 1 deletion src/py/flwr/proto/recordset_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

65 changes: 65 additions & 0 deletions src/py/flwr/proto/recordset_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -238,3 +238,68 @@ class ConfigsRecord(google.protobuf.message.Message):
) -> None: ...
def ClearField(self, field_name: typing_extensions.Literal["data",b"data"]) -> None: ...
global___ConfigsRecord = ConfigsRecord

class RecordSet(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
class ParametersEntry(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
KEY_FIELD_NUMBER: builtins.int
VALUE_FIELD_NUMBER: builtins.int
key: typing.Text
@property
def value(self) -> global___ParametersRecord: ...
def __init__(self,
*,
key: typing.Text = ...,
value: typing.Optional[global___ParametersRecord] = ...,
) -> None: ...
def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ...
def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ...

class MetricsEntry(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
KEY_FIELD_NUMBER: builtins.int
VALUE_FIELD_NUMBER: builtins.int
key: typing.Text
@property
def value(self) -> global___MetricsRecord: ...
def __init__(self,
*,
key: typing.Text = ...,
value: typing.Optional[global___MetricsRecord] = ...,
) -> None: ...
def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ...
def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ...

class ConfigsEntry(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
KEY_FIELD_NUMBER: builtins.int
VALUE_FIELD_NUMBER: builtins.int
key: typing.Text
@property
def value(self) -> global___ConfigsRecord: ...
def __init__(self,
*,
key: typing.Text = ...,
value: typing.Optional[global___ConfigsRecord] = ...,
) -> None: ...
def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ...
def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ...

PARAMETERS_FIELD_NUMBER: builtins.int
METRICS_FIELD_NUMBER: builtins.int
CONFIGS_FIELD_NUMBER: builtins.int
@property
def parameters(self) -> google.protobuf.internal.containers.MessageMap[typing.Text, global___ParametersRecord]: ...
@property
def metrics(self) -> google.protobuf.internal.containers.MessageMap[typing.Text, global___MetricsRecord]: ...
@property
def configs(self) -> google.protobuf.internal.containers.MessageMap[typing.Text, global___ConfigsRecord]: ...
def __init__(self,
*,
parameters: typing.Optional[typing.Mapping[typing.Text, global___ParametersRecord]] = ...,
metrics: typing.Optional[typing.Mapping[typing.Text, global___MetricsRecord]] = ...,
configs: typing.Optional[typing.Mapping[typing.Text, global___ConfigsRecord]] = ...,
) -> None: ...
def ClearField(self, field_name: typing_extensions.Literal["configs",b"configs","metrics",b"metrics","parameters",b"parameters"]) -> None: ...
global___RecordSet = RecordSet
Loading

0 comments on commit 6d8e788

Please sign in to comment.