Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core ] Integrate Flash attention 2 in most used models #25598

Merged
merged 99 commits into from
Sep 22, 2023
Merged
Show file tree
Hide file tree
Changes from 74 commits
Commits
Show all changes
99 commits
Select commit Hold shift + click to select a range
8bb77a1
v1
younesbelkada Aug 18, 2023
2e18421
oops
younesbelkada Aug 18, 2023
fe5795e
working v1
younesbelkada Aug 18, 2023
4bd15e2
fixup
younesbelkada Aug 18, 2023
49fe318
add some TODOs
younesbelkada Aug 18, 2023
f5d440b
Merge branch 'add-flash-attn-2' of https://github.com/younesbelkada/t…
younesbelkada Aug 18, 2023
50491e8
fixup
younesbelkada Aug 18, 2023
7df78c0
Merge remote-tracking branch 'upstream/main' into add-flash-attn-2
younesbelkada Aug 22, 2023
0e30d13
padding support + try with module replacement
younesbelkada Aug 23, 2023
ad8b905
nit
younesbelkada Aug 23, 2023
3c31f10
alternative design
younesbelkada Sep 1, 2023
2628bf3
oops
younesbelkada Sep 1, 2023
56d0b49
Merge remote-tracking branch 'upstream/main' into HEAD
younesbelkada Sep 1, 2023
20d1b37
add `use_cache` support for llama
younesbelkada Sep 1, 2023
a82f1ca
v1 falcon
younesbelkada Sep 1, 2023
c72e8ff
nit
younesbelkada Sep 1, 2023
66823f9
a bit of refactor
younesbelkada Sep 1, 2023
41f8f3d
nit
younesbelkada Sep 1, 2023
a64a1a9
nits nits
younesbelkada Sep 1, 2023
67e3fc2
add v1 padding support falcon (even though it seemed to work before)
younesbelkada Sep 1, 2023
8444ab6
nit
younesbelkada Sep 1, 2023
8b1c2df
falcon works
younesbelkada Sep 1, 2023
c3ebcd2
fixup
younesbelkada Sep 1, 2023
1c212d8
v1 tests
younesbelkada Sep 1, 2023
4618701
nit
younesbelkada Sep 1, 2023
85ec946
fix generation llama flash
fxmarty Sep 1, 2023
2248f20
Merge branch 'add-flash-attn-2' of https://github.com/younesbelkada/t…
fxmarty Sep 1, 2023
0881ced
update tests
younesbelkada Sep 1, 2023
a8a1b2d
Merge branch 'add-flash-attn-2' of https://github.com/younesbelkada/t…
younesbelkada Sep 1, 2023
2be3e03
fix tests + nits
younesbelkada Sep 1, 2023
b6d3e58
fix copies
younesbelkada Sep 1, 2023
b47e85c
fix nit
younesbelkada Sep 1, 2023
db8bd64
test- padding mask
younesbelkada Sep 1, 2023
58848ab
stype
younesbelkada Sep 4, 2023
3f73557
add more mem efficient support
younesbelkada Sep 4, 2023
baae736
Update src/transformers/modeling_utils.py
younesbelkada Sep 4, 2023
55f6140
fixup
younesbelkada Sep 4, 2023
10d5c1b
Merge remote-tracking branch 'upstream/main' into add-flash-attn-2
younesbelkada Sep 4, 2023
3fb221a
nit
younesbelkada Sep 4, 2023
a931aeb
fixup
younesbelkada Sep 4, 2023
68a1204
remove it from config when saving
younesbelkada Sep 4, 2023
36e0d6e
fixup
younesbelkada Sep 4, 2023
2beeb68
revert docstring
younesbelkada Sep 4, 2023
7b5da2c
add more checks
younesbelkada Sep 4, 2023
b99a582
use values
younesbelkada Sep 4, 2023
adaed45
oops
younesbelkada Sep 4, 2023
7f06af6
new version
fxmarty Sep 5, 2023
2d36c6f
fixup
younesbelkada Sep 5, 2023
a663fa4
Merge branch 'main' into add-flash-attn-2
younesbelkada Sep 5, 2023
9d3693f
add same trick for falcon
younesbelkada Sep 5, 2023
65ae59c
nit
younesbelkada Sep 5, 2023
43185b5
add another test
younesbelkada Sep 11, 2023
c61157e
change tests
younesbelkada Sep 11, 2023
2f17792
fix issues with GC and also falcon
younesbelkada Sep 11, 2023
65c3861
fixup
younesbelkada Sep 11, 2023
165a503
oops
younesbelkada Sep 11, 2023
5abc702
Update src/transformers/models/falcon/modeling_falcon.py
younesbelkada Sep 13, 2023
5069e4a
add init_rope
younesbelkada Sep 13, 2023
11400d8
Merge branch 'add-flash-attn-2' of https://github.com/younesbelkada/t…
younesbelkada Sep 13, 2023
ace7939
updates
younesbelkada Sep 13, 2023
fe9b16d
fix copies
younesbelkada Sep 13, 2023
6174c06
Merge branch 'main' into add-flash-attn-2
younesbelkada Sep 13, 2023
acfc954
fixup
younesbelkada Sep 13, 2023
33a0f62
fixup
younesbelkada Sep 13, 2023
ee8ba20
more clarification
younesbelkada Sep 13, 2023
e28fb0b
fixup
younesbelkada Sep 13, 2023
025727c
right padding tests
younesbelkada Sep 13, 2023
8f7e400
add docs
younesbelkada Sep 13, 2023
3259392
add FA in docker image
younesbelkada Sep 13, 2023
57a077b
more clarifications
younesbelkada Sep 13, 2023
e62b0b8
add some figures
younesbelkada Sep 13, 2023
7419438
add todo
younesbelkada Sep 13, 2023
3ba5e98
rectify comment
younesbelkada Sep 14, 2023
585e463
Change to FA2
younesbelkada Sep 14, 2023
ec0f8b9
Update docs/source/en/perf_infer_gpu_one.md
younesbelkada Sep 19, 2023
3e5ea35
split in two lines
younesbelkada Sep 19, 2023
4bb1bc5
change test name
younesbelkada Sep 19, 2023
3ea4633
Merge remote-tracking branch 'origin/main' into add-flash-attn-2
younesbelkada Sep 19, 2023
b67c21e
add more tests
younesbelkada Sep 19, 2023
5b73557
some clean up
younesbelkada Sep 19, 2023
48e3bcf
remove `rearrange` deps
younesbelkada Sep 19, 2023
0461384
add more docs
younesbelkada Sep 19, 2023
8d72a66
revert changes on dockerfile
younesbelkada Sep 19, 2023
73b2f07
Revert "revert changes on dockerfile"
younesbelkada Sep 19, 2023
fb7654c
revert changes on dockerfile
younesbelkada Sep 19, 2023
a737bde
Apply suggestions from code review
younesbelkada Sep 20, 2023
80951ae
address some comments
younesbelkada Sep 20, 2023
6f7ff42
docs
younesbelkada Sep 20, 2023
257a633
use inheritance
younesbelkada Sep 20, 2023
360da70
Update src/transformers/testing_utils.py
younesbelkada Sep 20, 2023
1d91bc4
fixup
younesbelkada Sep 20, 2023
8ecab97
Merge branch 'main' into add-flash-attn-2
younesbelkada Sep 20, 2023
7c5720f
Apply suggestions from code review
younesbelkada Sep 21, 2023
28b82e2
Update src/transformers/modeling_utils.py
younesbelkada Sep 21, 2023
84b5793
final comments
younesbelkada Sep 21, 2023
949172f
clean up
younesbelkada Sep 21, 2023
825c7e0
style
younesbelkada Sep 21, 2023
1af232c
add cast + warning for PEFT models
younesbelkada Sep 22, 2023
d7f16c5
fixup
younesbelkada Sep 22, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docker/transformers-all-latest-gpu/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ RUN python3 -m pip install --no-cache-dir bitsandbytes
# Add auto-gptq for gtpq quantization testing
RUN python3 -m pip install --no-cache-dir auto-gptq --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/

