Skip to content

Commit

Permalink
Merge branch 'main' into openhands-fix-issue-5633
Browse files Browse the repository at this point in the history
  • Loading branch information
neubig authored Dec 16, 2024
2 parents dc63963 + 50478c7 commit 33bd1b7
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 13 deletions.
29 changes: 24 additions & 5 deletions frontend/src/components/features/markdown/code.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,38 @@ export function code({
const match = /language-(\w+)/.exec(className || ""); // get the language

if (!match) {
const isMultiline = String(children).includes("\n");

if (!isMultiline) {
return (
<code
className={className}
style={{
backgroundColor: "#2a3038",
padding: "0.2em 0.4em",
borderRadius: "4px",
color: "#e6edf3",
border: "1px solid #30363d",
}}
>
{children}
</code>
);
}

return (
<code
className={className}
<pre
style={{
backgroundColor: "#2a3038",
padding: "0.2em 0.4em",
padding: "1em",
borderRadius: "4px",
color: "#e6edf3",
border: "1px solid #30363d",
overflow: "auto",
}}
>
{children}
</code>
<code className={className}>{String(children).replace(/\n$/, "")}</code>
</pre>
);
}

Expand Down
16 changes: 16 additions & 0 deletions openhands/controller/agent_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,20 @@ async def _handle_message_action(self, action: MessageAction) -> None:
str(action),
extra={'msg_type': 'ACTION', 'event_source': EventSource.USER},
)
# Extend max iterations when the user sends a message (only in non-headless mode)
if self._initial_max_iterations is not None and not self.headless_mode:
self.state.max_iterations = (
self.state.iteration + self._initial_max_iterations
)
if (
self.state.traffic_control_state == TrafficControlState.THROTTLING
or self.state.traffic_control_state == TrafficControlState.PAUSED
):
self.state.traffic_control_state = TrafficControlState.NORMAL
self.log(
'debug',
f'Extended max iterations to {self.state.max_iterations} after user message',
)
if self.get_agent_state() != AgentState.RUNNING:
await self.set_agent_state_to(AgentState.RUNNING)
elif action.source == EventSource.AGENT and action.wait_for_response:
Expand Down Expand Up @@ -342,6 +356,7 @@ async def set_agent_state_to(self, new_state: AgentState) -> None:
elif (
new_state == AgentState.RUNNING
and self.state.agent_state == AgentState.PAUSED
# TODO: do we really need both THROTTLING and PAUSED states, or can we clean up one of them completely?
and self.state.traffic_control_state == TrafficControlState.THROTTLING
):
# user intends to interrupt traffic control and let the task resume temporarily
Expand All @@ -351,6 +366,7 @@ async def set_agent_state_to(self, new_state: AgentState) -> None:
self.state.iteration is not None
and self.state.max_iterations is not None
and self._initial_max_iterations is not None
and not self.headless_mode
):
if self.state.iteration >= self.state.max_iterations:
self.state.max_iterations += self._initial_max_iterations
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,6 @@ jupyter_kernel_gateway = "*"
flake8 = "*"
opencv-python = "*"


