Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Context Window Exceeded fix #4977

Merged
merged 16 commits into from
Nov 14, 2024
132 changes: 130 additions & 2 deletions openhands/controller/agent_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Callable, ClassVar, Type

import litellm
from litellm.exceptions import ContextWindowExceededError

from openhands.controller.agent import Agent
from openhands.controller.state.state import State, TrafficControlState
Expand Down Expand Up @@ -485,6 +486,15 @@ async def _step(self) -> None:
EventSource.AGENT,
)
return
except ContextWindowExceededError:
# When context window is exceeded, keep roughly half of agent interactions
self.state.history = self._apply_conversation_window(self.state.history)

# Save the ID of the first event in our truncated history for future reloading
if self.state.history:
self.state.start_id = self.state.history[0].id
# Don't add error event - let the agent retry with reduced context
return

if action.runnable:
if self.state.confirmation_mode and (
Expand Down Expand Up @@ -659,6 +669,12 @@ def _init_history(self):
- For delegate events (between AgentDelegateAction and AgentDelegateObservation):
- Excludes all events between the action and observation
- Includes the delegate action and observation themselves

The history is loaded in two parts if truncation_id is set:
1. First user message from start_id onwards
2. Rest of history from truncation_id to the end

Otherwise loads normally from start_id.
"""

# define range of events to fetch
Expand All @@ -680,8 +696,33 @@ def _init_history(self):
self.state.history = []
return

# Get all events, filtering out backend events and hidden events
events = list(
events: list[Event] = []

# If we have a truncation point, get first user message and then rest of history
if hasattr(self.state, 'truncation_id') and self.state.truncation_id > 0:
# Find first user message from stream
first_user_msg = next(
(
e
for e in self.event_stream.get_events(
start_id=start_id,
end_id=end_id,
reverse=False,
filter_out_type=self.filter_out,
filter_hidden=True,
)
if isinstance(e, MessageAction) and e.source == EventSource.USER
),
None,
)
if first_user_msg:
events.append(first_user_msg)

# the rest of the events are from the truncation point
start_id = self.state.truncation_id

# Get rest of history
events_to_add = list(
self.event_stream.get_events(
start_id=start_id,
end_id=end_id,
Expand All @@ -690,6 +731,7 @@ def _init_history(self):
filter_hidden=True,
)
)
events.extend(events_to_add)

# Find all delegate action/observation pairs
delegate_ranges: list[tuple[int, int]] = []
Expand Down Expand Up @@ -744,6 +786,92 @@ def _init_history(self):
# make sure history is in sync
self.state.start_id = start_id

def _apply_conversation_window(self, events: list[Event]) -> list[Event]:
"""Cuts history roughly in half when context window is exceeded, preserving action-observation pairs
and ensuring the first user message is always included.

The algorithm:
1. Cut history in half
2. Check first event in new history:
- If Observation: find and include its Action
- If MessageAction: ensure its related Action-Observation pair isn't split
3. Always include the first user message

Args:
events: List of events to filter

Returns:
Filtered list of events keeping newest half while preserving pairs
"""
if not events:
return events

# Find first user message - we'll need to ensure it's included
first_user_msg = next(
(
e
for e in events
if isinstance(e, MessageAction) and e.source == EventSource.USER
),
None,
)

# cut in half
mid_point = max(1, len(events) // 2)
kept_events = events[mid_point:]

# Handle first event in truncated history
if kept_events:
i = 0
while i < len(kept_events):
first_event = kept_events[i]
if isinstance(first_event, Observation) and first_event.cause:
# Find its action and include it
matching_action = next(
(
e
for e in reversed(events[:mid_point])
if isinstance(e, Action) and e.id == first_event.cause
),
None,
)
if matching_action:
kept_events = [matching_action] + kept_events
else:
self.log(
'warning',
f'Found Observation without matching Action at id={first_event.id}',
)
# drop this observation
kept_events = kept_events[1:]
break

elif isinstance(first_event, MessageAction) or (
isinstance(first_event, Action)
and first_event.source == EventSource.USER
):
# if it's a message action or a user action, keep it and continue to find the next event
i += 1
continue

else:
# if it's an action with source == EventSource.AGENT, we're good
break

# Save where to continue from in next reload
if kept_events:
self.state.truncation_id = kept_events[0].id

# Ensure first user message is included
if first_user_msg and first_user_msg not in kept_events:
kept_events = [first_user_msg] + kept_events

# start_id points to first user message
if first_user_msg:
self.state.start_id = first_user_msg.id

return kept_events

def _is_stuck(self):
"""Checks if the agent or its delegate is stuck in a loop.

Expand Down
2 changes: 2 additions & 0 deletions openhands/controller/state/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ class State:
# start_id and end_id track the range of events in history
start_id: int = -1
end_id: int = -1
# truncation_id tracks where to load history after context window truncation
truncation_id: int = -1
almost_stuck: int = 0
delegates: dict[tuple[int, int], tuple[str, str]] = field(default_factory=dict)
# NOTE: This will never be used by the controller, but it can be used by different
Expand Down
188 changes: 188 additions & 0 deletions tests/unit/test_truncation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
from unittest.mock import MagicMock

import pytest

from openhands.controller.agent_controller import AgentController
from openhands.events import EventSource
from openhands.events.action import CmdRunAction, MessageAction
from openhands.events.observation import CmdOutputObservation


@pytest.fixture
def mock_event_stream():
stream = MagicMock()
# Mock get_events to return an empty list by default
stream.get_events.return_value = []
return stream


@pytest.fixture
def mock_agent():
agent = MagicMock()
agent.llm = MagicMock()
agent.llm.config = MagicMock()
return agent


class TestTruncation:
def test_apply_conversation_window_basic(self, mock_event_stream, mock_agent):
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
max_iterations=10,
sid='test_truncation',
confirmation_mode=False,
headless_mode=True,
)

