diff --git a/src/py/flwr/client/grpc_client/connection.py b/src/py/flwr/client/grpc_client/connection.py index 5f11912c587c..5a4a3f8fbf8b 100644 --- a/src/py/flwr/client/grpc_client/connection.py +++ b/src/py/flwr/client/grpc_client/connection.py @@ -23,6 +23,14 @@ from typing import Callable, Iterator, Optional, Tuple, Union from flwr.common import GRPC_MAX_MESSAGE_LENGTH +from flwr.common import recordset_compat as compat +from flwr.common import serde, typing +from flwr.common.constant import ( + TASK_TYPE_EVALUATE, + TASK_TYPE_FIT, + TASK_TYPE_GET_PARAMETERS, + TASK_TYPE_GET_PROPERTIES, +) from flwr.common.grpc import create_channel from flwr.common.logger import log from flwr.proto.node_pb2 import Node # pylint: disable=E0611 @@ -118,7 +126,42 @@ def grpc_connection( server_message_iterator: Iterator[ServerMessage] = stub.Join(iter(queue.get, None)) def receive() -> TaskIns: - server_message = next(server_message_iterator) + # Receive ServerMessage proto + proto = next(server_message_iterator) + + # ServerMessage proto --> *Ins --> RecordSet + field = proto.WhichOneof("msg") + task_type = "" + if field == "get_properties_ins": + recordset = compat.getpropertiesins_to_recordset( + serde.get_properties_ins_from_proto(proto.get_properties_ins) + ) + task_type = TASK_TYPE_GET_PROPERTIES + elif field == "get_parameters_ins": + recordset = compat.getparametersins_to_recordset( + serde.get_parameters_ins_from_proto(proto.get_parameters_ins) + ) + task_type = TASK_TYPE_GET_PARAMETERS + elif field == "fit_ins": + recordset = compat.fitins_to_recordset( + serde.fit_ins_from_proto(proto.fit_ins), False + ) + task_type = TASK_TYPE_FIT + elif field == "evaluate_ins": + recordset = compat.evaluateins_to_recordset( + serde.evaluate_ins_from_proto(proto.evaluate_ins), False + ) + task_type = TASK_TYPE_EVALUATE + else: + raise ValueError( + "Unsupported instruction in ServerMessage, " + "cannot deserialize from ProtoBuf" + ) + + # RecordSet --> RecordSet proto + recordset_proto = serde.recordset_to_proto(recordset) + + # Construct TaskIns return TaskIns( task_id=str(uuid.uuid4()), group_id="", @@ -127,13 +170,43 @@ def receive() -> TaskIns: producer=Node(node_id=0, anonymous=True), consumer=Node(node_id=0, anonymous=True), ancestry=[], - legacy_server_message=server_message, + task_type=task_type, + recordset=recordset_proto, ), ) def send(task_res: TaskRes) -> None: - msg = task_res.task.legacy_client_message - return queue.put(msg, block=False) + # Retrieve RecordSet and task_type + recordset = serde.recordset_from_proto(task_res.task.recordset) + task_type = task_res.task.task_type + + # RecordSet --> *Res --> ClientMessage + if task_type == TASK_TYPE_GET_PROPERTIES: + client_message = typing.ClientMessage( + get_properties_res=compat.recordset_to_getpropertiesres(recordset) + ) + elif task_type == TASK_TYPE_GET_PARAMETERS: + client_message = typing.ClientMessage( + get_parameters_res=compat.recordset_to_getparametersres( + recordset, False + ) + ) + elif task_type == TASK_TYPE_FIT: + client_message = typing.ClientMessage( + fit_res=compat.recordset_to_fitres(recordset, False) + ) + elif task_type == TASK_TYPE_EVALUATE: + client_message = typing.ClientMessage( + evaluate_res=compat.recordset_to_evaluateres(recordset) + ) + else: + raise ValueError(f"Invalid task type: {task_type}") + + # ClientMessage --> ClientMessage proto + client_message_proto = serde.client_message_to_proto(client_message) + + # Send ClientMessage proto + return queue.put(client_message_proto, block=False) try: # Yield methods diff --git a/src/py/flwr/client/message_handler/message_handler.py b/src/py/flwr/client/message_handler/message_handler.py index 47d89b7f2c36..8397d675b640 100644 --- a/src/py/flwr/client/message_handler/message_handler.py +++ b/src/py/flwr/client/message_handler/message_handler.py @@ -32,6 +32,12 @@ from flwr.client.secure_aggregation import SecureAggregationHandler from flwr.client.typing import ClientFn from flwr.common import serde +from flwr.common.constant import ( + TASK_TYPE_EVALUATE, + TASK_TYPE_FIT, + TASK_TYPE_GET_PARAMETERS, + TASK_TYPE_GET_PROPERTIES, +) from flwr.common.context import Context from flwr.common.message import Message from flwr.common.recordset import RecordSet @@ -201,34 +207,38 @@ def handle_legacy_message_from_tasktype( task_type = message.metadata.task_type out_message = Message(metadata=message.metadata, message=RecordSet()) - if task_type == "get_properties_ins": + # Handle GetPropertiesIns + if task_type == TASK_TYPE_GET_PROPERTIES: get_properties_res = maybe_call_get_properties( client=client, get_properties_ins=recordset_to_getpropertiesins(message.message), ) out_message.message = getpropertiesres_to_recordset(get_properties_res) - elif task_type == "get_parameteres_ins": + # Handle GetParametersIns + elif task_type == TASK_TYPE_GET_PARAMETERS: get_parameters_res = maybe_call_get_parameters( client=client, get_parameters_ins=recordset_to_getparametersins(message.message), ) - out_message.message = getparametersres_to_recordset(get_parameters_res) - elif task_type == "fit_ins": + out_message.message = getparametersres_to_recordset( + get_parameters_res, keep_input=False + ) + # Handle FitIns + elif task_type == TASK_TYPE_FIT: fit_res = maybe_call_fit( client=client, fit_ins=recordset_to_fitins(message.message, keep_input=False), ) out_message.message = fitres_to_recordset(fit_res, keep_input=False) - elif task_type == "evaluate_ins": + # Handle EvaluateIns + elif task_type == TASK_TYPE_EVALUATE: evaluate_res = maybe_call_evaluate( client=client, evaluate_ins=recordset_to_evaluateins(message.message, keep_input=False), ) out_message.message = evaluateres_to_recordset(evaluate_res) else: - # TODO: what to do with reconnect? - print("do something") - + raise ValueError(f"Invalid task type: {task_type}") return out_message diff --git a/src/py/flwr/common/constant.py b/src/py/flwr/common/constant.py index 49802f2815be..8d1d865f084b 100644 --- a/src/py/flwr/common/constant.py +++ b/src/py/flwr/common/constant.py @@ -31,3 +31,8 @@ TRANSPORT_TYPE_GRPC_RERE, TRANSPORT_TYPE_REST, ] + +TASK_TYPE_GET_PROPERTIES = "get_properties" +TASK_TYPE_GET_PARAMETERS = "get_parameters" +TASK_TYPE_FIT = "fit" +TASK_TYPE_EVALUATE = "evaluate"