Skip to content

Commit

Permalink
add output handler to llama_index agent
Browse files Browse the repository at this point in the history
  • Loading branch information
User committed Jun 21, 2024
1 parent af3eafd commit 13d9b17
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 3 deletions.
19 changes: 18 additions & 1 deletion examples/old/single_llama_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,37 @@
from motleycrew.agents.llama_index import ReActLlamaIndexMotleyAgent
from motleycrew.common import configure_logging
from motleycrew.tasks import SimpleTask
from motleycrew.common.exceptions import InvalidOutput
from motleycrew.common import AsyncBackend

from langchain_core.tools import StructuredTool


def main():
"""Main function of running the example."""
search_tool = DuckDuckGoSearchRun()

def check_output(output: str):
if "medicine" not in output.lower():
raise InvalidOutput("Add more information about AI applications in medicine.")

return {"checked_output": output.lower()}

output_handler = StructuredTool.from_function(
name="output_handler",
description="Output handler",
func=check_output,
)

# TODO: add LlamaIndex native tools
researcher = ReActLlamaIndexMotleyAgent(
description="Your goal is to uncover cutting-edge developments in AI and data science",
tools=[search_tool],
output_handler=output_handler,
verbose=True,
)

crew = MotleyCrew()
crew = MotleyCrew(async_backend=AsyncBackend.NONE)

# Create tasks for your agents
task = SimpleTask(
Expand Down
58 changes: 57 additions & 1 deletion motleycrew/agents/llama_index/llama_index.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
""" Module description """

from typing import Any, Optional, Sequence
import uuid

try:
from llama_index.core.agent import AgentRunner
Expand All @@ -9,11 +10,60 @@

from langchain_core.runnables import RunnableConfig

from motleycrew.agents.parent import MotleyAgentParent
from motleycrew.agents.parent import MotleyAgentParent, DirectOutput
from motleycrew.common import MotleySupportedTool
from motleycrew.common import MotleyAgentFactory
from motleycrew.common.utils import ensure_module_is_installed

from llama_index.core.chat_engine.types import ChatResponseMode
from llama_index.core.agent.types import TaskStep, TaskStepOutput
from llama_index.core.chat_engine.types import AgentChatResponse


def run_step_decorator(agent, output_handler = None):
"""Decorator for inclusion in the call chain of the agent, the output handler tool"""
def decorator(func):
output_task_step = None
def wrapper(task_id: str,
step: Optional[TaskStep] = None,
input: Optional[str] = None,
mode: ChatResponseMode = ChatResponseMode.WAIT,
**kwargs: Any):

nonlocal output_task_step

try:
cur_step_output = func(agent, task_id, step, input,mode, **kwargs)
except DirectOutput as e:
output = AgentChatResponse(e.output.get("checked_output"))
cur_step_output = TaskStepOutput(
output = output,
is_last = True,
next_steps = [],
task_step = output_task_step
)
return cur_step_output

if output_handler is None:
return cur_step_output

if cur_step_output.is_last:
cur_step_output.is_last = False
task_id = cur_step_output.task_step.task_id
output_task_step = TaskStep(task_id=task_id,
step_id=str(uuid.uuid4()),
input="For finish answer use tool {}".format(output_handler.name))

cur_step_output.next_steps.append(output_task_step)

step_queue = agent.state.get_step_queue(task_id)
step_queue.extend(cur_step_output.next_steps)

return cur_step_output

return wrapper
return decorator


class LlamaIndexMotleyAgent(MotleyAgentParent):
def __init__(
Expand All @@ -22,6 +72,7 @@ def __init__(
name: str | None = None,
agent_factory: MotleyAgentFactory[AgentRunner] | None = None,
tools: Sequence[MotleySupportedTool] | None = None,
output_handler: MotleySupportedTool | None = None,
verbose: bool = False,
):
"""Description
Expand All @@ -38,9 +89,14 @@ def __init__(
name=name,
agent_factory=agent_factory,
tools=tools,
output_handler=output_handler,
verbose=verbose,
)

def materialize(self):
super(LlamaIndexMotleyAgent, self).materialize()
self._agent._run_step = run_step_decorator(self._agent, self.output_handler)(self._agent.__class__._run_step)

def invoke(
self,
task_dict: dict,
Expand Down
2 changes: 2 additions & 0 deletions motleycrew/agents/llama_index/llama_index_react.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def __init__(
name: str | None = None,
tools: Sequence[MotleySupportedTool] | None = None,
llm: LLM | None = None,
output_handler: MotleySupportedTool | None = None,
verbose: bool = False,
):
"""Description
Expand Down Expand Up @@ -56,5 +57,6 @@ def agent_factory(tools: dict[str, MotleyTool]) -> ReActAgent:
name=name,
agent_factory=agent_factory,
tools=tools,
output_handler=output_handler,
verbose=verbose,
)
2 changes: 1 addition & 1 deletion motleycrew/tracking/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def _on_function_call_end(
params = self._get_initial_track_event_params(
LunaryRunType.TOOL, LunaryEventName.END, event_id
)
params["output"] = payload.get(EventPayload.FUNCTION_OUTPUT)
params["output"] = payload.get(EventPayload.FUNCTION_OUTPUT) if payload is not None else ""
return params

def _on_agent_step_start(
Expand Down

0 comments on commit 13d9b17

Please sign in to comment.