From 5087818a5c9c4067f279a4e9b92d9184b5e7e971 Mon Sep 17 00:00:00 2001 From: wendy Date: Fri, 14 Jun 2024 12:36:56 +0800 Subject: [PATCH] renamed variables for clarity --- eval/gemini_runner.py | 3 ++- eval/mistral_runner.py | 3 ++- query_generators/anthropic.py | 3 ++- query_generators/openai.py | 3 ++- utils/gen_prompt.py | 22 +++++++++++----------- utils/pruning.py | 4 ++-- 6 files changed, 21 insertions(+), 17 deletions(-) diff --git a/eval/gemini_runner.py b/eval/gemini_runner.py index 459b2ed..031e73a 100644 --- a/eval/gemini_runner.py +++ b/eval/gemini_runner.py @@ -42,9 +42,10 @@ def generate_prompt( question_instructions = question + " " + instructions if table_metadata_string == "": - pruned_metadata_str = prune_metadata_str( + pruned_metadata_ddl, join_str = prune_metadata_str( question_instructions, db_name, public_data, num_columns_to_keep, shuffle ) + pruned_metadata_str = pruned_metadata_ddl + join_str else: pruned_metadata_str = table_metadata_string diff --git a/eval/mistral_runner.py b/eval/mistral_runner.py index 6e92277..b9b56e1 100644 --- a/eval/mistral_runner.py +++ b/eval/mistral_runner.py @@ -45,9 +45,10 @@ def generate_prompt( question_instructions = question + " " + instructions if table_metadata_string == "": - pruned_metadata_str = prune_metadata_str( + pruned_metadata_ddl, join_str = prune_metadata_str( question_instructions, db_name, public_data, columns_to_keep, shuffle ) + pruned_metadata_str = pruned_metadata_ddl + join_str else: pruned_metadata_str = table_metadata_string diff --git a/query_generators/anthropic.py b/query_generators/anthropic.py index 6af470d..5cb77c1 100644 --- a/query_generators/anthropic.py +++ b/query_generators/anthropic.py @@ -99,13 +99,14 @@ def generate_query( model_prompt = file.read() question_instructions = question + " " + instructions if table_metadata_string == "": - pruned_metadata_str = prune_metadata_str( + pruned_metadata_ddl, join_str = prune_metadata_str( question_instructions, self.db_name, self.use_public_data, columns_to_keep, shuffle, ) + pruned_metadata_str = pruned_metadata_ddl + join_str else: pruned_metadata_str = table_metadata_string prompt = model_prompt.format( diff --git a/query_generators/openai.py b/query_generators/openai.py index e45b0f1..fb161d5 100644 --- a/query_generators/openai.py +++ b/query_generators/openai.py @@ -140,13 +140,14 @@ def generate_query( question_instructions = question + " " + instructions if table_metadata_string == "": if columns_to_keep > 0: - table_metadata_string = prune_metadata_str( + table_metadata_ddl, join_str = prune_metadata_str( question_instructions, self.db_name, self.use_public_data, columns_to_keep, shuffle, ) + table_metadata_string = table_metadata_ddl + join_str elif columns_to_keep == 0: md = dbs[self.db_name]["table_metadata"] table_metadata_string = to_prompt_schema(md, shuffle) diff --git a/utils/gen_prompt.py b/utils/gen_prompt.py index c38b04c..ad5c6a8 100644 --- a/utils/gen_prompt.py +++ b/utils/gen_prompt.py @@ -134,7 +134,7 @@ def generate_prompt( if columns_to_keep > 0: from utils.pruning import prune_metadata_str - table_metadata_string, join_str = prune_metadata_str( + table_metadata_ddl, join_str = prune_metadata_str( question_instructions, db_name, public_data, @@ -142,7 +142,7 @@ def generate_prompt( shuffle_metadata, ) # remove triple backticks - table_metadata_string = table_metadata_string.replace("```", "").strip() + table_metadata_ddl = table_metadata_ddl.replace("```", "").strip() elif columns_to_keep == 0: if public_data: import defog_data.supplementary as sup @@ -155,7 +155,7 @@ def generate_prompt( md = dbs[db_name]["table_metadata"] table_names = list(md.keys()) - table_metadata_string = to_prompt_schema(md, shuffle_metadata) + table_metadata_ddl = to_prompt_schema(md, shuffle_metadata) # get join_str from column_join join_list = [] @@ -173,34 +173,34 @@ def generate_prompt( raise ValueError("columns_to_keep must be >= 0") # add schema creation statements if relevant - schema_names = get_schema_names(table_metadata_string) + schema_names = get_schema_names(table_metadata_ddl) if schema_names: for schema_name in schema_names: - table_metadata_string = ( + table_metadata_ddl = ( f"CREATE SCHEMA IF NOT EXISTS {schema_name};\n" - + table_metadata_string + + table_metadata_ddl ) # transform metadata string to target dialect if necessary if db_type in ["postgres", "snowflake"]: - table_metadata_string = table_metadata_string + join_str + table_metadata_string = table_metadata_ddl + join_str elif db_type == "bigquery": table_metadata_string = ( - ddl_to_bigquery(table_metadata_string, "postgres", db_name, "")[0] + ddl_to_bigquery(table_metadata_ddl, "postgres", db_name, "")[0] + join_str ) elif db_type == "mysql": table_metadata_string = ( - ddl_to_mysql(table_metadata_string, "postgres", db_name, "")[0] + ddl_to_mysql(table_metadata_ddl, "postgres", db_name, "")[0] + join_str ) elif db_type == "sqlite": table_metadata_string = ( - ddl_to_sqlite(table_metadata_string, "postgres", db_name, "")[0] + ddl_to_sqlite(table_metadata_ddl, "postgres", db_name, "")[0] + join_str ) elif db_type == "tsql": table_metadata_string = ( - ddl_to_tsql(table_metadata_string, "postgres", db_name, "")[0] + ddl_to_tsql(table_metadata_ddl, "postgres", db_name, "")[0] + join_str ) else: diff --git a/utils/pruning.py b/utils/pruning.py index cb3492b..cb97a38 100644 --- a/utils/pruning.py +++ b/utils/pruning.py @@ -212,7 +212,7 @@ def prune_metadata_str( emb = emb_tuple[0] csv_descriptions = emb_tuple[1] try: - table_metadata_csv, join_str = get_md_emb( + table_metadata_ddl, join_str = get_md_emb( question, emb[db_name], csv_descriptions[db_name], @@ -226,4 +226,4 @@ def prune_metadata_str( raise ValueError(f"DB name `{db_name}` not found in public data") else: raise ValueError(f"DB name `{db_name}` not found in private data") - return table_metadata_csv, join_str + return table_metadata_ddl, join_str