diff --git a/src/py/flwr/common/serde.py b/src/py/flwr/common/serde.py index 30dc1bb60f48..e78948dbffa6 100644 --- a/src/py/flwr/common/serde.py +++ b/src/py/flwr/common/serde.py @@ -15,13 +15,20 @@ """ProtoBuf serialization and deserialization.""" -from typing import Any, Dict, List, MutableMapping, OrderedDict, cast +from typing import Any, Dict, List, MutableMapping, OrderedDict, Type, TypeVar, cast + +from google.protobuf.message import Message # pylint: disable=E0611 -from flwr.proto.recordset_pb2 import Array as ArrayProto -from flwr.proto.recordset_pb2 import ConfigsRecord as ConfigsRecordProto -from flwr.proto.recordset_pb2 import MetricsRecord as MetricsRecordProto -from flwr.proto.recordset_pb2 import ParametersRecord as ParametersRecordProto +from flwr.proto.recordset_pb2 import Array as ProtoArray +from flwr.proto.recordset_pb2 import BoolList, BytesList +from flwr.proto.recordset_pb2 import ConfigsRecord as ProtoConfigsRecord +from flwr.proto.recordset_pb2 import ConfigsRecordValue as ProtoConfigsRecordValue +from flwr.proto.recordset_pb2 import DoubleList +from flwr.proto.recordset_pb2 import MetricsRecord as ProtoMetricsRecord +from flwr.proto.recordset_pb2 import MetricsRecordValue as ProtoMetricsRecordValue +from flwr.proto.recordset_pb2 import ParametersRecord as ProtoParametersRecord +from flwr.proto.recordset_pb2 import Sint64List, StringList from flwr.proto.task_pb2 import Value from flwr.proto.transport_pb2 import ( ClientMessage, @@ -502,7 +509,7 @@ def scalar_from_proto(scalar_msg: Scalar) -> typing.Scalar: # === Value messages === -_python_type_to_field_name = { +_type_to_field = { float: "double", int: "sint64", bool: "bool", @@ -511,22 +518,20 @@ def scalar_from_proto(scalar_msg: Scalar) -> typing.Scalar: } -_python_list_type_to_message_and_field_name = { - float: (Value.DoubleList, "double_list"), - int: (Value.Sint64List, "sint64_list"), - bool: (Value.BoolList, "bool_list"), - str: (Value.StringList, "string_list"), - bytes: (Value.BytesList, "bytes_list"), +_list_type_to_class_and_field = { + float: (DoubleList, "double_list"), + int: (Sint64List, "sint64_list"), + bool: (BoolList, "bool_list"), + str: (StringList, "string_list"), + bytes: (BytesList, "bytes_list"), } def _check_value(value: typing.Value) -> None: - if isinstance(value, tuple(_python_type_to_field_name.keys())): + if isinstance(value, tuple(_type_to_field.keys())): return if isinstance(value, list): - if len(value) > 0 and isinstance( - value[0], tuple(_python_type_to_field_name.keys()) - ): + if len(value) > 0 and isinstance(value[0], tuple(_type_to_field.keys())): data_type = type(value[0]) for element in value: if isinstance(element, data_type): @@ -548,12 +553,12 @@ def value_to_proto(value: typing.Value) -> Value: arg = {} if isinstance(value, list): - msg_class, field_name = _python_list_type_to_message_and_field_name[ + msg_class, field_name = _list_type_to_class_and_field[ type(value[0]) if len(value) > 0 else int ] arg[field_name] = msg_class(vals=value) else: - arg[_python_type_to_field_name[type(value)]] = value + arg[_type_to_field[type(value)]] = value return Value(**arg) @@ -587,12 +592,66 @@ def named_values_from_proto( # === Record messages === -def array_to_proto(array: Array) -> ArrayProto: +T = TypeVar("T") + + +def record_value_to_proto( + value: Any, allowed_types: List[type], proto_class: Type[T] +) -> T: + """Serialize `*RecordValue` to ProtoBuf.""" + arg = {} + for t in allowed_types: + # Single element + # Note: `isinstance(False, int) == True`. + if type(value) == t: # pylint: disable=C0123 + arg[_type_to_field[t]] = value + return proto_class(**arg) + # List + if isinstance(value, list) and all(isinstance(item, t) for item in value): + list_class, field_name = _list_type_to_class_and_field[t] + arg[field_name] = list_class(vals=value) + return proto_class(**arg) + # Invalid types + raise TypeError( + f"The type of the following value is not allowed " + f"in '{proto_class.__name__}':\n{value}" + ) + + +def record_value_from_proto(value_proto: Message) -> Any: + """Deserialize `*RecordValue` from ProtoBuf.""" + value_field = cast(str, value_proto.WhichOneof("value")) + if value_field.endswith("list"): + value = list(getattr(value_proto, value_field).vals) + else: + value = getattr(value_proto, value_field) + return value + + +def record_value_dict_to_proto( + value_dict: Dict[str, Any], allowed_types: List[type], value_proto_class: Type[T] +) -> Dict[str, T]: + """Serialize the record value dict to ProtoBuf.""" + + def proto(_v: Any) -> T: + return record_value_to_proto(_v, allowed_types, value_proto_class) + + return {k: proto(v) for k, v in value_dict.items()} + + +def record_value_dict_from_proto( + value_dict_proto: MutableMapping[str, Any] +) -> Dict[str, Any]: + """Deserialize the record value dict from ProtoBuf.""" + return {k: record_value_from_proto(v) for k, v in value_dict_proto.items()} + + +def array_to_proto(array: Array) -> ProtoArray: """Serialize Array to ProtoBuf.""" - return ArrayProto(**vars(array)) + return ProtoArray(**vars(array)) -def array_from_proto(array_proto: ArrayProto) -> Array: +def array_from_proto(array_proto: ProtoArray) -> Array: """Deserialize Array from ProtoBuf.""" return Array( dtype=array_proto.dtype, @@ -602,16 +661,16 @@ def array_from_proto(array_proto: ArrayProto) -> Array: ) -def parameters_record_to_proto(record: ParametersRecord) -> ParametersRecordProto: +def parameters_record_to_proto(record: ParametersRecord) -> ProtoParametersRecord: """Serialize ParametersRecord to ProtoBuf.""" - return ParametersRecordProto( + return ProtoParametersRecord( data_keys=record.data.keys(), data_values=map(array_to_proto, record.data.values()), ) def parameters_record_from_proto( - record_proto: ParametersRecordProto, + record_proto: ProtoParametersRecord, ) -> ParametersRecord: """Deserialize ParametersRecord from ProtoBuf.""" return ParametersRecord( @@ -622,37 +681,41 @@ def parameters_record_from_proto( ) -def metrics_record_to_proto(record: MetricsRecord) -> MetricsRecordProto: +def metrics_record_to_proto(record: MetricsRecord) -> ProtoMetricsRecord: """Serialize MetricsRecord to ProtoBuf.""" - return MetricsRecordProto( - data=named_values_to_proto(cast(Dict[str, typing.Value], record.data)) + return ProtoMetricsRecord( + data=record_value_dict_to_proto( + record.data, [float, int], ProtoMetricsRecordValue + ) ) -def metrics_record_from_proto(record_proto: MetricsRecordProto) -> MetricsRecord: +def metrics_record_from_proto(record_proto: ProtoMetricsRecord) -> MetricsRecord: """Deserialize MetricsRecord from ProtoBuf.""" return MetricsRecord( metrics_dict=cast( Dict[str, typing.MetricsRecordValues], - named_values_from_proto(record_proto.data), + record_value_dict_from_proto(record_proto.data), ), keep_input=False, ) -def configs_record_to_proto(record: ConfigsRecord) -> ConfigsRecordProto: +def configs_record_to_proto(record: ConfigsRecord) -> ProtoConfigsRecord: """Serialize ConfigsRecord to ProtoBuf.""" - return ConfigsRecordProto( - data=named_values_to_proto(cast(Dict[str, typing.Value], record.data)) + return ProtoConfigsRecord( + data=record_value_dict_to_proto( + record.data, [int, float, bool, str, bytes], ProtoConfigsRecordValue + ) ) -def configs_record_from_proto(record_proto: ConfigsRecordProto) -> ConfigsRecord: +def configs_record_from_proto(record_proto: ProtoConfigsRecord) -> ConfigsRecord: """Deserialize ConfigsRecord from ProtoBuf.""" return ConfigsRecord( configs_dict=cast( Dict[str, typing.ConfigsRecordValues], - named_values_from_proto(record_proto.data), + record_value_dict_from_proto(record_proto.data), ), keep_input=False, ) diff --git a/src/py/flwr/common/serde_test.py b/src/py/flwr/common/serde_test.py index dcb3c2d52e07..c584597d89f6 100644 --- a/src/py/flwr/common/serde_test.py +++ b/src/py/flwr/common/serde_test.py @@ -19,10 +19,10 @@ # pylint: disable=E0611 from flwr.proto import transport_pb2 as pb2 -from flwr.proto.recordset_pb2 import Array as ArrayProto -from flwr.proto.recordset_pb2 import ConfigsRecord as ConfigsRecordProto -from flwr.proto.recordset_pb2 import MetricsRecord as MetricsRecordProto -from flwr.proto.recordset_pb2 import ParametersRecord as ParametersRecordProto +from flwr.proto.recordset_pb2 import Array as ProtoArray +from flwr.proto.recordset_pb2 import ConfigsRecord as ProtoConfigsRecord +from flwr.proto.recordset_pb2 import MetricsRecord as ProtoMetricsRecord +from flwr.proto.recordset_pb2 import ParametersRecord as ProtoParametersRecord # pylint: enable=E0611 from . import typing @@ -186,7 +186,7 @@ def test_array_serialization_deserialization() -> None: deserialized = array_from_proto(proto) # Assert - assert isinstance(proto, ArrayProto) + assert isinstance(proto, ProtoArray) assert original == deserialized @@ -208,7 +208,7 @@ def test_parameters_record_serialization_deserialization() -> None: deserialized = parameters_record_from_proto(proto) # Assert - assert isinstance(proto, ParametersRecordProto) + assert isinstance(proto, ProtoParametersRecord) assert original.data == deserialized.data @@ -224,7 +224,7 @@ def test_metrics_record_serialization_deserialization() -> None: deserialized = metrics_record_from_proto(proto) # Assert - assert isinstance(proto, MetricsRecordProto) + assert isinstance(proto, ProtoMetricsRecord) assert original.data == deserialized.data @@ -240,5 +240,5 @@ def test_configs_record_serialization_deserialization() -> None: deserialized = configs_record_from_proto(proto) # Assert - assert isinstance(proto, ConfigsRecordProto) + assert isinstance(proto, ProtoConfigsRecord) assert original.data == deserialized.data