From 0b97e1f941306669445adb83fcbcb48305245049 Mon Sep 17 00:00:00 2001 From: Corneliu Croitoru Date: Tue, 28 Jan 2025 14:14:47 +0100 Subject: [PATCH] fix InlineAgent & ComprehendAgent & add unit tests --- python/setup.cfg | 2 +- .../agents/bedrock_inline_agent.py | 4 + .../agents/comprehend_filter_agent.py | 42 ++-- .../tests/agents/test_bedrock_inline_agent.py | 232 ++++++++++++++++++ .../src/tests/agents/test_comprehend_agent.py | 220 +++++++++++++++++ 5 files changed, 480 insertions(+), 20 deletions(-) create mode 100644 python/src/tests/agents/test_bedrock_inline_agent.py create mode 100644 python/src/tests/agents/test_comprehend_agent.py diff --git a/python/setup.cfg b/python/setup.cfg index aae56b5d..83273e78 100644 --- a/python/setup.cfg +++ b/python/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = multi_agent_orchestrator -version = 0.1.6 +version = 0.1.7 author = Anthony Bernabeu, Corneliu Croitoru author_email = brnaba@amazon.com, ccroito@amazon.com description = Multi-agent orchestrator framework diff --git a/python/src/multi_agent_orchestrator/agents/bedrock_inline_agent.py b/python/src/multi_agent_orchestrator/agents/bedrock_inline_agent.py index 2540a9b0..a137fcce 100644 --- a/python/src/multi_agent_orchestrator/agents/bedrock_inline_agent.py +++ b/python/src/multi_agent_orchestrator/agents/bedrock_inline_agent.py @@ -15,6 +15,8 @@ # BedrockInlineAgentOptions Dataclass @dataclass class BedrockInlineAgentOptions(AgentOptions): + model_id: Optional[str] = None + region: Optional[str] = None inference_config: Optional[Dict[str, Any]] = None client: Optional[Any] = None bedrock_agent_client: Optional[Any] = None @@ -71,6 +73,8 @@ def __init__(self, options: BedrockInlineAgentOptions): else: self.client = boto3.client('bedrock-runtime') + self.model_id: str = options.model_id or BEDROCK_MODEL_ID_CLAUDE_3_HAIKU + # Initialize bedrock agent client if options.bedrock_agent_client: self.bedrock_agent_client = options.bedrock_agent_client diff --git a/python/src/multi_agent_orchestrator/agents/comprehend_filter_agent.py b/python/src/multi_agent_orchestrator/agents/comprehend_filter_agent.py index 4d7f07f0..833a53a3 100644 --- a/python/src/multi_agent_orchestrator/agents/comprehend_filter_agent.py +++ b/python/src/multi_agent_orchestrator/agents/comprehend_filter_agent.py @@ -4,35 +4,39 @@ from .agent import Agent, AgentOptions import boto3 from botocore.config import Config +import os +from dataclasses import dataclass + # Type alias for CheckFunction CheckFunction = Callable[[str], str] +@dataclass class ComprehendFilterAgentOptions(AgentOptions): - def __init__(self, - enable_sentiment_check: bool = True, - enable_pii_check: bool = True, - enable_toxicity_check: bool = True, - sentiment_threshold: float = 0.7, - toxicity_threshold: float = 0.7, - allow_pii: bool = False, - language_code: str = 'en', - **kwargs): - super().__init__(**kwargs) - self.enable_sentiment_check = enable_sentiment_check - self.enable_pii_check = enable_pii_check - self.enable_toxicity_check = enable_toxicity_check - self.sentiment_threshold = sentiment_threshold - self.toxicity_threshold = toxicity_threshold - self.allow_pii = allow_pii - self.language_code = language_code + enable_sentiment_check: bool = True + enable_pii_check: bool = True + enable_toxicity_check: bool = True + sentiment_threshold: float = 0.7 + toxicity_threshold: float = 0.7 + allow_pii: bool = False + language_code: str = 'en' + region: Optional[str] = None + client: Optional[Any] = None class ComprehendFilterAgent(Agent): def __init__(self, options: ComprehendFilterAgentOptions): super().__init__(options) - config = Config(region_name=options.region) if options.region else None - self.comprehend_client = boto3.client('comprehend', config=config) + if options.client: + self.comprehend_client = options.client + else: + if options.region: + self.client = boto3.client( + 'comprehend', + region_name=options.region or os.environ.get('AWS_REGION') + ) + else: + self.client = boto3.client('comprehend') self.custom_checks: List[CheckFunction] = [] diff --git a/python/src/tests/agents/test_bedrock_inline_agent.py b/python/src/tests/agents/test_bedrock_inline_agent.py new file mode 100644 index 00000000..8dcfd3ef --- /dev/null +++ b/python/src/tests/agents/test_bedrock_inline_agent.py @@ -0,0 +1,232 @@ +import unittest +from unittest.mock import Mock +import json +from typing import Dict, Any + +from multi_agent_orchestrator.types import ConversationMessage, ParticipantRole +from multi_agent_orchestrator.agents import BedrockInlineAgent, BedrockInlineAgentOptions + +class TestBedrockInlineAgent(unittest.IsolatedAsyncioTestCase): + async def asyncSetUp(self): + # Mock clients + self.mock_bedrock_client = Mock() + self.mock_bedrock_agent_client = Mock() + + # Sample action groups and knowledge bases + self.action_groups = [ + { + "actionGroupName": "TestActionGroup1", + "description": "Test action group 1 description" + }, + { + "actionGroupName": "TestActionGroup2", + "description": "Test action group 2 description" + } + ] + + self.knowledge_bases = [ + { + "knowledgeBaseId": "kb1", + "description": "Test knowledge base 1" + }, + { + "knowledgeBaseId": "kb2", + "description": "Test knowledge base 2" + } + ] + + # Create agent instance + self.agent = BedrockInlineAgent( + BedrockInlineAgentOptions( + name="Test Agent", + description="Test agent description", + client=self.mock_bedrock_client, + bedrock_agent_client=self.mock_bedrock_agent_client, + action_groups_list=self.action_groups, + knowledge_bases=self.knowledge_bases + ) + ) + + async def test_initialization(self): + """Test agent initialization and configuration""" + self.assertEqual(self.agent.name, "Test Agent") + self.assertEqual(self.agent.description, "Test agent description") + self.assertEqual(len(self.agent.action_groups_list), 2) + self.assertEqual(len(self.agent.knowledge_bases), 2) + self.assertEqual(self.agent.tool_config['toolMaxRecursions'], 1) + + async def test_process_request_without_tool_use(self): + """Test processing a request that doesn't require tool use""" + # Mock the converse response + mock_response = { + 'output': { + 'message': { + 'role': 'assistant', + 'content': [{'text': 'Test response'}] + } + } + } + self.mock_bedrock_client.converse.return_value = mock_response + + # Test input + input_text = "Hello" + chat_history = [] + + # Process request + response = await self.agent.process_request( + input_text=input_text, + user_id='test_user', + session_id='test_session', + chat_history=chat_history + ) + + # Verify response + self.assertIsInstance(response, ConversationMessage) + self.assertEqual(response.role, ParticipantRole.ASSISTANT.value) + self.assertEqual(response.content[0]['text'], 'Test response') + + async def test_process_request_with_tool_use(self): + """Test processing a request that requires tool use""" + # Mock the converse response with tool use + tool_use_response = { + 'output': { + 'message': { + 'role': 'assistant', + 'content': [{ + 'toolUse': { + 'name': 'inline_agent_creation', + 'input': { + 'action_group_names': ['TestActionGroup1'], + 'knowledge_bases': ['kb1'], + 'description': 'Test description', + 'user_request': 'Test request' + } + } + }] + } + } + } + self.mock_bedrock_client.converse.return_value = tool_use_response + + # Mock the inline agent response + mock_completion = { + 'chunk': { + 'bytes': b'Inline agent response' + } + } + self.mock_bedrock_agent_client.invoke_inline_agent.return_value = { + 'completion': [mock_completion] + } + + # Test input + input_text = "Use inline agent" + chat_history = [] + + # Process request + response = await self.agent.process_request( + input_text=input_text, + user_id='test_user', + session_id='test_session', + chat_history=chat_history + ) + + # Verify response + self.assertIsInstance(response, ConversationMessage) + self.assertEqual(response.role, ParticipantRole.ASSISTANT.value) + self.assertEqual(response.content[0]['text'], 'Inline agent response') + + # Verify inline agent was called with correct parameters + self.mock_bedrock_agent_client.invoke_inline_agent.assert_called_once() + call_kwargs = self.mock_bedrock_agent_client.invoke_inline_agent.call_args[1] + self.assertEqual(len(call_kwargs['actionGroups']), 1) + self.assertEqual(len(call_kwargs['knowledgeBases']), 1) + self.assertEqual(call_kwargs['inputText'], 'Test request') + + async def test_error_handling(self): + """Test error handling in process_request""" + # Mock the converse method to raise an exception + self.mock_bedrock_client.converse.side_effect = Exception("Test error") + + # Test input + input_text = "Hello" + chat_history = [] + + # Verify exception is raised + with self.assertRaises(Exception) as context: + await self.agent.process_request( + input_text=input_text, + user_id='test_user', + session_id='test_session', + chat_history=chat_history + ) + + self.assertTrue("Test error" in str(context.exception)) + + async def test_system_prompt_formatting(self): + """Test system prompt formatting and template replacement""" + # Test with custom variables + test_variables = { + 'test_var': 'test_value' + } + self.agent.set_system_prompt( + template="Test template with {{test_var}}", + variables=test_variables + ) + + self.assertEqual(self.agent.system_prompt, "Test template with test_value") + + async def test_inline_agent_tool_handler(self): + """Test the inline agent tool handler""" + # Mock response content + response = ConversationMessage( + role=ParticipantRole.ASSISTANT.value, + content=[{ + 'toolUse': { + 'name': 'inline_agent_creation', + 'input': { + 'action_group_names': ['TestActionGroup1'], + 'knowledge_bases': ['kb1'], + 'description': 'Test description', + 'user_request': 'Test request' + } + } + }] + ) + + # Mock inline agent response + mock_completion = { + 'chunk': { + 'bytes': b'Handler test response' + } + } + self.mock_bedrock_agent_client.invoke_inline_agent.return_value = { + 'completion': [mock_completion] + } + + # Call handler + result = await self.agent.inline_agent_tool_handler( + session_id='test_session', + response=response, + conversation=[] + ) + + # Verify result + self.assertIsInstance(result, ConversationMessage) + self.assertEqual(result.content[0]['text'], 'Handler test response') + + async def test_custom_prompt_template(self): + """Test custom prompt template setup""" + custom_template = "Custom template {{test_var}}" + custom_variables = {"test_var": "test_value"} + + self.agent.set_system_prompt( + template=custom_template, + variables=custom_variables + ) + + self.assertEqual(self.agent.prompt_template, custom_template) + self.assertEqual(self.agent.custom_variables, custom_variables) + self.assertEqual(self.agent.system_prompt, "Custom template test_value") + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/python/src/tests/agents/test_comprehend_agent.py b/python/src/tests/agents/test_comprehend_agent.py new file mode 100644 index 00000000..ac6ebf56 --- /dev/null +++ b/python/src/tests/agents/test_comprehend_agent.py @@ -0,0 +1,220 @@ +import unittest +from unittest.mock import Mock +from typing import Dict, Any + +from multi_agent_orchestrator.types import ConversationMessage, ParticipantRole +from multi_agent_orchestrator.agents import ComprehendFilterAgent, ComprehendFilterAgentOptions + +class TestComprehendFilterAgent(unittest.IsolatedAsyncioTestCase): + async def asyncSetUp(self): + # Create mock comprehend client + self.mock_comprehend_client = Mock() + + # Setup default positive responses + self.mock_comprehend_client.detect_sentiment.return_value = { + 'Sentiment': 'POSITIVE', + 'SentimentScore': { + 'Positive': 0.9, + 'Negative': 0.1, + 'Neutral': 0.0, + 'Mixed': 0.0 + } + } + + self.mock_comprehend_client.detect_pii_entities.return_value = { + 'Entities': [] + } + + self.mock_comprehend_client.detect_toxic_content.return_value = { + 'ResultList': [{ + 'Labels': [] + }] + } + + # Create agent instance + self.agent = ComprehendFilterAgent( + ComprehendFilterAgentOptions( + name="Test Filter Agent", + description="Test agent for filtering content", + client=self.mock_comprehend_client + ) + ) + + async def test_initialization(self): + """Test agent initialization and configuration""" + self.assertEqual(self.agent.name, "Test Filter Agent") + self.assertEqual(self.agent.description, "Test agent for filtering content") + self.assertTrue(self.agent.enable_sentiment_check) + self.assertTrue(self.agent.enable_pii_check) + self.assertTrue(self.agent.enable_toxicity_check) + self.assertEqual(self.agent.language_code, "en") + + async def test_process_clean_content(self): + """Test processing clean content passes through filters""" + input_text = "Hello, this is a friendly message!" + + response = await self.agent.process_request( + input_text=input_text, + user_id="test_user", + session_id="test_session", + chat_history=[] + ) + + self.assertIsNotNone(response) + self.assertIsInstance(response, ConversationMessage) + self.assertEqual(response.role, ParticipantRole.ASSISTANT.value) + self.assertEqual(response.content[0]["text"], input_text) + + async def test_negative_sentiment_blocking(self): + """Test that highly negative content is blocked""" + # Configure mock for negative sentiment + self.mock_comprehend_client.detect_sentiment.return_value = { + 'Sentiment': 'NEGATIVE', + 'SentimentScore': { + 'Positive': 0.0, + 'Negative': 0.9, + 'Neutral': 0.1, + 'Mixed': 0.0 + } + } + + response = await self.agent.process_request( + input_text="I hate everything!", + user_id="test_user", + session_id="test_session", + chat_history=[] + ) + + self.assertIsNone(response) + self.mock_comprehend_client.detect_sentiment.assert_called_once() + + async def test_pii_detection_blocking(self): + """Test that content with PII is blocked""" + # Configure mock for PII detection + self.mock_comprehend_client.detect_pii_entities.return_value = { + 'Entities': [ + {'Type': 'EMAIL', 'Score': 0.99}, + {'Type': 'PHONE', 'Score': 0.95} + ] + } + + response = await self.agent.process_request( + input_text="Contact me at test@email.com", + user_id="test_user", + session_id="test_session", + chat_history=[] + ) + + self.assertIsNone(response) + self.mock_comprehend_client.detect_pii_entities.assert_called_once() + + async def test_toxic_content_blocking(self): + """Test that toxic content is blocked""" + # Configure mock for toxic content + self.mock_comprehend_client.detect_toxic_content.return_value = { + 'ResultList': [{ + 'Labels': [ + {'Name': 'HATE_SPEECH', 'Score': 0.95} + ] + }] + } + + response = await self.agent.process_request( + input_text="Some toxic content here", + user_id="test_user", + session_id="test_session", + chat_history=[] + ) + + self.assertIsNone(response) + self.mock_comprehend_client.detect_toxic_content.assert_called_once() + + async def test_custom_check(self): + """Test custom check functionality""" + async def custom_check(text: str) -> str: + if "banned" in text.lower(): + return "Contains banned word" + return None + + self.agent.add_custom_check(custom_check) + + response = await self.agent.process_request( + input_text="This contains a banned word", + user_id="test_user", + session_id="test_session", + chat_history=[] + ) + + self.assertIsNone(response) + + async def test_language_code_validation(self): + """Test language code validation and setting""" + # Test valid language code + self.agent.set_language_code("es") + self.assertEqual(self.agent.language_code, "es") + + # Test invalid language code + with self.assertRaises(ValueError): + self.agent.set_language_code("invalid") + + async def test_allow_pii_configuration(self): + """Test PII allowance configuration""" + # Create new agent instance with PII allowed + agent_with_pii = ComprehendFilterAgent( + ComprehendFilterAgentOptions( + name="Test Filter Agent", + description="Test agent for filtering content", + client=self.mock_comprehend_client, + allow_pii=True + ) + ) + + # Configure mock for PII detection + self.mock_comprehend_client.detect_pii_entities.return_value = { + 'Entities': [ + {'Type': 'EMAIL', 'Score': 0.99} + ] + } + + response = await agent_with_pii.process_request( + input_text="Contact me at test@email.com", + user_id="test_user", + session_id="test_session", + chat_history=[] + ) + + self.assertIsNotNone(response) + self.assertEqual(response.content[0]["text"], "Contact me at test@email.com") + + async def test_error_handling(self): + """Test error handling in process_request""" + # Configure mock to raise an exception + self.mock_comprehend_client.detect_sentiment.side_effect = Exception("Test error") + + with self.assertRaises(Exception) as context: + await self.agent.process_request( + input_text="Hello", + user_id="test_user", + session_id="test_session", + chat_history=[] + ) + + self.assertTrue("Test error" in str(context.exception)) + + async def test_threshold_configuration(self): + """Test custom threshold configurations""" + agent = ComprehendFilterAgent( + ComprehendFilterAgentOptions( + name="Test Filter Agent", + description="Test agent for filtering content", + client=self.mock_comprehend_client, + sentiment_threshold=0.5, + toxicity_threshold=0.8 + ) + ) + + self.assertEqual(agent.sentiment_threshold, 0.5) + self.assertEqual(agent.toxicity_threshold, 0.8) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file