Skip to content

Commit

Permalink
Context Window Exceeded fix (#4977)
Browse files Browse the repository at this point in the history
  • Loading branch information
enyst authored Nov 14, 2024
1 parent a93f140 commit 8dee334
Show file tree
Hide file tree
Showing 3 changed files with 320 additions and 2 deletions.
132 changes: 130 additions & 2 deletions openhands/controller/agent_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Callable, ClassVar, Type

import litellm
from litellm.exceptions import ContextWindowExceededError

from openhands.controller.agent import Agent
from openhands.controller.state.state import State, TrafficControlState
Expand Down Expand Up @@ -485,6 +486,15 @@ async def _step(self) -> None:
EventSource.AGENT,
)
return
except ContextWindowExceededError:
# When context window is exceeded, keep roughly half of agent interactions
self.state.history = self._apply_conversation_window(self.state.history)

# Save the ID of the first event in our truncated history for future reloading
if self.state.history:
self.state.start_id = self.state.history[0].id
# Don't add error event - let the agent retry with reduced context
return

if action.runnable:
if self.state.confirmation_mode and (
Expand Down Expand Up @@ -659,6 +669,12 @@ def _init_history(self):
- For delegate events (between AgentDelegateAction and AgentDelegateObservation):
- Excludes all events between the action and observation
- Includes the delegate action and observation themselves
The history is loaded in two parts if truncation_id is set:
1. First user message from start_id onwards
2. Rest of history from truncation_id to the end
Otherwise loads normally from start_id.
"""

# define range of events to fetch
Expand All @@ -680,8 +696,33 @@ def _init_history(self):
self.state.history = []
return

# Get all events, filtering out backend events and hidden events
events = list(
events: list[Event] = []

# If we have a truncation point, get first user message and then rest of history
if hasattr(self.state, 'truncation_id') and self.state.truncation_id > 0:
# Find first user message from stream
first_user_msg = next(
(
e
for e in self.event_stream.get_events(
start_id=start_id,
end_id=end_id,
reverse=False,
filter_out_type=self.filter_out,
filter_hidden=True,
)
if isinstance(e, MessageAction) and e.source == EventSource.USER
),
None,
)
if first_user_msg:
events.append(first_user_msg)

# the rest of the events are from the truncation point
start_id = self.state.truncation_id

# Get rest of history
events_to_add = list(
self.event_stream.get_events(
start_id=start_id,
end_id=end_id,
Expand All @@ -690,6 +731,7 @@ def _init_history(self):
filter_hidden=True,
)
)
events.extend(events_to_add)

# Find all delegate action/observation pairs
delegate_ranges: list[tuple[int, int]] = []
Expand Down Expand Up @@ -744,6 +786,92 @@ def _init_history(self):
# make sure history is in sync
self.state.start_id = start_id

def _apply_conversation_window(self, events: list[Event]) -> list[Event]:
"""Cuts history roughly in half when context window is exceeded, preserving action-observation pairs
and ensuring the first user message is always included.
The algorithm:
1. Cut history in half
2. Check first event in new history:
- If Observation: find and include its Action
- If MessageAction: ensure its related Action-Observation pair isn't split
3. Always include the first user message
Args:
events: List of events to filter
Returns:
Filtered list of events keeping newest half while preserving pairs
"""
if not events:
return events

# Find first user message - we'll need to ensure it's included
first_user_msg = next(
(
e
for e in events
if isinstance(e, MessageAction) and e.source == EventSource.USER
),
None,
)

# cut in half
mid_point = max(1, len(events) // 2)
kept_events = events[mid_point:]

# Handle first event in truncated history
if kept_events:
i = 0
while i < len(kept_events):
first_event = kept_events[i]
if isinstance(first_event, Observation) and first_event.cause:
# Find its action and include it
matching_action = next(
(
e
for e in reversed(events[:mid_point])
if isinstance(e, Action) and e.id == first_event.cause
),
None,
)
if matching_action:
kept_events = [matching_action] + kept_events
else:
self.log(
'warning',
f'Found Observation without matching Action at id={first_event.id}',
)
# drop this observation
kept_events = kept_events[1:]
break

elif isinstance(first_event, MessageAction) or (
isinstance(first_event, Action)
and first_event.source == EventSource.USER
):
# if it's a message action or a user action, keep it and continue to find the next event
i += 1
continue

else:
# if it's an action with source == EventSource.AGENT, we're good
break

# Save where to continue from in next reload
if kept_events:
self.state.truncation_id = kept_events[0].id

# Ensure first user message is included
if first_user_msg and first_user_msg not in kept_events:
kept_events = [first_user_msg] + kept_events

# start_id points to first user message
if first_user_msg:
self.state.start_id = first_user_msg.id

return kept_events

def _is_stuck(self):
"""Checks if the agent or its delegate is stuck in a loop.
Expand Down
2 changes: 2 additions & 0 deletions openhands/controller/state/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ class State:
# start_id and end_id track the range of events in history
start_id: int = -1
end_id: int = -1
# truncation_id tracks where to load history after context window truncation
truncation_id: int = -1
almost_stuck: int = 0
delegates: dict[tuple[int, int], tuple[str, str]] = field(default_factory=dict)
# NOTE: This will never be used by the controller, but it can be used by different
Expand Down
188 changes: 188 additions & 0 deletions tests/unit/test_truncation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
from unittest.mock import MagicMock

import pytest

from openhands.controller.agent_controller import AgentController
from openhands.events import EventSource
from openhands.events.action import CmdRunAction, MessageAction
from openhands.events.observation import CmdOutputObservation


@pytest.fixture
def mock_event_stream():
stream = MagicMock()
# Mock get_events to return an empty list by default
stream.get_events.return_value = []
return stream


@pytest.fixture
def mock_agent():
agent = MagicMock()
agent.llm = MagicMock()
agent.llm.config = MagicMock()
return agent


class TestTruncation:
def test_apply_conversation_window_basic(self, mock_event_stream, mock_agent):
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
max_iterations=10,
sid='test_truncation',
confirmation_mode=False,
headless_mode=True,
)

