Skip to content

Commit

Permalink
feat(framework) Limit TaskRes TTL when saving it (#3615)
Browse files Browse the repository at this point in the history
Co-authored-by: Heng Pan <[email protected]>
Co-authored-by: Daniel J. Beutel <[email protected]>
  • Loading branch information
3 people authored Oct 7, 2024
1 parent dbf8d6d commit 2fd3bd8
Show file tree
Hide file tree
Showing 6 changed files with 181 additions and 20 deletions.
3 changes: 3 additions & 0 deletions src/py/flwr/common/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@
GRPC_ADAPTER_METADATA_MESSAGE_MODULE_KEY = "grpc-message-module"
GRPC_ADAPTER_METADATA_MESSAGE_QUALNAME_KEY = "grpc-message-qualname"

# Message TTL
MESSAGE_TTL_TOLERANCE = 1e-1


class MessageType:
"""Message type."""
Expand Down
47 changes: 32 additions & 15 deletions src/py/flwr/common/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@
from __future__ import annotations

import time
import warnings
from logging import WARNING
from typing import Optional, cast

from .constant import MESSAGE_TTL_TOLERANCE
from .logger import log
from .record import RecordSet

DEFAULT_TTL = 3600
Expand Down Expand Up @@ -289,13 +291,6 @@ def create_error_reply(self, error: Error, ttl: float | None = None) -> Message:
ttl = msg.meta.ttl - (reply.meta.created_at - msg.meta.created_at)
"""
if ttl:
warnings.warn(
"A custom TTL was set, but note that the SuperLink does not enforce "
"the TTL yet. The SuperLink will start enforcing the TTL in a future "
"version of Flower.",
stacklevel=2,
)
# If no TTL passed, use default for message creation (will update after
# message creation)
ttl_ = DEFAULT_TTL if ttl is None else ttl
Expand All @@ -309,6 +304,8 @@ def create_error_reply(self, error: Error, ttl: float | None = None) -> Message:
)
message.metadata.ttl = ttl

self._limit_task_res_ttl(message)

return message

def create_reply(self, content: RecordSet, ttl: float | None = None) -> Message:
Expand All @@ -334,13 +331,6 @@ def create_reply(self, content: RecordSet, ttl: float | None = None) -> Message:
Message
A new `Message` instance representing the reply.
"""
if ttl:
warnings.warn(
"A custom TTL was set, but note that the SuperLink does not enforce "
"the TTL yet. The SuperLink will start enforcing the TTL in a future "
"version of Flower.",
stacklevel=2,
)
# If no TTL passed, use default for message creation (will update after
# message creation)
ttl_ = DEFAULT_TTL if ttl is None else ttl
Expand All @@ -357,6 +347,8 @@ def create_reply(self, content: RecordSet, ttl: float | None = None) -> Message:
)
message.metadata.ttl = ttl

self._limit_task_res_ttl(message)

return message

def __repr__(self) -> str:
Expand All @@ -370,6 +362,31 @@ def __repr__(self) -> str:
)
return f"{self.__class__.__qualname__}({view})"

def _limit_task_res_ttl(self, message: Message) -> None:
"""Limit the TaskRes TTL to not exceed the expiration time of the TaskIns it
replies to.
Parameters
----------
message : Message
The message to which the TaskRes is replying.
"""
# Calculate the maximum allowed TTL
max_allowed_ttl = (
self.metadata.created_at + self.metadata.ttl - message.metadata.created_at
)

if message.metadata.ttl - max_allowed_ttl > MESSAGE_TTL_TOLERANCE:
log(
WARNING,
"The reply TTL of %.2f seconds exceeded the "
"allowed maximum of %.2f seconds. "
"The TTL has been updated to the allowed maximum.",
message.metadata.ttl,
max_allowed_ttl,
)
message.metadata.ttl = max_allowed_ttl


def _create_reply_metadata(msg: Message, ttl: float) -> Metadata:
"""Construct metadata for a reply message."""
Expand Down
33 changes: 33 additions & 0 deletions src/py/flwr/common/message_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

# pylint: enable=E0611
from . import RecordSet
from .constant import MESSAGE_TTL_TOLERANCE
from .message import Error, Message, Metadata
from .serde_test import RecordMaker

Expand Down Expand Up @@ -202,3 +203,35 @@ def test_repr(cls: type, kwargs: dict[str, Any]) -> None:

# Assert
assert str(actual) == str(expected)


@pytest.mark.parametrize(
"message_creation_fn,initial_ttl,reply_ttl,expected_reply_ttl",
[
# Case where the reply_ttl is larger than the allowed TTL
(create_message_with_content, 20, 30, 20),
(create_message_with_error, 20, 30, 20),
# Case where the reply_ttl is within the allowed range
(create_message_with_content, 20, 10, 10),
(create_message_with_error, 20, 10, 10),
],
)
def test_reply_ttl_limitation(
message_creation_fn: Callable[[float], Message],
initial_ttl: float,
reply_ttl: float,
expected_reply_ttl: float,
) -> None:
"""Test that the reply TTL does not exceed the allowed TTL."""
message = message_creation_fn(initial_ttl)

if message.has_error():
dummy_error = Error(code=0, reason="test error")
reply_message = message.create_error_reply(dummy_error, ttl=reply_ttl)
else:
reply_message = message.create_reply(content=RecordSet(), ttl=reply_ttl)

