-
Notifications
You must be signed in to change notification settings - Fork 26.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[core
] Integrate Flash attention 2 in most used models
#25598
Changes from all commits
8bb77a1
2e18421
fe5795e
4bd15e2
49fe318
f5d440b
50491e8
7df78c0
0e30d13
ad8b905
3c31f10
2628bf3
56d0b49
20d1b37
a82f1ca
c72e8ff
66823f9
41f8f3d
a64a1a9
67e3fc2
8444ab6
8b1c2df
c3ebcd2
1c212d8
4618701
85ec946
2248f20
0881ced
a8a1b2d
2be3e03
b6d3e58
b47e85c
db8bd64
58848ab
3f73557
baae736
55f6140
10d5c1b
3fb221a
a931aeb
68a1204
36e0d6e
2beeb68
7b5da2c
b99a582
adaed45
7f06af6
2d36c6f
a663fa4
9d3693f
65ae59c
43185b5
c61157e
2f17792
65c3861
165a503
5abc702
5069e4a
11400d8
ace7939
fe9b16d
6174c06
acfc954
33a0f62
ee8ba20
e28fb0b
025727c
8f7e400
3259392
57a077b
e62b0b8
7419438
3ba5e98
585e463
ec0f8b9
3e5ea35
4bb1bc5
3ea4633
b67c21e
5b73557
48e3bcf
0461384
8d72a66
73b2f07
fb7654c
a737bde
80951ae
6f7ff42
257a633
360da70
1d91bc4
8ecab97
7c5720f
28b82e2
84b5793
949172f
825c7e0
1af232c
d7f16c5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,154 @@ rendered properly in your Markdown viewer. | |
|
||
In addition to this guide, relevant information can be found as well in [the guide for training on a single GPU](perf_train_gpu_one) and [the guide for inference on CPUs](perf_infer_cpu). | ||
|
||
## Flash Attention 2 | ||
|
||
<Tip> | ||
|
||
Note that this feature is experimental and might considerably change in future versions. For instance, the Flash Attention 2 API might migrate to `BetterTransformer` API in the near future. | ||
|
||
</Tip> | ||
|
||
Flash Attention 2 can considerably speed up transformer-based models' training and inference speed. Flash Attention 2 has been introduced in the [official Flash Attention repository](https://github.com/Dao-AILab/flash-attention) by Tri Dao et al. The scientific paper on Flash Attention can be found [here](https://arxiv.org/abs/2205.14135). | ||
|
||
Make sure to follow the installation guide on the repository mentioned above to properly install Flash Attention 2. Once that package is installed, you can benefit from this feature. | ||
|
||
We natively support Flash Attention 2 for the following models: | ||
|
||
- Llama | ||
- Falcon | ||
|
||
You can request to add Flash Attention 2 support for more models by opening an issue on GitHub, and even open a Pull Request to integrate the changes. The supported models can be used for inference and training, including training with padding tokens - *which is currently not supported for `BetterTransformer` API below.* | ||
|
||
<Tip> | ||
|
||
Flash Attention 2 can only be used when the models' dtype is `fp16` or `bf16` and runs only on NVIDIA-GPU devices. Make sure to cast your model to the appropriate dtype and load them on a supported device before using that feature. | ||
|
||
</Tip> | ||
|
||
### Quick usage | ||
|
||
To enable Flash Attention 2 in your model, add `use_flash_attention_2` in the `from_pretrained` arguments: | ||
|
||
```python | ||
import torch | ||
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM | ||
|
||
model_id = "tiiuae/falcon-7b" | ||
tokenizer = AutoTokenizer.from_pretrained(model_id) | ||
|
||
model = AutoModelForCausalLM.from_pretrained( | ||
model_id, | ||
torch_dtype=torch.bfloat16, | ||
use_flash_attention_2=True, | ||
) | ||
``` | ||
|
||
And use it for generation or fine-tuning. | ||
|
||
### Expected speedups | ||
|
||
You can benefit from considerable speedups for fine-tuning and inference, especially for long sequences. However, since Flash Attention does not support computing attention scores with padding tokens under the hood, we must manually pad / unpad the attention scores for batched inference when the sequence contains padding tokens. This leads to a significant slowdown for batched generations with padding tokens. | ||
|
||
To overcome this, one should use Flash Attention without padding tokens in the sequence for training (e.g., by packing a dataset, i.e., concatenating sequences until reaching the maximum sequence length. An example is provided [here](https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_clm.py#L516). | ||
|
||
Below is the expected speedup you can get for a simple forward pass on [tiiuae/falcon-7b](https://hf.co/tiiuae/falcon-7b) with a sequence length of 4096 and various batch sizes without padding tokens: | ||
|
||
Below is the expected speedup you can get for a simple forward pass on [tiiuae/falcon-7b](https://hf.co/tiiuae/falcon-7b) with a sequence length of 4096 and various batch sizes, without padding tokens: | ||
|
||
<div style="text-align: center"> | ||
<img src="https://huggingface.co/datasets/ybelkada/documentation-images/resolve/main/falcon-7b-inference-large-seqlen.png"> | ||
</div> | ||
|
||
Below is the expected speedup you can get for a simple forward pass on [`meta-llama/Llama-7b-hf`](https://hf.co/meta-llama/Llama-7b-hf) with a sequence length of 4096 and various batch sizes, without padding tokens: | ||
|
||
<div style="text-align: center"> | ||
<img src="https://huggingface.co/datasets/ybelkada/documentation-images/resolve/main/llama-7b-inference-large-seqlen.png"> | ||
</div> | ||
|
||
For sequences with padding tokens (training with padding tokens or generating with padding tokens), we need to unpad / pad the input sequences to compute correctly the attention scores. For relatively small sequence length, on pure forward pass, this creates an overhead leading to a small speedup (below 30% of the input has been filled with padding tokens). | ||
|
||
<div style="text-align: center"> | ||
<img src="https://huggingface.co/datasets/ybelkada/documentation-images/resolve/main/llama-2-small-seqlen-padding.png"> | ||
</div> | ||
|
||
But for large sequence length you can benefit from interesting speedup for pure inference (also training) | ||
|
||
Note that Flash Attention makes the attention computation more memory efficient, meaning you can train with much larger sequence lengths without facing CUDA OOM issues. It can lead up to memory reduction up to 20 for large sequence length. Check out [the official flash attention repository](https://github.com/Dao-AILab/flash-attention) for more details. | ||
|
||
<div style="text-align: center"> | ||
<img src="https://huggingface.co/datasets/ybelkada/documentation-images/resolve/main/llama-2-large-seqlen-padding.png"> | ||
</div> | ||
|
||
|
||
### Advanced usage | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. very nice examples! |
||
|
||
You can combine this feature with many exisiting feature for model optimization. Check out few examples below: | ||
|
||
### Combining Flash Attention 2 and 8-bit models | ||
|
||
You can combine this feature together with 8-bit quantization: | ||
|
||
```python | ||
import torch | ||
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM | ||
|
||
model_id = "tiiuae/falcon-7b" | ||
tokenizer = AutoTokenizer.from_pretrained(model_id) | ||
|
||
model = AutoModelForCausalLM.from_pretrained( | ||
model_id, | ||
load_in_8bit=True, | ||
use_flash_attention_2=True, | ||
) | ||
``` | ||
|
||
### Combining Flash Attention 2 and 4-bit models | ||
|
||
You can combine this feature together with 4-bit quantization: | ||
|
||
```python | ||
import torch | ||
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM | ||
|
||
model_id = "tiiuae/falcon-7b" | ||
tokenizer = AutoTokenizer.from_pretrained(model_id) | ||
|
||
model = AutoModelForCausalLM.from_pretrained( | ||
model_id, | ||
load_in_4bit=True, | ||
use_flash_attention_2=True, | ||
) | ||
``` | ||
|
||
### Combining Flash Attention 2 and PEFT | ||
|
||
You can combine this feature together with PEFT for training adapters using Flash Attention 2 under the hood: | ||
|
||
```python | ||
import torch | ||
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM | ||
from peft import LoraConfig | ||
|
||
model_id = "tiiuae/falcon-7b" | ||
tokenizer = AutoTokenizer.from_pretrained(model_id) | ||
|
||
model = AutoModelForCausalLM.from_pretrained( | ||
model_id, | ||
load_in_4bit=True, | ||
use_flash_attention_2=True, | ||
) | ||
|
||
lora_config = LoraConfig( | ||
r=8, | ||
task_type="CAUSAL_LM" | ||
) | ||
|
||
model.add_adapter(lora_config) | ||
|
||
... # train your model | ||
``` | ||
|
||
## BetterTransformer | ||
|
||
[BetterTransformer](https://huggingface.co/docs/optimum/bettertransformer/overview) converts 🤗 Transformers models to use the PyTorch-native fastpath execution, which calls optimized kernels like Flash Attention under the hood. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -70,6 +70,7 @@ | |
is_accelerate_available, | ||
is_auto_gptq_available, | ||
is_bitsandbytes_available, | ||
is_flash_attn_available, | ||
is_offline_mode, | ||
is_optimum_available, | ||
is_peft_available, | ||
|
@@ -1116,6 +1117,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix | |
is_parallelizable = False | ||
supports_gradient_checkpointing = False | ||
|
||
# Flash Attention 2 support | ||
_supports_flash_attn_2 = False | ||
|
||
@property | ||
def dummy_inputs(self) -> Dict[str, torch.Tensor]: | ||
""" | ||
|
@@ -1239,6 +1243,84 @@ def can_generate(cls) -> bool: | |
return False | ||
return True | ||
|
||
@classmethod | ||
def _check_and_enable_flash_attn_2( | ||
cls, config, torch_dtype: Optional[torch.dtype] = None, device_map: Optional[Union[str, Dict[str, int]]] = None | ||
) -> PretrainedConfig: | ||
Comment on lines
+1247
to
+1249
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good error raising here |
||
""" | ||
If you don't know about Flash Attention, check out the official repository of flash attention: | ||
https://github.com/Dao-AILab/flash-attention | ||
|
||
For using Flash Attention 1.0 you can do it directly via the `BetterTransformer` API, have a look at this | ||
specific section of the documentation to learn more about it: | ||
https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#decoder-models | ||
|
||
The method checks if the current setup is compatible with Flash Attention as it requires the model to be in | ||
half precision and not ran on CPU. | ||
|
||
If all checks pass, the method will create an attribute in the config `_flash_attn_2_enabled` so that the model | ||
can initialize the correct attention module | ||
""" | ||
if not cls._supports_flash_attn_2: | ||
raise ValueError( | ||
"The current architecture does not support Flash Attention 2.0. Please open an issue on GitHub to " | ||
"request support for this architecture: https://github.com/huggingface/transformers/issues/new" | ||
) | ||
|
||
if not is_flash_attn_available(): | ||
raise ImportError( | ||
"Flash Attention 2.0 is not available. Please refer to the documentation of https://github.com/Dao-AILab/flash-attention for" | ||
" installing it." | ||
) | ||
else: | ||
flash_attention_version = version.parse(importlib.metadata.version("flash_attn")) | ||
is_flash_greater_than_2 = flash_attention_version > version.parse("2.0.0") | ||
if not is_flash_greater_than_2: | ||
raise ValueError( | ||
f"You need flash_attn package version to be greater than 2.0. Make sure to have that version installed - detected version {flash_attention_version}" | ||
) | ||
|
||
_is_bettertransformer = getattr(cls, "use_bettertransformer", False) | ||
|
||
if _is_bettertransformer: | ||
raise ValueError( | ||
"Flash Attention 2 and BetterTransformer API are not compatible. Please make sure to disable BetterTransformers by doing model.reverse_bettertransformer()" | ||
) | ||
Comment on lines
+1285
to
+1288
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this just toggle it off with an |
||
|
||
if torch_dtype is None: | ||
logger.warning( | ||
"You are attempting to use Flash Attention 2.0 without specifying a torch dtype. This might lead to unexpected behaviour" | ||
) | ||
elif torch_dtype is not None and torch_dtype not in [torch.float16, torch.bfloat16]: | ||
raise ValueError( | ||
f"Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes. You passed {torch_dtype}, this might lead to" | ||
" unexpected behaviour." | ||
) | ||
|
||
if device_map is None: | ||
if torch.cuda.is_available(): | ||
logger.warning( | ||
"You are attempting to use Flash Attention 2.0 with a model initialized on CPU. Make sure to move the model to GPU" | ||
" after initializing it on CPU with `model.to('cuda')`." | ||
) | ||
else: | ||
raise ValueError( | ||
"You are attempting to use Flash Attention 2.0 with a model initialized on CPU and with no GPU available. " | ||
"This is not supported yet. Please make sure to have access to a GPU and either initialise the model on a GPU by passing a device_map " | ||
"or initialising the model on CPU and then moving it to GPU." | ||
) | ||
elif ( | ||
device_map is not None | ||
and isinstance(device_map, dict) | ||
and ("cpu" in device_map.values() or "disk" in device_map.values()) | ||
): | ||
raise ValueError( | ||
"You are attempting to use Flash Attention 2.0 with a model dispatched on CPU or disk. This is not supported. Please make sure to " | ||
"initialise the model on a GPU by passing a device_map that contains only GPU devices as keys." | ||
) | ||
Comment on lines
+1318
to
+1320
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so basically we can't use it if you don't have enough gpu VRAM. It's not 100% clear for me There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, also not supported if you excplictly want to do CPU / Disk offloading There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually why is this not supported? It shouldn't be a problem to support Flash Attention + cpu offload IMO (we're supporting it for diffusers) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would be nice to support indeed, would enable a bunch of larger models There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's update the comment There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This would require some work as you need to intantiate a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I propose to do it in a follow up PR |
||
config._flash_attn_2_enabled = True | ||
return config | ||
|
||
def enable_input_require_grads(self): | ||
""" | ||
Enables the gradients for the input embeddings. This is useful for fine-tuning adapter weights while keeping | ||
|
@@ -2374,6 +2456,7 @@ def from_pretrained( | |
variant = kwargs.pop("variant", None) | ||
_adapter_model_path = kwargs.pop("_adapter_model_path", None) | ||
adapter_name = kwargs.pop("adapter_name", "default") | ||
use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False) | ||
|
||
if is_fsdp_enabled(): | ||
low_cpu_mem_usage = True | ||
|
@@ -2985,6 +3068,9 @@ def from_pretrained( | |
elif load_in_8bit or load_in_4bit or low_cpu_mem_usage: | ||
init_contexts.append(init_empty_weights()) | ||
|
||
if use_flash_attention_2: | ||
config = cls._check_and_enable_flash_attn_2(config, torch_dtype=torch_dtype, device_map=device_map) | ||
|
||
with ContextManagers(init_contexts): | ||
model = cls(config, *model_args, **model_kwargs) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
did not know this was planned 😄 If not let's just not say anything
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well it is all about providing a single meaningful API to users and avoid confusing them. In PyTorch 2.2 (hopefully not too late!), we'll be in a state where FA2 will be supported by SDPA so basically a duplicate of this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As felix said, the goal in the future would be to have an unified API through
BetterTransformer
, hence marking the API as being experimental