diff --git a/src/py/flwr/client/grpc_client/connection.py b/src/py/flwr/client/grpc_client/connection.py index 5a4a3f8fbf8b..4ca80642199d 100644 --- a/src/py/flwr/client/grpc_client/connection.py +++ b/src/py/flwr/client/grpc_client/connection.py @@ -20,11 +20,12 @@ from logging import DEBUG from pathlib import Path from queue import Queue -from typing import Callable, Iterator, Optional, Tuple, Union +from typing import Callable, Iterator, Optional, Tuple, Union, cast from flwr.common import GRPC_MAX_MESSAGE_LENGTH from flwr.common import recordset_compat as compat -from flwr.common import serde, typing +from flwr.common import serde +from flwr.common.configsrecord import ConfigsRecord from flwr.common.constant import ( TASK_TYPE_EVALUATE, TASK_TYPE_FIT, @@ -33,10 +34,12 @@ ) from flwr.common.grpc import create_channel from flwr.common.logger import log +from flwr.common.recordset import RecordSet from flwr.proto.node_pb2 import Node # pylint: disable=E0611 from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611 from flwr.proto.transport_pb2 import ( # pylint: disable=E0611 ClientMessage, + Reason, ServerMessage, ) from flwr.proto.transport_pb2_grpc import FlowerServiceStub # pylint: disable=E0611 @@ -54,7 +57,7 @@ def on_channel_state_change(channel_connectivity: str) -> None: @contextmanager -def grpc_connection( +def grpc_connection( # pylint: disable=R0915 server_address: str, insecure: bool, max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, @@ -152,6 +155,12 @@ def receive() -> TaskIns: serde.evaluate_ins_from_proto(proto.evaluate_ins), False ) task_type = TASK_TYPE_EVALUATE + elif field == "reconnect_ins": + recordset = RecordSet() + recordset.set_configs( + "config", ConfigsRecord({"seconds": proto.reconnect_ins.seconds}) + ) + task_type = "reconnect" else: raise ValueError( "Unsupported instruction in ServerMessage, " @@ -180,33 +189,33 @@ def send(task_res: TaskRes) -> None: recordset = serde.recordset_from_proto(task_res.task.recordset) task_type = task_res.task.task_type - # RecordSet --> *Res --> ClientMessage + # RecordSet --> *Res --> *Res proto -> ClientMessage proto if task_type == TASK_TYPE_GET_PROPERTIES: - client_message = typing.ClientMessage( - get_properties_res=compat.recordset_to_getpropertiesres(recordset) + getpropres = compat.recordset_to_getpropertiesres(recordset) + msg_proto = ClientMessage( + get_properties_res=serde.get_properties_res_to_proto(getpropres) ) elif task_type == TASK_TYPE_GET_PARAMETERS: - client_message = typing.ClientMessage( - get_parameters_res=compat.recordset_to_getparametersres( - recordset, False - ) + getparamres = compat.recordset_to_getparametersres(recordset, False) + msg_proto = ClientMessage( + get_parameters_res=serde.get_parameters_res_to_proto(getparamres) ) elif task_type == TASK_TYPE_FIT: - client_message = typing.ClientMessage( - fit_res=compat.recordset_to_fitres(recordset, False) - ) + fitres = compat.recordset_to_fitres(recordset, False) + msg_proto = ClientMessage(fit_res=serde.fit_res_to_proto(fitres)) elif task_type == TASK_TYPE_EVALUATE: - client_message = typing.ClientMessage( - evaluate_res=compat.recordset_to_evaluateres(recordset) + evalres = compat.recordset_to_evaluateres(recordset) + msg_proto = ClientMessage(evaluate_res=serde.evaluate_res_to_proto(evalres)) + elif task_type == "reconnect": + reason = cast(Reason.ValueType, recordset.get_configs("config")["reason"]) + msg_proto = ClientMessage( + disconnect_res=ClientMessage.DisconnectRes(reason=reason) ) else: raise ValueError(f"Invalid task type: {task_type}") - # ClientMessage --> ClientMessage proto - client_message_proto = serde.client_message_to_proto(client_message) - # Send ClientMessage proto - return queue.put(client_message_proto, block=False) + return queue.put(msg_proto, block=False) try: # Yield methods diff --git a/src/py/flwr/client/grpc_client/connection_test.py b/src/py/flwr/client/grpc_client/connection_test.py index bcfa76bb36c0..f2b362750df4 100644 --- a/src/py/flwr/client/grpc_client/connection_test.py +++ b/src/py/flwr/client/grpc_client/connection_test.py @@ -23,6 +23,12 @@ import grpc +from flwr.common import recordset_compat as compat +from flwr.common import serde +from flwr.common.configsrecord import ConfigsRecord +from flwr.common.constant import TASK_TYPE_GET_PROPERTIES +from flwr.common.recordset import RecordSet +from flwr.common.typing import Code, GetPropertiesRes, Status from flwr.proto.task_pb2 import Task, TaskRes # pylint: disable=E0611 from flwr.proto.transport_pb2 import ( # pylint: disable=E0611 ClientMessage, @@ -35,11 +41,21 @@ EXPECTED_NUM_SERVER_MESSAGE = 10 -SERVER_MESSAGE = ServerMessage() +SERVER_MESSAGE = ServerMessage(get_properties_ins=ServerMessage.GetPropertiesIns()) SERVER_MESSAGE_RECONNECT = ServerMessage(reconnect_ins=ServerMessage.ReconnectIns()) -CLIENT_MESSAGE = ClientMessage() -CLIENT_MESSAGE_DISCONNECT = ClientMessage(disconnect_res=ClientMessage.DisconnectRes()) +TASK_GET_PROPERTIES = Task( + task_type=TASK_TYPE_GET_PROPERTIES, + recordset=serde.recordset_to_proto( + compat.getpropertiesres_to_recordset(GetPropertiesRes(Status(Code.OK, ""), {})) + ), +) +TASK_DISCONNECT = Task( + task_type="reconnect", + recordset=serde.recordset_to_proto( + RecordSet(configs={"config": ConfigsRecord({"reason": 0})}) + ), +) def unused_tcp_port() -> int: @@ -104,31 +120,14 @@ def run_client() -> int: # Block until server responds with a message task_ins = receive() - if task_ins is None: - raise ValueError("Unexpected None value") - - # pylint: disable=no-member - if task_ins.HasField("task") and task_ins.task.HasField( - "legacy_server_message" - ): - server_message = task_ins.task.legacy_server_message - else: - server_message = None - # pylint: enable=no-member - - if server_message is None: - raise ValueError("Unexpected None value") - messages_received += 1 - if server_message.HasField("reconnect_ins"): - task_res = TaskRes( - task=Task(legacy_client_message=CLIENT_MESSAGE_DISCONNECT) - ) + if task_ins.task.task_type == "reconnect": # type: ignore + task_res = TaskRes(task=TASK_DISCONNECT) send(task_res) break # Process server_message and send client_message... - task_res = TaskRes(task=Task(legacy_client_message=CLIENT_MESSAGE)) + task_res = TaskRes(task=TASK_GET_PROPERTIES) send(task_res) return messages_received diff --git a/src/py/flwr/client/message_handler/message_handler.py b/src/py/flwr/client/message_handler/message_handler.py index 000ecdbdb35e..e16d190b0289 100644 --- a/src/py/flwr/client/message_handler/message_handler.py +++ b/src/py/flwr/client/message_handler/message_handler.py @@ -15,22 +15,17 @@ """Client-side message handler.""" -from typing import Optional, Tuple +from typing import Optional, Tuple, cast from flwr.client.client import ( - Client, maybe_call_evaluate, maybe_call_fit, maybe_call_get_parameters, maybe_call_get_properties, ) -from flwr.client.message_handler.task_handler import ( - get_server_message_from_task_ins, - wrap_client_message_in_task_res, -) -from flwr.client.secure_aggregation import SecureAggregationHandler from flwr.client.typing import ClientFn from flwr.common import serde +from flwr.common.configsrecord import ConfigsRecord from flwr.common.constant import ( TASK_TYPE_EVALUATE, TASK_TYPE_FIT, @@ -50,12 +45,7 @@ recordset_to_getparametersins, recordset_to_getpropertiesins, ) -from flwr.proto.task_pb2 import ( # pylint: disable=E0611 - SecureAggregation, - Task, - TaskIns, - TaskRes, -) +from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611 from flwr.proto.transport_pb2 import ( # pylint: disable=E0611 ClientMessage, Reason, @@ -81,120 +71,37 @@ def handle_control_message(task_ins: TaskIns) -> Tuple[Optional[TaskRes], int]: Returns ------- + task_res : Optional[TaskRes] + TaskRes to be returned to the server. If None, the client should + continue to process messages from the server. sleep_duration : int Number of seconds that the client should disconnect from the server. - keep_going : bool - Flag that indicates whether the client should continue to process the - next message from the server (True) or disconnect and optionally - reconnect later (False). """ - server_msg = get_server_message_from_task_ins(task_ins, exclude_reconnect_ins=False) - - # SecAgg message - if server_msg is None: - return None, 0 - - # ReconnectIns message - field = server_msg.WhichOneof("msg") - if field == "reconnect_ins": - disconnect_msg, sleep_duration = _reconnect(server_msg.reconnect_ins) - task_res = wrap_client_message_in_task_res(disconnect_msg) + if task_ins.task.task_type == "reconnect": + # Retrieve ReconnectIns from recordset + recordset = serde.recordset_from_proto(task_ins.task.recordset) + seconds = cast(int, recordset.get_configs("config")["seconds"]) + # Construct ReconnectIns and call _reconnect + disconnect_msg, sleep_duration = _reconnect( + ServerMessage.ReconnectIns(seconds=seconds) + ) + # Store DisconnectRes in recordset + reason = cast(int, disconnect_msg.disconnect_res.reason) + recordset = RecordSet() + recordset.set_configs("config", ConfigsRecord({"reason": reason})) + task_res = TaskRes( + task=Task( + task_type="reconnect", + recordset=serde.recordset_to_proto(recordset), + ) + ) + # Return TaskRes and sleep duration return task_res, sleep_duration # Any other message return None, 0 -def handle( - client_fn: ClientFn, context: Context, task_ins: TaskIns -) -> Tuple[TaskRes, Context]: - """Handle incoming TaskIns from the server. - - Parameters - ---------- - client_fn : ClientFn - A callable that instantiates a Client. - context : Context - A dataclass storing the context for the run being executed by the client. - task_ins: TaskIns - The task instruction coming from the server, to be processed by the client. - - Returns - ------- - task_res : TaskRes - The task response that should be returned to the server. - """ - server_msg = get_server_message_from_task_ins(task_ins, exclude_reconnect_ins=False) - if server_msg is None: - # Instantiate the client - client = client_fn("-1") - client.set_context(context) - # Secure Aggregation - if task_ins.task.HasField("sa") and isinstance( - client, SecureAggregationHandler - ): - # pylint: disable-next=invalid-name - named_values = serde.named_values_from_proto(task_ins.task.sa.named_values) - res = client.handle_secure_aggregation(named_values) - task_res = TaskRes( - task_id="", - group_id="", - run_id=0, - task=Task( - ancestry=[], - sa=SecureAggregation(named_values=serde.named_values_to_proto(res)), - ), - ) - return task_res, client.get_context() - raise NotImplementedError() - client_msg, updated_context = handle_legacy_message(client_fn, context, server_msg) - task_res = wrap_client_message_in_task_res(client_msg) - return task_res, updated_context - - -def handle_legacy_message( - client_fn: ClientFn, context: Context, server_msg: ServerMessage -) -> Tuple[ClientMessage, Context]: - """Handle incoming messages from the server. - - Parameters - ---------- - client_fn : ClientFn - A callable that instantiates a Client. - context : Context - A dataclass storing the context for the run being executed by the client. - server_msg: ServerMessage - The message coming from the server, to be processed by the client. - - Returns - ------- - client_msg : ClientMessage - The result message that should be returned to the server. - """ - field = server_msg.WhichOneof("msg") - - # Must be handled elsewhere - if field == "reconnect_ins": - raise UnexpectedServerMessage() - - # Instantiate the client - client = client_fn("-1") - client.set_context(context) - # Execute task - message = None - if field == "get_properties_ins": - message = _get_properties(client, server_msg.get_properties_ins) - if field == "get_parameters_ins": - message = _get_parameters(client, server_msg.get_parameters_ins) - if field == "fit_ins": - message = _fit(client, server_msg.fit_ins) - if field == "evaluate_ins": - message = _evaluate(client, server_msg.evaluate_ins) - if message: - return message, client.get_context() - raise UnknownServerMessage() - - def handle_legacy_message_from_tasktype( client_fn: ClientFn, message: Message, context: Context ) -> Message: @@ -253,67 +160,3 @@ def _reconnect( # Build DisconnectRes message disconnect_res = ClientMessage.DisconnectRes(reason=reason) return ClientMessage(disconnect_res=disconnect_res), sleep_duration - - -def _get_properties( - client: Client, get_properties_msg: ServerMessage.GetPropertiesIns -) -> ClientMessage: - # Deserialize `get_properties` instruction - get_properties_ins = serde.get_properties_ins_from_proto(get_properties_msg) - - # Request properties - get_properties_res = maybe_call_get_properties( - client=client, - get_properties_ins=get_properties_ins, - ) - - # Serialize response - get_properties_res_proto = serde.get_properties_res_to_proto(get_properties_res) - return ClientMessage(get_properties_res=get_properties_res_proto) - - -def _get_parameters( - client: Client, get_parameters_msg: ServerMessage.GetParametersIns -) -> ClientMessage: - # Deserialize `get_parameters` instruction - get_parameters_ins = serde.get_parameters_ins_from_proto(get_parameters_msg) - - # Request parameters - get_parameters_res = maybe_call_get_parameters( - client=client, - get_parameters_ins=get_parameters_ins, - ) - - # Serialize response - get_parameters_res_proto = serde.get_parameters_res_to_proto(get_parameters_res) - return ClientMessage(get_parameters_res=get_parameters_res_proto) - - -def _fit(client: Client, fit_msg: ServerMessage.FitIns) -> ClientMessage: - # Deserialize fit instruction - fit_ins = serde.fit_ins_from_proto(fit_msg) - - # Perform fit - fit_res = maybe_call_fit( - client=client, - fit_ins=fit_ins, - ) - - # Serialize fit result - fit_res_proto = serde.fit_res_to_proto(fit_res) - return ClientMessage(fit_res=fit_res_proto) - - -def _evaluate(client: Client, evaluate_msg: ServerMessage.EvaluateIns) -> ClientMessage: - # Deserialize evaluate instruction - evaluate_ins = serde.evaluate_ins_from_proto(evaluate_msg) - - # Perform evaluation - evaluate_res = maybe_call_evaluate( - client=client, - evaluate_ins=evaluate_ins, - ) - - # Serialize evaluate result - evaluate_res_proto = serde.evaluate_res_to_proto(evaluate_res) - return ClientMessage(evaluate_res=evaluate_res_proto) diff --git a/src/py/flwr/client/message_handler/message_handler_test.py b/src/py/flwr/client/message_handler/message_handler_test.py index 707570cd8e57..842b6d1e1ee1 100644 --- a/src/py/flwr/client/message_handler/message_handler_test.py +++ b/src/py/flwr/client/message_handler/message_handler_test.py @@ -20,6 +20,7 @@ from flwr.client import Client from flwr.client.typing import ClientFn from flwr.common import ( + Code, EvaluateIns, EvaluateRes, FitIns, @@ -29,21 +30,16 @@ GetPropertiesIns, GetPropertiesRes, Parameters, - serde, - typing, + Status, ) +from flwr.common import recordset_compat as compat +from flwr.common import typing +from flwr.common.constant import TASK_TYPE_GET_PROPERTIES from flwr.common.context import Context +from flwr.common.message import Message, Metadata from flwr.common.recordset import RecordSet -from flwr.proto.node_pb2 import Node # pylint: disable=E0611 -from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611 -from flwr.proto.transport_pb2 import ( # pylint: disable=E0611 - ClientMessage, - Code, - ServerMessage, - Status, -) -from .message_handler import handle, handle_control_message +from .message_handler import handle_legacy_message_from_tasktype class ClientWithoutProps(Client): @@ -122,137 +118,71 @@ def test_client_without_get_properties() -> None: """Test client implementing get_properties.""" # Prepare client = ClientWithoutProps() - ins = ServerMessage.GetPropertiesIns() - - task_ins: TaskIns = TaskIns( - task_id=str(uuid.uuid4()), - group_id="", - run_id=0, - task=Task( - producer=Node(node_id=0, anonymous=True), - consumer=Node(node_id=0, anonymous=True), - ancestry=[], - legacy_server_message=ServerMessage(get_properties_ins=ins), + recordset = compat.getpropertiesins_to_recordset(GetPropertiesIns({})) + message = Message( + metadata=Metadata( + run_id=0, + task_id=str(uuid.uuid4()), + group_id="", + ttl="", + task_type=TASK_TYPE_GET_PROPERTIES, ), + message=recordset, ) # Execute - disconnect_task_res, actual_sleep_duration = handle_control_message( - task_ins=task_ins - ) - task_res, _ = handle( + actual_msg = handle_legacy_message_from_tasktype( client_fn=_get_client_fn(client), + message=message, context=Context(state=RecordSet()), - task_ins=task_ins, - ) - - if not task_res.HasField("task"): - raise ValueError("Task value not found") - - # pylint: disable=no-member - if not task_res.task.HasField("legacy_client_message"): - raise ValueError("Unexpected None value") - # pylint: enable=no-member - - task_res.MergeFrom( - TaskRes( - task_id=str(uuid.uuid4()), - group_id="", - run_id=0, - ) - ) - # pylint: disable=no-member - task_res.task.MergeFrom( - Task( - producer=Node(node_id=0, anonymous=True), - consumer=Node(node_id=0, anonymous=True), - ancestry=[task_ins.task_id], - ) ) - actual_msg = task_res.task.legacy_client_message - # pylint: enable=no-member - # Assert - expected_get_properties_res = ClientMessage.GetPropertiesRes( + expected_get_properties_res = GetPropertiesRes( status=Status( code=Code.GET_PROPERTIES_NOT_IMPLEMENTED, message="Client does not implement `get_properties`", - ) + ), + properties={}, ) - expected_msg = ClientMessage(get_properties_res=expected_get_properties_res) + expected_rs = compat.getpropertiesres_to_recordset(expected_get_properties_res) + expected_msg = Message(message.metadata, expected_rs) assert actual_msg == expected_msg - assert not disconnect_task_res - assert actual_sleep_duration == 0 def test_client_with_get_properties() -> None: """Test client not implementing get_properties.""" # Prepare client = ClientWithProps() - ins = ServerMessage.GetPropertiesIns() - task_ins = TaskIns( - task_id=str(uuid.uuid4()), - group_id="", - run_id=0, - task=Task( - producer=Node(node_id=0, anonymous=True), - consumer=Node(node_id=0, anonymous=True), - ancestry=[], - legacy_server_message=ServerMessage(get_properties_ins=ins), + recordset = compat.getpropertiesins_to_recordset(GetPropertiesIns({})) + message = Message( + metadata=Metadata( + run_id=0, + task_id=str(uuid.uuid4()), + group_id="", + ttl="", + task_type=TASK_TYPE_GET_PROPERTIES, ), + message=recordset, ) # Execute - disconnect_task_res, actual_sleep_duration = handle_control_message( - task_ins=task_ins - ) - task_res, _ = handle( + actual_msg = handle_legacy_message_from_tasktype( client_fn=_get_client_fn(client), + message=message, context=Context(state=RecordSet()), - task_ins=task_ins, ) - if not task_res.HasField("task"): - raise ValueError("Task value not found") - - # pylint: disable=no-member - if not task_res.task.HasField("legacy_client_message"): - raise ValueError("Unexpected None value") - # pylint: enable=no-member - - task_res.MergeFrom( - TaskRes( - task_id=str(uuid.uuid4()), - group_id="", - run_id=0, - ) - ) - # pylint: disable=no-member - task_res.task.MergeFrom( - Task( - producer=Node(node_id=0, anonymous=True), - consumer=Node(node_id=0, anonymous=True), - ancestry=[task_ins.task_id], - ) - ) - - actual_msg = task_res.task.legacy_client_message - # pylint: enable=no-member - # Assert - expected_get_properties_res = ClientMessage.GetPropertiesRes( + expected_get_properties_res = GetPropertiesRes( status=Status( code=Code.OK, message="Success", ), - properties=serde.properties_to_proto( - properties={"str_prop": "val", "int_prop": 1} - ), + properties={"str_prop": "val", "int_prop": 1}, ) - expected_msg = ClientMessage(get_properties_res=expected_get_properties_res) + expected_rs = compat.getpropertiesres_to_recordset(expected_get_properties_res) + expected_msg = Message(message.metadata, expected_rs) assert actual_msg == expected_msg - assert not disconnect_task_res - assert actual_sleep_duration == 0