# Add flash attention
# commands copied from https://github.com/Dao-AILab/flash-attention#installation-and-features
RUN python3 -m pip uninstall -y ninja && python3 -m pip install ninja
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need to test this and make sure building this new docker image works as expected

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See my comment in #26268 (comment)

RUN python3 -m pip install flash-attn --no-build-isolation

# Add einops for additional model testing
RUN python3 -m pip install --no-cache-dir einops

Expand Down
4 changes: 4 additions & 0 deletions docs/source/en/perf_infer_gpu_many.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ Note: A multi GPU setup can use the majority of the strategies described in the

</Tip>

## Flash Attention 2

Flash Attention 2 integration also works in a multi-GPU setup, check out the appropriate section in the [single GPU section](./perf_infer_gpu_one#Flash-Attention-2)

## BetterTransformer

[BetterTransformer](https://huggingface.co/docs/optimum/bettertransformer/overview) converts 🤗 Transformers models to use the PyTorch-native fastpath execution, which calls optimized kernels like Flash Attention under the hood.
Expand Down
136 changes: 136 additions & 0 deletions docs/source/en/perf_infer_gpu_one.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,142 @@ rendered properly in your Markdown viewer.

In addition to this guide, relevant information can be found as well in [the guide for training on a single GPU](perf_train_gpu_one) and [the guide for inference on CPUs](perf_infer_cpu).

## Flash Attention 2

<Tip>

Note that this feature is experimental and might considerably change in future versions. For instance, the Flash Attention 2 API might migrate to `BetterTransformer` API in the near future.

Comment on lines +24 to +25
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the Flash Attention 2 API might migrate to BetterTransformer API in the near future.

did not know this was planned 😄 If not let's just not say anything

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well it is all about providing a single meaningful API to users and avoid confusing them. In PyTorch 2.2 (hopefully not too late!), we'll be in a state where FA2 will be supported by SDPA so basically a duplicate of this.

Copy link
Contributor Author

@younesbelkada younesbelkada Sep 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As felix said, the goal in the future would be to have an unified API through BetterTransformer, hence marking the API as being experimental

</Tip>

Flash Attention 2 can considerably speedup the training and inference speed of transformer based models. Flash Attention 2 has been introduced in the [official Flash Attention repository](https://github.com/Dao-AILab/flash-attention) from Tri Dao et al. The scientific paper of Flash attention can be found [here](https://arxiv.org/abs/2205.14135).
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved

Make sure to follow the installation guide on the repository mentioned above to properly install Flash Attention 2. Once that package is installed, you can benefit from this feature.

We natively support Flash Attention 2 for some models, currently supported architectures are:
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved

- Llama
- Falcon

You can request to add Flash Attention 2 support for more models by opening an issue on GitHub!

And they can be used for inference and training, including training with padding tokens - which is currently not supported for `BetterTransformer` API below.
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved

<Tip>

Flash Attention 2 can only be used for models using fp16 or bf16 dtype, and can be run only on NVIDIA-GPU devices. Make sure to cast your model to the appropriate dtype and load them on a supported device before using that feature.
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved

</Tip>

### Quick usage

To enable Flash Attention 2 in your model, simply add `use_flash_attn_2` in `from_pretrained` arguments
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved

```python
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM

model_id = "tiiuae/falcon-7b"
tokenizer = AutoTokenizer.from_pretrained(model_id)

model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
use_flash_attn_2=True,
)
```

And use it for generation or fine-tuning.

### Expected speedups

You can benefit from considerable speedup for fine-tuning and inference, especially for long sequence length.
However, note that due to the fact that Flash Attention does not support computing attention scores with padd tokens under the hood, we need to manually pad / unpad the attention scores for batched inference when the sequence contains padd tokens. This leads to an important slowdown for batched `generate` with padd tokens. To overcome this, one should use Flash Attention without padd tokens in the sequence for training (e.g. by packing a dataset, i.e. concatenating sequences until reaching the maximum sequence length)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO you should use pad or padding, but not padd

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(applies to the entire document)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

e.g. by packing a dataset, i.e. concatenating sequences until reaching the maximum sequence length

I'd add a link to a doc explaining that in our docs and/or to some of our examples that do it (for ex I think that's what's happening here in run_clm)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep added few lines in that direction, let me know how does that sounds to you

younesbelkada marked this conversation as resolved.
Show resolved Hide resolved

Below is the expected speedup you can get for a simple forward pass on `tiiuae/falcon-7b` with a sequence length of 4096 and various batch sizes, without padd tokens:

<div style="text-align: center">
<img src="https://huggingface.co/datasets/ybelkada/documentation-images/resolve/main/falcon-7b-inference-large-seqlen.png">
</div>

Below is the expected speedup you can get for a simple forward pass on `meta-llama/Llama-7b-hf` with a sequence length of 4096 and various batch sizes, without padd tokens:

<div style="text-align: center">
<img src="https://huggingface.co/datasets/ybelkada/documentation-images/resolve/main/llama-7b-inference-large-seqlen.png">
</div>

TODO: @younesbelkada add more figures and cases where FA fails.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll address that a bit later, I need to check first if we can merge younesbelkada#5


Note that Flash Attention makes the attention computation more memory efficient, meaning you can train with much larger sequenc lengths without facing CUDA OOM issues.
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved

### Advanced usage
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very nice examples!


You can combine this feature with many exisiting feature for model optimization. Check out few examples below:

### Combining Flash Attention 2 and 8-bit models

You can combine this feature together with 8-bit quantization:

```python
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM

model_id = "tiiuae/falcon-7b"
tokenizer = AutoTokenizer.from_pretrained(model_id)

model = AutoModelForCausalLM.from_pretrained(
model_id,
load_in_8bit=True,
use_flash_attn_2=True,
)
```

### Combining Flash Attention 2 and 4-bit models

You can combine this feature together with 4-bit quantization:

```python
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM

model_id = "tiiuae/falcon-7b"
tokenizer = AutoTokenizer.from_pretrained(model_id)

model = AutoModelForCausalLM.from_pretrained(
model_id,
load_in_4bit=True,
use_flash_attn_2=True,
)
```

### Combining Flash Attention 2 and PEFT

You can combine this feature together with PEFT for training adapters using Flash Attention 2 under the hood:

```python
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM
from peft import LoraConfig

model_id = "tiiuae/falcon-7b"
tokenizer = AutoTokenizer.from_pretrained(model_id)

model = AutoModelForCausalLM.from_pretrained(
model_id,
load_in_4bit=True,
use_flash_attn_2=True,
)

lora_config = LoraConfig(
r=8,
task_type="CAUSAL_LM"
)

model.add_adapter(lora_config)

... # train your model
```

## BetterTransformer

[BetterTransformer](https://huggingface.co/docs/optimum/bettertransformer/overview) converts 🤗 Transformers models to use the PyTorch-native fastpath execution, which calls optimized kernels like Flash Attention under the hood.
Expand Down
4 changes: 4 additions & 0 deletions docs/source/en/perf_train_gpu_one.md
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,10 @@ For additional information on tf32 vs other precisions, please refer to the foll
[RTX-3090](https://github.com/huggingface/transformers/issues/14608#issuecomment-1004390803) and
[A100](https://github.com/huggingface/transformers/issues/15026#issuecomment-1004543189).

## Flash Attention 2

You can speedup the training throughput by using Flash Attention 2 integration in transformers. Check out the appropriate section in the [single GPU section](./perf_infer_gpu_one#Flash-Attention-2) to learn more about how to load a model with Flash Attention 2 modules.

## Optimizer choice

The most common optimizer used to train transformer models is Adam or AdamW (Adam with weight decay). Adam achieves
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -855,6 +855,9 @@ def to_diff_dict(self) -> Dict[str, Any]:

self.dict_torch_dtype_to_str(serializable_config_dict)

if "_flash_attn_2_enabled" in serializable_config_dict:
del serializable_config_dict["_flash_attn_2_enabled"]
ArthurZucker marked this conversation as resolved.
Show resolved Hide resolved

return serializable_config_dict

def to_dict(self) -> Dict[str, Any]:
Expand All @@ -871,6 +874,8 @@ def to_dict(self) -> Dict[str, Any]:
del output["_auto_class"]
if "_commit_hash" in output:
del output["_commit_hash"]
if "_flash_attn_2_enabled" in output:
del output["_flash_attn_2_enabled"]

# Transformers version when serializing the model
output["transformers_version"] = __version__
Expand Down
85 changes: 85 additions & 0 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
is_accelerate_available,
is_auto_gptq_available,
is_bitsandbytes_available,
is_flash_attn_available,
is_offline_mode,
is_optimum_available,
is_peft_available,
Expand Down Expand Up @@ -1116,6 +1117,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
is_parallelizable = False
supports_gradient_checkpointing = False

# Flash Attention 2 support
_supports_flash_attn_2 = False

@property
def dummy_inputs(self) -> Dict[str, torch.Tensor]:
"""
Expand Down Expand Up @@ -1239,6 +1243,83 @@ def can_generate(cls) -> bool:
return False
return True

@classmethod
def _check_and_enable_flash_attn_2(
cls, config, torch_dtype: Optional[torch.dtype] = None, device_map: Optional[Union[str, Dict[str, int]]] = None
) -> PretrainedConfig:
Comment on lines +1247 to +1249
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good error raising here

"""
If you don't know about Flash Attention, check out the official repository of flash attention:
https://github.com/Dao-AILab/flash-attention

For using Flash Attention 1.0 you can do it directly via the `BetterTransformer` API, have a look at this
specific section of the documentation to learn more about it:
https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#decoder-models

The method checks if the current setup is compatible with Flash Attention as it requires the model to be in
half precision and not ran on CPU.

If all checks pass, the method will create an attribute in the config `_flash_attn_2_enabled` so that the model
can initialize the correct attention module
"""
if not cls._supports_flash_attn_2:
raise ValueError(
"The current architecture does not support Flash Attention 2.0. Please open an issue on GitHub to "
"request support for this architecture."
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


if not is_flash_attn_available():
raise ImportError(
"Flash Attention 2.0 is not available. Please refer to the documentation of https://github.com/Dao-AILab/flash-attention for"
" installing it."
)
else:
is_flash_greater_than_2 = version.parse(importlib.metadata.version("flash_attn")) > version.parse("2.0.0")
if not is_flash_greater_than_2:
raise ValueError(
"You need flash_attn package version to be greater than 2.0. Make sure to have that version installed."
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe print the current version they have installed currently

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense!


_is_bettertransformer = getattr(cls, "use_bettertransformer", False)

if _is_bettertransformer:
raise ValueError(
"Flash Attention 2 and BetterTransformer API are not compatible. Please make sure to disable BetterTransformers by doing model.reverse_bettertransformer()"
)
Comment on lines +1285 to +1288
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this just toggle it off with an error-level logging statement?


if torch_dtype is None:
warnings.warn(
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved
"You are attempting to use Flash Attention 2.0 without specifying a torch dtype. This might lead to unexpected behaviour"
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if the convention changed, but originally we favored logger.warning statements as users can choose their lvl of warning (we don't handle warnings statements with our transformers.logging module)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the heads up ! Changed it with logger.warning

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

flagging as it seems to still be a warning.warn

elif torch_dtype is not None and torch_dtype not in [torch.float16, torch.bfloat16]:
raise ValueError(
f"Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes. You passed {torch_dtype}, this might lead to"
" unexpected behaviour."
)

if device_map is None:
if torch.cuda.is_available():
warnings.warn(
"You are attempting to use Flash Attention 2.0 with a model initialized on CPU. Make sure to move the model to GPU"
" after initializing it on CPU with `model.to('cuda')`."
)
else:
raise ValueError(
"You are attempting to use Flash Attention 2.0 with a model initialized on CPU and with no GPU available. "
"This is not supported. Please make sure to have access to a GPU and either initialise the model on a GPU by passing a device_map "
"or initialising the model on CPU and then moving it to GPU."
)
elif (
device_map is not None
and isinstance(device_map, dict)
and ("cpu" in device_map.values() or "disk" in device_map.values())
):
raise ValueError(
"You are attempting to use Flash Attention 2.0 with a model dispatched on CPU or disk. This is not supported. Please make sure to "
"initialise the model on a GPU by passing a device_map that contains only GPU devices as keys."
)
Comment on lines +1318 to +1320
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so basically we can't use it if you don't have enough gpu VRAM. It's not 100% clear for me

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, also not supported if you excplictly want to do CPU / Disk offloading

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually why is this not supported? It shouldn't be a problem to support Flash Attention + cpu offload IMO (we're supporting it for diffusers)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be nice to support indeed, would enable a bunch of larger models

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's update the comment

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would require some work as you need to intantiate a xxxAttention module on the offloaded parts of the model. For running large models users can still run FA-2 + quantization, which already gives a huge memory reduction.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I propose to do it in a follow up PR

config._flash_attn_2_enabled = True
return config

def enable_input_require_grads(self):
"""
Enables the gradients for the input embeddings. This is useful for fine-tuning adapter weights while keeping
Expand Down Expand Up @@ -2369,6 +2450,7 @@ def from_pretrained(
variant = kwargs.pop("variant", None)
_adapter_model_path = kwargs.pop("_adapter_model_path", None)
adapter_name = kwargs.pop("adapter_name", "default")
use_flash_attn_2 = kwargs.pop("use_flash_attn_2", False)

if is_fsdp_enabled():
low_cpu_mem_usage = True
Expand Down Expand Up @@ -2980,6 +3062,9 @@ def from_pretrained(
elif load_in_8bit or load_in_4bit or low_cpu_mem_usage:
init_contexts.append(init_empty_weights())

if use_flash_attn_2:
config = cls._check_and_enable_flash_attn_2(config, torch_dtype=torch_dtype, device_map=device_map)

with ContextManagers(init_contexts):
model = cls(config, *model_args, **model_kwargs)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,6 @@ def __init__(self, config: OpenLlamaConfig):
self.input_layernorm = OpenLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = OpenLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

# Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward
def forward(
self,
hidden_states: torch.Tensor,
Expand Down
Loading
Loading