Skip to content

Commit

Permalink
renamed variables for clarity
Browse files Browse the repository at this point in the history
  • Loading branch information
wendy-aw committed Jun 14, 2024
1 parent ab26dfe commit 5087818
Show file tree
Hide file tree
Showing 6 changed files with 21 additions and 17 deletions.
3 changes: 2 additions & 1 deletion eval/gemini_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,10 @@ def generate_prompt(
question_instructions = question + " " + instructions

if table_metadata_string == "":
pruned_metadata_str = prune_metadata_str(
pruned_metadata_ddl, join_str = prune_metadata_str(
question_instructions, db_name, public_data, num_columns_to_keep, shuffle
)
pruned_metadata_str = pruned_metadata_ddl + join_str
else:
pruned_metadata_str = table_metadata_string

Expand Down
3 changes: 2 additions & 1 deletion eval/mistral_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,10 @@ def generate_prompt(
question_instructions = question + " " + instructions

if table_metadata_string == "":
pruned_metadata_str = prune_metadata_str(
pruned_metadata_ddl, join_str = prune_metadata_str(
question_instructions, db_name, public_data, columns_to_keep, shuffle
)
pruned_metadata_str = pruned_metadata_ddl + join_str
else:
pruned_metadata_str = table_metadata_string

Expand Down
3 changes: 2 additions & 1 deletion query_generators/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,14 @@ def generate_query(
model_prompt = file.read()
question_instructions = question + " " + instructions
if table_metadata_string == "":
pruned_metadata_str = prune_metadata_str(
pruned_metadata_ddl, join_str = prune_metadata_str(
question_instructions,
self.db_name,
self.use_public_data,
columns_to_keep,
shuffle,
)
pruned_metadata_str = pruned_metadata_ddl + join_str
else:
pruned_metadata_str = table_metadata_string
prompt = model_prompt.format(
Expand Down
3 changes: 2 additions & 1 deletion query_generators/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,13 +140,14 @@ def generate_query(
question_instructions = question + " " + instructions
if table_metadata_string == "":
if columns_to_keep > 0:
table_metadata_string = prune_metadata_str(
table_metadata_ddl, join_str = prune_metadata_str(
question_instructions,
self.db_name,
self.use_public_data,
columns_to_keep,
shuffle,
)
table_metadata_string = table_metadata_ddl + join_str
elif columns_to_keep == 0:
md = dbs[self.db_name]["table_metadata"]
table_metadata_string = to_prompt_schema(md, shuffle)
Expand Down
22 changes: 11 additions & 11 deletions utils/gen_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,15 +134,15 @@ def generate_prompt(
if columns_to_keep > 0:
from utils.pruning import prune_metadata_str

table_metadata_string, join_str = prune_metadata_str(
table_metadata_ddl, join_str = prune_metadata_str(
question_instructions,
db_name,
public_data,
columns_to_keep,
shuffle_metadata,
)
# remove triple backticks
table_metadata_string = table_metadata_string.replace("```", "").strip()
table_metadata_ddl = table_metadata_ddl.replace("```", "").strip()
elif columns_to_keep == 0:
if public_data:
import defog_data.supplementary as sup
Expand All @@ -155,7 +155,7 @@ def generate_prompt(

md = dbs[db_name]["table_metadata"]
table_names = list(md.keys())
table_metadata_string = to_prompt_schema(md, shuffle_metadata)
table_metadata_ddl = to_prompt_schema(md, shuffle_metadata)

# get join_str from column_join
join_list = []
Expand All @@ -173,34 +173,34 @@ def generate_prompt(
raise ValueError("columns_to_keep must be >= 0")

# add schema creation statements if relevant
schema_names = get_schema_names(table_metadata_string)
schema_names = get_schema_names(table_metadata_ddl)
if schema_names:
for schema_name in schema_names:
table_metadata_string = (
table_metadata_ddl = (
f"CREATE SCHEMA IF NOT EXISTS {schema_name};\n"
+ table_metadata_string
+ table_metadata_ddl
)
# transform metadata string to target dialect if necessary
if db_type in ["postgres", "snowflake"]:
table_metadata_string = table_metadata_string + join_str
table_metadata_string = table_metadata_ddl + join_str
elif db_type == "bigquery":
table_metadata_string = (
ddl_to_bigquery(table_metadata_string, "postgres", db_name, "")[0]
ddl_to_bigquery(table_metadata_ddl, "postgres", db_name, "")[0]
+ join_str
)
elif db_type == "mysql":
table_metadata_string = (
ddl_to_mysql(table_metadata_string, "postgres", db_name, "")[0]
ddl_to_mysql(table_metadata_ddl, "postgres", db_name, "")[0]
+ join_str
)
elif db_type == "sqlite":
table_metadata_string = (
ddl_to_sqlite(table_metadata_string, "postgres", db_name, "")[0]
ddl_to_sqlite(table_metadata_ddl, "postgres", db_name, "")[0]
+ join_str
)
elif db_type == "tsql":
table_metadata_string = (
ddl_to_tsql(table_metadata_string, "postgres", db_name, "")[0]
ddl_to_tsql(table_metadata_ddl, "postgres", db_name, "")[0]
+ join_str
)
else:
Expand Down
4 changes: 2 additions & 2 deletions utils/pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def prune_metadata_str(
emb = emb_tuple[0]
csv_descriptions = emb_tuple[1]
try:
table_metadata_csv, join_str = get_md_emb(
table_metadata_ddl, join_str = get_md_emb(
question,
emb[db_name],
csv_descriptions[db_name],
Expand All @@ -226,4 +226,4 @@ def prune_metadata_str(
raise ValueError(f"DB name `{db_name}` not found in public data")
else:
raise ValueError(f"DB name `{db_name}` not found in private data")
return table_metadata_csv, join_str
return table_metadata_ddl, join_str

0 comments on commit 5087818

Please sign in to comment.