From 3e3ae39351b8232bdb334b5b03d692e1646c27da Mon Sep 17 00:00:00 2001 From: pramitchoudhary Date: Wed, 21 Jun 2023 08:48:33 -0700 Subject: [PATCH] Bookkeeping v0.0.3 n fixed updating table metadata --- README.md | 4 ++-- pyproject.toml | 2 +- sidekick/prompter.py | 26 ++++++++++++++++++++++---- sidekick/query.py | 7 +++++-- 4 files changed, 30 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index ee4620a..d55fe86 100644 --- a/README.md +++ b/README.md @@ -17,10 +17,10 @@ Step1: - docker pull postgres (will pull the latest version) - docker run --rm --name pgsql-dev -e POSTGRES_PASSWORD=abc -p 5432:5432 postgres -Step2: Download and install .whl --> s3://sql-sidekick/releases/sql_sidekick-0.0.2-py3-none-any.whl +Step2: Download and install .whl --> s3://sql-sidekick/releases/sql_sidekick-0.0.3-py3-none-any.whl - python3 -m venv .sidekickvenv - source .sidekickvenv/bin/activate -- python3 -m pip install sql_sidekick-0.0.2-py3-none-any.whl +- python3 -m pip install sql_sidekick-0.0.3-py3-none-any.whl ``` ## Start ``` diff --git a/pyproject.toml b/pyproject.toml index 75dd282..10d6d5f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "sql-sidekick" -version = "0.0.2" +version = "0.0.3" license = "Proprietary" description = "An AI assistant for SQL" authors = [ diff --git a/sidekick/prompter.py b/sidekick/prompter.py index bb84f33..e64b08f 100644 --- a/sidekick/prompter.py +++ b/sidekick/prompter.py @@ -86,6 +86,24 @@ def _get_table_info(cache_path: str): return table_info_path +def update_table_info(cache_path: str, table_info_path: str = None, table_name: str = None): + if Path(f"{cache_path}/table_context.json").exists(): + f = open(f"{cache_path}/table_context.json", "r") + table_metadata = json.load(f) + if table_name: + table_metadata["tables_in_use"] = [table_name] + if table_info_path: + table_metadata["schema_info_path"] = table_info_path + else: + if table_name: + table_metadata = {"tables_in_use": [table_name]} + if table_info_path: + table_metadata = {"schema_info_path": table_info_path} + + with open(f"{cache_path}/table_context.json", "w") as outfile: + json.dump(table_metadata, outfile, indent=4, sort_keys=False) + + @configure.command("db-setup", help="Enter information to configure postgres database locally") @click.option("--db_name", "-n", default="querydb", help="Database name", prompt="Enter Database name") @click.option("--hostname", "-h", default="localhost", help="Database hostname", prompt="Enter hostname name") @@ -130,17 +148,17 @@ def db_setup(db_name: str, hostname: str, user_name: str, password: str, port: i else: break + if table_info_path is None: + table_info_path = _get_table_info(path) + if val.lower() == "y" or val.lower() == "yes": table_value = input("Enter table name: ") click.echo(f"Table name: {table_value}") # set table name db_obj.table_name = table_value.replace(" ", "_") - - if table_info_path is None: - table_info_path = _get_table_info(path) - db_obj.create_table(table_info_path) + update_table_info(path, table_info_path, db_obj.table_name) # Check if table exists; pending --> and doesn't have any rows if db_obj.has_table(): click.echo(f"Checked table {db_obj.table_name} exists in the DB.") diff --git a/sidekick/query.py b/sidekick/query.py index 7f3adc1..d0df3aa 100644 --- a/sidekick/query.py +++ b/sidekick/query.py @@ -8,9 +8,11 @@ import sqlglot import toml from langchain import OpenAI -from llama_index import GPTSimpleVectorIndex, GPTSQLStructStoreIndex, LLMPredictor, ServiceContext, SQLDatabase +from llama_index import (GPTSimpleVectorIndex, GPTSQLStructStoreIndex, + LLMPredictor, ServiceContext, SQLDatabase) from llama_index.indices.struct_store import SQLContextContainerBuilder -from sidekick.configs.prompt_template import DEBUGGING_PROMPT, QUERY_PROMPT, TASK_PROMPT +from sidekick.configs.prompt_template import (DEBUGGING_PROMPT, QUERY_PROMPT, + TASK_PROMPT) from sidekick.logger import logger from sidekick.utils import csv_parser, filter_samples, remove_duplicates from sqlalchemy import create_engine @@ -171,6 +173,7 @@ def generate_tasks(self, table_names: list, input_question: str): filtered_context = ( filter_samples(input_question, updated_context, m_path) if len(updated_context) > 1 else updated_context ) + logger.info(f"Number of possible contextual queries to question: {len(filtered_context)}") _queries = "\n".join(filtered_context) self.content_queries = _queries