Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SIG] add HeadQA dataset #513

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions configs/datasets/HeadQA/HeadQA_ppl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from mmengine.config import read_base

with read_base():
from .HeadQA_ppl_983537 import HeadQA_datasets # noqa: F401, F403
53 changes: 53 additions & 0 deletions configs/datasets/HeadQA/HeadQA_ppl_983537.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import FixKRetriever
from opencompass.openicl.icl_inferencer import PPLInferencer
from opencompass.openicl.icl_evaluator import AccEvaluator
from opencompass.datasets import HeadQADataset


_hint = "The following questions come from exams to access a specialized position in the Spanish healthcare system. \n" \
"Please choose the correct answer according to the question. \n"

HeadQA_infer_cfg = dict(
ice_template=dict(
type=PromptTemplate,
template="This is a {category} question which was extracted from the {year} {name} exam.\n" \
"{qtext}\n{choices}Answer: {ra}",
),
prompt_template=dict(
type=PromptTemplate,
template={
answer:
f"{_hint}</E>This is a {{category}} question which was extracted from the {{year}} {{name}} exam.\n" \
f"{{qtext}}\n{{choices}}Answer: {answer}"
for answer in [1, 2, 3, 4, 5]
},
ice_token='</E>',
),
retriever=dict(type=FixKRetriever, fix_id_list=[200, 400, 600, 800, 1000]),
inferencer=dict(type=PPLInferencer))

HeadQA_eval_cfg = dict(evaluator=dict(type=AccEvaluator), )

langs = ['en', 'es']
HeadQA_datasets = []
for lang in langs:
for _split in ['validation', 'test']:

HeadQA_reader_cfg = dict(
input_columns=['name', 'year', 'category', 'qtext', 'choices'],
output_column='ra',
test_split=_split
)

HeadQA_datasets.append(
dict(
abbr=f'HeadQA-{_split}',
type=HeadQADataset,
path='head_qa',
name=lang,
reader_cfg=HeadQA_reader_cfg,
infer_cfg=HeadQA_infer_cfg,
eval_cfg=HeadQA_eval_cfg
)
)
1 change: 1 addition & 0 deletions opencompass/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from .GaokaoBench import * # noqa: F401, F403
from .govrepcrs import * # noqa: F401, F403
from .gsm8k import * # noqa: F401, F403
from .headqa import * # noqa: F401, F403
from .hellaswag import * # noqa: F401, F403
from .huggingface import * # noqa: F401, F403
from .humaneval import * # noqa: F401, F403
Expand Down
24 changes: 24 additions & 0 deletions opencompass/datasets/headqa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from datasets import load_dataset

from opencompass.registry import LOAD_DATASET

from .base import BaseDataset


@LOAD_DATASET.register_module()
class HeadQADataset(BaseDataset):

@staticmethod
def load(path: str, name: str):
dataset = load_dataset(path=path, name=name)

def preprocess(example):
answers = example.pop('answers')
choices_str = ''
for ans in answers:
choices_str += f"{ans['aid']}. {ans['atext']}\n"
example['choices'] = choices_str
return example

dataset = dataset.map(preprocess).remove_columns(['image'])
return dataset