Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow any kind of history for Langchain agents #45

Merged
merged 1 commit into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/tool_calling_with_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def main():
researcher = ReActToolCallingAgent(
tools=tools,
verbose=True,
with_history=True,
chat_history=True,
# llm=init_llm(
# llm_framework=LLMFramework.LANGCHAIN,
# llm_family=LLMFamily.ANTHROPIC,
Expand Down
33 changes: 21 additions & 12 deletions motleycrew/agents/langchain/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,13 @@

from langchain.agents import AgentExecutor
from langchain_core.runnables import RunnableConfig
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.chat_history import BaseChatMessageHistory, InMemoryChatMessageHistory
from langchain_core.runnables.history import RunnableWithMessageHistory, GetSessionHistoryCallable
from langchain_core.chat_history import InMemoryChatMessageHistory
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts.chat import ChatPromptTemplate


from motleycrew.agents.parent import MotleyAgentParent
from motleycrew.agents.abstract_parent import MotleyAgentAbstractParent

from motleycrew.tools import MotleyTool
from motleycrew.tracking import add_default_callbacks_to_langchain_config
Expand All @@ -29,8 +28,7 @@ def __init__(
agent_factory: MotleyAgentFactory | None = None,
tools: Sequence[MotleySupportedTool] | None = None,
verbose: bool = False,
with_history: bool = False,
chat_history: BaseChatMessageHistory | None = None,
chat_history: bool | GetSessionHistoryCallable = True,
):
"""Description

Expand All @@ -40,6 +38,10 @@ def __init__(
agent_factory (:obj:`MotleyAgentFactory`, optional):
tools (:obj:`Sequence[MotleySupportedTool]`, optional):
verbose (bool):
chat_history (:obj:`bool`, :obj:`GetSessionHistoryCallable`):
Whether to use chat history or not. If `True`, uses `InMemoryChatMessageHistory`.
If a callable is passed, it is used to get the chat history by session_id.
See Langchain `RunnableWithMessageHistory` get_session_history param for more details.
"""
super().__init__(
description=description,
Expand All @@ -49,21 +51,24 @@ def __init__(
verbose=verbose,
)

self.with_history = with_history
self.chat_history = chat_history or InMemoryChatMessageHistory()
if chat_history is True:
chat_history = InMemoryChatMessageHistory()
self.get_session_history_callable = lambda _: chat_history
else:
self.get_session_history_callable = chat_history

def materialize(self):
"""Materialize the agent and wrap it in RunnableWithMessageHistory if needed."""
if self.is_materialized:
return

super().materialize()
if self.with_history:
if self.get_session_history_callable:
if isinstance(self._agent, RunnableWithMessageHistory):
return
self._agent = RunnableWithMessageHistory(
runnable=self._agent,
get_session_history=lambda _: self.chat_history,
get_session_history=self.get_session_history_callable,
input_messages_key="input",
history_messages_key="chat_history",
)
Expand All @@ -88,7 +93,7 @@ def invoke(
prompt = self.compose_prompt(task_dict, task_dict.get("prompt"))

config = add_default_callbacks_to_langchain_config(config)
if self.with_history:
if self.get_session_history_callable:
config["configurable"] = config.get("configurable") or {}
config["configurable"]["session_id"] = (
config["configurable"].get("session_id") or "default"
Expand All @@ -110,7 +115,7 @@ def from_function(
tools: Sequence[MotleySupportedTool] | None = None,
prompt: ChatPromptTemplate | Sequence[ChatPromptTemplate] | None = None,
require_tools: bool = False,
with_history: bool = False,
chat_history: bool | GetSessionHistoryCallable = True,
verbose: bool = False,
) -> "LangchainMotleyAgent":
"""Description
Expand All @@ -123,6 +128,10 @@ def from_function(
tools (:obj:`Sequence[MotleySupportedTool]`, optional):
prompt (:obj:`ChatPromptTemplate`, :obj:`Sequence[ChatPromptTemplate]`, optional):
require_tools (bool):
chat_history (:obj:`bool`, :obj:`GetSessionHistoryCallable`):
Whether to use chat history or not. If `True`, uses `InMemoryChatMessageHistory`.
If a callable is passed, it is used to get the chat history by session_id.
See Langchain `RunnableWithMessageHistory` get_session_history param for more details.
verbose (bool):

Returns:
Expand All @@ -149,7 +158,7 @@ def agent_factory(tools: dict[str, MotleyTool]):
name=name,
agent_factory=agent_factory,
tools=tools,
with_history=with_history,
chat_history=chat_history,
verbose=verbose,
)

Expand Down
19 changes: 14 additions & 5 deletions motleycrew/agents/langchain/tool_calling_react.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import Sequence, List, Union

from langchain_core.messages import BaseMessage, HumanMessage, ChatMessage
from langchain_core.messages import BaseMessage, HumanMessage, ToolMessage
from langchain_core.language_models import BaseChatModel
from langchain_core.prompts.chat import ChatPromptTemplate
from langchain_core.runnables import Runnable, RunnablePassthrough, RunnableLambda
from langchain_core.runnables.history import GetSessionHistoryCallable
from langchain_core.tools import BaseTool
from langchain_core.agents import AgentFinish, AgentActionMessageLog
from langchain_core.prompts import MessagesPlaceholder
Expand Down Expand Up @@ -147,7 +148,11 @@ def merge_consecutive_messages(messages: Sequence[BaseMessage]) -> List[BaseMess
"""
merged_messages = []
for message in messages:
if not merged_messages or type(merged_messages[-1]) != type(message):
if (
not merged_messages
or type(merged_messages[-1]) != type(message)
or isinstance(message, ToolMessage)
):
merged_messages.append(message)
else:
merged_messages[-1].content = merge_content(
Expand Down Expand Up @@ -223,8 +228,8 @@ def __new__(
description: str | None = None,
name: str | None = None,
prompt: ChatPromptTemplate | Sequence[ChatPromptTemplate] | None = None,
with_history: bool = False,
llm: BaseChatModel | None = None,
chat_history: bool | GetSessionHistoryCallable = True,
verbose: bool = False,
):
"""Description
Expand All @@ -233,7 +238,11 @@ def __new__(
tools (Sequence[MotleySupportedTool]):
description (:obj:`str`, optional):
name (:obj:`str`, optional):
prompt (:obj:ChatPromptTemplate`, :obj:`Sequence[ChatPromptTemplate]', optional):
prompt (:obj:ChatPromptTemplate`, :obj:`Sequence[ChatPromptTemplate]`, optional):
chat_history (:obj:`bool`, :obj:`GetSessionHistoryCallable`):
Whether to use chat history or not. If `True`, uses `InMemoryChatMessageHistory`.
If a callable is passed, it is used to get the chat history by session_id.
See Langchain `RunnableWithMessageHistory` get_session_history param for more details.
llm (:obj:`BaseLanguageModel`, optional):
verbose (:obj:`bool`, optional):
"""
Expand All @@ -245,6 +254,6 @@ def __new__(
prompt=prompt,
function=create_tool_calling_react_agent,
require_tools=True,
with_history=with_history,
chat_history=chat_history,
verbose=verbose,
)
Loading