-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathgradio_adapter.py
106 lines (84 loc) · 4.16 KB
/
gradio_adapter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
from typing import Optional, Tuple
import gradio
import pandas as pd
from dbally.collection import Collection
from dbally.similarity import SimilarityIndex
from dbally.utils.errors import UnsupportedQueryError
class GradioAdapter:
"""A class to adapt Gradio interface with a similarity store and data operations."""
def __init__(self, similarity_store: SimilarityIndex = None):
"""Initializes the GradioAdapter with an optional similarity store.
Args:
similarity_store: An instance of SimilarityIndex for similarity operations. Defaults to None.
"""
self.collection = None
self.similarity_store = similarity_store
self.loaded_dataframe = None
async def load_data(self, input_dataframe: pd.DataFrame) -> str:
"""Loads data into the adapter from a given DataFrame.
Args:
input_dataframe: The DataFrame to load.
Returns:
A message indicating the data has been loaded.
"""
if self.similarity_store:
await self.similarity_store.update()
self.loaded_dataframe = input_dataframe
return "Frame data loaded."
async def load_selected_data(self, selected_view: str) -> Tuple[pd.DataFrame, str]:
"""Loads selected view data into the adapter.
Args:
selected_view: The name of the view to load.
Returns:
A tuple containing the loaded DataFrame and a message indicating the view data has been loaded.
"""
if self.similarity_store:
await self.similarity_store.update()
self.loaded_dataframe = pd.DataFrame.from_records(self.collection.get(selected_view).execute().results)
return self.loaded_dataframe, f"{selected_view} data loaded."
async def execute_query(self, query: str) -> Tuple[str, Optional[pd.DataFrame]]:
"""Executes a query against the collection.
Args:
query: The question to ask.
Returns:
A tuple containing the generated SQL (str) and the resulting DataFrame (pd.DataFrame).
If the query is unsupported, returns a message indicating this and None.
"""
try:
execution_result = await self.collection.ask(query)
result = execution_result.context.get("sql"), pd.DataFrame.from_records(execution_result.results)
except UnsupportedQueryError:
result = "Unsupported query", None
return result
def create_interface(self, user_collection: Collection) -> Optional[gradio.Interface]:
"""Creates a Gradio interface for the provided user collection.
Args:
user_collection: The user collection to create an interface for.
Returns:
The created Gradio interface, or None if no views are available in the collection.
"""
view_list = user_collection.list()
if not view_list:
print("There is no data to be loaded")
return None
self.collection = user_collection
with gradio.Blocks() as demo:
with gradio.Row():
with gradio.Column():
view_dropdown = gradio.Dropdown(label="Available views", choices=view_list)
load_info = gradio.Label(value="No data loaded.")
with gradio.Column():
loaded_data_frame = gradio.Dataframe(interactive=True, col_count=(4, "Fixed"))
load_data_button = gradio.Button("Load new data")
with gradio.Row():
query = gradio.Text(label="Ask question")
query_button = gradio.Button("Proceed")
with gradio.Row():
query_sql_result = gradio.Text(label="Generated SQL")
query_result_frame = gradio.Dataframe(interactive=False)
view_dropdown.change(
fn=self.load_selected_data, inputs=view_dropdown, outputs=[loaded_data_frame, load_info]
)
load_data_button.click(fn=self.load_data, inputs=loaded_data_frame, outputs=load_info)
query_button.click(fn=self.execute_query, inputs=[query], outputs=[query_sql_result, query_result_frame])
return demo