diff --git a/eval/anthropic_runner.py b/eval/anthropic_runner.py index f8bb1b7..c759f1c 100644 --- a/eval/anthropic_runner.py +++ b/eval/anthropic_runner.py @@ -77,8 +77,6 @@ def run_anthropic_eval(args): row = input_rows[i] result_dict = f.result() query_gen = result_dict["query"] - print("Query for") - print(query_gen) reason = result_dict["reason"] err = result_dict["err"] # save custom metrics @@ -142,9 +140,6 @@ def run_anthropic_eval(args): os.makedirs(output_dir) output_df.to_csv(output_file, index=False, float_format="%.2f") - # get average rate of exact matches - avg_acc = output_df["exact_match"].sum() / len(output_df) - print(f"Average rate of exact match: {avg_acc:.2f}") # get average rate of correct results avg_subset = output_df["correct"].sum() / len(output_df) print(f"Average correct rate: {avg_subset:.2f}") diff --git a/query_generators/anthropic.py b/query_generators/anthropic.py index 5cb77c1..b03f824 100644 --- a/query_generators/anthropic.py +++ b/query_generators/anthropic.py @@ -6,6 +6,7 @@ from query_generators.query_generator import QueryGenerator from utils.pruning import prune_metadata_str +from utils.gen_prompt import to_prompt_schema anthropic = Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY")) @@ -95,18 +96,30 @@ def generate_query( self.query = "" self.reason = "" + if self.use_public_data: + from defog_data.metadata import dbs + else: + # raise Exception("Replace this with your private data import") + from defog_data_private.metadata import dbs + with open(self.prompt_file) as file: model_prompt = file.read() question_instructions = question + " " + instructions if table_metadata_string == "": - pruned_metadata_ddl, join_str = prune_metadata_str( - question_instructions, - self.db_name, - self.use_public_data, - columns_to_keep, - shuffle, - ) - pruned_metadata_str = pruned_metadata_ddl + join_str + if columns_to_keep > 0: + pruned_metadata_ddl, join_str = prune_metadata_str( + question_instructions, + self.db_name, + self.use_public_data, + columns_to_keep, + shuffle, + ) + pruned_metadata_str = pruned_metadata_ddl + join_str + elif columns_to_keep == 0: + md = dbs[self.db_name]["table_metadata"] + pruned_metadata_str = to_prompt_schema(md, shuffle) + else: + raise ValueError("columns_to_keep must be >= 0") else: pruned_metadata_str = table_metadata_string prompt = model_prompt.format(