From 084835ac05e95cc6733e97c13f4ef1d866e88918 Mon Sep 17 00:00:00 2001 From: Javier Date: Tue, 2 Apr 2024 15:25:19 +0100 Subject: [PATCH] Warn once if `TTL` is set (#3195) --- src/py/flwr/common/message.py | 15 ++++++++++++++ src/py/flwr/server/driver/driver.py | 20 ++++++++++++++----- .../flwr/server/workflow/default_workflows.py | 5 +---- .../secure_aggregation/secaggplus_workflow.py | 5 ----- 4 files changed, 31 insertions(+), 14 deletions(-) diff --git a/src/py/flwr/common/message.py b/src/py/flwr/common/message.py index 7707f3c72de1..2105eabda27e 100644 --- a/src/py/flwr/common/message.py +++ b/src/py/flwr/common/message.py @@ -17,6 +17,7 @@ from __future__ import annotations import time +import warnings from dataclasses import dataclass from .record import RecordSet @@ -311,6 +312,13 @@ def create_error_reply(self, error: Error, ttl: float | None = None) -> Message: ttl = msg.meta.ttl - (reply.meta.created_at - msg.meta.created_at) """ + if ttl: + warnings.warn( + "A custom TTL was set, but note that the SuperLink does not enforce " + "the TTL yet. The SuperLink will start enforcing the TTL in a future " + "version of Flower.", + stacklevel=2, + ) # If no TTL passed, use default for message creation (will update after # message creation) ttl_ = DEFAULT_TTL if ttl is None else ttl @@ -349,6 +357,13 @@ def create_reply(self, content: RecordSet, ttl: float | None = None) -> Message: Message A new `Message` instance representing the reply. """ + if ttl: + warnings.warn( + "A custom TTL was set, but note that the SuperLink does not enforce " + "the TTL yet. The SuperLink will start enforcing the TTL in a future " + "version of Flower.", + stacklevel=2, + ) # If no TTL passed, use default for message creation (will update after # message creation) ttl_ = DEFAULT_TTL if ttl is None else ttl diff --git a/src/py/flwr/server/driver/driver.py b/src/py/flwr/server/driver/driver.py index afebd90ea265..a917912c0f63 100644 --- a/src/py/flwr/server/driver/driver.py +++ b/src/py/flwr/server/driver/driver.py @@ -14,8 +14,8 @@ # ============================================================================== """Flower driver service client.""" - import time +import warnings from typing import Iterable, List, Optional, Tuple from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet @@ -91,7 +91,7 @@ def create_message( # pylint: disable=too-many-arguments message_type: str, dst_node_id: int, group_id: str, - ttl: float = DEFAULT_TTL, + ttl: Optional[float] = None, ) -> Message: """Create a new message with specified parameters. @@ -111,10 +111,11 @@ def create_message( # pylint: disable=too-many-arguments group_id : str The ID of the group to which this message is associated. In some settings, this is used as the FL round. - ttl : float (default: common.DEFAULT_TTL) + ttl : Optional[float] (default: None) Time-to-live for the round trip of this message, i.e., the time from sending this message to receiving a reply. It specifies in seconds the duration for - which the message and its potential reply are considered valid. + which the message and its potential reply are considered valid. If unset, + the default TTL (i.e., `common.DEFAULT_TTL`) will be used. Returns ------- @@ -122,6 +123,15 @@ def create_message( # pylint: disable=too-many-arguments A new `Message` instance with the specified content and metadata. """ _, run_id = self._get_grpc_driver_and_run_id() + if ttl: + warnings.warn( + "A custom TTL was set, but note that the SuperLink does not enforce " + "the TTL yet. The SuperLink will start enforcing the TTL in a future " + "version of Flower.", + stacklevel=2, + ) + + ttl_ = DEFAULT_TTL if ttl is None else ttl metadata = Metadata( run_id=run_id, message_id="", # Will be set by the server @@ -129,7 +139,7 @@ def create_message( # pylint: disable=too-many-arguments dst_node_id=dst_node_id, reply_to_message="", group_id=group_id, - ttl=ttl, + ttl=ttl_, message_type=message_type, ) return Message(metadata=metadata, content=content) diff --git a/src/py/flwr/server/workflow/default_workflows.py b/src/py/flwr/server/workflow/default_workflows.py index 42b1151f9835..ac023cc98ca5 100644 --- a/src/py/flwr/server/workflow/default_workflows.py +++ b/src/py/flwr/server/workflow/default_workflows.py @@ -21,7 +21,7 @@ from typing import Optional, cast import flwr.common.recordset_compat as compat -from flwr.common import DEFAULT_TTL, ConfigsRecord, Context, GetParametersIns, log +from flwr.common import ConfigsRecord, Context, GetParametersIns, log from flwr.common.constant import MessageType, MessageTypeLegacy from ..compat.app_utils import start_update_client_manager_thread @@ -127,7 +127,6 @@ def default_init_params_workflow(driver: Driver, context: Context) -> None: message_type=MessageTypeLegacy.GET_PARAMETERS, dst_node_id=random_client.node_id, group_id="0", - ttl=DEFAULT_TTL, ) ] ) @@ -226,7 +225,6 @@ def default_fit_workflow( # pylint: disable=R0914 message_type=MessageType.TRAIN, dst_node_id=proxy.node_id, group_id=str(current_round), - ttl=DEFAULT_TTL, ) for proxy, fitins in client_instructions ] @@ -306,7 +304,6 @@ def default_evaluate_workflow(driver: Driver, context: Context) -> None: message_type=MessageType.EVALUATE, dst_node_id=proxy.node_id, group_id=str(current_round), - ttl=DEFAULT_TTL, ) for proxy, evalins in client_instructions ] diff --git a/src/py/flwr/server/workflow/secure_aggregation/secaggplus_workflow.py b/src/py/flwr/server/workflow/secure_aggregation/secaggplus_workflow.py index 326947b653ff..d6d97c28f313 100644 --- a/src/py/flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +++ b/src/py/flwr/server/workflow/secure_aggregation/secaggplus_workflow.py @@ -22,7 +22,6 @@ import flwr.common.recordset_compat as compat from flwr.common import ( - DEFAULT_TTL, ConfigsRecord, Context, FitRes, @@ -374,7 +373,6 @@ def make(nid: int) -> Message: message_type=MessageType.TRAIN, dst_node_id=nid, group_id=str(cfg[WorkflowKey.CURRENT_ROUND]), - ttl=DEFAULT_TTL, ) log( @@ -422,7 +420,6 @@ def make(nid: int) -> Message: message_type=MessageType.TRAIN, dst_node_id=nid, group_id=str(cfg[WorkflowKey.CURRENT_ROUND]), - ttl=DEFAULT_TTL, ) # Broadcast public keys to clients and receive secret key shares @@ -493,7 +490,6 @@ def make(nid: int) -> Message: message_type=MessageType.TRAIN, dst_node_id=nid, group_id=str(cfg[WorkflowKey.CURRENT_ROUND]), - ttl=DEFAULT_TTL, ) log( @@ -564,7 +560,6 @@ def make(nid: int) -> Message: message_type=MessageType.TRAIN, dst_node_id=nid, group_id=str(current_round), - ttl=DEFAULT_TTL, ) log(