diff --git a/ip_adapter.py b/ip_adapter.py index bab71c0..9a55020 100644 --- a/ip_adapter.py +++ b/ip_adapter.py @@ -78,16 +78,15 @@ def load_state_dict(self, state_dict): self.to_kvs[i].weight.data = state_dict[key] class IPAdapterModel(torch.nn.Module): - def __init__(self, state_dict, plus, cross_attention_dim=768, clip_embeddings_dim=1024, clip_extra_context_tokens=4): + def __init__(self, state_dict, plus, cross_attention_dim=768, clip_embeddings_dim=1024, clip_extra_context_tokens=4, sdxl_plus=False): super().__init__() - self.device = "cuda" self.plus = plus if self.plus: self.image_proj_model = Resampler( - dim=cross_attention_dim, + dim=1280 if sdxl_plus else cross_attention_dim, depth=4, dim_head=64, - heads=12, + heads=20 if sdxl_plus else 12, num_queries=clip_extra_context_tokens, embedding_dim=clip_embeddings_dim, output_dim=cross_attention_dim, @@ -138,8 +137,8 @@ def INPUT_TYPES(s): CATEGORY = "loaders" def adapter(self, model, image, clip_vision, weight, model_name, dtype, mask=None): - self.dtype = torch.float32 if dtype == "fp32" else torch.float16 - device = "cuda" + device = comfy.model_management.get_torch_device() + self.dtype = torch.float32 if dtype == "fp32" or device.type == "mps" else torch.float16 self.weight = weight # ip_adapter scale ip_state_dict = torch.load(os.path.join(CURRENT_DIR, os.path.join(CURRENT_DIR, "models", model_name)), map_location="cpu") @@ -148,6 +147,9 @@ def adapter(self, model, image, clip_vision, weight, model_name, dtype, mask=Non # cross_attention_dim is equal to text_encoder output self.cross_attention_dim = ip_state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[1] + self.sdxl = self.cross_attention_dim == 2048 + self.sdxl_plus = self.sdxl and self.plus + # number of tokens of ip_adapter embedding if self.plus: self.clip_extra_context_tokens = ip_state_dict["image_proj"]["latents"].shape[1] @@ -156,16 +158,14 @@ def adapter(self, model, image, clip_vision, weight, model_name, dtype, mask=Non cond, uncond, outputs = self.clip_vision_encode(clip_vision, image, self.plus) self.clip_embeddings_dim = cond.shape[-1] - - # sd_v1-2: 1024, sd_xl: 2048 - self.sdxl = self.cross_attention_dim == 2048 self.ipadapter = IPAdapterModel( ip_state_dict, plus = self.plus, cross_attention_dim = self.cross_attention_dim, clip_embeddings_dim = self.clip_embeddings_dim, - clip_extra_context_tokens = self.clip_extra_context_tokens + clip_extra_context_tokens = self.clip_extra_context_tokens, + sdxl_plus = self.sdxl_plus ) self.ipadapter.to(device, dtype=self.dtype) @@ -178,6 +178,11 @@ def adapter(self, model, image, clip_vision, weight, model_name, dtype, mask=Non new_model = model.clone() + if mask is not None: + if mask.dim() == 3: + mask = mask[0] + mask = mask.to(device) + ''' patch_name of sdv1-2: ("input" or "output" or "middle", block_id) patch_name of sdxl: ("input" or "output" or "middle", block_id, transformer_index) @@ -189,7 +194,7 @@ def adapter(self, model, image, clip_vision, weight, model_name, dtype, mask=Non "dtype": self.dtype, "cond": self.image_emb, "uncond": self.uncond_image_emb, - "mask": mask if mask is None else mask.to(device) + "mask": mask } if not self.sdxl: