Skip to content

Commit

Permalink
Fix issue #5559: The turn limit should be measured from the last user…
Browse files Browse the repository at this point in the history
… interaction (#5560)

Co-authored-by: Graham Neubig <[email protected]>
Co-authored-by: Engel Nyst <[email protected]>
  • Loading branch information
3 people authored Dec 16, 2024
1 parent dd79acd commit 4998b5d
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 10 deletions.
3 changes: 1 addition & 2 deletions evaluation/benchmarks/swe_bench/run_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from datasets import load_dataset

import openhands.agenthub

from evaluation.utils.shared import (
EvalException,
EvalMetadata,
Expand Down Expand Up @@ -76,7 +75,7 @@ def get_instruction(instance: pd.Series, metadata: EvalMetadata):
'4. Rerun your reproduce script and confirm that the error is fixed!\n'
'5. Think about edgecases and make sure your fix handles them as well\n'
"Your thinking should be thorough and so it's fine if it's very long.\n"
)
)

if RUN_WITH_BROWSING:
instruction += (
Expand Down
16 changes: 16 additions & 0 deletions openhands/controller/agent_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,20 @@ async def _handle_message_action(self, action: MessageAction) -> None:
str(action),
extra={'msg_type': 'ACTION', 'event_source': EventSource.USER},
)
# Extend max iterations when the user sends a message (only in non-headless mode)
if self._initial_max_iterations is not None and not self.headless_mode:
self.state.max_iterations = (
self.state.iteration + self._initial_max_iterations
)
if (
self.state.traffic_control_state == TrafficControlState.THROTTLING
or self.state.traffic_control_state == TrafficControlState.PAUSED
):
self.state.traffic_control_state = TrafficControlState.NORMAL
self.log(
'debug',
f'Extended max iterations to {self.state.max_iterations} after user message',
)
if self.get_agent_state() != AgentState.RUNNING:
await self.set_agent_state_to(AgentState.RUNNING)
elif action.source == EventSource.AGENT and action.wait_for_response:
Expand Down Expand Up @@ -342,6 +356,7 @@ async def set_agent_state_to(self, new_state: AgentState) -> None:
elif (
new_state == AgentState.RUNNING
and self.state.agent_state == AgentState.PAUSED
# TODO: do we really need both THROTTLING and PAUSED states, or can we clean up one of them completely?
and self.state.traffic_control_state == TrafficControlState.THROTTLING
):
# user intends to interrupt traffic control and let the task resume temporarily
Expand All @@ -351,6 +366,7 @@ async def set_agent_state_to(self, new_state: AgentState) -> None:
self.state.iteration is not None
and self.state.max_iterations is not None
and self._initial_max_iterations is not None
and not self.headless_mode
):
if self.state.iteration >= self.state.max_iterations:
self.state.max_iterations += self._initial_max_iterations
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -100,14 +100,14 @@ reportlab = "*"
[tool.coverage.run]
concurrency = ["gevent"]


[tool.poetry.group.runtime.dependencies]
jupyterlab = "*"
notebook = "*"
jupyter_kernel_gateway = "*"
flake8 = "*"
opencv-python = "*"


