Skip to content

Commit

Permalink
Bug fixes related to UI flow
Browse files Browse the repository at this point in the history
  • Loading branch information
pramitchoudhary authored Jan 24, 2024
2 parents 88973fc + 2513b40 commit 83dd5b4
Show file tree
Hide file tree
Showing 10 changed files with 185 additions and 132 deletions.
4 changes: 4 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,9 @@ download_demo_data:
run:
./.sidekickvenv/bin/python3 start.py

clean:
rm -rf ./db
rm -rf ./var

cloud_bundle:
h2o bundle -L debug 2>&1 | tee -a h2o-bundle.log
2 changes: 1 addition & 1 deletion about.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

**Target Audience:** Data (Machine Learning) Scientists, Citizen Data Scientists, Data Engineers Managers and Business Analysts

**Actively Being Maintained:** Yes (Demo release: _In active RnD_)
**Actively Being Maintained:** Yes (Demo release)

**Last Updated:** January, 2024

Expand Down
5 changes: 3 additions & 2 deletions app.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ name = "ai.h2o.wave.sql-sidekick"
title = "SQL-Sidekick"
description = "QnA with tabular data using NLQ"
LongDescription = "about.md"
Tags = ["DATA_SCIENCE", "MACHINE_LEARNING", "NLP"]
Version = "0.2.0"
InstanceLifecycle = "MANAGED"
Tags = ["DATA_SCIENCE", "MACHINE_LEARNING", "NLP", "GENERATIVE_AI"]
Version = "0.2.1"

