From de71fa59f56c89499fa2404a50c672809a9d3b0c Mon Sep 17 00:00:00 2001 From: AnyISalIn Date: Wed, 25 Oct 2023 13:28:15 +0800 Subject: [PATCH] fix error of peft lora when xformers enabled (#5506) Signed-off-by: AnyISalIn --- src/diffusers/models/attention_processor.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 9856f3c7739c..89ecf143be22 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -909,6 +909,8 @@ def __call__( ): residual = hidden_states + args = () if USE_PEFT_BACKEND else (scale,) + if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) @@ -936,15 +938,15 @@ def __call__( if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - query = attn.to_q(hidden_states, scale=scale) + query = attn.to_q(hidden_states, *args) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - key = attn.to_k(encoder_hidden_states, scale=scale) - value = attn.to_v(encoder_hidden_states, scale=scale) + key = attn.to_k(encoder_hidden_states, *args) + value = attn.to_v(encoder_hidden_states, *args) query = attn.head_to_batch_dim(query).contiguous() key = attn.head_to_batch_dim(key).contiguous() @@ -957,7 +959,7 @@ def __call__( hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj - hidden_states = attn.to_out[0](hidden_states, scale=scale) + hidden_states = attn.to_out[0](hidden_states, *args) # dropout hidden_states = attn.to_out[1](hidden_states)