Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add *Record ProtoBuf messages and corresponding serde functions. #2831

Merged
merged 16 commits into from
Jan 22, 2024
36 changes: 36 additions & 0 deletions src/proto/flwr/proto/recordset.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Copyright 2024 Flower Labs GmbH. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// ==============================================================================

syntax = "proto3";

package flwr.proto;

import "flwr/proto/task.proto";

message Array {
string dtype = 1;
repeated int32 shape = 2;
string stype = 3;
bytes data = 4;
}

message ParametersRecord {
repeated string data_keys = 1;
repeated Array data_values = 2;
}

message MetricsRecord { map<string, Value> data = 1; }
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MetricsRecord only allows int and float types (and lists of those two types)

Copy link
Contributor Author

@panh99 panh99 Jan 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I get it.
I was thinking that, since we enforce type checks when instantiating the MetricsRecord, which means the data field inside must be valid, we can reuse Value to store them to avoid having repetitive code in proto files and in serde.py. But I also happy with copy-pasting part of the serde functions for Value and create new messages for MetricsRecordValue.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, makes sense. I suspected that to be the case :)

My recommendation would be to create new messages that map closely to the Python code for two reasons:

  1. Enable us to remove Value once we migrated everything to RecordSet
  2. Make it easier to implement non-Python clients based on the ProtoBuf definitions


message ConfigsRecord { map<string, Value> data = 1; }
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ConfigsRecord doesn't allow bool yet. That's sth we should fix in a separate PR by adding it to ConfigsRecord.

The more general question: should we define ProtoBuf messages that follow the same naming scheme as their Python dataclass counterpart?

Copy link
Contributor Author

@panh99 panh99 Jan 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason for reusing Value instead of creating a new message ConfigsRecordValue is the same as above (MetricsRecord).

Re: naming. IMO, we should use different names or even store them in a different way in a TaskIns. There's no benefits to keep them strictly equivalent. I think the ProtoBuf messages is aimed solely to transfer contents over the wire, which is different from the purpose of introducing record types. And hence I think we don't need to follow the same naming scheme and even don't need to have a counterpart for each dataclass record.

The advantage of not having counterparts is that we may not have to change protobuf messages and serde functions accordingly when we decide to modify our RecordSet. The disadvantage is that it's not easy to design a protobuf messages for general uses, and the naming will be less intuitive if we allow users to change protobuf messages in the future.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that we do not need to aim for a strict 1:1 mapping from Python to ProtoBuf.

I do think however that there is an advantage to keeping them close. As stated in the other comment, we want to implement Flower clients in languages other than Python. This usually starts with compiling existing ProtoBuf messages. If those messages are close to the Python level, it will be easier for others to implement the Java/C++/... counterpart.

In addition to that, it will maintenance of the Python client easier as well. In the case of MetricsRecord, for example, we would not need to check for unsupported Value types if we have a MetricsRecord on the ProtoBuf level that supports the exact set of types that the Python MetricsRecord supports.

83 changes: 82 additions & 1 deletion src/py/flwr/common/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,12 @@
"""ProtoBuf serialization and deserialization."""


from typing import Any, Dict, List, MutableMapping, cast
from typing import Any, Dict, List, MutableMapping, OrderedDict, cast

from flwr.proto.recordset_pb2 import Array as ArrayProto
from flwr.proto.recordset_pb2 import ConfigsRecord as ConfigsRecordProto
from flwr.proto.recordset_pb2 import MetricsRecord as MetricsRecordProto
from flwr.proto.recordset_pb2 import ParametersRecord as ParametersRecordProto
from flwr.proto.task_pb2 import Value # pylint: disable=E0611
from flwr.proto.transport_pb2 import ( # pylint: disable=E0611
ClientMessage,
Expand All @@ -29,6 +33,9 @@
)

from . import typing
from .configsrecord import ConfigsRecord
from .metricsrecord import MetricsRecord
from .parametersrecord import Array, ParametersRecord

# === ServerMessage message ===

Expand Down Expand Up @@ -573,3 +580,77 @@ def named_values_from_proto(
) -> Dict[str, typing.Value]:
"""Deserialize named values from ProtoBuf."""
return {name: value_from_proto(value) for name, value in named_values_proto.items()}


# === Record messages ===


def array_to_proto(array: Array) -> ArrayProto:
"""Serialize Array to ProtoBuf."""
return ArrayProto(**vars(array))


def array_from_proto(array_proto: ArrayProto) -> Array:
"""Deserialize Array from ProtoBuf."""
return Array(
dtype=array_proto.dtype,
shape=list(array_proto.shape),
stype=array_proto.stype,
data=array_proto.data,
)


