Skip to content

Commit 03420a9

Browse files
author
Ming Li
committed
add eval code and result
1 parent 576be81 commit 03420a9

25 files changed

+23941
-88
lines changed

evaluation/generation/eva_generation.py

+38-22
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import os
66
from tqdm import tqdm
77

8-
PROMPT_DICT = {
8+
PROMPT_DICT_ALPACA = {
99
"prompt_input": (
1010
"Below is an instruction that describes a task, paired with an input that provides further context. "
1111
"Write a response that appropriately completes the request.\n\n"
@@ -17,20 +17,36 @@
1717
"### Instruction:\n{instruction}\n\n### Response:"
1818
),
1919
}
20+
PROMPT_DICT_WIZARDLM = {
21+
"prompt_input": (
22+
"{instruction}\n{input}\n\n### Response:"
23+
),
24+
"prompt_no_input": (
25+
"{instruction}\n\n### Response:"
26+
),
27+
}
28+
PROMPT_DICT_VICUNA = {
29+
"prompt_input": (
30+
"A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: {instruction}\nInput:\n{input} ASSISTANT:"
31+
),
32+
"prompt_no_input": (
33+
"A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: {instruction} ASSISTANT:"
34+
),
35+
}
2036

2137
def parse_args():
22-
parser = argparse.ArgumentParser(description="Finetune a transformers model on a summarization task")
38+
parser = argparse.ArgumentParser()
2339
parser.add_argument(
2440
"--dataset_name",
2541
type=str,
2642
default=None,
2743
help="The name of the dataset to use (via the datasets library).",
2844
)
2945
parser.add_argument(
30-
"--training_data_source_name",
46+
"--prompt",
3147
type=str,
3248
default='alpaca',
33-
help="The training_data_source_name.",
49+
help="alpaca, wiz, vicuna.",
3450
)
3551
parser.add_argument(
3652
"--num_beams",
@@ -47,12 +63,6 @@ def parse_args():
4763
help="Path to pretrained model or model identifier from huggingface.co/models.",
4864
required=False,
4965
)
50-
parser.add_argument(
51-
"--per_device_eval_batch_size",
52-
type=int,
53-
default=8,
54-
help="Batch size (per device) for the evaluation dataloader.",
55-
)
5666
parser.add_argument("--seed", type=int, default=0, help="A seed for reproducible training.")
5767
parser.add_argument("--max_length", type=int, default=1024)
5868
args = parser.parse_args()
@@ -69,26 +79,29 @@ def main():
6979

7080
model.to(device)
7181
model.eval()
72-
if(args.training_data_source_name=='alpaca'or args.training_data_source_name=='alpaca_gpt4'):
73-
prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]
82+
83+
if args.prompt == 'alpaca':
84+
prompt_input, prompt_no_input = PROMPT_DICT_ALPACA["prompt_input"], PROMPT_DICT_ALPACA["prompt_no_input"]
85+
elif args.prompt == 'wiz':
86+
prompt_input, prompt_no_input = PROMPT_DICT_WIZARDLM["prompt_input"], PROMPT_DICT_WIZARDLM["prompt_no_input"]
87+
elif args.prompt == 'vicuna':
88+
prompt_input, prompt_no_input = PROMPT_DICT_VICUNA["prompt_input"], PROMPT_DICT_VICUNA["prompt_no_input"]
89+
7490

7591
if(args.dataset_name=="vicuna"):
76-
dataset_path = './test_data/vicuna_test_set.jsonl'
92+
dataset_path = 'evaluation/test_data/vicuna_test_set.jsonl'
7793
prompt_key = 'text'
7894
elif(args.dataset_name=="koala"):
79-
dataset_path = './test_data/koala_test_set.jsonl'
95+
dataset_path = 'evaluation/test_data/koala_test_set.jsonl'
8096
prompt_key = 'prompt'
8197
elif(args.dataset_name=="sinstruct"):
82-
dataset_path = './test_data/sinstruct_test_set.jsonl'
98+
dataset_path = 'evaluation/test_data/sinstruct_test_set.jsonl'
8399
prompt_key = 'instruction'
84100
elif(args.dataset_name=="wizardlm"):
85-
dataset_path = './test_data/wizardlm_test_set.jsonl'
101+
dataset_path = 'evaluation/test_data/wizardlm_test_set.jsonl'
86102
prompt_key = 'Instruction'
87-
elif(args.dataset_name=="truthfulqa"):
88-
dataset_path = './test_data/truthfulqa_test_set.jsonl'
89-
prompt_key = 'Question'
90103
elif(args.dataset_name=="lima"):
91-
dataset_path = './test_data/lima_test_set.jsonl'
104+
dataset_path = 'evaluation/test_data/lima_test_set.jsonl'
92105
prompt_key = 'conversations'
93106

