diff --git a/src/py/flwr/server/superlink/state/state_test.py b/src/py/flwr/server/superlink/state/state_test.py index 08d1e2424c8f..e49240f31359 100644 --- a/src/py/flwr/server/superlink/state/state_test.py +++ b/src/py/flwr/server/superlink/state/state_test.py @@ -22,6 +22,8 @@ from datetime import datetime, timezone from unittest.mock import patch +from parameterized import parameterized + from flwr.common import DEFAULT_TTL from flwr.common.constant import ErrorCode from flwr.common.secure_aggregation.crypto.symmetric_encryption import ( @@ -707,63 +709,59 @@ def test_store_task_res_task_ins_expired(self) -> None: # Assert assert result is None - def test_store_task_res_limit_ttl(self) -> None: - """Test the behavior of store_task_res regarding the TTL limit of TaskRes.""" - current_time = time.time() - - test_cases = [ + @parameterized.expand( + [ # type: ignore ( - current_time - 5, + time.time() - 5, 10, - current_time - 2, + time.time() - 2, 6, True, ), # TaskRes within allowed TTL ( - current_time - 5, + time.time() - 5, 10, - current_time - 2, + time.time() - 2, 15, False, ), # TaskRes TTL exceeds max allowed TTL ] + ) + def test_store_task_res_limit_ttl( + self, + task_ins_created_at: float, + task_ins_ttl: float, + task_res_created_at: float, + task_res_ttl: float, + expected_store_result: bool, + ) -> None: + """Test the behavior of store_task_res regarding the TTL limit of TaskRes.""" + # Prepare + state: State = self.state_factory() + run_id = state.create_run(None, None, "9f86d08", {}) - for ( - task_ins_created_at, - task_ins_ttl, - task_res_created_at, - task_res_ttl, - expected_store_result, - ) in test_cases: - - # Prepare - state: State = self.state_factory() - run_id = state.create_run(None, None, "9f86d08", {}) - - task_ins = create_task_ins( - consumer_node_id=0, anonymous=True, run_id=run_id - ) - task_ins.task.created_at = task_ins_created_at - task_ins.task.ttl = task_ins_ttl - task_ins_id = state.store_task_ins(task_ins) + task_ins = create_task_ins(consumer_node_id=0, anonymous=True, run_id=run_id) + task_ins.task.created_at = task_ins_created_at + task_ins.task.ttl = task_ins_ttl + task_ins_id = state.store_task_ins(task_ins) - task_res = create_task_res( - producer_node_id=0, - anonymous=True, - ancestry=[str(task_ins_id)], - run_id=run_id, - ) - task_res.task.created_at = task_res_created_at - task_res.task.ttl = task_res_ttl + task_res = create_task_res( + producer_node_id=0, + anonymous=True, + ancestry=[str(task_ins_id)], + run_id=run_id, + ) + task_res.task.created_at = task_res_created_at + task_res.task.ttl = task_res_ttl - # Execute - res = state.store_task_res(task_res) + # Execute + res = state.store_task_res(task_res) - # Assert - if expected_store_result: - assert res is not None - else: - assert res is None + # Assert + if expected_store_result: + assert res is not None + else: + assert res is None def create_task_ins(