def parameters_record_to_proto(record: ParametersRecord) -> ParametersRecordProto:
"""Serialize ParametersRecord to ProtoBuf."""
return ParametersRecordProto(
data_keys=record.data.keys(),
data_values=map(array_to_proto, record.data.values()),
)


def parameters_record_from_proto(
record_proto: ParametersRecordProto,
) -> ParametersRecord:
"""Deserialize ParametersRecord from ProtoBuf."""
return ParametersRecord(
array_dict=OrderedDict(
zip(record_proto.data_keys, map(array_from_proto, record_proto.data_values))
),
keep_input=False,
)


def metrics_record_to_proto(record: MetricsRecord) -> MetricsRecordProto:
"""Serialize MetricsRecord to ProtoBuf."""
return MetricsRecordProto(
data=named_values_to_proto(cast(Dict[str, typing.Value], record.data))
)


def metrics_record_from_proto(record_proto: MetricsRecordProto) -> MetricsRecord:
"""Deserialize MetricsRecord from ProtoBuf."""
return MetricsRecord(
metrics_dict=cast(
Dict[str, typing.MetricsRecordValues],
named_values_from_proto(record_proto.data),
),
keep_input=False,
)


def configs_record_to_proto(record: ConfigsRecord) -> ConfigsRecordProto:
"""Serialize ConfigsRecord to ProtoBuf."""
return ConfigsRecordProto(
data=named_values_to_proto(cast(Dict[str, typing.Value], record.data))
)


def configs_record_from_proto(record_proto: ConfigsRecordProto) -> ConfigsRecord:
"""Deserialize ConfigsRecord from ProtoBuf."""
return ConfigsRecord(
configs_dict=cast(
Dict[str, typing.ConfigsRecordValues],
named_values_from_proto(record_proto.data),
),
keep_input=False,
)
87 changes: 85 additions & 2 deletions src/py/flwr/common/serde_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,29 @@
"""(De-)serialization tests."""


from typing import Dict, Union, cast
from typing import Dict, OrderedDict, Union, cast

from flwr.common import typing
from flwr.proto import transport_pb2 as pb2 # pylint: disable=E0611

from flwr.proto.recordset_pb2 import Array as ArrayProto
from flwr.proto.recordset_pb2 import ConfigsRecord as ConfigsRecordProto
from flwr.proto.recordset_pb2 import MetricsRecord as MetricsRecordProto
from flwr.proto.recordset_pb2 import ParametersRecord as ParametersRecordProto

