diff --git a/src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py b/src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py index 24844eea4e33..a5b8a12cc9d5 100644 --- a/src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py +++ b/src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py @@ -480,6 +480,19 @@ def update_run_status(self, run_id: int, new_status: RunStatus) -> bool: run_record.status = new_status return True + def get_pending_run_id(self) -> Optional[int]: + """Get the `run_id` of a run with `Status.PENDING` status, if any.""" + pending_run_id = None + + # Loop through all registered runs + for run_id, run_rec in self.run_ids.items(): + # Break once a pending run is found + if run_rec.status.status == Status.PENDING: + pending_run_id = run_id + break + + return pending_run_id + def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool: """Acknowledge a ping received from a node, serving as a heartbeat.""" with self.lock: diff --git a/src/py/flwr/server/superlink/linkstate/linkstate.py b/src/py/flwr/server/superlink/linkstate/linkstate.py index 1ea38681a1ad..6e20b6717207 100644 --- a/src/py/flwr/server/superlink/linkstate/linkstate.py +++ b/src/py/flwr/server/superlink/linkstate/linkstate.py @@ -215,6 +215,17 @@ def update_run_status(self, run_id: int, new_status: RunStatus) -> bool: True if the status update is successful; False otherwise. """ + @abc.abstractmethod + def get_pending_run_id(self) -> Optional[int]: + """Get the `run_id` of a run with `Status.PENDING` status. + + Returns + ------- + Optional[int] + The `run_id` of a `Run` that is pending to be started; None if + there is no Run pending. + """ + @abc.abstractmethod def store_server_private_public_key( self, private_key: bytes, public_key: bytes diff --git a/src/py/flwr/server/superlink/linkstate/linkstate_test.py b/src/py/flwr/server/superlink/linkstate/linkstate_test.py index 4cbb3c8b2e68..418e61168915 100644 --- a/src/py/flwr/server/superlink/linkstate/linkstate_test.py +++ b/src/py/flwr/server/superlink/linkstate/linkstate_test.py @@ -67,6 +67,26 @@ def test_create_and_get_run(self) -> None: assert run.fab_hash == "9f86d08" assert run.override_config["test_key"] == "test_value" + def test_get_pending_run_id(self) -> None: + """Test if get_pending_run_id works correctly.""" + # Prepare + state = self.state_factory() + _ = state.create_run(None, None, "9f86d08", {"test_key": "test_value"}) + run_id2 = state.create_run(None, None, "fffffff", {"mock_key": "mock_value"}) + state.update_run_status(run_id2, RunStatus(Status.STARTING, "", "")) + + # Execute + pending_run_id = state.get_pending_run_id() + assert pending_run_id is not None + run_status_dict = state.get_run_status({pending_run_id}) + assert run_status_dict[pending_run_id].status == Status.PENDING + + # Change state + state.update_run_status(pending_run_id, RunStatus(Status.STARTING, "", "")) + # Attempt get pending run + pending_run_id = state.get_pending_run_id() + assert pending_run_id is None + def test_get_and_update_run_status(self) -> None: """Test if get_run_status and update_run_status work correctly.""" # Prepare diff --git a/src/py/flwr/server/superlink/linkstate/sqlite_linkstate.py b/src/py/flwr/server/superlink/linkstate/sqlite_linkstate.py index 4b77020b2324..bcf0b319f307 100644 --- a/src/py/flwr/server/superlink/linkstate/sqlite_linkstate.py +++ b/src/py/flwr/server/superlink/linkstate/sqlite_linkstate.py @@ -943,6 +943,18 @@ def update_run_status(self, run_id: int, new_status: RunStatus) -> bool: self.query(query % timestamp_fld, data) return True + def get_pending_run_id(self) -> Optional[int]: + """Get the `run_id` of a run with `Status.PENDING` status, if any.""" + pending_run_id = None + + # Fetch all runs with unset `starting_at` (i.e. they are in PENDING status) + query = "SELECT * FROM run WHERE starting_at = '' LIMIT 1;" + rows = self.query(query) + if rows: + pending_run_id = convert_sint64_to_uint64(rows[0]["run_id"]) + + return pending_run_id + def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool: """Acknowledge a ping received from a node, serving as a heartbeat.""" sint64_node_id = convert_uint64_to_sint64(node_id)