# Create a sequence of events with IDs
first_msg = MessageAction(content='Hello, start task', wait_for_response=False)
first_msg._source = EventSource.USER
first_msg._id = 1

cmd1 = CmdRunAction(command='ls')
cmd1._id = 2
obs1 = CmdOutputObservation(command='ls', content='file1.txt', command_id=2)
obs1._id = 3
obs1._cause = 2

cmd2 = CmdRunAction(command='pwd')
cmd2._id = 4
obs2 = CmdOutputObservation(command='pwd', content='/home', command_id=4)
obs2._id = 5
obs2._cause = 4

events = [first_msg, cmd1, obs1, cmd2, obs2]

# Apply truncation
truncated = controller._apply_conversation_window(events)

# Should keep first user message and roughly half of other events
assert (
len(truncated) >= 3
) # First message + at least one action-observation pair
assert truncated[0] == first_msg # First message always preserved
assert controller.state.start_id == first_msg._id
assert controller.state.truncation_id is not None

# Verify pairs aren't split
for i, event in enumerate(truncated[1:]):
if isinstance(event, CmdOutputObservation):
assert any(e._id == event._cause for e in truncated[: i + 1])

def test_context_window_exceeded_handling(self, mock_event_stream, mock_agent):
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
max_iterations=10,
sid='test_truncation',
confirmation_mode=False,
headless_mode=True,
)

# Setup initial history with IDs
first_msg = MessageAction(content='Start task', wait_for_response=False)
first_msg._source = EventSource.USER
first_msg._id = 1

# Add agent question
agent_msg = MessageAction(
content='What task would you like me to perform?', wait_for_response=True
)
agent_msg._source = EventSource.AGENT
agent_msg._id = 2

# Add user response
user_response = MessageAction(
content='Please list all files and show me current directory',
wait_for_response=False,
)
user_response._source = EventSource.USER
user_response._id = 3

cmd1 = CmdRunAction(command='ls')
cmd1._id = 4
obs1 = CmdOutputObservation(command='ls', content='file1.txt', command_id=4)
obs1._id = 5
obs1._cause = 4

# Update mock event stream to include new messages
mock_event_stream.get_events.return_value = [
first_msg,
agent_msg,
user_response,
cmd1,
obs1,
]
controller.state.history = [first_msg, agent_msg, user_response, cmd1, obs1]
original_history_len = len(controller.state.history)

# Simulate ContextWindowExceededError and truncation
controller.state.history = controller._apply_conversation_window(
controller.state.history
)

# Verify truncation occurred
assert len(controller.state.history) < original_history_len
assert controller.state.start_id == first_msg._id
assert controller.state.truncation_id is not None
assert controller.state.truncation_id > controller.state.start_id

def test_history_restoration_after_truncation(self, mock_event_stream, mock_agent):
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
max_iterations=10,
sid='test_truncation',
confirmation_mode=False,
headless_mode=True,
)

# Create events with IDs
first_msg = MessageAction(content='Start task', wait_for_response=False)
first_msg._source = EventSource.USER
first_msg._id = 1

events = [first_msg]
for i in range(5):
cmd = CmdRunAction(command=f'cmd{i}')
cmd._id = i + 2
obs = CmdOutputObservation(
command=f'cmd{i}', content=f'output{i}', command_id=cmd._id
)
obs._cause = cmd._id
events.extend([cmd, obs])

# Set up initial history
controller.state.history = events.copy()

# Force truncation
controller.state.history = controller._apply_conversation_window(
controller.state.history
)

# Save state
saved_start_id = controller.state.start_id
saved_truncation_id = controller.state.truncation_id
saved_history_len = len(controller.state.history)

# Set up mock event stream for new controller
mock_event_stream.get_events.return_value = controller.state.history

# Create new controller with saved state
new_controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
max_iterations=10,
sid='test_truncation',
confirmation_mode=False,
headless_mode=True,
)
new_controller.state.start_id = saved_start_id
new_controller.state.truncation_id = saved_truncation_id
new_controller.state.history = mock_event_stream.get_events()

# Verify restoration
assert len(new_controller.state.history) == saved_history_len
assert new_controller.state.history[0] == first_msg
assert new_controller.state.start_id == saved_start_id
Loading