-
Notifications
You must be signed in to change notification settings - Fork 26.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add flash-attn deterministic option to flash-attn>=2.4.1 #31961
add flash-attn deterministic option to flash-attn>=2.4.1 #31961
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds good! are there other args we should be reading from env ?
from typing import Optional, Tuple | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
|
||
from .utils import is_flash_attn_2_available | ||
from .utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_41 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from .utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_41 | |
from .utils import is_flash_attn_2_available, is_flash_attn_greater |
@@ -177,6 +181,9 @@ def _flash_attention_forward( | |||
) | |||
flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {} | |||
|
|||
if is_flash_attn_greater_or_equal_2_41(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if is_flash_attn_greater_or_equal_2_41(): | |
if is_flash_attn_greater_or_equal("2.41"): |
def is_flash_attn_greater_or_equal_2_41(): | ||
if not _is_package_available("flash_attn"): | ||
return False | ||
|
||
return version.parse(importlib.metadata.version("flash_attn")) >= version.parse("2.4.1") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def is_flash_attn_greater_or_equal_2_41(): | |
if not _is_package_available("flash_attn"): | |
return False | |
return version.parse(importlib.metadata.version("flash_attn")) >= version.parse("2.4.1") |
unnecessary, see def is_flash_attn_greater_or_equal(library_version: str):
…lash_attn_greater_or_equal`
@ArthurZucker |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for iterating
…#31961) * add flash-attn deterministic option to flash-attn>=2.4.1 * Add Missing Import * Fix ruff linting issues * Replace `is_flash_attn_greater_or_equal_2_41` with the existing `is_flash_attn_greater_or_equal` --------- Co-authored-by: jun.4 <[email protected]>
…#31961) * add flash-attn deterministic option to flash-attn>=2.4.1 * Add Missing Import * Fix ruff linting issues * Replace `is_flash_attn_greater_or_equal_2_41` with the existing `is_flash_attn_greater_or_equal` --------- Co-authored-by: jun.4 <[email protected]>
…#31961) * add flash-attn deterministic option to flash-attn>=2.4.1 * Add Missing Import * Fix ruff linting issues * Replace `is_flash_attn_greater_or_equal_2_41` with the existing `is_flash_attn_greater_or_equal` --------- Co-authored-by: jun.4 <[email protected]>
* add flash-attn deterministic option to flash-attn>=2.4.1 * Add Missing Import * Fix ruff linting issues * Replace `is_flash_attn_greater_or_equal_2_41` with the existing `is_flash_attn_greater_or_equal` --------- Co-authored-by: jun.4 <[email protected]>
What does this PR do?
This PR introduces support for the deterministic mode available in Flash Attention version 2.4.1 and above.
The deterministic mode in Flash Attention 2.4.1 ensures reproducible results, which is crucial for debugging and scientific research. By integrating this feature, users can benefit from deterministic behavior in their models.
The changes include enabling deterministic behavior through an environment variable and integrating this feature into the _flash_attention_forward function.
Flash Attention Forward Function Update:
Trainer Utilities Update:
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@ArthurZucker