Skip to content

Commit dd65a0f

Browse files
committed
- Separate table_aliases as its own field instead of nesting logic within cot_instructions, which we prefer to leave as an experimental field
- Rename join_hints to join_str to better match the variable name
1 parent 9330b5c commit dd65a0f

File tree

7 files changed

+31
-34
lines changed

7 files changed

+31
-34
lines changed

README.md

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -381,14 +381,18 @@ python -W ignore main.py \
381381

382382
### Bedrock
383383

384+
Before running this, you would need to export the following environment variables for the boto3 client to work:
385+
- `AWS_ACCESS_KEY_ID`
386+
- `AWS_SECRET_ACCESS_KEY`
387+
- `AWS_DEFAULT_REGION`
388+
384389
```bash
385390
python3 main.py \
386391
-db postgres \
387392
-q data/instruct_basic_postgres.csv data/instruct_advanced_postgres.csv data/questions_gen_postgres.csv \
388393
-o results/bedrock_llama_70b_basic.csv results/bedrock_llama_70b_advanced.csv results/bedrock_llama_70b_v1.csv \
389394
-g bedrock \
390395
-f prompts/prompt_cot_postgres.md \
391-
--cot_table_alias prealias \
392396
-m meta.llama3-70b-instruct-v1:0 \
393397
-c 0 \
394398
-p 10
@@ -405,7 +409,6 @@ python3 main.py \
405409
-o results/together_llama_70b_basic.csv results/together_llama_70b_advanced.csv results/together_llama_70b_v1.csv \
406410
-g together \
407411
-f prompts/prompt_together.json \
408-
--cot_table_alias prealias \
409412
-m "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo" \
410413
-c 0 \
411414
-p 10
@@ -437,14 +440,14 @@ You can use the following flags in the command line to change the configurations
437440

438441
### Inference-technique-related parameters
439442

440-
| CLI Flags | Description |
441-
| ---------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --- |
442-
| -f, --prompt_file | Markdown file with the prompt used for query generation. You can pass in a list of prompts to test sequentially without reloading the script. |
443-
| -b, --num_beams | Indicates the number of beams you want to use for beam search at inference. Only available for `hf_runner`, `vllm_runner`, and `api_runner`. |
444-
| -c, --num_columns | Number of columns, default 20. To not prune the columns, set it to 0. |
445-
| -s, --shuffle_metadata | Shuffle metadata, default False. This shuffles the order of the tables within the schema and the order of the columns within each table but does not shift columns between tables (to preserve the structure of the database). |
446-
| -k, --k_shot | Used when you want to include k-shot examples in your prompt. Make sure that the column 'k_shot_prompt' exists in your questions_file. |
447-
| --cot_table_alias | Used when you want to include chain-of-thought instructions before the actual sql generation. Allowed values are `instruct`, `prealias` and `pregen`. If using `instruct` or `prealias`, make sure that the placeholder '{cot_instructions}' exists in your prompt file. `instruct` will get your model generate the chain-of-thought table aliases, while `prealias` would already generate the aliases in the prompt. | |
443+
| CLI Flags | Description | |
444+
| ---------------------- |------------- | --- |
445+
| -f, --prompt_file | Markdown file with the prompt used for query generation. You can pass in a list of prompts to test sequentially without reloading the script. |
446+
| -b, --num_beams | Indicates the number of beams you want to use for beam search at inference. Only available for `hf_runner`, `vllm_runner`, and `api_runner`. |
447+
| -c, --num_columns | Number of columns, default 20. To not prune the columns, set it to 0. |
448+
| -s, --shuffle_metadata | Shuffle metadata, default False. This shuffles the order of the tables within the schema and the order of the columns within each table but does not shift columns between tables (to preserve the structure of the database). |
449+
| -k, --k_shot | Used when you want to include k-shot examples in your prompt. Make sure that the column 'k_shot_prompt' exists in your questions_file. |
450+
| --cot_table_alias | (Experimental) Used when you want to include chain-of-thought instructions before the actual sql generation. Allowed values are `instruct`. If using `instruct`, make sure that the placeholder '{cot_instructions}' exists in your prompt file. `instruct` will get your model generate the chain-of-thought table aliases. |
448451

449452
### Execution-related parameters
450453

main.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@
2727
parser.add_argument("-c", "--num_columns", type=int, default=0)
2828
parser.add_argument("-s", "--shuffle_metadata", action="store_true")
2929
parser.add_argument("-k", "--k_shot", action="store_true")
30-
parser.add_argument("--cot_table_alias", type=str)
30+
parser.add_argument(
31+
"--cot_table_alias", type=str, choices=["instruct", "pregen", ""], default=""
32+
)
3133
# execution-related parameters
3234
parser.add_argument("-o", "--output_file", nargs="+", type=str, required=True)
3335
parser.add_argument("-p", "--parallel_threads", type=int, default=5)

prompts/prompt_cot.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ Generate a {db_type} query to answer this question: `{user_question}`
77
DDL statements:
88
{table_metadata_string}
99

10-
{cot_instructions}Generate a valid {db_type} query that best answers the question `{user_question}`.<|eot_id|><|start_header_id|>assistant<|end_header_id|>
10+
{table_aliases}Generate a valid {db_type} query that best answers the question `{user_question}`.<|eot_id|><|start_header_id|>assistant<|end_header_id|>
1111

1212
I will reflect on the user's request before answering the question.
1313

prompts/prompt_cot_postgres.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@ Generate a SQL query to answer this question: `{user_question}`
66
{instructions}
77
DDL statements:
88
{table_metadata_string}
9-
{join_hints}
9+
{join_str}
1010

11-
{cot_instructions}Generate a valid SQL query that best answers the question `{user_question}`.<|eot_id|><|start_header_id|>assistant<|end_header_id|>
11+
{table_aliases}Generate a valid SQL query that best answers the question `{user_question}`.<|eot_id|><|start_header_id|>assistant<|end_header_id|>
1212

1313
I will reflect on the user's request before answering the question.
1414

prompts/prompt_cot_sqlite.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ Generate a {db_type} query to answer this question: `{user_question}`
77
DDL statements:
88
{table_metadata_string}
99

10-
{cot_instructions}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
10+
{table_aliases}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
1111

1212
I was asked to generate a SQL query for this question: `{user_question}`
1313

utils/gen_prompt.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -287,10 +287,8 @@ def generate_prompt(
287287
query_1=query_1,
288288
cot_instructions=cot_instructions,
289289
instruction_reflections=instruction_reflections,
290-
join_hints=join_str,
290+
table_aliases=table_aliases,
291+
join_str=join_str,
291292
pruned_join_hints=pruned_join_str,
292293
)
293-
if cot_pregen:
294-
table_aliases = generate_aliases(table_names)
295-
prompt = prompt + table_aliases
296294
return prompt

utils/questions.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,11 @@ def prepare_questions_df(
9393
else:
9494
question_query_df["table_metadata_string"] = ""
9595

96+
# get table_aliases
97+
question_query_df["table_aliases"] = question_query_df["db_name"].apply(
98+
get_table_aliases
99+
)
100+
96101
# get prev_invalid_sql if applicable
97102
if "prev_invalid_sql" in question_query_df.columns:
98103
question_query_df["prev_invalid_sql"] = question_query_df[
@@ -127,25 +132,14 @@ def prepare_questions_df(
127132
else:
128133
question_query_df["query_1"] = ""
129134

130-
# add all cot instructions to the `cot_instructions` column
135+
# add all cot instructions to the respective columns
136+
question_query_df["cot_instructions"] = ""
137+
question_query_df["cot_pregen"] = False
131138
if cot_table_alias == "instruct":
132139
question_query_df["cot_instructions"] = (
133140
"List the table aliases for each table as comments, starting with the most relevant tables to the question."
134141
)
135-
elif cot_table_alias == "prealias":
136-
question_query_df["cot_instructions"] = question_query_df["db_name"].apply(
137-
get_table_aliases
138-
)
139-
question_query_df["table_aliases"] = question_query_df["db_name"].apply(
140-
get_table_aliases
141-
)
142-
else:
143-
question_query_df["cot_instructions"] = ""
144-
question_query_df["table_aliases"] = ""
145-
146-
if cot_table_alias == "pregen":
142+
elif cot_table_alias == "pregen":
147143
question_query_df["cot_pregen"] = True
148-
else:
149-
question_query_df["cot_pregen"] = False
150144

151145
return question_query_df

0 commit comments

Comments
 (0)