diff --git a/src/gha_runner/gh.py b/src/gha_runner/gh.py index b637da8..04907bf 100644 --- a/src/gha_runner/gh.py +++ b/src/gha_runner/gh.py @@ -246,7 +246,9 @@ def get_runner(self, label: str) -> SelfHostedRunner: return runner raise MissingRunnerLabel(f"Runner {label} not found") - def wait_for_runner(self, label: str, timeout: int, wait: int = 15): + def wait_for_runner( + self, label: str, timeout: int, wait: int = 15 + ) -> SelfHostedRunner: """Wait for the runner with the given label to be online. Parameters @@ -257,16 +259,30 @@ def wait_for_runner(self, label: str, timeout: int, wait: int = 15): The time in seconds to wait between checks. Defaults to 15 seconds. timeout : int The maximum time in seconds to wait for the runner to be online. + + Returns + ------- + SelfHostedRunner + The runner with the given label. + """ max = time.time() + timeout - runner = self.get_runner(label) - while runner is None: - if time.time() > max: - raise RuntimeError(f"Timeout reached: Runner {label} not found") - print(f"Runner {label} not found. Waiting...") + try: runner = self.get_runner(label) - time.sleep(wait) - print(f"Runner {label} found!") + return runner + except MissingRunnerLabel: + print(f"Waiting for runner {label}...") + while True: + if time.time() > max: + raise RuntimeError( + f"Timeout reached: Runner {label} not found" + ) + try: + runner = self.get_runner(label) + return runner + except MissingRunnerLabel: + print(f"Runner {label} not found. Waiting...") + time.sleep(wait) def remove_runner(self, label: str): """Remove a runner by a given label. diff --git a/tests/test_gh.py b/tests/test_gh.py index 917fbc4..fa6caa5 100644 --- a/tests/test_gh.py +++ b/tests/test_gh.py @@ -114,7 +114,9 @@ def test_get_runners_no_json(github_instance): body="", status=200, ) - with pytest.raises(RunnerListError, match="Did not receive mapping object: *"): + with pytest.raises( + RunnerListError, match="Did not receive mapping object: *" + ): github_instance.get_runners() @@ -205,19 +207,57 @@ def test_get_latest_runner_release_invalid_arch(github_instance): @patch("time.sleep") # Prevent actual sleeping in tests def test_wait_for_runner_success(mock_sleep, github_instance, mock_runner): with patch.object( - github_instance, "get_runner", side_effect=[None, mock_runner] + github_instance, + "get_runner", + side_effect=[ + MissingRunnerLabel("First fail"), + MissingRunnerLabel("Second fail"), + mock_runner, + ], ): - github_instance.wait_for_runner("test-label", timeout=30) + runner = github_instance.wait_for_runner("test-label", timeout=30) assert mock_sleep.call_count == 1 + assert runner == mock_runner -@patch("time.sleep") -@patch("time.time", side_effect=[0, 31]) +@patch("time.sleep") # Prevent actual sleeping +def test_wait_for_runner_already_exists( + mock_sleep, github_instance, mock_runner +): + with patch.object( + github_instance, + "get_runner", + side_effect=[ + mock_runner, + ], + ): + runner = github_instance.wait_for_runner("test-label", timeout=30) + assert mock_sleep.call_count == 0 + assert runner == mock_runner + + +@patch("time.sleep") # Prevent actual sleeping in tests +@patch("time.time") # Control time for timeout logic def test_wait_for_runner_timeout(mock_time, mock_sleep, github_instance): - with patch.object(github_instance, "get_runner", return_value=None): - with pytest.raises(RuntimeError): + # Mock time.time() to first return 0, then return 31 (past the timeout) + mock_time.side_effect = [0, 29, 31] + + with patch.object( + github_instance, + "get_runner", + side_effect=[ + MissingRunnerLabel("Initial fail"), + MissingRunnerLabel("Fail after timeout"), + ], + ): + with pytest.raises(RuntimeError) as excinfo: github_instance.wait_for_runner("test-label", timeout=30) + assert "Timeout reached: Runner test-label not found" in str( + excinfo.value + ) + assert mock_sleep.call_count == 1 + @responses.activate def test_get_latest_release(github_instance):