Skip to content

Commit

Permalink
feat(framework) Check TTL when retrieving TaskIns (#3620)
Browse files Browse the repository at this point in the history
  • Loading branch information
mohammadnaseri authored Oct 7, 2024
1 parent 7ef7426 commit 8bf5b82
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/py/flwr/server/superlink/state/in_memory_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def get_task_ins(

# Find TaskIns for node_id that were not delivered yet
task_ins_list: list[TaskIns] = []
current_time = time.time()
with self.lock:
for _, task_ins in self.task_ins_store.items():
# pylint: disable=too-many-boolean-expressions
Expand All @@ -95,11 +96,13 @@ def get_task_ins(
and task_ins.task.consumer.anonymous is False
and task_ins.task.consumer.node_id == node_id
and task_ins.task.delivered_at == ""
and task_ins.task.created_at + task_ins.task.ttl > current_time
) or (
node_id is None # Anonymous
and task_ins.task.consumer.anonymous is True
and task_ins.task.consumer.node_id == 0
and task_ins.task.delivered_at == ""
and task_ins.task.created_at + task_ins.task.ttl > current_time
):
task_ins_list.append(task_ins)
if limit and len(task_ins_list) == limit:
Expand Down
2 changes: 2 additions & 0 deletions src/py/flwr/server/superlink/state/sqlite_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ def get_task_ins(
WHERE consumer_anonymous == 1
AND consumer_node_id == 0
AND delivered_at = ""
AND (created_at + ttl) > CAST(strftime('%s', 'now') AS REAL)
"""
else:
# Convert the uint64 value to sint64 for SQLite
Expand All @@ -311,6 +312,7 @@ def get_task_ins(
WHERE consumer_anonymous == 0
AND consumer_node_id == :node_id
AND delivered_at = ""
AND (created_at + ttl) > CAST(strftime('%s', 'now') AS REAL)
"""

if limit is not None:
Expand Down
20 changes: 20 additions & 0 deletions src/py/flwr/server/superlink/state/state_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,6 +765,26 @@ def test_store_task_res_limit_ttl(self) -> None:
else:
assert res is None

def test_get_task_ins_not_return_expired(self) -> None:
"""Test get_task_ins not to return expired tasks."""
# Prepare
consumer_node_id = 1
state = self.state_factory()
run_id = state.create_run(None, None, "9f86d08", {})
task_ins = create_task_ins(
consumer_node_id=consumer_node_id, anonymous=False, run_id=run_id
)
task_ins.task.created_at = time.time() - 5
task_ins.task.ttl = 5.0

# Execute
state.store_task_ins(task_ins=task_ins)

# Assert
with patch("time.time", side_effect=lambda: task_ins.task.created_at + 6.1):
task_ins_list = state.get_task_ins(node_id=1, limit=None)
assert len(task_ins_list) == 0


def create_task_ins(
consumer_node_id: int,
Expand Down

0 comments on commit 8bf5b82

Please sign in to comment.