Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Yunnglin committed Jan 3, 2025
1 parent 1bc459d commit e0fc9b4
Show file tree
Hide file tree
Showing 9 changed files with 215 additions and 342 deletions.
4 changes: 2 additions & 2 deletions docs/en/advanced_guides/collection/schema.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ Where:
- `name` is the name of the data mixing schema.
- `datasets` is a list of datasets, where each dataset (DatasetInfo) includes attributes such as `name`, `weight`, `task_type`, `tags`, and `args`.
- `name` is the name of the dataset. Supported dataset names can be found in the [dataset list](../../get_started/supported_dataset.md#1-native-supported-datasets).
- `weight` is the weight of the dataset, used for weighted sampling. The default is 1, and all data will be normalized during sampling.
- `weight` is the weight of the dataset, used for weighted sampling. The default is 1.0, and all data will be normalized during sampling. (The value must be greater than 0)
- `task_type` is the task type of the dataset and can be filled in as needed.
- `tags` are labels for the dataset, which can also be filled in as needed.
- `args` are parameters for the dataset, and the configurable parameters can be found in the [dataset parameters](../../get_started/parameters.md#dataset-parameters).
Expand All @@ -42,7 +42,7 @@ complex_schema = CollectionSchema(name='math&reasoning', datasets=[
]),
])
```
- `weight` is the weight of the data mixing schema, used for weighted sampling. The default is 1, and all data will be normalized during sampling.
- `weight` is the weight of the data mixing schema, used for weighted sampling. The default is 1.0, and all data will be normalized during sampling. (The value must be greater than 0)
- `datasets` can contain CollectionSchema, enabling the nesting of datasets. During evaluation, the name of the `CollectionSchema` will be recursively added to the tags of each sample.

## Using the Schema
Expand Down
10 changes: 5 additions & 5 deletions docs/zh/advanced_guides/collection/schema.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ simple_schema = CollectionSchema(name='reasoning', datasets=[
- `name` 是数据混合schema的名称
- `datasets` 是数据集列表,每个数据集(DatasetInfo)包含 `name``weight``task_type``tags``args` 等属性。
- `name` 是数据集的名称,支持的数据集名称见[数据集列表](../../get_started/supported_dataset.md#1-原生支持的数据集)
- `weight` 是数据集的权重,用于加权采样,默认为1,采样时所有数据会归一化
- `weight` 是数据集的权重,类型为float,用于加权采样,默认为1.0,采样时所有数据会归一化(数值需要大于0)
- `task_type` 是数据集的任务类型,可自行填写
- `tags` 是数据集的标签,可自行填写
- `args` 是数据集的参数,可指定的参数见[数据集参数](../../get_started/parameters.md#数据集参数)
Expand All @@ -32,17 +32,17 @@ complex_schema = CollectionSchema(name='math&reasoning', datasets=[
CollectionSchema(name='math', weight=3, datasets=[
DatasetInfo(name='gsm8k', weight=1, task_type='math', tags=['en']),
DatasetInfo(name='competition_math', weight=1, task_type='math', tags=['en']),
DatasetInfo(name='cmmlu', weight=1, task_type='math', tags=['zh'], args={'subset_list': ['college_mathematics', 'high_school_mathematics']}),
DatasetInfo(name='ceval', weight=1, task_type='math', tags=['zh'], args={'subset_list': ['advanced_mathematics', 'high_school_mathematics', 'discrete_mathematics', 'middle_school_mathematics']}),
DatasetInfo(name='cmmlu', weight=1, task_type='math_examination', tags=['zh'], args={'subset_list': ['college_mathematics', 'high_school_mathematics']}),
DatasetInfo(name='ceval', weight=1, task_type='math_examination', tags=['zh'], args={'subset_list': ['advanced_mathematics', 'high_school_mathematics', 'discrete_mathematics', 'middle_school_mathematics']}),
]),
CollectionSchema(name='reasoning', weight=1, datasets=[
DatasetInfo(name='arc', weight=1, task_type='reasoning', tags=['en']),
DatasetInfo(name='ceval', weight=1, task_type='reasoning', tags=['zh'], args={'subset_list': ['logic']}),
DatasetInfo(name='ceval', weight=1, task_type='reasoning_examination', tags=['zh'], args={'subset_list': ['logic']}),
DatasetInfo(name='race', weight=1, task_type='reasoning', tags=['en']),
]),
])
```
- `weight` 是数据混合schema的权重,用于加权采样,默认为1,采样时所有数据会归一化
- `weight` 是数据混合schema的权重,类型为float,用于加权采样,默认为1.0,采样时所有数据会归一化(数值需要大于0)
- `datasets` 中可以包含CollectionSchema,从而实现数据集的嵌套;在评测时,`CollectionSchema`的名称会递归添加到每个样本的tag中

