Skip to content

Commit

Permalink
Advanced output handlers (#52)
Browse files Browse the repository at this point in the history
* Advanced output handlers

* minor refactor
  • Loading branch information
whimo authored Jun 20, 2024
1 parent 721ba41 commit fbc1dd9
Show file tree
Hide file tree
Showing 9 changed files with 108 additions and 35 deletions.
24 changes: 11 additions & 13 deletions examples/output_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,37 +3,35 @@

from motleycrew import MotleyCrew
from motleycrew.agents.langchain.tool_calling_react import ReActToolCallingAgent
from motleycrew.agents import MotleyOutputHandler
from motleycrew.common import configure_logging
from motleycrew.tasks import SimpleTask

from motleycrew.common.exceptions import InvalidOutput


from langchain_core.tools import StructuredTool


def main():
search_tool = DuckDuckGoSearchRun()

tools = [search_tool]

def check_output(output: str):
if "medicine" not in output.lower():
raise InvalidOutput("Add more information about AI applications in medicine.")
class ReportOutputHandler(MotleyOutputHandler):
def handle_output(self, output: str):
if "medicine" not in output.lower():
raise InvalidOutput("Add more information about AI applications in medicine.")

return {"checked_output": output}
if "2024" in self.agent_input["prompt"]:
output += "\n\nThis report is up-to-date for 2024."

output_handler = StructuredTool.from_function(
name="output_handler",
description="Output handler",
func=check_output,
)
output += f"\n\nBrought to you by motleycrew's {self.agent}."

return {"checked_output": output}

researcher = ReActToolCallingAgent(
tools=tools,
verbose=True,
chat_history=True,
output_handler=output_handler,
output_handler=ReportOutputHandler(),
)

crew = MotleyCrew()
Expand Down
2 changes: 2 additions & 0 deletions motleycrew/agents/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from .langchain import LangchainMotleyAgent
from .crewai import CrewAIMotleyAgent
from .llama_index import LlamaIndexMotleyAgent

from .output_handler import MotleyOutputHandler
4 changes: 2 additions & 2 deletions motleycrew/agents/abstract_parent.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@ class MotleyAgentAbstractParent(ABC):
@abstractmethod
def invoke(
self,
task_dict: dict,
input: dict,
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Any:
""" Description
Args:
task_dict (dict):
input (dict):
config (:obj:`RunnableConfig`, optional):
**kwargs:
Expand Down
9 changes: 4 additions & 5 deletions motleycrew/agents/crewai/crewai.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,30 +46,29 @@ def __init__(

def invoke(
self,
task_dict: dict,
input: dict,
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Any:
"""Description
Args:
task_dict (dict):
input (dict):
config (:obj:`RunnableConfig`, optional):
**kwargs:
Returns:
Any:
"""
self.materialize()
prompt = self.compose_prompt(task_dict, task_dict.get("prompt"))
prompt = self.prepare_for_invocation(input=input)

langchain_tools = [tool.to_langchain_tool() for tool in self.tools.values()]
config = add_default_callbacks_to_langchain_config(config)

crewai_task = CrewAI__Task(description=prompt)

output = self.agent.execute_task(
task=crewai_task, context=task_dict.get("context"), tools=langchain_tools, config=config
task=crewai_task, context=input.get("context"), tools=langchain_tools, config=config
)
return output

Expand Down
7 changes: 3 additions & 4 deletions motleycrew/agents/langchain/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,22 +140,21 @@ def materialize(self):

def invoke(
self,
task_dict: dict,
input: dict,
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Any:
"""Description
Args:
task_dict (dict):
input (dict):
config (:obj:`RunnableConfig`, optional):
**kwargs:
Returns:
"""
self.materialize()
prompt = self.compose_prompt(task_dict, task_dict.get("prompt"))
prompt = self.prepare_for_invocation(input=input)

config = add_default_callbacks_to_langchain_config(config)
if self.get_session_history_callable:
Expand Down
7 changes: 3 additions & 4 deletions motleycrew/agents/llama_index/llama_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,22 +43,21 @@ def __init__(

def invoke(
self,
task_dict: dict,
input: dict,
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Any:
"""Description
Args:
task_dict (dict):
input (dict):
config (:obj:`RunnableConfig`, optional):
**kwargs:
Returns:
Any:
"""
self.materialize()
prompt = self.compose_prompt(task_dict, task_dict.get("prompt"))
prompt = self.prepare_for_invocation(input=input)

output = self.agent.chat(prompt)
return output.response
Expand Down
46 changes: 46 additions & 0 deletions motleycrew/agents/output_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from typing import Optional
from abc import ABC, abstractmethod
from langchain_core.tools import StructuredTool, BaseTool
from langchain_core.pydantic_v1 import BaseModel

from motleycrew.agents.abstract_parent import MotleyAgentAbstractParent
from motleycrew.common.exceptions import InvalidOutput
from motleycrew.tools import MotleyTool


class MotleyOutputHandler(MotleyTool, ABC):
_name: str = "output_handler"
"""Name of the output handler tool."""

_description: str = "Output handler. ONLY RETURN THE FINAL RESULT USING THIS TOOL!"
"""Description of the output handler tool."""

_args_schema: Optional[BaseModel] = None
"""Pydantic schema for the arguments of the output handler tool.
Inferred from the `handle_output` method if not provided."""

_exceptions_to_handle: tuple[Exception] = (InvalidOutput,)
"""Exceptions that should be returned to the agent when raised in the `handle_output` method."""

def __init__(self):
langchain_tool = self._create_langchain_tool()
super().__init__(langchain_tool)

self.agent: Optional[MotleyAgentAbstractParent] = None
self.agent_input: Optional[dict] = None

@property
def exceptions_to_handle(self):
return self._exceptions_to_handle

def _create_langchain_tool(self):
return StructuredTool.from_function(
name=self._name,
description=self._description,
args_schema=self._args_schema,
func=self.handle_output,
)

@abstractmethod
def handle_output(self, *args, **kwargs):
pass
41 changes: 34 additions & 7 deletions motleycrew/agents/parent.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from pydantic import BaseModel

from motleycrew.agents.abstract_parent import MotleyAgentAbstractParent
from motleycrew.agents.output_handler import MotleyOutputHandler
from motleycrew.tools import MotleyTool
from motleycrew.common import MotleyAgentFactory, MotleySupportedTool
from motleycrew.common.exceptions import (
Expand Down Expand Up @@ -63,7 +64,7 @@ def __init__(
self.add_tools(tools)

def __repr__(self):
return f"Agent(name={self.name})"
return f"{self.__class__.__name__}(name={self.name})"

def __str__(self):
return self.__repr__()
Expand Down Expand Up @@ -121,20 +122,26 @@ def _prepare_output_handler(self) -> Optional[MotleyTool]:
if not self.output_handler:
return None

# TODO: make this neater by constructing MotleyOutputHandler from tools?
if isinstance(self.output_handler, MotleyOutputHandler):
exceptions_to_handle = self.output_handler.exceptions_to_handle
description = self.output_handler.description
else:
exceptions_to_handle = (InvalidOutput,)
description = self.output_handler.description or f"Output handler"
assert isinstance(description, str)
description += "\n ONLY RETURN THE FINAL RESULT USING THIS TOOL!"

def handle_agent_output(*args, **kwargs):
assert self.output_handler
try:
output = self.output_handler._run(*args, **kwargs)
except InvalidOutput as exc:
except exceptions_to_handle as exc:
return f"{exc.__class__.__name__}: {str(exc)}"

raise DirectOutput(output)

description = self.output_handler.description or f"Output handler for {self.name}"
assert isinstance(description, str)
description += "\n ONLY RETURN THE FINAL RESULT USING THIS TOOL!"

prepared_output_handler = StructuredTool.from_function(
prepared_output_handler = StructuredTool(
name=self.output_handler.name,
description=description,
func=handle_agent_output,
Expand Down Expand Up @@ -188,6 +195,26 @@ def materialize(self):
else:
self._agent = self.agent_factory(tools=self.tools)

def prepare_for_invocation(self, input: dict) -> str:
"""Prepares the agent for invocation by materializing it and composing the prompt.
Should be called in the beginning of the agent's invoke method.
Args:
input (dict): the input to the agent
Returns:
str: the composed prompt
"""
self.materialize()

if isinstance(self.output_handler, MotleyOutputHandler):
self.output_handler.agent = self
self.output_handler.agent_input = input

prompt = self.compose_prompt(input, input.get("prompt"))
return prompt

def add_tools(self, tools: Sequence[MotleySupportedTool]):
"""Description
Expand Down
3 changes: 3 additions & 0 deletions motleycrew/tools/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ def invoke(
) -> Any:
return self.tool.invoke(input=input, config=config, **kwargs)

def _run(self, *args: tuple, **kwargs: Dict[str, Any]) -> Any:
return self.tool._run(*args, **kwargs)

@staticmethod
def from_langchain_tool(langchain_tool: BaseTool) -> "MotleyTool":
"""Description
Expand Down

0 comments on commit fbc1dd9

Please sign in to comment.