from .configsrecord import ConfigsRecord
from .metricsrecord import MetricsRecord
from .parametersrecord import Array, ParametersRecord
from .serde import (
array_from_proto,
array_to_proto,
configs_record_from_proto,
configs_record_to_proto,
metrics_record_from_proto,
metrics_record_to_proto,
named_values_from_proto,
named_values_to_proto,
parameters_record_from_proto,
parameters_record_to_proto,
scalar_from_proto,
scalar_to_proto,
status_from_proto,
Expand Down Expand Up @@ -157,3 +172,71 @@ def test_named_values_serialization_deserialization() -> None:
assert elm1 == elm2
else:
assert expected == actual


def test_array_serialization_deserialization() -> None:
"""Test serialization and deserialization of Array."""
# Prepare
original = Array(dtype="float", shape=[2, 2], stype="dense", data=b"1234")

# Execute
proto = array_to_proto(original)
deserialized = array_from_proto(proto)

# Assert
assert isinstance(proto, ArrayProto)
assert original == deserialized


def test_parameters_record_serialization_deserialization() -> None:
"""Test serialization and deserialization of ParametersRecord."""
# Prepare
original = ParametersRecord(
array_dict=OrderedDict(
[
("k1", Array(dtype="float", shape=[2, 2], stype="dense", data=b"1234")),
("k2", Array(dtype="int", shape=[3], stype="sparse", data=b"567")),
]
),
keep_input=False,
)

# Execute
proto = parameters_record_to_proto(original)
deserialized = parameters_record_from_proto(proto)

# Assert
assert isinstance(proto, ParametersRecordProto)
assert original.data == deserialized.data


def test_metrics_record_serialization_deserialization() -> None:
"""Test serialization and deserialization of MetricsRecord."""
# Prepare
original = MetricsRecord(
metrics_dict={"accuracy": 0.95, "loss": 0.1}, keep_input=False
)

# Execute
proto = metrics_record_to_proto(original)
deserialized = metrics_record_from_proto(proto)

# Assert
assert isinstance(proto, MetricsRecordProto)
assert original.data == deserialized.data


def test_configs_record_serialization_deserialization() -> None:
"""Test serialization and deserialization of ConfigsRecord."""
# Prepare
original = ConfigsRecord(
configs_dict={"learning_rate": 0.01, "batch_size": 32}, keep_input=False
)

# Execute
proto = configs_record_to_proto(original)
deserialized = configs_record_from_proto(proto)

# Assert
assert isinstance(proto, ConfigsRecordProto)
assert original.data == deserialized.data
41 changes: 41 additions & 0 deletions src/py/flwr/proto/recordset_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

104 changes: 104 additions & 0 deletions src/py/flwr/proto/recordset_pb2.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
"""
@generated by mypy-protobuf. Do not edit manually!
isort:skip_file
"""
import builtins
import flwr.proto.task_pb2
import google.protobuf.descriptor
import google.protobuf.internal.containers
import google.protobuf.message
import typing
import typing_extensions

DESCRIPTOR: google.protobuf.descriptor.FileDescriptor

class Array(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
DTYPE_FIELD_NUMBER: builtins.int
SHAPE_FIELD_NUMBER: builtins.int
STYPE_FIELD_NUMBER: builtins.int
DATA_FIELD_NUMBER: builtins.int
dtype: typing.Text
@property
def shape(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ...
stype: typing.Text
data: builtins.bytes
def __init__(self,
*,
dtype: typing.Text = ...,
shape: typing.Optional[typing.Iterable[builtins.int]] = ...,
stype: typing.Text = ...,
data: builtins.bytes = ...,
) -> None: ...
def ClearField(self, field_name: typing_extensions.Literal["data",b"data","dtype",b"dtype","shape",b"shape","stype",b"stype"]) -> None: ...
global___Array = Array

class ParametersRecord(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
DATA_KEYS_FIELD_NUMBER: builtins.int
DATA_VALUES_FIELD_NUMBER: builtins.int
@property
def data_keys(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[typing.Text]: ...
@property
def data_values(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Array]: ...
def __init__(self,
*,
data_keys: typing.Optional[typing.Iterable[typing.Text]] = ...,
data_values: typing.Optional[typing.Iterable[global___Array]] = ...,
) -> None: ...
def ClearField(self, field_name: typing_extensions.Literal["data_keys",b"data_keys","data_values",b"data_values"]) -> None: ...
global___ParametersRecord = ParametersRecord

class MetricsRecord(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
class DataEntry(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
KEY_FIELD_NUMBER: builtins.int
VALUE_FIELD_NUMBER: builtins.int
key: typing.Text
@property
def value(self) -> flwr.proto.task_pb2.Value: ...
def __init__(self,
*,
key: typing.Text = ...,
value: typing.Optional[flwr.proto.task_pb2.Value] = ...,
) -> None: ...
def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ...
def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ...

DATA_FIELD_NUMBER: builtins.int
@property
def data(self) -> google.protobuf.internal.containers.MessageMap[typing.Text, flwr.proto.task_pb2.Value]: ...
def __init__(self,
*,
data: typing.Optional[typing.Mapping[typing.Text, flwr.proto.task_pb2.Value]] = ...,
) -> None: ...
def ClearField(self, field_name: typing_extensions.Literal["data",b"data"]) -> None: ...
global___MetricsRecord = MetricsRecord

class ConfigsRecord(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
class DataEntry(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
KEY_FIELD_NUMBER: builtins.int
VALUE_FIELD_NUMBER: builtins.int
key: typing.Text
@property
def value(self) -> flwr.proto.task_pb2.Value: ...
def __init__(self,
*,
key: typing.Text = ...,
value: typing.Optional[flwr.proto.task_pb2.Value] = ...,
) -> None: ...
def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ...
def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ...

DATA_FIELD_NUMBER: builtins.int
@property
def data(self) -> google.protobuf.internal.containers.MessageMap[typing.Text, flwr.proto.task_pb2.Value]: ...
def __init__(self,
*,
data: typing.Optional[typing.Mapping[typing.Text, flwr.proto.task_pb2.Value]] = ...,
) -> None: ...
def ClearField(self, field_name: typing_extensions.Literal["data",b"data"]) -> None: ...
global___ConfigsRecord = ConfigsRecord
4 changes: 4 additions & 0 deletions src/py/flwr/proto/recordset_pb2_grpc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import grpc

4 changes: 4 additions & 0 deletions src/py/flwr/proto/recordset_pb2_grpc.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
"""
@generated by mypy-protobuf. Do not edit manually!
isort:skip_file
"""
Loading
Loading