Skip to content

Commit b6201a8

Browse files
feat: add bubbling of events
1 parent aac9151 commit b6201a8

File tree

8 files changed

+141
-67
lines changed

8 files changed

+141
-67
lines changed

.env

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
MODEL_PROVIDER=openai
66

77
# The name of LLM model to use.
8-
MODEL=gpt-4o
8+
MODEL=gpt-4o-mini
99

1010
# Name of the embedding model to use.
1111
EMBEDDING_MODEL=text-embedding-3-large

.vscode/launch.json

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
{
2+
// Use IntelliSense to learn about possible attributes.
3+
// Hover to view descriptions of existing attributes.
4+
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
5+
"version": "0.2.0",
6+
"configurations": [
7+
{
8+
"name": "Python Debugger: Current File",
9+
"type": "debugpy",
10+
"request": "launch",
11+
"program": "main.py",
12+
"console": "integratedTerminal",
13+
"justMyCode": false
14+
}
15+
]
16+
}

app/core/agent_call.py

Lines changed: 48 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,54 @@
11
import asyncio
22
from typing import Any, List
33

4-
from llama_index.core.tools import FunctionTool
4+
from llama_index.core.tools.types import ToolMetadata, ToolOutput
5+
from llama_index.core.tools.utils import create_schema_from_function
6+
from llama_index.core.workflow import Context, Workflow
57

8+
from app.core.function_call import (
9+
AgentRunResult,
10+
ContextAwareTool,
11+
FunctionCallingAgent,
12+
)
613
from app.core.planner_agent import StructuredPlannerAgent
7-
from app.core.prefix import PrintPrefix
8-
from app.core.function_call import AgentRunResult, FunctionCallingAgent
914

10-
import textwrap
15+
16+
class AgentCallTool(ContextAwareTool):
17+
def __init__(self, agent: Workflow) -> None:
18+
self.agent = agent
19+
# create the schema without the context
20+
name = f"call_{agent.name}"
21+
22+
async def schema_call(input: str) -> str:
23+
pass
24+
25+
# create the schema without the Context
26+
fn_schema = create_schema_from_function(name, schema_call)
27+
self._metadata = ToolMetadata(
28+
name=name,
29+
description=(
30+
f"Use this tool to delegate a sub task to the {agent.name} agent."
31+
+ (f" The agent is an {agent.role}." if agent.role else "")
32+
),
33+
fn_schema=fn_schema,
34+
)
35+
36+
# overload the acall function with the ctx argument as it's needed for bubbling the events
37+
async def acall(self, ctx: Context, input: str) -> ToolOutput:
38+
# FIXME: reset contexts, not needed after https://github.com/run-llama/llama_index/pull/15776
39+
self.agent._contexts = set()
40+
task = asyncio.create_task(self.agent.run(input=input))
41+
# bubble all events while running the agent to the calling agent
42+
async for ev in self.agent.stream_events():
43+
ctx.write_event_to_stream(ev)
44+
ret: AgentRunResult = await task
45+
response = ret.response.message.content
46+
return ToolOutput(
47+
content=str(response),
48+
tool_name=self.metadata.name,
49+
raw_input={"args": input, "kwargs": {}},
50+
raw_output=response,
51+
)
1152

1253