[Runtime]
MemoryLimit = "64Gi"
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "sql-sidekick"
version = "0.2.0"
version = "0.2.1"
license = "Apache-2.0 license"
description = "An AI assistant for SQL generation"
authors = [
Expand Down
52 changes: 33 additions & 19 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ aiosignal==1.3.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.
ansicon==1.89.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" and platform_system == "Windows"
anyio==4.2.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
async-timeout==4.0.3 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
attrs==23.1.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
attrs==23.2.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
beautifulsoup4==4.12.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
bitsandbytes==0.41.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
blessed==1.20.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
Expand All @@ -13,73 +13,86 @@ certifi==2023.11.17 ; python_full_version >= "3.8.1" and python_full_version <=
charset-normalizer==3.3.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
click==8.1.7 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
colorama==0.4.6 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
databricks-sql-connector==3.0.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
dataclasses-json==0.6.3 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
deprecated==1.2.14 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
distro==1.8.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
distro==1.9.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
editor==1.6.5 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
et-xmlfile==1.1.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
exceptiongroup==1.2.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
filelock==3.13.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
frozenlist==1.4.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
fsspec==2023.12.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
greenlet==3.0.3 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
h11==0.14.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
h2o-wave==0.26.3 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
h2ogpte==1.2.6 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
h2ogpte==1.2.7 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
httpcore==0.17.3 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
httpx==0.24.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
huggingface-hub==0.20.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
huggingface-hub==0.20.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
idna==3.6 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
inquirer==3.1.4 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
inquirer==3.2.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
instructorembedding==1.0.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
jinja2==3.1.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
jinja2==3.1.3 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
jinxed==1.2.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" and platform_system == "Windows"
joblib==1.3.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
llama-index==0.9.20 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
jsonpatch==1.33 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
jsonpointer==2.4 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
langchain-community==0.0.13 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
langchain-core==0.1.11 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
langsmith==0.0.81 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
llama-index==0.9.32 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
loguru==0.7.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
lxml==4.9.3 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
lz4==4.3.3 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
markupsafe==2.1.3 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
marshmallow==3.20.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
marshmallow==3.20.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
mpmath==1.3.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
multidict==6.0.4 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
mypy-extensions==1.0.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
nest-asyncio==1.5.8 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
nest-asyncio==1.5.9 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
networkx==3.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
nltk==3.8.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
numpy==1.24.4 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
openai==1.6.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
oauthlib==3.2.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
openai==1.8.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
openpyxl==3.1.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
packaging==23.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
pandas==1.5.3 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
pandasql==0.7.3 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
pillow==10.1.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
pillow==10.2.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
psutil==5.9.7 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
psycopg2-binary==2.9.9 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
pyarrow==14.0.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
pydantic==1.10.13 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
pydantic[dotenv]==1.10.13 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
python-dateutil==2.8.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
python-dotenv==1.0.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
python-editor==1.0.4 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
pytz==2023.3.post1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
pyyaml==6.0.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
readchar==4.0.5 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
regex==2023.10.3 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
regex==2023.12.25 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
requests==2.31.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
runs==1.2.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
safetensors==0.4.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
scikit-learn==1.3.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
scipy==1.10.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
sentence-transformers==2.2.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
sentencepiece==0.1.99 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
setuptools==69.0.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
setuptools==69.0.3 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
six==1.16.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
sniffio==1.3.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
soupsieve==2.5 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
sqlalchemy-utils==0.41.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
sqlalchemy==1.4.50 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
sqlalchemy[asyncio]==1.4.50 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
sqlalchemy==2.0.25 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
sqlalchemy[asyncio]==2.0.25 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
sqlglot==12.4.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
sqlparse==0.4.4 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
starlette==0.34.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
starlette==0.35.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
sympy==1.12 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
tenacity==8.2.3 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
threadpoolctl==3.2.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
thrift==0.16.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
tiktoken==0.5.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
tokenizers==0.15.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
toml==0.10.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
Expand All @@ -90,9 +103,10 @@ transformers==4.36.2 ; python_full_version >= "3.8.1" and python_full_version <=
typing-extensions==4.9.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
typing-inspect==0.9.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
urllib3==2.1.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
uvicorn==0.25.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
wcwidth==0.2.12 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
uvicorn==0.26.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
wcwidth==0.2.13 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
websockets==11.0.3 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
win32-setctime==1.1.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" and sys_platform == "win32"
wrapt==1.16.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
xmod==1.8.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
yarl==1.9.4 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
4 changes: 2 additions & 2 deletions sidekick/configs/prompt_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
### *History*:\n{_sample_queries}
### *Question*: For table {_table_name}, {_question}
# SELECT 1
### *Tasks for table {_table_name}*:\n{_tasks}
### *Plan for table {_table_name}*:\n{_tasks}
### *Policies for SQL generation*:
# Avoid overly complex SQL queries, favor concise human readable SQL queries which are easy to understand and debug
# Avoid patterns that might be vulnerable to SQL injection
Expand Down Expand Up @@ -118,7 +118,7 @@
- Only use supplied table names: **{table_name}** for generation
- Only use column names from the CREATE TABLE statement: **{column_info}** for generation. DO NOT USE any other column names outside of this.
- Avoid overly complex SQL queries, favor concise human readable SQL queries which are easy to understand and debug
- Avoid patterns that might be vulnerable to SQL injection, e.g. sanitize inputs
- Avoid patterns that might be vulnerable to SQL injection, e.g. use proper sanitization and escaping for raw user input
- Always cast the numerator as float when computing ratios
- Always use COUNT(1) instead of COUNT(*)
- If the question is asking for a rate, use COUNT to compute percentage
Expand Down
22 changes: 12 additions & 10 deletions sidekick/prompter.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
execute_query_pd, extract_table_names,
generate_suggestions, save_query, setup_dir)

__version__ = "0.2.0"
__version__ = "0.2.1"

# Load the config file and initialize required paths
app_base_path = (Path(__file__).parent / "../").resolve()
Expand All @@ -41,6 +41,8 @@
os.environ["TOKENIZERS_PARALLELISM"] = "False"
os.environ["H2O_BASE_MODEL_URL"] = h2ogpt_base_model_url
os.environ["H2O_BASE_MODEL_API_KEY"] = h2ogpt_base_model_key
os.environ["RECOMMENDATION_MODEL_REMOTE_URL"] = h2o_remote_url
os.environ["RECOMMENDATION_MODEL_API_KEY"] = h2o_key

