Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sdpa for Vivit #33757

Merged
merged 18 commits into from
Oct 15, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/en/perf_infer_gpu_one.md
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ For now, Transformers supports SDPA inference and training for the following arc
* [ViTMAE](https://huggingface.co/docs/transformers/model_doc/vit_mae#transformers.ViTMAEModel)
* [ViTMSN](https://huggingface.co/docs/transformers/model_doc/vit_msn#transformers.ViTMSNModel)
* [VideoMAE](https://huggingface.co/docs/transformers/model_doc/videomae#transformers.VideoMAEModell)
* [ViViT](https://huggingface.co/docs/transformers/model_doc/vivit#transformers.VivitModel)
* [wav2vec2](https://huggingface.co/docs/transformers/model_doc/wav2vec2#transformers.Wav2Vec2Model)
* [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperModel)
* [XLM-RoBERTa](https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaModel)
Expand Down
55 changes: 51 additions & 4 deletions src/transformers/models/vivit/modeling_vivit.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,10 @@ def __init__(self, config):
torch.zeros(1, self.patch_embeddings.num_patches + 1, config.hidden_size)
)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.patch_size = config.tubelet_size
self.config = config

# Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding
# Adapted from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
Expand All @@ -129,8 +130,8 @@ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width:

dim = embeddings.shape[-1]

new_height = height // self.patch_size
new_width = width // self.patch_size
new_height = height // self.patch_size[1]
new_width = width // self.patch_size[2]

sqrt_num_positions = torch_int(num_positions**0.5)
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
Expand Down Expand Up @@ -226,6 +227,38 @@ def forward(
return outputs


# Copied from transformers.models.vit.modeling_vit.ViTSdpaSelfAttention with ViT->Vivit
class VivitSdpaSelfAttention(VivitSelfAttention):
def __init__(self, config: VivitConfig) -> None:
super().__init__(config)
self.attention_probs_dropout_prob = config.attention_probs_dropout_prob

def forward(
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
mixed_query_layer = self.query(hidden_states)

key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)

context_layer = torch.nn.functional.scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
head_mask,
self.attention_probs_dropout_prob if self.training else 0.0,
is_causal=False,
scale=None,
)

context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)

return context_layer, None


# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->Vivit
class VivitSelfOutput(nn.Module):
"""
Expand Down Expand Up @@ -285,6 +318,13 @@ def forward(
return outputs


# Copied from transformers.models.vit.modeling_vit.ViTSdpaAttention with ViT->Vivit
class VivitSdpaAttention(VivitAttention):
def __init__(self, config: VivitConfig) -> None:
super().__init__(config)
self.attention = VivitSdpaSelfAttention(config)


class VivitIntermediate(nn.Module):
def __init__(self, config):
super().__init__()
Expand Down Expand Up @@ -319,14 +359,20 @@ def forward(self, hidden_states, input_tensor):
return hidden_states


VIVIT_ATTENTION_CLASSES = {
"eager": VivitAttention,
"sdpa": VivitSdpaAttention,
}


class VivitLayer(nn.Module):
"""This corresponds to the EncoderBlock class in the scenic/vivit implementation."""

def __init__(self, config):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = VivitAttention(config)
self.attention = VIVIT_ATTENTION_CLASSES[config._attn_implementation](config)
self.intermediate = VivitIntermediate(config)
self.output = VivitOutput(config)
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
Expand Down Expand Up @@ -435,6 +481,7 @@ class VivitPreTrainedModel(PreTrainedModel):
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
_no_split_modules = []
_supports_sdpa = True

def _init_weights(self, module):
"""Initialize the weights"""
Expand Down
6 changes: 6 additions & 0 deletions tests/models/vivit/test_modeling_vivit.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ def __init__(
layer_norm_eps=1e-06,
qkv_bias=True,
scope=None,
attn_implementation="eager",
mask_ratio=0.5,
):
self.parent = parent
self.batch_size = batch_size
Expand All @@ -86,12 +88,15 @@ def __init__(
self.layer_norm_eps = layer_norm_eps
self.qkv_bias = qkv_bias
self.scope = scope
self.attn_implementation = attn_implementation

self.seq_length = (
(self.image_size // self.tubelet_size[2])
* (self.image_size // self.tubelet_size[1])
* (self.num_frames // self.tubelet_size[0])
) + 1 # CLS token
self.mask_ratio = mask_ratio
self.num_masks = int(mask_ratio * self.seq_length)

def prepare_config_and_inputs(self):
pixel_values = floats_tensor(
Expand Down Expand Up @@ -122,6 +127,7 @@ def get_config(self):
initializer_range=self.initializer_range,
layer_norm_eps=self.layer_norm_eps,
qkv_bias=self.qkv_bias,
attn_implementation=self.attn_implementation,
)
config.num_labels = self.num_labels
return config
Expand Down