Skip to content

Commit

Permalink
Supply prompt prefix as a list of messages
Browse files Browse the repository at this point in the history
  • Loading branch information
Ubuntu committed Aug 25, 2024
1 parent 552fae4 commit 16b0ab1
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 15 deletions.
7 changes: 6 additions & 1 deletion motleycrew/agents/langchain/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def __init__(
tools: Sequence[MotleySupportedTool] | None = None,
output_handler: MotleySupportedTool | None = None,
chat_history: bool | GetSessionHistoryCallable = True,
input_as_messages: bool = False,
verbose: bool = False,
):
"""
Expand Down Expand Up @@ -65,6 +66,8 @@ def __init__(
See :class:`langchain_core.runnables.history.RunnableWithMessageHistory`
for more details.
input_as_messages: Whether the agent expects a list of messages as input instead of a single string.
verbose: Whether to log verbose output.
"""
super().__init__(
Expand All @@ -85,6 +88,8 @@ def __init__(
else:
self.get_session_history_callable = chat_history

self.input_as_messages = input_as_messages

def materialize(self):
"""Materialize the agent and wrap it in RunnableWithMessageHistory if needed."""
if self.is_materialized:
Expand Down Expand Up @@ -141,7 +146,7 @@ def invoke(
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Any:
prompt = self.prepare_for_invocation(input=input)
prompt = self.prepare_for_invocation(input=input, prompt_as_messages=self.input_as_messages)

config = add_default_callbacks_to_langchain_config(config)
if self.get_session_history_callable:
Expand Down
25 changes: 21 additions & 4 deletions motleycrew/agents/langchain/tool_calling_react.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
from __future__ import annotations

from typing import Sequence, Optional
from typing import Sequence, Optional, Callable

from langchain.agents import AgentExecutor
from langchain.agents.format_scratchpad.tools import format_to_tool_messages
from langchain.agents.output_parsers.tools import ToolsAgentOutputParser
from langchain_core.language_models import BaseChatModel
from langchain_core.prompts.chat import ChatPromptTemplate
from langchain_core.runnables import Runnable, RunnablePassthrough
from langchain_core.runnables import Runnable, RunnablePassthrough, RunnableLambda
from langchain_core.runnables.history import GetSessionHistoryCallable
from langchain_core.tools import BaseTool

from motleycrew.common.utils import print_passthrough

try:
from langchain_anthropic import ChatAnthropic
except ImportError:
Expand All @@ -21,7 +23,7 @@
ToolCallingReActPromptsForOpenAI,
ToolCallingReActPromptsForAnthropic,
)
from motleycrew.common import LLMFramework
from motleycrew.common import LLMFramework, Defaults
from motleycrew.common import MotleySupportedTool
from motleycrew.common.llms import init_llm
from motleycrew.tools import MotleyTool
Expand Down Expand Up @@ -63,6 +65,7 @@ def create_tool_calling_react_agent(
tools: Sequence[BaseTool],
prompt: ChatPromptTemplate,
output_handler: BaseTool | None = None,
intermediate_steps_processor: Callable | None = None,
) -> Runnable:
prompt = prompt.partial(
tools=render_text_description(list(tools)),
Expand All @@ -76,12 +79,18 @@ def create_tool_calling_react_agent(

llm_with_tools = llm.bind_tools(tools=tools_for_llm)

if not intermediate_steps_processor:
intermediate_steps_processor = lambda x: x

agent = (
RunnablePassthrough.assign(
agent_scratchpad=lambda x: format_to_tool_messages(x["intermediate_steps"]),
agent_scratchpad=lambda x: format_to_tool_messages(
intermediate_steps_processor(x["intermediate_steps"])
),
additional_notes=lambda x: x.get("additional_notes") or [],
)
| prompt
| RunnableLambda(print_passthrough)
| llm_with_tools
| ToolsAgentOutputParser()
)
Expand All @@ -108,6 +117,8 @@ def __init__(
handle_parsing_errors: bool = True,
handle_tool_errors: bool = True,
llm: BaseChatModel | None = None,
max_iterations: int | None = Defaults.DEFAULT_REACT_AGENT_MAX_ITERATIONS,
intermediate_steps_processor: Callable | None = None,
verbose: bool = False,
):
"""
Expand All @@ -128,6 +139,9 @@ def __init__(
handle_tool_errors: Whether to handle tool errors.
If True, `handle_tool_error` and `handle_validation_error` in all tools
are set to True.
max_iterations: The maximum number of agent iterations.
intermediate_steps_processor: Function that modifies the intermediate steps array
in some way before each agent iteration.
llm: Language model to use.
verbose: Whether to log verbose output.
Expand Down Expand Up @@ -162,6 +176,7 @@ def agent_factory(
tools=tools_for_langchain,
prompt=prompt,
output_handler=output_handler_for_langchain,
intermediate_steps_processor=intermediate_steps_processor,
)

if output_handler_for_langchain:
Expand All @@ -177,6 +192,7 @@ def agent_factory(
tools=tools_for_langchain,
handle_parsing_errors=handle_parsing_errors,
verbose=verbose,
max_iterations=max_iterations,
)
return agent_executor

Expand All @@ -188,5 +204,6 @@ def agent_factory(
tools=tools,
output_handler=output_handler,
chat_history=chat_history,
input_as_messages=True,
verbose=verbose,
)
6 changes: 3 additions & 3 deletions motleycrew/agents/langchain/tool_calling_react_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(self):
MessagesPlaceholder(variable_name="chat_history", optional=True),
("system", self.main_instruction),
MessagesPlaceholder(variable_name="example_messages", optional=True),
("user", "{input}"),
MessagesPlaceholder(variable_name="input"),
MessagesPlaceholder(variable_name="agent_scratchpad"),
MessagesPlaceholder(variable_name="additional_notes", optional=True),
]
Expand Down Expand Up @@ -64,14 +64,14 @@ class ToolCallingReActPromptsForOpenAI(ToolCallingReActPrompts):
Begin!
"""

output_instruction_with_output_handler = """
output_instruction_without_output_handler = """
If you have sufficient information to answer the question, your reply must look like
```
Final Answer: [the final answer to the original input question]
```
but without the backticks."""

output_instruction_without_output_handler = """
output_instruction_with_output_handler = """
If you have sufficient information to answer the question, you must call the output handler tool.
NEVER return the final answer directly, but always do it by CALLING this tool:
Expand Down
15 changes: 11 additions & 4 deletions motleycrew/agents/parent.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
Union,
)

from langchain_core.messages import BaseMessage
from langchain_core.prompts.chat import ChatPromptTemplate, HumanMessage, SystemMessage
from langchain_core.runnables import RunnableConfig
from langchain_core.tools import StructuredTool
Expand Down Expand Up @@ -108,13 +109,14 @@ def __str__(self):
return self.__repr__()

def compose_prompt(
self, input_dict: dict, prompt: ChatPromptTemplate | str
) -> Union[str, ChatPromptTemplate]:
self, input_dict: dict, prompt: ChatPromptTemplate | str, as_messages: bool = False
) -> Union[str, list[BaseMessage]]:
"""Compose the agent's prompt from the prompt prefix and the provided prompt.
Args:
input_dict: The input dictionary to the agent.
prompt: The prompt to be added to the agent's prompt.
as_messages: Whether the prompt should be returned as a Langchain messages list instead of a single string.
Returns:
The composed prompt.
Expand Down Expand Up @@ -145,6 +147,9 @@ def compose_prompt(
else:
raise ValueError("Prompt must be a string or a ChatPromptTemplate")

if as_messages:
return prompt_messages

# TODO: pass the unformatted messages list to agents that can handle it
prompt = "\n\n".join([m.content for m in prompt_messages]) + "\n"
return prompt
Expand Down Expand Up @@ -237,13 +242,15 @@ def materialize(self):
else:
self._agent = self.agent_factory(tools=self.tools)

def prepare_for_invocation(self, input: dict) -> str:
def prepare_for_invocation(self, input: dict, prompt_as_messages: bool = False) -> str:
"""Prepare the agent for invocation by materializing it and composing the prompt.
Should be called in the beginning of the agent's invoke method.
Args:
input: the input to the agent
prompt_as_messages: Whether the prompt should be returned as a Langchain messages list
instead of a single string.
Returns:
str: the composed prompt
Expand All @@ -254,7 +261,7 @@ def prepare_for_invocation(self, input: dict) -> str:
self.output_handler.agent = self
self.output_handler.agent_input = input

prompt = self.compose_prompt(input, input.get("prompt"))
prompt = self.compose_prompt(input, input.get("prompt"), as_messages=prompt_as_messages)
return prompt

def add_tools(self, tools: Sequence[MotleySupportedTool]):
Expand Down
1 change: 1 addition & 0 deletions motleycrew/common/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
class Defaults:
"""Default values for various settings."""

DEFAULT_REACT_AGENT_MAX_ITERATIONS = 15
DEFAULT_LLM_FAMILY = LLMFamily.OPENAI
DEFAULT_LLM_NAME = "gpt-4o"
DEFAULT_LLM_TEMPERATURE = 0.0
Expand Down
7 changes: 4 additions & 3 deletions motleycrew/common/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
"""Various helpers and utility functions used throughout the project."""

import hashlib
import sys
from typing import Optional, Sequence
import hashlib
from urllib.parse import urlparse

from langchain_core.messages import BaseMessage

from motleycrew.common.exceptions import ModuleNotInstalled
Expand Down Expand Up @@ -47,9 +49,8 @@ def generate_hex_hash(data: str, length: Optional[int] = None):
def print_passthrough(x):
"""A helper function useful for debugging LCEL chains. It just returns the input value.
You can put a breakpoint in this function to debug the chain.
You can put a breakpoint in this function to debug a chain.
"""

return x


Expand Down

0 comments on commit 16b0ab1

Please sign in to comment.