Skip to content

Commit

Permalink
Allow any kind of history for Langchain agents (#45)
Browse files Browse the repository at this point in the history
  • Loading branch information
whimo committed Jun 13, 2024
1 parent 02023b6 commit b61f46f
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 18 deletions.
2 changes: 1 addition & 1 deletion examples/tool_calling_with_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def main():
researcher = ReActToolCallingAgent(
tools=tools,
verbose=True,
with_history=True,
chat_history=True,
# llm=init_llm(
# llm_framework=LLMFramework.LANGCHAIN,
# llm_family=LLMFamily.ANTHROPIC,
Expand Down
33 changes: 21 additions & 12 deletions motleycrew/agents/langchain/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,13 @@

from langchain.agents import AgentExecutor
from langchain_core.runnables import RunnableConfig
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.chat_history import BaseChatMessageHistory, InMemoryChatMessageHistory
from langchain_core.runnables.history import RunnableWithMessageHistory, GetSessionHistoryCallable
from langchain_core.chat_history import InMemoryChatMessageHistory
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts.chat import ChatPromptTemplate


from motleycrew.agents.parent import MotleyAgentParent
from motleycrew.agents.abstract_parent import MotleyAgentAbstractParent

from motleycrew.tools import MotleyTool
from motleycrew.tracking import add_default_callbacks_to_langchain_config
Expand All @@ -29,8 +28,7 @@ def __init__(
agent_factory: MotleyAgentFactory | None = None,
tools: Sequence[MotleySupportedTool] | None = None,
verbose: bool = False,
with_history: bool = False,
chat_history: BaseChatMessageHistory | None = None,
chat_history: bool | GetSessionHistoryCallable = True,
):
"""Description
Expand All @@ -40,6 +38,10 @@ def __init__(
agent_factory (:obj:`MotleyAgentFactory`, optional):
tools (:obj:`Sequence[MotleySupportedTool]`, optional):
verbose (bool):
chat_history (:obj:`bool`, :obj:`GetSessionHistoryCallable`):
Whether to use chat history or not. If `True`, uses `InMemoryChatMessageHistory`.
If a callable is passed, it is used to get the chat history by session_id.
See Langchain `RunnableWithMessageHistory` get_session_history param for more details.
"""
super().__init__(
description=description,
Expand All @@ -49,21 +51,24 @@ def __init__(
verbose=verbose,
)

self.with_history = with_history
self.chat_history = chat_history or InMemoryChatMessageHistory()
if chat_history is True:
chat_history = InMemoryChatMessageHistory()
self.get_session_history_callable = lambda _: chat_history
else:
self.get_session_history_callable = chat_history

def materialize(self):
"""Materialize the agent and wrap it in RunnableWithMessageHistory if needed."""
if self.is_materialized:
return

super().materialize()
if self.with_history:
if self.get_session_history_callable:
if isinstance(self._agent, RunnableWithMessageHistory):
return
self._agent = RunnableWithMessageHistory(
runnable=self._agent,
get_session_history=lambda _: self.chat_history,
get_session_history=self.get_session_history_callable,
input_messages_key="input",
history_messages_key="chat_history",
)
Expand All @@ -88,7 +93,7 @@ def invoke(
prompt = self.compose_prompt(task_dict, task_dict.get("prompt"))

config = add_default_callbacks_to_langchain_config(config)
if self.with_history:
if self.get_session_history_callable:
config["configurable"] = config.get("configurable") or {}
config["configurable"]["session_id"] = (
config["configurable"].get("session_id") or "default"
Expand All @@ -110,7 +115,7 @@ def from_function(
tools: Sequence[MotleySupportedTool] | None = None,
prompt: ChatPromptTemplate | Sequence[ChatPromptTemplate] | None = None,
require_tools: bool = False,
with_history: bool = False,
chat_history: bool | GetSessionHistoryCallable = True,
verbose: bool = False,
) -> "LangchainMotleyAgent":
"""Description
Expand All @@ -123,6 +128,10 @@ def from_function(
tools (:obj:`Sequence[MotleySupportedTool]`, optional):
prompt (:obj:`ChatPromptTemplate`, :obj:`Sequence[ChatPromptTemplate]`, optional):
require_tools (bool):
chat_history (:obj:`bool`, :obj:`GetSessionHistoryCallable`):
Whether to use chat history or not. If `True`, uses `InMemoryChatMessageHistory`.
If a callable is passed, it is used to get the chat history by session_id.
See Langchain `RunnableWithMessageHistory` get_session_history param for more details.
verbose (bool):
Returns:
Expand All @@ -149,7 +158,7 @@ def agent_factory(tools: dict[str, MotleyTool]):
name=name,
agent_factory=agent_factory,
tools=tools,
with_history=with_history,
chat_history=chat_history,
verbose=verbose,
)

Expand Down
19 changes: 14 additions & 5 deletions motleycrew/agents/langchain/tool_calling_react.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import Sequence, List, Union

from langchain_core.messages import BaseMessage, HumanMessage, ChatMessage
from langchain_core.messages import BaseMessage, HumanMessage, ToolMessage
from langchain_core.language_models import BaseChatModel
from langchain_core.prompts.chat import ChatPromptTemplate
from langchain_core.runnables import Runnable, RunnablePassthrough, RunnableLambda
from langchain_core.runnables.history import GetSessionHistoryCallable
from langchain_core.tools import BaseTool
from langchain_core.agents import AgentFinish, AgentActionMessageLog
from langchain_core.prompts import MessagesPlaceholder
Expand Down Expand Up @@ -147,7 +148,11 @@ def merge_consecutive_messages(messages: Sequence[BaseMessage]) -> List[BaseMess
"""
merged_messages = []
for message in messages:
if not merged_messages or type(merged_messages[-1]) != type(message):
if (
not merged_messages
or type(merged_messages[-1]) != type(message)
or isinstance(message, ToolMessage)
):
merged_messages.append(message)
else:
merged_messages[-1].content = merge_content(
Expand Down Expand Up @@ -223,8 +228,8 @@ def __new__(
description: str | None = None,
name: str | None = None,
prompt: ChatPromptTemplate | Sequence[ChatPromptTemplate] | None = None,
with_history: bool = False,
llm: BaseChatModel | None = None,
chat_history: bool | GetSessionHistoryCallable = True,
verbose: bool = False,
):
"""Description
Expand All @@ -233,7 +238,11 @@ def __new__(
tools (Sequence[MotleySupportedTool]):
description (:obj:`str`, optional):
name (:obj:`str`, optional):
prompt (:obj:ChatPromptTemplate`, :obj:`Sequence[ChatPromptTemplate]', optional):
prompt (:obj:ChatPromptTemplate`, :obj:`Sequence[ChatPromptTemplate]`, optional):
chat_history (:obj:`bool`, :obj:`GetSessionHistoryCallable`):
Whether to use chat history or not. If `True`, uses `InMemoryChatMessageHistory`.
If a callable is passed, it is used to get the chat history by session_id.
See Langchain `RunnableWithMessageHistory` get_session_history param for more details.
llm (:obj:`BaseLanguageModel`, optional):
verbose (:obj:`bool`, optional):
"""
Expand All @@ -245,6 +254,6 @@ def __new__(
prompt=prompt,
function=create_tool_calling_react_agent,
require_tools=True,
with_history=with_history,
chat_history=chat_history,
verbose=verbose,
)

0 comments on commit b61f46f

Please sign in to comment.