From eb6d9be853b13449b395b4160d41492a495fefa4 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Thu, 24 Oct 2024 16:42:40 +0100 Subject: [PATCH] feat(framework) Add `PullServerAppInputs` and `PushServerAppOutputs` rpcs to `Driver` service (#4363) --- src/proto/flwr/proto/driver.proto | 24 +++++++ src/py/flwr/proto/driver_pb2.py | 39 +++++++---- src/py/flwr/proto/driver_pb2.pyi | 59 ++++++++++++++++ src/py/flwr/proto/driver_pb2_grpc.py | 68 +++++++++++++++++++ src/py/flwr/proto/driver_pb2_grpc.pyi | 26 +++++++ .../superlink/driver/driver_servicer.py | 16 +++++ 6 files changed, 217 insertions(+), 15 deletions(-) diff --git a/src/proto/flwr/proto/driver.proto b/src/proto/flwr/proto/driver.proto index e26003862a76..2385d578cfde 100644 --- a/src/proto/flwr/proto/driver.proto +++ b/src/proto/flwr/proto/driver.proto @@ -18,6 +18,7 @@ syntax = "proto3"; package flwr.proto; import "flwr/proto/node.proto"; +import "flwr/proto/message.proto"; import "flwr/proto/task.proto"; import "flwr/proto/run.proto"; import "flwr/proto/fab.proto"; @@ -40,6 +41,14 @@ service Driver { // Get FAB rpc GetFab(GetFabRequest) returns (GetFabResponse) {} + + // Pull ServerApp inputs + rpc PullServerAppInputs(PullServerAppInputsRequest) + returns (PullServerAppInputsResponse) {} + + // Push ServerApp outputs + rpc PushServerAppOutputs(PushServerAppOutputsRequest) + returns (PushServerAppOutputsResponse) {} } // GetNodes messages @@ -56,3 +65,18 @@ message PullTaskResRequest { repeated string task_ids = 2; } message PullTaskResResponse { repeated TaskRes task_res_list = 1; } + +// PullServerAppInputs messages +message PullServerAppInputsRequest { uint64 run_id = 1; } +message PullServerAppInputsResponse { + Context context = 1; + Run run = 2; + Fab fab = 3; +} + +// PushServerAppOutputs messages +message PushServerAppOutputsRequest { + uint64 run_id = 1; + Context context = 2; +} +message PushServerAppOutputsResponse {} diff --git a/src/py/flwr/proto/driver_pb2.py b/src/py/flwr/proto/driver_pb2.py index d294b03be5af..f64d8373013c 100644 --- a/src/py/flwr/proto/driver_pb2.py +++ b/src/py/flwr/proto/driver_pb2.py @@ -13,30 +13,39 @@ from flwr.proto import node_pb2 as flwr_dot_proto_dot_node__pb2 +from flwr.proto import message_pb2 as flwr_dot_proto_dot_message__pb2 from flwr.proto import task_pb2 as flwr_dot_proto_dot_task__pb2 from flwr.proto import run_pb2 as flwr_dot_proto_dot_run__pb2 from flwr.proto import fab_pb2 as flwr_dot_proto_dot_fab__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x17\x66lwr/proto/driver.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x15\x66lwr/proto/task.proto\x1a\x14\x66lwr/proto/run.proto\x1a\x14\x66lwr/proto/fab.proto\"!\n\x0fGetNodesRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"3\n\x10GetNodesResponse\x12\x1f\n\x05nodes\x18\x01 \x03(\x0b\x32\x10.flwr.proto.Node\"@\n\x12PushTaskInsRequest\x12*\n\rtask_ins_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskIns\"\'\n\x13PushTaskInsResponse\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"F\n\x12PullTaskResRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"A\n\x13PullTaskResResponse\x12*\n\rtask_res_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskRes2\xc7\x03\n\x06\x44river\x12J\n\tCreateRun\x12\x1c.flwr.proto.CreateRunRequest\x1a\x1d.flwr.proto.CreateRunResponse\"\x00\x12G\n\x08GetNodes\x12\x1b.flwr.proto.GetNodesRequest\x1a\x1c.flwr.proto.GetNodesResponse\"\x00\x12P\n\x0bPushTaskIns\x12\x1e.flwr.proto.PushTaskInsRequest\x1a\x1f.flwr.proto.PushTaskInsResponse\"\x00\x12P\n\x0bPullTaskRes\x12\x1e.flwr.proto.PullTaskResRequest\x1a\x1f.flwr.proto.PullTaskResResponse\"\x00\x12\x41\n\x06GetRun\x12\x19.flwr.proto.GetRunRequest\x1a\x1a.flwr.proto.GetRunResponse\"\x00\x12\x41\n\x06GetFab\x12\x19.flwr.proto.GetFabRequest\x1a\x1a.flwr.proto.GetFabResponse\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x17\x66lwr/proto/driver.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x18\x66lwr/proto/message.proto\x1a\x15\x66lwr/proto/task.proto\x1a\x14\x66lwr/proto/run.proto\x1a\x14\x66lwr/proto/fab.proto\"!\n\x0fGetNodesRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"3\n\x10GetNodesResponse\x12\x1f\n\x05nodes\x18\x01 \x03(\x0b\x32\x10.flwr.proto.Node\"@\n\x12PushTaskInsRequest\x12*\n\rtask_ins_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskIns\"\'\n\x13PushTaskInsResponse\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"F\n\x12PullTaskResRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"A\n\x13PullTaskResResponse\x12*\n\rtask_res_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskRes\",\n\x1aPullServerAppInputsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"\x7f\n\x1bPullServerAppInputsResponse\x12$\n\x07\x63ontext\x18\x01 \x01(\x0b\x32\x13.flwr.proto.Context\x12\x1c\n\x03run\x18\x02 \x01(\x0b\x32\x0f.flwr.proto.Run\x12\x1c\n\x03\x66\x61\x62\x18\x03 \x01(\x0b\x32\x0f.flwr.proto.Fab\"S\n\x1bPushServerAppOutputsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12$\n\x07\x63ontext\x18\x02 \x01(\x0b\x32\x13.flwr.proto.Context\"\x1e\n\x1cPushServerAppOutputsResponse2\x9e\x05\n\x06\x44river\x12J\n\tCreateRun\x12\x1c.flwr.proto.CreateRunRequest\x1a\x1d.flwr.proto.CreateRunResponse\"\x00\x12G\n\x08GetNodes\x12\x1b.flwr.proto.GetNodesRequest\x1a\x1c.flwr.proto.GetNodesResponse\"\x00\x12P\n\x0bPushTaskIns\x12\x1e.flwr.proto.PushTaskInsRequest\x1a\x1f.flwr.proto.PushTaskInsResponse\"\x00\x12P\n\x0bPullTaskRes\x12\x1e.flwr.proto.PullTaskResRequest\x1a\x1f.flwr.proto.PullTaskResResponse\"\x00\x12\x41\n\x06GetRun\x12\x19.flwr.proto.GetRunRequest\x1a\x1a.flwr.proto.GetRunResponse\"\x00\x12\x41\n\x06GetFab\x12\x19.flwr.proto.GetFabRequest\x1a\x1a.flwr.proto.GetFabResponse\"\x00\x12h\n\x13PullServerAppInputs\x12&.flwr.proto.PullServerAppInputsRequest\x1a\'.flwr.proto.PullServerAppInputsResponse\"\x00\x12k\n\x14PushServerAppOutputs\x12\'.flwr.proto.PushServerAppOutputsRequest\x1a(.flwr.proto.PushServerAppOutputsResponse\"\x00\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.driver_pb2', _globals) if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None - _globals['_GETNODESREQUEST']._serialized_start=129 - _globals['_GETNODESREQUEST']._serialized_end=162 - _globals['_GETNODESRESPONSE']._serialized_start=164 - _globals['_GETNODESRESPONSE']._serialized_end=215 - _globals['_PUSHTASKINSREQUEST']._serialized_start=217 - _globals['_PUSHTASKINSREQUEST']._serialized_end=281 - _globals['_PUSHTASKINSRESPONSE']._serialized_start=283 - _globals['_PUSHTASKINSRESPONSE']._serialized_end=322 - _globals['_PULLTASKRESREQUEST']._serialized_start=324 - _globals['_PULLTASKRESREQUEST']._serialized_end=394 - _globals['_PULLTASKRESRESPONSE']._serialized_start=396 - _globals['_PULLTASKRESRESPONSE']._serialized_end=461 - _globals['_DRIVER']._serialized_start=464 - _globals['_DRIVER']._serialized_end=919 + _globals['_GETNODESREQUEST']._serialized_start=155 + _globals['_GETNODESREQUEST']._serialized_end=188 + _globals['_GETNODESRESPONSE']._serialized_start=190 + _globals['_GETNODESRESPONSE']._serialized_end=241 + _globals['_PUSHTASKINSREQUEST']._serialized_start=243 + _globals['_PUSHTASKINSREQUEST']._serialized_end=307 + _globals['_PUSHTASKINSRESPONSE']._serialized_start=309 + _globals['_PUSHTASKINSRESPONSE']._serialized_end=348 + _globals['_PULLTASKRESREQUEST']._serialized_start=350 + _globals['_PULLTASKRESREQUEST']._serialized_end=420 + _globals['_PULLTASKRESRESPONSE']._serialized_start=422 + _globals['_PULLTASKRESRESPONSE']._serialized_end=487 + _globals['_PULLSERVERAPPINPUTSREQUEST']._serialized_start=489 + _globals['_PULLSERVERAPPINPUTSREQUEST']._serialized_end=533 + _globals['_PULLSERVERAPPINPUTSRESPONSE']._serialized_start=535 + _globals['_PULLSERVERAPPINPUTSRESPONSE']._serialized_end=662 + _globals['_PUSHSERVERAPPOUTPUTSREQUEST']._serialized_start=664 + _globals['_PUSHSERVERAPPOUTPUTSREQUEST']._serialized_end=747 + _globals['_PUSHSERVERAPPOUTPUTSRESPONSE']._serialized_start=749 + _globals['_PUSHSERVERAPPOUTPUTSRESPONSE']._serialized_end=779 + _globals['_DRIVER']._serialized_start=782 + _globals['_DRIVER']._serialized_end=1452 # @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/driver_pb2.pyi b/src/py/flwr/proto/driver_pb2.pyi index 77ceb496d70c..f52562b4f467 100644 --- a/src/py/flwr/proto/driver_pb2.pyi +++ b/src/py/flwr/proto/driver_pb2.pyi @@ -3,7 +3,10 @@ isort:skip_file """ import builtins +import flwr.proto.fab_pb2 +import flwr.proto.message_pb2 import flwr.proto.node_pb2 +import flwr.proto.run_pb2 import flwr.proto.task_pb2 import google.protobuf.descriptor import google.protobuf.internal.containers @@ -91,3 +94,59 @@ class PullTaskResResponse(google.protobuf.message.Message): ) -> None: ... def ClearField(self, field_name: typing_extensions.Literal["task_res_list",b"task_res_list"]) -> None: ... global___PullTaskResResponse = PullTaskResResponse + +class PullServerAppInputsRequest(google.protobuf.message.Message): + """PullServerAppInputs messages""" + DESCRIPTOR: google.protobuf.descriptor.Descriptor + RUN_ID_FIELD_NUMBER: builtins.int + run_id: builtins.int + def __init__(self, + *, + run_id: builtins.int = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["run_id",b"run_id"]) -> None: ... +global___PullServerAppInputsRequest = PullServerAppInputsRequest + +class PullServerAppInputsResponse(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + CONTEXT_FIELD_NUMBER: builtins.int + RUN_FIELD_NUMBER: builtins.int + FAB_FIELD_NUMBER: builtins.int + @property + def context(self) -> flwr.proto.message_pb2.Context: ... + @property + def run(self) -> flwr.proto.run_pb2.Run: ... + @property + def fab(self) -> flwr.proto.fab_pb2.Fab: ... + def __init__(self, + *, + context: typing.Optional[flwr.proto.message_pb2.Context] = ..., + run: typing.Optional[flwr.proto.run_pb2.Run] = ..., + fab: typing.Optional[flwr.proto.fab_pb2.Fab] = ..., + ) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["context",b"context","fab",b"fab","run",b"run"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["context",b"context","fab",b"fab","run",b"run"]) -> None: ... +global___PullServerAppInputsResponse = PullServerAppInputsResponse + +class PushServerAppOutputsRequest(google.protobuf.message.Message): + """PushServerAppOutputs messages""" + DESCRIPTOR: google.protobuf.descriptor.Descriptor + RUN_ID_FIELD_NUMBER: builtins.int + CONTEXT_FIELD_NUMBER: builtins.int + run_id: builtins.int + @property + def context(self) -> flwr.proto.message_pb2.Context: ... + def __init__(self, + *, + run_id: builtins.int = ..., + context: typing.Optional[flwr.proto.message_pb2.Context] = ..., + ) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["context",b"context"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["context",b"context","run_id",b"run_id"]) -> None: ... +global___PushServerAppOutputsRequest = PushServerAppOutputsRequest + +class PushServerAppOutputsResponse(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + def __init__(self, + ) -> None: ... +global___PushServerAppOutputsResponse = PushServerAppOutputsResponse diff --git a/src/py/flwr/proto/driver_pb2_grpc.py b/src/py/flwr/proto/driver_pb2_grpc.py index 91e9fd8b9bdd..ab2627e5b907 100644 --- a/src/py/flwr/proto/driver_pb2_grpc.py +++ b/src/py/flwr/proto/driver_pb2_grpc.py @@ -46,6 +46,16 @@ def __init__(self, channel): request_serializer=flwr_dot_proto_dot_fab__pb2.GetFabRequest.SerializeToString, response_deserializer=flwr_dot_proto_dot_fab__pb2.GetFabResponse.FromString, ) + self.PullServerAppInputs = channel.unary_unary( + '/flwr.proto.Driver/PullServerAppInputs', + request_serializer=flwr_dot_proto_dot_driver__pb2.PullServerAppInputsRequest.SerializeToString, + response_deserializer=flwr_dot_proto_dot_driver__pb2.PullServerAppInputsResponse.FromString, + ) + self.PushServerAppOutputs = channel.unary_unary( + '/flwr.proto.Driver/PushServerAppOutputs', + request_serializer=flwr_dot_proto_dot_driver__pb2.PushServerAppOutputsRequest.SerializeToString, + response_deserializer=flwr_dot_proto_dot_driver__pb2.PushServerAppOutputsResponse.FromString, + ) class DriverServicer(object): @@ -93,6 +103,20 @@ def GetFab(self, request, context): context.set_details('Method not implemented!') raise NotImplementedError('Method not implemented!') + def PullServerAppInputs(self, request, context): + """Pull ServerApp inputs + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def PushServerAppOutputs(self, request, context): + """Push ServerApp outputs + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + def add_DriverServicer_to_server(servicer, server): rpc_method_handlers = { @@ -126,6 +150,16 @@ def add_DriverServicer_to_server(servicer, server): request_deserializer=flwr_dot_proto_dot_fab__pb2.GetFabRequest.FromString, response_serializer=flwr_dot_proto_dot_fab__pb2.GetFabResponse.SerializeToString, ), + 'PullServerAppInputs': grpc.unary_unary_rpc_method_handler( + servicer.PullServerAppInputs, + request_deserializer=flwr_dot_proto_dot_driver__pb2.PullServerAppInputsRequest.FromString, + response_serializer=flwr_dot_proto_dot_driver__pb2.PullServerAppInputsResponse.SerializeToString, + ), + 'PushServerAppOutputs': grpc.unary_unary_rpc_method_handler( + servicer.PushServerAppOutputs, + request_deserializer=flwr_dot_proto_dot_driver__pb2.PushServerAppOutputsRequest.FromString, + response_serializer=flwr_dot_proto_dot_driver__pb2.PushServerAppOutputsResponse.SerializeToString, + ), } generic_handler = grpc.method_handlers_generic_handler( 'flwr.proto.Driver', rpc_method_handlers) @@ -237,3 +271,37 @@ def GetFab(request, flwr_dot_proto_dot_fab__pb2.GetFabResponse.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def PullServerAppInputs(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/flwr.proto.Driver/PullServerAppInputs', + flwr_dot_proto_dot_driver__pb2.PullServerAppInputsRequest.SerializeToString, + flwr_dot_proto_dot_driver__pb2.PullServerAppInputsResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def PushServerAppOutputs(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/flwr.proto.Driver/PushServerAppOutputs', + flwr_dot_proto_dot_driver__pb2.PushServerAppOutputsRequest.SerializeToString, + flwr_dot_proto_dot_driver__pb2.PushServerAppOutputsResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/src/py/flwr/proto/driver_pb2_grpc.pyi b/src/py/flwr/proto/driver_pb2_grpc.pyi index 8f665301073d..5a3e1f964b7f 100644 --- a/src/py/flwr/proto/driver_pb2_grpc.pyi +++ b/src/py/flwr/proto/driver_pb2_grpc.pyi @@ -40,6 +40,16 @@ class DriverStub: flwr.proto.fab_pb2.GetFabResponse] """Get FAB""" + PullServerAppInputs: grpc.UnaryUnaryMultiCallable[ + flwr.proto.driver_pb2.PullServerAppInputsRequest, + flwr.proto.driver_pb2.PullServerAppInputsResponse] + """Pull ServerApp inputs""" + + PushServerAppOutputs: grpc.UnaryUnaryMultiCallable[ + flwr.proto.driver_pb2.PushServerAppOutputsRequest, + flwr.proto.driver_pb2.PushServerAppOutputsResponse] + """Push ServerApp outputs""" + class DriverServicer(metaclass=abc.ABCMeta): @abc.abstractmethod @@ -90,5 +100,21 @@ class DriverServicer(metaclass=abc.ABCMeta): """Get FAB""" pass + @abc.abstractmethod + def PullServerAppInputs(self, + request: flwr.proto.driver_pb2.PullServerAppInputsRequest, + context: grpc.ServicerContext, + ) -> flwr.proto.driver_pb2.PullServerAppInputsResponse: + """Pull ServerApp inputs""" + pass + + @abc.abstractmethod + def PushServerAppOutputs(self, + request: flwr.proto.driver_pb2.PushServerAppOutputsRequest, + context: grpc.ServicerContext, + ) -> flwr.proto.driver_pb2.PushServerAppOutputsResponse: + """Push ServerApp outputs""" + pass + def add_DriverServicer_to_server(servicer: DriverServicer, server: grpc.Server) -> None: ... diff --git a/src/py/flwr/server/superlink/driver/driver_servicer.py b/src/py/flwr/server/superlink/driver/driver_servicer.py index 41a1a64e8879..31d13eff9c78 100644 --- a/src/py/flwr/server/superlink/driver/driver_servicer.py +++ b/src/py/flwr/server/superlink/driver/driver_servicer.py @@ -34,8 +34,12 @@ from flwr.proto.driver_pb2 import ( # pylint: disable=E0611 GetNodesRequest, GetNodesResponse, + PullServerAppInputsRequest, + PullServerAppInputsResponse, PullTaskResRequest, PullTaskResResponse, + PushServerAppOutputsRequest, + PushServerAppOutputsResponse, PushTaskInsRequest, PushTaskInsResponse, ) @@ -200,6 +204,18 @@ def GetFab( raise ValueError(f"Found no FAB with hash: {request.hash_str}") + def PullServerAppInputs( + self, request: PullServerAppInputsRequest, context: grpc.ServicerContext + ) -> PullServerAppInputsResponse: + """Pull ServerApp process inputs.""" + raise NotImplementedError() + + def PushServerAppOutputs( + self, request: PushServerAppOutputsRequest, context: grpc.ServicerContext + ) -> PushServerAppOutputsResponse: + """Push ServerApp process outputs.""" + raise NotImplementedError() + def _raise_if(validation_error: bool, detail: str) -> None: if validation_error: