Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Gradio Adater implementation #39

Merged
merged 26 commits into from
Jun 11, 2024
Merged
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions docs/how-to/visualize_views.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# How-To: Visualize Views

To create simple UI interface use [GradioAdapter class](../../src/dbally/utils/gradio_adapter.py) It allows to display Data Preview related to Views
and execute user queries.

## Installation
```bash
pip install dbally[gradio]
```

## Create own gradio interface
Define collection with implemented views

```python
llm = LiteLLM(model_name="gpt-3.5-turbo")
collection = dbally.create_collection("recruitment", llm, event_handlers=[CLIEventHandler()])
collection.add(CandidateView, lambda: CandidateView(engine))
collection.add(SampleText2SQLView, lambda: SampleText2SQLView(prepare_freeform_enginge()))
```

Create gradio interface
```python
gradio_adapter = GradioAdapter()
gradio_interface = await gradio_adapter.create_interface(collection)
```

Launch the gradio interface. To publish public interface pass argument `share=True`
```python
gradio_interface.launch()
```

The endpoint is set by by triggering python module with Gradio Adapter launch command.
Private endpoint is set to http://127.0.0.1:7860/ by default.

## Links
* [Example Gradio Interface](visualize_views_code.py)
80 changes: 80 additions & 0 deletions docs/how-to/visualize_views_code.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import asyncio
import dotenv
import os

import sqlalchemy
from sqlalchemy import text

import dbally
from dbally.audit import CLIEventHandler
from dbally.embeddings import LiteLLMEmbeddingClient
from dbally.similarity import SimilarityIndex, SimpleSqlAlchemyFetcher, FaissStore
from dbally.llms.litellm import LiteLLM
from dbally.utils.gradio_adapter import GradioAdapter
from dbally.views.freeform.text2sql import BaseText2SQLView, TableConfig, ColumnConfig
from sandbox.quickstart2 import CandidateView, engine, Candidate

dotenv.load_dotenv()
country_similarity = SimilarityIndex(
fetcher=SimpleSqlAlchemyFetcher(
engine,
table=Candidate,
column=Candidate.country,
),
store=FaissStore(
index_dir="./similarity_indexes",
index_name="country_similarity",
embedding_client=LiteLLMEmbeddingClient(
api_key=os.environ["OPENAI_API_KEY"],
),
),
)


class SampleText2SQLViewCyphers(BaseText2SQLView):
def get_tables(self):
return [
TableConfig(
name="security_specialists",
columns=[
ColumnConfig("id", "SERIAL PRIMARY KEY"),
ColumnConfig("name", "VARCHAR(255)"),
ColumnConfig("cypher", "VARCHAR(255)"),
],
description="Knowledge base",
)
]


def prepare_freeform_enginge():
freeform_engine = sqlalchemy.create_engine("sqlite:///:memory:")

statements = [
"CREATE TABLE security_specialists (id INTEGER PRIMARY KEY, name TEXT, cypher TEXT)",
"INSERT INTO security_specialists (name, cypher) VALUES ('Alice', 'HAMAC')",
"INSERT INTO security_specialists (name, cypher) VALUES ('Bob', 'AES')",
"INSERT INTO security_specialists (name, cypher) VALUES ('Charlie', 'RSA')",
"INSERT INTO security_specialists (name, cypher) VALUES ('David', 'SHA2')",
]

with freeform_engine.connect() as conn:
for statement in statements:
conn.execute(text(statement))

conn.commit()

return freeform_engine


async def main():
llm = LiteLLM(model_name="gpt-3.5-turbo")
collection = dbally.create_collection("new_one", llm, event_handlers=[CLIEventHandler()])
collection.add(CandidateView, lambda: CandidateView(engine))
collection.add(SampleText2SQLViewCyphers, lambda: SampleText2SQLViewCyphers(prepare_freeform_enginge()))
gradio_adapter = GradioAdapter()
gradio_interface = await gradio_adapter.create_interface(collection)
gradio_interface.launch()


if __name__ == "__main__":
asyncio.run(main())
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
@@ -29,6 +29,7 @@ nav:
- how-to/use_elastic_store.md
- how-to/use_custom_similarity_store.md
- how-to/update_similarity_indexes.md
- how-to/visualize_views.md
- how-to/log_runs_to_langsmith.md
- how-to/create_custom_event_handler.md
- how-to/openai_assistants_integration.md
6 changes: 5 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -64,7 +64,11 @@ benchmark =
pydantic-settings~=2.0.3
psycopg2-binary~=2.9.9
elasticsearch =
elasticsearch==8.13.1
elasticsearch~=8.13.1
gradio =
gradio~=4.31.5
gradio_client~=0.16.4


[options.packages.find]
where = src
59 changes: 36 additions & 23 deletions src/dbally/audit/event_handlers/cli_event_handler.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,30 @@
import re
from io import StringIO
from sys import stdout
from typing import Optional, Union

try:
from rich import print as pprint
from rich.console import Console
from rich.syntax import Syntax
from rich.text import Text

RICH_OUTPUT = True
except ImportError:
RICH_OUTPUT = False
# TODO: remove color tags from bare print
pprint = print # type: ignore

from dbally.audit.event_handlers.base import EventHandler
from dbally.data_models.audit import LLMEvent, RequestEnd, RequestStart, SimilarityEvent

