diff --git a/docs/source/en/model_doc/albert.md b/docs/source/en/model_doc/albert.md index a75e6757804862..d195203615de83 100644 --- a/docs/source/en/model_doc/albert.md +++ b/docs/source/en/model_doc/albert.md @@ -59,7 +59,52 @@ This model was contributed by [lysandre](https://huggingface.co/lysandre). This - Layers are split in groups that share parameters (to save memory). Next sentence prediction is replaced by a sentence ordering prediction: in the inputs, we have two sentences A and B (that are consecutive) and we either feed A followed by B or B followed by A. The model must predict if they have been swapped or not. - +### Using Scaled Dot Product Attention (SDPA) + +PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function +encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the +[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) +or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention) +page for more information. + +SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set +`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used. + +``` +from transformers import AlbertModel +model = AlbertModel.from_pretrained("albert/albert-base-v1", torch_dtype=torch.float16, attn_implementation="sdpa") +... +``` + +For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`). + +On a local benchmark (GeForce RTX 2060-8GB, PyTorch 2.3.1, OS Ubuntu 20.04) with `float16`, we saw the +following speedups during training and inference. + +#### Training for 100 iterations + +|batch_size|seq_len|Time per batch (eager - s)| Time per batch (sdpa - s)| Speedup (%)| Eager peak mem (MB)| sdpa peak mem (MB)| Mem saving (%)| +|----------|-------|--------------------------|--------------------------|------------|--------------------|-------------------|---------------| +|2 |256 |0.028 |0.024 |14.388 |358.411 |321.088 |11.624 | +|2 |512 |0.049 |0.041 |17.681 |753.458 |602.660 |25.022 | +|4 |256 |0.044 |0.039 |12.246 |679.534 |602.660 |12.756 | +|4 |512 |0.090 |0.076 |18.472 |1434.820 |1134.140 |26.512 | +|8 |256 |0.081 |0.072 |12.664 |1283.825 |1134.140 |13.198 | +|8 |512 |0.170 |0.143 |18.957 |2820.398 |2219.695 |27.062 | + +#### Inference with 50 batches + +|batch_size|seq_len|Per token latency eager (ms)|Per token latency SDPA (ms)|Speedup (%) |Mem eager (MB)|Mem BT (MB)|Mem saved (%)| +|----------|-------|----------------------------|---------------------------|------------|--------------|-----------|-------------| +|4 |128 |0.083 |0.071 |16.967 |48.319 |48.45 |-0.268 | +|4 |256 |0.148 |0.127 |16.37 |63.4 |63.922 |-0.817 | +|4 |512 |0.31 |0.247 |25.473 |110.092 |94.343 |16.693 | +|8 |128 |0.137 |0.124 |11.102 |63.4 |63.66 |-0.409 | +|8 |256 |0.271 |0.231 |17.271 |91.202 |92.246 |-1.132 | +|8 |512 |0.602 |0.48 |25.47 |186.159 |152.564 |22.021 | +|16 |128 |0.252 |0.224 |12.506 |91.202 |91.722 |-0.567 | +|16 |256 |0.526 |0.448 |17.604 |148.378 |150.467 |-1.388 | +|16 |512 |1.203 |0.96 |25.365 |338.293 |271.102 |24.784 | This model was contributed by [lysandre](https://huggingface.co/lysandre). This model jax version was contributed by [kamalkraj](https://huggingface.co/kamalkraj). The original code can be found [here](https://github.com/google-research/ALBERT). diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index b0109a0e8dc1a9..999da9bb2c44e5 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -197,6 +197,7 @@ FlashAttention is more memory efficient, meaning you can train on much larger se PyTorch's [`torch.nn.functional.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html) (SDPA) can also call FlashAttention and memory-efficient attention kernels under the hood. SDPA support is currently being added natively in Transformers and is used by default for `torch>=2.1.1` when an implementation is available. You may also set `attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used. For now, Transformers supports SDPA inference and training for the following architectures: +* [Albert](https://huggingface.co/docs/transformers/model_doc/albert#transformers.AlbertModel) * [Audio Spectrogram Transformer](https://huggingface.co/docs/transformers/model_doc/audio-spectrogram-transformer#transformers.ASTModel) * [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel) * [Bert](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertModel) diff --git a/src/transformers/models/albert/modeling_albert.py b/src/transformers/models/albert/modeling_albert.py index ac4958798b2cdd..6ccb266009e193 100755 --- a/src/transformers/models/albert/modeling_albert.py +++ b/src/transformers/models/albert/modeling_albert.py @@ -24,6 +24,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPooling, @@ -34,7 +35,12 @@ TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import ( + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + is_torch_greater_or_equal_than_2_2, + prune_linear_layer, +) from ...utils import ( ModelOutput, add_code_sample_docstrings, @@ -358,6 +364,66 @@ def forward( return (layernormed_context_layer, attention_probs) if output_attentions else (layernormed_context_layer,) +class AlbertSdpaAttention(AlbertAttention): + def __init__(self, config): + super().__init__(config) + self.dropout_prob = config.attention_probs_dropout_prob + self.require_contiguous_qkv = not is_torch_greater_or_equal_than_2_2 + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: bool = False, + ) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: + if self.position_embedding_type != "absolute" or output_attentions or head_mask is not None: + logger.warning( + "AlbertSdpaAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support " + "non-absolute `position_embedding_type` or `output_attentions=True` or `head_mask`. Falling back to " + "the eager attention implementation, but specifying the eager implementation will be required from " + "Transformers version v5.0.0 onwards. This warning can be removed using the argument " + '`attn_implementation="eager"` when loading the model.' + ) + return super().forward(hidden_states, attention_mask, head_mask, output_attentions) + + batch_size, seq_len, _ = hidden_states.size() + query_layer = self.transpose_for_scores(self.query(hidden_states)) + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom + # attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0. + # Reference: https://github.com/pytorch/pytorch/issues/112577 + if self.require_contiguous_qkv and query_layer.device.type == "cuda" and attention_mask is not None: + query_layer = query_layer.contiguous() + key_layer = key_layer.contiguous() + value_layer = value_layer.contiguous() + + attention_output = torch.nn.functional.scaled_dot_product_attention( + query=query_layer, + key=key_layer, + value=value_layer, + attn_mask=attention_mask, + dropout_p=self.dropout_prob if self.training else 0.0, + is_causal=False, + ) + + attention_output = attention_output.transpose(1, 2) + attention_output = attention_output.reshape(batch_size, seq_len, self.all_head_size) + + projected_context_layer = self.dense(attention_output) + projected_context_layer_dropout = self.output_dropout(projected_context_layer) + layernormed_context_layer = self.LayerNorm(hidden_states + projected_context_layer_dropout) + return (layernormed_context_layer,) + + +ALBERT_ATTENTION_CLASSES = { + "eager": AlbertAttention, + "sdpa": AlbertSdpaAttention, +} + + class AlbertLayer(nn.Module): def __init__(self, config: AlbertConfig): super().__init__() @@ -366,7 +432,7 @@ def __init__(self, config: AlbertConfig): self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 self.full_layer_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.attention = AlbertAttention(config) + self.attention = ALBERT_ATTENTION_CLASSES[config._attn_implementation](config) self.ffn = nn.Linear(config.hidden_size, config.intermediate_size) self.ffn_output = nn.Linear(config.intermediate_size, config.hidden_size) self.activation = ACT2FN[config.hidden_act] @@ -496,6 +562,7 @@ class AlbertPreTrainedModel(PreTrainedModel): config_class = AlbertConfig load_tf_weights = load_tf_weights_in_albert base_model_prefix = "albert" + _supports_sdpa = True def _init_weights(self, module): """Initialize the weights.""" @@ -635,6 +702,9 @@ def __init__(self, config: AlbertConfig, add_pooling_layer: bool = True): self.pooler = None self.pooler_activation = None + self.attn_implementation = config._attn_implementation + self.position_embedding_type = config.position_embedding_type + # Initialize weights and apply final processing self.post_init() @@ -708,14 +778,28 @@ def forward( else: token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) - extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) - extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility - extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(self.dtype).min - head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) - embedding_output = self.embeddings( input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds ) + + use_sdpa_attention_mask = ( + self.attn_implementation == "sdpa" + and self.position_embedding_type == "absolute" + and head_mask is None + and not output_attentions + ) + + if use_sdpa_attention_mask: + extended_attention_mask = _prepare_4d_attention_mask_for_sdpa( + attention_mask, embedding_output.dtype, tgt_len=seq_length + ) + else: + extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) + extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(self.dtype).min + + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + encoder_outputs = self.encoder( embedding_output, extended_attention_mask,