Skip to content

Commit

Permalink
Add ability to build pkg(.whl) (#6)
Browse files Browse the repository at this point in the history
* Build .whl
* Model path correction
* Fix model access all-MiniLM-L6-v2
* Add control to adjust log levels
* Model cache path fix
  • Loading branch information
pramitchoudhary authored Jun 9, 2023
1 parent 063977e commit 743d7d0
Show file tree
Hide file tree
Showing 7 changed files with 106 additions and 51 deletions.
6 changes: 2 additions & 4 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
sentence_transformer = s3cmd get --recursive --skip-existing s3://h2o-model-gym/models/nlp/sentence_trasnsformer/all-MiniLM-L6-v1/ ./models/sentence_transformer/all-MiniLM-L6-v1
sentence_transformer = s3cmd get --recursive --skip-existing s3://h2o-model-gym/models/nlp/sentence_trasnsformer/all-MiniLM-L6-v2/ ./models/sentence_transformers/sentence-transformers_all-MiniLM-L6-v2

.PHONY: download-models

all: download-models

setup: download_models ## Setup
mkdir -p ./var/lib/tmp/data
mkdir -p ./var/lib/tmp/.cache
python3 -m venv .sidekickvenv
./.sidekickvenv/bin/python3 -m pip install --upgrade pip
./.sidekickvenv/bin/python3 -m pip install wheel
./.sidekickvenv/bin/python3 -m pip install -r requirements.txt

download_models:
mkdir -p ./models/sentence_transformer/all-MiniLM-L6-v1
mkdir -p ./models/sentence_transformers/sentence-transformers_all-MiniLM-L6-v2
$(sentence_transformer)
11 changes: 7 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,21 @@ version = "0.0.1"
license = "Proprietary"
description = "An AI assistant for SQL"
authors = [
"Pramit Choudhary <[email protected]>",
"Pramit Choudhary <[email protected]>",
"Michal Malohlava <[email protected]>"
]
readme = "README.md"
classifiers = [
"Development Status :: 4 - Beta",
"Environment :: Console",
"Intended Audience :: Developers",
"Development Status :: Alpha",
"Environment :: CLI",
"Intended Audience :: Developers, Analysts",
"License :: Other/Proprietary License",
"Programming Language :: Python",
"Programming Language :: Python :: 3.8+",
"Topic :: Scientific/Engineering",
"Topic :: Scientific/Engineering :: Artificial Intelligence"
]
packages = [{include = "sidekick"}]

[tool.poetry.dependencies]
python = ">=3.8.16,<3.10"
Expand All @@ -36,6 +37,8 @@ sentence-transformers = "^2.2.2"
torch = "^2.0.1"
sqlalchemy-utils = "^0.41.1"

[tool.poetry.scripts]
sql-sidekick = "sidekick.prompter:cli"

[tool.poetry.dev-dependencies]
pylint = { version = "^2.12.2", allow-prereleases = true }
Expand Down
3 changes: 3 additions & 0 deletions .env.toml → sidekick/configs/.env.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,6 @@ HOST_NAME = "localhost"
USER_NAME = "postgres"
PASSWORD = "abc"
DB_NAME = "querydb"

[LOGGING]
LOG-LEVEL = "INFO"
11 changes: 5 additions & 6 deletions sidekick/examples/sample_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
Column Name: resource_type
Column Type: TEXT NOT NULL, -- or int?
Sample Values: ['FEATURE_STORE', 'PROJECT', 'MLOPS_EXPERIMENT', 'APP', 'APP_INSTANCE', 'MLOPS_DEPLOYMENT',
'MLOPS_DATASET', 'MLOPS_USER', 'RESOURCE_TYPE_UNSPECIFIED', 'SCORING'], 'DAI_ENGINE', 'MLOPS_MODEL']
'MLOPS_DATASET', 'MLOPS_USER', 'RESOURCE_TYPE_UNSPECIFIED', 'SCORING', 'DAI_ENGINE', 'MLOPS_MODEL']
Column Name: resource_id
Column Type: TEXT
Expand Down Expand Up @@ -66,7 +66,6 @@
Column Name: source
Column Type: TEXT NOT NULL
Column Name: payload
Column Type: jsonb NOT NULL
Sample Values:
Expand Down Expand Up @@ -105,14 +104,14 @@

