Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

F.scaled_dot_product_attention support #26572

Merged
merged 114 commits into from
Dec 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
114 commits
Select commit Hold shift + click to select a range
74be54c
add sdpa
fxmarty Oct 3, 2023
9d14f0d
wip
fxmarty Oct 3, 2023
f803de3
cleaning
fxmarty Oct 3, 2023
c0bcbfa
add ref
fxmarty Oct 3, 2023
38332d7
yet more cleaning
fxmarty Oct 3, 2023
3b47502
and more :)
fxmarty Oct 3, 2023
dd646c1
Merge branch 'main' into torch-sdpa-preliminary-support
fxmarty Oct 31, 2023
79c12a9
wip llama
fxmarty Oct 31, 2023
dc929cd
working llama
fxmarty Oct 31, 2023
17954dd
add output_attentions=True support
fxmarty Oct 31, 2023
f48f4fa
bigcode sdpa support
fxmarty Oct 31, 2023
dfc47a5
fixes
fxmarty Oct 31, 2023
eba83c1
gpt-bigcode support, require torch>=2.1.1
fxmarty Nov 3, 2023
5693535
add falcon support
fxmarty Nov 3, 2023
3758375
Merge branch 'main' into torch-sdpa-preliminary-support
fxmarty Nov 3, 2023
ca87380
fix conflicts falcon
fxmarty Nov 3, 2023
969dda9
style
fxmarty Nov 3, 2023
06766ec
fix attention_mask definition
fxmarty Nov 3, 2023
5c648d4
remove output_attentions from attnmaskconverter
fxmarty Nov 3, 2023
674bff4
support whisper without removing any Copied from statement
fxmarty Nov 3, 2023
dd89c3c
fix mbart default to eager renaming
fxmarty Nov 3, 2023
f31c7b3
fix typo in falcon
fxmarty Nov 6, 2023
280c078
fix is_causal in SDPA
fxmarty Nov 8, 2023
e41ecfa
check is_flash_attn_2_available in the models init as well in case th…
fxmarty Nov 17, 2023
951bce0
Merge branch 'main' into torch-sdpa-preliminary-support
fxmarty Nov 17, 2023
6f7964d
add warnings when falling back on the manual implementation
fxmarty Nov 17, 2023
0e38a95
precise doc
fxmarty Nov 17, 2023
1bd07aa
wip replace _flash_attn_enabled by config.attn_implementation
fxmarty Nov 17, 2023
feae821
fix typo
fxmarty Nov 17, 2023
2032e64
add tests
fxmarty Nov 17, 2023
d98c2f9
style
fxmarty Nov 17, 2023
ab59f9d
add a copy.deepcopy on the config in from_pretrained, as we do not wa…
fxmarty Nov 17, 2023
98a3825
obey to config.attn_implementation if a config is passed in from_pret…
fxmarty Nov 17, 2023
098a62e
fix is_torch_sdpa_available when torch is not installed
fxmarty Nov 17, 2023
b960912
remove dead code
fxmarty Nov 17, 2023
9df4c8f
Merge branch 'main' into torch-sdpa-preliminary-support
fxmarty Nov 21, 2023
f1df402
Update src/transformers/modeling_attn_mask_utils.py
fxmarty Nov 21, 2023
f49c2a3
Update src/transformers/modeling_attn_mask_utils.py
fxmarty Nov 21, 2023
3a22d8d
Update src/transformers/modeling_attn_mask_utils.py
fxmarty Nov 21, 2023
f0fa993
Update src/transformers/modeling_attn_mask_utils.py
fxmarty Nov 21, 2023
f084040
Update src/transformers/modeling_attn_mask_utils.py
fxmarty Nov 21, 2023
885bbe4
Update src/transformers/models/bart/modeling_bart.py
fxmarty Nov 21, 2023
4dd5523
remove duplicate pretraining_tp code
fxmarty Nov 21, 2023
349c99b
add dropout in llama
fxmarty Nov 21, 2023
5e56014
precise comment on attn_mask
fxmarty Nov 21, 2023
951f70e
add fmt: off for _unmask_unattended docstring
fxmarty Nov 21, 2023
c4e207e
precise num_masks comment
fxmarty Nov 21, 2023
e752d93
nuke pretraining_tp in LlamaSDPAAttention following Arthur's suggestion
fxmarty Nov 21, 2023
a072c5d
cleanup modeling_utils
fxmarty Nov 22, 2023
f700973
backward compatibility
fxmarty Nov 22, 2023
e267764
fix style as requested
fxmarty Nov 22, 2023
d044d81
style
fxmarty Nov 22, 2023
a9e7606
improve documentation
fxmarty Nov 22, 2023
1727210
test pass
fxmarty Nov 22, 2023
ae86680
style
fxmarty Nov 22, 2023
5706ecb
add _unmask_unattended tests
fxmarty Nov 22, 2023
d2326e2
skip meaningless tests for idefics
fxmarty Nov 22, 2023
c0f849e
hard_check SDPA requirements when specifically requested
fxmarty Nov 22, 2023
0fa8de0
standardize the use if XXX_ATTENTION_CLASSES
fxmarty Nov 22, 2023
637e473
fix SDPA bug with mem-efficient backend on CUDA when using fp32
fxmarty Nov 22, 2023
55ec325
fix test
fxmarty Nov 22, 2023
33ef389
rely on SDPA is_causal parameter to handle the causal mask in some cases
fxmarty Nov 22, 2023
2e6bc3e
fix FALCON_ATTENTION_CLASSES
fxmarty Nov 23, 2023
688d86e
Merge branch 'main' into torch-sdpa-preliminary-support
fxmarty Nov 23, 2023
5913dee
remove _flash_attn_2_enabled occurences
fxmarty Nov 23, 2023
11ab3ae
fix test
fxmarty Nov 23, 2023
b74894d
add OPT to the list of supported flash models
fxmarty Nov 23, 2023
4ff1057
improve test
fxmarty Nov 23, 2023
8bd6c81
properly test on different SDPA backends, on different dtypes & prope…
fxmarty Nov 24, 2023
a11c114
remove remaining _flash_attn_2_enabled occurence
fxmarty Nov 24, 2023
b5593a1
Update src/transformers/modeling_utils.py
fxmarty Nov 24, 2023
1bc983a
Update src/transformers/modeling_utils.py
fxmarty Nov 24, 2023
316b448
Update src/transformers/modeling_utils.py
fxmarty Nov 24, 2023
52178ba
Update src/transformers/modeling_attn_mask_utils.py
fxmarty Nov 24, 2023
231e354
Update docs/source/en/perf_infer_gpu_one.md
fxmarty Nov 24, 2023
f907b3f
remove use_attn_implementation
fxmarty Nov 24, 2023
0e9e9f2
fix docstring & slight bug
fxmarty Nov 24, 2023
c47c24e
Merge branch 'main' into torch-sdpa-preliminary-support
fxmarty Dec 4, 2023
5c77b94
make attn_implementation internal (_attn_implementation)
fxmarty Dec 4, 2023
cd9e209
typos
fxmarty Dec 4, 2023
e475f25
fix tests
fxmarty Dec 5, 2023
48a6bfc
deprecate use_flash_attention_2=True
fxmarty Dec 6, 2023
8e7f8b5
fix test
fxmarty Dec 6, 2023
7a85efc
add back llama that was removed by mistake
fxmarty Dec 6, 2023
3649553
fix tests
fxmarty Dec 6, 2023
f09a65c
Merge branch 'main' into torch-sdpa-preliminary-support
fxmarty Dec 7, 2023
c1b87b8
remove _flash_attn_2_enabled occurences bis
fxmarty Dec 7, 2023
8950b60
add check & test that passed attn_implementation is valid
fxmarty Dec 7, 2023
18c2678
fix falcon torchscript export
fxmarty Dec 7, 2023
d96e0d2
fix device of mask in tests
fxmarty Dec 7, 2023
bb20113
Merge branch 'fix-device-mask-tests' into torch-sdpa-preliminary-support
fxmarty Dec 7, 2023
9e133c9
add tip about torch.jit.trace and move bt doc below sdpa
fxmarty Dec 7, 2023
76a1e17
fix parameterized.expand order
fxmarty Dec 7, 2023
65aeba6
move tests from test_modeling_attn_mask_utils to test_modeling_utils …
fxmarty Dec 7, 2023
09ab820
Merge branch 'main' into torch-sdpa-preliminary-support
fxmarty Dec 7, 2023
48d95ea
Merge branch 'main' into torch-sdpa-preliminary-support
fxmarty Dec 8, 2023
546dd51
update sdpaattention class with the new cache
fxmarty Dec 8, 2023
2045915
Update src/transformers/configuration_utils.py
fxmarty Dec 8, 2023
eb11883
Update src/transformers/models/bark/modeling_bark.py
fxmarty Dec 8, 2023
920686e
address review comments
fxmarty Dec 8, 2023
2146857
WIP torch.jit.trace fix. left: test both eager & sdpa
fxmarty Dec 8, 2023
9b48591
add test for torch.jit.trace for both eager/sdpa
fxmarty Dec 8, 2023
4315638
Merge branch 'main' into torch-sdpa-preliminary-support
fxmarty Dec 8, 2023
cc7fc4e
fix falcon with torch==2.0 that needs to use sdpa
fxmarty Dec 8, 2023
84d9605
Merge branch 'torch-sdpa-preliminary-support' of https://github.com/f…
fxmarty Dec 8, 2023
8486770
fix doc
fxmarty Dec 8, 2023
c6181f2
hopefully last fix
fxmarty Dec 8, 2023
7ebfd1d
fix key_value_length that has no default now in mask converter
fxmarty Dec 8, 2023
dacf149
is it flacky?
fxmarty Dec 8, 2023
f196bef
Merge branch 'main' into torch-sdpa-preliminary-support
fxmarty Dec 8, 2023
810de1a
fix speculative decoding bug
fxmarty Dec 8, 2023
f116cce
tests do pass
fxmarty Dec 8, 2023
4721c36
Merge branch 'main' into torch-sdpa-preliminary-support
fxmarty Dec 8, 2023
3f06a3a
fix following #27907
fxmarty Dec 8, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/en/llm_tutorial_optimization.md
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ flush()
```

For comparison, let's run the same function, but enable Flash Attention instead.
To do so, we convert the model to [BetterTransformers](https://huggingface.co/docs/optimum/bettertransformer/overview) and by doing so enabling PyTorch's [SDPA self-attention](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention) which in turn is based on Flash Attention.
To do so, we convert the model to [BetterTransformer](https://huggingface.co/docs/optimum/bettertransformer/overview) and by doing so enabling PyTorch's [SDPA self-attention](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention) which in turn is able to use Flash Attention.

```python
model.to_bettertransformer()
Expand Down
6 changes: 3 additions & 3 deletions docs/source/en/model_doc/bark.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,10 @@ pip install -U flash-attn --no-build-isolation

