Skip to content

Commit

Permalink
Raise NotImplementedError when using output handlers with CrewAI
Browse files Browse the repository at this point in the history
  • Loading branch information
whimo committed Jul 1, 2024
1 parent 2ca00a4 commit a840165
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 43 deletions.
7 changes: 5 additions & 2 deletions examples/llama_index_output_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@ def main():

def check_output(output: str):
if "medicine" not in output.lower():
raise InvalidOutput("Add more information about AI applications in medicine.")
raise InvalidOutput(
"Add more information about AI applications in medicine."
)

return {"checked_output": output.lower()}
return {"checked_output": output}

output_handler = StructuredTool.from_function(
name="output_handler",
Expand All @@ -34,6 +36,7 @@ def check_output(output: str):
tools=[search_tool],
output_handler=output_handler,
verbose=True,
max_iterations=16, # default is 10, we add more because the output handler may reject the output
)

crew = MotleyCrew(async_backend=AsyncBackend.NONE)
Expand Down
54 changes: 13 additions & 41 deletions motleycrew/agents/crewai/crewai.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
""" Module description """

from typing import Any, Optional, Sequence, Callable
from typing import Any, Optional, Sequence

from langchain_core.runnables import RunnableConfig

from motleycrew.agents.crewai import CrewAIAgentWithConfig
from motleycrew.agents.mixins import LangchainOutputHandlingAgentMixin
from motleycrew.agents.parent import MotleyAgentParent
from motleycrew.common import MotleyAgentFactory
from motleycrew.common import MotleySupportedTool
Expand All @@ -19,7 +18,7 @@
pass


class CrewAIMotleyAgentParent(MotleyAgentParent, LangchainOutputHandlingAgentMixin):
class CrewAIMotleyAgentParent(MotleyAgentParent):
def __init__(
self,
goal: str,
Expand All @@ -38,6 +37,13 @@ def __init__(
tools (:obj:`Sequence[MotleySupportedTool]`, optional:
verbose (bool):
"""

if output_handler:
raise NotImplementedError(
"Output handler is not supported for CrewAI agents "
"because of the specificity of CrewAi's prompts."
)

ensure_module_is_installed("crewai")
super().__init__(
description=goal,
Expand All @@ -48,12 +54,6 @@ def __init__(
verbose=verbose,
)

if self.output_handler:
self._agent_finish_blocker_tool = self._create_agent_finish_blocker_tool()
self.tools[self._agent_finish_blocker_tool.name] = MotleyTool.from_langchain_tool(self._agent_finish_blocker_tool)
output_handler_tool = self._prepare_output_handler()
self.tools[output_handler_tool.name] = output_handler_tool

def invoke(
self,
input: dict,
Expand All @@ -78,44 +78,16 @@ def invoke(
crewai_task = CrewAI__Task(description=prompt)

output = self.agent.execute_task(
task=crewai_task, context=input.get("context"), tools=langchain_tools, config=config
task=crewai_task,
context=input.get("context"),
tools=langchain_tools,
config=config,
)
return output

def _create_agent_executor_decorator(self):
"""Decorator adding logic for working with output_handler when creating agent_executor"""

def decorator(func: Callable):
def wrapper(tools=None):
result = func(tools)

object.__setattr__(
self._agent.agent_executor.agent,
"plan",
self.agent_plan_decorator()(self._agent.agent_executor.agent.plan),
)

object.__setattr__(
self._agent.agent_executor,
"_take_next_step",
self.take_next_step_decorator()(self._agent.agent_executor._take_next_step),
)
return result

return wrapper

return decorator

def materialize(self):
super().materialize()

if self.output_handler:
object.__setattr__(
self._agent,
"create_agent_executor",
self._create_agent_executor_decorator()(self._agent.create_agent_executor),
)

# TODO: what do these do?
def set_cache_handler(self, cache_handler: Any) -> None:
"""Description
Expand Down
2 changes: 2 additions & 0 deletions motleycrew/agents/llama_index/llama_index_react.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def __init__(
llm: LLM | None = None,
output_handler: MotleySupportedTool | None = None,
verbose: bool = False,
max_iterations: int = 10,
):
"""Description
Expand All @@ -48,6 +49,7 @@ def agent_factory(tools: dict[str, MotleyTool]) -> ReActAgent:
tools=llama_index_tools,
llm=llm,
verbose=verbose,
max_iterations=max_iterations,
callback_manager=CallbackManager(callbacks),
)
return agent
Expand Down

0 comments on commit a840165

Please sign in to comment.