Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sdpa support for Albert #32092

Merged
merged 10 commits into from
Sep 3, 2024
47 changes: 46 additions & 1 deletion docs/source/en/model_doc/albert.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,52 @@ This model was contributed by [lysandre](https://huggingface.co/lysandre). This
- Layers are split in groups that share parameters (to save memory).
Next sentence prediction is replaced by a sentence ordering prediction: in the inputs, we have two sentences A and B (that are consecutive) and we either feed A followed by B or B followed by A. The model must predict if they have been swapped or not.


### Using Scaled Dot Product Attention (SDPA)

PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function
encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the
[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention)
page for more information.

SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set
`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.

```
from transformers import AlbertModel
model = AlbertModel.from_pretrained("albert/albert-base-v1", torch_dtype=torch.float16, attn_implementation="sdpa")
...
```
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also run a few benchmarks for the model to show expected speed ups when using SDPA


For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`).

On a local benchmark (GeForce RTX 2060-8GB, PyTorch 2.3.1, OS Ubuntu 20.04) with `float16`, we saw the
following speedups during training and inference.

#### Training for 100 iterations

|batch_size|seq_len|Time per batch (eager - s)| Time per batch (sdpa - s)| Speedup (%)| Eager peak mem (MB)| sdpa peak mem (MB)| Mem saving (%)|
|----------|-------|--------------------------|--------------------------|------------|--------------------|-------------------|---------------|
|2 |256 |0.028 |0.024 |14.388 |358.411 |321.088 |11.624 |
|2 |512 |0.049 |0.041 |17.681 |753.458 |602.660 |25.022 |
|4 |256 |0.044 |0.039 |12.246 |679.534 |602.660 |12.756 |
|4 |512 |0.090 |0.076 |18.472 |1434.820 |1134.140 |26.512 |
|8 |256 |0.081 |0.072 |12.664 |1283.825 |1134.140 |13.198 |
|8 |512 |0.170 |0.143 |18.957 |2820.398 |2219.695 |27.062 |

#### Inference with 50 batches

|batch_size|seq_len|Per token latency eager (ms)|Per token latency SDPA (ms)|Speedup (%) |Mem eager (MB)|Mem BT (MB)|Mem saved (%)|
|----------|-------|----------------------------|---------------------------|------------|--------------|-----------|-------------|
|4 |128 |0.083 |0.071 |16.967 |48.319 |48.45 |-0.268 |
|4 |256 |0.148 |0.127 |16.37 |63.4 |63.922 |-0.817 |
|4 |512 |0.31 |0.247 |25.473 |110.092 |94.343 |16.693 |
|8 |128 |0.137 |0.124 |11.102 |63.4 |63.66 |-0.409 |
|8 |256 |0.271 |0.231 |17.271 |91.202 |92.246 |-1.132 |
|8 |512 |0.602 |0.48 |25.47 |186.159 |152.564 |22.021 |
|16 |128 |0.252 |0.224 |12.506 |91.202 |91.722 |-0.567 |
|16 |256 |0.526 |0.448 |17.604 |148.378 |150.467 |-1.388 |
|16 |512 |1.203 |0.96 |25.365 |338.293 |271.102 |24.784 |

