diff --git a/evaluation/benchmarks/swe_bench/run_infer.py b/evaluation/benchmarks/swe_bench/run_infer.py index 01111f75d126..134c98cb9646 100644 --- a/evaluation/benchmarks/swe_bench/run_infer.py +++ b/evaluation/benchmarks/swe_bench/run_infer.py @@ -9,7 +9,6 @@ from datasets import load_dataset import openhands.agenthub - from evaluation.utils.shared import ( EvalException, EvalMetadata, @@ -76,7 +75,7 @@ def get_instruction(instance: pd.Series, metadata: EvalMetadata): '4. Rerun your reproduce script and confirm that the error is fixed!\n' '5. Think about edgecases and make sure your fix handles them as well\n' "Your thinking should be thorough and so it's fine if it's very long.\n" - ) + ) if RUN_WITH_BROWSING: instruction += ( diff --git a/openhands/controller/agent_controller.py b/openhands/controller/agent_controller.py index b2de6dd7d5ac..d3cde88f636d 100644 --- a/openhands/controller/agent_controller.py +++ b/openhands/controller/agent_controller.py @@ -312,6 +312,20 @@ async def _handle_message_action(self, action: MessageAction) -> None: str(action), extra={'msg_type': 'ACTION', 'event_source': EventSource.USER}, ) + # Extend max iterations when the user sends a message (only in non-headless mode) + if self._initial_max_iterations is not None and not self.headless_mode: + self.state.max_iterations = ( + self.state.iteration + self._initial_max_iterations + ) + if ( + self.state.traffic_control_state == TrafficControlState.THROTTLING + or self.state.traffic_control_state == TrafficControlState.PAUSED + ): + self.state.traffic_control_state = TrafficControlState.NORMAL + self.log( + 'debug', + f'Extended max iterations to {self.state.max_iterations} after user message', + ) if self.get_agent_state() != AgentState.RUNNING: await self.set_agent_state_to(AgentState.RUNNING) elif action.source == EventSource.AGENT and action.wait_for_response: @@ -342,6 +356,7 @@ async def set_agent_state_to(self, new_state: AgentState) -> None: elif ( new_state == AgentState.RUNNING and self.state.agent_state == AgentState.PAUSED + # TODO: do we really need both THROTTLING and PAUSED states, or can we clean up one of them completely? and self.state.traffic_control_state == TrafficControlState.THROTTLING ): # user intends to interrupt traffic control and let the task resume temporarily @@ -351,6 +366,7 @@ async def set_agent_state_to(self, new_state: AgentState) -> None: self.state.iteration is not None and self.state.max_iterations is not None and self._initial_max_iterations is not None + and not self.headless_mode ): if self.state.iteration >= self.state.max_iterations: self.state.max_iterations += self._initial_max_iterations diff --git a/pyproject.toml b/pyproject.toml index 2b0d3ca1e8a0..2ed685a7c0f7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -100,6 +100,7 @@ reportlab = "*" [tool.coverage.run] concurrency = ["gevent"] + [tool.poetry.group.runtime.dependencies] jupyterlab = "*" notebook = "*" @@ -107,7 +108,6 @@ jupyter_kernel_gateway = "*" flake8 = "*" opencv-python = "*" - [build-system] build-backend = "poetry.core.masonry.api" requires = [ @@ -130,6 +130,7 @@ ignore = ["D1"] [tool.ruff.lint.pydocstyle] convention = "google" + [tool.poetry.group.evaluation.dependencies] streamlit = "*" whatthepatch = "*" diff --git a/tests/unit/test_agent_controller.py b/tests/unit/test_agent_controller.py index 9c07969bd090..08fe0e0f5587 100644 --- a/tests/unit/test_agent_controller.py +++ b/tests/unit/test_agent_controller.py @@ -6,7 +6,7 @@ from openhands.controller.agent import Agent from openhands.controller.agent_controller import AgentController -from openhands.controller.state.state import TrafficControlState +from openhands.controller.state.state import State, TrafficControlState from openhands.core.config import AppConfig from openhands.core.main import run_controller from openhands.core.schema import AgentState @@ -41,7 +41,9 @@ def mock_agent(): @pytest.fixture def mock_event_stream(): - return MagicMock(spec=EventStream) + mock = MagicMock(spec=EventStream) + mock.get_latest_event_id.return_value = 0 + return mock @pytest.fixture @@ -278,7 +280,9 @@ async def test_delegate_step_different_states( @pytest.mark.asyncio -async def test_step_max_iterations(mock_agent, mock_event_stream): +async def test_max_iterations_extension(mock_agent, mock_event_stream): + # Test with headless_mode=False - should extend max_iterations + initial_state = State(max_iterations=10) controller = AgentController( agent=mock_agent, event_stream=mock_event_stream, @@ -286,18 +290,34 @@ async def test_step_max_iterations(mock_agent, mock_event_stream): sid='test', confirmation_mode=False, headless_mode=False, + initial_state=initial_state, ) controller.state.agent_state = AgentState.RUNNING controller.state.iteration = 10 assert controller.state.traffic_control_state == TrafficControlState.NORMAL + + # Trigger throttling by calling _step() when we hit max_iterations await controller._step() assert controller.state.traffic_control_state == TrafficControlState.THROTTLING assert controller.state.agent_state == AgentState.ERROR - await controller.close() + # Simulate a new user message + message_action = MessageAction(content='Test message') + message_action._source = EventSource.USER + await controller.on_event(message_action) + + # Max iterations should be extended to current iteration + initial max_iterations + assert ( + controller.state.max_iterations == 20 + ) # Current iteration (10 initial because _step() should not have been executed) + initial max_iterations (10) + assert controller.state.traffic_control_state == TrafficControlState.NORMAL + assert controller.state.agent_state == AgentState.RUNNING -@pytest.mark.asyncio -async def test_step_max_iterations_headless(mock_agent, mock_event_stream): + # Close the controller to clean up + await controller.close() + + # Test with headless_mode=True - should NOT extend max_iterations + initial_state = State(max_iterations=10) controller = AgentController( agent=mock_agent, event_stream=mock_event_stream, @@ -305,13 +325,24 @@ async def test_step_max_iterations_headless(mock_agent, mock_event_stream): sid='test', confirmation_mode=False, headless_mode=True, + initial_state=initial_state, ) controller.state.agent_state = AgentState.RUNNING controller.state.iteration = 10 assert controller.state.traffic_control_state == TrafficControlState.NORMAL + + # Simulate a new user message + message_action = MessageAction(content='Test message') + message_action._source = EventSource.USER + await controller.on_event(message_action) + + # Max iterations should NOT be extended in headless mode + assert controller.state.max_iterations == 10 # Original value unchanged + + # Trigger throttling by calling _step() when we hit max_iterations await controller._step() + assert controller.state.traffic_control_state == TrafficControlState.THROTTLING - # In headless mode, throttling results in an error assert controller.state.agent_state == AgentState.ERROR await controller.close() diff --git a/tests/unit/test_iteration_limit.py b/tests/unit/test_iteration_limit.py new file mode 100644 index 000000000000..4520231a0d13 --- /dev/null +++ b/tests/unit/test_iteration_limit.py @@ -0,0 +1,62 @@ +import asyncio + +import pytest + +from openhands.controller.agent_controller import AgentController +from openhands.core.schema import AgentState +from openhands.events import EventStream +from openhands.events.action import MessageAction +from openhands.events.event import EventSource + + +class DummyAgent: + def __init__(self): + self.name = 'dummy' + self.llm = type( + 'DummyLLM', + (), + {'metrics': type('DummyMetrics', (), {'merge': lambda x: None})()}, + )() + + def reset(self): + pass + + +@pytest.mark.asyncio +async def test_iteration_limit_extends_on_user_message(): + # Initialize test components + from openhands.storage.memory import InMemoryFileStore + + file_store = InMemoryFileStore() + event_stream = EventStream(sid='test', file_store=file_store) + agent = DummyAgent() + initial_max_iterations = 100 + controller = AgentController( + agent=agent, + event_stream=event_stream, + max_iterations=initial_max_iterations, + sid='test', + headless_mode=False, + ) + + # Set initial state + await controller.set_agent_state_to(AgentState.RUNNING) + controller.state.iteration = 90 # Close to the limit + assert controller.state.max_iterations == initial_max_iterations + + # Simulate user message + user_message = MessageAction('test message', EventSource.USER) + event_stream.add_event(user_message, EventSource.USER) + await asyncio.sleep(0.1) # Give time for event to be processed + + # Verify max_iterations was extended + assert controller.state.max_iterations == 90 + initial_max_iterations + + # Simulate more iterations and another user message + controller.state.iteration = 180 # Close to new limit + user_message2 = MessageAction('another message', EventSource.USER) + event_stream.add_event(user_message2, EventSource.USER) + await asyncio.sleep(0.1) # Give time for event to be processed + + # Verify max_iterations was extended again + assert controller.state.max_iterations == 180 + initial_max_iterations