diff --git a/tests/unit/test_runtime_reboot.py b/tests/unit/test_runtime_reboot.py index 0e57b12fb3d5..e3ae31815a3e 100644 --- a/tests/unit/test_runtime_reboot.py +++ b/tests/unit/test_runtime_reboot.py @@ -1,12 +1,14 @@ -from unittest.mock import Mock, patch +from unittest.mock import Mock import pytest import requests -from openhands.core.exceptions import AgentRuntimeDisconnectedError, AgentRuntimeTimeoutError +from openhands.core.exceptions import ( + AgentRuntimeDisconnectedError, + AgentRuntimeTimeoutError, +) from openhands.events.action import CmdRunAction from openhands.runtime.base import Runtime -from openhands.runtime.utils.request import RequestHTTPError @pytest.fixture @@ -24,44 +26,62 @@ def runtime(mock_session): def test_runtime_timeout_error(runtime, mock_session): # Create a command action - action = CmdRunAction(command="test command") + action = CmdRunAction(command='test command') action.timeout = 120 # Mock the runtime to raise a timeout error runtime.send_action_for_execution.side_effect = AgentRuntimeTimeoutError( - "Runtime failed to return execute_action before the requested timeout of 120s" + 'Runtime failed to return execute_action before the requested timeout of 120s' ) # Verify that the error message indicates a timeout with pytest.raises(AgentRuntimeTimeoutError) as exc_info: runtime.send_action_for_execution(action) - assert str(exc_info.value) == "Runtime failed to return execute_action before the requested timeout of 120s" + assert ( + str(exc_info.value) + == 'Runtime failed to return execute_action before the requested timeout of 120s' + ) @pytest.mark.parametrize( - "status_code,expected_message", + 'status_code,expected_message', [ - (404, "Runtime is not responding. This may be temporary, please try again."), - (502, "Runtime is temporarily unavailable. This may be due to a restart or network issue, please try again."), + (404, 'Runtime is not responding. This may be temporary, please try again.'), + ( + 502, + 'Runtime is temporarily unavailable. This may be due to a restart or network issue, please try again.', + ), ], ) -def test_runtime_disconnected_error(runtime, mock_session, status_code, expected_message): +def test_runtime_disconnected_error( + runtime, mock_session, status_code, expected_message +): # Mock the request to return the specified status code mock_response = Mock() mock_response.status_code = status_code - mock_response.raise_for_status = Mock(side_effect=requests.HTTPError(response=mock_response)) - mock_response.json = Mock(return_value={'observation': 'run', 'content': 'test', 'extras': {'command_id': 'test_id', 'command': 'test command'}}) + mock_response.raise_for_status = Mock( + side_effect=requests.HTTPError(response=mock_response) + ) + mock_response.json = Mock( + return_value={ + 'observation': 'run', + 'content': 'test', + 'extras': {'command_id': 'test_id', 'command': 'test command'}, + } + ) # Mock the runtime to raise the error - runtime.send_action_for_execution.side_effect = AgentRuntimeDisconnectedError(expected_message) + runtime.send_action_for_execution.side_effect = AgentRuntimeDisconnectedError( + expected_message + ) # Create a command action - action = CmdRunAction(command="test command") + action = CmdRunAction(command='test command') action.timeout = 120 # Verify that the error message is correct with pytest.raises(AgentRuntimeDisconnectedError) as exc_info: runtime.send_action_for_execution(action) - assert str(exc_info.value) == expected_message \ No newline at end of file + assert str(exc_info.value) == expected_message