Skip to content
This repository has been archived by the owner on Dec 25, 2023. It is now read-only.

support sdxl plus and fix mask #36

Merged
merged 2 commits into from
Sep 30, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 16 additions & 11 deletions ip_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,16 +78,15 @@ def load_state_dict(self, state_dict):
self.to_kvs[i].weight.data = state_dict[key]

class IPAdapterModel(torch.nn.Module):
def __init__(self, state_dict, plus, cross_attention_dim=768, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
def __init__(self, state_dict, plus, cross_attention_dim=768, clip_embeddings_dim=1024, clip_extra_context_tokens=4, sdxl_plus=False):
super().__init__()
self.device = "cuda"
self.plus = plus
if self.plus:
self.image_proj_model = Resampler(
dim=cross_attention_dim,
dim=1280 if sdxl_plus else cross_attention_dim,
depth=4,
dim_head=64,
heads=12,
heads=20 if sdxl_plus else 12,
num_queries=clip_extra_context_tokens,
embedding_dim=clip_embeddings_dim,
output_dim=cross_attention_dim,
Expand Down Expand Up @@ -138,8 +137,8 @@ def INPUT_TYPES(s):
CATEGORY = "loaders"

def adapter(self, model, image, clip_vision, weight, model_name, dtype, mask=None):
self.dtype = torch.float32 if dtype == "fp32" else torch.float16
device = "cuda"
device = comfy.model_management.get_torch_device()
self.dtype = torch.float32 if dtype == "fp32" or device.type == "mps" else torch.float16
self.weight = weight # ip_adapter scale

ip_state_dict = torch.load(os.path.join(CURRENT_DIR, os.path.join(CURRENT_DIR, "models", model_name)), map_location="cpu")
Expand All @@ -148,6 +147,9 @@ def adapter(self, model, image, clip_vision, weight, model_name, dtype, mask=Non
# cross_attention_dim is equal to text_encoder output
self.cross_attention_dim = ip_state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[1]

self.sdxl = self.cross_attention_dim == 2048
self.sdxl_plus = self.sdxl and self.plus

# number of tokens of ip_adapter embedding
if self.plus:
self.clip_extra_context_tokens = ip_state_dict["image_proj"]["latents"].shape[1]
Expand All @@ -156,16 +158,14 @@ def adapter(self, model, image, clip_vision, weight, model_name, dtype, mask=Non

cond, uncond, outputs = self.clip_vision_encode(clip_vision, image, self.plus)
self.clip_embeddings_dim = cond.shape[-1]

# sd_v1-2: 1024, sd_xl: 2048
self.sdxl = self.cross_attention_dim == 2048

self.ipadapter = IPAdapterModel(
ip_state_dict,
plus = self.plus,
cross_attention_dim = self.cross_attention_dim,
clip_embeddings_dim = self.clip_embeddings_dim,
clip_extra_context_tokens = self.clip_extra_context_tokens
clip_extra_context_tokens = self.clip_extra_context_tokens,
sdxl_plus = self.sdxl_plus
)

self.ipadapter.to(device, dtype=self.dtype)
Expand All @@ -178,6 +178,11 @@ def adapter(self, model, image, clip_vision, weight, model_name, dtype, mask=Non

new_model = model.clone()

if mask is not None:
if mask.dim() == 3:
mask = mask[0]
mask = mask.to(device)

'''
patch_name of sdv1-2: ("input" or "output" or "middle", block_id)
patch_name of sdxl: ("input" or "output" or "middle", block_id, transformer_index)
Expand All @@ -189,7 +194,7 @@ def adapter(self, model, image, clip_vision, weight, model_name, dtype, mask=Non
"dtype": self.dtype,
"cond": self.image_emb,
"uncond": self.uncond_image_emb,
"mask": mask if mask is None else mask.to(device)
"mask": mask
}

if not self.sdxl:
Expand Down