From e7ea7cbae64af583e526a9188bcabfdfbb4db2d8 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Tue, 2 Apr 2024 15:10:52 +0100 Subject: [PATCH] temp --- .../server/superlink/state/in_memory_state.py | 27 ++++++++-- .../server/superlink/state/sqlite_state.py | 52 ++++++++++++++++++- .../flwr/server/superlink/state/state_test.py | 24 +++++++++ src/py/flwr/server/superlink/state/utils.py | 50 ++++++++++++++++++ 4 files changed, 148 insertions(+), 5 deletions(-) create mode 100644 src/py/flwr/server/superlink/state/utils.py diff --git a/src/py/flwr/server/superlink/state/in_memory_state.py b/src/py/flwr/server/superlink/state/in_memory_state.py index 6fc57707ac36..4acb32021cd5 100644 --- a/src/py/flwr/server/superlink/state/in_memory_state.py +++ b/src/py/flwr/server/superlink/state/in_memory_state.py @@ -27,6 +27,8 @@ from flwr.server.superlink.state.state import State from flwr.server.utils import validate_task_ins_or_res +from .utils import make_node_unavailable_taskres + class InMemoryState(State): """In-memory State implementation.""" @@ -129,14 +131,31 @@ def get_task_res(self, task_ids: Set[UUID], limit: Optional[int]) -> List[TaskRe with self.lock: # Find TaskRes that were not delivered yet task_res_list: List[TaskRes] = [] + replied_task_ids: Set[UUID] = set() for _, task_res in self.task_res_store.items(): - if ( - UUID(task_res.task.ancestry[0]) in task_ids - and task_res.task.delivered_at == "" - ): + reply_to = UUID(task_res.task.ancestry[0]) + if reply_to in task_ids and task_res.task.delivered_at == "": task_res_list.append(task_res) + replied_task_ids.add(reply_to) + if limit and len(task_res_list) == limit: + break + + # Check if the node is offline + for task_id in task_ids - replied_task_ids: if limit and len(task_res_list) == limit: break + task_ins = self.task_ins_store.get(task_id) + if task_ins is None: + continue + node_id = task_ins.task.consumer.node_id + online_until, _ = self.node_ids[node_id] + # Generate a TaskRes containing an error reply if the node is offline. + if online_until < time.time(): + err_taskres = make_node_unavailable_taskres( + ref_taskins=task_ins, + ) + self.task_res_store[UUID(err_taskres.task_id)] = err_taskres + task_res_list.append(err_taskres) # Mark all of them as delivered delivered_at = now().isoformat() diff --git a/src/py/flwr/server/superlink/state/sqlite_state.py b/src/py/flwr/server/superlink/state/sqlite_state.py index 6996d51d2a9b..8af048a25485 100644 --- a/src/py/flwr/server/superlink/state/sqlite_state.py +++ b/src/py/flwr/server/superlink/state/sqlite_state.py @@ -30,6 +30,7 @@ from flwr.server.utils.validator import validate_task_ins_or_res from .state import State +from .utils import make_node_unavailable_taskres SQL_CREATE_TABLE_NODE = """ CREATE TABLE IF NOT EXISTS node( @@ -344,6 +345,7 @@ def store_task_res(self, task_res: TaskRes) -> Optional[UUID]: return task_id + # pylint: disable-next=R0914 def get_task_res(self, task_ids: Set[UUID], limit: Optional[int]) -> List[TaskRes]: """Get TaskRes for task_ids. @@ -374,7 +376,7 @@ def get_task_res(self, task_ids: Set[UUID], limit: Optional[int]) -> List[TaskRe AND delivered_at = "" """ - data: Dict[str, Union[str, int]] = {} + data: Dict[str, Union[str, float, int]] = {} if limit is not None: query += " LIMIT :limit" @@ -408,6 +410,54 @@ def get_task_res(self, task_ids: Set[UUID], limit: Optional[int]) -> List[TaskRe rows = self.query(query, data) result = [dict_to_task_res(row) for row in rows] + + # 1. Query: Fetch producer_node_id of remaining task_ids + # Assume the ancestry field only contains one element + data.clear() + replied_task_ids: Set[UUID] = {UUID(str(row["ancestry"])) for row in rows} + remaining_task_ids = task_ids - replied_task_ids + placeholders = ",".join([f":id_{i}" for i in range(len(remaining_task_ids))]) + query = f""" + SELECT producer_node_id + FROM task_ins + WHERE task_id IN ({placeholders}); + """ + for index, task_id in enumerate(remaining_task_ids): + data[f"id_{index}"] = str(task_id) + node_ids = [int(row["producer_node_id"]) for row in self.query(query, data)] + + # 2. Query: Select offline nodes + placeholders = ",".join([f":id_{i}" for i in range(len(node_ids))]) + query = f""" + SELECT node_id + FROM node + WHERE node_id IN ({placeholders}) + AND online_until < :time; + """ + data = {f"id_{i}": str(node_id) for i, node_id in enumerate(node_ids)} + data["time"] = time.time() + offline_node_ids = [int(row["node_id"]) for row in self.query(query, data)] + + # 3. Query: Select TaskIns for offline nodes + placeholders = ",".join([f":id_{i}" for i in range(len(offline_node_ids))]) + query = f""" + SELECT * + FROM task_ins + WHERE consumer_node_id IN ({placeholders}); + """ + data = {f"id_{i}": str(node_id) for i, node_id in enumerate(offline_node_ids)} + task_ins_rows = self.query(query, data) + + # Make TaskRes containing node unavailabe error + for row in task_ins_rows: + if limit and len(result) == limit: + break + task_ins = dict_to_task_ins(row) + err_taskres = make_node_unavailable_taskres( + ref_taskins=task_ins, + ) + result.append(err_taskres) + return result def num_task_ins(self) -> int: diff --git a/src/py/flwr/server/superlink/state/state_test.py b/src/py/flwr/server/superlink/state/state_test.py index 1757cfac4255..7cefebc0f318 100644 --- a/src/py/flwr/server/superlink/state/state_test.py +++ b/src/py/flwr/server/superlink/state/state_test.py @@ -412,6 +412,30 @@ def test_acknowledge_ping(self) -> None: # Assert self.assertSetEqual(actual_node_ids, set(node_ids[70:])) + + def test_node_unavailable_error(self) -> None: + """Test if get_task_res return TaskRes containing node unavailable error.""" + # Prepare + state: State = self.state_factory() + run_id = state.create_run() + node_id_0 = state.create_node() + node_id_1 = state.create_node() + # state.acknowledge_ping(node_ids[0], ping_interval=30) + # state.acknowledge_ping(node_ids[1], ping_interval=90) + # task_ins_0 = create_task_ins( + # consumer_node_id=consumer_node_id, anonymous=False, run_id=run_id + # ) + # task_ins_1 = create_task_ins( + # consumer_node_id=consumer_node_id, anonymous=False, run_id=run_id + # ) + + # Execute + current_time = time.time() + with patch("time.time", side_effect=lambda: current_time + 50): + actual_node_ids = state.get_nodes(run_id) + + # Assert + self.assertSetEqual(actual_node_ids, set(node_ids[70:])) def create_task_ins( diff --git a/src/py/flwr/server/superlink/state/utils.py b/src/py/flwr/server/superlink/state/utils.py new file mode 100644 index 000000000000..1eaea40b2c2b --- /dev/null +++ b/src/py/flwr/server/superlink/state/utils.py @@ -0,0 +1,50 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""In-memory State implementation.""" + + +import time +from uuid import uuid4 + +# pylint: disable=E0611 +from flwr.proto.error_pb2 import Error +from flwr.proto.node_pb2 import Node +from flwr.proto.task_pb2 import Task, TaskIns, TaskRes + +# pylint: enable=E0611 + + +NODE_UNAVAILABLE_ERROR_REASON = ( + "Error: Node Unavailable - The destination node is currently unavailable" +) + + +def make_node_unavailable_taskres(ref_taskins: TaskIns) -> TaskRes: + """Create a TaskRes containing the node available error based on the reference + TaskIns.""" + return TaskRes( + task_id=str(uuid4()), + group_id=ref_taskins.group_id, + run_id=ref_taskins.run_id, + task=Task( + producer=Node(node_id=0, anonymous=True), + consumer=Node(node_id=ref_taskins.task.producer.node_id, anonymous=False), + created_at=time.time(), + ttl=0, + ancestry=[ref_taskins.task_id], + task_type=ref_taskins.task.task_type, + error=Error(code=3, reason=NODE_UNAVAILABLE_ERROR_REASON), + ), + )