diff --git a/README.md b/README.md index a999d6fe..36e4bc9a 100644 --- a/README.md +++ b/README.md @@ -52,7 +52,7 @@ class CandidateView(SqlAlchemyBaseView): """ return Candidate.country == country -engine = create_engine('sqlite:///candidates.db') +engine = create_engine('sqlite:///examples/recruiting/data/candidates.db') llm = LiteLLM(model_name="gpt-3.5-turbo") my_collection = create_collection("collection_name", llm) my_collection.add(CandidateView, lambda: CandidateView(engine)) diff --git a/docs/how-to/sql_views.md b/docs/how-to/sql_views.md index 85411180..75b39fb6 100644 --- a/docs/how-to/sql_views.md +++ b/docs/how-to/sql_views.md @@ -77,7 +77,7 @@ You need to connect to the database using SQLAlchemy before you can use your vie ```python from sqlalchemy import create_engine -engine = create_engine('sqlite:///candidates.db') +engine = create_engine('sqlite:///examples/recruiting/data/candidates.db') ``` ## Registering the view diff --git a/docs/how-to/use_elastic_vector_store_code.py b/docs/how-to/use_elastic_vector_store_code.py index c325fa2c..4817fcf4 100644 --- a/docs/how-to/use_elastic_vector_store_code.py +++ b/docs/how-to/use_elastic_vector_store_code.py @@ -17,7 +17,7 @@ from dbally.similarity.elastic_vector_search import ElasticVectorStore load_dotenv() -engine = create_engine("sqlite:///candidates.db") +engine = create_engine("sqlite:///examples/recruiting/data/candidates.db") Base = automap_base() diff --git a/docs/how-to/use_elasticsearch_store_code.py b/docs/how-to/use_elasticsearch_store_code.py index 39258c44..1f690c35 100644 --- a/docs/how-to/use_elasticsearch_store_code.py +++ b/docs/how-to/use_elasticsearch_store_code.py @@ -18,7 +18,7 @@ from dbally.similarity.elasticsearch_store import ElasticsearchStore load_dotenv() -engine = create_engine("sqlite:///candidates.db") +engine = create_engine("sqlite:///examples/recruiting/data/candidates.db") Base = automap_base() diff --git a/docs/how-to/visualize_views.md b/docs/how-to/visualize_views.md new file mode 100644 index 00000000..6553e07b --- /dev/null +++ b/docs/how-to/visualize_views.md @@ -0,0 +1,44 @@ +# How-To: Visualize Views + +To create simple UI interface use [create_gradio_interface function](https://github.com/deepsense-ai/db-ally/tree/main/src/dbally/gradio/gradio_interface.py) It allows to display Data Preview related to Views +and execute user queries. + +## Installation +```bash +pip install dbally["gradio"] +``` +When You plan to use some other feature like faiss similarity store install them as well. + +```bash +pip install dbally["faiss"] +``` + +## Create own gradio interface +Define collection with implemented views + +```python +llm = LiteLLM(model_name="gpt-3.5-turbo") +await country_similarity.update() +collection = dbally.create_collection("recruitment", llm, event_handlers=[CLIEventHandler()]) +collection.add(CandidateView, lambda: CandidateView(engine)) +collection.add(SampleText2SQLViewCyphers, lambda: SampleText2SQLViewCyphers(create_freeform_memory_engine())) +``` + +>_**NOTE**_: The following code requires environment variables to proceed with LLM queries. For the example below, set the +> ```OPENAI_API_KEY``` environment variable. + +Create gradio interface +```python +gradio_interface = await create_gradio_interface(user_collection=collection) +``` + +Launch the gradio interface. To publish public interface pass argument `share=True` +```python +gradio_interface.launch() +``` + +The endpoint is set by triggering python module with Gradio Adapter launch command. +Private endpoint is set to http://127.0.0.1:7860/ by default. + +## Links +* [Example Gradio Interface](https://github.com/deepsense-ai/db-ally/tree/main/examples/visualize_views_code.py) \ No newline at end of file diff --git a/docs/quickstart/index.md b/docs/quickstart/index.md index 2852e966..9754fe08 100644 --- a/docs/quickstart/index.md +++ b/docs/quickstart/index.md @@ -30,14 +30,14 @@ pip install dbally[litellm] ## Database Configuration -In this guide, we will use an example SQLAlchemy database containing a single table named `candidates`. This table includes columns such as `id`, `name`, `country`, `years_of_experience`, `position`, `university`, `skills`, and `tags`. You can download the example database from [candidates.db](candidates.db). Alternatively, you can use your own database and models. +In this guide, we will use an example SQLAlchemy database containing a single table named `candidates`. This table includes columns such as `id`, `name`, `country`, `years_of_experience`, `position`, `university`, `skills`, and `tags`. You can download the example database from [candidates.db](https://github.com/deepsense-ai/db-ally/tree/main/examples/recruiting/candidates.db). Alternatively, you can use your own database and models. To connect to the database using SQLAlchemy, you need an engine and your database models. Start by creating an engine: ```python from sqlalchemy import create_engine -engine = create_engine('sqlite:///candidates.db') +engine = create_engine('sqlite:///examples/recruiting/data/candidates.db') ``` Next, define an SQLAlchemy model for the `candidates` table. You can either declare the `Candidate` model using [declarative mapping](https://docs.sqlalchemy.org/en/20/orm/mapping_styles.html#declarative-mapping) or generate it using [automap](https://docs.sqlalchemy.org/en/20/orm/extensions/automap.html). For simplicity, we'll use automap: diff --git a/docs/quickstart/quickstart2_code.py b/docs/quickstart/quickstart2_code.py index eab1e38e..593e7b4a 100644 --- a/docs/quickstart/quickstart2_code.py +++ b/docs/quickstart/quickstart2_code.py @@ -16,7 +16,7 @@ from dbally.llms.litellm import LiteLLM load_dotenv() -engine = create_engine("sqlite:///candidates.db") +engine = create_engine("sqlite:///examples/recruiting/data/candidates.db") Base = automap_base() Base.prepare(autoload_with=engine) diff --git a/docs/quickstart/quickstart3_code.py b/docs/quickstart/quickstart3_code.py index 2ef8f1df..f0c9270a 100644 --- a/docs/quickstart/quickstart3_code.py +++ b/docs/quickstart/quickstart3_code.py @@ -14,7 +14,7 @@ from dbally.embeddings.litellm import LiteLLMEmbeddingClient from dbally.llms.litellm import LiteLLM -engine = create_engine("sqlite:///candidates.db") +engine = create_engine("sqlite:///examples/recruiting/data/candidates.db") Base = automap_base() Base.prepare(autoload_with=engine) diff --git a/docs/quickstart/quickstart_code.py b/docs/quickstart/quickstart_code.py index 8cdbd446..34ee9765 100644 --- a/docs/quickstart/quickstart_code.py +++ b/docs/quickstart/quickstart_code.py @@ -11,17 +11,19 @@ from dbally.llms.litellm import LiteLLM -engine = create_engine('sqlite:///candidates.db') +engine = create_engine("sqlite:///examples/recruiting/data/candidates.db") Base = automap_base() Base.prepare(autoload_with=engine) Candidate = Base.classes.candidates + class CandidateView(SqlAlchemyBaseView): """ A view for retrieving candidates from the database. """ + def get_select(self) -> sqlalchemy.Select: """ Creates the initial SqlAlchemy select object, which will be used to build the query. @@ -52,6 +54,7 @@ def from_country(self, country: str) -> sqlalchemy.ColumnElement: """ return Candidate.country == country + async def main(): llm = LiteLLM(model_name="gpt-3.5-turbo") diff --git a/docs/quickstart/similarity_indexes/country_similarity.index b/docs/quickstart/similarity_indexes/country_similarity.index deleted file mode 100644 index d141a443..00000000 Binary files a/docs/quickstart/similarity_indexes/country_similarity.index and /dev/null differ diff --git a/docs/quickstart/similarity_indexes/country_similarity.npy b/docs/quickstart/similarity_indexes/country_similarity.npy deleted file mode 100644 index e9b8f814..00000000 Binary files a/docs/quickstart/similarity_indexes/country_similarity.npy and /dev/null differ diff --git a/examples/recruiting.py b/examples/recruiting.py index fedb2ea0..a4813b41 100644 --- a/examples/recruiting.py +++ b/examples/recruiting.py @@ -2,8 +2,8 @@ from dataclasses import dataclass from typing import List -from recruting.db import ENGINE, fill_candidate_table, get_recruitment_db_description -from recruting.views import RecruitmentView +from recruiting.db import ENGINE, fill_candidate_table, get_recruitment_db_description +from recruiting.views import RecruitmentView import dbally from dbally.audit.event_handlers.cli_event_handler import CLIEventHandler diff --git a/examples/recruting/__init__.py b/examples/recruiting/__init__.py similarity index 100% rename from examples/recruting/__init__.py rename to examples/recruiting/__init__.py diff --git a/examples/recruiting/candidate_view_with_similarity_store.py b/examples/recruiting/candidate_view_with_similarity_store.py new file mode 100644 index 00000000..f50c4545 --- /dev/null +++ b/examples/recruiting/candidate_view_with_similarity_store.py @@ -0,0 +1,68 @@ +# pylint: disable=missing-return-doc, missing-param-doc, missing-function-docstring + +import sqlalchemy +from sqlalchemy import create_engine +from sqlalchemy.ext.automap import automap_base +from typing_extensions import Annotated + +from dbally import SqlAlchemyBaseView, decorators +from dbally.embeddings.litellm import LiteLLMEmbeddingClient +from dbally.similarity import FaissStore, SimilarityIndex, SimpleSqlAlchemyFetcher + +engine = create_engine("sqlite:///examples/recruiting/data/candidates.db") + +Base = automap_base() +Base.prepare(autoload_with=engine) + +Candidate = Base.classes.candidates + +country_similarity = SimilarityIndex( + fetcher=SimpleSqlAlchemyFetcher( + engine, + table=Candidate, + column=Candidate.country, + ), + store=FaissStore( + index_dir="./similarity_indexes", + index_name="country_similarity", + embedding_client=LiteLLMEmbeddingClient( + model="text-embedding-3-small", # to use openai embedding model + ), + ), +) + + +class CandidateView(SqlAlchemyBaseView): + """ + A view for retrieving candidates from the database. + """ + + def get_select(self) -> sqlalchemy.Select: + """ + Creates the initial SqlAlchemy select object, which will be used to build the query. + """ + return sqlalchemy.select(Candidate) + + @decorators.view_filter() + def at_least_experience(self, years: int) -> sqlalchemy.ColumnElement: + """ + Filters candidates with at least `years` of experience. + """ + return Candidate.years_of_experience >= years + + @decorators.view_filter() + def senior_data_scientist_position(self) -> sqlalchemy.ColumnElement: + """ + Filters candidates that can be considered for a senior data scientist position. + """ + return sqlalchemy.and_( + Candidate.position.in_(["Data Scientist", "Machine Learning Engineer", "Data Engineer"]), + Candidate.years_of_experience >= 3, + ) + + @decorators.view_filter() + def from_country(self, country: Annotated[str, country_similarity]) -> sqlalchemy.ColumnElement: + """ + Filters candidates from a specific country. + """ + return Candidate.country == country diff --git a/examples/recruiting/cypher_text2sql_view.py b/examples/recruiting/cypher_text2sql_view.py new file mode 100644 index 00000000..fb76f7b5 --- /dev/null +++ b/examples/recruiting/cypher_text2sql_view.py @@ -0,0 +1,42 @@ +# pylint: disable=missing-return-doc, missing-function-docstring, missing-class-docstring, missing-return-type-doc +from typing import List + +import sqlalchemy +from sqlalchemy import text + +from dbally.views.freeform.text2sql import BaseText2SQLView, ColumnConfig, TableConfig + + +class SampleText2SQLViewCyphers(BaseText2SQLView): + def get_tables(self) -> List[TableConfig]: + return [ + TableConfig( + name="security_specialists", + columns=[ + ColumnConfig("id", "SERIAL PRIMARY KEY"), + ColumnConfig("name", "VARCHAR(255)"), + ColumnConfig("cypher", "VARCHAR(255)"), + ], + description="Knowledge base", + ) + ] + + +def create_freeform_memory_engine() -> sqlalchemy.Engine: + freeform_engine = sqlalchemy.create_engine("sqlite:///:memory:") + + statements = [ + "CREATE TABLE security_specialists (id INTEGER PRIMARY KEY, name TEXT, cypher TEXT)", + "INSERT INTO security_specialists (name, cypher) VALUES ('Alice', 'HAMAC')", + "INSERT INTO security_specialists (name, cypher) VALUES ('Bob', 'AES')", + "INSERT INTO security_specialists (name, cypher) VALUES ('Charlie', 'RSA')", + "INSERT INTO security_specialists (name, cypher) VALUES ('David', 'SHA2')", + ] + + with freeform_engine.connect() as conn: + for statement in statements: + conn.execute(text(statement)) + + conn.commit() + + return freeform_engine diff --git a/examples/recruting/data/application.csv b/examples/recruiting/data/application.csv similarity index 100% rename from examples/recruting/data/application.csv rename to examples/recruiting/data/application.csv diff --git a/docs/quickstart/candidates.db b/examples/recruiting/data/candidates.db similarity index 100% rename from docs/quickstart/candidates.db rename to examples/recruiting/data/candidates.db diff --git a/examples/recruting/data/offers.csv b/examples/recruiting/data/offers.csv similarity index 100% rename from examples/recruting/data/offers.csv rename to examples/recruiting/data/offers.csv diff --git a/examples/recruting/data/recruiting.csv b/examples/recruiting/data/recruiting.csv similarity index 100% rename from examples/recruting/data/recruiting.csv rename to examples/recruiting/data/recruiting.csv diff --git a/examples/recruting/db.py b/examples/recruiting/db.py similarity index 100% rename from examples/recruting/db.py rename to examples/recruiting/db.py diff --git a/examples/recruting/views.py b/examples/recruiting/views.py similarity index 100% rename from examples/recruting/views.py rename to examples/recruiting/views.py diff --git a/examples/visualize_views_code.py b/examples/visualize_views_code.py new file mode 100644 index 00000000..31806fe8 --- /dev/null +++ b/examples/visualize_views_code.py @@ -0,0 +1,24 @@ +# pylint: disable=missing-function-docstring +import asyncio + +from recruiting.candidate_view_with_similarity_store import CandidateView, country_similarity, engine +from recruiting.cypher_text2sql_view import SampleText2SQLViewCyphers, create_freeform_memory_engine + +import dbally +from dbally.audit import CLIEventHandler +from dbally.gradio import create_gradio_interface +from dbally.llms.litellm import LiteLLM + + +async def main(): + await country_similarity.update() + llm = LiteLLM(model_name="gpt-3.5-turbo") + collection = dbally.create_collection("candidates", llm, event_handlers=[CLIEventHandler()]) + collection.add(CandidateView, lambda: CandidateView(engine)) + collection.add(SampleText2SQLViewCyphers, lambda: SampleText2SQLViewCyphers(create_freeform_memory_engine())) + gradio_interface = await create_gradio_interface(user_collection=collection) + gradio_interface.launch() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/mkdocs.yml b/mkdocs.yml index 9e1b971f..27c764b6 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -29,6 +29,7 @@ nav: - how-to/use_elastic_store.md - how-to/use_custom_similarity_store.md - how-to/update_similarity_indexes.md + - how-to/visualize_views.md - how-to/log_runs_to_langsmith.md - how-to/create_custom_event_handler.md - how-to/openai_assistants_integration.md diff --git a/setup.cfg b/setup.cfg index 34e18232..4a571dbd 100644 --- a/setup.cfg +++ b/setup.cfg @@ -64,7 +64,11 @@ benchmark = pydantic-settings~=2.0.3 psycopg2-binary~=2.9.9 elasticsearch = - elasticsearch==8.13.1 + elasticsearch~=8.13.1 +gradio = + gradio~=4.31.5 + gradio_client~=0.16.4 + [options.packages.find] where = src diff --git a/src/dbally/audit/event_handlers/cli_event_handler.py b/src/dbally/audit/event_handlers/cli_event_handler.py index 15583f1d..f738f90b 100644 --- a/src/dbally/audit/event_handlers/cli_event_handler.py +++ b/src/dbally/audit/event_handlers/cli_event_handler.py @@ -1,24 +1,30 @@ +import re +from io import StringIO +from sys import stdout from typing import Optional, Union try: from rich import print as pprint from rich.console import Console from rich.syntax import Syntax + from rich.text import Text RICH_OUTPUT = True except ImportError: RICH_OUTPUT = False - # TODO: remove color tags from bare print pprint = print # type: ignore from dbally.audit.event_handlers.base import EventHandler from dbally.data_models.audit import LLMEvent, RequestEnd, RequestStart, SimilarityEvent +_RICH_FORMATING_KEYWORD_SET = {"green", "orange", "grey", "bold", "cyan"} +_RICH_FORMATING_PATTERN = rf"\[.*({'|'.join(_RICH_FORMATING_KEYWORD_SET)}).*\]" + class CLIEventHandler(EventHandler): """ This handler displays all interactions between LLM and user happening during `Collection.ask`\ - execution inside the terminal. + execution inside the terminal or store them in the given buffer. ### Usage @@ -34,16 +40,23 @@ class CLIEventHandler(EventHandler): ![Example output from CLIEventHandler](../../assets/event_handler_example.png) """ - def __init__(self) -> None: + def __init__(self, buffer: StringIO = None) -> None: super().__init__() - self._console = Console() if RICH_OUTPUT else None - def _print_syntax(self, content: str, lexer: str) -> None: + self.buffer = buffer + out = self.buffer if buffer else stdout + self._console = Console(file=out, record=True) if RICH_OUTPUT else None + + def _print_syntax(self, content: str, lexer: str = None) -> None: if self._console: - console_content = Syntax(content, lexer, word_wrap=True) + if lexer: + console_content = Syntax(content, lexer, word_wrap=True) + else: + console_content = Text.from_markup(content) self._console.print(console_content) else: - print(content) + content_without_formatting = re.sub(_RICH_FORMATING_PATTERN, "", content) + print(content_without_formatting) async def request_start(self, user_request: RequestStart) -> None: """ @@ -52,10 +65,9 @@ async def request_start(self, user_request: RequestStart) -> None: Args: user_request: Object containing name of collection and asked query """ - - pprint(f"[orange3 bold]Request starts... \n[orange3 bold]MESSAGE: [grey53]{user_request.question}") - pprint("[grey53]\n=======================================") - pprint("[grey53]=======================================\n") + self._print_syntax(f"[orange3 bold]Request starts... \n[orange3 bold]MESSAGE: [grey53]{user_request.question}") + self._print_syntax("[grey53]\n=======================================") + self._print_syntax("[grey53]=======================================\n") async def event_start(self, event: Union[LLMEvent, SimilarityEvent], request_context: None) -> None: """ @@ -68,16 +80,18 @@ async def event_start(self, event: Union[LLMEvent, SimilarityEvent], request_con """ if isinstance(event, LLMEvent): - pprint(f"[cyan bold]LLM event starts... \n[cyan bold]LLM EVENT PROMPT TYPE: [grey53]{event.type}") + self._print_syntax( + f"[cyan bold]LLM event starts... \n[cyan bold]LLM EVENT PROMPT TYPE: [grey53]{event.type}" + ) if isinstance(event.prompt, tuple): for msg in event.prompt: - pprint(f"\n[orange3]{msg['role']}") + self._print_syntax(f"\n[orange3]{msg['role']}") self._print_syntax(msg["content"], "text") else: self._print_syntax(f"{event.prompt}", "text") elif isinstance(event, SimilarityEvent): - pprint( + self._print_syntax( f"[cyan bold]Similarity event starts... \n" f"[cyan bold]INPUT: [grey53]{event.input_value}\n" f"[cyan bold]STORE: [grey53]{event.store}\n" @@ -95,15 +109,14 @@ async def event_end( request_context: Optional context passed from request_start method event_context: Optional context passed from event_start method """ - if isinstance(event, LLMEvent): - pprint(f"\n[green bold]RESPONSE: {event.response}") - pprint("[grey53]\n=======================================") - pprint("[grey53]=======================================\n") + self._print_syntax(f"\n[green bold]RESPONSE: {event.response}") + self._print_syntax("[grey53]\n=======================================") + self._print_syntax("[grey53]=======================================\n") elif isinstance(event, SimilarityEvent): - pprint(f"[green bold]OUTPUT: {event.output_value}") - pprint("[grey53]\n=======================================") - pprint("[grey53]=======================================\n") + self._print_syntax(f"[green bold]OUTPUT: {event.output_value}") + self._print_syntax("[grey53]\n=======================================") + self._print_syntax("[grey53]=======================================\n") async def request_end(self, output: RequestEnd, request_context: Optional[dict] = None) -> None: """ @@ -113,8 +126,8 @@ async def request_end(self, output: RequestEnd, request_context: Optional[dict] output: The output of the request. request_context: Optional context passed from request_start method """ + self._print_syntax("[green bold]REQUEST OUTPUT:") + self._print_syntax(f"Number of rows: {len(output.result.results)}") - pprint("[green bold]REQUEST OUTPUT:") - pprint(f"Number of rows: {len(output.result.results)}") if "sql" in output.result.context: self._print_syntax(f"{output.result.context['sql']}", "psql") diff --git a/src/dbally/collection.py b/src/dbally/collection.py index 922a6365..089ce0e2 100644 --- a/src/dbally/collection.py +++ b/src/dbally/collection.py @@ -128,6 +128,15 @@ def build_dogs_df_view(): self._views[name] = view self._builders[name] = builder + def add_event_handler(self, event_handler: EventHandler): + """ + Adds an event handler to the list of event handlers. + + Args: + event_handler: The event handler to be added. + """ + self._event_handlers.append(event_handler) + def get(self, name: str) -> BaseView: """ Returns an instance of the view with the given name diff --git a/src/dbally/gradio/__init__.py b/src/dbally/gradio/__init__.py new file mode 100644 index 00000000..41d84c3c --- /dev/null +++ b/src/dbally/gradio/__init__.py @@ -0,0 +1,3 @@ +from dbally.gradio.gradio_interface import create_gradio_interface + +__all__ = ["create_gradio_interface"] diff --git a/src/dbally/gradio/gradio_interface.py b/src/dbally/gradio/gradio_interface.py new file mode 100644 index 00000000..a6669685 --- /dev/null +++ b/src/dbally/gradio/gradio_interface.py @@ -0,0 +1,249 @@ +from io import StringIO +from typing import Tuple + +import gradio +import pandas as pd + +from dbally import BaseStructuredView +from dbally.audit import CLIEventHandler +from dbally.collection import Collection +from dbally.prompts import PromptTemplateError +from dbally.utils.errors import NoViewFoundError, UnsupportedQueryError + + +async def create_gradio_interface(user_collection: Collection, preview_limit: int = 10) -> gradio.Interface: + """Adapt and integrate data collection and query execution with Gradio interface components. + + Args: + user_collection: The user's collection to interact with. + preview_limit: The maximum number of preview data records to display. Default is 10. + + Returns: + The created Gradio interface. + """ + adapter = GradioAdapter() + gradio_interface = await adapter.create_interface(user_collection, preview_limit) + return gradio_interface + + +class GradioAdapter: + """ + A class to adapt and integrate data collection and query execution with Gradio interface components. + """ + + def __init__(self): + """ + Initializes the GradioAdapter with a preview limit. + + """ + self.preview_limit = None + self.selected_view_name = None + self.collection = None + self.log = StringIO() + + def _load_gradio_data(self, preview_dataframe, label, empty_warning=None) -> Tuple[gradio.DataFrame, gradio.Label]: + if not empty_warning: + empty_warning = "Preview not available" + + if preview_dataframe.empty: + gradio_preview_dataframe = gradio.DataFrame(label=label, value=preview_dataframe, visible=False) + empty_frame_label = gradio.Label(value=f"{label} not available", visible=True, show_label=False) + else: + gradio_preview_dataframe = gradio.DataFrame(label=label, value=preview_dataframe, visible=True) + empty_frame_label = gradio.Label(value=f"{label} not available", visible=False, show_label=False) + return gradio_preview_dataframe, empty_frame_label + + async def _ui_load_preview_data( + self, selected_view_name: str + ) -> Tuple[gradio.DataFrame, gradio.Label, None, None, None]: + """ + Asynchronously loads preview data for a selected view name. + + Args: + selected_view_name: The name of the selected view to load preview data for. + + Returns: + A tuple containing the preview dataframe, load status text, and four None values to clean gradio fields. + """ + self.selected_view_name = selected_view_name + preview_dataframe = self._load_preview_data(selected_view_name) + gradio_preview_dataframe, empty_frame_label = self._load_gradio_data(preview_dataframe, "Preview") + + return gradio_preview_dataframe, empty_frame_label, None, None, None + + def _load_preview_data(self, selected_view_name: str) -> pd.DataFrame: + """ + Loads preview data for a selected view name. + + Args: + selected_view_name: The name of the selected view to load preview data for. + + Returns: + A tuple containing the preview dataframe + """ + selected_view = self.collection.get(selected_view_name) + if issubclass(type(selected_view), BaseStructuredView): + selected_view_results = selected_view.execute() + preview_dataframe = pd.DataFrame.from_records(selected_view_results.results).head(self.preview_limit) + else: + preview_dataframe = pd.DataFrame() + + return preview_dataframe + + async def _ui_ask_query( + self, question_query: str, natural_language_flag: bool + ) -> Tuple[gradio.DataFrame, gradio.Label, gradio.Text, gradio.Text, str]: + """ + Asynchronously processes a query and returns the results. + + Args: + question_query: The query to process. + natural_language_flag: Flag to indicate if the natural language shall be returned + + Returns: + A tuple containing the generated query context, the query results as a dataframe, and the log output. + """ + self.log.seek(0) + self.log.truncate(0) + textual_response = "" + try: + execution_result = await self.collection.ask( + question=question_query, return_natural_response=natural_language_flag + ) + generated_query = str(execution_result.context) + data = pd.DataFrame.from_records(execution_result.results) + textual_response = str(execution_result.textual_response) if natural_language_flag else textual_response + except UnsupportedQueryError: + generated_query = {"Query": "unsupported"} + data = pd.DataFrame() + except NoViewFoundError: + generated_query = {"Query": "No view matched to query"} + data = pd.DataFrame() + except PromptTemplateError: + generated_query = {"Query": "No view matched to query"} + data = pd.DataFrame() + finally: + self.log.seek(0) + log_content = self.log.read() + + gradio_dataframe, empty_dataframe_warning = self._load_gradio_data(data, "Results", "No matching results found") + return ( + gradio_dataframe, + empty_dataframe_warning, + gradio.Text(value=generated_query, visible=True), + gradio.Text(value=textual_response, visible=natural_language_flag), + log_content, + ) + + def _clear_results(self) -> Tuple[gradio.DataFrame, gradio.Label, gradio.Text, gradio.Text]: + preview_dataframe = self._load_preview_data(self.selected_view_name) + gradio_preview_dataframe, empty_frame_label = self._load_gradio_data(preview_dataframe, "Preview") + + return ( + gradio_preview_dataframe, + empty_frame_label, + gradio.Text(visible=False), + gradio.Text(visible=False), + ) + + async def create_interface(self, user_collection: Collection, preview_limit: int) -> gradio.Interface: + """ + Creates a Gradio interface for interacting with the user collection and similarity stores. + + Args: + user_collection: The user's collection to interact with. + preview_limit: The maximum number of preview data records to display. + + Returns: + The created Gradio interface. + """ + + self.preview_limit = preview_limit + self.collection = user_collection + self.collection.add_event_handler(CLIEventHandler(self.log)) + + data_preview_frame = pd.DataFrame() + question_interactive = False + + view_list = [*user_collection.list()] + if view_list: + self.selected_view_name = view_list[0] + data_preview_frame = self._load_preview_data(self.selected_view_name) + question_interactive = True + + with gradio.Blocks() as demo: + with gradio.Row(): + with gradio.Column(): + view_dropdown = gradio.Dropdown( + label="Data View preview", choices=view_list, value=self.selected_view_name + ) + query = gradio.Text(label="Ask question", interactive=question_interactive) + query_button = gradio.Button("Ask db-ally", interactive=question_interactive) + clear_button = gradio.ClearButton(components=[query], interactive=question_interactive) + natural_language_response_checkbox = gradio.Checkbox( + label="Return natural language answer", interactive=question_interactive + ) + + with gradio.Column(): + if not data_preview_frame.empty: + loaded_data_frame = gradio.Dataframe( + label="Preview", value=data_preview_frame, interactive=False + ) + empty_frame_label = gradio.Label(value="Preview not available", visible=False) + else: + loaded_data_frame = gradio.Dataframe(interactive=False, visible=False) + empty_frame_label = gradio.Label(value="Preview not available", visible=True) + + query_sql_result = gradio.Text(label="Generated query context", visible=False) + generated_natural_language_answer = gradio.Text( + label="Generated answer in natural language:", visible=False + ) + + with gradio.Row(): + log_console = gradio.Code(label="Logs", language="shell") + + clear_button.add( + [ + natural_language_response_checkbox, + loaded_data_frame, + query_sql_result, + generated_natural_language_answer, + log_console, + ] + ) + + clear_button.click( + fn=self._clear_results, + inputs=[], + outputs=[ + loaded_data_frame, + empty_frame_label, + query_sql_result, + generated_natural_language_answer, + ], + ) + + view_dropdown.change( + fn=self._ui_load_preview_data, + inputs=view_dropdown, + outputs=[ + loaded_data_frame, + empty_frame_label, + query, + query_sql_result, + log_console, + ], + ) + query_button.click( + fn=self._ui_ask_query, + inputs=[query, natural_language_response_checkbox], + outputs=[ + loaded_data_frame, + empty_frame_label, + query_sql_result, + generated_natural_language_answer, + log_console, + ], + ) + + return demo