Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(gradio): enhance gradio interface #90

Merged
merged 12 commits into from
Sep 4, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
refactor gradio interface
micpst committed Sep 4, 2024
commit 0aa89d6d92f5f7d112f660c8b934c7553efaaa4e
421 changes: 242 additions & 179 deletions src/dbally/gradio/gradio_interface.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
# pylint: disable=too-many-locals,unused-variable
# flake8: noqa: F841

import json
from io import StringIO
from typing import Any, Dict, List, Optional, Tuple

import gradio
@@ -9,181 +13,163 @@
from dbally.audit.event_handlers.buffer_event_handler import BufferEventHandler
from dbally.collection import Collection
from dbally.collection.exceptions import NoViewFoundError
from dbally.iql_generator.prompt import UnsupportedQueryError
from dbally.prompt.template import PromptTemplateError
from dbally.views.exceptions import ViewExecutionError


async def create_gradio_interface(user_collection: Collection, preview_limit: int = 10) -> gradio.Interface:
"""Adapt and integrate data collection and query execution with Gradio interface components.
def create_gradio_interface(collection: Collection, preview_limit: int = 10) -> gradio.Interface:
"""
Creates a Gradio interface for interacting with the user collection and similarity stores.
Args:
user_collection: The user's collection to interact with.
collection: The collection to interact with.
preview_limit: The maximum number of preview data records to display. Default is 10.
Returns:
The created Gradio interface.
"""
adapter = GradioAdapter()
gradio_interface = await adapter.create_interface(user_collection, preview_limit)
return gradio_interface


def find_event_buffer() -> Optional[BufferEventHandler]:
"""
Searches through global event handlers to find an instance of BufferEventHandler.
This function iterates over the list of global event handlers stored in `dbally.event_handlers`.
It checks the type of each handler, and if it finds one that is an instance of `BufferEventHandler`, it
returns that handler. If no such handler is found, the function returns `None`.
Returns:
The first instance of `BufferEventHandler` found in the list, or `None` if no such handler is found.
"""
for handler in dbally.event_handlers:
if isinstance(handler, BufferEventHandler):
return handler
return None
adapter = GradioAdapter(collection=collection, preview_limit=preview_limit)
return adapter.create_interface()


class GradioAdapter:
"""
A class to adapt and integrate data collection and query execution with Gradio interface components.
"""

def __init__(self):
def __init__(self, collection: Collection, preview_limit: int = 10) -> None:
"""
Initializes the GradioAdapter with a preview limit.
Creates the gradio adapter.
Args:
collection: The collection to interact with.
preview_limit: The maximum number of preview data records to display.
"""
self.collection = collection
self.preview_limit = preview_limit
self.log = self._setup_event_buffer()

def _setup_event_buffer(self) -> StringIO:
"""
self.preview_limit = None
self.selected_view_name = None
self.collection = None
Setup the event buffer for the gradio interface.
Returns:
The buffer event handler.
"""
buffer_event_handler = None
for handler in dbally.event_handlers:
if isinstance(handler, BufferEventHandler):
buffer_event_handler = handler

buffer_event_handler = find_event_buffer()
if not buffer_event_handler:
buffer_event_handler = BufferEventHandler()
dbally.event_handlers.append(buffer_event_handler)

self.log: BufferEventHandler = buffer_event_handler.buffer # pylint: disable=no-member
return buffer_event_handler.buffer

def _load_gradio_data(self, preview_dataframe, label) -> Tuple[gradio.DataFrame, gradio.Label]:
def _render_dataframe(
self, df: pd.DataFrame, message: Optional[str] = None
) -> Tuple[gradio.Dataframe, gradio.Label]:
"""
Load data into Gradio components for preview.
This function takes a DataFrame and a label, and returns a tuple containing a Gradio DataFrame
and a Gradio Label. The visibility of these components is determined by whether the input
DataFrame is empty.
Renders the dataframe and label for the gradio interface.
Args:
preview_dataframe: The DataFrame to be loaded into the Gradio DataFrame component.
label: The label to be associated with the Gradio components.
df: The dataframe to render.
message: The message to display if the dataframe is empty.
Returns:
A tuple containing the Gradio DataFrame component with the provided data and label and A Gradio Label
indicating the availability of data.
A tuple containing the dataframe and label.
"""
if preview_dataframe.empty:
gradio_preview_dataframe = gradio.DataFrame(label=label, value=preview_dataframe, visible=False)
empty_frame_label = gradio.Label(value=f"{label} not available", visible=True, show_label=False)
else:
gradio_preview_dataframe = gradio.DataFrame(label=label, value=preview_dataframe, visible=True)
empty_frame_label = gradio.Label(value=f"{label} not available", visible=False, show_label=False)
return gradio_preview_dataframe, empty_frame_label
return (
gradio.Dataframe(value=df, visible=not df.empty, height=325),
gradio.Label(value=message, visible=df.empty, show_label=False),
)

async def _ui_load_preview_data(
self, selected_view_name: str
) -> Tuple[gradio.DataFrame, gradio.Label, None, None, None]:
def _render_view_preview(self, view_name: str) -> Tuple[gradio.Dataframe, gradio.Label]:
"""
Asynchronously loads preview data for a selected view name.
Loads preview data for a selected view name.
Args:
selected_view_name: The name of the selected view to load preview data for.
view_name: The name of the selected view to load preview data for.
Returns:
A tuple containing the preview dataframe, load status text, and four None values to clean gradio fields.
"""
self.selected_view_name = selected_view_name
preview_dataframe = self._load_preview_data(selected_view_name)
gradio_preview_dataframe, empty_frame_label = self._load_gradio_data(preview_dataframe, "Preview")
data = pd.DataFrame()
view = self.collection.get(view_name)

return gradio_preview_dataframe, empty_frame_label, None, None, None
if isinstance(view, BaseStructuredView):
results = view.execute().results
data = self._load_results_into_dataframe(results)
data = data.head(self.preview_limit)

def _load_preview_data(self, selected_view_name: str) -> pd.DataFrame:
"""
Loads preview data for a selected view name.
Args:
selected_view_name: The name of the selected view to load preview data for.
return self._render_dataframe(data, "Preview not available")

Returns:
A tuple containing the preview dataframe
async def _ask_collection(
self,
question: str,
return_natural_response: bool,
) -> Tuple[str, str, str, gradio.Text, gradio.DataFrame, str]:
"""
selected_view = self.collection.get(selected_view_name)
if issubclass(type(selected_view), BaseStructuredView):
selected_view_results = selected_view.execute()
preview_dataframe = self._load_results_into_dataframe(selected_view_results.results).head(
self.preview_limit
)
else:
preview_dataframe = pd.DataFrame()

return preview_dataframe

async def _ui_ask_query(
self, question_query: str, natural_language_flag: bool
) -> Tuple[gradio.DataFrame, gradio.Label, gradio.Text, gradio.Text, str]:
"""
Asynchronously processes a query and returns the results.
Processes the question and returns the results.
Args:
question_query: The query to process.
natural_language_flag: Flag to indicate if the natural language shall be returned
question: The question to ask the collection.
return_natural_response: Flag to indicate if the natural language shall be returned.
Returns:
A tuple containing the generated query context, the query results as a dataframe, and the log output.
"""
self.log.seek(0)
self.log.truncate(0)
textual_response = ""

try:
execution_result = await self.collection.ask(
question=question_query, return_natural_response=natural_language_flag
result = await self.collection.ask(
question=question,
return_natural_response=return_natural_response,
)
generated_query = str(execution_result.context)
data = self._load_results_into_dataframe(execution_result.results)
textual_response = str(execution_result.textual_response) if natural_language_flag else textual_response

except UnsupportedQueryError:
generated_query = {"Query": "unsupported"}
data = pd.DataFrame()
except NoViewFoundError:
generated_query = {"Query": "No view matched to query"}
data = pd.DataFrame()
except PromptTemplateError:
generated_query = {"Query": "No view matched to query"}
data = pd.DataFrame()
finally:
self.log.seek(0)
log_content = self.log.read()

gradio_dataframe, empty_dataframe_warning = self._load_gradio_data(data, "Results")
except (NoViewFoundError, ViewExecutionError):
sql = ""
iql_filters = ""
iql_aggregation = ""
retrieved_rows = pd.DataFrame()
textual_response = ""
else:
sql = result.context.get("sql", "")
iql_filters = result.context.get("iql", {}).get("filters", "")
iql_aggregation = result.context.get("iql", {}).get("aggregation", "")
retrieved_rows = self._load_results_into_dataframe(result.results)
textual_response = result.textual_response or ""

retrieved_rows, empty_retrieved_rows_warning = self._render_dataframe(retrieved_rows, "No rows retrieved")

self.log.seek(0)
log_content = self.log.read()

return (
gradio_dataframe,
empty_dataframe_warning,
gradio.Text(value=generated_query, visible=True),
gradio.Text(value=textual_response, visible=natural_language_flag),
gradio.Code(value=iql_filters, visible=bool(iql_filters)),
gradio.Code(value=iql_aggregation, visible=bool(iql_aggregation)),
gradio.Code(value=sql, visible=bool(sql)),
gradio.Textbox(value=textual_response, visible=return_natural_response),
retrieved_rows,
empty_retrieved_rows_warning,
log_content,
)

def _clear_results(self) -> Tuple[gradio.DataFrame, gradio.Label, gradio.Text, gradio.Text]:
preview_dataframe = self._load_preview_data(self.selected_view_name)
gradio_preview_dataframe, empty_frame_label = self._load_gradio_data(preview_dataframe, "Preview")
"""
Clears the results from the gradio interface.
Returns:
A tuple containing the cleared results.
"""
retrieved_rows, retrieved_rows_label = self._render_dataframe(pd.DataFrame(), "No rows retrieved")
return (
gradio_preview_dataframe,
empty_frame_label,
gradio.Text(visible=False),
gradio.Text(visible=False),
gradio.Textbox(visible=False),
gradio.Code(visible=False),
gradio.Code(visible=False),
gradio.Code(visible=False),
retrieved_rows,
retrieved_rows_label,
)

@staticmethod
@@ -199,101 +185,178 @@ def _load_results_into_dataframe(results: List[Dict[str, Any]]) -> pd.DataFrame:
"""
return pd.DataFrame(json.loads(json.dumps(results, default=str)))

