From 46be6d81d7b7f2b9d9ec9bd739d346631868d80d Mon Sep 17 00:00:00 2001 From: Lukasz Karlowski Date: Tue, 4 Jun 2024 02:40:18 +0200 Subject: [PATCH] gradio adapter --- docs/how-to/visualize_data.md | 36 ++++++++++++++++++ setup.cfg | 6 ++- src/dbally/utils/gradio_adapter.py | 50 ++++++++++++------------- src/dbally/utils/gradio_log_redirect.py | 18 --------- src/dbally/utils/log_to_file.py | 18 +++++++++ 5 files changed, 83 insertions(+), 45 deletions(-) create mode 100644 docs/how-to/visualize_data.md delete mode 100644 src/dbally/utils/gradio_log_redirect.py create mode 100644 src/dbally/utils/log_to_file.py diff --git a/docs/how-to/visualize_data.md b/docs/how-to/visualize_data.md new file mode 100644 index 00000000..6a891c76 --- /dev/null +++ b/docs/how-to/visualize_data.md @@ -0,0 +1,36 @@ +# How-To: Visualize Views + +There has been implemented Gradio Adapter class to create simple UI interface. It allows to display Data Preview related to Views +and execute user queries. + +## Installation +```bash +pip install dbally[gradio] +``` + +## Create own gradio interface +Define collection with implemented views + +```python + llm = LiteLLM(model_name="gpt-3.5-turbo") + collection = dbally.create_collection("recruitment", llm, event_handlers=[CLIEventHandler()]) + collection.add(CandidateView, lambda: CandidateView(engine)) + collection.add(SampleText2SQLView, lambda: SampleText2SQLView(prepare_freeform_enginge())) +``` + +Create gradio interface +```python + gradio_adapter = GradioAdapter() + gradio_interface = await gradio_adapter.create_interface(collection, similarity_store_list=[country_similarity]) +``` + +Launch the gradio interface. To publish public interface pass argument `share=True` +```python + gradio_interface.launch() +``` + +The endpoint is set by gradio server by triggering python module with Gradio Adapter launch command. +Private endpoint is set to http://127.0.0.1:7860/ by default. + +## Links +* [Example Gradio Interface](visualize_views_code.py) \ No newline at end of file diff --git a/setup.cfg b/setup.cfg index 0109094d..42da19cd 100644 --- a/setup.cfg +++ b/setup.cfg @@ -65,7 +65,11 @@ benchmark = pydantic-settings~=2.0.3 psycopg2-binary~=2.9.9 elasticsearch = - elasticsearch==8.13.1 + elasticsearch~=8.13.1 +gradio = + gradio~=4.31.5 + gradio_client~=0.16.4 + [options.packages.find] where = src diff --git a/src/dbally/utils/gradio_adapter.py b/src/dbally/utils/gradio_adapter.py index 4c0057e1..2a0c1dd9 100644 --- a/src/dbally/utils/gradio_adapter.py +++ b/src/dbally/utils/gradio_adapter.py @@ -1,5 +1,5 @@ import sys -from typing import Optional, Tuple, Dict +from typing import Optional, Tuple, Dict, List import gradio import pandas as pd @@ -8,10 +8,10 @@ from dbally.collection import Collection from dbally.similarity import SimilarityIndex from dbally.utils.errors import UnsupportedQueryError -from dbally.utils.gradio_log_redirect import Logger +from dbally.utils.log_to_file import FileLogger - -sys.stdout = Logger("console.log") +CONSOLE_FILE_NAME = "console.log" +sys.stdout = FileLogger(CONSOLE_FILE_NAME) class GradioAdapter: @@ -20,15 +20,11 @@ class GradioAdapter: SQL_RESULT = "sql" PANDAS_RESULT = "filter_mask" - def __init__(self, similarity_store: SimilarityIndex = None, engine=None): - """Initializes the GradioAdapter with an optional similarity store. - - Args: - similarity_store: An instance of SimilarityIndex for similarity operations. Defaults to None. - """ + def __init__(self, preview_limit: int = 20): + """Initializes the GradioAdapter with an optional similarity store.""" + self.preview_limit = preview_limit + self.similarity_store_list = [] self.collection = None - self.similarity_store = similarity_store - self.loaded_dataframe = None sys.stdout.flush() async def ui_load_preview_data(self, selected_view_name: str) -> Tuple[pd.DataFrame, str, None, None, None, None]: @@ -40,6 +36,7 @@ async def ui_load_preview_data(self, selected_view_name: str) -> Tuple[pd.DataFr Returns: A tuple containing the loaded DataFrame and a message indicating the view data has been loaded. """ + preview_dataframe, load_status_text = self.load_preview_data(selected_view_name) return preview_dataframe, load_status_text, None, None, None, None @@ -48,7 +45,7 @@ def load_preview_data(self, selected_view_name: str): text_to_display = "No data preview available" if issubclass(type(selected_view), BaseStructuredView): selected_view_results = selected_view.execute() - preview_dataframe = pd.DataFrame.from_records(selected_view_results.results) + preview_dataframe = pd.DataFrame.from_records(selected_view_results.results).head(self.preview_limit) text_to_display = "Data preview loaded" else: preview_dataframe = pd.DataFrame() @@ -66,40 +63,38 @@ async def ui_ask_query(self, question_query: str) -> Tuple[Dict, Optional[pd.Dat If the query is unsupported, returns a message indicating this and None. """ try: - if self.similarity_store: - await self.similarity_store.update() + for similarity_store in self.similarity_store_list: + await similarity_store.update() execution_result = await self.collection.ask(question=question_query) - if self.SQL_RESULT in execution_result.context: - generated_query = execution_result.context.get(self.SQL_RESULT) - elif self.PANDAS_RESULT in execution_result.context: - generated_query = execution_result.context.get(self.PANDAS_RESULT) - else: - generated_query = "Unsupported generated query" - + generated_query = str(execution_result.context) data = pd.DataFrame.from_records(execution_result.results) except UnsupportedQueryError: generated_query = {"Query": "unsupported"} data = pd.DataFrame() finally: sys.stdout.flush() - with open("output.log", "r") as f: + with open(CONSOLE_FILE_NAME, "r") as f: log = f.read() return generated_query, data, log - async def create_interface(self, user_collection: Collection) -> Optional[gradio.Interface]: + async def create_interface( + self, user_collection: Collection, similarity_store_list: List[SimilarityIndex] = [] + ) -> Optional[gradio.Interface]: """Creates a Gradio interface for the provided user collection. Args: user_collection: The user collection to create an interface for. + similarity_store_list: SimilarityIndex Returns: The created Gradio interface, or None if no views are available in the collection. """ self.collection = user_collection + self.similarity_store_list = similarity_store_list + view_list = [*user_collection.list()] - print(view_list[0]) if view_list: default_selected_view_name = view_list[0] else: @@ -119,7 +114,10 @@ async def create_interface(self, user_collection: Collection) -> Optional[gradio data_preview_frame, data_preview_status = self.load_preview_data(view_list[0]) data_preview_info = gradio.Text(label="Data preview", value=data_preview_status) - loaded_data_frame = gradio.Dataframe(data_preview_frame) + if not data_preview_frame.empty: + loaded_data_frame = gradio.Dataframe(value=data_preview_frame, interactive=False) + else: + loaded_data_frame = gradio.Dataframe(interactive=False) with gradio.Row(): query = gradio.Text(label="Ask question") diff --git a/src/dbally/utils/gradio_log_redirect.py b/src/dbally/utils/gradio_log_redirect.py deleted file mode 100644 index 4fff387f..00000000 --- a/src/dbally/utils/gradio_log_redirect.py +++ /dev/null @@ -1,18 +0,0 @@ -import sys - - -class Logger: - def __init__(self, filename): - self.terminal = sys.stdout - self.log = open(filename, "w") - - def write(self, message): - self.terminal.write(message) - self.log.write(message) - - def flush(self): - self.terminal.flush() - self.log.flush() - - def isatty(self): - return False diff --git a/src/dbally/utils/log_to_file.py b/src/dbally/utils/log_to_file.py new file mode 100644 index 00000000..1c7de7dd --- /dev/null +++ b/src/dbally/utils/log_to_file.py @@ -0,0 +1,18 @@ +import sys + + +class FileLogger: + def __init__(self, filename): + self.logFile = open(filename, "w") + self.console = sys.stdout + + def write(self, message): + self.logFile.write(message) + self.console.write(message) + + def flush(self): + self.logFile.flush() + self.console.flush() + + def isatty(self): + return False