## 使用schema
Expand Down
286 changes: 9 additions & 277 deletions evalscope/benchmarks/competition_math/competition_math_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from evalscope.benchmarks import Benchmark, DataAdapter
from evalscope.metrics import WeightedAverageAccuracy
from evalscope.metrics.math_accuracy import is_equiv, last_boxed_only_string, remove_boxed
from evalscope.models import ChatGenerationModelAdapter
from evalscope.utils.logger import get_logger

Expand Down Expand Up @@ -76,11 +77,11 @@ def gen_prompt(self, input_d: dict, few_shot_list: list, **kwargs) -> dict:
use_fewshot = self.few_shot_num > 0
full_prompt = self._generate_prompt(input_d, use_fewshot=use_fewshot)

return {'data': [full_prompt]}
return {'data': [full_prompt], 'system_prompt': 'Put the final answer in \\boxed{}.'}

def get_gold_answer(self, input_d: dict) -> str:
# Extract the gold answer from the input dict.
return self._preprocess_input(input_d['solution'])
return remove_boxed(last_boxed_only_string(input_d['solution']))

def parse_pred_result(self, result: str, raw_input_d: dict = None, eval_type: str = 'checkpoint') -> str:
"""
Expand All @@ -94,13 +95,16 @@ def parse_pred_result(self, result: str, raw_input_d: dict = None, eval_type: st
Returns:
The parsed answer. Depending on the dataset. Usually a string for chat.
"""
# TODO: check answer extraction
# Note: Use same extraction method for both of checkpoint/service/custom
return self._math_postprocess(result)
try:
result = remove_boxed(last_boxed_only_string(result))
except Exception:
return None
return result

def match(self, gold: str, pred: str) -> float:
res = 0
if self._is_equiv(pred, gold):
if is_equiv(pred, gold):
res = 1

return res
Expand All @@ -120,275 +124,3 @@ def _generate_prompt(cls, input_d: dict, use_fewshot: bool = True) -> str:
else:
context = 'Problem:\n' + problem + '\nSolution:\n'
return context

@classmethod
def _preprocess_input(cls, input: str) -> str:
"""
Preprocess the input data, remove the boxed solution.
Args:
input_d: The raw input. A single data format of the Competition Math.
Returns:
The preprocessed input.
"""
return cls._remove_boxed(cls._last_boxed_only_string(input))

@classmethod
def _remove_boxed(cls, s):
if s is None:
return s

if '\\boxed ' in s:
left = '\\boxed '
assert s[:len(left)] == left
return s[len(left):]

left = '\\boxed{'

assert s[:len(left)] == left
assert s[-1] == '}'

return s[len(left):-1]

@classmethod
def _last_boxed_only_string(cls, string):

idx = string.rfind('\\boxed')
if '\\boxed ' in string:
return '\\boxed ' + string.split('\\boxed ')[-1].split('$')[0]
if idx < 0:
idx = string.rfind('\\fbox')
if idx < 0:
return None

i = idx
right_brace_idx = None
num_left_braces_open = 0
while i < len(string):
if string[i] == '{':
num_left_braces_open += 1
if string[i] == '}':
num_left_braces_open -= 1
if num_left_braces_open == 0:
right_brace_idx = i
break
i += 1

if right_brace_idx is None:
retval = None
else:
retval = string[idx:right_brace_idx + 1]

return retval

