-
Notifications
You must be signed in to change notification settings - Fork 0
/
checkpoints_to_csv.py
98 lines (80 loc) · 3.62 KB
/
checkpoints_to_csv.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from testing_script import load_model, chat
import csv
import os
from dataclasses import dataclass, field
from typing import Optional
from transformers import HfArgumentParser
from tqdm import trange, tqdm
@dataclass
class ScriptArguments:
question_file: Optional[str] = field(default=None, metadata={"help": "the file path"})
model_name: Optional[str] = field(default="meta-llama/Llama-2-13b-chat-hf", metadata={"help": "the model name"})
def get_questions(csv_file_path):
try:
# Open the CSV file in read mode
with open(csv_file_path, 'r', newline='') as csv_file:
# Create a CSV DictReader object
csv_reader = csv.DictReader(csv_file)
# Iterate through the rows of the CSV data
questions = []
for row in csv_reader:
# You can access values by column name (header)
questions.append(row['question'])
return questions
except FileNotFoundError:
print(f"The file '{csv_file_path}' was not found.")
except Exception as e:
print(f"An error occurred: {e}")
if __name__ == "__main__":
# parse filename
parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses()[0]
csv_file_path = script_args.question_file
# get questions from csv
questions = get_questions(csv_file_path)
# create list of all models
base_model = script_args.model_name
all_models = [{'model_name':base_model, 'model_name_or_path':None}] # untrained
# add sfts
base_dirs = ["sft_1", "sft_2","sft_3","sft_4","sft_5",]
checkpoints = ["final_checkpoint"]
for base_dir in base_dirs:
for checkpoint in checkpoints:
all_models.append({'model_name':base_model, 'model_name_or_path':f"{base_dir}/{checkpoint}"})
# add dpos
base_dirs = ["dpo_1", "dpo_2","dpo_3","dpo_4","dpo_5",]
checkpoints = ["checkpoint-20", "checkpoint-40", "checkpoint-60", "checkpoint-80", "checkpoint-100",
"checkpoint-120", "checkpoint-140", "checkpoint-160", "checkpoint-180", "checkpoint-200",]
for base_dir in base_dirs:
for checkpoint in checkpoints:
all_models.append({'model_name':base_model, 'model_name_or_path':f"{base_dir}/{checkpoint}"})
# create new csv file
model_and_answer_pairs = {}
for model_settings in tqdm(all_models):
model_name = model_settings['model_name']
model_name_or_path = model_settings['model_name_or_path']
# load model
try:
model, tokenizer = load_model(model_name, model_name_or_path)
except Exception as e:
print(f"Skipping {model_name_or_path} due to error: {e}")
continue
# ask all questions
answers = []
for question in questions:
answer = chat(question, model, tokenizer)
answers.append(answer)
# add to dictionary
model_and_answer_pairs[model_name_or_path] = answers
# write to csv. each row is a question, each column is a model
os.makedirs('output_data', exist_ok=True)
with open('output_data/answers.csv', 'w', newline='') as csvfile:
writer = csv.writer(csvfile)
models = list(model_and_answer_pairs.keys())
models_renamed = [m if m is not None else "untrained" for m in models]
writer.writerow(['question'] + models_renamed)
for i, question in enumerate(questions):
row = [question]
for model_name_or_path in model_and_answer_pairs.keys():
row.append(model_and_answer_pairs[model_name_or_path][i])
writer.writerow(row)