Skip to content

Commit

Permalink
Gradio Adater implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
karllu3 committed May 28, 2024
1 parent 2fb275f commit 653f1fd
Show file tree
Hide file tree
Showing 4 changed files with 158 additions and 0 deletions.
40 changes: 40 additions & 0 deletions docs/how-to/visualize_views_code.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import asyncio
import dotenv
import os

import dbally
from dbally.audit import CLIEventHandler
from dbally.embeddings import LiteLLMEmbeddingClient
from dbally.similarity import SimilarityIndex, SimpleSqlAlchemyFetcher, FaissStore
from dbally.llms.litellm import LiteLLM
from dbally.utils.gradio_adapter import GradioAdapter
from sandbox.quickstart2 import CandidateView, engine, Candidate

dotenv.load_dotenv()
country_similarity = SimilarityIndex(
fetcher=SimpleSqlAlchemyFetcher(
engine,
table=Candidate,
column=Candidate.country,
),
store=FaissStore(
index_dir="./similarity_indexes",
index_name="country_similarity",
embedding_client=LiteLLMEmbeddingClient(
api_key=os.environ["OPENAI_API_KEY"],
),
),
)


async def main():
llm = LiteLLM(model_name="gpt-3.5-turbo")
collection = dbally.create_collection("recruitment", llm, event_handlers=[CLIEventHandler()])
collection.add(CandidateView, lambda: CandidateView(engine))
gradio_adapter = GradioAdapter(similarity_store=country_similarity)
gradio_interface = gradio_adapter.create_interface(collection)
gradio_interface.launch()


if __name__ == "__main__":
asyncio.run(main())
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ install_requires =
tabulate>=0.9.0
click~=8.1.7
numpy>=1.24.0
gradio>=4.31.5

[options.extras_require]
litellm =
Expand Down
13 changes: 13 additions & 0 deletions src/dbally/utils/dbcon.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import pandas as pd
from sqlalchemy import create_engine, inspect


def main():
engine = create_engine(r"sqlite:////home/karllu/projects/db-ally/candidates.db")
insp = inspect(engine)
print(insp.get_table_names())
pd.read_sql_table(insp.get_table_names()[0], engine)


if __name__ == "__main__":
main()
104 changes: 104 additions & 0 deletions src/dbally/utils/gradio_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
from typing import Optional, Tuple

import gradio
import pandas as pd

from dbally.collection import Collection
from dbally.similarity import SimilarityIndex
from dbally.utils.errors import UnsupportedQueryError


class GradioAdapter:
def __init__(self, similarity_store: SimilarityIndex = None):
"""Initializes the GradioAdapter with an optional similarity store.
Args:
similarity_store: An instance of SimilarityIndex for similarity operations. Defaults to None.
"""
self.collection = None
self.similarity_store = similarity_store
self.loaded_dataframe = None

async def load_data(self, input_dataframe: pd.DataFrame) -> str:
"""Loads data into the adapter from a given DataFrame.
Args:
input_dataframe: The DataFrame to load.
Returns:
A message indicating the data has been loaded.
"""
if self.similarity_store:
await self.similarity_store.update()
self.loaded_dataframe = input_dataframe
return "Frame data loaded."

async def load_selected_data(self, selected_view: str) -> Tuple[pd.DataFrame, str]:
"""Loads selected view data into the adapter.
Args:
selected_view: The name of the view to load.
Returns:
A tuple containing the loaded DataFrame and a message indicating the view data has been loaded.
"""
if self.similarity_store:
await self.similarity_store.update()
self.loaded_dataframe = pd.DataFrame.from_records(self.collection.get(selected_view).execute().results)
return self.loaded_dataframe, f"{selected_view} data loaded."

async def execute_query(self, query: str) -> Tuple[str, Optional[pd.DataFrame]]:
"""Executes a query against the collection.
Args:
query: The question to ask.
Returns:
A tuple containing the generated SQL (str) and the resulting DataFrame (pd.DataFrame).
If the query is unsupported, returns a message indicating this and None.
"""
try:
execution_result = await self.collection.ask(query)
result = execution_result.context.get("sql"), pd.DataFrame.from_records(execution_result.results)
except UnsupportedQueryError:
result = "Unsupported query", None
return result

def create_interface(self, user_collection: Collection) -> Optional[gradio.Interface]:
"""Creates a Gradio interface for the provided user collection.
Args:
user_collection: The user collection to create an interface for.
Returns:
The created Gradio interface, or None if no views are available in the collection.
"""
view_list = user_collection.list()
if not view_list:
print("There is no data to be loaded")
return None

self.collection = user_collection

with gradio.Blocks() as demo:
with gradio.Row():
with gradio.Column():
view_dropdown = gradio.Dropdown(label="Available views", choices=view_list)
load_info = gradio.Label(value="No data loaded.")
with gradio.Column():
loaded_data_frame = gradio.Dataframe(interactive=True, col_count=(4, "Fixed"))
load_data_button = gradio.Button("Load new data")
with gradio.Row():
query = gradio.Text(label="Ask question")
query_button = gradio.Button("Proceed")
with gradio.Row():
query_sql_result = gradio.Text(label="Generated SQL")
query_result_frame = gradio.Dataframe(interactive=False)

view_dropdown.change(
fn=self.load_selected_data, inputs=view_dropdown, outputs=[loaded_data_frame, load_info]
)
load_data_button.click(fn=self.load_data, inputs=loaded_data_frame, outputs=load_info)
query_button.click(fn=self.execute_query, inputs=[query], outputs=[query_sql_result, query_result_frame])

return demo

0 comments on commit 653f1fd

Please sign in to comment.