diff --git a/src/py/flwr/server/superlink/state/in_memory_state.py b/src/py/flwr/server/superlink/state/in_memory_state.py index f8ae5a7e95b7..2c4519d8c148 100644 --- a/src/py/flwr/server/superlink/state/in_memory_state.py +++ b/src/py/flwr/server/superlink/state/in_memory_state.py @@ -116,6 +116,7 @@ def get_task_ins( # Return TaskIns return task_ins_list + # pylint: disable=R0911 def store_task_res(self, task_res: TaskRes) -> Optional[UUID]: """Store one TaskRes.""" # Validate task @@ -129,6 +130,17 @@ def store_task_res(self, task_res: TaskRes) -> Optional[UUID]: task_ins_id = task_res.task.ancestry[0] task_ins = self.task_ins_store.get(UUID(task_ins_id)) + # Ensure that the consumer_id of taskIns matches the producer_id of taskRes. + if ( + task_ins + and task_res + and not ( + task_ins.task.consumer.anonymous or task_res.task.producer.anonymous + ) + and task_ins.task.consumer.node_id != task_res.task.producer.node_id + ): + return None + if task_ins is None: log(ERROR, "TaskIns with task_id %s does not exist.", task_ins_id) return None diff --git a/src/py/flwr/server/superlink/state/sqlite_state.py b/src/py/flwr/server/superlink/state/sqlite_state.py index 502b1e2461b2..73c121b01b4b 100644 --- a/src/py/flwr/server/superlink/state/sqlite_state.py +++ b/src/py/flwr/server/superlink/state/sqlite_state.py @@ -390,6 +390,16 @@ def store_task_res(self, task_res: TaskRes) -> Optional[UUID]: ) return None + # Ensure that the consumer_id of taskIns matches the producer_id of taskRes. + if ( + task_ins + and task_res + and not (task_ins["consumer_anonymous"] or task_res.task.producer.anonymous) + and convert_sint64_to_uint64(task_ins["consumer_node_id"]) + != task_res.task.producer.node_id + ): + return None + # Fail if the TaskRes TTL exceeds the # expiration time of the TaskIns it replies to. # Condition: TaskIns.created_at + TaskIns.ttl ≥ diff --git a/src/py/flwr/server/superlink/state/state_test.py b/src/py/flwr/server/superlink/state/state_test.py index 2a5eab30b4b7..c3e0ac70d567 100644 --- a/src/py/flwr/server/superlink/state/state_test.py +++ b/src/py/flwr/server/superlink/state/state_test.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================== """Tests all state implemenations have to conform to.""" -# pylint: disable=invalid-name, disable=R0904,R0913 +# pylint: disable=invalid-name, too-many-lines, R0904, R0913 import tempfile import time @@ -149,7 +149,7 @@ def test_store_and_delete_tasks(self) -> None: # Insert one TaskRes and retrive it to mark it as delivered task_res_0 = create_task_res( - producer_node_id=100, + producer_node_id=consumer_node_id, anonymous=False, ancestry=[str(task_id_0)], run_id=run_id, @@ -160,7 +160,7 @@ def test_store_and_delete_tasks(self) -> None: # Insert one TaskRes, but don't retrive it task_res_1: TaskRes = create_task_res( - producer_node_id=100, + producer_node_id=consumer_node_id, anonymous=False, ancestry=[str(task_id_1)], run_id=run_id, @@ -662,7 +662,7 @@ def test_node_unavailable_error(self) -> None: # Create and store TaskRes task_res_0 = create_task_res( - producer_node_id=100, + producer_node_id=node_id_0, anonymous=False, ancestry=[str(task_id_0)], run_id=run_id, @@ -871,6 +871,32 @@ def test_get_task_res_return_if_not_expired(self) -> None: # Assert assert len(task_res_list) != 0 + def test_store_task_res_fail_if_consumer_producer_id_mismatch(self) -> None: + """Test store_task_res to fail if there is a mismatch between the + consumer_node_id of taskIns and the producer_node_id of taskRes.""" + # Prepare + consumer_node_id = 1 + state = self.state_factory() + run_id = state.create_run(None, None, "9f86d08", {}) + task_ins = create_task_ins( + consumer_node_id=consumer_node_id, anonymous=False, run_id=run_id + ) + + task_id = state.store_task_ins(task_ins=task_ins) + + task_res = create_task_res( + producer_node_id=100, # different than consumer_node_id + anonymous=False, + ancestry=[str(task_id)], + run_id=run_id, + ) + + # Execute + task_res_uuid = state.store_task_res(task_res=task_res) + + # Assert + assert task_res_uuid is None + def create_task_ins( consumer_node_id: int,