Skip to content

Commit

Permalink
add timestamps for in memory state
Browse files Browse the repository at this point in the history
  • Loading branch information
panh99 committed Sep 23, 2024
1 parent 0eb8ad1 commit 8765100
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 21 deletions.
6 changes: 5 additions & 1 deletion src/py/flwr/server/superlink/fleet/vce/vce_api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
start_vce,
)
from flwr.server.superlink.state import InMemoryState, StateFactory
from flwr.server.superlink.state.in_memory_state import RunRecord


class DummyClient(NumPyClient):
Expand Down Expand Up @@ -114,7 +115,7 @@ def register_messages_into_state(
) -> dict[UUID, float]:
"""Register `num_messages` into the state factory."""
state: InMemoryState = state_factory.state() # type: ignore
state.run_ids[run_id] = (
state.run_ids[run_id] = RunRecord(
Run(
run_id=run_id,
fab_id="Mock/mock",
Expand All @@ -127,6 +128,9 @@ def register_messages_into_state(
sub_status="",
details="",
),
"",
"",
"",
)
# Artificially add TaskIns to state so they can be processed
# by the Simulation Engine logic
Expand Down
58 changes: 40 additions & 18 deletions src/py/flwr/server/superlink/state/in_memory_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import threading
import time
from dataclasses import dataclass
from logging import ERROR
from typing import Optional
from uuid import UUID, uuid4
Expand All @@ -36,6 +37,17 @@
)


@dataclass
class RunRecord:
"""The record of a specific run, including its status and timestamps."""

run: Run
status: RunStatus
starting_at: str
running_at: str
finished_at: str


class InMemoryState(State): # pylint: disable=R0902,R0904
"""In-memory State implementation."""

Expand All @@ -45,8 +57,8 @@ def __init__(self) -> None:
self.node_ids: dict[int, tuple[float, float]] = {}
self.public_key_to_node_id: dict[bytes, int] = {}

# Map run_id to (fab_id, fab_version)
self.run_ids: dict[int, tuple[Run, RunStatus]] = {}
# Map run_id to RunRecord
self.run_ids: dict[int, RunRecord] = {}
self.task_ins_store: dict[UUID, TaskIns] = {}
self.task_res_store: dict[UUID, TaskRes] = {}

Expand Down Expand Up @@ -293,19 +305,24 @@ def create_run(
run_id = generate_rand_int_from_bytes(RUN_ID_NUM_BYTES)

if run_id not in self.run_ids:
run = Run(
run_id=run_id,
fab_id=fab_id if fab_id else "",
fab_version=fab_version if fab_version else "",
fab_hash=fab_hash if fab_hash else "",
override_config=override_config,
)
initial_status_info = RunStatus(
status=Status.STARTING,
sub_status="",
details="",
run_record = RunRecord(
run=Run(
run_id=run_id,
fab_id=fab_id if fab_id else "",
fab_version=fab_version if fab_version else "",
fab_hash=fab_hash if fab_hash else "",
override_config=override_config,
),
status=RunStatus(
status=Status.STARTING,
sub_status="",
details="",
),
starting_at=now().isoformat(),
running_at="",
finished_at="",
)
self.run_ids[run_id] = (run, initial_status_info)
self.run_ids[run_id] = run_record
return run_id
log(ERROR, "Unexpected run creation failure.")
return 0
Expand Down Expand Up @@ -349,13 +366,13 @@ def get_run(self, run_id: int) -> Optional[Run]:
if run_id not in self.run_ids:
log(ERROR, "`run_id` is invalid")
return None
return self.run_ids[run_id][0]
return self.run_ids[run_id].run

def get_run_status(self, run_ids: set[int]) -> dict[int, RunStatus]:
"""Retrieve the statuses for the specified runs."""
with self.lock:
return {
run_id: self.run_ids[run_id][1]
run_id: self.run_ids[run_id].status
for run_id in run_ids
if run_id in self.run_ids
}
Expand All @@ -369,7 +386,7 @@ def update_run_status(self, run_id: int, new_status: RunStatus) -> bool:
return False

# Check if the status transition is valid
status = self.run_ids[run_id][1]
status = self.run_ids[run_id].status
if not is_valid_transition(status, new_status):
log(
ERROR,
Expand All @@ -390,7 +407,12 @@ def update_run_status(self, run_id: int, new_status: RunStatus) -> bool:
return False

# Update the status
self.run_ids[run_id] = (self.run_ids[run_id][0], new_status)
run_record = self.run_ids[run_id]
if new_status.status == Status.RUNNING:
run_record.running_at = now().isoformat()
elif new_status.status == Status.FINISHED:
run_record.finished_at = now().isoformat()
run_record.status = new_status
return True

def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
Expand Down
11 changes: 9 additions & 2 deletions src/py/flwr/simulation/run_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

from flwr.cli.config_utils import load_and_validate
from flwr.client import ClientApp
from flwr.common import EventType, event, log
from flwr.common import EventType, event, log, now
from flwr.common.config import get_fused_config_from_dir, parse_config_args
from flwr.common.constant import RUN_ID_NUM_BYTES, Status
from flwr.common.logger import (
Expand All @@ -45,6 +45,7 @@
from flwr.server.superlink.fleet import vce
from flwr.server.superlink.fleet.vce.backend.backend import BackendConfig
from flwr.server.superlink.state import StateFactory
from flwr.server.superlink.state.in_memory_state import RunRecord
from flwr.server.superlink.state.utils import generate_rand_int_from_bytes
from flwr.simulation.ray_transport.utils import (
enable_tf_gpu_growth as enable_gpu_growth,
Expand Down Expand Up @@ -400,7 +401,13 @@ def _main_loop(
# Register run
log(DEBUG, "Pre-registering run with id %s", run.run_id)
init_status = RunStatus(Status.RUNNING, "", "")
state_factory.state().run_ids[run.run_id] = (run, init_status) # type: ignore
state_factory.state().run_ids[run.run_id] = RunRecord( # type: ignore
run=run,
status=init_status,
starting_at=now().isoformat(),
running_at=now().isoformat(),
finished_at="",
)

if server_app_run_config is None:
server_app_run_config = {}
Expand Down

0 comments on commit 8765100

Please sign in to comment.