diff --git a/src/py/flwr/server/driver/inmemory_driver_test.py b/src/py/flwr/server/driver/inmemory_driver_test.py index a5e893c707fa..c10c57648900 100644 --- a/src/py/flwr/server/driver/inmemory_driver_test.py +++ b/src/py/flwr/server/driver/inmemory_driver_test.py @@ -45,9 +45,8 @@ def push_messages(driver: InMemoryDriver, num_nodes: int) -> tuple[Iterable[str], int]: """Help push messages to state.""" for _ in range(num_nodes): - driver.state.create_node(ping_interval=PING_MAX_INTERVAL) + node_id = driver.state.create_node(ping_interval=PING_MAX_INTERVAL) num_messages = 3 - node_id = 1 msgs = [ driver.create_message(RecordSet(), "message_type", node_id, "") for _ in range(num_messages) diff --git a/src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py b/src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py index 0830c26fc49c..52194a5a9ac8 100644 --- a/src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py +++ b/src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py @@ -87,8 +87,25 @@ def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]: return None # Validate run_id if task_ins.run_id not in self.run_ids: - log(ERROR, "`run_id` is invalid") + log(ERROR, "Invalid run ID for TaskIns: %s", task_ins.run_id) + return None + # Validate source node ID + if task_ins.task.producer.node_id != 0: + log( + ERROR, + "Invalid source node ID for TaskIns: %s", + task_ins.task.producer.node_id, + ) return None + # Validate destination node ID + if not task_ins.task.consumer.anonymous: + if task_ins.task.consumer.node_id not in self.node_ids: + log( + ERROR, + "Invalid destination node ID for TaskIns: %s", + task_ins.task.consumer.node_id, + ) + return None # Create task_id task_id = uuid4() diff --git a/src/py/flwr/server/superlink/linkstate/linkstate_test.py b/src/py/flwr/server/superlink/linkstate/linkstate_test.py index 1fc21bf02a2a..9e00e4a0c49a 100644 --- a/src/py/flwr/server/superlink/linkstate/linkstate_test.py +++ b/src/py/flwr/server/superlink/linkstate/linkstate_test.py @@ -192,11 +192,11 @@ def test_get_task_res_empty(self) -> None: def test_store_task_ins_one(self) -> None: """Test store_task_ins.""" # Prepare - consumer_node_id = 1 state = self.state_factory() + node_id = state.create_node(1e3) run_id = state.create_run(None, None, "9f86d08", {}) task_ins = create_task_ins( - consumer_node_id=consumer_node_id, anonymous=False, run_id=run_id + consumer_node_id=node_id, anonymous=False, run_id=run_id ) assert task_ins.task.created_at < time.time() # pylint: disable=no-member @@ -204,7 +204,7 @@ def test_store_task_ins_one(self) -> None: # Execute state.store_task_ins(task_ins=task_ins) - task_ins_list = state.get_task_ins(node_id=consumer_node_id, limit=10) + task_ins_list = state.get_task_ins(node_id=node_id, limit=10) # Assert assert len(task_ins_list) == 1 @@ -224,20 +224,39 @@ def test_store_task_ins_one(self) -> None: ) assert actual_task.ttl > 0 + def test_store_task_ins_invalid_node_id(self) -> None: + """Test store_task_ins with invalid node_id.""" + # Prepare + state = self.state_factory() + node_id = state.create_node(1e3) + invalid_node_id = 61016 if node_id != 61016 else 61017 + run_id = state.create_run(None, None, "9f86d08", {}) + task_ins = create_task_ins( + consumer_node_id=invalid_node_id, anonymous=False, run_id=run_id + ) + task_ins2 = create_task_ins( + consumer_node_id=node_id, anonymous=False, run_id=run_id + ) + task_ins2.task.producer.node_id = 61016 + + # Execute and assert + assert state.store_task_ins(task_ins) is None + assert state.store_task_ins(task_ins2) is None + def test_store_and_delete_tasks(self) -> None: """Test delete_tasks.""" # Prepare - consumer_node_id = 1 state = self.state_factory() + node_id = state.create_node(1e3) run_id = state.create_run(None, None, "9f86d08", {}) task_ins_0 = create_task_ins( - consumer_node_id=consumer_node_id, anonymous=False, run_id=run_id + consumer_node_id=node_id, anonymous=False, run_id=run_id ) task_ins_1 = create_task_ins( - consumer_node_id=consumer_node_id, anonymous=False, run_id=run_id + consumer_node_id=node_id, anonymous=False, run_id=run_id ) task_ins_2 = create_task_ins( - consumer_node_id=consumer_node_id, anonymous=False, run_id=run_id + consumer_node_id=node_id, anonymous=False, run_id=run_id ) # Insert three TaskIns @@ -250,11 +269,11 @@ def test_store_and_delete_tasks(self) -> None: assert task_id_2 # Get TaskIns to mark them delivered - _ = state.get_task_ins(node_id=consumer_node_id, limit=None) + _ = state.get_task_ins(node_id=node_id, limit=None) # Insert one TaskRes and retrive it to mark it as delivered task_res_0 = create_task_res( - producer_node_id=consumer_node_id, + producer_node_id=node_id, anonymous=False, ancestry=[str(task_id_0)], run_id=run_id, @@ -265,7 +284,7 @@ def test_store_and_delete_tasks(self) -> None: # Insert one TaskRes, but don't retrive it task_res_1: TaskRes = create_task_res( - producer_node_id=consumer_node_id, + producer_node_id=node_id, anonymous=False, ancestry=[str(task_id_1)], run_id=run_id, @@ -332,8 +351,11 @@ def test_task_ins_store_identity_and_fail_retrieving_anonymous(self) -> None: """Store identity TaskIns and fail retrieving it as anonymous.""" # Prepare state: LinkState = self.state_factory() + node_id = state.create_node(1e3) run_id = state.create_run(None, None, "9f86d08", {}) - task_ins = create_task_ins(consumer_node_id=1, anonymous=False, run_id=run_id) + task_ins = create_task_ins( + consumer_node_id=node_id, anonymous=False, run_id=run_id + ) # Execute _ = state.store_task_ins(task_ins) @@ -346,12 +368,15 @@ def test_task_ins_store_identity_and_retrieve_identity(self) -> None: """Store identity TaskIns and retrieve it.""" # Prepare state: LinkState = self.state_factory() + node_id = state.create_node(1e3) run_id = state.create_run(None, None, "9f86d08", {}) - task_ins = create_task_ins(consumer_node_id=1, anonymous=False, run_id=run_id) + task_ins = create_task_ins( + consumer_node_id=node_id, anonymous=False, run_id=run_id + ) # Execute task_ins_uuid = state.store_task_ins(task_ins) - task_ins_list = state.get_task_ins(node_id=1, limit=None) + task_ins_list = state.get_task_ins(node_id=node_id, limit=None) # Assert assert len(task_ins_list) == 1 @@ -363,14 +388,17 @@ def test_task_ins_store_delivered_and_fail_retrieving(self) -> None: """Fail retrieving delivered task.""" # Prepare state: LinkState = self.state_factory() + node_id = state.create_node(1e3) run_id = state.create_run(None, None, "9f86d08", {}) - task_ins = create_task_ins(consumer_node_id=1, anonymous=False, run_id=run_id) + task_ins = create_task_ins( + consumer_node_id=node_id, anonymous=False, run_id=run_id + ) # Execute _ = state.store_task_ins(task_ins) # 1st get: set to delivered - task_ins_list = state.get_task_ins(node_id=1, limit=None) + task_ins_list = state.get_task_ins(node_id=node_id, limit=None) assert len(task_ins_list) == 1 @@ -874,11 +902,11 @@ def test_store_task_res_limit_ttl(self) -> None: def test_get_task_ins_not_return_expired(self) -> None: """Test get_task_ins not to return expired tasks.""" # Prepare - consumer_node_id = 1 state = self.state_factory() + node_id = state.create_node(1e3) run_id = state.create_run(None, None, "9f86d08", {}) task_ins = create_task_ins( - consumer_node_id=consumer_node_id, anonymous=False, run_id=run_id + consumer_node_id=node_id, anonymous=False, run_id=run_id ) task_ins.task.created_at = time.time() - 5 task_ins.task.ttl = 5.0 @@ -894,11 +922,11 @@ def test_get_task_ins_not_return_expired(self) -> None: def test_get_task_res_not_return_expired(self) -> None: """Test get_task_res not to return TaskRes if its TaskIns is expired.""" # Prepare - consumer_node_id = 1 state = self.state_factory() + node_id = state.create_node(1e3) run_id = state.create_run(None, None, "9f86d08", {}) task_ins = create_task_ins( - consumer_node_id=consumer_node_id, anonymous=False, run_id=run_id + consumer_node_id=node_id, anonymous=False, run_id=run_id ) task_ins.task.created_at = time.time() - 5 task_ins.task.ttl = 5.1 @@ -948,11 +976,11 @@ def test_get_task_res_return_if_not_expired(self) -> None: """Test get_task_res to return TaskRes if its TaskIns exists and is not expired.""" # Prepare - consumer_node_id = 1 state = self.state_factory() + node_id = state.create_node(1e3) run_id = state.create_run(None, None, "9f86d08", {}) task_ins = create_task_ins( - consumer_node_id=consumer_node_id, anonymous=False, run_id=run_id + consumer_node_id=node_id, anonymous=False, run_id=run_id ) task_ins.task.created_at = time.time() - 5 task_ins.task.ttl = 7.1 @@ -960,7 +988,7 @@ def test_get_task_res_return_if_not_expired(self) -> None: task_id = state.store_task_ins(task_ins=task_ins) task_res = create_task_res( - producer_node_id=1, + producer_node_id=node_id, anonymous=False, ancestry=[str(task_id)], run_id=run_id, @@ -980,17 +1008,18 @@ def test_store_task_res_fail_if_consumer_producer_id_mismatch(self) -> None: """Test store_task_res to fail if there is a mismatch between the consumer_node_id of taskIns and the producer_node_id of taskRes.""" # Prepare - consumer_node_id = 1 state = self.state_factory() + node_id = state.create_node(1e3) run_id = state.create_run(None, None, "9f86d08", {}) task_ins = create_task_ins( - consumer_node_id=consumer_node_id, anonymous=False, run_id=run_id + consumer_node_id=node_id, anonymous=False, run_id=run_id ) task_id = state.store_task_ins(task_ins=task_ins) task_res = create_task_res( - producer_node_id=100, # different than consumer_node_id + # Different than consumer_node_id + producer_node_id=100 if node_id != 100 else 101, anonymous=False, ancestry=[str(task_id)], run_id=run_id, diff --git a/src/py/flwr/server/superlink/linkstate/sqlite_linkstate.py b/src/py/flwr/server/superlink/linkstate/sqlite_linkstate.py index 2094bd1d8592..ad73bd4fcce0 100644 --- a/src/py/flwr/server/superlink/linkstate/sqlite_linkstate.py +++ b/src/py/flwr/server/superlink/linkstate/sqlite_linkstate.py @@ -271,7 +271,6 @@ def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]: if any(errors): log(ERROR, errors) return None - # Create task_id task_id = uuid4() @@ -284,16 +283,36 @@ def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]: data[0], ["run_id", "producer_node_id", "consumer_node_id"] ) + # Validate run_id + query = "SELECT run_id FROM run WHERE run_id = ?;" + if not self.query(query, (data[0]["run_id"],)): + log(ERROR, "Invalid run ID for TaskIns: %s", task_ins.run_id) + return None + # Validate source node ID + if task_ins.task.producer.node_id != 0: + log( + ERROR, + "Invalid source node ID for TaskIns: %s", + task_ins.task.producer.node_id, + ) + return None + # Validate destination node ID + query = "SELECT node_id FROM node WHERE node_id = ?;" + if not task_ins.task.consumer.anonymous: + if not self.query(query, (data[0]["consumer_node_id"],)): + log( + ERROR, + "Invalid destination node ID for TaskIns: %s", + task_ins.task.consumer.node_id, + ) + return None + columns = ", ".join([f":{key}" for key in data[0]]) query = f"INSERT INTO task_ins VALUES({columns});" # Only invalid run_id can trigger IntegrityError. # This may need to be changed in the future version with more integrity checks. - try: - self.query(query, data) - except sqlite3.IntegrityError: - log(ERROR, "`run` is invalid") - return None + self.query(query, data) return task_id