Skip to content

Commit

Permalink
test: Enable prompt extensions in test_prompt_manager_extensions
Browse files Browse the repository at this point in the history
  • Loading branch information
openhands-agent committed Feb 2, 2025
1 parent 7a75236 commit e82aa13
Showing 1 changed file with 165 additions and 1 deletion.
166 changes: 165 additions & 1 deletion tests/unit/test_agent_controller.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import asyncio
from unittest.mock import ANY, AsyncMock, MagicMock
from unittest.mock import ANY, AsyncMock, MagicMock, call
from uuid import uuid4

import pytest
Expand All @@ -16,6 +16,7 @@
ChangeAgentStateAction,
CmdRunAction,
MessageAction,
PromptExtensionAction,
SystemMessageAction,
)
from openhands.events.observation import (
Expand Down Expand Up @@ -750,6 +751,169 @@ async def test_prompt_manager_initialization(mock_agent, mock_event_stream):
await controller.close()


@pytest.mark.asyncio
async def test_prompt_manager_extensions(mock_agent, mock_event_stream):
"""Test that prompt extensions are properly added to event stream."""
# Mock the prompt manager and enable extensions
mock_agent.config.enable_prompt_extensions = True
mock_prompt_manager = MagicMock(spec=PromptManager)
mock_prompt_manager.get_system_message.return_value = "Test system message"

# Mock the prompt extension methods
def add_examples(msg):
msg.content[0].text = "Examples added: " + msg.content[0].text
mock_prompt_manager.add_examples_to_initial_message.side_effect = add_examples

def add_info(msg):
msg.content[0].text = "Info added: " + msg.content[0].text
mock_prompt_manager.add_info_to_initial_message.side_effect = add_info

def enhance(msg):
msg.content[0].text = "Enhanced: " + msg.content[0].text
mock_prompt_manager.enhance_message.side_effect = enhance

mock_agent.get_prompt_manager.return_value = mock_prompt_manager

# Create controller
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
max_iterations=10,
sid='test',
confirmation_mode=False,
headless_mode=True,
)

# Send a user message
message_action = MessageAction(content="Test message")
message_action._source = EventSource.USER
await controller._on_event(message_action)

# Get all calls to add_event
actual_calls = mock_event_stream.add_event.call_args_list

# Verify that system message was added
assert any(
isinstance(args[0], SystemMessageAction) and args[0].content == "Test system message"
for args, _ in actual_calls
)

# Verify that prompt extensions were added
expected_extensions = [
("Examples added: Test message", "examples"),
("Info added: Examples added: Test message", "info"),
("Enhanced: Info added: Examples added: Test message", "enhance"),
]
for content, ext_type in expected_extensions:
assert any(
isinstance(args[0], PromptExtensionAction)
and args[0].content == content
and args[0].extension_type == ext_type
for args, _ in actual_calls
), f"Missing extension: {ext_type}"



await controller.close()


@pytest.mark.asyncio
async def test_prompt_manager_extensions_disabled(mock_agent, mock_event_stream):
"""Test that prompt extensions are not added when disabled."""
# Mock the prompt manager but disable extensions
mock_agent.config.enable_prompt_extensions = False
mock_prompt_manager = MagicMock(spec=PromptManager)
mock_prompt_manager.get_system_message.return_value = "Test system message"
mock_agent.get_prompt_manager.return_value = mock_prompt_manager

# Create controller
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
max_iterations=10,
sid='test',
confirmation_mode=False,
headless_mode=True,
)

# Send a user message
message_action = MessageAction(content="Test message")
message_action._source = EventSource.USER
await controller._on_event(message_action)

# Get all calls to add_event
actual_calls = mock_event_stream.add_event.call_args_list

# Verify that system message was added
assert any(
isinstance(args[0], SystemMessageAction) and args[0].content == "Test system message"
for args, _ in actual_calls
)

# Verify that no prompt extensions were added
assert not any(
isinstance(args[0], PromptExtensionAction)
for args, _ in actual_calls
)

# Verify that extension methods were not called
mock_prompt_manager.add_examples_to_initial_message.assert_not_called()
mock_prompt_manager.add_info_to_initial_message.assert_not_called()
mock_prompt_manager.enhance_message.assert_not_called()

await controller.close()


@pytest.mark.asyncio
async def test_prompt_manager_extensions_delegate(mock_agent, mock_event_stream):
"""Test that prompt extensions are not added for delegate controllers."""
# Mock the prompt manager
mock_prompt_manager = MagicMock(spec=PromptManager)
mock_prompt_manager.get_system_message.return_value = "Test system message"
mock_agent.get_prompt_manager.return_value = mock_prompt_manager

# Create delegate controller
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
max_iterations=10,
sid='test',
confirmation_mode=False,
headless_mode=True,
is_delegate=True, # This should prevent system message and extensions
)

# Send a user message
message_action = MessageAction(content="Test message")
message_action._source = EventSource.USER
await controller._on_event(message_action)

# Get all calls to add_event
actual_calls = mock_event_stream.add_event.call_args_list

# Verify that only the original message was added (plus state changes)
print("Message action:", message_action)
print("Actual calls:", actual_calls)
assert any(
args[0] == message_action and args[1] == EventSource.USER
for args, _ in actual_calls
)

# Verify that no system message or prompt extensions were added
assert not any(
isinstance(args[0], SystemMessageAction) or isinstance(args[0], PromptExtensionAction)
for args, _ in actual_calls
)

# Verify that system message and extension methods were not called
mock_prompt_manager.get_system_message.assert_not_called()
mock_prompt_manager.add_examples_to_initial_message.assert_not_called()
mock_prompt_manager.add_info_to_initial_message.assert_not_called()
mock_prompt_manager.enhance_message.assert_not_called()

await controller.close()


@pytest.mark.asyncio
async def test_prompt_manager_not_initialized(mock_agent, mock_event_stream):
"""Test that no system message is sent if prompt manager is not initialized."""
Expand Down

0 comments on commit e82aa13

Please sign in to comment.