-
Notifications
You must be signed in to change notification settings - Fork 258
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add ParallelAgent class #103
base: main
Are you sure you want to change the base?
Changes from 1 commit
019aab6
42f064e
49c55e9
52d9d74
8e8eb85
952f412
5c81610
488d3e8
bbc9d8e
8980eb6
62d4922
0dee77c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
import asyncio | ||
from typing import Any, AsyncIterable | ||
|
||
from multi_agent_orchestrator.agents import ( | ||
Agent, | ||
AgentOptions, | ||
BedrockLLMAgent, | ||
) | ||
from multi_agent_orchestrator.types import ConversationMessage, ParticipantRole | ||
from multi_agent_orchestrator.utils.logger import Logger | ||
|
||
|
||
# Extend AgentOptions for ParallelAgent class: | ||
class ParallelAgentOptions(AgentOptions): | ||
def __init__( | ||
self, | ||
agents: list[str], | ||
default_output: str = None, | ||
**kwargs, | ||
): | ||
super().__init__(**kwargs) | ||
self.agents = agents | ||
self.default_output = default_output | ||
|
||
|
||
# Create a new custom agent that allows for parallel processing: | ||
class ParallelAgent(Agent): | ||
def __init__(self, options: ParallelAgentOptions): | ||
super().__init__(options) | ||
self.agents = options.agents | ||
self.default_output = ( | ||
options.default_output or "No output generated from the ParallelAgent." | ||
) | ||
if len(self.agents) == 0: | ||
raise ValueError("ParallelAgent requires at least 1 agent to initiate!") | ||
|
||
async def _get_llm_response( | ||
self, | ||
agent: BedrockLLMAgent, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Change from BedrockLLMAgent to Agent There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed in new commit below: Update parallel_agent.py after initial PR review |
||
input_text: str, | ||
user_id: str, | ||
session_id: str, | ||
chat_history: list[ConversationMessage], | ||
additional_params: dict[str, str] = None, | ||
) -> str: | ||
# Get response from LLM agent: | ||
final_response: ConversationMessage | AsyncIterable[Any] | ||
|
||
try: | ||
response = await agent.process_request( | ||
input_text, user_id, session_id, chat_history, additional_params | ||
) | ||
if self.is_conversation_message(response): | ||
if response.content and "text" in response.content[0]: | ||
final_response = response | ||
else: | ||
Logger.warn(f"Agent {agent.name} returned no text content.") | ||
return self.create_default_response() | ||
elif self.is_async_iterable(response): | ||
Logger.warn("Streaming is not allowed for ParallelAgents!") | ||
return self.create_default_response() | ||
else: | ||
Logger.warn(f"Agent {agent.name} returned an invalid response type.") | ||
return self.create_default_response() | ||
|
||
except Exception as error: | ||
Logger.error( | ||
f"Error processing request with agent {agent.name}: {str(error)}" | ||
) | ||
raise f"Error processing request with agent {agent.name}: {str(error)}" | ||
|
||
return final_response | ||
|
||
async def process_request( | ||
self, | ||
input_text: str, | ||
user_id: str, | ||
session_id: str, | ||
chat_history: list[ConversationMessage], | ||
additional_params: dict[str, str] = None, | ||
) -> ConversationMessage: | ||
# Create tasks for all LLMs to run in parallel: | ||
tasks = [] | ||
for agent in self.agents: | ||
tasks.append( | ||
self._get_llm_response( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why adding another method? Can't you just call agent.process_request()? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wanted to include some of the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see ok. Well I'd suggest to change the method name from _get_llm_response to self.agent_process_request() The framework is not only about llm. |
||
agent, | ||
input_text, | ||
user_id, | ||
session_id, | ||
chat_history, | ||
additional_params, | ||
) | ||
) | ||
|
||
# Run all tasks concurrently and wait for results: | ||
responses = await asyncio.gather(*tasks) | ||
|
||
# Create dictionary of responses: | ||
response_dict = { | ||
agent.name: response.content[0]["text"] | ||
for agent, response in zip(self.agents, responses) | ||
if response # Only include non-empty responses! | ||
} | ||
|
||
# Convert dictionary to string representation: | ||
combined_response = str(response_dict) | ||
|
||
return ConversationMessage( | ||
role=ParticipantRole.ASSISTANT.value, | ||
content=[{"text": combined_response}], | ||
) | ||
|
||
@staticmethod | ||
def is_async_iterable(obj: any) -> bool: | ||
return hasattr(obj, "__aiter__") | ||
|
||
@staticmethod | ||
def is_conversation_message(response: any) -> bool: | ||
return ( | ||
isinstance(response, ConversationMessage) | ||
and hasattr(response, "role") | ||
and hasattr(response, "content") | ||
and isinstance(response.content, list) | ||
) | ||
|
||
def create_default_response(self) -> ConversationMessage: | ||
return ConversationMessage( | ||
role=ParticipantRole.ASSISTANT.value, | ||
content=[{"text": self.default_output}], | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That should be a list of Agents
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed in new commit below: Update parallel_agent.py after initial PR review