Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core ] Integrate Flash attention 2 in most used models #25598

Merged
merged 99 commits into from
Sep 22, 2023
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
99 commits
Select commit Hold shift + click to select a range
8bb77a1
v1
younesbelkada Aug 18, 2023
2e18421
oops
younesbelkada Aug 18, 2023
fe5795e
working v1
younesbelkada Aug 18, 2023
4bd15e2
fixup
younesbelkada Aug 18, 2023
49fe318
add some TODOs
younesbelkada Aug 18, 2023
f5d440b
Merge branch 'add-flash-attn-2' of https://github.com/younesbelkada/t…
younesbelkada Aug 18, 2023
50491e8
fixup
younesbelkada Aug 18, 2023
7df78c0
Merge remote-tracking branch 'upstream/main' into add-flash-attn-2
younesbelkada Aug 22, 2023
0e30d13
padding support + try with module replacement
younesbelkada Aug 23, 2023
ad8b905
nit
younesbelkada Aug 23, 2023
3c31f10
alternative design
younesbelkada Sep 1, 2023
2628bf3
oops
younesbelkada Sep 1, 2023
56d0b49
Merge remote-tracking branch 'upstream/main' into HEAD
younesbelkada Sep 1, 2023
20d1b37
add `use_cache` support for llama
younesbelkada Sep 1, 2023
a82f1ca
v1 falcon
younesbelkada Sep 1, 2023
c72e8ff
nit
younesbelkada Sep 1, 2023
66823f9
a bit of refactor
younesbelkada Sep 1, 2023
41f8f3d
nit
younesbelkada Sep 1, 2023
a64a1a9
nits nits
younesbelkada Sep 1, 2023
67e3fc2
add v1 padding support falcon (even though it seemed to work before)
younesbelkada Sep 1, 2023
8444ab6
nit
younesbelkada Sep 1, 2023
8b1c2df
falcon works
younesbelkada Sep 1, 2023
c3ebcd2
fixup
younesbelkada Sep 1, 2023
1c212d8
v1 tests
younesbelkada Sep 1, 2023
4618701
nit
younesbelkada Sep 1, 2023
85ec946
fix generation llama flash
fxmarty Sep 1, 2023
2248f20
Merge branch 'add-flash-attn-2' of https://github.com/younesbelkada/t…
fxmarty Sep 1, 2023
0881ced
update tests
younesbelkada Sep 1, 2023
a8a1b2d
Merge branch 'add-flash-attn-2' of https://github.com/younesbelkada/t…
younesbelkada Sep 1, 2023
2be3e03
fix tests + nits
younesbelkada Sep 1, 2023
b6d3e58
fix copies
younesbelkada Sep 1, 2023
b47e85c
fix nit
younesbelkada Sep 1, 2023
db8bd64
test- padding mask
younesbelkada Sep 1, 2023
58848ab
stype
younesbelkada Sep 4, 2023
3f73557
add more mem efficient support
younesbelkada Sep 4, 2023
baae736
Update src/transformers/modeling_utils.py
younesbelkada Sep 4, 2023
55f6140
fixup
younesbelkada Sep 4, 2023
10d5c1b
Merge remote-tracking branch 'upstream/main' into add-flash-attn-2
younesbelkada Sep 4, 2023
3fb221a
nit
younesbelkada Sep 4, 2023
a931aeb
fixup
younesbelkada Sep 4, 2023
68a1204
remove it from config when saving
younesbelkada Sep 4, 2023
36e0d6e
fixup
younesbelkada Sep 4, 2023
2beeb68
revert docstring
younesbelkada Sep 4, 2023
7b5da2c
add more checks
younesbelkada Sep 4, 2023
b99a582
use values
younesbelkada Sep 4, 2023
adaed45
oops
younesbelkada Sep 4, 2023
7f06af6
new version
fxmarty Sep 5, 2023
2d36c6f
fixup
younesbelkada Sep 5, 2023
a663fa4
Merge branch 'main' into add-flash-attn-2
younesbelkada Sep 5, 2023
9d3693f
add same trick for falcon
younesbelkada Sep 5, 2023
65ae59c
nit
younesbelkada Sep 5, 2023
43185b5
add another test
younesbelkada Sep 11, 2023
c61157e
change tests
younesbelkada Sep 11, 2023
2f17792
fix issues with GC and also falcon
younesbelkada Sep 11, 2023
65c3861
fixup
younesbelkada Sep 11, 2023
165a503
oops
younesbelkada Sep 11, 2023
5abc702
Update src/transformers/models/falcon/modeling_falcon.py
younesbelkada Sep 13, 2023
5069e4a
add init_rope
younesbelkada Sep 13, 2023
11400d8
Merge branch 'add-flash-attn-2' of https://github.com/younesbelkada/t…
younesbelkada Sep 13, 2023
ace7939
updates
younesbelkada Sep 13, 2023
fe9b16d
fix copies
younesbelkada Sep 13, 2023
6174c06
Merge branch 'main' into add-flash-attn-2
younesbelkada Sep 13, 2023
acfc954
fixup
younesbelkada Sep 13, 2023
33a0f62
fixup
younesbelkada Sep 13, 2023
ee8ba20
more clarification
younesbelkada Sep 13, 2023
e28fb0b
fixup
younesbelkada Sep 13, 2023
025727c
right padding tests
younesbelkada Sep 13, 2023
8f7e400
add docs
younesbelkada Sep 13, 2023
3259392
add FA in docker image
younesbelkada Sep 13, 2023
57a077b
more clarifications
younesbelkada Sep 13, 2023
e62b0b8
add some figures
younesbelkada Sep 13, 2023
7419438
add todo
younesbelkada Sep 13, 2023
3ba5e98
rectify comment
younesbelkada Sep 14, 2023
585e463
Change to FA2
younesbelkada Sep 14, 2023
ec0f8b9
Update docs/source/en/perf_infer_gpu_one.md
younesbelkada Sep 19, 2023
3e5ea35
split in two lines
younesbelkada Sep 19, 2023
4bb1bc5
change test name
younesbelkada Sep 19, 2023
3ea4633
Merge remote-tracking branch 'origin/main' into add-flash-attn-2
younesbelkada Sep 19, 2023
b67c21e
add more tests
younesbelkada Sep 19, 2023
5b73557
some clean up
younesbelkada Sep 19, 2023
48e3bcf
remove `rearrange` deps
younesbelkada Sep 19, 2023
0461384
add more docs
younesbelkada Sep 19, 2023
8d72a66
revert changes on dockerfile
younesbelkada Sep 19, 2023
73b2f07
Revert "revert changes on dockerfile"
younesbelkada Sep 19, 2023
fb7654c
revert changes on dockerfile
younesbelkada Sep 19, 2023
a737bde
Apply suggestions from code review
younesbelkada Sep 20, 2023
80951ae
address some comments
younesbelkada Sep 20, 2023
6f7ff42
docs
younesbelkada Sep 20, 2023
257a633
use inheritance
younesbelkada Sep 20, 2023
360da70
Update src/transformers/testing_utils.py
younesbelkada Sep 20, 2023
1d91bc4
fixup
younesbelkada Sep 20, 2023
8ecab97
Merge branch 'main' into add-flash-attn-2
younesbelkada Sep 20, 2023
7c5720f
Apply suggestions from code review
younesbelkada Sep 21, 2023
28b82e2
Update src/transformers/modeling_utils.py
younesbelkada Sep 21, 2023
84b5793
final comments
younesbelkada Sep 21, 2023
949172f
clean up
younesbelkada Sep 21, 2023
825c7e0
style
younesbelkada Sep 21, 2023
1af232c
add cast + warning for PEFT models
younesbelkada Sep 22, 2023
d7f16c5
fixup
younesbelkada Sep 22, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
is_accelerate_available,
is_auto_gptq_available,
is_bitsandbytes_available,
is_flash_attn_available,
is_offline_mode,
is_optimum_available,
is_peft_available,
Expand Down Expand Up @@ -1098,6 +1099,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
is_parallelizable = False
supports_gradient_checkpointing = False

