Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

don't zero out the attention_mask when using sliding window with flash attention #31670

Merged
merged 2 commits into from
Jun 28, 2024

Conversation

winglian
Copy link
Contributor

What does this PR do?

Flash attention has it's own sliding window argument, but currently the attention_mask is clobbered with sliding window leading to incorrect/large loss values (and also breaks sample packing in axolotl):

Here's the existing loss with flash attention and no sample packing (LoRA):

Screenshot 2024-06-27 at 12 17 12 PM

Here's the loss with the fix (no sample packing):

Screenshot 2024-06-27 at 12 20 09 PM

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@muellerzr @SunMarc @ArthurZucker

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can indeed skip that step for FA2! as the sliding argument is enough

@ArthurZucker ArthurZucker merged commit 0142aab into huggingface:main Jun 28, 2024
17 of 20 checks passed
@ArthurZucker
Copy link
Collaborator

ArthurZucker commented Jun 28, 2024

Thanks 🤗

ArthurZucker pushed a commit that referenced this pull request Jun 28, 2024
…h attention (#31670)

* don't zero out the attention_mask when using sliding window with flash attention

* chore: lint
@ArthurZucker
Copy link
Collaborator

Checkout https://pypi.org/project/transformers/4.42.2/ which includes this commit!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants