Skip to content

Commit

Permalink
gradio adapter
Browse files Browse the repository at this point in the history
  • Loading branch information
karllu3 committed Jun 4, 2024
1 parent 849b4f7 commit 46be6d8
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 45 deletions.
36 changes: 36 additions & 0 deletions docs/how-to/visualize_data.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# How-To: Visualize Views

There has been implemented Gradio Adapter class to create simple UI interface. It allows to display Data Preview related to Views
and execute user queries.

## Installation
```bash
pip install dbally[gradio]
```

## Create own gradio interface
Define collection with implemented views

```python
llm = LiteLLM(model_name="gpt-3.5-turbo")
collection = dbally.create_collection("recruitment", llm, event_handlers=[CLIEventHandler()])
collection.add(CandidateView, lambda: CandidateView(engine))
collection.add(SampleText2SQLView, lambda: SampleText2SQLView(prepare_freeform_enginge()))
```

Create gradio interface
```python
gradio_adapter = GradioAdapter()
gradio_interface = await gradio_adapter.create_interface(collection, similarity_store_list=[country_similarity])
```

Launch the gradio interface. To publish public interface pass argument `share=True`
```python
gradio_interface.launch()
```

The endpoint is set by gradio server by triggering python module with Gradio Adapter launch command.
Private endpoint is set to http://127.0.0.1:7860/ by default.

## Links
* [Example Gradio Interface](visualize_views_code.py)
6 changes: 5 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,11 @@ benchmark =
pydantic-settings~=2.0.3
psycopg2-binary~=2.9.9
elasticsearch =
elasticsearch==8.13.1
elasticsearch~=8.13.1
gradio =
gradio~=4.31.5
gradio_client~=0.16.4


[options.packages.find]
where = src
Expand Down
50 changes: 24 additions & 26 deletions src/dbally/utils/gradio_adapter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import sys
from typing import Optional, Tuple, Dict
from typing import Optional, Tuple, Dict, List

import gradio
import pandas as pd
Expand All @@ -8,10 +8,10 @@
from dbally.collection import Collection
from dbally.similarity import SimilarityIndex
from dbally.utils.errors import UnsupportedQueryError
from dbally.utils.gradio_log_redirect import Logger
from dbally.utils.log_to_file import FileLogger


sys.stdout = Logger("console.log")
CONSOLE_FILE_NAME = "console.log"
sys.stdout = FileLogger(CONSOLE_FILE_NAME)


class GradioAdapter:
Expand All @@ -20,15 +20,11 @@ class GradioAdapter:
SQL_RESULT = "sql"
PANDAS_RESULT = "filter_mask"

def __init__(self, similarity_store: SimilarityIndex = None, engine=None):
"""Initializes the GradioAdapter with an optional similarity store.
Args:
similarity_store: An instance of SimilarityIndex for similarity operations. Defaults to None.
"""
def __init__(self, preview_limit: int = 20):
"""Initializes the GradioAdapter with an optional similarity store."""
self.preview_limit = preview_limit
self.similarity_store_list = []
self.collection = None
self.similarity_store = similarity_store
self.loaded_dataframe = None
sys.stdout.flush()

async def ui_load_preview_data(self, selected_view_name: str) -> Tuple[pd.DataFrame, str, None, None, None, None]:
Expand All @@ -40,6 +36,7 @@ async def ui_load_preview_data(self, selected_view_name: str) -> Tuple[pd.DataFr
Returns:
A tuple containing the loaded DataFrame and a message indicating the view data has been loaded.
"""

preview_dataframe, load_status_text = self.load_preview_data(selected_view_name)
return preview_dataframe, load_status_text, None, None, None, None

Expand All @@ -48,7 +45,7 @@ def load_preview_data(self, selected_view_name: str):
text_to_display = "No data preview available"
if issubclass(type(selected_view), BaseStructuredView):
selected_view_results = selected_view.execute()
preview_dataframe = pd.DataFrame.from_records(selected_view_results.results)
preview_dataframe = pd.DataFrame.from_records(selected_view_results.results).head(self.preview_limit)
text_to_display = "Data preview loaded"
else:
preview_dataframe = pd.DataFrame()
Expand All @@ -66,40 +63,38 @@ async def ui_ask_query(self, question_query: str) -> Tuple[Dict, Optional[pd.Dat
If the query is unsupported, returns a message indicating this and None.
"""
try:
if self.similarity_store:
await self.similarity_store.update()
for similarity_store in self.similarity_store_list:
await similarity_store.update()

execution_result = await self.collection.ask(question=question_query)
if self.SQL_RESULT in execution_result.context:
generated_query = execution_result.context.get(self.SQL_RESULT)
elif self.PANDAS_RESULT in execution_result.context:
generated_query = execution_result.context.get(self.PANDAS_RESULT)
else:
generated_query = "Unsupported generated query"

generated_query = str(execution_result.context)
data = pd.DataFrame.from_records(execution_result.results)
except UnsupportedQueryError:
generated_query = {"Query": "unsupported"}
data = pd.DataFrame()
finally:
sys.stdout.flush()
with open("output.log", "r") as f:
with open(CONSOLE_FILE_NAME, "r") as f:
log = f.read()

return generated_query, data, log

async def create_interface(self, user_collection: Collection) -> Optional[gradio.Interface]:
async def create_interface(
self, user_collection: Collection, similarity_store_list: List[SimilarityIndex] = []
) -> Optional[gradio.Interface]:
"""Creates a Gradio interface for the provided user collection.
Args:
user_collection: The user collection to create an interface for.
similarity_store_list: SimilarityIndex
Returns:
The created Gradio interface, or None if no views are available in the collection.
"""
self.collection = user_collection
self.similarity_store_list = similarity_store_list

view_list = [*user_collection.list()]
print(view_list[0])
if view_list:
default_selected_view_name = view_list[0]
else:
Expand All @@ -119,7 +114,10 @@ async def create_interface(self, user_collection: Collection) -> Optional[gradio
data_preview_frame, data_preview_status = self.load_preview_data(view_list[0])

data_preview_info = gradio.Text(label="Data preview", value=data_preview_status)
loaded_data_frame = gradio.Dataframe(data_preview_frame)
if not data_preview_frame.empty:
loaded_data_frame = gradio.Dataframe(value=data_preview_frame, interactive=False)
else:
loaded_data_frame = gradio.Dataframe(interactive=False)

with gradio.Row():
query = gradio.Text(label="Ask question")
Expand Down
18 changes: 0 additions & 18 deletions src/dbally/utils/gradio_log_redirect.py

This file was deleted.

18 changes: 18 additions & 0 deletions src/dbally/utils/log_to_file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import sys


class FileLogger:
def __init__(self, filename):
self.logFile = open(filename, "w")
self.console = sys.stdout

def write(self, message):
self.logFile.write(message)
self.console.write(message)

def flush(self):
self.logFile.flush()
self.console.flush()

def isatty(self):
return False

0 comments on commit 46be6d8

Please sign in to comment.