-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
158 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import asyncio | ||
import dotenv | ||
import os | ||
|
||
import dbally | ||
from dbally.audit import CLIEventHandler | ||
from dbally.embeddings import LiteLLMEmbeddingClient | ||
from dbally.similarity import SimilarityIndex, SimpleSqlAlchemyFetcher, FaissStore | ||
from dbally.llms.litellm import LiteLLM | ||
from dbally.utils.gradio_adapter import GradioAdapter | ||
from sandbox.quickstart2 import CandidateView, engine, Candidate | ||
|
||
dotenv.load_dotenv() | ||
country_similarity = SimilarityIndex( | ||
fetcher=SimpleSqlAlchemyFetcher( | ||
engine, | ||
table=Candidate, | ||
column=Candidate.country, | ||
), | ||
store=FaissStore( | ||
index_dir="./similarity_indexes", | ||
index_name="country_similarity", | ||
embedding_client=LiteLLMEmbeddingClient( | ||
api_key=os.environ["OPENAI_API_KEY"], | ||
), | ||
), | ||
) | ||
|
||
|
||
async def main(): | ||
llm = LiteLLM(model_name="gpt-3.5-turbo") | ||
collection = dbally.create_collection("recruitment", llm, event_handlers=[CLIEventHandler()]) | ||
collection.add(CandidateView, lambda: CandidateView(engine)) | ||
gradio_adapter = GradioAdapter(similarity_store=country_similarity) | ||
gradio_interface = gradio_adapter.create_interface(collection) | ||
gradio_interface.launch() | ||
|
||
|
||
if __name__ == "__main__": | ||
asyncio.run(main()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
import pandas as pd | ||
from sqlalchemy import create_engine, inspect | ||
|
||
|
||
def main(): | ||
engine = create_engine(r"sqlite:////home/karllu/projects/db-ally/candidates.db") | ||
insp = inspect(engine) | ||
print(insp.get_table_names()) | ||
pd.read_sql_table(insp.get_table_names()[0], engine) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
from typing import Optional, Tuple | ||
|
||
import gradio | ||
import pandas as pd | ||
|
||
from dbally.collection import Collection | ||
from dbally.similarity import SimilarityIndex | ||
from dbally.utils.errors import UnsupportedQueryError | ||
|
||
|
||
class GradioAdapter: | ||
def __init__(self, similarity_store: SimilarityIndex = None): | ||
"""Initializes the GradioAdapter with an optional similarity store. | ||
Args: | ||
similarity_store: An instance of SimilarityIndex for similarity operations. Defaults to None. | ||
""" | ||
self.collection = None | ||
self.similarity_store = similarity_store | ||
self.loaded_dataframe = None | ||
|
||
async def load_data(self, input_dataframe: pd.DataFrame) -> str: | ||
"""Loads data into the adapter from a given DataFrame. | ||
Args: | ||
input_dataframe: The DataFrame to load. | ||
Returns: | ||
A message indicating the data has been loaded. | ||
""" | ||
if self.similarity_store: | ||
await self.similarity_store.update() | ||
self.loaded_dataframe = input_dataframe | ||
return "Frame data loaded." | ||
|
||
async def load_selected_data(self, selected_view: str) -> Tuple[pd.DataFrame, str]: | ||
"""Loads selected view data into the adapter. | ||
Args: | ||
selected_view: The name of the view to load. | ||
Returns: | ||
A tuple containing the loaded DataFrame and a message indicating the view data has been loaded. | ||
""" | ||
if self.similarity_store: | ||
await self.similarity_store.update() | ||
self.loaded_dataframe = pd.DataFrame.from_records(self.collection.get(selected_view).execute().results) | ||
return self.loaded_dataframe, f"{selected_view} data loaded." | ||
|
||
async def execute_query(self, query: str) -> Tuple[str, Optional[pd.DataFrame]]: | ||
"""Executes a query against the collection. | ||
Args: | ||
query: The question to ask. | ||
Returns: | ||
A tuple containing the generated SQL (str) and the resulting DataFrame (pd.DataFrame). | ||
If the query is unsupported, returns a message indicating this and None. | ||
""" | ||
try: | ||
execution_result = await self.collection.ask(query) | ||
result = execution_result.context.get("sql"), pd.DataFrame.from_records(execution_result.results) | ||
except UnsupportedQueryError: | ||
result = "Unsupported query", None | ||
return result | ||
|
||
def create_interface(self, user_collection: Collection) -> Optional[gradio.Interface]: | ||
"""Creates a Gradio interface for the provided user collection. | ||
Args: | ||
user_collection: The user collection to create an interface for. | ||
Returns: | ||
The created Gradio interface, or None if no views are available in the collection. | ||
""" | ||
view_list = user_collection.list() | ||
if not view_list: | ||
print("There is no data to be loaded") | ||
return None | ||
|
||
self.collection = user_collection | ||
|
||
with gradio.Blocks() as demo: | ||
with gradio.Row(): | ||
with gradio.Column(): | ||
view_dropdown = gradio.Dropdown(label="Available views", choices=view_list) | ||
load_info = gradio.Label(value="No data loaded.") | ||
with gradio.Column(): | ||
loaded_data_frame = gradio.Dataframe(interactive=True, col_count=(4, "Fixed")) | ||
load_data_button = gradio.Button("Load new data") | ||
with gradio.Row(): | ||
query = gradio.Text(label="Ask question") | ||
query_button = gradio.Button("Proceed") | ||
with gradio.Row(): | ||
query_sql_result = gradio.Text(label="Generated SQL") | ||
query_result_frame = gradio.Dataframe(interactive=False) | ||
|
||
view_dropdown.change( | ||
fn=self.load_selected_data, inputs=view_dropdown, outputs=[loaded_data_frame, load_info] | ||
) | ||
load_data_button.click(fn=self.load_data, inputs=loaded_data_frame, outputs=load_info) | ||
query_button.click(fn=self.execute_query, inputs=[query], outputs=[query_sql_result, query_result_frame]) | ||
|
||
return demo |