From ba9bc74179f95cbb9cbe16167c49a8c98102a45e Mon Sep 17 00:00:00 2001 From: Raymond Xu Date: Fri, 15 Nov 2024 20:20:04 -0800 Subject: [PATCH] pass lint and truncation tests --- openhands/controller/agent_controller.py | 21 ++++++++++++++++----- tests/unit/test_truncation.py | 1 + 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/openhands/controller/agent_controller.py b/openhands/controller/agent_controller.py index 9c3e083a83b0..6cf0f45bf0d7 100644 --- a/openhands/controller/agent_controller.py +++ b/openhands/controller/agent_controller.py @@ -228,11 +228,11 @@ async def on_event(self, event: Event): # Create temporary history with new event temp_history = self.state.history + [event] token_count = self.agent.llm.get_token_count(temp_history) - + if token_count > self.context_window: # Truncate existing history before adding new event self.state.history = self._apply_conversation_window(self.state.history) - + # Now add the new event self.state.history.append(event) @@ -839,6 +839,10 @@ def _apply_conversation_window(self, events: list[Event]) -> list[Event]: None, ) + # Always set start_id to first user message id if found, regardless of truncation + if first_user_msg: + self.state.start_id = first_user_msg.id + # cut in half mid_point = max(1, len(events) // 2) kept_events = events[mid_point:] @@ -889,9 +893,16 @@ def _apply_conversation_window(self, events: list[Event]) -> list[Event]: if first_user_msg and first_user_msg not in kept_events: kept_events = [first_user_msg] + kept_events - # start_id points to first user message - if first_user_msg: - self.state.start_id = first_user_msg.id + # Verify truncated history fits in context window + while ( + self.context_window is not None + and self.agent.llm.get_token_count(kept_events) > self.context_window + ): + # Need to truncate more - remove oldest non-first-message event + if len(kept_events) > 2: # Keep at least first message and one more event + kept_events.pop(1) + else: + break return kept_events diff --git a/tests/unit/test_truncation.py b/tests/unit/test_truncation.py index b930c11a6697..183255eed136 100644 --- a/tests/unit/test_truncation.py +++ b/tests/unit/test_truncation.py @@ -220,6 +220,7 @@ def test_context_window_parameter_truncation(self, mock_event_stream, mock_agent # Set initial history controller.state.history = events[:3] # Start with a few events + controller.state.start_id = first_msg._id # Explicitly set start_id initial_history_len = len(controller.state.history) # Add a new event that should trigger truncation due to token count