Skip to content

Commit

Permalink
add serde functions
Browse files Browse the repository at this point in the history
  • Loading branch information
panh99 committed Jan 22, 2024
1 parent 673db0b commit 2b54a53
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 42 deletions.
131 changes: 97 additions & 34 deletions src/py/flwr/common/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,20 @@
"""ProtoBuf serialization and deserialization."""


from typing import Any, Dict, List, MutableMapping, OrderedDict, cast
from typing import Any, Dict, List, MutableMapping, OrderedDict, Type, TypeVar, cast

from google.protobuf.message import Message

# pylint: disable=E0611
from flwr.proto.recordset_pb2 import Array as ArrayProto
from flwr.proto.recordset_pb2 import ConfigsRecord as ConfigsRecordProto
from flwr.proto.recordset_pb2 import MetricsRecord as MetricsRecordProto
from flwr.proto.recordset_pb2 import ParametersRecord as ParametersRecordProto
from flwr.proto.recordset_pb2 import Array as ProtoArray
from flwr.proto.recordset_pb2 import BoolList, BytesList
from flwr.proto.recordset_pb2 import ConfigsRecord as ProtoConfigsRecord
from flwr.proto.recordset_pb2 import ConfigsRecordValue as ProtoConfigsRecordValue
from flwr.proto.recordset_pb2 import DoubleList
from flwr.proto.recordset_pb2 import MetricsRecord as ProtoMetricsRecord
from flwr.proto.recordset_pb2 import MetricsRecordValue as ProtoMetricsRecordValue
from flwr.proto.recordset_pb2 import ParametersRecord as ProtoParametersRecord
from flwr.proto.recordset_pb2 import Sint64List, StringList
from flwr.proto.task_pb2 import Value
from flwr.proto.transport_pb2 import (
ClientMessage,
Expand Down Expand Up @@ -502,7 +509,7 @@ def scalar_from_proto(scalar_msg: Scalar) -> typing.Scalar:
# === Value messages ===


_python_type_to_field_name = {
_type_to_field = {
float: "double",
int: "sint64",
bool: "bool",
Expand All @@ -511,22 +518,20 @@ def scalar_from_proto(scalar_msg: Scalar) -> typing.Scalar:
}


_python_list_type_to_message_and_field_name = {
float: (Value.DoubleList, "double_list"),
int: (Value.Sint64List, "sint64_list"),
bool: (Value.BoolList, "bool_list"),
str: (Value.StringList, "string_list"),
bytes: (Value.BytesList, "bytes_list"),
_list_type_to_class_and_field = {
float: (DoubleList, "double_list"),
int: (Sint64List, "sint64_list"),
bool: (BoolList, "bool_list"),
str: (StringList, "string_list"),
bytes: (BytesList, "bytes_list"),
}


def _check_value(value: typing.Value) -> None:
if isinstance(value, tuple(_python_type_to_field_name.keys())):
if isinstance(value, tuple(_type_to_field.keys())):
return
if isinstance(value, list):
if len(value) > 0 and isinstance(
value[0], tuple(_python_type_to_field_name.keys())
):
if len(value) > 0 and isinstance(value[0], tuple(_type_to_field.keys())):
data_type = type(value[0])
for element in value:
if isinstance(element, data_type):
Expand All @@ -548,12 +553,12 @@ def value_to_proto(value: typing.Value) -> Value:

arg = {}
if isinstance(value, list):
msg_class, field_name = _python_list_type_to_message_and_field_name[
msg_class, field_name = _list_type_to_class_and_field[
type(value[0]) if len(value) > 0 else int
]
arg[field_name] = msg_class(vals=value)
else:
arg[_python_type_to_field_name[type(value)]] = value
arg[_type_to_field[type(value)]] = value
return Value(**arg)


Expand Down Expand Up @@ -587,12 +592,66 @@ def named_values_from_proto(
# === Record messages ===


def array_to_proto(array: Array) -> ArrayProto:
T = TypeVar("T")


def record_value_to_proto(
value: Any, allowed_types: List[type], proto_class: Type[T]
) -> T:
"""Serialize `*RecordValue` to ProtoBuf."""
arg = {}
for t in allowed_types:
# Single element
# Note: `isinstance(False, int) == True`.
if type(value) == t: # pylint: disable=C0123
arg[_type_to_field[t]] = value
return proto_class(**arg)
# List
if isinstance(value, list) and all(isinstance(item, t) for item in value):
list_class, field_name = _list_type_to_class_and_field[t]
arg[field_name] = list_class(vals=value)
return proto_class(**arg)
# Invalid types
raise TypeError(
f"The type of the following value is not allowed "
f"in '{proto_class.__name__}':\n{value}"
)


def record_value_from_proto(value_proto: Message) -> Any:
"""Deserialize `*RecordValue` from ProtoBuf."""
value_field = cast(str, value_proto.WhichOneof("value"))
if value_field.endswith("list"):
value = list(getattr(value_proto, value_field).vals)
else:
value = getattr(value_proto, value_field)
return value


def record_value_dict_to_proto(
value_dict: Dict[str, Any], allowed_types: List[type], value_proto_class: Type[T]
) -> Dict[str, T]:
"""Serialize the record value dict to ProtoBuf."""

def proto(_v: Any) -> T:
return record_value_to_proto(_v, allowed_types, value_proto_class)

return {k: proto(v) for k, v in value_dict.items()}


def record_value_dict_from_proto(
value_dict_proto: MutableMapping[str, Any]
) -> Dict[str, Any]:
"""Deserialize the record value dict from ProtoBuf."""
return {k: record_value_from_proto(v) for k, v in value_dict_proto.items()}


def array_to_proto(array: Array) -> ProtoArray:
"""Serialize Array to ProtoBuf."""
return ArrayProto(**vars(array))
return ProtoArray(**vars(array))


def array_from_proto(array_proto: ArrayProto) -> Array:
def array_from_proto(array_proto: ProtoArray) -> Array:
"""Deserialize Array from ProtoBuf."""
return Array(
dtype=array_proto.dtype,
Expand All @@ -602,16 +661,16 @@ def array_from_proto(array_proto: ArrayProto) -> Array:
)


def parameters_record_to_proto(record: ParametersRecord) -> ParametersRecordProto:
def parameters_record_to_proto(record: ParametersRecord) -> ProtoParametersRecord:
"""Serialize ParametersRecord to ProtoBuf."""
return ParametersRecordProto(
return ProtoParametersRecord(
data_keys=record.data.keys(),
data_values=map(array_to_proto, record.data.values()),
)


def parameters_record_from_proto(
record_proto: ParametersRecordProto,
record_proto: ProtoParametersRecord,
) -> ParametersRecord:
"""Deserialize ParametersRecord from ProtoBuf."""
return ParametersRecord(
Expand All @@ -622,37 +681,41 @@ def parameters_record_from_proto(
)


def metrics_record_to_proto(record: MetricsRecord) -> MetricsRecordProto:
def metrics_record_to_proto(record: MetricsRecord) -> ProtoMetricsRecord:
"""Serialize MetricsRecord to ProtoBuf."""
return MetricsRecordProto(
data=named_values_to_proto(cast(Dict[str, typing.Value], record.data))
return ProtoMetricsRecord(
data=record_value_dict_to_proto(
record.data, [float, int], ProtoMetricsRecordValue
)
)


def metrics_record_from_proto(record_proto: MetricsRecordProto) -> MetricsRecord:
def metrics_record_from_proto(record_proto: ProtoMetricsRecord) -> MetricsRecord:
"""Deserialize MetricsRecord from ProtoBuf."""
return MetricsRecord(
metrics_dict=cast(
Dict[str, typing.MetricsRecordValues],
named_values_from_proto(record_proto.data),
record_value_dict_from_proto(record_proto.data),
),
keep_input=False,
)


def configs_record_to_proto(record: ConfigsRecord) -> ConfigsRecordProto:
def configs_record_to_proto(record: ConfigsRecord) -> ProtoConfigsRecord:
"""Serialize ConfigsRecord to ProtoBuf."""
return ConfigsRecordProto(
data=named_values_to_proto(cast(Dict[str, typing.Value], record.data))
return ProtoConfigsRecord(
data=record_value_dict_to_proto(
record.data, [int, float, bool, str, bytes], ProtoConfigsRecordValue
)
)


def configs_record_from_proto(record_proto: ConfigsRecordProto) -> ConfigsRecord:
def configs_record_from_proto(record_proto: ProtoConfigsRecord) -> ConfigsRecord:
"""Deserialize ConfigsRecord from ProtoBuf."""
return ConfigsRecord(
configs_dict=cast(
Dict[str, typing.ConfigsRecordValues],
named_values_from_proto(record_proto.data),
record_value_dict_from_proto(record_proto.data),
),
keep_input=False,
)
16 changes: 8 additions & 8 deletions src/py/flwr/common/serde_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@

# pylint: disable=E0611
from flwr.proto import transport_pb2 as pb2
from flwr.proto.recordset_pb2 import Array as ArrayProto
from flwr.proto.recordset_pb2 import ConfigsRecord as ConfigsRecordProto
from flwr.proto.recordset_pb2 import MetricsRecord as MetricsRecordProto
from flwr.proto.recordset_pb2 import ParametersRecord as ParametersRecordProto
from flwr.proto.recordset_pb2 import Array as ProtoArray
from flwr.proto.recordset_pb2 import ConfigsRecord as ProtoConfigsRecord
from flwr.proto.recordset_pb2 import MetricsRecord as ProtoMetricsRecord
from flwr.proto.recordset_pb2 import ParametersRecord as ProtoParametersRecord

# pylint: enable=E0611
from . import typing
Expand Down Expand Up @@ -186,7 +186,7 @@ def test_array_serialization_deserialization() -> None:
deserialized = array_from_proto(proto)

# Assert
assert isinstance(proto, ArrayProto)
assert isinstance(proto, ProtoArray)
assert original == deserialized


Expand All @@ -208,7 +208,7 @@ def test_parameters_record_serialization_deserialization() -> None:
deserialized = parameters_record_from_proto(proto)

# Assert
assert isinstance(proto, ParametersRecordProto)
assert isinstance(proto, ProtoParametersRecord)
assert original.data == deserialized.data


Expand All @@ -224,7 +224,7 @@ def test_metrics_record_serialization_deserialization() -> None:
deserialized = metrics_record_from_proto(proto)

# Assert
assert isinstance(proto, MetricsRecordProto)
assert isinstance(proto, ProtoMetricsRecord)
assert original.data == deserialized.data


Expand All @@ -240,5 +240,5 @@ def test_configs_record_serialization_deserialization() -> None:
deserialized = configs_record_from_proto(proto)

# Assert
assert isinstance(proto, ConfigsRecordProto)
assert isinstance(proto, ProtoConfigsRecord)
assert original.data == deserialized.data

0 comments on commit 2b54a53

Please sign in to comment.