async def create_interface(self, user_collection: Collection, preview_limit: int) -> gradio.Interface:
def create_interface(self) -> gradio.Interface:
"""
Creates a Gradio interface for interacting with the user collection and similarity stores.
Args:
user_collection: The user's collection to interact with.
preview_limit: The maximum number of preview data records to display.
Creates a Gradio interface for interacting with the collection.
Returns:
The created Gradio interface.
The Gradio interface.
"""

self.preview_limit = preview_limit
self.collection = user_collection

data_preview_frame = pd.DataFrame()
question_interactive = False

view_list = [*user_collection.list()]
view_list = [*self.collection.list()]
if view_list:
self.selected_view_name = view_list[0]
data_preview_frame = self._load_preview_data(self.selected_view_name)
selected_view_name = view_list[0]
question_interactive = True
else:
selected_view_name = None
question_interactive = False

with gradio.Blocks(title="db-ally lab") as demo:
gradio.Markdown("# 🔍 db-ally lab")

with gradio.Tab("Collection"):
with gradio.Row():
with gradio.Column():
api_key = gradio.Textbox(
label="API Key",
placeholder="Enter your API Key",
type="password",
interactive=question_interactive,
)
model_name = gradio.Textbox(
label="Model Name",
placeholder="Enter your model name",
value="gpt-3.5-turbo",
interactive=question_interactive,
max_lines=1,
)
query = gradio.Textbox(
label="Question",
placeholder="Enter your question",
interactive=question_interactive,
max_lines=1,
)
natural_language_response_checkbox = gradio.Checkbox(
label="Use Natural Language Responder",
interactive=question_interactive,
)
query_button = gradio.Button(
value="Ask",
interactive=question_interactive,
variant="primary",
)
clear_button = gradio.ClearButton(
value="Reset",
components=[query],
interactive=question_interactive,
)

