Skip to content

Commit

Permalink
minor compatibility changes to make new eval work with hf_runner
Browse files Browse the repository at this point in the history
  • Loading branch information
rishsriv committed Aug 15, 2023
1 parent c9a365b commit 9aedeb3
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 23 deletions.
2 changes: 1 addition & 1 deletion eval/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,9 +203,9 @@ def compare_query_results(
query_gen: str,
db_name: str,
db_creds: dict,
timeout: float,
question: str,
query_category: str,
timeout: float = 10.0,
) -> "tuple[bool, bool]":
"""
Compares the results of two queries and returns a tuple of booleans, where the first element is
Expand Down
39 changes: 17 additions & 22 deletions eval/hf_runner.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from eval.eval import compare_df, query_postgres_db, subset_df
from eval.eval import compare_query_results
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
Expand All @@ -7,7 +7,6 @@
from psycopg2.extensions import QueryCanceledError
from time import time
import gc
import traceback


def prepare_questions_df(questions_file, num_questions):
Expand Down Expand Up @@ -128,29 +127,25 @@ def run_hf_eval(
question = row["question"]
query_category = row["query_category"]
exact_match = correct = 0
generated_result = expected_result = None
db_creds = {
"host": "localhost",
"port": 5432,
"user": "postgres",
"password": "postgres",
"database": db_name,
}

try:
expected_result = query_postgres_db(golden_query, db_name).rename(
columns=str.lower
exact_match, correct = compare_query_results(
query_gold=golden_query,
query_gen=generated_query,
db_name=db_name,
db_creds=db_creds,
question=question,
query_category=query_category,
)

generated_result = query_postgres_db(generated_query, db_name).rename(
columns=str.lower
)

exact_match = correct = int(
compare_df(
expected_result, generated_result, query_category, question
)
)
if not exact_match:
correct = subset_df(
df_sub=expected_result,
df_super=generated_result,
query_category=query_category,
question=question,
)
row["exact_match"] = int(exact_match)
row["correct"] = int(correct)
row["exact_match"] = int(exact_match)
row["correct"] = int(correct)
row["error_msg"] = ""
Expand Down

0 comments on commit 9aedeb3

Please sign in to comment.