diff --git a/README.md b/README.md index e8d259f..3feecf5 100644 --- a/README.md +++ b/README.md @@ -371,7 +371,7 @@ You can use the following flags in the command line to change the configurations | -c, --num_columns | Number of columns, default 20. To not prune the columns, set it to 0. | | -s, --shuffle_metadata | Shuffle metadata, default False. This shuffles the order of the tables within the schema and the order of the columns within each table but does not shift columns between tables (to preserve the structure of the database). | | -k, --k_shot | Used when you want to include k-shot examples in your prompt. Make sure that the column 'k_shot_prompt' exists in your questions_file. | -| --cot_table_alias | Used when you want to include chain-of-thought instructions before the actual sql generation. Allowed values are `instruct` and `pregen`. If using `instruct`, make sure that the placeholder '{cot_instructions}' exists in your prompt file. | | +| --cot_table_alias | Used when you want to include chain-of-thought instructions before the actual sql generation. Allowed values are `instruct`, `prealias` and `pregen`. If using `instruct` or `prealias`, make sure that the placeholder '{cot_instructions}' exists in your prompt file. `instruct` will get your model generate the chain-of-thought table aliases, while `prealias` would already generate the aliases in the prompt. | | ### Execution-related parameters diff --git a/run_checkpoints.sh b/run_checkpoints.sh index 2971c04..454eaf9 100755 --- a/run_checkpoints.sh +++ b/run_checkpoints.sh @@ -1,6 +1,6 @@ #!/bin/zsh -model_names=("sqlcoder_8b_fullft_ds_003_llama3_mgn1_b1_0900_b2_0990") +model_names=("sqlcoder_8b_fullft_ds_011_llama3_mgn1_b1_0900_b2_0990") PORT=8082 # avoid 8081 as it's used by nginx export CUDA_VISIBLE_DEVICES=0 # set gpu you want to use (just 1 will do) diff --git a/run_checkpoints_cot.sh b/run_checkpoints_cot.sh index bdb7c6c..9f4fb51 100755 --- a/run_checkpoints_cot.sh +++ b/run_checkpoints_cot.sh @@ -1,6 +1,6 @@ #!/bin/zsh -model_names=("sqlcoder_8b_fullft_ds_003_llama3_mgn1_b1_0900_b2_0990") +model_names=("sqlcoder_8b_fullft_ds_011_llama3_mgn1_b1_0900_b2_0990") PORT=8083 # avoid 8081 as it's used by nginx export CUDA_VISIBLE_DEVICES=1 # set gpu you want to use (just 1 will do) @@ -46,7 +46,7 @@ for model_name in "${model_names[@]}"; do --api_url "http://localhost:${PORT}/generate" \ --api_type "vllm" \ -p 10 \ - --cot_table_alias + --cot_table_alias "prealias" # finally, kill the api server pkill -9 -f "python3 utils/api_server.py.*--port ${PORT}" done diff --git a/utils/questions.py b/utils/questions.py index 7584c44..5119fd2 100644 --- a/utils/questions.py +++ b/utils/questions.py @@ -2,12 +2,26 @@ import pandas as pd +def get_table_aliases(db_name: str) -> str: + from defog_data.metadata import dbs + from defog_utils.utils_db import generate_aliases + + metadata = dbs[db_name]["table_metadata"] + table_names = list(metadata.keys()) + aliases = generate_aliases(table_names) + aliases_instruction = ( + "Use the following table aliases when referencing tables in the query:\n" + + aliases + ) + return aliases_instruction + + def prepare_questions_df( questions_file: str, db_type: str, num_questions: Optional[int] = None, k_shot: bool = False, - cot_table_alias: bool = False, + cot_table_alias: str = "", ): question_query_df = pd.read_csv(questions_file, nrows=num_questions) question_query_df["db_type"] = db_type @@ -118,6 +132,10 @@ def prepare_questions_df( question_query_df["cot_instructions"] = ( "List the table aliases for each table as comments, starting with the most relevant tables to the question." ) + elif cot_table_alias == "prealias": + question_query_df["cot_instructions"] = question_query_df["db_name"].apply( + get_table_aliases + ) else: question_query_df["cot_instructions"] = ""