Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add flash-attn deterministic option to flash-attn>=2.4.1 #31961

Merged

Conversation

junrae6454
Copy link
Contributor

What does this PR do?

This PR introduces support for the deterministic mode available in Flash Attention version 2.4.1 and above.
The deterministic mode in Flash Attention 2.4.1 ensures reproducible results, which is crucial for debugging and scientific research. By integrating this feature, users can benefit from deterministic behavior in their models.
The changes include enabling deterministic behavior through an environment variable and integrating this feature into the _flash_attention_forward function.

  • Flash Attention Forward Function Update:

    • Added deterministic parameter to _flash_attention_forward to control deterministic behavior.
    • Modified the function to check for Flash Attention version 2.4.1 or higher and set the deterministic argument accordingly.
  • Trainer Utilities Update:

    • Set the FLASH_ATTENTION_DETERMINISTIC environment variable in enable_full_determinism.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@ArthurZucker

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good! are there other args we should be reading from env ?

from typing import Optional, Tuple

import torch
import torch.nn.functional as F

from .utils import is_flash_attn_2_available
from .utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_41
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from .utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_41
from .utils import is_flash_attn_2_available, is_flash_attn_greater

@@ -177,6 +181,9 @@ def _flash_attention_forward(
)
flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {}

if is_flash_attn_greater_or_equal_2_41():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if is_flash_attn_greater_or_equal_2_41():
if is_flash_attn_greater_or_equal("2.41"):

Comment on lines 817 to 821
def is_flash_attn_greater_or_equal_2_41():
if not _is_package_available("flash_attn"):
return False

return version.parse(importlib.metadata.version("flash_attn")) >= version.parse("2.4.1")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def is_flash_attn_greater_or_equal_2_41():
if not _is_package_available("flash_attn"):
return False
return version.parse(importlib.metadata.version("flash_attn")) >= version.parse("2.4.1")

unnecessary, see def is_flash_attn_greater_or_equal(library_version: str):

@junrae6454
Copy link
Contributor Author

sounds good! are there other args we should be reading from env ?

@ArthurZucker
Other changes are related to the model structure and the model's return value, so there doesn't seem to be anything else that needs to be read from the environment.

@ArthurZucker ArthurZucker mentioned this pull request Jul 16, 2024
4 tasks
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for iterating

@ArthurZucker ArthurZucker merged commit 036d3de into huggingface:main Jul 16, 2024
18 of 20 checks passed
amyeroberts pushed a commit to amyeroberts/transformers that referenced this pull request Jul 19, 2024
…#31961)

* add flash-attn deterministic option to flash-attn>=2.4.1

* Add Missing Import

* Fix ruff linting issues

* Replace `is_flash_attn_greater_or_equal_2_41` with the existing `is_flash_attn_greater_or_equal`

---------

Co-authored-by: jun.4 <[email protected]>
MHRDYN7 pushed a commit to MHRDYN7/transformers that referenced this pull request Jul 23, 2024
…#31961)

* add flash-attn deterministic option to flash-attn>=2.4.1

* Add Missing Import

* Fix ruff linting issues

* Replace `is_flash_attn_greater_or_equal_2_41` with the existing `is_flash_attn_greater_or_equal`

---------

Co-authored-by: jun.4 <[email protected]>
zucchini-nlp pushed a commit to zucchini-nlp/transformers that referenced this pull request Jul 24, 2024
…#31961)

* add flash-attn deterministic option to flash-attn>=2.4.1

* Add Missing Import

* Fix ruff linting issues

* Replace `is_flash_attn_greater_or_equal_2_41` with the existing `is_flash_attn_greater_or_equal`

---------

Co-authored-by: jun.4 <[email protected]>
itazap pushed a commit that referenced this pull request Jul 25, 2024
* add flash-attn deterministic option to flash-attn>=2.4.1

* Add Missing Import

* Fix ruff linting issues

* Replace `is_flash_attn_greater_or_equal_2_41` with the existing `is_flash_attn_greater_or_equal`

---------

Co-authored-by: jun.4 <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants