-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathquickstart2_code.py
96 lines (77 loc) · 3.05 KB
/
quickstart2_code.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
# pylint: disable=missing-return-doc, missing-param-doc, missing-function-docstring
import os
import asyncio
from typing_extensions import Annotated
from dotenv import load_dotenv
import sqlalchemy
from sqlalchemy import create_engine
from sqlalchemy.ext.automap import automap_base
import dbally
from dbally import decorators, SqlAlchemyBaseView
from dbally.audit.event_handlers.cli_event_handler import CLIEventHandler
from dbally.similarity import SimpleSqlAlchemyFetcher, FaissStore, SimilarityIndex
from dbally.embeddings.litellm import LiteLLMEmbeddingClient
from dbally.llms.litellm import LiteLLM
load_dotenv()
engine = create_engine("sqlite:///examples/recruiting/data/candidates.db")
Base = automap_base()
Base.prepare(autoload_with=engine)
Candidate = Base.classes.candidates
country_similarity = SimilarityIndex(
fetcher=SimpleSqlAlchemyFetcher(
engine,
table=Candidate,
column=Candidate.country,
),
store=FaissStore(
index_dir="./similarity_indexes",
index_name="country_similarity",
embedding_client=LiteLLMEmbeddingClient(
model="text-embedding-3-small", # to use openai embedding model
api_key=os.environ["OPENAI_API_KEY"],
),
),
)
class CandidateView(SqlAlchemyBaseView):
"""
A view for retrieving candidates from the database.
"""
def get_select(self) -> sqlalchemy.Select:
"""
Creates the initial SqlAlchemy select object, which will be used to build the query.
"""
return sqlalchemy.select(Candidate)
@decorators.view_filter()
def at_least_experience(self, years: int) -> sqlalchemy.ColumnElement:
"""
Filters candidates with at least `years` of experience.
"""
return Candidate.years_of_experience >= years
@decorators.view_filter()
def senior_data_scientist_position(self) -> sqlalchemy.ColumnElement:
"""
Filters candidates that can be considered for a senior data scientist position.
"""
return sqlalchemy.and_(
Candidate.position.in_(["Data Scientist", "Machine Learning Engineer", "Data Engineer"]),
Candidate.years_of_experience >= 3,
)
@decorators.view_filter()
def from_country(self, country: Annotated[str, country_similarity]) -> sqlalchemy.ColumnElement:
"""
Filters candidates from a specific country.
"""
return Candidate.country == country
async def main():
await country_similarity.update()
llm = LiteLLM(model_name="gpt-3.5-turbo")
collection = dbally.create_collection("recruitment", llm, event_handlers=[CLIEventHandler()])
collection.add(CandidateView, lambda: CandidateView(engine))
result = await collection.ask("Find someone from the United States with more than 2 years of experience.")
print(f"The generated SQL query is: {result.context.get('sql')}")
print()
print(f"Retrieved {len(result.results)} candidates:")
for candidate in result.results:
print(candidate)
if __name__ == "__main__":
asyncio.run(main())