From eb36760e5121515cb67a18d5e10feb67b5657e0b Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Wed, 15 Jan 2025 17:30:16 +0000 Subject: [PATCH] adopt pull/push messages rpcs --- src/py/flwr/client/connection/fleet_api.py | 24 +++---- .../grpc_adapter_fleet_connection.py | 16 ----- .../grpc_adapter_fleet_connection_test.py | 4 ++ .../grpc_rere/client_interceptor_test.py | 62 ++++++++++--------- .../connection/rere_fleet_connection.py | 40 ++++++------ .../connection/rest/rest_fleet_connection.py | 32 +++++----- 6 files changed, 86 insertions(+), 92 deletions(-) diff --git a/src/py/flwr/client/connection/fleet_api.py b/src/py/flwr/client/connection/fleet_api.py index 67834dc23be7..d23968566c60 100644 --- a/src/py/flwr/client/connection/fleet_api.py +++ b/src/py/flwr/client/connection/fleet_api.py @@ -26,10 +26,10 @@ DeleteNodeResponse, PingRequest, PingResponse, - PullTaskInsRequest, - PullTaskInsResponse, - PushTaskResRequest, - PushTaskResResponse, + PullMessagesRequest, + PullMessagesResponse, + PushMessagesRequest, + PushMessagesResponse, ) from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611 @@ -56,16 +56,16 @@ def DeleteNode( # pylint: disable=C0103 """Fleet.DeleteNode.""" @abstractmethod - def PullTaskIns( # pylint: disable=C0103 - self, request: PullTaskInsRequest, **kwargs: Any - ) -> PullTaskInsResponse: - """Fleet.PullTaskIns.""" + def PullMessages( # pylint: disable=C0103 + self, request: PullMessagesRequest, **kwargs: Any + ) -> PullMessagesResponse: + """Fleet.PullMessages.""" @abstractmethod - def PushTaskRes( # pylint: disable=C0103 - self, request: PushTaskResRequest, **kwargs: Any - ) -> PushTaskResResponse: - """Fleet.PushTaskRes.""" + def PushMessages( # pylint: disable=C0103 + self, request: PushMessagesRequest, **kwargs: Any + ) -> PushMessagesResponse: + """Fleet.PushMessages.""" @abstractmethod def GetRun( # pylint: disable=C0103 diff --git a/src/py/flwr/client/connection/grpc_adapter/grpc_adapter_fleet_connection.py b/src/py/flwr/client/connection/grpc_adapter/grpc_adapter_fleet_connection.py index f69b01fd5e26..ee62edf9f0af 100644 --- a/src/py/flwr/client/connection/grpc_adapter/grpc_adapter_fleet_connection.py +++ b/src/py/flwr/client/connection/grpc_adapter/grpc_adapter_fleet_connection.py @@ -46,12 +46,8 @@ PingResponse, PullMessagesRequest, PullMessagesResponse, - PullTaskInsRequest, - PullTaskInsResponse, PushMessagesRequest, PushMessagesResponse, - PushTaskResRequest, - PushTaskResResponse, ) from flwr.proto.grpcadapter_pb2 import MessageContainer # pylint: disable=E0611 from flwr.proto.grpcadapter_pb2_grpc import GrpcAdapterStub @@ -165,24 +161,12 @@ def Ping( # pylint: disable=C0103 """.""" return self._send_and_receive(request, PingResponse, **kwargs) - def PullTaskIns( # pylint: disable=C0103 - self, request: PullTaskInsRequest, **kwargs: Any - ) -> PullTaskInsResponse: - """.""" - return self._send_and_receive(request, PullTaskInsResponse, **kwargs) - def PullMessages( # pylint: disable=C0103 self, request: PullMessagesRequest, **kwargs: Any ) -> PullMessagesResponse: """.""" return self._send_and_receive(request, PullMessagesResponse, **kwargs) - def PushTaskRes( # pylint: disable=C0103 - self, request: PushTaskResRequest, **kwargs: Any - ) -> PushTaskResResponse: - """.""" - return self._send_and_receive(request, PushTaskResResponse, **kwargs) - def PushMessages( # pylint: disable=C0103 self, request: PushMessagesRequest, **kwargs: Any ) -> PushMessagesResponse: diff --git a/src/py/flwr/client/connection/grpc_adapter/grpc_adapter_fleet_connection_test.py b/src/py/flwr/client/connection/grpc_adapter/grpc_adapter_fleet_connection_test.py index a92896a29572..ac0146405d76 100644 --- a/src/py/flwr/client/connection/grpc_adapter/grpc_adapter_fleet_connection_test.py +++ b/src/py/flwr/client/connection/grpc_adapter/grpc_adapter_fleet_connection_test.py @@ -36,5 +36,9 @@ def test_grpc_adapter_methods() -> None: if inspect.isfunction(ref) } + # Backward compatibility + expected_methods.remove("PullTaskIns") + expected_methods.remove("PushTaskRes") + # Assert assert expected_methods.issubset(methods) diff --git a/src/py/flwr/client/connection/grpc_rere/client_interceptor_test.py b/src/py/flwr/client/connection/grpc_rere/client_interceptor_test.py index 50da0202dfc3..079b221eca0c 100644 --- a/src/py/flwr/client/connection/grpc_rere/client_interceptor_test.py +++ b/src/py/flwr/client/connection/grpc_rere/client_interceptor_test.py @@ -25,7 +25,7 @@ import grpc -from flwr.common import GRPC_MAX_MESSAGE_LENGTH, serde +from flwr.common import GRPC_MAX_MESSAGE_LENGTH from flwr.common.message import Message, Metadata from flwr.common.record import RecordSet from flwr.common.retry_invoker import RetryInvoker, exponential @@ -35,21 +35,26 @@ generate_shared_key, public_key_to_bytes, ) -from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611 + +# pylint: disable=E0611 +from flwr.proto.fleet_pb2 import ( CreateNodeRequest, CreateNodeResponse, DeleteNodeRequest, DeleteNodeResponse, - PullTaskInsRequest, - PullTaskInsResponse, - PushTaskResRequest, - PushTaskResResponse, + PullMessagesRequest, + PullMessagesResponse, + PushMessagesRequest, + PushMessagesResponse, ) from flwr.proto.fleet_pb2_grpc import FleetServicer -from flwr.proto.node_pb2 import Node # pylint: disable=E0611 -from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611 -from flwr.proto.task_pb2 import Task, TaskIns # pylint: disable=E0611 +from flwr.proto.message_pb2 import Message as MessageProto +from flwr.proto.message_pb2 import Metadata as MetadataProto +from flwr.proto.node_pb2 import Node +from flwr.proto.recordset_pb2 import RecordSet as RecordSetProto +from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse +# pylint: enable=E0611 from .client_interceptor import _AUTH_TOKEN_HEADER, _PUBLIC_KEY_HEADER, Request from .grpc_rere_fleet_connection import GrpcRereFleetConnection @@ -66,10 +71,11 @@ def __init__(self) -> None: self.server_private_key, self.server_public_key = generate_key_pairs() self._received_message_bytes: bytes = b"" - def unary_unary( - self, request: Request, context: grpc.ServicerContext - ) -> Union[ - CreateNodeResponse, DeleteNodeResponse, PushTaskResResponse, PullTaskInsResponse + def unary_unary(self, request: Request, context: grpc.ServicerContext) -> Union[ + CreateNodeResponse, + DeleteNodeResponse, + PushMessagesResponse, + PullMessagesResponse, ]: """Handle unary call.""" with self._lock: @@ -90,16 +96,14 @@ def unary_unary( return CreateNodeResponse(node=Node(node_id=123)) if isinstance(request, DeleteNodeRequest): return DeleteNodeResponse() - if isinstance(request, PushTaskResRequest): - return PushTaskResResponse() - - return PullTaskInsResponse( - task_ins_list=[ - TaskIns( - task=Task( - consumer=Node(node_id=123), - recordset=serde.recordset_to_proto(RecordSet()), - ) + if isinstance(request, PushMessagesRequest): + return PushMessagesResponse() + + return PullMessagesResponse( + messages_list=[ + MessageProto( + metadata=MetadataProto(dst_node_id=123), + content=RecordSetProto(), ) ] ) @@ -129,15 +133,15 @@ def _add_generic_handler(servicer: _MockServicer, server: grpc.Server) -> None: request_deserializer=DeleteNodeRequest.FromString, response_serializer=DeleteNodeResponse.SerializeToString, ), - "PullTaskIns": grpc.unary_unary_rpc_method_handler( + "PullMessages": grpc.unary_unary_rpc_method_handler( servicer.unary_unary, - request_deserializer=PullTaskInsRequest.FromString, - response_serializer=PullTaskInsResponse.SerializeToString, + request_deserializer=PullMessagesRequest.FromString, + response_serializer=PullMessagesResponse.SerializeToString, ), - "PushTaskRes": grpc.unary_unary_rpc_method_handler( + "PushMessages": grpc.unary_unary_rpc_method_handler( servicer.unary_unary, - request_deserializer=PushTaskResRequest.FromString, - response_serializer=PushTaskResResponse.SerializeToString, + request_deserializer=PushMessagesRequest.FromString, + response_serializer=PushMessagesResponse.SerializeToString, ), "GetRun": grpc.unary_unary_rpc_method_handler( servicer.unary_unary, diff --git a/src/py/flwr/client/connection/rere_fleet_connection.py b/src/py/flwr/client/connection/rere_fleet_connection.py index 8ed305769702..6da844ddf28c 100644 --- a/src/py/flwr/client/connection/rere_fleet_connection.py +++ b/src/py/flwr/client/connection/rere_fleet_connection.py @@ -28,7 +28,6 @@ from flwr.client.heartbeat import start_ping_loop from flwr.client.message_handler.message_handler import validate_out_message -from flwr.client.message_handler.task_handler import get_task_ins, validate_task_ins from flwr.common import GRPC_MAX_MESSAGE_LENGTH from flwr.common.constant import ( PING_BASE_MULTIPLIER, @@ -39,7 +38,7 @@ from flwr.common.logger import log from flwr.common.message import Message, Metadata from flwr.common.retry_invoker import RetryInvoker -from flwr.common.serde import message_from_taskins, message_to_taskres, run_from_proto +from flwr.common.serde import message_from_proto, message_to_proto, run_from_proto from flwr.common.typing import Fab, Run, RunNotRunningException from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611 from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611 @@ -48,12 +47,12 @@ DeleteNodeRequest, PingRequest, PingResponse, - PullTaskInsRequest, - PushTaskResRequest, + PullMessagesRequest, + PullMessagesResponse, + PushMessagesRequest, ) from flwr.proto.node_pb2 import Node # pylint: disable=E0611 from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611 -from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611 from .fleet_api import FleetApi from .fleet_connection import FleetConnection @@ -170,22 +169,23 @@ def receive(self) -> Message | None: log(ERROR, "Node instance missing") return None - # Request instructions (task) from server - req = PullTaskInsRequest(node=self.node) - res = self.retry_invoker.invoke(self.api.PullTaskIns, request=req) + # Request instructions (message) from server + req = PullMessagesRequest(node=self.node) + res: PullMessagesResponse = self.retry_invoker.invoke( + self.api.PullMessages, request=req + ) # Get the current TaskIns - task_ins: TaskIns | None = get_task_ins(res) + message_proto = res.messages_list[0] if res.messages_list else None - # Discard the current TaskIns if not valid - if task_ins is not None and not ( - task_ins.task.consumer.node_id == self.node.node_id - and validate_task_ins(task_ins) + # Discard the current message if not valid + if message_proto is not None and not ( + message_proto.metadata.dst_node_id == self.node.node_id ): - task_ins = None + message_proto = None # Construct the Message - in_message = message_from_taskins(task_ins) if task_ins else None + in_message = message_from_proto(message_proto) if message_proto else None # Remember `metadata` of the in message if in_message: @@ -213,12 +213,10 @@ def send(self, message: Message) -> None: log(ERROR, "Invalid out message") return - # Construct TaskRes - task_res = message_to_taskres(message) - - # Serialize ProtoBuf to bytes - req = PushTaskResRequest(node=self.node, task_res_list=[task_res]) - self.retry_invoker.invoke(self.api.PushTaskRes, req) + # Serialize Message + message_proto = message_to_proto(message=message) + req = PushMessagesRequest(node=self.node, messages_list=[message_proto]) + self.retry_invoker.invoke(self.api.PushMessages, req) # Cleanup metadata = None diff --git a/src/py/flwr/client/connection/rest/rest_fleet_connection.py b/src/py/flwr/client/connection/rest/rest_fleet_connection.py index a143d89f5b7e..e5ecfb2874bb 100644 --- a/src/py/flwr/client/connection/rest/rest_fleet_connection.py +++ b/src/py/flwr/client/connection/rest/rest_fleet_connection.py @@ -34,10 +34,10 @@ DeleteNodeResponse, PingRequest, PingResponse, - PullTaskInsRequest, - PullTaskInsResponse, - PushTaskResRequest, - PushTaskResResponse, + PullMessagesRequest, + PullMessagesResponse, + PushMessagesRequest, + PushMessagesResponse, ) from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611 @@ -51,8 +51,8 @@ PATH_CREATE_NODE: str = "api/v0/fleet/create-node" PATH_DELETE_NODE: str = "api/v0/fleet/delete-node" -PATH_PULL_TASK_INS: str = "api/v0/fleet/pull-task-ins" -PATH_PUSH_TASK_RES: str = "api/v0/fleet/push-task-res" +PATH_PULL_MESSAGES: str = "/api/v0/fleet/pull-messages" +PATH_PUSH_MESSAGES: str = "/api/v0/fleet/push-messages" PATH_PING: str = "api/v0/fleet/ping" PATH_GET_RUN: str = "/api/v0/fleet/get-run" PATH_GET_FAB: str = "/api/v0/fleet/get-fab" @@ -181,17 +181,21 @@ def Ping( # pylint: disable=C0103 """.""" return self._request(request, PingResponse, PATH_PING, **kwargs) - def PullTaskIns( # pylint: disable=C0103 - self, request: PullTaskInsRequest, **kwargs: Any - ) -> PullTaskInsResponse: + def PullMessages( # pylint: disable=C0103 + self, request: PullMessagesRequest, **kwargs: Any + ) -> PullMessagesResponse: """.""" - return self._request(request, PullTaskInsResponse, PATH_PULL_TASK_INS, **kwargs) + return self._request( + request, PullMessagesResponse, PATH_PULL_MESSAGES, **kwargs + ) - def PushTaskRes( # pylint: disable=C0103 - self, request: PushTaskResRequest, **kwargs: Any - ) -> PushTaskResResponse: + def PushMessages( # pylint: disable=C0103 + self, request: PushMessagesRequest, **kwargs: Any + ) -> PushMessagesResponse: """.""" - return self._request(request, PushTaskResResponse, PATH_PUSH_TASK_RES, **kwargs) + return self._request( + request, PushMessagesResponse, PATH_PUSH_MESSAGES, **kwargs + ) def GetRun( # pylint: disable=C0103 self, request: GetRunRequest, **kwargs: Any