##### Usage

To load a model using Flash Attention 2, we can pass the `use_flash_attention_2` flag to [`.from_pretrained`](https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.from_pretrained). We'll also load the model in half-precision (e.g. `torch.float16`), since it results in almost no degradation to audio quality but significantly lower memory usage and faster inference:
To load a model using Flash Attention 2, we can pass the `attn_implementation="flash_attention_2"` flag to [`.from_pretrained`](https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.from_pretrained). We'll also load the model in half-precision (e.g. `torch.float16`), since it results in almost no degradation to audio quality but significantly lower memory usage and faster inference:

```python
model = BarkModel.from_pretrained("suno/bark-small", torch_dtype=torch.float16, use_flash_attention_2=True).to(device)
model = BarkModel.from_pretrained("suno/bark-small", torch_dtype=torch.float16, attn_implementation="flash_attention_2").to(device)
```

##### Performance comparison
Expand Down Expand Up @@ -114,7 +114,7 @@ import torch
device = "cuda" if torch.cuda.is_available() else "cpu"

# load in fp16 and use Flash Attention 2
model = BarkModel.from_pretrained("suno/bark-small", torch_dtype=torch.float16, use_flash_attention_2=True).to(device)
model = BarkModel.from_pretrained("suno/bark-small", torch_dtype=torch.float16, attn_implementation="flash_attention_2").to(device)

