Skip to content

Commit

Permalink
add serde functions for recordset
Browse files Browse the repository at this point in the history
  • Loading branch information
panh99 committed Jan 23, 2024
1 parent 1ce30a3 commit 4dd9375
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 0 deletions.
32 changes: 32 additions & 0 deletions src/py/flwr/common/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from flwr.proto.recordset_pb2 import MetricsRecord as ProtoMetricsRecord
from flwr.proto.recordset_pb2 import MetricsRecordValue as ProtoMetricsRecordValue
from flwr.proto.recordset_pb2 import ParametersRecord as ProtoParametersRecord
from flwr.proto.recordset_pb2 import RecordSet as ProtoRecordSet
from flwr.proto.recordset_pb2 import Sint64List, StringList
from flwr.proto.task_pb2 import Value
from flwr.proto.transport_pb2 import (
Expand All @@ -45,6 +46,7 @@
from .configsrecord import ConfigsRecord
from .metricsrecord import MetricsRecord
from .parametersrecord import Array, ParametersRecord
from .recordset import RecordSet

# === ServerMessage message ===

Expand Down Expand Up @@ -719,3 +721,33 @@ def configs_record_from_proto(record_proto: ProtoConfigsRecord) -> ConfigsRecord
),
keep_input=False,
)


# === RecordSet message ===


def recordset_to_proto(recordset: RecordSet) -> ProtoRecordSet:
"""Serialize RecordSet to ProtoBuf."""
return ProtoRecordSet(
parameters={
k: parameters_record_to_proto(v) for k, v in recordset.parameters.items()
},
metrics={k: metrics_record_to_proto(v) for k, v in recordset.metrics.items()},
configs={k: configs_record_to_proto(v) for k, v in recordset.configs.items()},
)


def recordset_from_proto(recordset_proto: ProtoRecordSet) -> RecordSet:
"""Deserialize RecordSet from ProtoBuf."""
return RecordSet(
parameters={
k: parameters_record_from_proto(v)
for k, v in recordset_proto.parameters.items()
},
metrics={
k: metrics_record_from_proto(v) for k, v in recordset_proto.metrics.items()
},
configs={
k: configs_record_from_proto(v) for k, v in recordset_proto.configs.items()
},
)
65 changes: 65 additions & 0 deletions src/py/flwr/common/serde_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,14 @@
from flwr.proto.recordset_pb2 import ConfigsRecord as ProtoConfigsRecord
from flwr.proto.recordset_pb2 import MetricsRecord as ProtoMetricsRecord
from flwr.proto.recordset_pb2 import ParametersRecord as ProtoParametersRecord
from flwr.proto.recordset_pb2 import RecordSet as ProtoRecordSet

# pylint: enable=E0611
from . import typing
from .configsrecord import ConfigsRecord
from .metricsrecord import MetricsRecord
from .parametersrecord import Array, ParametersRecord
from .recordset import RecordSet
from .serde import (
array_from_proto,
array_to_proto,
Expand All @@ -40,6 +42,8 @@
named_values_to_proto,
parameters_record_from_proto,
parameters_record_to_proto,
recordset_from_proto,
recordset_to_proto,
scalar_from_proto,
scalar_to_proto,
status_from_proto,
Expand Down Expand Up @@ -242,3 +246,64 @@ def test_configs_record_serialization_deserialization() -> None:
# Assert
assert isinstance(proto, ProtoConfigsRecord)
assert original.data == deserialized.data


def test_recordset_serialization_deserialization() -> None:
"""Test serialization and deserialization of RecordSet."""
# Prepare
encoder_params_record = ParametersRecord(
array_dict=OrderedDict(
[
(
"k1",
Array(dtype="float", shape=[2, 2], stype="dense", data=b"1234"),
),
("k2", Array(dtype="int", shape=[3], stype="sparse", data=b"567")),
]
),
keep_input=False,
)
decoder_params_record = ParametersRecord(
array_dict=OrderedDict(
[
(
"k1",
Array(
dtype="float", shape=[32, 32, 4], stype="dense", data=b"0987"
),
),
]
),
keep_input=False,
)

original = RecordSet(
parameters={
"encoder_parameters": encoder_params_record,
"decoder_parameters": decoder_params_record,
},
metrics={
"acc_metrics": MetricsRecord(
metrics_dict={"accuracy": 0.95, "loss": 0.1}, keep_input=False
)
},
configs={
"my_configs": ConfigsRecord(
configs_dict={
"learning_rate": 0.01,
"batch_size": 32,
"public_key": b"21f8sioj@!#",
"log": "Hello, world!",
},
keep_input=False,
)
},
)

# Execute
proto = recordset_to_proto(original)
deserialized = recordset_from_proto(proto)

# Assert
assert isinstance(proto, ProtoRecordSet)
assert original == deserialized

0 comments on commit 4dd9375

Please sign in to comment.