1354
class AgentCallingAgent(FunctionCallingAgent):
@@ -19,7 +60,7 @@ def __init__(
1960
**kwargs: Any,
2061
) -> None:
2162
agents = agents or []
22-
tools = [_create_call_workflow_fn(self, agent) for agent in agents]
63+
tools = [AgentCallTool(agent=agent) for agent in agents]
2364
super().__init__(*args, name=name, tools=tools, **kwargs)
2465
# call add_workflows so agents will get detected by llama agents automatically
2566
self.add_workflows(**{agent.name: agent for agent in agents})
@@ -29,12 +70,12 @@ class AgentOrchestrator(StructuredPlannerAgent):
2970
def __init__(
3071
self,
3172
*args: Any,
32-
agents: List[FunctionCallingAgent] | None = None,
3373
name: str = "orchestrator",
74+
agents: List[FunctionCallingAgent] | None = None,
3475
**kwargs: Any,
3576
) -> None:
3677
agents = agents or []
37-
tools = [_create_call_workflow_fn(self, agent) for agent in agents]
78+
tools = [AgentCallTool(agent=agent) for agent in agents]
3879
super().__init__(
3980
*args,
4081
name=name,
@@ -43,35 +84,3 @@ def __init__(
4384
)
4485
# call add_workflows so agents will get detected by llama agents automatically
4586
self.add_workflows(**{agent.name: agent for agent in agents})
46-
47-
48-
def _create_call_workflow_fn(
49-
caller: FunctionCallingAgent, agent: FunctionCallingAgent
50-
) -> FunctionTool:
51-
def info(prefix: str, text: str) -> None:
52-
truncated = textwrap.shorten(text, width=255, placeholder="...")
53-
print(f"{prefix}: '{truncated}'")
54-
55-
async def acall_workflow_fn(input: str) -> str:
56-
# info(f"[{caller_name}->{agent.name}]", input)
57-
task = asyncio.create_task(agent.run(input=input))
58-
# bubble all events while running the agent to the calling agent
59-
if len(caller._sessions) > 1:
60-
print("XXX: Bubbling events only works with single-session agents")
61-
else:
62-
session = next(iter(caller._sessions))
63-
async for ev in agent.stream_events():
64-
session.write_event_to_stream(ev)
65-
ret: AgentRunResult = await task
66-
response = ret.response.message.content
67-
# info(f"[{caller_name}<-{agent.name}]", response)
68-
return response
69-
70-
return FunctionTool.from_defaults(
71-
async_fn=acall_workflow_fn,
72-
name=f"call_{agent.name}",
73-
description=(
74-
f"Use this tool to delegate a sub task to the {agent.name} agent."
75-
+ (f" The agent is an {agent.role}." if agent.role else "")
76-
),
77-
)

app/core/function_call.py

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from abc import abstractmethod
12
from typing import Any, List, Optional
23

34
from llama_index.core.llms import ChatMessage, ChatResponse
@@ -6,6 +7,8 @@
67
from llama_index.core.settings import Settings
78
from llama_index.core.tools import ToolOutput, ToolSelection
89
from llama_index.core.tools.types import BaseTool
10+
from llama_index.core.tools import FunctionTool
11+
912
from llama_index.core.workflow import (
1013
Context,
1114
Event,
@@ -25,11 +28,30 @@ class ToolCallEvent(Event):
2528
tool_calls: list[ToolSelection]
2629

2730

31+
class AgentRunEvent(Event):
32+
name: str
33+
_msg: str
34+
35+
@property
36+
def msg(self):
37+
return self._msg
38+
39+
@msg.setter
40+
def msg(self, value):
41+
self._msg = value
42+
43+
2844
class AgentRunResult(BaseModel):
2945
response: ChatResponse
3046
sources: list[ToolOutput]
3147

3248

49+
class ContextAwareTool(FunctionTool):
50+
@abstractmethod
51+
async def acall(self, ctx: Context, input: Any) -> ToolOutput:
52+
pass
53+
54+
3355
class FunctionCallingAgent(Workflow):
3456
def __init__(
3557
self,
@@ -40,13 +62,15 @@ def __init__(
4062
verbose: bool = False,
4163
timeout: float = 360.0,
4264
name: str,
65+
write_events: bool = True,
4366
role: Optional[str] = None,
4467
**kwargs: Any,
4568
) -> None:
4669
super().__init__(*args, verbose=verbose, timeout=timeout, **kwargs)
4770
self.tools = tools or []
4871
self.name = name
4972
self.role = role
73+
self.write_events = write_events
5074

5175
if llm is None:
5276
llm = Settings.llm
@@ -72,9 +96,10 @@ async def prepare_chat_history(self, ctx: Context, ev: StartEvent) -> InputEvent
7296
user_input = ev.input
7397
user_msg = ChatMessage(role="user", content=user_input)
7498
self.memory.put(user_msg)
75-
ctx.session.write_event_to_stream(
76-
Event(msg=f"[{self.name}] Start to work on: {user_input}")
77-
)
99+
if self.write_events:
100+
ctx.write_event_to_stream(
101+
AgentRunEvent(name=self.name, msg=f"Start to work on: {user_input}")
102+
)
78103

79104
# get chat history
80105
chat_history = self.memory.get()
@@ -96,7 +121,10 @@ async def handle_llm_input(
96121
)
97122

98123
if not tool_calls:
99-
ctx.session.write_event_to_stream(Event(msg=f"[{self.name}] Finished task"))
124+
if self.write_events:
125+
ctx.write_event_to_stream(
126+
AgentRunEvent(name=self.name, msg="Finished task")
127+
)
100128
return StopEvent(
101129
result=AgentRunResult(response=response, sources=[*self.sources])
102130
)
@@ -128,7 +156,11 @@ async def handle_tool_calls(self, ctx: Context, ev: ToolCallEvent) -> InputEvent
128156
continue
129157

130158
try:
131-
tool_output = await tool.acall(**tool_call.tool_kwargs)
159+
if isinstance(tool, ContextAwareTool):
160+
# inject context for calling an context aware tool
161+
tool_output = await tool.acall(ctx=ctx, **tool_call.tool_kwargs)
162+
else:
163+
tool_output = await tool.acall(**tool_call.tool_kwargs)
132164
self.sources.append(tool_output)
133165
tool_msgs.append(
134166
ChatMessage(

app/core/planner_agent.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
from typing import Any, List
23

34
from llama_index.core.llms.function_calling import FunctionCallingLLM
@@ -10,7 +11,7 @@
1011
step,
1112
)
1213

13-
from app.core.function_call import AgentRunResult, FunctionCallingAgent
14+
from app.core.function_call import AgentRunEvent, AgentRunResult, FunctionCallingAgent
1415
from app.core.planner import Planner, SubTask, Plan
1516
from llama_index.core.tools import BaseTool
1617

@@ -35,8 +36,7 @@ class PlanEventType(Enum):
3536
REFINED = "refined"
3637

3738

38-
class PlanEvent(Event):
39-
39+
class PlanEvent(AgentRunEvent):
4040
event_type: PlanEventType
4141
plan: Plan
4242

@@ -68,9 +68,11 @@ def __init__(
6868
name="executor",
6969
llm=llm,
7070
tools=self.tools,
71+
write_events=False,
7172
# it's important to instruct to just return the tool call, otherwise the executor will interpret and change the result
7273
system_prompt="You are an expert in completing given tasks by calling the right tool for the task. Just return the result of the tool call. Don't add any information yourself",
7374
)
75+
self.add_workflows(executor=self.executor)
7476

7577
@step()
7678
async def create_plan(
@@ -80,8 +82,8 @@ async def create_plan(
8082
ctx.data["task"] = ev.input
8183
ctx.data["act_plan_id"] = plan_id
8284
# inform about the new plan
83-
ctx.session.write_event_to_stream(
84-
PlanEvent(event_type=PlanEventType.CREATED, plan=plan)
85+
ctx.write_event_to_stream(
86+
PlanEvent(name=self.name, event_type=PlanEventType.CREATED, plan=plan)
8587
)
8688
if self._verbose:
8789
print("=== Executing plan ===\n")
@@ -97,7 +99,7 @@ async def execute_plan(self, ctx: Context, ev: ExecutePlanEvent) -> SubTaskEvent
9799
# send an event per sub task
98100
events = [SubTaskEvent(sub_task=sub_task) for sub_task in upcoming_sub_tasks]
99101
for event in events:
100-
ctx.session.send_event(event)
102+
ctx.send_event(event)
101103

102104
return None
103105

@@ -107,7 +109,13 @@ async def execute_sub_task(
107109
) -> SubTaskResultEvent:
108110
if self._verbose:
109111
print(f"=== Executing sub task: {ev.sub_task.name} ===")
110-
result: AgentRunResult = await self.executor.run(input=ev.sub_task.input)
112+
# FIXME: reset contexts, not needed after https://github.com/run-llama/llama_index/pull/15776
113+
self.executor._contexts = set()
114+
task = asyncio.create_task(self.executor.run(input=ev.sub_task.input))
115+
# bubble all events while running the executor to the planner
116+
async for event in self.executor.stream_events():
117+
ctx.write_event_to_stream(event)
118+
result: AgentRunResult = await task
111119
if self._verbose:
112120
print("=== Done executing sub task ===\n")
113121
self.planner.state.add_completed_sub_task(ctx.data["act_plan_id"], ev.sub_task)
@@ -141,8 +149,10 @@ async def gather_results(
141149
)
142150
# inform about the new plan
143151
if new_plan is not None:
144-
ctx.session.write_event_to_stream(
145-
PlanEvent(event_type=PlanEventType.REFINED, plan=new_plan)
152+
ctx.write_event_to_stream(
153+
PlanEvent(
154+
name=self.name, event_type=PlanEventType.REFINED, plan=new_plan
155+
)
146156
)
147157

148158
# continue executing plan

main.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
# flake8: noqa: E402
22
import asyncio
33
import os
4+
import textwrap
45
from dotenv import load_dotenv
56
from app.core.agent_call import AgentCallingAgent, AgentOrchestrator
6-
from app.core.function_call import AgentRunResult, FunctionCallingAgent
7+
from app.core.function_call import AgentRunEvent, AgentRunResult, FunctionCallingAgent
78
from app.engine.index import get_index
89
from app.settings import init_settings
910
from llama_index.core.tools import QueryEngineTool, ToolMetadata
@@ -80,6 +81,11 @@ def create_orchestrator():
8081
)
8182

8283

84+
def info(prefix: str, text: str) -> None:
85+
truncated = textwrap.shorten(text, width=255, placeholder="...")
86+
print(f"[{prefix}] {truncated}")
87+
88+
8389
async def main():
8490
# agent = create_choreography()
8591
agent = create_orchestrator()
@@ -88,10 +94,11 @@ async def main():
8894
)
8995

9096
async for ev in agent.stream_events():
91-
print(ev.msg)
97+
if isinstance(ev, AgentRunEvent):
98+
info(ev.name, ev.msg)
9299

93100
ret: AgentRunResult = await task
94-
print(ret.response.message.content)
101+
print(f"\n\nResult:\n\n{ret.response.message.content}")
95102

96103

97104
if __name__ == "__main__":

0 commit comments

Comments
 (0)