Skip to content

Commit

Permalink
Edit
Browse files Browse the repository at this point in the history
  • Loading branch information
mohammadnaseri committed Sep 28, 2024
1 parent 216491c commit b7008f6
Showing 1 changed file with 46 additions and 43 deletions.
89 changes: 46 additions & 43 deletions src/py/flwr/server/superlink/state/state_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from datetime import datetime, timezone
from unittest.mock import patch

from parameterized import parameterized

from flwr.common import DEFAULT_TTL
from flwr.common.constant import ErrorCode
Expand Down Expand Up @@ -709,59 +708,63 @@ def test_store_task_res_task_ins_expired(self) -> None:
# Assert
assert result is None

@parameterized.expand(
[ # type: ignore
def test_store_task_res_limit_ttl(self) -> None:
"""Test the behavior of store_task_res regarding the TTL limit of TaskRes."""
current_time = time.time()

test_cases = [
(
time.time() - 5,
250,
time.time() - 4,
50,
current_time - 5,
10,
current_time - 2,
6,
True,
), # TaskRes within allowed TTL
(
time.time() - 5,
100,
time.time() - 4,
250,
current_time - 5,
10,
current_time - 2,
15,
False,
), # TaskRes TTL exceeds max allowed TTL
]
)
def test_store_task_res_limit_ttl(
self,
task_ins_created_at: float,
task_ins_ttl: float,
task_res_created_at: float,
task_res_ttl: float,
expected_store_result: bool,
) -> None:
"""Test the behavior of store_task_res regarding the TTL limit of TaskRes."""
# Prepare
state: State = self.state_factory()
run_id = state.create_run(None, None, "9f86d08", {})

task_ins = create_task_ins(consumer_node_id=0, anonymous=True, run_id=run_id)
task_ins.task.created_at = task_ins_created_at
task_ins.task.ttl = task_ins_ttl
task_ins_id = state.store_task_ins(task_ins)
for (
task_ins_created_at,
task_ins_ttl,
task_res_created_at,
task_res_ttl,
expected_store_result,
) in test_cases:

task_res = create_task_res(
producer_node_id=0,
anonymous=True,
ancestry=[str(task_ins_id)],
run_id=run_id,
)
task_res.task.created_at = task_res_created_at
task_res.task.ttl = task_res_ttl
# Prepare
state: State = self.state_factory()
run_id = state.create_run(None, None, "9f86d08", {})

# Execute
res = state.store_task_res(task_res)
task_ins = create_task_ins(
consumer_node_id=0, anonymous=True, run_id=run_id
)
task_ins.task.created_at = task_ins_created_at
task_ins.task.ttl = task_ins_ttl
task_ins_id = state.store_task_ins(task_ins)

# Assert
if expected_store_result:
assert res is not None
else:
assert res is None
task_res = create_task_res(
producer_node_id=0,
anonymous=True,
ancestry=[str(task_ins_id)],
run_id=run_id,
)
task_res.task.created_at = task_res_created_at
task_res.task.ttl = task_res_ttl

# Execute
res = state.store_task_res(task_res)

# Assert
if expected_store_result:
assert res is not None
else:
assert res is None


def create_task_ins(
Expand Down

0 comments on commit b7008f6

Please sign in to comment.