From 206f578337d445640c87951f79b39efc41afdeff Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Wed, 23 Oct 2024 18:25:23 +0100 Subject: [PATCH] feat(framework) Add `get_run_status` and `update_run_status` to `LinkState` (#4229) Co-authored-by: Daniel J. Beutel Co-authored-by: Javier --- src/py/flwr/common/constant.py | 25 ++++ src/py/flwr/common/typing.py | 9 ++ .../superlink/fleet/vce/vce_api_test.py | 23 ++- .../linkstate/in_memory_linkstate.py | 99 +++++++++++-- .../server/superlink/linkstate/linkstate.py | 39 +++++- .../superlink/linkstate/linkstate_test.py | 79 ++++++++++- .../superlink/linkstate/sqlite_linkstate.py | 131 +++++++++++++++--- .../flwr/server/superlink/linkstate/utils.py | 58 +++++++- src/py/flwr/simulation/run_simulation.py | 16 ++- 9 files changed, 433 insertions(+), 46 deletions(-) diff --git a/src/py/flwr/common/constant.py b/src/py/flwr/common/constant.py index 081fa49b2153..4dc18457a852 100644 --- a/src/py/flwr/common/constant.py +++ b/src/py/flwr/common/constant.py @@ -132,3 +132,28 @@ class ErrorCode: def __new__(cls) -> ErrorCode: """Prevent instantiation.""" raise TypeError(f"{cls.__name__} cannot be instantiated.") + + +class Status: + """Run status.""" + + PENDING = "pending" + STARTING = "starting" + RUNNING = "running" + FINISHED = "finished" + + def __new__(cls) -> Status: + """Prevent instantiation.""" + raise TypeError(f"{cls.__name__} cannot be instantiated.") + + +class SubStatus: + """Run sub-status.""" + + COMPLETED = "completed" + FAILED = "failed" + STOPPED = "stopped" + + def __new__(cls) -> SubStatus: + """Prevent instantiation.""" + raise TypeError(f"{cls.__name__} cannot be instantiated.") diff --git a/src/py/flwr/common/typing.py b/src/py/flwr/common/typing.py index 081a957f28ff..6b07fe2c1c38 100644 --- a/src/py/flwr/common/typing.py +++ b/src/py/flwr/common/typing.py @@ -218,6 +218,15 @@ class Run: override_config: UserConfig +@dataclass +class RunStatus: + """Run status information.""" + + status: str + sub_status: str + details: str + + @dataclass class Fab: """Fab file representation.""" diff --git a/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py b/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py index d14ce86c58c4..b490b0de58d7 100644 --- a/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py +++ b/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py @@ -40,15 +40,17 @@ RecordSet, Scalar, ) +from flwr.common.constant import Status from flwr.common.recordset_compat import getpropertiesins_to_recordset from flwr.common.serde import message_from_taskres, message_to_taskins -from flwr.common.typing import Run +from flwr.common.typing import Run, RunStatus from flwr.server.superlink.fleet.vce.vce_api import ( NodeToPartitionMapping, _register_nodes, start_vce, ) from flwr.server.superlink.linkstate import InMemoryLinkState, LinkStateFactory +from flwr.server.superlink.linkstate.in_memory_linkstate import RunRecord class DummyClient(NumPyClient): @@ -113,12 +115,19 @@ def register_messages_into_state( ) -> dict[UUID, float]: """Register `num_messages` into the state factory.""" state: InMemoryLinkState = state_factory.state() # type: ignore - state.run_ids[run_id] = Run( - run_id=run_id, - fab_id="Mock/mock", - fab_version="v1.0.0", - fab_hash="hash", - override_config={}, + state.run_ids[run_id] = RunRecord( + Run( + run_id=run_id, + fab_id="Mock/mock", + fab_version="v1.0.0", + fab_hash="hash", + override_config={}, + ), + RunStatus( + status=Status.PENDING, + sub_status="", + details="", + ), ) # Artificially add TaskIns to state so they can be processed # by the Simulation Engine logic diff --git a/src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py b/src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py index 8fdb5a1ed9ec..24844eea4e33 100644 --- a/src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py +++ b/src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py @@ -17,6 +17,7 @@ import threading import time +from dataclasses import dataclass from logging import ERROR, WARNING from typing import Optional from uuid import UUID, uuid4 @@ -26,13 +27,31 @@ MESSAGE_TTL_TOLERANCE, NODE_ID_NUM_BYTES, RUN_ID_NUM_BYTES, + Status, ) -from flwr.common.typing import Run, UserConfig +from flwr.common.typing import Run, RunStatus, UserConfig from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611 from flwr.server.superlink.linkstate.linkstate import LinkState from flwr.server.utils import validate_task_ins_or_res -from .utils import generate_rand_int_from_bytes, make_node_unavailable_taskres +from .utils import ( + generate_rand_int_from_bytes, + has_valid_sub_status, + is_valid_transition, + make_node_unavailable_taskres, +) + + +@dataclass +class RunRecord: + """The record of a specific run, including its status and timestamps.""" + + run: Run + status: RunStatus + pending_at: str = "" + starting_at: str = "" + running_at: str = "" + finished_at: str = "" class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904 @@ -44,8 +63,8 @@ def __init__(self) -> None: self.node_ids: dict[int, tuple[float, float]] = {} self.public_key_to_node_id: dict[bytes, int] = {} - # Map run_id to (fab_id, fab_version) - self.run_ids: dict[int, Run] = {} + # Map run_id to RunRecord + self.run_ids: dict[int, RunRecord] = {} self.task_ins_store: dict[UUID, TaskIns] = {} self.task_res_store: dict[UUID, TaskRes] = {} @@ -351,13 +370,22 @@ def create_run( run_id = generate_rand_int_from_bytes(RUN_ID_NUM_BYTES) if run_id not in self.run_ids: - self.run_ids[run_id] = Run( - run_id=run_id, - fab_id=fab_id if fab_id else "", - fab_version=fab_version if fab_version else "", - fab_hash=fab_hash if fab_hash else "", - override_config=override_config, + run_record = RunRecord( + run=Run( + run_id=run_id, + fab_id=fab_id if fab_id else "", + fab_version=fab_version if fab_version else "", + fab_hash=fab_hash if fab_hash else "", + override_config=override_config, + ), + status=RunStatus( + status=Status.PENDING, + sub_status="", + details="", + ), + pending_at=now().isoformat(), ) + self.run_ids[run_id] = run_record return run_id log(ERROR, "Unexpected run creation failure.") return 0 @@ -401,7 +429,56 @@ def get_run(self, run_id: int) -> Optional[Run]: if run_id not in self.run_ids: log(ERROR, "`run_id` is invalid") return None - return self.run_ids[run_id] + return self.run_ids[run_id].run + + def get_run_status(self, run_ids: set[int]) -> dict[int, RunStatus]: + """Retrieve the statuses for the specified runs.""" + with self.lock: + return { + run_id: self.run_ids[run_id].status + for run_id in set(run_ids) + if run_id in self.run_ids + } + + def update_run_status(self, run_id: int, new_status: RunStatus) -> bool: + """Update the status of the run with the specified `run_id`.""" + with self.lock: + # Check if the run_id exists + if run_id not in self.run_ids: + log(ERROR, "`run_id` is invalid") + return False + + # Check if the status transition is valid + current_status = self.run_ids[run_id].status + if not is_valid_transition(current_status, new_status): + log( + ERROR, + 'Invalid status transition: from "%s" to "%s"', + current_status.status, + new_status.status, + ) + return False + + # Check if the sub-status is valid + if not has_valid_sub_status(current_status): + log( + ERROR, + 'Invalid sub-status "%s" for status "%s"', + current_status.sub_status, + current_status.status, + ) + return False + + # Update the status + run_record = self.run_ids[run_id] + if new_status.status == Status.STARTING: + run_record.starting_at = now().isoformat() + elif new_status.status == Status.RUNNING: + run_record.running_at = now().isoformat() + elif new_status.status == Status.FINISHED: + run_record.finished_at = now().isoformat() + run_record.status = new_status + return True def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool: """Acknowledge a ping received from a node, serving as a heartbeat.""" diff --git a/src/py/flwr/server/superlink/linkstate/linkstate.py b/src/py/flwr/server/superlink/linkstate/linkstate.py index e8e254873957..1ea38681a1ad 100644 --- a/src/py/flwr/server/superlink/linkstate/linkstate.py +++ b/src/py/flwr/server/superlink/linkstate/linkstate.py @@ -19,7 +19,7 @@ from typing import Optional from uuid import UUID -from flwr.common.typing import Run, UserConfig +from flwr.common.typing import Run, RunStatus, UserConfig from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611 @@ -178,6 +178,43 @@ def get_run(self, run_id: int) -> Optional[Run]: - `fab_version`: The version of the FAB used in the specified run. """ + @abc.abstractmethod + def get_run_status(self, run_ids: set[int]) -> dict[int, RunStatus]: + """Retrieve the statuses for the specified runs. + + Parameters + ---------- + run_ids : set[int] + A set of run identifiers for which to retrieve statuses. + + Returns + ------- + dict[int, RunStatus] + A dictionary mapping each valid run ID to its corresponding status. + + Notes + ----- + Only valid run IDs that exist in the State will be included in the returned + dictionary. If a run ID is not found, it will be omitted from the result. + """ + + @abc.abstractmethod + def update_run_status(self, run_id: int, new_status: RunStatus) -> bool: + """Update the status of the run with the specified `run_id`. + + Parameters + ---------- + run_id : int + The identifier of the run. + new_status : RunStatus + The new status to be assigned to the run. + + Returns + ------- + bool + True if the status update is successful; False otherwise. + """ + @abc.abstractmethod def store_server_private_public_key( self, private_key: bytes, public_key: bytes diff --git a/src/py/flwr/server/superlink/linkstate/linkstate_test.py b/src/py/flwr/server/superlink/linkstate/linkstate_test.py index dec0a3b705e7..4cbb3c8b2e68 100644 --- a/src/py/flwr/server/superlink/linkstate/linkstate_test.py +++ b/src/py/flwr/server/superlink/linkstate/linkstate_test.py @@ -24,12 +24,13 @@ from uuid import UUID from flwr.common import DEFAULT_TTL -from flwr.common.constant import ErrorCode +from flwr.common.constant import ErrorCode, Status, SubStatus from flwr.common.secure_aggregation.crypto.symmetric_encryption import ( generate_key_pairs, private_key_to_bytes, public_key_to_bytes, ) +from flwr.common.typing import RunStatus from flwr.proto.node_pb2 import Node # pylint: disable=E0611 from flwr.proto.recordset_pb2 import RecordSet # pylint: disable=E0611 from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611 @@ -66,6 +67,82 @@ def test_create_and_get_run(self) -> None: assert run.fab_hash == "9f86d08" assert run.override_config["test_key"] == "test_value" + def test_get_and_update_run_status(self) -> None: + """Test if get_run_status and update_run_status work correctly.""" + # Prepare + state = self.state_factory() + run_id1 = state.create_run(None, None, "9f86d08", {"test_key": "test_value"}) + run_id2 = state.create_run(None, None, "fffffff", {"mock_key": "mock_value"}) + state.update_run_status(run_id2, RunStatus(Status.STARTING, "", "")) + state.update_run_status(run_id2, RunStatus(Status.RUNNING, "", "")) + + # Execute + run_status_dict = state.get_run_status({run_id1, run_id2}) + status1 = run_status_dict[run_id1] + status2 = run_status_dict[run_id2] + + # Assert + assert status1.status == Status.PENDING + assert status2.status == Status.RUNNING + + def test_status_transition_valid(self) -> None: + """Test valid run status transactions.""" + # Prepare + state = self.state_factory() + run_id = state.create_run(None, None, "9f86d08", {"test_key": "test_value"}) + + # Execute and assert + status1 = state.get_run_status({run_id})[run_id] + assert state.update_run_status(run_id, RunStatus(Status.STARTING, "", "")) + status2 = state.get_run_status({run_id})[run_id] + assert state.update_run_status(run_id, RunStatus(Status.RUNNING, "", "")) + status3 = state.get_run_status({run_id})[run_id] + assert state.update_run_status( + run_id, RunStatus(Status.FINISHED, SubStatus.FAILED, "mock failure") + ) + status4 = state.get_run_status({run_id})[run_id] + + assert status1.status == Status.PENDING + assert status2.status == Status.STARTING + assert status3.status == Status.RUNNING + assert status4.status == Status.FINISHED + + def test_status_transition_invalid(self) -> None: + """Test invalid run status transitions.""" + # Prepare + state = self.state_factory() + run_id = state.create_run(None, None, "9f86d08", {"test_key": "test_value"}) + run_statuses = [ + RunStatus(Status.PENDING, "", ""), + RunStatus(Status.STARTING, "", ""), + RunStatus(Status.PENDING, "", ""), + RunStatus(Status.FINISHED, SubStatus.COMPLETED, ""), + RunStatus(Status.FINISHED, SubStatus.FAILED, ""), + RunStatus(Status.FINISHED, SubStatus.STOPPED, ""), + ] + + # Execute and assert + # Cannot transition from RunStatus.PENDING + # to RunStatus.PENDING, RunStatus.RUNNING, or RunStatus.FINISHED + for run_status in [s for s in run_statuses if s.status != Status.STARTING]: + assert not state.update_run_status(run_id, run_status) + state.update_run_status(run_id, RunStatus(Status.STARTING, "", "")) + # Cannot transition from RunStatus.STARTING + # to RunStatus.PENDING, RunStatus.STARTING, or RunStatus.FINISHED + for run_status in [s for s in run_statuses if s.status != Status.RUNNING]: + assert not state.update_run_status(run_id, run_status) + state.update_run_status(run_id, RunStatus(Status.RUNNING, "", "")) + # Cannot transition from RunStatus.RUNNING + # to RunStatus.PENDING, RunStatus.STARTING, or RunStatus.RUNNING + for run_status in [s for s in run_statuses if s.status != Status.FINISHED]: + assert not state.update_run_status(run_id, run_status) + state.update_run_status( + run_id, RunStatus(Status.FINISHED, SubStatus.COMPLETED, "") + ) + # Cannot transition to any status from RunStatus.FINISHED + for run_status in run_statuses: + assert not state.update_run_status(run_id, run_status) + def test_get_task_ins_empty(self) -> None: """Validate that a new state has no TaskIns.""" # Prepare diff --git a/src/py/flwr/server/superlink/linkstate/sqlite_linkstate.py b/src/py/flwr/server/superlink/linkstate/sqlite_linkstate.py index 4344ce8b062d..4b77020b2324 100644 --- a/src/py/flwr/server/superlink/linkstate/sqlite_linkstate.py +++ b/src/py/flwr/server/superlink/linkstate/sqlite_linkstate.py @@ -30,8 +30,9 @@ MESSAGE_TTL_TOLERANCE, NODE_ID_NUM_BYTES, RUN_ID_NUM_BYTES, + Status, ) -from flwr.common.typing import Run, UserConfig +from flwr.common.typing import Run, RunStatus, UserConfig from flwr.proto.node_pb2 import Node # pylint: disable=E0611 from flwr.proto.recordset_pb2 import RecordSet # pylint: disable=E0611 from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611 @@ -44,6 +45,8 @@ convert_uint64_to_sint64, convert_uint64_values_in_dict_to_sint64, generate_rand_int_from_bytes, + has_valid_sub_status, + is_valid_transition, make_node_unavailable_taskres, ) @@ -79,7 +82,13 @@ fab_id TEXT, fab_version TEXT, fab_hash TEXT, - override_config TEXT + override_config TEXT, + pending_at TEXT, + starting_at TEXT, + running_at TEXT, + finished_at TEXT, + sub_status TEXT, + details TEXT ); """ @@ -133,7 +142,7 @@ def __init__( self, database_path: str, ) -> None: - """Initialize an SqliteState. + """Initialize an SqliteLinkState. Parameters ---------- @@ -773,26 +782,16 @@ def create_run( if self.query(query, (sint64_run_id,))[0]["COUNT(*)"] == 0: query = ( "INSERT INTO run " - "(run_id, fab_id, fab_version, fab_hash, override_config)" - "VALUES (?, ?, ?, ?, ?);" + "(run_id, fab_id, fab_version, fab_hash, override_config, pending_at, " + "starting_at, running_at, finished_at, sub_status, details)" + "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);" ) if fab_hash: - self.query( - query, - (sint64_run_id, "", "", fab_hash, json.dumps(override_config)), - ) - else: - self.query( - query, - ( - sint64_run_id, - fab_id, - fab_version, - "", - json.dumps(override_config), - ), - ) - # Note: we need to return the uint64 value of the run_id + fab_id, fab_version = "", "" + override_config_json = json.dumps(override_config) + data = [sint64_run_id, fab_id, fab_version, fab_hash, override_config_json] + data += [now().isoformat(), "", "", "", "", ""] + self.query(query, tuple(data)) return uint64_run_id log(ERROR, "Unexpected run creation failure.") return 0 @@ -868,6 +867,82 @@ def get_run(self, run_id: int) -> Optional[Run]: log(ERROR, "`run_id` does not exist.") return None + def get_run_status(self, run_ids: set[int]) -> dict[int, RunStatus]: + """Retrieve the statuses for the specified runs.""" + # Convert the uint64 value to sint64 for SQLite + sint64_run_ids = (convert_uint64_to_sint64(run_id) for run_id in set(run_ids)) + query = f"SELECT * FROM run WHERE run_id IN ({','.join(['?'] * len(run_ids))});" + rows = self.query(query, tuple(sint64_run_ids)) + + return { + # Restore uint64 run IDs + convert_sint64_to_uint64(row["run_id"]): RunStatus( + status=determine_run_status(row), + sub_status=row["sub_status"], + details=row["details"], + ) + for row in rows + } + + def update_run_status(self, run_id: int, new_status: RunStatus) -> bool: + """Update the status of the run with the specified `run_id`.""" + # Convert the uint64 value to sint64 for SQLite + sint64_run_id = convert_uint64_to_sint64(run_id) + query = "SELECT * FROM run WHERE run_id = ?;" + rows = self.query(query, (sint64_run_id,)) + + # Check if the run_id exists + if not rows: + log(ERROR, "`run_id` is invalid") + return False + + # Check if the status transition is valid + row = rows[0] + current_status = RunStatus( + status=determine_run_status(row), + sub_status=row["sub_status"], + details=row["details"], + ) + if not is_valid_transition(current_status, new_status): + log( + ERROR, + 'Invalid status transition: from "%s" to "%s"', + current_status.status, + new_status.status, + ) + return False + + # Check if the sub-status is valid + if not has_valid_sub_status(current_status): + log( + ERROR, + 'Invalid sub-status "%s" for status "%s"', + current_status.sub_status, + current_status.status, + ) + return False + + # Update the status + query = "UPDATE run SET %s= ?, sub_status = ?, details = ? " + query += "WHERE run_id = ?;" + + timestamp_fld = "" + if new_status.status == Status.STARTING: + timestamp_fld = "starting_at" + elif new_status.status == Status.RUNNING: + timestamp_fld = "running_at" + elif new_status.status == Status.FINISHED: + timestamp_fld = "finished_at" + + data = ( + now().isoformat(), + new_status.sub_status, + new_status.details, + sint64_run_id, + ) + self.query(query % timestamp_fld, data) + return True + def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool: """Acknowledge a ping received from a node, serving as a heartbeat.""" sint64_node_id = convert_uint64_to_sint64(node_id) @@ -1023,3 +1098,17 @@ def dict_to_task_res(task_dict: dict[str, Any]) -> TaskRes: ), ) return result + + +def determine_run_status(row: dict[str, Any]) -> str: + """Determine the status of the run based on timestamp fields.""" + if row["pending_at"]: + if row["starting_at"]: + if row["running_at"]: + if row["finished_at"]: + return Status.FINISHED + return Status.RUNNING + return Status.STARTING + return Status.PENDING + run_id = convert_sint64_to_uint64(row["run_id"]) + raise sqlite3.IntegrityError(f"The run {run_id} does not have a valid status.") diff --git a/src/py/flwr/server/superlink/linkstate/utils.py b/src/py/flwr/server/superlink/linkstate/utils.py index db44719c6a8a..1e5c5de612a5 100644 --- a/src/py/flwr/server/superlink/linkstate/utils.py +++ b/src/py/flwr/server/superlink/linkstate/utils.py @@ -21,7 +21,8 @@ from uuid import uuid4 from flwr.common import log -from flwr.common.constant import ErrorCode +from flwr.common.constant import ErrorCode, Status, SubStatus +from flwr.common.typing import RunStatus from flwr.proto.error_pb2 import Error # pylint: disable=E0611 from flwr.proto.node_pb2 import Node # pylint: disable=E0611 from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611 @@ -31,6 +32,17 @@ "It exceeds the time limit specified in its last ping." ) +VALID_RUN_STATUS_TRANSITIONS = { + (Status.PENDING, Status.STARTING), + (Status.STARTING, Status.RUNNING), + (Status.RUNNING, Status.FINISHED), +} +VALID_RUN_SUB_STATUSES = { + SubStatus.COMPLETED, + SubStatus.FAILED, + SubStatus.STOPPED, +} + def generate_rand_int_from_bytes(num_bytes: int) -> int: """Generate a random unsigned integer from `num_bytes` bytes.""" @@ -146,3 +158,47 @@ def make_node_unavailable_taskres(ref_taskins: TaskIns) -> TaskRes: ), ), ) + + +def is_valid_transition(current_status: RunStatus, new_status: RunStatus) -> bool: + """Check if a transition between two run statuses is valid. + + Parameters + ---------- + current_status : RunStatus + The current status of the run. + new_status : RunStatus + The new status to transition to. + + Returns + ------- + bool + True if the transition is valid, False otherwise. + """ + return ( + current_status.status, + new_status.status, + ) in VALID_RUN_STATUS_TRANSITIONS + + +def has_valid_sub_status(status: RunStatus) -> bool: + """Check if the 'sub_status' field of the given status is valid. + + Parameters + ---------- + status : RunStatus + The status object to be checked. + + Returns + ------- + bool + True if the status object has a valid sub-status, False otherwise. + + Notes + ----- + Only an empty string (i.e., "") is considered a valid sub-status for + non-finished statuses. The sub-status of a finished status cannot be empty. + """ + if status.status == Status.FINISHED: + return status.sub_status in VALID_RUN_SUB_STATUSES + return status.sub_status == "" diff --git a/src/py/flwr/simulation/run_simulation.py b/src/py/flwr/simulation/run_simulation.py index e9b2352e0c0c..15ff6bf7d206 100644 --- a/src/py/flwr/simulation/run_simulation.py +++ b/src/py/flwr/simulation/run_simulation.py @@ -29,22 +29,23 @@ from flwr.cli.config_utils import load_and_validate from flwr.client import ClientApp -from flwr.common import EventType, event, log +from flwr.common import EventType, event, log, now from flwr.common.config import get_fused_config_from_dir, parse_config_args -from flwr.common.constant import RUN_ID_NUM_BYTES +from flwr.common.constant import RUN_ID_NUM_BYTES, Status from flwr.common.logger import ( set_logger_propagation, update_console_handler, warn_deprecated_feature, warn_deprecated_feature_with_example, ) -from flwr.common.typing import Run, UserConfig +from flwr.common.typing import Run, RunStatus, UserConfig from flwr.server.driver import Driver, InMemoryDriver from flwr.server.run_serverapp import run as run_server_app from flwr.server.server_app import ServerApp from flwr.server.superlink.fleet import vce from flwr.server.superlink.fleet.vce.backend.backend import BackendConfig from flwr.server.superlink.linkstate import LinkStateFactory +from flwr.server.superlink.linkstate.in_memory_linkstate import RunRecord from flwr.server.superlink.linkstate.utils import generate_rand_int_from_bytes from flwr.simulation.ray_transport.utils import ( enable_tf_gpu_growth as enable_gpu_growth, @@ -399,7 +400,14 @@ def _main_loop( try: # Register run log(DEBUG, "Pre-registering run with id %s", run.run_id) - state_factory.state().run_ids[run.run_id] = run # type: ignore + init_status = RunStatus(Status.RUNNING, "", "") + state_factory.state().run_ids[run.run_id] = RunRecord( # type: ignore + run=run, + status=init_status, + starting_at=now().isoformat(), + running_at=now().isoformat(), + finished_at="", + ) if server_app_run_config is None: server_app_run_config = {}