-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsql.py
194 lines (166 loc) · 6.04 KB
/
sql.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
import json
import os
import sqlite3
from dataclasses import asdict, dataclass
from datetime import datetime, timezone
from pathlib import Path
from textwrap import dedent
from typing import Any, Dict, List, Tuple, Union
import pandas as pd
from crewai import Agent, Crew, Process, Task
from crewai_tools import tool
from langchain.schema import AgentFinish
from langchain.schema.output import LLMResult
from langchain_community.tools.sql_database.tool import (
InfoSQLDatabaseTool,
ListSQLDatabaseTool,
QuerySQLCheckerTool,
QuerySQLDataBaseTool,
)
from langchain_community.utilities.sql_database import SQLDatabase
from langchain_core.callbacks.base import BaseCallbackHandler
from langchain_core.prompts import ChatPromptTemplate
from langchain_groq import ChatGroq
os.environ["GROQ_API_KEY"] = "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"
df = pd.read_csv("sdq2.csv")
df.head()
connection = sqlite3.connect("sqd.db")
df.to_sql(name="sqd", con=connection, if_exists='replace')
@dataclass
class Event:
event: str
timestamp: str
text: str
def _current_time() -> str:
return datetime.now(timezone.utc).isoformat()
class LLMCallbackHandler(BaseCallbackHandler):
def __init__(self, log_path: Path):
self.log_path = log_path
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> Any:
"""Run when LLM starts running."""
assert len(prompts) == 1
event = Event(event="llm_start", timestamp=_current_time(), text=prompts[0])
with self.log_path.open("a", encoding="utf-8") as file:
file.write(json.dumps(asdict(event)) + "\n")
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> Any:
"""Run when LLM ends running."""
generation = response.generations[-1][-1].message.content
event = Event(event="llm_end", timestamp=_current_time(), text=generation)
with self.log_path.open("a", encoding="utf-8") as file:
file.write(json.dumps(asdict(event)) + "\n")
llm = ChatGroq(
temperature=0,
model_name="llama3-70b-8192",
callbacks=[LLMCallbackHandler(Path("prompts.jsonl"))],
)
db = SQLDatabase.from_uri("sqlite:///sqd.db")
@tool("list_tables")
def list_tables() -> str:
"""List the available tables in the database"""
return ListSQLDatabaseTool(db=db).invoke("")
@tool("tables_schema")
def tables_schema(tables: str) -> str:
"""
Input is a comma-separated list of tables, output is the schema and sample rows
for those tables. Be sure that the tables actually exist by calling `list_tables` first!
Example Input: table1, table2, table3
"""
tool = InfoSQLDatabaseTool(db=db)
return tool.invoke(tables)
@tool("execute_sql")
def execute_sql(sql_query: str) -> str:
"""Execute a SQL query against the database. Returns the result"""
return QuerySQLDataBaseTool(db=db).invoke(sql_query)
@tool("check_sql")
def check_sql(sql_query: str) -> str:
"""
Use this tool to double check if your query is correct before executing it. Always use this
tool before executing a query with `execute_sql`.
"""
return QuerySQLCheckerTool(db=db, llm=llm).invoke({"query": sql_query})
sql_dev = Agent(
role="Senior Database Developer",
goal="Construct and execute SQL queries based on a request",
backstory=dedent(
"""
You are an experienced database engineer who is master at creating efficient and complex SQL queries.
You have a deep understanding of how different databases work and how to optimize queries.
Use the `list_tables` to find available tables.
Use the `tables_schema` to understand the metadata for the tables.
Use the `check_sql` to check your queries for correctness.
Use the `execute_sql` to execute queries against the database.
"""
),
llm=llm,
tools=[list_tables, tables_schema, execute_sql, check_sql],
allow_delegation=False,
max_iter = 5
)
data_analyst = Agent(
role="Senior Data Analyst",
goal="You receive data from the database developer and analyze it",
backstory=dedent(
"""
You have deep experience with analyzing datasets using Python.
Your work is always based on the provided data and is clear,
easy-to-understand and to the point. You have attention
to detail and always produce very detailed work (as long as you need).
"""
),
llm=llm,
allow_delegation=False,
verbose=True
)
report_writer = Agent(
role="Senior Report Editor",
goal="Write an executive summary type of report based on the work of the analyst",
backstory=dedent(
"""
Your writing still is well known for clear and effective communication.
You always summarize long texts into bullet points that contain the most
important details.
"""
),
llm=llm,
allow_delegation=False,
verbose=True
)
extract_data = Task(
description="Extract data that is required for the query {query}.",
expected_output="Database result for the query",
agent=sql_dev,
verbose=True
)
analyze_data = Task(
description="Analyze the data from the database and write an analysis for {query}.",
expected_output="Detailed analysis text",
agent=data_analyst,
context=[extract_data],
verbose=True
)
write_report = Task(
description=dedent(
"""
Write an executive summary of the report from the analysis. The report
must be less than 100 words.
"""
),
expected_output="Markdown report",
agent=report_writer,
context=[analyze_data],
verbose=True
)
crew = Crew(
agents=[sql_dev, data_analyst, report_writer],
tasks=[extract_data, analyze_data, write_report],
process=Process.sequential,
verbose=True,
memory=False,
output_log_file="crew.log",
)
inputs = {
"query": "Get the important metrics for the company Bonafide Health like conversion rate and hits for the most recent week number (max) and highlight interesting facts about affiliates performance"
}
result = crew.kickoff(inputs=inputs)