-
Notifications
You must be signed in to change notification settings - Fork 1
/
eval_CoT.py
70 lines (57 loc) · 2.5 KB
/
eval_CoT.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import json
from tqdm import tqdm
import os
from nltk import sent_tokenize
import re
def load_json_file(in_file, return_dict=False):
with open(in_file, "r") as f:
data = f.readlines()
if return_dict:
all_data = [json.loads(i) for i in data]
return all_data
else:
return data
def test_code_solution_for_predict(test_file=None, golden_file=None, results_file=None):
data = load_json_file(test_file, return_dict=True)
golden_data = load_json_file(golden_file, return_dict=True)
right = 0
all_num = len(data)
with open(results_file, "w") as f_out:
for i, (item, y) in enumerate(tqdm(zip(data, golden_data))):
cot = item["output"]
if isinstance(y["num_answer"], str):
answer = int(y["num_answer"].replace(',',''))
else:
answer = int(y["num_answer"])
pattern = r'Answer: ([\S]+)'
match = re.search(pattern, cot)
if match:
final_answer = match.group(1)
else:
sens = sent_tokenize(cot)
final_answer = sens[-1]
if str(answer) in final_answer:
right += 1
acc = right/all_num
print("Accuracy: {}".format(acc))
with open(results_file, "w") as f_out:
f_out.write("Accuracy: {}".format(acc))
if __name__ == '__main__':
# gsm8k
# result_dir = "model/3.5-turbo/codet5-large"
# test_file = result_dir + "/" + "generated_predictions.json"
# gloden_file = "dataset/gsm8k-generat-data/gsm8k-1/test_file/test_add_code.json"
# result_file = save_dir + "/" + "results.txt"
# test_code_solution_for_predict(test_file=test_file, golden_file=gloden_file, results_file=result_file)
# SVAMP
# result_dir = "model/3.5-turbo/codet5-small"
# test_file = result_dir + "/SVAMP/" + "generated_predictions.json"
# gloden_file = "dataset/SVAMP/svamp_refine.json"
# result_file = result_dir + "/" + "SVAMP_results.txt"
# test_code_solution_for_predict(test_file=test_file, golden_file=gloden_file, results_file=result_file)
# MultiArith
result_dir = "model/3.5-turbo/codet5-large"
test_file = result_dir + "/MultiArith/" + "generated_predictions.json"
gloden_file = "dataset/MultiArith/multi_arith_refine.json"
result_file = result_dir + "/" + "MultiArith_results.txt"
test_code_solution_for_predict(test_file=test_file, golden_file=gloden_file, results_file=result_file)