diff --git a/src/transformers/models/instructblip/modeling_instructblip.py b/src/transformers/models/instructblip/modeling_instructblip.py index 1254ef558f6b1b..645c38c2046273 100644 --- a/src/transformers/models/instructblip/modeling_instructblip.py +++ b/src/transformers/models/instructblip/modeling_instructblip.py @@ -558,7 +558,6 @@ def get_input_embeddings(self): return self.embeddings -# Copied from transformers.models.blip_2.modeling_blip_2.Blip2QFormerMultiHeadAttention with Blip2->InstructBlip class InstructBlipQFormerMultiHeadAttention(nn.Module): def __init__(self, config, is_cross_attention=False): super().__init__() @@ -659,13 +658,14 @@ def forward( attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key attention_scores = attention_scores / math.sqrt(self.attention_head_size) + attention_scores_dtype = attention_scores.dtype if attention_mask is not None: # Apply the attention mask is (precomputed for all layers in BertModel forward() function) attention_scores = attention_scores + attention_mask # Normalize the attention scores to probabilities. - attention_probs = nn.Softmax(dim=-1)(attention_scores) + attention_probs = nn.Softmax(dim=-1)(attention_scores).to(attention_scores_dtype) if is_cross_attention and self.save_attention: self.save_attention_map(attention_probs) @@ -1038,6 +1038,7 @@ def forward( else: embeddings = query_embeds + embeddings = embeddings.to(self.layernorm.weight.dtype) embeddings = self.layernorm(embeddings) embeddings = self.dropout(embeddings) return embeddings