Skip to content

Commit

Permalink
support torch>=2.0
Browse files Browse the repository at this point in the history
  • Loading branch information
songweige committed Sep 27, 2023
1 parent 12da0a4 commit a02b874
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 28 deletions.
48 changes: 21 additions & 27 deletions scripts/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1107,17 +1107,16 @@ def __call__(

class AttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
Default processor for performing attention-related computations.
"""

def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")

# Rich-Text: inject self-attention maps
def __call__(
self,
attn: Attention,
hidden_states,
real_attn_probs=None,
attn_weights=None,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
Expand All @@ -1136,13 +1135,7 @@ def __call__(
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
inner_dim = hidden_states.shape[-1]

if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)

if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
Expand All @@ -1157,21 +1150,18 @@ def __call__(
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)

head_dim = inner_dim // attn.heads

query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)

hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
if real_attn_probs is None:
# Rich-Text: font size
attention_probs = attn.get_attention_scores(query, key, attention_mask, attn_weights=attn_weights)
else:
# Rich-Text: inject self-attention maps
attention_probs = real_attn_probs
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)

# linear proj
hidden_states = attn.to_out[0](hidden_states)
Expand All @@ -1186,7 +1176,11 @@ def __call__(

hidden_states = hidden_states / attn.rescale_output_factor

return hidden_states
# Rich-Text Modified: return attn probs
# We return the map averaged over heads to save memory footprint
attention_probs_avg = attn.reshape_batch_dim_to_heads_and_average(
attention_probs)
return hidden_states, [attention_probs_avg, attention_probs]


class LoRAXFormersAttnProcessor(nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion scripts/rich_text_on_tab.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def generate(
minimum=0,
maximum=1,
step=0.01,
value=0.)
value=0.3)
inject_background = gr.Slider(label='Unformatted token preservation',
info='(To affect less the tokens without any rich-text attributes, increase this.)',
minimum=0,
Expand Down

0 comments on commit a02b874

Please sign in to comment.