Skip to content

Commit

Permalink
Update serde
Browse files Browse the repository at this point in the history
  • Loading branch information
charlesbvll committed Jul 8, 2024
1 parent 72eb30a commit 8c0d1b9
Showing 1 changed file with 22 additions and 12 deletions.
34 changes: 22 additions & 12 deletions src/py/flwr/common/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,32 @@
"""ProtoBuf serialization and deserialization."""


from typing import Any, Dict, List, MutableMapping, OrderedDict, Type, TypeVar, cast
from typing import (
Any,
Dict,
List,
MutableMapping,
OrderedDict,
Type,
TypeVar,
Union,
cast,
)

from google.protobuf.message import Message as GrpcMessage

# pylint: disable=E0611
from flwr.proto.common_pb2 import BoolList, BytesList
from flwr.proto.common_pb2 import ConfigsRecordValue as ProtoConfigsRecordValue
from flwr.proto.common_pb2 import DoubleList, Sint64List, StringList
from flwr.proto.error_pb2 import Error as ProtoError
from flwr.proto.node_pb2 import Node
from flwr.proto.recordset_pb2 import Array as ProtoArray
from flwr.proto.recordset_pb2 import BoolList, BytesList
from flwr.proto.recordset_pb2 import ConfigsRecord as ProtoConfigsRecord
from flwr.proto.recordset_pb2 import ConfigsRecordValue as ProtoConfigsRecordValue
from flwr.proto.recordset_pb2 import DoubleList
from flwr.proto.recordset_pb2 import MetricsRecord as ProtoMetricsRecord
from flwr.proto.recordset_pb2 import MetricsRecordValue as ProtoMetricsRecordValue
from flwr.proto.recordset_pb2 import ParametersRecord as ProtoParametersRecord
from flwr.proto.recordset_pb2 import RecordSet as ProtoRecordSet
from flwr.proto.recordset_pb2 import Sint64List, StringList
from flwr.proto.task_pb2 import Task, TaskIns, TaskRes
from flwr.proto.transport_pb2 import (
ClientMessage,
Expand All @@ -47,6 +56,7 @@
from . import Array, ConfigsRecord, MetricsRecord, ParametersRecord, RecordSet, typing
from .message import Error, Message, Metadata
from .record.typeddict import TypedDict
from .typing import ConfigsRecordValues

# === Parameters message ===

Expand Down Expand Up @@ -411,8 +421,8 @@ def _record_value_from_proto(value_proto: GrpcMessage) -> Any:
return value


def _record_value_dict_to_proto(
value_dict: TypedDict[str, Any],
def record_value_dict_to_proto(
value_dict: Union[TypedDict[str, Any], Dict[str, ConfigsRecordValues]],
allowed_types: List[type],
value_proto_class: Type[T],
) -> Dict[str, T]:
Expand All @@ -431,7 +441,7 @@ def proto(_v: Any) -> T:
return {k: proto(v) for k, v in value_dict.items()}


def _record_value_dict_from_proto(
def record_value_dict_from_proto(
value_dict_proto: MutableMapping[str, Any]
) -> Dict[str, Any]:
"""Deserialize the record value dict from ProtoBuf."""
Expand Down Expand Up @@ -476,7 +486,7 @@ def parameters_record_from_proto(
def metrics_record_to_proto(record: MetricsRecord) -> ProtoMetricsRecord:
"""Serialize MetricsRecord to ProtoBuf."""
return ProtoMetricsRecord(
data=_record_value_dict_to_proto(record, [float, int], ProtoMetricsRecordValue)
data=record_value_dict_to_proto(record, [float, int], ProtoMetricsRecordValue)
)


Expand All @@ -485,7 +495,7 @@ def metrics_record_from_proto(record_proto: ProtoMetricsRecord) -> MetricsRecord
return MetricsRecord(
metrics_dict=cast(
Dict[str, typing.MetricsRecordValues],
_record_value_dict_from_proto(record_proto.data),
record_value_dict_from_proto(record_proto.data),
),
keep_input=False,
)
Expand All @@ -494,7 +504,7 @@ def metrics_record_from_proto(record_proto: ProtoMetricsRecord) -> MetricsRecord
def configs_record_to_proto(record: ConfigsRecord) -> ProtoConfigsRecord:
"""Serialize ConfigsRecord to ProtoBuf."""
return ProtoConfigsRecord(
data=_record_value_dict_to_proto(
data=record_value_dict_to_proto(
record,
[bool, int, float, str, bytes],
ProtoConfigsRecordValue,
Expand All @@ -507,7 +517,7 @@ def configs_record_from_proto(record_proto: ProtoConfigsRecord) -> ConfigsRecord
return ConfigsRecord(
configs_dict=cast(
Dict[str, typing.ConfigsRecordValues],
_record_value_dict_from_proto(record_proto.data),
record_value_dict_from_proto(record_proto.data),
),
keep_input=False,
)
Expand Down

0 comments on commit 8c0d1b9

Please sign in to comment.