From 227abfa9e6e8bf9867ed360a4f80461edc48e6af Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Mon, 23 Sep 2024 12:16:43 +0100 Subject: [PATCH 1/2] amend protocols --- src/proto/flwr/proto/fab.proto | 7 +++- src/proto/flwr/proto/fleet.proto | 5 ++- src/proto/flwr/proto/run.proto | 11 +++++-- src/py/flwr/proto/fab_pb2.py | 15 +++++---- src/py/flwr/proto/fab_pb2.pyi | 8 ++++- src/py/flwr/proto/fleet_pb2.py | 20 ++++++------ src/py/flwr/proto/fleet_pb2.pyi | 7 +++- src/py/flwr/proto/run_pb2.py | 55 ++++++++++++++++---------------- src/py/flwr/proto/run_pb2.pyi | 15 +++++++-- 9 files changed, 91 insertions(+), 52 deletions(-) diff --git a/src/proto/flwr/proto/fab.proto b/src/proto/flwr/proto/fab.proto index 6f8e6b87808d..367b6e5b5c13 100644 --- a/src/proto/flwr/proto/fab.proto +++ b/src/proto/flwr/proto/fab.proto @@ -17,6 +17,8 @@ syntax = "proto3"; package flwr.proto; +import "flwr/proto/node.proto"; + message Fab { // This field is the hash of the data field. It is used to identify the data. // The hash is calculated using the SHA-256 algorithm and is represented as a @@ -26,5 +28,8 @@ message Fab { bytes content = 2; } -message GetFabRequest { string hash_str = 1; } +message GetFabRequest { + Node node = 1; + string hash_str = 2; +} message GetFabResponse { Fab fab = 1; } diff --git a/src/proto/flwr/proto/fleet.proto b/src/proto/flwr/proto/fleet.proto index b87214ac52f3..130b30b96669 100644 --- a/src/proto/flwr/proto/fleet.proto +++ b/src/proto/flwr/proto/fleet.proto @@ -69,7 +69,10 @@ message PullTaskInsResponse { } // PushTaskRes messages -message PushTaskResRequest { repeated TaskRes task_res_list = 1; } +message PushTaskResRequest { + Node node = 1; + repeated TaskRes task_res_list = 2; +} message PushTaskResResponse { Reconnect reconnect = 1; map results = 2; diff --git a/src/proto/flwr/proto/run.proto b/src/proto/flwr/proto/run.proto index 2c9bd877f66c..4312e1127cc2 100644 --- a/src/proto/flwr/proto/run.proto +++ b/src/proto/flwr/proto/run.proto @@ -18,6 +18,7 @@ syntax = "proto3"; package flwr.proto; import "flwr/proto/fab.proto"; +import "flwr/proto/node.proto"; import "flwr/proto/transport.proto"; message Run { @@ -47,7 +48,10 @@ message CreateRunRequest { message CreateRunResponse { uint64 run_id = 1; } // GetRun -message GetRunRequest { uint64 run_id = 1; } +message GetRunRequest { + Node node = 1; + uint64 run_id = 2; +} message GetRunResponse { Run run = 1; } // UpdateRunStatus @@ -58,5 +62,8 @@ message UpdateRunStatusRequest { message UpdateRunStatusResponse {} // GetRunStatus -message GetRunStatusRequest { repeated uint64 run_ids = 1; } +message GetRunStatusRequest { + Node node = 1; + repeated uint64 run_ids = 2; +} message GetRunStatusResponse { map run_status_dict = 1; } diff --git a/src/py/flwr/proto/fab_pb2.py b/src/py/flwr/proto/fab_pb2.py index 3f04e6693ab8..3a5e50000c10 100644 --- a/src/py/flwr/proto/fab_pb2.py +++ b/src/py/flwr/proto/fab_pb2.py @@ -12,19 +12,20 @@ _sym_db = _symbol_database.Default() +from flwr.proto import node_pb2 as flwr_dot_proto_dot_node__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x14\x66lwr/proto/fab.proto\x12\nflwr.proto\"(\n\x03\x46\x61\x62\x12\x10\n\x08hash_str\x18\x01 \x01(\t\x12\x0f\n\x07\x63ontent\x18\x02 \x01(\x0c\"!\n\rGetFabRequest\x12\x10\n\x08hash_str\x18\x01 \x01(\t\".\n\x0eGetFabResponse\x12\x1c\n\x03\x66\x61\x62\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Fabb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x14\x66lwr/proto/fab.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\"(\n\x03\x46\x61\x62\x12\x10\n\x08hash_str\x18\x01 \x01(\t\x12\x0f\n\x07\x63ontent\x18\x02 \x01(\x0c\"A\n\rGetFabRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08hash_str\x18\x02 \x01(\t\".\n\x0eGetFabResponse\x12\x1c\n\x03\x66\x61\x62\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Fabb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.fab_pb2', _globals) if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None - _globals['_FAB']._serialized_start=36 - _globals['_FAB']._serialized_end=76 - _globals['_GETFABREQUEST']._serialized_start=78 - _globals['_GETFABREQUEST']._serialized_end=111 - _globals['_GETFABRESPONSE']._serialized_start=113 - _globals['_GETFABRESPONSE']._serialized_end=159 + _globals['_FAB']._serialized_start=59 + _globals['_FAB']._serialized_end=99 + _globals['_GETFABREQUEST']._serialized_start=101 + _globals['_GETFABREQUEST']._serialized_end=166 + _globals['_GETFABRESPONSE']._serialized_start=168 + _globals['_GETFABRESPONSE']._serialized_end=214 # @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/fab_pb2.pyi b/src/py/flwr/proto/fab_pb2.pyi index b2715dde5021..8cfdcbaf76ad 100644 --- a/src/py/flwr/proto/fab_pb2.pyi +++ b/src/py/flwr/proto/fab_pb2.pyi @@ -3,6 +3,7 @@ isort:skip_file """ import builtins +import flwr.proto.node_pb2 import google.protobuf.descriptor import google.protobuf.message import typing @@ -33,13 +34,18 @@ global___Fab = Fab class GetFabRequest(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor + NODE_FIELD_NUMBER: builtins.int HASH_STR_FIELD_NUMBER: builtins.int + @property + def node(self) -> flwr.proto.node_pb2.Node: ... hash_str: typing.Text def __init__(self, *, + node: typing.Optional[flwr.proto.node_pb2.Node] = ..., hash_str: typing.Text = ..., ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["hash_str",b"hash_str"]) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["node",b"node"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["hash_str",b"hash_str","node",b"node"]) -> None: ... global___GetFabRequest = GetFabRequest class GetFabResponse(google.protobuf.message.Message): diff --git a/src/py/flwr/proto/fleet_pb2.py b/src/py/flwr/proto/fleet_pb2.py index d1fe719f2d91..3185bc2ce111 100644 --- a/src/py/flwr/proto/fleet_pb2.py +++ b/src/py/flwr/proto/fleet_pb2.py @@ -18,7 +18,7 @@ from flwr.proto import fab_pb2 as flwr_dot_proto_dot_fab__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16\x66lwr/proto/fleet.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x15\x66lwr/proto/task.proto\x1a\x14\x66lwr/proto/run.proto\x1a\x14\x66lwr/proto/fab.proto\"*\n\x11\x43reateNodeRequest\x12\x15\n\rping_interval\x18\x01 \x01(\x01\"4\n\x12\x43reateNodeResponse\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\"3\n\x11\x44\x65leteNodeRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\"\x14\n\x12\x44\x65leteNodeResponse\"D\n\x0bPingRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x15\n\rping_interval\x18\x02 \x01(\x01\"\x1f\n\x0cPingResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\"F\n\x12PullTaskInsRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"k\n\x13PullTaskInsResponse\x12(\n\treconnect\x18\x01 \x01(\x0b\x32\x15.flwr.proto.Reconnect\x12*\n\rtask_ins_list\x18\x02 \x03(\x0b\x32\x13.flwr.proto.TaskIns\"@\n\x12PushTaskResRequest\x12*\n\rtask_res_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskRes\"\xae\x01\n\x13PushTaskResResponse\x12(\n\treconnect\x18\x01 \x01(\x0b\x32\x15.flwr.proto.Reconnect\x12=\n\x07results\x18\x02 \x03(\x0b\x32,.flwr.proto.PushTaskResResponse.ResultsEntry\x1a.\n\x0cResultsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\r:\x02\x38\x01\"\x1e\n\tReconnect\x12\x11\n\treconnect\x18\x01 \x01(\x04\x32\x8c\x04\n\x05\x46leet\x12M\n\nCreateNode\x12\x1d.flwr.proto.CreateNodeRequest\x1a\x1e.flwr.proto.CreateNodeResponse\"\x00\x12M\n\nDeleteNode\x12\x1d.flwr.proto.DeleteNodeRequest\x1a\x1e.flwr.proto.DeleteNodeResponse\"\x00\x12;\n\x04Ping\x12\x17.flwr.proto.PingRequest\x1a\x18.flwr.proto.PingResponse\"\x00\x12P\n\x0bPullTaskIns\x12\x1e.flwr.proto.PullTaskInsRequest\x1a\x1f.flwr.proto.PullTaskInsResponse\"\x00\x12P\n\x0bPushTaskRes\x12\x1e.flwr.proto.PushTaskResRequest\x1a\x1f.flwr.proto.PushTaskResResponse\"\x00\x12\x41\n\x06GetRun\x12\x19.flwr.proto.GetRunRequest\x1a\x1a.flwr.proto.GetRunResponse\"\x00\x12\x41\n\x06GetFab\x12\x19.flwr.proto.GetFabRequest\x1a\x1a.flwr.proto.GetFabResponse\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16\x66lwr/proto/fleet.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x15\x66lwr/proto/task.proto\x1a\x14\x66lwr/proto/run.proto\x1a\x14\x66lwr/proto/fab.proto\"*\n\x11\x43reateNodeRequest\x12\x15\n\rping_interval\x18\x01 \x01(\x01\"4\n\x12\x43reateNodeResponse\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\"3\n\x11\x44\x65leteNodeRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\"\x14\n\x12\x44\x65leteNodeResponse\"D\n\x0bPingRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x15\n\rping_interval\x18\x02 \x01(\x01\"\x1f\n\x0cPingResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\"F\n\x12PullTaskInsRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"k\n\x13PullTaskInsResponse\x12(\n\treconnect\x18\x01 \x01(\x0b\x32\x15.flwr.proto.Reconnect\x12*\n\rtask_ins_list\x18\x02 \x03(\x0b\x32\x13.flwr.proto.TaskIns\"`\n\x12PushTaskResRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12*\n\rtask_res_list\x18\x02 \x03(\x0b\x32\x13.flwr.proto.TaskRes\"\xae\x01\n\x13PushTaskResResponse\x12(\n\treconnect\x18\x01 \x01(\x0b\x32\x15.flwr.proto.Reconnect\x12=\n\x07results\x18\x02 \x03(\x0b\x32,.flwr.proto.PushTaskResResponse.ResultsEntry\x1a.\n\x0cResultsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\r:\x02\x38\x01\"\x1e\n\tReconnect\x12\x11\n\treconnect\x18\x01 \x01(\x04\x32\x8c\x04\n\x05\x46leet\x12M\n\nCreateNode\x12\x1d.flwr.proto.CreateNodeRequest\x1a\x1e.flwr.proto.CreateNodeResponse\"\x00\x12M\n\nDeleteNode\x12\x1d.flwr.proto.DeleteNodeRequest\x1a\x1e.flwr.proto.DeleteNodeResponse\"\x00\x12;\n\x04Ping\x12\x17.flwr.proto.PingRequest\x1a\x18.flwr.proto.PingResponse\"\x00\x12P\n\x0bPullTaskIns\x12\x1e.flwr.proto.PullTaskInsRequest\x1a\x1f.flwr.proto.PullTaskInsResponse\"\x00\x12P\n\x0bPushTaskRes\x12\x1e.flwr.proto.PushTaskResRequest\x1a\x1f.flwr.proto.PushTaskResResponse\"\x00\x12\x41\n\x06GetRun\x12\x19.flwr.proto.GetRunRequest\x1a\x1a.flwr.proto.GetRunResponse\"\x00\x12\x41\n\x06GetFab\x12\x19.flwr.proto.GetFabRequest\x1a\x1a.flwr.proto.GetFabResponse\"\x00\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -44,13 +44,13 @@ _globals['_PULLTASKINSRESPONSE']._serialized_start=476 _globals['_PULLTASKINSRESPONSE']._serialized_end=583 _globals['_PUSHTASKRESREQUEST']._serialized_start=585 - _globals['_PUSHTASKRESREQUEST']._serialized_end=649 - _globals['_PUSHTASKRESRESPONSE']._serialized_start=652 - _globals['_PUSHTASKRESRESPONSE']._serialized_end=826 - _globals['_PUSHTASKRESRESPONSE_RESULTSENTRY']._serialized_start=780 - _globals['_PUSHTASKRESRESPONSE_RESULTSENTRY']._serialized_end=826 - _globals['_RECONNECT']._serialized_start=828 - _globals['_RECONNECT']._serialized_end=858 - _globals['_FLEET']._serialized_start=861 - _globals['_FLEET']._serialized_end=1385 + _globals['_PUSHTASKRESREQUEST']._serialized_end=681 + _globals['_PUSHTASKRESRESPONSE']._serialized_start=684 + _globals['_PUSHTASKRESRESPONSE']._serialized_end=858 + _globals['_PUSHTASKRESRESPONSE_RESULTSENTRY']._serialized_start=812 + _globals['_PUSHTASKRESRESPONSE_RESULTSENTRY']._serialized_end=858 + _globals['_RECONNECT']._serialized_start=860 + _globals['_RECONNECT']._serialized_end=890 + _globals['_FLEET']._serialized_start=893 + _globals['_FLEET']._serialized_end=1417 # @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/fleet_pb2.pyi b/src/py/flwr/proto/fleet_pb2.pyi index 5989f45c5c60..76875bc1a4b9 100644 --- a/src/py/flwr/proto/fleet_pb2.pyi +++ b/src/py/flwr/proto/fleet_pb2.pyi @@ -124,14 +124,19 @@ global___PullTaskInsResponse = PullTaskInsResponse class PushTaskResRequest(google.protobuf.message.Message): """PushTaskRes messages""" DESCRIPTOR: google.protobuf.descriptor.Descriptor + NODE_FIELD_NUMBER: builtins.int TASK_RES_LIST_FIELD_NUMBER: builtins.int @property + def node(self) -> flwr.proto.node_pb2.Node: ... + @property def task_res_list(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[flwr.proto.task_pb2.TaskRes]: ... def __init__(self, *, + node: typing.Optional[flwr.proto.node_pb2.Node] = ..., task_res_list: typing.Optional[typing.Iterable[flwr.proto.task_pb2.TaskRes]] = ..., ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["task_res_list",b"task_res_list"]) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["node",b"node"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["node",b"node","task_res_list",b"task_res_list"]) -> None: ... global___PushTaskResRequest = PushTaskResRequest class PushTaskResResponse(google.protobuf.message.Message): diff --git a/src/py/flwr/proto/run_pb2.py b/src/py/flwr/proto/run_pb2.py index d59cc26fbb48..cc3f6897918f 100644 --- a/src/py/flwr/proto/run_pb2.py +++ b/src/py/flwr/proto/run_pb2.py @@ -13,10 +13,11 @@ from flwr.proto import fab_pb2 as flwr_dot_proto_dot_fab__pb2 +from flwr.proto import node_pb2 as flwr_dot_proto_dot_node__pb2 from flwr.proto import transport_pb2 as flwr_dot_proto_dot_transport__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x14\x66lwr/proto/run.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/fab.proto\x1a\x1a\x66lwr/proto/transport.proto\"\xd5\x01\n\x03Run\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12\x0e\n\x06\x66\x61\x62_id\x18\x02 \x01(\t\x12\x13\n\x0b\x66\x61\x62_version\x18\x03 \x01(\t\x12<\n\x0foverride_config\x18\x04 \x03(\x0b\x32#.flwr.proto.Run.OverrideConfigEntry\x12\x10\n\x08\x66\x61\x62_hash\x18\x05 \x01(\t\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"@\n\tRunStatus\x12\x0e\n\x06status\x18\x01 \x01(\t\x12\x12\n\nsub_status\x18\x02 \x01(\t\x12\x0f\n\x07\x64\x65tails\x18\x03 \x01(\t\"\xeb\x01\n\x10\x43reateRunRequest\x12\x0e\n\x06\x66\x61\x62_id\x18\x01 \x01(\t\x12\x13\n\x0b\x66\x61\x62_version\x18\x02 \x01(\t\x12I\n\x0foverride_config\x18\x03 \x03(\x0b\x32\x30.flwr.proto.CreateRunRequest.OverrideConfigEntry\x12\x1c\n\x03\x66\x61\x62\x18\x04 \x01(\x0b\x32\x0f.flwr.proto.Fab\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"#\n\x11\x43reateRunResponse\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"\x1f\n\rGetRunRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\".\n\x0eGetRunResponse\x12\x1c\n\x03run\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Run\"S\n\x16UpdateRunStatusRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12)\n\nrun_status\x18\x02 \x01(\x0b\x32\x15.flwr.proto.RunStatus\"\x19\n\x17UpdateRunStatusResponse\"&\n\x13GetRunStatusRequest\x12\x0f\n\x07run_ids\x18\x01 \x03(\x04\"\xb1\x01\n\x14GetRunStatusResponse\x12L\n\x0frun_status_dict\x18\x01 \x03(\x0b\x32\x33.flwr.proto.GetRunStatusResponse.RunStatusDictEntry\x1aK\n\x12RunStatusDictEntry\x12\x0b\n\x03key\x18\x01 \x01(\x04\x12$\n\x05value\x18\x02 \x01(\x0b\x32\x15.flwr.proto.RunStatus:\x02\x38\x01\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x14\x66lwr/proto/run.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/fab.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x1a\x66lwr/proto/transport.proto\"\xd5\x01\n\x03Run\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12\x0e\n\x06\x66\x61\x62_id\x18\x02 \x01(\t\x12\x13\n\x0b\x66\x61\x62_version\x18\x03 \x01(\t\x12<\n\x0foverride_config\x18\x04 \x03(\x0b\x32#.flwr.proto.Run.OverrideConfigEntry\x12\x10\n\x08\x66\x61\x62_hash\x18\x05 \x01(\t\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"@\n\tRunStatus\x12\x0e\n\x06status\x18\x01 \x01(\t\x12\x12\n\nsub_status\x18\x02 \x01(\t\x12\x0f\n\x07\x64\x65tails\x18\x03 \x01(\t\"\xeb\x01\n\x10\x43reateRunRequest\x12\x0e\n\x06\x66\x61\x62_id\x18\x01 \x01(\t\x12\x13\n\x0b\x66\x61\x62_version\x18\x02 \x01(\t\x12I\n\x0foverride_config\x18\x03 \x03(\x0b\x32\x30.flwr.proto.CreateRunRequest.OverrideConfigEntry\x12\x1c\n\x03\x66\x61\x62\x18\x04 \x01(\x0b\x32\x0f.flwr.proto.Fab\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"#\n\x11\x43reateRunResponse\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"?\n\rGetRunRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x0e\n\x06run_id\x18\x02 \x01(\x04\".\n\x0eGetRunResponse\x12\x1c\n\x03run\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Run\"S\n\x16UpdateRunStatusRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12)\n\nrun_status\x18\x02 \x01(\x0b\x32\x15.flwr.proto.RunStatus\"\x19\n\x17UpdateRunStatusResponse\"F\n\x13GetRunStatusRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x0f\n\x07run_ids\x18\x02 \x03(\x04\"\xb1\x01\n\x14GetRunStatusResponse\x12L\n\x0frun_status_dict\x18\x01 \x03(\x0b\x32\x33.flwr.proto.GetRunStatusResponse.RunStatusDictEntry\x1aK\n\x12RunStatusDictEntry\x12\x0b\n\x03key\x18\x01 \x01(\x04\x12$\n\x05value\x18\x02 \x01(\x0b\x32\x15.flwr.proto.RunStatus:\x02\x38\x01\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -29,30 +30,30 @@ _globals['_CREATERUNREQUEST_OVERRIDECONFIGENTRY']._serialized_options = b'8\001' _globals['_GETRUNSTATUSRESPONSE_RUNSTATUSDICTENTRY']._options = None _globals['_GETRUNSTATUSRESPONSE_RUNSTATUSDICTENTRY']._serialized_options = b'8\001' - _globals['_RUN']._serialized_start=87 - _globals['_RUN']._serialized_end=300 - _globals['_RUN_OVERRIDECONFIGENTRY']._serialized_start=227 - _globals['_RUN_OVERRIDECONFIGENTRY']._serialized_end=300 - _globals['_RUNSTATUS']._serialized_start=302 - _globals['_RUNSTATUS']._serialized_end=366 - _globals['_CREATERUNREQUEST']._serialized_start=369 - _globals['_CREATERUNREQUEST']._serialized_end=604 - _globals['_CREATERUNREQUEST_OVERRIDECONFIGENTRY']._serialized_start=227 - _globals['_CREATERUNREQUEST_OVERRIDECONFIGENTRY']._serialized_end=300 - _globals['_CREATERUNRESPONSE']._serialized_start=606 - _globals['_CREATERUNRESPONSE']._serialized_end=641 - _globals['_GETRUNREQUEST']._serialized_start=643 - _globals['_GETRUNREQUEST']._serialized_end=674 - _globals['_GETRUNRESPONSE']._serialized_start=676 - _globals['_GETRUNRESPONSE']._serialized_end=722 - _globals['_UPDATERUNSTATUSREQUEST']._serialized_start=724 - _globals['_UPDATERUNSTATUSREQUEST']._serialized_end=807 - _globals['_UPDATERUNSTATUSRESPONSE']._serialized_start=809 - _globals['_UPDATERUNSTATUSRESPONSE']._serialized_end=834 - _globals['_GETRUNSTATUSREQUEST']._serialized_start=836 - _globals['_GETRUNSTATUSREQUEST']._serialized_end=874 - _globals['_GETRUNSTATUSRESPONSE']._serialized_start=877 - _globals['_GETRUNSTATUSRESPONSE']._serialized_end=1054 - _globals['_GETRUNSTATUSRESPONSE_RUNSTATUSDICTENTRY']._serialized_start=979 - _globals['_GETRUNSTATUSRESPONSE_RUNSTATUSDICTENTRY']._serialized_end=1054 + _globals['_RUN']._serialized_start=110 + _globals['_RUN']._serialized_end=323 + _globals['_RUN_OVERRIDECONFIGENTRY']._serialized_start=250 + _globals['_RUN_OVERRIDECONFIGENTRY']._serialized_end=323 + _globals['_RUNSTATUS']._serialized_start=325 + _globals['_RUNSTATUS']._serialized_end=389 + _globals['_CREATERUNREQUEST']._serialized_start=392 + _globals['_CREATERUNREQUEST']._serialized_end=627 + _globals['_CREATERUNREQUEST_OVERRIDECONFIGENTRY']._serialized_start=250 + _globals['_CREATERUNREQUEST_OVERRIDECONFIGENTRY']._serialized_end=323 + _globals['_CREATERUNRESPONSE']._serialized_start=629 + _globals['_CREATERUNRESPONSE']._serialized_end=664 + _globals['_GETRUNREQUEST']._serialized_start=666 + _globals['_GETRUNREQUEST']._serialized_end=729 + _globals['_GETRUNRESPONSE']._serialized_start=731 + _globals['_GETRUNRESPONSE']._serialized_end=777 + _globals['_UPDATERUNSTATUSREQUEST']._serialized_start=779 + _globals['_UPDATERUNSTATUSREQUEST']._serialized_end=862 + _globals['_UPDATERUNSTATUSRESPONSE']._serialized_start=864 + _globals['_UPDATERUNSTATUSRESPONSE']._serialized_end=889 + _globals['_GETRUNSTATUSREQUEST']._serialized_start=891 + _globals['_GETRUNSTATUSREQUEST']._serialized_end=961 + _globals['_GETRUNSTATUSRESPONSE']._serialized_start=964 + _globals['_GETRUNSTATUSRESPONSE']._serialized_end=1141 + _globals['_GETRUNSTATUSRESPONSE_RUNSTATUSDICTENTRY']._serialized_start=1066 + _globals['_GETRUNSTATUSRESPONSE_RUNSTATUSDICTENTRY']._serialized_end=1141 # @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/run_pb2.pyi b/src/py/flwr/proto/run_pb2.pyi index cec90c4d2d4c..16411712eaf2 100644 --- a/src/py/flwr/proto/run_pb2.pyi +++ b/src/py/flwr/proto/run_pb2.pyi @@ -4,6 +4,7 @@ isort:skip_file """ import builtins import flwr.proto.fab_pb2 +import flwr.proto.node_pb2 import flwr.proto.transport_pb2 import google.protobuf.descriptor import google.protobuf.internal.containers @@ -128,13 +129,18 @@ global___CreateRunResponse = CreateRunResponse class GetRunRequest(google.protobuf.message.Message): """GetRun""" DESCRIPTOR: google.protobuf.descriptor.Descriptor + NODE_FIELD_NUMBER: builtins.int RUN_ID_FIELD_NUMBER: builtins.int + @property + def node(self) -> flwr.proto.node_pb2.Node: ... run_id: builtins.int def __init__(self, *, + node: typing.Optional[flwr.proto.node_pb2.Node] = ..., run_id: builtins.int = ..., ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["run_id",b"run_id"]) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["node",b"node"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["node",b"node","run_id",b"run_id"]) -> None: ... global___GetRunRequest = GetRunRequest class GetRunResponse(google.protobuf.message.Message): @@ -176,14 +182,19 @@ global___UpdateRunStatusResponse = UpdateRunStatusResponse class GetRunStatusRequest(google.protobuf.message.Message): """GetRunStatus""" DESCRIPTOR: google.protobuf.descriptor.Descriptor + NODE_FIELD_NUMBER: builtins.int RUN_IDS_FIELD_NUMBER: builtins.int @property + def node(self) -> flwr.proto.node_pb2.Node: ... + @property def run_ids(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ... def __init__(self, *, + node: typing.Optional[flwr.proto.node_pb2.Node] = ..., run_ids: typing.Optional[typing.Iterable[builtins.int]] = ..., ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["run_ids",b"run_ids"]) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["node",b"node"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["node",b"node","run_ids",b"run_ids"]) -> None: ... global___GetRunStatusRequest = GetRunStatusRequest class GetRunStatusResponse(google.protobuf.message.Message): From 8b4bf5d3d8182dedb188d035cc6dfb0174f2b1ab Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Mon, 23 Sep 2024 12:20:21 +0100 Subject: [PATCH 2/2] amend grpc-rere and rest connection --- src/py/flwr/client/grpc_rere_client/connection.py | 6 +++--- src/py/flwr/client/rest_client/connection.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/py/flwr/client/grpc_rere_client/connection.py b/src/py/flwr/client/grpc_rere_client/connection.py index 7ce3d37b7a17..b4fa28373600 100644 --- a/src/py/flwr/client/grpc_rere_client/connection.py +++ b/src/py/flwr/client/grpc_rere_client/connection.py @@ -269,7 +269,7 @@ def send(message: Message) -> None: task_res = message_to_taskres(message) # Serialize ProtoBuf to bytes - request = PushTaskResRequest(task_res_list=[task_res]) + request = PushTaskResRequest(node=node, task_res_list=[task_res]) _ = retry_invoker.invoke(stub.PushTaskRes, request) # Cleanup @@ -277,7 +277,7 @@ def send(message: Message) -> None: def get_run(run_id: int) -> Run: # Call FleetAPI - get_run_request = GetRunRequest(run_id=run_id) + get_run_request = GetRunRequest(node=node, run_id=run_id) get_run_response: GetRunResponse = retry_invoker.invoke( stub.GetRun, request=get_run_request, @@ -294,7 +294,7 @@ def get_run(run_id: int) -> Run: def get_fab(fab_hash: str) -> Fab: # Call FleetAPI - get_fab_request = GetFabRequest(hash_str=fab_hash) + get_fab_request = GetFabRequest(node=node, hash_str=fab_hash) get_fab_response: GetFabResponse = retry_invoker.invoke( stub.GetFab, request=get_fab_request, diff --git a/src/py/flwr/client/rest_client/connection.py b/src/py/flwr/client/rest_client/connection.py index 72b6be25a708..485bbd7a1810 100644 --- a/src/py/flwr/client/rest_client/connection.py +++ b/src/py/flwr/client/rest_client/connection.py @@ -340,7 +340,7 @@ def send(message: Message) -> None: task_res = message_to_taskres(message) # Serialize ProtoBuf to bytes - req = PushTaskResRequest(task_res_list=[task_res]) + req = PushTaskResRequest(node=node, task_res_list=[task_res]) # Send the request res = _request(req, PushTaskResResponse, PATH_PUSH_TASK_RES) @@ -356,7 +356,7 @@ def send(message: Message) -> None: def get_run(run_id: int) -> Run: # Construct the request - req = GetRunRequest(run_id=run_id) + req = GetRunRequest(node=node, run_id=run_id) # Send the request res = _request(req, GetRunResponse, PATH_GET_RUN) @@ -373,7 +373,7 @@ def get_run(run_id: int) -> Run: def get_fab(fab_hash: str) -> Fab: # Construct the request - req = GetFabRequest(hash_str=fab_hash) + req = GetFabRequest(node=node, hash_str=fab_hash) # Send the request res = _request(req, GetFabResponse, PATH_GET_FAB)