Skip to content

Commit

Permalink
Update example
Browse files Browse the repository at this point in the history
  • Loading branch information
mhordynski committed May 21, 2024
1 parent ebe57d4 commit 58cfb69
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 25 deletions.
74 changes: 50 additions & 24 deletions examples/freeform.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,70 @@
import asyncio
from typing import List

import sqlalchemy

import dbally
from dbally.audit.event_handlers.cli_event_handler import CLIEventHandler
from dbally.llm_client.openai_client import OpenAIClient
from dbally.views.freeform.text2sql import Text2SQLConfig, Text2SQLFreeformView, Text2SQLTableConfig
from dbally.llms import LiteLLM
from dbally.views.freeform.text2sql import BaseText2SQLView, ColumnConfig, TableConfig


async def main():
"""Main function to run the example."""
config = Text2SQLConfig(
tables={
"customers": Text2SQLTableConfig(
ddl="CREATE TABLE customers (id INTEGER, name TEXT, city TEXT, country TEXT, age INTEGER)",
description="Table of customers",
similarity={"city": "city", "country": "country"},
class MyText2SqlView(BaseText2SQLView):
"""
A Text2SQL view for the example.
"""

def get_tables(self) -> List[TableConfig]:
"""
Get the tables used by the view.
Returns:
A list of tables.
"""
return [
TableConfig(
name="customers",
columns=[
ColumnConfig("id", "SERIAL PRIMARY KEY"),
ColumnConfig("name", "VARCHAR(255)"),
ColumnConfig("city", "VARCHAR(255)"),
ColumnConfig("country", "VARCHAR(255)"),
ColumnConfig("age", "INTEGER"),
],
),
"products": Text2SQLTableConfig(
ddl="CREATE TABLE products (id INTEGER, name TEXT, category TEXT, price REAL)",
description="Table of products",
similarity={"name": "name", "category": "category"},
TableConfig(
name="products",
columns=[
ColumnConfig("id", "SERIAL PRIMARY KEY"),
ColumnConfig("name", "VARCHAR(255)"),
ColumnConfig("category", "VARCHAR(255)"),
ColumnConfig("price", "REAL"),
],
),
"purchases": Text2SQLTableConfig(
ddl="CREATE TABLE purchases (customer_id INTEGER, product_id INTEGER, quantity INTEGER, date TEXT)",
description="Table of purchases",
similarity={},
TableConfig(
name="purchases",
columns=[
ColumnConfig("customer_id", "INTEGER"),
ColumnConfig("product_id", "INTEGER"),
ColumnConfig("quantity", "INTEGER"),
ColumnConfig("date", "TEXT"),
],
),
}
)
]


async def main():
"""Main function to run the example."""
engine = sqlalchemy.create_engine("sqlite:///:memory:")

# Create tables from config
with engine.connect() as connection:
for _, table_config in config.tables.items():
for table_config in MyText2SqlView(engine).get_tables():
connection.execute(sqlalchemy.text(table_config.ddl))

llm_client = OpenAIClient()
collection = dbally.create_collection("text2sql", llm_client=llm_client, event_handlers=[CLIEventHandler()])
collection.add(Text2SQLFreeformView, lambda: Text2SQLFreeformView(engine, config))
llm = LiteLLM()
collection = dbally.create_collection("text2sql", llm=llm, event_handlers=[CLIEventHandler()])
collection.add(MyText2SqlView, lambda: MyText2SqlView(engine))

await collection.ask("What are the names of products bought by customers from London?")
await collection.ask("Which customers bought products from the category 'electronics'?")
Expand Down
2 changes: 1 addition & 1 deletion src/dbally/views/freeform/text2sql/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def ddl(self) -> str:
The DDL for the table.
"""
return (
f"CREATE TABLE {self.name} )"
f"CREATE TABLE {self.name} ("
+ ", ".join(f"{column.name} {column.data_type}" for column in self.columns)
+ ");"
)
Expand Down

0 comments on commit 58cfb69

Please sign in to comment.