From 653f1fd43773899c323a0b9973bdab2c8cdb99ec Mon Sep 17 00:00:00 2001 From: Lukasz Karlowski Date: Fri, 24 May 2024 01:32:54 +0200 Subject: [PATCH] Gradio Adater implementation --- docs/how-to/visualize_views_code.py | 40 +++++++++++ setup.cfg | 1 + src/dbally/utils/dbcon.py | 13 ++++ src/dbally/utils/gradio_adapter.py | 104 ++++++++++++++++++++++++++++ 4 files changed, 158 insertions(+) create mode 100644 docs/how-to/visualize_views_code.py create mode 100644 src/dbally/utils/dbcon.py create mode 100644 src/dbally/utils/gradio_adapter.py diff --git a/docs/how-to/visualize_views_code.py b/docs/how-to/visualize_views_code.py new file mode 100644 index 00000000..2e5c8e4c --- /dev/null +++ b/docs/how-to/visualize_views_code.py @@ -0,0 +1,40 @@ +import asyncio +import dotenv +import os + +import dbally +from dbally.audit import CLIEventHandler +from dbally.embeddings import LiteLLMEmbeddingClient +from dbally.similarity import SimilarityIndex, SimpleSqlAlchemyFetcher, FaissStore +from dbally.llms.litellm import LiteLLM +from dbally.utils.gradio_adapter import GradioAdapter +from sandbox.quickstart2 import CandidateView, engine, Candidate + +dotenv.load_dotenv() +country_similarity = SimilarityIndex( + fetcher=SimpleSqlAlchemyFetcher( + engine, + table=Candidate, + column=Candidate.country, + ), + store=FaissStore( + index_dir="./similarity_indexes", + index_name="country_similarity", + embedding_client=LiteLLMEmbeddingClient( + api_key=os.environ["OPENAI_API_KEY"], + ), + ), +) + + +async def main(): + llm = LiteLLM(model_name="gpt-3.5-turbo") + collection = dbally.create_collection("recruitment", llm, event_handlers=[CLIEventHandler()]) + collection.add(CandidateView, lambda: CandidateView(engine)) + gradio_adapter = GradioAdapter(similarity_store=country_similarity) + gradio_interface = gradio_adapter.create_interface(collection) + gradio_interface.launch() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/setup.cfg b/setup.cfg index 62d3bef8..b05cc593 100644 --- a/setup.cfg +++ b/setup.cfg @@ -39,6 +39,7 @@ install_requires = tabulate>=0.9.0 click~=8.1.7 numpy>=1.24.0 + gradio>=4.31.5 [options.extras_require] litellm = diff --git a/src/dbally/utils/dbcon.py b/src/dbally/utils/dbcon.py new file mode 100644 index 00000000..5e28eade --- /dev/null +++ b/src/dbally/utils/dbcon.py @@ -0,0 +1,13 @@ +import pandas as pd +from sqlalchemy import create_engine, inspect + + +def main(): + engine = create_engine(r"sqlite:////home/karllu/projects/db-ally/candidates.db") + insp = inspect(engine) + print(insp.get_table_names()) + pd.read_sql_table(insp.get_table_names()[0], engine) + + +if __name__ == "__main__": + main() diff --git a/src/dbally/utils/gradio_adapter.py b/src/dbally/utils/gradio_adapter.py new file mode 100644 index 00000000..2f36ee50 --- /dev/null +++ b/src/dbally/utils/gradio_adapter.py @@ -0,0 +1,104 @@ +from typing import Optional, Tuple + +import gradio +import pandas as pd + +from dbally.collection import Collection +from dbally.similarity import SimilarityIndex +from dbally.utils.errors import UnsupportedQueryError + + +class GradioAdapter: + def __init__(self, similarity_store: SimilarityIndex = None): + """Initializes the GradioAdapter with an optional similarity store. + + Args: + similarity_store: An instance of SimilarityIndex for similarity operations. Defaults to None. + """ + self.collection = None + self.similarity_store = similarity_store + self.loaded_dataframe = None + + async def load_data(self, input_dataframe: pd.DataFrame) -> str: + """Loads data into the adapter from a given DataFrame. + + Args: + input_dataframe: The DataFrame to load. + + Returns: + A message indicating the data has been loaded. + """ + if self.similarity_store: + await self.similarity_store.update() + self.loaded_dataframe = input_dataframe + return "Frame data loaded." + + async def load_selected_data(self, selected_view: str) -> Tuple[pd.DataFrame, str]: + """Loads selected view data into the adapter. + + Args: + selected_view: The name of the view to load. + + Returns: + A tuple containing the loaded DataFrame and a message indicating the view data has been loaded. + """ + if self.similarity_store: + await self.similarity_store.update() + self.loaded_dataframe = pd.DataFrame.from_records(self.collection.get(selected_view).execute().results) + return self.loaded_dataframe, f"{selected_view} data loaded." + + async def execute_query(self, query: str) -> Tuple[str, Optional[pd.DataFrame]]: + """Executes a query against the collection. + + Args: + query: The question to ask. + + Returns: + A tuple containing the generated SQL (str) and the resulting DataFrame (pd.DataFrame). + If the query is unsupported, returns a message indicating this and None. + """ + try: + execution_result = await self.collection.ask(query) + result = execution_result.context.get("sql"), pd.DataFrame.from_records(execution_result.results) + except UnsupportedQueryError: + result = "Unsupported query", None + return result + + def create_interface(self, user_collection: Collection) -> Optional[gradio.Interface]: + """Creates a Gradio interface for the provided user collection. + + Args: + user_collection: The user collection to create an interface for. + + Returns: + The created Gradio interface, or None if no views are available in the collection. + """ + view_list = user_collection.list() + if not view_list: + print("There is no data to be loaded") + return None + + self.collection = user_collection + + with gradio.Blocks() as demo: + with gradio.Row(): + with gradio.Column(): + view_dropdown = gradio.Dropdown(label="Available views", choices=view_list) + load_info = gradio.Label(value="No data loaded.") + with gradio.Column(): + loaded_data_frame = gradio.Dataframe(interactive=True, col_count=(4, "Fixed")) + load_data_button = gradio.Button("Load new data") + with gradio.Row(): + query = gradio.Text(label="Ask question") + query_button = gradio.Button("Proceed") + with gradio.Row(): + query_sql_result = gradio.Text(label="Generated SQL") + query_result_frame = gradio.Dataframe(interactive=False) + + view_dropdown.change( + fn=self.load_selected_data, inputs=view_dropdown, outputs=[loaded_data_frame, load_info] + ) + load_data_button.click(fn=self.load_data, inputs=loaded_data_frame, outputs=load_info) + query_button.click(fn=self.execute_query, inputs=[query], outputs=[query_sql_result, query_result_frame]) + + return demo