Skip to content

Commit

Permalink
[Gemma2] Support FA2 softcapping (#31887)
Browse files Browse the repository at this point in the history
* Support softcapping

* strictly greater than

* update
  • Loading branch information
ArthurZucker committed Jul 11, 2024
1 parent 2e43416 commit e002fcd
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 1 deletion.
7 changes: 6 additions & 1 deletion src/transformers/models/gemma2/modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal,
is_flash_attn_greater_or_equal_2_10,
logging,
replace_return_docstrings,
Expand Down Expand Up @@ -382,6 +383,7 @@ def forward(
q_len,
dropout=dropout_rate,
softmax_scale=self.scaling,
softcap=self.config.attn_logit_softcapping,
)

attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
Expand All @@ -402,6 +404,7 @@ def _flash_attention_forward(
dropout=0.0,
softmax_scale=None,
cache_position=0,
softcap=None,
):
"""
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
Expand Down Expand Up @@ -432,7 +435,9 @@ def _flash_attention_forward(
use_sliding_windows = (
_flash_supports_window_size and self.sliding_window is not None and cache_position > self.sliding_window
)
flash_kwargs = {"window_size": (self.sliding_window, self.sliding_window)} if use_sliding_windows else {}
flash_kwargs = {"softcap"} if is_flash_attn_greater_or_equal("2.6.0") else {}
if use_sliding_windows:
flash_kwargs.update({"window_size": (self.sliding_window, self.sliding_window)})
# Contains at least one padding token in the sequence
if attention_mask is not None:
batch_size = query_states.shape[0]
Expand Down
1 change: 1 addition & 0 deletions src/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@
is_essentia_available,
is_faiss_available,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal,
is_flash_attn_greater_or_equal_2_10,
is_flax_available,
is_fsdp_available,
Expand Down
7 changes: 7 additions & 0 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -812,6 +812,13 @@ def is_flash_attn_greater_or_equal_2_10():
return version.parse(importlib.metadata.version("flash_attn")) >= version.parse("2.1.0")


def is_flash_attn_greater_or_equal(version):
if not _is_package_available("flash_attn"):
return False

return version.parse(importlib.metadata.version("flash_attn")) >= version.parse(version)


def is_torchdistx_available():
return _torchdistx_available

Expand Down

0 comments on commit e002fcd

Please sign in to comment.