Skip to content

Commit

Permalink
[Dataset] improve regex of MGSM
Browse files Browse the repository at this point in the history
  • Loading branch information
huyiwen committed May 27, 2024
1 parent 1126c9d commit 3d66f5b
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions utilization/dataset/mgsm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import re
import signal
from functools import cached_property

from ..metric import Accuracy
Expand All @@ -17,16 +16,16 @@ class Mgsm(GenerationDataset):
'answer_number': 11,
'equation_solution': '5 + 6 = 11.'
"""

instruction = "Answer the following question in {{lang}}.\n\nQuestion: {{question.replace('\n', ' ')}}\nAnswer:"

evaluation_set = "test"
example_set = "train"
load_args = ("juletxara/mgsm",)
metrics = [Accuracy()]
extra_model_args = dict(temperature=0)

_decimal_separator = re.compile(r"(\d),(\d)")
_decimal_separator = re.compile(r"(?<=\d),(?=\d)")
_extract_numbers = re.compile(r"[-+]?\d*\.\d+|\d+")

def init_arguments(self):
Expand All @@ -40,7 +39,7 @@ def post_processing(self, predictions):
new_predictions = []
for pred in predictions:
# replace numbers like `x,xxx` with `xxxx`
pred = self._decimal_separator.sub(r"\1\2", pred)
pred = self._decimal_separator.sub("", pred)
numbers = self._extract_numbers.findall(pred)
if numbers:
new_predictions.append(numbers[-1])
Expand Down

0 comments on commit 3d66f5b

Please sign in to comment.