# For few shot prompting
samples_queries = [
"""
"""
# query: Total number of CPUs used?
# answer:
SELECT sum((payload->'engineEvent'-> 'pausing' -> 'engine'->> 'cpu')::integer) AS total_cpus_used
FROM telemetry
WHERE payload->'engineEvent'-> 'pausing' -> 'engine'->> 'cpu' IS NOT NULL;
""",
"""
"""
# query: Find the number of AI units for each user using stream for each resource type (overall)
# answer:
SELECT user_id, user_name, resource_type, date_trunc('day', ts) as start_day,
Expand All @@ -124,7 +123,7 @@
) sub GROUP BY user_id, user_name, resource_type, start_day
ORDER BY start_day DESC NULLS LAST;
""",
"""
"""
# query: Compute global usage over time
# answer:
SELECT
Expand All @@ -144,5 +143,5 @@
WHERE stream = 'gauage_resources'
) AS internal
ORDER BY 1, 2 DESC;
"""
""",
]
76 changes: 49 additions & 27 deletions sidekick/prompter.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import io
import json
import os
from pathlib import Path
Expand All @@ -8,25 +9,37 @@
from colorama import Back as B
from colorama import Fore as F
from colorama import Style
from db_config import DBConfig
from loguru import logger
from memory import EntityMemory
from query import SQLGenerator
from utils import save_query
from sidekick.db_config import DBConfig
from sidekick.memory import EntityMemory
from sidekick.query import SQLGenerator
from sidekick.utils import save_query, setup_dir

# Load the config file and initialize required paths
base_path = (Path(__file__).parent / "../").resolve()
env_settings = toml.load(f"{base_path}/.env.toml")
env_settings = toml.load(f"{base_path}/sidekick/configs/.env.toml")
os.environ["TOKENIZERS_PARALLELISM"] = "False"
__version__ = "0.0.1"


def color(fore="", back="", text=None):
return f"{fore}{back}{text}{Style.RESET_ALL}"


@click.group()
@click.version_option()
msg = """Welcome to the SQL Sidekick!\nI am AI assistant that helps you with SQL queries.
I can help you with the following:\n
1. Configure a local database(for schema validation and syntax checking): `sql-sidekick configure db-setup`.\n
2. Learn contextual query/answer pairs: `sql-sidekick learn add-samples`.\n
3. Simply add context: `sql-sidekick learn update-context`.\n
4. Ask a question: `sql-sidekick query`.
"""


@click.group(help=msg)
@click.version_option("-V", "--version", message=f"sql-sidekick - {__version__}")
def cli():
""" """
# Book-keeping
setup_dir(base_path)


@cli.group("configure")
Expand All @@ -39,6 +52,16 @@ def enter_table_name():
return val


@configure.command("log", help="Adjust log settings")
@click.option("--set_level", "-l", help="Set log level (Default: INFO)")
def set_loglevel(set_level):
env_settings["LOGGING"]["LOG-LEVEL"] = set_level
# Update settings file for future use.
f = open(f"{base_path}/sidekick/configs/.env.toml", "w")
toml.dump(env_settings, f)
f.close()


