Skip to content

Commit

Permalink
Bookkeeping v0.0.3 n fixed updating table metadata
Browse files Browse the repository at this point in the history
  • Loading branch information
pramitchoudhary committed Jun 21, 2023
1 parent 65cfd1d commit 3e3ae39
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 9 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@ Step1:
- docker pull postgres (will pull the latest version)
- docker run --rm --name pgsql-dev -e POSTGRES_PASSWORD=abc -p 5432:5432 postgres
Step2: Download and install .whl --> s3://sql-sidekick/releases/sql_sidekick-0.0.2-py3-none-any.whl
Step2: Download and install .whl --> s3://sql-sidekick/releases/sql_sidekick-0.0.3-py3-none-any.whl
- python3 -m venv .sidekickvenv
- source .sidekickvenv/bin/activate
- python3 -m pip install sql_sidekick-0.0.2-py3-none-any.whl
- python3 -m pip install sql_sidekick-0.0.3-py3-none-any.whl
```
## Start
```
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "sql-sidekick"
version = "0.0.2"
version = "0.0.3"
license = "Proprietary"
description = "An AI assistant for SQL"
authors = [
Expand Down
26 changes: 22 additions & 4 deletions sidekick/prompter.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,24 @@ def _get_table_info(cache_path: str):
return table_info_path


def update_table_info(cache_path: str, table_info_path: str = None, table_name: str = None):
if Path(f"{cache_path}/table_context.json").exists():
f = open(f"{cache_path}/table_context.json", "r")
table_metadata = json.load(f)
if table_name:
table_metadata["tables_in_use"] = [table_name]
if table_info_path:
table_metadata["schema_info_path"] = table_info_path
else:
if table_name:
table_metadata = {"tables_in_use": [table_name]}
if table_info_path:
table_metadata = {"schema_info_path": table_info_path}

with open(f"{cache_path}/table_context.json", "w") as outfile:
json.dump(table_metadata, outfile, indent=4, sort_keys=False)


@configure.command("db-setup", help="Enter information to configure postgres database locally")
@click.option("--db_name", "-n", default="querydb", help="Database name", prompt="Enter Database name")
@click.option("--hostname", "-h", default="localhost", help="Database hostname", prompt="Enter hostname name")
Expand Down Expand Up @@ -130,17 +148,17 @@ def db_setup(db_name: str, hostname: str, user_name: str, password: str, port: i
else:
break

if table_info_path is None:
table_info_path = _get_table_info(path)

if val.lower() == "y" or val.lower() == "yes":
table_value = input("Enter table name: ")
click.echo(f"Table name: {table_value}")
# set table name
db_obj.table_name = table_value.replace(" ", "_")

if table_info_path is None:
table_info_path = _get_table_info(path)

db_obj.create_table(table_info_path)

update_table_info(path, table_info_path, db_obj.table_name)
# Check if table exists; pending --> and doesn't have any rows
if db_obj.has_table():
click.echo(f"Checked table {db_obj.table_name} exists in the DB.")
Expand Down
7 changes: 5 additions & 2 deletions sidekick/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
import sqlglot
import toml
from langchain import OpenAI
from llama_index import GPTSimpleVectorIndex, GPTSQLStructStoreIndex, LLMPredictor, ServiceContext, SQLDatabase
from llama_index import (GPTSimpleVectorIndex, GPTSQLStructStoreIndex,
LLMPredictor, ServiceContext, SQLDatabase)
from llama_index.indices.struct_store import SQLContextContainerBuilder
from sidekick.configs.prompt_template import DEBUGGING_PROMPT, QUERY_PROMPT, TASK_PROMPT
from sidekick.configs.prompt_template import (DEBUGGING_PROMPT, QUERY_PROMPT,
TASK_PROMPT)
from sidekick.logger import logger
from sidekick.utils import csv_parser, filter_samples, remove_duplicates
from sqlalchemy import create_engine
Expand Down Expand Up @@ -171,6 +173,7 @@ def generate_tasks(self, table_names: list, input_question: str):
filtered_context = (
filter_samples(input_question, updated_context, m_path) if len(updated_context) > 1 else updated_context
)
logger.info(f"Number of possible contextual queries to question: {len(filtered_context)}")
_queries = "\n".join(filtered_context)
self.content_queries = _queries

Expand Down

0 comments on commit 3e3ae39

Please sign in to comment.