From 94e7774450f5900046c205eb0302aa2b756177a5 Mon Sep 17 00:00:00 2001 From: jafermarq Date: Thu, 21 Mar 2024 21:02:36 +0100 Subject: [PATCH 01/14] wip --- src/proto/flwr/proto/task.proto | 2 +- src/py/flwr/client/grpc_client/connection.py | 3 ++- .../client/grpc_client/connection_test.py | 6 ++--- .../client/message_handler/message_handler.py | 14 +++++++++--- .../message_handler/message_handler_test.py | 13 ++++++----- .../mod/secure_aggregation/secaggplus_mod.py | 3 ++- .../secure_aggregation/secaggplus_mod_test.py | 13 ++++++++--- src/py/flwr/client/mod/utils_test.py | 3 ++- src/py/flwr/common/__init__.py | 2 ++ src/py/flwr/common/message.py | 22 ++++++++++--------- src/py/flwr/common/serde_test.py | 2 +- src/py/flwr/proto/task_pb2.py | 2 +- src/py/flwr/proto/task_pb2.pyi | 4 ++-- src/py/flwr/server/driver/driver.py | 8 +++---- src/py/flwr/server/driver/driver_test.py | 12 +++++----- .../fleet/vce/backend/raybackend_test.py | 3 ++- .../superlink/fleet/vce/vce_api_test.py | 10 +++++++-- .../server/superlink/state/in_memory_state.py | 10 +++------ .../server/superlink/state/sqlite_state.py | 10 +++------ src/py/flwr/server/utils/validator.py | 14 ++++++++---- .../flwr/server/workflow/default_workflows.py | 8 +++---- .../secure_aggregation/secaggplus_workflow.py | 9 ++++---- .../ray_transport/ray_client_proxy.py | 4 ++-- .../ray_transport/ray_client_proxy_test.py | 3 ++- 24 files changed, 106 insertions(+), 74 deletions(-) diff --git a/src/proto/flwr/proto/task.proto b/src/proto/flwr/proto/task.proto index 423df76f1335..2a05faa74687 100644 --- a/src/proto/flwr/proto/task.proto +++ b/src/proto/flwr/proto/task.proto @@ -27,7 +27,7 @@ message Task { Node consumer = 2; string created_at = 3; string delivered_at = 4; - string ttl = 5; + sint64 ttl = 5; repeated string ancestry = 6; string task_type = 7; RecordSet recordset = 8; diff --git a/src/py/flwr/client/grpc_client/connection.py b/src/py/flwr/client/grpc_client/connection.py index 163a58542c9e..4431b53d2592 100644 --- a/src/py/flwr/client/grpc_client/connection.py +++ b/src/py/flwr/client/grpc_client/connection.py @@ -23,6 +23,7 @@ from typing import Callable, Iterator, Optional, Tuple, Union, cast from flwr.common import ( + DEFAULT_TTL, GRPC_MAX_MESSAGE_LENGTH, ConfigsRecord, Message, @@ -180,7 +181,7 @@ def receive() -> Message: dst_node_id=0, reply_to_message="", group_id="", - ttl="", + ttl=DEFAULT_TTL, message_type=message_type, ), content=recordset, diff --git a/src/py/flwr/client/grpc_client/connection_test.py b/src/py/flwr/client/grpc_client/connection_test.py index b7737f511a2a..061e7d4377a0 100644 --- a/src/py/flwr/client/grpc_client/connection_test.py +++ b/src/py/flwr/client/grpc_client/connection_test.py @@ -23,7 +23,7 @@ import grpc -from flwr.common import ConfigsRecord, Message, Metadata, RecordSet +from flwr.common import DEFAULT_TTL, ConfigsRecord, Message, Metadata, RecordSet from flwr.common import recordset_compat as compat from flwr.common.constant import MessageTypeLegacy from flwr.common.retry_invoker import RetryInvoker, exponential @@ -50,7 +50,7 @@ dst_node_id=0, reply_to_message="", group_id="", - ttl="", + ttl=DEFAULT_TTL, message_type=MessageTypeLegacy.GET_PROPERTIES, ), content=compat.getpropertiesres_to_recordset( @@ -65,7 +65,7 @@ dst_node_id=0, reply_to_message="", group_id="", - ttl="", + ttl=DEFAULT_TTL, message_type="reconnect", ), content=RecordSet(configs_records={"config": ConfigsRecord({"reason": 0})}), diff --git a/src/py/flwr/client/message_handler/message_handler.py b/src/py/flwr/client/message_handler/message_handler.py index 9a5d70b1ac4d..c2d67bad018b 100644 --- a/src/py/flwr/client/message_handler/message_handler.py +++ b/src/py/flwr/client/message_handler/message_handler.py @@ -26,7 +26,15 @@ ) from flwr.client.numpy_client import NumPyClient from flwr.client.typing import ClientFn -from flwr.common import ConfigsRecord, Context, Message, Metadata, RecordSet, log +from flwr.common import ( + DEFAULT_TTL, + ConfigsRecord, + Context, + Message, + Metadata, + RecordSet, + log, +) from flwr.common.constant import MessageType, MessageTypeLegacy from flwr.common.recordset_compat import ( evaluateres_to_recordset, @@ -81,7 +89,7 @@ def handle_control_message(message: Message) -> Tuple[Optional[Message], int]: reason = cast(int, disconnect_msg.disconnect_res.reason) recordset = RecordSet() recordset.configs_records["config"] = ConfigsRecord({"reason": reason}) - out_message = message.create_reply(recordset, ttl="") + out_message = message.create_reply(recordset, ttl=DEFAULT_TTL) # Return TaskRes and sleep duration return out_message, sleep_duration @@ -143,7 +151,7 @@ def handle_legacy_message_from_msgtype( raise ValueError(f"Invalid message type: {message_type}") # Return Message - return message.create_reply(out_recordset, ttl="") + return message.create_reply(out_recordset, ttl=DEFAULT_TTL) def _reconnect( diff --git a/src/py/flwr/client/message_handler/message_handler_test.py b/src/py/flwr/client/message_handler/message_handler_test.py index eaf16f7dc993..e3f6487421cc 100644 --- a/src/py/flwr/client/message_handler/message_handler_test.py +++ b/src/py/flwr/client/message_handler/message_handler_test.py @@ -23,6 +23,7 @@ from flwr.client import Client from flwr.client.typing import ClientFn from flwr.common import ( + DEFAULT_TTL, Code, Context, EvaluateIns, @@ -131,7 +132,7 @@ def test_client_without_get_properties() -> None: src_node_id=0, dst_node_id=1123, reply_to_message="", - ttl="", + ttl=DEFAULT_TTL, message_type=MessageTypeLegacy.GET_PROPERTIES, ), content=recordset, @@ -161,7 +162,7 @@ def test_client_without_get_properties() -> None: src_node_id=1123, dst_node_id=0, reply_to_message=message.metadata.message_id, - ttl="", + ttl=DEFAULT_TTL, message_type=MessageTypeLegacy.GET_PROPERTIES, ), content=expected_rs, @@ -184,7 +185,7 @@ def test_client_with_get_properties() -> None: src_node_id=0, dst_node_id=1123, reply_to_message="", - ttl="", + ttl=DEFAULT_TTL, message_type=MessageTypeLegacy.GET_PROPERTIES, ), content=recordset, @@ -214,7 +215,7 @@ def test_client_with_get_properties() -> None: src_node_id=1123, dst_node_id=0, reply_to_message=message.metadata.message_id, - ttl="", + ttl=DEFAULT_TTL, message_type=MessageTypeLegacy.GET_PROPERTIES, ), content=expected_rs, @@ -237,7 +238,7 @@ def setUp(self) -> None: dst_node_id=20, reply_to_message="", group_id="group1", - ttl="60", + ttl=DEFAULT_TTL, message_type="mock", ) self.valid_out_metadata = Metadata( @@ -247,7 +248,7 @@ def setUp(self) -> None: dst_node_id=10, reply_to_message="qwerty", group_id="group1", - ttl="60", + ttl=DEFAULT_TTL, message_type="mock", ) self.common_content = RecordSet() diff --git a/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod.py b/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod.py index 989d5f6e1361..0a901bc734c3 100644 --- a/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod.py +++ b/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod.py @@ -22,6 +22,7 @@ from flwr.client.typing import ClientAppCallable from flwr.common import ( + DEFAULT_TTL, ConfigsRecord, Context, Message, @@ -187,7 +188,7 @@ def secaggplus_mod( # Return message out_content.configs_records[RECORD_KEY_CONFIGS] = ConfigsRecord(res, False) - return msg.create_reply(out_content, ttl="") + return msg.create_reply(out_content, ttl=DEFAULT_TTL) def check_stage(current_stage: str, configs: ConfigsRecord) -> None: diff --git a/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod_test.py b/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod_test.py index db5ed67c02a4..0b88fcd00a2f 100644 --- a/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod_test.py +++ b/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod_test.py @@ -19,7 +19,14 @@ from typing import Callable, Dict, List from flwr.client.mod import make_ffn -from flwr.common import ConfigsRecord, Context, Message, Metadata, RecordSet +from flwr.common import ( + DEFAULT_TTL, + ConfigsRecord, + Context, + Message, + Metadata, + RecordSet, +) from flwr.common.constant import MessageType from flwr.common.secure_aggregation.secaggplus_constants import ( RECORD_KEY_CONFIGS, @@ -38,7 +45,7 @@ def get_test_handler( """.""" def empty_ffn(_msg: Message, _2: Context) -> Message: - return _msg.create_reply(RecordSet(), ttl="") + return _msg.create_reply(RecordSet(), ttl=DEFAULT_TTL) app = make_ffn(empty_ffn, [secaggplus_mod]) @@ -51,7 +58,7 @@ def func(configs: Dict[str, ConfigsRecordValues]) -> ConfigsRecord: dst_node_id=123, reply_to_message="", group_id="", - ttl="", + ttl=DEFAULT_TTL, message_type=MessageType.TRAIN, ), content=RecordSet( diff --git a/src/py/flwr/client/mod/utils_test.py b/src/py/flwr/client/mod/utils_test.py index e588b8b53b3b..4676a2c02c4b 100644 --- a/src/py/flwr/client/mod/utils_test.py +++ b/src/py/flwr/client/mod/utils_test.py @@ -20,6 +20,7 @@ from flwr.client.typing import ClientAppCallable, Mod from flwr.common import ( + DEFAULT_TTL, ConfigsRecord, Context, Message, @@ -84,7 +85,7 @@ def _get_dummy_flower_message() -> Message: src_node_id=0, dst_node_id=0, reply_to_message="", - ttl="", + ttl=DEFAULT_TTL, message_type="mock", ), ) diff --git a/src/py/flwr/common/__init__.py b/src/py/flwr/common/__init__.py index 9f9ff7ebc68a..2fb98c82dd6f 100644 --- a/src/py/flwr/common/__init__.py +++ b/src/py/flwr/common/__init__.py @@ -22,6 +22,7 @@ from .grpc import GRPC_MAX_MESSAGE_LENGTH from .logger import configure as configure from .logger import log as log +from .message import DEFAULT_TTL from .message import Error as Error from .message import Message as Message from .message import Metadata as Metadata @@ -87,6 +88,7 @@ "Message", "MessageType", "MessageTypeLegacy", + "DEFAULT_TTL", "Metadata", "Metrics", "MetricsAggregationFn", diff --git a/src/py/flwr/common/message.py b/src/py/flwr/common/message.py index 88cf750f1a94..c4a842aa9652 100644 --- a/src/py/flwr/common/message.py +++ b/src/py/flwr/common/message.py @@ -20,6 +20,8 @@ from .record import RecordSet +DEFAULT_TTL = 3600 + @dataclass class Metadata: # pylint: disable=too-many-instance-attributes @@ -40,7 +42,7 @@ class Metadata: # pylint: disable=too-many-instance-attributes group_id : str An identifier for grouping messages. In some settings, this is used as the FL round. - ttl : str + ttl : int Time-to-live for this message. message_type : str A string that encodes the action to be executed on @@ -57,7 +59,7 @@ class Metadata: # pylint: disable=too-many-instance-attributes _dst_node_id: int _reply_to_message: str _group_id: str - _ttl: str + _ttl: int _message_type: str _partition_id: int | None @@ -69,7 +71,7 @@ def __init__( # pylint: disable=too-many-arguments dst_node_id: int, reply_to_message: str, group_id: str, - ttl: str, + ttl: int, message_type: str, partition_id: int | None = None, ) -> None: @@ -124,12 +126,12 @@ def group_id(self, value: str) -> None: self._group_id = value @property - def ttl(self) -> str: + def ttl(self) -> int: """Time-to-live for this message.""" return self._ttl @ttl.setter - def ttl(self, value: str) -> None: + def ttl(self, value: int) -> None: """Set ttl.""" self._ttl = value @@ -266,7 +268,7 @@ def has_error(self) -> bool: """Return True if message has an error, else False.""" return self._error is not None - def _create_reply_metadata(self, ttl: str) -> Metadata: + def _create_reply_metadata(self, ttl: int) -> Metadata: """Construct metadata for a reply message.""" return Metadata( run_id=self.metadata.run_id, @@ -283,7 +285,7 @@ def _create_reply_metadata(self, ttl: str) -> Metadata: def create_error_reply( self, error: Error, - ttl: str, + ttl: int, ) -> Message: """Construct a reply message indicating an error happened. @@ -291,14 +293,14 @@ def create_error_reply( ---------- error : Error The error that was encountered. - ttl : str + ttl : int Time-to-live for this message. """ # Create reply with error message = Message(metadata=self._create_reply_metadata(ttl), error=error) return message - def create_reply(self, content: RecordSet, ttl: str) -> Message: + def create_reply(self, content: RecordSet, ttl: int) -> Message: """Create a reply to this message with specified content and TTL. The method generates a new `Message` as a reply to this message. @@ -309,7 +311,7 @@ def create_reply(self, content: RecordSet, ttl: str) -> Message: ---------- content : RecordSet The content for the reply message. - ttl : str + ttl : int Time-to-live for this message. Returns diff --git a/src/py/flwr/common/serde_test.py b/src/py/flwr/common/serde_test.py index 8596e5d2f330..e2a2bea5c255 100644 --- a/src/py/flwr/common/serde_test.py +++ b/src/py/flwr/common/serde_test.py @@ -219,7 +219,7 @@ def metadata(self) -> Metadata: src_node_id=self.rng.randint(0, 1 << 63), dst_node_id=self.rng.randint(0, 1 << 63), reply_to_message=self.get_str(64), - ttl=self.get_str(10), + ttl=self.rng.randint(0, 1 << 30), message_type=self.get_str(10), ) diff --git a/src/py/flwr/proto/task_pb2.py b/src/py/flwr/proto/task_pb2.py index 4d5f863e88dd..5507522b6630 100644 --- a/src/py/flwr/proto/task_pb2.py +++ b/src/py/flwr/proto/task_pb2.py @@ -18,7 +18,7 @@ from flwr.proto import error_pb2 as flwr_dot_proto_dot_error__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/task.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\x1a\x66lwr/proto/transport.proto\x1a\x16\x66lwr/proto/error.proto\"\xf6\x01\n\x04Task\x12\"\n\x08producer\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\"\n\x08\x63onsumer\x18\x02 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x12\n\ncreated_at\x18\x03 \x01(\t\x12\x14\n\x0c\x64\x65livered_at\x18\x04 \x01(\t\x12\x0b\n\x03ttl\x18\x05 \x01(\t\x12\x10\n\x08\x61ncestry\x18\x06 \x03(\t\x12\x11\n\ttask_type\x18\x07 \x01(\t\x12(\n\trecordset\x18\x08 \x01(\x0b\x32\x15.flwr.proto.RecordSet\x12 \n\x05\x65rror\x18\t \x01(\x0b\x32\x11.flwr.proto.Error\"\\\n\x07TaskIns\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Task\"\\\n\x07TaskRes\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Taskb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/task.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\x1a\x66lwr/proto/transport.proto\x1a\x16\x66lwr/proto/error.proto\"\xf6\x01\n\x04Task\x12\"\n\x08producer\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\"\n\x08\x63onsumer\x18\x02 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x12\n\ncreated_at\x18\x03 \x01(\t\x12\x14\n\x0c\x64\x65livered_at\x18\x04 \x01(\t\x12\x0b\n\x03ttl\x18\x05 \x01(\x12\x12\x10\n\x08\x61ncestry\x18\x06 \x03(\t\x12\x11\n\ttask_type\x18\x07 \x01(\t\x12(\n\trecordset\x18\x08 \x01(\x0b\x32\x15.flwr.proto.RecordSet\x12 \n\x05\x65rror\x18\t \x01(\x0b\x32\x11.flwr.proto.Error\"\\\n\x07TaskIns\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Task\"\\\n\x07TaskRes\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Taskb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) diff --git a/src/py/flwr/proto/task_pb2.pyi b/src/py/flwr/proto/task_pb2.pyi index b9c10139cfb3..1d012cb33450 100644 --- a/src/py/flwr/proto/task_pb2.pyi +++ b/src/py/flwr/proto/task_pb2.pyi @@ -31,7 +31,7 @@ class Task(google.protobuf.message.Message): def consumer(self) -> flwr.proto.node_pb2.Node: ... created_at: typing.Text delivered_at: typing.Text - ttl: typing.Text + ttl: builtins.int @property def ancestry(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[typing.Text]: ... task_type: typing.Text @@ -45,7 +45,7 @@ class Task(google.protobuf.message.Message): consumer: typing.Optional[flwr.proto.node_pb2.Node] = ..., created_at: typing.Text = ..., delivered_at: typing.Text = ..., - ttl: typing.Text = ..., + ttl: builtins.int = ..., ancestry: typing.Optional[typing.Iterable[typing.Text]] = ..., task_type: typing.Text = ..., recordset: typing.Optional[flwr.proto.recordset_pb2.RecordSet] = ..., diff --git a/src/py/flwr/server/driver/driver.py b/src/py/flwr/server/driver/driver.py index 0098e0ce97c2..b9c01e108eb4 100644 --- a/src/py/flwr/server/driver/driver.py +++ b/src/py/flwr/server/driver/driver.py @@ -18,7 +18,7 @@ import time from typing import Iterable, List, Optional, Tuple -from flwr.common import Message, Metadata, RecordSet +from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet from flwr.common.serde import message_from_taskres, message_to_taskins from flwr.proto.driver_pb2 import ( # pylint: disable=E0611 CreateRunRequest, @@ -90,7 +90,7 @@ def create_message( # pylint: disable=too-many-arguments message_type: str, dst_node_id: int, group_id: str, - ttl: str, + ttl: int = DEFAULT_TTL, ) -> Message: """Create a new message with specified parameters. @@ -110,7 +110,7 @@ def create_message( # pylint: disable=too-many-arguments group_id : str The ID of the group to which this message is associated. In some settings, this is used as the FL round. - ttl : str + ttl : int Time-to-live for the round trip of this message, i.e., the time from sending this message to receiving a reply. It specifies the duration for which the message and its potential reply are considered valid. @@ -128,7 +128,7 @@ def create_message( # pylint: disable=too-many-arguments dst_node_id=dst_node_id, reply_to_message="", group_id=group_id, - ttl=ttl, + ttl=time.time_ns() + ttl, message_type=message_type, ) return Message(metadata=metadata, content=content) diff --git a/src/py/flwr/server/driver/driver_test.py b/src/py/flwr/server/driver/driver_test.py index 5136f4f90210..3f1cd552250f 100644 --- a/src/py/flwr/server/driver/driver_test.py +++ b/src/py/flwr/server/driver/driver_test.py @@ -19,7 +19,7 @@ import unittest from unittest.mock import Mock, patch -from flwr.common import RecordSet +from flwr.common import DEFAULT_TTL, RecordSet from flwr.common.message import Error from flwr.common.serde import error_to_proto, recordset_to_proto from flwr.proto.driver_pb2 import ( # pylint: disable=E0611 @@ -99,7 +99,8 @@ def test_push_messages_valid(self) -> None: mock_response = Mock(task_ids=["id1", "id2"]) self.mock_grpc_driver.push_task_ins.return_value = mock_response msgs = [ - self.driver.create_message(RecordSet(), "", 0, "", "") for _ in range(2) + self.driver.create_message(RecordSet(), "", 0, "", DEFAULT_TTL) + for _ in range(2) ] # Execute @@ -121,7 +122,8 @@ def test_push_messages_invalid(self) -> None: mock_response = Mock(task_ids=["id1", "id2"]) self.mock_grpc_driver.push_task_ins.return_value = mock_response msgs = [ - self.driver.create_message(RecordSet(), "", 0, "", "") for _ in range(2) + self.driver.create_message(RecordSet(), "", 0, "", DEFAULT_TTL) + for _ in range(2) ] # Use invalid run_id msgs[1].metadata._run_id += 1 # pylint: disable=protected-access @@ -170,7 +172,7 @@ def test_send_and_receive_messages_complete(self) -> None: task_res_list=[TaskRes(task=Task(ancestry=["id1"], error=error_proto))] ) self.mock_grpc_driver.pull_task_res.return_value = mock_response - msgs = [self.driver.create_message(RecordSet(), "", 0, "", "")] + msgs = [self.driver.create_message(RecordSet(), "", 0, "", DEFAULT_TTL)] # Execute ret_msgs = list(self.driver.send_and_receive(msgs)) @@ -187,7 +189,7 @@ def test_send_and_receive_messages_timeout(self) -> None: self.mock_grpc_driver.push_task_ins.return_value = mock_response mock_response = Mock(task_res_list=[]) self.mock_grpc_driver.pull_task_res.return_value = mock_response - msgs = [self.driver.create_message(RecordSet(), "", 0, "", "")] + msgs = [self.driver.create_message(RecordSet(), "", 0, "", DEFAULT_TTL)] # Execute with patch("time.sleep", side_effect=lambda t: sleep_fn(t * 0.01)): diff --git a/src/py/flwr/server/superlink/fleet/vce/backend/raybackend_test.py b/src/py/flwr/server/superlink/fleet/vce/backend/raybackend_test.py index 2610307bb749..dcac0b81d666 100644 --- a/src/py/flwr/server/superlink/fleet/vce/backend/raybackend_test.py +++ b/src/py/flwr/server/superlink/fleet/vce/backend/raybackend_test.py @@ -25,6 +25,7 @@ from flwr.client import Client, NumPyClient from flwr.client.client_app import ClientApp, LoadClientAppError from flwr.common import ( + DEFAULT_TTL, Config, ConfigsRecord, Context, @@ -111,7 +112,7 @@ def _create_message_and_context() -> Tuple[Message, Context, float]: src_node_id=0, dst_node_id=0, reply_to_message="", - ttl="", + ttl=DEFAULT_TTL, message_type=MessageTypeLegacy.GET_PROPERTIES, ), ) diff --git a/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py b/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py index 8c37399ae295..2c917c3eed27 100644 --- a/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py +++ b/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py @@ -26,7 +26,13 @@ from unittest import IsolatedAsyncioTestCase from uuid import UUID -from flwr.common import GetPropertiesIns, Message, MessageTypeLegacy, Metadata +from flwr.common import ( + DEFAULT_TTL, + GetPropertiesIns, + Message, + MessageTypeLegacy, + Metadata, +) from flwr.common.recordset_compat import getpropertiesins_to_recordset from flwr.common.serde import message_from_taskres, message_to_taskins from flwr.server.superlink.fleet.vce.vce_api import ( @@ -97,7 +103,7 @@ def register_messages_into_state( src_node_id=0, dst_node_id=dst_node_id, # indicate destination node reply_to_message="", - ttl="", + ttl=DEFAULT_TTL, message_type=( "a bad message" if erroneous_message diff --git a/src/py/flwr/server/superlink/state/in_memory_state.py b/src/py/flwr/server/superlink/state/in_memory_state.py index ac1ab158e254..7bff8ab4befc 100644 --- a/src/py/flwr/server/superlink/state/in_memory_state.py +++ b/src/py/flwr/server/superlink/state/in_memory_state.py @@ -17,7 +17,7 @@ import os import threading -from datetime import datetime, timedelta +from datetime import datetime from logging import ERROR from typing import Dict, List, Optional, Set from uuid import UUID, uuid4 @@ -50,15 +50,13 @@ def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]: log(ERROR, "`run_id` is invalid") return None - # Create task_id, created_at and ttl + # Create task_id and created_at task_id = uuid4() created_at: datetime = now() - ttl: datetime = created_at + timedelta(hours=24) # Store TaskIns task_ins.task_id = str(task_id) task_ins.task.created_at = created_at.isoformat() - task_ins.task.ttl = ttl.isoformat() with self.lock: self.task_ins_store[task_id] = task_ins @@ -113,15 +111,13 @@ def store_task_res(self, task_res: TaskRes) -> Optional[UUID]: log(ERROR, "`run_id` is invalid") return None - # Create task_id, created_at and ttl + # Create task_id and created_at task_id = uuid4() created_at: datetime = now() - ttl: datetime = created_at + timedelta(hours=24) # Store TaskRes task_res.task_id = str(task_id) task_res.task.created_at = created_at.isoformat() - task_res.task.ttl = ttl.isoformat() with self.lock: self.task_res_store[task_id] = task_res diff --git a/src/py/flwr/server/superlink/state/sqlite_state.py b/src/py/flwr/server/superlink/state/sqlite_state.py index 224c16cdf013..0b7835fa73f3 100644 --- a/src/py/flwr/server/superlink/state/sqlite_state.py +++ b/src/py/flwr/server/superlink/state/sqlite_state.py @@ -18,7 +18,7 @@ import os import re import sqlite3 -from datetime import datetime, timedelta +from datetime import datetime from logging import DEBUG, ERROR from typing import Any, Dict, List, Optional, Set, Tuple, Union, cast from uuid import UUID, uuid4 @@ -185,15 +185,13 @@ def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]: log(ERROR, errors) return None - # Create task_id, created_at and ttl + # Create task_id and created_at task_id = uuid4() created_at: datetime = now() - ttl: datetime = created_at + timedelta(hours=24) # Store TaskIns task_ins.task_id = str(task_id) task_ins.task.created_at = created_at.isoformat() - task_ins.task.ttl = ttl.isoformat() data = (task_ins_to_dict(task_ins),) columns = ", ".join([f":{key}" for key in data[0]]) query = f"INSERT INTO task_ins VALUES({columns});" @@ -320,15 +318,13 @@ def store_task_res(self, task_res: TaskRes) -> Optional[UUID]: log(ERROR, errors) return None - # Create task_id, created_at and ttl + # Create task_id and created_at task_id = uuid4() created_at: datetime = now() - ttl: datetime = created_at + timedelta(hours=24) # Store TaskIns task_res.task_id = str(task_id) task_res.task.created_at = created_at.isoformat() - task_res.task.ttl = ttl.isoformat() data = (task_res_to_dict(task_res),) columns = ", ".join([f":{key}" for key in data[0]]) query = f"INSERT INTO task_res VALUES({columns});" diff --git a/src/py/flwr/server/utils/validator.py b/src/py/flwr/server/utils/validator.py index f9b271beafdc..846217b085a1 100644 --- a/src/py/flwr/server/utils/validator.py +++ b/src/py/flwr/server/utils/validator.py @@ -66,8 +66,11 @@ def validate_task_ins_or_res(tasks_ins_res: Union[TaskIns, TaskRes]) -> List[str # Content check if tasks_ins_res.task.task_type == "": validation_errors.append("`task_type` MUST be set") - if not tasks_ins_res.task.HasField("recordset"): - validation_errors.append("`recordset` MUST be set") + if not ( + tasks_ins_res.task.HasField("recordset") + ^ tasks_ins_res.task.HasField("error") + ): + validation_errors.append("Either `recordset` or `error` MUST be set") # Ancestors if len(tasks_ins_res.task.ancestry) != 0: @@ -106,8 +109,11 @@ def validate_task_ins_or_res(tasks_ins_res: Union[TaskIns, TaskRes]) -> List[str # Content check if tasks_ins_res.task.task_type == "": validation_errors.append("`task_type` MUST be set") - if not tasks_ins_res.task.HasField("recordset"): - validation_errors.append("`recordset` MUST be set") + if not ( + tasks_ins_res.task.HasField("recordset") + ^ tasks_ins_res.task.HasField("error") + ): + validation_errors.append("Either `recordset` or `error` MUST be set") # Ancestors if len(tasks_ins_res.task.ancestry) == 0: diff --git a/src/py/flwr/server/workflow/default_workflows.py b/src/py/flwr/server/workflow/default_workflows.py index 876ae56dcadc..42b1151f9835 100644 --- a/src/py/flwr/server/workflow/default_workflows.py +++ b/src/py/flwr/server/workflow/default_workflows.py @@ -21,7 +21,7 @@ from typing import Optional, cast import flwr.common.recordset_compat as compat -from flwr.common import ConfigsRecord, Context, GetParametersIns, log +from flwr.common import DEFAULT_TTL, ConfigsRecord, Context, GetParametersIns, log from flwr.common.constant import MessageType, MessageTypeLegacy from ..compat.app_utils import start_update_client_manager_thread @@ -127,7 +127,7 @@ def default_init_params_workflow(driver: Driver, context: Context) -> None: message_type=MessageTypeLegacy.GET_PARAMETERS, dst_node_id=random_client.node_id, group_id="0", - ttl="", + ttl=DEFAULT_TTL, ) ] ) @@ -226,7 +226,7 @@ def default_fit_workflow( # pylint: disable=R0914 message_type=MessageType.TRAIN, dst_node_id=proxy.node_id, group_id=str(current_round), - ttl="", + ttl=DEFAULT_TTL, ) for proxy, fitins in client_instructions ] @@ -306,7 +306,7 @@ def default_evaluate_workflow(driver: Driver, context: Context) -> None: message_type=MessageType.EVALUATE, dst_node_id=proxy.node_id, group_id=str(current_round), - ttl="", + ttl=DEFAULT_TTL, ) for proxy, evalins in client_instructions ] diff --git a/src/py/flwr/server/workflow/secure_aggregation/secaggplus_workflow.py b/src/py/flwr/server/workflow/secure_aggregation/secaggplus_workflow.py index 42ee9c15f1cd..326947b653ff 100644 --- a/src/py/flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +++ b/src/py/flwr/server/workflow/secure_aggregation/secaggplus_workflow.py @@ -22,6 +22,7 @@ import flwr.common.recordset_compat as compat from flwr.common import ( + DEFAULT_TTL, ConfigsRecord, Context, FitRes, @@ -373,7 +374,7 @@ def make(nid: int) -> Message: message_type=MessageType.TRAIN, dst_node_id=nid, group_id=str(cfg[WorkflowKey.CURRENT_ROUND]), - ttl="", + ttl=DEFAULT_TTL, ) log( @@ -421,7 +422,7 @@ def make(nid: int) -> Message: message_type=MessageType.TRAIN, dst_node_id=nid, group_id=str(cfg[WorkflowKey.CURRENT_ROUND]), - ttl="", + ttl=DEFAULT_TTL, ) # Broadcast public keys to clients and receive secret key shares @@ -492,7 +493,7 @@ def make(nid: int) -> Message: message_type=MessageType.TRAIN, dst_node_id=nid, group_id=str(cfg[WorkflowKey.CURRENT_ROUND]), - ttl="", + ttl=DEFAULT_TTL, ) log( @@ -563,7 +564,7 @@ def make(nid: int) -> Message: message_type=MessageType.TRAIN, dst_node_id=nid, group_id=str(current_round), - ttl="", + ttl=DEFAULT_TTL, ) log( diff --git a/src/py/flwr/simulation/ray_transport/ray_client_proxy.py b/src/py/flwr/simulation/ray_transport/ray_client_proxy.py index c3493163ac52..82bb2628debd 100644 --- a/src/py/flwr/simulation/ray_transport/ray_client_proxy.py +++ b/src/py/flwr/simulation/ray_transport/ray_client_proxy.py @@ -23,7 +23,7 @@ from flwr.client import ClientFn from flwr.client.client_app import ClientApp from flwr.client.node_state import NodeState -from flwr.common import Message, Metadata, RecordSet +from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet from flwr.common.constant import MessageType, MessageTypeLegacy from flwr.common.logger import log from flwr.common.recordset_compat import ( @@ -105,7 +105,7 @@ def _wrap_recordset_in_message( src_node_id=0, dst_node_id=int(self.cid), reply_to_message="", - ttl=str(timeout) if timeout else "", + ttl=int(timeout) if timeout else DEFAULT_TTL, message_type=message_type, partition_id=int(self.cid), ), diff --git a/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py b/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py index 22c5425cd9fd..9680b3846f1d 100644 --- a/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py +++ b/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py @@ -24,6 +24,7 @@ from flwr.client import Client, NumPyClient from flwr.client.client_app import ClientApp from flwr.common import ( + DEFAULT_TTL, Config, ConfigsRecord, Context, @@ -202,7 +203,7 @@ def _load_app() -> ClientApp: src_node_id=0, dst_node_id=12345, reply_to_message="", - ttl="", + ttl=DEFAULT_TTL, message_type=MessageTypeLegacy.GET_PROPERTIES, partition_id=int(cid), ), From ef1634a7c447219caca8eec4bbaf02f2aa2dbafa Mon Sep 17 00:00:00 2001 From: jafermarq Date: Thu, 21 Mar 2024 21:16:52 +0100 Subject: [PATCH 02/14] full --- src/py/flwr/server/superlink/state/state_test.py | 6 +----- src/py/flwr/server/utils/validator.py | 4 ++-- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/src/py/flwr/server/superlink/state/state_test.py b/src/py/flwr/server/superlink/state/state_test.py index d0470a7ce7f7..c7cc911e48c7 100644 --- a/src/py/flwr/server/superlink/state/state_test.py +++ b/src/py/flwr/server/superlink/state/state_test.py @@ -73,7 +73,6 @@ def test_store_task_ins_one(self) -> None: assert task_ins.task.created_at == "" # pylint: disable=no-member assert task_ins.task.delivered_at == "" # pylint: disable=no-member - assert task_ins.task.ttl == "" # pylint: disable=no-member # Execute state.store_task_ins(task_ins=task_ins) @@ -91,7 +90,6 @@ def test_store_task_ins_one(self) -> None: assert actual_task.created_at != "" assert actual_task.delivered_at != "" - assert actual_task.ttl != "" assert datetime.fromisoformat(actual_task.created_at) > datetime( 2020, 1, 1, tzinfo=timezone.utc @@ -99,9 +97,7 @@ def test_store_task_ins_one(self) -> None: assert datetime.fromisoformat(actual_task.delivered_at) > datetime( 2020, 1, 1, tzinfo=timezone.utc ) - assert datetime.fromisoformat(actual_task.ttl) > datetime( - 2020, 1, 1, tzinfo=timezone.utc - ) + assert actual_task.ttl > 0 def test_store_and_delete_tasks(self) -> None: """Test delete_tasks.""" diff --git a/src/py/flwr/server/utils/validator.py b/src/py/flwr/server/utils/validator.py index 846217b085a1..da0d9cd9be52 100644 --- a/src/py/flwr/server/utils/validator.py +++ b/src/py/flwr/server/utils/validator.py @@ -36,8 +36,8 @@ def validate_task_ins_or_res(tasks_ins_res: Union[TaskIns, TaskRes]) -> List[str validation_errors.append("`created_at` must be an empty str") if tasks_ins_res.task.delivered_at != "": validation_errors.append("`delivered_at` must be an empty str") - if tasks_ins_res.task.ttl != "": - validation_errors.append("`ttl` must be an empty str") + if tasks_ins_res.task.ttl > 0: + validation_errors.append("`ttl` must be higher than zero") # TaskIns specific if isinstance(tasks_ins_res, TaskIns): From 893aa9bc73adf44302f78b22755f0e46186e11bb Mon Sep 17 00:00:00 2001 From: jafermarq Date: Thu, 21 Mar 2024 21:42:04 +0100 Subject: [PATCH 03/14] fix --- src/py/flwr/server/utils/validator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/py/flwr/server/utils/validator.py b/src/py/flwr/server/utils/validator.py index da0d9cd9be52..285807d8d0e7 100644 --- a/src/py/flwr/server/utils/validator.py +++ b/src/py/flwr/server/utils/validator.py @@ -36,7 +36,7 @@ def validate_task_ins_or_res(tasks_ins_res: Union[TaskIns, TaskRes]) -> List[str validation_errors.append("`created_at` must be an empty str") if tasks_ins_res.task.delivered_at != "": validation_errors.append("`delivered_at` must be an empty str") - if tasks_ins_res.task.ttl > 0: + if tasks_ins_res.task.ttl <= 0: validation_errors.append("`ttl` must be higher than zero") # TaskIns specific From 7c67465adfc694d27bbde582663daad86c7bf957 Mon Sep 17 00:00:00 2001 From: jafermarq Date: Thu, 21 Mar 2024 21:59:54 +0100 Subject: [PATCH 04/14] fixes for sqllite --- src/py/flwr/server/superlink/state/sqlite_state.py | 4 ++-- src/py/flwr/server/superlink/state/state_test.py | 3 +++ src/py/flwr/server/utils/validator_test.py | 3 +++ 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/py/flwr/server/superlink/state/sqlite_state.py b/src/py/flwr/server/superlink/state/sqlite_state.py index 0b7835fa73f3..95513890dd87 100644 --- a/src/py/flwr/server/superlink/state/sqlite_state.py +++ b/src/py/flwr/server/superlink/state/sqlite_state.py @@ -54,7 +54,7 @@ consumer_node_id INTEGER, created_at TEXT, delivered_at TEXT, - ttl TEXT, + ttl INTEGER, ancestry TEXT, task_type TEXT, recordset BLOB, @@ -74,7 +74,7 @@ consumer_node_id INTEGER, created_at TEXT, delivered_at TEXT, - ttl TEXT, + ttl INTEGER, ancestry TEXT, task_type TEXT, recordset BLOB, diff --git a/src/py/flwr/server/superlink/state/state_test.py b/src/py/flwr/server/superlink/state/state_test.py index c7cc911e48c7..01ac64de1380 100644 --- a/src/py/flwr/server/superlink/state/state_test.py +++ b/src/py/flwr/server/superlink/state/state_test.py @@ -22,6 +22,7 @@ from typing import List from uuid import uuid4 +from flwr.common import DEFAULT_TTL from flwr.proto.node_pb2 import Node # pylint: disable=E0611 from flwr.proto.recordset_pb2 import RecordSet # pylint: disable=E0611 from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611 @@ -416,6 +417,7 @@ def create_task_ins( consumer=consumer, task_type="mock", recordset=RecordSet(parameters={}, metrics={}, configs={}), + ttl=DEFAULT_TTL, ), ) return task @@ -438,6 +440,7 @@ def create_task_res( ancestry=ancestry, task_type="mock", recordset=RecordSet(parameters={}, metrics={}, configs={}), + ttl=DEFAULT_TTL, ), ) return task_res diff --git a/src/py/flwr/server/utils/validator_test.py b/src/py/flwr/server/utils/validator_test.py index 8e0849508020..926103c6b09a 100644 --- a/src/py/flwr/server/utils/validator_test.py +++ b/src/py/flwr/server/utils/validator_test.py @@ -18,6 +18,7 @@ import unittest from typing import List, Tuple +from flwr.common import DEFAULT_TTL from flwr.proto.node_pb2 import Node # pylint: disable=E0611 from flwr.proto.recordset_pb2 import RecordSet # pylint: disable=E0611 from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611 @@ -96,6 +97,7 @@ def create_task_ins( consumer=consumer, task_type="mock", recordset=RecordSet(parameters={}, metrics={}, configs={}), + ttl=DEFAULT_TTL, ), ) return task @@ -117,6 +119,7 @@ def create_task_res( ancestry=ancestry, task_type="mock", recordset=RecordSet(parameters={}, metrics={}, configs={}), + ttl=DEFAULT_TTL, ), ) return task_res From 4ff2845daae62f8c29b44add87d659c85ffbbbd2 Mon Sep 17 00:00:00 2001 From: jafermarq Date: Sat, 23 Mar 2024 09:23:10 +0100 Subject: [PATCH 05/14] timestamp taken in SuperLink --- src/py/flwr/server/driver/driver.py | 6 +++--- src/py/flwr/server/superlink/state/in_memory_state.py | 3 +++ src/py/flwr/server/superlink/state/sqlite_state.py | 3 +++ 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/src/py/flwr/server/driver/driver.py b/src/py/flwr/server/driver/driver.py index b9c01e108eb4..1c880cf1234e 100644 --- a/src/py/flwr/server/driver/driver.py +++ b/src/py/flwr/server/driver/driver.py @@ -18,7 +18,7 @@ import time from typing import Iterable, List, Optional, Tuple -from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet +from flwr.common import Message, Metadata, RecordSet from flwr.common.serde import message_from_taskres, message_to_taskins from flwr.proto.driver_pb2 import ( # pylint: disable=E0611 CreateRunRequest, @@ -90,7 +90,7 @@ def create_message( # pylint: disable=too-many-arguments message_type: str, dst_node_id: int, group_id: str, - ttl: int = DEFAULT_TTL, + ttl: int, ) -> Message: """Create a new message with specified parameters. @@ -128,7 +128,7 @@ def create_message( # pylint: disable=too-many-arguments dst_node_id=dst_node_id, reply_to_message="", group_id=group_id, - ttl=time.time_ns() + ttl, + ttl=ttl, message_type=message_type, ) return Message(metadata=metadata, content=content) diff --git a/src/py/flwr/server/superlink/state/in_memory_state.py b/src/py/flwr/server/superlink/state/in_memory_state.py index 7bff8ab4befc..fe4cdce46c37 100644 --- a/src/py/flwr/server/superlink/state/in_memory_state.py +++ b/src/py/flwr/server/superlink/state/in_memory_state.py @@ -17,6 +17,7 @@ import os import threading +import time from datetime import datetime from logging import ERROR from typing import Dict, List, Optional, Set @@ -57,6 +58,7 @@ def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]: # Store TaskIns task_ins.task_id = str(task_id) task_ins.task.created_at = created_at.isoformat() + task_ins.task.ttl = time.time_ns() with self.lock: self.task_ins_store[task_id] = task_ins @@ -118,6 +120,7 @@ def store_task_res(self, task_res: TaskRes) -> Optional[UUID]: # Store TaskRes task_res.task_id = str(task_id) task_res.task.created_at = created_at.isoformat() + task_res.task.ttl = time.time_ns() with self.lock: self.task_res_store[task_id] = task_res diff --git a/src/py/flwr/server/superlink/state/sqlite_state.py b/src/py/flwr/server/superlink/state/sqlite_state.py index 95513890dd87..14b9bdd0c319 100644 --- a/src/py/flwr/server/superlink/state/sqlite_state.py +++ b/src/py/flwr/server/superlink/state/sqlite_state.py @@ -18,6 +18,7 @@ import os import re import sqlite3 +import time from datetime import datetime from logging import DEBUG, ERROR from typing import Any, Dict, List, Optional, Set, Tuple, Union, cast @@ -192,6 +193,7 @@ def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]: # Store TaskIns task_ins.task_id = str(task_id) task_ins.task.created_at = created_at.isoformat() + task_ins.task.ttl = time.time_ns() data = (task_ins_to_dict(task_ins),) columns = ", ".join([f":{key}" for key in data[0]]) query = f"INSERT INTO task_ins VALUES({columns});" @@ -325,6 +327,7 @@ def store_task_res(self, task_res: TaskRes) -> Optional[UUID]: # Store TaskIns task_res.task_id = str(task_id) task_res.task.created_at = created_at.isoformat() + task_res.task.ttl = time.time_ns() data = (task_res_to_dict(task_res),) columns = ", ".join([f":{key}" for key in data[0]]) query = f"INSERT INTO task_res VALUES({columns});" From 718310f8656bbd2e7b65cf979e73f91333de16c0 Mon Sep 17 00:00:00 2001 From: jafermarq Date: Mon, 25 Mar 2024 13:28:48 +0100 Subject: [PATCH 06/14] fix --- examples/app-pytorch/client_low_level.py | 8 ++++---- examples/app-pytorch/server_custom.py | 3 ++- examples/app-pytorch/server_low_level.py | 4 ++-- src/py/flwr/server/superlink/state/in_memory_state.py | 4 ++-- src/py/flwr/server/superlink/state/sqlite_state.py | 4 ++-- 5 files changed, 12 insertions(+), 11 deletions(-) diff --git a/examples/app-pytorch/client_low_level.py b/examples/app-pytorch/client_low_level.py index feea1ee658fe..1373377cc322 100644 --- a/examples/app-pytorch/client_low_level.py +++ b/examples/app-pytorch/client_low_level.py @@ -1,5 +1,5 @@ from flwr.client import ClientApp -from flwr.common import Message, Context +from flwr.common import Message, Context,DEFAULT_TTL def hello_world_mod(msg, ctx, call_next) -> Message: @@ -20,16 +20,16 @@ def hello_world_mod(msg, ctx, call_next) -> Message: @app.train() def train(msg: Message, ctx: Context): print("`train` is not implemented, echoing original message") - return msg.create_reply(msg.content, ttl="") + return msg.create_reply(msg.content, ttl=DEFAULT_TTL) @app.evaluate() def eval(msg: Message, ctx: Context): print("`evaluate` is not implemented, echoing original message") - return msg.create_reply(msg.content, ttl="") + return msg.create_reply(msg.content, ttl=DEFAULT_TTL) @app.query() def query(msg: Message, ctx: Context): print("`query` is not implemented, echoing original message") - return msg.create_reply(msg.content, ttl="") + return msg.create_reply(msg.content, ttl=DEFAULT_TTL) diff --git a/examples/app-pytorch/server_custom.py b/examples/app-pytorch/server_custom.py index 0c2851e2afee..ba9cdb11d694 100644 --- a/examples/app-pytorch/server_custom.py +++ b/examples/app-pytorch/server_custom.py @@ -13,6 +13,7 @@ Message, MessageType, Metrics, + DEFAULT_TTL, ) from flwr.common.recordset_compat import fitins_to_recordset, recordset_to_fitres from flwr.server import Driver, History @@ -89,7 +90,7 @@ def main(driver: Driver, context: Context) -> None: message_type=MessageType.TRAIN, dst_node_id=node_id, group_id=str(server_round), - ttl="", + ttl=DEFAULT_TTL, ) messages.append(message) diff --git a/examples/app-pytorch/server_low_level.py b/examples/app-pytorch/server_low_level.py index 560babac1b95..c7c3dc2513a6 100644 --- a/examples/app-pytorch/server_low_level.py +++ b/examples/app-pytorch/server_low_level.py @@ -3,7 +3,7 @@ import time import flwr as fl -from flwr.common import Context, NDArrays, Message, MessageType, Metrics, RecordSet +from flwr.common import Context, NDArrays, Message, MessageType, Metrics, RecordSet, DEFAULT_TTL from flwr.server import Driver @@ -30,7 +30,7 @@ def main(driver: Driver, context: Context) -> None: message_type=MessageType.TRAIN, dst_node_id=node_id, group_id=str(server_round), - ttl="", + ttl=DEFAULT_TTL, ) messages.append(message) diff --git a/src/py/flwr/server/superlink/state/in_memory_state.py b/src/py/flwr/server/superlink/state/in_memory_state.py index fe4cdce46c37..889eb84f9761 100644 --- a/src/py/flwr/server/superlink/state/in_memory_state.py +++ b/src/py/flwr/server/superlink/state/in_memory_state.py @@ -58,7 +58,7 @@ def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]: # Store TaskIns task_ins.task_id = str(task_id) task_ins.task.created_at = created_at.isoformat() - task_ins.task.ttl = time.time_ns() + task_ins.task.ttl += time.time_ns() with self.lock: self.task_ins_store[task_id] = task_ins @@ -120,7 +120,7 @@ def store_task_res(self, task_res: TaskRes) -> Optional[UUID]: # Store TaskRes task_res.task_id = str(task_id) task_res.task.created_at = created_at.isoformat() - task_res.task.ttl = time.time_ns() + task_res.task.ttl += time.time_ns() with self.lock: self.task_res_store[task_id] = task_res diff --git a/src/py/flwr/server/superlink/state/sqlite_state.py b/src/py/flwr/server/superlink/state/sqlite_state.py index 14b9bdd0c319..eba44a12e838 100644 --- a/src/py/flwr/server/superlink/state/sqlite_state.py +++ b/src/py/flwr/server/superlink/state/sqlite_state.py @@ -193,7 +193,7 @@ def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]: # Store TaskIns task_ins.task_id = str(task_id) task_ins.task.created_at = created_at.isoformat() - task_ins.task.ttl = time.time_ns() + task_ins.task.ttl += time.time_ns() data = (task_ins_to_dict(task_ins),) columns = ", ".join([f":{key}" for key in data[0]]) query = f"INSERT INTO task_ins VALUES({columns});" @@ -327,7 +327,7 @@ def store_task_res(self, task_res: TaskRes) -> Optional[UUID]: # Store TaskIns task_res.task_id = str(task_id) task_res.task.created_at = created_at.isoformat() - task_res.task.ttl = time.time_ns() + task_res.task.ttl += time.time_ns() data = (task_res_to_dict(task_res),) columns = ", ".join([f":{key}" for key in data[0]]) query = f"INSERT INTO task_res VALUES({columns});" From fbcc95aeebb3e51b9cfc6251dca822d9850ec344 Mon Sep 17 00:00:00 2001 From: jafermarq Date: Mon, 25 Mar 2024 13:31:53 +0100 Subject: [PATCH 07/14] format --- examples/app-pytorch/client_low_level.py | 2 +- examples/app-pytorch/server_low_level.py | 10 +++++++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/examples/app-pytorch/client_low_level.py b/examples/app-pytorch/client_low_level.py index 1373377cc322..905744bf09e3 100644 --- a/examples/app-pytorch/client_low_level.py +++ b/examples/app-pytorch/client_low_level.py @@ -1,5 +1,5 @@ from flwr.client import ClientApp -from flwr.common import Message, Context,DEFAULT_TTL +from flwr.common import Message, Context, DEFAULT_TTL def hello_world_mod(msg, ctx, call_next) -> Message: diff --git a/examples/app-pytorch/server_low_level.py b/examples/app-pytorch/server_low_level.py index c7c3dc2513a6..7ab79a4a04c8 100644 --- a/examples/app-pytorch/server_low_level.py +++ b/examples/app-pytorch/server_low_level.py @@ -3,7 +3,15 @@ import time import flwr as fl -from flwr.common import Context, NDArrays, Message, MessageType, Metrics, RecordSet, DEFAULT_TTL +from flwr.common import ( + Context, + NDArrays, + Message, + MessageType, + Metrics, + RecordSet, + DEFAULT_TTL, +) from flwr.server import Driver From 99ed10007a7b8228a415437d4af222b7589d1288 Mon Sep 17 00:00:00 2001 From: jafermarq Date: Mon, 25 Mar 2024 14:06:40 +0100 Subject: [PATCH 08/14] fixes --- src/py/flwr/common/serde_test.py | 2 +- src/py/flwr/server/compat/driver_client_proxy.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/py/flwr/common/serde_test.py b/src/py/flwr/common/serde_test.py index e2a2bea5c255..fc12ce95328f 100644 --- a/src/py/flwr/common/serde_test.py +++ b/src/py/flwr/common/serde_test.py @@ -219,7 +219,7 @@ def metadata(self) -> Metadata: src_node_id=self.rng.randint(0, 1 << 63), dst_node_id=self.rng.randint(0, 1 << 63), reply_to_message=self.get_str(64), - ttl=self.rng.randint(0, 1 << 30), + ttl=self.rng.randint(1, 1 << 30), message_type=self.get_str(10), ) diff --git a/src/py/flwr/server/compat/driver_client_proxy.py b/src/py/flwr/server/compat/driver_client_proxy.py index 84c67149fad7..99ba50d3e2d1 100644 --- a/src/py/flwr/server/compat/driver_client_proxy.py +++ b/src/py/flwr/server/compat/driver_client_proxy.py @@ -19,7 +19,7 @@ from typing import List, Optional from flwr import common -from flwr.common import MessageType, MessageTypeLegacy, RecordSet +from flwr.common import DEFAULT_TTL, MessageType, MessageTypeLegacy, RecordSet from flwr.common import recordset_compat as compat from flwr.common import serde from flwr.proto import driver_pb2, node_pb2, task_pb2 # pylint: disable=E0611 @@ -129,6 +129,7 @@ def _send_receive_recordset( ), task_type=task_type, recordset=serde.recordset_to_proto(recordset), + ttl=DEFAULT_TTL, ), ) push_task_ins_req = driver_pb2.PushTaskInsRequest( # pylint: disable=E1101 From 654c087dc6aab08b7782febd7090553217d9db91 Mon Sep 17 00:00:00 2001 From: jafermarq Date: Mon, 25 Mar 2024 14:20:33 +0100 Subject: [PATCH 09/14] more --- src/py/flwr/client/client_app.py | 8 ++++---- src/py/flwr/server/driver/driver.py | 1 + 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/py/flwr/client/client_app.py b/src/py/flwr/client/client_app.py index ad7a01326991..eeeb960db9f1 100644 --- a/src/py/flwr/client/client_app.py +++ b/src/py/flwr/client/client_app.py @@ -115,7 +115,7 @@ def train(self) -> Callable[[ClientAppCallable], ClientAppCallable]: >>> def train(message: Message, context: Context) -> Message: >>> print("ClientApp training running") >>> # Create and return an echo reply message - >>> return message.create_reply(content=message.content(), ttl="") + >>> return message.create_reply(content=message.content(), ttl=DEFAULT_TTL) """ def train_decorator(train_fn: ClientAppCallable) -> ClientAppCallable: @@ -143,7 +143,7 @@ def evaluate(self) -> Callable[[ClientAppCallable], ClientAppCallable]: >>> def evaluate(message: Message, context: Context) -> Message: >>> print("ClientApp evaluation running") >>> # Create and return an echo reply message - >>> return message.create_reply(content=message.content(), ttl="") + >>> return message.create_reply(content=message.content(), ttl=DEFAULT_TTL) """ def evaluate_decorator(evaluate_fn: ClientAppCallable) -> ClientAppCallable: @@ -171,7 +171,7 @@ def query(self) -> Callable[[ClientAppCallable], ClientAppCallable]: >>> def query(message: Message, context: Context) -> Message: >>> print("ClientApp query running") >>> # Create and return an echo reply message - >>> return message.create_reply(content=message.content(), ttl="") + >>> return message.create_reply(content=message.content(), ttl=DEFAULT_TTL) """ def query_decorator(query_fn: ClientAppCallable) -> ClientAppCallable: @@ -218,7 +218,7 @@ def _registration_error(fn_name: str) -> ValueError: >>> print("ClientApp {fn_name} running") >>> # Create and return an echo reply message >>> return message.create_reply( - >>> content=message.content(), ttl="" + >>> content=message.content(), ttl=DEFAULT_TTL >>> ) """, ) diff --git a/src/py/flwr/server/driver/driver.py b/src/py/flwr/server/driver/driver.py index 1c880cf1234e..74d4e9b97fa1 100644 --- a/src/py/flwr/server/driver/driver.py +++ b/src/py/flwr/server/driver/driver.py @@ -81,6 +81,7 @@ def _check_message(self, message: Message) -> None: and message.metadata.src_node_id == self.node.node_id and message.metadata.message_id == "" and message.metadata.reply_to_message == "" + and message.metadata.ttl > 0 ): raise ValueError(f"Invalid message: {message}") From afa2dfa4886e63476ba9c3612f4b35566dd35e46 Mon Sep 17 00:00:00 2001 From: jafermarq Date: Mon, 25 Mar 2024 15:30:48 +0100 Subject: [PATCH 10/14] ttl is float; `create_reply` comes with default ttl --- examples/app-pytorch/client_low_level.py | 8 ++++---- src/proto/flwr/proto/task.proto | 2 +- src/py/flwr/client/client_app.py | 8 ++++---- .../client/message_handler/message_handler.py | 14 +++----------- .../mod/secure_aggregation/secaggplus_mod.py | 3 +-- .../secure_aggregation/secaggplus_mod_test.py | 2 +- src/py/flwr/common/message.py | 18 +++++++++--------- src/py/flwr/proto/task_pb2.py | 2 +- src/py/flwr/proto/task_pb2.pyi | 4 ++-- src/py/flwr/server/driver/driver.py | 8 ++++---- .../server/superlink/state/in_memory_state.py | 4 ++-- .../server/superlink/state/sqlite_state.py | 8 ++++---- 12 files changed, 36 insertions(+), 45 deletions(-) diff --git a/examples/app-pytorch/client_low_level.py b/examples/app-pytorch/client_low_level.py index 905744bf09e3..19268ff84ba4 100644 --- a/examples/app-pytorch/client_low_level.py +++ b/examples/app-pytorch/client_low_level.py @@ -1,5 +1,5 @@ from flwr.client import ClientApp -from flwr.common import Message, Context, DEFAULT_TTL +from flwr.common import Message, Context def hello_world_mod(msg, ctx, call_next) -> Message: @@ -20,16 +20,16 @@ def hello_world_mod(msg, ctx, call_next) -> Message: @app.train() def train(msg: Message, ctx: Context): print("`train` is not implemented, echoing original message") - return msg.create_reply(msg.content, ttl=DEFAULT_TTL) + return msg.create_reply(msg.content) @app.evaluate() def eval(msg: Message, ctx: Context): print("`evaluate` is not implemented, echoing original message") - return msg.create_reply(msg.content, ttl=DEFAULT_TTL) + return msg.create_reply(msg.content) @app.query() def query(msg: Message, ctx: Context): print("`query` is not implemented, echoing original message") - return msg.create_reply(msg.content, ttl=DEFAULT_TTL) + return msg.create_reply(msg.content) diff --git a/src/proto/flwr/proto/task.proto b/src/proto/flwr/proto/task.proto index 2a05faa74687..4c86ebae9562 100644 --- a/src/proto/flwr/proto/task.proto +++ b/src/proto/flwr/proto/task.proto @@ -27,7 +27,7 @@ message Task { Node consumer = 2; string created_at = 3; string delivered_at = 4; - sint64 ttl = 5; + double ttl = 5; repeated string ancestry = 6; string task_type = 7; RecordSet recordset = 8; diff --git a/src/py/flwr/client/client_app.py b/src/py/flwr/client/client_app.py index eeeb960db9f1..0b56219807c6 100644 --- a/src/py/flwr/client/client_app.py +++ b/src/py/flwr/client/client_app.py @@ -115,7 +115,7 @@ def train(self) -> Callable[[ClientAppCallable], ClientAppCallable]: >>> def train(message: Message, context: Context) -> Message: >>> print("ClientApp training running") >>> # Create and return an echo reply message - >>> return message.create_reply(content=message.content(), ttl=DEFAULT_TTL) + >>> return message.create_reply(content=message.content()) """ def train_decorator(train_fn: ClientAppCallable) -> ClientAppCallable: @@ -143,7 +143,7 @@ def evaluate(self) -> Callable[[ClientAppCallable], ClientAppCallable]: >>> def evaluate(message: Message, context: Context) -> Message: >>> print("ClientApp evaluation running") >>> # Create and return an echo reply message - >>> return message.create_reply(content=message.content(), ttl=DEFAULT_TTL) + >>> return message.create_reply(content=message.content()) """ def evaluate_decorator(evaluate_fn: ClientAppCallable) -> ClientAppCallable: @@ -171,7 +171,7 @@ def query(self) -> Callable[[ClientAppCallable], ClientAppCallable]: >>> def query(message: Message, context: Context) -> Message: >>> print("ClientApp query running") >>> # Create and return an echo reply message - >>> return message.create_reply(content=message.content(), ttl=DEFAULT_TTL) + >>> return message.create_reply(content=message.content()) """ def query_decorator(query_fn: ClientAppCallable) -> ClientAppCallable: @@ -218,7 +218,7 @@ def _registration_error(fn_name: str) -> ValueError: >>> print("ClientApp {fn_name} running") >>> # Create and return an echo reply message >>> return message.create_reply( - >>> content=message.content(), ttl=DEFAULT_TTL + >>> content=message.content() >>> ) """, ) diff --git a/src/py/flwr/client/message_handler/message_handler.py b/src/py/flwr/client/message_handler/message_handler.py index c2d67bad018b..87014f436cf7 100644 --- a/src/py/flwr/client/message_handler/message_handler.py +++ b/src/py/flwr/client/message_handler/message_handler.py @@ -26,15 +26,7 @@ ) from flwr.client.numpy_client import NumPyClient from flwr.client.typing import ClientFn -from flwr.common import ( - DEFAULT_TTL, - ConfigsRecord, - Context, - Message, - Metadata, - RecordSet, - log, -) +from flwr.common import ConfigsRecord, Context, Message, Metadata, RecordSet, log from flwr.common.constant import MessageType, MessageTypeLegacy from flwr.common.recordset_compat import ( evaluateres_to_recordset, @@ -89,7 +81,7 @@ def handle_control_message(message: Message) -> Tuple[Optional[Message], int]: reason = cast(int, disconnect_msg.disconnect_res.reason) recordset = RecordSet() recordset.configs_records["config"] = ConfigsRecord({"reason": reason}) - out_message = message.create_reply(recordset, ttl=DEFAULT_TTL) + out_message = message.create_reply(recordset) # Return TaskRes and sleep duration return out_message, sleep_duration @@ -151,7 +143,7 @@ def handle_legacy_message_from_msgtype( raise ValueError(f"Invalid message type: {message_type}") # Return Message - return message.create_reply(out_recordset, ttl=DEFAULT_TTL) + return message.create_reply(out_recordset) def _reconnect( diff --git a/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod.py b/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod.py index 0a901bc734c3..5b196ad84321 100644 --- a/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod.py +++ b/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod.py @@ -22,7 +22,6 @@ from flwr.client.typing import ClientAppCallable from flwr.common import ( - DEFAULT_TTL, ConfigsRecord, Context, Message, @@ -188,7 +187,7 @@ def secaggplus_mod( # Return message out_content.configs_records[RECORD_KEY_CONFIGS] = ConfigsRecord(res, False) - return msg.create_reply(out_content, ttl=DEFAULT_TTL) + return msg.create_reply(out_content) def check_stage(current_stage: str, configs: ConfigsRecord) -> None: diff --git a/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod_test.py b/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod_test.py index 0b88fcd00a2f..36844a2983a1 100644 --- a/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod_test.py +++ b/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod_test.py @@ -45,7 +45,7 @@ def get_test_handler( """.""" def empty_ffn(_msg: Message, _2: Context) -> Message: - return _msg.create_reply(RecordSet(), ttl=DEFAULT_TTL) + return _msg.create_reply(RecordSet()) app = make_ffn(empty_ffn, [secaggplus_mod]) diff --git a/src/py/flwr/common/message.py b/src/py/flwr/common/message.py index c4a842aa9652..c82236ad8a8d 100644 --- a/src/py/flwr/common/message.py +++ b/src/py/flwr/common/message.py @@ -42,8 +42,8 @@ class Metadata: # pylint: disable=too-many-instance-attributes group_id : str An identifier for grouping messages. In some settings, this is used as the FL round. - ttl : int - Time-to-live for this message. + ttl : float + Time-to-live for this message in seconds. message_type : str A string that encodes the action to be executed on the receiving end. @@ -59,7 +59,7 @@ class Metadata: # pylint: disable=too-many-instance-attributes _dst_node_id: int _reply_to_message: str _group_id: str - _ttl: int + _ttl: float _message_type: str _partition_id: int | None @@ -71,7 +71,7 @@ def __init__( # pylint: disable=too-many-arguments dst_node_id: int, reply_to_message: str, group_id: str, - ttl: int, + ttl: float, message_type: str, partition_id: int | None = None, ) -> None: @@ -126,12 +126,12 @@ def group_id(self, value: str) -> None: self._group_id = value @property - def ttl(self) -> int: + def ttl(self) -> float: """Time-to-live for this message.""" return self._ttl @ttl.setter - def ttl(self, value: int) -> None: + def ttl(self, value: float) -> None: """Set ttl.""" self._ttl = value @@ -268,7 +268,7 @@ def has_error(self) -> bool: """Return True if message has an error, else False.""" return self._error is not None - def _create_reply_metadata(self, ttl: int) -> Metadata: + def _create_reply_metadata(self, ttl: float) -> Metadata: """Construct metadata for a reply message.""" return Metadata( run_id=self.metadata.run_id, @@ -285,7 +285,7 @@ def _create_reply_metadata(self, ttl: int) -> Metadata: def create_error_reply( self, error: Error, - ttl: int, + ttl: float, ) -> Message: """Construct a reply message indicating an error happened. @@ -300,7 +300,7 @@ def create_error_reply( message = Message(metadata=self._create_reply_metadata(ttl), error=error) return message - def create_reply(self, content: RecordSet, ttl: int) -> Message: + def create_reply(self, content: RecordSet, ttl: float = DEFAULT_TTL) -> Message: """Create a reply to this message with specified content and TTL. The method generates a new `Message` as a reply to this message. diff --git a/src/py/flwr/proto/task_pb2.py b/src/py/flwr/proto/task_pb2.py index 5507522b6630..abf7d72d7174 100644 --- a/src/py/flwr/proto/task_pb2.py +++ b/src/py/flwr/proto/task_pb2.py @@ -18,7 +18,7 @@ from flwr.proto import error_pb2 as flwr_dot_proto_dot_error__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/task.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\x1a\x66lwr/proto/transport.proto\x1a\x16\x66lwr/proto/error.proto\"\xf6\x01\n\x04Task\x12\"\n\x08producer\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\"\n\x08\x63onsumer\x18\x02 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x12\n\ncreated_at\x18\x03 \x01(\t\x12\x14\n\x0c\x64\x65livered_at\x18\x04 \x01(\t\x12\x0b\n\x03ttl\x18\x05 \x01(\x12\x12\x10\n\x08\x61ncestry\x18\x06 \x03(\t\x12\x11\n\ttask_type\x18\x07 \x01(\t\x12(\n\trecordset\x18\x08 \x01(\x0b\x32\x15.flwr.proto.RecordSet\x12 \n\x05\x65rror\x18\t \x01(\x0b\x32\x11.flwr.proto.Error\"\\\n\x07TaskIns\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Task\"\\\n\x07TaskRes\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Taskb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/task.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\x1a\x66lwr/proto/transport.proto\x1a\x16\x66lwr/proto/error.proto\"\xf6\x01\n\x04Task\x12\"\n\x08producer\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\"\n\x08\x63onsumer\x18\x02 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x12\n\ncreated_at\x18\x03 \x01(\t\x12\x14\n\x0c\x64\x65livered_at\x18\x04 \x01(\t\x12\x0b\n\x03ttl\x18\x05 \x01(\x01\x12\x10\n\x08\x61ncestry\x18\x06 \x03(\t\x12\x11\n\ttask_type\x18\x07 \x01(\t\x12(\n\trecordset\x18\x08 \x01(\x0b\x32\x15.flwr.proto.RecordSet\x12 \n\x05\x65rror\x18\t \x01(\x0b\x32\x11.flwr.proto.Error\"\\\n\x07TaskIns\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Task\"\\\n\x07TaskRes\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Taskb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) diff --git a/src/py/flwr/proto/task_pb2.pyi b/src/py/flwr/proto/task_pb2.pyi index 1d012cb33450..735400eca701 100644 --- a/src/py/flwr/proto/task_pb2.pyi +++ b/src/py/flwr/proto/task_pb2.pyi @@ -31,7 +31,7 @@ class Task(google.protobuf.message.Message): def consumer(self) -> flwr.proto.node_pb2.Node: ... created_at: typing.Text delivered_at: typing.Text - ttl: builtins.int + ttl: builtins.float @property def ancestry(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[typing.Text]: ... task_type: typing.Text @@ -45,7 +45,7 @@ class Task(google.protobuf.message.Message): consumer: typing.Optional[flwr.proto.node_pb2.Node] = ..., created_at: typing.Text = ..., delivered_at: typing.Text = ..., - ttl: builtins.int = ..., + ttl: builtins.float = ..., ancestry: typing.Optional[typing.Iterable[typing.Text]] = ..., task_type: typing.Text = ..., recordset: typing.Optional[flwr.proto.recordset_pb2.RecordSet] = ..., diff --git a/src/py/flwr/server/driver/driver.py b/src/py/flwr/server/driver/driver.py index 74d4e9b97fa1..ab943a51a674 100644 --- a/src/py/flwr/server/driver/driver.py +++ b/src/py/flwr/server/driver/driver.py @@ -91,7 +91,7 @@ def create_message( # pylint: disable=too-many-arguments message_type: str, dst_node_id: int, group_id: str, - ttl: int, + ttl: float, ) -> Message: """Create a new message with specified parameters. @@ -111,10 +111,10 @@ def create_message( # pylint: disable=too-many-arguments group_id : str The ID of the group to which this message is associated. In some settings, this is used as the FL round. - ttl : int + ttl : float Time-to-live for the round trip of this message, i.e., the time from sending - this message to receiving a reply. It specifies the duration for which the - message and its potential reply are considered valid. + this message to receiving a reply. It specifies in seconds the duration for + which the message and its potential reply are considered valid. Returns ------- diff --git a/src/py/flwr/server/superlink/state/in_memory_state.py b/src/py/flwr/server/superlink/state/in_memory_state.py index 889eb84f9761..e895958518c8 100644 --- a/src/py/flwr/server/superlink/state/in_memory_state.py +++ b/src/py/flwr/server/superlink/state/in_memory_state.py @@ -58,7 +58,7 @@ def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]: # Store TaskIns task_ins.task_id = str(task_id) task_ins.task.created_at = created_at.isoformat() - task_ins.task.ttl += time.time_ns() + task_ins.task.ttl += 1e9 * time.time_ns() with self.lock: self.task_ins_store[task_id] = task_ins @@ -120,7 +120,7 @@ def store_task_res(self, task_res: TaskRes) -> Optional[UUID]: # Store TaskRes task_res.task_id = str(task_id) task_res.task.created_at = created_at.isoformat() - task_res.task.ttl += time.time_ns() + task_res.task.ttl += 1e9 * time.time_ns() with self.lock: self.task_res_store[task_id] = task_res diff --git a/src/py/flwr/server/superlink/state/sqlite_state.py b/src/py/flwr/server/superlink/state/sqlite_state.py index eba44a12e838..91faff36c396 100644 --- a/src/py/flwr/server/superlink/state/sqlite_state.py +++ b/src/py/flwr/server/superlink/state/sqlite_state.py @@ -55,7 +55,7 @@ consumer_node_id INTEGER, created_at TEXT, delivered_at TEXT, - ttl INTEGER, + ttl FLOAT, ancestry TEXT, task_type TEXT, recordset BLOB, @@ -75,7 +75,7 @@ consumer_node_id INTEGER, created_at TEXT, delivered_at TEXT, - ttl INTEGER, + ttl FLOAT, ancestry TEXT, task_type TEXT, recordset BLOB, @@ -193,7 +193,7 @@ def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]: # Store TaskIns task_ins.task_id = str(task_id) task_ins.task.created_at = created_at.isoformat() - task_ins.task.ttl += time.time_ns() + task_ins.task.ttl += 1e9 * time.time_ns() data = (task_ins_to_dict(task_ins),) columns = ", ".join([f":{key}" for key in data[0]]) query = f"INSERT INTO task_ins VALUES({columns});" @@ -327,7 +327,7 @@ def store_task_res(self, task_res: TaskRes) -> Optional[UUID]: # Store TaskIns task_res.task_id = str(task_id) task_res.task.created_at = created_at.isoformat() - task_res.task.ttl += time.time_ns() + task_res.task.ttl += 1e9 * time.time_ns() data = (task_res_to_dict(task_res),) columns = ", ".join([f":{key}" for key in data[0]]) query = f"INSERT INTO task_res VALUES({columns});" From cef80c7f2020e050eb0aec6a890dd63c7444039d Mon Sep 17 00:00:00 2001 From: Javier Date: Tue, 26 Mar 2024 10:10:20 +0000 Subject: [PATCH 11/14] Apply suggestions from code review Co-authored-by: Daniel J. Beutel --- src/py/flwr/common/message.py | 8 ++++---- src/py/flwr/simulation/ray_transport/ray_client_proxy.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/py/flwr/common/message.py b/src/py/flwr/common/message.py index c82236ad8a8d..d3ea069946d2 100644 --- a/src/py/flwr/common/message.py +++ b/src/py/flwr/common/message.py @@ -293,8 +293,8 @@ def create_error_reply( ---------- error : Error The error that was encountered. - ttl : int - Time-to-live for this message. + ttl : float + Time-to-live for this message in seconds. """ # Create reply with error message = Message(metadata=self._create_reply_metadata(ttl), error=error) @@ -311,8 +311,8 @@ def create_reply(self, content: RecordSet, ttl: float = DEFAULT_TTL) -> Message: ---------- content : RecordSet The content for the reply message. - ttl : int - Time-to-live for this message. + ttl : float + Time-to-live for this message in seconds. Returns ------- diff --git a/src/py/flwr/simulation/ray_transport/ray_client_proxy.py b/src/py/flwr/simulation/ray_transport/ray_client_proxy.py index 82bb2628debd..5e344eb087ee 100644 --- a/src/py/flwr/simulation/ray_transport/ray_client_proxy.py +++ b/src/py/flwr/simulation/ray_transport/ray_client_proxy.py @@ -105,7 +105,7 @@ def _wrap_recordset_in_message( src_node_id=0, dst_node_id=int(self.cid), reply_to_message="", - ttl=int(timeout) if timeout else DEFAULT_TTL, + ttl=timeout if timeout else DEFAULT_TTL, message_type=message_type, partition_id=int(self.cid), ), From 54bb2ebff88b2a6c8e9651ccb3ef756a5c1974fd Mon Sep 17 00:00:00 2001 From: jafermarq Date: Tue, 26 Mar 2024 16:09:14 +0100 Subject: [PATCH 12/14] no edits to `ttl` by superlink --- src/py/flwr/server/superlink/state/in_memory_state.py | 3 --- src/py/flwr/server/superlink/state/sqlite_state.py | 3 --- 2 files changed, 6 deletions(-) diff --git a/src/py/flwr/server/superlink/state/in_memory_state.py b/src/py/flwr/server/superlink/state/in_memory_state.py index e895958518c8..7bff8ab4befc 100644 --- a/src/py/flwr/server/superlink/state/in_memory_state.py +++ b/src/py/flwr/server/superlink/state/in_memory_state.py @@ -17,7 +17,6 @@ import os import threading -import time from datetime import datetime from logging import ERROR from typing import Dict, List, Optional, Set @@ -58,7 +57,6 @@ def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]: # Store TaskIns task_ins.task_id = str(task_id) task_ins.task.created_at = created_at.isoformat() - task_ins.task.ttl += 1e9 * time.time_ns() with self.lock: self.task_ins_store[task_id] = task_ins @@ -120,7 +118,6 @@ def store_task_res(self, task_res: TaskRes) -> Optional[UUID]: # Store TaskRes task_res.task_id = str(task_id) task_res.task.created_at = created_at.isoformat() - task_res.task.ttl += 1e9 * time.time_ns() with self.lock: self.task_res_store[task_id] = task_res diff --git a/src/py/flwr/server/superlink/state/sqlite_state.py b/src/py/flwr/server/superlink/state/sqlite_state.py index 91faff36c396..886b585e3796 100644 --- a/src/py/flwr/server/superlink/state/sqlite_state.py +++ b/src/py/flwr/server/superlink/state/sqlite_state.py @@ -18,7 +18,6 @@ import os import re import sqlite3 -import time from datetime import datetime from logging import DEBUG, ERROR from typing import Any, Dict, List, Optional, Set, Tuple, Union, cast @@ -193,7 +192,6 @@ def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]: # Store TaskIns task_ins.task_id = str(task_id) task_ins.task.created_at = created_at.isoformat() - task_ins.task.ttl += 1e9 * time.time_ns() data = (task_ins_to_dict(task_ins),) columns = ", ".join([f":{key}" for key in data[0]]) query = f"INSERT INTO task_ins VALUES({columns});" @@ -327,7 +325,6 @@ def store_task_res(self, task_res: TaskRes) -> Optional[UUID]: # Store TaskIns task_res.task_id = str(task_id) task_res.task.created_at = created_at.isoformat() - task_res.task.ttl += 1e9 * time.time_ns() data = (task_res_to_dict(task_res),) columns = ", ".join([f":{key}" for key in data[0]]) query = f"INSERT INTO task_res VALUES({columns});" From 8932af3313b807a9a8a24a7eb95ec22414ab82b7 Mon Sep 17 00:00:00 2001 From: jafermarq Date: Tue, 26 Mar 2024 18:37:25 +0100 Subject: [PATCH 13/14] make `message.create_reply` ttl optional --- src/py/flwr/common/message.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/py/flwr/common/message.py b/src/py/flwr/common/message.py index d3ea069946d2..25607179764d 100644 --- a/src/py/flwr/common/message.py +++ b/src/py/flwr/common/message.py @@ -300,7 +300,7 @@ def create_error_reply( message = Message(metadata=self._create_reply_metadata(ttl), error=error) return message - def create_reply(self, content: RecordSet, ttl: float = DEFAULT_TTL) -> Message: + def create_reply(self, content: RecordSet, ttl: float | None = None) -> Message: """Create a reply to this message with specified content and TTL. The method generates a new `Message` as a reply to this message. @@ -311,14 +311,18 @@ def create_reply(self, content: RecordSet, ttl: float = DEFAULT_TTL) -> Message: ---------- content : RecordSet The content for the reply message. - ttl : float - Time-to-live for this message in seconds. + ttl : Optional[float] (default: None) + Time-to-live for this message in seconds. If unset, it will use + the `common.DEFAULT_TTL` value. Returns ------- Message A new `Message` instance representing the reply. """ + if ttl is None: + ttl = DEFAULT_TTL + return Message( metadata=self._create_reply_metadata(ttl), content=content, From d18bd624ef645f12c5080d5f5bca198344b66f6f Mon Sep 17 00:00:00 2001 From: jafermarq Date: Tue, 26 Mar 2024 19:20:58 +0100 Subject: [PATCH 14/14] set sql type to `REAL`; deafult TTL in `[driver].create_message()` --- src/py/flwr/server/driver/driver.py | 6 +++--- src/py/flwr/server/superlink/state/sqlite_state.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/py/flwr/server/driver/driver.py b/src/py/flwr/server/driver/driver.py index ab943a51a674..afebd90ea265 100644 --- a/src/py/flwr/server/driver/driver.py +++ b/src/py/flwr/server/driver/driver.py @@ -18,7 +18,7 @@ import time from typing import Iterable, List, Optional, Tuple -from flwr.common import Message, Metadata, RecordSet +from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet from flwr.common.serde import message_from_taskres, message_to_taskins from flwr.proto.driver_pb2 import ( # pylint: disable=E0611 CreateRunRequest, @@ -91,7 +91,7 @@ def create_message( # pylint: disable=too-many-arguments message_type: str, dst_node_id: int, group_id: str, - ttl: float, + ttl: float = DEFAULT_TTL, ) -> Message: """Create a new message with specified parameters. @@ -111,7 +111,7 @@ def create_message( # pylint: disable=too-many-arguments group_id : str The ID of the group to which this message is associated. In some settings, this is used as the FL round. - ttl : float + ttl : float (default: common.DEFAULT_TTL) Time-to-live for the round trip of this message, i.e., the time from sending this message to receiving a reply. It specifies in seconds the duration for which the message and its potential reply are considered valid. diff --git a/src/py/flwr/server/superlink/state/sqlite_state.py b/src/py/flwr/server/superlink/state/sqlite_state.py index 886b585e3796..25d138f94203 100644 --- a/src/py/flwr/server/superlink/state/sqlite_state.py +++ b/src/py/flwr/server/superlink/state/sqlite_state.py @@ -54,7 +54,7 @@ consumer_node_id INTEGER, created_at TEXT, delivered_at TEXT, - ttl FLOAT, + ttl REAL, ancestry TEXT, task_type TEXT, recordset BLOB, @@ -74,7 +74,7 @@ consumer_node_id INTEGER, created_at TEXT, delivered_at TEXT, - ttl FLOAT, + ttl REAL, ancestry TEXT, task_type TEXT, recordset BLOB,