Skip to content

Commit 09349fb

Browse files
authored
Create eval_degradation.py
1 parent ade3ce0 commit 09349fb

File tree

1 file changed

+265
-0
lines changed

1 file changed

+265
-0
lines changed

eval_degradation.py

+265
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,265 @@
1+
# -*- coding: utf-8 -*-
2+
3+
# modified from: https://github.com/kongds/Prompt-BERT/blob/main/evaluation.py
4+
5+
import sys
6+
import os
7+
import logging
8+
9+
# Set up logger
10+
logging.basicConfig(format='%(asctime)s : %(message)s', level=logging.DEBUG)
11+
12+
import torch
13+
import fcntl
14+
import time
15+
import argparse
16+
from prettytable import PrettyTable
17+
from transformers import AutoTokenizer
18+
from angle_emb import Pooler
19+
from modeling_llama import LlamaForCausalLM
20+
21+
22+
# Import SentEval
23+
sys.path.insert(0, '../SentEval')
24+
import senteval
25+
26+
27+
PATH_TO_DATA = '../SentEval/data'
28+
29+
30+
def print_table(task_names, scores):
31+
tb = PrettyTable()
32+
tb.field_names = task_names
33+
tb.add_row(scores)
34+
print(tb)
35+
36+
37+
def lock_and_write_file(file_path, content):
38+
with open(file_path, 'a') as file:
39+
while True:
40+
try:
41+
# Acquire an exclusive lock (non-blocking)
42+
fcntl.flock(file, fcntl.LOCK_EX | fcntl.LOCK_NB)
43+
44+
# Perform your write operations here
45+
file.write(content + '\n')
46+
file.flush()
47+
48+
except IOError as e:
49+
print("File is locked by another process. Can't write.")
50+
time.sleep(1)
51+
finally:
52+
# Release the lock
53+
fcntl.flock(file, fcntl.LOCK_UN)
54+
break
55+
56+
57+
def main():
58+
parser = argparse.ArgumentParser()
59+
parser.add_argument('--prompt', type=str, default='Summarize sentence "{text}" in one word:"')
60+
parser.add_argument("--tokenizer_name", type=str, default='')
61+
parser.add_argument("--pooling_strategy", type=str, default='cls_avg')
62+
parser.add_argument("--n_layer", type=int, default=None)
63+
parser.add_argument("--apply_bfloat16", type=int, default=1, choices=[0, 1])
64+
parser.add_argument("--model_name_or_path", type=str,
65+
help="Transformers' model name or path")
66+
parser.add_argument("--max_length", type=int, default=64,
67+
help="max length")
68+
parser.add_argument("--mode", type=str,
69+
choices=['dev', 'test', 'fasttest'],
70+
default='test',
71+
help="What evaluation mode to use (dev: fast mode, dev results; test: full mode, test results); fasttest: fast mode, test results")
72+
parser.add_argument("--task_set", type=str,
73+
choices=['sts', 'transfer', 'full', 'na'],
74+
default='sts',
75+
help="What set of tasks to evaluate on. If not 'na', this will override '--tasks'")
76+
parser.add_argument('--load_kbit', type=int,
77+
choices=[4,8,16],
78+
default=8,
79+
help="Load model in kbit")
80+
81+
parser.add_argument('--avg', action='store_true')
82+
parser.add_argument('--lora_weight', type=str, default=None)
83+
parser.add_argument('--pretrained_model_path', type=str, default=None)
84+
parser.add_argument('--checkpoint_path', type=str, default=None)
85+
86+
87+
args = parser.parse_args()
88+
89+
if args.apply_bfloat16:
90+
model = LlamaForCausalLM.from_pretrained(args.model_name_or_path).bfloat16().cuda()
91+
else:
92+
model = LlamaForCausalLM.from_pretrained(args.model_name_or_path,
93+
device_map='auto',
94+
torch_dtype=torch.float16)
95+
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
96+
97+
class Model:
98+
def __init__(self, model) -> None:
99+
self.model = model
100+
101+
def encode(self, texts, **kwargs):
102+
# print("texts>>>>>", texts)
103+
inputs = tokenizer(texts, padding='longest', truncation=True, max_length=args.max_length, return_tensors="pt")
104+
for key, val in inputs.items():
105+
inputs[key] = val.cuda()
106+
hidden_states = self.model(output_hidden_states=True, return_dict=True, n_layer=args.n_layer, **inputs).hidden_states[-1]
107+
batch_size = hidden_states.shape[0]
108+
if self.model.config.pad_token_id is None and batch_size != 1:
109+
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
110+
sequence_lengths = (torch.eq(inputs['input_ids'], self.model.config.pad_token_id).long().argmax(-1) - 1).to(
111+
hidden_states.device
112+
)
113+
114+
outputs = hidden_states[torch.arange(batch_size, device=hidden_states.device), sequence_lengths].float().detach().cpu().numpy()
115+
return outputs
116+
117+
model = Model(model=model)
118+
119+
# Set up the tasks
120+
if args.task_set == 'sts':
121+
args.tasks = ['STS12', 'STS13', 'STS14', 'STS15', 'STS16', 'STSBenchmark', 'SICKRelatedness']
122+
if args.mode == 'dev':
123+
args.tasks = ['STSBenchmark-dev']
124+
elif args.task_set == 'transfer':
125+
args.tasks = ['MR', 'CR', 'MPQA', 'SUBJ', 'SST2', 'TREC', 'MRPC']
126+
elif args.task_set == 'full':
127+
args.tasks = ['STS12', 'STS13', 'STS14', 'STS15', 'STS16', 'STSBenchmark', 'SICKRelatedness']
128+
args.tasks += ['MR', 'CR', 'MPQA', 'SUBJ', 'SST2', 'TREC', 'MRPC']
129+
130+
# Set params for SentEval
131+
if args.mode == 'dev' or args.mode == 'fasttest':
132+
# Fast mode
133+
params = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 5, 'batch_size': 32}
134+
params['classifier'] = {'nhid': 0, 'optim': 'rmsprop', 'batch_size': 32,
135+
'tenacity': 3, 'epoch_size': 2}
136+
elif args.mode == 'test':
137+
# Full mode
138+
params = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 10, 'batch_size':16}
139+
params['classifier'] = {'nhid': 0, 'optim': 'adam', 'batch_size': 64,
140+
'tenacity': 5, 'epoch_size': 4}
141+
else:
142+
raise NotImplementedError
143+
144+
# SentEval prepare and batcher
145+
def prepare(params, samples):
146+
return
147+
148+
def batcher(params, batch, max_length=None):
149+
# Handle rare token encoding issues in the dataset
150+
if len(batch) >= 1 and len(batch[0]) >= 1 and isinstance(batch[0][0], bytes):
151+
batch = [[word.decode('utf-8') for word in s] for s in batch]
152+
153+
sentences = [' '.join(s) for s in batch]
154+
if max_length == 500:
155+
sentences = [tokenizer.decode(tokenizer.encode(s, add_special_tokens=False)[:max_length]) for s in sentences]
156+
max_length = 512
157+
158+
if args.prompt is not None:
159+
for i, s in enumerate(sentences):
160+
if len(s) > 0 and s[-1] not in '.?"\'': s += '.'
161+
s = s.replace('"', '\'')
162+
if len(s) > 0 and '?' == s[-1]: s = s[:-1] + '.'
163+
sentences[i] = args.prompt.format(text=s)
164+
165+
return model.encode(sentences, to_numpy=True, max_length=args.max_length)
166+
167+
results = {}
168+
for task in args.tasks:
169+
se = senteval.engine.SE(params, batcher, prepare)
170+
result = se.eval(task)
171+
results[task] = result
172+
173+
# Print evaluation results
174+
if args.mode == 'dev':
175+
print("------ %s ------" % (args.mode))
176+
177+
task_names = []
178+
scores = []
179+
for task in ['STSBenchmark-dev']:
180+
task_names.append(task)
181+
if task in results:
182+
scores.append("%.2f" % (results[task]['dev']['spearman'][0] * 100))
183+
else:
184+
scores.append("0.00")
185+
print_table(task_names, scores)
186+
187+
if args.checkpoint_path is not None:
188+
# evaluate checkpoints on dev
189+
if os.path.exists(os.path.join(args.checkpoint_path, 'dev_results')):
190+
max_scores = 0
191+
with open(os.path.join(args.checkpoint_path, 'dev_results'), 'r') as f:
192+
for i in f:
193+
max_scores = max(max_scores, float(i.split()[1]))
194+
else:
195+
max_scores = 0
196+
197+
# save best checkpoint
198+
if float(scores[-1]) >= max_scores:
199+
import shutil
200+
if args.lora_weight is not None:
201+
shutil.copytree(args.lora_weight, os.path.join(args.checkpoint_path, 'best_model'), dirs_exist_ok=True)
202+
else:
203+
shutil.copytree(args.model_name_or_path, os.path.join(args.checkpoint_path, 'best_model'), dirs_exist_ok=True)
204+
205+
# log dev results
206+
with open(os.path.join(args.checkpoint_path, 'dev_results'), 'a') as f:
207+
prefix = args.mask_embedding_sentence_template if not args.avg else 'avg'
208+
line = prefix + ' ' +str(scores[-1]) + ' ' + \
209+
args.lora_weight if args.lora_weight is not None else args.model_name_or_path
210+
f.write( line + '\n')
211+
212+
task_names = []
213+
scores = []
214+
for task in ['MR', 'CR', 'SUBJ', 'MPQA', 'SST2', 'TREC', 'MRPC']:
215+
task_names.append(task)
216+
if task in results:
217+
scores.append("%.2f" % (results[task]['devacc']))
218+
else:
219+
scores.append("0.00")
220+
task_names.append("Avg.")
221+
scores.append("%.2f" % (sum([float(score) for score in scores]) / len(scores)))
222+
print_table(task_names, scores)
223+
224+
225+
elif args.mode == 'test' or args.mode == 'fasttest':
226+
print("------ %s ------" % (args.mode))
227+
228+
task_names = []
229+
scores = []
230+
for task in ['STS12', 'STS13', 'STS14', 'STS15', 'STS16', 'STSBenchmark', 'SICKRelatedness']:
231+
task_names.append(task)
232+
if task in results:
233+
if task in ['STS12', 'STS13', 'STS14', 'STS15', 'STS16']:
234+
scores.append("%.2f" % (results[task]['all']['spearman']['all'] * 100))
235+
else:
236+
scores.append("%.2f" % (results[task]['test']['spearman'].correlation * 100))
237+
else:
238+
scores.append("0.00")
239+
task_names.append("Avg.")
240+
scores.append("%.2f" % (sum([float(score) for score in scores]) / len(scores)))
241+
print_table(task_names, scores)
242+
#
243+
# write results and template to file
244+
if args.prompt is not None and args.task_set != 'transfer':
245+
with open('./sts-org-results', 'a') as f:
246+
bits = f'{args.load_kbit}bit'
247+
model_name = args.model_name_or_path.split('/')[-1] + f'({bits})'
248+
f.write(args.prompt.replace(' ', '_') + ' ' + model_name + ' ' + ' '.join([str(s) for s in scores]) + '\n')
249+
250+
task_names = []
251+
scores = []
252+
for task in ['MR', 'CR', 'SUBJ', 'MPQA', 'SST2', 'TREC', 'MRPC']:
253+
task_names.append(task)
254+
if task in results:
255+
scores.append("%.2f" % (results[task]['acc']))
256+
else:
257+
scores.append("0.00")
258+
task_names.append("Avg.")
259+
scores.append("%.2f" % (sum([float(score) for score in scores]) / len(scores)))
260+
print_table(task_names, scores)
261+
262+
263+
if __name__ == "__main__":
264+
main()
265+

0 commit comments

Comments
 (0)