Skip to content

Commit

Permalink
- split md str from join str to facilitate translations
Browse files Browse the repository at this point in the history
- fixed splitting in get_md_emb for numeric/decimal cases with parentheses
  • Loading branch information
wendy-aw committed Jun 14, 2024
1 parent 6690087 commit 397a499
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 34 deletions.
18 changes: 12 additions & 6 deletions tests/test_utils_pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def test_get_md_emb_no_shuffle(test_metadata):
threshold = 0.0

# Call the function and get the result
result = get_md_emb(
result, join_str = get_md_emb(
question,
column_emb,
column_csv,
Expand All @@ -112,11 +112,13 @@ def test_get_md_emb_no_shuffle(test_metadata):
id integer, --unique id for country, not iso code
);
```
"""
expected_join_str = """
Here is a list of joinable columns:
airport.country_id can be joined with country.id
"""
assert result == expected
assert join_str == expected_join_str


def test_get_md_emb_shuffle(test_metadata):
Expand All @@ -127,7 +129,7 @@ def test_get_md_emb_shuffle(test_metadata):
threshold = 0.0

# Call the function and get the result
result = get_md_emb(
result, join_str = get_md_emb(
question,
column_emb,
column_csv,
Expand All @@ -154,11 +156,14 @@ def test_get_md_emb_shuffle(test_metadata):
airport_name text, --name of the airport
);
```
"""
expected_join_str = """
Here is a list of joinable columns:
airport.country_id can be joined with country.id
"""

assert result == expected
assert join_str == expected_join_str


def test_get_md_emb_sql_emb_empty(test_metadata):
Expand All @@ -168,7 +173,7 @@ def test_get_md_emb_sql_emb_empty(test_metadata):
threshold = 1.0 # arbitrarily high threshold to test empty results

# Call the function and get the result
result = get_md_emb(
result, join_str = get_md_emb(
question,
column_emb,
column_csv,
Expand All @@ -179,6 +184,7 @@ def test_get_md_emb_sql_emb_empty(test_metadata):
threshold,
)
assert result == ""
assert join_str == ""


def test_get_md_emb_coldesc(test_metadata_diff_coldesc):
Expand All @@ -189,7 +195,7 @@ def test_get_md_emb_coldesc(test_metadata_diff_coldesc):
threshold = 0.0

# Call the function and get the result
result = get_md_emb(
result, join_str = get_md_emb(
question,
column_emb,
column_csv,
Expand Down
41 changes: 20 additions & 21 deletions utils/gen_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,13 +128,13 @@ def generate_prompt(
question_instructions = question + " " + instructions
table_names = []

column_join = {}
join_str = ""
# retrieve metadata, either pruned or full
if table_metadata_string == "":
if columns_to_keep > 0:
from utils.pruning import prune_metadata_str

table_metadata_string = prune_metadata_str(
table_metadata_string, join_str = prune_metadata_str(
question_instructions,
db_name,
public_data,
Expand All @@ -156,22 +156,21 @@ def generate_prompt(
md = dbs[db_name]["table_metadata"]
table_names = list(md.keys())
table_metadata_string = to_prompt_schema(md, shuffle_metadata)
else:
raise ValueError("columns_to_keep must be >= 0")

# get join list if retrieving full metadata
join_list = []
for values in column_join.values():
col_1, col_2 = values[0]
# add to join_list
join_str = f"{col_1} can be joined with {col_2}"
if join_str not in join_list:
join_list.append(join_str)

if len(join_list) > 0:
join_list = "\nHere is a list of joinable columns:\n" + "\n".join(join_list)
# get join_str from column_join
join_list = []
for values in column_join.values():
col_1, col_2 = values[0]
# add to join_list
join_str = f"{col_1} can be joined with {col_2}"
if join_str not in join_list:
join_list.append(join_str)
if len(join_list) > 0:
join_str = "\nHere is a list of joinable columns:\n" + "\n".join(join_list)
else:
join_str = ""
else:
join_list = ""
raise ValueError("columns_to_keep must be >= 0")

# add schema creation statements if relevant
schema_names = get_schema_names(table_metadata_string)
Expand All @@ -183,26 +182,26 @@ def generate_prompt(
)
# transform metadata string to target dialect if necessary
if db_type in ["postgres", "snowflake"]:
table_metadata_string = table_metadata_string + join_list
table_metadata_string = table_metadata_string + join_str
elif db_type == "bigquery":
table_metadata_string = (
ddl_to_bigquery(table_metadata_string, "postgres", db_name, "")[0]
+ join_list
+ join_str
)
elif db_type == "mysql":
table_metadata_string = (
ddl_to_mysql(table_metadata_string, "postgres", db_name, "")[0]
+ join_list
+ join_str
)
elif db_type == "sqlite":
table_metadata_string = (
ddl_to_sqlite(table_metadata_string, "postgres", db_name, "")[0]
+ join_list
+ join_str
)
elif db_type == "tsql":
table_metadata_string = (
ddl_to_tsql(table_metadata_string, "postgres", db_name, "")[0]
+ join_list
+ join_str
)
else:
raise ValueError(
Expand Down
16 changes: 9 additions & 7 deletions utils/pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def get_md_emb(
else:
table_name, column_name = table_col_name.split(".", 1)
# check if column_info has numeric or decimal in it
if "numeric" in column_info.lower() or "decimal" in column_info.lower():
if "numeric(" in column_info.lower() or "decimal(" in column_info.lower():
column_type, col_desc = column_info.split("),", 1)
column_type += ")"
else:
Expand Down Expand Up @@ -188,10 +188,12 @@ def get_md_emb(
md_str = format_topk_sql(topk_table_columns, shuffle)

if len(join_list) > 0:
md_str += "\nHere is a list of joinable columns:\n"
md_str += "\n".join(join_list)
md_str += "\n"
return md_str
join_str = "\nHere is a list of joinable columns:\n"
join_str += "\n".join(join_list)
join_str += "\n"
else:
join_str = ""
return md_str, join_str


def prune_metadata_str(
Expand All @@ -210,7 +212,7 @@ def prune_metadata_str(
emb = emb_tuple[0]
csv_descriptions = emb_tuple[1]
try:
table_metadata_csv = get_md_emb(
table_metadata_csv, join_str = get_md_emb(
question,
emb[db_name],
csv_descriptions[db_name],
Expand All @@ -224,4 +226,4 @@ def prune_metadata_str(
raise ValueError(f"DB name `{db_name}` not found in public data")
else:
raise ValueError(f"DB name `{db_name}` not found in private data")
return table_metadata_csv
return table_metadata_csv, join_str

0 comments on commit 397a499

Please sign in to comment.