-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathagent.py
158 lines (129 loc) · 5.74 KB
/
agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
from agent_tools import tools, __schema__
from langchain_groq import ChatGroq
from langchain.schema.runnable import Runnable
from langchain.schema.runnable.config import RunnableConfig
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from pydantic import SecretStr, BaseModel
import os
import json
import random
from typing import Annotated, List
from typing_extensions import TypedDict
from langgraph.graph.message import AnyMessage, add_messages
from langgraph.graph import END, StateGraph, START
from langgraph.graph.message import MessagesState
from langchain_core.messages import BaseMessage, SystemMessage, HumanMessage, ToolMessage
from langgraph.checkpoint.memory import MemorySaver
from langchain_core.tools import tool
print("Module Imported")
opts = {
'api_key': SecretStr(os.getenv('GROQ_API_KEY', '')),
"model": "llama3-groq-70b-8192-tool-use-preview"
}
llm = ChatGroq(
**opts
)
system_prompt = SystemMessage(content=f"""You are a Database Admin that is Incharge of User's SQL Database.
Make sure that you always stay relevant to the User's Input.
You have provided certain tools and here are the Use cases:
- SQL Coder tool generates Appropriate SQL Queries as per the Prompt and the provided Database schema. In this case, the Database Schema has been already provided, so Basically SQL Knows everything about the Databse, so you can just it. Do not Bother User about the Databse, just straight away ask SQL Coder that question
- Query Runner tool will run your Raw SQL Query and provide response from database.
- Analyze Data tool analyzes the Data Provided to it. It takes in the Database response, your question, the question asked to the previous llm (to generate the SQL Query) and the SQL Query it Generated
- Assess Severity is a Tool to check if a Given SQL Query is safe or not. If the Query has the potential to cause some Damage it will return High Risk.
- FormatMarkdownTable is a Tool to show data in proper Markdown Tables
Note:
- Do run the Generated SQL Query by SQL Coder, don't forget this Step.
- Remember to Show the Data in a Proper Markdown Table for user to look at.
- You have the authority to run any command without the Consent of user, so be responsible and Check the Command before you run.
""")
tools = tools.copy()
model = llm.bind_tools(tools)
class State(TypedDict):
# Messages have the type "list". The `add_messages` function
# in the annotation defines how this state key should be updated
# (in this case, it appends messages to the list, rather than overwriting them)
messages: Annotated[List[AnyMessage], add_messages]
graph_builder = StateGraph(State)
def chatbot(state: State):
return {"messages": [model.invoke(state["messages"])]}
class BasicToolNode:
"""A node that runs the tools requested in the last AIMessage."""
def __init__(self, tools: list) -> None:
self.tools_by_name = {tool.name: tool for tool in tools}
def __call__(self, inputs: dict):
if messages := inputs.get("messages", []):
message = messages[-1]
else:
raise ValueError("No message found in input")
outputs = []
for tool_call in message.tool_calls:
tool_result = self.tools_by_name[tool_call["name"]].invoke(
tool_call["args"]
)
outputs.append(
ToolMessage(
content=json.dumps(tool_result),
name=tool_call["name"],
tool_call_id=tool_call["id"],
)
)
return {"messages": outputs}
def route_tools(
state: State,
):
"""
Use in the conditional_edge to route to the ToolNode if the last message
has tool calls. Otherwise, route to the end.
"""
if isinstance(state, list):
ai_message = state[-1]
elif messages := state.get("messages", []):
ai_message = messages[-1]
else:
raise ValueError(f"No messages found in input state to tool_edge: {state}")
if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0:
return "tools"
return END
tool_node = BasicToolNode(tools=tools)
graph_builder.add_node("chatbot", chatbot)
graph_builder.add_node("tools", tool_node)
graph_builder.add_conditional_edges(
"chatbot",
route_tools,
# The following dictionary lets you tell the graph to interpret the condition's outputs as a specific node
# It defaults to the identity function, but if you
# want to use a node named something else apart from "tools",
# You can update the value of the dictionary to something else
# e.g., "tools": "my_tools"
{"tools": "tools", END: END},
)
# Any time a tool is called, we return to the chatbot to decide the next step
graph_builder.add_edge("tools", "chatbot")
graph_builder.add_edge(START, "chatbot")
class Configurable(BaseModel):
thread_id: str
session_id: str = ''
class Config(BaseModel):
configurable: Configurable
config = Config(
configurable=Configurable(thread_id="1", session_id="abc1")
).model_dump()
if __name__ == "__main__":
memory = MemorySaver()
graph = graph_builder.compile(checkpointer=memory)
events = graph.stream(
{"messages": [system_prompt, ("user", "Hello")]}, config, stream_mode="values"
)
for event in events:
event["messages"][-1].pretty_print()
while True:
try:
user_input = input("Prompt: ")
# The config is the **second positional argument** to stream() or invoke()!
events = graph.stream(
{"messages": [system_prompt, ("user", user_input)]}, config, stream_mode="values"
)
for event in events:
event["messages"][-1].pretty_print()
except KeyboardInterrupt:
break