Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Gemma2] Support FA2 softcapping #31887

Merged
merged 3 commits into from
Jul 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/transformers/models/gemma2/modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal,
is_flash_attn_greater_or_equal_2_10,
logging,
replace_return_docstrings,
Expand Down Expand Up @@ -382,6 +383,7 @@ def forward(
q_len,
dropout=dropout_rate,
softmax_scale=self.scaling,
softcap=self.config.attn_logit_softcapping,
)

attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
Expand All @@ -402,6 +404,7 @@ def _flash_attention_forward(
dropout=0.0,
softmax_scale=None,
cache_position=0,
softcap=None,
):
"""
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
Expand Down Expand Up @@ -432,7 +435,9 @@ def _flash_attention_forward(
use_sliding_windows = (
_flash_supports_window_size and self.sliding_window is not None and cache_position > self.sliding_window
)
flash_kwargs = {"window_size": (self.sliding_window, self.sliding_window)} if use_sliding_windows else {}
flash_kwargs = {"softcap"} if is_flash_attn_greater_or_equal("2.6.0") else {}
if use_sliding_windows:
flash_kwargs.update({"window_size": (self.sliding_window, self.sliding_window)})
# Contains at least one padding token in the sequence
if attention_mask is not None:
batch_size = query_states.shape[0]
Expand Down
1 change: 1 addition & 0 deletions src/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@
is_essentia_available,
is_faiss_available,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal,
is_flash_attn_greater_or_equal_2_10,
is_flax_available,
is_fsdp_available,
Expand Down
7 changes: 7 additions & 0 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -812,6 +812,13 @@ def is_flash_attn_greater_or_equal_2_10():
return version.parse(importlib.metadata.version("flash_attn")) >= version.parse("2.1.0")


def is_flash_attn_greater_or_equal(version):
if not _is_package_available("flash_attn"):
return False

return version.parse(importlib.metadata.version("flash_attn")) >= version.parse(version)


def is_torchdistx_available():
return _torchdistx_available

Expand Down
Loading