Skip to content

Commit 89ba880

Browse files
jnanliuNathanHB
andauthored
Add G-Pass@k Metric (#589)
* add gpassk metric * fix pre-commit error * fix return type check * fix metrics * support gpassk for aime24/25 and math_500 * fix List to list * remove List --------- Co-authored-by: Nathan Habib <[email protected]>
1 parent 88e3a3b commit 89ba880

File tree

3 files changed

+263
-0
lines changed

3 files changed

+263
-0
lines changed

src/lighteval/metrics/metrics.py

+59
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
Extractiveness,
5252
F1_score,
5353
Faithfulness,
54+
GPassAtK,
5455
LoglikelihoodAcc,
5556
MajAtK,
5657
PassAtK,
@@ -578,6 +579,64 @@ class Metrics(Enum):
578579
corpus_level_fn=np.mean,
579580
higher_is_better=True,
580581
)
582+
g_pass_at_16 = SampleLevelMetricGrouping(
583+
metric_name="G-Pass@16:48_samples",
584+
sample_level_fn=GPassAtK(k=16, n=48, strip_strings=True).compute,
585+
category=MetricCategory.GENERATIVE_SAMPLING,
586+
use_case=MetricUseCase.REASONING,
587+
corpus_level_fn={metric: np.mean for metric in GPassAtK(k=16, n=48, strip_strings=True).all_metrics},
588+
higher_is_better={metric: True for metric in GPassAtK(k=16, n=48, strip_strings=True).all_metrics},
589+
)
590+
g_pass_at_8_16 = SampleLevelMetricGrouping(
591+
metric_name="G-Pass@8-16:48_samples",
592+
sample_level_fn=GPassAtK(k=[8, 16], n=48, strip_strings=True).compute,
593+
category=MetricCategory.GENERATIVE_SAMPLING,
594+
use_case=MetricUseCase.REASONING,
595+
corpus_level_fn={metric: np.mean for metric in GPassAtK(k=16, n=48, strip_strings=True).all_metrics},
596+
higher_is_better={metric: True for metric in GPassAtK(k=16, n=48, strip_strings=True).all_metrics},
597+
)
598+
g_pass_at_16_expr_gold = SampleLevelMetricGrouping(
599+
metric_name="G-Pass@16:48_samples",
600+
sample_level_fn=GPassAtK(
601+
k=16,
602+
n=48,
603+
strip_strings=True,
604+
sample_scoring_function=lambda pred, ref, doc: multilingual_extractive_match_metric(
605+
language=Language.ENGLISH,
606+
fallback_mode="first_match",
607+
precision=5,
608+
gold_extraction_target=(ExprExtractionConfig(),),
609+
# Match boxed first before trying other regexes
610+
pred_extraction_target=(ExprExtractionConfig(), LatexExtractionConfig(boxed_match_priority=0)),
611+
aggregation_function=max,
612+
).sample_level_fn([ref], [pred], doc),
613+
).compute,
614+
category=MetricCategory.GENERATIVE_SAMPLING,
615+
use_case=MetricUseCase.REASONING,
616+
corpus_level_fn={metric: np.mean for metric in GPassAtK(k=16, n=48, strip_strings=True).all_metrics},
617+
higher_is_better={metric: True for metric in GPassAtK(k=16, n=48, strip_strings=True).all_metrics},
618+
)
619+
g_pass_at_16_latex_gold = SampleLevelMetricGrouping(
620+
metric_name="G-Pass@16:48_samples",
621+
sample_level_fn=GPassAtK(
622+
k=16,
623+
n=48,
624+
strip_strings=True,
625+
sample_scoring_function=lambda pred, ref, doc: multilingual_extractive_match_metric(
626+
language=Language.ENGLISH,
627+
fallback_mode="first_match",
628+
precision=5,
629+
gold_extraction_target=(LatexExtractionConfig(),),
630+
# Match boxed first before trying other regexes
631+
pred_extraction_target=(ExprExtractionConfig(), LatexExtractionConfig(boxed_match_priority=0)),
632+
aggregation_function=max,
633+
).sample_level_fn([ref], [pred], doc),
634+
).compute,
635+
category=MetricCategory.GENERATIVE_SAMPLING,
636+
use_case=MetricUseCase.REASONING,
637+
corpus_level_fn={metric: np.mean for metric in GPassAtK(k=16, n=48, strip_strings=True).all_metrics},
638+
higher_is_better={metric: True for metric in GPassAtK(k=16, n=48, strip_strings=True).all_metrics},
639+
)
581640
perfect_exact_match = SampleLevelMetric(
582641
metric_name="perfect_em",
583642
sample_level_fn=ExactMatches().compute,

src/lighteval/metrics/metrics_sample.py

+162
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from nltk.tokenize.treebank import TreebankWordTokenizer
3737
from nltk.translate.bleu_score import sentence_bleu
3838
from pydantic import BaseModel
39+
from scipy.stats import hypergeom
3940
from transformers import AutoModelForSequenceClassification, AutoTokenizer
4041

4142
from lighteval.metrics.imports.bert_scorer import BERTScorer
@@ -1189,3 +1190,164 @@ def pass_at_k(self, all_scores: list[int]) -> float:
11891190
return 1.0
11901191

11911192
return 1.0 - np.prod(1.0 - self.k / np.arange(self.n - c + 1, self.n + 1))
1193+
1194+
1195+
class GPassAtK:
1196+
def __init__(
1197+
self,
1198+
k: Union[int, list[int]],
1199+
n: int = None,
1200+
thresholds: list[float] = [0.0, 0.25, 0.5, 0.75, 1.0],
1201+
normalize_gold: Callable = None,
1202+
normalize_pred: Callable = None,
1203+
strip_strings: bool = False,
1204+
sample_scoring_function: Union[Callable[[str, str], float], str] = None,
1205+
):
1206+
"""Computing G-Pass@k from http://arxiv.org/abs/2412.13147
1207+
1208+
Args:
1209+
k (int, list): The number of successful attempts to be considered.
1210+
n (int): Number of samples to generate.
1211+
thresholds (list): Thresholds to control successful attempts in k generate.
1212+
normalize_gold (callable, optional): Function to use to normalize the reference strings.
1213+
Defaults to None if no normalization is applied.
1214+
normalize_pred (callable, optional): Function to use to normalize the predicted strings.
1215+
Defaults to None if no normalization is applied.
1216+
strip_strings (bool, optional): Whether to strip both reference and predictions. Defaults to False.
1217+
sample_scoring_function (callable or str, optional): Function to use to score each sample.
1218+
Either pass the full function (should take a string prediction and a string gold, and return a score between 0 and 1)
1219+
a string (any of `prefix`, `suffix` or `full`) to define the type of exact match that you want, or nothing to defaults to "full".
1220+
`prefix` checks if the prediction starts with the gold,
1221+
`suffix` if the prediction ends with the gold,
1222+
`full` if the prediction and gold are equal
1223+
"""
1224+
self.k = as_list(k)
1225+
self.n = n
1226+
self.thresholds = thresholds
1227+
self.normalize_gold = normalize_gold
1228+
self.normalize_pred = normalize_pred
1229+
self.strip_strings = strip_strings
1230+
1231+
# Managed the logic of the per prediction of sample scoring
1232+
if callable(sample_scoring_function):
1233+
self.score_sample = sample_scoring_function
1234+
self.type_exact_match = None
1235+
else:
1236+
if isinstance(sample_scoring_function, str):
1237+
if sample_scoring_function not in ["prefix", "suffix", "full"]:
1238+
raise ValueError(
1239+
f"type_exact_match (used in parametrized_exact_match) must be one of prefix, suffix, or full. Was {sample_scoring_function} instead."
1240+
)
1241+
self.type_exact_match = sample_scoring_function
1242+
else:
1243+
self.type_exact_match = "full"
1244+
self.score_sample = self.default_sample_scoring
1245+
1246+
def compute(self, predictions: list[str], formatted_doc: list[Doc], **kwargs) -> dict[str, float]:
1247+
"""Computes the metric over a list of golds and predictions for one single item with possibly many samples.
1248+
It applies normalisation (if needed) to model prediction and gold, computes their per prediction score,
1249+
then aggregates the scores over the samples using a pass@k.
1250+
1251+
Args:
1252+
golds (list[str]): Reference targets
1253+
predictions (list[str]): k predicted strings
1254+
1255+
Returns:
1256+
float: Aggregated score over the current sample's items.
1257+
"""
1258+
golds = formatted_doc.get_golds()
1259+
1260+
if len(golds) > 1:
1261+
raise Exception("Cannot compute G-Pass@k with several golds")
1262+
1263+
if self.n is None:
1264+
self.n = len(predictions)
1265+
logger.warning(
1266+
"n undefined in the G-Pass@k. We assume it's the same as the sample's number of predictions."
1267+
)
1268+
elif len(predictions) < self.n:
1269+
logger.warning(f"Number of predictions is less than {self.n} for G-Pass@k.")
1270+
1271+
gold = self.get_processed_gold(golds[0])
1272+
1273+
all_scores = []
1274+
for pred in predictions[: self.n]:
1275+
cur_pred = self.get_processed_pred(pred=pred)
1276+
all_scores.append(self.score_sample(cur_pred, gold, formatted_doc))
1277+
1278+
return self.g_pass_at_k(all_scores)
1279+
1280+
def get_processed_gold(self, gold: str) -> str:
1281+
if self.strip_strings:
1282+
gold = gold.strip()
1283+
1284+
if self.normalize_gold:
1285+
gold = self.normalize_gold(gold)
1286+
1287+
return gold
1288+
1289+
def get_processed_pred(self, pred: str) -> str:
1290+
if not pred:
1291+
return ""
1292+
1293+
if self.strip_strings:
1294+
pred = pred.strip()
1295+
1296+
if self.normalize_pred:
1297+
pred = self.normalize_pred(pred)
1298+
1299+
return pred
1300+
1301+
def default_sample_scoring(self, pred: str, gold: str) -> int:
1302+
if self.type_exact_match == "prefix":
1303+
return 1 if pred.startswith(gold) else 0
1304+
if self.type_exact_match == "suffix":
1305+
return 1 if pred.endswith(gold) else 0
1306+
return 1 if gold == pred else 0
1307+
1308+
def g_pass_at_k(self, all_scores: list[int]) -> float:
1309+
"""Computation of G-Pass@k details from http://arxiv.org/abs/2412.13147"""
1310+
c: int = sum(all_scores)
1311+
n: int = self.n
1312+
ks: int = self.k
1313+
thresholds: list[float] = self.thresholds
1314+
1315+
def _compute_g_pass_at_k(n, c, k, m):
1316+
if m > min(c, k) or k > n or c < 0 or n <= 0 or m < 0:
1317+
return 0.0
1318+
return hypergeom.sf(m - 1, n, c, k)
1319+
1320+
def compute_g_pass_at_k(n, c, k, t):
1321+
m = max(int(np.ceil(k * t)), 1)
1322+
return _compute_g_pass_at_k(n, c, k, m)
1323+
1324+
def compute_mg_pass_at_k(n, c, k):
1325+
low, high = int(np.ceil(k * 0.5)), k
1326+
1327+
mg_pass_at_k = 0.0
1328+
for i in range(low + 1, high + 1):
1329+
mg_pass_at_k += _compute_g_pass_at_k(n, c, k, i)
1330+
mg_pass_at_k = 2 * mg_pass_at_k / k
1331+
1332+
return mg_pass_at_k
1333+
1334+
metrics = {}
1335+
for k in ks:
1336+
for t in thresholds:
1337+
metrics[f"G-Pass@{k}_{t}"] = compute_g_pass_at_k(n, c, k, t)
1338+
metrics[f"mG-Pass@{k}"] = compute_mg_pass_at_k(n, c, k)
1339+
1340+
return metrics
1341+
1342+
@property
1343+
def all_metrics(self):
1344+
ks: int = self.k
1345+
thresholds: list[float] = self.thresholds
1346+
1347+
metrics = []
1348+
for k in ks:
1349+
for t in thresholds:
1350+
metrics.append(f"G-Pass@{k}_{t}")
1351+
metrics.append(f"mG-Pass@{k}")
1352+
1353+
return metrics

src/lighteval/tasks/default_tasks.py

+42
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,20 @@
329329
],
330330
version=1,
331331
)
332+
aime24_gpassk = LightevalTaskConfig(
333+
name="aime24_gpassk",
334+
suite=["lighteval"],
335+
prompt_function=prompt.aime_prompt_fn,
336+
hf_repo="HuggingFaceH4/aime_2024",
337+
hf_subset="default",
338+
hf_avail_splits=["train"],
339+
evaluation_splits=["train"],
340+
few_shots_split=None,
341+
few_shots_select=None,
342+
generation_size=8192,
343+
metric=[Metrics.g_pass_at_16_expr_gold],
344+
version=1,
345+
)
332346
aime25 = LightevalTaskConfig(
333347
name="aime25",
334348
suite=["lighteval"],
@@ -346,6 +360,20 @@
346360
],
347361
version=1,
348362
)
363+
aime25_gpassk = LightevalTaskConfig(
364+
name="aime25_gpassk",
365+
suite=["lighteval"],
366+
prompt_function=prompt.aime_prompt_fn,
367+
hf_repo="yentinglin/aime_2025",
368+
hf_subset="default",
369+
hf_avail_splits=["train"],
370+
evaluation_splits=["train"],
371+
few_shots_split=None,
372+
few_shots_select=None,
373+
generation_size=8192,
374+
metric=[Metrics.g_pass_at_16_expr_gold],
375+
version=1,
376+
)
349377
anachronisms_bigbench = LightevalTaskConfig(
350378
name="anachronisms",
351379
suite=["bigbench", "bigbench_json"],
@@ -9661,6 +9689,20 @@
96619689
metric=[Metrics.latex_gold_metric],
96629690
version=1,
96639691
)
9692+
math_500_gpassk = LightevalTaskConfig(
9693+
name="math_500_gpassk",
9694+
suite=["lighteval"],
9695+
prompt_function=prompt.math_500,
9696+
hf_repo="HuggingFaceH4/MATH-500",
9697+
hf_subset="default",
9698+
hf_avail_splits=["test"],
9699+
evaluation_splits=["test"],
9700+
few_shots_split=None,
9701+
few_shots_select=None,
9702+
generation_size=8192,
9703+
metric=[Metrics.g_pass_at_16_latex_gold],
9704+
version=1,
9705+
)
96649706
math_algebra_lighteval = LightevalTaskConfig(
96659707
name="math:algebra",
96669708
suite=["lighteval", "math"],

0 commit comments

Comments
 (0)