1
1
from typing import Any , List , Optional
2
2
3
+ from llama_index .core .llms import ChatMessage , ChatResponse
3
4
from llama_index .core .llms .function_calling import FunctionCallingLLM
4
5
from llama_index .core .memory import ChatMemoryBuffer
5
- from llama_index .core .tools .types import BaseTool
6
- from llama_index .core .workflow import Workflow , StartEvent , StopEvent , step
7
-
8
- from llama_index .core .llms import ChatMessage , ChatResponse
9
- from llama_index .core .tools import ToolSelection , ToolOutput
10
- from llama_index .core .workflow import Event
11
6
from llama_index .core .settings import Settings
7
+ from llama_index .core .tools import ToolOutput , ToolSelection
8
+ from llama_index .core .tools .types import BaseTool
9
+ from llama_index .core .workflow import (
10
+ Context ,
11
+ Event ,
12
+ StartEvent ,
13
+ StopEvent ,
14
+ Workflow ,
15
+ step ,
16
+ )
12
17
from pydantic import BaseModel
13
18
14
19
@@ -54,7 +59,7 @@ def __init__(
54
59
self .sources = []
55
60
56
61
@step ()
57
- async def prepare_chat_history (self , ev : StartEvent ) -> InputEvent :
62
+ async def prepare_chat_history (self , ctx : Context , ev : StartEvent ) -> InputEvent :
58
63
# clear sources
59
64
self .sources = []
60
65
@@ -67,13 +72,18 @@ async def prepare_chat_history(self, ev: StartEvent) -> InputEvent:
67
72
user_input = ev .input
68
73
user_msg = ChatMessage (role = "user" , content = user_input )
69
74
self .memory .put (user_msg )
75
+ ctx .session .write_event_to_stream (
76
+ Event (msg = f"[{ self .name } ] Start to work on: { user_input } " )
77
+ )
70
78
71
79
# get chat history
72
80
chat_history = self .memory .get ()
73
81
return InputEvent (input = chat_history )
74
82
75
83
@step ()
76
- async def handle_llm_input (self , ev : InputEvent ) -> ToolCallEvent | StopEvent :
84
+ async def handle_llm_input (
85
+ self , ctx : Context , ev : InputEvent
86
+ ) -> ToolCallEvent | StopEvent :
77
87
chat_history = ev .input
78
88
79
89
response = await self .llm .achat_with_tools (
@@ -86,14 +96,15 @@ async def handle_llm_input(self, ev: InputEvent) -> ToolCallEvent | StopEvent:
86
96
)
87
97
88
98
if not tool_calls :
99
+ ctx .session .write_event_to_stream (Event (msg = f"[{ self .name } ] Finished task" ))
89
100
return StopEvent (
90
101
result = AgentRunResult (response = response , sources = [* self .sources ])
91
102
)
92
103
else :
93
104
return ToolCallEvent (tool_calls = tool_calls )
94
105
95
106
@step ()
96
- async def handle_tool_calls (self , ev : ToolCallEvent ) -> InputEvent :
107
+ async def handle_tool_calls (self , ctx : Context , ev : ToolCallEvent ) -> InputEvent :
97
108
tool_calls = ev .tool_calls
98
109
tools_by_name = {tool .metadata .get_name (): tool for tool in self .tools }
99
110
0 commit comments