From 432aba1289eb241db3fe7e9d12d93ddb9d8e39ef Mon Sep 17 00:00:00 2001 From: User Date: Fri, 5 Jul 2024 14:02:54 +0300 Subject: [PATCH] add max_iterations parameter for output_handler --- motleycrew/agents/output_handler.py | 4 +++- motleycrew/agents/parent.py | 15 +++++++++++++-- motleycrew/common/defaults.py | 6 +++++- motleycrew/common/exceptions.py | 11 +++++++++++ 4 files changed, 32 insertions(+), 4 deletions(-) diff --git a/motleycrew/agents/output_handler.py b/motleycrew/agents/output_handler.py index 3a1d5c03..63bec84d 100644 --- a/motleycrew/agents/output_handler.py +++ b/motleycrew/agents/output_handler.py @@ -5,6 +5,7 @@ from motleycrew.agents.abstract_parent import MotleyAgentAbstractParent from motleycrew.common.exceptions import InvalidOutput +from motleycrew.common import Defaults from motleycrew.tools import MotleyTool @@ -22,7 +23,8 @@ class MotleyOutputHandler(MotleyTool, ABC): _exceptions_to_handle: tuple[Exception] = (InvalidOutput,) """Exceptions that should be returned to the agent when raised in the `handle_output` method.""" - def __init__(self): + def __init__(self, max_iterations: int = Defaults.DEFAULT_OUTPUT_HANDLER_MAX_ITERATIONS): + self.max_iterations = max_iterations # number of iterations of catching an exception langchain_tool = self._create_langchain_tool() super().__init__(langchain_tool) diff --git a/motleycrew/agents/parent.py b/motleycrew/agents/parent.py index 4c95af35..4c587c17 100644 --- a/motleycrew/agents/parent.py +++ b/motleycrew/agents/parent.py @@ -12,11 +12,12 @@ from motleycrew.agents.abstract_parent import MotleyAgentAbstractParent from motleycrew.common import MotleyAgentFactory, MotleySupportedTool -from motleycrew.common import logger +from motleycrew.common import logger, Defaults from motleycrew.common.exceptions import ( AgentNotMaterialized, CannotModifyMaterializedAgent, InvalidOutput, + OutputHandlerMaxIterationsExceeded, ) from motleycrew.tools import MotleyTool @@ -131,18 +132,28 @@ def _prepare_output_handler(self) -> Optional[MotleyTool]: if isinstance(self.output_handler, MotleyOutputHandler): exceptions_to_handle = self.output_handler.exceptions_to_handle description = self.output_handler.description + max_iterations = self.output_handler.max_iterations + else: exceptions_to_handle = (InvalidOutput,) description = self.output_handler.description or f"Output handler" assert isinstance(description, str) description += "\n ONLY RETURN THE FINAL RESULT USING THIS TOOL!" + max_iterations = Defaults.DEFAULT_OUTPUT_HANDLER_MAX_ITERATIONS + + iteration = 0 def handle_agent_output(*args, **kwargs): assert self.output_handler + nonlocal iteration + try: + iteration += 1 output = self.output_handler._run(*args, **kwargs) except exceptions_to_handle as exc: - return f"{exc.__class__.__name__}: {str(exc)}" + if iteration <= max_iterations: + return f"{exc.__class__.__name__}: {str(exc)}" + raise OutputHandlerMaxIterationsExceeded(*args, **kwargs) raise DirectOutput(output) diff --git a/motleycrew/common/defaults.py b/motleycrew/common/defaults.py index e6795f16..d2e3fdb0 100644 --- a/motleycrew/common/defaults.py +++ b/motleycrew/common/defaults.py @@ -1,10 +1,11 @@ """ Module description """ + from motleycrew.common import LLMFamily from motleycrew.common import GraphStoreType class Defaults: - """ Description + """Description Attributes: DEFAULT_LLM_FAMILY (str): @@ -15,8 +16,10 @@ class Defaults: MODULE_INSTALL_COMMANDS (dict): DEFAULT_NUM_THREADS (int): DEFAULT_EVENT_LOOP_SLEEP (int): + DEFAULT_OUTPUT_HANDLER_MAX_ITERATIONS (int): """ + DEFAULT_LLM_FAMILY = LLMFamily.OPENAI DEFAULT_LLM_NAME = "gpt-4o" DEFAULT_LLM_TEMPERATURE = 0.0 @@ -35,3 +38,4 @@ class Defaults: DEFAULT_NUM_THREADS = 4 DEFAULT_EVENT_LOOP_SLEEP = 1 + DEFAULT_OUTPUT_HANDLER_MAX_ITERATIONS = 5 diff --git a/motleycrew/common/exceptions.py b/motleycrew/common/exceptions.py index 7eaefba1..78a5ebd0 100644 --- a/motleycrew/common/exceptions.py +++ b/motleycrew/common/exceptions.py @@ -142,3 +142,14 @@ class InvalidOutput(Exception): """Raised in output handlers when an agent's output is not accepted""" pass + + +class OutputHandlerMaxIterationsExceeded(BaseException): + """Raised when the output handlers iteration limit is exceeded""" + + def __init__(self, *args, **kwargs): + self.args = args + self.kwargs = kwargs + + def __str__(self): + return "\n args: {}\n kwargs: {}".format(self.args, self.kwargs)