Skip to content

Commit

Permalink
Add sdpa support for Albert (#32092)
Browse files Browse the repository at this point in the history
* Add sdpa support for Albert

* [run_slow] albert

* Add benchmarks and PR suggestion

* Fix quality

* Fix

* [run_slow] albert
  • Loading branch information
OmarManzoor authored Sep 3, 2024
1 parent e969d88 commit 03c12d0
Show file tree
Hide file tree
Showing 3 changed files with 138 additions and 8 deletions.
47 changes: 46 additions & 1 deletion docs/source/en/model_doc/albert.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,52 @@ This model was contributed by [lysandre](https://huggingface.co/lysandre). This
- Layers are split in groups that share parameters (to save memory).
Next sentence prediction is replaced by a sentence ordering prediction: in the inputs, we have two sentences A and B (that are consecutive) and we either feed A followed by B or B followed by A. The model must predict if they have been swapped or not.


### Using Scaled Dot Product Attention (SDPA)

PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function
encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the
[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention)
page for more information.

SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set
`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.

```
from transformers import AlbertModel
model = AlbertModel.from_pretrained("albert/albert-base-v1", torch_dtype=torch.float16, attn_implementation="sdpa")
...
```

For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`).

On a local benchmark (GeForce RTX 2060-8GB, PyTorch 2.3.1, OS Ubuntu 20.04) with `float16`, we saw the
following speedups during training and inference.

#### Training for 100 iterations

|batch_size|seq_len|Time per batch (eager - s)| Time per batch (sdpa - s)| Speedup (%)| Eager peak mem (MB)| sdpa peak mem (MB)| Mem saving (%)|
|----------|-------|--------------------------|--------------------------|------------|--------------------|-------------------|---------------|
|2 |256 |0.028 |0.024 |14.388 |358.411 |321.088 |11.624 |
|2 |512 |0.049 |0.041 |17.681 |753.458 |602.660 |25.022 |
|4 |256 |0.044 |0.039 |12.246 |679.534 |602.660 |12.756 |
|4 |512 |0.090 |0.076 |18.472 |1434.820 |1134.140 |26.512 |
|8 |256 |0.081 |0.072 |12.664 |1283.825 |1134.140 |13.198 |
|8 |512 |0.170 |0.143 |18.957 |2820.398 |2219.695 |27.062 |

#### Inference with 50 batches

|batch_size|seq_len|Per token latency eager (ms)|Per token latency SDPA (ms)|Speedup (%) |Mem eager (MB)|Mem BT (MB)|Mem saved (%)|
|----------|-------|----------------------------|---------------------------|------------|--------------|-----------|-------------|
|4 |128 |0.083 |0.071 |16.967 |48.319 |48.45 |-0.268 |
|4 |256 |0.148 |0.127 |16.37 |63.4 |63.922 |-0.817 |
|4 |512 |0.31 |0.247 |25.473 |110.092 |94.343 |16.693 |
|8 |128 |0.137 |0.124 |11.102 |63.4 |63.66 |-0.409 |
|8 |256 |0.271 |0.231 |17.271 |91.202 |92.246 |-1.132 |
|8 |512 |0.602 |0.48 |25.47 |186.159 |152.564 |22.021 |
|16 |128 |0.252 |0.224 |12.506 |91.202 |91.722 |-0.567 |
|16 |256 |0.526 |0.448 |17.604 |148.378 |150.467 |-1.388 |
|16 |512 |1.203 |0.96 |25.365 |338.293 |271.102 |24.784 |

This model was contributed by [lysandre](https://huggingface.co/lysandre). This model jax version was contributed by
[kamalkraj](https://huggingface.co/kamalkraj). The original code can be found [here](https://github.com/google-research/ALBERT).
Expand Down
1 change: 1 addition & 0 deletions docs/source/en/perf_infer_gpu_one.md
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ FlashAttention is more memory efficient, meaning you can train on much larger se
PyTorch's [`torch.nn.functional.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html) (SDPA) can also call FlashAttention and memory-efficient attention kernels under the hood. SDPA support is currently being added natively in Transformers and is used by default for `torch>=2.1.1` when an implementation is available. You may also set `attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.

For now, Transformers supports SDPA inference and training for the following architectures:
* [Albert](https://huggingface.co/docs/transformers/model_doc/albert#transformers.AlbertModel)
* [Audio Spectrogram Transformer](https://huggingface.co/docs/transformers/model_doc/audio-spectrogram-transformer#transformers.ASTModel)
* [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel)
* [Bert](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertModel)
Expand Down
98 changes: 91 additions & 7 deletions src/transformers/models/albert/modeling_albert.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from ...activations import ACT2FN
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPooling,
Expand All @@ -34,7 +35,12 @@
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...pytorch_utils import (
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
is_torch_greater_or_equal_than_2_2,
prune_linear_layer,
)
from ...utils import (
ModelOutput,
add_code_sample_docstrings,
Expand Down Expand Up @@ -358,6 +364,66 @@ def forward(
return (layernormed_context_layer, attention_probs) if output_attentions else (layernormed_context_layer,)


class AlbertSdpaAttention(AlbertAttention):
def __init__(self, config):
super().__init__(config)
self.dropout_prob = config.attention_probs_dropout_prob
self.require_contiguous_qkv = not is_torch_greater_or_equal_than_2_2

def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: bool = False,
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
if self.position_embedding_type != "absolute" or output_attentions or head_mask is not None:
logger.warning(
"AlbertSdpaAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
"non-absolute `position_embedding_type` or `output_attentions=True` or `head_mask`. Falling back to "
"the eager attention implementation, but specifying the eager implementation will be required from "
"Transformers version v5.0.0 onwards. This warning can be removed using the argument "
'`attn_implementation="eager"` when loading the model.'
)
return super().forward(hidden_states, attention_mask, head_mask, output_attentions)

batch_size, seq_len, _ = hidden_states.size()
query_layer = self.transpose_for_scores(self.query(hidden_states))
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))

# SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom
# attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0.
# Reference: https://github.com/pytorch/pytorch/issues/112577
if self.require_contiguous_qkv and query_layer.device.type == "cuda" and attention_mask is not None:
query_layer = query_layer.contiguous()
key_layer = key_layer.contiguous()
value_layer = value_layer.contiguous()

attention_output = torch.nn.functional.scaled_dot_product_attention(
query=query_layer,
key=key_layer,
value=value_layer,
attn_mask=attention_mask,
dropout_p=self.dropout_prob if self.training else 0.0,
is_causal=False,
)

attention_output = attention_output.transpose(1, 2)
attention_output = attention_output.reshape(batch_size, seq_len, self.all_head_size)

projected_context_layer = self.dense(attention_output)
projected_context_layer_dropout = self.output_dropout(projected_context_layer)
layernormed_context_layer = self.LayerNorm(hidden_states + projected_context_layer_dropout)
return (layernormed_context_layer,)


ALBERT_ATTENTION_CLASSES = {
"eager": AlbertAttention,
"sdpa": AlbertSdpaAttention,
}


class AlbertLayer(nn.Module):
def __init__(self, config: AlbertConfig):
super().__init__()
Expand All @@ -366,7 +432,7 @@ def __init__(self, config: AlbertConfig):
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.full_layer_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.attention = AlbertAttention(config)
self.attention = ALBERT_ATTENTION_CLASSES[config._attn_implementation](config)
self.ffn = nn.Linear(config.hidden_size, config.intermediate_size)
self.ffn_output = nn.Linear(config.intermediate_size, config.hidden_size)
self.activation = ACT2FN[config.hidden_act]
Expand Down Expand Up @@ -496,6 +562,7 @@ class AlbertPreTrainedModel(PreTrainedModel):
config_class = AlbertConfig
load_tf_weights = load_tf_weights_in_albert
base_model_prefix = "albert"
_supports_sdpa = True

def _init_weights(self, module):
"""Initialize the weights."""
Expand Down Expand Up @@ -635,6 +702,9 @@ def __init__(self, config: AlbertConfig, add_pooling_layer: bool = True):
self.pooler = None
self.pooler_activation = None

self.attn_implementation = config._attn_implementation
self.position_embedding_type = config.position_embedding_type

# Initialize weights and apply final processing
self.post_init()

Expand Down Expand Up @@ -708,14 +778,28 @@ def forward(
else:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)

extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(self.dtype).min
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)

embedding_output = self.embeddings(
input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
)

use_sdpa_attention_mask = (
self.attn_implementation == "sdpa"
and self.position_embedding_type == "absolute"
and head_mask is None
and not output_attentions
)

if use_sdpa_attention_mask:
extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(
attention_mask, embedding_output.dtype, tgt_len=seq_length
)
else:
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(self.dtype).min

head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)

encoder_outputs = self.encoder(
embedding_output,
extended_attention_mask,
Expand Down

0 comments on commit 03c12d0

Please sign in to comment.