-
Notifications
You must be signed in to change notification settings - Fork 5.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
320 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,188 @@ | ||
from unittest.mock import MagicMock | ||
|
||
import pytest | ||
|
||
from openhands.controller.agent_controller import AgentController | ||
from openhands.events import EventSource | ||
from openhands.events.action import CmdRunAction, MessageAction | ||
from openhands.events.observation import CmdOutputObservation | ||
|
||
|
||
@pytest.fixture | ||
def mock_event_stream(): | ||
stream = MagicMock() | ||
# Mock get_events to return an empty list by default | ||
stream.get_events.return_value = [] | ||
return stream | ||
|
||
|
||
@pytest.fixture | ||
def mock_agent(): | ||
agent = MagicMock() | ||
agent.llm = MagicMock() | ||
agent.llm.config = MagicMock() | ||
return agent | ||
|
||
|
||
class TestTruncation: | ||
def test_apply_conversation_window_basic(self, mock_event_stream, mock_agent): | ||
controller = AgentController( | ||
agent=mock_agent, | ||
event_stream=mock_event_stream, | ||
max_iterations=10, | ||
sid='test_truncation', | ||
confirmation_mode=False, | ||
headless_mode=True, | ||
) | ||
|
||
# Create a sequence of events with IDs | ||
first_msg = MessageAction(content='Hello, start task', wait_for_response=False) | ||
first_msg._source = EventSource.USER | ||
first_msg._id = 1 | ||
|
||
cmd1 = CmdRunAction(command='ls') | ||
cmd1._id = 2 | ||
obs1 = CmdOutputObservation(command='ls', content='file1.txt', command_id=2) | ||
obs1._id = 3 | ||
obs1._cause = 2 | ||
|
||
cmd2 = CmdRunAction(command='pwd') | ||
cmd2._id = 4 | ||
obs2 = CmdOutputObservation(command='pwd', content='/home', command_id=4) | ||
obs2._id = 5 | ||
obs2._cause = 4 | ||
|
||
events = [first_msg, cmd1, obs1, cmd2, obs2] | ||
|
||
# Apply truncation | ||
truncated = controller._apply_conversation_window(events) | ||
|
||
# Should keep first user message and roughly half of other events | ||
assert ( | ||
len(truncated) >= 3 | ||
) # First message + at least one action-observation pair | ||
assert truncated[0] == first_msg # First message always preserved | ||
assert controller.state.start_id == first_msg._id | ||
assert controller.state.truncation_id is not None | ||
|
||
# Verify pairs aren't split | ||
for i, event in enumerate(truncated[1:]): | ||
if isinstance(event, CmdOutputObservation): | ||
assert any(e._id == event._cause for e in truncated[: i + 1]) | ||
|
||
def test_context_window_exceeded_handling(self, mock_event_stream, mock_agent): | ||
controller = AgentController( | ||
agent=mock_agent, | ||
event_stream=mock_event_stream, | ||
max_iterations=10, | ||
sid='test_truncation', | ||
confirmation_mode=False, | ||
headless_mode=True, | ||
) | ||
|
||
# Setup initial history with IDs | ||
first_msg = MessageAction(content='Start task', wait_for_response=False) | ||
first_msg._source = EventSource.USER | ||
first_msg._id = 1 | ||
|
||
# Add agent question | ||
agent_msg = MessageAction( | ||
content='What task would you like me to perform?', wait_for_response=True | ||
) | ||
agent_msg._source = EventSource.AGENT | ||
agent_msg._id = 2 | ||
|
||
# Add user response | ||
user_response = MessageAction( | ||
content='Please list all files and show me current directory', | ||
wait_for_response=False, | ||
) | ||
user_response._source = EventSource.USER | ||
user_response._id = 3 | ||
|
||
cmd1 = CmdRunAction(command='ls') | ||
cmd1._id = 4 | ||
obs1 = CmdOutputObservation(command='ls', content='file1.txt', command_id=4) | ||
obs1._id = 5 | ||
obs1._cause = 4 | ||
|
||
# Update mock event stream to include new messages | ||
mock_event_stream.get_events.return_value = [ | ||
first_msg, | ||
agent_msg, | ||
user_response, | ||
cmd1, | ||
obs1, | ||
] | ||
controller.state.history = [first_msg, agent_msg, user_response, cmd1, obs1] | ||
original_history_len = len(controller.state.history) | ||
|
||
# Simulate ContextWindowExceededError and truncation | ||
controller.state.history = controller._apply_conversation_window( | ||
controller.state.history | ||
) | ||
|
||
# Verify truncation occurred | ||
assert len(controller.state.history) < original_history_len | ||
assert controller.state.start_id == first_msg._id | ||
assert controller.state.truncation_id is not None | ||
assert controller.state.truncation_id > controller.state.start_id | ||
|
||
def test_history_restoration_after_truncation(self, mock_event_stream, mock_agent): | ||
controller = AgentController( | ||
agent=mock_agent, | ||
event_stream=mock_event_stream, | ||
max_iterations=10, | ||
sid='test_truncation', | ||
confirmation_mode=False, | ||
headless_mode=True, | ||
) | ||
|
||
# Create events with IDs | ||
first_msg = MessageAction(content='Start task', wait_for_response=False) | ||
first_msg._source = EventSource.USER | ||
first_msg._id = 1 | ||
|
||
events = [first_msg] | ||
for i in range(5): | ||
cmd = CmdRunAction(command=f'cmd{i}') | ||
cmd._id = i + 2 | ||
obs = CmdOutputObservation( | ||
command=f'cmd{i}', content=f'output{i}', command_id=cmd._id | ||
) | ||
obs._cause = cmd._id | ||
events.extend([cmd, obs]) | ||
|
||
# Set up initial history | ||
controller.state.history = events.copy() | ||
|
||
# Force truncation | ||
controller.state.history = controller._apply_conversation_window( | ||
controller.state.history | ||
) | ||
|
||
# Save state | ||
saved_start_id = controller.state.start_id | ||
saved_truncation_id = controller.state.truncation_id | ||
saved_history_len = len(controller.state.history) | ||
|
||
# Set up mock event stream for new controller | ||
mock_event_stream.get_events.return_value = controller.state.history | ||
|
||
# Create new controller with saved state | ||
new_controller = AgentController( | ||
agent=mock_agent, | ||
event_stream=mock_event_stream, | ||
max_iterations=10, | ||
sid='test_truncation', | ||
confirmation_mode=False, | ||
headless_mode=True, | ||
) | ||
new_controller.state.start_id = saved_start_id | ||
new_controller.state.truncation_id = saved_truncation_id | ||
new_controller.state.history = mock_event_stream.get_events() | ||
|
||
# Verify restoration | ||
assert len(new_controller.state.history) == saved_history_len | ||
assert new_controller.state.history[0] == first_msg | ||
assert new_controller.state.start_id == saved_start_id |