Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sdpa for Vivit #33757

Merged
merged 18 commits into from
Oct 15, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions docs/source/en/model_doc/vivit.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,43 @@ The abstract from the paper is the following:

This model was contributed by [jegormeister](https://huggingface.co/jegormeister). The original code (written in JAX) can be found [here](https://github.com/google-research/scenic/tree/main/scenic/projects/vivit).

### Using Scaled Dot Product Attention (SDPA)

PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function
encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the
[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention)
page for more information.

SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set
`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.

```
from transformers import VivitModel
model = VivitModel.from_pretrained("google/vivit-b-16x2-kinetics400", attn_implementation="sdpa", torch_dtype=torch.float16)
...
```

For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`).

On a local benchmark (A100-40GB, PyTorch 2.3.0, OS Ubuntu 22.04) with `float32` and `google/vivit-b-16x2-kinetics400` model, we saw the following speedups during inference.

### Training
| num_training_steps | batch_size | is cuda | Speedup (%) | Eager peak mem (MB) | sdpa peak mem (MB) | Mem saving (%) |
|---------------------:|-------------:|----------:|--------------:|----------------------:|---------------------:|-----------------:|
| 100 | 1 | True | 7.122 | 2575.28 | 5932.54 | 130.364 |



### Inference
| num_batches | batch_size | is cuda | is half | Speedup (%) | Mem eager (MB) | Mem BT (MB) | Mem saved (%) |
|---------------|--------------|-----------|-----------|---------------|------------------|---------------|-----------------|
| 20 | 1 | True | False | 15.422 | 715.807 | 317.079 | 125.75 |
| 20 | 2 | True | False | 17.146 | 1234.75 | 447.175 | 176.122 |
| 20 | 4 | True | False | 18.093 | 2275.82 | 709.864 | 220.6 |
| 20 | 8 | True | False | 19.284 | 4358.19 | 1233.24 | 253.393 |


## VivitConfig

[[autodoc]] VivitConfig
Expand Down
1 change: 1 addition & 0 deletions docs/source/en/perf_infer_gpu_one.md
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ For now, Transformers supports SDPA inference and training for the following arc
* [ViTMAE](https://huggingface.co/docs/transformers/model_doc/vit_mae#transformers.ViTMAEModel)
* [ViTMSN](https://huggingface.co/docs/transformers/model_doc/vit_msn#transformers.ViTMSNModel)
* [VideoMAE](https://huggingface.co/docs/transformers/model_doc/videomae#transformers.VideoMAEModell)
* [ViViT](https://huggingface.co/docs/transformers/model_doc/vivit#transformers.VivitModel)
* [wav2vec2](https://huggingface.co/docs/transformers/model_doc/wav2vec2#transformers.Wav2Vec2Model)
* [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperModel)
* [XLM-RoBERTa](https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaModel)
Expand Down
48 changes: 47 additions & 1 deletion src/transformers/models/vivit/modeling_vivit.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,38 @@ def forward(
return outputs


# Copied from transformers.models.vit.modeling_vit.ViTSdpaSelfAttention with ViT->Vivit
class VivitSdpaSelfAttention(VivitSelfAttention):
def __init__(self, config: VivitConfig) -> None:
super().__init__(config)
self.attention_probs_dropout_prob = config.attention_probs_dropout_prob

def forward(
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
mixed_query_layer = self.query(hidden_states)

key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)

context_layer = torch.nn.functional.scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
head_mask,
self.attention_probs_dropout_prob if self.training else 0.0,
is_causal=False,
scale=None,
)

context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)

return context_layer, None


# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->Vivit
class VivitSelfOutput(nn.Module):
"""
Expand Down Expand Up @@ -286,6 +318,13 @@ def forward(
return outputs


# Copied from transformers.models.vit.modeling_vit.ViTSdpaAttention with ViT->Vivit
class VivitSdpaAttention(VivitAttention):
def __init__(self, config: VivitConfig) -> None:
super().__init__(config)
self.attention = VivitSdpaSelfAttention(config)


class VivitIntermediate(nn.Module):
def __init__(self, config):
super().__init__()
Expand Down Expand Up @@ -320,14 +359,20 @@ def forward(self, hidden_states, input_tensor):
return hidden_states


VIVIT_ATTENTION_CLASSES = {
"eager": VivitAttention,
"sdpa": VivitSdpaAttention,
}


class VivitLayer(nn.Module):
"""This corresponds to the EncoderBlock class in the scenic/vivit implementation."""

def __init__(self, config):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = VivitAttention(config)
self.attention = VIVIT_ATTENTION_CLASSES[config._attn_implementation](config)
self.intermediate = VivitIntermediate(config)
self.output = VivitOutput(config)
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
Expand Down Expand Up @@ -436,6 +481,7 @@ class VivitPreTrainedModel(PreTrainedModel):
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
_no_split_modules = []
_supports_sdpa = True

def _init_weights(self, module):
"""Initialize the weights"""
Expand Down
6 changes: 6 additions & 0 deletions tests/models/vivit/test_modeling_vivit.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ def __init__(
layer_norm_eps=1e-06,
qkv_bias=True,
scope=None,
attn_implementation="eager",
mask_ratio=0.5,
):
self.parent = parent
self.batch_size = batch_size
Expand All @@ -86,12 +88,15 @@ def __init__(
self.layer_norm_eps = layer_norm_eps
self.qkv_bias = qkv_bias
self.scope = scope
self.attn_implementation = attn_implementation

self.seq_length = (
(self.image_size // self.tubelet_size[2])
* (self.image_size // self.tubelet_size[1])
* (self.num_frames // self.tubelet_size[0])
) + 1 # CLS token
self.mask_ratio = mask_ratio
self.num_masks = int(mask_ratio * self.seq_length)

def prepare_config_and_inputs(self):
pixel_values = floats_tensor(
Expand Down Expand Up @@ -122,6 +127,7 @@ def get_config(self):
initializer_range=self.initializer_range,
layer_norm_eps=self.layer_norm_eps,
qkv_bias=self.qkv_bias,
attn_implementation=self.attn_implementation,
)
config.num_labels = self.num_labels
return config
Expand Down
Loading