-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathmetric.py
337 lines (262 loc) · 10.2 KB
/
metric.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
import os
import time
import numpy as np
import regex as re
from claudette import Chat, models
from evaluate import load
from anthropic import RateLimitError
import regex as re
class Metric:
def __init__(self, **kwargs):
self._load_metric(**kwargs)
def _load_metric(self, **kwargs):
raise NotImplementedError("This method should be overridden by subclasses.")
def compute(self, prompts, predictions, references):
raise NotImplementedError("This method should be overridden by subclasses.")
class Rouge(Metric):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def _load_metric(self, **kwargs):
self.metric = load("rouge", keep_in_memory=True)
def compute(self, prompts, predictions, references):
return self.metric.compute(predictions=predictions, references=references)
class Bleurt(Metric):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def _load_metric(self, **kwargs):
self.metric = load("bleurt", keep_in_memory=True)
def compute(self, prompts, predictions, references):
return np.mean(
self.metric.compute(predictions=predictions, references=references)[
"scores"
]
)
class BertScore(Metric):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def _load_metric(self, **kwargs):
self.metric = load("bertscore", keep_in_memory=True)
def compute(self, prompts, predictions, references):
result = self.metric.compute(
predictions=predictions, references=references, lang="en"
)
return {
"precision": np.mean(result["precision"]),
"recall": np.mean(result["recall"]),
"f1": np.mean(result["f1"]),
}
class Accuracy(Metric):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def _load_metric(self, **kwargs):
from sklearn.metrics import accuracy_score
self.metric = accuracy_score
def compute(self, prompts, predictions, references):
return self.metric(references, predictions)
class ExactMatchScore(Metric):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def _load_metric(self, **kwargs):
pass
def compute(self, prompts, predictions, references):
return np.mean(
[
1 if p.split() == r.split() else 0
for p, r in zip(predictions, references)
]
)
class LevenshteinDistance(Metric):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def _load_metric(self, **kwargs):
from fuzzywuzzy import fuzz
self.metric = fuzz.ratio
def compute(self, prompts, predictions, references):
return np.mean([self.metric(p, r) for p, r in zip(predictions, references)])
class RulerStringMatch(Metric):
"""
Metric used in RULER.
Reference: https://github.com/hsiehjackson/RULER/blob/main/scripts/eval/synthetic/constants.py
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
@staticmethod
def postprocess_pred(predict_str: str):
predict_str = predict_str.strip()
# Remove all non-printable characters
np_pattern = re.compile(r"[\x00-\x1f]")
predict_str = np_pattern.sub("\n", predict_str).strip()
return predict_str
@staticmethod
def string_match_part(refs, preds):
scores = [
max([1.0 if r.lower() in pred.lower() else 0.0 for r in ref])
for pred, ref in zip(preds, refs)
]
score = sum(scores) / len(preds) * 100
return {"score": round(score, 4)}
@staticmethod
def string_match_all(refs, preds):
scores = [
sum([1.0 if r.lower() in pred.lower() else 0.0 for r in ref]) / len(ref)
for pred, ref in zip(preds, refs)
]
score = sum(scores) / len(preds) * 100
return {"score": round(score, 4)}
def _load_metric(self, **kwargs):
if kwargs.get("match_part", False):
self.metric = self.string_match_part
else:
self.metric = self.string_match_all
def compute(self, prompts, predictions, references):
predictions = [self.postprocess_pred(pred) for pred in predictions]
return self.metric(references, predictions)
REFERENCE_TEMPLATE = """You are shown ground-truth answer(s) and asked to judge the quality of an LLM-generated answer.
Assign it a score from 1-5 where 1 is the worst and 5 is the best based on how similar it is to the ground-truth(s).
Do NOT explain your choice. Simply return a number from 1-5.
====GROUND TRUTHS====
{labels}
====ANSWER====
{prediction}"""
PREFILL = "The score (1-5) is:"
class LLMRouge(Metric):
def __init__(self, num_retries=5, **kwargs) -> None:
assert (
"ANTHROPIC_API_KEY" in os.environ
), "Please set the ANTHROPIC_API_KEY environment variable."
super().__init__(**kwargs)
self.num_retries = num_retries
def _load_metric(self, **kwargs):
name = kwargs.get("name", "haiku")
matching_names = [m for m in models if name in m]
assert len(matching_names) > 0, f"Model name {name} not found in {models}"
assert (
len(matching_names) == 1
), f"Model name {name} found x{len(matching_names)} in {models}"
self.chat = Chat(
matching_names[0], sp="""You are a helpful and concise assistant."""
)
def parse_int(self, text):
return int(re.search(r"\d+", text).group())
def compute(self, prompts, predictions, labels):
scores = []
for p, ls in zip(predictions, labels):
prompt = REFERENCE_TEMPLATE.format(labels="\n---\n".join(ls), prediction=p)
# Clear conversation history
self.chat.h = []
try:
score = (
self.chat(prompt, prefill=PREFILL)
.content[0]
.text[len(PREFILL) :]
.strip()
)
except RateLimitError:
retries = 0
while retries < self.num_retries:
time.sleep(10)
try:
score = (
self.chat(prompt, prefill=PREFILL)
.content[0]
.text[len(PREFILL) :]
.strip()
)
break
except RateLimitError:
retries += 1
if retries == self.num_retries:
raise RateLimitError("Exceeded maximum number of retries.")
score = self.parse_int(score)
scores.append(score)
return {"llm_rouge": sum(scores) / len(scores)}
LLM_JUDGE_TEMPLATE = """You are shown a prompt and asked to assess the quality of an LLM-generated answer on the following dimensions:
===CRITERIA===
{criteria}
Respond with "criteria: score" for each criteria with a newline for each criteria.
Assign a score from 1-5 where 1 is the worst and 5 is the best based on how well the answer meets the criteria.
====PROMPT====
{prompt}
====ANSWER====
{prediction}"""
CRITERIA = {
"helpful": "The answer executes the action requested by the prompt without extraneous detail.",
"coherent": "The answer is logically structured and coherent (ignore the prompt).",
"faithful": "The answer is faithful to the prompt and does not contain false information.",
}
class LLMJudge(LLMRouge):
def __init__(self, **kwargs) -> None:
assert (
"ANTHROPIC_API_KEY" in os.environ
), "Please set the ANTHROPIC_API_KEY environment variable."
super().__init__(**kwargs)
self.criteria = list(sorted([k for k in CRITERIA]))
self.criteria_def = "\n".join([f"{k}: {CRITERIA[k]}" for k in self.criteria])
self.prefill = (
f"\n\n====SCORES for {', '.join(self.criteria)}====\n\n{self.criteria[0]}:"
)
def parse_scorecard(self, scorecard):
try:
return {
k: int(v)
for k, v in dict(
re.findall(rf"({'|'.join(self.criteria)})\W+(\d+)", scorecard)
).items()
}
except Exception as e:
print(e)
raise Exception(
f"Could not parse LLM-generated scorecard for {self.__class__}:\n{scorecard}"
)
def claudette_scorecard(self, prompt, prediction):
prompt = LLM_JUDGE_TEMPLATE.format(
criteria=self.criteria_def, prompt=prompt, prediction=prediction
)
# Clear conversation history
self.chat.h = []
scorecard = (
self.chat(prompt, prefill=self.prefill)
.content[0]
.text[len(self.prefill) - len(self.criteria[0]) - 1 :]
.strip()
)
return scorecard
def compute(self, prompts, predictions, labels):
scores = []
for prompt, pred in zip(prompts, predictions):
scorecard = self.claudette_scorecard(prompt, pred)
score_dict = self.parse_scorecard(scorecard)
scores.append(score_dict)
return {k: np.mean([s[k] for s in scores]) for k in self.criteria}
METRIC_MAPPING = {
"accuracy": Accuracy,
"bertscore": BertScore,
"bleurt": Bleurt,
"exact_match": ExactMatchScore,
"levenshtein": LevenshteinDistance,
"llm-rouge": LLMRouge,
"llm-as-a-judge": LLMJudge,
"rouge": Rouge,
"ruler-string-match": RulerStringMatch,
}
class AutoMetric:
def __init__(self):
raise EnvironmentError(
"This class is designed to be instantiated only through the from_name method"
)
def from_name(metric_name, **kwargs):
if metric_name not in METRIC_MAPPING:
raise ValueError(f"Invalid metric name: {metric_name}")
return METRIC_MAPPING[metric_name](**kwargs)
if __name__ == "__main__":
metric = AutoMetric.from_name("llm-as-a-judge")
predictions = [
"The answer to 2x2 is 4.",
"The answer to 2x2 is 5.",
]
labels = [["4"], ["4"]]
prompts = [
"What is 2x2?",
"What is 2x2?",
]
print(metric.compute(prompts=prompts, predictions=predictions, labels=None))