Skip to content

Commit

Permalink
fix InlineAgent & ComprehendAgent & add unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
cornelcroi committed Jan 28, 2025
1 parent 1f17341 commit 0b97e1f
Show file tree
Hide file tree
Showing 5 changed files with 480 additions and 20 deletions.
2 changes: 1 addition & 1 deletion python/setup.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[metadata]
name = multi_agent_orchestrator
version = 0.1.6
version = 0.1.7
author = Anthony Bernabeu, Corneliu Croitoru
author_email = [email protected], [email protected]
description = Multi-agent orchestrator framework
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
# BedrockInlineAgentOptions Dataclass
@dataclass
class BedrockInlineAgentOptions(AgentOptions):
model_id: Optional[str] = None
region: Optional[str] = None
inference_config: Optional[Dict[str, Any]] = None
client: Optional[Any] = None
bedrock_agent_client: Optional[Any] = None
Expand Down Expand Up @@ -71,6 +73,8 @@ def __init__(self, options: BedrockInlineAgentOptions):
else:
self.client = boto3.client('bedrock-runtime')

self.model_id: str = options.model_id or BEDROCK_MODEL_ID_CLAUDE_3_HAIKU

# Initialize bedrock agent client
if options.bedrock_agent_client:
self.bedrock_agent_client = options.bedrock_agent_client
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,35 +4,39 @@
from .agent import Agent, AgentOptions
import boto3
from botocore.config import Config
import os
from dataclasses import dataclass


# Type alias for CheckFunction
CheckFunction = Callable[[str], str]

@dataclass
class ComprehendFilterAgentOptions(AgentOptions):
def __init__(self,
enable_sentiment_check: bool = True,
enable_pii_check: bool = True,
enable_toxicity_check: bool = True,
sentiment_threshold: float = 0.7,
toxicity_threshold: float = 0.7,
allow_pii: bool = False,
language_code: str = 'en',
**kwargs):
super().__init__(**kwargs)
self.enable_sentiment_check = enable_sentiment_check
self.enable_pii_check = enable_pii_check
self.enable_toxicity_check = enable_toxicity_check
self.sentiment_threshold = sentiment_threshold
self.toxicity_threshold = toxicity_threshold
self.allow_pii = allow_pii
self.language_code = language_code
enable_sentiment_check: bool = True
enable_pii_check: bool = True
enable_toxicity_check: bool = True
sentiment_threshold: float = 0.7
toxicity_threshold: float = 0.7
allow_pii: bool = False
language_code: str = 'en'
region: Optional[str] = None
client: Optional[Any] = None

class ComprehendFilterAgent(Agent):
def __init__(self, options: ComprehendFilterAgentOptions):
super().__init__(options)

config = Config(region_name=options.region) if options.region else None
self.comprehend_client = boto3.client('comprehend', config=config)
if options.client:
self.comprehend_client = options.client
else:
if options.region:
self.client = boto3.client(
'comprehend',
region_name=options.region or os.environ.get('AWS_REGION')
)
else:
self.client = boto3.client('comprehend')

self.custom_checks: List[CheckFunction] = []

Expand Down
232 changes: 232 additions & 0 deletions python/src/tests/agents/test_bedrock_inline_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
import unittest
from unittest.mock import Mock
import json
from typing import Dict, Any

from multi_agent_orchestrator.types import ConversationMessage, ParticipantRole
from multi_agent_orchestrator.agents import BedrockInlineAgent, BedrockInlineAgentOptions

class TestBedrockInlineAgent(unittest.IsolatedAsyncioTestCase):
async def asyncSetUp(self):
# Mock clients
self.mock_bedrock_client = Mock()
self.mock_bedrock_agent_client = Mock()

# Sample action groups and knowledge bases
self.action_groups = [
{
"actionGroupName": "TestActionGroup1",
"description": "Test action group 1 description"
},
{
"actionGroupName": "TestActionGroup2",
"description": "Test action group 2 description"
}
]

self.knowledge_bases = [
{
"knowledgeBaseId": "kb1",
"description": "Test knowledge base 1"
},
{
"knowledgeBaseId": "kb2",
"description": "Test knowledge base 2"
}
]

# Create agent instance
self.agent = BedrockInlineAgent(
BedrockInlineAgentOptions(
name="Test Agent",
description="Test agent description",
client=self.mock_bedrock_client,
bedrock_agent_client=self.mock_bedrock_agent_client,
action_groups_list=self.action_groups,
knowledge_bases=self.knowledge_bases
)
)

async def test_initialization(self):
"""Test agent initialization and configuration"""
self.assertEqual(self.agent.name, "Test Agent")
self.assertEqual(self.agent.description, "Test agent description")
self.assertEqual(len(self.agent.action_groups_list), 2)
self.assertEqual(len(self.agent.knowledge_bases), 2)
self.assertEqual(self.agent.tool_config['toolMaxRecursions'], 1)

