Skip to content

Commit

Permalink
fixed anthropic runner issues
Browse files Browse the repository at this point in the history
  • Loading branch information
rishsriv committed Jun 24, 2024
1 parent 93b1860 commit 4be51eb
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 13 deletions.
5 changes: 0 additions & 5 deletions eval/anthropic_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,6 @@ def run_anthropic_eval(args):
row = input_rows[i]
result_dict = f.result()
query_gen = result_dict["query"]
print("Query for")
print(query_gen)
reason = result_dict["reason"]
err = result_dict["err"]
# save custom metrics
Expand Down Expand Up @@ -142,9 +140,6 @@ def run_anthropic_eval(args):
os.makedirs(output_dir)
output_df.to_csv(output_file, index=False, float_format="%.2f")

# get average rate of exact matches
avg_acc = output_df["exact_match"].sum() / len(output_df)
print(f"Average rate of exact match: {avg_acc:.2f}")
# get average rate of correct results
avg_subset = output_df["correct"].sum() / len(output_df)
print(f"Average correct rate: {avg_subset:.2f}")
Expand Down
29 changes: 21 additions & 8 deletions query_generators/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from query_generators.query_generator import QueryGenerator
from utils.pruning import prune_metadata_str
from utils.gen_prompt import to_prompt_schema

anthropic = Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY"))

Expand Down Expand Up @@ -95,18 +96,30 @@ def generate_query(
self.query = ""
self.reason = ""

if self.use_public_data:
from defog_data.metadata import dbs
else:
# raise Exception("Replace this with your private data import")
from defog_data_private.metadata import dbs

with open(self.prompt_file) as file:
model_prompt = file.read()
question_instructions = question + " " + instructions
if table_metadata_string == "":
pruned_metadata_ddl, join_str = prune_metadata_str(
question_instructions,
self.db_name,
self.use_public_data,
columns_to_keep,
shuffle,
)
pruned_metadata_str = pruned_metadata_ddl + join_str
if columns_to_keep > 0:
pruned_metadata_ddl, join_str = prune_metadata_str(
question_instructions,
self.db_name,
self.use_public_data,
columns_to_keep,
shuffle,
)
pruned_metadata_str = pruned_metadata_ddl + join_str
elif columns_to_keep == 0:
md = dbs[self.db_name]["table_metadata"]
pruned_metadata_str = to_prompt_schema(md, shuffle)
else:
raise ValueError("columns_to_keep must be >= 0")
else:
pruned_metadata_str = table_metadata_string
prompt = model_prompt.format(
Expand Down

0 comments on commit 4be51eb

Please sign in to comment.