|
36 | 36 | from nltk.tokenize.treebank import TreebankWordTokenizer
|
37 | 37 | from nltk.translate.bleu_score import sentence_bleu
|
38 | 38 | from pydantic import BaseModel
|
| 39 | +from scipy.stats import hypergeom |
39 | 40 | from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
40 | 41 |
|
41 | 42 | from lighteval.metrics.imports.bert_scorer import BERTScorer
|
@@ -1189,3 +1190,164 @@ def pass_at_k(self, all_scores: list[int]) -> float:
|
1189 | 1190 | return 1.0
|
1190 | 1191 |
|
1191 | 1192 | return 1.0 - np.prod(1.0 - self.k / np.arange(self.n - c + 1, self.n + 1))
|
| 1193 | + |
| 1194 | + |
| 1195 | +class GPassAtK: |
| 1196 | + def __init__( |
| 1197 | + self, |
| 1198 | + k: Union[int, list[int]], |
| 1199 | + n: int = None, |
| 1200 | + thresholds: list[float] = [0.0, 0.25, 0.5, 0.75, 1.0], |
| 1201 | + normalize_gold: Callable = None, |
| 1202 | + normalize_pred: Callable = None, |
| 1203 | + strip_strings: bool = False, |
| 1204 | + sample_scoring_function: Union[Callable[[str, str], float], str] = None, |
| 1205 | + ): |
| 1206 | + """Computing G-Pass@k from http://arxiv.org/abs/2412.13147 |
| 1207 | +
|
| 1208 | + Args: |
| 1209 | + k (int, list): The number of successful attempts to be considered. |
| 1210 | + n (int): Number of samples to generate. |
| 1211 | + thresholds (list): Thresholds to control successful attempts in k generate. |
| 1212 | + normalize_gold (callable, optional): Function to use to normalize the reference strings. |
| 1213 | + Defaults to None if no normalization is applied. |
| 1214 | + normalize_pred (callable, optional): Function to use to normalize the predicted strings. |
| 1215 | + Defaults to None if no normalization is applied. |
| 1216 | + strip_strings (bool, optional): Whether to strip both reference and predictions. Defaults to False. |
| 1217 | + sample_scoring_function (callable or str, optional): Function to use to score each sample. |
| 1218 | + Either pass the full function (should take a string prediction and a string gold, and return a score between 0 and 1) |
| 1219 | + a string (any of `prefix`, `suffix` or `full`) to define the type of exact match that you want, or nothing to defaults to "full". |
| 1220 | + `prefix` checks if the prediction starts with the gold, |
| 1221 | + `suffix` if the prediction ends with the gold, |
| 1222 | + `full` if the prediction and gold are equal |
| 1223 | + """ |
| 1224 | + self.k = as_list(k) |
| 1225 | + self.n = n |
| 1226 | + self.thresholds = thresholds |
| 1227 | + self.normalize_gold = normalize_gold |
| 1228 | + self.normalize_pred = normalize_pred |
| 1229 | + self.strip_strings = strip_strings |
| 1230 | + |
| 1231 | + # Managed the logic of the per prediction of sample scoring |
| 1232 | + if callable(sample_scoring_function): |
| 1233 | + self.score_sample = sample_scoring_function |
| 1234 | + self.type_exact_match = None |
| 1235 | + else: |
| 1236 | + if isinstance(sample_scoring_function, str): |
| 1237 | + if sample_scoring_function not in ["prefix", "suffix", "full"]: |
| 1238 | + raise ValueError( |
| 1239 | + f"type_exact_match (used in parametrized_exact_match) must be one of prefix, suffix, or full. Was {sample_scoring_function} instead." |
| 1240 | + ) |
| 1241 | + self.type_exact_match = sample_scoring_function |
| 1242 | + else: |
| 1243 | + self.type_exact_match = "full" |
| 1244 | + self.score_sample = self.default_sample_scoring |
| 1245 | + |
| 1246 | + def compute(self, predictions: list[str], formatted_doc: list[Doc], **kwargs) -> dict[str, float]: |
| 1247 | + """Computes the metric over a list of golds and predictions for one single item with possibly many samples. |
| 1248 | + It applies normalisation (if needed) to model prediction and gold, computes their per prediction score, |
| 1249 | + then aggregates the scores over the samples using a pass@k. |
| 1250 | +
|
| 1251 | + Args: |
| 1252 | + golds (list[str]): Reference targets |
| 1253 | + predictions (list[str]): k predicted strings |
| 1254 | +
|
| 1255 | + Returns: |
| 1256 | + float: Aggregated score over the current sample's items. |
| 1257 | + """ |
| 1258 | + golds = formatted_doc.get_golds() |
| 1259 | + |
| 1260 | + if len(golds) > 1: |
| 1261 | + raise Exception("Cannot compute G-Pass@k with several golds") |
| 1262 | + |
| 1263 | + if self.n is None: |
| 1264 | + self.n = len(predictions) |
| 1265 | + logger.warning( |
| 1266 | + "n undefined in the G-Pass@k. We assume it's the same as the sample's number of predictions." |
| 1267 | + ) |
| 1268 | + elif len(predictions) < self.n: |
| 1269 | + logger.warning(f"Number of predictions is less than {self.n} for G-Pass@k.") |
| 1270 | + |
| 1271 | + gold = self.get_processed_gold(golds[0]) |
| 1272 | + |
| 1273 | + all_scores = [] |
| 1274 | + for pred in predictions[: self.n]: |
| 1275 | + cur_pred = self.get_processed_pred(pred=pred) |
| 1276 | + all_scores.append(self.score_sample(cur_pred, gold, formatted_doc)) |
| 1277 | + |
| 1278 | + return self.g_pass_at_k(all_scores) |
| 1279 | + |
| 1280 | + def get_processed_gold(self, gold: str) -> str: |
| 1281 | + if self.strip_strings: |
| 1282 | + gold = gold.strip() |
| 1283 | + |
| 1284 | + if self.normalize_gold: |
| 1285 | + gold = self.normalize_gold(gold) |
| 1286 | + |
| 1287 | + return gold |
| 1288 | + |
| 1289 | + def get_processed_pred(self, pred: str) -> str: |
| 1290 | + if not pred: |
| 1291 | + return "" |
| 1292 | + |
| 1293 | + if self.strip_strings: |
| 1294 | + pred = pred.strip() |
| 1295 | + |
| 1296 | + if self.normalize_pred: |
| 1297 | + pred = self.normalize_pred(pred) |
| 1298 | + |
| 1299 | + return pred |
| 1300 | + |
| 1301 | + def default_sample_scoring(self, pred: str, gold: str) -> int: |
| 1302 | + if self.type_exact_match == "prefix": |
| 1303 | + return 1 if pred.startswith(gold) else 0 |
| 1304 | + if self.type_exact_match == "suffix": |
| 1305 | + return 1 if pred.endswith(gold) else 0 |
| 1306 | + return 1 if gold == pred else 0 |
| 1307 | + |
| 1308 | + def g_pass_at_k(self, all_scores: list[int]) -> float: |
| 1309 | + """Computation of G-Pass@k details from http://arxiv.org/abs/2412.13147""" |
| 1310 | + c: int = sum(all_scores) |
| 1311 | + n: int = self.n |
| 1312 | + ks: int = self.k |
| 1313 | + thresholds: list[float] = self.thresholds |
| 1314 | + |
| 1315 | + def _compute_g_pass_at_k(n, c, k, m): |
| 1316 | + if m > min(c, k) or k > n or c < 0 or n <= 0 or m < 0: |
| 1317 | + return 0.0 |
| 1318 | + return hypergeom.sf(m - 1, n, c, k) |
| 1319 | + |
| 1320 | + def compute_g_pass_at_k(n, c, k, t): |
| 1321 | + m = max(int(np.ceil(k * t)), 1) |
| 1322 | + return _compute_g_pass_at_k(n, c, k, m) |
| 1323 | + |
| 1324 | + def compute_mg_pass_at_k(n, c, k): |
| 1325 | + low, high = int(np.ceil(k * 0.5)), k |
| 1326 | + |
| 1327 | + mg_pass_at_k = 0.0 |
| 1328 | + for i in range(low + 1, high + 1): |
| 1329 | + mg_pass_at_k += _compute_g_pass_at_k(n, c, k, i) |
| 1330 | + mg_pass_at_k = 2 * mg_pass_at_k / k |
| 1331 | + |
| 1332 | + return mg_pass_at_k |
| 1333 | + |
| 1334 | + metrics = {} |
| 1335 | + for k in ks: |
| 1336 | + for t in thresholds: |
| 1337 | + metrics[f"G-Pass@{k}_{t}"] = compute_g_pass_at_k(n, c, k, t) |
| 1338 | + metrics[f"mG-Pass@{k}"] = compute_mg_pass_at_k(n, c, k) |
| 1339 | + |
| 1340 | + return metrics |
| 1341 | + |
| 1342 | + @property |
| 1343 | + def all_metrics(self): |
| 1344 | + ks: int = self.k |
| 1345 | + thresholds: list[float] = self.thresholds |
| 1346 | + |
| 1347 | + metrics = [] |
| 1348 | + for k in ks: |
| 1349 | + for t in thresholds: |
| 1350 | + metrics.append(f"G-Pass@{k}_{t}") |
| 1351 | + metrics.append(f"mG-Pass@{k}") |
| 1352 | + |
| 1353 | + return metrics |
0 commit comments