with gradio.Blocks() as demo:
with gradio.Row():
with gradio.Column():
view_dropdown = gradio.Dropdown(
label="Data View preview", choices=view_list, value=self.selected_view_name
)
query = gradio.Text(label="Ask question", interactive=question_interactive)
query_button = gradio.Button("Ask db-ally", interactive=question_interactive)
clear_button = gradio.ClearButton(components=[query], interactive=question_interactive)
natural_language_response_checkbox = gradio.Checkbox(
label="Return natural language answer", interactive=question_interactive
with gradio.Column():
view_dropdown = gradio.Dropdown(
label="View Preview",
choices=view_list,
value=selected_view_name,
interactive=question_interactive,
)
if selected_view_name:
view_preview, view_preview_label = self._render_view_preview(selected_view_name)
else:
view_preview, view_preview_label = self._render_dataframe(
pd.DataFrame(), "No view selected"
)

with gradio.Tab("Logs"):
log_console = gradio.Code(label="Logs", language="shell")

with gradio.Tab("Results"):
natural_language_response = gradio.Textbox(
label="Natural Language Response",
visible=False,
)

with gradio.Column():
if not data_preview_frame.empty:
loaded_data_frame = gradio.Dataframe(
label="Preview", value=data_preview_frame, interactive=False
with gradio.Row():
iql_fitlers_result = gradio.Code(
label="IQL Filters Query",
lines=1,
language="python",
visible=False,
)
empty_frame_label = gradio.Label(value="Preview not available", visible=False)
else:
loaded_data_frame = gradio.Dataframe(interactive=False, visible=False)
empty_frame_label = gradio.Label(value="Preview not available", visible=True)

query_sql_result = gradio.Text(label="Generated query context", visible=False)
generated_natural_language_answer = gradio.Text(
label="Generated answer in natural language:", visible=False
iql_aggregation_result = gradio.Code(
label="IQL Aggreagation Query",
lines=1,
language="python",
visible=False,
)

sql_result = gradio.Code(
label="SQL Query",
lines=3,
language="sql",
visible=False,
)

with gradio.Row():
log_console = gradio.Code(label="Logs", language="shell")
with gradio.Accordion("See Retrieved Rows", open=False):
retrieved_rows = gradio.Dataframe(
interactive=False,
height=325,
visible=False,
)
retrieved_rows_label = gradio.Label(
value="No rows retrieved",
visible=True,
show_label=False,
)

with gradio.Tab("Help"):
gradio.Markdown(
"""
## How to use this app:
1. Enter your API Key for the LLM you want to use in the provided field.
2. Choose the [model](https://docs.litellm.ai/docs/providers) you want to use.
3. Type your question in the textbox.
4. Click on `Ask`. The retrieval results will appear in the `Results` tab.
## Learn more:
Want to learn more about db-ally? Check out our resources:
- [Website](https://deepsense.ai/db-ally)
- [GitHub](https://github.com/deepsense-ai/db-ally)
- [Documentation](https://db-ally.deepsense.ai)
"""
)

clear_button.add(
[
natural_language_response_checkbox,
loaded_data_frame,
query_sql_result,
generated_natural_language_answer,
natural_language_response,
iql_fitlers_result,
iql_aggregation_result,
sql_result,
retrieved_rows,
retrieved_rows_label,
log_console,
]
)

clear_button.click(
fn=self._clear_results,
inputs=[],
outputs=[
loaded_data_frame,
empty_frame_label,
query_sql_result,
generated_natural_language_answer,
natural_language_response,
iql_fitlers_result,
iql_aggregation_result,
sql_result,
retrieved_rows,
retrieved_rows_label,
],
)

view_dropdown.change(
fn=self._ui_load_preview_data,
fn=self._render_view_preview,
inputs=view_dropdown,
outputs=[
loaded_data_frame,
empty_frame_label,
query,
query_sql_result,
log_console,
view_preview,
view_preview_label,
],
)
query_button.click(
fn=self._ui_ask_query,
inputs=[query, natural_language_response_checkbox],
fn=self._ask_collection,
inputs=[
query,
natural_language_response_checkbox,
],
outputs=[
loaded_data_frame,
empty_frame_label,
query_sql_result,
generated_natural_language_answer,
iql_fitlers_result,
iql_aggregation_result,
sql_result,
natural_language_response,
retrieved_rows,
retrieved_rows_label,
log_console,
],
)