Skip to content

Commit

Permalink
add RunStatus dataclass, amend State, implement new methods in InMemo…
Browse files Browse the repository at this point in the history
…ryState
  • Loading branch information
panh99 committed Sep 16, 2024
1 parent 1d15221 commit 60ff84b
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 4 deletions.
9 changes: 9 additions & 0 deletions src/py/flwr/common/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,15 @@ class Run:
override_config: UserConfig


@dataclass
class RunStatus:
"""Run status."""

phase: str
result: str
reason: str


@dataclass
class Fab:
"""Fab file representation."""
Expand Down
48 changes: 45 additions & 3 deletions src/py/flwr/server/superlink/state/in_memory_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,17 @@

from flwr.common import log, now
from flwr.common.constant import NODE_ID_NUM_BYTES, RUN_ID_NUM_BYTES
from flwr.common.typing import Run, UserConfig
from flwr.common.typing import Run, RunStatus, UserConfig
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
from flwr.server.superlink.state.state import State
from flwr.server.utils import validate_task_ins_or_res

from .utils import generate_rand_int_from_bytes, make_node_unavailable_taskres
from .utils import (
generate_rand_int_from_bytes,
has_valid_result,
is_valid_transition,
make_node_unavailable_taskres,
)


class InMemoryState(State): # pylint: disable=R0902,R0904
Expand All @@ -41,7 +46,7 @@ def __init__(self) -> None:
self.public_key_to_node_id: dict[bytes, int] = {}

# Map run_id to (fab_id, fab_version)
self.run_ids: dict[int, Run] = {}
self.run_ids: dict[int, tuple[Run, RunStatus]] = {}
self.task_ins_store: dict[UUID, TaskIns] = {}
self.task_res_store: dict[UUID, TaskRes] = {}

Expand Down Expand Up @@ -340,6 +345,43 @@ def get_run(self, run_id: int) -> Optional[Run]:
return None
return self.run_ids[run_id]

def get_run_status(self, run_ids: set[int]) -> dict[int, RunStatus]:
"""Get the status of the run with the specified `run_id`."""
with self.lock:
return {
run_id: self.run_ids[run_id][1]
for run_id in run_ids
if run_id in self.run_ids
}

def update_run_status(self, run_id: int, new_status: RunStatus) -> bool:
"""Update the status of the run with the specified `run_id`."""
with self.lock:
# Check if the run_id exists
if run_id not in self.run_ids:
log(ERROR, "`run_id` is invalid")
return False

# Check if the status transition is valid
status = self.run_ids[run_id][1]
if not is_valid_transition(status, new_status):
log(
ERROR,
'Invalid status transition: from "%s" to "%s"',
status.phase,
new_status.phase,
)
return False

# Check if the result is valid
if not has_valid_result(status):
log(ERROR, 'Invalid run status: "%s:"', status.phase, status.result)
return False

# Update the status
self.run_ids[run_id][1] = new_status
return True

def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
"""Acknowledge a ping received from a node, serving as a heartbeat."""
with self.lock:
Expand Down
28 changes: 27 additions & 1 deletion src/py/flwr/server/superlink/state/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from typing import Optional
from uuid import UUID

from flwr.common.typing import Run, UserConfig
from flwr.common.typing import Run, RunStatus, UserConfig
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611


Expand Down Expand Up @@ -184,6 +184,32 @@ def get_run(self, run_id: int) -> Optional[Run]:
- `fab_version`: The version of the FAB used in the specified run.
"""

@abc.abstractmethod
def get_run_status(self, run_ids: set[int]) -> dict[int, RunStatus]:
"""Get the status of the run with the specified `run_id`.
Parameters
----------
run_ids : set[int]
"""

@abc.abstractmethod
def update_run_status(self, run_id: int, new_status: RunStatus) -> bool:
"""Update the status of the run with the specified `run_id`.
Parameters
----------
run_id : int
The identifier of the run.
new_status : RunStatus
The new status to be assigned to the run.
Returns
-------
bool
True if the status update is successful; False otherwise.
"""

@abc.abstractmethod
def store_server_private_public_key(
self, private_key: bytes, public_key: bytes
Expand Down
45 changes: 45 additions & 0 deletions src/py/flwr/server/superlink/state/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from flwr.common import log
from flwr.common.constant import ErrorCode
from flwr.common.typing import RunStatus
from flwr.proto.error_pb2 import Error # pylint: disable=E0611
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611
Expand All @@ -31,6 +32,9 @@
"It exceeds the time limit specified in its last ping."
)

VALID_RUN_STATUS_PHASE_TRANSITIONS = {("starting", "running"), ("running", "finished")}
VALID_RUN_STATUS_RESULTS = {"completed", "failed", "stopped", ""}


def generate_rand_int_from_bytes(num_bytes: int) -> int:
"""Generate a random `num_bytes` integer."""
Expand Down Expand Up @@ -60,3 +64,44 @@ def make_node_unavailable_taskres(ref_taskins: TaskIns) -> TaskRes:
),
),
)


def is_valid_transition(current_status: RunStatus, new_status: RunStatus) -> bool:
"""Check if a transition between two run statuses is valid.
Parameters
----------
current_status : RunStatus
The current status of the run.
new_status : RunStatus
The new status to transition to.
Returns
-------
bool
True if the transition is valid, False otherwise.
"""
return (
current_status.phase,
new_status.phase,
) in VALID_RUN_STATUS_PHASE_TRANSITIONS


def has_valid_result(status: RunStatus) -> bool:
"""Check if the 'result' field of the given status is valid.
Parameters
----------
status : RunStatus
The run status object to be checked.
Returns
-------
bool
True if the status has a valid result, False otherwise.
Notes
-----
An empty string (i.e., "") is considered a valid result.
"""
return status.result in VALID_RUN_STATUS_RESULTS

0 comments on commit 60ff84b

Please sign in to comment.