def color(fore="", back="", text=None):
return f"{fore}{back}{text}{Style.RESET_ALL}"
Expand Down Expand Up @@ -103,7 +105,7 @@ def _get_table_info(cache_path: str, table_name: str = None):
if table_info_path is None:
# if table_info_path is None, generate default schema n set path
data_path = current_meta["samples_path"]
_, table_info_path = generate_schema(data_path, f"{cache_path}/{table_name}_table_info.jsonl")
_, table_info_path = generate_schema(data_path=data_path, output_path=f"{cache_path}/{table_name}_table_info.jsonl")
table_metadata = {"schema_info_path": table_info_path}
with open(f"{cache_path}/table_context.json", "w") as outfile:
json.dump(table_metadata, outfile, indent=4, sort_keys=False)
Expand Down Expand Up @@ -178,7 +180,7 @@ def recommend_suggestions(cache_path: str, table_name: str, n_qs: int=10):
@click.option("--data_path", default="data.csv", help="Enter the path of csv", type=str)
@click.option("--output_path", default="table_info.jsonl", help="Enter the path of generated schema in jsonl", type=str)
def generate_input_schema(data_path, output_path):
_, o_path = generate_schema(data_path, output_path)
_, o_path = generate_schema(data_path=data_path, output_path=output_path)
click.echo(f"Schema generated for the input data at {o_path}")


Expand Down Expand Up @@ -463,7 +465,7 @@ def ask(
"""

results = []
err = None # TODO - Need to handle errors if occurred
res = err = alt_res = None # TODO - Need to handle errors if occurred
# Book-keeping
base_path = local_base_path if local_base_path else default_base_path
setup_dir(base_path)
Expand Down Expand Up @@ -575,7 +577,7 @@ def ask(
click.echo("Skipping edit...")
if updated_tasks is not None:
sql_g._tasks = updated_tasks
alt_res = None

# The interface could also be used to simply execute user provided SQL
# Keyword: "Execute SQL: <SQL query>"
if (
Expand Down Expand Up @@ -650,12 +652,12 @@ def ask(
attempt = 0
error_condition = lambda e: ('OperationalError'.lower() in e.lower() or 'OperationError'.lower() in e.lower() or 'Syntax error'.lower() in e.lower()) if e else False
if self_correction and error_condition(err):
logger.info("Attempting to auto-correct the query...")
logger.info("Attempting to auto-correct the query during runtime...")
while attempt !=3 and error_condition(err):
try:
logger.debug(f"Attempt: {attempt+1}")
_tmp = err.split("\n")
_err = _tmp[0].split("Error occurred :")[1] if len(_tmp) > 0 else None
_err = _tmp[0].split("Error occurred:")[1] if len(_tmp) > 0 else None
env_url = os.environ["RECOMMENDATION_MODEL_REMOTE_URL"]
env_key = os.environ["RECOMMENDATION_MODEL_API_KEY"]
corr_sql = sql_g.self_correction(input_prompt=_val, error_msg=_err, remote_url=env_url, client_key=env_key)
Expand All @@ -667,7 +669,7 @@ def ask(
logger.error(f"Something went wrong:\n{e}")
attempt += 1
if m:
_t = "\nWarning:\n".join([str(q_res), m])
_t = "\n\n**Warning:**\n".join([str(q_res), m])
q_res = _t
elif option == "pandas":
tables = extract_table_names(_val)
Expand Down Expand Up @@ -697,7 +699,7 @@ def ask(
click.echo("Error in executing the query. Validate generated SQL and try again.")
click.echo("No result to display.")

results.append("**Result:** \n")
results.append("**Result:**\n")
if q_res:
# Check shape of the final result to avoid blowing up memory
# Logging a quick preview of the result
Expand All @@ -718,7 +720,7 @@ def ask(
else:
click.echo("Exiting...")
else:
results = ["I was not able to generate a response for the question. Please try re-phrasing."]
results = ["I was not able to generate a response for the question. Please try re-phrasing or try again."]
alt_res, err = None, None
except (MemoryError, RuntimeError, AttributeError) as e:
logger.error(f"Something went wrong while generating response: {e}")
Expand Down
Loading

0 comments on commit 83dd5b4

Please sign in to comment.