async def test_process_request_without_tool_use(self):
"""Test processing a request that doesn't require tool use"""
# Mock the converse response
mock_response = {
'output': {
'message': {
'role': 'assistant',
'content': [{'text': 'Test response'}]
}
}
}
self.mock_bedrock_client.converse.return_value = mock_response

# Test input
input_text = "Hello"
chat_history = []

# Process request
response = await self.agent.process_request(
input_text=input_text,
user_id='test_user',
session_id='test_session',
chat_history=chat_history
)

# Verify response
self.assertIsInstance(response, ConversationMessage)
self.assertEqual(response.role, ParticipantRole.ASSISTANT.value)
self.assertEqual(response.content[0]['text'], 'Test response')

async def test_process_request_with_tool_use(self):
"""Test processing a request that requires tool use"""
# Mock the converse response with tool use
tool_use_response = {
'output': {
'message': {
'role': 'assistant',
'content': [{
'toolUse': {
'name': 'inline_agent_creation',
'input': {
'action_group_names': ['TestActionGroup1'],
'knowledge_bases': ['kb1'],
'description': 'Test description',
'user_request': 'Test request'
}
}
}]
}
}
}
self.mock_bedrock_client.converse.return_value = tool_use_response

# Mock the inline agent response
mock_completion = {
'chunk': {
'bytes': b'Inline agent response'
}
}
self.mock_bedrock_agent_client.invoke_inline_agent.return_value = {
'completion': [mock_completion]
}

# Test input
input_text = "Use inline agent"
chat_history = []

# Process request
response = await self.agent.process_request(
input_text=input_text,
user_id='test_user',
session_id='test_session',
chat_history=chat_history
)

# Verify response
self.assertIsInstance(response, ConversationMessage)
self.assertEqual(response.role, ParticipantRole.ASSISTANT.value)
self.assertEqual(response.content[0]['text'], 'Inline agent response')

# Verify inline agent was called with correct parameters
self.mock_bedrock_agent_client.invoke_inline_agent.assert_called_once()
call_kwargs = self.mock_bedrock_agent_client.invoke_inline_agent.call_args[1]
self.assertEqual(len(call_kwargs['actionGroups']), 1)
self.assertEqual(len(call_kwargs['knowledgeBases']), 1)
self.assertEqual(call_kwargs['inputText'], 'Test request')

async def test_error_handling(self):
"""Test error handling in process_request"""
# Mock the converse method to raise an exception
self.mock_bedrock_client.converse.side_effect = Exception("Test error")

# Test input
input_text = "Hello"
chat_history = []

# Verify exception is raised
with self.assertRaises(Exception) as context:
await self.agent.process_request(
input_text=input_text,
user_id='test_user',
session_id='test_session',
chat_history=chat_history
)

self.assertTrue("Test error" in str(context.exception))

async def test_system_prompt_formatting(self):
"""Test system prompt formatting and template replacement"""
# Test with custom variables
test_variables = {
'test_var': 'test_value'
}
self.agent.set_system_prompt(
template="Test template with {{test_var}}",
variables=test_variables
)

self.assertEqual(self.agent.system_prompt, "Test template with test_value")

async def test_inline_agent_tool_handler(self):
"""Test the inline agent tool handler"""
# Mock response content
response = ConversationMessage(
role=ParticipantRole.ASSISTANT.value,
content=[{
'toolUse': {
'name': 'inline_agent_creation',
'input': {
'action_group_names': ['TestActionGroup1'],
'knowledge_bases': ['kb1'],
'description': 'Test description',
'user_request': 'Test request'
}
}
}]
)

# Mock inline agent response
mock_completion = {
'chunk': {
'bytes': b'Handler test response'
}
}
self.mock_bedrock_agent_client.invoke_inline_agent.return_value = {
'completion': [mock_completion]
}

# Call handler
result = await self.agent.inline_agent_tool_handler(
session_id='test_session',
response=response,
conversation=[]
)

# Verify result
self.assertIsInstance(result, ConversationMessage)
self.assertEqual(result.content[0]['text'], 'Handler test response')

async def test_custom_prompt_template(self):
"""Test custom prompt template setup"""
custom_template = "Custom template {{test_var}}"
custom_variables = {"test_var": "test_value"}

self.agent.set_system_prompt(
template=custom_template,
variables=custom_variables
)

self.assertEqual(self.agent.prompt_template, custom_template)
self.assertEqual(self.agent.custom_variables, custom_variables)
self.assertEqual(self.agent.system_prompt, "Custom template test_value")

if __name__ == '__main__':
unittest.main()
Loading

0 comments on commit 0b97e1f

Please sign in to comment.