Skip to content

Commit

Permalink
Merge branch 'main' into improve-loadapp-import-mgmt
Browse files Browse the repository at this point in the history
  • Loading branch information
panh99 authored Nov 7, 2024
2 parents cd315c5 + ebe73a4 commit 2757e12
Show file tree
Hide file tree
Showing 5 changed files with 3 additions and 108 deletions.
1 change: 0 additions & 1 deletion src/py/flwr/common/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,6 @@ class ErrorCode:
UNKNOWN = 0
LOAD_CLIENT_APP_EXCEPTION = 1
CLIENT_APP_RAISED_EXCEPTION = 2
NODE_UNAVAILABLE = 3

def __new__(cls) -> ErrorCode:
"""Prevent instantiation."""
Expand Down
16 changes: 0 additions & 16 deletions src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
generate_rand_int_from_bytes,
has_valid_sub_status,
is_valid_transition,
make_node_unavailable_taskres,
)


Expand Down Expand Up @@ -257,21 +256,6 @@ def get_task_res(self, task_ids: set[UUID]) -> list[TaskRes]:
task_res_list.append(task_res)
replied_task_ids.add(reply_to)

# Check if the node is offline
for task_id in task_ids - replied_task_ids:
task_ins = self.task_ins_store.get(task_id)
if task_ins is None:
continue
node_id = task_ins.task.consumer.node_id
online_until, _ = self.node_ids[node_id]
# Generate a TaskRes containing an error reply if the node is offline.
if online_until < time.time():
err_taskres = make_node_unavailable_taskres(
ref_taskins=task_ins,
)
self.task_res_store[UUID(err_taskres.task_id)] = err_taskres
task_res_list.append(err_taskres)

# Mark all of them as delivered
delivered_at = now().isoformat()
for task_res in task_res_list:
Expand Down
44 changes: 1 addition & 43 deletions src/py/flwr/server/superlink/linkstate/linkstate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from uuid import UUID

from flwr.common import DEFAULT_TTL, ConfigsRecord, Context, RecordSet, now
from flwr.common.constant import ErrorCode, Status, SubStatus
from flwr.common.constant import Status, SubStatus
from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
generate_key_pairs,
private_key_to_bytes,
Expand Down Expand Up @@ -786,48 +786,6 @@ def test_acknowledge_ping(self) -> None:
# Assert
self.assertSetEqual(actual_node_ids, set(node_ids[70:]))

def test_node_unavailable_error(self) -> None:
"""Test if get_task_res return TaskRes containing node unavailable error."""
# Prepare
state: LinkState = self.state_factory()
run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord())
node_id_0 = state.create_node(ping_interval=90)
node_id_1 = state.create_node(ping_interval=30)
# Create and store TaskIns
task_ins_0 = create_task_ins(
consumer_node_id=node_id_0, anonymous=False, run_id=run_id
)
task_ins_1 = create_task_ins(
consumer_node_id=node_id_1, anonymous=False, run_id=run_id
)
task_id_0 = state.store_task_ins(task_ins=task_ins_0)
task_id_1 = state.store_task_ins(task_ins=task_ins_1)
assert task_id_0 is not None and task_id_1 is not None

# Get TaskIns to mark them delivered
state.get_task_ins(node_id=node_id_0, limit=None)

# Create and store TaskRes
task_res_0 = create_task_res(
producer_node_id=node_id_0,
anonymous=False,
ancestry=[str(task_id_0)],
run_id=run_id,
)
state.store_task_res(task_res_0)

# Execute
current_time = time.time()
task_res_list: list[TaskRes] = []
with patch("time.time", side_effect=lambda: current_time + 50):
task_res_list = state.get_task_res({task_id_0, task_id_1})

# Assert
assert len(task_res_list) == 2
err_taskres = task_res_list[1]
assert err_taskres.task.HasField("error")
assert err_taskres.task.error.code == ErrorCode.NODE_UNAVAILABLE

def test_store_task_res_task_ins_expired(self) -> None:
"""Test behavior of store_task_res when the TaskIns it references is expired."""
# Prepare
Expand Down
15 changes: 0 additions & 15 deletions src/py/flwr/server/superlink/linkstate/sqlite_linkstate.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@
generate_rand_int_from_bytes,
has_valid_sub_status,
is_valid_transition,
make_node_unavailable_taskres,
)

SQL_CREATE_TABLE_NODE = """
Expand Down Expand Up @@ -640,20 +639,6 @@ def get_task_res(self, task_ids: set[UUID]) -> list[TaskRes]:
data = {f"id_{i}": str(node_id) for i, node_id in enumerate(offline_node_ids)}
task_ins_rows = self.query(query, data)

# Make TaskRes containing node unavailabe error
for row in task_ins_rows:
for row in rows:
# Convert values from sint64 to uint64
convert_sint64_values_in_dict_to_uint64(
row, ["run_id", "producer_node_id", "consumer_node_id"]
)

task_ins = dict_to_task_ins(row)
err_taskres = make_node_unavailable_taskres(
ref_taskins=task_ins,
)
result.append(err_taskres)

return result

def num_task_ins(self) -> int:
Expand Down
35 changes: 2 additions & 33 deletions src/py/flwr/server/superlink/linkstate/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,15 @@
"""Utility functions for State."""


import time
from logging import ERROR
from os import urandom
from uuid import uuid4

from flwr.common import ConfigsRecord, Context, log, serde
from flwr.common.constant import ErrorCode, Status, SubStatus
from flwr.common import ConfigsRecord, Context, serde
from flwr.common.constant import Status, SubStatus
from flwr.common.typing import RunStatus
from flwr.proto.error_pb2 import Error # pylint: disable=E0611
from flwr.proto.message_pb2 import Context as ProtoContext # pylint: disable=E0611
from flwr.proto.node_pb2 import Node # pylint: disable=E0611

# pylint: disable=E0611
from flwr.proto.recordset_pb2 import ConfigsRecord as ProtoConfigsRecord
from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611

NODE_UNAVAILABLE_ERROR_REASON = (
"Error: Node Unavailable - The destination node is currently unavailable. "
Expand Down Expand Up @@ -161,31 +155,6 @@ def configsrecord_from_bytes(configsrecord_bytes: bytes) -> ConfigsRecord:
)


def make_node_unavailable_taskres(ref_taskins: TaskIns) -> TaskRes:
"""Generate a TaskRes with a node unavailable error from a TaskIns."""
current_time = time.time()
ttl = ref_taskins.task.ttl - (current_time - ref_taskins.task.created_at)
if ttl < 0:
log(ERROR, "Creating TaskRes for TaskIns that exceeds its TTL.")
ttl = 0
return TaskRes(
task_id=str(uuid4()),
group_id=ref_taskins.group_id,
run_id=ref_taskins.run_id,
task=Task(
producer=Node(node_id=ref_taskins.task.consumer.node_id, anonymous=False),
consumer=Node(node_id=ref_taskins.task.producer.node_id, anonymous=False),
created_at=current_time,
ttl=ttl,
ancestry=[ref_taskins.task_id],
task_type=ref_taskins.task.task_type,
error=Error(
code=ErrorCode.NODE_UNAVAILABLE, reason=NODE_UNAVAILABLE_ERROR_REASON
),
),
)


def is_valid_transition(current_status: RunStatus, new_status: RunStatus) -> bool:
"""Check if a transition between two run statuses is valid.
Expand Down

0 comments on commit 2757e12

Please sign in to comment.