Skip to content

Commit

Permalink
[InstructBlip] Fix int8/fp4 issues (#24888)
Browse files Browse the repository at this point in the history
* fix dtype issue

* revert `.float()`

* fix copies
  • Loading branch information
younesbelkada authored Jul 18, 2023
1 parent 3ec10e6 commit a9e067a
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions src/transformers/models/instructblip/modeling_instructblip.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,6 @@ def get_input_embeddings(self):
return self.embeddings


# Copied from transformers.models.blip_2.modeling_blip_2.Blip2QFormerMultiHeadAttention with Blip2->InstructBlip
class InstructBlipQFormerMultiHeadAttention(nn.Module):
def __init__(self, config, is_cross_attention=False):
super().__init__()
Expand Down Expand Up @@ -659,13 +658,14 @@ def forward(
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key

attention_scores = attention_scores / math.sqrt(self.attention_head_size)
attention_scores_dtype = attention_scores.dtype

if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
attention_scores = attention_scores + attention_mask

# Normalize the attention scores to probabilities.
attention_probs = nn.Softmax(dim=-1)(attention_scores)
attention_probs = nn.Softmax(dim=-1)(attention_scores).to(attention_scores_dtype)

if is_cross_attention and self.save_attention:
self.save_attention_map(attention_probs)
Expand Down Expand Up @@ -1038,6 +1038,7 @@ def forward(
else:
embeddings = query_embeds

embeddings = embeddings.to(self.layernorm.weight.dtype)
embeddings = self.layernorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
Expand Down

0 comments on commit a9e067a

Please sign in to comment.