diff --git a/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py b/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py index 1cc3a8f128b6..5c99e310f579 100644 --- a/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py +++ b/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py @@ -42,7 +42,7 @@ ) from flwr.common.recordset_compat import getpropertiesins_to_recordset from flwr.common.serde import message_from_taskres, message_to_taskins -from flwr.common.typing import Run +from flwr.common.typing import Run, RunStatus from flwr.server.superlink.fleet.vce.vce_api import ( NodeToPartitionMapping, _register_nodes, @@ -113,12 +113,19 @@ def register_messages_into_state( ) -> dict[UUID, float]: """Register `num_messages` into the state factory.""" state: InMemoryState = state_factory.state() # type: ignore - state.run_ids[run_id] = Run( - run_id=run_id, - fab_id="Mock/mock", - fab_version="v1.0.0", - fab_hash="hash", - override_config={}, + state.run_ids[run_id] = ( + Run( + run_id=run_id, + fab_id="Mock/mock", + fab_version="v1.0.0", + fab_hash="hash", + override_config={}, + ), + RunStatus( + phase="starting", + result="", + reason="", + ), ) # Artificially add TaskIns to state so they can be processed # by the Simulation Engine logic diff --git a/src/py/flwr/server/superlink/state/in_memory_state.py b/src/py/flwr/server/superlink/state/in_memory_state.py index 87e65e93e93f..36aebaad8d50 100644 --- a/src/py/flwr/server/superlink/state/in_memory_state.py +++ b/src/py/flwr/server/superlink/state/in_memory_state.py @@ -293,13 +293,19 @@ def create_run( run_id = generate_rand_int_from_bytes(RUN_ID_NUM_BYTES) if run_id not in self.run_ids: - self.run_ids[run_id] = Run( + run = Run( run_id=run_id, fab_id=fab_id if fab_id else "", fab_version=fab_version if fab_version else "", fab_hash=fab_hash if fab_hash else "", override_config=override_config, ) + initial_status = RunStatus( + phase="starting", + result="", + reason="", + ) + self.run_ids[run_id] = (run, initial_status) return run_id log(ERROR, "Unexpected run creation failure.") return 0 @@ -343,7 +349,7 @@ def get_run(self, run_id: int) -> Optional[Run]: if run_id not in self.run_ids: log(ERROR, "`run_id` is invalid") return None - return self.run_ids[run_id] + return self.run_ids[run_id][0] def get_run_status(self, run_ids: set[int]) -> dict[int, RunStatus]: """Get the status of the run with the specified `run_id`.""" @@ -375,11 +381,11 @@ def update_run_status(self, run_id: int, new_status: RunStatus) -> bool: # Check if the result is valid if not has_valid_result(status): - log(ERROR, 'Invalid run status: "%s:"', status.phase, status.result) + log(ERROR, 'Invalid run status: "%s:%s"', status.phase, status.result) return False # Update the status - self.run_ids[run_id][1] = new_status + self.run_ids[run_id] = (self.run_ids[run_id][0], new_status) return True def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool: diff --git a/src/py/flwr/server/superlink/state/sqlite_state.py b/src/py/flwr/server/superlink/state/sqlite_state.py index 4bb31fa6cea5..586fff588eda 100644 --- a/src/py/flwr/server/superlink/state/sqlite_state.py +++ b/src/py/flwr/server/superlink/state/sqlite_state.py @@ -26,14 +26,19 @@ from flwr.common import log, now from flwr.common.constant import NODE_ID_NUM_BYTES, RUN_ID_NUM_BYTES -from flwr.common.typing import Run, UserConfig +from flwr.common.typing import Run, RunStatus, UserConfig from flwr.proto.node_pb2 import Node # pylint: disable=E0611 from flwr.proto.recordset_pb2 import RecordSet # pylint: disable=E0611 from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611 from flwr.server.utils.validator import validate_task_ins_or_res from .state import State -from .utils import generate_rand_int_from_bytes, make_node_unavailable_taskres +from .utils import ( + generate_rand_int_from_bytes, + has_valid_result, + is_valid_transition, + make_node_unavailable_taskres, +) SQL_CREATE_TABLE_NODE = """ CREATE TABLE IF NOT EXISTS node( @@ -67,7 +72,10 @@ fab_id TEXT, fab_version TEXT, fab_hash TEXT, - override_config TEXT + override_config TEXT, + status_phase TEXT, + status_result TEXT, + status_reason TEXT ); """ @@ -634,18 +642,15 @@ def create_run( if self.query(query, (run_id,))[0]["COUNT(*)"] == 0: query = ( "INSERT INTO run " - "(run_id, fab_id, fab_version, fab_hash, override_config)" - "VALUES (?, ?, ?, ?, ?);" + "(run_id, fab_id, fab_version, fab_hash, override_config, " + "status_phase, status_result, status_reason)" + "VALUES (?, ?, ?, ?, ?, ?, ?, ?);" ) if fab_hash: - self.query( - query, (run_id, "", "", fab_hash, json.dumps(override_config)) - ) - else: - self.query( - query, - (run_id, fab_id, fab_version, "", json.dumps(override_config)), - ) + fab_id, fab_version = "", "" + data = [run_id, fab_id, fab_version, fab_hash, json.dumps(override_config)] + data += ["starting", "", ""] + self.query(query, tuple(data)) return run_id log(ERROR, "Unexpected run creation failure.") return 0 @@ -719,6 +724,58 @@ def get_run(self, run_id: int) -> Optional[Run]: log(ERROR, "`run_id` does not exist.") return None + def get_run_status(self, run_ids: set[int]) -> dict[int, RunStatus]: + """Get the status of the run with the specified `run_id`.""" + query = f"SELECT * FROM run WHERE run_id IN ({','.join(['?'] * len(run_ids))});" + rows = self.query(query, tuple(run_ids)) + + return { + row["run_id"]: RunStatus( + phase=row["status_phase"], + result=row["status_result"], + reason=row["status_reason"], + ) + for row in rows + } + + def update_run_status(self, run_id: int, new_status: RunStatus) -> bool: + """Update the status of the run with the specified `run_id`.""" + query = "SELECT * FROM run WHERE run_id = ?;" + rows = self.query(query, (run_id,)) + + # Check if the run_id exists + if not rows: + log(ERROR, "`run_id` is invalid") + return False + + # Check if the status transition is valid + row = rows[0] + status = RunStatus( + phase=row["status_phase"], + result=row["status_result"], + reason=row["status_reason"], + ) + if not is_valid_transition(status, new_status): + log( + ERROR, + 'Invalid status transition: from "%s" to "%s"', + status.phase, + new_status.phase, + ) + return False + + # Check if the result is valid + if not has_valid_result(status): + log(ERROR, 'Invalid run status: "%s:%s"', status.phase, status.result) + return False + + # Update the status + query = "UPDATE run SET status_phase = ?, status_result = ?, status_reason = ? " + query += "WHERE run_id = ?;" + data = (new_status.phase, new_status.result, new_status.reason, run_id) + self.query(query, data) + return True + def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool: """Acknowledge a ping received from a node, serving as a heartbeat.""" # Update `online_until` and `ping_interval` for the given `node_id` diff --git a/src/py/flwr/server/superlink/state/state_test.py b/src/py/flwr/server/superlink/state/state_test.py index 42c0768f1c7d..bacd49abcf83 100644 --- a/src/py/flwr/server/superlink/state/state_test.py +++ b/src/py/flwr/server/superlink/state/state_test.py @@ -30,6 +30,7 @@ private_key_to_bytes, public_key_to_bytes, ) +from flwr.common.typing import RunStatus from flwr.proto.node_pb2 import Node # pylint: disable=E0611 from flwr.proto.recordset_pb2 import RecordSet # pylint: disable=E0611 from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611 @@ -62,6 +63,64 @@ def test_create_and_get_run(self) -> None: assert run.fab_hash == "9f86d08" assert run.override_config["test_key"] == "test_value" + def test_get_and_update_run_status(self) -> None: + """Test if get_run_status and update_run_status work correctly.""" + # Prepare + state = self.state_factory() + run_id1 = state.create_run(None, None, "9f86d08", {"test_key": "test_value"}) + run_id2 = state.create_run(None, None, "fffffff", {"mock_key": "mock_value"}) + state.update_run_status(run_id2, RunStatus("running", "", "")) + + # Execute + run_status_dict = state.get_run_status({run_id1, run_id2}) + status1 = run_status_dict[run_id1] + status2 = run_status_dict[run_id2] + + # Assert + assert status1.phase == "starting" + assert status2.phase == "running" + + def test_status_transition_valid(self) -> None: + """Test valid run status transactions.""" + # Prepare + state = self.state_factory() + run_id = state.create_run(None, None, "9f86d08", {"test_key": "test_value"}) + + # Execute and assert + status1 = state.get_run_status({run_id})[run_id] + assert state.update_run_status(run_id, RunStatus("running", "", "")) + status2 = state.get_run_status({run_id})[run_id] + assert state.update_run_status( + run_id, RunStatus("finished", "failed", "mock failure") + ) + status3 = state.get_run_status({run_id})[run_id] + + assert status1.phase == "starting" + assert status2.phase == "running" + assert status3.phase == "finished" + + def test_status_transition_invalid(self) -> None: + """Test invalid run status transitions.""" + # Prepare + state = self.state_factory() + run_id = state.create_run(None, None, "9f86d08", {"test_key": "test_value"}) + + # Execute and assert + # Cannot transition to "starting" or "finished" from "starting" + assert not state.update_run_status(run_id, RunStatus("starting", "", "")) + assert not state.update_run_status( + run_id, RunStatus("finished", "completed", "") + ) + state.update_run_status(run_id, RunStatus("running", "", "")) + # Cannot transition to "starting" or "running" from "running" + assert not state.update_run_status(run_id, RunStatus("starting", "", "")) + assert not state.update_run_status(run_id, RunStatus("running", "", "")) + state.update_run_status(run_id, RunStatus("finished", "completed", "")) + # Cannot transition to any status from "finished" + assert not state.update_run_status(run_id, RunStatus("starting", "", "")) + assert not state.update_run_status(run_id, RunStatus("running", "", "")) + assert not state.update_run_status(run_id, RunStatus("finished", "failed", "")) + def test_get_task_ins_empty(self) -> None: """Validate that a new state has no TaskIns.""" # Prepare