diff --git a/tests/unit/test_agent_controller.py b/tests/unit/test_agent_controller.py index 02cb80b01a74..5913fd3bf992 100644 --- a/tests/unit/test_agent_controller.py +++ b/tests/unit/test_agent_controller.py @@ -1,5 +1,5 @@ import asyncio -from unittest.mock import ANY, AsyncMock, MagicMock +from unittest.mock import ANY, AsyncMock, MagicMock, call from uuid import uuid4 import pytest @@ -16,6 +16,7 @@ ChangeAgentStateAction, CmdRunAction, MessageAction, + PromptExtensionAction, SystemMessageAction, ) from openhands.events.observation import ( @@ -750,6 +751,169 @@ async def test_prompt_manager_initialization(mock_agent, mock_event_stream): await controller.close() +@pytest.mark.asyncio +async def test_prompt_manager_extensions(mock_agent, mock_event_stream): + """Test that prompt extensions are properly added to event stream.""" + # Mock the prompt manager and enable extensions + mock_agent.config.enable_prompt_extensions = True + mock_prompt_manager = MagicMock(spec=PromptManager) + mock_prompt_manager.get_system_message.return_value = "Test system message" + + # Mock the prompt extension methods + def add_examples(msg): + msg.content[0].text = "Examples added: " + msg.content[0].text + mock_prompt_manager.add_examples_to_initial_message.side_effect = add_examples + + def add_info(msg): + msg.content[0].text = "Info added: " + msg.content[0].text + mock_prompt_manager.add_info_to_initial_message.side_effect = add_info + + def enhance(msg): + msg.content[0].text = "Enhanced: " + msg.content[0].text + mock_prompt_manager.enhance_message.side_effect = enhance + + mock_agent.get_prompt_manager.return_value = mock_prompt_manager + + # Create controller + controller = AgentController( + agent=mock_agent, + event_stream=mock_event_stream, + max_iterations=10, + sid='test', + confirmation_mode=False, + headless_mode=True, + ) + + # Send a user message + message_action = MessageAction(content="Test message") + message_action._source = EventSource.USER + await controller._on_event(message_action) + + # Get all calls to add_event + actual_calls = mock_event_stream.add_event.call_args_list + + # Verify that system message was added + assert any( + isinstance(args[0], SystemMessageAction) and args[0].content == "Test system message" + for args, _ in actual_calls + ) + + # Verify that prompt extensions were added + expected_extensions = [ + ("Examples added: Test message", "examples"), + ("Info added: Examples added: Test message", "info"), + ("Enhanced: Info added: Examples added: Test message", "enhance"), + ] + for content, ext_type in expected_extensions: + assert any( + isinstance(args[0], PromptExtensionAction) + and args[0].content == content + and args[0].extension_type == ext_type + for args, _ in actual_calls + ), f"Missing extension: {ext_type}" + + + + await controller.close() + + +@pytest.mark.asyncio +async def test_prompt_manager_extensions_disabled(mock_agent, mock_event_stream): + """Test that prompt extensions are not added when disabled.""" + # Mock the prompt manager but disable extensions + mock_agent.config.enable_prompt_extensions = False + mock_prompt_manager = MagicMock(spec=PromptManager) + mock_prompt_manager.get_system_message.return_value = "Test system message" + mock_agent.get_prompt_manager.return_value = mock_prompt_manager + + # Create controller + controller = AgentController( + agent=mock_agent, + event_stream=mock_event_stream, + max_iterations=10, + sid='test', + confirmation_mode=False, + headless_mode=True, + ) + + # Send a user message + message_action = MessageAction(content="Test message") + message_action._source = EventSource.USER + await controller._on_event(message_action) + + # Get all calls to add_event + actual_calls = mock_event_stream.add_event.call_args_list + + # Verify that system message was added + assert any( + isinstance(args[0], SystemMessageAction) and args[0].content == "Test system message" + for args, _ in actual_calls + ) + + # Verify that no prompt extensions were added + assert not any( + isinstance(args[0], PromptExtensionAction) + for args, _ in actual_calls + ) + + # Verify that extension methods were not called + mock_prompt_manager.add_examples_to_initial_message.assert_not_called() + mock_prompt_manager.add_info_to_initial_message.assert_not_called() + mock_prompt_manager.enhance_message.assert_not_called() + + await controller.close() + + +@pytest.mark.asyncio +async def test_prompt_manager_extensions_delegate(mock_agent, mock_event_stream): + """Test that prompt extensions are not added for delegate controllers.""" + # Mock the prompt manager + mock_prompt_manager = MagicMock(spec=PromptManager) + mock_prompt_manager.get_system_message.return_value = "Test system message" + mock_agent.get_prompt_manager.return_value = mock_prompt_manager + + # Create delegate controller + controller = AgentController( + agent=mock_agent, + event_stream=mock_event_stream, + max_iterations=10, + sid='test', + confirmation_mode=False, + headless_mode=True, + is_delegate=True, # This should prevent system message and extensions + ) + + # Send a user message + message_action = MessageAction(content="Test message") + message_action._source = EventSource.USER + await controller._on_event(message_action) + + # Get all calls to add_event + actual_calls = mock_event_stream.add_event.call_args_list + + # Verify that only the original message was added (plus state changes) + print("Message action:", message_action) + print("Actual calls:", actual_calls) + assert any( + args[0] == message_action and args[1] == EventSource.USER + for args, _ in actual_calls + ) + + # Verify that no system message or prompt extensions were added + assert not any( + isinstance(args[0], SystemMessageAction) or isinstance(args[0], PromptExtensionAction) + for args, _ in actual_calls + ) + + # Verify that system message and extension methods were not called + mock_prompt_manager.get_system_message.assert_not_called() + mock_prompt_manager.add_examples_to_initial_message.assert_not_called() + mock_prompt_manager.add_info_to_initial_message.assert_not_called() + mock_prompt_manager.enhance_message.assert_not_called() + + await controller.close() + + @pytest.mark.asyncio async def test_prompt_manager_not_initialized(mock_agent, mock_event_stream): """Test that no system message is sent if prompt manager is not initialized."""