Skip to content

Commit

Permalink
add prealias option
Browse files Browse the repository at this point in the history
  • Loading branch information
wongjingping committed Jun 7, 2024
1 parent e366b60 commit 7b3cded
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 5 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ You can use the following flags in the command line to change the configurations
| -c, --num_columns | Number of columns, default 20. To not prune the columns, set it to 0. |
| -s, --shuffle_metadata | Shuffle metadata, default False. This shuffles the order of the tables within the schema and the order of the columns within each table but does not shift columns between tables (to preserve the structure of the database). |
| -k, --k_shot | Used when you want to include k-shot examples in your prompt. Make sure that the column 'k_shot_prompt' exists in your questions_file. |
| --cot_table_alias | Used when you want to include chain-of-thought instructions before the actual sql generation. Allowed values are `instruct` and `pregen`. If using `instruct`, make sure that the placeholder '{cot_instructions}' exists in your prompt file. | |
| --cot_table_alias | Used when you want to include chain-of-thought instructions before the actual sql generation. Allowed values are `instruct`, `prealias` and `pregen`. If using `instruct` or `prealias`, make sure that the placeholder '{cot_instructions}' exists in your prompt file. `instruct` will get your model generate the chain-of-thought table aliases, while `prealias` would already generate the aliases in the prompt. | |
### Execution-related parameters
Expand Down
2 changes: 1 addition & 1 deletion run_checkpoints.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/bin/zsh

model_names=("sqlcoder_8b_fullft_ds_003_llama3_mgn1_b1_0900_b2_0990")
model_names=("sqlcoder_8b_fullft_ds_011_llama3_mgn1_b1_0900_b2_0990")
PORT=8082 # avoid 8081 as it's used by nginx
export CUDA_VISIBLE_DEVICES=0 # set gpu you want to use (just 1 will do)

Expand Down
4 changes: 2 additions & 2 deletions run_checkpoints_cot.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/bin/zsh

model_names=("sqlcoder_8b_fullft_ds_003_llama3_mgn1_b1_0900_b2_0990")
model_names=("sqlcoder_8b_fullft_ds_011_llama3_mgn1_b1_0900_b2_0990")
PORT=8083 # avoid 8081 as it's used by nginx
export CUDA_VISIBLE_DEVICES=1 # set gpu you want to use (just 1 will do)

Expand Down Expand Up @@ -46,7 +46,7 @@ for model_name in "${model_names[@]}"; do
--api_url "http://localhost:${PORT}/generate" \
--api_type "vllm" \
-p 10 \
--cot_table_alias
--cot_table_alias "prealias"
# finally, kill the api server
pkill -9 -f "python3 utils/api_server.py.*--port ${PORT}"
done
Expand Down
20 changes: 19 additions & 1 deletion utils/questions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,26 @@
import pandas as pd


def get_table_aliases(db_name: str) -> str:
from defog_data.metadata import dbs
from defog_utils.utils_db import generate_aliases

metadata = dbs[db_name]["table_metadata"]
table_names = list(metadata.keys())
aliases = generate_aliases(table_names)
aliases_instruction = (
"Use the following table aliases when referencing tables in the query:\n"
+ aliases
)
return aliases_instruction


def prepare_questions_df(
questions_file: str,
db_type: str,
num_questions: Optional[int] = None,
k_shot: bool = False,
cot_table_alias: bool = False,
cot_table_alias: str = "",
):
question_query_df = pd.read_csv(questions_file, nrows=num_questions)
question_query_df["db_type"] = db_type
Expand Down Expand Up @@ -118,6 +132,10 @@ def prepare_questions_df(
question_query_df["cot_instructions"] = (
"List the table aliases for each table as comments, starting with the most relevant tables to the question."
)
elif cot_table_alias == "prealias":
question_query_df["cot_instructions"] = question_query_df["db_name"].apply(
get_table_aliases
)
else:
question_query_df["cot_instructions"] = ""

Expand Down

0 comments on commit 7b3cded

Please sign in to comment.