[build-system]
build-backend = "poetry.core.masonry.api"
requires = [
Expand Down
45 changes: 38 additions & 7 deletions tests/unit/test_agent_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from openhands.controller.agent import Agent
from openhands.controller.agent_controller import AgentController
from openhands.controller.state.state import TrafficControlState
from openhands.controller.state.state import State, TrafficControlState
from openhands.core.config import AppConfig
from openhands.core.main import run_controller
from openhands.core.schema import AgentState
Expand Down Expand Up @@ -41,7 +41,9 @@ def mock_agent():

@pytest.fixture
def mock_event_stream():
return MagicMock(spec=EventStream)
mock = MagicMock(spec=EventStream)
mock.get_latest_event_id.return_value = 0
return mock


@pytest.fixture
Expand Down Expand Up @@ -278,40 +280,69 @@ async def test_delegate_step_different_states(


@pytest.mark.asyncio
async def test_step_max_iterations(mock_agent, mock_event_stream):
async def test_max_iterations_extension(mock_agent, mock_event_stream):
# Test with headless_mode=False - should extend max_iterations
initial_state = State(max_iterations=10)
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
max_iterations=10,
sid='test',
confirmation_mode=False,
headless_mode=False,
initial_state=initial_state,
)
controller.state.agent_state = AgentState.RUNNING
controller.state.iteration = 10
assert controller.state.traffic_control_state == TrafficControlState.NORMAL

# Trigger throttling by calling _step() when we hit max_iterations
await controller._step()
assert controller.state.traffic_control_state == TrafficControlState.THROTTLING
assert controller.state.agent_state == AgentState.ERROR
await controller.close()

# Simulate a new user message
message_action = MessageAction(content='Test message')
message_action._source = EventSource.USER
await controller.on_event(message_action)

# Max iterations should be extended to current iteration + initial max_iterations
assert (
controller.state.max_iterations == 20
) # Current iteration (10 initial because _step() should not have been executed) + initial max_iterations (10)
assert controller.state.traffic_control_state == TrafficControlState.NORMAL
assert controller.state.agent_state == AgentState.RUNNING

@pytest.mark.asyncio
async def test_step_max_iterations_headless(mock_agent, mock_event_stream):
# Close the controller to clean up
await controller.close()

# Test with headless_mode=True - should NOT extend max_iterations
initial_state = State(max_iterations=10)
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
max_iterations=10,
sid='test',
confirmation_mode=False,
headless_mode=True,
initial_state=initial_state,
)
controller.state.agent_state = AgentState.RUNNING
controller.state.iteration = 10
assert controller.state.traffic_control_state == TrafficControlState.NORMAL

# Simulate a new user message
message_action = MessageAction(content='Test message')
message_action._source = EventSource.USER
await controller.on_event(message_action)

# Max iterations should NOT be extended in headless mode
assert controller.state.max_iterations == 10 # Original value unchanged

# Trigger throttling by calling _step() when we hit max_iterations
await controller._step()

assert controller.state.traffic_control_state == TrafficControlState.THROTTLING
# In headless mode, throttling results in an error
assert controller.state.agent_state == AgentState.ERROR
await controller.close()

Expand Down
62 changes: 62 additions & 0 deletions tests/unit/test_iteration_limit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import asyncio

import pytest

from openhands.controller.agent_controller import AgentController
from openhands.core.schema import AgentState
from openhands.events import EventStream
from openhands.events.action import MessageAction
from openhands.events.event import EventSource


class DummyAgent:
def __init__(self):
self.name = 'dummy'
self.llm = type(
'DummyLLM',
(),
{'metrics': type('DummyMetrics', (), {'merge': lambda x: None})()},
)()

def reset(self):
pass


@pytest.mark.asyncio
async def test_iteration_limit_extends_on_user_message():
# Initialize test components
from openhands.storage.memory import InMemoryFileStore

file_store = InMemoryFileStore()
event_stream = EventStream(sid='test', file_store=file_store)
agent = DummyAgent()
initial_max_iterations = 100
controller = AgentController(
agent=agent,
event_stream=event_stream,
max_iterations=initial_max_iterations,
sid='test',
headless_mode=False,
)

# Set initial state
await controller.set_agent_state_to(AgentState.RUNNING)
controller.state.iteration = 90 # Close to the limit
assert controller.state.max_iterations == initial_max_iterations

# Simulate user message
user_message = MessageAction('test message', EventSource.USER)
event_stream.add_event(user_message, EventSource.USER)
await asyncio.sleep(0.1) # Give time for event to be processed

# Verify max_iterations was extended
assert controller.state.max_iterations == 90 + initial_max_iterations

# Simulate more iterations and another user message
controller.state.iteration = 180 # Close to new limit
user_message2 = MessageAction('another message', EventSource.USER)
event_stream.add_event(user_message2, EventSource.USER)
await asyncio.sleep(0.1) # Give time for event to be processed

# Verify max_iterations was extended again
assert controller.state.max_iterations == 180 + initial_max_iterations

0 comments on commit 33bd1b7

Please sign in to comment.