Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add conversion functions for Message #2843

Merged
merged 13 commits into from
Jan 25, 2024
74 changes: 71 additions & 3 deletions src/py/flwr/common/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from typing import Any, Dict, List, MutableMapping, OrderedDict, Type, TypeVar, cast

from google.protobuf.message import Message
from google.protobuf.message import Message as GrpcMessage

# pylint: disable=E0611
from flwr.proto.recordset_pb2 import Array as ProtoArray
Expand All @@ -30,7 +30,7 @@
from flwr.proto.recordset_pb2 import ParametersRecord as ProtoParametersRecord
from flwr.proto.recordset_pb2 import RecordSet as ProtoRecordSet
from flwr.proto.recordset_pb2 import Sint64List, StringList
from flwr.proto.task_pb2 import Value
from flwr.proto.task_pb2 import Task, TaskIns, TaskRes, Value
from flwr.proto.transport_pb2 import (
ClientMessage,
Code,
Expand All @@ -44,6 +44,7 @@
# pylint: enable=E0611
from . import typing
from .configsrecord import ConfigsRecord
from .message import Message, Metadata
from .metricsrecord import MetricsRecord
from .parametersrecord import Array, ParametersRecord
from .recordset import RecordSet
Expand Down Expand Up @@ -620,7 +621,7 @@ def _record_value_to_proto(
)


def _record_value_from_proto(value_proto: Message) -> Any:
def _record_value_from_proto(value_proto: GrpcMessage) -> Any:
"""Deserialize `*RecordValue` from ProtoBuf."""
value_field = cast(str, value_proto.WhichOneof("value"))
if value_field.endswith("list"):
Expand Down Expand Up @@ -751,3 +752,70 @@ def recordset_from_proto(recordset_proto: ProtoRecordSet) -> RecordSet:
k: configs_record_from_proto(v) for k, v in recordset_proto.configs.items()
},
)


# === Message ===


def message_to_task_ins(message: Message) -> TaskIns:
"""Create a TaskIns from the Message."""
return TaskIns(
task_id="", # This will be generated by the server
group_id=message.metadata.group_id,
run_id=message.metadata.run_id,
task=Task(
ttl=message.metadata.ttl,
task_type=message.metadata.task_type,
recordset=recordset_to_proto(message.message),
),
)


def message_from_task_ins(task_ins: TaskIns) -> Message:
"""Create a Message from the TaskIns."""
# Retrieve the Metadata
metadata = Metadata(
run_id=task_ins.run_id,
task_id=task_ins.task_id,
group_id=task_ins.group_id,
ttl=task_ins.task.ttl,
task_type=task_ins.task.task_type,
)

# Return the Message
return Message(
metadata=metadata,
message=recordset_from_proto(task_ins.task.recordset),
)


def message_to_task_res(message: Message) -> TaskRes:
"""Create a TaskRes from the Message."""
return TaskRes(
task_id="", # This will be generated by the server
group_id=message.metadata.group_id,
run_id=message.metadata.run_id,
task=Task(
ttl=message.metadata.ttl,
task_type=message.metadata.task_type,
recordset=recordset_to_proto(message.message),
),
)


def message_from_task_res(task_res: TaskRes) -> Message:
"""Create a Message from the TaskIns."""
# Retrieve the MetaData
metadata = Metadata(
run_id=task_res.run_id,
task_id=task_res.task_id,
group_id=task_res.group_id,
ttl=task_res.task.ttl,
task_type=task_res.task.task_type,
)

# Return the Message
return Message(
metadata=metadata,
message=recordset_from_proto(task_res.task.recordset),
)
Loading
Loading