Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to reduce context window #5193

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions openhands/controller/agent_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,21 @@ async def on_event(self, event: Event):

# if the event is not filtered out, add it to the history
if not any(isinstance(event, filter_type) for filter_type in self.filter_out):
# Check if adding this event would exceed context window
if self.agent.llm.config.max_input_tokens is not None:
# Create temporary history with new event
temp_history = self.state.history + [event]
try:
token_count = self.agent.llm.get_token_count(temp_history)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm afraid this line doesn't really do what we want. We need to do the token counting on the messages that are sent to the LLM API. OpenHands works with events, and these are events, which it then 'translates' into messages and sends to the LLM API. You can see, as you try to make it work on your machine for your use case, that the number of tokens will not match, some events differ more than others but they all differ. 😅

An alternative is to define a custom exception like, let's say, TokenLimitExceeded, here, and move this check to the LLM class, then raise the exception when the token comparison fails. Then maybe the exception can be treated like this ContextWindowExceededError is. What do you think, does that make sense?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ahh I see I think I understand. I was getting unpredictable behavior last night and that makes more sense and this definitely doesn't work. I'll take another crack at it when I get a chance today.

I do have a dumb question but I can't recreate it predictably. How is the config.toml file loaded? Is it only with certain commands? I can't tell if I broke something on my machine or I am assuming incorrect behavior

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's figure it out! There is something a bit unpredictable about it: when you set values there, then run either with UI or with main.py or with cli.py, then config.toml is loaded. However, if running with UI, any settings that are defined in the UI (in the Settings window) will override the toml values.
Could that be what happens?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And it's absolutely a good question! We're still trying to document it properly, the simple thing that it turns out a bit difficult does say that something's wrong and we should rethink some of it.
For reference: #3220

except Exception as e:
logger.error(f'NO TRUNCATION: Error getting token count: {e}.')
token_count = float('inf')

if token_count > self.agent.llm.config.max_input_tokens:
# Need to truncate history if there are too many tokens
self.state.history = self._apply_conversation_window(self.state.history)

# Now add the new event
self.state.history.append(event)

if isinstance(event, Action):
Expand Down Expand Up @@ -828,6 +843,10 @@ def _apply_conversation_window(self, events: list[Event]) -> list[Event]:
None,
)

# Always set start_id to first user message id if found, regardless of truncation
if first_user_msg:
self.state.start_id = first_user_msg.id

# cut in half
mid_point = max(1, len(events) // 2)
kept_events = events[mid_point:]
Expand Down
57 changes: 57 additions & 0 deletions tests/unit/test_truncation.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,3 +186,60 @@ def test_history_restoration_after_truncation(self, mock_event_stream, mock_agen
assert len(new_controller.state.history) == saved_history_len
assert new_controller.state.history[0] == first_msg
assert new_controller.state.start_id == saved_start_id

def test_context_window_parameter_truncation(self, mock_event_stream, mock_agent):
# Configure mock agent's LLM to return specific token counts
mock_agent.llm.get_token_count.return_value = 100

# Set max_input_tokens in LLM config
mock_agent.llm.config.max_input_tokens = 80

# Create controller
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
max_iterations=10,
sid='test_truncation',
confirmation_mode=False,
headless_mode=True,
)

# Create initial events
first_msg = MessageAction(content='Start task', wait_for_response=False)
first_msg._source = EventSource.USER
first_msg._id = 1

events = [first_msg]
for i in range(5):
cmd = CmdRunAction(command=f'cmd{i}')
cmd._id = i + 2
obs = CmdOutputObservation(
command=f'cmd{i}', content=f'output{i}', command_id=cmd._id
)
obs._id = i + 3
obs._cause = cmd._id
events.extend([cmd, obs])

# Set initial history
controller.state.history = events[:3] # Start with a few events
controller.state.start_id = first_msg._id # Explicitly set start_id
initial_history_len = len(controller.state.history)

# Add a new event that should trigger truncation due to token count
mock_agent.llm.get_token_count.return_value = 90 # Exceed our context window
controller.on_event(events[3])

# Verify truncation occurred
assert len(controller.state.history) < initial_history_len + 1
assert controller.state.start_id == first_msg._id
assert controller.state.truncation_id is not None
assert (
first_msg in controller.state.history
) # First message should be preserved

# Verify action-observation pairs weren't split
for i, event in enumerate(controller.state.history[1:]):
if isinstance(event, CmdOutputObservation):
assert any(
e._id == event._cause for e in controller.state.history[: i + 1]
)
Loading