From 3d66f5b21b85f757ed9ed6b427ac2cee90247e2e Mon Sep 17 00:00:00 2001 From: huyiwen <1020030101@qq.com> Date: Mon, 27 May 2024 22:11:53 +0800 Subject: [PATCH] [Dataset] improve regex of MGSM --- utilization/dataset/mgsm.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/utilization/dataset/mgsm.py b/utilization/dataset/mgsm.py index 79ccbfc7..665e101c 100644 --- a/utilization/dataset/mgsm.py +++ b/utilization/dataset/mgsm.py @@ -1,5 +1,4 @@ import re -import signal from functools import cached_property from ..metric import Accuracy @@ -17,16 +16,16 @@ class Mgsm(GenerationDataset): 'answer_number': 11, 'equation_solution': '5 + 6 = 11.' """ - + instruction = "Answer the following question in {{lang}}.\n\nQuestion: {{question.replace('\n', ' ')}}\nAnswer:" - + evaluation_set = "test" example_set = "train" load_args = ("juletxara/mgsm",) metrics = [Accuracy()] extra_model_args = dict(temperature=0) - _decimal_separator = re.compile(r"(\d),(\d)") + _decimal_separator = re.compile(r"(?<=\d),(?=\d)") _extract_numbers = re.compile(r"[-+]?\d*\.\d+|\d+") def init_arguments(self): @@ -40,7 +39,7 @@ def post_processing(self, predictions): new_predictions = [] for pred in predictions: # replace numbers like `x,xxx` with `xxxx` - pred = self._decimal_separator.sub(r"\1\2", pred) + pred = self._decimal_separator.sub("", pred) numbers = self._extract_numbers.findall(pred) if numbers: new_predictions.append(numbers[-1])