# Create a sequence of events with IDs
first_msg = MessageAction(content='Hello, start task', wait_for_response=False)
first_msg._source = EventSource.USER
first_msg._id = 1

cmd1 = CmdRunAction(command='ls')
cmd1._id = 2
obs1 = CmdOutputObservation(command='ls', content='file1.txt', command_id=2)
obs1._id = 3
obs1._cause = 2

cmd2 = CmdRunAction(command='pwd')
cmd2._id = 4
obs2 = CmdOutputObservation(command='pwd', content='/home', command_id=4)
obs2._id = 5
obs2._cause = 4

events = [first_msg, cmd1, obs1, cmd2, obs2]

# Apply truncation
truncated = controller._apply_conversation_window(events)

# Should keep first user message and roughly half of other events
assert (
len(truncated) >= 3
) # First message + at least one action-observation pair
assert truncated[0] == first_msg # First message always preserved
assert controller.state.start_id == first_msg._id
assert controller.state.truncation_id is not None

# Verify pairs aren't split
for i, event in enumerate(truncated[1:]):
if isinstance(event, CmdOutputObservation):
assert any(e._id == event._cause for e in truncated[: i + 1])

def test_context_window_exceeded_handling(self, mock_event_stream, mock_agent):
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
max_iterations=10,
sid='test_truncation',
confirmation_mode=False,
headless_mode=True,
)

# Setup initial history with IDs
first_msg = MessageAction(content='Start task', wait_for_response=False)
first_msg._source = EventSource.USER
first_msg._id = 1

# Add agent question
agent_msg = MessageAction(
content='What task would you like me to perform?', wait_for_response=True
)
agent_msg._source = EventSource.AGENT
agent_msg._id = 2

# Add user response
user_response = MessageAction(
content='Please list all files and show me current directory',
wait_for_response=False,
)
user_response._source = EventSource.USER
user_response._id = 3

cmd1 = CmdRunAction(command='ls')
cmd1._id = 4
obs1 = CmdOutputObservation(command='ls', content='file1.txt', command_id=4)
obs1._id = 5
obs1._cause = 4

# Update mock event stream to include new messages
mock_event_stream.get_events.return_value = [
first_msg,
agent_msg,
user_response,
cmd1,
obs1,
]
controller.state.history = [first_msg, agent_msg, user_response, cmd1, obs1]
original_history_len = len(controller.state.history)

# Simulate ContextWindowExceededError and truncation
controller.state.history = controller._apply_conversation_window(
controller.state.history
)

# Verify truncation occurred
assert len(controller.state.history) < original_history_len
assert controller.state.start_id == first_msg._id
assert controller.state.truncation_id is not None
assert controller.state.truncation_id > controller.state.start_id

def test_history_restoration_after_truncation(self, mock_event_stream, mock_agent):
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
max_iterations=10,
sid='test_truncation',
confirmation_mode=False,
headless_mode=True,
)

# Create events with IDs
first_msg = MessageAction(content='Start task', wait_for_response=False)
first_msg._source = EventSource.USER
first_msg._id = 1

events = [first_msg]
for i in range(5):
cmd = CmdRunAction(command=f'cmd{i}')
cmd._id = i + 2
obs = CmdOutputObservation(
command=f'cmd{i}', content=f'output{i}', command_id=cmd._id
)
obs._cause = cmd._id
events.extend([cmd, obs])

# Set up initial history
controller.state.history = events.copy()

# Force truncation
controller.state.history = controller._apply_conversation_window(
controller.state.history
)

# Save state
saved_start_id = controller.state.start_id
saved_truncation_id = controller.state.truncation_id
saved_history_len = len(controller.state.history)

# Set up mock event stream for new controller
mock_event_stream.get_events.return_value = controller.state.history

# Create new controller with saved state
new_controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
max_iterations=10,
sid='test_truncation',
confirmation_mode=False,
headless_mode=True,
)
new_controller.state.start_id = saved_start_id
new_controller.state.truncation_id = saved_truncation_id
new_controller.state.history = mock_event_stream.get_events()

# Verify restoration
assert len(new_controller.state.history) == saved_history_len
assert new_controller.state.history[0] == first_msg
assert new_controller.state.start_id == saved_start_id

0 comments on commit 8dee334

Please sign in to comment.