From e54a33bdbb96dfbec7abcf844aa3157e5480ee9c Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Wed, 12 Jun 2024 14:54:02 +0100 Subject: [PATCH] feat(framework) Add `GetRun` rpc to the Driver servicer (#3578) --- src/proto/flwr/proto/driver.proto | 4 ++ src/proto/flwr/proto/fleet.proto | 10 +-- src/proto/flwr/proto/run.proto | 26 ++++++++ .../grpc_rere_client/client_interceptor.py | 2 +- .../client_interceptor_test.py | 3 +- .../client/grpc_rere_client/connection.py | 3 +- src/py/flwr/client/rest_client/connection.py | 3 +- src/py/flwr/proto/driver_pb2.py | 39 ++++++------ src/py/flwr/proto/driver_pb2_grpc.py | 35 +++++++++++ src/py/flwr/proto/driver_pb2_grpc.pyi | 14 +++++ src/py/flwr/proto/fleet_pb2.py | 61 +++++++++---------- src/py/flwr/proto/fleet_pb2.pyi | 42 ------------- src/py/flwr/proto/fleet_pb2_grpc.py | 13 ++-- src/py/flwr/proto/fleet_pb2_grpc.pyi | 9 +-- src/py/flwr/proto/run_pb2.py | 30 +++++++++ src/py/flwr/proto/run_pb2.pyi | 52 ++++++++++++++++ src/py/flwr/proto/run_pb2_grpc.py | 4 ++ src/py/flwr/proto/run_pb2_grpc.pyi | 4 ++ .../superlink/driver/driver_servicer.py | 7 +++ .../fleet/grpc_rere/fleet_servicer.py | 3 +- .../fleet/grpc_rere/server_interceptor.py | 3 +- .../grpc_rere/server_interceptor_test.py | 3 +- .../fleet/message_handler/message_handler.py | 8 ++- .../superlink/fleet/rest_rere/rest_api.py | 2 +- src/py/flwr_tool/protoc_test.py | 2 +- 25 files changed, 251 insertions(+), 131 deletions(-) create mode 100644 src/proto/flwr/proto/run.proto create mode 100644 src/py/flwr/proto/run_pb2.py create mode 100644 src/py/flwr/proto/run_pb2.pyi create mode 100644 src/py/flwr/proto/run_pb2_grpc.py create mode 100644 src/py/flwr/proto/run_pb2_grpc.pyi diff --git a/src/proto/flwr/proto/driver.proto b/src/proto/flwr/proto/driver.proto index 54e6b6b41b68..edbd5d91bb5b 100644 --- a/src/proto/flwr/proto/driver.proto +++ b/src/proto/flwr/proto/driver.proto @@ -19,6 +19,7 @@ package flwr.proto; import "flwr/proto/node.proto"; import "flwr/proto/task.proto"; +import "flwr/proto/run.proto"; service Driver { // Request run_id @@ -32,6 +33,9 @@ service Driver { // Get task results rpc PullTaskRes(PullTaskResRequest) returns (PullTaskResResponse) {} + + // Get run details + rpc GetRun(GetRunRequest) returns (GetRunResponse) {} } // CreateRun diff --git a/src/proto/flwr/proto/fleet.proto b/src/proto/flwr/proto/fleet.proto index df6b5843023d..24f60bb3d825 100644 --- a/src/proto/flwr/proto/fleet.proto +++ b/src/proto/flwr/proto/fleet.proto @@ -19,6 +19,7 @@ package flwr.proto; import "flwr/proto/node.proto"; import "flwr/proto/task.proto"; +import "flwr/proto/run.proto"; service Fleet { rpc CreateNode(CreateNodeRequest) returns (CreateNodeResponse) {} @@ -70,13 +71,4 @@ message PushTaskResResponse { map results = 2; } -// GetRun messages -message Run { - sint64 run_id = 1; - string fab_id = 2; - string fab_version = 3; -} -message GetRunRequest { sint64 run_id = 1; } -message GetRunResponse { Run run = 1; } - message Reconnect { uint64 reconnect = 1; } diff --git a/src/proto/flwr/proto/run.proto b/src/proto/flwr/proto/run.proto new file mode 100644 index 000000000000..76a7fd91532f --- /dev/null +++ b/src/proto/flwr/proto/run.proto @@ -0,0 +1,26 @@ +// Copyright 2024 Flower Labs GmbH. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================== + +syntax = "proto3"; + +package flwr.proto; + +message Run { + sint64 run_id = 1; + string fab_id = 2; + string fab_version = 3; +} +message GetRunRequest { sint64 run_id = 1; } +message GetRunResponse { Run run = 1; } diff --git a/src/py/flwr/client/grpc_rere_client/client_interceptor.py b/src/py/flwr/client/grpc_rere_client/client_interceptor.py index 8bc55878971d..d2dded8a73d9 100644 --- a/src/py/flwr/client/grpc_rere_client/client_interceptor.py +++ b/src/py/flwr/client/grpc_rere_client/client_interceptor.py @@ -31,11 +31,11 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611 CreateNodeRequest, DeleteNodeRequest, - GetRunRequest, PingRequest, PullTaskInsRequest, PushTaskResRequest, ) +from flwr.proto.run_pb2 import GetRunRequest # pylint: disable=E0611 _PUBLIC_KEY_HEADER = "public-key" _AUTH_TOKEN_HEADER = "auth-token" diff --git a/src/py/flwr/client/grpc_rere_client/client_interceptor_test.py b/src/py/flwr/client/grpc_rere_client/client_interceptor_test.py index 487361a06026..cc35ffef46db 100644 --- a/src/py/flwr/client/grpc_rere_client/client_interceptor_test.py +++ b/src/py/flwr/client/grpc_rere_client/client_interceptor_test.py @@ -41,13 +41,12 @@ CreateNodeResponse, DeleteNodeRequest, DeleteNodeResponse, - GetRunRequest, - GetRunResponse, PullTaskInsRequest, PullTaskInsResponse, PushTaskResRequest, PushTaskResResponse, ) +from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611 from .client_interceptor import _AUTH_TOKEN_HEADER, _PUBLIC_KEY_HEADER, Request diff --git a/src/py/flwr/client/grpc_rere_client/connection.py b/src/py/flwr/client/grpc_rere_client/connection.py index 9579d5830165..52d0cc58b2bb 100644 --- a/src/py/flwr/client/grpc_rere_client/connection.py +++ b/src/py/flwr/client/grpc_rere_client/connection.py @@ -44,8 +44,6 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611 CreateNodeRequest, DeleteNodeRequest, - GetRunRequest, - GetRunResponse, PingRequest, PingResponse, PullTaskInsRequest, @@ -53,6 +51,7 @@ ) from flwr.proto.fleet_pb2_grpc import FleetStub # pylint: disable=E0611 from flwr.proto.node_pb2 import Node # pylint: disable=E0611 +from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611 from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611 from .client_interceptor import AuthenticateClientInterceptor diff --git a/src/py/flwr/client/rest_client/connection.py b/src/py/flwr/client/rest_client/connection.py index da8fbd351ab1..7383eae3d22b 100644 --- a/src/py/flwr/client/rest_client/connection.py +++ b/src/py/flwr/client/rest_client/connection.py @@ -46,8 +46,6 @@ CreateNodeResponse, DeleteNodeRequest, DeleteNodeResponse, - GetRunRequest, - GetRunResponse, PingRequest, PingResponse, PullTaskInsRequest, @@ -56,6 +54,7 @@ PushTaskResResponse, ) from flwr.proto.node_pb2 import Node # pylint: disable=E0611 +from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611 from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611 try: diff --git a/src/py/flwr/proto/driver_pb2.py b/src/py/flwr/proto/driver_pb2.py index b0caae58ff6f..a2458b445563 100644 --- a/src/py/flwr/proto/driver_pb2.py +++ b/src/py/flwr/proto/driver_pb2.py @@ -14,31 +14,32 @@ from flwr.proto import node_pb2 as flwr_dot_proto_dot_node__pb2 from flwr.proto import task_pb2 as flwr_dot_proto_dot_task__pb2 +from flwr.proto import run_pb2 as flwr_dot_proto_dot_run__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x17\x66lwr/proto/driver.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x15\x66lwr/proto/task.proto\"7\n\x10\x43reateRunRequest\x12\x0e\n\x06\x66\x61\x62_id\x18\x01 \x01(\t\x12\x13\n\x0b\x66\x61\x62_version\x18\x02 \x01(\t\"#\n\x11\x43reateRunResponse\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\"!\n\x0fGetNodesRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\"3\n\x10GetNodesResponse\x12\x1f\n\x05nodes\x18\x01 \x03(\x0b\x32\x10.flwr.proto.Node\"@\n\x12PushTaskInsRequest\x12*\n\rtask_ins_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskIns\"\'\n\x13PushTaskInsResponse\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"F\n\x12PullTaskResRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"A\n\x13PullTaskResResponse\x12*\n\rtask_res_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskRes2\xc1\x02\n\x06\x44river\x12J\n\tCreateRun\x12\x1c.flwr.proto.CreateRunRequest\x1a\x1d.flwr.proto.CreateRunResponse\"\x00\x12G\n\x08GetNodes\x12\x1b.flwr.proto.GetNodesRequest\x1a\x1c.flwr.proto.GetNodesResponse\"\x00\x12P\n\x0bPushTaskIns\x12\x1e.flwr.proto.PushTaskInsRequest\x1a\x1f.flwr.proto.PushTaskInsResponse\"\x00\x12P\n\x0bPullTaskRes\x12\x1e.flwr.proto.PullTaskResRequest\x1a\x1f.flwr.proto.PullTaskResResponse\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x17\x66lwr/proto/driver.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x15\x66lwr/proto/task.proto\x1a\x14\x66lwr/proto/run.proto\"7\n\x10\x43reateRunRequest\x12\x0e\n\x06\x66\x61\x62_id\x18\x01 \x01(\t\x12\x13\n\x0b\x66\x61\x62_version\x18\x02 \x01(\t\"#\n\x11\x43reateRunResponse\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\"!\n\x0fGetNodesRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\"3\n\x10GetNodesResponse\x12\x1f\n\x05nodes\x18\x01 \x03(\x0b\x32\x10.flwr.proto.Node\"@\n\x12PushTaskInsRequest\x12*\n\rtask_ins_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskIns\"\'\n\x13PushTaskInsResponse\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"F\n\x12PullTaskResRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"A\n\x13PullTaskResResponse\x12*\n\rtask_res_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskRes2\x84\x03\n\x06\x44river\x12J\n\tCreateRun\x12\x1c.flwr.proto.CreateRunRequest\x1a\x1d.flwr.proto.CreateRunResponse\"\x00\x12G\n\x08GetNodes\x12\x1b.flwr.proto.GetNodesRequest\x1a\x1c.flwr.proto.GetNodesResponse\"\x00\x12P\n\x0bPushTaskIns\x12\x1e.flwr.proto.PushTaskInsRequest\x1a\x1f.flwr.proto.PushTaskInsResponse\"\x00\x12P\n\x0bPullTaskRes\x12\x1e.flwr.proto.PullTaskResRequest\x1a\x1f.flwr.proto.PullTaskResResponse\"\x00\x12\x41\n\x06GetRun\x12\x19.flwr.proto.GetRunRequest\x1a\x1a.flwr.proto.GetRunResponse\"\x00\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.driver_pb2', _globals) if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None - _globals['_CREATERUNREQUEST']._serialized_start=85 - _globals['_CREATERUNREQUEST']._serialized_end=140 - _globals['_CREATERUNRESPONSE']._serialized_start=142 - _globals['_CREATERUNRESPONSE']._serialized_end=177 - _globals['_GETNODESREQUEST']._serialized_start=179 - _globals['_GETNODESREQUEST']._serialized_end=212 - _globals['_GETNODESRESPONSE']._serialized_start=214 - _globals['_GETNODESRESPONSE']._serialized_end=265 - _globals['_PUSHTASKINSREQUEST']._serialized_start=267 - _globals['_PUSHTASKINSREQUEST']._serialized_end=331 - _globals['_PUSHTASKINSRESPONSE']._serialized_start=333 - _globals['_PUSHTASKINSRESPONSE']._serialized_end=372 - _globals['_PULLTASKRESREQUEST']._serialized_start=374 - _globals['_PULLTASKRESREQUEST']._serialized_end=444 - _globals['_PULLTASKRESRESPONSE']._serialized_start=446 - _globals['_PULLTASKRESRESPONSE']._serialized_end=511 - _globals['_DRIVER']._serialized_start=514 - _globals['_DRIVER']._serialized_end=835 + _globals['_CREATERUNREQUEST']._serialized_start=107 + _globals['_CREATERUNREQUEST']._serialized_end=162 + _globals['_CREATERUNRESPONSE']._serialized_start=164 + _globals['_CREATERUNRESPONSE']._serialized_end=199 + _globals['_GETNODESREQUEST']._serialized_start=201 + _globals['_GETNODESREQUEST']._serialized_end=234 + _globals['_GETNODESRESPONSE']._serialized_start=236 + _globals['_GETNODESRESPONSE']._serialized_end=287 + _globals['_PUSHTASKINSREQUEST']._serialized_start=289 + _globals['_PUSHTASKINSREQUEST']._serialized_end=353 + _globals['_PUSHTASKINSRESPONSE']._serialized_start=355 + _globals['_PUSHTASKINSRESPONSE']._serialized_end=394 + _globals['_PULLTASKRESREQUEST']._serialized_start=396 + _globals['_PULLTASKRESREQUEST']._serialized_end=466 + _globals['_PULLTASKRESRESPONSE']._serialized_start=468 + _globals['_PULLTASKRESRESPONSE']._serialized_end=533 + _globals['_DRIVER']._serialized_start=536 + _globals['_DRIVER']._serialized_end=924 # @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/driver_pb2_grpc.py b/src/py/flwr/proto/driver_pb2_grpc.py index ac6815023ebd..2cd3ebe62a63 100644 --- a/src/py/flwr/proto/driver_pb2_grpc.py +++ b/src/py/flwr/proto/driver_pb2_grpc.py @@ -3,6 +3,7 @@ import grpc from flwr.proto import driver_pb2 as flwr_dot_proto_dot_driver__pb2 +from flwr.proto import run_pb2 as flwr_dot_proto_dot_run__pb2 class DriverStub(object): @@ -34,6 +35,11 @@ def __init__(self, channel): request_serializer=flwr_dot_proto_dot_driver__pb2.PullTaskResRequest.SerializeToString, response_deserializer=flwr_dot_proto_dot_driver__pb2.PullTaskResResponse.FromString, ) + self.GetRun = channel.unary_unary( + '/flwr.proto.Driver/GetRun', + request_serializer=flwr_dot_proto_dot_run__pb2.GetRunRequest.SerializeToString, + response_deserializer=flwr_dot_proto_dot_run__pb2.GetRunResponse.FromString, + ) class DriverServicer(object): @@ -67,6 +73,13 @@ def PullTaskRes(self, request, context): context.set_details('Method not implemented!') raise NotImplementedError('Method not implemented!') + def GetRun(self, request, context): + """Get run details + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + def add_DriverServicer_to_server(servicer, server): rpc_method_handlers = { @@ -90,6 +103,11 @@ def add_DriverServicer_to_server(servicer, server): request_deserializer=flwr_dot_proto_dot_driver__pb2.PullTaskResRequest.FromString, response_serializer=flwr_dot_proto_dot_driver__pb2.PullTaskResResponse.SerializeToString, ), + 'GetRun': grpc.unary_unary_rpc_method_handler( + servicer.GetRun, + request_deserializer=flwr_dot_proto_dot_run__pb2.GetRunRequest.FromString, + response_serializer=flwr_dot_proto_dot_run__pb2.GetRunResponse.SerializeToString, + ), } generic_handler = grpc.method_handlers_generic_handler( 'flwr.proto.Driver', rpc_method_handlers) @@ -167,3 +185,20 @@ def PullTaskRes(request, flwr_dot_proto_dot_driver__pb2.PullTaskResResponse.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def GetRun(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/flwr.proto.Driver/GetRun', + flwr_dot_proto_dot_run__pb2.GetRunRequest.SerializeToString, + flwr_dot_proto_dot_run__pb2.GetRunResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/src/py/flwr/proto/driver_pb2_grpc.pyi b/src/py/flwr/proto/driver_pb2_grpc.pyi index 43cf45f39b25..4ff09db588ca 100644 --- a/src/py/flwr/proto/driver_pb2_grpc.pyi +++ b/src/py/flwr/proto/driver_pb2_grpc.pyi @@ -4,6 +4,7 @@ isort:skip_file """ import abc import flwr.proto.driver_pb2 +import flwr.proto.run_pb2 import grpc class DriverStub: @@ -28,6 +29,11 @@ class DriverStub: flwr.proto.driver_pb2.PullTaskResResponse] """Get task results""" + GetRun: grpc.UnaryUnaryMultiCallable[ + flwr.proto.run_pb2.GetRunRequest, + flwr.proto.run_pb2.GetRunResponse] + """Get run details""" + class DriverServicer(metaclass=abc.ABCMeta): @abc.abstractmethod @@ -62,5 +68,13 @@ class DriverServicer(metaclass=abc.ABCMeta): """Get task results""" pass + @abc.abstractmethod + def GetRun(self, + request: flwr.proto.run_pb2.GetRunRequest, + context: grpc.ServicerContext, + ) -> flwr.proto.run_pb2.GetRunResponse: + """Get run details""" + pass + def add_DriverServicer_to_server(servicer: DriverServicer, server: grpc.Server) -> None: ... diff --git a/src/py/flwr/proto/fleet_pb2.py b/src/py/flwr/proto/fleet_pb2.py index 42f3292d910d..9763b71fed2f 100644 --- a/src/py/flwr/proto/fleet_pb2.py +++ b/src/py/flwr/proto/fleet_pb2.py @@ -14,9 +14,10 @@ from flwr.proto import node_pb2 as flwr_dot_proto_dot_node__pb2 from flwr.proto import task_pb2 as flwr_dot_proto_dot_task__pb2 +from flwr.proto import run_pb2 as flwr_dot_proto_dot_run__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16\x66lwr/proto/fleet.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x15\x66lwr/proto/task.proto\"*\n\x11\x43reateNodeRequest\x12\x15\n\rping_interval\x18\x01 \x01(\x01\"4\n\x12\x43reateNodeResponse\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\"3\n\x11\x44\x65leteNodeRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\"\x14\n\x12\x44\x65leteNodeResponse\"D\n\x0bPingRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x15\n\rping_interval\x18\x02 \x01(\x01\"\x1f\n\x0cPingResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\"F\n\x12PullTaskInsRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"k\n\x13PullTaskInsResponse\x12(\n\treconnect\x18\x01 \x01(\x0b\x32\x15.flwr.proto.Reconnect\x12*\n\rtask_ins_list\x18\x02 \x03(\x0b\x32\x13.flwr.proto.TaskIns\"@\n\x12PushTaskResRequest\x12*\n\rtask_res_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskRes\"\xae\x01\n\x13PushTaskResResponse\x12(\n\treconnect\x18\x01 \x01(\x0b\x32\x15.flwr.proto.Reconnect\x12=\n\x07results\x18\x02 \x03(\x0b\x32,.flwr.proto.PushTaskResResponse.ResultsEntry\x1a.\n\x0cResultsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\r:\x02\x38\x01\":\n\x03Run\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\x12\x0e\n\x06\x66\x61\x62_id\x18\x02 \x01(\t\x12\x13\n\x0b\x66\x61\x62_version\x18\x03 \x01(\t\"\x1f\n\rGetRunRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\".\n\x0eGetRunResponse\x12\x1c\n\x03run\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Run\"\x1e\n\tReconnect\x12\x11\n\treconnect\x18\x01 \x01(\x04\x32\xc9\x03\n\x05\x46leet\x12M\n\nCreateNode\x12\x1d.flwr.proto.CreateNodeRequest\x1a\x1e.flwr.proto.CreateNodeResponse\"\x00\x12M\n\nDeleteNode\x12\x1d.flwr.proto.DeleteNodeRequest\x1a\x1e.flwr.proto.DeleteNodeResponse\"\x00\x12;\n\x04Ping\x12\x17.flwr.proto.PingRequest\x1a\x18.flwr.proto.PingResponse\"\x00\x12P\n\x0bPullTaskIns\x12\x1e.flwr.proto.PullTaskInsRequest\x1a\x1f.flwr.proto.PullTaskInsResponse\"\x00\x12P\n\x0bPushTaskRes\x12\x1e.flwr.proto.PushTaskResRequest\x1a\x1f.flwr.proto.PushTaskResResponse\"\x00\x12\x41\n\x06GetRun\x12\x19.flwr.proto.GetRunRequest\x1a\x1a.flwr.proto.GetRunResponse\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16\x66lwr/proto/fleet.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x15\x66lwr/proto/task.proto\x1a\x14\x66lwr/proto/run.proto\"*\n\x11\x43reateNodeRequest\x12\x15\n\rping_interval\x18\x01 \x01(\x01\"4\n\x12\x43reateNodeResponse\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\"3\n\x11\x44\x65leteNodeRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\"\x14\n\x12\x44\x65leteNodeResponse\"D\n\x0bPingRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x15\n\rping_interval\x18\x02 \x01(\x01\"\x1f\n\x0cPingResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\"F\n\x12PullTaskInsRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"k\n\x13PullTaskInsResponse\x12(\n\treconnect\x18\x01 \x01(\x0b\x32\x15.flwr.proto.Reconnect\x12*\n\rtask_ins_list\x18\x02 \x03(\x0b\x32\x13.flwr.proto.TaskIns\"@\n\x12PushTaskResRequest\x12*\n\rtask_res_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskRes\"\xae\x01\n\x13PushTaskResResponse\x12(\n\treconnect\x18\x01 \x01(\x0b\x32\x15.flwr.proto.Reconnect\x12=\n\x07results\x18\x02 \x03(\x0b\x32,.flwr.proto.PushTaskResResponse.ResultsEntry\x1a.\n\x0cResultsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\r:\x02\x38\x01\"\x1e\n\tReconnect\x12\x11\n\treconnect\x18\x01 \x01(\x04\x32\xc9\x03\n\x05\x46leet\x12M\n\nCreateNode\x12\x1d.flwr.proto.CreateNodeRequest\x1a\x1e.flwr.proto.CreateNodeResponse\"\x00\x12M\n\nDeleteNode\x12\x1d.flwr.proto.DeleteNodeRequest\x1a\x1e.flwr.proto.DeleteNodeResponse\"\x00\x12;\n\x04Ping\x12\x17.flwr.proto.PingRequest\x1a\x18.flwr.proto.PingResponse\"\x00\x12P\n\x0bPullTaskIns\x12\x1e.flwr.proto.PullTaskInsRequest\x1a\x1f.flwr.proto.PullTaskInsResponse\"\x00\x12P\n\x0bPushTaskRes\x12\x1e.flwr.proto.PushTaskResRequest\x1a\x1f.flwr.proto.PushTaskResResponse\"\x00\x12\x41\n\x06GetRun\x12\x19.flwr.proto.GetRunRequest\x1a\x1a.flwr.proto.GetRunResponse\"\x00\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -25,36 +26,30 @@ DESCRIPTOR._options = None _globals['_PUSHTASKRESRESPONSE_RESULTSENTRY']._options = None _globals['_PUSHTASKRESRESPONSE_RESULTSENTRY']._serialized_options = b'8\001' - _globals['_CREATENODEREQUEST']._serialized_start=84 - _globals['_CREATENODEREQUEST']._serialized_end=126 - _globals['_CREATENODERESPONSE']._serialized_start=128 - _globals['_CREATENODERESPONSE']._serialized_end=180 - _globals['_DELETENODEREQUEST']._serialized_start=182 - _globals['_DELETENODEREQUEST']._serialized_end=233 - _globals['_DELETENODERESPONSE']._serialized_start=235 - _globals['_DELETENODERESPONSE']._serialized_end=255 - _globals['_PINGREQUEST']._serialized_start=257 - _globals['_PINGREQUEST']._serialized_end=325 - _globals['_PINGRESPONSE']._serialized_start=327 - _globals['_PINGRESPONSE']._serialized_end=358 - _globals['_PULLTASKINSREQUEST']._serialized_start=360 - _globals['_PULLTASKINSREQUEST']._serialized_end=430 - _globals['_PULLTASKINSRESPONSE']._serialized_start=432 - _globals['_PULLTASKINSRESPONSE']._serialized_end=539 - _globals['_PUSHTASKRESREQUEST']._serialized_start=541 - _globals['_PUSHTASKRESREQUEST']._serialized_end=605 - _globals['_PUSHTASKRESRESPONSE']._serialized_start=608 - _globals['_PUSHTASKRESRESPONSE']._serialized_end=782 - _globals['_PUSHTASKRESRESPONSE_RESULTSENTRY']._serialized_start=736 - _globals['_PUSHTASKRESRESPONSE_RESULTSENTRY']._serialized_end=782 - _globals['_RUN']._serialized_start=784 - _globals['_RUN']._serialized_end=842 - _globals['_GETRUNREQUEST']._serialized_start=844 - _globals['_GETRUNREQUEST']._serialized_end=875 - _globals['_GETRUNRESPONSE']._serialized_start=877 - _globals['_GETRUNRESPONSE']._serialized_end=923 - _globals['_RECONNECT']._serialized_start=925 - _globals['_RECONNECT']._serialized_end=955 - _globals['_FLEET']._serialized_start=958 - _globals['_FLEET']._serialized_end=1415 + _globals['_CREATENODEREQUEST']._serialized_start=106 + _globals['_CREATENODEREQUEST']._serialized_end=148 + _globals['_CREATENODERESPONSE']._serialized_start=150 + _globals['_CREATENODERESPONSE']._serialized_end=202 + _globals['_DELETENODEREQUEST']._serialized_start=204 + _globals['_DELETENODEREQUEST']._serialized_end=255 + _globals['_DELETENODERESPONSE']._serialized_start=257 + _globals['_DELETENODERESPONSE']._serialized_end=277 + _globals['_PINGREQUEST']._serialized_start=279 + _globals['_PINGREQUEST']._serialized_end=347 + _globals['_PINGRESPONSE']._serialized_start=349 + _globals['_PINGRESPONSE']._serialized_end=380 + _globals['_PULLTASKINSREQUEST']._serialized_start=382 + _globals['_PULLTASKINSREQUEST']._serialized_end=452 + _globals['_PULLTASKINSRESPONSE']._serialized_start=454 + _globals['_PULLTASKINSRESPONSE']._serialized_end=561 + _globals['_PUSHTASKRESREQUEST']._serialized_start=563 + _globals['_PUSHTASKRESREQUEST']._serialized_end=627 + _globals['_PUSHTASKRESRESPONSE']._serialized_start=630 + _globals['_PUSHTASKRESRESPONSE']._serialized_end=804 + _globals['_PUSHTASKRESRESPONSE_RESULTSENTRY']._serialized_start=758 + _globals['_PUSHTASKRESRESPONSE_RESULTSENTRY']._serialized_end=804 + _globals['_RECONNECT']._serialized_start=806 + _globals['_RECONNECT']._serialized_end=836 + _globals['_FLEET']._serialized_start=839 + _globals['_FLEET']._serialized_end=1296 # @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/fleet_pb2.pyi b/src/py/flwr/proto/fleet_pb2.pyi index a6f38b703e76..5989f45c5c60 100644 --- a/src/py/flwr/proto/fleet_pb2.pyi +++ b/src/py/flwr/proto/fleet_pb2.pyi @@ -164,48 +164,6 @@ class PushTaskResResponse(google.protobuf.message.Message): def ClearField(self, field_name: typing_extensions.Literal["reconnect",b"reconnect","results",b"results"]) -> None: ... global___PushTaskResResponse = PushTaskResResponse -class Run(google.protobuf.message.Message): - """GetRun messages""" - DESCRIPTOR: google.protobuf.descriptor.Descriptor - RUN_ID_FIELD_NUMBER: builtins.int - FAB_ID_FIELD_NUMBER: builtins.int - FAB_VERSION_FIELD_NUMBER: builtins.int - run_id: builtins.int - fab_id: typing.Text - fab_version: typing.Text - def __init__(self, - *, - run_id: builtins.int = ..., - fab_id: typing.Text = ..., - fab_version: typing.Text = ..., - ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["fab_id",b"fab_id","fab_version",b"fab_version","run_id",b"run_id"]) -> None: ... -global___Run = Run - -class GetRunRequest(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - RUN_ID_FIELD_NUMBER: builtins.int - run_id: builtins.int - def __init__(self, - *, - run_id: builtins.int = ..., - ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["run_id",b"run_id"]) -> None: ... -global___GetRunRequest = GetRunRequest - -class GetRunResponse(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - RUN_FIELD_NUMBER: builtins.int - @property - def run(self) -> global___Run: ... - def __init__(self, - *, - run: typing.Optional[global___Run] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal["run",b"run"]) -> builtins.bool: ... - def ClearField(self, field_name: typing_extensions.Literal["run",b"run"]) -> None: ... -global___GetRunResponse = GetRunResponse - class Reconnect(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor RECONNECT_FIELD_NUMBER: builtins.int diff --git a/src/py/flwr/proto/fleet_pb2_grpc.py b/src/py/flwr/proto/fleet_pb2_grpc.py index 16757eaed381..e0b0fbc50460 100644 --- a/src/py/flwr/proto/fleet_pb2_grpc.py +++ b/src/py/flwr/proto/fleet_pb2_grpc.py @@ -3,6 +3,7 @@ import grpc from flwr.proto import fleet_pb2 as flwr_dot_proto_dot_fleet__pb2 +from flwr.proto import run_pb2 as flwr_dot_proto_dot_run__pb2 class FleetStub(object): @@ -41,8 +42,8 @@ def __init__(self, channel): ) self.GetRun = channel.unary_unary( '/flwr.proto.Fleet/GetRun', - request_serializer=flwr_dot_proto_dot_fleet__pb2.GetRunRequest.SerializeToString, - response_deserializer=flwr_dot_proto_dot_fleet__pb2.GetRunResponse.FromString, + request_serializer=flwr_dot_proto_dot_run__pb2.GetRunRequest.SerializeToString, + response_deserializer=flwr_dot_proto_dot_run__pb2.GetRunResponse.FromString, ) @@ -121,8 +122,8 @@ def add_FleetServicer_to_server(servicer, server): ), 'GetRun': grpc.unary_unary_rpc_method_handler( servicer.GetRun, - request_deserializer=flwr_dot_proto_dot_fleet__pb2.GetRunRequest.FromString, - response_serializer=flwr_dot_proto_dot_fleet__pb2.GetRunResponse.SerializeToString, + request_deserializer=flwr_dot_proto_dot_run__pb2.GetRunRequest.FromString, + response_serializer=flwr_dot_proto_dot_run__pb2.GetRunResponse.SerializeToString, ), } generic_handler = grpc.method_handlers_generic_handler( @@ -231,7 +232,7 @@ def GetRun(request, timeout=None, metadata=None): return grpc.experimental.unary_unary(request, target, '/flwr.proto.Fleet/GetRun', - flwr_dot_proto_dot_fleet__pb2.GetRunRequest.SerializeToString, - flwr_dot_proto_dot_fleet__pb2.GetRunResponse.FromString, + flwr_dot_proto_dot_run__pb2.GetRunRequest.SerializeToString, + flwr_dot_proto_dot_run__pb2.GetRunResponse.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/src/py/flwr/proto/fleet_pb2_grpc.pyi b/src/py/flwr/proto/fleet_pb2_grpc.pyi index f275cd149d69..1c0ab862d45c 100644 --- a/src/py/flwr/proto/fleet_pb2_grpc.pyi +++ b/src/py/flwr/proto/fleet_pb2_grpc.pyi @@ -4,6 +4,7 @@ isort:skip_file """ import abc import flwr.proto.fleet_pb2 +import flwr.proto.run_pb2 import grpc class FleetStub: @@ -37,8 +38,8 @@ class FleetStub: """ GetRun: grpc.UnaryUnaryMultiCallable[ - flwr.proto.fleet_pb2.GetRunRequest, - flwr.proto.fleet_pb2.GetRunResponse] + flwr.proto.run_pb2.GetRunRequest, + flwr.proto.run_pb2.GetRunResponse] class FleetServicer(metaclass=abc.ABCMeta): @@ -84,9 +85,9 @@ class FleetServicer(metaclass=abc.ABCMeta): @abc.abstractmethod def GetRun(self, - request: flwr.proto.fleet_pb2.GetRunRequest, + request: flwr.proto.run_pb2.GetRunRequest, context: grpc.ServicerContext, - ) -> flwr.proto.fleet_pb2.GetRunResponse: ... + ) -> flwr.proto.run_pb2.GetRunResponse: ... def add_FleetServicer_to_server(servicer: FleetServicer, server: grpc.Server) -> None: ... diff --git a/src/py/flwr/proto/run_pb2.py b/src/py/flwr/proto/run_pb2.py new file mode 100644 index 000000000000..13f06e7169aa --- /dev/null +++ b/src/py/flwr/proto/run_pb2.py @@ -0,0 +1,30 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: flwr/proto/run.proto +# Protobuf Python Version: 4.25.0 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x14\x66lwr/proto/run.proto\x12\nflwr.proto\":\n\x03Run\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\x12\x0e\n\x06\x66\x61\x62_id\x18\x02 \x01(\t\x12\x13\n\x0b\x66\x61\x62_version\x18\x03 \x01(\t\"\x1f\n\rGetRunRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\".\n\x0eGetRunResponse\x12\x1c\n\x03run\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Runb\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.run_pb2', _globals) +if _descriptor._USE_C_DESCRIPTORS == False: + DESCRIPTOR._options = None + _globals['_RUN']._serialized_start=36 + _globals['_RUN']._serialized_end=94 + _globals['_GETRUNREQUEST']._serialized_start=96 + _globals['_GETRUNREQUEST']._serialized_end=127 + _globals['_GETRUNRESPONSE']._serialized_start=129 + _globals['_GETRUNRESPONSE']._serialized_end=175 +# @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/run_pb2.pyi b/src/py/flwr/proto/run_pb2.pyi new file mode 100644 index 000000000000..401d27855a41 --- /dev/null +++ b/src/py/flwr/proto/run_pb2.pyi @@ -0,0 +1,52 @@ +""" +@generated by mypy-protobuf. Do not edit manually! +isort:skip_file +""" +import builtins +import google.protobuf.descriptor +import google.protobuf.message +import typing +import typing_extensions + +DESCRIPTOR: google.protobuf.descriptor.FileDescriptor + +class Run(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + RUN_ID_FIELD_NUMBER: builtins.int + FAB_ID_FIELD_NUMBER: builtins.int + FAB_VERSION_FIELD_NUMBER: builtins.int + run_id: builtins.int + fab_id: typing.Text + fab_version: typing.Text + def __init__(self, + *, + run_id: builtins.int = ..., + fab_id: typing.Text = ..., + fab_version: typing.Text = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["fab_id",b"fab_id","fab_version",b"fab_version","run_id",b"run_id"]) -> None: ... +global___Run = Run + +class GetRunRequest(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + RUN_ID_FIELD_NUMBER: builtins.int + run_id: builtins.int + def __init__(self, + *, + run_id: builtins.int = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["run_id",b"run_id"]) -> None: ... +global___GetRunRequest = GetRunRequest + +class GetRunResponse(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + RUN_FIELD_NUMBER: builtins.int + @property + def run(self) -> global___Run: ... + def __init__(self, + *, + run: typing.Optional[global___Run] = ..., + ) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["run",b"run"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["run",b"run"]) -> None: ... +global___GetRunResponse = GetRunResponse diff --git a/src/py/flwr/proto/run_pb2_grpc.py b/src/py/flwr/proto/run_pb2_grpc.py new file mode 100644 index 000000000000..2daafffebfc8 --- /dev/null +++ b/src/py/flwr/proto/run_pb2_grpc.py @@ -0,0 +1,4 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc + diff --git a/src/py/flwr/proto/run_pb2_grpc.pyi b/src/py/flwr/proto/run_pb2_grpc.pyi new file mode 100644 index 000000000000..f3a5a087ef5d --- /dev/null +++ b/src/py/flwr/proto/run_pb2_grpc.pyi @@ -0,0 +1,4 @@ +""" +@generated by mypy-protobuf. Do not edit manually! +isort:skip_file +""" diff --git a/src/py/flwr/server/superlink/driver/driver_servicer.py b/src/py/flwr/server/superlink/driver/driver_servicer.py index ce2d9d68d8ca..e808616af778 100644 --- a/src/py/flwr/server/superlink/driver/driver_servicer.py +++ b/src/py/flwr/server/superlink/driver/driver_servicer.py @@ -35,6 +35,7 @@ PushTaskInsResponse, ) from flwr.proto.node_pb2 import Node # pylint: disable=E0611 +from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611 from flwr.proto.task_pb2 import TaskRes # pylint: disable=E0611 from flwr.server.superlink.state import State, StateFactory from flwr.server.utils.validator import validate_task_ins_or_res @@ -129,6 +130,12 @@ def on_rpc_done() -> None: context.set_code(grpc.StatusCode.OK) return PullTaskResResponse(task_res_list=task_res_list) + def GetRun( + self, request: GetRunRequest, context: grpc.ServicerContext + ) -> GetRunResponse: + """Get run information.""" + raise NotImplementedError + def _raise_if(validation_error: bool, detail: str) -> None: if validation_error: diff --git a/src/py/flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py b/src/py/flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py index 03a2ec064213..13e024eb31e4 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +++ b/src/py/flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py @@ -26,8 +26,6 @@ CreateNodeResponse, DeleteNodeRequest, DeleteNodeResponse, - GetRunRequest, - GetRunResponse, PingRequest, PingResponse, PullTaskInsRequest, @@ -35,6 +33,7 @@ PushTaskResRequest, PushTaskResResponse, ) +from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611 from flwr.server.superlink.fleet.message_handler import message_handler from flwr.server.superlink.state import StateFactory diff --git a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py index 6a302679a235..21e9c44907cd 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +++ b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py @@ -34,8 +34,6 @@ CreateNodeResponse, DeleteNodeRequest, DeleteNodeResponse, - GetRunRequest, - GetRunResponse, PingRequest, PingResponse, PullTaskInsRequest, @@ -44,6 +42,7 @@ PushTaskResResponse, ) from flwr.proto.node_pb2 import Node # pylint: disable=E0611 +from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611 from flwr.server.superlink.state import State _PUBLIC_KEY_HEADER = "public-key" diff --git a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor_test.py b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor_test.py index c4c71e5a8188..01499102b7d8 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor_test.py +++ b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor_test.py @@ -32,8 +32,6 @@ CreateNodeResponse, DeleteNodeRequest, DeleteNodeResponse, - GetRunRequest, - GetRunResponse, PingRequest, PingResponse, PullTaskInsRequest, @@ -42,6 +40,7 @@ PushTaskResResponse, ) from flwr.proto.node_pb2 import Node # pylint: disable=E0611 +from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611 from flwr.proto.task_pb2 import Task, TaskRes # pylint: disable=E0611 from flwr.server.app import ADDRESS_FLEET_API_GRPC_RERE, _run_fleet_api_grpc_rere from flwr.server.superlink.state.state_factory import StateFactory diff --git a/src/py/flwr/server/superlink/fleet/message_handler/message_handler.py b/src/py/flwr/server/superlink/fleet/message_handler/message_handler.py index 83b005a4cb8e..4c796502436b 100644 --- a/src/py/flwr/server/superlink/fleet/message_handler/message_handler.py +++ b/src/py/flwr/server/superlink/fleet/message_handler/message_handler.py @@ -24,8 +24,6 @@ CreateNodeResponse, DeleteNodeRequest, DeleteNodeResponse, - GetRunRequest, - GetRunResponse, PingRequest, PingResponse, PullTaskInsRequest, @@ -33,9 +31,13 @@ PushTaskResRequest, PushTaskResResponse, Reconnect, - Run, ) from flwr.proto.node_pb2 import Node # pylint: disable=E0611 +from flwr.proto.run_pb2 import ( # pylint: disable=E0611 + GetRunRequest, + GetRunResponse, + Run, +) from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611 from flwr.server.superlink.state import State diff --git a/src/py/flwr/server/superlink/fleet/rest_rere/rest_api.py b/src/py/flwr/server/superlink/fleet/rest_rere/rest_api.py index 8ac7c6cfc613..c7ff496d39bf 100644 --- a/src/py/flwr/server/superlink/fleet/rest_rere/rest_api.py +++ b/src/py/flwr/server/superlink/fleet/rest_rere/rest_api.py @@ -21,11 +21,11 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611 CreateNodeRequest, DeleteNodeRequest, - GetRunRequest, PingRequest, PullTaskInsRequest, PushTaskResRequest, ) +from flwr.proto.run_pb2 import GetRunRequest # pylint: disable=E0611 from flwr.server.superlink.fleet.message_handler import message_handler from flwr.server.superlink.state import State diff --git a/src/py/flwr_tool/protoc_test.py b/src/py/flwr_tool/protoc_test.py index 8dcf4c6474d6..6aec4251c384 100644 --- a/src/py/flwr_tool/protoc_test.py +++ b/src/py/flwr_tool/protoc_test.py @@ -28,4 +28,4 @@ def test_directories() -> None: def test_proto_file_count() -> None: """Test if the correct number of proto files were captured by the glob.""" - assert len(PROTO_FILES) == 8 + assert len(PROTO_FILES) == 9