From 41fc0663495b3ce7e461a3f3293aa19ce98faa60 Mon Sep 17 00:00:00 2001 From: Adam Narozniak <51029327+adam-narozniak@users.noreply.github.com> Date: Thu, 26 Sep 2024 18:22:35 +0200 Subject: [PATCH 1/6] docs(datasets) Embed HF Space in the flwr-datasets docs (#4260) --- datasets/doc/source/index.rst | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/datasets/doc/source/index.rst b/datasets/doc/source/index.rst index 84e25a920f2f..070655550fa1 100644 --- a/datasets/doc/source/index.rst +++ b/datasets/doc/source/index.rst @@ -3,6 +3,15 @@ Flower Datasets Flower Datasets (``flwr-datasets``) is a library that enables the quick and easy creation of datasets for federated learning/analytics/evaluation. It enables heterogeneity (non-iidness) simulation and division of datasets with the preexisting notion of IDs. The library was created by the ``Flower Labs`` team that also created `Flower `_ : A Friendly Federated Learning Framework. +.. raw:: html + + + + + Flower Datasets Framework ------------------------- From 3aadc21ab83417ba657fb29d0139d734c3e46912 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Thu, 26 Sep 2024 18:44:46 +0100 Subject: [PATCH 2/6] fix(framework:skip) Use unsigned int for node IDs in the SecAgg+ protocol (#4246) --- src/py/flwr/common/secure_aggregation/secaggplus_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/py/flwr/common/secure_aggregation/secaggplus_utils.py b/src/py/flwr/common/secure_aggregation/secaggplus_utils.py index 7bfb80f57891..919894d5388f 100644 --- a/src/py/flwr/common/secure_aggregation/secaggplus_utils.py +++ b/src/py/flwr/common/secure_aggregation/secaggplus_utils.py @@ -43,8 +43,8 @@ def share_keys_plaintext_concat( """ return b"".join( [ - int.to_bytes(src_node_id, 8, "little", signed=True), - int.to_bytes(dst_node_id, 8, "little", signed=True), + int.to_bytes(src_node_id, 8, "little", signed=False), + int.to_bytes(dst_node_id, 8, "little", signed=False), int.to_bytes(len(b_share), 4, "little"), b_share, sk_share, @@ -72,8 +72,8 @@ def share_keys_plaintext_separate(plaintext: bytes) -> tuple[int, int, bytes, by the secret key share of the source sent to the destination. """ src, dst, mark = ( - int.from_bytes(plaintext[:8], "little", signed=True), - int.from_bytes(plaintext[8:16], "little", signed=True), + int.from_bytes(plaintext[:8], "little", signed=False), + int.from_bytes(plaintext[8:16], "little", signed=False), int.from_bytes(plaintext[16:20], "little"), ) ret = (src, dst, plaintext[20 : 20 + mark], plaintext[20 + mark :]) From 7bc3a5bff1c07247c57291da2a978edca1e75244 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Thu, 26 Sep 2024 18:52:54 +0100 Subject: [PATCH 3/6] feat(framework) Add `node` to all Fleet requests (#4250) Co-authored-by: Daniel Nata Nugraha --- src/proto/flwr/proto/fab.proto | 7 ++- src/proto/flwr/proto/fleet.proto | 5 +- src/proto/flwr/proto/run.proto | 11 +++- .../client/grpc_rere_client/connection.py | 6 +- src/py/flwr/client/rest_client/connection.py | 6 +- src/py/flwr/proto/fab_pb2.py | 15 ++--- src/py/flwr/proto/fab_pb2.pyi | 8 ++- src/py/flwr/proto/fleet_pb2.py | 20 +++---- src/py/flwr/proto/fleet_pb2.pyi | 7 ++- src/py/flwr/proto/run_pb2.py | 55 ++++++++++--------- src/py/flwr/proto/run_pb2.pyi | 15 ++++- 11 files changed, 97 insertions(+), 58 deletions(-) diff --git a/src/proto/flwr/proto/fab.proto b/src/proto/flwr/proto/fab.proto index 6f8e6b87808d..367b6e5b5c13 100644 --- a/src/proto/flwr/proto/fab.proto +++ b/src/proto/flwr/proto/fab.proto @@ -17,6 +17,8 @@ syntax = "proto3"; package flwr.proto; +import "flwr/proto/node.proto"; + message Fab { // This field is the hash of the data field. It is used to identify the data. // The hash is calculated using the SHA-256 algorithm and is represented as a @@ -26,5 +28,8 @@ message Fab { bytes content = 2; } -message GetFabRequest { string hash_str = 1; } +message GetFabRequest { + Node node = 1; + string hash_str = 2; +} message GetFabResponse { Fab fab = 1; } diff --git a/src/proto/flwr/proto/fleet.proto b/src/proto/flwr/proto/fleet.proto index b87214ac52f3..130b30b96669 100644 --- a/src/proto/flwr/proto/fleet.proto +++ b/src/proto/flwr/proto/fleet.proto @@ -69,7 +69,10 @@ message PullTaskInsResponse { } // PushTaskRes messages -message PushTaskResRequest { repeated TaskRes task_res_list = 1; } +message PushTaskResRequest { + Node node = 1; + repeated TaskRes task_res_list = 2; +} message PushTaskResResponse { Reconnect reconnect = 1; map results = 2; diff --git a/src/proto/flwr/proto/run.proto b/src/proto/flwr/proto/run.proto index 2c9bd877f66c..4312e1127cc2 100644 --- a/src/proto/flwr/proto/run.proto +++ b/src/proto/flwr/proto/run.proto @@ -18,6 +18,7 @@ syntax = "proto3"; package flwr.proto; import "flwr/proto/fab.proto"; +import "flwr/proto/node.proto"; import "flwr/proto/transport.proto"; message Run { @@ -47,7 +48,10 @@ message CreateRunRequest { message CreateRunResponse { uint64 run_id = 1; } // GetRun -message GetRunRequest { uint64 run_id = 1; } +message GetRunRequest { + Node node = 1; + uint64 run_id = 2; +} message GetRunResponse { Run run = 1; } // UpdateRunStatus @@ -58,5 +62,8 @@ message UpdateRunStatusRequest { message UpdateRunStatusResponse {} // GetRunStatus -message GetRunStatusRequest { repeated uint64 run_ids = 1; } +message GetRunStatusRequest { + Node node = 1; + repeated uint64 run_ids = 2; +} message GetRunStatusResponse { map run_status_dict = 1; } diff --git a/src/py/flwr/client/grpc_rere_client/connection.py b/src/py/flwr/client/grpc_rere_client/connection.py index 7ce3d37b7a17..b4fa28373600 100644 --- a/src/py/flwr/client/grpc_rere_client/connection.py +++ b/src/py/flwr/client/grpc_rere_client/connection.py @@ -269,7 +269,7 @@ def send(message: Message) -> None: task_res = message_to_taskres(message) # Serialize ProtoBuf to bytes - request = PushTaskResRequest(task_res_list=[task_res]) + request = PushTaskResRequest(node=node, task_res_list=[task_res]) _ = retry_invoker.invoke(stub.PushTaskRes, request) # Cleanup @@ -277,7 +277,7 @@ def send(message: Message) -> None: def get_run(run_id: int) -> Run: # Call FleetAPI - get_run_request = GetRunRequest(run_id=run_id) + get_run_request = GetRunRequest(node=node, run_id=run_id) get_run_response: GetRunResponse = retry_invoker.invoke( stub.GetRun, request=get_run_request, @@ -294,7 +294,7 @@ def get_run(run_id: int) -> Run: def get_fab(fab_hash: str) -> Fab: # Call FleetAPI - get_fab_request = GetFabRequest(hash_str=fab_hash) + get_fab_request = GetFabRequest(node=node, hash_str=fab_hash) get_fab_response: GetFabResponse = retry_invoker.invoke( stub.GetFab, request=get_fab_request, diff --git a/src/py/flwr/client/rest_client/connection.py b/src/py/flwr/client/rest_client/connection.py index 72b6be25a708..485bbd7a1810 100644 --- a/src/py/flwr/client/rest_client/connection.py +++ b/src/py/flwr/client/rest_client/connection.py @@ -340,7 +340,7 @@ def send(message: Message) -> None: task_res = message_to_taskres(message) # Serialize ProtoBuf to bytes - req = PushTaskResRequest(task_res_list=[task_res]) + req = PushTaskResRequest(node=node, task_res_list=[task_res]) # Send the request res = _request(req, PushTaskResResponse, PATH_PUSH_TASK_RES) @@ -356,7 +356,7 @@ def send(message: Message) -> None: def get_run(run_id: int) -> Run: # Construct the request - req = GetRunRequest(run_id=run_id) + req = GetRunRequest(node=node, run_id=run_id) # Send the request res = _request(req, GetRunResponse, PATH_GET_RUN) @@ -373,7 +373,7 @@ def get_run(run_id: int) -> Run: def get_fab(fab_hash: str) -> Fab: # Construct the request - req = GetFabRequest(hash_str=fab_hash) + req = GetFabRequest(node=node, hash_str=fab_hash) # Send the request res = _request(req, GetFabResponse, PATH_GET_FAB) diff --git a/src/py/flwr/proto/fab_pb2.py b/src/py/flwr/proto/fab_pb2.py index 3f04e6693ab8..3a5e50000c10 100644 --- a/src/py/flwr/proto/fab_pb2.py +++ b/src/py/flwr/proto/fab_pb2.py @@ -12,19 +12,20 @@ _sym_db = _symbol_database.Default() +from flwr.proto import node_pb2 as flwr_dot_proto_dot_node__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x14\x66lwr/proto/fab.proto\x12\nflwr.proto\"(\n\x03\x46\x61\x62\x12\x10\n\x08hash_str\x18\x01 \x01(\t\x12\x0f\n\x07\x63ontent\x18\x02 \x01(\x0c\"!\n\rGetFabRequest\x12\x10\n\x08hash_str\x18\x01 \x01(\t\".\n\x0eGetFabResponse\x12\x1c\n\x03\x66\x61\x62\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Fabb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x14\x66lwr/proto/fab.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\"(\n\x03\x46\x61\x62\x12\x10\n\x08hash_str\x18\x01 \x01(\t\x12\x0f\n\x07\x63ontent\x18\x02 \x01(\x0c\"A\n\rGetFabRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08hash_str\x18\x02 \x01(\t\".\n\x0eGetFabResponse\x12\x1c\n\x03\x66\x61\x62\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Fabb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.fab_pb2', _globals) if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None - _globals['_FAB']._serialized_start=36 - _globals['_FAB']._serialized_end=76 - _globals['_GETFABREQUEST']._serialized_start=78 - _globals['_GETFABREQUEST']._serialized_end=111 - _globals['_GETFABRESPONSE']._serialized_start=113 - _globals['_GETFABRESPONSE']._serialized_end=159 + _globals['_FAB']._serialized_start=59 + _globals['_FAB']._serialized_end=99 + _globals['_GETFABREQUEST']._serialized_start=101 + _globals['_GETFABREQUEST']._serialized_end=166 + _globals['_GETFABRESPONSE']._serialized_start=168 + _globals['_GETFABRESPONSE']._serialized_end=214 # @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/fab_pb2.pyi b/src/py/flwr/proto/fab_pb2.pyi index b2715dde5021..8cfdcbaf76ad 100644 --- a/src/py/flwr/proto/fab_pb2.pyi +++ b/src/py/flwr/proto/fab_pb2.pyi @@ -3,6 +3,7 @@ isort:skip_file """ import builtins +import flwr.proto.node_pb2 import google.protobuf.descriptor import google.protobuf.message import typing @@ -33,13 +34,18 @@ global___Fab = Fab class GetFabRequest(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor + NODE_FIELD_NUMBER: builtins.int HASH_STR_FIELD_NUMBER: builtins.int + @property + def node(self) -> flwr.proto.node_pb2.Node: ... hash_str: typing.Text def __init__(self, *, + node: typing.Optional[flwr.proto.node_pb2.Node] = ..., hash_str: typing.Text = ..., ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["hash_str",b"hash_str"]) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["node",b"node"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["hash_str",b"hash_str","node",b"node"]) -> None: ... global___GetFabRequest = GetFabRequest class GetFabResponse(google.protobuf.message.Message): diff --git a/src/py/flwr/proto/fleet_pb2.py b/src/py/flwr/proto/fleet_pb2.py index d1fe719f2d91..3185bc2ce111 100644 --- a/src/py/flwr/proto/fleet_pb2.py +++ b/src/py/flwr/proto/fleet_pb2.py @@ -18,7 +18,7 @@ from flwr.proto import fab_pb2 as flwr_dot_proto_dot_fab__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16\x66lwr/proto/fleet.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x15\x66lwr/proto/task.proto\x1a\x14\x66lwr/proto/run.proto\x1a\x14\x66lwr/proto/fab.proto\"*\n\x11\x43reateNodeRequest\x12\x15\n\rping_interval\x18\x01 \x01(\x01\"4\n\x12\x43reateNodeResponse\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\"3\n\x11\x44\x65leteNodeRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\"\x14\n\x12\x44\x65leteNodeResponse\"D\n\x0bPingRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x15\n\rping_interval\x18\x02 \x01(\x01\"\x1f\n\x0cPingResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\"F\n\x12PullTaskInsRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"k\n\x13PullTaskInsResponse\x12(\n\treconnect\x18\x01 \x01(\x0b\x32\x15.flwr.proto.Reconnect\x12*\n\rtask_ins_list\x18\x02 \x03(\x0b\x32\x13.flwr.proto.TaskIns\"@\n\x12PushTaskResRequest\x12*\n\rtask_res_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskRes\"\xae\x01\n\x13PushTaskResResponse\x12(\n\treconnect\x18\x01 \x01(\x0b\x32\x15.flwr.proto.Reconnect\x12=\n\x07results\x18\x02 \x03(\x0b\x32,.flwr.proto.PushTaskResResponse.ResultsEntry\x1a.\n\x0cResultsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\r:\x02\x38\x01\"\x1e\n\tReconnect\x12\x11\n\treconnect\x18\x01 \x01(\x04\x32\x8c\x04\n\x05\x46leet\x12M\n\nCreateNode\x12\x1d.flwr.proto.CreateNodeRequest\x1a\x1e.flwr.proto.CreateNodeResponse\"\x00\x12M\n\nDeleteNode\x12\x1d.flwr.proto.DeleteNodeRequest\x1a\x1e.flwr.proto.DeleteNodeResponse\"\x00\x12;\n\x04Ping\x12\x17.flwr.proto.PingRequest\x1a\x18.flwr.proto.PingResponse\"\x00\x12P\n\x0bPullTaskIns\x12\x1e.flwr.proto.PullTaskInsRequest\x1a\x1f.flwr.proto.PullTaskInsResponse\"\x00\x12P\n\x0bPushTaskRes\x12\x1e.flwr.proto.PushTaskResRequest\x1a\x1f.flwr.proto.PushTaskResResponse\"\x00\x12\x41\n\x06GetRun\x12\x19.flwr.proto.GetRunRequest\x1a\x1a.flwr.proto.GetRunResponse\"\x00\x12\x41\n\x06GetFab\x12\x19.flwr.proto.GetFabRequest\x1a\x1a.flwr.proto.GetFabResponse\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16\x66lwr/proto/fleet.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x15\x66lwr/proto/task.proto\x1a\x14\x66lwr/proto/run.proto\x1a\x14\x66lwr/proto/fab.proto\"*\n\x11\x43reateNodeRequest\x12\x15\n\rping_interval\x18\x01 \x01(\x01\"4\n\x12\x43reateNodeResponse\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\"3\n\x11\x44\x65leteNodeRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\"\x14\n\x12\x44\x65leteNodeResponse\"D\n\x0bPingRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x15\n\rping_interval\x18\x02 \x01(\x01\"\x1f\n\x0cPingResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\"F\n\x12PullTaskInsRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"k\n\x13PullTaskInsResponse\x12(\n\treconnect\x18\x01 \x01(\x0b\x32\x15.flwr.proto.Reconnect\x12*\n\rtask_ins_list\x18\x02 \x03(\x0b\x32\x13.flwr.proto.TaskIns\"`\n\x12PushTaskResRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12*\n\rtask_res_list\x18\x02 \x03(\x0b\x32\x13.flwr.proto.TaskRes\"\xae\x01\n\x13PushTaskResResponse\x12(\n\treconnect\x18\x01 \x01(\x0b\x32\x15.flwr.proto.Reconnect\x12=\n\x07results\x18\x02 \x03(\x0b\x32,.flwr.proto.PushTaskResResponse.ResultsEntry\x1a.\n\x0cResultsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\r:\x02\x38\x01\"\x1e\n\tReconnect\x12\x11\n\treconnect\x18\x01 \x01(\x04\x32\x8c\x04\n\x05\x46leet\x12M\n\nCreateNode\x12\x1d.flwr.proto.CreateNodeRequest\x1a\x1e.flwr.proto.CreateNodeResponse\"\x00\x12M\n\nDeleteNode\x12\x1d.flwr.proto.DeleteNodeRequest\x1a\x1e.flwr.proto.DeleteNodeResponse\"\x00\x12;\n\x04Ping\x12\x17.flwr.proto.PingRequest\x1a\x18.flwr.proto.PingResponse\"\x00\x12P\n\x0bPullTaskIns\x12\x1e.flwr.proto.PullTaskInsRequest\x1a\x1f.flwr.proto.PullTaskInsResponse\"\x00\x12P\n\x0bPushTaskRes\x12\x1e.flwr.proto.PushTaskResRequest\x1a\x1f.flwr.proto.PushTaskResResponse\"\x00\x12\x41\n\x06GetRun\x12\x19.flwr.proto.GetRunRequest\x1a\x1a.flwr.proto.GetRunResponse\"\x00\x12\x41\n\x06GetFab\x12\x19.flwr.proto.GetFabRequest\x1a\x1a.flwr.proto.GetFabResponse\"\x00\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -44,13 +44,13 @@ _globals['_PULLTASKINSRESPONSE']._serialized_start=476 _globals['_PULLTASKINSRESPONSE']._serialized_end=583 _globals['_PUSHTASKRESREQUEST']._serialized_start=585 - _globals['_PUSHTASKRESREQUEST']._serialized_end=649 - _globals['_PUSHTASKRESRESPONSE']._serialized_start=652 - _globals['_PUSHTASKRESRESPONSE']._serialized_end=826 - _globals['_PUSHTASKRESRESPONSE_RESULTSENTRY']._serialized_start=780 - _globals['_PUSHTASKRESRESPONSE_RESULTSENTRY']._serialized_end=826 - _globals['_RECONNECT']._serialized_start=828 - _globals['_RECONNECT']._serialized_end=858 - _globals['_FLEET']._serialized_start=861 - _globals['_FLEET']._serialized_end=1385 + _globals['_PUSHTASKRESREQUEST']._serialized_end=681 + _globals['_PUSHTASKRESRESPONSE']._serialized_start=684 + _globals['_PUSHTASKRESRESPONSE']._serialized_end=858 + _globals['_PUSHTASKRESRESPONSE_RESULTSENTRY']._serialized_start=812 + _globals['_PUSHTASKRESRESPONSE_RESULTSENTRY']._serialized_end=858 + _globals['_RECONNECT']._serialized_start=860 + _globals['_RECONNECT']._serialized_end=890 + _globals['_FLEET']._serialized_start=893 + _globals['_FLEET']._serialized_end=1417 # @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/fleet_pb2.pyi b/src/py/flwr/proto/fleet_pb2.pyi index 5989f45c5c60..76875bc1a4b9 100644 --- a/src/py/flwr/proto/fleet_pb2.pyi +++ b/src/py/flwr/proto/fleet_pb2.pyi @@ -124,14 +124,19 @@ global___PullTaskInsResponse = PullTaskInsResponse class PushTaskResRequest(google.protobuf.message.Message): """PushTaskRes messages""" DESCRIPTOR: google.protobuf.descriptor.Descriptor + NODE_FIELD_NUMBER: builtins.int TASK_RES_LIST_FIELD_NUMBER: builtins.int @property + def node(self) -> flwr.proto.node_pb2.Node: ... + @property def task_res_list(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[flwr.proto.task_pb2.TaskRes]: ... def __init__(self, *, + node: typing.Optional[flwr.proto.node_pb2.Node] = ..., task_res_list: typing.Optional[typing.Iterable[flwr.proto.task_pb2.TaskRes]] = ..., ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["task_res_list",b"task_res_list"]) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["node",b"node"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["node",b"node","task_res_list",b"task_res_list"]) -> None: ... global___PushTaskResRequest = PushTaskResRequest class PushTaskResResponse(google.protobuf.message.Message): diff --git a/src/py/flwr/proto/run_pb2.py b/src/py/flwr/proto/run_pb2.py index d59cc26fbb48..cc3f6897918f 100644 --- a/src/py/flwr/proto/run_pb2.py +++ b/src/py/flwr/proto/run_pb2.py @@ -13,10 +13,11 @@ from flwr.proto import fab_pb2 as flwr_dot_proto_dot_fab__pb2 +from flwr.proto import node_pb2 as flwr_dot_proto_dot_node__pb2 from flwr.proto import transport_pb2 as flwr_dot_proto_dot_transport__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x14\x66lwr/proto/run.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/fab.proto\x1a\x1a\x66lwr/proto/transport.proto\"\xd5\x01\n\x03Run\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12\x0e\n\x06\x66\x61\x62_id\x18\x02 \x01(\t\x12\x13\n\x0b\x66\x61\x62_version\x18\x03 \x01(\t\x12<\n\x0foverride_config\x18\x04 \x03(\x0b\x32#.flwr.proto.Run.OverrideConfigEntry\x12\x10\n\x08\x66\x61\x62_hash\x18\x05 \x01(\t\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"@\n\tRunStatus\x12\x0e\n\x06status\x18\x01 \x01(\t\x12\x12\n\nsub_status\x18\x02 \x01(\t\x12\x0f\n\x07\x64\x65tails\x18\x03 \x01(\t\"\xeb\x01\n\x10\x43reateRunRequest\x12\x0e\n\x06\x66\x61\x62_id\x18\x01 \x01(\t\x12\x13\n\x0b\x66\x61\x62_version\x18\x02 \x01(\t\x12I\n\x0foverride_config\x18\x03 \x03(\x0b\x32\x30.flwr.proto.CreateRunRequest.OverrideConfigEntry\x12\x1c\n\x03\x66\x61\x62\x18\x04 \x01(\x0b\x32\x0f.flwr.proto.Fab\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"#\n\x11\x43reateRunResponse\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"\x1f\n\rGetRunRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\".\n\x0eGetRunResponse\x12\x1c\n\x03run\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Run\"S\n\x16UpdateRunStatusRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12)\n\nrun_status\x18\x02 \x01(\x0b\x32\x15.flwr.proto.RunStatus\"\x19\n\x17UpdateRunStatusResponse\"&\n\x13GetRunStatusRequest\x12\x0f\n\x07run_ids\x18\x01 \x03(\x04\"\xb1\x01\n\x14GetRunStatusResponse\x12L\n\x0frun_status_dict\x18\x01 \x03(\x0b\x32\x33.flwr.proto.GetRunStatusResponse.RunStatusDictEntry\x1aK\n\x12RunStatusDictEntry\x12\x0b\n\x03key\x18\x01 \x01(\x04\x12$\n\x05value\x18\x02 \x01(\x0b\x32\x15.flwr.proto.RunStatus:\x02\x38\x01\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x14\x66lwr/proto/run.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/fab.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x1a\x66lwr/proto/transport.proto\"\xd5\x01\n\x03Run\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12\x0e\n\x06\x66\x61\x62_id\x18\x02 \x01(\t\x12\x13\n\x0b\x66\x61\x62_version\x18\x03 \x01(\t\x12<\n\x0foverride_config\x18\x04 \x03(\x0b\x32#.flwr.proto.Run.OverrideConfigEntry\x12\x10\n\x08\x66\x61\x62_hash\x18\x05 \x01(\t\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"@\n\tRunStatus\x12\x0e\n\x06status\x18\x01 \x01(\t\x12\x12\n\nsub_status\x18\x02 \x01(\t\x12\x0f\n\x07\x64\x65tails\x18\x03 \x01(\t\"\xeb\x01\n\x10\x43reateRunRequest\x12\x0e\n\x06\x66\x61\x62_id\x18\x01 \x01(\t\x12\x13\n\x0b\x66\x61\x62_version\x18\x02 \x01(\t\x12I\n\x0foverride_config\x18\x03 \x03(\x0b\x32\x30.flwr.proto.CreateRunRequest.OverrideConfigEntry\x12\x1c\n\x03\x66\x61\x62\x18\x04 \x01(\x0b\x32\x0f.flwr.proto.Fab\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"#\n\x11\x43reateRunResponse\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"?\n\rGetRunRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x0e\n\x06run_id\x18\x02 \x01(\x04\".\n\x0eGetRunResponse\x12\x1c\n\x03run\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Run\"S\n\x16UpdateRunStatusRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12)\n\nrun_status\x18\x02 \x01(\x0b\x32\x15.flwr.proto.RunStatus\"\x19\n\x17UpdateRunStatusResponse\"F\n\x13GetRunStatusRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x0f\n\x07run_ids\x18\x02 \x03(\x04\"\xb1\x01\n\x14GetRunStatusResponse\x12L\n\x0frun_status_dict\x18\x01 \x03(\x0b\x32\x33.flwr.proto.GetRunStatusResponse.RunStatusDictEntry\x1aK\n\x12RunStatusDictEntry\x12\x0b\n\x03key\x18\x01 \x01(\x04\x12$\n\x05value\x18\x02 \x01(\x0b\x32\x15.flwr.proto.RunStatus:\x02\x38\x01\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -29,30 +30,30 @@ _globals['_CREATERUNREQUEST_OVERRIDECONFIGENTRY']._serialized_options = b'8\001' _globals['_GETRUNSTATUSRESPONSE_RUNSTATUSDICTENTRY']._options = None _globals['_GETRUNSTATUSRESPONSE_RUNSTATUSDICTENTRY']._serialized_options = b'8\001' - _globals['_RUN']._serialized_start=87 - _globals['_RUN']._serialized_end=300 - _globals['_RUN_OVERRIDECONFIGENTRY']._serialized_start=227 - _globals['_RUN_OVERRIDECONFIGENTRY']._serialized_end=300 - _globals['_RUNSTATUS']._serialized_start=302 - _globals['_RUNSTATUS']._serialized_end=366 - _globals['_CREATERUNREQUEST']._serialized_start=369 - _globals['_CREATERUNREQUEST']._serialized_end=604 - _globals['_CREATERUNREQUEST_OVERRIDECONFIGENTRY']._serialized_start=227 - _globals['_CREATERUNREQUEST_OVERRIDECONFIGENTRY']._serialized_end=300 - _globals['_CREATERUNRESPONSE']._serialized_start=606 - _globals['_CREATERUNRESPONSE']._serialized_end=641 - _globals['_GETRUNREQUEST']._serialized_start=643 - _globals['_GETRUNREQUEST']._serialized_end=674 - _globals['_GETRUNRESPONSE']._serialized_start=676 - _globals['_GETRUNRESPONSE']._serialized_end=722 - _globals['_UPDATERUNSTATUSREQUEST']._serialized_start=724 - _globals['_UPDATERUNSTATUSREQUEST']._serialized_end=807 - _globals['_UPDATERUNSTATUSRESPONSE']._serialized_start=809 - _globals['_UPDATERUNSTATUSRESPONSE']._serialized_end=834 - _globals['_GETRUNSTATUSREQUEST']._serialized_start=836 - _globals['_GETRUNSTATUSREQUEST']._serialized_end=874 - _globals['_GETRUNSTATUSRESPONSE']._serialized_start=877 - _globals['_GETRUNSTATUSRESPONSE']._serialized_end=1054 - _globals['_GETRUNSTATUSRESPONSE_RUNSTATUSDICTENTRY']._serialized_start=979 - _globals['_GETRUNSTATUSRESPONSE_RUNSTATUSDICTENTRY']._serialized_end=1054 + _globals['_RUN']._serialized_start=110 + _globals['_RUN']._serialized_end=323 + _globals['_RUN_OVERRIDECONFIGENTRY']._serialized_start=250 + _globals['_RUN_OVERRIDECONFIGENTRY']._serialized_end=323 + _globals['_RUNSTATUS']._serialized_start=325 + _globals['_RUNSTATUS']._serialized_end=389 + _globals['_CREATERUNREQUEST']._serialized_start=392 + _globals['_CREATERUNREQUEST']._serialized_end=627 + _globals['_CREATERUNREQUEST_OVERRIDECONFIGENTRY']._serialized_start=250 + _globals['_CREATERUNREQUEST_OVERRIDECONFIGENTRY']._serialized_end=323 + _globals['_CREATERUNRESPONSE']._serialized_start=629 + _globals['_CREATERUNRESPONSE']._serialized_end=664 + _globals['_GETRUNREQUEST']._serialized_start=666 + _globals['_GETRUNREQUEST']._serialized_end=729 + _globals['_GETRUNRESPONSE']._serialized_start=731 + _globals['_GETRUNRESPONSE']._serialized_end=777 + _globals['_UPDATERUNSTATUSREQUEST']._serialized_start=779 + _globals['_UPDATERUNSTATUSREQUEST']._serialized_end=862 + _globals['_UPDATERUNSTATUSRESPONSE']._serialized_start=864 + _globals['_UPDATERUNSTATUSRESPONSE']._serialized_end=889 + _globals['_GETRUNSTATUSREQUEST']._serialized_start=891 + _globals['_GETRUNSTATUSREQUEST']._serialized_end=961 + _globals['_GETRUNSTATUSRESPONSE']._serialized_start=964 + _globals['_GETRUNSTATUSRESPONSE']._serialized_end=1141 + _globals['_GETRUNSTATUSRESPONSE_RUNSTATUSDICTENTRY']._serialized_start=1066 + _globals['_GETRUNSTATUSRESPONSE_RUNSTATUSDICTENTRY']._serialized_end=1141 # @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/run_pb2.pyi b/src/py/flwr/proto/run_pb2.pyi index cec90c4d2d4c..16411712eaf2 100644 --- a/src/py/flwr/proto/run_pb2.pyi +++ b/src/py/flwr/proto/run_pb2.pyi @@ -4,6 +4,7 @@ isort:skip_file """ import builtins import flwr.proto.fab_pb2 +import flwr.proto.node_pb2 import flwr.proto.transport_pb2 import google.protobuf.descriptor import google.protobuf.internal.containers @@ -128,13 +129,18 @@ global___CreateRunResponse = CreateRunResponse class GetRunRequest(google.protobuf.message.Message): """GetRun""" DESCRIPTOR: google.protobuf.descriptor.Descriptor + NODE_FIELD_NUMBER: builtins.int RUN_ID_FIELD_NUMBER: builtins.int + @property + def node(self) -> flwr.proto.node_pb2.Node: ... run_id: builtins.int def __init__(self, *, + node: typing.Optional[flwr.proto.node_pb2.Node] = ..., run_id: builtins.int = ..., ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["run_id",b"run_id"]) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["node",b"node"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["node",b"node","run_id",b"run_id"]) -> None: ... global___GetRunRequest = GetRunRequest class GetRunResponse(google.protobuf.message.Message): @@ -176,14 +182,19 @@ global___UpdateRunStatusResponse = UpdateRunStatusResponse class GetRunStatusRequest(google.protobuf.message.Message): """GetRunStatus""" DESCRIPTOR: google.protobuf.descriptor.Descriptor + NODE_FIELD_NUMBER: builtins.int RUN_IDS_FIELD_NUMBER: builtins.int @property + def node(self) -> flwr.proto.node_pb2.Node: ... + @property def run_ids(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ... def __init__(self, *, + node: typing.Optional[flwr.proto.node_pb2.Node] = ..., run_ids: typing.Optional[typing.Iterable[builtins.int]] = ..., ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["run_ids",b"run_ids"]) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["node",b"node"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["node",b"node","run_ids",b"run_ids"]) -> None: ... global___GetRunStatusRequest = GetRunStatusRequest class GetRunStatusResponse(google.protobuf.message.Message): From 6d37e2510557daa0c13dc632a844f72a6a433bff Mon Sep 17 00:00:00 2001 From: Chong Shen Ng Date: Thu, 26 Sep 2024 19:35:26 +0100 Subject: [PATCH 4/6] feat(framework) Add log functions to `flwr log` (#3611) --- src/py/flwr/cli/log.py | 37 ++++++++++++++++-- src/py/flwr/cli/log_test.py | 78 +++++++++++++++++++++++++++++++++++++ 2 files changed, 112 insertions(+), 3 deletions(-) create mode 100644 src/py/flwr/cli/log_test.py diff --git a/src/py/flwr/cli/log.py b/src/py/flwr/cli/log.py index 6915de1e00c5..cd4079c1c131 100644 --- a/src/py/flwr/cli/log.py +++ b/src/py/flwr/cli/log.py @@ -26,18 +26,49 @@ from flwr.cli.config_utils import load_and_validate from flwr.common.grpc import GRPC_MAX_MESSAGE_LENGTH, create_channel from flwr.common.logger import log as logger +from flwr.proto.exec_pb2 import StreamLogsRequest # pylint: disable=E0611 +from flwr.proto.exec_pb2_grpc import ExecStub CONN_REFRESH_PERIOD = 60 # Connection refresh period for log streaming (seconds) -# pylint: disable=unused-argument -def stream_logs(run_id: int, channel: grpc.Channel, period: int) -> None: +def stream_logs(run_id: int, channel: grpc.Channel, duration: int) -> None: """Stream logs from the beginning of a run with connection refresh.""" + start_time = time.time() + stub = ExecStub(channel) + req = StreamLogsRequest(run_id=run_id) + + for res in stub.StreamLogs(req): + print(res.log_output) + if time.time() - start_time > duration: + break -# pylint: disable=unused-argument def print_logs(run_id: int, channel: grpc.Channel, timeout: int) -> None: """Print logs from the beginning of a run.""" + stub = ExecStub(channel) + req = StreamLogsRequest(run_id=run_id) + + try: + while True: + try: + # Enforce timeout for graceful exit + for res in stub.StreamLogs(req, timeout=timeout): + print(res.log_output) + except grpc.RpcError as e: + # pylint: disable=E1101 + if e.code() == grpc.StatusCode.DEADLINE_EXCEEDED: + break + if e.code() == grpc.StatusCode.NOT_FOUND: + logger(ERROR, "Invalid run_id `%s`, exiting", run_id) + break + if e.code() == grpc.StatusCode.CANCELLED: + break + except KeyboardInterrupt: + logger(DEBUG, "Stream interrupted by user") + finally: + channel.close() + logger(DEBUG, "Channel closed") def on_channel_state_change(channel_connectivity: str) -> None: diff --git a/src/py/flwr/cli/log_test.py b/src/py/flwr/cli/log_test.py new file mode 100644 index 000000000000..932610bea2f3 --- /dev/null +++ b/src/py/flwr/cli/log_test.py @@ -0,0 +1,78 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test for Flower command line interface `log` command.""" + + +import unittest +from typing import NoReturn +from unittest.mock import Mock, call, patch + +from flwr.proto.exec_pb2 import StreamLogsResponse # pylint: disable=E0611 + +from .log import print_logs, stream_logs + + +class InterruptedStreamLogsResponse: + """Create a StreamLogsResponse object with KeyboardInterrupt.""" + + @property + def log_output(self) -> NoReturn: + """Raise KeyboardInterrupt to exit logstream test gracefully.""" + raise KeyboardInterrupt + + +class TestFlwrLog(unittest.TestCase): + """Unit tests for `flwr log` CLI functions.""" + + def setUp(self) -> None: + """Initialize mock ExecStub before each test.""" + self.expected_calls = [ + call("log_output_1"), + call("log_output_2"), + call("log_output_3"), + ] + mock_response_iterator = [ + iter( + [StreamLogsResponse(log_output=f"log_output_{i}") for i in range(1, 4)] + + [InterruptedStreamLogsResponse()] + ) + ] + self.mock_stub = Mock() + self.mock_stub.StreamLogs.side_effect = mock_response_iterator + self.patcher = patch("flwr.cli.log.ExecStub", return_value=self.mock_stub) + + self.patcher.start() + + # Create mock channel + self.mock_channel = Mock() + + def tearDown(self) -> None: + """Cleanup.""" + self.patcher.stop() + + def test_flwr_log_stream_method(self) -> None: + """Test stream_logs.""" + with patch("builtins.print") as mock_print: + with self.assertRaises(KeyboardInterrupt): + stream_logs(run_id=123, channel=self.mock_channel, duration=1) + # Assert that mock print was called with the expected arguments + mock_print.assert_has_calls(self.expected_calls) + + def test_flwr_log_print_method(self) -> None: + """Test print_logs.""" + with patch("builtins.print") as mock_print: + print_logs(run_id=123, channel=self.mock_channel, timeout=0) + # Assert that mock print was called with the expected arguments + mock_print.assert_has_calls(self.expected_calls) From 8fa9c5641fb06d4e3c9af1dd2990dfcdcf556b8e Mon Sep 17 00:00:00 2001 From: Mohammad Naseri Date: Thu, 26 Sep 2024 19:44:36 +0100 Subject: [PATCH 5/6] feat(framework) Verify message TTL when storing TaskIns and TaskRes (#3596) Co-authored-by: Heng Pan --- src/py/flwr/server/utils/validator.py | 6 ++++++ src/py/flwr/server/utils/validator_test.py | 18 ++++++++++++++++++ 2 files changed, 24 insertions(+) diff --git a/src/py/flwr/server/utils/validator.py b/src/py/flwr/server/utils/validator.py index fb3d0425db86..01f926c4985d 100644 --- a/src/py/flwr/server/utils/validator.py +++ b/src/py/flwr/server/utils/validator.py @@ -15,6 +15,7 @@ """Validators.""" +import time from typing import Union from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611 @@ -47,6 +48,11 @@ def validate_task_ins_or_res(tasks_ins_res: Union[TaskIns, TaskRes]) -> list[str # unix timestamp of 27 March 2024 00h:00m:00s UTC validation_errors.append("`pushed_at` is not a recent timestamp") + # Verify TTL and created_at time + current_time = time.time() + if tasks_ins_res.task.created_at + tasks_ins_res.task.ttl <= current_time: + validation_errors.append("Task TTL has expired") + # TaskIns specific if isinstance(tasks_ins_res, TaskIns): # Task producer diff --git a/src/py/flwr/server/utils/validator_test.py b/src/py/flwr/server/utils/validator_test.py index 20162883efea..ce8e3636467c 100644 --- a/src/py/flwr/server/utils/validator_test.py +++ b/src/py/flwr/server/utils/validator_test.py @@ -76,6 +76,24 @@ def test_is_valid_task_res(self) -> None: val_errors = validate_task_ins_or_res(msg) self.assertTrue(val_errors, (producer_node_id, anonymous, ancestry)) + def test_task_ttl_expired(self) -> None: + """Test validation for expired Task TTL.""" + # Prepare an expired TaskIns + expired_task_ins = create_task_ins(0, True) + expired_task_ins.task.created_at = time.time() - 10 # 10 seconds ago + expired_task_ins.task.ttl = 6 # 6 seconds TTL + + expired_task_res = create_task_res(0, True, ["1"]) + expired_task_res.task.created_at = time.time() - 10 # 10 seconds ago + expired_task_res.task.ttl = 6 # 6 seconds TTL + + # Execute & Assert + val_errors_ins = validate_task_ins_or_res(expired_task_ins) + self.assertIn("Task TTL has expired", val_errors_ins) + + val_errors_res = validate_task_ins_or_res(expired_task_res) + self.assertIn("Task TTL has expired", val_errors_res) + def create_task_ins( consumer_node_id: int, From 83cd4ba34f864e9134913e310079d02870fa64be Mon Sep 17 00:00:00 2001 From: Mohammad Naseri Date: Thu, 26 Sep 2024 19:59:16 +0100 Subject: [PATCH 6/6] feat(framework) Verify the TaskIns TTL when saving TaskRes (#3609) Co-authored-by: Heng Pan --- .../server/superlink/state/in_memory_state.py | 17 ++++++ .../server/superlink/state/sqlite_state.py | 40 +++++++++++++- .../flwr/server/superlink/state/state_test.py | 53 +++++++++++++++++-- 3 files changed, 104 insertions(+), 6 deletions(-) diff --git a/src/py/flwr/server/superlink/state/in_memory_state.py b/src/py/flwr/server/superlink/state/in_memory_state.py index e34d15374350..e09df8dc76f6 100644 --- a/src/py/flwr/server/superlink/state/in_memory_state.py +++ b/src/py/flwr/server/superlink/state/in_memory_state.py @@ -117,6 +117,23 @@ def store_task_res(self, task_res: TaskRes) -> Optional[UUID]: log(ERROR, errors) return None + with self.lock: + # Check if the TaskIns it is replying to exists and is valid + task_ins_id = task_res.task.ancestry[0] + task_ins = self.task_ins_store.get(UUID(task_ins_id)) + + if task_ins is None: + log(ERROR, "TaskIns with task_id %s does not exist.", task_ins_id) + return None + + if task_ins.task.created_at + task_ins.task.ttl <= time.time(): + log( + ERROR, + "Failed to store TaskRes: TaskIns with task_id %s has expired.", + task_ins_id, + ) + return None + # Validate run_id if task_res.run_id not in self.run_ids: log(ERROR, "`run_id` is invalid") diff --git a/src/py/flwr/server/superlink/state/sqlite_state.py b/src/py/flwr/server/superlink/state/sqlite_state.py index 28d957a90bd3..d18683286196 100644 --- a/src/py/flwr/server/superlink/state/sqlite_state.py +++ b/src/py/flwr/server/superlink/state/sqlite_state.py @@ -372,7 +372,18 @@ def store_task_res(self, task_res: TaskRes) -> Optional[UUID]: # Create task_id task_id = uuid4() - # Store TaskIns + task_ins_id = task_res.task.ancestry[0] + task_ins = self.get_valid_task_ins(task_ins_id) + if task_ins is None: + log( + ERROR, + "Failed to store TaskRes: " + "TaskIns with task_id %s does not exist or has expired.", + task_ins_id, + ) + return None + + # Store TaskRes task_res.task_id = str(task_id) data = (task_res_to_dict(task_res),) @@ -810,6 +821,33 @@ def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool: log(ERROR, "`node_id` does not exist.") return False + def get_valid_task_ins(self, task_id: str) -> Optional[dict[str, Any]]: + """Check if the TaskIns exists and is valid (not expired). + + Return TaskIns if valid. + """ + query = """ + SELECT * + FROM task_ins + WHERE task_id = :task_id + """ + data = {"task_id": task_id} + rows = self.query(query, data) + if not rows: + # TaskIns does not exist + return None + + task_ins = rows[0] + created_at = task_ins["created_at"] + ttl = task_ins["ttl"] + current_time = time.time() + + # Check if TaskIns is expired + if ttl is not None and created_at + ttl <= current_time: + return None + + return task_ins + def dict_factory( cursor: sqlite3.Cursor, diff --git a/src/py/flwr/server/superlink/state/state_test.py b/src/py/flwr/server/superlink/state/state_test.py index 42c0768f1c7d..85cda1a5af9c 100644 --- a/src/py/flwr/server/superlink/state/state_test.py +++ b/src/py/flwr/server/superlink/state/state_test.py @@ -21,7 +21,6 @@ from abc import abstractmethod from datetime import datetime, timezone from unittest.mock import patch -from uuid import uuid4 from flwr.common import DEFAULT_TTL from flwr.common.constant import ErrorCode @@ -302,7 +301,10 @@ def test_task_res_store_and_retrieve_by_task_ins_id(self) -> None: # Prepare state: State = self.state_factory() run_id = state.create_run(None, None, "9f86d08", {}) - task_ins_id = uuid4() + + task_ins = create_task_ins(consumer_node_id=0, anonymous=True, run_id=run_id) + task_ins_id = state.store_task_ins(task_ins) + task_res = create_task_res( producer_node_id=0, anonymous=True, @@ -312,7 +314,9 @@ def test_task_res_store_and_retrieve_by_task_ins_id(self) -> None: # Execute task_res_uuid = state.store_task_res(task_res) - task_res_list = state.get_task_res(task_ids={task_ins_id}, limit=None) + + if task_ins_id is not None: + task_res_list = state.get_task_res(task_ids={task_ins_id}, limit=None) # Assert retrieved_task_res = task_res_list[0] @@ -507,11 +511,23 @@ def test_num_task_res(self) -> None: # Prepare state: State = self.state_factory() run_id = state.create_run(None, None, "9f86d08", {}) + + task_ins_0 = create_task_ins(consumer_node_id=0, anonymous=True, run_id=run_id) + task_ins_1 = create_task_ins(consumer_node_id=0, anonymous=True, run_id=run_id) + task_ins_id_0 = state.store_task_ins(task_ins_0) + task_ins_id_1 = state.store_task_ins(task_ins_1) + task_0 = create_task_res( - producer_node_id=0, anonymous=True, ancestry=["1"], run_id=run_id + producer_node_id=0, + anonymous=True, + ancestry=[str(task_ins_id_0)], + run_id=run_id, ) task_1 = create_task_res( - producer_node_id=0, anonymous=True, ancestry=["1"], run_id=run_id + producer_node_id=0, + anonymous=True, + ancestry=[str(task_ins_id_1)], + run_id=run_id, ) # Store two tasks @@ -664,6 +680,33 @@ def test_node_unavailable_error(self) -> None: assert err_taskres.task.HasField("error") assert err_taskres.task.error.code == ErrorCode.NODE_UNAVAILABLE + def test_store_task_res_task_ins_expired(self) -> None: + """Test behavior of store_task_res when the TaskIns it references is expired.""" + # Prepare + state: State = self.state_factory() + run_id = state.create_run(None, None, "9f86d08", {}) + + task_ins = create_task_ins(consumer_node_id=0, anonymous=True, run_id=run_id) + task_ins.task.created_at = time.time() - task_ins.task.ttl + 0.5 + task_ins_id = state.store_task_ins(task_ins) + + with patch( + "time.time", + side_effect=lambda: task_ins.task.created_at + task_ins.task.ttl + 0.1, + ): # Expired by 0.1 seconds + task = create_task_res( + producer_node_id=0, + anonymous=True, + ancestry=[str(task_ins_id)], + run_id=run_id, + ) + + # Execute + result = state.store_task_res(task) + + # Assert + assert result is None + def create_task_ins( consumer_node_id: int,