Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support InstantStyle #7586

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/diffusers/loaders/ip_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def load_ip_adapter(
token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
target_blocks = kwargs.pop("target_blocks", ["block"])

if low_cpu_mem_usage and not is_accelerate_available():
low_cpu_mem_usage = False
Expand Down Expand Up @@ -226,7 +227,7 @@ def load_ip_adapter(

# load ip-adapter into unet
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
unet._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
unet._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage, target_blocks=target_blocks)

def set_ip_adapter_scale(self, scale):
"""
Expand Down
14 changes: 11 additions & 3 deletions src/diffusers/loaders/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -800,7 +800,7 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us

return image_projection

def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=False):
def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=False, target_blocks=['block']):
from ..models.attention_processor import (
AttnProcessor,
AttnProcessor2_0,
Expand Down Expand Up @@ -864,11 +864,19 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=F
num_image_text_embeds += [state_dict["image_proj"]["latents"].shape[1]]

with init_context():

selected = False
for block_name in target_blocks:
if block_name in name:
selected = True
break
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this work?

Suggested change
selected = False
for block_name in target_blocks:
if block_name in name:
selected = True
break
selected = any( block_name in name for block_name in target_blocks)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it is much clearer. Updated.


attn_procs[name] = attn_processor_class(
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
scale=1.0,
num_tokens=num_image_text_embeds,
skip=not selected
)

value_dict = {}
Expand All @@ -887,14 +895,14 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=F

return attn_procs

def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=False):
def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=False, target_blocks=["block"]):
if not isinstance(state_dicts, list):
state_dicts = [state_dicts]
# Set encoder_hid_proj after loading ip_adapter weights,
# because `IPAdapterPlusImageProjection` also has `attn_processors`.
self.encoder_hid_proj = None

attn_procs = self._convert_ip_adapter_attn_to_diffusers(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
attn_procs = self._convert_ip_adapter_attn_to_diffusers(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage, target_blocks=target_blocks)
self.set_attn_processor(attn_procs)

# convert IP-Adapter Image Projection layers to diffusers
Expand Down
98 changes: 51 additions & 47 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2108,7 +2108,7 @@ class IPAdapterAttnProcessor(nn.Module):
the weight scale of image prompt.
"""

def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale=1.0):
def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale=1.0, skip=False):
super().__init__()

self.hidden_size = hidden_size
Expand All @@ -2117,6 +2117,7 @@ def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale
if not isinstance(num_tokens, (tuple, list)):
num_tokens = [num_tokens]
self.num_tokens = num_tokens
self.skip = skip

if not isinstance(scale, list):
scale = [scale] * len(num_tokens)
Expand Down Expand Up @@ -2208,29 +2209,30 @@ def __call__(
ip_adapter_masks = [None] * len(self.scale)

# for ip-adapter
for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip(
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks
):
ip_key = to_k_ip(current_ip_hidden_states)
ip_value = to_v_ip(current_ip_hidden_states)

ip_key = attn.head_to_batch_dim(ip_key)
ip_value = attn.head_to_batch_dim(ip_value)

ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states)

if mask is not None:
mask_downsample = IPAdapterMaskProcessor.downsample(
mask, batch_size, current_ip_hidden_states.shape[1], current_ip_hidden_states.shape[2]
)
if not self.skip:
for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip(
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks
):
ip_key = to_k_ip(current_ip_hidden_states)
ip_value = to_v_ip(current_ip_hidden_states)

ip_key = attn.head_to_batch_dim(ip_key)
ip_value = attn.head_to_batch_dim(ip_value)

ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states)

if mask is not None:
mask_downsample = IPAdapterMaskProcessor.downsample(
mask, batch_size, current_ip_hidden_states.shape[1], current_ip_hidden_states.shape[2]
)

mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)

current_ip_hidden_states = current_ip_hidden_states * mask_downsample
current_ip_hidden_states = current_ip_hidden_states * mask_downsample

hidden_states = hidden_states + scale * current_ip_hidden_states
hidden_states = hidden_states + scale * current_ip_hidden_states

# linear proj
hidden_states = attn.to_out[0](hidden_states)
Expand Down Expand Up @@ -2263,7 +2265,7 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module):
the weight scale of image prompt.
"""

def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale=1.0):
def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale=1.0, skip=False):
super().__init__()

if not hasattr(F, "scaled_dot_product_attention"):
Expand All @@ -2283,6 +2285,7 @@ def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale
if len(scale) != len(num_tokens):
raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.")
self.scale = scale
self.skip = skip

self.to_k_ip = nn.ModuleList(
[nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))]
Expand Down Expand Up @@ -2382,36 +2385,37 @@ def __call__(
ip_adapter_masks = [None] * len(self.scale)

# for ip-adapter
for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip(
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks
):
ip_key = to_k_ip(current_ip_hidden_states)
ip_value = to_v_ip(current_ip_hidden_states)

ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
current_ip_hidden_states = F.scaled_dot_product_attention(
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)

current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim
)
current_ip_hidden_states = current_ip_hidden_states.to(query.dtype)
if not self.skip:
for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip(
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks
):
ip_key = to_k_ip(current_ip_hidden_states)
ip_value = to_v_ip(current_ip_hidden_states)

ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
current_ip_hidden_states = F.scaled_dot_product_attention(
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)

if mask is not None:
mask_downsample = IPAdapterMaskProcessor.downsample(
mask, batch_size, current_ip_hidden_states.shape[1], current_ip_hidden_states.shape[2]
current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim
)
current_ip_hidden_states = current_ip_hidden_states.to(query.dtype)

if mask is not None:
mask_downsample = IPAdapterMaskProcessor.downsample(
mask, batch_size, current_ip_hidden_states.shape[1], current_ip_hidden_states.shape[2]
)

mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)

current_ip_hidden_states = current_ip_hidden_states * mask_downsample
current_ip_hidden_states = current_ip_hidden_states * mask_downsample

hidden_states = hidden_states + scale * current_ip_hidden_states
hidden_states = hidden_states + scale * current_ip_hidden_states

# linear proj
hidden_states = attn.to_out[0](hidden_states)
Expand Down
Loading