From b843afdce347383cb38267ce4af59a6975a7a364 Mon Sep 17 00:00:00 2001 From: whimo Date: Thu, 13 Jun 2024 22:43:00 +0300 Subject: [PATCH] Allow any kind of history for Langchain agents --- examples/tool_calling_with_memory.py | 2 +- motleycrew/agents/langchain/langchain.py | 33 ++++++++++++------- .../agents/langchain/tool_calling_react.py | 19 ++++++++--- 3 files changed, 36 insertions(+), 18 deletions(-) diff --git a/examples/tool_calling_with_memory.py b/examples/tool_calling_with_memory.py index 3416d87d..54669be9 100644 --- a/examples/tool_calling_with_memory.py +++ b/examples/tool_calling_with_memory.py @@ -18,7 +18,7 @@ def main(): researcher = ReActToolCallingAgent( tools=tools, verbose=True, - with_history=True, + chat_history=True, # llm=init_llm( # llm_framework=LLMFramework.LANGCHAIN, # llm_family=LLMFamily.ANTHROPIC, diff --git a/motleycrew/agents/langchain/langchain.py b/motleycrew/agents/langchain/langchain.py index 1ea53c62..e71b422c 100644 --- a/motleycrew/agents/langchain/langchain.py +++ b/motleycrew/agents/langchain/langchain.py @@ -4,14 +4,13 @@ from langchain.agents import AgentExecutor from langchain_core.runnables import RunnableConfig -from langchain_core.runnables.history import RunnableWithMessageHistory -from langchain_core.chat_history import BaseChatMessageHistory, InMemoryChatMessageHistory +from langchain_core.runnables.history import RunnableWithMessageHistory, GetSessionHistoryCallable +from langchain_core.chat_history import InMemoryChatMessageHistory from langchain_core.language_models import BaseLanguageModel from langchain_core.prompts.chat import ChatPromptTemplate from motleycrew.agents.parent import MotleyAgentParent -from motleycrew.agents.abstract_parent import MotleyAgentAbstractParent from motleycrew.tools import MotleyTool from motleycrew.tracking import add_default_callbacks_to_langchain_config @@ -29,8 +28,7 @@ def __init__( agent_factory: MotleyAgentFactory | None = None, tools: Sequence[MotleySupportedTool] | None = None, verbose: bool = False, - with_history: bool = False, - chat_history: BaseChatMessageHistory | None = None, + chat_history: bool | GetSessionHistoryCallable = True, ): """Description @@ -40,6 +38,10 @@ def __init__( agent_factory (:obj:`MotleyAgentFactory`, optional): tools (:obj:`Sequence[MotleySupportedTool]`, optional): verbose (bool): + chat_history (:obj:`bool`, :obj:`GetSessionHistoryCallable`): + Whether to use chat history or not. If `True`, uses `InMemoryChatMessageHistory`. + If a callable is passed, it is used to get the chat history by session_id. + See Langchain `RunnableWithMessageHistory` get_session_history param for more details. """ super().__init__( description=description, @@ -49,8 +51,11 @@ def __init__( verbose=verbose, ) - self.with_history = with_history - self.chat_history = chat_history or InMemoryChatMessageHistory() + if chat_history is True: + chat_history = InMemoryChatMessageHistory() + self.get_session_history_callable = lambda _: chat_history + else: + self.get_session_history_callable = chat_history def materialize(self): """Materialize the agent and wrap it in RunnableWithMessageHistory if needed.""" @@ -58,12 +63,12 @@ def materialize(self): return super().materialize() - if self.with_history: + if self.get_session_history_callable: if isinstance(self._agent, RunnableWithMessageHistory): return self._agent = RunnableWithMessageHistory( runnable=self._agent, - get_session_history=lambda _: self.chat_history, + get_session_history=self.get_session_history_callable, input_messages_key="input", history_messages_key="chat_history", ) @@ -88,7 +93,7 @@ def invoke( prompt = self.compose_prompt(task_dict, task_dict.get("prompt")) config = add_default_callbacks_to_langchain_config(config) - if self.with_history: + if self.get_session_history_callable: config["configurable"] = config.get("configurable") or {} config["configurable"]["session_id"] = ( config["configurable"].get("session_id") or "default" @@ -110,7 +115,7 @@ def from_function( tools: Sequence[MotleySupportedTool] | None = None, prompt: ChatPromptTemplate | Sequence[ChatPromptTemplate] | None = None, require_tools: bool = False, - with_history: bool = False, + chat_history: bool | GetSessionHistoryCallable = True, verbose: bool = False, ) -> "LangchainMotleyAgent": """Description @@ -123,6 +128,10 @@ def from_function( tools (:obj:`Sequence[MotleySupportedTool]`, optional): prompt (:obj:`ChatPromptTemplate`, :obj:`Sequence[ChatPromptTemplate]`, optional): require_tools (bool): + chat_history (:obj:`bool`, :obj:`GetSessionHistoryCallable`): + Whether to use chat history or not. If `True`, uses `InMemoryChatMessageHistory`. + If a callable is passed, it is used to get the chat history by session_id. + See Langchain `RunnableWithMessageHistory` get_session_history param for more details. verbose (bool): Returns: @@ -149,7 +158,7 @@ def agent_factory(tools: dict[str, MotleyTool]): name=name, agent_factory=agent_factory, tools=tools, - with_history=with_history, + chat_history=chat_history, verbose=verbose, ) diff --git a/motleycrew/agents/langchain/tool_calling_react.py b/motleycrew/agents/langchain/tool_calling_react.py index 94716501..b24d21bd 100644 --- a/motleycrew/agents/langchain/tool_calling_react.py +++ b/motleycrew/agents/langchain/tool_calling_react.py @@ -1,9 +1,10 @@ from typing import Sequence, List, Union -from langchain_core.messages import BaseMessage, HumanMessage, ChatMessage +from langchain_core.messages import BaseMessage, HumanMessage, ToolMessage from langchain_core.language_models import BaseChatModel from langchain_core.prompts.chat import ChatPromptTemplate from langchain_core.runnables import Runnable, RunnablePassthrough, RunnableLambda +from langchain_core.runnables.history import GetSessionHistoryCallable from langchain_core.tools import BaseTool from langchain_core.agents import AgentFinish, AgentActionMessageLog from langchain_core.prompts import MessagesPlaceholder @@ -147,7 +148,11 @@ def merge_consecutive_messages(messages: Sequence[BaseMessage]) -> List[BaseMess """ merged_messages = [] for message in messages: - if not merged_messages or type(merged_messages[-1]) != type(message): + if ( + not merged_messages + or type(merged_messages[-1]) != type(message) + or isinstance(message, ToolMessage) + ): merged_messages.append(message) else: merged_messages[-1].content = merge_content( @@ -223,8 +228,8 @@ def __new__( description: str | None = None, name: str | None = None, prompt: ChatPromptTemplate | Sequence[ChatPromptTemplate] | None = None, - with_history: bool = False, llm: BaseChatModel | None = None, + chat_history: bool | GetSessionHistoryCallable = True, verbose: bool = False, ): """Description @@ -233,7 +238,11 @@ def __new__( tools (Sequence[MotleySupportedTool]): description (:obj:`str`, optional): name (:obj:`str`, optional): - prompt (:obj:ChatPromptTemplate`, :obj:`Sequence[ChatPromptTemplate]', optional): + prompt (:obj:ChatPromptTemplate`, :obj:`Sequence[ChatPromptTemplate]`, optional): + chat_history (:obj:`bool`, :obj:`GetSessionHistoryCallable`): + Whether to use chat history or not. If `True`, uses `InMemoryChatMessageHistory`. + If a callable is passed, it is used to get the chat history by session_id. + See Langchain `RunnableWithMessageHistory` get_session_history param for more details. llm (:obj:`BaseLanguageModel`, optional): verbose (:obj:`bool`, optional): """ @@ -245,6 +254,6 @@ def __new__( prompt=prompt, function=create_tool_calling_react_agent, require_tools=True, - with_history=with_history, + chat_history=chat_history, verbose=verbose, )