Skip to content

Commit 1ff516c

Browse files
committed
renamed fiels for openai runner
1 parent e64d987 commit 1ff516c

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

eval/openai_runner.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ def run_openai_eval(args):
1414
question_query_df["generated_query"] = ""
1515
question_query_df["reason"] = ""
1616
question_query_df["error_msg"] = ""
17+
question_query_df["exact_match"] = 0
1718
question_query_df["correct"] = 0
18-
question_query_df["subset"] = 0
1919
question_query_df["error_query_gen"] = 0
2020
question_query_df["error_db_exec"] = 0
2121
question_query_df["timeout"] = 0
@@ -84,7 +84,7 @@ def run_openai_eval(args):
8484
db_name = row["db_name"]
8585
question = row["question"]
8686
query_category = row["query_category"]
87-
correct = subset = 0
87+
exact_match = correct = 0
8888
generated_result = expected_result = None
8989
db_creds = {
9090
"host": "localhost",
@@ -103,23 +103,23 @@ def run_openai_eval(args):
103103
query_gen, db_name, db_creds, args.timeout_exec
104104
)
105105
generated_result = generated_result.rename(columns=str.lower)
106-
correct = subset = int(
106+
exact_match = correct = int(
107107
compare_df(
108108
expected_result, generated_result, query_category, question
109109
)
110110
)
111-
if not correct:
112-
subset = subset_df(
111+
if not exact_match:
112+
correct = subset_df(
113113
df_sub=expected_result,
114114
df_super=generated_result,
115115
query_category=query_category,
116116
question=question,
117117
verbose=args.verbose,
118118
)
119+
row["exact_match"] = int(exact_match)
119120
row["correct"] = int(correct)
120-
row["subset"] = int(subset)
121121
row["error_msg"] = ""
122-
if subset:
122+
if correct:
123123
total_correct += 1
124124
except QueryCanceledError as e:
125125
row["timeout"] = 1
@@ -136,8 +136,8 @@ def run_openai_eval(args):
136136
output_df.to_csv(args.output_file, index=False, float_format="%.2f")
137137

138138
# get average accuracy
139-
avg_acc = output_df["correct"].sum() / len(output_df)
139+
avg_acc = output_df["exact_match"].sum() / len(output_df)
140140
print(f"Average accuracy: {avg_acc:.2f}")
141141
# get average subset or correct accuracy
142-
avg_subset = output_df["subset"].sum() / len(output_df)
142+
avg_subset = output_df["correct"].sum() / len(output_df)
143143
print(f"Average subset accuracy: {avg_subset:.2f}")

0 commit comments

Comments
 (0)