From 14be434e1eb5b69156f889693f9fe1feacdb06a2 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Tue, 23 Jan 2024 21:32:44 +0000 Subject: [PATCH] add conversion functions --- src/py/flwr/common/serde.py | 70 ++++++++++++++++++++++++++++++++++++- 1 file changed, 69 insertions(+), 1 deletion(-) diff --git a/src/py/flwr/common/serde.py b/src/py/flwr/common/serde.py index 2600d46edddc..20379655ebd0 100644 --- a/src/py/flwr/common/serde.py +++ b/src/py/flwr/common/serde.py @@ -30,7 +30,7 @@ from flwr.proto.recordset_pb2 import ParametersRecord as ProtoParametersRecord from flwr.proto.recordset_pb2 import RecordSet as ProtoRecordSet from flwr.proto.recordset_pb2 import Sint64List, StringList -from flwr.proto.task_pb2 import Value +from flwr.proto.task_pb2 import Task, TaskIns, TaskRes, Value from flwr.proto.transport_pb2 import ( ClientMessage, Code, @@ -44,6 +44,7 @@ # pylint: enable=E0611 from . import typing from .configsrecord import ConfigsRecord +from .flowercontext import FlowerContext from .metricsrecord import MetricsRecord from .parametersrecord import Array, ParametersRecord from .recordset import RecordSet @@ -751,3 +752,70 @@ def recordset_from_proto(recordset_proto: ProtoRecordSet) -> RecordSet: k: configs_record_from_proto(v) for k, v in recordset_proto.configs.items() }, ) + + +# === FlowerContext === + + +def flowercontext_to_task_ins(context: FlowerContext) -> TaskIns: + """Create a TaskIns from FlowerContext.""" + return TaskIns( + task_id="", # This will be generated by the server + group_id=context.metadata.group_id, + run_id=context.metadata.run_id, + task=Task( + ttl=context.metadata.ttl, + task_type=context.metadata.task_type, + recordset=recordset_to_proto(context.out_message), + ), + ) + + +def flowercontext_from_task_ins( + context: FlowerContext, task_ins: TaskIns +) -> FlowerContext: + """Retrieve data from a TaskIns.""" + # Set MetaData + context.metadata.run_id = task_ins.run_id + context.metadata.task_id = task_ins.task_id + context.metadata.group_id = task_ins.group_id + context.metadata.ttl = task_ins.task.ttl + context.metadata.task_type = task_ins.task.task_type + + # Set `in_message` + context.in_message = recordset_from_proto(task_ins.task.recordset) + + # Return the FlowerContext + return context + + +def flowercontext_to_task_res(context: FlowerContext) -> TaskRes: + """Create a TaskRes from FlowerContext.""" + return TaskRes( + task_id="", # This will be generated by the server + group_id=context.metadata.group_id, + run_id=context.metadata.run_id, + task=Task( + ttl=context.metadata.ttl, + task_type=context.metadata.task_type, + recordset=recordset_to_proto(context.out_message), + ), + ) + + +def flowercontext_from_task_res( + context: FlowerContext, task_res: TaskRes +) -> FlowerContext: + """Retrieve data from a TaskRes.""" + # Set MetaData + context.metadata.run_id = task_res.run_id + context.metadata.task_id = task_res.task_id + context.metadata.group_id = task_res.group_id + context.metadata.ttl = task_res.task.ttl + context.metadata.task_type = task_res.task.task_type + + # Set `in_message` + context.in_message = recordset_from_proto(task_res.task.recordset) + + # Return the FlowerContext + return context