diff --git a/examples/app-pytorch/client_low_level.py b/examples/app-pytorch/client_low_level.py index feea1ee658fe..19268ff84ba4 100644 --- a/examples/app-pytorch/client_low_level.py +++ b/examples/app-pytorch/client_low_level.py @@ -20,16 +20,16 @@ def hello_world_mod(msg, ctx, call_next) -> Message: @app.train() def train(msg: Message, ctx: Context): print("`train` is not implemented, echoing original message") - return msg.create_reply(msg.content, ttl="") + return msg.create_reply(msg.content) @app.evaluate() def eval(msg: Message, ctx: Context): print("`evaluate` is not implemented, echoing original message") - return msg.create_reply(msg.content, ttl="") + return msg.create_reply(msg.content) @app.query() def query(msg: Message, ctx: Context): print("`query` is not implemented, echoing original message") - return msg.create_reply(msg.content, ttl="") + return msg.create_reply(msg.content) diff --git a/examples/app-pytorch/server_custom.py b/examples/app-pytorch/server_custom.py index 51ac8a5c006c..67c1bce99c55 100644 --- a/examples/app-pytorch/server_custom.py +++ b/examples/app-pytorch/server_custom.py @@ -13,6 +13,7 @@ Message, MessageType, Metrics, + DEFAULT_TTL, ) from flwr.common.recordset_compat import fitins_to_recordset, recordset_to_fitres from flwr.server import Driver, History @@ -89,7 +90,7 @@ def main(driver: Driver, context: Context) -> None: message_type=MessageType.TRAIN, dst_node_id=node_id, group_id=str(server_round), - ttl="", + ttl=DEFAULT_TTL, ) messages.append(message) diff --git a/examples/app-pytorch/server_low_level.py b/examples/app-pytorch/server_low_level.py index 560babac1b95..7ab79a4a04c8 100644 --- a/examples/app-pytorch/server_low_level.py +++ b/examples/app-pytorch/server_low_level.py @@ -3,7 +3,15 @@ import time import flwr as fl -from flwr.common import Context, NDArrays, Message, MessageType, Metrics, RecordSet +from flwr.common import ( + Context, + NDArrays, + Message, + MessageType, + Metrics, + RecordSet, + DEFAULT_TTL, +) from flwr.server import Driver @@ -30,7 +38,7 @@ def main(driver: Driver, context: Context) -> None: message_type=MessageType.TRAIN, dst_node_id=node_id, group_id=str(server_round), - ttl="", + ttl=DEFAULT_TTL, ) messages.append(message) diff --git a/examples/llm-flowertune/requirements.txt b/examples/llm-flowertune/requirements.txt index c7ff57b403f7..196531c99b92 100644 --- a/examples/llm-flowertune/requirements.txt +++ b/examples/llm-flowertune/requirements.txt @@ -6,3 +6,4 @@ bitsandbytes==0.41.3 scipy==1.11.2 peft==0.4.0 fschat[model_worker,webui]==0.2.35 +transformers==4.38.1 diff --git a/src/proto/flwr/proto/fleet.proto b/src/proto/flwr/proto/fleet.proto index c900a3b1148d..fcb301181f5a 100644 --- a/src/proto/flwr/proto/fleet.proto +++ b/src/proto/flwr/proto/fleet.proto @@ -23,6 +23,7 @@ import "flwr/proto/task.proto"; service Fleet { rpc CreateNode(CreateNodeRequest) returns (CreateNodeResponse) {} rpc DeleteNode(DeleteNodeRequest) returns (DeleteNodeResponse) {} + rpc Ping(PingRequest) returns (PingResponse) {} // Retrieve one or more tasks, if possible // @@ -43,6 +44,10 @@ message CreateNodeResponse { Node node = 1; } message DeleteNodeRequest { Node node = 1; } message DeleteNodeResponse {} +// Ping messages +message PingRequest { Node node = 1; } +message PingResponse { bool success = 1; } + // PullTaskIns messages message PullTaskInsRequest { Node node = 1; diff --git a/src/proto/flwr/proto/task.proto b/src/proto/flwr/proto/task.proto index 423df76f1335..4c86ebae9562 100644 --- a/src/proto/flwr/proto/task.proto +++ b/src/proto/flwr/proto/task.proto @@ -27,7 +27,7 @@ message Task { Node consumer = 2; string created_at = 3; string delivered_at = 4; - string ttl = 5; + double ttl = 5; repeated string ancestry = 6; string task_type = 7; RecordSet recordset = 8; diff --git a/src/py/flwr/client/client_app.py b/src/py/flwr/client/client_app.py index ad7a01326991..0b56219807c6 100644 --- a/src/py/flwr/client/client_app.py +++ b/src/py/flwr/client/client_app.py @@ -115,7 +115,7 @@ def train(self) -> Callable[[ClientAppCallable], ClientAppCallable]: >>> def train(message: Message, context: Context) -> Message: >>> print("ClientApp training running") >>> # Create and return an echo reply message - >>> return message.create_reply(content=message.content(), ttl="") + >>> return message.create_reply(content=message.content()) """ def train_decorator(train_fn: ClientAppCallable) -> ClientAppCallable: @@ -143,7 +143,7 @@ def evaluate(self) -> Callable[[ClientAppCallable], ClientAppCallable]: >>> def evaluate(message: Message, context: Context) -> Message: >>> print("ClientApp evaluation running") >>> # Create and return an echo reply message - >>> return message.create_reply(content=message.content(), ttl="") + >>> return message.create_reply(content=message.content()) """ def evaluate_decorator(evaluate_fn: ClientAppCallable) -> ClientAppCallable: @@ -171,7 +171,7 @@ def query(self) -> Callable[[ClientAppCallable], ClientAppCallable]: >>> def query(message: Message, context: Context) -> Message: >>> print("ClientApp query running") >>> # Create and return an echo reply message - >>> return message.create_reply(content=message.content(), ttl="") + >>> return message.create_reply(content=message.content()) """ def query_decorator(query_fn: ClientAppCallable) -> ClientAppCallable: @@ -218,7 +218,7 @@ def _registration_error(fn_name: str) -> ValueError: >>> print("ClientApp {fn_name} running") >>> # Create and return an echo reply message >>> return message.create_reply( - >>> content=message.content(), ttl="" + >>> content=message.content() >>> ) """, ) diff --git a/src/py/flwr/client/grpc_client/connection.py b/src/py/flwr/client/grpc_client/connection.py index 163a58542c9e..4431b53d2592 100644 --- a/src/py/flwr/client/grpc_client/connection.py +++ b/src/py/flwr/client/grpc_client/connection.py @@ -23,6 +23,7 @@ from typing import Callable, Iterator, Optional, Tuple, Union, cast from flwr.common import ( + DEFAULT_TTL, GRPC_MAX_MESSAGE_LENGTH, ConfigsRecord, Message, @@ -180,7 +181,7 @@ def receive() -> Message: dst_node_id=0, reply_to_message="", group_id="", - ttl="", + ttl=DEFAULT_TTL, message_type=message_type, ), content=recordset, diff --git a/src/py/flwr/client/grpc_client/connection_test.py b/src/py/flwr/client/grpc_client/connection_test.py index b7737f511a2a..061e7d4377a0 100644 --- a/src/py/flwr/client/grpc_client/connection_test.py +++ b/src/py/flwr/client/grpc_client/connection_test.py @@ -23,7 +23,7 @@ import grpc -from flwr.common import ConfigsRecord, Message, Metadata, RecordSet +from flwr.common import DEFAULT_TTL, ConfigsRecord, Message, Metadata, RecordSet from flwr.common import recordset_compat as compat from flwr.common.constant import MessageTypeLegacy from flwr.common.retry_invoker import RetryInvoker, exponential @@ -50,7 +50,7 @@ dst_node_id=0, reply_to_message="", group_id="", - ttl="", + ttl=DEFAULT_TTL, message_type=MessageTypeLegacy.GET_PROPERTIES, ), content=compat.getpropertiesres_to_recordset( @@ -65,7 +65,7 @@ dst_node_id=0, reply_to_message="", group_id="", - ttl="", + ttl=DEFAULT_TTL, message_type="reconnect", ), content=RecordSet(configs_records={"config": ConfigsRecord({"reason": 0})}), diff --git a/src/py/flwr/client/message_handler/message_handler.py b/src/py/flwr/client/message_handler/message_handler.py index 9a5d70b1ac4d..87014f436cf7 100644 --- a/src/py/flwr/client/message_handler/message_handler.py +++ b/src/py/flwr/client/message_handler/message_handler.py @@ -81,7 +81,7 @@ def handle_control_message(message: Message) -> Tuple[Optional[Message], int]: reason = cast(int, disconnect_msg.disconnect_res.reason) recordset = RecordSet() recordset.configs_records["config"] = ConfigsRecord({"reason": reason}) - out_message = message.create_reply(recordset, ttl="") + out_message = message.create_reply(recordset) # Return TaskRes and sleep duration return out_message, sleep_duration @@ -143,7 +143,7 @@ def handle_legacy_message_from_msgtype( raise ValueError(f"Invalid message type: {message_type}") # Return Message - return message.create_reply(out_recordset, ttl="") + return message.create_reply(out_recordset) def _reconnect( diff --git a/src/py/flwr/client/message_handler/message_handler_test.py b/src/py/flwr/client/message_handler/message_handler_test.py index eaf16f7dc993..e3f6487421cc 100644 --- a/src/py/flwr/client/message_handler/message_handler_test.py +++ b/src/py/flwr/client/message_handler/message_handler_test.py @@ -23,6 +23,7 @@ from flwr.client import Client from flwr.client.typing import ClientFn from flwr.common import ( + DEFAULT_TTL, Code, Context, EvaluateIns, @@ -131,7 +132,7 @@ def test_client_without_get_properties() -> None: src_node_id=0, dst_node_id=1123, reply_to_message="", - ttl="", + ttl=DEFAULT_TTL, message_type=MessageTypeLegacy.GET_PROPERTIES, ), content=recordset, @@ -161,7 +162,7 @@ def test_client_without_get_properties() -> None: src_node_id=1123, dst_node_id=0, reply_to_message=message.metadata.message_id, - ttl="", + ttl=DEFAULT_TTL, message_type=MessageTypeLegacy.GET_PROPERTIES, ), content=expected_rs, @@ -184,7 +185,7 @@ def test_client_with_get_properties() -> None: src_node_id=0, dst_node_id=1123, reply_to_message="", - ttl="", + ttl=DEFAULT_TTL, message_type=MessageTypeLegacy.GET_PROPERTIES, ), content=recordset, @@ -214,7 +215,7 @@ def test_client_with_get_properties() -> None: src_node_id=1123, dst_node_id=0, reply_to_message=message.metadata.message_id, - ttl="", + ttl=DEFAULT_TTL, message_type=MessageTypeLegacy.GET_PROPERTIES, ), content=expected_rs, @@ -237,7 +238,7 @@ def setUp(self) -> None: dst_node_id=20, reply_to_message="", group_id="group1", - ttl="60", + ttl=DEFAULT_TTL, message_type="mock", ) self.valid_out_metadata = Metadata( @@ -247,7 +248,7 @@ def setUp(self) -> None: dst_node_id=10, reply_to_message="qwerty", group_id="group1", - ttl="60", + ttl=DEFAULT_TTL, message_type="mock", ) self.common_content = RecordSet() diff --git a/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod.py b/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod.py index 989d5f6e1361..5b196ad84321 100644 --- a/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod.py +++ b/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod.py @@ -187,7 +187,7 @@ def secaggplus_mod( # Return message out_content.configs_records[RECORD_KEY_CONFIGS] = ConfigsRecord(res, False) - return msg.create_reply(out_content, ttl="") + return msg.create_reply(out_content) def check_stage(current_stage: str, configs: ConfigsRecord) -> None: diff --git a/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod_test.py b/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod_test.py index db5ed67c02a4..36844a2983a1 100644 --- a/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod_test.py +++ b/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod_test.py @@ -19,7 +19,14 @@ from typing import Callable, Dict, List from flwr.client.mod import make_ffn -from flwr.common import ConfigsRecord, Context, Message, Metadata, RecordSet +from flwr.common import ( + DEFAULT_TTL, + ConfigsRecord, + Context, + Message, + Metadata, + RecordSet, +) from flwr.common.constant import MessageType from flwr.common.secure_aggregation.secaggplus_constants import ( RECORD_KEY_CONFIGS, @@ -38,7 +45,7 @@ def get_test_handler( """.""" def empty_ffn(_msg: Message, _2: Context) -> Message: - return _msg.create_reply(RecordSet(), ttl="") + return _msg.create_reply(RecordSet()) app = make_ffn(empty_ffn, [secaggplus_mod]) @@ -51,7 +58,7 @@ def func(configs: Dict[str, ConfigsRecordValues]) -> ConfigsRecord: dst_node_id=123, reply_to_message="", group_id="", - ttl="", + ttl=DEFAULT_TTL, message_type=MessageType.TRAIN, ), content=RecordSet( diff --git a/src/py/flwr/client/mod/utils_test.py b/src/py/flwr/client/mod/utils_test.py index e588b8b53b3b..4676a2c02c4b 100644 --- a/src/py/flwr/client/mod/utils_test.py +++ b/src/py/flwr/client/mod/utils_test.py @@ -20,6 +20,7 @@ from flwr.client.typing import ClientAppCallable, Mod from flwr.common import ( + DEFAULT_TTL, ConfigsRecord, Context, Message, @@ -84,7 +85,7 @@ def _get_dummy_flower_message() -> Message: src_node_id=0, dst_node_id=0, reply_to_message="", - ttl="", + ttl=DEFAULT_TTL, message_type="mock", ), ) diff --git a/src/py/flwr/common/__init__.py b/src/py/flwr/common/__init__.py index 9f9ff7ebc68a..2fb98c82dd6f 100644 --- a/src/py/flwr/common/__init__.py +++ b/src/py/flwr/common/__init__.py @@ -22,6 +22,7 @@ from .grpc import GRPC_MAX_MESSAGE_LENGTH from .logger import configure as configure from .logger import log as log +from .message import DEFAULT_TTL from .message import Error as Error from .message import Message as Message from .message import Metadata as Metadata @@ -87,6 +88,7 @@ "Message", "MessageType", "MessageTypeLegacy", + "DEFAULT_TTL", "Metadata", "Metrics", "MetricsAggregationFn", diff --git a/src/py/flwr/common/message.py b/src/py/flwr/common/message.py index 88cf750f1a94..25607179764d 100644 --- a/src/py/flwr/common/message.py +++ b/src/py/flwr/common/message.py @@ -20,6 +20,8 @@ from .record import RecordSet +DEFAULT_TTL = 3600 + @dataclass class Metadata: # pylint: disable=too-many-instance-attributes @@ -40,8 +42,8 @@ class Metadata: # pylint: disable=too-many-instance-attributes group_id : str An identifier for grouping messages. In some settings, this is used as the FL round. - ttl : str - Time-to-live for this message. + ttl : float + Time-to-live for this message in seconds. message_type : str A string that encodes the action to be executed on the receiving end. @@ -57,7 +59,7 @@ class Metadata: # pylint: disable=too-many-instance-attributes _dst_node_id: int _reply_to_message: str _group_id: str - _ttl: str + _ttl: float _message_type: str _partition_id: int | None @@ -69,7 +71,7 @@ def __init__( # pylint: disable=too-many-arguments dst_node_id: int, reply_to_message: str, group_id: str, - ttl: str, + ttl: float, message_type: str, partition_id: int | None = None, ) -> None: @@ -124,12 +126,12 @@ def group_id(self, value: str) -> None: self._group_id = value @property - def ttl(self) -> str: + def ttl(self) -> float: """Time-to-live for this message.""" return self._ttl @ttl.setter - def ttl(self, value: str) -> None: + def ttl(self, value: float) -> None: """Set ttl.""" self._ttl = value @@ -266,7 +268,7 @@ def has_error(self) -> bool: """Return True if message has an error, else False.""" return self._error is not None - def _create_reply_metadata(self, ttl: str) -> Metadata: + def _create_reply_metadata(self, ttl: float) -> Metadata: """Construct metadata for a reply message.""" return Metadata( run_id=self.metadata.run_id, @@ -283,7 +285,7 @@ def _create_reply_metadata(self, ttl: str) -> Metadata: def create_error_reply( self, error: Error, - ttl: str, + ttl: float, ) -> Message: """Construct a reply message indicating an error happened. @@ -291,14 +293,14 @@ def create_error_reply( ---------- error : Error The error that was encountered. - ttl : str - Time-to-live for this message. + ttl : float + Time-to-live for this message in seconds. """ # Create reply with error message = Message(metadata=self._create_reply_metadata(ttl), error=error) return message - def create_reply(self, content: RecordSet, ttl: str) -> Message: + def create_reply(self, content: RecordSet, ttl: float | None = None) -> Message: """Create a reply to this message with specified content and TTL. The method generates a new `Message` as a reply to this message. @@ -309,14 +311,18 @@ def create_reply(self, content: RecordSet, ttl: str) -> Message: ---------- content : RecordSet The content for the reply message. - ttl : str - Time-to-live for this message. + ttl : Optional[float] (default: None) + Time-to-live for this message in seconds. If unset, it will use + the `common.DEFAULT_TTL` value. Returns ------- Message A new `Message` instance representing the reply. """ + if ttl is None: + ttl = DEFAULT_TTL + return Message( metadata=self._create_reply_metadata(ttl), content=content, diff --git a/src/py/flwr/common/serde_test.py b/src/py/flwr/common/serde_test.py index 8596e5d2f330..fc12ce95328f 100644 --- a/src/py/flwr/common/serde_test.py +++ b/src/py/flwr/common/serde_test.py @@ -219,7 +219,7 @@ def metadata(self) -> Metadata: src_node_id=self.rng.randint(0, 1 << 63), dst_node_id=self.rng.randint(0, 1 << 63), reply_to_message=self.get_str(64), - ttl=self.get_str(10), + ttl=self.rng.randint(1, 1 << 30), message_type=self.get_str(10), ) diff --git a/src/py/flwr/proto/fleet_pb2.py b/src/py/flwr/proto/fleet_pb2.py index e8443c296f0c..dbf64fb850a5 100644 --- a/src/py/flwr/proto/fleet_pb2.py +++ b/src/py/flwr/proto/fleet_pb2.py @@ -16,7 +16,7 @@ from flwr.proto import task_pb2 as flwr_dot_proto_dot_task__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16\x66lwr/proto/fleet.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x15\x66lwr/proto/task.proto\"\x13\n\x11\x43reateNodeRequest\"4\n\x12\x43reateNodeResponse\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\"3\n\x11\x44\x65leteNodeRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\"\x14\n\x12\x44\x65leteNodeResponse\"F\n\x12PullTaskInsRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"k\n\x13PullTaskInsResponse\x12(\n\treconnect\x18\x01 \x01(\x0b\x32\x15.flwr.proto.Reconnect\x12*\n\rtask_ins_list\x18\x02 \x03(\x0b\x32\x13.flwr.proto.TaskIns\"@\n\x12PushTaskResRequest\x12*\n\rtask_res_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskRes\"\xae\x01\n\x13PushTaskResResponse\x12(\n\treconnect\x18\x01 \x01(\x0b\x32\x15.flwr.proto.Reconnect\x12=\n\x07results\x18\x02 \x03(\x0b\x32,.flwr.proto.PushTaskResResponse.ResultsEntry\x1a.\n\x0cResultsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\r:\x02\x38\x01\"\x1e\n\tReconnect\x12\x11\n\treconnect\x18\x01 \x01(\x04\x32\xc9\x02\n\x05\x46leet\x12M\n\nCreateNode\x12\x1d.flwr.proto.CreateNodeRequest\x1a\x1e.flwr.proto.CreateNodeResponse\"\x00\x12M\n\nDeleteNode\x12\x1d.flwr.proto.DeleteNodeRequest\x1a\x1e.flwr.proto.DeleteNodeResponse\"\x00\x12P\n\x0bPullTaskIns\x12\x1e.flwr.proto.PullTaskInsRequest\x1a\x1f.flwr.proto.PullTaskInsResponse\"\x00\x12P\n\x0bPushTaskRes\x12\x1e.flwr.proto.PushTaskResRequest\x1a\x1f.flwr.proto.PushTaskResResponse\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16\x66lwr/proto/fleet.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x15\x66lwr/proto/task.proto\"\x13\n\x11\x43reateNodeRequest\"4\n\x12\x43reateNodeResponse\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\"3\n\x11\x44\x65leteNodeRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\"\x14\n\x12\x44\x65leteNodeResponse\"-\n\x0bPingRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\"\x1f\n\x0cPingResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\"F\n\x12PullTaskInsRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"k\n\x13PullTaskInsResponse\x12(\n\treconnect\x18\x01 \x01(\x0b\x32\x15.flwr.proto.Reconnect\x12*\n\rtask_ins_list\x18\x02 \x03(\x0b\x32\x13.flwr.proto.TaskIns\"@\n\x12PushTaskResRequest\x12*\n\rtask_res_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskRes\"\xae\x01\n\x13PushTaskResResponse\x12(\n\treconnect\x18\x01 \x01(\x0b\x32\x15.flwr.proto.Reconnect\x12=\n\x07results\x18\x02 \x03(\x0b\x32,.flwr.proto.PushTaskResResponse.ResultsEntry\x1a.\n\x0cResultsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\r:\x02\x38\x01\"\x1e\n\tReconnect\x12\x11\n\treconnect\x18\x01 \x01(\x04\x32\x86\x03\n\x05\x46leet\x12M\n\nCreateNode\x12\x1d.flwr.proto.CreateNodeRequest\x1a\x1e.flwr.proto.CreateNodeResponse\"\x00\x12M\n\nDeleteNode\x12\x1d.flwr.proto.DeleteNodeRequest\x1a\x1e.flwr.proto.DeleteNodeResponse\"\x00\x12;\n\x04Ping\x12\x17.flwr.proto.PingRequest\x1a\x18.flwr.proto.PingResponse\"\x00\x12P\n\x0bPullTaskIns\x12\x1e.flwr.proto.PullTaskInsRequest\x1a\x1f.flwr.proto.PullTaskInsResponse\"\x00\x12P\n\x0bPushTaskRes\x12\x1e.flwr.proto.PushTaskResRequest\x1a\x1f.flwr.proto.PushTaskResResponse\"\x00\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -33,18 +33,22 @@ _globals['_DELETENODEREQUEST']._serialized_end=210 _globals['_DELETENODERESPONSE']._serialized_start=212 _globals['_DELETENODERESPONSE']._serialized_end=232 - _globals['_PULLTASKINSREQUEST']._serialized_start=234 - _globals['_PULLTASKINSREQUEST']._serialized_end=304 - _globals['_PULLTASKINSRESPONSE']._serialized_start=306 - _globals['_PULLTASKINSRESPONSE']._serialized_end=413 - _globals['_PUSHTASKRESREQUEST']._serialized_start=415 - _globals['_PUSHTASKRESREQUEST']._serialized_end=479 - _globals['_PUSHTASKRESRESPONSE']._serialized_start=482 - _globals['_PUSHTASKRESRESPONSE']._serialized_end=656 - _globals['_PUSHTASKRESRESPONSE_RESULTSENTRY']._serialized_start=610 - _globals['_PUSHTASKRESRESPONSE_RESULTSENTRY']._serialized_end=656 - _globals['_RECONNECT']._serialized_start=658 - _globals['_RECONNECT']._serialized_end=688 - _globals['_FLEET']._serialized_start=691 - _globals['_FLEET']._serialized_end=1020 + _globals['_PINGREQUEST']._serialized_start=234 + _globals['_PINGREQUEST']._serialized_end=279 + _globals['_PINGRESPONSE']._serialized_start=281 + _globals['_PINGRESPONSE']._serialized_end=312 + _globals['_PULLTASKINSREQUEST']._serialized_start=314 + _globals['_PULLTASKINSREQUEST']._serialized_end=384 + _globals['_PULLTASKINSRESPONSE']._serialized_start=386 + _globals['_PULLTASKINSRESPONSE']._serialized_end=493 + _globals['_PUSHTASKRESREQUEST']._serialized_start=495 + _globals['_PUSHTASKRESREQUEST']._serialized_end=559 + _globals['_PUSHTASKRESRESPONSE']._serialized_start=562 + _globals['_PUSHTASKRESRESPONSE']._serialized_end=736 + _globals['_PUSHTASKRESRESPONSE_RESULTSENTRY']._serialized_start=690 + _globals['_PUSHTASKRESRESPONSE_RESULTSENTRY']._serialized_end=736 + _globals['_RECONNECT']._serialized_start=738 + _globals['_RECONNECT']._serialized_end=768 + _globals['_FLEET']._serialized_start=771 + _globals['_FLEET']._serialized_end=1161 # @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/fleet_pb2.pyi b/src/py/flwr/proto/fleet_pb2.pyi index 86bc358858d2..39edb61ca0d7 100644 --- a/src/py/flwr/proto/fleet_pb2.pyi +++ b/src/py/flwr/proto/fleet_pb2.pyi @@ -53,6 +53,31 @@ class DeleteNodeResponse(google.protobuf.message.Message): ) -> None: ... global___DeleteNodeResponse = DeleteNodeResponse +class PingRequest(google.protobuf.message.Message): + """Ping messages""" + DESCRIPTOR: google.protobuf.descriptor.Descriptor + NODE_FIELD_NUMBER: builtins.int + @property + def node(self) -> flwr.proto.node_pb2.Node: ... + def __init__(self, + *, + node: typing.Optional[flwr.proto.node_pb2.Node] = ..., + ) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["node",b"node"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["node",b"node"]) -> None: ... +global___PingRequest = PingRequest + +class PingResponse(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + SUCCESS_FIELD_NUMBER: builtins.int + success: builtins.bool + def __init__(self, + *, + success: builtins.bool = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["success",b"success"]) -> None: ... +global___PingResponse = PingResponse + class PullTaskInsRequest(google.protobuf.message.Message): """PullTaskIns messages""" DESCRIPTOR: google.protobuf.descriptor.Descriptor diff --git a/src/py/flwr/proto/fleet_pb2_grpc.py b/src/py/flwr/proto/fleet_pb2_grpc.py index 2b53ec43e851..c31a4ec73f0e 100644 --- a/src/py/flwr/proto/fleet_pb2_grpc.py +++ b/src/py/flwr/proto/fleet_pb2_grpc.py @@ -24,6 +24,11 @@ def __init__(self, channel): request_serializer=flwr_dot_proto_dot_fleet__pb2.DeleteNodeRequest.SerializeToString, response_deserializer=flwr_dot_proto_dot_fleet__pb2.DeleteNodeResponse.FromString, ) + self.Ping = channel.unary_unary( + '/flwr.proto.Fleet/Ping', + request_serializer=flwr_dot_proto_dot_fleet__pb2.PingRequest.SerializeToString, + response_deserializer=flwr_dot_proto_dot_fleet__pb2.PingResponse.FromString, + ) self.PullTaskIns = channel.unary_unary( '/flwr.proto.Fleet/PullTaskIns', request_serializer=flwr_dot_proto_dot_fleet__pb2.PullTaskInsRequest.SerializeToString, @@ -51,6 +56,12 @@ def DeleteNode(self, request, context): context.set_details('Method not implemented!') raise NotImplementedError('Method not implemented!') + def Ping(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + def PullTaskIns(self, request, context): """Retrieve one or more tasks, if possible @@ -82,6 +93,11 @@ def add_FleetServicer_to_server(servicer, server): request_deserializer=flwr_dot_proto_dot_fleet__pb2.DeleteNodeRequest.FromString, response_serializer=flwr_dot_proto_dot_fleet__pb2.DeleteNodeResponse.SerializeToString, ), + 'Ping': grpc.unary_unary_rpc_method_handler( + servicer.Ping, + request_deserializer=flwr_dot_proto_dot_fleet__pb2.PingRequest.FromString, + response_serializer=flwr_dot_proto_dot_fleet__pb2.PingResponse.SerializeToString, + ), 'PullTaskIns': grpc.unary_unary_rpc_method_handler( servicer.PullTaskIns, request_deserializer=flwr_dot_proto_dot_fleet__pb2.PullTaskInsRequest.FromString, @@ -136,6 +152,23 @@ def DeleteNode(request, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + @staticmethod + def Ping(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/flwr.proto.Fleet/Ping', + flwr_dot_proto_dot_fleet__pb2.PingRequest.SerializeToString, + flwr_dot_proto_dot_fleet__pb2.PingResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + @staticmethod def PullTaskIns(request, target, diff --git a/src/py/flwr/proto/fleet_pb2_grpc.pyi b/src/py/flwr/proto/fleet_pb2_grpc.pyi index cfa83f737439..33ba9440793a 100644 --- a/src/py/flwr/proto/fleet_pb2_grpc.pyi +++ b/src/py/flwr/proto/fleet_pb2_grpc.pyi @@ -16,6 +16,10 @@ class FleetStub: flwr.proto.fleet_pb2.DeleteNodeRequest, flwr.proto.fleet_pb2.DeleteNodeResponse] + Ping: grpc.UnaryUnaryMultiCallable[ + flwr.proto.fleet_pb2.PingRequest, + flwr.proto.fleet_pb2.PingResponse] + PullTaskIns: grpc.UnaryUnaryMultiCallable[ flwr.proto.fleet_pb2.PullTaskInsRequest, flwr.proto.fleet_pb2.PullTaskInsResponse] @@ -46,6 +50,12 @@ class FleetServicer(metaclass=abc.ABCMeta): context: grpc.ServicerContext, ) -> flwr.proto.fleet_pb2.DeleteNodeResponse: ... + @abc.abstractmethod + def Ping(self, + request: flwr.proto.fleet_pb2.PingRequest, + context: grpc.ServicerContext, + ) -> flwr.proto.fleet_pb2.PingResponse: ... + @abc.abstractmethod def PullTaskIns(self, request: flwr.proto.fleet_pb2.PullTaskInsRequest, diff --git a/src/py/flwr/proto/task_pb2.py b/src/py/flwr/proto/task_pb2.py index 4d5f863e88dd..abf7d72d7174 100644 --- a/src/py/flwr/proto/task_pb2.py +++ b/src/py/flwr/proto/task_pb2.py @@ -18,7 +18,7 @@ from flwr.proto import error_pb2 as flwr_dot_proto_dot_error__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/task.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\x1a\x66lwr/proto/transport.proto\x1a\x16\x66lwr/proto/error.proto\"\xf6\x01\n\x04Task\x12\"\n\x08producer\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\"\n\x08\x63onsumer\x18\x02 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x12\n\ncreated_at\x18\x03 \x01(\t\x12\x14\n\x0c\x64\x65livered_at\x18\x04 \x01(\t\x12\x0b\n\x03ttl\x18\x05 \x01(\t\x12\x10\n\x08\x61ncestry\x18\x06 \x03(\t\x12\x11\n\ttask_type\x18\x07 \x01(\t\x12(\n\trecordset\x18\x08 \x01(\x0b\x32\x15.flwr.proto.RecordSet\x12 \n\x05\x65rror\x18\t \x01(\x0b\x32\x11.flwr.proto.Error\"\\\n\x07TaskIns\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Task\"\\\n\x07TaskRes\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Taskb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/task.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\x1a\x66lwr/proto/transport.proto\x1a\x16\x66lwr/proto/error.proto\"\xf6\x01\n\x04Task\x12\"\n\x08producer\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\"\n\x08\x63onsumer\x18\x02 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x12\n\ncreated_at\x18\x03 \x01(\t\x12\x14\n\x0c\x64\x65livered_at\x18\x04 \x01(\t\x12\x0b\n\x03ttl\x18\x05 \x01(\x01\x12\x10\n\x08\x61ncestry\x18\x06 \x03(\t\x12\x11\n\ttask_type\x18\x07 \x01(\t\x12(\n\trecordset\x18\x08 \x01(\x0b\x32\x15.flwr.proto.RecordSet\x12 \n\x05\x65rror\x18\t \x01(\x0b\x32\x11.flwr.proto.Error\"\\\n\x07TaskIns\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Task\"\\\n\x07TaskRes\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Taskb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) diff --git a/src/py/flwr/proto/task_pb2.pyi b/src/py/flwr/proto/task_pb2.pyi index b9c10139cfb3..735400eca701 100644 --- a/src/py/flwr/proto/task_pb2.pyi +++ b/src/py/flwr/proto/task_pb2.pyi @@ -31,7 +31,7 @@ class Task(google.protobuf.message.Message): def consumer(self) -> flwr.proto.node_pb2.Node: ... created_at: typing.Text delivered_at: typing.Text - ttl: typing.Text + ttl: builtins.float @property def ancestry(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[typing.Text]: ... task_type: typing.Text @@ -45,7 +45,7 @@ class Task(google.protobuf.message.Message): consumer: typing.Optional[flwr.proto.node_pb2.Node] = ..., created_at: typing.Text = ..., delivered_at: typing.Text = ..., - ttl: typing.Text = ..., + ttl: builtins.float = ..., ancestry: typing.Optional[typing.Iterable[typing.Text]] = ..., task_type: typing.Text = ..., recordset: typing.Optional[flwr.proto.recordset_pb2.RecordSet] = ..., diff --git a/src/py/flwr/server/compat/driver_client_proxy.py b/src/py/flwr/server/compat/driver_client_proxy.py index 1e6a84cbb42d..d666525c36ab 100644 --- a/src/py/flwr/server/compat/driver_client_proxy.py +++ b/src/py/flwr/server/compat/driver_client_proxy.py @@ -19,7 +19,7 @@ from typing import List, Optional from flwr import common -from flwr.common import MessageType, MessageTypeLegacy, RecordSet +from flwr.common import DEFAULT_TTL, MessageType, MessageTypeLegacy, RecordSet from flwr.common import recordset_compat as compat from flwr.common import serde from flwr.proto import driver_pb2, node_pb2, task_pb2 # pylint: disable=E0611 @@ -129,6 +129,7 @@ def _send_receive_recordset( ), task_type=task_type, recordset=serde.recordset_to_proto(recordset), + ttl=DEFAULT_TTL, ), ) push_task_ins_req = driver_pb2.PushTaskInsRequest( # pylint: disable=E1101 diff --git a/src/py/flwr/server/driver/driver.py b/src/py/flwr/server/driver/driver.py index 0098e0ce97c2..afebd90ea265 100644 --- a/src/py/flwr/server/driver/driver.py +++ b/src/py/flwr/server/driver/driver.py @@ -18,7 +18,7 @@ import time from typing import Iterable, List, Optional, Tuple -from flwr.common import Message, Metadata, RecordSet +from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet from flwr.common.serde import message_from_taskres, message_to_taskins from flwr.proto.driver_pb2 import ( # pylint: disable=E0611 CreateRunRequest, @@ -81,6 +81,7 @@ def _check_message(self, message: Message) -> None: and message.metadata.src_node_id == self.node.node_id and message.metadata.message_id == "" and message.metadata.reply_to_message == "" + and message.metadata.ttl > 0 ): raise ValueError(f"Invalid message: {message}") @@ -90,7 +91,7 @@ def create_message( # pylint: disable=too-many-arguments message_type: str, dst_node_id: int, group_id: str, - ttl: str, + ttl: float = DEFAULT_TTL, ) -> Message: """Create a new message with specified parameters. @@ -110,10 +111,10 @@ def create_message( # pylint: disable=too-many-arguments group_id : str The ID of the group to which this message is associated. In some settings, this is used as the FL round. - ttl : str + ttl : float (default: common.DEFAULT_TTL) Time-to-live for the round trip of this message, i.e., the time from sending - this message to receiving a reply. It specifies the duration for which the - message and its potential reply are considered valid. + this message to receiving a reply. It specifies in seconds the duration for + which the message and its potential reply are considered valid. Returns ------- diff --git a/src/py/flwr/server/driver/driver_test.py b/src/py/flwr/server/driver/driver_test.py index 5136f4f90210..3f1cd552250f 100644 --- a/src/py/flwr/server/driver/driver_test.py +++ b/src/py/flwr/server/driver/driver_test.py @@ -19,7 +19,7 @@ import unittest from unittest.mock import Mock, patch -from flwr.common import RecordSet +from flwr.common import DEFAULT_TTL, RecordSet from flwr.common.message import Error from flwr.common.serde import error_to_proto, recordset_to_proto from flwr.proto.driver_pb2 import ( # pylint: disable=E0611 @@ -99,7 +99,8 @@ def test_push_messages_valid(self) -> None: mock_response = Mock(task_ids=["id1", "id2"]) self.mock_grpc_driver.push_task_ins.return_value = mock_response msgs = [ - self.driver.create_message(RecordSet(), "", 0, "", "") for _ in range(2) + self.driver.create_message(RecordSet(), "", 0, "", DEFAULT_TTL) + for _ in range(2) ] # Execute @@ -121,7 +122,8 @@ def test_push_messages_invalid(self) -> None: mock_response = Mock(task_ids=["id1", "id2"]) self.mock_grpc_driver.push_task_ins.return_value = mock_response msgs = [ - self.driver.create_message(RecordSet(), "", 0, "", "") for _ in range(2) + self.driver.create_message(RecordSet(), "", 0, "", DEFAULT_TTL) + for _ in range(2) ] # Use invalid run_id msgs[1].metadata._run_id += 1 # pylint: disable=protected-access @@ -170,7 +172,7 @@ def test_send_and_receive_messages_complete(self) -> None: task_res_list=[TaskRes(task=Task(ancestry=["id1"], error=error_proto))] ) self.mock_grpc_driver.pull_task_res.return_value = mock_response - msgs = [self.driver.create_message(RecordSet(), "", 0, "", "")] + msgs = [self.driver.create_message(RecordSet(), "", 0, "", DEFAULT_TTL)] # Execute ret_msgs = list(self.driver.send_and_receive(msgs)) @@ -187,7 +189,7 @@ def test_send_and_receive_messages_timeout(self) -> None: self.mock_grpc_driver.push_task_ins.return_value = mock_response mock_response = Mock(task_res_list=[]) self.mock_grpc_driver.pull_task_res.return_value = mock_response - msgs = [self.driver.create_message(RecordSet(), "", 0, "", "")] + msgs = [self.driver.create_message(RecordSet(), "", 0, "", DEFAULT_TTL)] # Execute with patch("time.sleep", side_effect=lambda t: sleep_fn(t * 0.01)): diff --git a/src/py/flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py b/src/py/flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py index 278474477379..eb8dd800ea37 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +++ b/src/py/flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py @@ -15,7 +15,7 @@ """Fleet API gRPC request-response servicer.""" -from logging import INFO +from logging import DEBUG, INFO import grpc @@ -26,6 +26,8 @@ CreateNodeResponse, DeleteNodeRequest, DeleteNodeResponse, + PingRequest, + PingResponse, PullTaskInsRequest, PullTaskInsResponse, PushTaskResRequest, @@ -61,6 +63,14 @@ def DeleteNode( state=self.state_factory.state(), ) + def Ping(self, request: PingRequest, context: grpc.ServicerContext) -> PingResponse: + """.""" + log(DEBUG, "FleetServicer.Ping") + return message_handler.ping( + request=request, + state=self.state_factory.state(), + ) + def PullTaskIns( self, request: PullTaskInsRequest, context: grpc.ServicerContext ) -> PullTaskInsResponse: diff --git a/src/py/flwr/server/superlink/fleet/message_handler/message_handler.py b/src/py/flwr/server/superlink/fleet/message_handler/message_handler.py index c99a7854d53a..2e696dde78e1 100644 --- a/src/py/flwr/server/superlink/fleet/message_handler/message_handler.py +++ b/src/py/flwr/server/superlink/fleet/message_handler/message_handler.py @@ -23,6 +23,8 @@ CreateNodeResponse, DeleteNodeRequest, DeleteNodeResponse, + PingRequest, + PingResponse, PullTaskInsRequest, PullTaskInsResponse, PushTaskResRequest, @@ -55,6 +57,14 @@ def delete_node(request: DeleteNodeRequest, state: State) -> DeleteNodeResponse: return DeleteNodeResponse() +def ping( + request: PingRequest, # pylint: disable=unused-argument + state: State, # pylint: disable=unused-argument +) -> PingResponse: + """.""" + return PingResponse(success=True) + + def pull_task_ins(request: PullTaskInsRequest, state: State) -> PullTaskInsResponse: """Pull TaskIns handler.""" # Get node_id if client node is not anonymous diff --git a/src/py/flwr/server/superlink/fleet/vce/backend/raybackend_test.py b/src/py/flwr/server/superlink/fleet/vce/backend/raybackend_test.py index 2610307bb749..dcac0b81d666 100644 --- a/src/py/flwr/server/superlink/fleet/vce/backend/raybackend_test.py +++ b/src/py/flwr/server/superlink/fleet/vce/backend/raybackend_test.py @@ -25,6 +25,7 @@ from flwr.client import Client, NumPyClient from flwr.client.client_app import ClientApp, LoadClientAppError from flwr.common import ( + DEFAULT_TTL, Config, ConfigsRecord, Context, @@ -111,7 +112,7 @@ def _create_message_and_context() -> Tuple[Message, Context, float]: src_node_id=0, dst_node_id=0, reply_to_message="", - ttl="", + ttl=DEFAULT_TTL, message_type=MessageTypeLegacy.GET_PROPERTIES, ), ) diff --git a/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py b/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py index 8c37399ae295..2c917c3eed27 100644 --- a/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py +++ b/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py @@ -26,7 +26,13 @@ from unittest import IsolatedAsyncioTestCase from uuid import UUID -from flwr.common import GetPropertiesIns, Message, MessageTypeLegacy, Metadata +from flwr.common import ( + DEFAULT_TTL, + GetPropertiesIns, + Message, + MessageTypeLegacy, + Metadata, +) from flwr.common.recordset_compat import getpropertiesins_to_recordset from flwr.common.serde import message_from_taskres, message_to_taskins from flwr.server.superlink.fleet.vce.vce_api import ( @@ -97,7 +103,7 @@ def register_messages_into_state( src_node_id=0, dst_node_id=dst_node_id, # indicate destination node reply_to_message="", - ttl="", + ttl=DEFAULT_TTL, message_type=( "a bad message" if erroneous_message diff --git a/src/py/flwr/server/superlink/state/in_memory_state.py b/src/py/flwr/server/superlink/state/in_memory_state.py index ac1ab158e254..7bff8ab4befc 100644 --- a/src/py/flwr/server/superlink/state/in_memory_state.py +++ b/src/py/flwr/server/superlink/state/in_memory_state.py @@ -17,7 +17,7 @@ import os import threading -from datetime import datetime, timedelta +from datetime import datetime from logging import ERROR from typing import Dict, List, Optional, Set from uuid import UUID, uuid4 @@ -50,15 +50,13 @@ def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]: log(ERROR, "`run_id` is invalid") return None - # Create task_id, created_at and ttl + # Create task_id and created_at task_id = uuid4() created_at: datetime = now() - ttl: datetime = created_at + timedelta(hours=24) # Store TaskIns task_ins.task_id = str(task_id) task_ins.task.created_at = created_at.isoformat() - task_ins.task.ttl = ttl.isoformat() with self.lock: self.task_ins_store[task_id] = task_ins @@ -113,15 +111,13 @@ def store_task_res(self, task_res: TaskRes) -> Optional[UUID]: log(ERROR, "`run_id` is invalid") return None - # Create task_id, created_at and ttl + # Create task_id and created_at task_id = uuid4() created_at: datetime = now() - ttl: datetime = created_at + timedelta(hours=24) # Store TaskRes task_res.task_id = str(task_id) task_res.task.created_at = created_at.isoformat() - task_res.task.ttl = ttl.isoformat() with self.lock: self.task_res_store[task_id] = task_res diff --git a/src/py/flwr/server/superlink/state/sqlite_state.py b/src/py/flwr/server/superlink/state/sqlite_state.py index 224c16cdf013..25d138f94203 100644 --- a/src/py/flwr/server/superlink/state/sqlite_state.py +++ b/src/py/flwr/server/superlink/state/sqlite_state.py @@ -18,7 +18,7 @@ import os import re import sqlite3 -from datetime import datetime, timedelta +from datetime import datetime from logging import DEBUG, ERROR from typing import Any, Dict, List, Optional, Set, Tuple, Union, cast from uuid import UUID, uuid4 @@ -54,7 +54,7 @@ consumer_node_id INTEGER, created_at TEXT, delivered_at TEXT, - ttl TEXT, + ttl REAL, ancestry TEXT, task_type TEXT, recordset BLOB, @@ -74,7 +74,7 @@ consumer_node_id INTEGER, created_at TEXT, delivered_at TEXT, - ttl TEXT, + ttl REAL, ancestry TEXT, task_type TEXT, recordset BLOB, @@ -185,15 +185,13 @@ def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]: log(ERROR, errors) return None - # Create task_id, created_at and ttl + # Create task_id and created_at task_id = uuid4() created_at: datetime = now() - ttl: datetime = created_at + timedelta(hours=24) # Store TaskIns task_ins.task_id = str(task_id) task_ins.task.created_at = created_at.isoformat() - task_ins.task.ttl = ttl.isoformat() data = (task_ins_to_dict(task_ins),) columns = ", ".join([f":{key}" for key in data[0]]) query = f"INSERT INTO task_ins VALUES({columns});" @@ -320,15 +318,13 @@ def store_task_res(self, task_res: TaskRes) -> Optional[UUID]: log(ERROR, errors) return None - # Create task_id, created_at and ttl + # Create task_id and created_at task_id = uuid4() created_at: datetime = now() - ttl: datetime = created_at + timedelta(hours=24) # Store TaskIns task_res.task_id = str(task_id) task_res.task.created_at = created_at.isoformat() - task_res.task.ttl = ttl.isoformat() data = (task_res_to_dict(task_res),) columns = ", ".join([f":{key}" for key in data[0]]) query = f"INSERT INTO task_res VALUES({columns});" diff --git a/src/py/flwr/server/superlink/state/state_test.py b/src/py/flwr/server/superlink/state/state_test.py index d0470a7ce7f7..01ac64de1380 100644 --- a/src/py/flwr/server/superlink/state/state_test.py +++ b/src/py/flwr/server/superlink/state/state_test.py @@ -22,6 +22,7 @@ from typing import List from uuid import uuid4 +from flwr.common import DEFAULT_TTL from flwr.proto.node_pb2 import Node # pylint: disable=E0611 from flwr.proto.recordset_pb2 import RecordSet # pylint: disable=E0611 from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611 @@ -73,7 +74,6 @@ def test_store_task_ins_one(self) -> None: assert task_ins.task.created_at == "" # pylint: disable=no-member assert task_ins.task.delivered_at == "" # pylint: disable=no-member - assert task_ins.task.ttl == "" # pylint: disable=no-member # Execute state.store_task_ins(task_ins=task_ins) @@ -91,7 +91,6 @@ def test_store_task_ins_one(self) -> None: assert actual_task.created_at != "" assert actual_task.delivered_at != "" - assert actual_task.ttl != "" assert datetime.fromisoformat(actual_task.created_at) > datetime( 2020, 1, 1, tzinfo=timezone.utc @@ -99,9 +98,7 @@ def test_store_task_ins_one(self) -> None: assert datetime.fromisoformat(actual_task.delivered_at) > datetime( 2020, 1, 1, tzinfo=timezone.utc ) - assert datetime.fromisoformat(actual_task.ttl) > datetime( - 2020, 1, 1, tzinfo=timezone.utc - ) + assert actual_task.ttl > 0 def test_store_and_delete_tasks(self) -> None: """Test delete_tasks.""" @@ -420,6 +417,7 @@ def create_task_ins( consumer=consumer, task_type="mock", recordset=RecordSet(parameters={}, metrics={}, configs={}), + ttl=DEFAULT_TTL, ), ) return task @@ -442,6 +440,7 @@ def create_task_res( ancestry=ancestry, task_type="mock", recordset=RecordSet(parameters={}, metrics={}, configs={}), + ttl=DEFAULT_TTL, ), ) return task_res diff --git a/src/py/flwr/server/utils/validator.py b/src/py/flwr/server/utils/validator.py index 846217b085a1..285807d8d0e7 100644 --- a/src/py/flwr/server/utils/validator.py +++ b/src/py/flwr/server/utils/validator.py @@ -36,8 +36,8 @@ def validate_task_ins_or_res(tasks_ins_res: Union[TaskIns, TaskRes]) -> List[str validation_errors.append("`created_at` must be an empty str") if tasks_ins_res.task.delivered_at != "": validation_errors.append("`delivered_at` must be an empty str") - if tasks_ins_res.task.ttl != "": - validation_errors.append("`ttl` must be an empty str") + if tasks_ins_res.task.ttl <= 0: + validation_errors.append("`ttl` must be higher than zero") # TaskIns specific if isinstance(tasks_ins_res, TaskIns): diff --git a/src/py/flwr/server/utils/validator_test.py b/src/py/flwr/server/utils/validator_test.py index 8e0849508020..926103c6b09a 100644 --- a/src/py/flwr/server/utils/validator_test.py +++ b/src/py/flwr/server/utils/validator_test.py @@ -18,6 +18,7 @@ import unittest from typing import List, Tuple +from flwr.common import DEFAULT_TTL from flwr.proto.node_pb2 import Node # pylint: disable=E0611 from flwr.proto.recordset_pb2 import RecordSet # pylint: disable=E0611 from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611 @@ -96,6 +97,7 @@ def create_task_ins( consumer=consumer, task_type="mock", recordset=RecordSet(parameters={}, metrics={}, configs={}), + ttl=DEFAULT_TTL, ), ) return task @@ -117,6 +119,7 @@ def create_task_res( ancestry=ancestry, task_type="mock", recordset=RecordSet(parameters={}, metrics={}, configs={}), + ttl=DEFAULT_TTL, ), ) return task_res diff --git a/src/py/flwr/server/workflow/default_workflows.py b/src/py/flwr/server/workflow/default_workflows.py index 876ae56dcadc..42b1151f9835 100644 --- a/src/py/flwr/server/workflow/default_workflows.py +++ b/src/py/flwr/server/workflow/default_workflows.py @@ -21,7 +21,7 @@ from typing import Optional, cast import flwr.common.recordset_compat as compat -from flwr.common import ConfigsRecord, Context, GetParametersIns, log +from flwr.common import DEFAULT_TTL, ConfigsRecord, Context, GetParametersIns, log from flwr.common.constant import MessageType, MessageTypeLegacy from ..compat.app_utils import start_update_client_manager_thread @@ -127,7 +127,7 @@ def default_init_params_workflow(driver: Driver, context: Context) -> None: message_type=MessageTypeLegacy.GET_PARAMETERS, dst_node_id=random_client.node_id, group_id="0", - ttl="", + ttl=DEFAULT_TTL, ) ] ) @@ -226,7 +226,7 @@ def default_fit_workflow( # pylint: disable=R0914 message_type=MessageType.TRAIN, dst_node_id=proxy.node_id, group_id=str(current_round), - ttl="", + ttl=DEFAULT_TTL, ) for proxy, fitins in client_instructions ] @@ -306,7 +306,7 @@ def default_evaluate_workflow(driver: Driver, context: Context) -> None: message_type=MessageType.EVALUATE, dst_node_id=proxy.node_id, group_id=str(current_round), - ttl="", + ttl=DEFAULT_TTL, ) for proxy, evalins in client_instructions ] diff --git a/src/py/flwr/server/workflow/secure_aggregation/secaggplus_workflow.py b/src/py/flwr/server/workflow/secure_aggregation/secaggplus_workflow.py index 42ee9c15f1cd..326947b653ff 100644 --- a/src/py/flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +++ b/src/py/flwr/server/workflow/secure_aggregation/secaggplus_workflow.py @@ -22,6 +22,7 @@ import flwr.common.recordset_compat as compat from flwr.common import ( + DEFAULT_TTL, ConfigsRecord, Context, FitRes, @@ -373,7 +374,7 @@ def make(nid: int) -> Message: message_type=MessageType.TRAIN, dst_node_id=nid, group_id=str(cfg[WorkflowKey.CURRENT_ROUND]), - ttl="", + ttl=DEFAULT_TTL, ) log( @@ -421,7 +422,7 @@ def make(nid: int) -> Message: message_type=MessageType.TRAIN, dst_node_id=nid, group_id=str(cfg[WorkflowKey.CURRENT_ROUND]), - ttl="", + ttl=DEFAULT_TTL, ) # Broadcast public keys to clients and receive secret key shares @@ -492,7 +493,7 @@ def make(nid: int) -> Message: message_type=MessageType.TRAIN, dst_node_id=nid, group_id=str(cfg[WorkflowKey.CURRENT_ROUND]), - ttl="", + ttl=DEFAULT_TTL, ) log( @@ -563,7 +564,7 @@ def make(nid: int) -> Message: message_type=MessageType.TRAIN, dst_node_id=nid, group_id=str(current_round), - ttl="", + ttl=DEFAULT_TTL, ) log( diff --git a/src/py/flwr/simulation/ray_transport/ray_client_proxy.py b/src/py/flwr/simulation/ray_transport/ray_client_proxy.py index c3493163ac52..5e344eb087ee 100644 --- a/src/py/flwr/simulation/ray_transport/ray_client_proxy.py +++ b/src/py/flwr/simulation/ray_transport/ray_client_proxy.py @@ -23,7 +23,7 @@ from flwr.client import ClientFn from flwr.client.client_app import ClientApp from flwr.client.node_state import NodeState -from flwr.common import Message, Metadata, RecordSet +from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet from flwr.common.constant import MessageType, MessageTypeLegacy from flwr.common.logger import log from flwr.common.recordset_compat import ( @@ -105,7 +105,7 @@ def _wrap_recordset_in_message( src_node_id=0, dst_node_id=int(self.cid), reply_to_message="", - ttl=str(timeout) if timeout else "", + ttl=timeout if timeout else DEFAULT_TTL, message_type=message_type, partition_id=int(self.cid), ), diff --git a/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py b/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py index 22c5425cd9fd..9680b3846f1d 100644 --- a/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py +++ b/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py @@ -24,6 +24,7 @@ from flwr.client import Client, NumPyClient from flwr.client.client_app import ClientApp from flwr.common import ( + DEFAULT_TTL, Config, ConfigsRecord, Context, @@ -202,7 +203,7 @@ def _load_app() -> ClientApp: src_node_id=0, dst_node_id=12345, reply_to_message="", - ttl="", + ttl=DEFAULT_TTL, message_type=MessageTypeLegacy.GET_PROPERTIES, partition_id=int(cid), ),