Skip to content

Commit

Permalink
FIX Prompt learning with latest transformers error (#2140)
Browse files Browse the repository at this point in the history
The error in PEFT is occurring after this transformers change:

huggingface/transformers#33870

Now, in our tests, some model_kwargs no longer necessarily contain
past_key_values, resulting in a KeyError. We now account for this
possibility. Affected models were opt and gpt2.
  • Loading branch information
BenjaminBossan authored Oct 9, 2024
1 parent 8efa0cb commit 1eab9bd
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/peft/peft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1756,7 +1756,7 @@ def prepare_inputs_for_generation(self, *args, task_ids: Optional[torch.Tensor]
if peft_config.peft_type == PeftType.POLY:
model_kwargs["task_ids"] = task_ids
if peft_config.is_prompt_learning:
if uses_cache and (model_kwargs["past_key_values"] is not None):
if uses_cache and (model_kwargs.get("past_key_values", None) is not None):
# change in the logic of `prepare_inputs_for_generation` makes the below code necessary
# In prompt learning methods, past key values are longer when compared to the `input_ids`.
# As such only consider the last input ids in the autogressive generation phase.
Expand Down Expand Up @@ -1786,7 +1786,7 @@ def prepare_inputs_for_generation(self, *args, task_ids: Optional[torch.Tensor]
kwargs["token_type_ids"] = None

# no past_key_values or past_key_values empty cache
requires_prompt_injection = (model_kwargs["past_key_values"] is None) or (
requires_prompt_injection = (model_kwargs.get("past_key_values", None) is None) or (
isinstance(model_kwargs["past_key_values"], transformers.Cache)
and not model_kwargs["past_key_values"].get_seq_length()
)
Expand Down

0 comments on commit 1eab9bd

Please sign in to comment.