Skip to content

Commit

Permalink
temp
Browse files Browse the repository at this point in the history
  • Loading branch information
panh99 committed Apr 2, 2024
1 parent 0173567 commit e7ea7cb
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 5 deletions.
27 changes: 23 additions & 4 deletions src/py/flwr/server/superlink/state/in_memory_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
from flwr.server.superlink.state.state import State
from flwr.server.utils import validate_task_ins_or_res

from .utils import make_node_unavailable_taskres


class InMemoryState(State):
"""In-memory State implementation."""
Expand Down Expand Up @@ -129,14 +131,31 @@ def get_task_res(self, task_ids: Set[UUID], limit: Optional[int]) -> List[TaskRe
with self.lock:
# Find TaskRes that were not delivered yet
task_res_list: List[TaskRes] = []
replied_task_ids: Set[UUID] = set()
for _, task_res in self.task_res_store.items():
if (
UUID(task_res.task.ancestry[0]) in task_ids
and task_res.task.delivered_at == ""
):
reply_to = UUID(task_res.task.ancestry[0])
if reply_to in task_ids and task_res.task.delivered_at == "":
task_res_list.append(task_res)
replied_task_ids.add(reply_to)
if limit and len(task_res_list) == limit:
break

# Check if the node is offline
for task_id in task_ids - replied_task_ids:
if limit and len(task_res_list) == limit:
break
task_ins = self.task_ins_store.get(task_id)
if task_ins is None:
continue
node_id = task_ins.task.consumer.node_id
online_until, _ = self.node_ids[node_id]
# Generate a TaskRes containing an error reply if the node is offline.
if online_until < time.time():
err_taskres = make_node_unavailable_taskres(
ref_taskins=task_ins,
)
self.task_res_store[UUID(err_taskres.task_id)] = err_taskres
task_res_list.append(err_taskres)

# Mark all of them as delivered
delivered_at = now().isoformat()
Expand Down
52 changes: 51 additions & 1 deletion src/py/flwr/server/superlink/state/sqlite_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from flwr.server.utils.validator import validate_task_ins_or_res

from .state import State
from .utils import make_node_unavailable_taskres

SQL_CREATE_TABLE_NODE = """
CREATE TABLE IF NOT EXISTS node(
Expand Down Expand Up @@ -344,6 +345,7 @@ def store_task_res(self, task_res: TaskRes) -> Optional[UUID]:

return task_id

# pylint: disable-next=R0914
def get_task_res(self, task_ids: Set[UUID], limit: Optional[int]) -> List[TaskRes]:
"""Get TaskRes for task_ids.
Expand Down Expand Up @@ -374,7 +376,7 @@ def get_task_res(self, task_ids: Set[UUID], limit: Optional[int]) -> List[TaskRe
AND delivered_at = ""
"""

data: Dict[str, Union[str, int]] = {}
data: Dict[str, Union[str, float, int]] = {}

if limit is not None:
query += " LIMIT :limit"
Expand Down Expand Up @@ -408,6 +410,54 @@ def get_task_res(self, task_ids: Set[UUID], limit: Optional[int]) -> List[TaskRe
rows = self.query(query, data)

result = [dict_to_task_res(row) for row in rows]

# 1. Query: Fetch producer_node_id of remaining task_ids
# Assume the ancestry field only contains one element
data.clear()
replied_task_ids: Set[UUID] = {UUID(str(row["ancestry"])) for row in rows}
remaining_task_ids = task_ids - replied_task_ids
placeholders = ",".join([f":id_{i}" for i in range(len(remaining_task_ids))])
query = f"""
SELECT producer_node_id
FROM task_ins
WHERE task_id IN ({placeholders});
"""
for index, task_id in enumerate(remaining_task_ids):
data[f"id_{index}"] = str(task_id)
node_ids = [int(row["producer_node_id"]) for row in self.query(query, data)]

# 2. Query: Select offline nodes
placeholders = ",".join([f":id_{i}" for i in range(len(node_ids))])
query = f"""
SELECT node_id
FROM node
WHERE node_id IN ({placeholders})
AND online_until < :time;
"""
data = {f"id_{i}": str(node_id) for i, node_id in enumerate(node_ids)}
data["time"] = time.time()
offline_node_ids = [int(row["node_id"]) for row in self.query(query, data)]

# 3. Query: Select TaskIns for offline nodes
placeholders = ",".join([f":id_{i}" for i in range(len(offline_node_ids))])
query = f"""
SELECT *
FROM task_ins
WHERE consumer_node_id IN ({placeholders});
"""
data = {f"id_{i}": str(node_id) for i, node_id in enumerate(offline_node_ids)}
task_ins_rows = self.query(query, data)

# Make TaskRes containing node unavailabe error
for row in task_ins_rows:
if limit and len(result) == limit:
break
task_ins = dict_to_task_ins(row)
err_taskres = make_node_unavailable_taskres(
ref_taskins=task_ins,
)
result.append(err_taskres)

return result

def num_task_ins(self) -> int:
Expand Down
24 changes: 24 additions & 0 deletions src/py/flwr/server/superlink/state/state_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,30 @@ def test_acknowledge_ping(self) -> None:

# Assert
self.assertSetEqual(actual_node_ids, set(node_ids[70:]))

def test_node_unavailable_error(self) -> None:
"""Test if get_task_res return TaskRes containing node unavailable error."""
# Prepare
state: State = self.state_factory()
run_id = state.create_run()
node_id_0 = state.create_node()
node_id_1 = state.create_node()
# state.acknowledge_ping(node_ids[0], ping_interval=30)
# state.acknowledge_ping(node_ids[1], ping_interval=90)
# task_ins_0 = create_task_ins(
# consumer_node_id=consumer_node_id, anonymous=False, run_id=run_id
# )
# task_ins_1 = create_task_ins(
# consumer_node_id=consumer_node_id, anonymous=False, run_id=run_id
# )

# Execute
current_time = time.time()
with patch("time.time", side_effect=lambda: current_time + 50):
actual_node_ids = state.get_nodes(run_id)

# Assert
self.assertSetEqual(actual_node_ids, set(node_ids[70:]))


def create_task_ins(
Expand Down
50 changes: 50 additions & 0 deletions src/py/flwr/server/superlink/state/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""In-memory State implementation."""


import time
from uuid import uuid4

# pylint: disable=E0611
from flwr.proto.error_pb2 import Error
from flwr.proto.node_pb2 import Node
from flwr.proto.task_pb2 import Task, TaskIns, TaskRes

# pylint: enable=E0611


NODE_UNAVAILABLE_ERROR_REASON = (
"Error: Node Unavailable - The destination node is currently unavailable"
)


def make_node_unavailable_taskres(ref_taskins: TaskIns) -> TaskRes:
"""Create a TaskRes containing the node available error based on the reference
TaskIns."""
return TaskRes(
task_id=str(uuid4()),
group_id=ref_taskins.group_id,
run_id=ref_taskins.run_id,
task=Task(
producer=Node(node_id=0, anonymous=True),
consumer=Node(node_id=ref_taskins.task.producer.node_id, anonymous=False),
created_at=time.time(),
ttl=0,
ancestry=[ref_taskins.task_id],
task_type=ref_taskins.task.task_type,
error=Error(code=3, reason=NODE_UNAVAILABLE_ERROR_REASON),
),
)

0 comments on commit e7ea7cb

Please sign in to comment.