Skip to content

Commit

Permalink
udpate proto and add serde functions
Browse files Browse the repository at this point in the history
  • Loading branch information
panh99 committed Jan 19, 2024
1 parent a5eee50 commit 2864538
Show file tree
Hide file tree
Showing 13 changed files with 1,026 additions and 165 deletions.
36 changes: 36 additions & 0 deletions src/proto/flwr/proto/recordset.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Copyright 2022 Flower Labs GmbH. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// ==============================================================================

syntax = "proto3";

package flwr.proto;

import "flwr/proto/task.proto";

message Array {
string dtype = 1;
repeated int32 shape = 2;
string stype = 3;
bytes data = 4;
}

message ParametersRecord {
repeated string data_keys = 1;
repeated Array data_values = 2;
}

message MetricsRecord { map<string, Value> data = 1; }

message ConfigsRecord { map<string, Value> data = 1; }
83 changes: 82 additions & 1 deletion src/py/flwr/common/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,12 @@
"""ProtoBuf serialization and deserialization."""


from typing import Any, Dict, List, MutableMapping, cast
from typing import Any, Dict, List, MutableMapping, OrderedDict, cast

from flwr.proto.recordset_pb2 import Array as ArrayProto
from flwr.proto.recordset_pb2 import ConfigsRecord as ConfigsRecordProto
from flwr.proto.recordset_pb2 import MetricsRecord as MetricsRecordProto
from flwr.proto.recordset_pb2 import ParametersRecord as ParametersRecordProto
from flwr.proto.task_pb2 import Value # pylint: disable=E0611
from flwr.proto.transport_pb2 import ( # pylint: disable=E0611
ClientMessage,
Expand All @@ -29,6 +33,9 @@
)

from . import typing
from .configsrecord import ConfigsRecord
from .metricsrecord import MetricsRecord
from .parametersrecord import Array, ParametersRecord

# === ServerMessage message ===

Expand Down Expand Up @@ -573,3 +580,77 @@ def named_values_from_proto(
) -> Dict[str, typing.Value]:
"""Deserialize named values from ProtoBuf."""
return {name: value_from_proto(value) for name, value in named_values_proto.items()}


# === Record messages ===


def array_to_proto(array: Array) -> ArrayProto:
"""Serialize Array to ProtoBuf."""
return ArrayProto(**vars(array))


def array_from_proto(array_proto: ArrayProto) -> Array:
"""Deserialize Array from ProtoBuf."""
return Array(
dtype=array_proto.dtype,
shape=list(array_proto.shape),
stype=array_proto.stype,
data=array_proto.data,
)


def parameters_record_to_proto(record: ParametersRecord) -> ParametersRecordProto:
"""Serialize ParametersRecord to ProtoBuf."""
return ParametersRecordProto(
data_keys=record.data.keys(),
data_values=map(array_to_proto, record.data.values()),
)


def parameters_record_from_proto(
record_proto: ParametersRecordProto,
) -> ParametersRecord:
"""Deserialize ParametersRecord from ProtoBuf."""
return ParametersRecord(
array_dict=OrderedDict(
zip(record_proto.data_keys, map(array_from_proto, record_proto.data_values))
),
keep_input=False,
)


def metrics_record_to_proto(record: MetricsRecord) -> MetricsRecordProto:
"""Serialize MetricsRecord to ProtoBuf."""
return MetricsRecordProto(
data=named_values_to_proto(cast(Dict[str, typing.Value], record.data))
)


def metrics_record_from_proto(record_proto: MetricsRecordProto) -> MetricsRecord:
"""Deserialize MetricsRecord from ProtoBuf."""
return MetricsRecord(
metrics_dict=cast(
Dict[str, typing.MetricsRecordValues],
named_values_from_proto(record_proto.data),
),
keep_input=False,
)


def configs_record_to_proto(record: ConfigsRecord) -> ConfigsRecordProto:
"""Serialize ConfigsRecord to ProtoBuf."""
return ConfigsRecordProto(
data=named_values_to_proto(cast(Dict[str, typing.Value], record.data))
)


def configs_record_from_proto(record_proto: ConfigsRecordProto) -> ConfigsRecord:
"""Deserialize ConfigsRecord from ProtoBuf."""
return ConfigsRecord(
configs_dict=cast(
Dict[str, typing.ConfigsRecordValues],
named_values_from_proto(record_proto.data),
),
keep_input=False,
)
87 changes: 85 additions & 2 deletions src/py/flwr/common/serde_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,29 @@
"""(De-)serialization tests."""


from typing import Dict, Union, cast
from typing import Dict, OrderedDict, Union, cast

from flwr.common import typing
from flwr.proto import transport_pb2 as pb2 # pylint: disable=E0611

from flwr.proto.recordset_pb2 import Array as ArrayProto
from flwr.proto.recordset_pb2 import ConfigsRecord as ConfigsRecordProto
from flwr.proto.recordset_pb2 import MetricsRecord as MetricsRecordProto
from flwr.proto.recordset_pb2 import ParametersRecord as ParametersRecordProto

from .configsrecord import ConfigsRecord
from .metricsrecord import MetricsRecord
from .parametersrecord import Array, ParametersRecord
from .serde import (
array_from_proto,
array_to_proto,
configs_record_from_proto,
configs_record_to_proto,
metrics_record_from_proto,
metrics_record_to_proto,
named_values_from_proto,
named_values_to_proto,
parameters_record_from_proto,
parameters_record_to_proto,
scalar_from_proto,
scalar_to_proto,
status_from_proto,
Expand Down Expand Up @@ -157,3 +172,71 @@ def test_named_values_serialization_deserialization() -> None:
assert elm1 == elm2
else:
assert expected == actual


def test_array_serialization_deserialization() -> None:
"""Test serialization and deserialization of Array."""
# Prepare
original = Array(dtype="float", shape=[2, 2], stype="dense", data=b"1234")

# Execute
proto = array_to_proto(original)
deserialized = array_from_proto(proto)

# Assert
assert isinstance(proto, ArrayProto)
assert original == deserialized


def test_parameters_record_serialization_deserialization() -> None:
"""Test serialization and deserialization of ParametersRecord."""
# Prepare
original = ParametersRecord(
array_dict=OrderedDict(
[
("k1", Array(dtype="float", shape=[2, 2], stype="dense", data=b"1234")),
("k2", Array(dtype="int", shape=[3], stype="sparse", data=b"567")),
]
),
keep_input=False,
)

# Execute
proto = parameters_record_to_proto(original)
deserialized = parameters_record_from_proto(proto)

# Assert
assert isinstance(proto, ParametersRecordProto)
assert original.data == deserialized.data


def test_metrics_record_serialization_deserialization() -> None:
"""Test serialization and deserialization of MetricsRecord."""
# Prepare
original = MetricsRecord(
metrics_dict={"accuracy": 0.95, "loss": 0.1}, keep_input=False
)

# Execute
proto = metrics_record_to_proto(original)
deserialized = metrics_record_from_proto(proto)

# Assert
assert isinstance(proto, MetricsRecordProto)
assert original.data == deserialized.data


def test_configs_record_serialization_deserialization() -> None:
"""Test serialization and deserialization of ConfigsRecord."""
# Prepare
original = ConfigsRecord(
configs_dict={"learning_rate": 0.01, "batch_size": 32}, keep_input=False
)

# Execute
proto = configs_record_to_proto(original)
deserialized = configs_record_from_proto(proto)

# Assert
assert isinstance(proto, ConfigsRecordProto)
assert original.data == deserialized.data
111 changes: 88 additions & 23 deletions src/py/flwr/proto/driver_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 2864538

Please sign in to comment.