From be506a38d58829b58a2dc0ed5e4997cddc117692 Mon Sep 17 00:00:00 2001 From: Javier Date: Thu, 13 Feb 2025 08:30:08 +0000 Subject: [PATCH] refactor(framework) Remove `pushed_at` in Task (#4924) --- src/proto/flwr/proto/task.proto | 1 - src/py/flwr/proto/task_pb2.py | 12 ++++++------ src/py/flwr/proto/task_pb2.pyi | 5 +---- src/py/flwr/server/driver/inmemory_driver.py | 1 - src/py/flwr/server/driver/inmemory_driver_test.py | 1 - .../server/superlink/driver/serverappio_servicer.py | 6 +----- .../fleet/message_handler/message_handler.py | 4 ---- src/py/flwr/server/superlink/fleet/vce/vce_api.py | 1 - .../flwr/server/superlink/fleet/vce/vce_api_test.py | 4 ---- .../server/superlink/linkstate/linkstate_test.py | 3 --- .../server/superlink/linkstate/sqlite_linkstate.py | 6 ------ .../superlink/linkstate/sqlite_linkstate_test.py | 1 - src/py/flwr/server/utils/validator.py | 3 --- src/py/flwr/server/utils/validator_test.py | 2 -- 14 files changed, 8 insertions(+), 42 deletions(-) diff --git a/src/proto/flwr/proto/task.proto b/src/proto/flwr/proto/task.proto index 324a70a5359c..3eda3ce278da 100644 --- a/src/proto/flwr/proto/task.proto +++ b/src/proto/flwr/proto/task.proto @@ -26,7 +26,6 @@ message Task { Node consumer = 2; double created_at = 3; string delivered_at = 4; - double pushed_at = 5; double ttl = 6; repeated string ancestry = 7; string task_type = 8; diff --git a/src/py/flwr/proto/task_pb2.py b/src/py/flwr/proto/task_pb2.py index 24b3b5fa024a..24feab2c449b 100644 --- a/src/py/flwr/proto/task_pb2.py +++ b/src/py/flwr/proto/task_pb2.py @@ -17,7 +17,7 @@ from flwr.proto import error_pb2 as flwr_dot_proto_dot_error__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/task.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\x16\x66lwr/proto/error.proto\"\x89\x02\n\x04Task\x12\"\n\x08producer\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\"\n\x08\x63onsumer\x18\x02 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x12\n\ncreated_at\x18\x03 \x01(\x01\x12\x14\n\x0c\x64\x65livered_at\x18\x04 \x01(\t\x12\x11\n\tpushed_at\x18\x05 \x01(\x01\x12\x0b\n\x03ttl\x18\x06 \x01(\x01\x12\x10\n\x08\x61ncestry\x18\x07 \x03(\t\x12\x11\n\ttask_type\x18\x08 \x01(\t\x12(\n\trecordset\x18\t \x01(\x0b\x32\x15.flwr.proto.RecordSet\x12 \n\x05\x65rror\x18\n \x01(\x0b\x32\x11.flwr.proto.Error\"\\\n\x07TaskIns\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x04\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Task\"\\\n\x07TaskRes\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x04\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Taskb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/task.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\x16\x66lwr/proto/error.proto\"\xf6\x01\n\x04Task\x12\"\n\x08producer\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\"\n\x08\x63onsumer\x18\x02 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x12\n\ncreated_at\x18\x03 \x01(\x01\x12\x14\n\x0c\x64\x65livered_at\x18\x04 \x01(\t\x12\x0b\n\x03ttl\x18\x06 \x01(\x01\x12\x10\n\x08\x61ncestry\x18\x07 \x03(\t\x12\x11\n\ttask_type\x18\x08 \x01(\t\x12(\n\trecordset\x18\t \x01(\x0b\x32\x15.flwr.proto.RecordSet\x12 \n\x05\x65rror\x18\n \x01(\x0b\x32\x11.flwr.proto.Error\"\\\n\x07TaskIns\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x04\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Task\"\\\n\x07TaskRes\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x04\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Taskb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -25,9 +25,9 @@ if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None _globals['_TASK']._serialized_start=113 - _globals['_TASK']._serialized_end=378 - _globals['_TASKINS']._serialized_start=380 - _globals['_TASKINS']._serialized_end=472 - _globals['_TASKRES']._serialized_start=474 - _globals['_TASKRES']._serialized_end=566 + _globals['_TASK']._serialized_end=359 + _globals['_TASKINS']._serialized_start=361 + _globals['_TASKINS']._serialized_end=453 + _globals['_TASKRES']._serialized_start=455 + _globals['_TASKRES']._serialized_end=547 # @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/task_pb2.pyi b/src/py/flwr/proto/task_pb2.pyi index 455791ac9e6e..2b463bde9ca6 100644 --- a/src/py/flwr/proto/task_pb2.pyi +++ b/src/py/flwr/proto/task_pb2.pyi @@ -20,7 +20,6 @@ class Task(google.protobuf.message.Message): CONSUMER_FIELD_NUMBER: builtins.int CREATED_AT_FIELD_NUMBER: builtins.int DELIVERED_AT_FIELD_NUMBER: builtins.int - PUSHED_AT_FIELD_NUMBER: builtins.int TTL_FIELD_NUMBER: builtins.int ANCESTRY_FIELD_NUMBER: builtins.int TASK_TYPE_FIELD_NUMBER: builtins.int @@ -32,7 +31,6 @@ class Task(google.protobuf.message.Message): def consumer(self) -> flwr.proto.node_pb2.Node: ... created_at: builtins.float delivered_at: typing.Text - pushed_at: builtins.float ttl: builtins.float @property def ancestry(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[typing.Text]: ... @@ -47,7 +45,6 @@ class Task(google.protobuf.message.Message): consumer: typing.Optional[flwr.proto.node_pb2.Node] = ..., created_at: builtins.float = ..., delivered_at: typing.Text = ..., - pushed_at: builtins.float = ..., ttl: builtins.float = ..., ancestry: typing.Optional[typing.Iterable[typing.Text]] = ..., task_type: typing.Text = ..., @@ -55,7 +52,7 @@ class Task(google.protobuf.message.Message): error: typing.Optional[flwr.proto.error_pb2.Error] = ..., ) -> None: ... def HasField(self, field_name: typing_extensions.Literal["consumer",b"consumer","error",b"error","producer",b"producer","recordset",b"recordset"]) -> builtins.bool: ... - def ClearField(self, field_name: typing_extensions.Literal["ancestry",b"ancestry","consumer",b"consumer","created_at",b"created_at","delivered_at",b"delivered_at","error",b"error","producer",b"producer","pushed_at",b"pushed_at","recordset",b"recordset","task_type",b"task_type","ttl",b"ttl"]) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["ancestry",b"ancestry","consumer",b"consumer","created_at",b"created_at","delivered_at",b"delivered_at","error",b"error","producer",b"producer","recordset",b"recordset","task_type",b"task_type","ttl",b"ttl"]) -> None: ... global___Task = Task class TaskIns(google.protobuf.message.Message): diff --git a/src/py/flwr/server/driver/inmemory_driver.py b/src/py/flwr/server/driver/inmemory_driver.py index ede37ee64c02..ecdaad01c0eb 100644 --- a/src/py/flwr/server/driver/inmemory_driver.py +++ b/src/py/flwr/server/driver/inmemory_driver.py @@ -126,7 +126,6 @@ def push_messages(self, messages: Iterable[Message]) -> Iterable[str]: # Convert Message to TaskIns taskins = message_to_taskins(msg) # Store in state - taskins.task.pushed_at = time.time() task_id = self.state.store_task_ins(taskins) if task_id: task_ids.append(str(task_id)) diff --git a/src/py/flwr/server/driver/inmemory_driver_test.py b/src/py/flwr/server/driver/inmemory_driver_test.py index d755454391e6..9874cc31e60e 100644 --- a/src/py/flwr/server/driver/inmemory_driver_test.py +++ b/src/py/flwr/server/driver/inmemory_driver_test.py @@ -65,7 +65,6 @@ def get_replies( msg = message_from_taskins(taskin) reply_msg = msg.create_reply(RecordSet()) task_res = message_to_taskres(reply_msg) - task_res.task.pushed_at = time.time() driver.state.store_task_res(task_res=task_res) # Execute: Pull messages diff --git a/src/py/flwr/server/superlink/driver/serverappio_servicer.py b/src/py/flwr/server/superlink/driver/serverappio_servicer.py index 910320d50568..646440ec820c 100644 --- a/src/py/flwr/server/superlink/driver/serverappio_servicer.py +++ b/src/py/flwr/server/superlink/driver/serverappio_servicer.py @@ -22,7 +22,7 @@ import grpc -from flwr.common import ConfigsRecord, now +from flwr.common import ConfigsRecord from flwr.common.constant import Status from flwr.common.logger import log from flwr.common.serde import ( @@ -151,9 +151,6 @@ def PushMessages( context, ) - # Set pushed_at (timestamp in seconds) - pushed_at = now().timestamp() - # Validate request and insert in State _raise_if( validation_error=len(request.messages_list) == 0, @@ -165,7 +162,6 @@ def PushMessages( message_proto = request.messages_list.pop(0) message = message_from_proto(message_proto=message_proto) task_ins = message_to_taskins(message=message) - task_ins.task.pushed_at = pushed_at validation_errors = validate_task_ins_or_res(task_ins) _raise_if( validation_error=bool(validation_errors), diff --git a/src/py/flwr/server/superlink/fleet/message_handler/message_handler.py b/src/py/flwr/server/superlink/fleet/message_handler/message_handler.py index fcba1afe3008..185c086c9370 100644 --- a/src/py/flwr/server/superlink/fleet/message_handler/message_handler.py +++ b/src/py/flwr/server/superlink/fleet/message_handler/message_handler.py @@ -15,7 +15,6 @@ """Fleet API message handlers.""" -import time from typing import Optional from uuid import UUID @@ -122,9 +121,6 @@ def push_messages( if abort_msg: raise InvalidRunStatusException(abort_msg) - # Set pushed_at (timestamp in seconds) - task_res.task.pushed_at = time.time() - # Store TaskRes in State message_id: Optional[UUID] = state.store_task_res(task_res=task_res) diff --git a/src/py/flwr/server/superlink/fleet/vce/vce_api.py b/src/py/flwr/server/superlink/fleet/vce/vce_api.py index e278e1902e3d..0e93a55788d5 100644 --- a/src/py/flwr/server/superlink/fleet/vce/vce_api.py +++ b/src/py/flwr/server/superlink/fleet/vce/vce_api.py @@ -140,7 +140,6 @@ def worker( # Convert to TaskRes task_res = message_to_taskres(out_mssg) # Store TaskRes in state - task_res.task.pushed_at = time.time() taskres_queue.put(task_res) diff --git a/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py b/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py index 9faf9f341af6..b7e934076d3b 100644 --- a/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py +++ b/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py @@ -16,7 +16,6 @@ import threading -import time from itertools import cycle from json import JSONDecodeError from math import pi @@ -160,9 +159,6 @@ def register_messages_into_state( ) # Convert Message to TaskIns taskins = message_to_taskins(message) - # Normally recorded by the driver servicer - # but since we don't have one in this test, we do this manually - taskins.task.pushed_at = time.time() # Instert in state task_id = state.store_task_ins(taskins) if task_id: diff --git a/src/py/flwr/server/superlink/linkstate/linkstate_test.py b/src/py/flwr/server/superlink/linkstate/linkstate_test.py index 95247156a3d8..371e507e662e 100644 --- a/src/py/flwr/server/superlink/linkstate/linkstate_test.py +++ b/src/py/flwr/server/superlink/linkstate/linkstate_test.py @@ -275,7 +275,6 @@ def test_store_task_ins_one(self) -> None: assert actual_task.delivered_at != "" - assert actual_task.created_at < actual_task.pushed_at assert datetime.fromisoformat(actual_task.delivered_at) > datetime( 2020, 1, 1, tzinfo=timezone.utc ) @@ -1128,7 +1127,6 @@ def create_task_ins( created_at=time.time(), ), ) - task.task.pushed_at = time.time() return task @@ -1193,7 +1191,6 @@ def create_task_res( created_at=time.time(), ), ) - task_res.task.pushed_at = time.time() return task_res diff --git a/src/py/flwr/server/superlink/linkstate/sqlite_linkstate.py b/src/py/flwr/server/superlink/linkstate/sqlite_linkstate.py index 61da0e5f136a..734354530490 100644 --- a/src/py/flwr/server/superlink/linkstate/sqlite_linkstate.py +++ b/src/py/flwr/server/superlink/linkstate/sqlite_linkstate.py @@ -126,7 +126,6 @@ consumer_node_id INTEGER, created_at REAL, delivered_at TEXT, - pushed_at REAL, ttl REAL, ancestry TEXT, task_type TEXT, @@ -144,7 +143,6 @@ consumer_node_id INTEGER, created_at REAL, delivered_at TEXT, - pushed_at REAL, ttl REAL, ancestry TEXT, task_type TEXT, @@ -1053,7 +1051,6 @@ def task_ins_to_dict(task_msg: TaskIns) -> dict[str, Any]: "consumer_node_id": task_msg.task.consumer.node_id, "created_at": task_msg.task.created_at, "delivered_at": task_msg.task.delivered_at, - "pushed_at": task_msg.task.pushed_at, "ttl": task_msg.task.ttl, "ancestry": ",".join(task_msg.task.ancestry), "task_type": task_msg.task.task_type, @@ -1072,7 +1069,6 @@ def task_res_to_dict(task_msg: TaskRes) -> dict[str, Any]: "consumer_node_id": task_msg.task.consumer.node_id, "created_at": task_msg.task.created_at, "delivered_at": task_msg.task.delivered_at, - "pushed_at": task_msg.task.pushed_at, "ttl": task_msg.task.ttl, "ancestry": ",".join(task_msg.task.ancestry), "task_type": task_msg.task.task_type, @@ -1099,7 +1095,6 @@ def dict_to_task_ins(task_dict: dict[str, Any]) -> TaskIns: ), created_at=task_dict["created_at"], delivered_at=task_dict["delivered_at"], - pushed_at=task_dict["pushed_at"], ttl=task_dict["ttl"], ancestry=task_dict["ancestry"].split(","), task_type=task_dict["task_type"], @@ -1127,7 +1122,6 @@ def dict_to_task_res(task_dict: dict[str, Any]) -> TaskRes: ), created_at=task_dict["created_at"], delivered_at=task_dict["delivered_at"], - pushed_at=task_dict["pushed_at"], ttl=task_dict["ttl"], ancestry=task_dict["ancestry"].split(","), task_type=task_dict["task_type"], diff --git a/src/py/flwr/server/superlink/linkstate/sqlite_linkstate_test.py b/src/py/flwr/server/superlink/linkstate/sqlite_linkstate_test.py index 39ea57486368..ffce6d420ed3 100644 --- a/src/py/flwr/server/superlink/linkstate/sqlite_linkstate_test.py +++ b/src/py/flwr/server/superlink/linkstate/sqlite_linkstate_test.py @@ -36,7 +36,6 @@ def test_ins_res_to_dict(self) -> None: "consumer_node_id", "created_at", "delivered_at", - "pushed_at", "ttl", "ancestry", "task_type", diff --git a/src/py/flwr/server/utils/validator.py b/src/py/flwr/server/utils/validator.py index 279862ab0eb5..238f3d63c71e 100644 --- a/src/py/flwr/server/utils/validator.py +++ b/src/py/flwr/server/utils/validator.py @@ -45,9 +45,6 @@ def validate_task_ins_or_res(tasks_ins_res: Union[TaskIns, TaskRes]) -> list[str validation_errors.append("`delivered_at` must be an empty str") if tasks_ins_res.task.ttl <= 0: validation_errors.append("`ttl` must be higher than zero") - if tasks_ins_res.task.pushed_at < 1711497600.0: - # unix timestamp of 27 March 2024 00h:00m:00s UTC - validation_errors.append("`pushed_at` is not a recent timestamp") # Verify TTL and created_at time current_time = time.time() diff --git a/src/py/flwr/server/utils/validator_test.py b/src/py/flwr/server/utils/validator_test.py index 5d5323d5295d..035622a9dbd5 100644 --- a/src/py/flwr/server/utils/validator_test.py +++ b/src/py/flwr/server/utils/validator_test.py @@ -115,7 +115,6 @@ def create_task_ins( ), ) - task.task.pushed_at = time.time() return task @@ -139,5 +138,4 @@ def create_task_res( ), ) - task_res.task.pushed_at = time.time() return task_res