Skip to content

Commit

Permalink
refactor add LangchainOutputHandlerMixin
Browse files Browse the repository at this point in the history
  • Loading branch information
User committed Jun 25, 2024
1 parent 7f4491b commit ddcbb6f
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 97 deletions.
100 changes: 3 additions & 97 deletions motleycrew/agents/langchain/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,18 @@
from pydantic.v1.config import BaseConfig

from langchain.agents import AgentExecutor
from langchain_core.agents import AgentFinish, AgentAction
from langchain_core.runnables import Runnable, RunnableConfig
from langchain_core.runnables import RunnableConfig
from langchain_core.runnables.history import RunnableWithMessageHistory, GetSessionHistoryCallable
from langchain_core.chat_history import InMemoryChatMessageHistory
from langchain_core.tools import BaseTool
from langchain_core.tools import Tool
from langchain_core.callbacks import CallbackManagerForChainRun
from langchain_core.messages import AIMessage

from motleycrew.agents.parent import MotleyAgentParent
from motleycrew.tracking import add_default_callbacks_to_langchain_config
from motleycrew.common import MotleySupportedTool, logger
from motleycrew.common import MotleyAgentFactory
from motleycrew.agents.parent import DirectOutput
from motleycrew.agents.mixins import LangchainOutputHandlerMixin


class LangchainMotleyAgent(MotleyAgentParent):
class LangchainMotleyAgent(MotleyAgentParent, LangchainOutputHandlerMixin):
def __init__(
self,
description: str | None = None,
Expand Down Expand Up @@ -62,94 +57,6 @@ def __init__(
else:
self.get_session_history_callable = chat_history

self._agent_finish_blocker_tool = self._create_agent_finish_blocker_tool()

def _create_agent_finish_blocker_tool(self) -> BaseTool:
"""Create a tool that will force the agent to retry if it attempts to return the output
bypassing the output handler.
"""

def create_agent_finish_blocking_message(input: Any) -> str:
return f"{input}\n\nYou must use {self.output_handler.name} to return the final output."

return Tool.from_function(
name="agent_finish_blocker",
description="",
func=create_agent_finish_blocking_message,
)

def _block_agent_finish(self, input: Any):
"""Intercept AgentFinish for forcing output via output handler.
If the agent attempts to return the output bypassing the output handler,
a tool call to the agent_finish_blocker_tool will be made
so that one more AgentExecutor iteration is forced.
"""
if isinstance(input, AgentFinish) and self.output_handler:
return [
AgentAction(
tool=self._agent_finish_blocker_tool.name,
tool_input={"input": input.return_values},
log="\nDetected AgentFinish, blocking it to force output via output handler.\n",
)
]
return input

def agent_plane_decorator(self):
"""Decorator for inclusion in the call chain of the agent, the output handler tool"""

def decorator(func: Callable):

def wrapper(
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: "Callbacks" = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
step = func(intermediate_steps, callbacks, **kwargs)

if not isinstance(step, AgentFinish):
return step

if self.output_handler is not None:
return AgentAction(
tool=self._agent_finish_blocker_tool.name,
tool_input={"input": self._agent_finish_blocker_tool},
log="\nDetected AgentFinish, blocking it to force output via output handler.\n",
)
return step

return wrapper

return decorator

def take_next_step_decorator(self):
"""DirectOutput exception interception decorator"""

def decorator(func: Callable):
def wrapper(
name_to_tool_map: Dict[str, BaseTool],
color_mapping: Dict[str, str],
inputs: Dict[str, str],
intermediate_steps: List[Tuple[AgentAction, str]],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Union[AgentFinish, List[Tuple[AgentAction, str]]]:

try:
step = func(
name_to_tool_map, color_mapping, inputs, intermediate_steps, run_manager
)
except DirectOutput as direct_ex:
message = "Final answer\n" + str(direct_ex.output)
return AgentFinish(
return_values={"output": direct_ex.output},
messages=[AIMessage(content=message)],
log=message,
)
return step

return wrapper

return decorator

def materialize(self):
"""Materialize the agent and wrap it in RunnableWithMessageHistory if needed."""
if self.is_materialized:
Expand All @@ -159,7 +66,6 @@ def materialize(self):
assert isinstance(self._agent, AgentExecutor)

if self.output_handler:
self._agent.tools += [self._agent_finish_blocker_tool]

plan = ModelField(
name="plan", type_=Callable, class_validators={}, model_config=BaseConfig
Expand Down
68 changes: 68 additions & 0 deletions motleycrew/agents/mixins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from typing import Any, Optional, Callable, Union, Dict, List, Tuple

from langchain_core.agents import AgentFinish, AgentAction
from langchain_core.tools import BaseTool
from langchain_core.callbacks import CallbackManagerForChainRun
from langchain_core.messages import AIMessage
from motleycrew.agents.parent import DirectOutput


class LangchainOutputHandlerMixin:

def agent_plane_decorator(self):
"""Decorator for inclusion in the call chain of the agent, the output handler tool"""

def decorator(func: Callable):

def wrapper(
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: "Callbacks" = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
step = func(intermediate_steps, callbacks, **kwargs)

if not isinstance(step, AgentFinish):
return step

if self.output_handler is not None:
return AgentAction(
tool=self.output_handler.name,
tool_input=step.return_values,
log="Use tool: {}\nInput: {}".format(
self.output_handler.name, step.return_values
),
)
return step

return wrapper

return decorator

def take_next_step_decorator(self):
"""DirectOutput exception interception decorator"""

def decorator(func: Callable):
def wrapper(
name_to_tool_map: Dict[str, BaseTool],
color_mapping: Dict[str, str],
inputs: Dict[str, str],
intermediate_steps: List[Tuple[AgentAction, str]],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Union[AgentFinish, List[Tuple[AgentAction, str]]]:

try:
step = func(
name_to_tool_map, color_mapping, inputs, intermediate_steps, run_manager
)
except DirectOutput as direct_ex:
message = "Final answer\n" + str(direct_ex.output)
return AgentFinish(
return_values={"output": direct_ex.output},
messages=[AIMessage(content=message)],
log=message,
)
return step

return wrapper

return decorator

0 comments on commit ddcbb6f

Please sign in to comment.