@classmethod
def _is_equiv(cls, str1, str2, verbose=False):
if str1 is None and str2 is None:
logger.warning('WARNING: Both None')
return True
if str1 is None or str2 is None:
return False

try:
ss1 = cls.strip_string(str1)
ss2 = cls.strip_string(str2)
if verbose:
logger.info(f'ss1: {ss1}, ss2: {ss2}')
return ss1 == ss2
except Exception:
return str1 == str2

@classmethod
def strip_string(cls, string):
# linebreaks
string = string.replace('\n', '')

# remove inverse spaces
string = string.replace('\\!', '')

# replace \\ with \
string = string.replace('\\\\', '\\')

# replace tfrac and dfrac with frac
string = string.replace('tfrac', 'frac')
string = string.replace('dfrac', 'frac')

# remove \left and \right
string = string.replace('\\left', '')
string = string.replace('\\right', '')

# Remove circ (degrees)
string = string.replace('^{\\circ}', '')
string = string.replace('^\\circ', '')

# remove dollar signs
string = string.replace('\\$', '')

# remove units (on the right)
string = cls.remove_right_units(string)

# remove percentage
string = string.replace('\\%', '')
string = string.replace('\%', '') # noqa: W605

# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
string = string.replace(' .', ' 0.')
string = string.replace('{.', '{0.')
# if empty, return empty string
if len(string) == 0:
return string
if string[0] == '.':
string = '0' + string

# to consider: get rid of e.g. "k = " or "q = " at beginning
if len(string.split('=')) == 2:
if len(string.split('=')[0]) <= 2:
string = string.split('=')[1]

# fix sqrt3 --> sqrt{3}
string = cls.fix_sqrt(string)

# remove spaces
string = string.replace(' ', '')

# \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
string = cls.fix_fracs(string)

# manually change 0.5 --> \frac{1}{2}
if string == '0.5':
string = '\\frac{1}{2}'

# NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
string = cls.fix_a_slash_b(string)

return string

@classmethod
def remove_right_units(cls, string):
# "\\text{ " only ever occurs (at least in the val set) when describing units
if '\\text{ ' in string:
splits = string.split('\\text{ ')
assert len(splits) == 2
return splits[0]
else:
return string

@classmethod
def fix_fracs(cls, string):
substrs = string.split('\\frac')
new_str = substrs[0]
if len(substrs) > 1:
substrs = substrs[1:]
for substr in substrs:
new_str += '\\frac'
if substr[0] == '{':
new_str += substr
else:
try:
assert len(substr) >= 2
except AssertionError:
return string
a = substr[0]
b = substr[1]
if b != '{':
if len(substr) > 2:
post_substr = substr[2:]
new_str += '{' + a + '}{' + b + '}' + post_substr
else:
new_str += '{' + a + '}{' + b + '}'
else:
if len(substr) > 2:
post_substr = substr[2:]
new_str += '{' + a + '}' + b + post_substr
else:
new_str += '{' + a + '}' + b
string = new_str
return string

@classmethod
def fix_sqrt(cls, string):
if '\\sqrt' not in string:
return string
splits = string.split('\\sqrt')
new_string = splits[0]
for split in splits[1:]:
if split[0] != '{':
a = split[0]
new_substr = '\\sqrt{' + a + '}' + split[1:]
else:
new_substr = '\\sqrt' + split
new_string += new_substr
return new_string

@classmethod
def fix_a_slash_b(cls, string):
if len(string.split('/')) != 2:
return string
a = string.split('/')[0]
b = string.split('/')[1]
try:
a = int(a)
b = int(b)
assert string == '{}/{}'.format(a, b)
new_string = '\\frac{' + str(a) + '}{' + str(b) + '}'
return new_string
except AssertionError:
return string

