Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update attention.py #1416

Merged
merged 4 commits into from
Oct 9, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 27 additions & 22 deletions optimum/bettertransformer/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,6 +695,7 @@ def gpt_bigcode_wrapped_scaled_dot_product(
# MHA models: (batch_size, num_heads, query_length, head_dim)
query_shape = query.shape
batch_size = query_shape[0]
kv_seq_len = key.shape[-2]

if self.multi_query:
query_length = query_shape[1]
Expand All @@ -721,30 +722,34 @@ def gpt_bigcode_wrapped_scaled_dot_product(
key = key.expand(-1, self.num_heads, -1, -1)
value = value.expand(-1, self.num_heads, -1, -1)

if batch_size == 1 or self.training:
if query_length > 1:
sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=dropout_p, is_causal=True
)
else:
sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=dropout_p, is_causal=False
)
# We treat self.training and (batch_size == 1 and query_length == 1) cases separately to still allow the dispatch to Flash Attention.
if self.training:
is_causal = True
attn_mask = None
elif batch_size == 1 and query_length == 1:
is_causal = False
attn_mask = None
elif batch_size == 1 and kv_seq_len == query_length:
is_causal = True
attn_mask = None
elif attention_mask is not None:
mask_value = self._get_mask_value(query.device, query.dtype)

# gpt_bigcode has the bad taste to use a causal mask a
# [batch_size, target_length, 1, source_length] which is different from
# **all** other architectures and not compatible with SDPA.
# We could avoid this transpose by overriding the forward from GPTBigCodeModel,
# but it is probably not worth it.
attention_mask = attention_mask.transpose(1, 2)
attn_mask = torch.where(attention_mask, 0.0, mask_value)
is_causal = False
else:
if attention_mask is not None:
mask_value = self._get_mask_value(query.device, query.dtype)
attn_mask = None
is_causal = True

# gpt_bigcode has the bad taste to use a causal mask a
# [batch_size, target_length, 1, source_length] which is different from
# **all** other architectures and not compatible with SDPA.
# We could avoid this transpose by overriding the forward from GPTBigCodeModel,
# but it is probably not worth it.
attention_mask = attention_mask.transpose(1, 2)
attention_mask = torch.where(attention_mask, 0.0, mask_value)

sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=dropout_p, is_causal=False
)
sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal
)

if self.multi_query:
# (batch_size, num_heads, seq_len, head_dim) --> (batch_size, seq_len, num_heads, head_dim)
Expand Down
Loading