# enable CPU offload
model.enable_cpu_offload()
Expand Down
2 changes: 1 addition & 1 deletion docs/source/en/model_doc/distilbert.md
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ To load and run a model using Flash Attention 2, refer to the snippet below:
>>> device = "cuda" # the device to load the model onto

>>> tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')
>>> model = AutoModel.from_pretrained("distilbert-base-uncased", torch_dtype=torch.float16, use_flash_attention_2=True)
>>> model = AutoModel.from_pretrained("distilbert-base-uncased", torch_dtype=torch.float16, attn_implementation="flash_attention_2")

>>> text = "Replace me by any text you'd like."

Expand Down
2 changes: 1 addition & 1 deletion docs/source/en/model_doc/gpt_bigcode.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ To load and run a model using Flash Attention 2, refer to the snippet below:
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
>>> device = "cuda" # the device to load the model onto

>>> model = AutoModelForCausalLM.from_pretrained("bigcode/gpt_bigcode-santacoder", torch_dtype=torch.float16, use_flash_attention_2=True)
>>> model = AutoModelForCausalLM.from_pretrained("bigcode/gpt_bigcode-santacoder", torch_dtype=torch.float16, attn_implementation="flash_attention_2")
>>> tokenizer = AutoTokenizer.from_pretrained("bigcode/gpt_bigcode-santacoder")

