Skip to content

Commit

Permalink
add tests for output handler max iterations
Browse files Browse the repository at this point in the history
  • Loading branch information
User committed Jul 5, 2024
1 parent 432aba1 commit b8d52b1
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 31 deletions.
38 changes: 27 additions & 11 deletions tests/test_agents/test_langchain_output_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from motleycrew.agents import MotleyOutputHandler
from motleycrew.agents.langchain.tool_calling_react import ReActToolCallingAgent
from motleycrew.agents.parent import DirectOutput
from motleycrew.common.exceptions import InvalidOutput
from motleycrew.common.exceptions import InvalidOutput, OutputHandlerMaxIterationsExceeded

invalid_output = "Add more information about AI applications in medicine."

Expand Down Expand Up @@ -41,7 +41,7 @@ def agent():
tools=[DuckDuckGoSearchRun()],
verbose=True,
chat_history=True,
output_handler=ReportOutputHandler(),
output_handler=ReportOutputHandler(max_iterations=5),
)
agent.materialize()
object.__setattr__(agent._agent, "plan", fake_agent_plan)
Expand All @@ -56,6 +56,19 @@ def agent():
return agent


@pytest.fixture
def run_kwargs(agent):
agent_executor = agent.agent.bound.bound.steps[1].bound

run_kwargs = {
"name_to_tool_map": {tool.name: tool for tool in agent_executor.tools},
"color_mapping": {},
"inputs": {},
"intermediate_steps": [],
}
return run_kwargs


def test_agent_plan(agent):
agent_executor = agent.agent
agent_action = AgentAction("tool", "tool_input", "tool_log")
Expand All @@ -71,15 +84,7 @@ def test_agent_plan(agent):
assert step.tool_input == "test_output"


def test_agent_take_next_step(agent):
agent_executor = agent.agent.bound.bound.steps[1].bound

run_kwargs = {
"name_to_tool_map": {tool.name: tool for tool in agent_executor.tools},
"color_mapping": {},
"inputs": {},
"intermediate_steps": [],
}
def test_agent_take_next_step(agent, run_kwargs):

# test wrong output
input_data = "Latest advancements in AI in 2024."
Expand All @@ -95,3 +100,14 @@ def test_agent_take_next_step(agent):
assert isinstance(step_result.return_values, dict)
output_result = step_result.return_values.get("output")
assert output_result == {"checked_output": input_data}


def test_output_handler_max_iteration(agent, run_kwargs):
input_data = "Latest advancements in AI in 2024."
run_kwargs["inputs"] = input_data

with pytest.raises(OutputHandlerMaxIterationsExceeded):
for iteration in range(agent.output_handler.max_iterations + 1):
agent.agent._take_next_step(**run_kwargs)

assert iteration == agent.output_handler.max_iterations
75 changes: 55 additions & 20 deletions tests/test_agents/test_llama_index_output_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest

from langchain_community.tools import DuckDuckGoSearchRun
from langchain_core.tools import StructuredTool

try:
from llama_index.core.agent.types import Task, TaskStep, TaskStepOutput
Expand All @@ -17,7 +18,11 @@

from motleycrew.agents.llama_index import ReActLlamaIndexMotleyAgent
from motleycrew.agents import MotleyOutputHandler
from motleycrew.common.exceptions import InvalidOutput, ModuleNotInstalled
from motleycrew.common.exceptions import (
InvalidOutput,
ModuleNotInstalled,
OutputHandlerMaxIterationsExceeded,
)


invalid_output = "Add more information about AI applications in medicine."
Expand Down Expand Up @@ -49,7 +54,7 @@ def agent():
agent = ReActLlamaIndexMotleyAgent(
description="Your goal is to uncover cutting-edge developments in AI and data science",
tools=[search_tool],
output_handler=ReportOutputHandler(),
output_handler=ReportOutputHandler(max_iterations=5),
verbose=True,
)
agent.materialize()
Expand All @@ -58,17 +63,17 @@ def agent():

except ModuleNotInstalled:
return

return agent


def test_run_step(agent):
@pytest.fixture
def task_data(agent):
if agent is None:
return

task = Task(input="User input", memory=agent._agent.memory)
task_step = TaskStep(
task_id=task.task_id, step_id=str(uuid.uuid4()), input="Test input"
)
task_step = TaskStep(task_id=task.task_id, step_id=str(uuid.uuid4()), input="Test input")

task_state = TaskState(
task=task,
Expand All @@ -82,6 +87,24 @@ def test_run_step(agent):
output=AgentChatResponse(response="Test response"),
next_steps=[],
)
return task, task_step_output


def find_output_handler(agent: ReActLlamaIndexMotleyAgent) -> StructuredTool:
agent_worker = agent.agent.agent_worker
output_handler = None
for tool in agent_worker._get_tools(""):
if tool.metadata.name == "output_handler":
output_handler = tool.to_langchain_tool()
break
return output_handler


def test_run_step(agent, task_data):
if agent is None:
return

task, task_step_output = task_data

# test not last output
cur_step_output = agent._agent._run_step("", task_step_output=task_step_output)
Expand All @@ -100,23 +123,12 @@ def test_run_step(agent):
_task_step = step_queue.pop()

assert _task_step.task_id == task.task_id
assert (
_task_step.input
== "You must call the `{}` tool to return the output.".format(
agent.output_handler.name
)
assert _task_step.input == "You must call the `{}` tool to return the output.".format(
agent.output_handler.name
)

# test direct output

# find output handler
agent_worker = agent.agent.agent_worker
output_handler = None
for tool in agent_worker._get_tools(""):
if tool.metadata.name == "output_handler":
output_handler = tool.to_langchain_tool()
break

output_handler = find_output_handler(agent)
if output_handler is None:
return

Expand Down Expand Up @@ -144,3 +156,26 @@ def test_run_step(agent):
)
assert hasattr(agent, "direct_output")
assert agent.direct_output.output == {"checked_output": output_handler_input}


def test_output_handler_max_iteration(agent, task_data):
if agent is None:
return

task, task_step_output = task_data

output_handler = find_output_handler(agent)
if output_handler is None:
return

output_handler_input = "Latest advancements in AI in 2024."
with pytest.raises(OutputHandlerMaxIterationsExceeded):
for iteration in range(agent.output_handler.max_iterations + 1):

agent._agent._run_step(
"",
task_step_output=task_step_output,
output_handler=output_handler,
output_handler_input=output_handler_input,
)
assert iteration == agent.output_handler.max_iterations

0 comments on commit b8d52b1

Please sign in to comment.