assert reply_message.metadata.ttl - expected_reply_ttl <= MESSAGE_TTL_TOLERANCE, (
f"Expected TTL to be <= {expected_reply_ttl}, "
f"but got {reply_message.metadata.ttl}"
)
29 changes: 27 additions & 2 deletions src/py/flwr/server/superlink/state/in_memory_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,16 @@

import threading
import time
from logging import ERROR
from logging import ERROR, WARNING
from typing import Optional
from uuid import UUID, uuid4

from flwr.common import log, now
from flwr.common.constant import NODE_ID_NUM_BYTES, RUN_ID_NUM_BYTES
from flwr.common.constant import (
MESSAGE_TTL_TOLERANCE,
NODE_ID_NUM_BYTES,
RUN_ID_NUM_BYTES,
)
from flwr.common.typing import Run, UserConfig
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
from flwr.server.superlink.state.state import State
Expand Down Expand Up @@ -134,6 +138,27 @@ def store_task_res(self, task_res: TaskRes) -> Optional[UUID]:
)
return None

# Fail if the TaskRes TTL exceeds the
# expiration time of the TaskIns it replies to.
# Condition: TaskIns.created_at + TaskIns.ttl ≥
# TaskRes.created_at + TaskRes.ttl
# A small tolerance is introduced to account
# for floating-point precision issues.
max_allowed_ttl = (
task_ins.task.created_at + task_ins.task.ttl - task_res.task.created_at
)
if task_res.task.ttl and (
task_res.task.ttl - max_allowed_ttl > MESSAGE_TTL_TOLERANCE
):
log(
WARNING,
"Received TaskRes with TTL %.2f "
"exceeding the allowed maximum TTL %.2f.",
task_res.task.ttl,
max_allowed_ttl,
)
return None

# Validate run_id
if task_res.run_id not in self.run_ids:
log(ERROR, "`run_id` is invalid")
Expand Down
29 changes: 27 additions & 2 deletions src/py/flwr/server/superlink/state/sqlite_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,16 @@
import sqlite3
import time
from collections.abc import Sequence
from logging import DEBUG, ERROR
from logging import DEBUG, ERROR, WARNING
from typing import Any, Optional, Union, cast
from uuid import UUID, uuid4

from flwr.common import log, now
from flwr.common.constant import NODE_ID_NUM_BYTES, RUN_ID_NUM_BYTES
from flwr.common.constant import (
MESSAGE_TTL_TOLERANCE,
NODE_ID_NUM_BYTES,
RUN_ID_NUM_BYTES,
)
from flwr.common.typing import Run, UserConfig
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
from flwr.proto.recordset_pb2 import RecordSet # pylint: disable=E0611
Expand Down Expand Up @@ -383,6 +387,27 @@ def store_task_res(self, task_res: TaskRes) -> Optional[UUID]:
)
return None

# Fail if the TaskRes TTL exceeds the
# expiration time of the TaskIns it replies to.
# Condition: TaskIns.created_at + TaskIns.ttl ≥
# TaskRes.created_at + TaskRes.ttl
# A small tolerance is introduced to account
# for floating-point precision issues.
max_allowed_ttl = (
task_ins["created_at"] + task_ins["ttl"] - task_res.task.created_at
)
if task_res.task.ttl and (
task_res.task.ttl - max_allowed_ttl > MESSAGE_TTL_TOLERANCE
):
log(
WARNING,
"Received TaskRes with TTL %.2f "
"exceeding the allowed maximum TTL %.2f.",
task_res.task.ttl,
max_allowed_ttl,
)
return None

# Store TaskRes
task_res.task_id = str(task_id)
data = (task_res_to_dict(task_res),)
Expand Down
60 changes: 59 additions & 1 deletion src/py/flwr/server/superlink/state/state_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
# ==============================================================================
"""Tests all state implemenations have to conform to."""
# pylint: disable=invalid-name, disable=R0904
# pylint: disable=invalid-name, disable=R0904,R0913

import tempfile
import time
Expand Down Expand Up @@ -707,6 +707,64 @@ def test_store_task_res_task_ins_expired(self) -> None:
# Assert
assert result is None

def test_store_task_res_limit_ttl(self) -> None:
"""Test the behavior of store_task_res regarding the TTL limit of TaskRes."""
current_time = time.time()

test_cases = [
(
current_time - 5,
10,
current_time - 2,
6,
True,
), # TaskRes within allowed TTL
(
current_time - 5,
10,
current_time - 2,
15,
False,
), # TaskRes TTL exceeds max allowed TTL
]

for (
task_ins_created_at,
task_ins_ttl,
task_res_created_at,
task_res_ttl,
expected_store_result,
) in test_cases:

# Prepare
state: State = self.state_factory()
run_id = state.create_run(None, None, "9f86d08", {})

task_ins = create_task_ins(
consumer_node_id=0, anonymous=True, run_id=run_id
)
task_ins.task.created_at = task_ins_created_at
task_ins.task.ttl = task_ins_ttl
task_ins_id = state.store_task_ins(task_ins)

task_res = create_task_res(
producer_node_id=0,
anonymous=True,
ancestry=[str(task_ins_id)],
run_id=run_id,
)
task_res.task.created_at = task_res_created_at
task_res.task.ttl = task_res_ttl

# Execute
res = state.store_task_res(task_res)

# Assert
if expected_store_result:
assert res is not None
else:
assert res is None


def create_task_ins(
consumer_node_id: int,
Expand Down

0 comments on commit 2fd3bd8

Please sign in to comment.