_RICH_FORMATING_KEYWORD_SET = {"green", "orange", "grey", "bold", "cyan"}
_RICH_FORMATING_PATTERN = rf"\[.*({'|'.join(_RICH_FORMATING_KEYWORD_SET)}).*\]"


class CLIEventHandler(EventHandler):
"""
This handler displays all interactions between LLM and user happening during `Collection.ask`\
execution inside the terminal.
execution inside the terminal or store them in the given buffer.

### Usage

@@ -34,16 +40,23 @@ class CLIEventHandler(EventHandler):
![Example output from CLIEventHandler](../../assets/event_handler_example.png)
"""

def __init__(self) -> None:
def __init__(self, buffer: StringIO = None) -> None:
super().__init__()
self._console = Console() if RICH_OUTPUT else None

def _print_syntax(self, content: str, lexer: str) -> None:
self.buffer = buffer
out = self.buffer if buffer else stdout
self._console = Console(file=out, record=True) if RICH_OUTPUT else None

def _print_syntax(self, content: str, lexer: str = None) -> None:
if self._console:
console_content = Syntax(content, lexer, word_wrap=True)
if lexer:
console_content = Syntax(content, lexer, word_wrap=True)
else:
console_content = Text.from_markup(content)
self._console.print(console_content)
else:
print(content)
content_without_formatting = re.sub(_RICH_FORMATING_PATTERN, "", content)
print(content_without_formatting)

async def request_start(self, user_request: RequestStart) -> None:
"""
@@ -52,10 +65,9 @@ async def request_start(self, user_request: RequestStart) -> None:
Args:
user_request: Object containing name of collection and asked query
"""

pprint(f"[orange3 bold]Request starts... \n[orange3 bold]MESSAGE: [grey53]{user_request.question}")
pprint("[grey53]\n=======================================")
pprint("[grey53]=======================================\n")
self._print_syntax(f"[orange3 bold]Request starts... \n[orange3 bold]MESSAGE: [grey53]{user_request.question}")
self._print_syntax("[grey53]\n=======================================")
self._print_syntax("[grey53]=======================================\n")

async def event_start(self, event: Union[LLMEvent, SimilarityEvent], request_context: None) -> None:
"""
@@ -68,16 +80,18 @@ async def event_start(self, event: Union[LLMEvent, SimilarityEvent], request_con
"""

if isinstance(event, LLMEvent):
pprint(f"[cyan bold]LLM event starts... \n[cyan bold]LLM EVENT PROMPT TYPE: [grey53]{event.type}")
self._print_syntax(
f"[cyan bold]LLM event starts... \n[cyan bold]LLM EVENT PROMPT TYPE: [grey53]{event.type}"
)

if isinstance(event.prompt, tuple):
for msg in event.prompt:
pprint(f"\n[orange3]{msg['role']}")
self._print_syntax(f"\n[orange3]{msg['role']}")
self._print_syntax(msg["content"], "text")
else:
self._print_syntax(f"{event.prompt}", "text")
elif isinstance(event, SimilarityEvent):
pprint(
self._print_syntax(
f"[cyan bold]Similarity event starts... \n"
f"[cyan bold]INPUT: [grey53]{event.input_value}\n"
f"[cyan bold]STORE: [grey53]{event.store}\n"
@@ -95,15 +109,14 @@ async def event_end(
request_context: Optional context passed from request_start method
event_context: Optional context passed from event_start method
"""

if isinstance(event, LLMEvent):
pprint(f"\n[green bold]RESPONSE: {event.response}")
pprint("[grey53]\n=======================================")
pprint("[grey53]=======================================\n")
self._print_syntax(f"\n[green bold]RESPONSE: {event.response}")
self._print_syntax("[grey53]\n=======================================")
self._print_syntax("[grey53]=======================================\n")
elif isinstance(event, SimilarityEvent):
pprint(f"[green bold]OUTPUT: {event.output_value}")
pprint("[grey53]\n=======================================")
pprint("[grey53]=======================================\n")
self._print_syntax(f"[green bold]OUTPUT: {event.output_value}")
self._print_syntax("[grey53]\n=======================================")
self._print_syntax("[grey53]=======================================\n")

async def request_end(self, output: RequestEnd, request_context: Optional[dict] = None) -> None:
"""
@@ -113,8 +126,8 @@ async def request_end(self, output: RequestEnd, request_context: Optional[dict]
output: The output of the request.
request_context: Optional context passed from request_start method
"""
self._print_syntax("[green bold]REQUEST OUTPUT:")
self._print_syntax(f"Number of rows: {len(output.result.results)}")

pprint("[green bold]REQUEST OUTPUT:")
pprint(f"Number of rows: {len(output.result.results)}")
if "sql" in output.result.context:
self._print_syntax(f"{output.result.context['sql']}", "psql")
9 changes: 9 additions & 0 deletions src/dbally/collection.py
Original file line number Diff line number Diff line change
@@ -128,6 +128,15 @@ def build_dogs_df_view():
self._views[name] = view
self._builders[name] = builder

def add_event_handler(self, event_handler: EventHandler):
"""
Adds an event handler to the list of event handlers.

Args:
event_handler: The event handler to be added.
"""
self._event_handlers.append(event_handler)

def get(self, name: str) -> BaseView:
"""
Returns an instance of the view with the given name
Loading