Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FlashAttention #24

Merged
merged 25 commits into from
Mar 8, 2023
Merged

Add FlashAttention #24

merged 25 commits into from
Mar 8, 2023

Conversation

epwalsh
Copy link
Member

@epwalsh epwalsh commented Mar 8, 2023

Closes #23

⚠️ This is blocked right now on Dao-AILab/flash-attention#132 (update: I got this to work with an older version of triton, we'll keep tracking that issue and update triton when it's fixed - #26)

Uses the official FlashAttention implementation. I've prebuilt a Python 3.10, PyTorch 1.13.1, CUDA 11.7 wheel for this which you can install with:

pip install https://storage.googleapis.com/ai2-python-wheels/flash_attn/flash_attn-0.2.8%2Bcu117torch1.13.1-cp310-cp310-linux_x86_64.whl

This will be built in to the Docker images.

For now we'll have to use their triton version since the CUDA version doesn't support arbitrary attention biases, meaning we can't use ALiBi.

The advantages and disadvantages of the Triton implementation are discussed here:

https://github.com/HazyResearch/flash-attention/blob/57ee618170e1adecbf787365cdf330c63768abd2/flash_attn/flash_attn_triton.py#L1-L35

And to add to that, one of our contacts at MosaicML says:

In my experience, IF you are not using ALL the GPUs memory, the triton version is almost always slightly faster than the CUDA FlashAttn implementation.
But its much slower when you are close to using all the memory (happens with larger models).
I've filed an issue in the triton repo
There I note that, in the FSDP config, setting limit_all_gathers to True enables running the model with the triton attn implementation at larger scale.
As noted here, on 128GPUs I can run the 30B model using Triton (which supports bias) and its nearly as fast as using the CUDA version.
Final note: FlashAttn has a rewrite (slated for April) and they plan to support attn bias in May.
With the current implementation there is this RP for supporting bias. I tried building from that branch and running, but the kernel was much slower in that branch...


This PR is based on:

@epwalsh epwalsh changed the title Get flash attention working Add FlashAttention Mar 8, 2023
@epwalsh epwalsh marked this pull request as ready for review March 8, 2023 22:55
@epwalsh epwalsh merged commit 5222c35 into main Mar 8, 2023
@epwalsh epwalsh deleted the flash-attn branch March 8, 2023 23:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Get flash attention working
3 participants