Skip to content

Commit

Permalink
auto generate task_res for unavailable node
Browse files Browse the repository at this point in the history
  • Loading branch information
panh99 committed Apr 2, 2024
1 parent 71f74d9 commit 84ba2e5
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 25 deletions.
1 change: 1 addition & 0 deletions src/py/flwr/common/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ class ErrorCode:
UNKNOWN = 0
LOAD_CLIENT_APP_EXCEPTION = 1
CLIENT_APP_RAISED_EXCEPTION = 2
NODE_UNAVAILABLE = 3

def __new__(cls) -> ErrorCode:
"""Prevent instantiation."""
Expand Down
6 changes: 3 additions & 3 deletions src/py/flwr/server/superlink/state/sqlite_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,20 +411,20 @@ def get_task_res(self, task_ids: Set[UUID], limit: Optional[int]) -> List[TaskRe

result = [dict_to_task_res(row) for row in rows]

# 1. Query: Fetch producer_node_id of remaining task_ids
# 1. Query: Fetch consumer_node_id of remaining task_ids
# Assume the ancestry field only contains one element
data.clear()
replied_task_ids: Set[UUID] = {UUID(str(row["ancestry"])) for row in rows}
remaining_task_ids = task_ids - replied_task_ids
placeholders = ",".join([f":id_{i}" for i in range(len(remaining_task_ids))])
query = f"""
SELECT producer_node_id
SELECT consumer_node_id
FROM task_ins
WHERE task_id IN ({placeholders});
"""
for index, task_id in enumerate(remaining_task_ids):
data[f"id_{index}"] = str(task_id)
node_ids = [int(row["producer_node_id"]) for row in self.query(query, data)]
node_ids = [int(row["consumer_node_id"]) for row in self.query(query, data)]

# 2. Query: Select offline nodes
placeholders = ",".join([f":id_{i}" for i in range(len(node_ids))])
Expand Down
45 changes: 32 additions & 13 deletions src/py/flwr/server/superlink/state/state_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from uuid import uuid4

from flwr.common import DEFAULT_TTL
from flwr.common.constant import ErrorCode
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
from flwr.proto.recordset_pb2 import RecordSet # pylint: disable=E0611
from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611
Expand Down Expand Up @@ -412,30 +413,48 @@ def test_acknowledge_ping(self) -> None:

# Assert
self.assertSetEqual(actual_node_ids, set(node_ids[70:]))

def test_node_unavailable_error(self) -> None:
"""Test if get_task_res return TaskRes containing node unavailable error."""
# Prepare
state: State = self.state_factory()
run_id = state.create_run()
node_id_0 = state.create_node()
node_id_1 = state.create_node()
# state.acknowledge_ping(node_ids[0], ping_interval=30)
# state.acknowledge_ping(node_ids[1], ping_interval=90)
# task_ins_0 = create_task_ins(
# consumer_node_id=consumer_node_id, anonymous=False, run_id=run_id
# )
# task_ins_1 = create_task_ins(
# consumer_node_id=consumer_node_id, anonymous=False, run_id=run_id
# )
node_id_0 = state.create_node(ping_interval=90)
node_id_1 = state.create_node(ping_interval=30)
# Create and store TaskIns
task_ins_0 = create_task_ins(
consumer_node_id=node_id_0, anonymous=False, run_id=run_id
)
task_ins_1 = create_task_ins(
consumer_node_id=node_id_1, anonymous=False, run_id=run_id
)
task_id_0 = state.store_task_ins(task_ins=task_ins_0)
task_id_1 = state.store_task_ins(task_ins=task_ins_1)
assert task_id_0 is not None and task_id_1 is not None

# Get TaskIns to mark them delivered
state.get_task_ins(node_id=node_id_0, limit=None)

# Create and store TaskRes
task_res_0 = create_task_res(
producer_node_id=100,
anonymous=False,
ancestry=[str(task_id_0)],
run_id=run_id,
)
state.store_task_res(task_res_0)

# Execute
current_time = time.time()
task_res_list: List[TaskRes] = []
with patch("time.time", side_effect=lambda: current_time + 50):
actual_node_ids = state.get_nodes(run_id)
task_res_list = state.get_task_res({task_id_0, task_id_1}, limit=None)

# Assert
self.assertSetEqual(actual_node_ids, set(node_ids[70:]))
assert len(task_res_list) == 2
err_taskres = task_res_list[1]
assert err_taskres.task.HasField("error")
assert err_taskres.task.error.code == ErrorCode.NODE_UNAVAILABLE


def create_task_ins(
Expand Down
17 changes: 8 additions & 9 deletions src/py/flwr/server/superlink/state/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,13 @@
import time
from uuid import uuid4

# pylint: disable=E0611
from flwr.proto.error_pb2 import Error
from flwr.proto.node_pb2 import Node
from flwr.proto.task_pb2 import Task, TaskIns, TaskRes

# pylint: enable=E0611

from flwr.common.constant import ErrorCode
from flwr.proto.error_pb2 import Error # pylint: disable=E0611
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611

NODE_UNAVAILABLE_ERROR_REASON = (
"Error: Node Unavailable - The destination node is currently unavailable"
"Error: Node Unavailable - The destination node is currently unavailable."
)


Expand All @@ -45,6 +42,8 @@ def make_node_unavailable_taskres(ref_taskins: TaskIns) -> TaskRes:
ttl=0,
ancestry=[ref_taskins.task_id],
task_type=ref_taskins.task.task_type,
error=Error(code=3, reason=NODE_UNAVAILABLE_ERROR_REASON),
error=Error(
code=ErrorCode.NODE_UNAVAILABLE, reason=NODE_UNAVAILABLE_ERROR_REASON
),
),
)

0 comments on commit 84ba2e5

Please sign in to comment.