Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Supply prompt prefix as a list of messages #70

Merged
merged 6 commits into from
Aug 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion motleycrew/agents/langchain/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def __init__(
tools: Sequence[MotleySupportedTool] | None = None,
output_handler: MotleySupportedTool | None = None,
chat_history: bool | GetSessionHistoryCallable = True,
input_as_messages: bool = False,
verbose: bool = False,
):
"""
Expand Down Expand Up @@ -66,6 +67,8 @@ def __init__(
See :class:`langchain_core.runnables.history.RunnableWithMessageHistory`
for more details.

input_as_messages: Whether the agent expects a list of messages as input instead of a single string.

verbose: Whether to log verbose output.
"""
super().__init__(
Expand All @@ -86,6 +89,8 @@ def __init__(
else:
self.get_session_history_callable = chat_history

self.input_as_messages = input_as_messages

def materialize(self):
"""Materialize the agent and wrap it in RunnableWithMessageHistory if needed."""
if self.is_materialized:
Expand Down Expand Up @@ -142,7 +147,7 @@ def invoke(
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Any:
prompt = self.prepare_for_invocation(input=input)
prompt = self.prepare_for_invocation(input=input, prompt_as_messages=self.input_as_messages)

config = add_default_callbacks_to_langchain_config(config)
if self.get_session_history_callable:
Expand Down
25 changes: 21 additions & 4 deletions motleycrew/agents/langchain/tool_calling_react.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
from __future__ import annotations

from typing import Sequence, Optional
from typing import Sequence, Optional, Callable

from langchain.agents import AgentExecutor
from langchain.agents.format_scratchpad.tools import format_to_tool_messages
from langchain.agents.output_parsers.tools import ToolsAgentOutputParser
from langchain_core.language_models import BaseChatModel
from langchain_core.prompts.chat import ChatPromptTemplate
from langchain_core.runnables import Runnable, RunnablePassthrough
from langchain_core.runnables import Runnable, RunnablePassthrough, RunnableLambda
from langchain_core.runnables.history import GetSessionHistoryCallable
from langchain_core.tools import BaseTool

from motleycrew.common.utils import print_passthrough

try:
from langchain_anthropic import ChatAnthropic
except ImportError:
Expand All @@ -21,7 +23,7 @@
ToolCallingReActPromptsForOpenAI,
ToolCallingReActPromptsForAnthropic,
)
from motleycrew.common import LLMFramework
from motleycrew.common import LLMFramework, Defaults
from motleycrew.common import MotleySupportedTool
from motleycrew.common.llms import init_llm
from motleycrew.tools import MotleyTool
Expand Down Expand Up @@ -63,6 +65,7 @@ def create_tool_calling_react_agent(
tools: Sequence[BaseTool],
prompt: ChatPromptTemplate,
output_handler: BaseTool | None = None,
intermediate_steps_processor: Callable | None = None,
) -> Runnable:
prompt = prompt.partial(
tools=render_text_description(list(tools)),
Expand All @@ -76,12 +79,18 @@ def create_tool_calling_react_agent(

llm_with_tools = llm.bind_tools(tools=tools_for_llm)

if not intermediate_steps_processor:
intermediate_steps_processor = lambda x: x

agent = (
RunnablePassthrough.assign(
agent_scratchpad=lambda x: format_to_tool_messages(x["intermediate_steps"]),
agent_scratchpad=lambda x: format_to_tool_messages(
intermediate_steps_processor(x["intermediate_steps"])
),
additional_notes=lambda x: x.get("additional_notes") or [],
)
| prompt
| RunnableLambda(print_passthrough)
| llm_with_tools
| ToolsAgentOutputParser()
)
Expand All @@ -108,6 +117,8 @@ def __init__(
handle_parsing_errors: bool = True,
handle_tool_errors: bool = True,
llm: BaseChatModel | None = None,
max_iterations: int | None = Defaults.DEFAULT_REACT_AGENT_MAX_ITERATIONS,
intermediate_steps_processor: Callable | None = None,
verbose: bool = False,
):
"""
Expand All @@ -128,6 +139,9 @@ def __init__(
handle_tool_errors: Whether to handle tool errors.
If True, `handle_tool_error` and `handle_validation_error` in all tools
are set to True.
max_iterations: The maximum number of agent iterations.
intermediate_steps_processor: Function that modifies the intermediate steps array
in some way before each agent iteration.
llm: Language model to use.
verbose: Whether to log verbose output.

Expand Down Expand Up @@ -162,6 +176,7 @@ def agent_factory(
tools=tools_for_langchain,
prompt=prompt,
output_handler=output_handler_for_langchain,
intermediate_steps_processor=intermediate_steps_processor,
)

if output_handler_for_langchain:
Expand All @@ -177,6 +192,7 @@ def agent_factory(
tools=tools_for_langchain,
handle_parsing_errors=handle_parsing_errors,
verbose=verbose,
max_iterations=max_iterations,
)
return agent_executor

Expand All @@ -188,5 +204,6 @@ def agent_factory(
tools=tools,
output_handler=output_handler,
chat_history=chat_history,
input_as_messages=True,
verbose=verbose,
)
6 changes: 3 additions & 3 deletions motleycrew/agents/langchain/tool_calling_react_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(self):
MessagesPlaceholder(variable_name="chat_history", optional=True),
("system", self.main_instruction),
MessagesPlaceholder(variable_name="example_messages", optional=True),
("user", "{input}"),
MessagesPlaceholder(variable_name="input"),
MessagesPlaceholder(variable_name="agent_scratchpad"),
MessagesPlaceholder(variable_name="additional_notes", optional=True),
]
Expand Down Expand Up @@ -64,14 +64,14 @@ class ToolCallingReActPromptsForOpenAI(ToolCallingReActPrompts):
Begin!
"""

output_instruction_with_output_handler = """
output_instruction_without_output_handler = """
If you have sufficient information to answer the question, your reply must look like
```
Final Answer: [the final answer to the original input question]
```
but without the backticks."""

output_instruction_without_output_handler = """
output_instruction_with_output_handler = """
If you have sufficient information to answer the question, you must call the output handler tool.

NEVER return the final answer directly, but always do it by CALLING this tool:
Expand Down
15 changes: 11 additions & 4 deletions motleycrew/agents/parent.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
Union,
)

from langchain_core.messages import BaseMessage
from langchain_core.prompts.chat import ChatPromptTemplate, HumanMessage, SystemMessage
from langchain_core.runnables import RunnableConfig
from langchain_core.tools import StructuredTool
Expand Down Expand Up @@ -108,13 +109,14 @@ def __str__(self):
return self.__repr__()

def compose_prompt(
self, input_dict: dict, prompt: ChatPromptTemplate | str
) -> Union[str, ChatPromptTemplate]:
self, input_dict: dict, prompt: ChatPromptTemplate | str, as_messages: bool = False
) -> Union[str, list[BaseMessage]]:
"""Compose the agent's prompt from the prompt prefix and the provided prompt.

Args:
input_dict: The input dictionary to the agent.
prompt: The prompt to be added to the agent's prompt.
as_messages: Whether the prompt should be returned as a Langchain messages list instead of a single string.

Returns:
The composed prompt.
Expand Down Expand Up @@ -145,6 +147,9 @@ def compose_prompt(
else:
raise ValueError("Prompt must be a string or a ChatPromptTemplate")

if as_messages:
return prompt_messages

# TODO: pass the unformatted messages list to agents that can handle it
prompt = "\n\n".join([m.content for m in prompt_messages]) + "\n"
return prompt
Expand Down Expand Up @@ -237,13 +242,15 @@ def materialize(self):
else:
self._agent = self.agent_factory(tools=self.tools)

def prepare_for_invocation(self, input: dict) -> str:
def prepare_for_invocation(self, input: dict, prompt_as_messages: bool = False) -> str:
"""Prepare the agent for invocation by materializing it and composing the prompt.

Should be called in the beginning of the agent's invoke method.

Args:
input: the input to the agent
prompt_as_messages: Whether the prompt should be returned as a Langchain messages list
instead of a single string.

Returns:
str: the composed prompt
Expand All @@ -254,7 +261,7 @@ def prepare_for_invocation(self, input: dict) -> str:
self.output_handler.agent = self
self.output_handler.agent_input = input

prompt = self.compose_prompt(input, input.get("prompt"))
prompt = self.compose_prompt(input, input.get("prompt"), as_messages=prompt_as_messages)
return prompt

def add_tools(self, tools: Sequence[MotleySupportedTool]):
Expand Down
1 change: 1 addition & 0 deletions motleycrew/common/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
class Defaults:
"""Default values for various settings."""

DEFAULT_REACT_AGENT_MAX_ITERATIONS = 15
DEFAULT_LLM_FAMILY = LLMFamily.OPENAI
DEFAULT_LLM_NAME = "gpt-4o"
DEFAULT_LLM_TEMPERATURE = 0.0
Expand Down
7 changes: 4 additions & 3 deletions motleycrew/common/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
"""Various helpers and utility functions used throughout the project."""

import hashlib
import sys
from typing import Optional, Sequence
import hashlib
from urllib.parse import urlparse

from langchain_core.messages import BaseMessage

from motleycrew.common.exceptions import ModuleNotInstalled
Expand Down Expand Up @@ -47,9 +49,8 @@ def generate_hex_hash(data: str, length: Optional[int] = None):
def print_passthrough(x):
"""A helper function useful for debugging LCEL chains. It just returns the input value.

You can put a breakpoint in this function to debug the chain.
You can put a breakpoint in this function to debug a chain.
"""

return x


Expand Down
Loading
Loading