diff --git a/src/proto/flwr/proto/control.proto b/src/proto/flwr/proto/control.proto index 4e0867bbf9e4..9747c2e3ea11 100644 --- a/src/proto/flwr/proto/control.proto +++ b/src/proto/flwr/proto/control.proto @@ -17,7 +17,7 @@ syntax = "proto3"; package flwr.proto; -import "flwr/proto/driver.proto"; +import "flwr/proto/run.proto"; service Control { // Request to create a new run diff --git a/src/proto/flwr/proto/driver.proto b/src/proto/flwr/proto/driver.proto index c7ae7dcf30f0..e26003862a76 100644 --- a/src/proto/flwr/proto/driver.proto +++ b/src/proto/flwr/proto/driver.proto @@ -21,7 +21,6 @@ import "flwr/proto/node.proto"; import "flwr/proto/task.proto"; import "flwr/proto/run.proto"; import "flwr/proto/fab.proto"; -import "flwr/proto/transport.proto"; service Driver { // Request run_id @@ -43,15 +42,6 @@ service Driver { rpc GetFab(GetFabRequest) returns (GetFabResponse) {} } -// CreateRun -message CreateRunRequest { - string fab_id = 1; - string fab_version = 2; - map override_config = 3; - Fab fab = 4; -} -message CreateRunResponse { uint64 run_id = 1; } - // GetNodes messages message GetNodesRequest { uint64 run_id = 1; } message GetNodesResponse { repeated Node nodes = 1; } diff --git a/src/proto/flwr/proto/run.proto b/src/proto/flwr/proto/run.proto index fc3294f7a583..ada72610e182 100644 --- a/src/proto/flwr/proto/run.proto +++ b/src/proto/flwr/proto/run.proto @@ -17,6 +17,7 @@ syntax = "proto3"; package flwr.proto; +import "flwr/proto/fab.proto"; import "flwr/proto/transport.proto"; message Run { @@ -26,5 +27,16 @@ message Run { map override_config = 4; string fab_hash = 5; } + +// CreateRun +message CreateRunRequest { + string fab_id = 1; + string fab_version = 2; + map override_config = 3; + Fab fab = 4; +} +message CreateRunResponse { uint64 run_id = 1; } + +// GetRun message GetRunRequest { uint64 run_id = 1; } message GetRunResponse { Run run = 1; } diff --git a/src/py/flwr/proto/control_pb2.py b/src/py/flwr/proto/control_pb2.py index d206e75cb388..2b8776509d32 100644 --- a/src/py/flwr/proto/control_pb2.py +++ b/src/py/flwr/proto/control_pb2.py @@ -12,16 +12,16 @@ _sym_db = _symbol_database.Default() -from flwr.proto import driver_pb2 as flwr_dot_proto_dot_driver__pb2 +from flwr.proto import run_pb2 as flwr_dot_proto_dot_run__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x18\x66lwr/proto/control.proto\x12\nflwr.proto\x1a\x17\x66lwr/proto/driver.proto2U\n\x07\x43ontrol\x12J\n\tCreateRun\x12\x1c.flwr.proto.CreateRunRequest\x1a\x1d.flwr.proto.CreateRunResponse\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x18\x66lwr/proto/control.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/run.proto2U\n\x07\x43ontrol\x12J\n\tCreateRun\x12\x1c.flwr.proto.CreateRunRequest\x1a\x1d.flwr.proto.CreateRunResponse\"\x00\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.control_pb2', _globals) if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None - _globals['_CONTROL']._serialized_start=65 - _globals['_CONTROL']._serialized_end=150 + _globals['_CONTROL']._serialized_start=62 + _globals['_CONTROL']._serialized_end=147 # @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/control_pb2_grpc.py b/src/py/flwr/proto/control_pb2_grpc.py index 9c671be88a47..987b9d8d7433 100644 --- a/src/py/flwr/proto/control_pb2_grpc.py +++ b/src/py/flwr/proto/control_pb2_grpc.py @@ -2,7 +2,7 @@ """Client and server classes corresponding to protobuf-defined services.""" import grpc -from flwr.proto import driver_pb2 as flwr_dot_proto_dot_driver__pb2 +from flwr.proto import run_pb2 as flwr_dot_proto_dot_run__pb2 class ControlStub(object): @@ -16,8 +16,8 @@ def __init__(self, channel): """ self.CreateRun = channel.unary_unary( '/flwr.proto.Control/CreateRun', - request_serializer=flwr_dot_proto_dot_driver__pb2.CreateRunRequest.SerializeToString, - response_deserializer=flwr_dot_proto_dot_driver__pb2.CreateRunResponse.FromString, + request_serializer=flwr_dot_proto_dot_run__pb2.CreateRunRequest.SerializeToString, + response_deserializer=flwr_dot_proto_dot_run__pb2.CreateRunResponse.FromString, ) @@ -36,8 +36,8 @@ def add_ControlServicer_to_server(servicer, server): rpc_method_handlers = { 'CreateRun': grpc.unary_unary_rpc_method_handler( servicer.CreateRun, - request_deserializer=flwr_dot_proto_dot_driver__pb2.CreateRunRequest.FromString, - response_serializer=flwr_dot_proto_dot_driver__pb2.CreateRunResponse.SerializeToString, + request_deserializer=flwr_dot_proto_dot_run__pb2.CreateRunRequest.FromString, + response_serializer=flwr_dot_proto_dot_run__pb2.CreateRunResponse.SerializeToString, ), } generic_handler = grpc.method_handlers_generic_handler( @@ -61,7 +61,7 @@ def CreateRun(request, timeout=None, metadata=None): return grpc.experimental.unary_unary(request, target, '/flwr.proto.Control/CreateRun', - flwr_dot_proto_dot_driver__pb2.CreateRunRequest.SerializeToString, - flwr_dot_proto_dot_driver__pb2.CreateRunResponse.FromString, + flwr_dot_proto_dot_run__pb2.CreateRunRequest.SerializeToString, + flwr_dot_proto_dot_run__pb2.CreateRunResponse.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/src/py/flwr/proto/control_pb2_grpc.pyi b/src/py/flwr/proto/control_pb2_grpc.pyi index f4613fa0e2f3..c38b7f15d125 100644 --- a/src/py/flwr/proto/control_pb2_grpc.pyi +++ b/src/py/flwr/proto/control_pb2_grpc.pyi @@ -3,23 +3,23 @@ isort:skip_file """ import abc -import flwr.proto.driver_pb2 +import flwr.proto.run_pb2 import grpc class ControlStub: def __init__(self, channel: grpc.Channel) -> None: ... CreateRun: grpc.UnaryUnaryMultiCallable[ - flwr.proto.driver_pb2.CreateRunRequest, - flwr.proto.driver_pb2.CreateRunResponse] + flwr.proto.run_pb2.CreateRunRequest, + flwr.proto.run_pb2.CreateRunResponse] """Request to create a new run""" class ControlServicer(metaclass=abc.ABCMeta): @abc.abstractmethod def CreateRun(self, - request: flwr.proto.driver_pb2.CreateRunRequest, + request: flwr.proto.run_pb2.CreateRunRequest, context: grpc.ServicerContext, - ) -> flwr.proto.driver_pb2.CreateRunResponse: + ) -> flwr.proto.run_pb2.CreateRunResponse: """Request to create a new run""" pass diff --git a/src/py/flwr/proto/driver_pb2.py b/src/py/flwr/proto/driver_pb2.py index 1322660cfac1..d294b03be5af 100644 --- a/src/py/flwr/proto/driver_pb2.py +++ b/src/py/flwr/proto/driver_pb2.py @@ -16,36 +16,27 @@ from flwr.proto import task_pb2 as flwr_dot_proto_dot_task__pb2 from flwr.proto import run_pb2 as flwr_dot_proto_dot_run__pb2 from flwr.proto import fab_pb2 as flwr_dot_proto_dot_fab__pb2 -from flwr.proto import transport_pb2 as flwr_dot_proto_dot_transport__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x17\x66lwr/proto/driver.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x15\x66lwr/proto/task.proto\x1a\x14\x66lwr/proto/run.proto\x1a\x14\x66lwr/proto/fab.proto\x1a\x1a\x66lwr/proto/transport.proto\"\xeb\x01\n\x10\x43reateRunRequest\x12\x0e\n\x06\x66\x61\x62_id\x18\x01 \x01(\t\x12\x13\n\x0b\x66\x61\x62_version\x18\x02 \x01(\t\x12I\n\x0foverride_config\x18\x03 \x03(\x0b\x32\x30.flwr.proto.CreateRunRequest.OverrideConfigEntry\x12\x1c\n\x03\x66\x61\x62\x18\x04 \x01(\x0b\x32\x0f.flwr.proto.Fab\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"#\n\x11\x43reateRunResponse\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"!\n\x0fGetNodesRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"3\n\x10GetNodesResponse\x12\x1f\n\x05nodes\x18\x01 \x03(\x0b\x32\x10.flwr.proto.Node\"@\n\x12PushTaskInsRequest\x12*\n\rtask_ins_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskIns\"\'\n\x13PushTaskInsResponse\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"F\n\x12PullTaskResRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"A\n\x13PullTaskResResponse\x12*\n\rtask_res_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskRes2\xc7\x03\n\x06\x44river\x12J\n\tCreateRun\x12\x1c.flwr.proto.CreateRunRequest\x1a\x1d.flwr.proto.CreateRunResponse\"\x00\x12G\n\x08GetNodes\x12\x1b.flwr.proto.GetNodesRequest\x1a\x1c.flwr.proto.GetNodesResponse\"\x00\x12P\n\x0bPushTaskIns\x12\x1e.flwr.proto.PushTaskInsRequest\x1a\x1f.flwr.proto.PushTaskInsResponse\"\x00\x12P\n\x0bPullTaskRes\x12\x1e.flwr.proto.PullTaskResRequest\x1a\x1f.flwr.proto.PullTaskResResponse\"\x00\x12\x41\n\x06GetRun\x12\x19.flwr.proto.GetRunRequest\x1a\x1a.flwr.proto.GetRunResponse\"\x00\x12\x41\n\x06GetFab\x12\x19.flwr.proto.GetFabRequest\x1a\x1a.flwr.proto.GetFabResponse\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x17\x66lwr/proto/driver.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x15\x66lwr/proto/task.proto\x1a\x14\x66lwr/proto/run.proto\x1a\x14\x66lwr/proto/fab.proto\"!\n\x0fGetNodesRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"3\n\x10GetNodesResponse\x12\x1f\n\x05nodes\x18\x01 \x03(\x0b\x32\x10.flwr.proto.Node\"@\n\x12PushTaskInsRequest\x12*\n\rtask_ins_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskIns\"\'\n\x13PushTaskInsResponse\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"F\n\x12PullTaskResRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"A\n\x13PullTaskResResponse\x12*\n\rtask_res_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskRes2\xc7\x03\n\x06\x44river\x12J\n\tCreateRun\x12\x1c.flwr.proto.CreateRunRequest\x1a\x1d.flwr.proto.CreateRunResponse\"\x00\x12G\n\x08GetNodes\x12\x1b.flwr.proto.GetNodesRequest\x1a\x1c.flwr.proto.GetNodesResponse\"\x00\x12P\n\x0bPushTaskIns\x12\x1e.flwr.proto.PushTaskInsRequest\x1a\x1f.flwr.proto.PushTaskInsResponse\"\x00\x12P\n\x0bPullTaskRes\x12\x1e.flwr.proto.PullTaskResRequest\x1a\x1f.flwr.proto.PullTaskResResponse\"\x00\x12\x41\n\x06GetRun\x12\x19.flwr.proto.GetRunRequest\x1a\x1a.flwr.proto.GetRunResponse\"\x00\x12\x41\n\x06GetFab\x12\x19.flwr.proto.GetFabRequest\x1a\x1a.flwr.proto.GetFabResponse\"\x00\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.driver_pb2', _globals) if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None - _globals['_CREATERUNREQUEST_OVERRIDECONFIGENTRY']._options = None - _globals['_CREATERUNREQUEST_OVERRIDECONFIGENTRY']._serialized_options = b'8\001' - _globals['_CREATERUNREQUEST']._serialized_start=158 - _globals['_CREATERUNREQUEST']._serialized_end=393 - _globals['_CREATERUNREQUEST_OVERRIDECONFIGENTRY']._serialized_start=320 - _globals['_CREATERUNREQUEST_OVERRIDECONFIGENTRY']._serialized_end=393 - _globals['_CREATERUNRESPONSE']._serialized_start=395 - _globals['_CREATERUNRESPONSE']._serialized_end=430 - _globals['_GETNODESREQUEST']._serialized_start=432 - _globals['_GETNODESREQUEST']._serialized_end=465 - _globals['_GETNODESRESPONSE']._serialized_start=467 - _globals['_GETNODESRESPONSE']._serialized_end=518 - _globals['_PUSHTASKINSREQUEST']._serialized_start=520 - _globals['_PUSHTASKINSREQUEST']._serialized_end=584 - _globals['_PUSHTASKINSRESPONSE']._serialized_start=586 - _globals['_PUSHTASKINSRESPONSE']._serialized_end=625 - _globals['_PULLTASKRESREQUEST']._serialized_start=627 - _globals['_PULLTASKRESREQUEST']._serialized_end=697 - _globals['_PULLTASKRESRESPONSE']._serialized_start=699 - _globals['_PULLTASKRESRESPONSE']._serialized_end=764 - _globals['_DRIVER']._serialized_start=767 - _globals['_DRIVER']._serialized_end=1222 + _globals['_GETNODESREQUEST']._serialized_start=129 + _globals['_GETNODESREQUEST']._serialized_end=162 + _globals['_GETNODESRESPONSE']._serialized_start=164 + _globals['_GETNODESRESPONSE']._serialized_end=215 + _globals['_PUSHTASKINSREQUEST']._serialized_start=217 + _globals['_PUSHTASKINSREQUEST']._serialized_end=281 + _globals['_PUSHTASKINSRESPONSE']._serialized_start=283 + _globals['_PUSHTASKINSRESPONSE']._serialized_end=322 + _globals['_PULLTASKRESREQUEST']._serialized_start=324 + _globals['_PULLTASKRESREQUEST']._serialized_end=394 + _globals['_PULLTASKRESRESPONSE']._serialized_start=396 + _globals['_PULLTASKRESRESPONSE']._serialized_end=461 + _globals['_DRIVER']._serialized_start=464 + _globals['_DRIVER']._serialized_end=919 # @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/driver_pb2.pyi b/src/py/flwr/proto/driver_pb2.pyi index d025e00474eb..77ceb496d70c 100644 --- a/src/py/flwr/proto/driver_pb2.pyi +++ b/src/py/flwr/proto/driver_pb2.pyi @@ -3,10 +3,8 @@ isort:skip_file """ import builtins -import flwr.proto.fab_pb2 import flwr.proto.node_pb2 import flwr.proto.task_pb2 -import flwr.proto.transport_pb2 import google.protobuf.descriptor import google.protobuf.internal.containers import google.protobuf.message @@ -15,56 +13,6 @@ import typing_extensions DESCRIPTOR: google.protobuf.descriptor.FileDescriptor -class CreateRunRequest(google.protobuf.message.Message): - """CreateRun""" - DESCRIPTOR: google.protobuf.descriptor.Descriptor - class OverrideConfigEntry(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - KEY_FIELD_NUMBER: builtins.int - VALUE_FIELD_NUMBER: builtins.int - key: typing.Text - @property - def value(self) -> flwr.proto.transport_pb2.Scalar: ... - def __init__(self, - *, - key: typing.Text = ..., - value: typing.Optional[flwr.proto.transport_pb2.Scalar] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ... - def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ... - - FAB_ID_FIELD_NUMBER: builtins.int - FAB_VERSION_FIELD_NUMBER: builtins.int - OVERRIDE_CONFIG_FIELD_NUMBER: builtins.int - FAB_FIELD_NUMBER: builtins.int - fab_id: typing.Text - fab_version: typing.Text - @property - def override_config(self) -> google.protobuf.internal.containers.MessageMap[typing.Text, flwr.proto.transport_pb2.Scalar]: ... - @property - def fab(self) -> flwr.proto.fab_pb2.Fab: ... - def __init__(self, - *, - fab_id: typing.Text = ..., - fab_version: typing.Text = ..., - override_config: typing.Optional[typing.Mapping[typing.Text, flwr.proto.transport_pb2.Scalar]] = ..., - fab: typing.Optional[flwr.proto.fab_pb2.Fab] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal["fab",b"fab"]) -> builtins.bool: ... - def ClearField(self, field_name: typing_extensions.Literal["fab",b"fab","fab_id",b"fab_id","fab_version",b"fab_version","override_config",b"override_config"]) -> None: ... -global___CreateRunRequest = CreateRunRequest - -class CreateRunResponse(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - RUN_ID_FIELD_NUMBER: builtins.int - run_id: builtins.int - def __init__(self, - *, - run_id: builtins.int = ..., - ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["run_id",b"run_id"]) -> None: ... -global___CreateRunResponse = CreateRunResponse - class GetNodesRequest(google.protobuf.message.Message): """GetNodes messages""" DESCRIPTOR: google.protobuf.descriptor.Descriptor diff --git a/src/py/flwr/proto/driver_pb2_grpc.py b/src/py/flwr/proto/driver_pb2_grpc.py index 6745bc7af62a..91e9fd8b9bdd 100644 --- a/src/py/flwr/proto/driver_pb2_grpc.py +++ b/src/py/flwr/proto/driver_pb2_grpc.py @@ -18,8 +18,8 @@ def __init__(self, channel): """ self.CreateRun = channel.unary_unary( '/flwr.proto.Driver/CreateRun', - request_serializer=flwr_dot_proto_dot_driver__pb2.CreateRunRequest.SerializeToString, - response_deserializer=flwr_dot_proto_dot_driver__pb2.CreateRunResponse.FromString, + request_serializer=flwr_dot_proto_dot_run__pb2.CreateRunRequest.SerializeToString, + response_deserializer=flwr_dot_proto_dot_run__pb2.CreateRunResponse.FromString, ) self.GetNodes = channel.unary_unary( '/flwr.proto.Driver/GetNodes', @@ -98,8 +98,8 @@ def add_DriverServicer_to_server(servicer, server): rpc_method_handlers = { 'CreateRun': grpc.unary_unary_rpc_method_handler( servicer.CreateRun, - request_deserializer=flwr_dot_proto_dot_driver__pb2.CreateRunRequest.FromString, - response_serializer=flwr_dot_proto_dot_driver__pb2.CreateRunResponse.SerializeToString, + request_deserializer=flwr_dot_proto_dot_run__pb2.CreateRunRequest.FromString, + response_serializer=flwr_dot_proto_dot_run__pb2.CreateRunResponse.SerializeToString, ), 'GetNodes': grpc.unary_unary_rpc_method_handler( servicer.GetNodes, @@ -148,8 +148,8 @@ def CreateRun(request, timeout=None, metadata=None): return grpc.experimental.unary_unary(request, target, '/flwr.proto.Driver/CreateRun', - flwr_dot_proto_dot_driver__pb2.CreateRunRequest.SerializeToString, - flwr_dot_proto_dot_driver__pb2.CreateRunResponse.FromString, + flwr_dot_proto_dot_run__pb2.CreateRunRequest.SerializeToString, + flwr_dot_proto_dot_run__pb2.CreateRunResponse.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/src/py/flwr/proto/driver_pb2_grpc.pyi b/src/py/flwr/proto/driver_pb2_grpc.pyi index 7f9fd0acbd82..8f665301073d 100644 --- a/src/py/flwr/proto/driver_pb2_grpc.pyi +++ b/src/py/flwr/proto/driver_pb2_grpc.pyi @@ -11,8 +11,8 @@ import grpc class DriverStub: def __init__(self, channel: grpc.Channel) -> None: ... CreateRun: grpc.UnaryUnaryMultiCallable[ - flwr.proto.driver_pb2.CreateRunRequest, - flwr.proto.driver_pb2.CreateRunResponse] + flwr.proto.run_pb2.CreateRunRequest, + flwr.proto.run_pb2.CreateRunResponse] """Request run_id""" GetNodes: grpc.UnaryUnaryMultiCallable[ @@ -44,9 +44,9 @@ class DriverStub: class DriverServicer(metaclass=abc.ABCMeta): @abc.abstractmethod def CreateRun(self, - request: flwr.proto.driver_pb2.CreateRunRequest, + request: flwr.proto.run_pb2.CreateRunRequest, context: grpc.ServicerContext, - ) -> flwr.proto.driver_pb2.CreateRunResponse: + ) -> flwr.proto.run_pb2.CreateRunResponse: """Request run_id""" pass diff --git a/src/py/flwr/proto/run_pb2.py b/src/py/flwr/proto/run_pb2.py index 13fc43f90f8c..99ca4df5c44c 100644 --- a/src/py/flwr/proto/run_pb2.py +++ b/src/py/flwr/proto/run_pb2.py @@ -12,10 +12,11 @@ _sym_db = _symbol_database.Default() +from flwr.proto import fab_pb2 as flwr_dot_proto_dot_fab__pb2 from flwr.proto import transport_pb2 as flwr_dot_proto_dot_transport__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x14\x66lwr/proto/run.proto\x12\nflwr.proto\x1a\x1a\x66lwr/proto/transport.proto\"\xd5\x01\n\x03Run\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12\x0e\n\x06\x66\x61\x62_id\x18\x02 \x01(\t\x12\x13\n\x0b\x66\x61\x62_version\x18\x03 \x01(\t\x12<\n\x0foverride_config\x18\x04 \x03(\x0b\x32#.flwr.proto.Run.OverrideConfigEntry\x12\x10\n\x08\x66\x61\x62_hash\x18\x05 \x01(\t\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"\x1f\n\rGetRunRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\".\n\x0eGetRunResponse\x12\x1c\n\x03run\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Runb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x14\x66lwr/proto/run.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/fab.proto\x1a\x1a\x66lwr/proto/transport.proto\"\xd5\x01\n\x03Run\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12\x0e\n\x06\x66\x61\x62_id\x18\x02 \x01(\t\x12\x13\n\x0b\x66\x61\x62_version\x18\x03 \x01(\t\x12<\n\x0foverride_config\x18\x04 \x03(\x0b\x32#.flwr.proto.Run.OverrideConfigEntry\x12\x10\n\x08\x66\x61\x62_hash\x18\x05 \x01(\t\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"\xeb\x01\n\x10\x43reateRunRequest\x12\x0e\n\x06\x66\x61\x62_id\x18\x01 \x01(\t\x12\x13\n\x0b\x66\x61\x62_version\x18\x02 \x01(\t\x12I\n\x0foverride_config\x18\x03 \x03(\x0b\x32\x30.flwr.proto.CreateRunRequest.OverrideConfigEntry\x12\x1c\n\x03\x66\x61\x62\x18\x04 \x01(\x0b\x32\x0f.flwr.proto.Fab\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"#\n\x11\x43reateRunResponse\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"\x1f\n\rGetRunRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\".\n\x0eGetRunResponse\x12\x1c\n\x03run\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Runb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -24,12 +25,20 @@ DESCRIPTOR._options = None _globals['_RUN_OVERRIDECONFIGENTRY']._options = None _globals['_RUN_OVERRIDECONFIGENTRY']._serialized_options = b'8\001' - _globals['_RUN']._serialized_start=65 - _globals['_RUN']._serialized_end=278 - _globals['_RUN_OVERRIDECONFIGENTRY']._serialized_start=205 - _globals['_RUN_OVERRIDECONFIGENTRY']._serialized_end=278 - _globals['_GETRUNREQUEST']._serialized_start=280 - _globals['_GETRUNREQUEST']._serialized_end=311 - _globals['_GETRUNRESPONSE']._serialized_start=313 - _globals['_GETRUNRESPONSE']._serialized_end=359 + _globals['_CREATERUNREQUEST_OVERRIDECONFIGENTRY']._options = None + _globals['_CREATERUNREQUEST_OVERRIDECONFIGENTRY']._serialized_options = b'8\001' + _globals['_RUN']._serialized_start=87 + _globals['_RUN']._serialized_end=300 + _globals['_RUN_OVERRIDECONFIGENTRY']._serialized_start=227 + _globals['_RUN_OVERRIDECONFIGENTRY']._serialized_end=300 + _globals['_CREATERUNREQUEST']._serialized_start=303 + _globals['_CREATERUNREQUEST']._serialized_end=538 + _globals['_CREATERUNREQUEST_OVERRIDECONFIGENTRY']._serialized_start=227 + _globals['_CREATERUNREQUEST_OVERRIDECONFIGENTRY']._serialized_end=300 + _globals['_CREATERUNRESPONSE']._serialized_start=540 + _globals['_CREATERUNRESPONSE']._serialized_end=575 + _globals['_GETRUNREQUEST']._serialized_start=577 + _globals['_GETRUNREQUEST']._serialized_end=608 + _globals['_GETRUNRESPONSE']._serialized_start=610 + _globals['_GETRUNRESPONSE']._serialized_end=656 # @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/run_pb2.pyi b/src/py/flwr/proto/run_pb2.pyi index e65feee9c518..26b69e7eed27 100644 --- a/src/py/flwr/proto/run_pb2.pyi +++ b/src/py/flwr/proto/run_pb2.pyi @@ -3,6 +3,7 @@ isort:skip_file """ import builtins +import flwr.proto.fab_pb2 import flwr.proto.transport_pb2 import google.protobuf.descriptor import google.protobuf.internal.containers @@ -51,7 +52,58 @@ class Run(google.protobuf.message.Message): def ClearField(self, field_name: typing_extensions.Literal["fab_hash",b"fab_hash","fab_id",b"fab_id","fab_version",b"fab_version","override_config",b"override_config","run_id",b"run_id"]) -> None: ... global___Run = Run +class CreateRunRequest(google.protobuf.message.Message): + """CreateRun""" + DESCRIPTOR: google.protobuf.descriptor.Descriptor + class OverrideConfigEntry(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + KEY_FIELD_NUMBER: builtins.int + VALUE_FIELD_NUMBER: builtins.int + key: typing.Text + @property + def value(self) -> flwr.proto.transport_pb2.Scalar: ... + def __init__(self, + *, + key: typing.Text = ..., + value: typing.Optional[flwr.proto.transport_pb2.Scalar] = ..., + ) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ... + + FAB_ID_FIELD_NUMBER: builtins.int + FAB_VERSION_FIELD_NUMBER: builtins.int + OVERRIDE_CONFIG_FIELD_NUMBER: builtins.int + FAB_FIELD_NUMBER: builtins.int + fab_id: typing.Text + fab_version: typing.Text + @property + def override_config(self) -> google.protobuf.internal.containers.MessageMap[typing.Text, flwr.proto.transport_pb2.Scalar]: ... + @property + def fab(self) -> flwr.proto.fab_pb2.Fab: ... + def __init__(self, + *, + fab_id: typing.Text = ..., + fab_version: typing.Text = ..., + override_config: typing.Optional[typing.Mapping[typing.Text, flwr.proto.transport_pb2.Scalar]] = ..., + fab: typing.Optional[flwr.proto.fab_pb2.Fab] = ..., + ) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["fab",b"fab"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["fab",b"fab","fab_id",b"fab_id","fab_version",b"fab_version","override_config",b"override_config"]) -> None: ... +global___CreateRunRequest = CreateRunRequest + +class CreateRunResponse(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + RUN_ID_FIELD_NUMBER: builtins.int + run_id: builtins.int + def __init__(self, + *, + run_id: builtins.int = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["run_id",b"run_id"]) -> None: ... +global___CreateRunResponse = CreateRunResponse + class GetRunRequest(google.protobuf.message.Message): + """GetRun""" DESCRIPTOR: google.protobuf.descriptor.Descriptor RUN_ID_FIELD_NUMBER: builtins.int run_id: builtins.int diff --git a/src/py/flwr/server/run_serverapp.py b/src/py/flwr/server/run_serverapp.py index c0628db04360..a9ec05fe90e0 100644 --- a/src/py/flwr/server/run_serverapp.py +++ b/src/py/flwr/server/run_serverapp.py @@ -35,11 +35,11 @@ from flwr.common.logger import log, update_console_handler, warn_deprecated_feature from flwr.common.object_ref import load_app from flwr.common.typing import UserConfig -from flwr.proto.driver_pb2 import ( # pylint: disable=E0611 +from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611 +from flwr.proto.run_pb2 import ( # pylint: disable=E0611 CreateRunRequest, CreateRunResponse, ) -from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611 from .driver import Driver from .driver.grpc_driver import GrpcDriver diff --git a/src/py/flwr/server/superlink/driver/driver_servicer.py b/src/py/flwr/server/superlink/driver/driver_servicer.py index 4d7d6cb6ce89..3cafb3e71f7c 100644 --- a/src/py/flwr/server/superlink/driver/driver_servicer.py +++ b/src/py/flwr/server/superlink/driver/driver_servicer.py @@ -32,8 +32,6 @@ from flwr.common.typing import Fab from flwr.proto import driver_pb2_grpc # pylint: disable=E0611 from flwr.proto.driver_pb2 import ( # pylint: disable=E0611 - CreateRunRequest, - CreateRunResponse, GetNodesRequest, GetNodesResponse, PullTaskResRequest, @@ -44,6 +42,8 @@ from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611 from flwr.proto.node_pb2 import Node # pylint: disable=E0611 from flwr.proto.run_pb2 import ( # pylint: disable=E0611 + CreateRunRequest, + CreateRunResponse, GetRunRequest, GetRunResponse, Run, diff --git a/src/py/flwr/superexec/deployment.py b/src/py/flwr/superexec/deployment.py index 55e519cf5f27..d60870995d39 100644 --- a/src/py/flwr/superexec/deployment.py +++ b/src/py/flwr/superexec/deployment.py @@ -28,8 +28,8 @@ from flwr.common.logger import log from flwr.common.serde import fab_to_proto, user_config_to_proto from flwr.common.typing import Fab, UserConfig -from flwr.proto.driver_pb2 import CreateRunRequest # pylint: disable=E0611 from flwr.proto.driver_pb2_grpc import DriverStub +from flwr.proto.run_pb2 import CreateRunRequest # pylint: disable=E0611 from .executor import Executor, RunTracker