From 8765100ef2c309f4c92d7751713d15e256d2feb2 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Mon, 23 Sep 2024 16:48:46 +0100 Subject: [PATCH] add timestamps for in memory state --- .../superlink/fleet/vce/vce_api_test.py | 6 +- .../server/superlink/state/in_memory_state.py | 58 +++++++++++++------ src/py/flwr/simulation/run_simulation.py | 11 +++- 3 files changed, 54 insertions(+), 21 deletions(-) diff --git a/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py b/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py index 1985e311e523..dce7347d6e74 100644 --- a/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py +++ b/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py @@ -50,6 +50,7 @@ start_vce, ) from flwr.server.superlink.state import InMemoryState, StateFactory +from flwr.server.superlink.state.in_memory_state import RunRecord class DummyClient(NumPyClient): @@ -114,7 +115,7 @@ def register_messages_into_state( ) -> dict[UUID, float]: """Register `num_messages` into the state factory.""" state: InMemoryState = state_factory.state() # type: ignore - state.run_ids[run_id] = ( + state.run_ids[run_id] = RunRecord( Run( run_id=run_id, fab_id="Mock/mock", @@ -127,6 +128,9 @@ def register_messages_into_state( sub_status="", details="", ), + "", + "", + "", ) # Artificially add TaskIns to state so they can be processed # by the Simulation Engine logic diff --git a/src/py/flwr/server/superlink/state/in_memory_state.py b/src/py/flwr/server/superlink/state/in_memory_state.py index 5ca510338316..d4df4faabba6 100644 --- a/src/py/flwr/server/superlink/state/in_memory_state.py +++ b/src/py/flwr/server/superlink/state/in_memory_state.py @@ -17,6 +17,7 @@ import threading import time +from dataclasses import dataclass from logging import ERROR from typing import Optional from uuid import UUID, uuid4 @@ -36,6 +37,17 @@ ) +@dataclass +class RunRecord: + """The record of a specific run, including its status and timestamps.""" + + run: Run + status: RunStatus + starting_at: str + running_at: str + finished_at: str + + class InMemoryState(State): # pylint: disable=R0902,R0904 """In-memory State implementation.""" @@ -45,8 +57,8 @@ def __init__(self) -> None: self.node_ids: dict[int, tuple[float, float]] = {} self.public_key_to_node_id: dict[bytes, int] = {} - # Map run_id to (fab_id, fab_version) - self.run_ids: dict[int, tuple[Run, RunStatus]] = {} + # Map run_id to RunRecord + self.run_ids: dict[int, RunRecord] = {} self.task_ins_store: dict[UUID, TaskIns] = {} self.task_res_store: dict[UUID, TaskRes] = {} @@ -293,19 +305,24 @@ def create_run( run_id = generate_rand_int_from_bytes(RUN_ID_NUM_BYTES) if run_id not in self.run_ids: - run = Run( - run_id=run_id, - fab_id=fab_id if fab_id else "", - fab_version=fab_version if fab_version else "", - fab_hash=fab_hash if fab_hash else "", - override_config=override_config, - ) - initial_status_info = RunStatus( - status=Status.STARTING, - sub_status="", - details="", + run_record = RunRecord( + run=Run( + run_id=run_id, + fab_id=fab_id if fab_id else "", + fab_version=fab_version if fab_version else "", + fab_hash=fab_hash if fab_hash else "", + override_config=override_config, + ), + status=RunStatus( + status=Status.STARTING, + sub_status="", + details="", + ), + starting_at=now().isoformat(), + running_at="", + finished_at="", ) - self.run_ids[run_id] = (run, initial_status_info) + self.run_ids[run_id] = run_record return run_id log(ERROR, "Unexpected run creation failure.") return 0 @@ -349,13 +366,13 @@ def get_run(self, run_id: int) -> Optional[Run]: if run_id not in self.run_ids: log(ERROR, "`run_id` is invalid") return None - return self.run_ids[run_id][0] + return self.run_ids[run_id].run def get_run_status(self, run_ids: set[int]) -> dict[int, RunStatus]: """Retrieve the statuses for the specified runs.""" with self.lock: return { - run_id: self.run_ids[run_id][1] + run_id: self.run_ids[run_id].status for run_id in run_ids if run_id in self.run_ids } @@ -369,7 +386,7 @@ def update_run_status(self, run_id: int, new_status: RunStatus) -> bool: return False # Check if the status transition is valid - status = self.run_ids[run_id][1] + status = self.run_ids[run_id].status if not is_valid_transition(status, new_status): log( ERROR, @@ -390,7 +407,12 @@ def update_run_status(self, run_id: int, new_status: RunStatus) -> bool: return False # Update the status - self.run_ids[run_id] = (self.run_ids[run_id][0], new_status) + run_record = self.run_ids[run_id] + if new_status.status == Status.RUNNING: + run_record.running_at = now().isoformat() + elif new_status.status == Status.FINISHED: + run_record.finished_at = now().isoformat() + run_record.status = new_status return True def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool: diff --git a/src/py/flwr/simulation/run_simulation.py b/src/py/flwr/simulation/run_simulation.py index 2a47c8cdd2b5..6329e3186492 100644 --- a/src/py/flwr/simulation/run_simulation.py +++ b/src/py/flwr/simulation/run_simulation.py @@ -29,7 +29,7 @@ from flwr.cli.config_utils import load_and_validate from flwr.client import ClientApp -from flwr.common import EventType, event, log +from flwr.common import EventType, event, log, now from flwr.common.config import get_fused_config_from_dir, parse_config_args from flwr.common.constant import RUN_ID_NUM_BYTES, Status from flwr.common.logger import ( @@ -45,6 +45,7 @@ from flwr.server.superlink.fleet import vce from flwr.server.superlink.fleet.vce.backend.backend import BackendConfig from flwr.server.superlink.state import StateFactory +from flwr.server.superlink.state.in_memory_state import RunRecord from flwr.server.superlink.state.utils import generate_rand_int_from_bytes from flwr.simulation.ray_transport.utils import ( enable_tf_gpu_growth as enable_gpu_growth, @@ -400,7 +401,13 @@ def _main_loop( # Register run log(DEBUG, "Pre-registering run with id %s", run.run_id) init_status = RunStatus(Status.RUNNING, "", "") - state_factory.state().run_ids[run.run_id] = (run, init_status) # type: ignore + state_factory.state().run_ids[run.run_id] = RunRecord( # type: ignore + run=run, + status=init_status, + starting_at=now().isoformat(), + running_at=now().isoformat(), + finished_at="", + ) if server_app_run_config is None: server_app_run_config = {}