94107
with open(dataset_path) as f:
@@ -111,14 +124,17 @@ def main():
111124
generate_ids = model.generate(input_ids, max_length=args.max_length)
112125
outputs = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
113126
point['raw_output'] = outputs
114-
point['response'] = outputs.split("Response:")[1]
127+
if args.prompt in ['alpaca','wiz']:
128+
point['response'] = outputs.split("Response:")[1]
129+
elif args.prompt in ['vicuna']:
130+
point['response'] = outputs.split("ASSISTANT:")[1]
115131
results.append(point)
116132

117133
output_dir = os.path.join(args.model_name_or_path, 'test_inference')
118134
if not os.path.exists(output_dir):
119135
os.makedirs(output_dir)
120136

121-
saved_name = args.dataset_name + "_" + str(args.seed) + '_' + str(args.max_length) + ".json"
137+
saved_name = args.dataset_name + "_" + str(args.max_length) + ".json"
122138
with open(os.path.join(output_dir, saved_name), "w") as f:
123139
json.dump(results, f, indent=4)
124140

evaluation/generation/eval.py

+4-31
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def get_json_list(file_path):
8484

8585

8686
if __name__ == "__main__":
87-
parser = argparse.ArgumentParser(description="ChatGPT-based QA evaluation.")
87+
parser = argparse.ArgumentParser()
8888
parser.add_argument("--wraped_file",default='')
8989
parser.add_argument("--api_key",type=str,default='')
9090
parser.add_argument("--api_model",type=str,default='gpt-3.5-turbo')
@@ -120,8 +120,6 @@ def get_json_list(file_path):
120120
prompt_key = 'instruction'
121121
elif(dataset_name=="wizardlm"):
122122
prompt_key = 'Instruction'
123-
elif(dataset_name=="truthfulqa"):
124-
prompt_key = 'Question'
125123
elif(dataset_name=="lima"):
126124
prompt_key = 'conversations'
127125

@@ -192,12 +190,8 @@ def get_json_list(file_path):
192190
predictions_all.append(predictions)
193191

194192
all_scores = []
195-
ans1_win_idsx_list = []
196-
ans2_win_idsx_list = []
197193
for reverse in range(2):
198194
scores_list = []
199-
ans1_win_idsx = [0 for _ in range(total_len)]
200-
ans2_win_idsx = [0 for _ in range(total_len)]
201195
predictions = predictions_all[reverse]
202196
for idx, prediction in enumerate(predictions):
203197
review = prediction['choices'][0]['message']['content']
@@ -207,40 +201,19 @@ def get_json_list(file_path):
207201
qa_jsons[idx][review_key] = review
208202
qa_jsons[idx][scores_key] = str(scores)
209203
scores_list.append(scores)
210-
if scores[0] > scores[1]:
211-
if not reverse:
212-
ans1_win_idsx[idx] = 1
213-
else:
214-
ans2_win_idsx[idx] = 1
215-
elif scores[1] > scores[0]:
216-
if not reverse:
217-
ans2_win_idsx[idx] = 1
218-
else:
219-
ans1_win_idsx[idx] = 1
220204

221205
all_scores.append(scores_list)
222206
avg_scores = np.array(scores_list).mean(0)
223207
avg_key = 'average_scores' if not reverse else 'average_scores_reverse'
224208
meta_info[avg_key] = str(avg_scores.tolist())
225209

226-
ans1_win_idsx_list.append(ans1_win_idsx)
227-
ans2_win_idsx_list.append(ans2_win_idsx)
228-
229-
ans1_win_idx_overall = np.array(ans1_win_idsx_list[0]) * np.array(ans1_win_idsx_list[1])
230-
ans2_win_idx_overall = np.array(ans2_win_idsx_list[0]) * np.array(ans2_win_idsx_list[1])
231-
# ans1_win_count = ans1_win_idx_overall.sum()
232-
# ans2_win_count = ans2_win_idx_overall.sum()
233-
234-
# meta_info['ans1_win_count'] = ans1_win_count.tolist()
235-
# meta_info['ans2_win_count'] = ans2_win_count.tolist()
236-
237210
wraped_info['Meta_Info'] = meta_info
238211
wraped_info['data'] = qa_jsons
239212

240-
if args.api_model == 'gpt-4':
213+
if 'gpt-4' in args.api_model:
241214
output_review_file = args.wraped_file.strip('.json') + '_reviews_gpt4.json'
242-
else:
243-
output_review_file = args.wraped_file.strip('.json') + '_reviews.json'
215+
elif 'gpt-3.5' in args.api_model:
216+
output_review_file = args.wraped_file.strip('.json') + '_reviews_gpt3.5.json'
244217
with open(f"{output_review_file}", "w") as f:
245218
json.dump(wraped_info, f, indent=4)
246219
pass

evaluation/generation/eval_generation_wrap.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,20 @@
33
import argparse
44

55
def parse_args():
6-
parser = argparse.ArgumentParser(description="Finetune a transformers model on a summarization task")
6+
parser = argparse.ArgumentParser()
77
parser.add_argument("--dataset_name", type=str, default='', help="The name of the dataset to use.")
88
parser.add_argument("--fname1", type=str, default='')
99
parser.add_argument("--fname2", type=str, default='')
1010
parser.add_argument("--save_name", type=str, default='') # a vs b format
11+
parser.add_argument("--max_length", type=int, default=1024)
1112