# Flash Attention 2 support
_supports_flash_attn_2 = False

@property
def dummy_inputs(self) -> Dict[str, torch.Tensor]:
"""
Expand Down Expand Up @@ -1231,6 +1235,65 @@ def make_inputs_require_grads(module, input, output):

self._require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads)

def enable_flash_attn_2(self) -> None:
"""
Enable the Flash Attention 2.0 implementation for this model for more memory efficient inference and training.
If you don't know about Flash Attention, check out the official repository of flash attention:
https://github.com/Dao-AILab/flash-attention

For using Flash Attention 1.0 you can do it directly via the `BetterTransformer` API, have a look at this
specific section of the documentation to learn more about it:
https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#decoder-models
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved
"""
if not self._supports_flash_attn_2:
raise ValueError(
"The current architecture does not support Flash Attention 2.0. Please open an issue on GitHub to "
"request support for this architecture."
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


if not is_flash_attn_available():
raise ImportError(
"Flash Attention 2.0 is not available. Please refer to the documentation of https://github.com/Dao-AILab/flash-attention for"
" installing it."
)
else:
is_flash_greater_than_2 = version.parse(importlib.metadata.version("flash_attn")) > version.parse("2.0.0")
if not is_flash_greater_than_2:
raise ValueError(
"You need flash_attn package version to be greater than 2.0. Make sure to have that version installed."
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe print the current version they have installed currently

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense!


_is_bettertransformer = getattr(self, "use_bettertransformer", False)

if _is_bettertransformer:
raise ValueError(
"Flash Attention 2 and BetterTransformer API are not compatible. Please use one API or the other."
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved
)

self._enable_flash_attn_2()
self._flash_attn_2_enabled = True

def disable_flash_attn_2(self) -> None:
"""
Disables the Flash Attention 2.0 implementation for this model for more memory efficient inference and
training.
"""
if not self._supports_flash_attn_2:
raise ValueError(
"The current architecture does not support Flash Attention 2.0. Please open an issue on GitHub to "
"request support for this architecture."
)

_flash_attn_2_enabled = self._flash_attn_2_enabled

if not _flash_attn_2_enabled:
raise ValueError(
"Flash Attention 2.0 is not enabled. Please enable it with `model.enable_flash_attn_2()`."
)

self._disable_flash_attn_2()
self._flash_attn_2_enabled = False

def disable_input_require_grads(self):
"""
Removes the `_require_grads_hook`.
Expand Down
204 changes: 202 additions & 2 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,22 @@
from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
from ...modeling_utils import PreTrainedModel
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from ...pytorch_utils import reset_and_attach_new_hooks
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_available,
logging,
replace_return_docstrings,
)
from .configuration_llama import LlamaConfig


