Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix a bug for prompt version 0.6 in jaqket_v2 #105

Merged
merged 2 commits into from
Oct 24, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 25 additions & 133 deletions lm_eval/tasks/ja/jaqket_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,125 +138,41 @@ def doc_to_target(self, doc):
answer = answer_list[0]
return answer

def fewshot_context(
self, doc, num_fewshot, provide_description=None, rnd=None, description=None
):
"""Returns a fewshot context string that is made up of a prepended description
(if provided), the `num_fewshot` number of examples, and an appended prompt example.

:param doc: str
The document as returned from training_docs, validation_docs, or test_docs.
:param num_fewshot: int
The number of fewshot examples to provide in the returned context string.
:param provide_description: bool
Not implemented, and this option is deprecated and will be removed in a future version in favor of a different description providing method
:param rnd: random.Random
The pseudo-random number generator used to randomly sample examples.
WARNING: This is currently a required arg although it's optionalized with a default `None`.
:param description: str
The task's description that will be prepended to the fewshot examples.
:returns: str
The fewshot context.
"""
assert (
rnd is not None
), "A `random.Random` generator argument must be provided to `rnd`"
assert not provide_description, (
"The `provide_description` arg will be removed in future versions. To prepend "
"a custom description to the context, supply the corresponding string via the "
"`description` arg."
def fewshot_context(self, doc, num_fewshot, **kwargs):
max_num_tokens = max(
[len(self._tokenize(answer)) for answer in doc["answers"]["text"]]
)
max_length = self.max_length - max_num_tokens

# If the prompt is too long with fewshot examples, reduce the number of
# examples until it fits.
while num_fewshot >= 0:
ctx = super().fewshot_context(doc, num_fewshot, **kwargs)
if len(self._tokenize(ctx)) <= max_length:
doc["context"] = ctx
return ctx
num_fewshot -= 1

# if we got here then even 0 fewshot is too long
return ValueError(
f"0-shot prompt is too long for max length {max_length}:\n{ctx}"
)
if provide_description is not None:
# nudge people to not specify it at all
print(
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
)

if hasattr(self, "FEWSHOT_SEP"):
FEWSHOT_SEP = self.FEWSHOT_SEP
elif hasattr(self, "SEP"):
FEWSHOT_SEP = f"{self.SEP}{self.SEP}"
else:
FEWSHOT_SEP = "\n\n"

if description:
description += FEWSHOT_SEP
elif hasattr(self, "DESCRIPTION"):
description = self.DESCRIPTION
else:
description = ""

if num_fewshot == 0:
labeled_examples = ""
def _tokenize(self, text, **kwargs):
encode_fn = self.tokenizer.encode
if "add_special_tokens" in inspect.getfullargspec(encode_fn).args:
encode_params = dict(add_special_tokens=False)
else:
# for sets with no training docs, draw from other set *but ensure no overlap with current doc*
if self.has_training_docs():
fewshotex = self.fewshot_examples(k=num_fewshot, rnd=rnd)
else:
if self._fewshot_docs is None:
self._fewshot_docs = list(
self.validation_docs()
if self.has_validation_docs()
else self.test_docs()
)

fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1)

# get rid of the doc that's the one we're evaluating, if it's in the fewshot
fewshotex = [x for x in fewshotex if x != doc][:num_fewshot]

labeled_examples = (
FEWSHOT_SEP.join(
[
self.doc_to_answering_text(doc) + self.doc_to_target(doc)
for doc in fewshotex
]
)
+ FEWSHOT_SEP
)

example = self.doc_to_text(doc)
return description + labeled_examples + example

def preprocess_ctx(self, ctx, max_length):
# if ctx fits in max length, return
if len(self.tokenizer.encode(ctx)) <= max_length:
return ctx

# if ctx is too long, split on a tag that separates each example
description, remainder = ctx.split(self.FEWSHOT_SEP, 1)
ctxs = remainder.split(self.FEWSHOT_SEP)

# if there is no example and still the description + QA prompt is too long, fail
if len(ctxs) < 2:
raise ValueError(
f"description + QA prompt with no example (0-shot) doesn't fit in max_length. ctx: {ctx}"
)

# delete the first example, the last includes QA prompt to be answered by lm
del ctxs[0]

# recur
return self.preprocess_ctx(
self.FEWSHOT_SEP.join([description, *ctxs]), max_length
)
encode_params = {}
return encode_fn(text, **encode_params, **kwargs)

def construct_requests(self, doc, ctx):
if DYNAMIC_MAX_LENGTH == "false" or not hasattr(self.tokenizer, "encode"):
continuation = rf.greedy_until(ctx, [self.SEP])
else:
encode_fn = self.tokenizer.encode
if "add_special_tokens" in inspect.getfullargspec(encode_fn).args:
encode_params = dict(add_special_tokens=False)
else:
encode_params = {}
max_num_tokens = max(
[
len(encode_fn(answer, **encode_params))
for answer in doc["answers"]["text"]
]
[len(self._tokenize(answer)) for answer in doc["answers"]["text"]]
)
ctx = self.preprocess_ctx(ctx, max_length=self.max_length - max_num_tokens)
continuation = rf.greedy_until(ctx, [self.SEP], max_num_tokens)
return continuation

Expand Down Expand Up @@ -433,30 +349,6 @@ def doc_to_answering_text(self, doc):
qa_prompt = self.doc_to_qa_prompt(doc)
return f"ユーザー: {answer_candidate}{self.SEP}{qa_prompt}{self.SEP}システム: "

def preprocess_ctx(self, ctx, max_length):
# if ctx fits in max length, return
if len(self.tokenizer.encode(ctx)) <= max_length:
return ctx

# if ctx is too long, split on a tag that separates each example
description, remainder = ctx.split(self.END_OF_DESCRIPTION, 1)
ctxs = remainder.split(self.START_OF_FEWSHOT)

# if there is no example and still the description + QA prompt is too long, fail
if len(ctxs) < 2:
raise ValueError(
f"description + QA prompt with no example (0-shot) doesn't fit in max_length. ctx: {ctx}"
)

# delete the first example, the last includes QA prompt to be answered by lm
del ctxs[1]

new_ctx = self.END_OF_DESCRIPTION.join(
[description, self.START_OF_FEWSHOT.join(ctxs)]
)
# recur
return self.preprocess_ctx(new_ctx, max_length)


class JAQKETV2WithRinnaBilingualInstructionSFT(JAQKETV2WithRinnaInstructionSFT):
"""
Expand Down
Loading