This model was contributed by [lysandre](https://huggingface.co/lysandre). This model jax version was contributed by
[kamalkraj](https://huggingface.co/kamalkraj). The original code can be found [here](https://github.com/google-research/ALBERT).
Expand Down
1 change: 1 addition & 0 deletions docs/source/en/perf_infer_gpu_one.md
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ FlashAttention is more memory efficient, meaning you can train on much larger se
PyTorch's [`torch.nn.functional.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html) (SDPA) can also call FlashAttention and memory-efficient attention kernels under the hood. SDPA support is currently being added natively in Transformers and is used by default for `torch>=2.1.1` when an implementation is available. You may also set `attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.

For now, Transformers supports SDPA inference and training for the following architectures:
* [Albert](https://huggingface.co/docs/transformers/model_doc/albert#transformers.AlbertModel)
* [Audio Spectrogram Transformer](https://huggingface.co/docs/transformers/model_doc/audio-spectrogram-transformer#transformers.ASTModel)
* [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel)
* [Bert](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertModel)
Expand Down
98 changes: 91 additions & 7 deletions src/transformers/models/albert/modeling_albert.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from ...activations import ACT2FN
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPooling,
Expand All @@ -34,7 +35,12 @@
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...pytorch_utils import (
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
is_torch_greater_or_equal_than_2_2,
prune_linear_layer,
)
from ...utils import (
ModelOutput,
add_code_sample_docstrings,
Expand Down Expand Up @@ -358,6 +364,66 @@ def forward(
return (layernormed_context_layer, attention_probs) if output_attentions else (layernormed_context_layer,)


class AlbertSdpaAttention(AlbertAttention):
def __init__(self, config):
super().__init__(config)
self.dropout_prob = config.attention_probs_dropout_prob
self.require_contiguous_qkv = not is_torch_greater_or_equal_than_2_2

def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: bool = False,
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
if self.position_embedding_type != "absolute" or output_attentions or head_mask is not None:
logger.warning(
"AlbertSdpaAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
"non-absolute `position_embedding_type` or `output_attentions=True` or `head_mask`. Falling back to "
"the eager attention implementation, but specifying the eager implementation will be required from "
"Transformers version v5.0.0 onwards. This warning can be removed using the argument "
'`attn_implementation="eager"` when loading the model.'
)
return super().forward(hidden_states, attention_mask, head_mask, output_attentions)

batch_size, seq_len, _ = hidden_states.size()
query_layer = self.transpose_for_scores(self.query(hidden_states))
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))

# SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom
# attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0.
# Reference: https://github.com/pytorch/pytorch/issues/112577
if self.require_contiguous_qkv and query_layer.device.type == "cuda" and attention_mask is not None:
query_layer = query_layer.contiguous()
key_layer = key_layer.contiguous()
value_layer = value_layer.contiguous()

attention_output = torch.nn.functional.scaled_dot_product_attention(
query=query_layer,
key=key_layer,
value=value_layer,
attn_mask=attention_mask,
dropout_p=self.dropout_prob if self.training else 0.0,
is_causal=False,
)

attention_output = attention_output.transpose(1, 2)
attention_output = attention_output.reshape(batch_size, seq_len, self.all_head_size)

projected_context_layer = self.dense(attention_output)
projected_context_layer_dropout = self.output_dropout(projected_context_layer)
layernormed_context_layer = self.LayerNorm(hidden_states + projected_context_layer_dropout)
return (layernormed_context_layer,)


ALBERT_ATTENTION_CLASSES = {
"eager": AlbertAttention,
"sdpa": AlbertSdpaAttention,
}


class AlbertLayer(nn.Module):
def __init__(self, config: AlbertConfig):
super().__init__()
Expand All @@ -366,7 +432,7 @@ def __init__(self, config: AlbertConfig):
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.full_layer_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.attention = AlbertAttention(config)
self.attention = ALBERT_ATTENTION_CLASSES[config._attn_implementation](config)
self.ffn = nn.Linear(config.hidden_size, config.intermediate_size)
self.ffn_output = nn.Linear(config.intermediate_size, config.hidden_size)
self.activation = ACT2FN[config.hidden_act]
Expand Down Expand Up @@ -496,6 +562,7 @@ class AlbertPreTrainedModel(PreTrainedModel):
config_class = AlbertConfig
load_tf_weights = load_tf_weights_in_albert
base_model_prefix = "albert"
_supports_sdpa = True

def _init_weights(self, module):
"""Initialize the weights."""
Expand Down Expand Up @@ -635,6 +702,9 @@ def __init__(self, config: AlbertConfig, add_pooling_layer: bool = True):
self.pooler = None
self.pooler_activation = None

self.attn_implementation = config._attn_implementation
self.position_embedding_type = config.position_embedding_type

# Initialize weights and apply final processing
self.post_init()

Expand Down Expand Up @@ -708,14 +778,28 @@ def forward(
else:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)

extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(self.dtype).min
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)

embedding_output = self.embeddings(
input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
)

use_sdpa_attention_mask = (
self.attn_implementation == "sdpa"
and self.position_embedding_type == "absolute"
and head_mask is None
and not output_attentions
)

if use_sdpa_attention_mask:
extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(
attention_mask, embedding_output.dtype, tgt_len=seq_length
)
else:
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(self.dtype).min

head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)

encoder_outputs = self.encoder(
embedding_output,
extended_attention_mask,
Expand Down
Loading