From 1350ef41ede17dff8b8dca5224f8a88d550a93ca Mon Sep 17 00:00:00 2001 From: Javier Date: Wed, 29 Jan 2025 22:08:52 +0100 Subject: [PATCH] refactor(framework) Remove unused `rpcs` in `Fleet API` and presence in framework (#4874) --- src/proto/flwr/proto/fleet.proto | 29 +----- .../client_interceptor_test.py | 61 +++--------- .../client/grpc_rere_client/grpc_adapter.py | 16 --- src/py/flwr/proto/fleet_pb2.py | 67 +++++-------- src/py/flwr/proto/fleet_pb2.pyi | 84 ---------------- src/py/flwr/proto/fleet_pb2_grpc.py | 98 +------------------ src/py/flwr/proto/fleet_pb2_grpc.pyi | 50 +++------- .../fleet/grpc_rere/fleet_servicer.py | 38 ------- .../fleet/grpc_rere/fleet_servicer_test.py | 61 ------------ .../grpc_rere/server_interceptor_test.py | 44 --------- .../fleet/message_handler/message_handler.py | 50 +--------- .../message_handler/message_handler_test.py | 59 +---------- .../superlink/fleet/rest_rere/rest_api.py | 27 ----- 13 files changed, 65 insertions(+), 619 deletions(-) diff --git a/src/proto/flwr/proto/fleet.proto b/src/proto/flwr/proto/fleet.proto index 1f76d460af2b..00eb25e1cd7c 100644 --- a/src/proto/flwr/proto/fleet.proto +++ b/src/proto/flwr/proto/fleet.proto @@ -18,7 +18,6 @@ syntax = "proto3"; package flwr.proto; import "flwr/proto/node.proto"; -import "flwr/proto/task.proto"; import "flwr/proto/run.proto"; import "flwr/proto/fab.proto"; import "flwr/proto/message.proto"; @@ -28,17 +27,13 @@ service Fleet { rpc DeleteNode(DeleteNodeRequest) returns (DeleteNodeResponse) {} rpc Ping(PingRequest) returns (PingResponse) {} - // Retrieve one or more tasks, if possible + // Retrieve one or more messages, if possible // - // HTTP API path: /api/v1/fleet/pull-task-ins - rpc PullTaskIns(PullTaskInsRequest) returns (PullTaskInsResponse) {} // HTTP API path: /api/v1/fleet/pull-messages rpc PullMessages(PullMessagesRequest) returns (PullMessagesResponse) {} - // Complete one or more tasks, if possible + // Complete one or more messages, if possible // - // HTTP API path: /api/v1/fleet/push-task-res - rpc PushTaskRes(PushTaskResRequest) returns (PushTaskResResponse) {} // HTTP API path: /api/v1/fleet/push-messages rpc PushMessages(PushMessagesRequest) returns (PushMessagesResponse) {} @@ -63,26 +58,6 @@ message PingRequest { } message PingResponse { bool success = 1; } -// PullTaskIns messages -message PullTaskInsRequest { - Node node = 1; - repeated string task_ids = 2; -} -message PullTaskInsResponse { - Reconnect reconnect = 1; - repeated TaskIns task_ins_list = 2; -} - -// PushTaskRes messages -message PushTaskResRequest { - Node node = 1; - repeated TaskRes task_res_list = 2; -} -message PushTaskResResponse { - Reconnect reconnect = 1; - map results = 2; -} - // PullMessages messages message PullMessagesRequest { Node node = 1; diff --git a/src/py/flwr/client/grpc_rere_client/client_interceptor_test.py b/src/py/flwr/client/grpc_rere_client/client_interceptor_test.py index 1eb75f5c2efe..3bd09228bfe9 100644 --- a/src/py/flwr/client/grpc_rere_client/client_interceptor_test.py +++ b/src/py/flwr/client/grpc_rere_client/client_interceptor_test.py @@ -50,16 +50,11 @@ DeleteNodeResponse, PullMessagesRequest, PullMessagesResponse, - PullTaskInsRequest, - PullTaskInsResponse, PushMessagesRequest, PushMessagesResponse, - PushTaskResRequest, - PushTaskResResponse, ) from flwr.proto.node_pb2 import Node # pylint: disable=E0611 from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611 -from flwr.proto.task_pb2 import Task, TaskIns # pylint: disable=E0611 class _MockServicer: @@ -85,41 +80,27 @@ def unary_unary( # pylint: disable=too-many-return-statements return CreateNodeResponse(node=Node(node_id=123)) if isinstance(request, DeleteNodeRequest): return DeleteNodeResponse() - if isinstance(request, PushTaskResRequest): - return PushTaskResResponse() if isinstance(request, PushMessagesRequest): return PushMessagesResponse() if isinstance(request, GetRunRequest): return GetRunResponse() - if isinstance(request, PullMessagesRequest): - - msg = Message( - metadata=Metadata( - run_id=1234, - message_id="", - src_node_id=123, - dst_node_id=SUPERLINK_NODE_ID, - group_id="", - ttl=DEFAULT_TTL, - message_type="mock", - reply_to_message="", - ), - content=RecordSet(), - ) - proto_msg = serde.message_to_proto(msg) - proto_msg.metadata.created_at = now().timestamp() - return PullMessagesResponse(messages_list=[]) - - return PullTaskInsResponse( - task_ins_list=[ - TaskIns( - task=Task( - consumer=Node(node_id=123), - recordset=serde.recordset_to_proto(RecordSet()), - ) - ) - ] + + msg = Message( + metadata=Metadata( + run_id=1234, + message_id="", + src_node_id=123, + dst_node_id=SUPERLINK_NODE_ID, + group_id="", + ttl=DEFAULT_TTL, + message_type="mock", + reply_to_message="", + ), + content=RecordSet(), ) + proto_msg = serde.message_to_proto(msg) + proto_msg.metadata.created_at = now().timestamp() + return PullMessagesResponse(messages_list=[]) def received_client_metadata( self, @@ -146,21 +127,11 @@ def _add_generic_handler(servicer: _MockServicer, server: grpc.Server) -> None: request_deserializer=DeleteNodeRequest.FromString, response_serializer=DeleteNodeResponse.SerializeToString, ), - "PullTaskIns": grpc.unary_unary_rpc_method_handler( - servicer.unary_unary, - request_deserializer=PullTaskInsRequest.FromString, - response_serializer=PullTaskInsResponse.SerializeToString, - ), "PullMessages": grpc.unary_unary_rpc_method_handler( servicer.unary_unary, request_deserializer=PullMessagesRequest.FromString, response_serializer=PullMessagesResponse.SerializeToString, ), - "PushTaskRes": grpc.unary_unary_rpc_method_handler( - servicer.unary_unary, - request_deserializer=PushTaskResRequest.FromString, - response_serializer=PushTaskResResponse.SerializeToString, - ), "PushMessages": grpc.unary_unary_rpc_method_handler( servicer.unary_unary, request_deserializer=PushMessagesRequest.FromString, diff --git a/src/py/flwr/client/grpc_rere_client/grpc_adapter.py b/src/py/flwr/client/grpc_rere_client/grpc_adapter.py index b9a761b59639..5ec6337eb6ee 100644 --- a/src/py/flwr/client/grpc_rere_client/grpc_adapter.py +++ b/src/py/flwr/client/grpc_rere_client/grpc_adapter.py @@ -42,12 +42,8 @@ PingResponse, PullMessagesRequest, PullMessagesResponse, - PullTaskInsRequest, - PullTaskInsResponse, PushMessagesRequest, PushMessagesResponse, - PushTaskResRequest, - PushTaskResResponse, ) from flwr.proto.grpcadapter_pb2 import MessageContainer # pylint: disable=E0611 from flwr.proto.grpcadapter_pb2_grpc import GrpcAdapterStub @@ -130,24 +126,12 @@ def Ping( # pylint: disable=C0103 """.""" return self._send_and_receive(request, PingResponse, **kwargs) - def PullTaskIns( # pylint: disable=C0103 - self, request: PullTaskInsRequest, **kwargs: Any - ) -> PullTaskInsResponse: - """.""" - return self._send_and_receive(request, PullTaskInsResponse, **kwargs) - def PullMessages( # pylint: disable=C0103 self, request: PullMessagesRequest, **kwargs: Any ) -> PullMessagesResponse: """.""" return self._send_and_receive(request, PullMessagesResponse, **kwargs) - def PushTaskRes( # pylint: disable=C0103 - self, request: PushTaskResRequest, **kwargs: Any - ) -> PushTaskResResponse: - """.""" - return self._send_and_receive(request, PushTaskResResponse, **kwargs) - def PushMessages( # pylint: disable=C0103 self, request: PushMessagesRequest, **kwargs: Any ) -> PushMessagesResponse: diff --git a/src/py/flwr/proto/fleet_pb2.py b/src/py/flwr/proto/fleet_pb2.py index 69b0e3a7b3c0..4d57dee23a29 100644 --- a/src/py/flwr/proto/fleet_pb2.py +++ b/src/py/flwr/proto/fleet_pb2.py @@ -23,57 +23,44 @@ from flwr.proto import node_pb2 as flwr_dot_proto_dot_node__pb2 -from flwr.proto import task_pb2 as flwr_dot_proto_dot_task__pb2 from flwr.proto import run_pb2 as flwr_dot_proto_dot_run__pb2 from flwr.proto import fab_pb2 as flwr_dot_proto_dot_fab__pb2 from flwr.proto import message_pb2 as flwr_dot_proto_dot_message__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16\x66lwr/proto/fleet.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x15\x66lwr/proto/task.proto\x1a\x14\x66lwr/proto/run.proto\x1a\x14\x66lwr/proto/fab.proto\x1a\x18\x66lwr/proto/message.proto\"*\n\x11\x43reateNodeRequest\x12\x15\n\rping_interval\x18\x01 \x01(\x01\"4\n\x12\x43reateNodeResponse\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\"3\n\x11\x44\x65leteNodeRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\"\x14\n\x12\x44\x65leteNodeResponse\"D\n\x0bPingRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x15\n\rping_interval\x18\x02 \x01(\x01\"\x1f\n\x0cPingResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\"F\n\x12PullTaskInsRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"k\n\x13PullTaskInsResponse\x12(\n\treconnect\x18\x01 \x01(\x0b\x32\x15.flwr.proto.Reconnect\x12*\n\rtask_ins_list\x18\x02 \x03(\x0b\x32\x13.flwr.proto.TaskIns\"`\n\x12PushTaskResRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12*\n\rtask_res_list\x18\x02 \x03(\x0b\x32\x13.flwr.proto.TaskRes\"\xae\x01\n\x13PushTaskResResponse\x12(\n\treconnect\x18\x01 \x01(\x0b\x32\x15.flwr.proto.Reconnect\x12=\n\x07results\x18\x02 \x03(\x0b\x32,.flwr.proto.PushTaskResResponse.ResultsEntry\x1a.\n\x0cResultsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\r:\x02\x38\x01\"J\n\x13PullMessagesRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x13\n\x0bmessage_ids\x18\x02 \x03(\t\"l\n\x14PullMessagesResponse\x12(\n\treconnect\x18\x01 \x01(\x0b\x32\x15.flwr.proto.Reconnect\x12*\n\rmessages_list\x18\x02 \x03(\x0b\x32\x13.flwr.proto.Message\"a\n\x13PushMessagesRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12*\n\rmessages_list\x18\x02 \x03(\x0b\x32\x13.flwr.proto.Message\"\xb0\x01\n\x14PushMessagesResponse\x12(\n\treconnect\x18\x01 \x01(\x0b\x32\x15.flwr.proto.Reconnect\x12>\n\x07results\x18\x02 \x03(\x0b\x32-.flwr.proto.PushMessagesResponse.ResultsEntry\x1a.\n\x0cResultsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\r:\x02\x38\x01\"\x1e\n\tReconnect\x12\x11\n\treconnect\x18\x01 \x01(\x04\x32\xb6\x05\n\x05\x46leet\x12M\n\nCreateNode\x12\x1d.flwr.proto.CreateNodeRequest\x1a\x1e.flwr.proto.CreateNodeResponse\"\x00\x12M\n\nDeleteNode\x12\x1d.flwr.proto.DeleteNodeRequest\x1a\x1e.flwr.proto.DeleteNodeResponse\"\x00\x12;\n\x04Ping\x12\x17.flwr.proto.PingRequest\x1a\x18.flwr.proto.PingResponse\"\x00\x12P\n\x0bPullTaskIns\x12\x1e.flwr.proto.PullTaskInsRequest\x1a\x1f.flwr.proto.PullTaskInsResponse\"\x00\x12S\n\x0cPullMessages\x12\x1f.flwr.proto.PullMessagesRequest\x1a .flwr.proto.PullMessagesResponse\"\x00\x12P\n\x0bPushTaskRes\x12\x1e.flwr.proto.PushTaskResRequest\x1a\x1f.flwr.proto.PushTaskResResponse\"\x00\x12S\n\x0cPushMessages\x12\x1f.flwr.proto.PushMessagesRequest\x1a .flwr.proto.PushMessagesResponse\"\x00\x12\x41\n\x06GetRun\x12\x19.flwr.proto.GetRunRequest\x1a\x1a.flwr.proto.GetRunResponse\"\x00\x12\x41\n\x06GetFab\x12\x19.flwr.proto.GetFabRequest\x1a\x1a.flwr.proto.GetFabResponse\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16\x66lwr/proto/fleet.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x14\x66lwr/proto/run.proto\x1a\x14\x66lwr/proto/fab.proto\x1a\x18\x66lwr/proto/message.proto\"*\n\x11\x43reateNodeRequest\x12\x15\n\rping_interval\x18\x01 \x01(\x01\"4\n\x12\x43reateNodeResponse\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\"3\n\x11\x44\x65leteNodeRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\"\x14\n\x12\x44\x65leteNodeResponse\"D\n\x0bPingRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x15\n\rping_interval\x18\x02 \x01(\x01\"\x1f\n\x0cPingResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\"J\n\x13PullMessagesRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x13\n\x0bmessage_ids\x18\x02 \x03(\t\"l\n\x14PullMessagesResponse\x12(\n\treconnect\x18\x01 \x01(\x0b\x32\x15.flwr.proto.Reconnect\x12*\n\rmessages_list\x18\x02 \x03(\x0b\x32\x13.flwr.proto.Message\"a\n\x13PushMessagesRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12*\n\rmessages_list\x18\x02 \x03(\x0b\x32\x13.flwr.proto.Message\"\xb0\x01\n\x14PushMessagesResponse\x12(\n\treconnect\x18\x01 \x01(\x0b\x32\x15.flwr.proto.Reconnect\x12>\n\x07results\x18\x02 \x03(\x0b\x32-.flwr.proto.PushMessagesResponse.ResultsEntry\x1a.\n\x0cResultsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\r:\x02\x38\x01\"\x1e\n\tReconnect\x12\x11\n\treconnect\x18\x01 \x01(\x04\x32\x92\x04\n\x05\x46leet\x12M\n\nCreateNode\x12\x1d.flwr.proto.CreateNodeRequest\x1a\x1e.flwr.proto.CreateNodeResponse\"\x00\x12M\n\nDeleteNode\x12\x1d.flwr.proto.DeleteNodeRequest\x1a\x1e.flwr.proto.DeleteNodeResponse\"\x00\x12;\n\x04Ping\x12\x17.flwr.proto.PingRequest\x1a\x18.flwr.proto.PingResponse\"\x00\x12S\n\x0cPullMessages\x12\x1f.flwr.proto.PullMessagesRequest\x1a .flwr.proto.PullMessagesResponse\"\x00\x12S\n\x0cPushMessages\x12\x1f.flwr.proto.PushMessagesRequest\x1a .flwr.proto.PushMessagesResponse\"\x00\x12\x41\n\x06GetRun\x12\x19.flwr.proto.GetRunRequest\x1a\x1a.flwr.proto.GetRunResponse\"\x00\x12\x41\n\x06GetFab\x12\x19.flwr.proto.GetFabRequest\x1a\x1a.flwr.proto.GetFabResponse\"\x00\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.fleet_pb2', _globals) if not _descriptor._USE_C_DESCRIPTORS: DESCRIPTOR._loaded_options = None - _globals['_PUSHTASKRESRESPONSE_RESULTSENTRY']._loaded_options = None - _globals['_PUSHTASKRESRESPONSE_RESULTSENTRY']._serialized_options = b'8\001' _globals['_PUSHMESSAGESRESPONSE_RESULTSENTRY']._loaded_options = None _globals['_PUSHMESSAGESRESPONSE_RESULTSENTRY']._serialized_options = b'8\001' - _globals['_CREATENODEREQUEST']._serialized_start=154 - _globals['_CREATENODEREQUEST']._serialized_end=196 - _globals['_CREATENODERESPONSE']._serialized_start=198 - _globals['_CREATENODERESPONSE']._serialized_end=250 - _globals['_DELETENODEREQUEST']._serialized_start=252 - _globals['_DELETENODEREQUEST']._serialized_end=303 - _globals['_DELETENODERESPONSE']._serialized_start=305 - _globals['_DELETENODERESPONSE']._serialized_end=325 - _globals['_PINGREQUEST']._serialized_start=327 - _globals['_PINGREQUEST']._serialized_end=395 - _globals['_PINGRESPONSE']._serialized_start=397 - _globals['_PINGRESPONSE']._serialized_end=428 - _globals['_PULLTASKINSREQUEST']._serialized_start=430 - _globals['_PULLTASKINSREQUEST']._serialized_end=500 - _globals['_PULLTASKINSRESPONSE']._serialized_start=502 - _globals['_PULLTASKINSRESPONSE']._serialized_end=609 - _globals['_PUSHTASKRESREQUEST']._serialized_start=611 - _globals['_PUSHTASKRESREQUEST']._serialized_end=707 - _globals['_PUSHTASKRESRESPONSE']._serialized_start=710 - _globals['_PUSHTASKRESRESPONSE']._serialized_end=884 - _globals['_PUSHTASKRESRESPONSE_RESULTSENTRY']._serialized_start=838 - _globals['_PUSHTASKRESRESPONSE_RESULTSENTRY']._serialized_end=884 - _globals['_PULLMESSAGESREQUEST']._serialized_start=886 - _globals['_PULLMESSAGESREQUEST']._serialized_end=960 - _globals['_PULLMESSAGESRESPONSE']._serialized_start=962 - _globals['_PULLMESSAGESRESPONSE']._serialized_end=1070 - _globals['_PUSHMESSAGESREQUEST']._serialized_start=1072 - _globals['_PUSHMESSAGESREQUEST']._serialized_end=1169 - _globals['_PUSHMESSAGESRESPONSE']._serialized_start=1172 - _globals['_PUSHMESSAGESRESPONSE']._serialized_end=1348 - _globals['_PUSHMESSAGESRESPONSE_RESULTSENTRY']._serialized_start=838 - _globals['_PUSHMESSAGESRESPONSE_RESULTSENTRY']._serialized_end=884 - _globals['_RECONNECT']._serialized_start=1350 - _globals['_RECONNECT']._serialized_end=1380 - _globals['_FLEET']._serialized_start=1383 - _globals['_FLEET']._serialized_end=2077 + _globals['_CREATENODEREQUEST']._serialized_start=131 + _globals['_CREATENODEREQUEST']._serialized_end=173 + _globals['_CREATENODERESPONSE']._serialized_start=175 + _globals['_CREATENODERESPONSE']._serialized_end=227 + _globals['_DELETENODEREQUEST']._serialized_start=229 + _globals['_DELETENODEREQUEST']._serialized_end=280 + _globals['_DELETENODERESPONSE']._serialized_start=282 + _globals['_DELETENODERESPONSE']._serialized_end=302 + _globals['_PINGREQUEST']._serialized_start=304 + _globals['_PINGREQUEST']._serialized_end=372 + _globals['_PINGRESPONSE']._serialized_start=374 + _globals['_PINGRESPONSE']._serialized_end=405 + _globals['_PULLMESSAGESREQUEST']._serialized_start=407 + _globals['_PULLMESSAGESREQUEST']._serialized_end=481 + _globals['_PULLMESSAGESRESPONSE']._serialized_start=483 + _globals['_PULLMESSAGESRESPONSE']._serialized_end=591 + _globals['_PUSHMESSAGESREQUEST']._serialized_start=593 + _globals['_PUSHMESSAGESREQUEST']._serialized_end=690 + _globals['_PUSHMESSAGESRESPONSE']._serialized_start=693 + _globals['_PUSHMESSAGESRESPONSE']._serialized_end=869 + _globals['_PUSHMESSAGESRESPONSE_RESULTSENTRY']._serialized_start=823 + _globals['_PUSHMESSAGESRESPONSE_RESULTSENTRY']._serialized_end=869 + _globals['_RECONNECT']._serialized_start=871 + _globals['_RECONNECT']._serialized_end=901 + _globals['_FLEET']._serialized_start=904 + _globals['_FLEET']._serialized_end=1434 # @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/fleet_pb2.pyi b/src/py/flwr/proto/fleet_pb2.pyi index b7ee05bc1086..956eaebc5a5b 100644 --- a/src/py/flwr/proto/fleet_pb2.pyi +++ b/src/py/flwr/proto/fleet_pb2.pyi @@ -5,7 +5,6 @@ isort:skip_file import builtins import flwr.proto.message_pb2 import flwr.proto.node_pb2 -import flwr.proto.task_pb2 import google.protobuf.descriptor import google.protobuf.internal.containers import google.protobuf.message @@ -87,89 +86,6 @@ class PingResponse(google.protobuf.message.Message): def ClearField(self, field_name: typing_extensions.Literal["success",b"success"]) -> None: ... global___PingResponse = PingResponse -class PullTaskInsRequest(google.protobuf.message.Message): - """PullTaskIns messages""" - DESCRIPTOR: google.protobuf.descriptor.Descriptor - NODE_FIELD_NUMBER: builtins.int - TASK_IDS_FIELD_NUMBER: builtins.int - @property - def node(self) -> flwr.proto.node_pb2.Node: ... - @property - def task_ids(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[typing.Text]: ... - def __init__(self, - *, - node: typing.Optional[flwr.proto.node_pb2.Node] = ..., - task_ids: typing.Optional[typing.Iterable[typing.Text]] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal["node",b"node"]) -> builtins.bool: ... - def ClearField(self, field_name: typing_extensions.Literal["node",b"node","task_ids",b"task_ids"]) -> None: ... -global___PullTaskInsRequest = PullTaskInsRequest - -class PullTaskInsResponse(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - RECONNECT_FIELD_NUMBER: builtins.int - TASK_INS_LIST_FIELD_NUMBER: builtins.int - @property - def reconnect(self) -> global___Reconnect: ... - @property - def task_ins_list(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[flwr.proto.task_pb2.TaskIns]: ... - def __init__(self, - *, - reconnect: typing.Optional[global___Reconnect] = ..., - task_ins_list: typing.Optional[typing.Iterable[flwr.proto.task_pb2.TaskIns]] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal["reconnect",b"reconnect"]) -> builtins.bool: ... - def ClearField(self, field_name: typing_extensions.Literal["reconnect",b"reconnect","task_ins_list",b"task_ins_list"]) -> None: ... -global___PullTaskInsResponse = PullTaskInsResponse - -class PushTaskResRequest(google.protobuf.message.Message): - """PushTaskRes messages""" - DESCRIPTOR: google.protobuf.descriptor.Descriptor - NODE_FIELD_NUMBER: builtins.int - TASK_RES_LIST_FIELD_NUMBER: builtins.int - @property - def node(self) -> flwr.proto.node_pb2.Node: ... - @property - def task_res_list(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[flwr.proto.task_pb2.TaskRes]: ... - def __init__(self, - *, - node: typing.Optional[flwr.proto.node_pb2.Node] = ..., - task_res_list: typing.Optional[typing.Iterable[flwr.proto.task_pb2.TaskRes]] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal["node",b"node"]) -> builtins.bool: ... - def ClearField(self, field_name: typing_extensions.Literal["node",b"node","task_res_list",b"task_res_list"]) -> None: ... -global___PushTaskResRequest = PushTaskResRequest - -class PushTaskResResponse(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - class ResultsEntry(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - KEY_FIELD_NUMBER: builtins.int - VALUE_FIELD_NUMBER: builtins.int - key: typing.Text - value: builtins.int - def __init__(self, - *, - key: typing.Text = ..., - value: builtins.int = ..., - ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ... - - RECONNECT_FIELD_NUMBER: builtins.int - RESULTS_FIELD_NUMBER: builtins.int - @property - def reconnect(self) -> global___Reconnect: ... - @property - def results(self) -> google.protobuf.internal.containers.ScalarMap[typing.Text, builtins.int]: ... - def __init__(self, - *, - reconnect: typing.Optional[global___Reconnect] = ..., - results: typing.Optional[typing.Mapping[typing.Text, builtins.int]] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal["reconnect",b"reconnect"]) -> builtins.bool: ... - def ClearField(self, field_name: typing_extensions.Literal["reconnect",b"reconnect","results",b"results"]) -> None: ... -global___PushTaskResResponse = PushTaskResResponse - class PullMessagesRequest(google.protobuf.message.Message): """PullMessages messages""" DESCRIPTOR: google.protobuf.descriptor.Descriptor diff --git a/src/py/flwr/proto/fleet_pb2_grpc.py b/src/py/flwr/proto/fleet_pb2_grpc.py index e4e00b4a33bf..7949894e5115 100644 --- a/src/py/flwr/proto/fleet_pb2_grpc.py +++ b/src/py/flwr/proto/fleet_pb2_grpc.py @@ -51,21 +51,11 @@ def __init__(self, channel): request_serializer=flwr_dot_proto_dot_fleet__pb2.PingRequest.SerializeToString, response_deserializer=flwr_dot_proto_dot_fleet__pb2.PingResponse.FromString, _registered_method=True) - self.PullTaskIns = channel.unary_unary( - '/flwr.proto.Fleet/PullTaskIns', - request_serializer=flwr_dot_proto_dot_fleet__pb2.PullTaskInsRequest.SerializeToString, - response_deserializer=flwr_dot_proto_dot_fleet__pb2.PullTaskInsResponse.FromString, - _registered_method=True) self.PullMessages = channel.unary_unary( '/flwr.proto.Fleet/PullMessages', request_serializer=flwr_dot_proto_dot_fleet__pb2.PullMessagesRequest.SerializeToString, response_deserializer=flwr_dot_proto_dot_fleet__pb2.PullMessagesResponse.FromString, _registered_method=True) - self.PushTaskRes = channel.unary_unary( - '/flwr.proto.Fleet/PushTaskRes', - request_serializer=flwr_dot_proto_dot_fleet__pb2.PushTaskResRequest.SerializeToString, - response_deserializer=flwr_dot_proto_dot_fleet__pb2.PushTaskResResponse.FromString, - _registered_method=True) self.PushMessages = channel.unary_unary( '/flwr.proto.Fleet/PushMessages', request_serializer=flwr_dot_proto_dot_fleet__pb2.PushMessagesRequest.SerializeToString, @@ -104,33 +94,19 @@ def Ping(self, request, context): context.set_details('Method not implemented!') raise NotImplementedError('Method not implemented!') - def PullTaskIns(self, request, context): - """Retrieve one or more tasks, if possible - - HTTP API path: /api/v1/fleet/pull-task-ins - """ - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - def PullMessages(self, request, context): - """HTTP API path: /api/v1/fleet/pull-messages - """ - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') + """Retrieve one or more messages, if possible - def PushTaskRes(self, request, context): - """Complete one or more tasks, if possible - - HTTP API path: /api/v1/fleet/push-task-res + HTTP API path: /api/v1/fleet/pull-messages """ context.set_code(grpc.StatusCode.UNIMPLEMENTED) context.set_details('Method not implemented!') raise NotImplementedError('Method not implemented!') def PushMessages(self, request, context): - """HTTP API path: /api/v1/fleet/push-messages + """Complete one or more messages, if possible + + HTTP API path: /api/v1/fleet/push-messages """ context.set_code(grpc.StatusCode.UNIMPLEMENTED) context.set_details('Method not implemented!') @@ -167,21 +143,11 @@ def add_FleetServicer_to_server(servicer, server): request_deserializer=flwr_dot_proto_dot_fleet__pb2.PingRequest.FromString, response_serializer=flwr_dot_proto_dot_fleet__pb2.PingResponse.SerializeToString, ), - 'PullTaskIns': grpc.unary_unary_rpc_method_handler( - servicer.PullTaskIns, - request_deserializer=flwr_dot_proto_dot_fleet__pb2.PullTaskInsRequest.FromString, - response_serializer=flwr_dot_proto_dot_fleet__pb2.PullTaskInsResponse.SerializeToString, - ), 'PullMessages': grpc.unary_unary_rpc_method_handler( servicer.PullMessages, request_deserializer=flwr_dot_proto_dot_fleet__pb2.PullMessagesRequest.FromString, response_serializer=flwr_dot_proto_dot_fleet__pb2.PullMessagesResponse.SerializeToString, ), - 'PushTaskRes': grpc.unary_unary_rpc_method_handler( - servicer.PushTaskRes, - request_deserializer=flwr_dot_proto_dot_fleet__pb2.PushTaskResRequest.FromString, - response_serializer=flwr_dot_proto_dot_fleet__pb2.PushTaskResResponse.SerializeToString, - ), 'PushMessages': grpc.unary_unary_rpc_method_handler( servicer.PushMessages, request_deserializer=flwr_dot_proto_dot_fleet__pb2.PushMessagesRequest.FromString, @@ -289,33 +255,6 @@ def Ping(request, metadata, _registered_method=True) - @staticmethod - def PullTaskIns(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, - target, - '/flwr.proto.Fleet/PullTaskIns', - flwr_dot_proto_dot_fleet__pb2.PullTaskInsRequest.SerializeToString, - flwr_dot_proto_dot_fleet__pb2.PullTaskInsResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - _registered_method=True) - @staticmethod def PullMessages(request, target, @@ -343,33 +282,6 @@ def PullMessages(request, metadata, _registered_method=True) - @staticmethod - def PushTaskRes(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, - target, - '/flwr.proto.Fleet/PushTaskRes', - flwr_dot_proto_dot_fleet__pb2.PushTaskResRequest.SerializeToString, - flwr_dot_proto_dot_fleet__pb2.PushTaskResResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - _registered_method=True) - @staticmethod def PushMessages(request, target, diff --git a/src/py/flwr/proto/fleet_pb2_grpc.pyi b/src/py/flwr/proto/fleet_pb2_grpc.pyi index 6e5771462aa9..3c0c14085c22 100644 --- a/src/py/flwr/proto/fleet_pb2_grpc.pyi +++ b/src/py/flwr/proto/fleet_pb2_grpc.pyi @@ -22,31 +22,21 @@ class FleetStub: flwr.proto.fleet_pb2.PingRequest, flwr.proto.fleet_pb2.PingResponse] - PullTaskIns: grpc.UnaryUnaryMultiCallable[ - flwr.proto.fleet_pb2.PullTaskInsRequest, - flwr.proto.fleet_pb2.PullTaskInsResponse] - """Retrieve one or more tasks, if possible - - HTTP API path: /api/v1/fleet/pull-task-ins - """ - PullMessages: grpc.UnaryUnaryMultiCallable[ flwr.proto.fleet_pb2.PullMessagesRequest, flwr.proto.fleet_pb2.PullMessagesResponse] - """HTTP API path: /api/v1/fleet/pull-messages""" + """Retrieve one or more messages, if possible - PushTaskRes: grpc.UnaryUnaryMultiCallable[ - flwr.proto.fleet_pb2.PushTaskResRequest, - flwr.proto.fleet_pb2.PushTaskResResponse] - """Complete one or more tasks, if possible - - HTTP API path: /api/v1/fleet/push-task-res + HTTP API path: /api/v1/fleet/pull-messages """ PushMessages: grpc.UnaryUnaryMultiCallable[ flwr.proto.fleet_pb2.PushMessagesRequest, flwr.proto.fleet_pb2.PushMessagesResponse] - """HTTP API path: /api/v1/fleet/push-messages""" + """Complete one or more messages, if possible + + HTTP API path: /api/v1/fleet/push-messages + """ GetRun: grpc.UnaryUnaryMultiCallable[ flwr.proto.run_pb2.GetRunRequest, @@ -77,33 +67,14 @@ class FleetServicer(metaclass=abc.ABCMeta): context: grpc.ServicerContext, ) -> flwr.proto.fleet_pb2.PingResponse: ... - @abc.abstractmethod - def PullTaskIns(self, - request: flwr.proto.fleet_pb2.PullTaskInsRequest, - context: grpc.ServicerContext, - ) -> flwr.proto.fleet_pb2.PullTaskInsResponse: - """Retrieve one or more tasks, if possible - - HTTP API path: /api/v1/fleet/pull-task-ins - """ - pass - @abc.abstractmethod def PullMessages(self, request: flwr.proto.fleet_pb2.PullMessagesRequest, context: grpc.ServicerContext, ) -> flwr.proto.fleet_pb2.PullMessagesResponse: - """HTTP API path: /api/v1/fleet/pull-messages""" - pass + """Retrieve one or more messages, if possible - @abc.abstractmethod - def PushTaskRes(self, - request: flwr.proto.fleet_pb2.PushTaskResRequest, - context: grpc.ServicerContext, - ) -> flwr.proto.fleet_pb2.PushTaskResResponse: - """Complete one or more tasks, if possible - - HTTP API path: /api/v1/fleet/push-task-res + HTTP API path: /api/v1/fleet/pull-messages """ pass @@ -112,7 +83,10 @@ class FleetServicer(metaclass=abc.ABCMeta): request: flwr.proto.fleet_pb2.PushMessagesRequest, context: grpc.ServicerContext, ) -> flwr.proto.fleet_pb2.PushMessagesResponse: - """HTTP API path: /api/v1/fleet/push-messages""" + """Complete one or more messages, if possible + + HTTP API path: /api/v1/fleet/push-messages + """ pass @abc.abstractmethod diff --git a/src/py/flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py b/src/py/flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py index f9bb49b4c6c6..92a51c637486 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +++ b/src/py/flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py @@ -33,12 +33,8 @@ PingResponse, PullMessagesRequest, PullMessagesResponse, - PullTaskInsRequest, - PullTaskInsResponse, PushMessagesRequest, PushMessagesResponse, - PushTaskResRequest, - PushTaskResResponse, ) from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611 from flwr.server.superlink.ffs.ffs_factory import FfsFactory @@ -89,17 +85,6 @@ def Ping(self, request: PingRequest, context: grpc.ServicerContext) -> PingRespo state=self.state_factory.state(), ) - def PullTaskIns( - self, request: PullTaskInsRequest, context: grpc.ServicerContext - ) -> PullTaskInsResponse: - """Pull TaskIns.""" - log(INFO, "[Fleet.PullTaskIns] node_id=%s", request.node.node_id) - log(DEBUG, "[Fleet.PullTaskIns] Request: %s", MessageToDict(request)) - return message_handler.pull_task_ins( - request=request, - state=self.state_factory.state(), - ) - def PullMessages( self, request: PullMessagesRequest, context: grpc.ServicerContext ) -> PullMessagesResponse: @@ -111,29 +96,6 @@ def PullMessages( state=self.state_factory.state(), ) - def PushTaskRes( - self, request: PushTaskResRequest, context: grpc.ServicerContext - ) -> PushTaskResResponse: - """Push TaskRes.""" - if request.task_res_list: - log( - INFO, - "[Fleet.PushTaskRes] Push results from node_id=%s", - request.task_res_list[0].task.producer.node_id, - ) - else: - log(INFO, "[Fleet.PushTaskRes] No task results to push") - - try: - res = message_handler.push_task_res( - request=request, - state=self.state_factory.state(), - ) - except InvalidRunStatusException as e: - abort_grpc_context(e.message, context) - - return res - def PushMessages( self, request: PushMessagesRequest, context: grpc.ServicerContext ) -> PushMessagesResponse: diff --git a/src/py/flwr/server/superlink/fleet/grpc_rere/fleet_servicer_test.py b/src/py/flwr/server/superlink/fleet/grpc_rere/fleet_servicer_test.py index d943436d102f..3f949bbfc497 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_rere/fleet_servicer_test.py +++ b/src/py/flwr/server/superlink/fleet/grpc_rere/fleet_servicer_test.py @@ -32,12 +32,9 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611 PushMessagesRequest, PushMessagesResponse, - PushTaskResRequest, - PushTaskResResponse, ) from flwr.proto.node_pb2 import Node # pylint: disable=E0611 from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611 -from flwr.proto.task_pb2 import Task, TaskRes # pylint: disable=E0611 from flwr.server.app import _run_fleet_api_grpc_rere from flwr.server.superlink.ffs.ffs_factory import FfsFactory from flwr.server.superlink.linkstate.linkstate_factory import LinkStateFactory @@ -70,11 +67,6 @@ def setUp(self) -> None: ) self._channel = grpc.insecure_channel("localhost:9092") - self._push_task_res = self._channel.unary_unary( - "/flwr.proto.Fleet/PushTaskRes", - request_serializer=PushTaskResRequest.SerializeToString, - response_deserializer=PushTaskResResponse.FromString, - ) self._push_messages = self._channel.unary_unary( "/flwr.proto.Fleet/PushMessages", request_serializer=PushMessagesRequest.SerializeToString, @@ -103,26 +95,6 @@ def _transition_run_status(self, run_id: int, num_transitions: int) -> None: if num_transitions > 2: _ = self.state.update_run_status(run_id, RunStatus(Status.FINISHED, "", "")) - def test_successful_push_task_res_if_running(self) -> None: - """Test `PushTaskRes` success.""" - # Prepare - node_id = self.state.create_node(ping_interval=30) - run_id = self.state.create_run("", "", "", {}, ConfigsRecord()) - # Transition status to running. PushTaskRes is only allowed in running status. - self._transition_run_status(run_id, 2) - request = PushTaskResRequest( - task_res_list=[ - TaskRes(task=Task(producer=Node(node_id=node_id)), run_id=run_id) - ] - ) - - # Execute - response, call = self._push_task_res.with_call(request=request) - - # Assert - assert isinstance(response, PushTaskResResponse) - assert grpc.StatusCode.OK == call.code() - def test_successful_push_messages_if_running(self) -> None: """Test `PushMessages` success.""" # Prepare @@ -145,39 +117,6 @@ def test_successful_push_messages_if_running(self) -> None: assert isinstance(response, PushMessagesResponse) assert grpc.StatusCode.OK == call.code() - def _assert_push_task_res_not_allowed(self, node_id: int, run_id: int) -> None: - """Assert `PushTaskRes` not allowed.""" - run_status = self.state.get_run_status({run_id})[run_id] - request = PushTaskResRequest( - task_res_list=[ - TaskRes(task=Task(producer=Node(node_id=node_id)), run_id=run_id) - ] - ) - - with self.assertRaises(grpc.RpcError) as e: - self._push_task_res.with_call(request=request) - assert e.exception.code() == grpc.StatusCode.PERMISSION_DENIED - assert e.exception.details() == self.status_to_msg[run_status.status] - - @parameterized.expand( - [ - (0,), # Test not successful if RunStatus is pending. - (1,), # Test not successful if RunStatus is starting. - (3,), # Test not successful if RunStatus is finished. - ] - ) # type: ignore - def test_push_task_res_not_successful_if_not_running( - self, num_transitions: int - ) -> None: - """Test `PushTaskRes` not successful if RunStatus is not running.""" - # Prepare - node_id = self.state.create_node(ping_interval=30) - run_id = self.state.create_run("", "", "", {}, ConfigsRecord()) - self._transition_run_status(run_id, num_transitions) - - # Execute & Assert - self._assert_push_task_res_not_allowed(node_id, run_id) - def _assert_push_messages_not_allowed(self, node_id: int, run_id: int) -> None: """Assert `PushMessages` not allowed.""" run_status = self.state.get_run_status({run_id})[run_id] diff --git a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor_test.py b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor_test.py index 5550eea63618..6807548fd1df 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor_test.py +++ b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor_test.py @@ -47,16 +47,11 @@ PingResponse, PullMessagesRequest, PullMessagesResponse, - PullTaskInsRequest, - PullTaskInsResponse, PushMessagesRequest, PushMessagesResponse, - PushTaskResRequest, - PushTaskResResponse, ) from flwr.proto.node_pb2 import Node # pylint: disable=E0611 from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611 -from flwr.proto.task_pb2 import Task, TaskRes # pylint: disable=E0611 from flwr.server.app import _run_fleet_api_grpc_rere from flwr.server.superlink.ffs.ffs_factory import FfsFactory from flwr.server.superlink.linkstate.linkstate_factory import LinkStateFactory @@ -98,21 +93,11 @@ def setUp(self) -> None: request_serializer=DeleteNodeRequest.SerializeToString, response_deserializer=DeleteNodeResponse.FromString, ) - self._pull_task_ins = self._channel.unary_unary( - "/flwr.proto.Fleet/PullTaskIns", - request_serializer=PullTaskInsRequest.SerializeToString, - response_deserializer=PullTaskInsResponse.FromString, - ) self._pull_messages = self._channel.unary_unary( "/flwr.proto.Fleet/PullMessages", request_serializer=PullMessagesRequest.SerializeToString, response_deserializer=PullMessagesResponse.FromString, ) - self._push_task_res = self._channel.unary_unary( - "/flwr.proto.Fleet/PushTaskRes", - request_serializer=PushTaskResRequest.SerializeToString, - response_deserializer=PushTaskResResponse.FromString, - ) self._push_messages = self._channel.unary_unary( "/flwr.proto.Fleet/PushMessages", request_serializer=PushMessagesRequest.SerializeToString, @@ -193,33 +178,12 @@ def _test_delete_node(self, metadata: list[Any]) -> Any: req = DeleteNodeRequest(node=Node(node_id=node_id)) return self._delete_node.with_call(request=req, metadata=metadata) - def _test_pull_task_ins(self, metadata: list[Any]) -> Any: - """Test PullTaskIns.""" - node_id = self._create_node_and_set_public_key() - req = PullTaskInsRequest(node=Node(node_id=node_id)) - return self._pull_task_ins.with_call(request=req, metadata=metadata) - def _test_pull_messages(self, metadata: list[Any]) -> Any: """Test PullMessages.""" node_id = self._create_node_and_set_public_key() req = PullMessagesRequest(node=Node(node_id=node_id)) return self._pull_messages.with_call(request=req, metadata=metadata) - def _test_push_task_res(self, metadata: list[Any]) -> Any: - """Test PushTaskRes.""" - node_id = self._create_node_and_set_public_key() - run_id = self.state.create_run("", "", "", {}, ConfigsRecord()) - # Transition status to running. PushTaskRes is only allowed in running status. - self.state.update_run_status(run_id, RunStatus(Status.STARTING, "", "")) - self.state.update_run_status(run_id, RunStatus(Status.RUNNING, "", "")) - req = PushTaskResRequest( - node=Node(node_id=node_id), - task_res_list=[ - TaskRes(task=Task(producer=Node(node_id=node_id)), run_id=run_id) - ], - ) - return self._push_task_res.with_call(request=req, metadata=metadata) - def _test_push_messages(self, metadata: list[Any]) -> Any: """Test PushMessages.""" node_id = self._create_node_and_set_public_key() @@ -274,8 +238,6 @@ def _create_node_and_set_public_key(self) -> int: [ (_test_create_node,), (_test_delete_node,), - (_test_pull_task_ins,), - (_test_push_task_res,), (_test_pull_messages,), (_test_push_messages,), (_test_get_run,), @@ -297,8 +259,6 @@ def test_successful_rpc_with_metadata( [ (_test_create_node,), (_test_delete_node,), - (_test_pull_task_ins,), - (_test_push_task_res,), (_test_pull_messages,), (_test_push_messages,), (_test_get_run,), @@ -319,8 +279,6 @@ def test_unsuccessful_rpc_with_invalid_signature( [ (_test_create_node,), (_test_delete_node,), - (_test_pull_task_ins,), - (_test_push_task_res,), (_test_pull_messages,), (_test_push_messages,), (_test_get_run,), @@ -341,8 +299,6 @@ def test_unsuccessful_rpc_with_invalid_public_key( [ (_test_create_node,), (_test_delete_node,), - (_test_pull_task_ins,), - (_test_push_task_res,), (_test_pull_messages,), (_test_push_messages,), (_test_get_run,), diff --git a/src/py/flwr/server/superlink/fleet/message_handler/message_handler.py b/src/py/flwr/server/superlink/fleet/message_handler/message_handler.py index bc46d9462a49..fcba1afe3008 100644 --- a/src/py/flwr/server/superlink/fleet/message_handler/message_handler.py +++ b/src/py/flwr/server/superlink/fleet/message_handler/message_handler.py @@ -39,12 +39,8 @@ PingResponse, PullMessagesRequest, PullMessagesResponse, - PullTaskInsRequest, - PullTaskInsResponse, PushMessagesRequest, PushMessagesResponse, - PushTaskResRequest, - PushTaskResResponse, Reconnect, ) from flwr.proto.node_pb2 import Node # pylint: disable=E0611 @@ -53,7 +49,7 @@ GetRunResponse, Run, ) -from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611 +from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611 from flwr.server.superlink.ffs.ffs import Ffs from flwr.server.superlink.linkstate import LinkState from flwr.server.superlink.utils import check_abort @@ -89,21 +85,6 @@ def ping( return PingResponse(success=res) -def pull_task_ins(request: PullTaskInsRequest, state: LinkState) -> PullTaskInsResponse: - """Pull TaskIns handler.""" - node = request.node # pylint: disable=no-member - node_id: int = node.node_id - - # Retrieve TaskIns from State - task_ins_list: list[TaskIns] = state.get_task_ins(node_id=node_id, limit=1) - - # Build response - response = PullTaskInsResponse( - task_ins_list=task_ins_list, - ) - return response - - def pull_messages( request: PullMessagesRequest, state: LinkState ) -> PullMessagesResponse: @@ -124,35 +105,6 @@ def pull_messages( return PullMessagesResponse(messages_list=msg_proto) -def push_task_res(request: PushTaskResRequest, state: LinkState) -> PushTaskResResponse: - """Push TaskRes handler.""" - # pylint: disable=no-member - task_res: TaskRes = request.task_res_list[0] - # pylint: enable=no-member - - # Abort if the run is not running - abort_msg = check_abort( - task_res.run_id, - [Status.PENDING, Status.STARTING, Status.FINISHED], - state, - ) - if abort_msg: - raise InvalidRunStatusException(abort_msg) - - # Set pushed_at (timestamp in seconds) - task_res.task.pushed_at = time.time() - - # Store TaskRes in State - task_id: Optional[UUID] = state.store_task_res(task_res=task_res) - - # Build response - response = PushTaskResResponse( - reconnect=Reconnect(reconnect=5), - results={str(task_id): 0}, - ) - return response - - def push_messages( request: PushMessagesRequest, state: LinkState ) -> PushMessagesResponse: diff --git a/src/py/flwr/server/superlink/fleet/message_handler/message_handler_test.py b/src/py/flwr/server/superlink/fleet/message_handler/message_handler_test.py index 03fdc46bdc17..85805ac2bd6d 100644 --- a/src/py/flwr/server/superlink/fleet/message_handler/message_handler_test.py +++ b/src/py/flwr/server/superlink/fleet/message_handler/message_handler_test.py @@ -23,21 +23,11 @@ CreateNodeRequest, DeleteNodeRequest, PullMessagesRequest, - PullTaskInsRequest, PushMessagesRequest, - PushTaskResRequest, ) from flwr.proto.node_pb2 import Node # pylint: disable=E0611 -from flwr.proto.task_pb2 import Task, TaskRes # pylint: disable=E0611 - -from .message_handler import ( - create_node, - delete_node, - pull_messages, - pull_task_ins, - push_messages, - push_task_res, -) + +from .message_handler import create_node, delete_node, pull_messages, push_messages def test_create_node() -> None: @@ -94,24 +84,6 @@ def test_delete_node_success() -> None: state.get_task_res.assert_not_called() -def test_pull_task_ins() -> None: - """Test pull_task_ins.""" - # Prepare - request = PullTaskInsRequest(node=Node(node_id=123)) - state = MagicMock() - - # Execute - pull_task_ins(request=request, state=state) - - # Assert - state.create_node.assert_not_called() - state.delete_node.assert_not_called() - state.store_task_ins.assert_not_called() - state.get_task_ins.assert_called_once() - state.store_task_res.assert_not_called() - state.get_task_res.assert_not_called() - - def test_pull_messages() -> None: """Test pull_messages.""" # Prepare @@ -130,33 +102,6 @@ def test_pull_messages() -> None: state.get_task_res.assert_not_called() -def test_push_task_res() -> None: - """Test push_task_res.""" - # Prepare - request = PushTaskResRequest( - task_res_list=[ - TaskRes( - task_id="", - group_id="", - run_id=0, - task=Task(), - ), - ], - ) - state = MagicMock() - - # Execute - push_task_res(request=request, state=state) - - # Assert - state.create_node.assert_not_called() - state.delete_node.assert_not_called() - state.store_task_ins.assert_not_called() - state.get_task_ins.assert_not_called() - state.store_task_res.assert_called_once() - state.get_task_res.assert_not_called() - - def test_push_messages() -> None: """Test push_messages.""" # Prepare diff --git a/src/py/flwr/server/superlink/fleet/rest_rere/rest_api.py b/src/py/flwr/server/superlink/fleet/rest_rere/rest_api.py index 91abe7639c1c..74102238bd02 100644 --- a/src/py/flwr/server/superlink/fleet/rest_rere/rest_api.py +++ b/src/py/flwr/server/superlink/fleet/rest_rere/rest_api.py @@ -33,12 +33,8 @@ PingResponse, PullMessagesRequest, PullMessagesResponse, - PullTaskInsRequest, - PullTaskInsResponse, PushMessagesRequest, PushMessagesResponse, - PushTaskResRequest, - PushTaskResResponse, ) from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611 from flwr.server.superlink.ffs.ffs import Ffs @@ -110,16 +106,6 @@ async def delete_node(request: DeleteNodeRequest) -> DeleteNodeResponse: return message_handler.delete_node(request=request, state=state) -@rest_request_response(PullTaskInsRequest) -async def pull_task_ins(request: PullTaskInsRequest) -> PullTaskInsResponse: - """Pull TaskIns.""" - # Get state from app - state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state() - - # Handle message - return message_handler.pull_task_ins(request=request, state=state) - - @rest_request_response(PullMessagesRequest) async def pull_message(request: PullMessagesRequest) -> PullMessagesResponse: """Pull PullMessages.""" @@ -130,17 +116,6 @@ async def pull_message(request: PullMessagesRequest) -> PullMessagesResponse: return message_handler.pull_messages(request=request, state=state) -# Check if token is needed here -@rest_request_response(PushTaskResRequest) -async def push_task_res(request: PushTaskResRequest) -> PushTaskResResponse: - """Push TaskRes.""" - # Get state from app - state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state() - - # Handle message - return message_handler.push_task_res(request=request, state=state) - - @rest_request_response(PushMessagesRequest) async def push_message(request: PushMessagesRequest) -> PushMessagesResponse: """Pull PushMessages.""" @@ -187,9 +162,7 @@ async def get_fab(request: GetFabRequest) -> GetFabResponse: routes = [ Route("/api/v0/fleet/create-node", create_node, methods=["POST"]), Route("/api/v0/fleet/delete-node", delete_node, methods=["POST"]), - Route("/api/v0/fleet/pull-task-ins", pull_task_ins, methods=["POST"]), Route("/api/v0/fleet/pull-messages", pull_message, methods=["POST"]), - Route("/api/v0/fleet/push-task-res", push_task_res, methods=["POST"]), Route("/api/v0/fleet/push-messages", push_message, methods=["POST"]), Route("/api/v0/fleet/ping", ping, methods=["POST"]), Route("/api/v0/fleet/get-run", get_run, methods=["POST"]),