if is_flash_attn_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import pad_input, unpad_input # noqa


logger = logging.get_logger(__name__)

_CONFIG_FOR_DOC = "LlamaConfig"
Expand All @@ -57,6 +69,59 @@ def _make_causal_mask(
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)


def _convert_to_padding_mask(attention_mask: torch.Tensor, mask_value: float = 0.0):
"""
Convert causal attention mask to key-padding mask
"""
if len(attention_mask.size()) != 4:
raise ValueError(
"Expecting attention_mask to have 4 dimensions, got tensor of shape: " f"{attention_mask.size()}"
)

batch_size = attention_mask.size(0)
key_length = attention_mask.size(-1)

padding_mask = torch.ones((batch_size, key_length), device=attention_mask.device)

for i in range(batch_size):
mask_slice = attention_mask[i, :, -1, :]
padding_mask[i, :] = torch.all(mask_slice == mask_value, dim=0)

return padding_mask


def recursively_replace_module(model, old_class, target_class):
"""
Recursively replace all old_class instances of the model with a target class. The target class should have the same
sub-module names than the old class.

Args:
model (`torch.nn.Module`):
The model or the child module used for recursion
old_class (`class`):
The target old class to replace
target_class (`class`):
The new class that is going to be used in the replaced module.
"""
for name, module in model.named_children():
if isinstance(module, old_class):
torch_device = module.q_proj.weight.device
with torch.device(torch_device):
new_module = target_class(module.config)

for inner_module_name, inner_module in module.named_modules():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should use module.named_children() instead of module.named_modules()

  1. named_modules() will return self module with blank name
  2. the function is already recursively, should only check children module instead of all. (but it will have same result in llama)

setattr(new_module, inner_module_name, inner_module)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this is the place where we get the original weigts to the new module?


if hasattr(module, "_hf_hook"):
reset_and_attach_new_hooks(module, new_module)

model._modules[name] = new_module
module = None

if module is not None and len(list(module.children())) > 0:
recursively_replace_module(module, old_class, target_class)


# Copied from transformers.models.bart.modeling_bart._expand_mask
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expand Down Expand Up @@ -256,6 +321,7 @@ def __init__(self, config: LlamaConfig):
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)

self._init_rope()

def _init_rope(self):
Expand Down Expand Up @@ -328,7 +394,6 @@ def forward(

past_key_value = (key_states, value_states) if use_cache else None

# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

Expand Down Expand Up @@ -358,6 +423,7 @@ def forward(
)

attn_output = attn_output.transpose(1, 2).contiguous()

attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

if self.config.pretraining_tp > 1:
Expand All @@ -373,6 +439,129 @@ def forward(
return attn_output, attn_weights, past_key_value


class LlamaFlashAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""

# Copied from transformers.models.llama.modeling_llama.LlamaAttention.__init__
def __init__(self, config: LlamaConfig):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings

if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
)
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)

