diff --git a/motleycrew/agents/langchain/langchain.py b/motleycrew/agents/langchain/langchain.py index 33a8510d..cfd394e1 100644 --- a/motleycrew/agents/langchain/langchain.py +++ b/motleycrew/agents/langchain/langchain.py @@ -26,6 +26,7 @@ def __init__( tools: Sequence[MotleySupportedTool] | None = None, output_handler: MotleySupportedTool | None = None, chat_history: bool | GetSessionHistoryCallable = True, + input_as_messages: bool = False, verbose: bool = False, ): """ @@ -65,6 +66,8 @@ def __init__( See :class:`langchain_core.runnables.history.RunnableWithMessageHistory` for more details. + input_as_messages: Whether the agent expects a list of messages as input instead of a single string. + verbose: Whether to log verbose output. """ super().__init__( @@ -85,6 +88,8 @@ def __init__( else: self.get_session_history_callable = chat_history + self.input_as_messages = input_as_messages + def materialize(self): """Materialize the agent and wrap it in RunnableWithMessageHistory if needed.""" if self.is_materialized: @@ -141,7 +146,7 @@ def invoke( config: Optional[RunnableConfig] = None, **kwargs: Any, ) -> Any: - prompt = self.prepare_for_invocation(input=input) + prompt = self.prepare_for_invocation(input=input, prompt_as_messages=self.input_as_messages) config = add_default_callbacks_to_langchain_config(config) if self.get_session_history_callable: diff --git a/motleycrew/agents/langchain/tool_calling_react.py b/motleycrew/agents/langchain/tool_calling_react.py index 3d4a7efd..579cc415 100644 --- a/motleycrew/agents/langchain/tool_calling_react.py +++ b/motleycrew/agents/langchain/tool_calling_react.py @@ -1,16 +1,18 @@ from __future__ import annotations -from typing import Sequence, Optional +from typing import Sequence, Optional, Callable from langchain.agents import AgentExecutor from langchain.agents.format_scratchpad.tools import format_to_tool_messages from langchain.agents.output_parsers.tools import ToolsAgentOutputParser from langchain_core.language_models import BaseChatModel from langchain_core.prompts.chat import ChatPromptTemplate -from langchain_core.runnables import Runnable, RunnablePassthrough +from langchain_core.runnables import Runnable, RunnablePassthrough, RunnableLambda from langchain_core.runnables.history import GetSessionHistoryCallable from langchain_core.tools import BaseTool +from motleycrew.common.utils import print_passthrough + try: from langchain_anthropic import ChatAnthropic except ImportError: @@ -21,7 +23,7 @@ ToolCallingReActPromptsForOpenAI, ToolCallingReActPromptsForAnthropic, ) -from motleycrew.common import LLMFramework +from motleycrew.common import LLMFramework, Defaults from motleycrew.common import MotleySupportedTool from motleycrew.common.llms import init_llm from motleycrew.tools import MotleyTool @@ -63,6 +65,7 @@ def create_tool_calling_react_agent( tools: Sequence[BaseTool], prompt: ChatPromptTemplate, output_handler: BaseTool | None = None, + intermediate_steps_processor: Callable | None = None, ) -> Runnable: prompt = prompt.partial( tools=render_text_description(list(tools)), @@ -76,12 +79,18 @@ def create_tool_calling_react_agent( llm_with_tools = llm.bind_tools(tools=tools_for_llm) + if not intermediate_steps_processor: + intermediate_steps_processor = lambda x: x + agent = ( RunnablePassthrough.assign( - agent_scratchpad=lambda x: format_to_tool_messages(x["intermediate_steps"]), + agent_scratchpad=lambda x: format_to_tool_messages( + intermediate_steps_processor(x["intermediate_steps"]) + ), additional_notes=lambda x: x.get("additional_notes") or [], ) | prompt + | RunnableLambda(print_passthrough) | llm_with_tools | ToolsAgentOutputParser() ) @@ -108,6 +117,8 @@ def __init__( handle_parsing_errors: bool = True, handle_tool_errors: bool = True, llm: BaseChatModel | None = None, + max_iterations: int | None = Defaults.DEFAULT_REACT_AGENT_MAX_ITERATIONS, + intermediate_steps_processor: Callable | None = None, verbose: bool = False, ): """ @@ -128,6 +139,9 @@ def __init__( handle_tool_errors: Whether to handle tool errors. If True, `handle_tool_error` and `handle_validation_error` in all tools are set to True. + max_iterations: The maximum number of agent iterations. + intermediate_steps_processor: Function that modifies the intermediate steps array + in some way before each agent iteration. llm: Language model to use. verbose: Whether to log verbose output. @@ -162,6 +176,7 @@ def agent_factory( tools=tools_for_langchain, prompt=prompt, output_handler=output_handler_for_langchain, + intermediate_steps_processor=intermediate_steps_processor, ) if output_handler_for_langchain: @@ -177,6 +192,7 @@ def agent_factory( tools=tools_for_langchain, handle_parsing_errors=handle_parsing_errors, verbose=verbose, + max_iterations=max_iterations, ) return agent_executor @@ -188,5 +204,6 @@ def agent_factory( tools=tools, output_handler=output_handler, chat_history=chat_history, + input_as_messages=True, verbose=verbose, ) diff --git a/motleycrew/agents/langchain/tool_calling_react_prompts.py b/motleycrew/agents/langchain/tool_calling_react_prompts.py index 4f92296d..40be639b 100644 --- a/motleycrew/agents/langchain/tool_calling_react_prompts.py +++ b/motleycrew/agents/langchain/tool_calling_react_prompts.py @@ -25,7 +25,7 @@ def __init__(self): MessagesPlaceholder(variable_name="chat_history", optional=True), ("system", self.main_instruction), MessagesPlaceholder(variable_name="example_messages", optional=True), - ("user", "{input}"), + MessagesPlaceholder(variable_name="input"), MessagesPlaceholder(variable_name="agent_scratchpad"), MessagesPlaceholder(variable_name="additional_notes", optional=True), ] @@ -64,14 +64,14 @@ class ToolCallingReActPromptsForOpenAI(ToolCallingReActPrompts): Begin! """ - output_instruction_with_output_handler = """ + output_instruction_without_output_handler = """ If you have sufficient information to answer the question, your reply must look like ``` Final Answer: [the final answer to the original input question] ``` but without the backticks.""" - output_instruction_without_output_handler = """ + output_instruction_with_output_handler = """ If you have sufficient information to answer the question, you must call the output handler tool. NEVER return the final answer directly, but always do it by CALLING this tool: diff --git a/motleycrew/agents/parent.py b/motleycrew/agents/parent.py index b492bcba..5b7d4ced 100644 --- a/motleycrew/agents/parent.py +++ b/motleycrew/agents/parent.py @@ -10,6 +10,7 @@ Union, ) +from langchain_core.messages import BaseMessage from langchain_core.prompts.chat import ChatPromptTemplate, HumanMessage, SystemMessage from langchain_core.runnables import RunnableConfig from langchain_core.tools import StructuredTool @@ -108,13 +109,14 @@ def __str__(self): return self.__repr__() def compose_prompt( - self, input_dict: dict, prompt: ChatPromptTemplate | str - ) -> Union[str, ChatPromptTemplate]: + self, input_dict: dict, prompt: ChatPromptTemplate | str, as_messages: bool = False + ) -> Union[str, list[BaseMessage]]: """Compose the agent's prompt from the prompt prefix and the provided prompt. Args: input_dict: The input dictionary to the agent. prompt: The prompt to be added to the agent's prompt. + as_messages: Whether the prompt should be returned as a Langchain messages list instead of a single string. Returns: The composed prompt. @@ -145,6 +147,9 @@ def compose_prompt( else: raise ValueError("Prompt must be a string or a ChatPromptTemplate") + if as_messages: + return prompt_messages + # TODO: pass the unformatted messages list to agents that can handle it prompt = "\n\n".join([m.content for m in prompt_messages]) + "\n" return prompt @@ -237,13 +242,15 @@ def materialize(self): else: self._agent = self.agent_factory(tools=self.tools) - def prepare_for_invocation(self, input: dict) -> str: + def prepare_for_invocation(self, input: dict, prompt_as_messages: bool = False) -> str: """Prepare the agent for invocation by materializing it and composing the prompt. Should be called in the beginning of the agent's invoke method. Args: input: the input to the agent + prompt_as_messages: Whether the prompt should be returned as a Langchain messages list + instead of a single string. Returns: str: the composed prompt @@ -254,7 +261,7 @@ def prepare_for_invocation(self, input: dict) -> str: self.output_handler.agent = self self.output_handler.agent_input = input - prompt = self.compose_prompt(input, input.get("prompt")) + prompt = self.compose_prompt(input, input.get("prompt"), as_messages=prompt_as_messages) return prompt def add_tools(self, tools: Sequence[MotleySupportedTool]): diff --git a/motleycrew/common/defaults.py b/motleycrew/common/defaults.py index 8e8849cc..5081336f 100644 --- a/motleycrew/common/defaults.py +++ b/motleycrew/common/defaults.py @@ -5,6 +5,7 @@ class Defaults: """Default values for various settings.""" + DEFAULT_REACT_AGENT_MAX_ITERATIONS = 15 DEFAULT_LLM_FAMILY = LLMFamily.OPENAI DEFAULT_LLM_NAME = "gpt-4o" DEFAULT_LLM_TEMPERATURE = 0.0 diff --git a/motleycrew/common/utils.py b/motleycrew/common/utils.py index bb96fca6..e8de4027 100644 --- a/motleycrew/common/utils.py +++ b/motleycrew/common/utils.py @@ -1,8 +1,10 @@ """Various helpers and utility functions used throughout the project.""" + +import hashlib import sys from typing import Optional, Sequence -import hashlib from urllib.parse import urlparse + from langchain_core.messages import BaseMessage from motleycrew.common.exceptions import ModuleNotInstalled @@ -47,9 +49,8 @@ def generate_hex_hash(data: str, length: Optional[int] = None): def print_passthrough(x): """A helper function useful for debugging LCEL chains. It just returns the input value. - You can put a breakpoint in this function to debug the chain. + You can put a breakpoint in this function to debug a chain. """ - return x