Skip to content

Commit

Permalink
add conversion functions
Browse files Browse the repository at this point in the history
  • Loading branch information
panh99 committed Jan 23, 2024
1 parent 214d1c8 commit 14be434
Showing 1 changed file with 69 additions and 1 deletion.
70 changes: 69 additions & 1 deletion src/py/flwr/common/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from flwr.proto.recordset_pb2 import ParametersRecord as ProtoParametersRecord
from flwr.proto.recordset_pb2 import RecordSet as ProtoRecordSet
from flwr.proto.recordset_pb2 import Sint64List, StringList
from flwr.proto.task_pb2 import Value
from flwr.proto.task_pb2 import Task, TaskIns, TaskRes, Value
from flwr.proto.transport_pb2 import (
ClientMessage,
Code,
Expand All @@ -44,6 +44,7 @@
# pylint: enable=E0611
from . import typing
from .configsrecord import ConfigsRecord
from .flowercontext import FlowerContext
from .metricsrecord import MetricsRecord
from .parametersrecord import Array, ParametersRecord
from .recordset import RecordSet
Expand Down Expand Up @@ -751,3 +752,70 @@ def recordset_from_proto(recordset_proto: ProtoRecordSet) -> RecordSet:
k: configs_record_from_proto(v) for k, v in recordset_proto.configs.items()
},
)


# === FlowerContext ===


def flowercontext_to_task_ins(context: FlowerContext) -> TaskIns:
"""Create a TaskIns from FlowerContext."""
return TaskIns(
task_id="", # This will be generated by the server
group_id=context.metadata.group_id,
run_id=context.metadata.run_id,
task=Task(
ttl=context.metadata.ttl,
task_type=context.metadata.task_type,
recordset=recordset_to_proto(context.out_message),
),
)


def flowercontext_from_task_ins(
context: FlowerContext, task_ins: TaskIns
) -> FlowerContext:
"""Retrieve data from a TaskIns."""
# Set MetaData
context.metadata.run_id = task_ins.run_id
context.metadata.task_id = task_ins.task_id
context.metadata.group_id = task_ins.group_id
context.metadata.ttl = task_ins.task.ttl
context.metadata.task_type = task_ins.task.task_type

# Set `in_message`
context.in_message = recordset_from_proto(task_ins.task.recordset)

# Return the FlowerContext
return context


def flowercontext_to_task_res(context: FlowerContext) -> TaskRes:
"""Create a TaskRes from FlowerContext."""
return TaskRes(
task_id="", # This will be generated by the server
group_id=context.metadata.group_id,
run_id=context.metadata.run_id,
task=Task(
ttl=context.metadata.ttl,
task_type=context.metadata.task_type,
recordset=recordset_to_proto(context.out_message),
),
)


def flowercontext_from_task_res(
context: FlowerContext, task_res: TaskRes
) -> FlowerContext:
"""Retrieve data from a TaskRes."""
# Set MetaData
context.metadata.run_id = task_res.run_id
context.metadata.task_id = task_res.task_id
context.metadata.group_id = task_res.group_id
context.metadata.ttl = task_res.task.ttl
context.metadata.task_type = task_res.task.task_type

# Set `in_message`
context.in_message = recordset_from_proto(task_res.task.recordset)

# Return the FlowerContext
return context

0 comments on commit 14be434

Please sign in to comment.