self._init_rope()

# Copied from transformers.models.llama.modeling_llama.LlamaAttention._init_rope
def _init_rope(self):
if self.config.rope_scaling is None:
self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
else:
scaling_type = self.config.rope_scaling["type"]
scaling_factor = self.config.rope_scaling["factor"]
if scaling_type == "linear":
self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor
)
elif scaling_type == "dynamic":
self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor
)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")

# Copied from transformers.models.llama.modeling_llama.LlamaAttention._shape
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()

def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# LlamaFlashAttention attention does not support output_attentions
output_attentions = False

bsz, q_len, _ = hidden_states.size()

query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)

# Flash attention requires the input to have the shape
# batch_size x seq_length x head_dime x hidden_dim
# therefore we just need to keep the original shape
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

kv_seq_len = key_states.shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)

query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)

past_key_value = None

# TODO: llama does not have dropout in the config??
# It is recommended to use dropout with FA according to the docs
# when training.
dropout_rate = 0.0 # if not self.training else self.attn_dropout

padding_mask = _convert_to_padding_mask(attention_mask)

# contains at least one padding token
if padding_mask.sum().item() != bsz * kv_seq_len:
query_states, indices, current_query_length, query_max_seqlen = unpad_input(query_states, padding_mask)
key_states, _, current_key_length, key_max_seqlen = unpad_input(key_states, padding_mask)
value_states, _, _, _ = unpad_input(value_states, padding_mask)

attn_output_unpad = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=current_query_length,
cu_seqlens_k=current_key_length,
max_seqlen_q=query_max_seqlen,
max_seqlen_k=key_max_seqlen,
dropout_p=0.0,
softmax_scale=None,
causal=True,
)

attn_output = pad_input(attn_output_unpad, indices, bsz, kv_seq_len)
else:
attn_output = flash_attn_func(query_states, key_states, value_states, dropout_rate, causal=True)

attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)

if not output_attentions:
attn_weights = None

return attn_output, attn_weights, past_key_value


class LlamaDecoderLayer(nn.Module):
def __init__(self, config: LlamaConfig):
super().__init__()
Expand Down Expand Up @@ -464,6 +653,7 @@ class LlamaPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["LlamaDecoderLayer"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True

def _init_weights(self, module):
std = self.config.initializer_range
Expand All @@ -480,6 +670,16 @@ def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, LlamaModel):
module.gradient_checkpointing = value

def _enable_flash_attn_2(self):
for _, module in self.named_children():
if len(list(module.children())) > 0:
recursively_replace_module(module, LlamaAttention, LlamaFlashAttention)

def _disable_flash_attn_2(self):
for _, module in self.named_children():
if len(list(module.children())) > 0:
recursively_replace_module(module, LlamaFlashAttention, LlamaAttention)


LLAMA_INPUTS_DOCSTRING = r"""
Args:
Expand Down
31 changes: 31 additions & 0 deletions src/transformers/pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,3 +286,34 @@ def id_tensor_storage(tensor: torch.Tensor) -> Tuple[torch.device, int, int]:
non-overlapping lifetimes may have the same id.
"""
return tensor.device, storage_ptr(tensor), storage_size(tensor)


def reset_and_attach_new_hooks(old_module, new_module) -> None:
"""
Attach new hooks in new_module that are similar to the hook of old_module

Args:
old_module (`torch.nn.Module`):
The old module that contains the old hook
new_module (`torch.nn.Module`):
The new module that does not contain any hook
hook (`~accelerate.hooks.AlignDeviceHook`):
The
"""
import accelerate
from accelerate.hooks import add_hook_to_module, remove_hook_from_module

hook = old_module._hf_hook

hook_cls = getattr(accelerate.hooks, hook.__class__.__name__)
hook_attr = hook.__dict__
filtered_old_hook_attr = {}
old_hook_init_signature = inspect.signature(hook_cls.__init__)
for k in hook_attr.keys():
if k in old_hook_init_signature.parameters:
filtered_old_hook_attr[k] = hook_attr[k]

new_hook = hook_cls(**filtered_old_hook_attr)

remove_hook_from_module(old_module)
add_hook_to_module(new_module, new_hook)
1 change: 1 addition & 0 deletions src/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@
is_detectron2_available,
is_essentia_available,
is_faiss_available,
is_flash_attn_available,
is_flax_available,
is_ftfy_available,
is_in_notebook,
Expand Down
Loading
Loading