Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[InstructBlip] Fix int8/fp4 issues #24888

Merged
merged 3 commits into from
Jul 18, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/transformers/models/instructblip/modeling_instructblip.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,6 @@ def get_input_embeddings(self):
return self.embeddings


# Copied from transformers.models.blip_2.modeling_blip_2.Blip2QFormerMultiHeadAttention with Blip2->InstructBlip
class InstructBlipQFormerMultiHeadAttention(nn.Module):
def __init__(self, config, is_cross_attention=False):
super().__init__()
Expand Down Expand Up @@ -659,13 +658,14 @@ def forward(
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key

attention_scores = attention_scores / math.sqrt(self.attention_head_size)
attention_scores_dtype = attention_scores.dtype

if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
attention_scores = attention_scores + attention_mask

# Normalize the attention scores to probabilities.
attention_probs = nn.Softmax(dim=-1)(attention_scores)
attention_probs = nn.Softmax(dim=-1)(attention_scores).to(attention_scores_dtype)

if is_cross_attention and self.save_attention:
self.save_attention_map(attention_probs)
Expand Down Expand Up @@ -1038,6 +1038,7 @@ def forward(
else:
embeddings = query_embeds

embeddings = embeddings.to(self.layernorm.weight.dtype)
embeddings = self.layernorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SHouldn't this be put back to the original dtype? Or is the layernorm in the right dtype.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes the LN is in the right dtype here, it is the embedding that is in full precision

Expand Down
Loading