1213
args = parser.parse_args()
1314
return args
1415

1516
args = parse_args()
1617

1718
print('args.dataset_name',args.dataset_name)
18-
f_name = args.dataset_name+'_0_1024.json'
19+
f_name = args.dataset_name+'_'+str(args.max_length)+'.json'
1920
args.fname1 = os.path.join(args.fname1,f_name)
2021
args.fname2 = os.path.join(args.fname2,f_name)
2122
print('args.fname1',args.fname1)
@@ -37,8 +38,6 @@ def parse_args():
3738
prompt_key = 'instruction'
3839
elif(args.dataset_name=="wizardlm"):
3940
prompt_key = 'Instruction'
40-
elif(args.dataset_name=="truthfulqa"):
41-
prompt_key = 'Question'
4241
elif(args.dataset_name=="lima"):
4342
prompt_key = 'conversations'
4443

evaluation/generation/review_eval_score.py

+23-31
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,22 @@
22
import json
33
import numpy as np
44
import matplotlib.pyplot as plt
5+
import argparse
56

6-
review_home_path = 'logs/xxx1-VSxxx2'
7-
datasets = ['Vicuna','Koala','WizardLM','SInstruct','LIMA']
8-
# datasets = ['Vicuna','Koala','WizardLM','SInstruct']
7+
def parse_args():
8+
parser = argparse.ArgumentParser()
9+
parser.add_argument("--review_home_path", type=str, default='', help="home path that save the reviews")
10+
parser.add_argument('--task_list', nargs='+', type=str, default=['Vicuna','Koala','WizardLM','SInstruct','LIMA'])
11+
parser.add_argument("--key1", type=str, default='Model1')
12+
parser.add_argument("--key2", type=str, default='Model2')
13+
parser.add_argument("--save_name", type=str, default='result') # a vs b format
14+
parser.add_argument("--max_length", type=int, default=1024)
15+
parser.add_argument("--api_model",type=str,default='gpt-3.5-turbo')
916

10-
save_name = review_home_path.split('/')[-1]
17+
args = parser.parse_args()
18+
return args
1119

12-
key1, key2 = save_name.split('-VS-')[0],save_name.split('-VS-')[1]
13-
title_ = save_name
14-
15-
# key1 = 'Pre-Experienced Selected by Alpaca (15%)'
16-
# # key2 = 'WizardLM' + r"$^*$" + '(100%)'
17-
# key2 = 'Alpaca (100%)'
18-
# title_ = key1 + ' vs. ' + key2
20+
args = parse_args()
1921

2022

2123
def survey(results, category_names):
@@ -84,20 +86,22 @@ def get_scores_all(pure_data):
8486
score3 += 1
8587
return [score1, score2, score3]
8688

87-
for dataset in datasets:
89+
for dataset in args.task_list:
8890
review_path = ''
89-
for root, ds, fs in os.walk(review_home_path):
91+
for root, ds, fs in os.walk(args.review_home_path):
9092
for f in fs:
91-
if 'reviews' in f and f.endswith('.json') and dataset.lower() in f:
92-
review_path = os.path.join(root, f)
93-
# if 'reviews_gpt4' in f and f.endswith('.json') and dataset.lower() in f:
94-
# review_path = os.path.join(root, f)
93+
if 'gpt-3.5' in args.api_model:
94+
if 'reviews_gpt3.5' in f and f.endswith('.json') and dataset.lower() in f:
95+
review_path = os.path.join(root, f)
96+
elif 'gpt-4' in args.api_model:
97+
if 'reviews_gpt4' in f and f.endswith('.json') and dataset.lower() in f:
98+
review_path = os.path.join(root, f)
9599
with open(review_path, "r") as f:
96100
review_data = json.load(f)
97101
pure_data = review_data['data']
98102

99103
scores = get_scores_all(pure_data)
100-
category_names = [f"{key1} wins", "Tie", f"{key2} wins"]
104+
category_names = [f"{args.key1} wins", "Tie", f"{args.key2} wins"]
101105
results[dataset] = scores
102106

103107
def cal_rate(results):
@@ -112,18 +116,6 @@ def cal_rate(results):
112116

113117
cal_rate(results)
114118
survey(results, category_names)
115-
img_path = os.path.join(review_home_path,save_name+'.jpg')
116-
plt.title(title_)
119+
img_path = os.path.join(args.review_home_path,args.save_name+'.jpg')
117120
plt.savefig(img_path)
118121
pass
119-
120-
# from PIL import Image
121-
# def crop_edges(image_path, left, upper, right, lower):
122-
# with Image.open(image_path) as img:
123-
# width, height = img.size
124-
# cropped = img.crop((left, upper, width - right, height - lower))
125-
# return cropped
126-
# cropped_img = crop_edges(img_path,45,45,45,45)
127-
# cropped_img.save(img_path)
128-
# pass
129-

0 commit comments

Comments
 (0)