diff --git a/src/py/flwr/server/superlink/state/in_memory_state.py b/src/py/flwr/server/superlink/state/in_memory_state.py index 5a184114914d..5ca510338316 100644 --- a/src/py/flwr/server/superlink/state/in_memory_state.py +++ b/src/py/flwr/server/superlink/state/in_memory_state.py @@ -352,7 +352,7 @@ def get_run(self, run_id: int) -> Optional[Run]: return self.run_ids[run_id][0] def get_run_status(self, run_ids: set[int]) -> dict[int, RunStatus]: - """Retrieve the status information for the specified runs.""" + """Retrieve the statuses for the specified runs.""" with self.lock: return { run_id: self.run_ids[run_id][1] @@ -360,7 +360,7 @@ def get_run_status(self, run_ids: set[int]) -> dict[int, RunStatus]: if run_id in self.run_ids } - def update_run_status(self, run_id: int, new_status_info: RunStatus) -> bool: + def update_run_status(self, run_id: int, new_status: RunStatus) -> bool: """Update the status of the run with the specified `run_id`.""" with self.lock: # Check if the run_id exists @@ -369,23 +369,28 @@ def update_run_status(self, run_id: int, new_status_info: RunStatus) -> bool: return False # Check if the status transition is valid - info = self.run_ids[run_id][1] - if not is_valid_transition(info, new_status_info): + status = self.run_ids[run_id][1] + if not is_valid_transition(status, new_status): log( ERROR, 'Invalid status transition: from "%s" to "%s"', - info.status, - new_status_info.status, + status.status, + new_status.status, ) return False # Check if the sub-status is valid - if not has_valid_sub_status(info): - log(ERROR, 'Invalid run status: "%s:%s"', info.status, info.sub_status) + if not has_valid_sub_status(status): + log( + ERROR, + 'Invalid run status: "%s:%s"', + status.status, + status.sub_status, + ) return False # Update the status - self.run_ids[run_id] = (self.run_ids[run_id][0], new_status_info) + self.run_ids[run_id] = (self.run_ids[run_id][0], new_status) return True def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool: diff --git a/src/py/flwr/server/superlink/state/sqlite_state.py b/src/py/flwr/server/superlink/state/sqlite_state.py index 4904697601ae..7c3e53d08bf8 100644 --- a/src/py/flwr/server/superlink/state/sqlite_state.py +++ b/src/py/flwr/server/superlink/state/sqlite_state.py @@ -79,7 +79,7 @@ override_config TEXT, status TEXT, sub_status TEXT, - reason TEXT + details TEXT ); """ @@ -706,7 +706,7 @@ def create_run( query = ( "INSERT INTO run " "(run_id, fab_id, fab_version, fab_hash, override_config, " - "status, sub_status, reason)" + "status, sub_status, details)" "VALUES (?, ?, ?, ?, ?, ?, ?, ?);" ) if fab_hash: @@ -790,7 +790,7 @@ def get_run(self, run_id: int) -> Optional[Run]: return None def get_run_status(self, run_ids: set[int]) -> dict[int, RunStatus]: - """Retrieve the status information for the specified runs.""" + """Retrieve the statuses for the specified runs.""" query = f"SELECT * FROM run WHERE run_id IN ({','.join(['?'] * len(run_ids))});" rows = self.query(query, tuple(run_ids)) @@ -798,12 +798,12 @@ def get_run_status(self, run_ids: set[int]) -> dict[int, RunStatus]: row["run_id"]: RunStatus( status=row["status"], sub_status=row["sub_status"], - details=row["reason"], + details=row["details"], ) for row in rows } - def update_run_status(self, run_id: int, new_status_info: RunStatus) -> bool: + def update_run_status(self, run_id: int, new_status: RunStatus) -> bool: """Update the status of the run with the specified `run_id`.""" query = "SELECT * FROM run WHERE run_id = ?;" rows = self.query(query, (run_id,)) @@ -818,14 +818,14 @@ def update_run_status(self, run_id: int, new_status_info: RunStatus) -> bool: status = RunStatus( status=row["status"], sub_status=row["sub_status"], - details=row["reason"], + details=row["details"], ) - if not is_valid_transition(status, new_status_info): + if not is_valid_transition(status, new_status): log( ERROR, 'Invalid status transition: from "%s" to "%s"', status.status, - new_status_info.status, + new_status.status, ) return False @@ -835,12 +835,12 @@ def update_run_status(self, run_id: int, new_status_info: RunStatus) -> bool: return False # Update the status - query = "UPDATE run SET status= ?, sub_status = ?, reason = ? " + query = "UPDATE run SET status= ?, sub_status = ?, details = ? " query += "WHERE run_id = ?;" data = ( - new_status_info.status, - new_status_info.sub_status, - new_status_info.details, + new_status.status, + new_status.sub_status, + new_status.details, run_id, ) self.query(query, data) diff --git a/src/py/flwr/server/superlink/state/state.py b/src/py/flwr/server/superlink/state/state.py index 092ba448f9f2..3f5f9274c964 100644 --- a/src/py/flwr/server/superlink/state/state.py +++ b/src/py/flwr/server/superlink/state/state.py @@ -186,7 +186,7 @@ def get_run(self, run_id: int) -> Optional[Run]: @abc.abstractmethod def get_run_status(self, run_ids: set[int]) -> dict[int, RunStatus]: - """Retrieve the status information for the specified runs. + """Retrieve the statuses for the specified runs. Parameters ---------- @@ -195,8 +195,8 @@ def get_run_status(self, run_ids: set[int]) -> dict[int, RunStatus]: Returns ------- - dict[int, StatusInfo] - A dictionary mapping each valid run ID to its corresponding `StatusInfo`. + dict[int, RunStatus] + A dictionary mapping each valid run ID to its corresponding status. Notes ----- @@ -205,14 +205,14 @@ def get_run_status(self, run_ids: set[int]) -> dict[int, RunStatus]: """ @abc.abstractmethod - def update_run_status(self, run_id: int, new_status_info: RunStatus) -> bool: + def update_run_status(self, run_id: int, new_status: RunStatus) -> bool: """Update the status of the run with the specified `run_id`. Parameters ---------- run_id : int The identifier of the run. - new_status_info : StatusInfo + new_status_info : RunStatus The new status info to be assigned to the run. Returns