Skip to content

Commit aac9151

Browse files
bubble events
1 parent f011db7 commit aac9151

File tree

2 files changed

+36
-17
lines changed

2 files changed

+36
-17
lines changed

app/core/agent_call.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
from typing import Any, List
23

34
from llama_index.core.tools import FunctionTool
@@ -18,7 +19,7 @@ def __init__(
1819
**kwargs: Any,
1920
) -> None:
2021
agents = agents or []
21-
tools = [_create_call_workflow_fn(name, agent) for agent in agents]
22+
tools = [_create_call_workflow_fn(self, agent) for agent in agents]
2223
super().__init__(*args, name=name, tools=tools, **kwargs)
2324
# call add_workflows so agents will get detected by llama agents automatically
2425
self.add_workflows(**{agent.name: agent for agent in agents})
@@ -33,7 +34,7 @@ def __init__(
3334
**kwargs: Any,
3435
) -> None:
3536
agents = agents or []
36-
tools = [_create_call_workflow_fn(name, agent) for agent in agents]
37+
tools = [_create_call_workflow_fn(self, agent) for agent in agents]
3738
super().__init__(
3839
*args,
3940
name=name,
@@ -45,18 +46,25 @@ def __init__(
4546

4647

4748
def _create_call_workflow_fn(
48-
caller_name: str, agent: FunctionCallingAgent
49+
caller: FunctionCallingAgent, agent: FunctionCallingAgent
4950
) -> FunctionTool:
5051
def info(prefix: str, text: str) -> None:
5152
truncated = textwrap.shorten(text, width=255, placeholder="...")
5253
print(f"{prefix}: '{truncated}'")
5354

5455
async def acall_workflow_fn(input: str) -> str:
55-
info(f"[{caller_name}->{agent.name}]", input)
56-
with PrintPrefix(f"[{agent.name}]"):
57-
ret: AgentRunResult = await agent.run(input=input)
58-
response = ret.response.message.content
59-
info(f"[{caller_name}<-{agent.name}]", response)
56+
# info(f"[{caller_name}->{agent.name}]", input)
57+
task = asyncio.create_task(agent.run(input=input))
58+
# bubble all events while running the agent to the calling agent
59+
if len(caller._sessions) > 1:
60+
print("XXX: Bubbling events only works with single-session agents")
61+
else:
62+
session = next(iter(caller._sessions))
63+
async for ev in agent.stream_events():
64+
session.write_event_to_stream(ev)
65+
ret: AgentRunResult = await task
66+
response = ret.response.message.content
67+
# info(f"[{caller_name}<-{agent.name}]", response)
6068
return response
6169

6270
return FunctionTool.from_defaults(

app/core/function_call.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,19 @@
11
from typing import Any, List, Optional
22

3+
from llama_index.core.llms import ChatMessage, ChatResponse
34
from llama_index.core.llms.function_calling import FunctionCallingLLM
45
from llama_index.core.memory import ChatMemoryBuffer
5-
from llama_index.core.tools.types import BaseTool
6-
from llama_index.core.workflow import Workflow, StartEvent, StopEvent, step
7-
8-
from llama_index.core.llms import ChatMessage, ChatResponse
9-
from llama_index.core.tools import ToolSelection, ToolOutput
10-
from llama_index.core.workflow import Event
116
from llama_index.core.settings import Settings
7+
from llama_index.core.tools import ToolOutput, ToolSelection
8+
from llama_index.core.tools.types import BaseTool
9+
from llama_index.core.workflow import (
10+
Context,
11+
Event,
12+
StartEvent,
13+
StopEvent,
14+
Workflow,
15+
step,
16+
)
1217
from pydantic import BaseModel
1318

1419

@@ -54,7 +59,7 @@ def __init__(
5459
self.sources = []
5560

5661
@step()
57-
async def prepare_chat_history(self, ev: StartEvent) -> InputEvent:
62+
async def prepare_chat_history(self, ctx: Context, ev: StartEvent) -> InputEvent:
5863
# clear sources
5964
self.sources = []
6065

@@ -67,13 +72,18 @@ async def prepare_chat_history(self, ev: StartEvent) -> InputEvent:
6772
user_input = ev.input
6873
user_msg = ChatMessage(role="user", content=user_input)
6974
self.memory.put(user_msg)
75+
ctx.session.write_event_to_stream(
76+
Event(msg=f"[{self.name}] Start to work on: {user_input}")
77+
)
7078

7179
# get chat history
7280
chat_history = self.memory.get()
7381
return InputEvent(input=chat_history)
7482

7583
@step()
76-
async def handle_llm_input(self, ev: InputEvent) -> ToolCallEvent | StopEvent:
84+
async def handle_llm_input(
85+
self, ctx: Context, ev: InputEvent
86+
) -> ToolCallEvent | StopEvent:
7787
chat_history = ev.input
7888

7989
response = await self.llm.achat_with_tools(
@@ -86,14 +96,15 @@ async def handle_llm_input(self, ev: InputEvent) -> ToolCallEvent | StopEvent:
8696
)
8797

8898
if not tool_calls:
99+
ctx.session.write_event_to_stream(Event(msg=f"[{self.name}] Finished task"))
89100
return StopEvent(
90101
result=AgentRunResult(response=response, sources=[*self.sources])
91102
)
92103
else:
93104
return ToolCallEvent(tool_calls=tool_calls)
94105

95106
@step()
96-
async def handle_tool_calls(self, ev: ToolCallEvent) -> InputEvent:
107+
async def handle_tool_calls(self, ctx: Context, ev: ToolCallEvent) -> InputEvent:
97108
tool_calls = ev.tool_calls
98109
tools_by_name = {tool.metadata.get_name(): tool for tool in self.tools}
99110

0 commit comments

Comments
 (0)