diff --git a/openhands/controller/agent_controller.py b/openhands/controller/agent_controller.py index e0fa0dab0384..216f9d6022ff 100644 --- a/openhands/controller/agent_controller.py +++ b/openhands/controller/agent_controller.py @@ -223,6 +223,21 @@ async def on_event(self, event: Event): # if the event is not filtered out, add it to the history if not any(isinstance(event, filter_type) for filter_type in self.filter_out): + # Check if adding this event would exceed context window + if self.agent.llm.config.max_input_tokens is not None: + # Create temporary history with new event + temp_history = self.state.history + [event] + try: + token_count = self.agent.llm.get_token_count(temp_history) + except Exception as e: + logger.error(f'NO TRUNCATION: Error getting token count: {e}.') + token_count = float('inf') + + if token_count > self.agent.llm.config.max_input_tokens: + # Need to truncate history if there are too many tokens + self.state.history = self._apply_conversation_window(self.state.history) + + # Now add the new event self.state.history.append(event) if isinstance(event, Action): @@ -828,6 +843,10 @@ def _apply_conversation_window(self, events: list[Event]) -> list[Event]: None, ) + # Always set start_id to first user message id if found, regardless of truncation + if first_user_msg: + self.state.start_id = first_user_msg.id + # cut in half mid_point = max(1, len(events) // 2) kept_events = events[mid_point:] diff --git a/tests/unit/test_truncation.py b/tests/unit/test_truncation.py index 7d03d2f619a5..5535d520b3be 100644 --- a/tests/unit/test_truncation.py +++ b/tests/unit/test_truncation.py @@ -186,3 +186,60 @@ def test_history_restoration_after_truncation(self, mock_event_stream, mock_agen assert len(new_controller.state.history) == saved_history_len assert new_controller.state.history[0] == first_msg assert new_controller.state.start_id == saved_start_id + + def test_context_window_parameter_truncation(self, mock_event_stream, mock_agent): + # Configure mock agent's LLM to return specific token counts + mock_agent.llm.get_token_count.return_value = 100 + + # Set max_input_tokens in LLM config + mock_agent.llm.config.max_input_tokens = 80 + + # Create controller + controller = AgentController( + agent=mock_agent, + event_stream=mock_event_stream, + max_iterations=10, + sid='test_truncation', + confirmation_mode=False, + headless_mode=True, + ) + + # Create initial events + first_msg = MessageAction(content='Start task', wait_for_response=False) + first_msg._source = EventSource.USER + first_msg._id = 1 + + events = [first_msg] + for i in range(5): + cmd = CmdRunAction(command=f'cmd{i}') + cmd._id = i + 2 + obs = CmdOutputObservation( + command=f'cmd{i}', content=f'output{i}', command_id=cmd._id + ) + obs._id = i + 3 + obs._cause = cmd._id + events.extend([cmd, obs]) + + # Set initial history + controller.state.history = events[:3] # Start with a few events + controller.state.start_id = first_msg._id # Explicitly set start_id + initial_history_len = len(controller.state.history) + + # Add a new event that should trigger truncation due to token count + mock_agent.llm.get_token_count.return_value = 90 # Exceed our context window + controller.on_event(events[3]) + + # Verify truncation occurred + assert len(controller.state.history) < initial_history_len + 1 + assert controller.state.start_id == first_msg._id + assert controller.state.truncation_id is not None + assert ( + first_msg in controller.state.history + ) # First message should be preserved + + # Verify action-observation pairs weren't split + for i, event in enumerate(controller.state.history[1:]): + if isinstance(event, CmdOutputObservation): + assert any( + e._id == event._cause for e in controller.state.history[: i + 1] + )