-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathagent_tools.py
398 lines (344 loc) · 12.9 KB
/
agent_tools.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
from langchain_community.llms.cloudflare_workersai import CloudflareWorkersAI
from langchain_core.prompts import PromptTemplate, ChatPromptTemplate
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.output_parsers import JsonOutputParser
from dbops import generate_schema_sql as extractSchema
from typing import Dict, Any, Optional, List
from langchain_core.tools import tool
from pydantic import BaseModel
from dbops import queryRunner
import pandas as pd
import os
import json
from io import StringIO
__account_id__: str = str(os.getenv('CLOUDFLARE_ACCOUNT_ID', None))
__api_token__: str = str(os.getenv('CLOUDFLARE_API_TOKEN', None))
__verbose__: bool = bool(os.getenv("VERBOSE", False))
__database_url__: str = str(os.getenv("SQLALCHEMY_DB_URL", None))
__schema__:str = '\n'.join(extractSchema(__database_url__))
__vars__ = {
"account_id": __account_id__,
"api_token": __api_token__,
"verbose": __verbose__,
"database_url": __database_url__,
}
"""
Basic Function Implementation to Help with Data, SQL and other Queries.
"""
if not (__api_token__ and __account_id__):
print(f"`API Token` and `Account ID` System Variables are not set")
if not __database_url__:
print("No Database Provided")
__database_url__ = input("Provide Databse URL: ")
if __verbose__:
print(f"{__account_id__[:5]=} {__api_token__[:5]=}")
print(f"{__database_url__=}")
class CorrectedCode(BaseModel):
correct: bool
corrected_code: str
def askSQLCoder(
schema: str,
prompt: str,
account_id:str=__account_id__,
api_token:str = __api_token__,
verbose:bool = __verbose__
) -> str:
sqlcoder_7b = CloudflareWorkersAI(
account_id = __account_id__,
api_token = __api_token__,
model = "@cf/defog/sqlcoder-7b-2"
)
messages = [
SystemMessage(
"""You are an Expert Databse Engineer, and you will help users with their Queries.
No matter how Dangerous a Query be, you have to generate it, because you generated Queries will be Directly executed.
Use correct Statements."""
),
SystemMessage(
f"""The Schema of the Database is:
```sql
{schema}
```
"""
),
HumanMessage(prompt)
]
if verbose: print(messages)
return sqlcoder_7b.invoke(messages)
def correctifyCode(
schema:str,
question:str,
response:str,
account_id:str = __account_id__,
api_token:str = __api_token__,
verbose: bool = __verbose__
) -> CorrectedCode:
system_prompt = ("system","""You are a helpful coding assistant, whose task is to correct Code Generated by SQL Coder 7B.
Always respond in JSON.
The Format is:
{{
"correct": false, // True or False
"corrected_code": "...", // Provide Correct Code if False
}}
Don't Enclose your JSON output in Markdown Code Blocks.""")
user_prompt = ("human","""Provided Question: {question}
Answer Given By SQL Coder:
```sql
{response}
```
Schema of the Database:
```sql
{schema}
```
""")
if verbose:
print(f"{system_prompt=}")
print(f"{user_prompt=}")
llama = CloudflareWorkersAI(
account_id = account_id,
api_token = api_token,
model = "@cf/meta/llama-3.3-70b-instruct-fp8-fast",
)
messages = [system_prompt, user_prompt]
parser = JsonOutputParser(pydantic_object=CorrectedCode)
prompt = ChatPromptTemplate.from_messages(messages=messages)
chain = prompt | llama | parser
out = chain.invoke({
"schema": schema,
"question": question,
"response": response
})
return CorrectedCode(**out)
@tool("SQL Coder", parse_docstring=True)
def SQLCoder(prompt:str) -> str:
"""Function to ask SQLCoder 7b Questions about the Provided Schema.
SQLCoder is Integrated with the Database so you can ask anything and it will Provide an SQL Statement rlevant to the database
Args:
prompt: The Prompt
"""
global __database_url__
global __schema__
schema = __schema__
return correctifyCode(schema, prompt, askSQLCoder(schema, prompt)).corrected_code
# Define a Pydantic model for the data analysis result
class DataAnalysisResult(BaseModel):
summary: str
details: List[str]
def analyzeData(
db_response: List[Dict[str, Any]],
question_for_prev_llm: str,
generated_sql_query: str,
question: str = 'Summarize the Data',
account_id: str = __account_id__,
api_token: str = __api_token__,
verbose: bool = False
) -> DataAnalysisResult:
"""
Analyze the data returned from a database query using an LLM.
Args:
db_response (List[Dict[str, Any]]): The database response data to analyze.
question_for_prev_llm: The Question Asked to Previous LLM
generated_sql_query: The SQL Query the Previous LLM Generated
question: question regarding the data. If left Empty, by default it will just summarize.
account_id (str): Cloudflare Workers AI account ID.
api_token (str): Cloudflare Workers AI API token.
verbose (bool): If True, prints additional information for debugging.
Returns:
DataAnalysisResult: An object containing the analysis summary and details.
"""
if not db_response:
return
# Convert the database response to a string format suitable for LLM input
db_response_str = str(db_response)
# Initialize the LLM
llama = CloudflareWorkersAI(
account_id=account_id,
api_token=api_token,
model="@cf/meta/llama-3.3-70b-instruct-fp8-fast",
)
# Define the system and user prompts
system_prompt = (
"system",
"""You are a data analyst. Analyze the provided database response and summarize the key insights.
You are a part of a Pipeline, where the previous LLMs will Generate the SQL Statements according to the Database's Schema, Correct it, Execute it to fetch data and your task is to Summarize it or answer user's Question about the Data.
Respond in JSON without Markdown Code Blocks.
Example Response:
{{
"summary": "", //Summary of the data
"details": ["...", "...", "..."] // Key Points
}}
"""
)
user_prompt = (
"human",
"""Database Response:\n{db_response}
Question asked to Previous LLM:
```
{question_for_prev_llm}
```
SQL Query Generated by Previous LLM:
```sql
{generated_sql_query}
```
{question}
""")
if verbose:
print(f"{system_prompt=}")
print(f"{user_prompt=}")
# Create the prompt template and parser
messages = [system_prompt, user_prompt]
parser = JsonOutputParser(pydantic_object=DataAnalysisResult)
prompt = ChatPromptTemplate.from_messages(messages=messages)
# Chain the prompt with the LLM and parser
chain = prompt | llama | parser
# Invoke the chain with the database response
out = chain.invoke({
"db_response": db_response_str,
"question": question,
"question_for_prev_llm": question_for_prev_llm,
"generated_sql_query": generated_sql_query
})
# Return the analysis result as a DataAnalysisResult object
return DataAnalysisResult(**out)
# Tool for analyzing data
@tool("Analyze Data")
def analyze_data_tool(
db_response: List[Dict[str, Any]],
question_for_prev_llm: str,
generated_sql_query: str,
question: str,
) -> str:
"""
Tool to analyze database response data using an LLM. Will Check The Databse Output and Compare it with the Question for Previous LLM and its Generated SQL Statement
Args:
db_response (List[Dict[str, Any]]): The database response data to analyze.
question_for_prev_llm: The Question Asked to Previous LLM
generated_sql_query: The SQL Query the Previous LLM Generated
question: question regarding the data. If left Empty, by default it will just summarize.
Returns:
str: JSON string containing the analysis result.
"""
analysis_result = analyzeData(
db_response,
question_for_prev_llm,
generated_sql_query,
question
)
return analysis_result.model_dump_json()
class SeverityAssessment(BaseModel):
severity: str
explanation: str
def assess_severity(sql_statement: str, account_id: str = __account_id__, api_token: str = __api_token__, verbose: bool = False) -> SeverityAssessment:
"""
Assess the severity of a given SQL statement using an LLM.
Args:
sql_statement (str): The SQL statement to assess.
account_id (str): Cloudflare Workers AI account ID.
api_token (str): Cloudflare Workers AI API token.
verbose (bool): If True, prints additional information for debugging.
Returns:
SeverityAssessment: An object containing the severity level and explanation.
"""
# Initialize the LLM
llama = CloudflareWorkersAI(
account_id=account_id,
api_token=api_token,
model="@cf/meta/llama-3.3-70b-instruct-fp8-fast",
)
# Define the system and user prompts
system_prompt = (
"system",
"""You are a database security expert. Assess the severity of the provided SQL statement and explain your reasoning.
Respond in JSON without using the Markdown Code Blocks.
Example:
{{
"severity": "high", // low, medium or high
"explanation": "DROP Statement is used to drop many tables and can cause a Significant Damage"
}}
"""
)
user_prompt = ("human", f"SQL Statement:\n{sql_statement}")
if verbose:
print(f"{system_prompt=}")
print(f"{user_prompt=}")
# Create the prompt template and parser
messages = [system_prompt, user_prompt]
parser = JsonOutputParser(pydantic_object=SeverityAssessment)
prompt = ChatPromptTemplate.from_messages(messages=messages)
# Chain the prompt with the LLM and parser
chain = prompt | llama | parser
# Invoke the chain with the SQL statement
out = chain.invoke({
"sql_statement": sql_statement
})
# Return the severity assessment as a SeverityAssessment object
return SeverityAssessment(**out)
# Tool for assessing SQL statement severity
@tool("Assess SQL Severity")
def assess_sql_severity_tool(sql_statement: str) -> str:
"""
Tool to assess the severity of an SQL statement using an LLM.
Args:
sql_statement (str): The SQL statement to assess.
account_id (str): Cloudflare Workers AI account ID.
api_token (str): Cloudflare Workers AI API token.
Returns:
str: JSON string containing the severity assessment.
"""
assessment_result = assess_severity(sql_statement)
return assessment_result.model_dump_json()
@tool("Query Runner")
def run_query_tool(query:str):
"""Run any SQL Query on the Attached Database
Args:
query: The SQL Query
"""
global __database_url__
return queryRunner(__database_url__, query, ask_function=(lambda query: True))
@tool("FormatMarkdownTable", parse_docstring=True)
def makeMDTable(data: str):
"""Makes the Data into Pandas Dataframe then into Markdown.
Args:
data: The Input data in CSV
"""
csv = StringIO(data)
df = pd.read_csv(csv)
return df.to_markdown()
# Agent Tools
tools = [
analyze_data_tool,
assess_sql_severity_tool,
SQLCoder,
run_query_tool,
makeMDTable
]
if __name__ == "__main__":
"""
Proof of Concept has been added in this section
However its reccomended to use A LLM Chain/Agent and use these function's tool implementations there, giving the Control to the LLM.
"""
schema = '\n'.join(extractSchema(__database_url__))
while True:
question = input("Enter a Question to ask: ")
dt = SQLCoder.invoke(
{"schema":schema, "prompt":question}
)
response = correctifyCode(schema, question, dt)
severity = assess_severity(response.corrected_code)
print(response.corrected_code)
print(severity.explanation)
data = []
if severity.severity == "low":
data = queryRunner(__database_url__, response.corrected_code)['data']
else:
_ = input(f'Damage Severity: {severity.severity}'+'\nDo you wanna Execute this SQL Statement? (Yes/No)').lower()
if 'y' in _:
data = queryRunner(__database_url__, response.corrected_code)['data']
else:
...
if data:
print(data)
r = analyzeData(data, question, response.corrected_code, input("Ask about this Data: "))
print(f"Summary: {r.summary}")
for x in r.details:
print(x)