@configure.command("db-setup", help="Enter information to configure postgres database locally")
@click.option("--db_name", "-n", default="querydb", help="Database name", prompt="Enter Database name")
@click.option("--hostname", "-h", default="localhost", help="Database hostname", prompt="Enter hostname name")
Expand All @@ -62,7 +85,7 @@ def db_setup(db_name: str, hostname: str, user_name: str, password: str, port: i
env_settings["LOCAL_DB_CONFIG"]["PORT"] = port
env_settings["LOCAL_DB_CONFIG"]["DB_NAME"] = db_name
# Update settings file for future use.
f = open(f"{base_path}/.env.toml", "w")
f = open(f"{base_path}/sidekick/configs/.env.toml", "w")
toml.dump(env_settings, f)
f.close()

Expand Down Expand Up @@ -139,6 +162,9 @@ def add_query_response():
@learn.command("update-context", help="Update context in memory for future use")
def update_context():
"""Helps learn context for generation."""
# Book-keeping
setup_dir(base_path)

context_dict = """{\n"<new_context_key>": "<new_context_value>"\n}
"""
content_file_path = f"{base_path}/var/lib/tmp/data/context.json"
Expand All @@ -164,20 +190,24 @@ def update_context():
@click.option("--question", "-q", help="Database name", prompt="Ask a question")
def query(question: str):
"""Asks question and returns SQL."""
# Book-keeping
setup_dir(base_path)

# Check if table exists
path = f"{base_path}/var/lib/tmp/data/"
path = f"{base_path}/var/lib/tmp/data"
table_context_file = f"{path}/table_context.json"
table_context = json.load(open(table_context_file, "r")) if Path(table_context_file).exists() else {}
table_names = []
if table_context:
table_name = table_context.get("tables_in_use", None)
table_name = [_t.replace(" ", "_") for _t in table_name]
table_names = [_t.replace(" ", "_") for _t in table_name]
else:
table_name = [click.prompt("Which table to use?")]
table_context["tables_in_use"] = table_name.replace(" ", "_")
# Ask for table name only when more than one table exists.
table_names = [click.prompt("Which table to use?")]
table_context["tables_in_use"] = [_t.replace(" ", "_") for _t in table_names]
with open(f"{path}/table_context.json", "w") as outfile:
json.dump(table_context, outfile, indent=4, sort_keys=False)
logger.info(f"Table in use: {table_name}")
logger.info(f"Table in use: {table_names}")
# Check if .env.toml file exists
api_key = env_settings["OPENAI"]["OPENAI_API_KEY"]
if api_key is None or api_key == "":
Expand All @@ -186,11 +216,12 @@ def query(question: str):
color(F.GREEN, "", "Looks like API key is not set, would you like to set OPENAI_API_KEY? (y/n):")
)
if val.lower() == "y":
api_key = input(color(F.GREEN, "", "Enter OPENAI_API_KEY:"))
api_key = input(color(F.GREEN, "", "Enter OPENAI_API_KEY :"))
os.environ["OPENAI_API_KEY"] = api_key
env_settings["OPENAI"]["OPENAI_API_KEY"] = api_key

# Update settings file for future use.
f = open(f"{base_path}/.env.toml", "w")
f = open(f"{base_path}/sidekick/configs/.env.toml", "w")
toml.dump(env_settings, f)
f.close()
openai.api_key = api_key
Expand All @@ -209,7 +240,7 @@ def query(question: str):
)

sql_g = SQLGenerator(db_url, api_key, path=base_path)
sql_g._tasks = sql_g.generate_tasks(table_name, question)
sql_g._tasks = sql_g.generate_tasks(table_names, question)
click.echo(sql_g._tasks)

updated_tasks = None
Expand All @@ -222,7 +253,7 @@ def query(question: str):
click.echo("Skipping edit...")
if updated_tasks is not None:
sql_g._tasks = updated_tasks
res = sql_g.generate_sql(table_name, question)
res = sql_g.generate_sql(table_names, question)
logger.info(f"Generated response:\n\n{res}")

if res is not None:
Expand All @@ -238,13 +269,4 @@ def query(question: str):


if __name__ == "__main__":
click.echo(
"""Welcome to the SQL Sidekick!\nI am AI assistant that helps you with SQL queries.
I can help you with the following:
1. Configure a local database(for schema validation and syntax checking): `python sidekick/prompter.py configure db-setup`.
2. Learn contextual query/answer pairs: `python sidekick/prompter.py learn add-samples`.
3. Simply add context: `python sidekick/prompter.py learn update-context`.
4. Ask a question: `python sidekick/prompter.py query`.\n
"""
)
cli()
24 changes: 15 additions & 9 deletions sidekick/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,22 @@
import numpy as np
import openai
import sqlglot
from configs.prompt_template import DEBUGGING_PROMPT, QUERY_PROMPT, TASK_PROMPT
from examples.sample_data import sample_values, samples_queries
import toml
from langchain import OpenAI
from llama_index import GPTSimpleVectorIndex, GPTSQLStructStoreIndex, LLMPredictor, ServiceContext, SQLDatabase
from llama_index import (GPTSimpleVectorIndex, GPTSQLStructStoreIndex,
LLMPredictor, ServiceContext, SQLDatabase)
from llama_index.indices.struct_store import SQLContextContainerBuilder
from loguru import logger
from sidekick.configs.prompt_template import (DEBUGGING_PROMPT, QUERY_PROMPT,
TASK_PROMPT)
from sidekick.examples.sample_data import sample_values, samples_queries
from sidekick.utils import remove_duplicates
from sqlalchemy import create_engine
from utils import remove_duplicates

logger.remove()
logger.add(sys.stderr, level="INFO")
base_path = (Path(__file__).parent / "../").resolve()
env_settings = toml.load(f"{base_path}/sidekick/configs/.env.toml")
logger.add(sys.stderr, level=env_settings['LOGGING']['LOG-LEVEL'])


