-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathrecruiting.py
136 lines (106 loc) · 4.31 KB
/
recruiting.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import asyncio
from dataclasses import dataclass
from typing import List
from recruiting.db import ENGINE, fill_candidate_table, get_recruitment_db_description
from recruiting.views import RecruitmentView
import dbally
from dbally.audit.event_handlers.cli_event_handler import CLIEventHandler
from dbally.audit.event_tracker import EventTracker
from dbally.llms.litellm import LiteLLM
from dbally.prompts import PromptTemplate
TEXT2SQL_PROMPT_TEMPLATE = PromptTemplate(
(
{
"role": "system",
"content": (
"You are given the following SQL tables:"
"\n\n{schema}\n\n"
"Your job is to write queries given a user’s request."
"Please return only the query, do not provide any extra text or explanation."
),
},
{
"role": "user",
"content": ("{question}"),
},
)
)
@dataclass
class Question:
"""
A question to be asked to the recruitment database.
"""
dbally_question: str
gpt_question: str = ""
class Benchmark:
"""
A benchmark for comparison of dbally and end2end gpt based text2sql.
"""
def __init__(self) -> None:
self._questions: List[Question] = []
@property
def questions(self) -> List[Question]:
"""List of benchmark questions
Raises:
ValueError: If no questions are added to the benchmark
Returns:
List[Question]: List of benchmark questions
"""
if self._questions:
return self._questions
raise ValueError("No questions added to the benchmark")
def add_question(self, question: Question) -> None:
"""Adds a question to the benchmark.
Args:
question (Question): A question to be added to the benchmark
"""
self._questions.append(question)
example_benchmark = Benchmark()
example_benchmark.add_question(Question(dbally_question="Give candidates with more than 5 years of experience"))
example_benchmark.add_question(
Question(
dbally_question="Return candidates available for senior positions",
gpt_question="Return me candidates available for senior positions.\
Seniors have more than 5 years of experience",
)
)
example_benchmark.add_question(Question(dbally_question="List candidates from Europe"))
example_benchmark.add_question(Question(dbally_question="Who studied at Stanford?"))
example_benchmark.add_question(
Question(
dbally_question="Do we have any perfect fits\
for data scientist positions?"
)
)
async def recruiting_example(db_description: str, benchmark: Benchmark = example_benchmark) -> None:
"""Runs a recruiting example which compares dbally and end2end gpt based text2sql.
Args:
db_description (str): database schema description,used to generate prompts for gpt.
benchmark (Benchmark, optional): Benchmark containing set of questions. Defaults to example_benchmark.
"""
recruitment_db = dbally.create_collection(
"recruitment",
llm=LiteLLM(),
event_handlers=[CLIEventHandler()],
)
recruitment_db.add(RecruitmentView, lambda: RecruitmentView(ENGINE))
event_tracker = EventTracker()
llm = LiteLLM("gpt-4")
for question in benchmark.questions:
await recruitment_db.ask(question.dbally_question, return_natural_response=True)
gpt_question = question.gpt_question if question.gpt_question else question.dbally_question
gpt_response = await llm.generate_text(
TEXT2SQL_PROMPT_TEMPLATE, {"schema": db_description, "question": gpt_question}, event_tracker=event_tracker
)
print(f"GPT response: {gpt_response}")
def run_recruiting_example(db_description: str = "", benchmark: Benchmark = example_benchmark) -> None:
"""Runs the recruiting example.
Args:
db_description (str, optional): database schema description, used to generate prompts for gpt. Defaults to "".
benchmark (Benchmark, optional): Benchmark containing set of questions. Defaults to example_benchmark.
"""
fill_candidate_table()
asyncio.run(recruiting_example(db_description, benchmark))
if __name__ == "__main__":
db_desc = get_recruitment_db_description()
run_recruiting_example(db_desc)