Skip to content

Commit

Permalink
added a utility to generate table aliases, and modified the prompt_co…
Browse files Browse the repository at this point in the history
…t accordingly
  • Loading branch information
rishsriv committed Jun 4, 2024
1 parent 60fe1bb commit 5de773d
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 5 deletions.
8 changes: 4 additions & 4 deletions prompts/prompt_cot.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
<|start_header_id|>user<|end_header_id|>

Generate a SQL query to answer this question: `{user_question}`
{instructions}{glossary}
{instructions}
DDL statements:
{table_metadata_string}

{cot_instructions}Generate a valid SQL query that answers the question `{user_question}`, and only references the tables and columns in the DDL statements.<|eot_id|><|start_header_id|>assistant<|end_header_id|>
{table_metadata_string}<|eot_id|><|start_header_id|>assistant<|end_header_id|>

The following SQL query best answers the question `{user_question}`
```sql
32 changes: 31 additions & 1 deletion utils/gen_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,30 @@ def to_prompt_schema(
md_create += ");\n"
return md_create

def generate_aliases(table_names: list) -> str:
"""
Generate aliases for table names
"""
aliases = {}
for table_name in table_names:
alias = table_name[0]
if alias in aliases.values() and "_" in table_name:
alias = table_name.split("_")[0] + table_name.split("_")[1]
if alias in aliases.values():
alias = table_name[:2]
if alias in aliases.values():
alias = table_name[:3]
num = 2
while alias in aliases.values():
alias = table_name[0] + str(num)
num += 1

aliases[table_name] = alias

aliases_str = ""
for table_name, alias in aliases.items():
aliases_str += f"-- {table_name} AS {alias}, "
return aliases

def generate_prompt(
prompt_file,
Expand All @@ -81,6 +105,7 @@ def generate_prompt(
with open(prompt_file, "r") as f:
prompt = f.read()
question_instructions = question + " " + instructions
table_names = []

if table_metadata_string == "":
if columns_to_keep > 0:
Expand Down Expand Up @@ -119,6 +144,7 @@ def generate_prompt(
join_list = ""

md = dbs[db_name]["table_metadata"]
table_names = list(md.keys())
table_metadata_string = to_prompt_schema(md, shuffle_metadata)

schema_names = get_schema_names(table_metadata_string)
Expand Down Expand Up @@ -174,6 +200,10 @@ def generate_prompt(
query_0=query_0,
question_1=question_1,
query_1=query_1,
cot_instructions=cot_instructions,
cot_instructions="",
)

if "cot_instructions" in prompt:
table_aliases = generate_aliases(table_names)
prompt = prompt + table_aliases
return prompt

0 comments on commit 5de773d

Please sign in to comment.