class SQLGenerator:
Expand Down Expand Up @@ -76,6 +81,7 @@ def _query_tasks(self, question_str, data_info, sample_queries, table_name: list
# Role and content
query_txt = [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}]
logger.debug(f"Query Text:\n {query_txt}")

# TODO ADD local model
completion = openai.ChatCompletion.create(
model="gpt-3.5-turbo-0301",
Expand Down Expand Up @@ -120,27 +126,27 @@ def generate_response(self, context_container, sql_index, input_prompt, attempt_
res = qry_txt
return res

def generate_tasks(self, table_name: str, input_question: str):
def generate_tasks(self, table_names: list, input_question: str):
try:
# Step 1: Given a question, generate tasks to possibly answer the question and persist the result -> tasks.txt
# Step 2: Append task list to 'query_prompt_template', generate SQL code to answer the question and persist the result -> sql.txt
context_queries: list = self.update_context_queries()

# Remove duplicates from the context queries
m_path = f"{self.path}/models/sentence_transformer/all-MiniLM-L6-v1/"
m_path = f"{self.path}/var/lib/tmp/.cache/models"
duplicates_idx = remove_duplicates(context_queries, m_path)
updated_context = np.delete(np.array(context_queries), duplicates_idx).tolist()

_queries = "\n".join(updated_context)
self.content_queries = _queries
task_list = self._query_tasks(input_question, sample_values, _queries.lower(), table_name)
task_list = self._query_tasks(input_question, sample_values, _queries.lower(), table_names)
with open(f"{self.path}/var/lib/tmp/data/tasks.txt", "w") as f:
f.write(task_list)
return task_list
except Exception as se:
raise se

def generate_sql(self, table_name: str, input_question: str, _dialect: str = "postgres"):
def generate_sql(self, table_name: list, input_question: str, _dialect: str = "postgres"):
_tasks = self.task_formatter(self._tasks)
context_file = f"{self.path}/var/lib/tmp/data/context.json"
additional_context = json.load(open(context_file, "r")) if Path(context_file).exists() else {}
Expand Down
26 changes: 25 additions & 1 deletion sidekick/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import json
import os
from pathlib import Path
from typing import Optional

import numpy as np
Expand All @@ -13,11 +15,21 @@ def generate_sentence_embeddings(model_path: str, x, batch_size: int = 32, devic
# 3. Model Card: https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2
# 4. Reference: https://huggingface.co/spaces/mteb/leaderboard
# Maps sentence & paragraphs to a 384 dimensional dense vector space.
sentence_model = SentenceTransformer(model_path, device=device)
model_name_path = f"{model_path}/sentence_transformers/sentence-transformers_all-MiniLM-L6-v2/"
current_torch_home = os.environ.get("TORCH_HOME", "")
if Path(model_name_path).is_dir():
is_empty = not any(Path(model_name_path).iterdir())
if is_empty:
# Download n cache at the specified location
# https://public.ukp.informatik.tu-darmstadt.de/reimers/sentence-transformers/v0.2/all-MiniLM-L6-v2.zip
os.environ["TORCH_HOME"] = model_path
model_name_path = "sentence-transformers/all-MiniLM-L6-v2"
sentence_model = SentenceTransformer(model_name_path, device=device)
all_res = np.zeros(shape=(len(x), 0))
res = sentence_model.encode(x, batch_size=batch_size, show_progress_bar=True)
all_res = np.hstack((all_res, res))
del sentence_model
os.environ["TORCH_HOME"] = current_torch_home
return all_res


Expand Down Expand Up @@ -49,3 +61,15 @@ def save_query(output_path: str, query, response, extracted_entity: Optional[dic
with open(f"{output_path}/var/lib/tmp/data/history.jsonl", "a") as outfile:
json.dump(chat_history, outfile)
outfile.write("\n")


def setup_dir(base_path: str):
dir_list = [
"var/lib/tmp/data",
"var/lib/tmp/.cache",
"var/lib/tmp/.cache/models/sentence_transformers/sentence-transformers_all-MiniLM-L6-v2",
]
for _dl in dir_list:
p = Path(f"{base_path}/{_dl}")
if not p.is_dir():
p.mkdir(parents=True, exist_ok=True)

0 comments on commit 743d7d0

Please sign in to comment.