From 83cd4ba34f864e9134913e310079d02870fa64be Mon Sep 17 00:00:00 2001 From: Mohammad Naseri Date: Thu, 26 Sep 2024 19:59:16 +0100 Subject: [PATCH] feat(framework) Verify the TaskIns TTL when saving TaskRes (#3609) Co-authored-by: Heng Pan --- .../server/superlink/state/in_memory_state.py | 17 ++++++ .../server/superlink/state/sqlite_state.py | 40 +++++++++++++- .../flwr/server/superlink/state/state_test.py | 53 +++++++++++++++++-- 3 files changed, 104 insertions(+), 6 deletions(-) diff --git a/src/py/flwr/server/superlink/state/in_memory_state.py b/src/py/flwr/server/superlink/state/in_memory_state.py index e34d15374350..e09df8dc76f6 100644 --- a/src/py/flwr/server/superlink/state/in_memory_state.py +++ b/src/py/flwr/server/superlink/state/in_memory_state.py @@ -117,6 +117,23 @@ def store_task_res(self, task_res: TaskRes) -> Optional[UUID]: log(ERROR, errors) return None + with self.lock: + # Check if the TaskIns it is replying to exists and is valid + task_ins_id = task_res.task.ancestry[0] + task_ins = self.task_ins_store.get(UUID(task_ins_id)) + + if task_ins is None: + log(ERROR, "TaskIns with task_id %s does not exist.", task_ins_id) + return None + + if task_ins.task.created_at + task_ins.task.ttl <= time.time(): + log( + ERROR, + "Failed to store TaskRes: TaskIns with task_id %s has expired.", + task_ins_id, + ) + return None + # Validate run_id if task_res.run_id not in self.run_ids: log(ERROR, "`run_id` is invalid") diff --git a/src/py/flwr/server/superlink/state/sqlite_state.py b/src/py/flwr/server/superlink/state/sqlite_state.py index 28d957a90bd3..d18683286196 100644 --- a/src/py/flwr/server/superlink/state/sqlite_state.py +++ b/src/py/flwr/server/superlink/state/sqlite_state.py @@ -372,7 +372,18 @@ def store_task_res(self, task_res: TaskRes) -> Optional[UUID]: # Create task_id task_id = uuid4() - # Store TaskIns + task_ins_id = task_res.task.ancestry[0] + task_ins = self.get_valid_task_ins(task_ins_id) + if task_ins is None: + log( + ERROR, + "Failed to store TaskRes: " + "TaskIns with task_id %s does not exist or has expired.", + task_ins_id, + ) + return None + + # Store TaskRes task_res.task_id = str(task_id) data = (task_res_to_dict(task_res),) @@ -810,6 +821,33 @@ def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool: log(ERROR, "`node_id` does not exist.") return False + def get_valid_task_ins(self, task_id: str) -> Optional[dict[str, Any]]: + """Check if the TaskIns exists and is valid (not expired). + + Return TaskIns if valid. + """ + query = """ + SELECT * + FROM task_ins + WHERE task_id = :task_id + """ + data = {"task_id": task_id} + rows = self.query(query, data) + if not rows: + # TaskIns does not exist + return None + + task_ins = rows[0] + created_at = task_ins["created_at"] + ttl = task_ins["ttl"] + current_time = time.time() + + # Check if TaskIns is expired + if ttl is not None and created_at + ttl <= current_time: + return None + + return task_ins + def dict_factory( cursor: sqlite3.Cursor, diff --git a/src/py/flwr/server/superlink/state/state_test.py b/src/py/flwr/server/superlink/state/state_test.py index 42c0768f1c7d..85cda1a5af9c 100644 --- a/src/py/flwr/server/superlink/state/state_test.py +++ b/src/py/flwr/server/superlink/state/state_test.py @@ -21,7 +21,6 @@ from abc import abstractmethod from datetime import datetime, timezone from unittest.mock import patch -from uuid import uuid4 from flwr.common import DEFAULT_TTL from flwr.common.constant import ErrorCode @@ -302,7 +301,10 @@ def test_task_res_store_and_retrieve_by_task_ins_id(self) -> None: # Prepare state: State = self.state_factory() run_id = state.create_run(None, None, "9f86d08", {}) - task_ins_id = uuid4() + + task_ins = create_task_ins(consumer_node_id=0, anonymous=True, run_id=run_id) + task_ins_id = state.store_task_ins(task_ins) + task_res = create_task_res( producer_node_id=0, anonymous=True, @@ -312,7 +314,9 @@ def test_task_res_store_and_retrieve_by_task_ins_id(self) -> None: # Execute task_res_uuid = state.store_task_res(task_res) - task_res_list = state.get_task_res(task_ids={task_ins_id}, limit=None) + + if task_ins_id is not None: + task_res_list = state.get_task_res(task_ids={task_ins_id}, limit=None) # Assert retrieved_task_res = task_res_list[0] @@ -507,11 +511,23 @@ def test_num_task_res(self) -> None: # Prepare state: State = self.state_factory() run_id = state.create_run(None, None, "9f86d08", {}) + + task_ins_0 = create_task_ins(consumer_node_id=0, anonymous=True, run_id=run_id) + task_ins_1 = create_task_ins(consumer_node_id=0, anonymous=True, run_id=run_id) + task_ins_id_0 = state.store_task_ins(task_ins_0) + task_ins_id_1 = state.store_task_ins(task_ins_1) + task_0 = create_task_res( - producer_node_id=0, anonymous=True, ancestry=["1"], run_id=run_id + producer_node_id=0, + anonymous=True, + ancestry=[str(task_ins_id_0)], + run_id=run_id, ) task_1 = create_task_res( - producer_node_id=0, anonymous=True, ancestry=["1"], run_id=run_id + producer_node_id=0, + anonymous=True, + ancestry=[str(task_ins_id_1)], + run_id=run_id, ) # Store two tasks @@ -664,6 +680,33 @@ def test_node_unavailable_error(self) -> None: assert err_taskres.task.HasField("error") assert err_taskres.task.error.code == ErrorCode.NODE_UNAVAILABLE + def test_store_task_res_task_ins_expired(self) -> None: + """Test behavior of store_task_res when the TaskIns it references is expired.""" + # Prepare + state: State = self.state_factory() + run_id = state.create_run(None, None, "9f86d08", {}) + + task_ins = create_task_ins(consumer_node_id=0, anonymous=True, run_id=run_id) + task_ins.task.created_at = time.time() - task_ins.task.ttl + 0.5 + task_ins_id = state.store_task_ins(task_ins) + + with patch( + "time.time", + side_effect=lambda: task_ins.task.created_at + task_ins.task.ttl + 0.1, + ): # Expired by 0.1 seconds + task = create_task_res( + producer_node_id=0, + anonymous=True, + ancestry=[str(task_ins_id)], + run_id=run_id, + ) + + # Execute + result = state.store_task_res(task) + + # Assert + assert result is None + def create_task_ins( consumer_node_id: int,