-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathquickstart3_code.py
145 lines (115 loc) · 4.66 KB
/
quickstart3_code.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
# pylint: disable=missing-return-doc, missing-param-doc, missing-function-docstring
import dbally
import os
import asyncio
from typing_extensions import Annotated
import sqlalchemy
from sqlalchemy import create_engine
from sqlalchemy.ext.automap import automap_base
import pandas as pd
from dbally import decorators, SqlAlchemyBaseView, DataFrameBaseView, ExecutionResult
from dbally.similarity import SimpleSqlAlchemyFetcher, FaissStore, SimilarityIndex
from dbally.embeddings.litellm import LiteLLMEmbeddingClient
from dbally.llms.litellm import LiteLLM
engine = create_engine("sqlite:///examples/recruiting/data/candidates.db")
Base = automap_base()
Base.prepare(autoload_with=engine)
Candidate = Base.classes.candidates
country_similarity = SimilarityIndex(
fetcher=SimpleSqlAlchemyFetcher(
engine,
table=Candidate,
column=Candidate.country,
),
store=FaissStore(
index_dir="./similarity_indexes",
index_name="country_similarity",
embedding_client=LiteLLMEmbeddingClient(
model="text-embedding-3-small", # to use openai embedding model
api_key=os.environ["OPENAI_API_KEY"],
),
),
)
class CandidateView(SqlAlchemyBaseView):
"""
A view for retrieving candidates from the database.
"""
def get_select(self) -> sqlalchemy.Select:
"""
Creates the initial SqlAlchemy select object, which will be used to build the query.
"""
return sqlalchemy.select(Candidate)
@decorators.view_filter()
def at_least_experience(self, years: int) -> sqlalchemy.ColumnElement:
"""
Filters candidates with at least `years` of experience.
"""
return Candidate.years_of_experience >= years
@decorators.view_filter()
def senior_data_scientist_position(self) -> sqlalchemy.ColumnElement:
"""
Filters candidates that can be considered for a senior data scientist position.
"""
return sqlalchemy.and_(
Candidate.position.in_(["Data Scientist", "Machine Learning Engineer", "Data Engineer"]),
Candidate.years_of_experience >= 3,
)
@decorators.view_filter()
def from_country(self, country: Annotated[str, country_similarity]) -> sqlalchemy.ColumnElement:
"""
Filters candidates from a specific country.
"""
return Candidate.country == country
jobs_data = pd.DataFrame.from_records(
[
{"title": "Data Scientist", "company": "Company A", "location": "New York", "salary": 100000},
{"title": "Data Engineer", "company": "Company B", "location": "San Francisco", "salary": 120000},
{"title": "Machine Learning Engineer", "company": "Company C", "location": "Berlin", "salary": 90000},
{"title": "Data Scientist", "company": "Company D", "location": "London", "salary": 110000},
{"title": "Data Scientist", "company": "Company E", "location": "Warsaw", "salary": 80000},
]
)
class JobView(DataFrameBaseView):
"""
View for retrieving information about job offers.
"""
@decorators.view_filter()
def with_salary_at_least(self, salary: int) -> pd.Series:
"""
Filters job offers with a salary of at least `salary`.
"""
return self.df.salary >= salary
@decorators.view_filter()
def in_location(self, location: str) -> pd.Series:
"""
Filters job offers in a specific location.
"""
return self.df.location == location
@decorators.view_filter()
def from_company(self, company: str) -> pd.Series:
"""
Filters job offers from a specific company.
"""
return self.df.company == company
def display_results(result: ExecutionResult):
if result.view_name == "CandidateView":
print(f"{len(result.results)} Candidates:")
for candidate in result.results:
print(f"{candidate['name']} - {candidate['skills']}")
elif result.view_name == "JobView":
print(f"{len(result.results)} Job Offers:")
for job in result.results:
print(f"{job['title']} at {job['company']} in {job['location']}")
async def main():
await country_similarity.update()
llm = LiteLLM(model_name="gpt-3.5-turbo")
collection = dbally.create_collection("recruitment", llm)
collection.add(CandidateView, lambda: CandidateView(engine))
collection.add(JobView, lambda: JobView(jobs_data))
result = await collection.ask("Find me job offers in New York with a salary of at least 100000.")
display_results(result)
print()
result = await collection.ask("Find me candidates from Poland.")
display_results(result)
if __name__ == "__main__":
asyncio.run(main())