>>> prompt = "def hello_world():"
Expand Down
2 changes: 1 addition & 1 deletion docs/source/en/model_doc/gpt_neo.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ To load and run a model using Flash Attention 2, refer to the snippet below:
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
>>> device = "cuda" # the device to load the model onto

>>> model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-neo-2.7B", torch_dtype=torch.float16, use_flash_attention_2=True)
>>> model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-neo-2.7B", torch_dtype=torch.float16, attn_implementation="flash_attention_2")
>>> tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-2.7B")

>>> prompt = "def hello_world():"
Expand Down
4 changes: 2 additions & 2 deletions docs/source/en/model_doc/gpt_neox.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,12 @@ pip install -U flash-attn --no-build-isolation

### Usage

To load a model using Flash Attention 2, we can pass the `use_flash_attention_2` flag to [`.from_pretrained`](https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.from_pretrained). We'll also load the model in half-precision (e.g. `torch.float16`), since it results in almost no degradation to audio quality but significantly lower memory usage and faster inference:
To load a model using Flash Attention 2, we can pass the argument `attn_implementation="flash_attention_2"` to [`.from_pretrained`](https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.from_pretrained). We'll also load the model in half-precision (e.g. `torch.float16`), since it results in almost no degradation to audio quality but significantly lower memory usage and faster inference:

```python
>>> from transformers import GPTNeoXForCausalLM, GPTNeoXTokenizerFast

model = GPTNeoXForCausalLM.from_pretrained("EleutherAI/gpt-neox-20b", torch_dtype=torch.float16, use_flash_attention_2=True).to(device)
model = GPTNeoXForCausalLM.from_pretrained("EleutherAI/gpt-neox-20b", torch_dtype=torch.float16, attn_implementation="flash_attention_2").to(device)
...
```

Expand Down
2 changes: 1 addition & 1 deletion docs/source/en/model_doc/mistral.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ To load and run a model using Flash Attention 2, refer to the snippet below:
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
>>> device = "cuda" # the device to load the model onto

>>> model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", torch_dtype=torch.float16, use_flash_attention_2=True)
>>> model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", torch_dtype=torch.float16, attn_implementation="flash_attention_2")
>>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")

>>> prompt = "My favourite condiment is"
Expand Down
2 changes: 1 addition & 1 deletion docs/source/en/model_doc/opt.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ To load and run a model using Flash Attention 2, refer to the snippet below:
>>> from transformers import OPTForCausalLM, GPT2Tokenizer
>>> device = "cuda" # the device to load the model onto

