diff --git a/src/py/flwr/common/constant.py b/src/py/flwr/common/constant.py index 6a4061a72505..b6d39b6e8932 100644 --- a/src/py/flwr/common/constant.py +++ b/src/py/flwr/common/constant.py @@ -83,6 +83,7 @@ class ErrorCode: UNKNOWN = 0 LOAD_CLIENT_APP_EXCEPTION = 1 CLIENT_APP_RAISED_EXCEPTION = 2 + NODE_UNAVAILABLE = 3 def __new__(cls) -> ErrorCode: """Prevent instantiation.""" diff --git a/src/py/flwr/server/superlink/state/sqlite_state.py b/src/py/flwr/server/superlink/state/sqlite_state.py index c8ebc5e5b21c..a40dbde16aaf 100644 --- a/src/py/flwr/server/superlink/state/sqlite_state.py +++ b/src/py/flwr/server/superlink/state/sqlite_state.py @@ -411,20 +411,20 @@ def get_task_res(self, task_ids: Set[UUID], limit: Optional[int]) -> List[TaskRe result = [dict_to_task_res(row) for row in rows] - # 1. Query: Fetch producer_node_id of remaining task_ids + # 1. Query: Fetch consumer_node_id of remaining task_ids # Assume the ancestry field only contains one element data.clear() replied_task_ids: Set[UUID] = {UUID(str(row["ancestry"])) for row in rows} remaining_task_ids = task_ids - replied_task_ids placeholders = ",".join([f":id_{i}" for i in range(len(remaining_task_ids))]) query = f""" - SELECT producer_node_id + SELECT consumer_node_id FROM task_ins WHERE task_id IN ({placeholders}); """ for index, task_id in enumerate(remaining_task_ids): data[f"id_{index}"] = str(task_id) - node_ids = [int(row["producer_node_id"]) for row in self.query(query, data)] + node_ids = [int(row["consumer_node_id"]) for row in self.query(query, data)] # 2. Query: Select offline nodes placeholders = ",".join([f":id_{i}" for i in range(len(node_ids))]) diff --git a/src/py/flwr/server/superlink/state/state_test.py b/src/py/flwr/server/superlink/state/state_test.py index 151520f87a2c..f48f498aed06 100644 --- a/src/py/flwr/server/superlink/state/state_test.py +++ b/src/py/flwr/server/superlink/state/state_test.py @@ -25,6 +25,7 @@ from uuid import uuid4 from flwr.common import DEFAULT_TTL +from flwr.common.constant import ErrorCode from flwr.proto.node_pb2 import Node # pylint: disable=E0611 from flwr.proto.recordset_pb2 import RecordSet # pylint: disable=E0611 from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611 @@ -412,30 +413,48 @@ def test_acknowledge_ping(self) -> None: # Assert self.assertSetEqual(actual_node_ids, set(node_ids[70:])) - + def test_node_unavailable_error(self) -> None: """Test if get_task_res return TaskRes containing node unavailable error.""" # Prepare state: State = self.state_factory() run_id = state.create_run() - node_id_0 = state.create_node() - node_id_1 = state.create_node() - # state.acknowledge_ping(node_ids[0], ping_interval=30) - # state.acknowledge_ping(node_ids[1], ping_interval=90) - # task_ins_0 = create_task_ins( - # consumer_node_id=consumer_node_id, anonymous=False, run_id=run_id - # ) - # task_ins_1 = create_task_ins( - # consumer_node_id=consumer_node_id, anonymous=False, run_id=run_id - # ) + node_id_0 = state.create_node(ping_interval=90) + node_id_1 = state.create_node(ping_interval=30) + # Create and store TaskIns + task_ins_0 = create_task_ins( + consumer_node_id=node_id_0, anonymous=False, run_id=run_id + ) + task_ins_1 = create_task_ins( + consumer_node_id=node_id_1, anonymous=False, run_id=run_id + ) + task_id_0 = state.store_task_ins(task_ins=task_ins_0) + task_id_1 = state.store_task_ins(task_ins=task_ins_1) + assert task_id_0 is not None and task_id_1 is not None + + # Get TaskIns to mark them delivered + state.get_task_ins(node_id=node_id_0, limit=None) + + # Create and store TaskRes + task_res_0 = create_task_res( + producer_node_id=100, + anonymous=False, + ancestry=[str(task_id_0)], + run_id=run_id, + ) + state.store_task_res(task_res_0) # Execute current_time = time.time() + task_res_list: List[TaskRes] = [] with patch("time.time", side_effect=lambda: current_time + 50): - actual_node_ids = state.get_nodes(run_id) + task_res_list = state.get_task_res({task_id_0, task_id_1}, limit=None) # Assert - self.assertSetEqual(actual_node_ids, set(node_ids[70:])) + assert len(task_res_list) == 2 + err_taskres = task_res_list[1] + assert err_taskres.task.HasField("error") + assert err_taskres.task.error.code == ErrorCode.NODE_UNAVAILABLE def create_task_ins( diff --git a/src/py/flwr/server/superlink/state/utils.py b/src/py/flwr/server/superlink/state/utils.py index 1eaea40b2c2b..53186749353a 100644 --- a/src/py/flwr/server/superlink/state/utils.py +++ b/src/py/flwr/server/superlink/state/utils.py @@ -18,16 +18,13 @@ import time from uuid import uuid4 -# pylint: disable=E0611 -from flwr.proto.error_pb2 import Error -from flwr.proto.node_pb2 import Node -from flwr.proto.task_pb2 import Task, TaskIns, TaskRes - -# pylint: enable=E0611 - +from flwr.common.constant import ErrorCode +from flwr.proto.error_pb2 import Error # pylint: disable=E0611 +from flwr.proto.node_pb2 import Node # pylint: disable=E0611 +from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611 NODE_UNAVAILABLE_ERROR_REASON = ( - "Error: Node Unavailable - The destination node is currently unavailable" + "Error: Node Unavailable - The destination node is currently unavailable." ) @@ -45,6 +42,8 @@ def make_node_unavailable_taskres(ref_taskins: TaskIns) -> TaskRes: ttl=0, ancestry=[ref_taskins.task_id], task_type=ref_taskins.task.task_type, - error=Error(code=3, reason=NODE_UNAVAILABLE_ERROR_REASON), + error=Error( + code=ErrorCode.NODE_UNAVAILABLE, reason=NODE_UNAVAILABLE_ERROR_REASON + ), ), )