Skip to content

Commit

Permalink
Corrected bug in Flores fewshot tasks and added flores task (#145)
Browse files Browse the repository at this point in the history
* Added special few-shot examples for DiaBLa and Flores

* Added docstrings to minimally describe specialised few-shot tasks

* Added additional few-shot tasks to Flores

* Corrected bug in flores few-shot tasks and added few-shot task wmt_hi2en

* Updated few-shot diabla tasks with different options for language directions

Co-authored-by: Rachel Bawden <[email protected]@users.noreply.github.com>
  • Loading branch information
rbawden and Rachel Bawden authored Nov 17, 2022
1 parent 8cea2f4 commit bdd1d3f
Show file tree
Hide file tree
Showing 3 changed files with 255 additions and 11 deletions.
6 changes: 5 additions & 1 deletion lm_eval/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,9 @@
"crd3": crd3.CRD3,
# DiaBLa
"diabla": diabla.DiaBLa,
"diabla_1_shot_context": diabla.DiaBLa_1_shot_context,
"diabla_1_shot_context_orig": diabla.DiaBLa_1_shot_context_orig,
"diabla_1_shot_context_same": diabla.DiaBLa_1_shot_context_same,
"diabla_1_shot_context_opposite": diabla.DiaBLa_1_shot_context_opposite,
# XQuAD
"xquad_en": xquad.XQuADEnglish,
"xquad_ar": xquad.XQuADArabic,
Expand All @@ -121,7 +123,9 @@
"flores_101_mt_fewshot_fr2en": flores_101.Flores101MT_fewshot_fr2en,
"flores_101_mt_fewshot_hi2en": flores_101.Flores101MT_fewshot_hi2en,
"flores_101_mt_fewshot_fr2ar": flores_101.Flores101MT_fewshot_fr2ar,
"flores_101_mt_fewshot_en2bn": flores_101.Flores101MT_fewshot_en2bn,
"flores_101_mt_fewshot_wmt_fr2en": flores_101.Flores101MT_fewshot_wmt_fr2en,
"flores_101_mt_fewshot_wmt_hi2en": flores_101.Flores101MT_fewshot_wmt_hi2en,
# Flores101 (Perplexity)
"flores_101_ppl": flores_101.Flores101Perplexity,
# GEM/WebNLG
Expand Down
60 changes: 58 additions & 2 deletions lm_eval/tasks/diabla.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
Homepage: http://almanach.inria.fr/software_and_resources/custom/DiaBLa-en.html
"""
from lm_eval.api.task import PromptSourceTask
from typing import List, Tuple, Union, Optional
from typing import List, Tuple, Optional
import datasets
import copy
import numpy as np
Expand Down Expand Up @@ -75,7 +75,7 @@ def invalid_doc_for_prompt(self, doc) -> bool:
return False


class DiaBLa_1_shot_context(PromptSourceTask):
class DiaBLa_1_shot_context_same(PromptSourceTask):
"""
This task is identical to the DiaBLa task, but in the 1-shot setting takes the
1-shot example from the previous sentence in the dialogue if this is available
Expand Down Expand Up @@ -227,3 +227,59 @@ def fewshot_context(
"ctx": ctx,
}
return ctx, logging_info


class DiaBLa_1_shot_context_opposite(DiaBLa_1_shot_context_same):
"""
This task is identical to the DiaBLa task, but in the 1-shot setting takes the
1-shot example from the previous sentence in the dialogue if this is available
(source sentence and MT output, in the same language direction as the direction
of the current example). N.B. this task is not currently designed for more than
1-shot.
"""

DATASET_PATH = "rbawden/DiaBLa"
DATASET_NAME = None

# heuristically hack the current template to replace the attributes 'orig' and 'ref' by
# the original and reference sentences of the previous sentence (if available)
def get_fewshot_template(self):
self.shot_prompt_template = copy.deepcopy(self.prompt_template)
old_jinja = self.shot_prompt_template.jinja
preamble = '{% set src_sent = ""%}'
preamble += '{% set trg_sent = "" %}'
preamble += "{% if dialogue_history|length > 0 %}{% if utterance_meta.lang != dialogue_history[-1].utterance_meta.lang %}{% set src_sent = dialogue_history[-1].orig %}{% set trg_sent = dialogue_history[-1].ref %}{% else %}{% set src_sent = dialogue_history[-1].ref %}{% set trg_sent = dialogue_history[-1].orig %}{% endif %}{% endif %}"
self.shot_prompt_template.jinja = preamble + old_jinja.replace(
"{{ orig }}", "{{ src_sent }}"
).replace("{{ ref }}", "{{ trg_sent }}").replace(
'{% if utterance_meta.lang == "french" %}',
'{% if utterance_meta.lang != "french" %}',
)
return self.shot_prompt_template


class DiaBLa_1_shot_context_orig(DiaBLa_1_shot_context_same):
"""
This task is identical to the DiaBLa task, but in the 1-shot setting takes the
1-shot example from the previous sentence in the dialogue if this is available
(source sentence and MT output, in the same language direction as the direction
of the current example). N.B. this task is not currently designed for more than
1-shot.
"""

DATASET_PATH = "rbawden/DiaBLa"
DATASET_NAME = None

# heuristically hack the current template to replace the attributes 'orig' and 'ref' by
# the original and reference sentences of the previous sentence (if available)
def get_fewshot_template(self):
self.shot_prompt_template = copy.deepcopy(self.prompt_template)
old_jinja = self.shot_prompt_template.jinja
preamble = '{% set src_sent = ""%}{% set trg_sent = "" %}'
preamble += "{% if dialogue_history|length > 0 %}{% set src_sent = dialogue_history[-1].orig %}{% set trg_sent = dialogue_history[-1].ref %}{% endif %}"
self.shot_prompt_template.jinja = preamble + old_jinja.replace(
"{{ orig }}", "{{ src_sent }}"
).replace("{{ ref }}", "{{ trg_sent }}").replace(
"utterance_meta.lang", "dialogue_history[-1].utterance_meta.lang"
)
return self.shot_prompt_template
200 changes: 192 additions & 8 deletions lm_eval/tasks/flores_101.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
Github: https://github.com/facebookresearch/flores
"""
from lm_eval.api.task import PromptSourceTask, PerplexityTask
from typing import List, Tuple, Union, Optional
from typing import List, Tuple, Optional
import datasets
import copy
import re
Expand Down Expand Up @@ -128,12 +128,181 @@ def __init__(

def fewshot_docs(self) -> datasets.Dataset:
"""Returns a wmt dataset split"""
return datasets.load_dataset(
"wmt14",
"fr-en",
cache_dir=self.cache_dir,
download_mode=self.download_mode,
)["validation"]
return (
"valid",
datasets.load_dataset(
"wmt14",
"fr-en",
cache_dir=self.cache_dir,
download_mode=self.download_mode,
)["validation"]["translation"],
)

def fewshot_context(
self, doc: dict, num_fewshot: int, rng: Optional[np.random.Generator]
) -> Tuple[str, dict]:
"""Returns a few-shot context string made up of `num_fewshot` number of
labeled examples, and an appended prompt example without labeling.
:param doc: dict
The document as returned from training_docs, validation_docs, or test_docs.
:param num_fewshot: int
The number of fewshot examples to provide in the returned context string.
:param rng: numpy.random.Generator
The pseudo-random number generator used to randomly sample few-shot examples.
:returns: Tuple[str, dict]
ctx: str
The fewshot context.
logging_info: dict
A `dict` of logging info that can be used to identify few-shot sources.
"""
assert (
rng is not None
), "A `numpy.random.Generator` argument must be provided to `rng`"

self.get_fewshot_template()

if num_fewshot == 0:
labeled_examples = ""
fewshot_idx, fewshot_target_idx, fewshot_src = ([], [], None)
else:
# Construct few-shot labeled examples.
fewshot_src, fewshot_docs = self.fewshot_docs()

fewshot_examples, fewshot_idx = self.fewshot_examples(
fewshot_docs, k=num_fewshot, rng=rng, prompt=doc
)
labeled_examples_list = []
fewshot_target_idx = []
for fewshot_example in fewshot_examples:
# format the example, but use the previous context of the example
text = self.doc_to_shot_text(fewshot_example)
targets = self.doc_to_shot_target(fewshot_example)
# Choose 1 random target from multi-reference targets.
target_idx = int(rng.integers(0, len(targets)))
target = targets[target_idx].strip()
labeled_examples_list.append(
self.format_example(text, target, self.text_target_separator)
)
fewshot_target_idx.append(target_idx)
labeled_examples = self.example_separator.join(labeled_examples_list)
# Leave an extra `example_separator` right before the prompt.
labeled_examples += self.example_separator

prompt = self.doc_to_text(doc)
ctx = labeled_examples + prompt
logging_info = {
"fewshot_idx": fewshot_idx,
"fewshot_target_idx": fewshot_target_idx,
"fewshot_source": fewshot_src,
"fewshot_num": num_fewshot,
"ctx": ctx,
}
return ctx, logging_info

def doc_to_shot_text(self, doc: dict) -> str:
text, _ = self.shot_prompt_template.apply(doc)
return text

def doc_to_shot_target(self, doc: dict) -> List[str]:
_, target = self.shot_prompt_template.apply(doc)
return target

def fewshot_values(self):
return "French", "English", "{{ fr }}", "{{ en }}"

# heuristically hack the prompt template used to create few-shot examples
def get_fewshot_template(self):
self.shot_prompt_template = copy.deepcopy(self.prompt_template)

# get things to replace in the prompt
src_lang, trg_lang = self.prompt_template.name.split("-")[-2:]
src_sent, trg_sent = re.findall("{{ .+? }}", self.prompt_template.jinja)
# new attributes to drop in as replacement
new_src_lang, new_trg_lang, new_src_sent, new_trg_sent = self.fewshot_values()
# create new prompt
assert len(re.findall(src_lang, self.shot_prompt_template.jinja)) == 1
assert len(re.findall(trg_lang, self.shot_prompt_template.jinja)) == 1
for old_text, new_text in [
(src_lang, new_src_lang),
(trg_lang, new_trg_lang),
(src_sent, new_src_sent),
(trg_sent, new_trg_sent),
]:
self.shot_prompt_template.jinja = self.shot_prompt_template.jinja.replace(
old_text, new_text
)
return self.shot_prompt_template

def fewshot_examples(
self,
docs: datasets.Dataset,
k: int,
rng: np.random.Generator,
prompt: dict = None,
) -> Tuple[List[dict], List[int]]:
"""Returns `k` random examples from the set of documents in `docs`.
Args:
docs (datasets.Dataset):
The dataset of documents to sample few-shot examples from.
k (int):
The number of few-shot examples.
rng (np.random.Generator):
The pseudo-random number generator used to randomly sample examples.
prompt (Optional[dict]):
The prompt document. Specify this to ensure the prompt is not in
the set of few-shot examples.
Returns:
A tuple of two lists. The first list contains the few-shot examples
"""
random_indices = np.arange(len(docs)).tolist()
rng.shuffle(random_indices)

i = 0
fewshot_examples, fewshot_idx = [], []
for idx in random_indices:
if i >= k: # Break when we have enough examples.
break
is_same_prompt = False
# is never same prompt with this task
# is_same_prompt = prompt is not None and all(
# # Skips the `doc_id` key assigned to `prompt`s during eval pre-processing.
# docs[idx][k] == prompt[k]
# for k in docs[idx].keys()
# )

if self.invalid_doc_for_prompt(docs[idx]) or is_same_prompt:
continue
fewshot_examples.append(docs[idx])
fewshot_idx.append(int(idx))
i += 1
return fewshot_examples, fewshot_idx


class Flores101MT_fewshot_wmt_hi2en(Flores101MT_fewshot_wmt_fr2en):
"""
This task is Identical to the Flores101MT task, except in the few-shot setting
where few-shot examples are created using examples from the WMT14 Hindi-to-English
development set, whatever the language specified in the prompt.
"""

VERSION = 0
DATASET_PATH = "gsarti/flores_101"
DATASET_NAME = "all"

def fewshot_docs(self) -> datasets.Dataset:
"""Returns a wmt dataset split"""
return (
"valid",
datasets.load_dataset(
"wmt14",
"hi-en",
cache_dir=self.cache_dir,
download_mode=self.download_mode,
)["validation"],
)


class Flores101MT_fewshot_fr2en(Flores101MT):
Expand Down Expand Up @@ -166,7 +335,7 @@ def get_fewshot_template(self):
(src_lang, new_src_lang),
(trg_lang, new_trg_lang),
(src_sent, new_src_sent),
(trg_sent, new_src_sent),
(trg_sent, new_trg_sent),
]:
self.shot_prompt_template.jinja = self.shot_prompt_template.jinja.replace(
old_text, new_text
Expand Down Expand Up @@ -274,6 +443,21 @@ def fewshot_values(self):
return "French", "Arabic", "{{ sentence_fra }}", "{{ sentence_ara }}"


class Flores101MT_fewshot_en2bn(Flores101MT_fewshot_fr2en):
"""
This task is Identical to the Flores101MT task, except in the few-shot setting
where few-shot examples are created using English as the source language and Bengali
as the target language, whatever the language specified in the prompt.
"""

VERSION = 0
DATASET_PATH = "gsarti/flores_101"
DATASET_NAME = "all"

def fewshot_values(self):
return "English", "Bengali", "{{ sentence_eng }}", "{{ sentence_ben }}"


class Flores101Perplexity(PerplexityTask):
"""Computes the perplexity for a specific language translation of Flores-101.
Expand Down

0 comments on commit bdd1d3f

Please sign in to comment.