diff --git a/docs/source/en/perf_infer_gpu_many.md b/docs/source/en/perf_infer_gpu_many.md index 756d2b3ef57b0b..2118b5ddb40431 100644 --- a/docs/source/en/perf_infer_gpu_many.md +++ b/docs/source/en/perf_infer_gpu_many.md @@ -22,6 +22,10 @@ Note: A multi GPU setup can use the majority of the strategies described in the +## Flash Attention 2 + +Flash Attention 2 integration also works in a multi-GPU setup, check out the appropriate section in the [single GPU section](./perf_infer_gpu_one#Flash-Attention-2) + ## BetterTransformer [BetterTransformer](https://huggingface.co/docs/optimum/bettertransformer/overview) converts 🤗 Transformers models to use the PyTorch-native fastpath execution, which calls optimized kernels like Flash Attention under the hood. diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index 855c52ffd98c62..86e137cf14d7a1 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -17,6 +17,154 @@ rendered properly in your Markdown viewer. In addition to this guide, relevant information can be found as well in [the guide for training on a single GPU](perf_train_gpu_one) and [the guide for inference on CPUs](perf_infer_cpu). +## Flash Attention 2 + + + +Note that this feature is experimental and might considerably change in future versions. For instance, the Flash Attention 2 API might migrate to `BetterTransformer` API in the near future. + + + +Flash Attention 2 can considerably speed up transformer-based models' training and inference speed. Flash Attention 2 has been introduced in the [official Flash Attention repository](https://github.com/Dao-AILab/flash-attention) by Tri Dao et al. The scientific paper on Flash Attention can be found [here](https://arxiv.org/abs/2205.14135). + +Make sure to follow the installation guide on the repository mentioned above to properly install Flash Attention 2. Once that package is installed, you can benefit from this feature. + +We natively support Flash Attention 2 for the following models: + +- Llama +- Falcon + +You can request to add Flash Attention 2 support for more models by opening an issue on GitHub, and even open a Pull Request to integrate the changes. The supported models can be used for inference and training, including training with padding tokens - *which is currently not supported for `BetterTransformer` API below.* + + + +Flash Attention 2 can only be used when the models' dtype is `fp16` or `bf16` and runs only on NVIDIA-GPU devices. Make sure to cast your model to the appropriate dtype and load them on a supported device before using that feature. + + + +### Quick usage + +To enable Flash Attention 2 in your model, add `use_flash_attention_2` in the `from_pretrained` arguments: + +```python +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM + +model_id = "tiiuae/falcon-7b" +tokenizer = AutoTokenizer.from_pretrained(model_id) + +model = AutoModelForCausalLM.from_pretrained( + model_id, + torch_dtype=torch.bfloat16, + use_flash_attention_2=True, +) +``` + +And use it for generation or fine-tuning. + +### Expected speedups + +You can benefit from considerable speedups for fine-tuning and inference, especially for long sequences. However, since Flash Attention does not support computing attention scores with padding tokens under the hood, we must manually pad / unpad the attention scores for batched inference when the sequence contains padding tokens. This leads to a significant slowdown for batched generations with padding tokens. + +To overcome this, one should use Flash Attention without padding tokens in the sequence for training (e.g., by packing a dataset, i.e., concatenating sequences until reaching the maximum sequence length. An example is provided [here](https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_clm.py#L516). + +Below is the expected speedup you can get for a simple forward pass on [tiiuae/falcon-7b](https://hf.co/tiiuae/falcon-7b) with a sequence length of 4096 and various batch sizes without padding tokens: + +Below is the expected speedup you can get for a simple forward pass on [tiiuae/falcon-7b](https://hf.co/tiiuae/falcon-7b) with a sequence length of 4096 and various batch sizes, without padding tokens: + +
+ +
+ +Below is the expected speedup you can get for a simple forward pass on [`meta-llama/Llama-7b-hf`](https://hf.co/meta-llama/Llama-7b-hf) with a sequence length of 4096 and various batch sizes, without padding tokens: + +
+ +
+ +For sequences with padding tokens (training with padding tokens or generating with padding tokens), we need to unpad / pad the input sequences to compute correctly the attention scores. For relatively small sequence length, on pure forward pass, this creates an overhead leading to a small speedup (below 30% of the input has been filled with padding tokens). + +
+ +
+ +But for large sequence length you can benefit from interesting speedup for pure inference (also training) + +Note that Flash Attention makes the attention computation more memory efficient, meaning you can train with much larger sequence lengths without facing CUDA OOM issues. It can lead up to memory reduction up to 20 for large sequence length. Check out [the official flash attention repository](https://github.com/Dao-AILab/flash-attention) for more details. + +
+ +
+ + +### Advanced usage + +You can combine this feature with many exisiting feature for model optimization. Check out few examples below: + +### Combining Flash Attention 2 and 8-bit models + +You can combine this feature together with 8-bit quantization: + +```python +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM + +model_id = "tiiuae/falcon-7b" +tokenizer = AutoTokenizer.from_pretrained(model_id) + +model = AutoModelForCausalLM.from_pretrained( + model_id, + load_in_8bit=True, + use_flash_attention_2=True, +) +``` + +### Combining Flash Attention 2 and 4-bit models + +You can combine this feature together with 4-bit quantization: + +```python +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM + +model_id = "tiiuae/falcon-7b" +tokenizer = AutoTokenizer.from_pretrained(model_id) + +model = AutoModelForCausalLM.from_pretrained( + model_id, + load_in_4bit=True, + use_flash_attention_2=True, +) +``` + +### Combining Flash Attention 2 and PEFT + +You can combine this feature together with PEFT for training adapters using Flash Attention 2 under the hood: + +```python +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM +from peft import LoraConfig + +model_id = "tiiuae/falcon-7b" +tokenizer = AutoTokenizer.from_pretrained(model_id) + +model = AutoModelForCausalLM.from_pretrained( + model_id, + load_in_4bit=True, + use_flash_attention_2=True, +) + +lora_config = LoraConfig( + r=8, + task_type="CAUSAL_LM" +) + +model.add_adapter(lora_config) + +... # train your model +``` + ## BetterTransformer [BetterTransformer](https://huggingface.co/docs/optimum/bettertransformer/overview) converts 🤗 Transformers models to use the PyTorch-native fastpath execution, which calls optimized kernels like Flash Attention under the hood. diff --git a/docs/source/en/perf_train_gpu_one.md b/docs/source/en/perf_train_gpu_one.md index f1b0f3976df0f8..17b62c3a1379ca 100644 --- a/docs/source/en/perf_train_gpu_one.md +++ b/docs/source/en/perf_train_gpu_one.md @@ -228,6 +228,10 @@ For additional information on tf32 vs other precisions, please refer to the foll [RTX-3090](https://github.com/huggingface/transformers/issues/14608#issuecomment-1004390803) and [A100](https://github.com/huggingface/transformers/issues/15026#issuecomment-1004543189). +## Flash Attention 2 + +You can speedup the training throughput by using Flash Attention 2 integration in transformers. Check out the appropriate section in the [single GPU section](./perf_infer_gpu_one#Flash-Attention-2) to learn more about how to load a model with Flash Attention 2 modules. + ## Optimizer choice The most common optimizer used to train transformer models is Adam or AdamW (Adam with weight decay). Adam achieves diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 00f9b5610e6bab..74086ca2d7fccb 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -855,6 +855,9 @@ def to_diff_dict(self) -> Dict[str, Any]: self.dict_torch_dtype_to_str(serializable_config_dict) + if "_flash_attn_2_enabled" in serializable_config_dict: + del serializable_config_dict["_flash_attn_2_enabled"] + return serializable_config_dict def to_dict(self) -> Dict[str, Any]: @@ -871,6 +874,8 @@ def to_dict(self) -> Dict[str, Any]: del output["_auto_class"] if "_commit_hash" in output: del output["_commit_hash"] + if "_flash_attn_2_enabled" in output: + del output["_flash_attn_2_enabled"] # Transformers version when serializing the model output["transformers_version"] = __version__ diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 1432c3b78a160c..f4e376e593a25e 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -70,6 +70,7 @@ is_accelerate_available, is_auto_gptq_available, is_bitsandbytes_available, + is_flash_attn_available, is_offline_mode, is_optimum_available, is_peft_available, @@ -1116,6 +1117,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix is_parallelizable = False supports_gradient_checkpointing = False + # Flash Attention 2 support + _supports_flash_attn_2 = False + @property def dummy_inputs(self) -> Dict[str, torch.Tensor]: """ @@ -1239,6 +1243,84 @@ def can_generate(cls) -> bool: return False return True + @classmethod + def _check_and_enable_flash_attn_2( + cls, config, torch_dtype: Optional[torch.dtype] = None, device_map: Optional[Union[str, Dict[str, int]]] = None + ) -> PretrainedConfig: + """ + If you don't know about Flash Attention, check out the official repository of flash attention: + https://github.com/Dao-AILab/flash-attention + + For using Flash Attention 1.0 you can do it directly via the `BetterTransformer` API, have a look at this + specific section of the documentation to learn more about it: + https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#decoder-models + + The method checks if the current setup is compatible with Flash Attention as it requires the model to be in + half precision and not ran on CPU. + + If all checks pass, the method will create an attribute in the config `_flash_attn_2_enabled` so that the model + can initialize the correct attention module + """ + if not cls._supports_flash_attn_2: + raise ValueError( + "The current architecture does not support Flash Attention 2.0. Please open an issue on GitHub to " + "request support for this architecture: https://github.com/huggingface/transformers/issues/new" + ) + + if not is_flash_attn_available(): + raise ImportError( + "Flash Attention 2.0 is not available. Please refer to the documentation of https://github.com/Dao-AILab/flash-attention for" + " installing it." + ) + else: + flash_attention_version = version.parse(importlib.metadata.version("flash_attn")) + is_flash_greater_than_2 = flash_attention_version > version.parse("2.0.0") + if not is_flash_greater_than_2: + raise ValueError( + f"You need flash_attn package version to be greater than 2.0. Make sure to have that version installed - detected version {flash_attention_version}" + ) + + _is_bettertransformer = getattr(cls, "use_bettertransformer", False) + + if _is_bettertransformer: + raise ValueError( + "Flash Attention 2 and BetterTransformer API are not compatible. Please make sure to disable BetterTransformers by doing model.reverse_bettertransformer()" + ) + + if torch_dtype is None: + logger.warning( + "You are attempting to use Flash Attention 2.0 without specifying a torch dtype. This might lead to unexpected behaviour" + ) + elif torch_dtype is not None and torch_dtype not in [torch.float16, torch.bfloat16]: + raise ValueError( + f"Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes. You passed {torch_dtype}, this might lead to" + " unexpected behaviour." + ) + + if device_map is None: + if torch.cuda.is_available(): + logger.warning( + "You are attempting to use Flash Attention 2.0 with a model initialized on CPU. Make sure to move the model to GPU" + " after initializing it on CPU with `model.to('cuda')`." + ) + else: + raise ValueError( + "You are attempting to use Flash Attention 2.0 with a model initialized on CPU and with no GPU available. " + "This is not supported yet. Please make sure to have access to a GPU and either initialise the model on a GPU by passing a device_map " + "or initialising the model on CPU and then moving it to GPU." + ) + elif ( + device_map is not None + and isinstance(device_map, dict) + and ("cpu" in device_map.values() or "disk" in device_map.values()) + ): + raise ValueError( + "You are attempting to use Flash Attention 2.0 with a model dispatched on CPU or disk. This is not supported. Please make sure to " + "initialise the model on a GPU by passing a device_map that contains only GPU devices as keys." + ) + config._flash_attn_2_enabled = True + return config + def enable_input_require_grads(self): """ Enables the gradients for the input embeddings. This is useful for fine-tuning adapter weights while keeping @@ -2374,6 +2456,7 @@ def from_pretrained( variant = kwargs.pop("variant", None) _adapter_model_path = kwargs.pop("_adapter_model_path", None) adapter_name = kwargs.pop("adapter_name", "default") + use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False) if is_fsdp_enabled(): low_cpu_mem_usage = True @@ -2985,6 +3068,9 @@ def from_pretrained( elif load_in_8bit or load_in_4bit or low_cpu_mem_usage: init_contexts.append(init_empty_weights()) + if use_flash_attention_2: + config = cls._check_and_enable_flash_attn_2(config, torch_dtype=torch_dtype, device_map=device_map) + with ContextManagers(init_contexts): model = cls(config, *model_args, **model_kwargs) diff --git a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py index 0d36d8c0e06306..a2e824133b7e2a 100644 --- a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py +++ b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py @@ -364,7 +364,6 @@ def __init__(self, config: OpenLlamaConfig): self.input_layernorm = OpenLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = OpenLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - # Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward def forward( self, hidden_states: torch.Tensor, diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index c541fab0a253a7..85a83258517b6f 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -32,11 +32,21 @@ TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel -from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging +from ...utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_available, + logging, +) from ..auto.configuration_auto import sanitize_code_revision from .configuration_falcon import FalconConfig +if is_flash_attn_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + logger = logging.get_logger(__name__) FALCON_PRETRAINED_MODEL_ARCHIVE_LIST = [ @@ -67,6 +77,19 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +# Copied from transformers.models.llama.modeling_llama._get_unpad_data +def _get_unpad_data(padding_mask): + seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + # TODO (joao): Is this the same implementation as in Llama? If so, let's make them the same and add the copy facilities class FalconRotaryEmbedding(nn.Module): """Implementation of RotaryEmbedding from GPT-NeoX. @@ -405,6 +428,7 @@ def forward( head_mask: Optional[torch.Tensor] = None, use_cache: bool = False, output_attentions: bool = False, + padding_mask: Optional[torch.LongTensor] = None, ): fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads @@ -519,6 +543,185 @@ def forward( return output_tensor, present +class FalconFlashAttention2(FalconAttention): + """ + Falcon flash attention module. This module inherits from `FalconAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def forward( + self, + hidden_states: torch.Tensor, + alibi: Optional[torch.Tensor], + attention_mask: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + head_mask: Optional[torch.Tensor] = None, + use_cache: bool = False, + output_attentions: bool = False, + padding_mask: Optional[torch.LongTensor] = None, + ): + fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] + num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads + # 3 x [batch_size, seq_length, num_heads, head_dim] + (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) + + batch_size, query_length, _, _ = query_layer.shape + + query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, query_length, self.head_dim) + key_layer = key_layer.transpose(1, 2).reshape( + batch_size * num_kv_heads, + query_length, + self.head_dim, + ) + value_layer = value_layer.transpose(1, 2).reshape(batch_size * num_kv_heads, query_length, self.head_dim) + + past_kv_length = 0 if layer_past is None else layer_past[0].shape[1] + query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length, position_ids) + + if layer_past is not None and use_cache: + past_key, past_value = layer_past + # concatenate along seq_length dimension: + # - key: [batch_size * self.num_heads, kv_length, head_dim] + # - value: [batch_size * self.num_heads, kv_length, head_dim] + key_layer = torch.cat((past_key, key_layer), dim=1) + value_layer = torch.cat((past_value, value_layer), dim=1) + + _, kv_seq_length, _ = key_layer.shape + + torch_dtype = query_layer.dtype + + past_key_value = (key_layer, value_layer) if use_cache else None + + query_layer = ( + query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim).transpose(1, 2).to(torch_dtype) + ) + key_layer = key_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim).transpose(1, 2).to(torch_dtype) + value_layer = value_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim).transpose(1, 2).to(torch_dtype) + + if alibi is not None: + raise ValueError("`alibi` is not supported when `use_flash_attn` is True") + + attn_dropout = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_layer.dtype + if input_dtype == torch.float32: + logger.warning_once( + "The input hidden states seems to be silently casted in float32, this might be related to" + " the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + " float16." + ) + + query_layer = query_layer.to(torch.float16) + key_layer = key_layer.to(torch.float16) + value_layer = value_layer.to(torch.float16) + + attn_output = self._flash_attention_forward( + query_layer, key_layer, value_layer, padding_mask, query_length, dropout=attn_dropout + ) + + attn_weights = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim) + attn_output = self.dense(attn_weights) + + if not output_attentions: + attn_weights = None + + return attn_output, past_key_value, attn_weights + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward + def _flash_attention_forward( + self, query_states, key_states, value_states, padding_mask, query_length, dropout=0.0, softmax_scale=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + padding_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`int`, *optional*): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + # Contains at least one padding token in the sequence + if padding_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, padding_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=True, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=True + ) + + return attn_output + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input + def _upad_input(self, query_layer, key_layer, value_layer, padding_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(padding_mask) + batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape + + key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) + value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + padding_mask = padding_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, padding_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + class FalconMLP(nn.Module): def __init__(self, config: FalconConfig): super().__init__() @@ -540,7 +743,12 @@ def __init__(self, config: FalconConfig): super().__init__() hidden_size = config.hidden_size self.num_heads = config.num_attention_heads - self.self_attention = FalconAttention(config) + + self.self_attention = ( + FalconAttention(config) + if not getattr(config, "_flash_attn_2_enabled", False) + else FalconFlashAttention2(config) + ) self.mlp = FalconMLP(config) self.hidden_dropout = config.hidden_dropout self.config = config @@ -565,6 +773,7 @@ def forward( head_mask: Optional[torch.Tensor] = None, use_cache: bool = False, output_attentions: bool = False, + padding_mask: Optional[torch.LongTensor] = None, ): residual = hidden_states @@ -584,6 +793,7 @@ def forward( head_mask=head_mask, use_cache=use_cache, output_attentions=output_attentions, + padding_mask=padding_mask, ) attention_output = attn_outputs[0] @@ -700,6 +910,7 @@ class FalconPreTrainedModel(PreTrainedModel): base_model_prefix = "transformer" supports_gradient_checkpointing = True _no_split_modules = ["FalconDecoderLayer"] + _supports_flash_attn_2 = True def __init__(self, *inputs, **kwargs): super().__init__(*inputs, **kwargs) @@ -917,9 +1128,15 @@ def forward( past_key_values_length = past_key_values[0][0].shape[1] # 1 because RW-cache, not standard format if attention_mask is None: attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=hidden_states.device) + padding_mask = None else: attention_mask = attention_mask.to(hidden_states.device) + if 0 in attention_mask: + padding_mask = attention_mask + else: + padding_mask = None + if self.use_alibi: alibi = build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype) else: @@ -964,6 +1181,7 @@ def custom_forward(*inputs): causal_mask, position_ids, head_mask[i], + padding_mask, ) else: outputs = block( @@ -975,6 +1193,7 @@ def custom_forward(*inputs): use_cache=use_cache, output_attentions=output_attentions, alibi=alibi, + padding_mask=padding_mask, ) hidden_states = outputs[0] diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 317a788869ed01..82d0300f60e85f 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -31,15 +31,38 @@ from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast from ...modeling_utils import PreTrainedModel from ...pytorch_utils import ALL_LAYERNORM_LAYERS -from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings +from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_available, + logging, + replace_return_docstrings, +) from .configuration_llama import LlamaConfig +if is_flash_attn_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "LlamaConfig" +def _get_unpad_data(padding_mask): + seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + # Copied from transformers.models.bart.modeling_bart._make_causal_mask def _make_causal_mask( input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 @@ -261,6 +284,7 @@ def __init__(self, config: LlamaConfig): self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + self._init_rope() def _init_rope(self): @@ -301,6 +325,7 @@ def forward( past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, + padding_mask: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -343,7 +368,6 @@ def forward( past_key_value = (key_states, value_states) if use_cache else None - # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) @@ -373,6 +397,7 @@ def forward( ) attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) if self.config.pretraining_tp > 1: @@ -388,11 +413,189 @@ def forward( return attn_output, attn_weights, past_key_value +class LlamaFlashAttention2(LlamaAttention): + """ + Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + padding_mask: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # LlamaFlashAttention2 attention does not support output_attentions + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dime x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + # TODO: llama does not have dropout in the config?? + # It is recommended to use dropout with FA according to the docs + # when training. + dropout_rate = 0.0 # if not self.training else self.attn_dropout + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + input_dtype = query_states.dtype + if input_dtype == torch.float32: + logger.warning_once( + "The input hidden states seems to be silently casted in float32, this might be related to" + " the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + " float16." + ) + + query_states = query_states.to(torch.float16) + key_states = key_states.to(torch.float16) + value_states = value_states.to(torch.float16) + + attn_output = self._flash_attention_forward( + query_states, key_states, value_states, padding_mask, q_len, dropout=dropout_rate + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + def _flash_attention_forward( + self, query_states, key_states, value_states, padding_mask, query_length, dropout=0.0, softmax_scale=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + padding_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`int`, *optional*): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + # Contains at least one padding token in the sequence + if padding_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, padding_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=True, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=True + ) + + return attn_output + + def _upad_input(self, query_layer, key_layer, value_layer, padding_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(padding_mask) + batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape + + key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) + value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + padding_mask = padding_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, padding_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + class LlamaDecoderLayer(nn.Module): def __init__(self, config: LlamaConfig): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = LlamaAttention(config=config) + self.self_attn = ( + LlamaAttention(config=config) + if not getattr(config, "_flash_attn_2_enabled", False) + else LlamaFlashAttention2(config=config) + ) self.mlp = LlamaMLP(config) self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -405,6 +608,7 @@ def forward( past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, + padding_mask: Optional[torch.LongTensor] = None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -432,6 +636,7 @@ def forward( past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, + padding_mask=padding_mask, ) hidden_states = residual + hidden_states @@ -479,6 +684,7 @@ class LlamaPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["LlamaDecoderLayer"] _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True def _init_weights(self, module): std = self.config.initializer_range @@ -669,6 +875,13 @@ def forward( attention_mask = torch.ones( (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device ) + padding_mask = None + else: + if 0 in attention_mask: + padding_mask = attention_mask + else: + padding_mask = None + attention_mask = self._prepare_decoder_attention_mask( attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length ) @@ -698,15 +911,12 @@ def forward( def create_custom_forward(module): def custom_forward(*inputs): # None for past_key_value - return module(*inputs, past_key_value, output_attentions) + return module(*inputs, past_key_value, output_attentions, padding_mask=padding_mask) return custom_forward layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), - hidden_states, - attention_mask, - position_ids, + create_custom_forward(decoder_layer), hidden_states, attention_mask, position_ids ) else: layer_outputs = decoder_layer( @@ -716,6 +926,7 @@ def custom_forward(*inputs): past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, + padding_mask=padding_mask, ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index 654660e4fa9f00..632bd082369778 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -452,7 +452,6 @@ def forward( "The bare Persimmon Model outputting raw hidden-states without any specific head on top.", PERSIMMON_START_DOCSTRING, ) -# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel with Llama->Persimmon class PersimmonPreTrainedModel(PreTrainedModel): config_class = PersimmonConfig base_model_prefix = "model" @@ -544,7 +543,6 @@ def _set_gradient_checkpointing(self, module, value=False): "The bare Persimmon Model outputting raw hidden-states without any specific head on top.", PERSIMMON_START_DOCSTRING, ) -# Copied from transformers.models.llama.modeling_llama.LlamaModel with LLAMA->PERSIMMON,Llama->Persimmon,PersimmonRMSNorm->nn.LayerNorm,norm->final_layernorm,rms_final_layernorm_eps->layer_norm_eps class PersimmonModel(PersimmonPreTrainedModel): """ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`PersimmonDecoderLayer`] @@ -553,6 +551,7 @@ class PersimmonModel(PersimmonPreTrainedModel): config: PersimmonConfig """ + # Copied from transformers.models.llama.modeling_llama.LlamaModel.__init__ with LLAMA->PERSIMMON,Llama->Persimmon,PersimmonRMSNorm->nn.LayerNorm,norm->final_layernorm,rms_final_layernorm_eps->layer_norm_eps def __init__(self, config: PersimmonConfig): super().__init__(config) self.padding_idx = config.pad_token_id diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index eaa44ff2246f55..d8ed86bda9fbc4 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -60,6 +60,7 @@ is_detectron2_available, is_essentia_available, is_faiss_available, + is_flash_attn_available, is_flax_available, is_fsdp_available, is_ftfy_available, @@ -392,6 +393,16 @@ def require_torch(test_case): return unittest.skipUnless(is_torch_available(), "test requires PyTorch")(test_case) +def require_flash_attn(test_case): + """ + Decorator marking a test that requires Flash Attention. + + These tests are skipped when Flash Attention isn't installed. + + """ + return unittest.skipUnless(is_flash_attn_available(), "test requires Flash Attention")(test_case) + + def require_peft(test_case): """ Decorator marking a test that requires PEFT. diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 5ee787738da511..45501a0aa620c7 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -114,6 +114,7 @@ is_detectron2_available, is_essentia_available, is_faiss_available, + is_flash_attn_available, is_flax_available, is_fsdp_available, is_ftfy_available, diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 621a993aa97df1..67fb1765679967 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -71,6 +71,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[ _accelerate_available, _accelerate_version = _is_package_available("accelerate", return_version=True) _apex_available = _is_package_available("apex") _bitsandbytes_available = _is_package_available("bitsandbytes") +_flash_attn_available = _is_package_available("flash_attn") # `importlib.metadata.version` doesn't work with `bs4` but `beautifulsoup4`. For `importlib.util.find_spec`, reversed. _bs4_available = importlib.util.find_spec("bs4") is not None _coloredlogs_available = _is_package_available("coloredlogs") @@ -570,6 +571,16 @@ def is_bitsandbytes_available(): return _bitsandbytes_available and torch.cuda.is_available() +def is_flash_attn_available(): + if not is_torch_available(): + return False + + # Let's add an extra check to see if cuda is available + import torch + + return _flash_attn_available and torch.cuda.is_available() + + def is_torchdistx_available(): return _torchdistx_available diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index 35a8a2fd3ebe46..0223acbbd72a8a 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -18,9 +18,10 @@ import unittest from parameterized import parameterized +from pytest import mark from transformers import LlamaConfig, is_torch_available, set_seed -from transformers.testing_utils import require_torch, require_torch_gpu, slow, torch_device +from transformers.testing_utils import require_flash_attn, require_torch, require_torch_gpu, slow, torch_device from ...generation.test_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester @@ -375,6 +376,41 @@ def test_model_rope_scaling(self, scaling_type): # The output should be different for long inputs self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5)) + @require_flash_attn + @require_torch_gpu + @mark.flash_attn_test + @slow + def test_flash_attn_2_generate_padding_right(self): + """ + Overwritting the common test as the test is flaky on tiny models + """ + model = LlamaForCausalLM.from_pretrained( + "meta-llama/Llama-2-7b-hf", + load_in_4bit=True, + device_map={"": 0}, + ) + + tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") + + texts = ["hi", "Hello this is a very long sentence"] + + tokenizer.padding_side = "right" + tokenizer.pad_token = tokenizer.eos_token + + inputs = tokenizer(texts, return_tensors="pt", padding=True).to(0) + + output_native = model.generate(**inputs, max_new_tokens=20, do_sample=False) + output_native = tokenizer.batch_decode(output_native) + + model = LlamaForCausalLM.from_pretrained( + "meta-llama/Llama-2-7b-hf", load_in_4bit=True, device_map={"": 0}, use_flash_attention_2=True + ) + + output_fa_2 = model.generate(**inputs, max_new_tokens=20, do_sample=False) + output_fa_2 = tokenizer.batch_decode(output_fa_2) + + self.assertListEqual(output_native, output_fa_2) + @require_torch class LlamaIntegrationTest(unittest.TestCase): diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 764554b4368d9b..b8d7367dd7c7d7 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -64,6 +64,7 @@ is_pt_flax_cross_test, is_pt_tf_cross_test, require_accelerate, + require_flash_attn, require_safetensors, require_torch, require_torch_gpu, @@ -2722,6 +2723,191 @@ def test_model_is_small(self): num_params < 1000000 ), f"{model_class} is too big for the common tests ({num_params})! It should have 1M max." + @require_flash_attn + @require_torch_gpu + @mark.flash_attn_test + @slow + def test_flash_attn_2_conversion(self): + import torch + + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + if not model_class._supports_flash_attn_2: + return + + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model = model_class.from_pretrained( + tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=True + ).to(torch_device) + + for _, module in model.named_modules(): + if "FlashAttention" in module.__class__.__name__: + return + + self.assertTrue(False, "FlashAttention2 modules not found in model") + + @require_flash_attn + @require_torch_gpu + @mark.flash_attn_test + @slow + def test_flash_attn_2_inference(self): + import torch + + for model_class in self.all_model_classes: + if not model_class._supports_flash_attn_2: + return + + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_fa = model_class.from_pretrained( + tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=True + ) + model_fa.to(torch_device) + + model = model_class.from_pretrained( + tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=False + ) + model.to(torch_device) + + dummy_input = torch.LongTensor([[1, 2, 3, 4, 5]]).to(torch_device) + dummy_attention_mask = torch.LongTensor([[0, 1, 1, 1, 1]]).to(torch_device) + + logits = model(dummy_input, output_hidden_states=True).hidden_states[-1] + logits_fa = model_fa(dummy_input, output_hidden_states=True).hidden_states[-1] + + self.assertTrue(torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)) + + output_fa = model_fa(dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True) + logits_fa = output_fa.hidden_states[-1] + + output = model(dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True) + logits = output.hidden_states[-1] + + self.assertTrue(torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2)) + + @require_flash_attn + @require_torch_gpu + @mark.flash_attn_test + @slow + def test_flash_attn_2_inference_padding_right(self): + import torch + + for model_class in self.all_model_classes: + if not model_class._supports_flash_attn_2: + return + + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_fa = model_class.from_pretrained( + tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=True + ) + model_fa.to(torch_device) + + model = model_class.from_pretrained( + tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=False + ) + model.to(torch_device) + + dummy_input = torch.LongTensor([[1, 2, 3, 4, 5]]).to(torch_device) + dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1, 0]]).to(torch_device) + + logits = model(dummy_input, output_hidden_states=True).hidden_states[-1] + logits_fa = model_fa(dummy_input, output_hidden_states=True).hidden_states[-1] + + self.assertTrue(torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)) + + output_fa = model_fa(dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True) + logits_fa = output_fa.hidden_states[-1] + + output = model(dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True) + logits = output.hidden_states[-1] + + self.assertTrue(torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2)) + + @require_flash_attn + @require_torch_gpu + @mark.flash_attn_test + @slow + def test_flash_attn_2_generate_left_padding(self): + import torch + + for model_class in self.all_generative_model_classes: + if not model_class._supports_flash_attn_2: + return + + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model = model_class.from_pretrained( + tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=False, low_cpu_mem_usage=True + ).to(torch_device) + + dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device) + dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [0, 1, 1, 1]]).to(torch_device) + + out = model.generate( + dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False + ) + + model = model_class.from_pretrained( + tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=True, low_cpu_mem_usage=True + ).to(torch_device) + + out_fa = model.generate( + dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False + ) + + self.assertTrue(torch.equal(out, out_fa)) + + @require_flash_attn + @require_torch_gpu + @mark.flash_attn_test + @slow + def test_flash_attn_2_generate_padding_right(self): + import torch + + for model_class in self.all_generative_model_classes: + if not model_class._supports_flash_attn_2: + return + + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model = model_class.from_pretrained( + tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=False, low_cpu_mem_usage=True + ).to(torch_device) + + dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device) + dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device) + + out = model.generate( + dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False + ) + + model = model_class.from_pretrained( + tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=True, low_cpu_mem_usage=True + ).to(torch_device) + + out_fa = model.generate( + dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False + ) + + self.assertTrue(torch.equal(out, out_fa)) + global_rng = random.Random()