@classmethod
def _math_postprocess(cls, text: str) -> str:
SUBSTITUTIONS = [('an ', ''), ('a ', ''), ('.$', '$'), ('\\$', ''), (r'\ ', ''), (' ', ''), ('mbox', 'text'),
(',\\text{and}', ','), ('\\text{and}', ','), ('\\text{m}', '\\text{}'), ('\\le', '<')]
REMOVED_EXPRESSIONS = [
'square', 'ways', 'integers', 'dollars', 'mph', 'inches', 'ft', 'hours', 'km', 'units', '\\ldots', 'sue',
'points', 'feet', 'minutes', 'digits', 'cents', 'degrees', 'cm', 'gm', 'pounds', 'meters', 'meals', 'edges',
'students', 'childrentickets', 'multiples', '\\text{s}', '\\text{.}', '\\text{\ns}', '\\text{}^2',
'\\text{}^3', '\\text{\n}', '\\text{}', r'\mathrm{th}', r'^\circ', r'^{\circ}', r'\;', r',\!', '{,}', '"',
'\\dots', '\n', '\r', '\f'
]
import re

def normalize_final_answer(final_answer: str) -> str:
"""Normalize a final answer to a quantitative reasoning question."""
# final_answer = final_answer.split('=')[-1]
for before, after in SUBSTITUTIONS:
final_answer = final_answer.replace(before, after)
for expr in REMOVED_EXPRESSIONS:
final_answer = final_answer.replace(expr, '')

# Extract answer that is in LaTeX math, is bold,
# is surrounded by a box, etc.
final_answer = re.sub(r'(\\text\{)(.*?)(\})', '\\2', final_answer)
final_answer = re.sub(r'(\\textbf\{)(.*?)(\})', '\\2', final_answer)
final_answer = re.sub(r'(\\overline\{)(.*?)(\})', '\\2', final_answer)
final_answer = re.sub(r'(\\boxed\{)(.*)(\})', '\\2', final_answer)
assert '\n' not in final_answer
assert '\r' not in final_answer
assert '\f' not in final_answer
if len(re.findall(r'finalansweris(.*)', final_answer)) > 0:
final_answer = re.findall(r'finalansweris(.*)', final_answer)[-1]

if len(re.findall(r'oxed\{(.*?)\}', final_answer)) > 0:
final_answer = re.findall(r'oxed\{(.*?)\}', final_answer)[-1]

if len(re.findall(r'\$(.*?)\$', final_answer)) > 0:
final_answer = re.findall(r'\$(.*?)\$', final_answer)[-1]
final_answer = final_answer.strip()
if 'rac' in final_answer and '\\frac' not in final_answer:
final_answer = final_answer.replace('rac', '\\frac')

final_answer = re.sub(r'(frac)([^{])(.)', 'frac{\\2}{\\3}', final_answer)
final_answer = re.sub(r'(sqrt)([^{])', 'sqrt{\\2}', final_answer)
final_answer = final_answer.replace('$', '')

# Normalize 100,000 -> 100000
if final_answer.replace(',', '').isdigit():
final_answer = final_answer.replace(',', '')

return final_answer

for maybe_ans in text.split('.'):
if 'final answer' in maybe_ans.lower():
return normalize_final_answer(maybe_ans)
return normalize_final_answer(text.split('.')[0])
2 changes: 1 addition & 1 deletion evalscope/collections/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from evalscope.collections.data_generator import StratifiedSampler, UniformSampler, WeightedSampler
from evalscope.collections.evaluator import EvaluatorCollection
from evalscope.collections.sampler import StratifiedSampler, UniformSampler, WeightedSampler
from evalscope.collections.schema import CollectionSchema, DatasetInfo
2 changes: 1 addition & 1 deletion evalscope/collections/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from tqdm import tqdm

from evalscope.benchmarks import Benchmark
from evalscope.collections.data_generator import DatasetEntry
from evalscope.collections.sampler import DatasetEntry
from evalscope.config import TaskConfig
from evalscope.constants import AnswerKeys, DumpMode, EvalType, ReviewKeys
from evalscope.evaluator import Evaluator
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(self, schema: CollectionSchema, count: Optional[int] = None):

@abstractmethod
def sample(self) -> List[dict]:
pass
raise NotImplementedError

def _collect_dataset_data(self, dataset_info_list: List[DatasetInfo]) -> List[DatasetEntry]:
all_data = []
Expand Down
Loading

0 comments on commit e0fc9b4

Please sign in to comment.