Skip to content

Commit

Permalink
fix a bug for prompt version 0.6 in jaqket_v2 (#105)
Browse files Browse the repository at this point in the history
* fix a bug for prompt version `0.6` in jaqket_v2

* clean up based on pull/80
  • Loading branch information
mkshing authored Oct 24, 2023
1 parent c10a2b8 commit 9999d17
Showing 1 changed file with 25 additions and 133 deletions.
158 changes: 25 additions & 133 deletions lm_eval/tasks/ja/jaqket_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,125 +138,41 @@ def doc_to_target(self, doc):
answer = answer_list[0]
return answer

def fewshot_context(
self, doc, num_fewshot, provide_description=None, rnd=None, description=None
):
"""Returns a fewshot context string that is made up of a prepended description
(if provided), the `num_fewshot` number of examples, and an appended prompt example.
:param doc: str
The document as returned from training_docs, validation_docs, or test_docs.
:param num_fewshot: int
The number of fewshot examples to provide in the returned context string.
:param provide_description: bool
Not implemented, and this option is deprecated and will be removed in a future version in favor of a different description providing method
:param rnd: random.Random
The pseudo-random number generator used to randomly sample examples.
WARNING: This is currently a required arg although it's optionalized with a default `None`.
:param description: str
The task's description that will be prepended to the fewshot examples.
:returns: str
The fewshot context.
"""
assert (
rnd is not None
), "A `random.Random` generator argument must be provided to `rnd`"
assert not provide_description, (
"The `provide_description` arg will be removed in future versions. To prepend "
"a custom description to the context, supply the corresponding string via the "
"`description` arg."
def fewshot_context(self, doc, num_fewshot, **kwargs):
max_num_tokens = max(
[len(self._tokenize(answer)) for answer in doc["answers"]["text"]]
)
max_length = self.max_length - max_num_tokens

# If the prompt is too long with fewshot examples, reduce the number of
# examples until it fits.
while num_fewshot >= 0:
ctx = super().fewshot_context(doc, num_fewshot, **kwargs)
if len(self._tokenize(ctx)) <= max_length:
doc["context"] = ctx
return ctx
num_fewshot -= 1

# if we got here then even 0 fewshot is too long
return ValueError(
f"0-shot prompt is too long for max length {max_length}:\n{ctx}"
)
if provide_description is not None:
# nudge people to not specify it at all
print(
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
)

if hasattr(self, "FEWSHOT_SEP"):
FEWSHOT_SEP = self.FEWSHOT_SEP
elif hasattr(self, "SEP"):
FEWSHOT_SEP = f"{self.SEP}{self.SEP}"
else:
FEWSHOT_SEP = "\n\n"

if description:
description += FEWSHOT_SEP
elif hasattr(self, "DESCRIPTION"):
description = self.DESCRIPTION
else:
description = ""

if num_fewshot == 0:
labeled_examples = ""
def _tokenize(self, text, **kwargs):
encode_fn = self.tokenizer.encode
if "add_special_tokens" in inspect.getfullargspec(encode_fn).args:
encode_params = dict(add_special_tokens=False)
else:
# for sets with no training docs, draw from other set *but ensure no overlap with current doc*
if self.has_training_docs():
fewshotex = self.fewshot_examples(k=num_fewshot, rnd=rnd)
else:
if self._fewshot_docs is None:
self._fewshot_docs = list(
self.validation_docs()
if self.has_validation_docs()
else self.test_docs()
)

fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1)

# get rid of the doc that's the one we're evaluating, if it's in the fewshot
fewshotex = [x for x in fewshotex if x != doc][:num_fewshot]

labeled_examples = (
FEWSHOT_SEP.join(
[
self.doc_to_answering_text(doc) + self.doc_to_target(doc)
for doc in fewshotex
]
)
+ FEWSHOT_SEP
)

example = self.doc_to_text(doc)
return description + labeled_examples + example

def preprocess_ctx(self, ctx, max_length):
# if ctx fits in max length, return
if len(self.tokenizer.encode(ctx)) <= max_length:
return ctx

# if ctx is too long, split on a tag that separates each example
description, remainder = ctx.split(self.FEWSHOT_SEP, 1)
ctxs = remainder.split(self.FEWSHOT_SEP)

# if there is no example and still the description + QA prompt is too long, fail
if len(ctxs) < 2:
raise ValueError(
f"description + QA prompt with no example (0-shot) doesn't fit in max_length. ctx: {ctx}"
)

# delete the first example, the last includes QA prompt to be answered by lm
del ctxs[0]

# recur
return self.preprocess_ctx(
self.FEWSHOT_SEP.join([description, *ctxs]), max_length
)
encode_params = {}
return encode_fn(text, **encode_params, **kwargs)

def construct_requests(self, doc, ctx):
if DYNAMIC_MAX_LENGTH == "false" or not hasattr(self.tokenizer, "encode"):
continuation = rf.greedy_until(ctx, [self.SEP])
else:
encode_fn = self.tokenizer.encode
if "add_special_tokens" in inspect.getfullargspec(encode_fn).args:
encode_params = dict(add_special_tokens=False)
else:
encode_params = {}
max_num_tokens = max(
[
len(encode_fn(answer, **encode_params))
for answer in doc["answers"]["text"]
]
[len(self._tokenize(answer)) for answer in doc["answers"]["text"]]
)
ctx = self.preprocess_ctx(ctx, max_length=self.max_length - max_num_tokens)
continuation = rf.greedy_until(ctx, [self.SEP], max_num_tokens)
return continuation

Expand Down Expand Up @@ -433,30 +349,6 @@ def doc_to_answering_text(self, doc):
qa_prompt = self.doc_to_qa_prompt(doc)
return f"ユーザー: {answer_candidate}{self.SEP}{qa_prompt}{self.SEP}システム: "

def preprocess_ctx(self, ctx, max_length):
# if ctx fits in max length, return
if len(self.tokenizer.encode(ctx)) <= max_length:
return ctx

# if ctx is too long, split on a tag that separates each example
description, remainder = ctx.split(self.END_OF_DESCRIPTION, 1)
ctxs = remainder.split(self.START_OF_FEWSHOT)

# if there is no example and still the description + QA prompt is too long, fail
if len(ctxs) < 2:
raise ValueError(
f"description + QA prompt with no example (0-shot) doesn't fit in max_length. ctx: {ctx}"
)

# delete the first example, the last includes QA prompt to be answered by lm
del ctxs[1]

new_ctx = self.END_OF_DESCRIPTION.join(
[description, self.START_OF_FEWSHOT.join(ctxs)]
)
# recur
return self.preprocess_ctx(new_ctx, max_length)


class JAQKETV2WithRinnaBilingualInstructionSFT(JAQKETV2WithRinnaInstructionSFT):
"""
Expand Down

0 comments on commit 9999d17

Please sign in to comment.