Skip to content

Commit

Permalink
Add conversion functions for Message (#2843)
Browse files Browse the repository at this point in the history
  • Loading branch information
panh99 authored Jan 25, 2024
1 parent 944afce commit 3402dbb
Show file tree
Hide file tree
Showing 2 changed files with 260 additions and 68 deletions.
68 changes: 65 additions & 3 deletions src/py/flwr/common/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from typing import Any, Dict, List, MutableMapping, OrderedDict, Type, TypeVar, cast

from google.protobuf.message import Message
from google.protobuf.message import Message as GrpcMessage

# pylint: disable=E0611
from flwr.proto.recordset_pb2 import Array as ProtoArray
Expand All @@ -30,7 +30,7 @@
from flwr.proto.recordset_pb2 import ParametersRecord as ProtoParametersRecord
from flwr.proto.recordset_pb2 import RecordSet as ProtoRecordSet
from flwr.proto.recordset_pb2 import Sint64List, StringList
from flwr.proto.task_pb2 import Value
from flwr.proto.task_pb2 import Task, TaskIns, TaskRes, Value
from flwr.proto.transport_pb2 import (
ClientMessage,
Code,
Expand All @@ -44,6 +44,7 @@
# pylint: enable=E0611
from . import typing
from .configsrecord import ConfigsRecord
from .message import Message, Metadata
from .metricsrecord import MetricsRecord
from .parametersrecord import Array, ParametersRecord
from .recordset import RecordSet
Expand Down Expand Up @@ -620,7 +621,7 @@ def _record_value_to_proto(
)


def _record_value_from_proto(value_proto: Message) -> Any:
def _record_value_from_proto(value_proto: GrpcMessage) -> Any:
"""Deserialize `*RecordValue` from ProtoBuf."""
value_field = cast(str, value_proto.WhichOneof("value"))
if value_field.endswith("list"):
Expand Down Expand Up @@ -751,3 +752,64 @@ def recordset_from_proto(recordset_proto: ProtoRecordSet) -> RecordSet:
k: configs_record_from_proto(v) for k, v in recordset_proto.configs.items()
},
)


# === Message ===


def message_to_taskins(message: Message) -> TaskIns:
"""Create a TaskIns from the Message."""
return TaskIns(
task=Task(
ttl=message.metadata.ttl,
task_type=message.metadata.task_type,
recordset=recordset_to_proto(message.message),
),
)


def message_from_taskins(taskins: TaskIns) -> Message:
"""Create a Message from the TaskIns."""
# Retrieve the Metadata
metadata = Metadata(
run_id=taskins.run_id,
task_id=taskins.task_id,
group_id=taskins.group_id,
ttl=taskins.task.ttl,
task_type=taskins.task.task_type,
)

# Return the Message
return Message(
metadata=metadata,
message=recordset_from_proto(taskins.task.recordset),
)


def message_to_taskres(message: Message) -> TaskRes:
"""Create a TaskRes from the Message."""
return TaskRes(
task=Task(
ttl=message.metadata.ttl,
task_type=message.metadata.task_type,
recordset=recordset_to_proto(message.message),
),
)


def message_from_taskres(taskres: TaskRes) -> Message:
"""Create a Message from the TaskIns."""
# Retrieve the MetaData
metadata = Metadata(
run_id=taskres.run_id,
task_id=taskres.task_id,
group_id=taskres.group_id,
ttl=taskres.task.ttl,
task_type=taskres.task.task_type,
)

# Return the Message
return Message(
metadata=metadata,
message=recordset_from_proto(taskres.task.recordset),
)
Loading

0 comments on commit 3402dbb

Please sign in to comment.