diff --git a/src/py/flwr/server/superlink/state/state_test.py b/src/py/flwr/server/superlink/state/state_test.py index 6fe4b8d975e7..e6ff1ddd915b 100644 --- a/src/py/flwr/server/superlink/state/state_test.py +++ b/src/py/flwr/server/superlink/state/state_test.py @@ -22,7 +22,6 @@ from datetime import datetime, timezone from unittest.mock import patch -from parameterized import parameterized from flwr.common import DEFAULT_TTL from flwr.common.constant import ErrorCode @@ -709,59 +708,63 @@ def test_store_task_res_task_ins_expired(self) -> None: # Assert assert result is None - @parameterized.expand( - [ # type: ignore + def test_store_task_res_limit_ttl(self) -> None: + """Test the behavior of store_task_res regarding the TTL limit of TaskRes.""" + current_time = time.time() + + test_cases = [ ( - time.time() - 5, - 250, - time.time() - 4, - 50, + current_time - 5, + 10, + current_time - 2, + 6, True, ), # TaskRes within allowed TTL ( - time.time() - 5, - 100, - time.time() - 4, - 250, + current_time - 5, + 10, + current_time - 2, + 15, False, ), # TaskRes TTL exceeds max allowed TTL ] - ) - def test_store_task_res_limit_ttl( - self, - task_ins_created_at: float, - task_ins_ttl: float, - task_res_created_at: float, - task_res_ttl: float, - expected_store_result: bool, - ) -> None: - """Test the behavior of store_task_res regarding the TTL limit of TaskRes.""" - # Prepare - state: State = self.state_factory() - run_id = state.create_run(None, None, "9f86d08", {}) - task_ins = create_task_ins(consumer_node_id=0, anonymous=True, run_id=run_id) - task_ins.task.created_at = task_ins_created_at - task_ins.task.ttl = task_ins_ttl - task_ins_id = state.store_task_ins(task_ins) + for ( + task_ins_created_at, + task_ins_ttl, + task_res_created_at, + task_res_ttl, + expected_store_result, + ) in test_cases: - task_res = create_task_res( - producer_node_id=0, - anonymous=True, - ancestry=[str(task_ins_id)], - run_id=run_id, - ) - task_res.task.created_at = task_res_created_at - task_res.task.ttl = task_res_ttl + # Prepare + state: State = self.state_factory() + run_id = state.create_run(None, None, "9f86d08", {}) - # Execute - res = state.store_task_res(task_res) + task_ins = create_task_ins( + consumer_node_id=0, anonymous=True, run_id=run_id + ) + task_ins.task.created_at = task_ins_created_at + task_ins.task.ttl = task_ins_ttl + task_ins_id = state.store_task_ins(task_ins) - # Assert - if expected_store_result: - assert res is not None - else: - assert res is None + task_res = create_task_res( + producer_node_id=0, + anonymous=True, + ancestry=[str(task_ins_id)], + run_id=run_id, + ) + task_res.task.created_at = task_res_created_at + task_res.task.ttl = task_res_ttl + + # Execute + res = state.store_task_res(task_res) + + # Assert + if expected_store_result: + assert res is not None + else: + assert res is None def create_task_ins(