Skip to content

Commit

Permalink
fix error of peft lora when xformers enabled (#5506)
Browse files Browse the repository at this point in the history
Signed-off-by: AnyISalIn <[email protected]>
  • Loading branch information
AnyISalIn authored Oct 25, 2023
1 parent dcbfe66 commit de71fa5
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -909,6 +909,8 @@ def __call__(
):
residual = hidden_states

args = () if USE_PEFT_BACKEND else (scale,)

if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)

Expand Down Expand Up @@ -936,15 +938,15 @@ def __call__(
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

query = attn.to_q(hidden_states, scale=scale)
query = attn.to_q(hidden_states, *args)

if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

key = attn.to_k(encoder_hidden_states, scale=scale)
value = attn.to_v(encoder_hidden_states, scale=scale)
key = attn.to_k(encoder_hidden_states, *args)
value = attn.to_v(encoder_hidden_states, *args)

query = attn.head_to_batch_dim(query).contiguous()
key = attn.head_to_batch_dim(key).contiguous()
Expand All @@ -957,7 +959,7 @@ def __call__(
hidden_states = attn.batch_to_head_dim(hidden_states)

# linear proj
hidden_states = attn.to_out[0](hidden_states, scale=scale)
hidden_states = attn.to_out[0](hidden_states, *args)
# dropout
hidden_states = attn.to_out[1](hidden_states)

Expand Down

0 comments on commit de71fa5

Please sign in to comment.