diff --git a/docs/source/en/model_doc/vivit.md b/docs/source/en/model_doc/vivit.md index 4426493a0ff585..c3e3df14ab988b 100644 --- a/docs/source/en/model_doc/vivit.md +++ b/docs/source/en/model_doc/vivit.md @@ -23,6 +23,43 @@ The abstract from the paper is the following: This model was contributed by [jegormeister](https://huggingface.co/jegormeister). The original code (written in JAX) can be found [here](https://github.com/google-research/scenic/tree/main/scenic/projects/vivit). +### Using Scaled Dot Product Attention (SDPA) + +PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function +encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the +[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) +or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention) +page for more information. + +SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set +`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used. + +``` +from transformers import VivitModel +model = VivitModel.from_pretrained("google/vivit-b-16x2-kinetics400", attn_implementation="sdpa", torch_dtype=torch.float16) +... +``` + +For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`). + +On a local benchmark (A100-40GB, PyTorch 2.3.0, OS Ubuntu 22.04) with `float32` and `google/vivit-b-16x2-kinetics400` model, we saw the following speedups during inference. + +### Training +| num_training_steps | batch_size | is cuda | Speedup (%) | Eager peak mem (MB) | sdpa peak mem (MB) | Mem saving (%) | +|---------------------:|-------------:|----------:|--------------:|----------------------:|---------------------:|-----------------:| +| 100 | 1 | True | 7.122 | 2575.28 | 5932.54 | 130.364 | + + + +### Inference +| num_batches | batch_size | is cuda | is half | Speedup (%) | Mem eager (MB) | Mem BT (MB) | Mem saved (%) | +|---------------|--------------|-----------|-----------|---------------|------------------|---------------|-----------------| +| 20 | 1 | True | False | 15.422 | 715.807 | 317.079 | 125.75 | +| 20 | 2 | True | False | 17.146 | 1234.75 | 447.175 | 176.122 | +| 20 | 4 | True | False | 18.093 | 2275.82 | 709.864 | 220.6 | +| 20 | 8 | True | False | 19.284 | 4358.19 | 1233.24 | 253.393 | + + ## VivitConfig [[autodoc]] VivitConfig diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index ed3b26029d0094..2b0ebce579e54f 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -275,6 +275,7 @@ For now, Transformers supports SDPA inference and training for the following arc * [ViTMAE](https://huggingface.co/docs/transformers/model_doc/vit_mae#transformers.ViTMAEModel) * [ViTMSN](https://huggingface.co/docs/transformers/model_doc/vit_msn#transformers.ViTMSNModel) * [VideoMAE](https://huggingface.co/docs/transformers/model_doc/videomae#transformers.VideoMAEModell) +* [ViViT](https://huggingface.co/docs/transformers/model_doc/vivit#transformers.VivitModel) * [wav2vec2](https://huggingface.co/docs/transformers/model_doc/wav2vec2#transformers.Wav2Vec2Model) * [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperModel) * [XLM-RoBERTa](https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaModel) diff --git a/src/transformers/models/vivit/modeling_vivit.py b/src/transformers/models/vivit/modeling_vivit.py index 3d543503284489..9b6516a25af45b 100755 --- a/src/transformers/models/vivit/modeling_vivit.py +++ b/src/transformers/models/vivit/modeling_vivit.py @@ -227,6 +227,51 @@ def forward( return outputs +# Adapted from transformers.models.vit.modeling_vit.ViTSdpaSelfAttention with ViT->Vivit +class VivitSdpaSelfAttention(VivitSelfAttention): + def __init__(self, config: VivitConfig) -> None: + super().__init__(config) + self.attention_probs_dropout_prob = config.attention_probs_dropout_prob + + def forward( + self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + if output_attentions or head_mask is not None: + logger.warning_once( + "VivitSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support" + " `output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but specifying" + " the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be" + ' removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states, + head_mask, + output_attentions, + ) + + mixed_query_layer = self.query(hidden_states) + + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = self.transpose_for_scores(mixed_query_layer) + + context_layer = torch.nn.functional.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + head_mask, + self.attention_probs_dropout_prob if self.training else 0.0, + is_causal=False, + scale=None, + ) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + return context_layer, None + + # Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->Vivit class VivitSelfOutput(nn.Module): """ @@ -286,6 +331,13 @@ def forward( return outputs +# Copied from transformers.models.vit.modeling_vit.ViTSdpaAttention with ViT->Vivit +class VivitSdpaAttention(VivitAttention): + def __init__(self, config: VivitConfig) -> None: + super().__init__(config) + self.attention = VivitSdpaSelfAttention(config) + + class VivitIntermediate(nn.Module): def __init__(self, config): super().__init__() @@ -320,6 +372,12 @@ def forward(self, hidden_states, input_tensor): return hidden_states +VIVIT_ATTENTION_CLASSES = { + "eager": VivitAttention, + "sdpa": VivitSdpaAttention, +} + + class VivitLayer(nn.Module): """This corresponds to the EncoderBlock class in the scenic/vivit implementation.""" @@ -327,7 +385,7 @@ def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = VivitAttention(config) + self.attention = VIVIT_ATTENTION_CLASSES[config._attn_implementation](config) self.intermediate = VivitIntermediate(config) self.output = VivitOutput(config) self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -436,6 +494,7 @@ class VivitPreTrainedModel(PreTrainedModel): main_input_name = "pixel_values" supports_gradient_checkpointing = True _no_split_modules = [] + _supports_sdpa = True def _init_weights(self, module): """Initialize the weights""" diff --git a/tests/models/vivit/test_modeling_vivit.py b/tests/models/vivit/test_modeling_vivit.py index 7cce77e6fc0019..8e6b0825948d40 100644 --- a/tests/models/vivit/test_modeling_vivit.py +++ b/tests/models/vivit/test_modeling_vivit.py @@ -65,6 +65,8 @@ def __init__( layer_norm_eps=1e-06, qkv_bias=True, scope=None, + attn_implementation="eager", + mask_ratio=0.5, ): self.parent = parent self.batch_size = batch_size @@ -86,12 +88,15 @@ def __init__( self.layer_norm_eps = layer_norm_eps self.qkv_bias = qkv_bias self.scope = scope + self.attn_implementation = attn_implementation self.seq_length = ( (self.image_size // self.tubelet_size[2]) * (self.image_size // self.tubelet_size[1]) * (self.num_frames // self.tubelet_size[0]) ) + 1 # CLS token + self.mask_ratio = mask_ratio + self.num_masks = int(mask_ratio * self.seq_length) def prepare_config_and_inputs(self): pixel_values = floats_tensor( @@ -122,6 +127,7 @@ def get_config(self): initializer_range=self.initializer_range, layer_norm_eps=self.layer_norm_eps, qkv_bias=self.qkv_bias, + attn_implementation=self.attn_implementation, ) config.num_labels = self.num_labels return config