[build-system]
build-backend = "poetry.core.masonry.api"
requires = [
Expand All @@ -130,6 +130,7 @@ ignore = ["D1"]
[tool.ruff.lint.pydocstyle]
convention = "google"


[tool.poetry.group.evaluation.dependencies]
streamlit = "*"
whatthepatch = "*"
Expand Down
45 changes: 38 additions & 7 deletions tests/unit/test_agent_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from openhands.controller.agent import Agent
from openhands.controller.agent_controller import AgentController
from openhands.controller.state.state import TrafficControlState
from openhands.controller.state.state import State, TrafficControlState
from openhands.core.config import AppConfig
from openhands.core.main import run_controller
from openhands.core.schema import AgentState
Expand Down Expand Up @@ -41,7 +41,9 @@ def mock_agent():

@pytest.fixture
def mock_event_stream():
return MagicMock(spec=EventStream)
mock = MagicMock(spec=EventStream)
mock.get_latest_event_id.return_value = 0
return mock


@pytest.fixture
Expand Down Expand Up @@ -278,40 +280,69 @@ async def test_delegate_step_different_states(


@pytest.mark.asyncio
async def test_step_max_iterations(mock_agent, mock_event_stream):
async def test_max_iterations_extension(mock_agent, mock_event_stream):
# Test with headless_mode=False - should extend max_iterations
initial_state = State(max_iterations=10)
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
max_iterations=10,
sid='test',
confirmation_mode=False,
headless_mode=False,
initial_state=initial_state,
)
controller.state.agent_state = AgentState.RUNNING
controller.state.iteration = 10
assert controller.state.traffic_control_state == TrafficControlState.NORMAL

# Trigger throttling by calling _step() when we hit max_iterations
await controller._step()
assert controller.state.traffic_control_state == TrafficControlState.THROTTLING
assert controller.state.agent_state == AgentState.ERROR
await controller.close()

# Simulate a new user message
message_action = MessageAction(content='Test message')
message_action._source = EventSource.USER
await controller.on_event(message_action)

# Max iterations should be extended to current iteration + initial max_iterations
assert (
controller.state.max_iterations == 20
) # Current iteration (10 initial because _step() should not have been executed) + initial max_iterations (10)
assert controller.state.traffic_control_state == TrafficControlState.NORMAL
assert controller.state.agent_state == AgentState.RUNNING

@pytest.mark.asyncio
async def test_step_max_iterations_headless(mock_agent, mock_event_stream):
# Close the controller to clean up
await controller.close()

# Test with headless_mode=True - should NOT extend max_iterations
initial_state = State(max_iterations=10)
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
max_iterations=10,
sid='test',
confirmation_mode=False,
headless_mode=True,
initial_state=initial_state,
)
controller.state.agent_state = AgentState.RUNNING
controller.state.iteration = 10
assert controller.state.traffic_control_state == TrafficControlState.NORMAL

# Simulate a new user message
message_action = MessageAction(content='Test message')
message_action._source = EventSource.USER
await controller.on_event(message_action)

# Max iterations should NOT be extended in headless mode
assert controller.state.max_iterations == 10 # Original value unchanged

# Trigger throttling by calling _step() when we hit max_iterations
await controller._step()

assert controller.state.traffic_control_state == TrafficControlState.THROTTLING
# In headless mode, throttling results in an error
assert controller.state.agent_state == AgentState.ERROR
await controller.close()

Expand Down
62 changes: 62 additions & 0 deletions tests/unit/test_iteration_limit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import asyncio

import pytest

from openhands.controller.agent_controller import AgentController
from openhands.core.schema import AgentState
from openhands.events import EventStream
from openhands.events.action import MessageAction
from openhands.events.event import EventSource


class DummyAgent:
def __init__(self):
self.name = 'dummy'
self.llm = type(
'DummyLLM',
(),
{'metrics': type('DummyMetrics', (), {'merge': lambda x: None})()},
)()

def reset(self):
pass


@pytest.mark.asyncio
async def test_iteration_limit_extends_on_user_message():
# Initialize test components
from openhands.storage.memory import InMemoryFileStore

file_store = InMemoryFileStore()
event_stream = EventStream(sid='test', file_store=file_store)
agent = DummyAgent()
initial_max_iterations = 100
controller = AgentController(
agent=agent,
event_stream=event_stream,
max_iterations=initial_max_iterations,
sid='test',
headless_mode=False,
)

# Set initial state
await controller.set_agent_state_to(AgentState.RUNNING)
controller.state.iteration = 90 # Close to the limit
assert controller.state.max_iterations == initial_max_iterations

# Simulate user message
user_message = MessageAction('test message', EventSource.USER)
event_stream.add_event(user_message, EventSource.USER)
await asyncio.sleep(0.1) # Give time for event to be processed

# Verify max_iterations was extended
assert controller.state.max_iterations == 90 + initial_max_iterations

# Simulate more iterations and another user message
controller.state.iteration = 180 # Close to new limit
user_message2 = MessageAction('another message', EventSource.USER)
event_stream.add_event(user_message2, EventSource.USER)
await asyncio.sleep(0.1) # Give time for event to be processed

# Verify max_iterations was extended again
assert controller.state.max_iterations == 180 + initial_max_iterations

0 comments on commit 4998b5d

Please sign in to comment.