-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for generative answering of multiple_choice tasks #2601
base: main
Are you sure you want to change the base?
Changes from all commits
c225602
0bd64c2
5cca68f
d9e49af
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,7 +27,7 @@ | |
from lm_eval import utils | ||
from lm_eval.api import samplers | ||
from lm_eval.api.instance import Instance, OutputType | ||
from lm_eval.api.metrics import bits_per_byte, mean, weighted_perplexity | ||
from lm_eval.api.metrics import bits_per_byte, exact_match_fn, mean, weighted_perplexity | ||
from lm_eval.api.registry import ( | ||
AGGREGATION_REGISTRY, | ||
DEFAULT_METRIC_REGISTRY, | ||
|
@@ -80,6 +80,8 @@ class TaskConfig(dict): | |
use_prompt: Optional[str] = None | ||
description: str = "" | ||
target_delimiter: str = " " | ||
choice_delimiter: str = " / " | ||
option_delimiter: str = "\n" | ||
fewshot_delimiter: str = "\n\n" | ||
fewshot_config: Optional[dict] = None | ||
# runtime configuration options | ||
|
@@ -111,16 +113,15 @@ def __post_init__(self) -> None: | |
if "until" not in self.generation_kwargs: | ||
self.generation_kwargs["until"] = [self.fewshot_delimiter] | ||
else: | ||
if self.output_type == "generate_until": | ||
# ensure that we greedily generate in absence of explicit arguments otherwise | ||
self.generation_kwargs = { | ||
"until": ( | ||
None | ||
if self.fewshot_delimiter is None | ||
else [self.fewshot_delimiter] | ||
), | ||
"do_sample": False, | ||
} | ||
# ensure that we greedily generate in absence of explicit arguments otherwise | ||
self.generation_kwargs = { | ||
"until": ( | ||
None | ||
if self.fewshot_delimiter is None | ||
else [self.fewshot_delimiter] | ||
), | ||
"do_sample": False, | ||
} | ||
|
||
def __getitem__(self, item): | ||
return getattr(self, item) | ||
|
@@ -380,6 +381,7 @@ def build_all_requests( | |
system_instruction: Optional[str] = None, | ||
apply_chat_template: bool = False, | ||
fewshot_as_multiturn: bool = False, | ||
multiple_choice_generate: Union[bool, str] = False, | ||
chat_template: Optional[Callable] = None, | ||
tokenizer_name: str = "", | ||
) -> None: | ||
|
@@ -391,6 +393,7 @@ def build_all_requests( | |
cache_key = f"requests-{self._config.task}-{self.config.num_fewshot}shot-rank{rank}-world_size{world_size}" | ||
cache_key += "-chat_template" if apply_chat_template else "" | ||
cache_key += "-fewshot_as_multiturn" if fewshot_as_multiturn else "" | ||
cache_key += "-multiple_choice_generate" if multiple_choice_generate else "" | ||
cache_key += ( | ||
f"-system_prompt_hash{utils.hash_string(system_instruction)}" | ||
if system_instruction is not None | ||
|
@@ -435,12 +438,22 @@ def build_all_requests( | |
total=num_docs, | ||
): | ||
# sample fewshot context #TODO: need to offset doc_id by rank now! | ||
doc_system_instruction = system_instruction or "" | ||
if self.OUTPUT_TYPE == "multiple_choice" and multiple_choice_generate: | ||
if doc_system_instruction: | ||
doc_system_instruction += " " | ||
if multiple_choice_generate == "abcd": | ||
doc_system_instruction += "Please include \"ANSWER: <letter>\" in your response with the letter of the correct last answer." | ||
else: | ||
doc_system_instruction += "Please answer with the letter of the correct last answer." | ||
Comment on lines
+446
to
+448
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What about non-english tasks that are already inside this repo? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's a great point. This approach follows the openai/evals philosophy which hardcodes the strings. Maybe there can be a way to override these instructions task by task, or forced to be provided by users during invocation. Depends on the overall philosophy of what's defined in the task and what in the harness, which I'm not sure I completely from the outside at the first look. |
||
|
||
fewshot_ctx = self.fewshot_context( | ||
doc, | ||
0 if self.config.num_fewshot is None else self.config.num_fewshot, | ||
system_instruction, | ||
doc_system_instruction, | ||
apply_chat_template, | ||
fewshot_as_multiturn, | ||
multiple_choice_generate, | ||
chat_template, | ||
) | ||
|
||
|
@@ -450,6 +463,7 @@ def build_all_requests( | |
ctx=fewshot_ctx, | ||
metadata=(self.config["task"], doc_id, self.config.repeats), | ||
apply_chat_template=apply_chat_template, | ||
multiple_choice_generate=multiple_choice_generate, | ||
) | ||
|
||
if not isinstance(inst, list): | ||
|
@@ -1024,6 +1038,7 @@ def fewshot_context( | |
system_instruction: Optional[str] = None, | ||
apply_chat_template: bool = False, | ||
fewshot_as_multiturn: bool = False, | ||
multiple_choice_generate: Union[bool, str] = False, | ||
chat_template: Optional[Callable] = None, | ||
) -> str: | ||
"""Returns a fewshot context string that is made up of a prepended description | ||
|
@@ -1039,6 +1054,8 @@ def fewshot_context( | |
Whether to apply the chat template to the fewshot context. | ||
:param fewshot_as_multiturn: bool | ||
Whether to provide the fewshot examples as a multiturn conversation or a single user turn. | ||
:param multiple_choice_generate: Union[bool, str] | ||
Whether to generate multiple choice answer from scratch rather than pick by logprobs. | ||
:param chat_template: | ||
callable (from lm.apply_chat_template) that takes in a list[Dict] chat transcript and renders it into a string. | ||
:returns: str | ||
|
@@ -1085,6 +1102,17 @@ def fewshot_context( | |
labeled_examples += self.sampler.get_context(doc, num_fewshot) | ||
|
||
example = self.doc_to_text(doc) | ||
if self.config.doc_to_choice is not None and multiple_choice_generate: | ||
if not isinstance(example, str): | ||
raise NotImplementedError("--multiple_choice_generate is implemented only for simple text docs") | ||
if multiple_choice_generate == "abcd": | ||
choices = self.doc_to_choice(doc) | ||
for label, choice in zip(list("ABCDEFGHIJKLMNOPQRSTUVWXYZ")[:len(choices)], choices): | ||
example += f"{self.config.option_delimiter}({label}) {choice}" | ||
else: | ||
example += self.config.target_delimiter | ||
example += "(" + self.config.choice_delimiter.join(self.doc_to_choice(doc)) + ")" | ||
|
||
if apply_chat_template: | ||
if self.multiple_input: | ||
return chat_template(labeled_examples) | ||
|
@@ -1300,17 +1328,24 @@ def doc_to_image(self, doc: Any, doc_to_image=None) -> Union[int, str, list]: | |
return None | ||
|
||
def construct_requests( | ||
self, doc: dict, ctx: str, **kwargs | ||
self, doc: dict, ctx: str, multiple_choice_generate: Union[bool, str], **kwargs | ||
) -> Union[List[Instance], Instance]: | ||
apply_chat_template = kwargs.pop("apply_chat_template", False) | ||
|
||
aux_arguments = None | ||
|
||
if self.OUTPUT_TYPE == "loglikelihood": | ||
self.multiple_choice_generate = multiple_choice_generate | ||
output_type = self.OUTPUT_TYPE | ||
if output_type == "multiple_choice" and multiple_choice_generate: | ||
output_type = "generate_until" | ||
if self.multiple_input: | ||
raise NotImplementedError("The \"multiple input\" mode of multiple_choice tasks is not implemented for --multiple_choice_generate.") | ||
|
||
if output_type == "loglikelihood": | ||
arguments = (ctx, self.doc_to_target(doc)) | ||
elif self.OUTPUT_TYPE == "loglikelihood_rolling": | ||
elif output_type == "loglikelihood_rolling": | ||
arguments = (self.doc_to_target(doc),) | ||
elif self.OUTPUT_TYPE == "multiple_choice": | ||
elif output_type == "multiple_choice": | ||
choices = self.doc_to_choice(doc) | ||
target_delimiter = self.config.target_delimiter | ||
if apply_chat_template: | ||
|
@@ -1337,7 +1372,7 @@ def construct_requests( | |
|
||
arguments.extend(aux_arguments) | ||
|
||
elif self.OUTPUT_TYPE == "generate_until": | ||
elif output_type == "generate_until": | ||
arguments = (ctx, deepcopy(self.config.generation_kwargs)) | ||
|
||
multimodal_arg = {} | ||
|
@@ -1355,7 +1390,7 @@ def construct_requests( | |
else: | ||
arguments = arguments + (multimodal_arg,) | ||
|
||
if self.OUTPUT_TYPE == "multiple_choice": | ||
if output_type == "multiple_choice": | ||
request_list = [ | ||
Instance( | ||
request_type="loglikelihood", | ||
|
@@ -1370,7 +1405,7 @@ def construct_requests( | |
return request_list | ||
|
||
return Instance( | ||
request_type=self.OUTPUT_TYPE, | ||
request_type=output_type, | ||
doc=doc, | ||
arguments=arguments, | ||
idx=0, | ||
|
@@ -1411,7 +1446,7 @@ def process_results(self, doc, results): | |
else {} | ||
), | ||
} | ||
elif self.OUTPUT_TYPE == "multiple_choice": | ||
elif self.OUTPUT_TYPE == "multiple_choice" and not self.multiple_choice_generate: | ||
lls, is_greedy = zip(*results) | ||
|
||
# retrieve choices in List[str] form, to compute choice lengths, etc. | ||
|
@@ -1492,14 +1527,22 @@ def process_results(self, doc, results): | |
acc_mutual_info = 1.0 if np.argmax(lls_mutual_info) == gold else 0.0 | ||
result_dict["acc_mutual_info"] = acc_mutual_info | ||
|
||
elif self.OUTPUT_TYPE == "generate_until": | ||
elif self.OUTPUT_TYPE == "generate_until" or (self.OUTPUT_TYPE == "multiple_choice" and self.multiple_choice_generate): | ||
gold = self.doc_to_target(doc) | ||
result = results[0] | ||
if self.config.doc_to_choice is not None: | ||
# If you set doc_to_choice, | ||
# it assumes that doc_to_target returns a number. | ||
choices = self.doc_to_choice(doc) | ||
gold = choices[gold] | ||
if self.multiple_choice_generate == "abcd": | ||
try: | ||
result_label = re.findall(r"ANSWER: ([A-Z])", result)[-1] | ||
result_i = list("ABCDEFGHIJKLMNOPQRSTUVWXYZ").index(result_label) | ||
result = choices[result_i] | ||
except (AttributeError, ValueError, IndexError): | ||
eval_logger.warning(f"[{self}] LLM did not pick a valid result ('{result}')") | ||
result = choices[0] # XXX guess "randomly" | ||
# we expect multiple_targets to be a list. | ||
elif self.multiple_target: | ||
gold = list(gold) | ||
|
@@ -1511,6 +1554,12 @@ def process_results(self, doc, results): | |
gold = type(result)(gold) | ||
|
||
for metric in self._metric_fn_list.keys(): | ||
metric_fn = self._metric_fn_list[metric] | ||
metric_result_key = metric | ||
if self.OUTPUT_TYPE == "multiple_choice" and self.multiple_choice_generate: | ||
metric_fn = exact_match_fn | ||
metric_result_key = "exact_match" | ||
|
||
if self.multiple_target: | ||
# in the case where we have multiple targets, | ||
# return true if any are true | ||
|
@@ -1522,7 +1571,7 @@ def process_results(self, doc, results): | |
gold = [gold] | ||
if metric == "exact_match": | ||
result = [result for _ in range(len(gold))] | ||
scores = self._metric_fn_list[metric]( | ||
scores = metric_fn( | ||
references=gold, | ||
predictions=result, | ||
**self._metric_fn_kwargs[metric], | ||
|
@@ -1531,15 +1580,15 @@ def process_results(self, doc, results): | |
else: | ||
for gold_option in gold: | ||
try: | ||
result_score = self._metric_fn_list[metric]( | ||
result_score = metric_fn( | ||
references=[gold_option], | ||
predictions=[result], | ||
**self._metric_fn_kwargs[metric], | ||
) | ||
except ( | ||
TypeError | ||
): # TODO: this is hacky and I don't want to do it | ||
result_score = self._metric_fn_list[metric]( | ||
result_score = metric_fn( | ||
[gold_option, result] | ||
) | ||
if isinstance(result_score, dict): | ||
|
@@ -1552,16 +1601,16 @@ def process_results(self, doc, results): | |
result_score = 0.0 | ||
else: | ||
try: | ||
result_score = self._metric_fn_list[metric]( | ||
result_score = metric_fn( | ||
references=[gold], | ||
predictions=[result], | ||
**self._metric_fn_kwargs[metric], | ||
) | ||
except TypeError: # needed for now in order to use a different interface between our own metrics and HF Evaluate metrics | ||
result_score = self._metric_fn_list[metric]([gold, result]) | ||
result_score = metric_fn([gold, result]) | ||
if isinstance(result_score, dict): | ||
# TODO: this handles the case where HF evaluate returns a dict. | ||
result_score = result_score[metric] | ||
result_score = result_score[metric_result_key] | ||
result_dict[metric] = result_score | ||
else: | ||
raise ValueError( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
May I suggest to not hardcode these. What if doc_system_instruction supposed to be delimited with some other delimiter? What if set of choices is not 4 letters, not these 4 letters, or not letters at all? This framework supports external tasks and also have multiple forks already, so there may be (I am not using "are" because of no intention to google proof of this idea) multiple choice tasks set up differently than "abcd".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to be clear, the "abcd" is just a (wannabe) userfriendly name for the feature, the letters aren't actually directly derived from the value.
Maybe the name is just confusing, but more modes can be added easily.