From ae3ecf68d424dca062d8ffb629b0c1c7f71a7a41 Mon Sep 17 00:00:00 2001 From: whimo Date: Thu, 20 Jun 2024 00:48:19 +0300 Subject: [PATCH] Advanced output handlers --- examples/output_handler.py | 21 ++++----- motleycrew/agents/__init__.py | 2 + motleycrew/agents/abstract_parent.py | 4 +- motleycrew/agents/crewai/crewai.py | 9 ++-- motleycrew/agents/langchain/langchain.py | 7 ++- motleycrew/agents/llama_index/llama_index.py | 7 ++- motleycrew/agents/output_handler.py | 46 ++++++++++++++++++++ motleycrew/agents/parent.py | 39 ++++++++++++++--- motleycrew/tools/tool.py | 3 ++ 9 files changed, 107 insertions(+), 31 deletions(-) diff --git a/examples/output_handler.py b/examples/output_handler.py index aace45ea..f5434c40 100644 --- a/examples/output_handler.py +++ b/examples/output_handler.py @@ -3,6 +3,7 @@ from motleycrew import MotleyCrew from motleycrew.agents.langchain.tool_calling_react import ReActToolCallingAgent +from motleycrew.agents import MotleyOutputHandler from motleycrew.common import configure_logging from motleycrew.tasks import SimpleTask @@ -17,23 +18,23 @@ def main(): tools = [search_tool] - def check_output(output: str): - if "medicine" not in output.lower(): - raise InvalidOutput("Add more information about AI applications in medicine.") + class ReportOutputHandler(MotleyOutputHandler): + def handle_output(self, output: str): + if "medicine" not in output.lower(): + raise InvalidOutput("Add more information about AI applications in medicine.") - return {"checked_output": output} + if "2024" in self.last_agent_input["prompt"]: + output += "\n\nThis report is up-to-date for 2024." - output_handler = StructuredTool.from_function( - name="output_handler", - description="Output handler", - func=check_output, - ) + output += f"\n\nBrought to you by motleycrew's {self.agent}." + + return {"checked_output": output} researcher = ReActToolCallingAgent( tools=tools, verbose=True, chat_history=True, - output_handler=output_handler, + output_handler=ReportOutputHandler(), ) crew = MotleyCrew() diff --git a/motleycrew/agents/__init__.py b/motleycrew/agents/__init__.py index 8de419eb..facbfb22 100644 --- a/motleycrew/agents/__init__.py +++ b/motleycrew/agents/__init__.py @@ -1,3 +1,5 @@ from .langchain import LangchainMotleyAgent from .crewai import CrewAIMotleyAgent from .llama_index import LlamaIndexMotleyAgent + +from .output_handler import MotleyOutputHandler diff --git a/motleycrew/agents/abstract_parent.py b/motleycrew/agents/abstract_parent.py index d11ffcfc..df38c5f3 100644 --- a/motleycrew/agents/abstract_parent.py +++ b/motleycrew/agents/abstract_parent.py @@ -12,14 +12,14 @@ class MotleyAgentAbstractParent(ABC): @abstractmethod def invoke( self, - task_dict: dict, + input: dict, config: Optional[RunnableConfig] = None, **kwargs: Any, ) -> Any: """ Description Args: - task_dict (dict): + input (dict): config (:obj:`RunnableConfig`, optional): **kwargs: diff --git a/motleycrew/agents/crewai/crewai.py b/motleycrew/agents/crewai/crewai.py index d93ff00c..349c736c 100644 --- a/motleycrew/agents/crewai/crewai.py +++ b/motleycrew/agents/crewai/crewai.py @@ -46,22 +46,21 @@ def __init__( def invoke( self, - task_dict: dict, + input: dict, config: Optional[RunnableConfig] = None, **kwargs: Any, ) -> Any: """Description Args: - task_dict (dict): + input (dict): config (:obj:`RunnableConfig`, optional): **kwargs: Returns: Any: """ - self.materialize() - prompt = self.compose_prompt(task_dict, task_dict.get("prompt")) + prompt = self.prepare_for_invocation(input=input) langchain_tools = [tool.to_langchain_tool() for tool in self.tools.values()] config = add_default_callbacks_to_langchain_config(config) @@ -69,7 +68,7 @@ def invoke( crewai_task = CrewAI__Task(description=prompt) output = self.agent.execute_task( - task=crewai_task, context=task_dict.get("context"), tools=langchain_tools, config=config + task=crewai_task, context=input.get("context"), tools=langchain_tools, config=config ) return output diff --git a/motleycrew/agents/langchain/langchain.py b/motleycrew/agents/langchain/langchain.py index bcdc516c..c81b468d 100644 --- a/motleycrew/agents/langchain/langchain.py +++ b/motleycrew/agents/langchain/langchain.py @@ -140,22 +140,21 @@ def materialize(self): def invoke( self, - task_dict: dict, + input: dict, config: Optional[RunnableConfig] = None, **kwargs: Any, ) -> Any: """Description Args: - task_dict (dict): + input (dict): config (:obj:`RunnableConfig`, optional): **kwargs: Returns: """ - self.materialize() - prompt = self.compose_prompt(task_dict, task_dict.get("prompt")) + prompt = self.prepare_for_invocation(input=input) config = add_default_callbacks_to_langchain_config(config) if self.get_session_history_callable: diff --git a/motleycrew/agents/llama_index/llama_index.py b/motleycrew/agents/llama_index/llama_index.py index 7d0dbf44..6413fc19 100644 --- a/motleycrew/agents/llama_index/llama_index.py +++ b/motleycrew/agents/llama_index/llama_index.py @@ -43,22 +43,21 @@ def __init__( def invoke( self, - task_dict: dict, + input: dict, config: Optional[RunnableConfig] = None, **kwargs: Any, ) -> Any: """Description Args: - task_dict (dict): + input (dict): config (:obj:`RunnableConfig`, optional): **kwargs: Returns: Any: """ - self.materialize() - prompt = self.compose_prompt(task_dict, task_dict.get("prompt")) + prompt = self.prepare_for_invocation(input=input) output = self.agent.chat(prompt) return output.response diff --git a/motleycrew/agents/output_handler.py b/motleycrew/agents/output_handler.py index e69de29b..4a3624c1 100644 --- a/motleycrew/agents/output_handler.py +++ b/motleycrew/agents/output_handler.py @@ -0,0 +1,46 @@ +from typing import Optional +from abc import ABC, abstractmethod +from langchain_core.tools import StructuredTool, BaseTool +from langchain_core.pydantic_v1 import BaseModel + +from motleycrew.agents.abstract_parent import MotleyAgentAbstractParent +from motleycrew.common.exceptions import InvalidOutput +from motleycrew.tools import MotleyTool + + +class MotleyOutputHandler(MotleyTool, ABC): + _name: str = "output_handler" + """Name of the output handler tool.""" + + _description: str = "Output handler. ONLY RETURN THE FINAL RESULT USING THIS TOOL!" + """Description of the output handler tool.""" + + _args_schema: Optional[BaseModel] = None + """Pydantic schema for the arguments of the output handler tool. + Inferred from the `handle_output` method if not provided.""" + + _exceptions_to_handle: tuple[Exception] = (InvalidOutput,) + """Exceptions that should be returned to the agent when raised in the `handle_output` method.""" + + def __init__(self): + langchain_tool = self._create_langchain_tool() + super().__init__(langchain_tool) + + self.agent: Optional[MotleyAgentAbstractParent] = None + self.last_agent_input: Optional[dict] = None + + @property + def exceptions_to_handle(self): + return self._exceptions_to_handle + + def _create_langchain_tool(self): + return StructuredTool.from_function( + name=self._name, + description=self._description, + args_schema=self._args_schema, + func=self.handle_output, + ) + + @abstractmethod + def handle_output(self, *args, **kwargs): + pass diff --git a/motleycrew/agents/parent.py b/motleycrew/agents/parent.py index 8b0710bd..4ea371ea 100644 --- a/motleycrew/agents/parent.py +++ b/motleycrew/agents/parent.py @@ -11,6 +11,7 @@ from pydantic import BaseModel from motleycrew.agents.abstract_parent import MotleyAgentAbstractParent +from motleycrew.agents.output_handler import MotleyOutputHandler from motleycrew.tools import MotleyTool from motleycrew.common import MotleyAgentFactory, MotleySupportedTool from motleycrew.common.exceptions import ( @@ -63,7 +64,7 @@ def __init__( self.add_tools(tools) def __repr__(self): - return f"Agent(name={self.name})" + return f"{self.__class__.__name__}(name={self.name})" def __str__(self): return self.__repr__() @@ -121,19 +122,25 @@ def _prepare_output_handler(self) -> Optional[MotleyTool]: if not self.output_handler: return None + # TODO: make this neater by constructing MotleyOutputHandler from tools? + if isinstance(self.output_handler, MotleyOutputHandler): + exceptions_to_handle = self.output_handler.exceptions_to_handle + description = self.output_handler.description + else: + exceptions_to_handle = (InvalidOutput,) + description = self.output_handler.description or f"Output handler" + assert isinstance(description, str) + description += "\n ONLY RETURN THE FINAL RESULT USING THIS TOOL!" + def handle_agent_output(*args, **kwargs): assert self.output_handler try: output = self.output_handler._run(*args, **kwargs) - except InvalidOutput as exc: + except exceptions_to_handle as exc: return f"{exc.__class__.__name__}: {str(exc)}" raise DirectOutput(output) - description = self.output_handler.description or f"Output handler for {self.name}" - assert isinstance(description, str) - description += "\n ONLY RETURN THE FINAL RESULT USING THIS TOOL!" - prepared_output_handler = StructuredTool.from_function( name=self.output_handler.name, description=description, @@ -188,6 +195,26 @@ def materialize(self): else: self._agent = self.agent_factory(tools=self.tools) + def prepare_for_invocation(self, input: dict) -> str: + """Prepares the agent for invocation by materializing it and composing the prompt. + + Should be called in the beginning of the agent's invoke method. + + Args: + input (dict): the input to the agent + + Returns: + str: the composed prompt + """ + self.materialize() + + if isinstance(self.output_handler, MotleyOutputHandler): + self.output_handler.agent = self + self.output_handler.last_agent_input = input + + prompt = self.compose_prompt(input, input.get("prompt")) + return prompt + def add_tools(self, tools: Sequence[MotleySupportedTool]): """Description diff --git a/motleycrew/tools/tool.py b/motleycrew/tools/tool.py index afe47f9a..ec03b3c7 100644 --- a/motleycrew/tools/tool.py +++ b/motleycrew/tools/tool.py @@ -56,6 +56,9 @@ def invoke( ) -> Any: return self.tool.invoke(input=input, config=config, **kwargs) + def _run(self, *args: tuple, **kwargs: Dict[str, Any]) -> Any: + return self.tool._run(*args, **kwargs) + @staticmethod def from_langchain_tool(langchain_tool: BaseTool) -> "MotleyTool": """Description