>>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m", torch_dtype=torch.float16, use_flash_attention_2=True)
>>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m", torch_dtype=torch.float16, attn_implementation="flash_attention_2")
>>> tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-350m")

>>> prompt = ("A chat between a curious human and the Statue of Liberty.\n\nHuman: What is your name?\nStatue: I am the "
Expand Down
4 changes: 2 additions & 2 deletions docs/source/en/model_doc/phi.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ To load and run a model using Flash Attention 2, refer to the snippet below:
>>> from transformers import PhiForCausalLM, AutoTokenizer

>>> # define the model and tokenizer and push the model and tokens to the GPU.
>>> model = PhiForCausalLM.from_pretrained("susnato/phi-1_5_dev", torch_dtype=torch.float16, use_flash_attention_2=True).to("cuda")
>>> model = PhiForCausalLM.from_pretrained("susnato/phi-1_5_dev", torch_dtype=torch.float16, attn_implementation="flash_attention_2").to("cuda")
>>> tokenizer = AutoTokenizer.from_pretrained("susnato/phi-1_5_dev")

>>> # feel free to change the prompt to your liking.
Expand Down Expand Up @@ -163,4 +163,4 @@ Below is an expected speedup diagram that compares pure inference time between t
- forward

</pt>
</frameworkcontent>
</frameworkcontent>
107 changes: 71 additions & 36 deletions docs/source/en/perf_infer_gpu_one.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,29 @@ FlashAttention-2 is experimental and may change considerably in future versions.
1. additionally parallelizing the attention computation over sequence length
2. partitioning the work between GPU threads to reduce communication and shared memory reads/writes between them

FlashAttention-2 supports inference with Llama, Mistral, Falcon and Bark models. You can request to add FlashAttention-2 support for another model by opening a GitHub Issue or Pull Request.
FlashAttention-2 is currently supported for the following architectures:
* [Bark](https://huggingface.co/docs/transformers/model_doc/bark#transformers.BarkModel)
* [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel)
* [DistilBert](https://huggingface.co/docs/transformers/model_doc/distilbert#transformers.DistilBertModel)
* [GPTBigCode](https://huggingface.co/docs/transformers/model_doc/gpt_bigcode#transformers.GPTBigCodeModel)
* [GPTNeo](https://huggingface.co/docs/transformers/model_doc/gpt_neo#transformers.GPTNeoModel)
* [GPTNeoX](https://huggingface.co/docs/transformers/model_doc/gpt_neox#transformers.GPTNeoXModel)
* [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel)
* [Llama](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel)
* [Llava](https://huggingface.co/docs/transformers/model_doc/llava)
* [MBart](https://huggingface.co/docs/transformers/model_doc/mbart#transformers.MBartModel)
* [Mistral](https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel)
* [OPT](https://huggingface.co/docs/transformers/model_doc/opt#transformers.OPTModel)
* [Phi](https://huggingface.co/docs/transformers/model_doc/phi#transformers.PhiModel)
* [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperModel)

You can request to add FlashAttention-2 support for another model by opening a GitHub Issue or Pull Request.

Before you begin, make sure you have FlashAttention-2 installed. For NVIDIA GPUs, the library is installable through pip: `pip install flash-attn --no-build-isolation`. We strongly suggest to refer to the [detailed installation instructions](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#installation-and-features).

FlashAttention-2 is also supported on AMD GPUs, with the current support limited to **Instinct MI210 and Instinct MI250**. We strongly suggest to use the following [Dockerfile](https://github.com/huggingface/optimum-amd/tree/main/docker/transformers-pytorch-amd-gpu-flash/Dockerfile) to use FlashAttention-2 on AMD GPUs.

To enable FlashAttention-2, add the `use_flash_attention_2` parameter to [`~AutoModelForCausalLM.from_pretrained`]:
To enable FlashAttention-2, pass the argument `attn_implementation="flash_attention_2"` to [`~AutoModelForCausalLM.from_pretrained`]:

```python
import torch
Expand All @@ -54,13 +70,15 @@ tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
use_flash_attention_2=True,
attn_implementation="flash_attention_2",
)
```

<Tip>

FlashAttention-2 can only be used when the model's dtype is `fp16` or `bf16`. Make sure to cast your model to the appropriate dtype and load them on a supported device before using FlashAttention-2.

Note that `use_flash_attention_2=True` can also be used to enable Flash Attention 2, but is deprecated in favor of `attn_implementation="flash_attention_2"`.

</Tip>

Expand All @@ -77,14 +95,14 @@ tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
load_in_8bit=True,
use_flash_attention_2=True,
attn_implementation="flash_attention_2",
)

# load in 4bit
model = AutoModelForCausalLM.from_pretrained(
model_id,
load_in_4bit=True,
use_flash_attention_2=True,
attn_implementation="flash_attention_2",
)
```

Expand Down Expand Up @@ -124,41 +142,21 @@ FlashAttention is more memory efficient, meaning you can train on much larger se
<img src="https://huggingface.co/datasets/ybelkada/documentation-images/resolve/main/llama-2-large-seqlen-padding.png">
</div>

## BetterTransformer

<Tip>

Check out our benchmarks with BetterTransformer and scaled dot product attention in the [Out of the box acceleration and memory savings of 🤗 decoder models with PyTorch 2.0](https://pytorch.org/blog/out-of-the-box-acceleration/) and learn more about the fastpath execution in the [BetterTransformer](https://medium.com/pytorch/bettertransformer-out-of-the-box-performance-for-huggingface-transformers-3fbe27d50ab2) blog post.

</Tip>

BetterTransformer accelerates inference with its fastpath (native PyTorch specialized implementation of Transformer functions) execution. The two optimizations in the fastpath execution are:

1. fusion, which combines multiple sequential operations into a single "kernel" to reduce the number of computation steps
2. skipping the inherent sparsity of padding tokens to avoid unnecessary computation with nested tensors

BetterTransformer also converts all attention operations to use the more memory-efficient [scaled dot product attention (SDPA)](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention), and it calls optimized kernels like [FlashAttention](https://huggingface.co/papers/2205.14135) under the hood.

Before you start, make sure you have 🤗 Optimum [installed](https://huggingface.co/docs/optimum/installation).

Then you can enable BetterTransformer with the [`PreTrainedModel.to_bettertransformer`] method:

```python
model = model.to_bettertransformer()
```

You can return the original Transformers model with the [`~PreTrainedModel.reverse_bettertransformer`] method. You should use this before saving your model to use the canonical Transformers modeling:
## FlashAttention and memory-efficient attention through PyTorch's scaled_dot_product_attention

```py
model = model.reverse_bettertransformer()
model.save_pretrained("saved_model")
```
PyTorch's [`torch.nn.functional.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html) (SDPA) can also call FlashAttention and memory-efficient attention kernels under the hood. SDPA support is currently being added natively in Transformers, and is used by default for `torch>=2.1.1` when an implementation is available.

### FlashAttention
For now, Transformers supports inference and training through SDPA for the following architectures:
* [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel)
* [GPTBigCode](https://huggingface.co/docs/transformers/model_doc/gpt_bigcode#transformers.GPTBigCodeModel)
* [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel)
* [Llama](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel)
* [Idefics](https://huggingface.co/docs/transformers/model_doc/idefics#transformers.IdeficsModel)
* [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperModel)
fxmarty marked this conversation as resolved.
Show resolved Hide resolved

SDPA can also call FlashAttention kernels under the hood. FlashAttention can only be used for models using the `fp16` or `bf16` dtype, so make sure to cast your model to the appropriate dtype before using it.
Note that FlashAttention can only be used for models with the `fp16` or `bf16` torch type, so make sure to cast your model to the appropriate type before using it.

To enable FlashAttention or to check whether it is available in a given setting (hardware, problem size), use [`torch.backends.cuda.sdp_kernel`](https://pytorch.org/docs/master/backends.html#torch.backends.cuda.sdp_kernel) as a context manager:
By default, `torch.nn.functional.scaled_dot_product_attention` selects the most performant kernel available, but to check whether a backend is available in a given setting (hardware, problem size), you can use [`torch.backends.cuda.sdp_kernel`](https://pytorch.org/docs/master/backends.html#torch.backends.cuda.sdp_kernel) as a context manager:

```diff
import torch
Expand Down Expand Up @@ -187,6 +185,43 @@ RuntimeError: No available kernel. Aborting execution.
pip3 install -U --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu118
```

## BetterTransformer

<Tip warning={true}>

Part of BetterTransformer features are being upstreamed in Transformers, with native `torch.nn.scaled_dot_product_attention` default support. BetterTransformer still has a wider coverage than the Transformers SDPA integration, but you can expect more and more architectures to support natively SDPA in Transformers.

</Tip>


<Tip>

Check out our benchmarks with BetterTransformer and scaled dot product attention in the [Out of the box acceleration and memory savings of 🤗 decoder models with PyTorch 2.0](https://pytorch.org/blog/out-of-the-box-acceleration/) and learn more about the fastpath execution in the [BetterTransformer](https://medium.com/pytorch/bettertransformer-out-of-the-box-performance-for-huggingface-transformers-3fbe27d50ab2) blog post.

</Tip>

BetterTransformer accelerates inference with its fastpath (native PyTorch specialized implementation of Transformer functions) execution. The two optimizations in the fastpath execution are:

1. fusion, which combines multiple sequential operations into a single "kernel" to reduce the number of computation steps
2. skipping the inherent sparsity of padding tokens to avoid unnecessary computation with nested tensors

BetterTransformer also converts all attention operations to use the more memory-efficient [scaled dot product attention (SDPA)](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention), and it calls optimized kernels like [FlashAttention](https://huggingface.co/papers/2205.14135) under the hood.

Before you start, make sure you have 🤗 Optimum [installed](https://huggingface.co/docs/optimum/installation).

Then you can enable BetterTransformer with the [`PreTrainedModel.to_bettertransformer`] method:

```python
model = model.to_bettertransformer()
```

You can return the original Transformers model with the [`~PreTrainedModel.reverse_bettertransformer`] method. You should use this before saving your model to use the canonical Transformers modeling:

```py
model = model.reverse_bettertransformer()
model.save_pretrained("saved_model")
```

## bitsandbytes

bitsandbytes is a quantization library that includes support for 4-bit and 8-bit quantization. Quantization reduces your model size compared to its native full precision version, making it easier to fit large models onto GPUs with limited memory.
Expand Down
2 changes: 1 addition & 1 deletion docs/source/en/quantization.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ AWQ quantization can also be combined with [FlashAttention-2](perf_infer_gpu_one
```py
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("TheBloke/zephyr-7B-alpha-AWQ", use_flash_attention_2=True, device_map="cuda:0")
model = AutoModelForCausalLM.from_pretrained("TheBloke/zephyr-7B-alpha-AWQ", attn_implementation="flash_attention_2", device_map="cuda:0")
```


Expand Down
10 changes: 5 additions & 5 deletions docs/source/ja/perf_infer_gpu_one.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ Flash Attention 2は、モデルのdtypeが`fp16`または`bf16`の場合にの

### Quick usage

モデルでFlash Attention 2を有効にするには、`from_pretrained`の引数に`use_flash_attention_2`を追加します。
モデルでFlash Attention 2を有効にするには、`from_pretrained`の引数に`attn_implementation="flash_attention_2"`を追加します。


```python
Expand All @@ -57,7 +57,7 @@ tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
use_flash_attention_2=True,
attn_implementation="flash_attention_2",
)
```

Expand Down Expand Up @@ -114,7 +114,7 @@ tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
load_in_8bit=True,
use_flash_attention_2=True,
attn_implementation="flash_attention_2",
)
```

Expand All @@ -132,7 +132,7 @@ tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
load_in_4bit=True,
use_flash_attention_2=True,
attn_implementation="flash_attention_2",
)
```

Expand All @@ -151,7 +151,7 @@ tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
load_in_4bit=True,
use_flash_attention_2=True,
attn_implementation="flash_attention_2",
)

lora_config = LoraConfig(
Expand Down
Loading
Loading