Skip to content

Commit

Permalink
Clearer AgentFinish blocker messages (#60)
Browse files Browse the repository at this point in the history
* Clearer AgentFinish blocker messages

* fix tests

* update tests data

* fix unit tests
  • Loading branch information
whimo authored Jul 4, 2024
1 parent 1938a62 commit 2a767f0
Show file tree
Hide file tree
Showing 25 changed files with 29 additions and 19 deletions.
5 changes: 4 additions & 1 deletion motleycrew/agents/langchain/tool_calling_react.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
MessagesPlaceholder(variable_name="chat_history", optional=True),
("user", "{input}"),
MessagesPlaceholder(variable_name="agent_scratchpad"),
MessagesPlaceholder(variable_name="additional_notes", optional=True),
]
)

Expand All @@ -78,6 +79,7 @@
MessagesPlaceholder(variable_name="chat_history", optional=True),
("user", "{input}"),
MessagesPlaceholder(variable_name="agent_scratchpad"),
MessagesPlaceholder(variable_name="additional_notes", optional=True),
]
)
default_act_prompt_without_output_handler = default_act_prompt.partial(
Expand Down Expand Up @@ -233,7 +235,8 @@ def create_tool_calling_react_agent(
RunnablePassthrough.assign(
agent_scratchpad=lambda x: merge_consecutive_messages(
format_to_tool_messages(x["intermediate_steps"])
)
),
additional_notes=lambda x: x.get("additional_notes") or [],
)
| {"thought": think_chain, "background": RunnablePassthrough()}
| RunnableLambda(print_passthrough)
Expand Down
2 changes: 1 addition & 1 deletion motleycrew/agents/llama_index/llama_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def wrapper(
output_task_step = TaskStep(
task_id=task_id,
step_id=str(uuid.uuid4()),
input="You must call the {} tool to return the output.".format(
input="You must call the `{}` tool to return the output.".format(
self.output_handler.name
),
)
Expand Down
18 changes: 11 additions & 7 deletions motleycrew/agents/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,17 @@ class LangchainOutputHandlingAgentMixin:
"""A mixin for Langchain-based agents that support output handlers."""

output_handler: Optional[MotleyTool] = None
_agent_finish_blocker_tool: Optional[MotleyTool] = None

def _create_agent_finish_blocker_tool(self) -> BaseTool:
"""Create a tool that will force the agent to retry if it attempts to return the output
bypassing the output handler.
"""

def create_agent_finish_blocking_message(input: Any = None) -> str:
return f"You must use {self.output_handler.name} to return the final output.\n"
return (
f"You must call the `{self.output_handler.name}` tool to return the final output.\n"
)

return Tool.from_function(
name="agent_finish_blocker",
Expand All @@ -37,27 +40,28 @@ def _is_blocker_action(self, action: AgentAction) -> bool:
def agent_plan_decorator(self, func: Callable):
"""Decorator for Agent.plan() method that intercepts AgentFinish events"""

additional_inputs = set()

def wrapper(
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: "Callbacks" = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
additional_notes = []

if self.output_handler:
to_remove_steps = []
for intermediate_step in intermediate_steps:
action, action_output = intermediate_step
if self._is_blocker_action(action):
additional_inputs.add(action_output)
# Add the interaction telling the LLM that it must use the output handler
additional_notes.append(("ai", action.tool_input))
additional_notes.append(("user", action_output))
to_remove_steps.append(intermediate_step)

for to_remove_step in to_remove_steps:
intermediate_steps.remove(to_remove_step)

if additional_inputs:
kwargs["input"] = kwargs["input"] + "\n{}".format("\n".join(additional_inputs))
if additional_notes:
kwargs["additional_notes"] = additional_notes

step = func(intermediate_steps, callbacks, **kwargs)

Expand All @@ -67,7 +71,7 @@ def wrapper(
if self.output_handler is not None:
return AgentAction(
tool=self._agent_finish_blocker_tool.name,
tool_input=step.return_values,
tool_input=step.log,
log="\nDetected AgentFinish, blocking it to force output via output handler.\n",
)
return step
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -1 +1 @@
"Agent stopped due to iteration limit or time limit."
"\\[\n\\begin{aligned}\nx &= \\frac{367}{71} \\\\\ny &= -\\frac{25}{49} \\\\\nx - y &= 2\n\\end{aligned}\n\\]"
10 changes: 4 additions & 6 deletions tests/test_agents/test_langchain_output_handler.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
import pytest

from langchain_community.tools import DuckDuckGoSearchRun
from langchain_core.agents import AgentFinish, AgentAction

from motleycrew.agents.langchain.tool_calling_react import ReActToolCallingAgent
from motleycrew.agents import MotleyOutputHandler
from motleycrew.agents.langchain.tool_calling_react import ReActToolCallingAgent
from motleycrew.agents.parent import DirectOutput
from motleycrew.common.exceptions import InvalidOutput


invalid_output = "Add more information about AI applications in medicine."


Expand All @@ -20,7 +18,7 @@ def handle_output(self, output: str):
return {"checked_output": output}


def fake_agent_plan(intermediate_steps, step):
def fake_agent_plan(intermediate_steps, step, **kwargs):
return step


Expand Down Expand Up @@ -65,12 +63,12 @@ def test_agent_plan(agent):
assert agent_action == step

return_values = {"output": "test_output"}
agent_finish = AgentFinish(return_values=return_values, log="agent finish log")
agent_finish = AgentFinish(return_values=return_values, log="test_output")

step = agent_executor.plan([], agent_finish)
assert isinstance(step, AgentAction)
assert step.tool == agent._agent_finish_blocker_tool.name
assert step.tool_input == return_values
assert step.tool_input == "test_output"


def test_agent_take_next_step(agent):
Expand Down
11 changes: 8 additions & 3 deletions tests/test_agents/test_llama_index_output_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ def test_run_step(agent):
return

task = Task(input="User input", memory=agent._agent.memory)
task_step = TaskStep(task_id=task.task_id, step_id=str(uuid.uuid4()), input="Test input")
task_step = TaskStep(
task_id=task.task_id, step_id=str(uuid.uuid4()), input="Test input"
)

task_state = TaskState(
task=task,
Expand Down Expand Up @@ -98,8 +100,11 @@ def test_run_step(agent):
_task_step = step_queue.pop()

assert _task_step.task_id == task.task_id
assert _task_step.input == "You must call the {} tool to return the output.".format(
agent.output_handler.name
assert (
_task_step.input
== "You must call the `{}` tool to return the output.".format(
agent.output_handler.name
)
)

# test direct output
Expand Down

0 comments on commit 2a767f0

Please sign in to comment.