From 66c03be45f30441cc01aaf7496c1339007de4cf1 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 8 Jul 2023 09:56:38 +0900 Subject: [PATCH 01/20] Fix TE key names for SD1/2 LoRA are invalid --- networks/lora.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/networks/lora.py b/networks/lora.py index b6788b995..cd73cbe7c 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -735,7 +735,7 @@ class LoRANetwork(torch.nn.Module): TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] LORA_PREFIX_UNET = "lora_unet" LORA_PREFIX_TEXT_ENCODER = "lora_te" - + # SDXL: must starts with LORA_PREFIX_TEXT_ENCODER LORA_PREFIX_TEXT_ENCODER1 = "lora_te1" LORA_PREFIX_TEXT_ENCODER2 = "lora_te2" @@ -877,7 +877,14 @@ def create_modules( self.text_encoder_loras = [] skipped_te = [] for i, text_encoder in enumerate(text_encoders): - text_encoder_loras, skipped = create_modules(False, i + 1, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) + if len(text_encoders) > 1: + index = i + 1 + print(f"create LoRA for Text Encoder {index}:") + else: + index = None + print(f"create LoRA for Text Encoder:") + + text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) self.text_encoder_loras.extend(text_encoder_loras) skipped_te += skipped print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") From d599394f6086b7291535e510e67b30c18fef4c07 Mon Sep 17 00:00:00 2001 From: ddPn08 Date: Sat, 8 Jul 2023 15:47:56 +0900 Subject: [PATCH 02/20] support avif --- library/train_util.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/library/train_util.py b/library/train_util.py index 62cd145e1..a3017694c 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -89,6 +89,12 @@ IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp", ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"] +try: + import pillow_avif + IMAGE_EXTENSIONS.extend([".avif", ".AVIF"]) +except: + pass + class ImageInfo: def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None: From fe7ede5af33e9d4fecc22ee454f06e28c4bcff74 Mon Sep 17 00:00:00 2001 From: ykume Date: Sun, 9 Jul 2023 13:33:16 +0900 Subject: [PATCH 03/20] fix wrapper tokenizer not work for weighted prompt --- library/sdxl_train_util.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index f4b2c1739..c67a70431 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -98,13 +98,16 @@ def tokenize(self, text, padding=False, truncation=None, max_length=None, return return SimpleNamespace(**{"input_ids": input_ids}) # for weighted prompt - input_ids = open_clip.tokenize(text, context_length=self.model_max_length) + assert isinstance(text, str), f"input must be str: {text}" + + input_ids = open_clip.tokenize(text, context_length=self.model_max_length)[0] # tokenizer returns list # find eos - eos_index = (input_ids == self.eos_token_id).nonzero()[0].max() # max index of each batch - input_ids = input_ids[:, : eos_index + 1] # include eos + eos_index = (input_ids == self.eos_token_id).nonzero().max() + input_ids = input_ids[: eos_index + 1] # include eos return SimpleNamespace(**{"input_ids": input_ids}) + def load_tokenizers(args: argparse.Namespace): print("prepare tokenizers") original_path = TOKENIZER_PATH From 1d25703ac32abe8752d12e58e2637883dc6919bb Mon Sep 17 00:00:00 2001 From: ykume Date: Sun, 9 Jul 2023 13:33:26 +0900 Subject: [PATCH 04/20] add generation script --- library/sdxl_model_util.py | 2 +- sdxl_gen_img.py | 2547 ++++++++++++++++++++++++++++++++++++ 2 files changed, 2548 insertions(+), 1 deletion(-) create mode 100644 sdxl_gen_img.py diff --git a/library/sdxl_model_util.py b/library/sdxl_model_util.py index b9533af10..ae764b17f 100644 --- a/library/sdxl_model_util.py +++ b/library/sdxl_model_util.py @@ -182,7 +182,7 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location): converted_sd, logit_scale = convert_sdxl_text_encoder_2_checkpoint(te2_sd, max_length=77) info2 = text_model2.load_state_dict(converted_sd) - print("text encoder2:", info2) + print("text encoder 2:", info2) # prepare vae print("building VAE") diff --git a/sdxl_gen_img.py b/sdxl_gen_img.py new file mode 100644 index 000000000..8f1c17d60 --- /dev/null +++ b/sdxl_gen_img.py @@ -0,0 +1,2547 @@ +import itertools +import json +from typing import Any, List, NamedTuple, Optional, Tuple, Union, Callable +import glob +import importlib +import inspect +import time +import zipfile +from diffusers.utils import deprecate +from diffusers.configuration_utils import FrozenDict +import argparse +import math +import os +import random +import re + +import diffusers +import numpy as np +import torch +import torchvision +from diffusers import ( + AutoencoderKL, + DDPMScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, + DPMSolverSinglestepScheduler, + LMSDiscreteScheduler, + PNDMScheduler, + DDIMScheduler, + EulerDiscreteScheduler, + HeunDiscreteScheduler, + KDPM2DiscreteScheduler, + KDPM2AncestralDiscreteScheduler, + # UNet2DConditionModel, + StableDiffusionPipeline, +) +from einops import rearrange +from tqdm import tqdm +from torchvision import transforms +from transformers import CLIPTextModel, CLIPTokenizer, CLIPModel, CLIPTextConfig +import PIL +from PIL import Image +from PIL.PngImagePlugin import PngInfo + +import library.model_util as model_util +import library.train_util as train_util +import library.sdxl_model_util as sdxl_model_util +import library.sdxl_train_util as sdxl_train_util +from networks.lora import LoRANetwork +import tools.original_control_net as original_control_net +from tools.original_control_net import ControlNetInfo +from library.sdxl_original_unet import SdxlUNet2DConditionModel +from library.original_unet import FlashAttentionFunction + +# scheduler: +SCHEDULER_LINEAR_START = 0.00085 +SCHEDULER_LINEAR_END = 0.0120 +SCHEDULER_TIMESTEPS = 1000 +SCHEDLER_SCHEDULE = "scaled_linear" + +# その他の設定 +LATENT_CHANNELS = 4 +DOWNSAMPLING_FACTOR = 8 + +# region モジュール入れ替え部 +""" +高速化のためのモジュール入れ替え +""" + + +def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers, sdpa): + if mem_eff_attn: + print("Enable memory efficient attention for U-Net") + + # これはDiffusersのU-Netではなく自前のU-Netなので置き換えなくても良い + unet.set_use_memory_efficient_attention(False, True) + elif xformers: + print("Enable xformers for U-Net") + try: + import xformers.ops + except ImportError: + raise ImportError("No xformers / xformersがインストールされていないようです") + + unet.set_use_memory_efficient_attention(True, False) + elif sdpa: + print("Enable SDPA for U-Net") + unet.set_use_memory_efficient_attention(False, False) + unet.set_use_sdpa(True) + + +# TODO common train_util.py +def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xformers, sdpa): + if mem_eff_attn: + replace_vae_attn_to_memory_efficient() + elif xformers: + replace_vae_attn_to_xformers() + elif sdpa: + replace_vae_attn_to_sdpa() + + +def replace_vae_attn_to_memory_efficient(): + print("VAE Attention.forward has been replaced to FlashAttention (not xformers)") + flash_func = FlashAttentionFunction + + def forward_flash_attn(self, hidden_states, **kwargs): + q_bucket_size = 512 + k_bucket_size = 1024 + + residual = hidden_states + batch, channel, height, width = hidden_states.shape + + # norm + hidden_states = self.group_norm(hidden_states) + + hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) + + # proj to q, k, v + query_proj = self.to_q(hidden_states) + key_proj = self.to_k(hidden_states) + value_proj = self.to_v(hidden_states) + + query_proj, key_proj, value_proj = map( + lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), (query_proj, key_proj, value_proj) + ) + + out = flash_func.apply(query_proj, key_proj, value_proj, None, False, q_bucket_size, k_bucket_size) + + out = rearrange(out, "b h n d -> b n (h d)") + + # compute next hidden_states + # linear proj + hidden_states = self.to_out[0](hidden_states) + # dropout + hidden_states = self.to_out[1](hidden_states) + + hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) + + # res connect and rescale + hidden_states = (hidden_states + residual) / self.rescale_output_factor + return hidden_states + + def forward_flash_attn_0_14(self, hidden_states, **kwargs): + if not hasattr(self, "to_q"): + self.to_q = self.query + self.to_k = self.key + self.to_v = self.value + self.to_out = [self.proj_attn, torch.nn.Identity()] + self.heads = self.num_heads + return forward_flash_attn(self, hidden_states, **kwargs) + + if diffusers.__version__ < "0.15.0": + diffusers.models.attention.AttentionBlock.forward = forward_flash_attn_0_14 + else: + diffusers.models.attention_processor.Attention.forward = forward_flash_attn + + +def replace_vae_attn_to_xformers(): + print("VAE: Attention.forward has been replaced to xformers") + import xformers.ops + + def forward_xformers(self, hidden_states, **kwargs): + residual = hidden_states + batch, channel, height, width = hidden_states.shape + + # norm + hidden_states = self.group_norm(hidden_states) + + hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) + + # proj to q, k, v + query_proj = self.to_q(hidden_states) + key_proj = self.to_k(hidden_states) + value_proj = self.to_v(hidden_states) + + query_proj, key_proj, value_proj = map( + lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), (query_proj, key_proj, value_proj) + ) + + query_proj = query_proj.contiguous() + key_proj = key_proj.contiguous() + value_proj = value_proj.contiguous() + out = xformers.ops.memory_efficient_attention(query_proj, key_proj, value_proj, attn_bias=None) + + out = rearrange(out, "b h n d -> b n (h d)") + + # compute next hidden_states + # linear proj + hidden_states = self.to_out[0](hidden_states) + # dropout + hidden_states = self.to_out[1](hidden_states) + + hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) + + # res connect and rescale + hidden_states = (hidden_states + residual) / self.rescale_output_factor + return hidden_states + + def forward_xformers_0_14(self, hidden_states, **kwargs): + if not hasattr(self, "to_q"): + self.to_q = self.query + self.to_k = self.key + self.to_v = self.value + self.to_out = [self.proj_attn, torch.nn.Identity()] + self.heads = self.num_heads + return forward_xformers(self, hidden_states, **kwargs) + + if diffusers.__version__ < "0.15.0": + diffusers.models.attention.AttentionBlock.forward = forward_xformers_0_14 + else: + diffusers.models.attention_processor.Attention.forward = forward_xformers + + +def replace_vae_attn_to_sdpa(): + print("VAE: Attention.forward has been replaced to sdpa") + + def forward_sdpa(self, hidden_states, **kwargs): + residual = hidden_states + batch, channel, height, width = hidden_states.shape + + # norm + hidden_states = self.group_norm(hidden_states) + + hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) + + # proj to q, k, v + query_proj = self.to_q(hidden_states) + key_proj = self.to_k(hidden_states) + value_proj = self.to_v(hidden_states) + + query_proj, key_proj, value_proj = map( + lambda t: rearrange(t, "b n (h d) -> b n h d", h=self.heads), (query_proj, key_proj, value_proj) + ) + + out = torch.nn.functional.scaled_dot_product_attention( + query_proj, key_proj, value_proj, attn_mask=None, dropout_p=0.0, is_causal=False + ) + + out = rearrange(out, "b n h d -> b n (h d)") + + # compute next hidden_states + # linear proj + hidden_states = self.to_out[0](hidden_states) + # dropout + hidden_states = self.to_out[1](hidden_states) + + hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) + + # res connect and rescale + hidden_states = (hidden_states + residual) / self.rescale_output_factor + return hidden_states + + def forward_sdpa_0_14(self, hidden_states, **kwargs): + if not hasattr(self, "to_q"): + self.to_q = self.query + self.to_k = self.key + self.to_v = self.value + self.to_out = [self.proj_attn, torch.nn.Identity()] + self.heads = self.num_heads + return forward_sdpa(self, hidden_states, **kwargs) + + if diffusers.__version__ < "0.15.0": + diffusers.models.attention.AttentionBlock.forward = forward_sdpa_0_14 + else: + diffusers.models.attention_processor.Attention.forward = forward_sdpa + + +# endregion + +# region 画像生成の本体:lpw_stable_diffusion.py (ASL)からコピーして修正 +# https://github.com/huggingface/diffusers/blob/main/examples/community/lpw_stable_diffusion.py +# Pipelineだけ独立して使えないのと機能追加するのとでコピーして修正 + + +class PipelineLike: + def __init__( + self, + device, + vae: AutoencoderKL, + text_encoders: List[CLIPTextModel], + tokenizers: List[CLIPTokenizer], + unet: SdxlUNet2DConditionModel, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + clip_skip: int, + ): + super().__init__() + self.device = device + self.clip_skip = clip_skip + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + self.vae = vae + self.text_encoders = text_encoders + self.tokenizers = tokenizers + self.unet: SdxlUNet2DConditionModel = unet + self.scheduler = scheduler + self.safety_checker = None + + # Textual Inversion # not tested yet + self.token_replacements_list = [] + for _ in range(len(self.text_encoders)): + self.token_replacements_list.append({}) + + # ControlNet # not supported yet + self.control_nets: List[ControlNetInfo] = [] + self.control_net_enabled = True # control_netsが空ならTrueでもFalseでもControlNetは動作しない + + # Textual Inversion + def add_token_replacement(self, text_encoder_index, target_token_id, rep_token_ids): + self.token_replacements_list[text_encoder_index][target_token_id] = rep_token_ids + + def set_enable_control_net(self, en: bool): + self.control_net_enabled = en + + def get_token_replacer(self, tokenizer): + tokenizer_index = self.tokenizers.index(tokenizer) + token_replacements = self.token_replacements_list[tokenizer_index] + + def replace_tokens(tokens): + new_tokens = [] + for token in tokens: + if token in token_replacements: + replacement = token_replacements[token] + new_tokens.extend(replacement) + else: + new_tokens.append(token) + return new_tokens + + return replace_tokens + + def set_control_nets(self, ctrl_nets): + self.control_nets = ctrl_nets + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + init_image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]] = None, + mask_image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]] = None, + height: int = 1024, + width: int = 1024, + original_height: int = None, + original_width: int = None, + crop_top: int = 0, + crop_left: int = 0, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_scale: float = None, + strength: float = 0.8, + # num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + max_embeddings_multiples: Optional[int] = 3, + output_type: Optional[str] = "pil", + vae_batch_size: float = None, + return_latents: bool = False, + # return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + is_cancelled_callback: Optional[Callable[[], bool]] = None, + callback_steps: Optional[int] = 1, + img2img_noise=None, + **kwargs, + ): + # TODO support secondary prompt + num_images_per_prompt = 1 # fixed because already prompt is repeated + + if isinstance(prompt, str): + batch_size = 1 + prompt = [prompt] + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + reginonal_network = " AND " in prompt[0] + + vae_batch_size = ( + batch_size + if vae_batch_size is None + else (int(vae_batch_size) if vae_batch_size >= 1 else max(1, int(batch_size * vae_batch_size))) + ) + + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." + ) + + # get prompt text embeddings + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + if not do_classifier_free_guidance and negative_scale is not None: + print(f"negative_scale is ignored if guidance scalle <= 1.0") + negative_scale = None + + # get unconditional embeddings for classifier free guidance + if negative_prompt is None: + negative_prompt = [""] * batch_size + elif isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] * batch_size + if batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + tes_text_embs = [] + tes_uncond_embs = [] + tes_real_uncond_embs = [] + # use last pool + for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders): + token_replacer = self.get_token_replacer(tokenizer) + + text_embeddings, text_pool, uncond_embeddings, uncond_pool, _ = get_weighted_text_embeddings( + tokenizer, + text_encoder, + prompt=prompt, + uncond_prompt=negative_prompt if do_classifier_free_guidance else None, + max_embeddings_multiples=max_embeddings_multiples, + clip_skip=self.clip_skip, + token_replacer=token_replacer, + device=self.device, + **kwargs, + ) + tes_text_embs.append(text_embeddings) + tes_uncond_embs.append(uncond_embeddings) + + if negative_scale is not None: + _, real_uncond_embeddings, _ = get_weighted_text_embeddings( + token_replacer, + prompt=prompt, # こちらのトークン長に合わせてuncondを作るので75トークン超で必須 + uncond_prompt=[""] * batch_size, + max_embeddings_multiples=max_embeddings_multiples, + clip_skip=self.clip_skip, + token_replacer=token_replacer, + device=self.device, + **kwargs, + ) + tes_real_uncond_embs.append(real_uncond_embeddings) + + # concat text encoder outputs + text_embeddings = tes_text_embs[0] + uncond_embeddings = tes_uncond_embs[0] + for i in range(1, len(tes_text_embs)): + text_embeddings = torch.cat([text_embeddings, tes_text_embs[i]], dim=2) # n,77,2048 + uncond_embeddings = torch.cat([uncond_embeddings, tes_uncond_embs[i]], dim=2) # n,77,2048 + + if do_classifier_free_guidance: + if negative_scale is None: + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + else: + text_embeddings = torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings]) + + if self.control_nets: + if isinstance(clip_guide_images, PIL.Image.Image): + clip_guide_images = [clip_guide_images] + + # ControlNetのhintにguide imageを流用する + # 前処理はControlNet側で行う + + # create size embs + if original_height is None: + original_height = height + if original_width is None: + original_width = width + if crop_top is None: + crop_top = 0 + if crop_left is None: + crop_left = 0 + emb1 = sdxl_train_util.get_timestep_embedding(torch.FloatTensor([original_height, original_width]).unsqueeze(0), 256) + emb2 = sdxl_train_util.get_timestep_embedding(torch.FloatTensor([crop_top, crop_left]).unsqueeze(0), 256) + emb3 = sdxl_train_util.get_timestep_embedding(torch.FloatTensor([height, width]).unsqueeze(0), 256) + c_vector = torch.cat([emb1, emb2, emb3], dim=1).to(self.device, dtype=text_embeddings.dtype) + uc_vector = c_vector.clone().to(self.device, dtype=text_embeddings.dtype) + + c_vector = torch.cat([text_pool, c_vector], dim=1) + uc_vector = torch.cat([uncond_pool, uc_vector], dim=1) + + vector_embeddings = torch.cat([uc_vector, c_vector]) + + # set timesteps + self.scheduler.set_timesteps(num_inference_steps, self.device) + + latents_dtype = text_embeddings.dtype + init_latents_orig = None + mask = None + + if init_image is None: + # get the initial random noise unless the user supplied it + + # Unlike in other pipelines, latents need to be generated in the target device + # for 1-to-1 results reproducibility with the CompVis implementation. + # However this currently doesn't work in `mps`. + latents_shape = ( + batch_size * num_images_per_prompt, + self.unet.in_channels, + height // 8, + width // 8, + ) + + if latents is None: + if self.device.type == "mps": + # randn does not exist on mps + latents = torch.randn( + latents_shape, + generator=generator, + device="cpu", + dtype=latents_dtype, + ).to(self.device) + else: + latents = torch.randn( + latents_shape, + generator=generator, + device=self.device, + dtype=latents_dtype, + ) + else: + if latents.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + latents = latents.to(self.device) + + timesteps = self.scheduler.timesteps.to(self.device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + else: + # image to tensor + if isinstance(init_image, PIL.Image.Image): + init_image = [init_image] + if isinstance(init_image[0], PIL.Image.Image): + init_image = [preprocess_image(im) for im in init_image] + init_image = torch.cat(init_image) + if isinstance(init_image, list): + init_image = torch.stack(init_image) + + # mask image to tensor + if mask_image is not None: + if isinstance(mask_image, PIL.Image.Image): + mask_image = [mask_image] + if isinstance(mask_image[0], PIL.Image.Image): + mask_image = torch.cat([preprocess_mask(im) for im in mask_image]) # H*W, 0 for repaint + + # encode the init image into latents and scale the latents + init_image = init_image.to(device=self.device, dtype=latents_dtype) + if init_image.size()[-2:] == (height // 8, width // 8): + init_latents = init_image + else: + if vae_batch_size >= batch_size: + init_latent_dist = self.vae.encode(init_image.to(self.vae.dtype)).latent_dist + init_latents = init_latent_dist.sample(generator=generator) + else: + if torch.cuda.is_available(): + torch.cuda.empty_cache() + init_latents = [] + for i in tqdm(range(0, min(batch_size, len(init_image)), vae_batch_size)): + init_latent_dist = self.vae.encode( + (init_image[i : i + vae_batch_size] if vae_batch_size > 1 else init_image[i].unsqueeze(0)).to( + self.vae.dtype + ) + ).latent_dist + init_latents.append(init_latent_dist.sample(generator=generator)) + init_latents = torch.cat(init_latents) + + init_latents = sdxl_model_util.VAE_SCALE_FACTOR * init_latents + + if len(init_latents) == 1: + init_latents = init_latents.repeat((batch_size, 1, 1, 1)) + init_latents_orig = init_latents + + # preprocess mask + if mask_image is not None: + mask = mask_image.to(device=self.device, dtype=latents_dtype) + if len(mask) == 1: + mask = mask.repeat((batch_size, 1, 1, 1)) + + # check sizes + if not mask.shape == init_latents.shape: + raise ValueError("The mask and init_image should be the same size!") + + # get the original timestep using init_timestep + offset = self.scheduler.config.get("steps_offset", 0) + init_timestep = int(num_inference_steps * strength) + offset + init_timestep = min(init_timestep, num_inference_steps) + + timesteps = self.scheduler.timesteps[-init_timestep] + timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device) + + # add noise to latents using the timesteps + latents = self.scheduler.add_noise(init_latents, img2img_noise, timesteps) + + t_start = max(num_inference_steps - init_timestep + offset, 0) + timesteps = self.scheduler.timesteps[t_start:].to(self.device) + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + num_latent_input = (3 if negative_scale is not None else 2) if do_classifier_free_guidance else 1 + + if self.control_nets: + guided_hints = original_control_net.get_guided_hints(self.control_nets, num_latent_input, batch_size, clip_guide_images) + + for i, t in enumerate(tqdm(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = latents.repeat((num_latent_input, 1, 1, 1)) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + if self.control_nets and self.control_net_enabled: + if reginonal_network: + num_sub_and_neg_prompts = len(text_embeddings) // batch_size + text_emb_last = text_embeddings[num_sub_and_neg_prompts - 2 :: num_sub_and_neg_prompts] # last subprompt + else: + text_emb_last = text_embeddings + + # not working yet + noise_pred = original_control_net.call_unet_and_control_net( + i, + num_latent_input, + self.unet, + self.control_nets, + guided_hints, + i / len(timesteps), + latent_model_input, + t, + text_emb_last, + ).sample + else: + noise_pred = self.unet(latent_model_input, t, text_embeddings, vector_embeddings) + + # perform guidance + if do_classifier_free_guidance: + if negative_scale is None: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(num_latent_input) # uncond by negative prompt + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + else: + noise_pred_negative, noise_pred_text, noise_pred_uncond = noise_pred.chunk( + num_latent_input + ) # uncond is real uncond + noise_pred = ( + noise_pred_uncond + + guidance_scale * (noise_pred_text - noise_pred_uncond) + - negative_scale * (noise_pred_negative - noise_pred_uncond) + ) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + if mask is not None: + # masking + init_latents_proper = self.scheduler.add_noise(init_latents_orig, img2img_noise, torch.tensor([t])) + latents = (init_latents_proper * mask) + (latents * (1 - mask)) + + # call the callback, if provided + if i % callback_steps == 0: + if callback is not None: + callback(i, t, latents) + if is_cancelled_callback is not None and is_cancelled_callback(): + return None + + if return_latents: + return (latents, False) + + latents = 1 / sdxl_model_util.VAE_SCALE_FACTOR * latents + if vae_batch_size >= batch_size: + image = self.vae.decode(latents.to(self.vae.dtype)).sample + else: + if torch.cuda.is_available(): + torch.cuda.empty_cache() + images = [] + for i in tqdm(range(0, batch_size, vae_batch_size)): + images.append( + self.vae.decode( + (latents[i : i + vae_batch_size] if vae_batch_size > 1 else latents[i].unsqueeze(0)).to(self.vae.dtype) + ).sample + ) + image = torch.cat(images) + + image = (image / 2 + 0.5).clamp(0, 1) + + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + if output_type == "pil": + # image = self.numpy_to_pil(image) + image = (image * 255).round().astype("uint8") + image = [Image.fromarray(im) for im in image] + + return image + + # return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + + +re_attention = re.compile( + r""" +\\\(| +\\\)| +\\\[| +\\]| +\\\\| +\\| +\(| +\[| +:([+-]?[.\d]+)\)| +\)| +]| +[^\\()\[\]:]+| +: +""", + re.X, +) + + +def parse_prompt_attention(text): + """ + Parses a string with attention tokens and returns a list of pairs: text and its associated weight. + Accepted tokens are: + (abc) - increases attention to abc by a multiplier of 1.1 + (abc:3.12) - increases attention to abc by a multiplier of 3.12 + [abc] - decreases attention to abc by a multiplier of 1.1 + \( - literal character '(' + \[ - literal character '[' + \) - literal character ')' + \] - literal character ']' + \\ - literal character '\' + anything else - just text + >>> parse_prompt_attention('normal text') + [['normal text', 1.0]] + >>> parse_prompt_attention('an (important) word') + [['an ', 1.0], ['important', 1.1], [' word', 1.0]] + >>> parse_prompt_attention('(unbalanced') + [['unbalanced', 1.1]] + >>> parse_prompt_attention('\(literal\]') + [['(literal]', 1.0]] + >>> parse_prompt_attention('(unnecessary)(parens)') + [['unnecessaryparens', 1.1]] + >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).') + [['a ', 1.0], + ['house', 1.5730000000000004], + [' ', 1.1], + ['on', 1.0], + [' a ', 1.1], + ['hill', 0.55], + [', sun, ', 1.1], + ['sky', 1.4641000000000006], + ['.', 1.1]] + """ + + res = [] + round_brackets = [] + square_brackets = [] + + round_bracket_multiplier = 1.1 + square_bracket_multiplier = 1 / 1.1 + + def multiply_range(start_position, multiplier): + for p in range(start_position, len(res)): + res[p][1] *= multiplier + + # keep break as separate token + text = text.replace("BREAK", "\\BREAK\\") + + for m in re_attention.finditer(text): + text = m.group(0) + weight = m.group(1) + + if text.startswith("\\"): + res.append([text[1:], 1.0]) + elif text == "(": + round_brackets.append(len(res)) + elif text == "[": + square_brackets.append(len(res)) + elif weight is not None and len(round_brackets) > 0: + multiply_range(round_brackets.pop(), float(weight)) + elif text == ")" and len(round_brackets) > 0: + multiply_range(round_brackets.pop(), round_bracket_multiplier) + elif text == "]" and len(square_brackets) > 0: + multiply_range(square_brackets.pop(), square_bracket_multiplier) + else: + res.append([text, 1.0]) + + for pos in round_brackets: + multiply_range(pos, round_bracket_multiplier) + + for pos in square_brackets: + multiply_range(pos, square_bracket_multiplier) + + if len(res) == 0: + res = [["", 1.0]] + + # merge runs of identical weights + i = 0 + while i + 1 < len(res): + if res[i][1] == res[i + 1][1] and res[i][0].strip() != "BREAK" and res[i + 1][0].strip() != "BREAK": + res[i][0] += res[i + 1][0] + res.pop(i + 1) + else: + i += 1 + + return res + + +def get_prompts_with_weights(tokenizer: CLIPTokenizer, token_replacer, prompt: List[str], max_length: int): + r""" + Tokenize a list of prompts and return its tokens with weights of each token. + No padding, starting or ending token is included. + """ + tokens = [] + weights = [] + truncated = False + + for text in prompt: + texts_and_weights = parse_prompt_attention(text) + text_token = [] + text_weight = [] + for word, weight in texts_and_weights: + if word.strip() == "BREAK": + # pad until next multiple of tokenizer's max token length + pad_len = tokenizer.model_max_length - (len(text_token) % tokenizer.model_max_length) + print(f"BREAK pad_len: {pad_len}") + for i in range(pad_len): + # v2のときEOSをつけるべきかどうかわからないぜ + # if i == 0: + # text_token.append(tokenizer.eos_token_id) + # else: + text_token.append(tokenizer.pad_token_id) + text_weight.append(1.0) + continue + + # tokenize and discard the starting and the ending token + token = tokenizer(word).input_ids[1:-1] + + token = token_replacer(token) # for Textual Inversion + + text_token += token + # copy the weight by length of token + text_weight += [weight] * len(token) + # stop if the text is too long (longer than truncation limit) + if len(text_token) > max_length: + truncated = True + break + # truncate + if len(text_token) > max_length: + truncated = True + text_token = text_token[:max_length] + text_weight = text_weight[:max_length] + tokens.append(text_token) + weights.append(text_weight) + if truncated: + print("warning: Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples") + return tokens, weights + + +def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad, no_boseos_middle=True, chunk_length=77): + r""" + Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length. + """ + max_embeddings_multiples = (max_length - 2) // (chunk_length - 2) + weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length + for i in range(len(tokens)): + tokens[i] = [bos] + tokens[i] + [eos] + [pad] * (max_length - 2 - len(tokens[i])) + if no_boseos_middle: + weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i])) + else: + w = [] + if len(weights[i]) == 0: + w = [1.0] * weights_length + else: + for j in range(max_embeddings_multiples): + w.append(1.0) # weight for starting token in this chunk + w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))] + w.append(1.0) # weight for ending token in this chunk + w += [1.0] * (weights_length - len(w)) + weights[i] = w[:] + + return tokens, weights + + +def get_unweighted_text_embeddings( + text_encoder: CLIPTextModel, + text_input: torch.Tensor, + chunk_length: int, + clip_skip: int, + eos: int, + pad: int, + no_boseos_middle: Optional[bool] = True, +): + """ + When the length of tokens is a multiple of the capacity of the text encoder, + it should be split into chunks and sent to the text encoder individually. + """ + max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2) + if max_embeddings_multiples > 1: + text_embeddings = [] + pool = None + for i in range(max_embeddings_multiples): + # extract the i-th chunk + text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone() + + # cover the head and the tail by the starting and the ending tokens + text_input_chunk[:, 0] = text_input[0, 0] + if pad == eos: # v1 + text_input_chunk[:, -1] = text_input[0, -1] + else: # v2 + for j in range(len(text_input_chunk)): + if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある + text_input_chunk[j, -1] = eos + if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD + text_input_chunk[j, 1] = eos + + # -2 is same for Text Encoder 1 and 2 + enc_out = text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True) + text_embedding = enc_out["hidden_states"][-2] + if pool is None: + pool = enc_out["text_embeds"] # use 1st chunk + + if no_boseos_middle: + if i == 0: + # discard the ending token + text_embedding = text_embedding[:, :-1] + elif i == max_embeddings_multiples - 1: + # discard the starting token + text_embedding = text_embedding[:, 1:] + else: + # discard both starting and ending tokens + text_embedding = text_embedding[:, 1:-1] + + text_embeddings.append(text_embedding) + text_embeddings = torch.concat(text_embeddings, axis=1) + else: + enc_out = text_encoder(text_input, output_hidden_states=True, return_dict=True) + text_embeddings = enc_out["hidden_states"][-2] + pool = enc_out.get("text_embeds", None) # text encoder 1 doesn't return this + return text_embeddings, pool + + +def get_weighted_text_embeddings( + tokenizer: CLIPTokenizer, + text_encoder: CLIPTextModel, + prompt: Union[str, List[str]], + uncond_prompt: Optional[Union[str, List[str]]] = None, + max_embeddings_multiples: Optional[int] = 1, + no_boseos_middle: Optional[bool] = False, + skip_parsing: Optional[bool] = False, + skip_weighting: Optional[bool] = False, + clip_skip=None, + token_replacer=None, + device=None, + **kwargs, +): + max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 + if isinstance(prompt, str): + prompt = [prompt] + + # split the prompts with "AND". each prompt must have the same number of splits + new_prompts = [] + for p in prompt: + new_prompts.extend(p.split(" AND ")) + prompt = new_prompts + + if not skip_parsing: + prompt_tokens, prompt_weights = get_prompts_with_weights(tokenizer, token_replacer, prompt, max_length - 2) + if uncond_prompt is not None: + if isinstance(uncond_prompt, str): + uncond_prompt = [uncond_prompt] + uncond_tokens, uncond_weights = get_prompts_with_weights(tokenizer, token_replacer, uncond_prompt, max_length - 2) + else: + prompt_tokens = [token[1:-1] for token in tokenizer(prompt, max_length=max_length, truncation=True).input_ids] + prompt_weights = [[1.0] * len(token) for token in prompt_tokens] + if uncond_prompt is not None: + if isinstance(uncond_prompt, str): + uncond_prompt = [uncond_prompt] + uncond_tokens = [token[1:-1] for token in tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids] + uncond_weights = [[1.0] * len(token) for token in uncond_tokens] + + # round up the longest length of tokens to a multiple of (model_max_length - 2) + max_length = max([len(token) for token in prompt_tokens]) + if uncond_prompt is not None: + max_length = max(max_length, max([len(token) for token in uncond_tokens])) + + max_embeddings_multiples = min( + max_embeddings_multiples, + (max_length - 1) // (tokenizer.model_max_length - 2) + 1, + ) + max_embeddings_multiples = max(1, max_embeddings_multiples) + max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 + + # pad the length of tokens and weights + bos = tokenizer.bos_token_id + eos = tokenizer.eos_token_id + pad = tokenizer.pad_token_id + prompt_tokens, prompt_weights = pad_tokens_and_weights( + prompt_tokens, + prompt_weights, + max_length, + bos, + eos, + pad, + no_boseos_middle=no_boseos_middle, + chunk_length=tokenizer.model_max_length, + ) + prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=device) + if uncond_prompt is not None: + uncond_tokens, uncond_weights = pad_tokens_and_weights( + uncond_tokens, + uncond_weights, + max_length, + bos, + eos, + pad, + no_boseos_middle=no_boseos_middle, + chunk_length=tokenizer.model_max_length, + ) + uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=device) + + # get the embeddings + text_embeddings, text_pool = get_unweighted_text_embeddings( + text_encoder, + prompt_tokens, + tokenizer.model_max_length, + clip_skip, + eos, + pad, + no_boseos_middle=no_boseos_middle, + ) + prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=device) + if uncond_prompt is not None: + uncond_embeddings, uncond_pool = get_unweighted_text_embeddings( + text_encoder, + uncond_tokens, + tokenizer.model_max_length, + clip_skip, + eos, + pad, + no_boseos_middle=no_boseos_middle, + ) + uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=device) + + # assign weights to the prompts and normalize in the sense of mean + # TODO: should we normalize by chunk or in a whole (current implementation)? + # →全体でいいんじゃないかな + if (not skip_parsing) and (not skip_weighting): + previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) + text_embeddings *= prompt_weights.unsqueeze(-1) + current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) + text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) + if uncond_prompt is not None: + previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) + uncond_embeddings *= uncond_weights.unsqueeze(-1) + current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) + uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) + + if uncond_prompt is not None: + return text_embeddings, text_pool, uncond_embeddings, uncond_pool, prompt_tokens + return text_embeddings, text_pool, None, None, prompt_tokens + + +def preprocess_image(image): + w, h = image.size + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + image = image.resize((w, h), resample=PIL.Image.LANCZOS) + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + return 2.0 * image - 1.0 + + +def preprocess_mask(mask): + mask = mask.convert("L") + w, h = mask.size + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + mask = mask.resize((w // 8, h // 8), resample=PIL.Image.BILINEAR) # LANCZOS) + mask = np.array(mask).astype(np.float32) / 255.0 + mask = np.tile(mask, (4, 1, 1)) + mask = mask[None].transpose(0, 1, 2, 3) # what does this step do? + mask = 1 - mask # repaint white, keep black + mask = torch.from_numpy(mask) + return mask + + +# regular expression for dynamic prompt: +# starts and ends with "{" and "}" +# contains at least one variant divided by "|" +# optional framgments divided by "$$" at start +# if the first fragment is "E" or "e", enumerate all variants +# if the second fragment is a number or two numbers, repeat the variants in the range +# if the third fragment is a string, use it as a separator + +RE_DYNAMIC_PROMPT = re.compile(r"\{((e|E)\$\$)?(([\d\-]+)\$\$)?(([^\|\}]+?)\$\$)?(.+?((\|).+?)*?)\}") + + +def handle_dynamic_prompt_variants(prompt, repeat_count): + founds = list(RE_DYNAMIC_PROMPT.finditer(prompt)) + if not founds: + return [prompt] + + # make each replacement for each variant + enumerating = False + replacers = [] + for found in founds: + # if "e$$" is found, enumerate all variants + found_enumerating = found.group(2) is not None + enumerating = enumerating or found_enumerating + + separator = ", " if found.group(6) is None else found.group(6) + variants = found.group(7).split("|") + + # parse count range + count_range = found.group(4) + if count_range is None: + count_range = [1, 1] + else: + count_range = count_range.split("-") + if len(count_range) == 1: + count_range = [int(count_range[0]), int(count_range[0])] + elif len(count_range) == 2: + count_range = [int(count_range[0]), int(count_range[1])] + else: + print(f"invalid count range: {count_range}") + count_range = [1, 1] + if count_range[0] > count_range[1]: + count_range = [count_range[1], count_range[0]] + if count_range[0] < 0: + count_range[0] = 0 + if count_range[1] > len(variants): + count_range[1] = len(variants) + + if found_enumerating: + # make function to enumerate all combinations + def make_replacer_enum(vari, cr, sep): + def replacer(): + values = [] + for count in range(cr[0], cr[1] + 1): + for comb in itertools.combinations(vari, count): + values.append(sep.join(comb)) + return values + + return replacer + + replacers.append(make_replacer_enum(variants, count_range, separator)) + else: + # make function to choose random combinations + def make_replacer_single(vari, cr, sep): + def replacer(): + count = random.randint(cr[0], cr[1]) + comb = random.sample(vari, count) + return [sep.join(comb)] + + return replacer + + replacers.append(make_replacer_single(variants, count_range, separator)) + + # make each prompt + if not enumerating: + # if not enumerating, repeat the prompt, replace each variant randomly + prompts = [] + for _ in range(repeat_count): + current = prompt + for found, replacer in zip(founds, replacers): + current = current.replace(found.group(0), replacer()[0], 1) + prompts.append(current) + else: + # if enumerating, iterate all combinations for previous prompts + prompts = [prompt] + + for found, replacer in zip(founds, replacers): + if found.group(2) is not None: + # make all combinations for existing prompts + new_prompts = [] + for current in prompts: + replecements = replacer() + for replecement in replecements: + new_prompts.append(current.replace(found.group(0), replecement, 1)) + prompts = new_prompts + + for found, replacer in zip(founds, replacers): + # make random selection for existing prompts + if found.group(2) is None: + for i in range(len(prompts)): + prompts[i] = prompts[i].replace(found.group(0), replacer()[0], 1) + + return prompts + + +# endregion + + +# def load_clip_l14_336(dtype): +# print(f"loading CLIP: {CLIP_ID_L14_336}") +# text_encoder = CLIPTextModel.from_pretrained(CLIP_ID_L14_336, torch_dtype=dtype) +# return text_encoder + + +class BatchDataBase(NamedTuple): + # バッチ分割が必要ないデータ + step: int + prompt: str + negative_prompt: str + seed: int + init_image: Any + mask_image: Any + clip_prompt: str + guide_image: Any + + +class BatchDataExt(NamedTuple): + # バッチ分割が必要なデータ + width: int + height: int + original_width: int + original_height: int + crop_left: int + crop_top: int + steps: int + scale: float + negative_scale: float + strength: float + network_muls: Tuple[float] + num_sub_prompts: int + + +class BatchData(NamedTuple): + return_latents: bool + base: BatchDataBase + ext: BatchDataExt + + +def main(args): + if args.fp16: + dtype = torch.float16 + elif args.bf16: + dtype = torch.bfloat16 + else: + dtype = torch.float32 + + highres_fix = args.highres_fix_scale is not None + # assert not highres_fix or args.image_path is None, f"highres_fix doesn't work with img2img / highres_fixはimg2imgと同時に使えません" + + # モデルを読み込む + if not os.path.isfile(args.ckpt): # ファイルがないならパターンで探し、一つだけ該当すればそれを使う + files = glob.glob(args.ckpt) + if len(files) == 1: + args.ckpt = files[0] + + use_stable_diffusion_format = os.path.isfile(args.ckpt) + assert use_stable_diffusion_format, "Diffusers pretrained models are not supported yet" + print("load StableDiffusion checkpoint") + text_encoder1, text_encoder2, vae, unet, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint( + sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, args.ckpt, "cpu" + ) + # else: + # print("load Diffusers pretrained models") + # TODO use Diffusers 0.18.1 and support SDXL pipeline + # raise NotImplementedError("Diffusers pretrained models are not supported yet") + # loading_pipe = StableDiffusionXLPipeline.from_pretrained(args.ckpt, safety_checker=None, torch_dtype=dtype) + # text_encoder = loading_pipe.text_encoder + # vae = loading_pipe.vae + # unet = loading_pipe.unet + # tokenizer = loading_pipe.tokenizer + # del loading_pipe + + # # Diffusers U-Net to original U-Net + # original_unet = SdxlUNet2DConditionModel( + # unet.config.sample_size, + # unet.config.attention_head_dim, + # unet.config.cross_attention_dim, + # unet.config.use_linear_projection, + # unet.config.upcast_attention, + # ) + # original_unet.load_state_dict(unet.state_dict()) + # unet = original_unet + + # VAEを読み込む + if args.vae is not None: + vae = model_util.load_vae(args.vae, dtype) + print("additional VAE loaded") + + # xformers、Hypernetwork対応 + if not args.diffusers_xformers: + mem_eff = not (args.xformers or args.sdpa) + replace_unet_modules(unet, mem_eff, args.xformers, args.sdpa) + replace_vae_modules(vae, mem_eff, args.xformers, args.sdpa) + + # tokenizerを読み込む + print("loading tokenizer") + if use_stable_diffusion_format: + tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args) + + # schedulerを用意する + sched_init_args = {} + scheduler_num_noises_per_step = 1 + if args.sampler == "ddim": + scheduler_cls = DDIMScheduler + scheduler_module = diffusers.schedulers.scheduling_ddim + elif args.sampler == "ddpm": # ddpmはおかしくなるのでoptionから外してある + scheduler_cls = DDPMScheduler + scheduler_module = diffusers.schedulers.scheduling_ddpm + elif args.sampler == "pndm": + scheduler_cls = PNDMScheduler + scheduler_module = diffusers.schedulers.scheduling_pndm + elif args.sampler == "lms" or args.sampler == "k_lms": + scheduler_cls = LMSDiscreteScheduler + scheduler_module = diffusers.schedulers.scheduling_lms_discrete + elif args.sampler == "euler" or args.sampler == "k_euler": + scheduler_cls = EulerDiscreteScheduler + scheduler_module = diffusers.schedulers.scheduling_euler_discrete + elif args.sampler == "euler_a" or args.sampler == "k_euler_a": + scheduler_cls = EulerAncestralDiscreteScheduler + scheduler_module = diffusers.schedulers.scheduling_euler_ancestral_discrete + elif args.sampler == "dpmsolver" or args.sampler == "dpmsolver++": + scheduler_cls = DPMSolverMultistepScheduler + sched_init_args["algorithm_type"] = args.sampler + scheduler_module = diffusers.schedulers.scheduling_dpmsolver_multistep + elif args.sampler == "dpmsingle": + scheduler_cls = DPMSolverSinglestepScheduler + scheduler_module = diffusers.schedulers.scheduling_dpmsolver_singlestep + elif args.sampler == "heun": + scheduler_cls = HeunDiscreteScheduler + scheduler_module = diffusers.schedulers.scheduling_heun_discrete + elif args.sampler == "dpm_2" or args.sampler == "k_dpm_2": + scheduler_cls = KDPM2DiscreteScheduler + scheduler_module = diffusers.schedulers.scheduling_k_dpm_2_discrete + elif args.sampler == "dpm_2_a" or args.sampler == "k_dpm_2_a": + scheduler_cls = KDPM2AncestralDiscreteScheduler + scheduler_module = diffusers.schedulers.scheduling_k_dpm_2_ancestral_discrete + scheduler_num_noises_per_step = 2 + + # samplerの乱数をあらかじめ指定するための処理 + + # replace randn + class NoiseManager: + def __init__(self): + self.sampler_noises = None + self.sampler_noise_index = 0 + + def reset_sampler_noises(self, noises): + self.sampler_noise_index = 0 + self.sampler_noises = noises + + def randn(self, shape, device=None, dtype=None, layout=None, generator=None): + # print("replacing", shape, len(self.sampler_noises), self.sampler_noise_index) + if self.sampler_noises is not None and self.sampler_noise_index < len(self.sampler_noises): + noise = self.sampler_noises[self.sampler_noise_index] + if shape != noise.shape: + noise = None + else: + noise = None + + if noise == None: + print(f"unexpected noise request: {self.sampler_noise_index}, {shape}") + noise = torch.randn(shape, dtype=dtype, device=device, generator=generator) + + self.sampler_noise_index += 1 + return noise + + class TorchRandReplacer: + def __init__(self, noise_manager): + self.noise_manager = noise_manager + + def __getattr__(self, item): + if item == "randn": + return self.noise_manager.randn + if hasattr(torch, item): + return getattr(torch, item) + raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item)) + + noise_manager = NoiseManager() + if scheduler_module is not None: + scheduler_module.torch = TorchRandReplacer(noise_manager) + + scheduler = scheduler_cls( + num_train_timesteps=SCHEDULER_TIMESTEPS, + beta_start=SCHEDULER_LINEAR_START, + beta_end=SCHEDULER_LINEAR_END, + beta_schedule=SCHEDLER_SCHEDULE, + **sched_init_args, + ) + + # clip_sample=Trueにする + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False: + print("set clip_sample to True") + scheduler.config.clip_sample = True + + # deviceを決定する + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # "mps"を考量してない + + # custom pipelineをコピったやつを生成する + if args.vae_slices: + from library.slicing_vae import SlicingAutoencoderKL + + sli_vae = SlicingAutoencoderKL( + act_fn="silu", + block_out_channels=(128, 256, 512, 512), + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D"], + in_channels=3, + latent_channels=4, + layers_per_block=2, + norm_num_groups=32, + out_channels=3, + sample_size=512, + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"], + num_slices=args.vae_slices, + ) + sli_vae.load_state_dict(vae.state_dict()) # vaeのパラメータをコピーする + vae = sli_vae + del sli_vae + + vae_dtype = dtype + if args.no_half_vae: + print("set vae_dtype to float32") + vae_dtype = torch.float32 + vae.to(vae_dtype).to(device) + + text_encoder1.to(dtype).to(device) + text_encoder2.to(dtype).to(device) + unet.to(dtype).to(device) + + # networkを組み込む + if args.network_module: + networks = [] + network_default_muls = [] + network_pre_calc = args.network_pre_calc + + for i, network_module in enumerate(args.network_module): + print("import network module:", network_module) + imported_module = importlib.import_module(network_module) + + network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i] + network_default_muls.append(network_mul) + + net_kwargs = {} + if args.network_args and i < len(args.network_args): + network_args = args.network_args[i] + # TODO escape special chars + network_args = network_args.split(";") + for net_arg in network_args: + key, value = net_arg.split("=") + net_kwargs[key] = value + + if args.network_weights and i < len(args.network_weights): + network_weight = args.network_weights[i] + print("load network weights from:", network_weight) + + if model_util.is_safetensors(network_weight) and args.network_show_meta: + from safetensors.torch import safe_open + + with safe_open(network_weight, framework="pt") as f: + metadata = f.metadata() + if metadata is not None: + print(f"metadata for: {network_weight}: {metadata}") + + network, weights_sd = imported_module.create_network_from_weights( + network_mul, network_weight, vae, [text_encoder1, text_encoder2], unet, for_inference=True, **net_kwargs + ) + else: + raise ValueError("No weight. Weight is required.") + if network is None: + return + + mergeable = network.is_mergeable() + if args.network_merge and not mergeable: + print("network is not mergiable. ignore merge option.") + + if not args.network_merge or not mergeable: + network.apply_to([text_encoder1, text_encoder2], unet) + info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい + print(f"weights are loaded: {info}") + + if args.opt_channels_last: + network.to(memory_format=torch.channels_last) + network.to(dtype).to(device) + + if network_pre_calc: + print("backup original weights") + network.backup_weights() + + networks.append(network) + else: + network.merge_to([text_encoder1, text_encoder2], unet, weights_sd, dtype, device) + + else: + networks = [] + + # upscalerの指定があれば取得する + upscaler = None + if args.highres_fix_upscaler: + print("import upscaler module:", args.highres_fix_upscaler) + imported_module = importlib.import_module(args.highres_fix_upscaler) + + us_kwargs = {} + if args.highres_fix_upscaler_args: + for net_arg in args.highres_fix_upscaler_args.split(";"): + key, value = net_arg.split("=") + us_kwargs[key] = value + + print("create upscaler") + upscaler = imported_module.create_upscaler(**us_kwargs) + upscaler.to(dtype).to(device) + + # ControlNetの処理 + control_nets: List[ControlNetInfo] = [] + if args.control_net_models: + for i, model in enumerate(args.control_net_models): + prep_type = None if not args.control_net_preps or len(args.control_net_preps) <= i else args.control_net_preps[i] + weight = 1.0 if not args.control_net_weights or len(args.control_net_weights) <= i else args.control_net_weights[i] + ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i] + + ctrl_unet, ctrl_net = original_control_net.load_control_net(False, unet, model) + prep = original_control_net.load_preprocess(prep_type) + control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio)) + + if args.opt_channels_last: + print(f"set optimizing: channels last") + text_encoder1.to(memory_format=torch.channels_last) + text_encoder2.to(memory_format=torch.channels_last) + vae.to(memory_format=torch.channels_last) + unet.to(memory_format=torch.channels_last) + if networks: + for network in networks: + network.to(memory_format=torch.channels_last) + + for cn in control_nets: + cn.unet.to(memory_format=torch.channels_last) + cn.net.to(memory_format=torch.channels_last) + + pipe = PipelineLike( + device, + vae, + [text_encoder1, text_encoder2], + [tokenizer1, tokenizer2], + unet, + scheduler, + args.clip_skip, + ) + pipe.set_control_nets(control_nets) + print("pipeline is ready.") + + if args.diffusers_xformers: + pipe.enable_xformers_memory_efficient_attention() + + # Textual Inversionを処理する + if args.textual_inversion_embeddings: + token_ids_embeds1 = [] + token_ids_embeds2 = [] + for embeds_file in args.textual_inversion_embeddings: + if model_util.is_safetensors(embeds_file): + from safetensors.torch import load_file + + data = load_file(embeds_file) + else: + data = torch.load(embeds_file, map_location="cpu") + + if "string_to_param" in data: + data = data["string_to_param"] + embeds1 = data["clip_l"] + embeds2 = data["clip_g"] + + num_vectors_per_token = embeds1.size()[0] + token_string = os.path.splitext(os.path.basename(embeds_file))[0] + token_strings = [token_string] + [f"{token_string}{i+1}" for i in range(num_vectors_per_token - 1)] + + # add new word to tokenizer, count is num_vectors_per_token + num_added_tokens1 = tokenizer1.add_tokens(token_strings) + num_added_tokens2 = tokenizer2.add_tokens(token_strings) # not working now + assert ( + num_added_tokens1 == num_vectors_per_token and num_added_tokens2 == num_vectors_per_token + ), f"tokenizer has same word to token string (filename). please rename the file / 指定した名前(ファイル名)のトークンが既に存在します。ファイルをリネームしてください: {embeds_file}" + + token_ids1 = tokenizer1.convert_tokens_to_ids(token_strings) + token_ids2 = tokenizer2.convert_tokens_to_ids(token_strings) + print(f"Textual Inversion embeddings `{token_string}` loaded. Tokens are added: {token_ids1} and {token_ids2}") + assert ( + min(token_ids1) == token_ids1[0] and token_ids1[-1] == token_ids1[0] + len(token_ids1) - 1 + ), f"token ids1 is not ordered" + assert ( + min(token_ids2) == token_ids2[0] and token_ids2[-1] == token_ids2[0] + len(token_ids2) - 1 + ), f"token ids2 is not ordered" + assert len(tokenizer1) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer1)}" + assert len(tokenizer2) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer2)}" + + if num_vectors_per_token > 1: + pipe.add_token_replacement(0, token_ids1[0], token_ids1) + pipe.add_token_replacement(1, token_ids2[0], token_ids2) + + token_ids_embeds1.append((token_ids1, embeds1)) + token_ids_embeds2.append((token_ids2, embeds2)) + + text_encoder1.resize_token_embeddings(len(tokenizer1)) + text_encoder2.resize_token_embeddings(len(tokenizer2)) + token_embeds1 = text_encoder1.get_input_embeddings().weight.data + token_embeds2 = text_encoder2.get_input_embeddings().weight.data + for token_ids, embeds in token_ids_embeds1: + for token_id, embed in zip(token_ids, embeds): + token_embeds1[token_id] = embed + for token_ids, embeds in token_ids_embeds2: + for token_id, embed in zip(token_ids, embeds): + token_embeds2[token_id] = embed + + # promptを取得する + if args.from_file is not None: + print(f"reading prompts from {args.from_file}") + with open(args.from_file, "r", encoding="utf-8") as f: + prompt_list = f.read().splitlines() + prompt_list = [d for d in prompt_list if len(d.strip()) > 0] + elif args.prompt is not None: + prompt_list = [args.prompt] + else: + prompt_list = [] + + if args.interactive: + args.n_iter = 1 + + # img2imgの前処理、画像の読み込みなど + def load_images(path): + if os.path.isfile(path): + paths = [path] + else: + paths = ( + glob.glob(os.path.join(path, "*.png")) + + glob.glob(os.path.join(path, "*.jpg")) + + glob.glob(os.path.join(path, "*.jpeg")) + + glob.glob(os.path.join(path, "*.webp")) + ) + paths.sort() + + images = [] + for p in paths: + image = Image.open(p) + if image.mode != "RGB": + print(f"convert image to RGB from {image.mode}: {p}") + image = image.convert("RGB") + images.append(image) + + return images + + def resize_images(imgs, size): + resized = [] + for img in imgs: + r_img = img.resize(size, Image.Resampling.LANCZOS) + if hasattr(img, "filename"): # filename属性がない場合があるらしい + r_img.filename = img.filename + resized.append(r_img) + return resized + + if args.image_path is not None: + print(f"load image for img2img: {args.image_path}") + init_images = load_images(args.image_path) + assert len(init_images) > 0, f"No image / 画像がありません: {args.image_path}" + print(f"loaded {len(init_images)} images for img2img") + else: + init_images = None + + if args.mask_path is not None: + print(f"load mask for inpainting: {args.mask_path}") + mask_images = load_images(args.mask_path) + assert len(mask_images) > 0, f"No mask image / マスク画像がありません: {args.image_path}" + print(f"loaded {len(mask_images)} mask images for inpainting") + else: + mask_images = None + + # promptがないとき、画像のPngInfoから取得する + if init_images is not None and len(prompt_list) == 0 and not args.interactive: + print("get prompts from images' meta data") + for img in init_images: + if "prompt" in img.text: + prompt = img.text["prompt"] + if "negative-prompt" in img.text: + prompt += " --n " + img.text["negative-prompt"] + prompt_list.append(prompt) + + # プロンプトと画像を一致させるため指定回数だけ繰り返す(画像を増幅する) + l = [] + for im in init_images: + l.extend([im] * args.images_per_prompt) + init_images = l + + if mask_images is not None: + l = [] + for im in mask_images: + l.extend([im] * args.images_per_prompt) + mask_images = l + + # 画像サイズにオプション指定があるときはリサイズする + if args.W is not None and args.H is not None: + # highres fix を考慮に入れる + w, h = args.W, args.H + if highres_fix: + w = int(w * args.highres_fix_scale + 0.5) + h = int(h * args.highres_fix_scale + 0.5) + + if init_images is not None: + print(f"resize img2img source images to {w}*{h}") + init_images = resize_images(init_images, (w, h)) + if mask_images is not None: + print(f"resize img2img mask images to {w}*{h}") + mask_images = resize_images(mask_images, (w, h)) + + regional_network = False + if networks and mask_images: + # mask を領域情報として流用する、現在は一回のコマンド呼び出しで1枚だけ対応 + regional_network = True + print("use mask as region") + + size = None + for i, network in enumerate(networks): + if i < 3: + np_mask = np.array(mask_images[0]) + np_mask = np_mask[:, :, i] + size = np_mask.shape + else: + np_mask = np.full(size, 255, dtype=np.uint8) + mask = torch.from_numpy(np_mask.astype(np.float32) / 255.0) + network.set_region(i, i == len(networks) - 1, mask) + mask_images = None + + prev_image = None # for VGG16 guided + if args.guide_image_path is not None: + print(f"load image for ControlNet guidance: {args.guide_image_path}") + guide_images = [] + for p in args.guide_image_path: + guide_images.extend(load_images(p)) + + print(f"loaded {len(guide_images)} guide images for guidance") + if len(guide_images) == 0: + print(f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}") + guide_images = None + else: + guide_images = None + + # seed指定時はseedを決めておく + if args.seed is not None: + # dynamic promptを使うと足りなくなる→images_per_promptを適当に大きくしておいてもらう + random.seed(args.seed) + predefined_seeds = [random.randint(0, 0x7FFFFFFF) for _ in range(args.n_iter * len(prompt_list) * args.images_per_prompt)] + if len(predefined_seeds) == 1: + predefined_seeds[0] = args.seed + else: + predefined_seeds = None + + # デフォルト画像サイズを設定する:img2imgではこれらの値は無視される(またはW*Hにリサイズ済み) + if args.W is None: + args.W = 1024 + if args.H is None: + args.H = 1024 + + # 画像生成のループ + os.makedirs(args.outdir, exist_ok=True) + max_embeddings_multiples = 1 if args.max_embeddings_multiples is None else args.max_embeddings_multiples + + for gen_iter in range(args.n_iter): + print(f"iteration {gen_iter+1}/{args.n_iter}") + iter_seed = random.randint(0, 0x7FFFFFFF) + + # バッチ処理の関数 + def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): + batch_size = len(batch) + + # highres_fixの処理 + if highres_fix and not highres_1st: + # 1st stageのバッチを作成して呼び出す:サイズを小さくして呼び出す + is_1st_latent = upscaler.support_latents() if upscaler else args.highres_fix_latents_upscaling + + print("process 1st stage") + batch_1st = [] + for _, base, ext in batch: + + def scale_and_round(x): + if x is None: + return None + return int(x * args.highres_fix_scale + 0.5) + + width_1st = scale_and_round(ext.width) + height_1st = scale_and_round(ext.height) + width_1st = width_1st - width_1st % 32 + height_1st = height_1st - height_1st % 32 + + original_width_1st = scale_and_round(ext.original_width) + original_height_1st = scale_and_round(ext.original_height) + crop_left_1st = scale_and_round(ext.crop_left) + crop_top_1st = scale_and_round(ext.crop_top) + + strength_1st = ext.strength if args.highres_fix_strength is None else args.highres_fix_strength + + ext_1st = BatchDataExt( + width_1st, + height_1st, + original_width_1st, + original_height_1st, + crop_left_1st, + crop_top_1st, + args.highres_fix_steps, + ext.scale, + ext.negative_scale, + strength_1st, + ext.network_muls, + ext.num_sub_prompts, + ) + batch_1st.append(BatchData(is_1st_latent, base, ext_1st)) + + pipe.set_enable_control_net(True) # 1st stageではControlNetを有効にする + images_1st = process_batch(batch_1st, True, True) + + # 2nd stageのバッチを作成して以下処理する + print("process 2nd stage") + width_2nd, height_2nd = batch[0].ext.width, batch[0].ext.height + + if upscaler: + # upscalerを使って画像を拡大する + lowreso_imgs = None if is_1st_latent else images_1st + lowreso_latents = None if not is_1st_latent else images_1st + + # 戻り値はPIL.Image.Imageかtorch.Tensorのlatents + batch_size = len(images_1st) + vae_batch_size = ( + batch_size + if args.vae_batch_size is None + else (max(1, int(batch_size * args.vae_batch_size)) if args.vae_batch_size < 1 else args.vae_batch_size) + ) + vae_batch_size = int(vae_batch_size) + images_1st = upscaler.upscale( + vae, lowreso_imgs, lowreso_latents, dtype, width_2nd, height_2nd, batch_size, vae_batch_size + ) + + elif args.highres_fix_latents_upscaling: + # latentを拡大する + org_dtype = images_1st.dtype + if images_1st.dtype == torch.bfloat16: + images_1st = images_1st.to(torch.float) # interpolateがbf16をサポートしていない + images_1st = torch.nn.functional.interpolate( + images_1st, (batch[0].ext.height // 8, batch[0].ext.width // 8), mode="bilinear" + ) # , antialias=True) + images_1st = images_1st.to(org_dtype) + + else: + # 画像をLANCZOSで拡大する + images_1st = [image.resize((width_2nd, height_2nd), resample=PIL.Image.LANCZOS) for image in images_1st] + + batch_2nd = [] + for i, (bd, image) in enumerate(zip(batch, images_1st)): + bd_2nd = BatchData(False, BatchDataBase(*bd.base[0:3], bd.base.seed + 1, image, None, *bd.base[6:]), bd.ext) + batch_2nd.append(bd_2nd) + batch = batch_2nd + + if args.highres_fix_disable_control_net: + pipe.set_enable_control_net(False) # オプション指定時、2nd stageではControlNetを無効にする + + # このバッチの情報を取り出す + ( + return_latents, + (step_first, _, _, _, init_image, mask_image, _, guide_image), + ( + width, + height, + original_width, + original_height, + crop_left, + crop_top, + steps, + scale, + negative_scale, + strength, + network_muls, + num_sub_prompts, + ), + ) = batch[0] + noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR) + + prompts = [] + negative_prompts = [] + start_code = torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) + noises = [ + torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) + for _ in range(steps * scheduler_num_noises_per_step) + ] + seeds = [] + clip_prompts = [] + + if init_image is not None: # img2img? + i2i_noises = torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) + init_images = [] + + if mask_image is not None: + mask_images = [] + else: + mask_images = None + else: + i2i_noises = None + init_images = None + mask_images = None + + if guide_image is not None: # CLIP image guided? + guide_images = [] + else: + guide_images = None + + # バッチ内の位置に関わらず同じ乱数を使うためにここで乱数を生成しておく。あわせてimage/maskがbatch内で同一かチェックする + all_images_are_same = True + all_masks_are_same = True + all_guide_images_are_same = True + for i, (_, (_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), _) in enumerate(batch): + prompts.append(prompt) + negative_prompts.append(negative_prompt) + seeds.append(seed) + clip_prompts.append(clip_prompt) + + if init_image is not None: + init_images.append(init_image) + if i > 0 and all_images_are_same: + all_images_are_same = init_images[-2] is init_image + + if mask_image is not None: + mask_images.append(mask_image) + if i > 0 and all_masks_are_same: + all_masks_are_same = mask_images[-2] is mask_image + + if guide_image is not None: + if type(guide_image) is list: + guide_images.extend(guide_image) + all_guide_images_are_same = False + else: + guide_images.append(guide_image) + if i > 0 and all_guide_images_are_same: + all_guide_images_are_same = guide_images[-2] is guide_image + + # make start code + torch.manual_seed(seed) + start_code[i] = torch.randn(noise_shape, device=device, dtype=dtype) + + # make each noises + for j in range(steps * scheduler_num_noises_per_step): + noises[j][i] = torch.randn(noise_shape, device=device, dtype=dtype) + + if i2i_noises is not None: # img2img noise + i2i_noises[i] = torch.randn(noise_shape, device=device, dtype=dtype) + + noise_manager.reset_sampler_noises(noises) + + # すべての画像が同じなら1枚だけpipeに渡すことでpipe側で処理を高速化する + if init_images is not None and all_images_are_same: + init_images = init_images[0] + if mask_images is not None and all_masks_are_same: + mask_images = mask_images[0] + if guide_images is not None and all_guide_images_are_same: + guide_images = guide_images[0] + + # ControlNet使用時はguide imageをリサイズする + if control_nets: + # TODO resampleのメソッド + guide_images = guide_images if type(guide_images) == list else [guide_images] + guide_images = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in guide_images] + if len(guide_images) == 1: + guide_images = guide_images[0] + + # generate + if networks: + # 追加ネットワークの処理 + shared = {} + for n, m in zip(networks, network_muls if network_muls else network_default_muls): + n.set_multiplier(m) + if regional_network: + n.set_current_generation(batch_size, num_sub_prompts, width, height, shared) + + if not regional_network and network_pre_calc: + for n in networks: + n.restore_weights() + for n in networks: + n.pre_calculation() + print("pre-calculation... done") + + images = pipe( + prompts, + negative_prompts, + init_images, + mask_images, + height, + width, + original_height, + original_width, + crop_top, + crop_left, + steps, + scale, + negative_scale, + strength, + latents=start_code, + output_type="pil", + max_embeddings_multiples=max_embeddings_multiples, + img2img_noise=i2i_noises, + vae_batch_size=args.vae_batch_size, + return_latents=return_latents, + clip_prompts=clip_prompts, + clip_guide_images=guide_images, + ) + if highres_1st and not args.highres_fix_save_1st: # return images or latents + return images + + # save image + highres_prefix = ("0" if highres_1st else "1") if highres_fix else "" + ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) + for i, (image, prompt, negative_prompts, seed, clip_prompt) in enumerate( + zip(images, prompts, negative_prompts, seeds, clip_prompts) + ): + metadata = PngInfo() + metadata.add_text("prompt", prompt) + metadata.add_text("seed", str(seed)) + metadata.add_text("sampler", args.sampler) + metadata.add_text("steps", str(steps)) + metadata.add_text("scale", str(scale)) + if negative_prompt is not None: + metadata.add_text("negative-prompt", negative_prompt) + if negative_scale is not None: + metadata.add_text("negative-scale", str(negative_scale)) + if clip_prompt is not None: + metadata.add_text("clip-prompt", clip_prompt) + metadata.add_text("original-height", str(original_height)) + metadata.add_text("original-width", str(original_width)) + metadata.add_text("crop-top", str(crop_top)) + metadata.add_text("crop-left", str(crop_left)) + + if args.use_original_file_name and init_images is not None: + if type(init_images) is list: + fln = os.path.splitext(os.path.basename(init_images[i % len(init_images)].filename))[0] + ".png" + else: + fln = os.path.splitext(os.path.basename(init_images.filename))[0] + ".png" + elif args.sequential_file_name: + fln = f"im_{highres_prefix}{step_first + i + 1:06d}.png" + else: + fln = f"im_{ts_str}_{highres_prefix}{i:03d}_{seed}.png" + + image.save(os.path.join(args.outdir, fln), pnginfo=metadata) + + if not args.no_preview and not highres_1st and args.interactive: + try: + import cv2 + + for prompt, image in zip(prompts, images): + cv2.imshow(prompt[:128], np.array(image)[:, :, ::-1]) # プロンプトが長いと死ぬ + cv2.waitKey() + cv2.destroyAllWindows() + except ImportError: + print("opencv-python is not installed, cannot preview / opencv-pythonがインストールされていないためプレビューできません") + + return images + + # 画像生成のプロンプトが一周するまでのループ + prompt_index = 0 + global_step = 0 + batch_data = [] + while args.interactive or prompt_index < len(prompt_list): + if len(prompt_list) == 0: + # interactive + valid = False + while not valid: + print("\nType prompt:") + try: + raw_prompt = input() + except EOFError: + break + + valid = len(raw_prompt.strip().split(" --")[0].strip()) > 0 + if not valid: # EOF, end app + break + else: + raw_prompt = prompt_list[prompt_index] + + # sd-dynamic-prompts like variants: + # count is 1 (not dynamic) or images_per_prompt (no enumeration) or arbitrary (enumeration) + raw_prompts = handle_dynamic_prompt_variants(raw_prompt, args.images_per_prompt) + + # repeat prompt + for pi in range(args.images_per_prompt if len(raw_prompts) == 1 else len(raw_prompts)): + raw_prompt = raw_prompts[pi] if len(raw_prompts) > 1 else raw_prompts[0] + + if pi == 0 or len(raw_prompts) > 1: + # parse prompt: if prompt is not changed, skip parsing + width = args.W + height = args.H + original_width = args.original_width + original_height = args.original_height + crop_top = args.crop_top + crop_left = args.crop_left + scale = args.scale + negative_scale = args.negative_scale + steps = args.steps + seed = None + seeds = None + strength = 0.8 if args.strength is None else args.strength + negative_prompt = "" + clip_prompt = None + network_muls = None + + prompt_args = raw_prompt.strip().split(" --") + prompt = prompt_args[0] + print(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}") + + for parg in prompt_args[1:]: + try: + m = re.match(r"w (\d+)", parg, re.IGNORECASE) + if m: + width = int(m.group(1)) + print(f"width: {width}") + continue + + m = re.match(r"h (\d+)", parg, re.IGNORECASE) + if m: + height = int(m.group(1)) + print(f"height: {height}") + continue + + m = re.match(r"ow (\d+)", parg, re.IGNORECASE) + if m: + original_width = int(m.group(1)) + print(f"original width: {width}") + continue + + m = re.match(r"oh (\d+)", parg, re.IGNORECASE) + if m: + original_height = int(m.group(1)) + print(f"original height: {height}") + continue + + m = re.match(r"ct (\d+)", parg, re.IGNORECASE) + if m: + crop_top = int(m.group(1)) + print(f"crop top: {crop_top}") + continue + + m = re.match(r"cl (\d+)", parg, re.IGNORECASE) + if m: + crop_left = int(m.group(1)) + print(f"crop left: {crop_left}") + continue + + m = re.match(r"s (\d+)", parg, re.IGNORECASE) + if m: # steps + steps = max(1, min(1000, int(m.group(1)))) + print(f"steps: {steps}") + continue + + m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE) + if m: # seed + seeds = [int(d) for d in m.group(1).split(",")] + print(f"seeds: {seeds}") + continue + + m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE) + if m: # scale + scale = float(m.group(1)) + print(f"scale: {scale}") + continue + + m = re.match(r"nl ([\d\.]+|none|None)", parg, re.IGNORECASE) + if m: # negative scale + if m.group(1).lower() == "none": + negative_scale = None + else: + negative_scale = float(m.group(1)) + print(f"negative scale: {negative_scale}") + continue + + m = re.match(r"t ([\d\.]+)", parg, re.IGNORECASE) + if m: # strength + strength = float(m.group(1)) + print(f"strength: {strength}") + continue + + m = re.match(r"n (.+)", parg, re.IGNORECASE) + if m: # negative prompt + negative_prompt = m.group(1) + print(f"negative prompt: {negative_prompt}") + continue + + m = re.match(r"c (.+)", parg, re.IGNORECASE) + if m: # clip prompt + clip_prompt = m.group(1) + print(f"clip prompt: {clip_prompt}") + continue + + m = re.match(r"am ([\d\.\-,]+)", parg, re.IGNORECASE) + if m: # network multiplies + network_muls = [float(v) for v in m.group(1).split(",")] + while len(network_muls) < len(networks): + network_muls.append(network_muls[-1]) + print(f"network mul: {network_muls}") + continue + + except ValueError as ex: + print(f"Exception in parsing / 解析エラー: {parg}") + print(ex) + + # prepare seed + if seeds is not None: # given in prompt + # 数が足りないなら前のをそのまま使う + if len(seeds) > 0: + seed = seeds.pop(0) + else: + if predefined_seeds is not None: + if len(predefined_seeds) > 0: + seed = predefined_seeds.pop(0) + else: + print("predefined seeds are exhausted") + seed = None + elif args.iter_same_seed: + seeds = iter_seed + else: + seed = None # 前のを消す + + if seed is None: + seed = random.randint(0, 0x7FFFFFFF) + if args.interactive: + print(f"seed: {seed}") + + # prepare init image, guide image and mask + init_image = mask_image = guide_image = None + + # 同一イメージを使うとき、本当はlatentに変換しておくと無駄がないが面倒なのでとりあえず毎回処理する + if init_images is not None: + init_image = init_images[global_step % len(init_images)] + + # img2imgの場合は、基本的に元画像のサイズで生成する。highres fixの場合はargs.W, args.Hとscaleに従いリサイズ済みなので無視する + # 32単位に丸めたやつにresizeされるので踏襲する + if not highres_fix: + width, height = init_image.size + width = width - width % 32 + height = height - height % 32 + if width != init_image.size[0] or height != init_image.size[1]: + print( + f"img2img image size is not divisible by 32 so aspect ratio is changed / img2imgの画像サイズが32で割り切れないためリサイズされます。画像が歪みます" + ) + + if mask_images is not None: + mask_image = mask_images[global_step % len(mask_images)] + + if guide_images is not None: + if control_nets: # 複数件の場合あり + c = len(control_nets) + p = global_step % (len(guide_images) // c) + guide_image = guide_images[p * c : p * c + c] + else: + guide_image = guide_images[global_step % len(guide_images)] + + if regional_network: + num_sub_prompts = len(prompt.split(" AND ")) + assert ( + len(networks) <= num_sub_prompts + ), "Number of networks must be less than or equal to number of sub prompts." + else: + num_sub_prompts = None + + b1 = BatchData( + False, + BatchDataBase(global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), + BatchDataExt( + width, + height, + original_width, + original_height, + crop_left, + crop_top, + steps, + scale, + negative_scale, + strength, + tuple(network_muls) if network_muls else None, + num_sub_prompts, + ), + ) + if len(batch_data) > 0 and batch_data[-1].ext != b1.ext: # バッチ分割必要? + process_batch(batch_data, highres_fix) + batch_data.clear() + + batch_data.append(b1) + if len(batch_data) == args.batch_size: + prev_image = process_batch(batch_data, highres_fix)[0] + batch_data.clear() + + global_step += 1 + + prompt_index += 1 + + if len(batch_data) > 0: + process_batch(batch_data, highres_fix) + batch_data.clear() + + print("done!") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + + parser.add_argument("--prompt", type=str, default=None, help="prompt / プロンプト") + parser.add_argument( + "--from_file", type=str, default=None, help="if specified, load prompts from this file / 指定時はプロンプトをファイルから読み込む" + ) + parser.add_argument( + "--interactive", action="store_true", help="interactive mode (generates one image) / 対話モード(生成される画像は1枚になります)" + ) + parser.add_argument( + "--no_preview", action="store_true", help="do not show generated image in interactive mode / 対話モードで画像を表示しない" + ) + parser.add_argument( + "--image_path", type=str, default=None, help="image to inpaint or to generate from / img2imgまたはinpaintを行う元画像" + ) + parser.add_argument("--mask_path", type=str, default=None, help="mask in inpainting / inpaint時のマスク") + parser.add_argument("--strength", type=float, default=None, help="img2img strength / img2img時のstrength") + parser.add_argument("--images_per_prompt", type=int, default=1, help="number of images per prompt / プロンプトあたりの出力枚数") + parser.add_argument("--outdir", type=str, default="outputs", help="dir to write results to / 生成画像の出力先") + parser.add_argument("--sequential_file_name", action="store_true", help="sequential output file name / 生成画像のファイル名を連番にする") + parser.add_argument( + "--use_original_file_name", + action="store_true", + help="prepend original file name in img2img / img2imgで元画像のファイル名を生成画像のファイル名の先頭に付ける", + ) + # parser.add_argument("--ddim_eta", type=float, default=0.0, help="ddim eta (eta=0.0 corresponds to deterministic sampling", ) + parser.add_argument("--n_iter", type=int, default=1, help="sample this often / 繰り返し回数") + parser.add_argument("--H", type=int, default=None, help="image height, in pixel space / 生成画像高さ") + parser.add_argument("--W", type=int, default=None, help="image width, in pixel space / 生成画像幅") + parser.add_argument( + "--original_height", type=int, default=None, help="original height for SDXL conditioning / SDXLの条件付けに用いるoriginal heightの値" + ) + parser.add_argument( + "--original_width", type=int, default=None, help="original width for SDXL conditioning / SDXLの条件付けに用いるoriginal widthの値" + ) + parser.add_argument("--crop_top", type=int, default=None, help="crop top for SDXL conditioning / SDXLの条件付けに用いるcrop topの値") + parser.add_argument("--crop_left", type=int, default=None, help="crop left for SDXL conditioning / SDXLの条件付けに用いるcrop leftの値") + parser.add_argument("--batch_size", type=int, default=1, help="batch size / バッチサイズ") + parser.add_argument( + "--vae_batch_size", + type=float, + default=None, + help="batch size for VAE, < 1.0 for ratio / VAE処理時のバッチサイズ、1未満の値の場合は通常バッチサイズの比率", + ) + parser.add_argument( + "--vae_slices", + type=int, + default=None, + help="number of slices to split image into for VAE to reduce VRAM usage, None for no splitting (default), slower if specified. 16 or 32 recommended / VAE処理時にVRAM使用量削減のため画像を分割するスライス数、Noneの場合は分割しない(デフォルト)、指定すると遅くなる。16か32程度を推奨", + ) + parser.add_argument("--no_half_vae", action="store_true", help="do not use fp16/bf16 precision for VAE / VAE処理時にfp16/bf16を使わない") + parser.add_argument("--steps", type=int, default=50, help="number of ddim sampling steps / サンプリングステップ数") + parser.add_argument( + "--sampler", + type=str, + default="ddim", + choices=[ + "ddim", + "pndm", + "lms", + "euler", + "euler_a", + "heun", + "dpm_2", + "dpm_2_a", + "dpmsolver", + "dpmsolver++", + "dpmsingle", + "k_lms", + "k_euler", + "k_euler_a", + "k_dpm_2", + "k_dpm_2_a", + ], + help=f"sampler (scheduler) type / サンプラー(スケジューラ)の種類", + ) + parser.add_argument( + "--scale", + type=float, + default=7.5, + help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty)) / guidance scale", + ) + parser.add_argument("--ckpt", type=str, default=None, help="path to checkpoint of model / モデルのcheckpointファイルまたはディレクトリ") + parser.add_argument( + "--vae", type=str, default=None, help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ" + ) + parser.add_argument( + "--tokenizer_cache_dir", + type=str, + default=None, + help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)", + ) + # parser.add_argument("--replace_clip_l14_336", action='store_true', + # help="Replace CLIP (Text Encoder) to l/14@336 / CLIP(Text Encoder)をl/14@336に入れ替える") + parser.add_argument( + "--seed", + type=int, + default=None, + help="seed, or seed of seeds in multiple generation / 1枚生成時のseed、または複数枚生成時の乱数seedを決めるためのseed", + ) + parser.add_argument( + "--iter_same_seed", + action="store_true", + help="use same seed for all prompts in iteration if no seed specified / 乱数seedの指定がないとき繰り返し内はすべて同じseedを使う(プロンプト間の差異の比較用)", + ) + parser.add_argument("--fp16", action="store_true", help="use fp16 / fp16を指定し省メモリ化する") + parser.add_argument("--bf16", action="store_true", help="use bfloat16 / bfloat16を指定し省メモリ化する") + parser.add_argument("--xformers", action="store_true", help="use xformers / xformersを使用し高速化する") + parser.add_argument("--sdpa", action="store_true", help="use sdpa in PyTorch 2 / sdpa") + parser.add_argument( + "--diffusers_xformers", + action="store_true", + help="use xformers by diffusers (Hypernetworks doesn't work) / Diffusersでxformersを使用する(Hypernetwork利用不可)", + ) + parser.add_argument( + "--opt_channels_last", action="store_true", help="set channels last option to model / モデルにchannels lastを指定し最適化する" + ) + parser.add_argument( + "--network_module", type=str, default=None, nargs="*", help="additional network module to use / 追加ネットワークを使う時そのモジュール名" + ) + parser.add_argument( + "--network_weights", type=str, default=None, nargs="*", help="additional network weights to load / 追加ネットワークの重み" + ) + parser.add_argument("--network_mul", type=float, default=None, nargs="*", help="additional network multiplier / 追加ネットワークの効果の倍率") + parser.add_argument( + "--network_args", type=str, default=None, nargs="*", help="additional argmuments for network (key=value) / ネットワークへの追加の引数" + ) + parser.add_argument("--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する") + parser.add_argument("--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする") + parser.add_argument( + "--network_pre_calc", action="store_true", help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する" + ) + parser.add_argument( + "--textual_inversion_embeddings", + type=str, + default=None, + nargs="*", + help="Embeddings files of Textual Inversion / Textual Inversionのembeddings", + ) + parser.add_argument("--clip_skip", type=int, default=None, help="layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う") + parser.add_argument( + "--max_embeddings_multiples", + type=int, + default=None, + help="max embeding multiples, max token length is 75 * multiples / トークン長をデフォルトの何倍とするか 75*この値 がトークン長となる", + ) + parser.add_argument( + "--guide_image_path", type=str, default=None, nargs="*", help="image to CLIP guidance / CLIP guided SDでガイドに使う画像" + ) + parser.add_argument( + "--highres_fix_scale", + type=float, + default=None, + help="enable highres fix, reso scale for 1st stage / highres fixを有効にして最初の解像度をこのscaleにする", + ) + parser.add_argument( + "--highres_fix_steps", type=int, default=28, help="1st stage steps for highres fix / highres fixの最初のステージのステップ数" + ) + parser.add_argument( + "--highres_fix_strength", + type=float, + default=None, + help="1st stage img2img strength for highres fix / highres fixの最初のステージのimg2img時のstrength、省略時はstrengthと同じ", + ) + parser.add_argument( + "--highres_fix_save_1st", action="store_true", help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する" + ) + parser.add_argument( + "--highres_fix_latents_upscaling", + action="store_true", + help="use latents upscaling for highres fix / highres fixでlatentで拡大する", + ) + parser.add_argument( + "--highres_fix_upscaler", type=str, default=None, help="upscaler module for highres fix / highres fixで使うupscalerのモジュール名" + ) + parser.add_argument( + "--highres_fix_upscaler_args", + type=str, + default=None, + help="additional argmuments for upscaler (key=value) / upscalerへの追加の引数", + ) + parser.add_argument( + "--highres_fix_disable_control_net", + action="store_true", + help="disable ControlNet for highres fix / highres fixでControlNetを使わない", + ) + + parser.add_argument( + "--negative_scale", type=float, default=None, help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する" + ) + + parser.add_argument( + "--control_net_models", type=str, default=None, nargs="*", help="ControlNet models to use / 使用するControlNetのモデル名" + ) + parser.add_argument( + "--control_net_preps", type=str, default=None, nargs="*", help="ControlNet preprocess to use / 使用するControlNetのプリプロセス名" + ) + parser.add_argument("--control_net_weights", type=float, default=None, nargs="*", help="ControlNet weights / ControlNetの重み") + parser.add_argument( + "--control_net_ratios", + type=float, + default=None, + nargs="*", + help="ControlNet guidance ratio for steps / ControlNetでガイドするステップ比率", + ) + # parser.add_argument( + # "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像" + # ) + + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + main(args) From 8371a7a3aadb12dd1f6ea8edb3d9caa6066154f9 Mon Sep 17 00:00:00 2001 From: ykume Date: Sun, 9 Jul 2023 13:38:48 +0900 Subject: [PATCH 05/20] update readme --- README.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/README.md b/README.md index 234301705..503427a5a 100644 --- a/README.md +++ b/README.md @@ -33,6 +33,7 @@ Summary of the feature: - The image generation during training is now available. However, the VAE for SDXL seems to produce NaNs in some cases when using `fp16`. The images will be black. Currently, the NaNs cannot be avoided even with `--no_half_vae` option. It works with `bf16` or without mixed precision. - `--weighted_captions` option is not supported yet. - `--min_timestep` and `--max_timestep` options are added to each training script. These options can be used to train U-Net with different timesteps. The default values are 0 and 1000. +- `sdxl_gen_img.py` is added. This script can be used to generate images with SDXL, including LoRA. See the help message for the usage. `requirements.txt` is updated to support SDXL training. @@ -47,6 +48,7 @@ Summary of the feature: - The LoRA training can be done with 12GB GPU memory. - `--network_train_unet_only` option is highly recommended for SDXL LoRA. Because SDXL has two text encoders, the result of the training will be unexpected. - PyTorch 2 seems to use slightly less GPU memory than PyTorch 1. +- `--bucket_reso_steps` can be set to 32 instead of the default value 64. Smaller values than 32 will not work for SDXL training. Example of the optimizer settings for Adafactor with the fixed learning rate: ``` @@ -57,6 +59,12 @@ lr_warmup_steps = 100 learning_rate = 4e-7 # SDXL original learning rate ``` +### TODO + +- [ ] Support Textual Inversion training. +- [ ] Support `--weighted_captions` option. +- [ ] Change `--output_config` option to continue the training. + ## About requirements.txt These files do not contain requirements for PyTorch. Because the versions of them depend on your environment. Please install PyTorch at first (see installation guide below.) From 5f348579d115e289422732b92fad3d5228deeb35 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sun, 9 Jul 2023 12:46:35 +0800 Subject: [PATCH 06/20] Update sdxl_train.py --- sdxl_train.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/sdxl_train.py b/sdxl_train.py index 9cf20252d..06cbc5710 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -267,6 +267,14 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): unet.to(weight_dtype) text_encoder1.to(weight_dtype) text_encoder2.to(weight_dtype) + elif args.full_bf16: + assert ( + args.mixed_precision == "bf16" + ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。" + accelerator.print("enable full bf16 training.") + unet.to(weight_dtype) + text_encoder1.to(weight_dtype) + text_encoder2.to(weight_dtype) # acceleratorがなんかよろしくやってくれるらしい if args.train_text_encoder: From d974959738d89b6d62e7dc60c6e10ee0d8f8ff4c Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sun, 9 Jul 2023 12:47:26 +0800 Subject: [PATCH 07/20] Update train_util.py for full_bf16 support --- library/train_util.py | 1 + 1 file changed, 1 insertion(+) diff --git a/library/train_util.py b/library/train_util.py index 62cd145e1..b9b4eecf8 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2416,6 +2416,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: "--mixed_precision", type=str, default="no", choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度" ) parser.add_argument("--full_fp16", action="store_true", help="fp16 training including gradients / 勾配も含めてfp16で学習する") + parser.add_argument("--full_bf16", action="store_true", help="bf16 training including gradients / 勾配も含めてbf16で学習する") parser.add_argument( "--clip_skip", type=int, From 0416f26a76c39911afe7aae09f2e628e02922a1c Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 9 Jul 2023 16:02:56 +0900 Subject: [PATCH 08/20] support multi gpu in caching text encoder outputs --- README.md | 6 +++++- library/sdxl_train_util.py | 8 ++++---- sdxl_train.py | 26 ++++++++++++++++---------- sdxl_train_network.py | 4 ++-- train_network.py | 10 +++++----- 5 files changed, 32 insertions(+), 22 deletions(-) diff --git a/README.md b/README.md index 503427a5a..8a47ee3b5 100644 --- a/README.md +++ b/README.md @@ -25,7 +25,10 @@ The feature of SDXL training is now available in sdxl branch as an experimental Summary of the feature: - `sdxl_train.py` is a script for SDXL fine-tuning. The usage is almost the same as `fine_tune.py`, but it also supports DreamBooth dataset. - - `prepare_buckets_latents.py` now supports SDXL fine-tuning. + - `--full_bf16` option is added. This option enables the full bfloat16 training. This option is useful to reduce the GPU memory usage. + - However, bitsandbytes==0.35 doesn't seem to support this. Please use a newer version of bitsandbytes or another optimizer. + - I cannot find bitsandbytes>0.35.0 that works correctly on Windows. +- `prepare_buckets_latents.py` now supports SDXL fine-tuning. - `sdxl_train_network.py` is a script for LoRA training for SDXL. The usage is almost the same as `train_network.py`. - Both scripts has following additional options: - `--cache_text_encoder_outputs`: Cache the outputs of the text encoders. This option is useful to reduce the GPU memory usage. This option cannot be used with options for shuffling or dropping the captions. @@ -64,6 +67,7 @@ learning_rate = 4e-7 # SDXL original learning rate - [ ] Support Textual Inversion training. - [ ] Support `--weighted_captions` option. - [ ] Change `--output_config` option to continue the training. +- [ ] Extend `--full_bf16` for all the scripts. ## About requirements.txt diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index c67a70431..675aac3da 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -319,7 +319,7 @@ def diffusers_saver(out_dir): # TextEncoderの出力をキャッシュする # weight_dtypeを指定するとText Encoderそのもの、およひ出力がweight_dtypeになる -def cache_text_encoder_outputs(args, accelerator, tokenizers, text_encoders, data_loader, weight_dtype): +def cache_text_encoder_outputs(args, accelerator, tokenizers, text_encoders, dataset, weight_dtype): print("caching text encoder outputs") tokenizer1, tokenizer2 = tokenizers @@ -332,9 +332,9 @@ def cache_text_encoder_outputs(args, accelerator, tokenizers, text_encoders, dat text_encoder1_cache = {} text_encoder2_cache = {} - for batch in tqdm(data_loader): - input_ids1_batch = batch["input_ids"] - input_ids2_batch = batch["input_ids2"] + for batch in tqdm(dataset): + input_ids1_batch = batch["input_ids"].to(accelerator.device) + input_ids2_batch = batch["input_ids2"].to(accelerator.device) # split batch to avoid OOM # TODO specify batch size by args diff --git a/sdxl_train.py b/sdxl_train.py index 06cbc5710..dd5b74dda 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -204,12 +204,25 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): text_encoder2.gradient_checkpointing_enable() training_models.append(text_encoder1) training_models.append(text_encoder2) + + text_encoder1_cache = None + text_encoder2_cache = None + + # set require_grad=True later else: text_encoder1.requires_grad_(False) text_encoder2.requires_grad_(False) text_encoder1.eval() text_encoder2.eval() + # TextEncoderの出力をキャッシュする + if args.cache_text_encoder_outputs: + # Text Encodes are eval and no grad + text_encoder1_cache, text_encoder2_cache = sdxl_train_util.cache_text_encoder_outputs( + args, accelerator, (tokenizer1, tokenizer2), (text_encoder1, text_encoder2), train_dataset_group, None + ) + accelerator.wait_for_everyone() + if not cache_latents: vae.requires_grad_(False) vae.eval() @@ -289,23 +302,16 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): (unet,) = train_util.transform_models_if_DDP([unet]) text_encoder1.to(weight_dtype) text_encoder2.to(weight_dtype) - text_encoder1.eval() - text_encoder2.eval() - # TextEncoderの出力をキャッシュする + # TextEncoderの出力をキャッシュするときにはCPUへ移動する if args.cache_text_encoder_outputs: - text_encoder1_cache, text_encoder2_cache = sdxl_train_util.cache_text_encoder_outputs( - args, accelerator, (tokenizer1, tokenizer2), (text_encoder1, text_encoder2), train_dataloader, None - ) - accelerator.wait_for_everyone() - # Text Encoder doesn't work on CPU with fp16 + # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16 text_encoder1.to("cpu", dtype=torch.float32) text_encoder2.to("cpu", dtype=torch.float32) if torch.cuda.is_available(): torch.cuda.empty_cache() else: - text_encoder1_cache = None - text_encoder2_cache = None + # make sure Text Encoders are on GPU text_encoder1.to(accelerator.device) text_encoder2.to(accelerator.device) diff --git a/sdxl_train_network.py b/sdxl_train_network.py index ec15ce4bd..0c3c0cc5b 100644 --- a/sdxl_train_network.py +++ b/sdxl_train_network.py @@ -47,7 +47,7 @@ def is_text_encoder_outputs_cached(self, args): return args.cache_text_encoder_outputs def cache_text_encoder_outputs_if_needed( - self, args, accelerator, unet, vae, tokenizers, text_encoders, data_loader, weight_dtype + self, args, accelerator, unet, vae, tokenizers, text_encoders, dataset, weight_dtype ): if args.cache_text_encoder_outputs: if not args.lowram: @@ -61,7 +61,7 @@ def cache_text_encoder_outputs_if_needed( torch.cuda.empty_cache() text_encoder1_cache, text_encoder2_cache = sdxl_train_util.cache_text_encoder_outputs( - args, accelerator, tokenizers, text_encoders, data_loader, weight_dtype + args, accelerator, tokenizers, text_encoders, dataset, weight_dtype ) accelerator.wait_for_everyone() text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU diff --git a/train_network.py b/train_network.py index 3c9515b55..f7ee451b1 100644 --- a/train_network.py +++ b/train_network.py @@ -255,6 +255,11 @@ def train(self, args): accelerator.wait_for_everyone() + # 必要ならテキストエンコーダーの出力をキャッシュする: Text Encoderはcpuまたはgpuへ移される + self.cache_text_encoder_outputs_if_needed( + args, accelerator, unet, vae, tokenizers, text_encoders, train_dataset_group, weight_dtype + ) + # prepare network net_kwargs = {} if args.network_args is not None: @@ -419,11 +424,6 @@ def train(self, args): vae.eval() vae.to(accelerator.device, dtype=vae_dtype) - # 必要ならテキストエンコーダーの出力をキャッシュする: Text Encoderはcpuまたはgpuへ移される - self.cache_text_encoder_outputs_if_needed( - args, accelerator, unet, vae, tokenizers, text_encoders, train_dataloader, weight_dtype - ) - # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする if args.full_fp16: train_util.patch_accelerator_for_fp16_training(accelerator) From a380502c01a9d99281f9ea0d7486ac2b03c7147c Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 9 Jul 2023 18:13:49 +0900 Subject: [PATCH 09/20] fix pad token is not handled --- library/sdxl_lpw_stable_diffusion.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/library/sdxl_lpw_stable_diffusion.py b/library/sdxl_lpw_stable_diffusion.py index e806bc61e..10038e6a7 100644 --- a/library/sdxl_lpw_stable_diffusion.py +++ b/library/sdxl_lpw_stable_diffusion.py @@ -185,14 +185,14 @@ def get_prompts_with_weights(pipe: StableDiffusionPipeline, prompt: List[str], m return tokens, weights -def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_middle=True, chunk_length=77): +def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad, no_boseos_middle=True, chunk_length=77): r""" Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length. """ max_embeddings_multiples = (max_length - 2) // (chunk_length - 2) weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length for i in range(len(tokens)): - tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i])) + tokens[i] = [bos] + tokens[i] + [eos] + [pad] * (max_length - 2 - len(tokens[i])) if no_boseos_middle: weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i])) else: @@ -363,6 +363,7 @@ def get_weighted_text_embeddings( max_length, bos, eos, + pad, no_boseos_middle=no_boseos_middle, chunk_length=pipe.tokenizer.model_max_length, ) @@ -374,6 +375,7 @@ def get_weighted_text_embeddings( max_length, bos, eos, + pad, no_boseos_middle=no_boseos_middle, chunk_length=pipe.tokenizer.model_max_length, ) @@ -711,7 +713,7 @@ def decode_latents(self, latents): # self.vae.set_use_memory_efficient_attention_xformers(False) # image = self.vae.decode(latents.to("cpu")).sample - image = self.vae.decode(latents).sample + image = self.vae.decode(latents.to(self.vae.dtype)).sample image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() From 77ec70d145deb30cba0a9d972d9aee762fbb7268 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 9 Jul 2023 19:00:38 +0900 Subject: [PATCH 10/20] fix conditioning --- library/sdxl_lpw_stable_diffusion.py | 36 ++++++++++++++++++---------- 1 file changed, 24 insertions(+), 12 deletions(-) diff --git a/library/sdxl_lpw_stable_diffusion.py b/library/sdxl_lpw_stable_diffusion.py index 10038e6a7..7ab609d81 100644 --- a/library/sdxl_lpw_stable_diffusion.py +++ b/library/sdxl_lpw_stable_diffusion.py @@ -285,7 +285,7 @@ def get_unweighted_text_embeddings( def get_weighted_text_embeddings( - pipe: StableDiffusionPipeline, + pipe, # : SdxlStableDiffusionLongPromptWeightingPipeline, prompt: Union[str, List[str]], uncond_prompt: Optional[Union[str, List[str]]] = None, max_embeddings_multiples: Optional[int] = 3, @@ -657,11 +657,9 @@ def _encode_prompt( uncond_pool = uncond_pool.repeat(1, num_images_per_prompt) uncond_pool = uncond_pool.view(bs_embed * num_images_per_prompt, -1) - text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) - if text_pool is not None: - text_pool = torch.cat([uncond_pool, text_pool]) + return text_embeddings, text_pool, uncond_embeddings, uncond_pool - return text_embeddings, text_pool + return text_embeddings, text_pool, None, None def check_inputs(self, prompt, height, width, strength, callback_steps): if not isinstance(prompt, str) and not isinstance(prompt, list): @@ -671,7 +669,6 @@ def check_inputs(self, prompt, height, width, strength, callback_steps): raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") if height % 8 != 0 or width % 8 != 0: - print(height, width) raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if (callback_steps is None) or ( @@ -901,12 +898,14 @@ def __call__( # 実装を簡単にするためにtokenzer/text encoderを切り替えて二回呼び出す # To simplify the implementation, switch the tokenzer/text encoder and call it twice text_embeddings_list = [] - text_pools = [] + text_pool = None + uncond_embeddings_list = [] + uncond_pool = None for i in range(len(self.tokenizers)): self.tokenizer = self.tokenizers[i] self.text_encoder = self.text_encoders[i] - text_embeddings, text_pool = self._encode_prompt( + text_embeddings, tp1, uncond_embeddings, up1 = self._encode_prompt( prompt, device, num_images_per_prompt, @@ -916,7 +915,12 @@ def __call__( is_sdxl_text_encoder2=i == 1, ) text_embeddings_list.append(text_embeddings) - text_pools.append(text_pool) + uncond_embeddings_list.append(uncond_embeddings) + + if tp1 is not None: + text_pool = tp1 + if up1 is not None: + uncond_pool = up1 dtype = text_embeddings_list[0].dtype @@ -965,11 +969,19 @@ def __call__( crop_size = torch.zeros_like(orig_size) target_size = orig_size embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, device).to(dtype) + + # make conditionings if do_classifier_free_guidance: - embs = torch.cat([embs] * 2) + text_embeddings = torch.cat(text_embeddings_list, dim=2) + uncond_embeddings = torch.cat(uncond_embeddings_list, dim=2) + text_embedding = torch.cat([text_embeddings, uncond_embeddings]).to(dtype) - vector_embedding = torch.cat([text_pools[1], embs], dim=1).to(dtype) - text_embedding = torch.cat(text_embeddings_list, dim=2).to(dtype) + cond_vector = torch.cat([text_pool, embs], dim=1) + uncond_vector = torch.cat([uncond_pool, embs], dim=1) + vector_embedding = torch.cat([cond_vector, uncond_vector]).to(dtype) + else: + text_embedding = torch.cat(text_embeddings_list, dim=2).to(dtype) + vector_embedding = torch.cat([text_pool, embs], dim=1).to(dtype) # 8. Denoising loop for i, t in enumerate(self.progress_bar(timesteps)): From c2ceb6de5fc861513582642edf87834a01ce84a2 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 9 Jul 2023 21:14:12 +0900 Subject: [PATCH 11/20] fix uncond/cond order --- library/sdxl_lpw_stable_diffusion.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/library/sdxl_lpw_stable_diffusion.py b/library/sdxl_lpw_stable_diffusion.py index 7ab609d81..d44b3cf8c 100644 --- a/library/sdxl_lpw_stable_diffusion.py +++ b/library/sdxl_lpw_stable_diffusion.py @@ -974,11 +974,11 @@ def __call__( if do_classifier_free_guidance: text_embeddings = torch.cat(text_embeddings_list, dim=2) uncond_embeddings = torch.cat(uncond_embeddings_list, dim=2) - text_embedding = torch.cat([text_embeddings, uncond_embeddings]).to(dtype) + text_embedding = torch.cat([uncond_embeddings, text_embeddings]).to(dtype) cond_vector = torch.cat([text_pool, embs], dim=1) uncond_vector = torch.cat([uncond_pool, embs], dim=1) - vector_embedding = torch.cat([cond_vector, uncond_vector]).to(dtype) + vector_embedding = torch.cat([uncond_vector, cond_vector]).to(dtype) else: text_embedding = torch.cat(text_embeddings_list, dim=2).to(dtype) vector_embedding = torch.cat([text_pool, embs], dim=1).to(dtype) From 5c80117fbdcbb3f6470a95c9602323f9e28dd5e2 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 9 Jul 2023 21:37:46 +0900 Subject: [PATCH 12/20] update readme --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 8a47ee3b5..bf1d61c6e 100644 --- a/README.md +++ b/README.md @@ -65,6 +65,7 @@ learning_rate = 4e-7 # SDXL original learning rate ### TODO - [ ] Support Textual Inversion training. +- [ ] Support conversion of Diffusers SDXL models. - [ ] Support `--weighted_captions` option. - [ ] Change `--output_config` option to continue the training. - [ ] Extend `--full_bf16` for all the scripts. From b6e328ea8f22b03355ebd86ebc5190e0477864f4 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 10 Jul 2023 08:46:15 +0900 Subject: [PATCH 13/20] don't hold latent on memory for finetuning dataset --- library/train_util.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 746a7f9dd..809f0af03 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -91,6 +91,7 @@ try: import pillow_avif + IMAGE_EXTENSIONS.extend([".avif", ".AVIF"]) except: pass @@ -853,16 +854,11 @@ def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_proc # split by resolution batches = [] batch = [] - for info in image_infos: + print("checking cache validity...") + for info in tqdm(image_infos): subset = self.image_to_subset[info.image_key] - if info.latents_npz is not None: - info.latents, info.latents_original_size, info.latents_crop_left_top = self.load_latents_from_npz(info, False) - info.latents = torch.FloatTensor(info.latents) - - info.latents_flipped, _, _ = self.load_latents_from_npz(info, True) # might be None - if info.latents_flipped is not None: - info.latents_flipped = torch.FloatTensor(info.latents_flipped) + if info.latents_npz is not None: # fine tuning dataset continue # check disk cache exists and size of latents From f54b784d88246a7b3f60a46c3782f29cd892d0c7 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 10 Jul 2023 22:04:02 +0900 Subject: [PATCH 14/20] support textual inversion training --- README.md | 30 +- library/sdxl_train_util.py | 41 +- sdxl_gen_img.py | 33 +- sdxl_train_textual_inversion.py | 142 +++++ train_textual_inversion.py | 1031 +++++++++++++++++-------------- 5 files changed, 809 insertions(+), 468 deletions(-) create mode 100644 sdxl_train_textual_inversion.py diff --git a/README.md b/README.md index bf1d61c6e..4e4150b60 100644 --- a/README.md +++ b/README.md @@ -25,18 +25,31 @@ The feature of SDXL training is now available in sdxl branch as an experimental Summary of the feature: - `sdxl_train.py` is a script for SDXL fine-tuning. The usage is almost the same as `fine_tune.py`, but it also supports DreamBooth dataset. - - `--full_bf16` option is added. This option enables the full bfloat16 training. This option is useful to reduce the GPU memory usage. + - `--full_bf16` option is added. Thanks to KohakuBlueleaf! + - This option enables the full bfloat16 training (includes gradients). This option is useful to reduce the GPU memory usage. - However, bitsandbytes==0.35 doesn't seem to support this. Please use a newer version of bitsandbytes or another optimizer. - I cannot find bitsandbytes>0.35.0 that works correctly on Windows. + - In addition, the full bfloat16 training might be unstable. Please use it at your own risk. - `prepare_buckets_latents.py` now supports SDXL fine-tuning. - `sdxl_train_network.py` is a script for LoRA training for SDXL. The usage is almost the same as `train_network.py`. - Both scripts has following additional options: - `--cache_text_encoder_outputs`: Cache the outputs of the text encoders. This option is useful to reduce the GPU memory usage. This option cannot be used with options for shuffling or dropping the captions. - `--no_half_vae`: Disable the half-precision (mixed-precision) VAE. VAE for SDXL seems to produce NaNs in some cases. This option is useful to avoid the NaNs. - The image generation during training is now available. However, the VAE for SDXL seems to produce NaNs in some cases when using `fp16`. The images will be black. Currently, the NaNs cannot be avoided even with `--no_half_vae` option. It works with `bf16` or without mixed precision. -- `--weighted_captions` option is not supported yet. + +- `--weighted_captions` option is not supported yet for both scripts. - `--min_timestep` and `--max_timestep` options are added to each training script. These options can be used to train U-Net with different timesteps. The default values are 0 and 1000. + +- `sdxl_train_textual_inversion.py` is a script for Textual Inversion training for SDXL. The usage is almost the same as `train_textual_inversion.py`. + - `--cache_text_encoder_outputs` is not supported. + - `token_string` must be alphabet only currently, due to the limitation of the open-clip tokenizer. + - There are two options for captions: + 1. Training with captions. All captions must include the token string. The token string is replaced with multiple tokens. + 2. Use `--use_object_template` or `--use_style_template` option. The captions are generated from the template. The existing captions are ignored. + - See below for the format of the embeddings. + - `sdxl_gen_img.py` is added. This script can be used to generate images with SDXL, including LoRA. See the help message for the usage. + - Textual Inversion is supported, but the name for the embeds in the caption becomes alphabet only. For example, `neg_hand_v1.safetensors` can be activated with `neghandv`. `requirements.txt` is updated to support SDXL training. @@ -54,7 +67,7 @@ Summary of the feature: - `--bucket_reso_steps` can be set to 32 instead of the default value 64. Smaller values than 32 will not work for SDXL training. Example of the optimizer settings for Adafactor with the fixed learning rate: -``` +```toml optimizer_type = "adafactor" optimizer_args = [ "scale_parameter=False", "relative_step=False", "warmup_init=False" ] lr_scheduler = "constant_with_warmup" @@ -62,13 +75,22 @@ lr_warmup_steps = 100 learning_rate = 4e-7 # SDXL original learning rate ``` +### Format of Textual Inversion embeddings + +```python +from safetensors.torch import save_file + +state_dict = {"clip_g": embs_for_text_encoder_1280, "clip_l": embs_for_text_encoder_768} +save_file(state_dict, file) +``` + ### TODO -- [ ] Support Textual Inversion training. - [ ] Support conversion of Diffusers SDXL models. - [ ] Support `--weighted_captions` option. - [ ] Change `--output_config` option to continue the training. - [ ] Extend `--full_bf16` for all the scripts. +- [x] Support Textual Inversion training. ## About requirements.txt diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index 675aac3da..0ce097158 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -78,12 +78,13 @@ def _load_target_model(args: argparse.Namespace, model_version: str, weight_dtyp class WrapperTokenizer: # open clipのtokenizerをHuggingFaceのtokenizerと同じ形で使えるようにする + # make open clip tokenizer compatible with HuggingFace tokenizer def __init__(self): open_clip_tokenizer = open_clip.tokenizer._tokenizer self.model_max_length = 77 self.bos_token_id = open_clip_tokenizer.all_special_ids[0] self.eos_token_id = open_clip_tokenizer.all_special_ids[1] - self.pad_token_id = 0 # 結果から推定している + self.pad_token_id = 0 # 結果から推定している assumption from result def __call__(self, *args: Any, **kwds: Any) -> Any: return self.tokenize(*args, **kwds) @@ -107,6 +108,42 @@ def tokenize(self, text, padding=False, truncation=None, max_length=None, return input_ids = input_ids[: eos_index + 1] # include eos return SimpleNamespace(**{"input_ids": input_ids}) + # for Textual Inversion + # わりと面倒くさいな……これWeb UIとかでどうするんだろう / this is a bit annoying... how to do this in Web UI? + + def encode(self, text, add_special_tokens=False): + assert not add_special_tokens + input_ids = open_clip.tokenizer._tokenizer.encode(text) + return input_ids + + def add_tokens(self, new_tokens): + tokens_to_add = [] + for token in new_tokens: + token = token.lower() + if token + "" not in open_clip.tokenizer._tokenizer.encoder: + tokens_to_add.append(token) + + # open clipのtokenizerに直接追加する / add tokens to open clip tokenizer + for token in tokens_to_add: + open_clip.tokenizer._tokenizer.encoder[token + ""] = len(open_clip.tokenizer._tokenizer.encoder) + open_clip.tokenizer._tokenizer.decoder[len(open_clip.tokenizer._tokenizer.decoder)] = token + "" + open_clip.tokenizer._tokenizer.vocab_size += 1 + + # open clipのtokenizerのcacheに直接設定することで、bpeとかいうやつに含まれていなくてもtokenizeできるようにする + # めちゃくちゃ乱暴なので、open clipのtokenizerの仕様が変わったら動かなくなる + # set cache of open clip tokenizer directly to enable tokenization even if the token is not included in bpe + # this is very rough, so it will not work if the specification of open clip tokenizer changes + open_clip.tokenizer._tokenizer.cache[token] = token + "" + + return len(tokens_to_add) + + def convert_tokens_to_ids(self, tokens): + input_ids = [open_clip.tokenizer._tokenizer.encoder[token + ""] for token in tokens] + return input_ids + + def __len__(self): + return open_clip.tokenizer._tokenizer.vocab_size + def load_tokenizers(args: argparse.Namespace): print("prepare tokenizers") @@ -392,7 +429,7 @@ def verify_sdxl_training_args(args: argparse.Namespace): print(f"noise_offset is set to {args.noise_offset} / noise_offsetが{args.noise_offset}に設定されました") assert ( - not args.weighted_captions + not hasattr(args, "weighted_captions") or not args.weighted_captions ), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません" diff --git a/sdxl_gen_img.py b/sdxl_gen_img.py index 8f1c17d60..1e20595cc 100644 --- a/sdxl_gen_img.py +++ b/sdxl_gen_img.py @@ -320,7 +320,7 @@ def __init__( self.scheduler = scheduler self.safety_checker = None - # Textual Inversion # not tested yet + # Textual Inversion self.token_replacements_list = [] for _ in range(len(self.text_encoders)): self.token_replacements_list.append({}) @@ -341,6 +341,10 @@ def get_token_replacer(self, tokenizer): token_replacements = self.token_replacements_list[tokenizer_index] def replace_tokens(tokens): + # print("replace_tokens", tokens, "=>", token_replacements) + if isinstance(tokens, torch.Tensor): + tokens = tokens.tolist() + new_tokens = [] for token in tokens: if token in token_replacements: @@ -1594,19 +1598,26 @@ def __getattr__(self, item): if "string_to_param" in data: data = data["string_to_param"] - embeds1 = data["clip_l"] - embeds2 = data["clip_g"] + + embeds1 = data["clip_l"] # text encoder 1 + embeds2 = data["clip_g"] # text encoder 2 num_vectors_per_token = embeds1.size()[0] token_string = os.path.splitext(os.path.basename(embeds_file))[0] - token_strings = [token_string] + [f"{token_string}{i+1}" for i in range(num_vectors_per_token - 1)] + + # remove non-alphabet characters to avoid splitting by tokenizer + # TODO make random alphabet string + token_string = "".join([c for c in token_string if c.isalpha()]) + + token_strings = [token_string] + [f"{token_string}{chr(ord('a') + i)}" for i in range(num_vectors_per_token - 1)] # add new word to tokenizer, count is num_vectors_per_token num_added_tokens1 = tokenizer1.add_tokens(token_strings) - num_added_tokens2 = tokenizer2.add_tokens(token_strings) # not working now - assert ( - num_added_tokens1 == num_vectors_per_token and num_added_tokens2 == num_vectors_per_token - ), f"tokenizer has same word to token string (filename). please rename the file / 指定した名前(ファイル名)のトークンが既に存在します。ファイルをリネームしてください: {embeds_file}" + num_added_tokens2 = tokenizer2.add_tokens(token_strings) + assert num_added_tokens1 == num_vectors_per_token and num_added_tokens2 == num_vectors_per_token, ( + f"tokenizer has same word to token string (filename). characters except alphabet are removed: {embeds_file}" + + f" / 指定した名前(ファイル名)のトークンが既に存在します。アルファベット以外の文字は削除されます: {embeds_file}" + ) token_ids1 = tokenizer1.convert_tokens_to_ids(token_strings) token_ids2 = tokenizer2.convert_tokens_to_ids(token_strings) @@ -1617,11 +1628,11 @@ def __getattr__(self, item): assert ( min(token_ids2) == token_ids2[0] and token_ids2[-1] == token_ids2[0] + len(token_ids2) - 1 ), f"token ids2 is not ordered" - assert len(tokenizer1) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer1)}" - assert len(tokenizer2) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer2)}" + assert len(tokenizer1) - 1 == token_ids1[-1], f"token ids 1 is not end of tokenize: {len(tokenizer1)}" + assert len(tokenizer2) - 1 == token_ids2[-1], f"token ids 2 is not end of tokenize: {len(tokenizer2)}" if num_vectors_per_token > 1: - pipe.add_token_replacement(0, token_ids1[0], token_ids1) + pipe.add_token_replacement(0, token_ids1[0], token_ids1) # hoge -> hoge, hogea, hogeb, ... pipe.add_token_replacement(1, token_ids2[0], token_ids2) token_ids_embeds1.append((token_ids1, embeds1)) diff --git a/sdxl_train_textual_inversion.py b/sdxl_train_textual_inversion.py new file mode 100644 index 000000000..9df370927 --- /dev/null +++ b/sdxl_train_textual_inversion.py @@ -0,0 +1,142 @@ +import argparse +import os + +import regex +import torch +import open_clip +from library import sdxl_model_util, sdxl_train_util, train_util + +import train_textual_inversion + + +class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTrainer): + def __init__(self): + super().__init__() + self.vae_scale_factor = sdxl_model_util.VAE_SCALE_FACTOR + + def assert_extra_args(self, args, train_dataset_group): + super().assert_extra_args(args, train_dataset_group) + sdxl_train_util.verify_sdxl_training_args(args) + + def load_target_model(self, args, weight_dtype, accelerator): + ( + load_stable_diffusion_format, + text_encoder1, + text_encoder2, + vae, + unet, + logit_scale, + ckpt_info, + ) = sdxl_train_util.load_target_model(args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, weight_dtype) + + self.load_stable_diffusion_format = load_stable_diffusion_format + self.logit_scale = logit_scale + self.ckpt_info = ckpt_info + + return sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, [text_encoder1, text_encoder2], vae, unet + + def load_tokenizer(self, args): + tokenizer = sdxl_train_util.load_tokenizers(args) + return tokenizer + + def assert_token_string(self, token_string, tokenizers): + # tokenizer 1 is seems to be ok + + # count words for token string: regular expression from open_clip + pat = regex.compile(r"""'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", regex.IGNORECASE) + words = regex.findall(pat, token_string) + word_count = len(words) + assert word_count == 1, ( + f"token string {token_string} contain {word_count} words, please don't use digits, punctuation, or special characters" + + f" / トークン文字列 {token_string} には{word_count}個の単語が含まれています。数字、句読点、特殊文字は使用しないでください" + ) + + def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype): + input_ids1 = batch["input_ids"] + input_ids2 = batch["input_ids2"] + with torch.enable_grad(): + input_ids1 = input_ids1.to(accelerator.device) + input_ids2 = input_ids2.to(accelerator.device) + encoder_hidden_states1, encoder_hidden_states2, pool2 = sdxl_train_util.get_hidden_states( + args, + input_ids1, + input_ids2, + tokenizers[0], + tokenizers[1], + text_encoders[0], + text_encoders[1], + None if not args.full_fp16 else weight_dtype, + ) + return encoder_hidden_states1, encoder_hidden_states2, pool2 + + def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype): + noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype + + # get size embeddings + orig_size = batch["original_sizes_hw"] + crop_size = batch["crop_top_lefts"] + target_size = batch["target_sizes_hw"] + embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype) + + # concat embeddings + encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds + vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype) + text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype) + + noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding) + return noise_pred + + def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement): + sdxl_train_util.sample_images( + accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement + ) + + def save_weights(self, file, updated_embs, save_dtype): + state_dict = {"clip_l": updated_embs[0], "clip_g": updated_embs[1]} + + if save_dtype is not None: + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(save_dtype) + state_dict[key] = v + + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import save_file + + save_file(state_dict, file) + else: + torch.save(state_dict, file) + + def load_weights(self, file): + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file + + data = load_file(file) + else: + data = torch.load(file, map_location="cpu") + + emb_l = data.get("clib_l", None) # ViT-L text encoder 1 + emb_g = data.get("clib_g", None) # BiG-G text encoder 2 + + assert ( + emb_l is not None or emb_g is not None + ), f"weight file does not contains weights for text encoder 1 or 2 / 重みファイルにテキストエンコーダー1または2の重みが含まれていません: {file}" + + return [emb_l, emb_g] + + +def setup_parser() -> argparse.ArgumentParser: + parser = train_textual_inversion.setup_parser() + # don't add sdxl_train_util.add_sdxl_training_arguments(parser): because it only adds text encoder caching + # sdxl_train_util.add_sdxl_training_arguments(parser) + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + args = train_util.read_config_from_file(args, parser) + + trainer = SdxlTextualInversionTrainer() + trainer.train(args) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index ecfaeb4fa..09294048f 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -8,6 +8,7 @@ import torch from accelerate.utils import set_seed from diffusers import DDPMScheduler +from library import model_util import library.train_util as train_util import library.huggingface_util as huggingface_util @@ -20,8 +21,6 @@ from library.custom_train_functions import ( apply_snr_weight, prepare_scheduler_for_custom_training, - pyramid_noise_like, - apply_noise_offset, scale_v_prediction_loss_like_noise_prediction, ) @@ -78,503 +77,632 @@ ] -def train(args): - if args.output_name is None: - args.output_name = args.token_string - use_template = args.use_object_template or args.use_style_template +class TextualInversionTrainer: + def __init__(self): + self.vae_scale_factor = 0.18215 - train_util.verify_training_args(args) - train_util.prepare_dataset_args(args, True) + def assert_extra_args(self, args, train_dataset_group): + pass - cache_latents = args.cache_latents + def load_target_model(self, args, weight_dtype, accelerator): + text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator) + return model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization), text_encoder, vae, unet - if args.seed is not None: - set_seed(args.seed) + def load_tokenizer(self, args): + tokenizer = train_util.load_tokenizer(args) + return tokenizer - tokenizer = train_util.load_tokenizer(args) + def assert_token_string(self, token_string, tokenizers): + pass + + def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype): + with torch.enable_grad(): + input_ids = batch["input_ids"].to(accelerator.device) + encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizers[0], text_encoders[0], None) + return encoder_hidden_states - # acceleratorを準備する - print("prepare accelerator") - accelerator = train_util.prepare_accelerator(args) + def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype): + noise_pred = unet(noisy_latents, timesteps, text_conds).sample + return noise_pred - # mixed precisionに対応した型を用意しておき適宜castする - weight_dtype, save_dtype = train_util.prepare_dtype(args) + def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement): + train_util.sample_images( + accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement + ) - # モデルを読み込む - text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator) + def save_weights(self, file, updated_embs, save_dtype): + state_dict = {"emb_params": updated_embs[0]} - # Convert the init_word to token_id - if args.init_word is not None: - init_token_ids = tokenizer.encode(args.init_word, add_special_tokens=False) - if len(init_token_ids) > 1 and len(init_token_ids) != args.num_vectors_per_token: - accelerator.print( - f"token length for init words is not same to num_vectors_per_token, init words is repeated or truncated / 初期化単語のトークン長がnum_vectors_per_tokenと合わないため、繰り返しまたは切り捨てが発生します: length {len(init_token_ids)}" - ) - else: - init_token_ids = None - - # add new word to tokenizer, count is num_vectors_per_token - token_strings = [args.token_string] + [f"{args.token_string}{i+1}" for i in range(args.num_vectors_per_token - 1)] - num_added_tokens = tokenizer.add_tokens(token_strings) - assert ( - num_added_tokens == args.num_vectors_per_token - ), f"tokenizer has same word to token string. please use another one / 指定したargs.token_stringは既に存在します。別の単語を使ってください: {args.token_string}" - - token_ids = tokenizer.convert_tokens_to_ids(token_strings) - accelerator.print(f"tokens are added: {token_ids}") - assert min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1, f"token ids is not ordered" - assert len(tokenizer) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer)}" - - # Resize the token embeddings as we are adding new special tokens to the tokenizer - text_encoder.resize_token_embeddings(len(tokenizer)) - - # Initialise the newly added placeholder token with the embeddings of the initializer token - token_embeds = text_encoder.get_input_embeddings().weight.data - if init_token_ids is not None: - for i, token_id in enumerate(token_ids): - token_embeds[token_id] = token_embeds[init_token_ids[i % len(init_token_ids)]] - # accelerator.print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min()) - - # load weights - if args.weights is not None: - embeddings = load_weights(args.weights) - assert len(token_ids) == len( - embeddings - ), f"num_vectors_per_token is mismatch for weights / 指定した重みとnum_vectors_per_tokenの値が異なります: {len(embeddings)}" - # accelerator.print(token_ids, embeddings.size()) - for token_id, embedding in zip(token_ids, embeddings): - token_embeds[token_id] = embedding - # accelerator.print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min()) - accelerator.print(f"weighs loaded") - - accelerator.print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}") - - # データセットを準備する - if args.dataset_class is None: - blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, False)) - if args.dataset_config is not None: - accelerator.print(f"Load dataset config from {args.dataset_config}") - user_config = config_util.load_user_config(args.dataset_config) - ignored = ["train_data_dir", "reg_data_dir", "in_json"] - if any(getattr(args, attr) is not None for attr in ignored): - accelerator.print( - "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( - ", ".join(ignored) - ) - ) - else: - use_dreambooth_method = args.in_json is None - if use_dreambooth_method: - accelerator.print("Use DreamBooth method.") - user_config = { - "datasets": [ - { - "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs( - args.train_data_dir, args.reg_data_dir - ) - } - ] - } - else: - print("Train with captions.") - user_config = { - "datasets": [ - { - "subsets": [ - { - "image_dir": args.train_data_dir, - "metadata_file": args.in_json, - } - ] - } - ] - } - - blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) - else: - train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer) - - current_epoch = Value("i", 0) - current_step = Value("i", 0) - ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None - collater = train_util.collater_class(current_epoch, current_step, ds_for_collater) - - # make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装 - if use_template: - accelerator.print("use template for training captions. is object: {args.use_object_template}") - templates = imagenet_templates_small if args.use_object_template else imagenet_style_templates_small - replace_to = " ".join(token_strings) - captions = [] - for tmpl in templates: - captions.append(tmpl.format(replace_to)) - train_dataset_group.add_replacement("", captions) - - if args.num_vectors_per_token > 1: - prompt_replacement = (args.token_string, replace_to) - else: - prompt_replacement = None - else: - if args.num_vectors_per_token > 1: - replace_to = " ".join(token_strings) - train_dataset_group.add_replacement(args.token_string, replace_to) - prompt_replacement = (args.token_string, replace_to) - else: - prompt_replacement = None - - if args.debug_dataset: - train_util.debug_dataset(train_dataset_group, show_input_ids=True) - return - if len(train_dataset_group) == 0: - accelerator.print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください") - return - - if cache_latents: - assert ( - train_dataset_group.is_latent_cacheable() - ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" - - # モデルに xformers とか memory efficient attention を組み込む - train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) - - # 学習を準備する - if cache_latents: - vae.to(accelerator.device, dtype=weight_dtype) - vae.requires_grad_(False) - vae.eval() - with torch.no_grad(): - train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) - vae.to("cpu") - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() - - accelerator.wait_for_everyone() - - if args.gradient_checkpointing: - unet.enable_gradient_checkpointing() - text_encoder.gradient_checkpointing_enable() - - # 学習に必要なクラスを準備する - accelerator.print("prepare optimizer, data loader etc.") - trainable_params = text_encoder.get_input_embeddings().parameters() - _, _, optimizer = train_util.get_optimizer(args, trainable_params) - - # dataloaderを準備する - # DataLoaderのプロセス数:0はメインプロセスになる - n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで - train_dataloader = torch.utils.data.DataLoader( - train_dataset_group, - batch_size=1, - shuffle=True, - collate_fn=collater, - num_workers=n_workers, - persistent_workers=args.persistent_data_loader_workers, - ) + if save_dtype is not None: + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(save_dtype) + state_dict[key] = v - # 学習ステップ数を計算する - if args.max_train_epochs is not None: - args.max_train_steps = args.max_train_epochs * math.ceil( - len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps - ) - accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import save_file - # データセット側にも学習ステップを送信 - train_dataset_group.set_max_train_steps(args.max_train_steps) + save_file(state_dict, file) + else: + torch.save(state_dict, file) # can be loaded in Web UI - # lr schedulerを用意する - lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) + def load_weights(self, file): + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file - # acceleratorがなんかよろしくやってくれるらしい - text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - text_encoder, optimizer, train_dataloader, lr_scheduler - ) + data = load_file(file) + else: + # compatible to Web UI's file format + data = torch.load(file, map_location="cpu") + if type(data) != dict: + raise ValueError(f"weight file is not dict / 重みファイルがdict形式ではありません: {file}") - # transform DDP after prepare - text_encoder, unet = train_util.transform_if_model_is_DDP(text_encoder, unet) - - index_no_updates = torch.arange(len(tokenizer)) < token_ids[0] - # accelerator.print(len(index_no_updates), torch.sum(index_no_updates)) - orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone() - - # Freeze all parameters except for the token embeddings in text encoder - text_encoder.requires_grad_(True) - text_encoder.text_model.encoder.requires_grad_(False) - text_encoder.text_model.final_layer_norm.requires_grad_(False) - text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) - # text_encoder.text_model.embeddings.token_embedding.requires_grad_(True) - - unet.requires_grad_(False) - unet.to(accelerator.device, dtype=weight_dtype) - if args.gradient_checkpointing: # according to TI example in Diffusers, train is required - unet.train() - else: - unet.eval() - - if not cache_latents: - vae.requires_grad_(False) - vae.eval() - vae.to(accelerator.device, dtype=weight_dtype) - - # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする - if args.full_fp16: - train_util.patch_accelerator_for_fp16_training(accelerator) - text_encoder.to(weight_dtype) - - # resumeする - train_util.resume_from_local_or_hf_if_specified(accelerator, args) - - # epoch数を計算する - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) - if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): - args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 - - # 学習する - total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps - accelerator.print("running training / 学習開始") - accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") - accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") - accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") - accelerator.print(f" num epochs / epoch数: {num_train_epochs}") - accelerator.print(f" batch size per device / バッチサイズ: {args.train_batch_size}") - accelerator.print( - f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}" - ) - accelerator.print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") - accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") + if "string_to_param" in data: # textual inversion embeddings + data = data["string_to_param"] + if hasattr(data, "_parameters"): # support old PyTorch? + data = getattr(data, "_parameters") - progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") - global_step = 0 + emb = next(iter(data.values())) + if type(emb) != torch.Tensor: + raise ValueError(f"weight file does not contains Tensor / 重みファイルのデータがTensorではありません: {file}") - noise_scheduler = DDPMScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False - ) - prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) - - if accelerator.is_main_process: - accelerator.init_trackers("textual_inversion" if args.log_tracker_name is None else args.log_tracker_name) - - # function for saving/removing - def save_model(ckpt_name, embs, steps, epoch_no, force_sync_upload=False): - os.makedirs(args.output_dir, exist_ok=True) - ckpt_file = os.path.join(args.output_dir, ckpt_name) - - accelerator.print(f"\nsaving checkpoint: {ckpt_file}") - save_weights(ckpt_file, embs, save_dtype) - if args.huggingface_repo_id is not None: - huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload) - - def remove_model(old_ckpt_name): - old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) - if os.path.exists(old_ckpt_file): - accelerator.print(f"removing old checkpoint: {old_ckpt_file}") - os.remove(old_ckpt_file) - - # training loop - for epoch in range(num_train_epochs): - accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") - current_epoch.value = epoch + 1 - - text_encoder.train() - - loss_total = 0 - - for step, batch in enumerate(train_dataloader): - current_step.value = global_step - with accelerator.accumulate(text_encoder): - with torch.no_grad(): - if "latents" in batch and batch["latents"] is not None: - latents = batch["latents"].to(accelerator.device) - else: - # latentに変換 - latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() - latents = latents * 0.18215 - b_size = latents.shape[0] - - # Get the text embedding for conditioning - input_ids = batch["input_ids"].to(accelerator.device) - # use float instead of fp16/bf16 because text encoder is float - encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, torch.float) - - # Sample noise, sample a random timestep for each image, and add noise to the latents, - # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) - - # Predict the noise residual - with accelerator.autocast(): - noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - - if args.v_parameterization: - # v-parameterization training - target = noise_scheduler.get_velocity(latents, noise, timesteps) - else: - target = noise - - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") - loss = loss.mean([1, 2, 3]) - - loss_weights = batch["loss_weights"] # 各sampleごとのweight - loss = loss * loss_weights - - if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) - if args.scale_v_pred_loss_like_noise_pred: - loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) - - loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし - - accelerator.backward(loss) - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - params_to_clip = text_encoder.get_input_embeddings().parameters() - accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) - - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad(set_to_none=True) - - # Let's make sure we don't update any embedding weights besides the newly added token - with torch.no_grad(): - accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = orig_embeds_params[ - index_no_updates - ] - - # Checks if the accelerator has performed an optimization step behind the scenes - if accelerator.sync_gradients: - progress_bar.update(1) - global_step += 1 - - train_util.sample_images( - accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement - ) - - # 指定ステップごとにモデルを保存 - if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: - accelerator.wait_for_everyone() - if accelerator.is_main_process: - updated_embs = ( - accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone() - ) + if len(emb.size()) == 1: + emb = emb.unsqueeze(0) - ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step) - save_model(ckpt_name, updated_embs, global_step, epoch) + return [emb] - if args.save_state: - train_util.save_and_remove_state_stepwise(args, accelerator, global_step) + def train(self, args): + if args.output_name is None: + args.output_name = args.token_string + use_template = args.use_object_template or args.use_style_template - remove_step_no = train_util.get_remove_step_no(args, global_step) - if remove_step_no is not None: - remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no) - remove_model(remove_ckpt_name) + train_util.verify_training_args(args) + train_util.prepare_dataset_args(args, True) - current_loss = loss.detach().item() - if args.logging_dir is not None: - logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} - if ( - args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower() - ): # tracking d*lr value - logs["lr/d*lr"] = ( - lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"] - ) - accelerator.log(logs, step=global_step) + cache_latents = args.cache_latents - loss_total += current_loss - avr_loss = loss_total / (step + 1) - logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} - progress_bar.set_postfix(**logs) + if args.seed is not None: + set_seed(args.seed) - if global_step >= args.max_train_steps: - break + tokenizer_or_list = self.load_tokenizer(args) # list of tokenizer or tokenizer + tokenizers = tokenizer_or_list if isinstance(tokenizer_or_list, list) else [tokenizer_or_list] - if args.logging_dir is not None: - logs = {"loss/epoch": loss_total / len(train_dataloader)} - accelerator.log(logs, step=epoch + 1) + # acceleratorを準備する + print("prepare accelerator") + accelerator = train_util.prepare_accelerator(args) - accelerator.wait_for_everyone() + # mixed precisionに対応した型を用意しておき適宜castする + weight_dtype, save_dtype = train_util.prepare_dtype(args) - updated_embs = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone() + # モデルを読み込む + model_version, text_encoder_or_list, vae, unet = self.load_target_model(args, weight_dtype, accelerator) + text_encoders = [text_encoder_or_list] if not isinstance(text_encoder_or_list, list) else text_encoder_or_list - if args.save_every_n_epochs is not None: - saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs - if accelerator.is_main_process and saving: - ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1) - save_model(ckpt_name, updated_embs, epoch + 1, global_step) + if len(text_encoders) > 1 and args.gradient_accumulation_steps > 1: + accelerator.print( + "accelerate doesn't seem to support gradient_accumulation_steps for multiple models (text encoders) / " + + "accelerateでは複数のモデル(テキストエンコーダー)のgradient_accumulation_stepsはサポートされていないようです" + ) - remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1) - if remove_epoch_no is not None: - remove_ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, remove_epoch_no) - remove_model(remove_ckpt_name) + # Convert the init_word to token_id + init_token_ids_list = [] + if args.init_word is not None: + for i, tokenizer in enumerate(tokenizers): + init_token_ids = tokenizer.encode(args.init_word, add_special_tokens=False) + if len(init_token_ids) > 1 and len(init_token_ids) != args.num_vectors_per_token: + accelerator.print( + f"token length for init words is not same to num_vectors_per_token, init words is repeated or truncated / " + + f"初期化単語のトークン長がnum_vectors_per_tokenと合わないため、繰り返しまたは切り捨てが発生します: tokenizer {i+1}, length {len(init_token_ids)}" + ) + init_token_ids_list.append(init_token_ids) + else: + init_token_ids_list = [None] * len(tokenizers) + + # tokenizerに新しい単語を追加する。追加する単語の数はnum_vectors_per_token + # add new word to tokenizer, count is num_vectors_per_token + + # token_stringが hoge の場合、"hoge", "hogea", "hogeb", ... が追加される + # 当初は "hoge", "hoge1", "hoge2", ... としていたが、open clipのtokenizerは数字を含む単語を分割してしまうため(;^ω^)、a, b, ... とした + + # if token_string is hoge, "hoge", "hogea", "hogeb", ... are added + # originally, "hoge", "hoge1", "hoge2", ... were used, but open clip's tokenizer splits words including numbers (;^ω^), so a, b, ... are used + + self.assert_token_string(args.token_string, tokenizers) + + token_strings = [args.token_string] + [ + f"{args.token_string}{chr(ord('a') + i)}" for i in range(args.num_vectors_per_token - 1) + ] + token_ids_list = [] + token_embeds_list = [] + for i, (tokenizer, text_encoder, init_token_ids) in enumerate(zip(tokenizers, text_encoders, init_token_ids_list)): + num_added_tokens = tokenizer.add_tokens(token_strings) + assert ( + num_added_tokens == args.num_vectors_per_token + ), f"tokenizer has same word to token string. please use another one / 指定したargs.token_stringは既に存在します。別の単語を使ってください: tokenizer {i+1}, {args.token_string}" + + token_ids = tokenizer.convert_tokens_to_ids(token_strings) + accelerator.print(f"tokens are added for tokenizer {i+1}: {token_ids}") + assert ( + min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1 + ), f"token ids is not ordered : tokenizer {i+1}, {token_ids}" + assert ( + len(tokenizer) - 1 == token_ids[-1] + ), f"token ids is not end of tokenize: tokenizer {i+1}, {token_ids}, {len(tokenizer)}" + token_ids_list.append(token_ids) + + # Resize the token embeddings as we are adding new special tokens to the tokenizer + text_encoder.resize_token_embeddings(len(tokenizer)) + + # Initialise the newly added placeholder token with the embeddings of the initializer token + token_embeds = text_encoder.get_input_embeddings().weight.data + if init_token_ids is not None: + for i, token_id in enumerate(token_ids): + token_embeds[token_id] = token_embeds[init_token_ids[i % len(init_token_ids)]] + # accelerator.print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min()) + token_embeds_list.append(token_embeds) + + # load weights + if args.weights is not None: + embeddings_list = self.load_weights(args.weights) + assert len(token_ids) == len( + embeddings_list[0] + ), f"num_vectors_per_token is mismatch for weights / 指定した重みとnum_vectors_per_tokenの値が異なります: {len(embeddings)}" + # accelerator.print(token_ids, embeddings.size()) + for token_ids, embeddings, token_embeds in zip(token_ids_list, embeddings_list, token_embeds_list): + for token_id, embedding in zip(token_ids, embeddings): + token_embeds[token_id] = embedding + # accelerator.print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min()) + accelerator.print(f"weighs loaded") + + accelerator.print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}") + + # データセットを準備する + if args.dataset_class is None: + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, False)) + if args.dataset_config is not None: + accelerator.print(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "reg_data_dir", "in_json"] + if any(getattr(args, attr) is not None for attr in ignored): + accelerator.print( + "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + ", ".join(ignored) + ) + ) + else: + use_dreambooth_method = args.in_json is None + if use_dreambooth_method: + accelerator.print("Use DreamBooth method.") + user_config = { + "datasets": [ + { + "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs( + args.train_data_dir, args.reg_data_dir + ) + } + ] + } + else: + print("Train with captions.") + user_config = { + "datasets": [ + { + "subsets": [ + { + "image_dir": args.train_data_dir, + "metadata_file": args.in_json, + } + ] + } + ] + } + + blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer_or_list) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + else: + train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer_or_list) - if args.save_state: - train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1) + self.assert_extra_args(args, train_dataset_group) - train_util.sample_images( - accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement + current_epoch = Value("i", 0) + current_step = Value("i", 0) + ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None + collater = train_util.collater_class(current_epoch, current_step, ds_for_collater) + + # make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装 + if use_template: + accelerator.print("use template for training captions. is object: {args.use_object_template}") + templates = imagenet_templates_small if args.use_object_template else imagenet_style_templates_small + replace_to = " ".join(token_strings) + captions = [] + for tmpl in templates: + captions.append(tmpl.format(replace_to)) + train_dataset_group.add_replacement("", captions) + + # サンプル生成用 + if args.num_vectors_per_token > 1: + prompt_replacement = (args.token_string, replace_to) + else: + prompt_replacement = None + else: + # サンプル生成用 + if args.num_vectors_per_token > 1: + replace_to = " ".join(token_strings) + train_dataset_group.add_replacement(args.token_string, replace_to) + prompt_replacement = (args.token_string, replace_to) + else: + prompt_replacement = None + + if args.debug_dataset: + train_util.debug_dataset(train_dataset_group, show_input_ids=True) + return + if len(train_dataset_group) == 0: + accelerator.print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください") + return + + if cache_latents: + assert ( + train_dataset_group.is_latent_cacheable() + ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + + # モデルに xformers とか memory efficient attention を組み込む + train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) + vae.set_use_memory_efficient_attention_xformers(args.xformers) + + # 学習を準備する + if cache_latents: + vae.to(accelerator.device, dtype=weight_dtype) + vae.requires_grad_(False) + vae.eval() + with torch.no_grad(): + train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) + vae.to("cpu") + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + + accelerator.wait_for_everyone() + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + for text_encoder in text_encoders: + text_encoder.gradient_checkpointing_enable() + + # 学習に必要なクラスを準備する + accelerator.print("prepare optimizer, data loader etc.") + trainable_params = [] + for text_encoder in text_encoders: + trainable_params += text_encoder.get_input_embeddings().parameters() + _, _, optimizer = train_util.get_optimizer(args, trainable_params) + + # dataloaderを準備する + # DataLoaderのプロセス数:0はメインプロセスになる + n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで + train_dataloader = torch.utils.data.DataLoader( + train_dataset_group, + batch_size=1, + shuffle=True, + collate_fn=collater, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, ) - # end of epoch + # 学習ステップ数を計算する + if args.max_train_epochs is not None: + args.max_train_steps = args.max_train_epochs * math.ceil( + len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps + ) + accelerator.print( + f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + ) - is_main_process = accelerator.is_main_process - if is_main_process: - text_encoder = accelerator.unwrap_model(text_encoder) + # データセット側にも学習ステップを送信 + train_dataset_group.set_max_train_steps(args.max_train_steps) - accelerator.end_training() + # lr schedulerを用意する + lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) - if args.save_state and is_main_process: - train_util.save_state_on_train_end(args, accelerator) + # acceleratorがなんかよろしくやってくれるらしい + if len(text_encoders) == 1: + text_encoder_or_list, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + text_encoder_or_list, optimizer, train_dataloader, lr_scheduler + ) + # transform DDP after prepare + text_encoder_or_list, unet = train_util.transform_if_model_is_DDP(text_encoder_or_list, unet) - updated_embs = text_encoder.get_input_embeddings().weight[token_ids].data.detach().clone() + elif len(text_encoders) == 2: + text_encoder1, text_encoder2, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + text_encoders[0], text_encoders[1], optimizer, train_dataloader, lr_scheduler + ) + # transform DDP after prepare + text_encoder1, text_encoder2, unet = train_util.transform_if_model_is_DDP(text_encoder1, text_encoder2, unet) - del accelerator # この後メモリを使うのでこれは消す + text_encoder_or_list = text_encoders = [text_encoder1, text_encoder2] - if is_main_process: - ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) - save_model(ckpt_name, updated_embs, global_step, num_train_epochs, force_sync_upload=True) + else: + raise NotImplementedError() + + index_no_updates_list = [] + orig_embeds_params_list = [] + for tokenizer, token_ids, text_encoder in zip(tokenizers, token_ids_list, text_encoders): + index_no_updates = torch.arange(len(tokenizer)) < token_ids[0] + index_no_updates_list.append(index_no_updates) + + # accelerator.print(len(index_no_updates), torch.sum(index_no_updates)) + orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone() + orig_embeds_params_list.append(orig_embeds_params) + + # Freeze all parameters except for the token embeddings in text encoder + text_encoder.requires_grad_(True) + text_encoder.text_model.encoder.requires_grad_(False) + text_encoder.text_model.final_layer_norm.requires_grad_(False) + text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) + # text_encoder.text_model.embeddings.token_embedding.requires_grad_(True) + + unet.requires_grad_(False) + unet.to(accelerator.device, dtype=weight_dtype) + if args.gradient_checkpointing: # according to TI example in Diffusers, train is required + # TODO U-Netをオリジナルに置き換えたのでいらないはずなので、後で確認して消す + unet.train() + else: + unet.eval() + + if not cache_latents: + vae.requires_grad_(False) + vae.eval() + vae.to(accelerator.device, dtype=weight_dtype) + + # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする + if args.full_fp16: + train_util.patch_accelerator_for_fp16_training(accelerator) + for text_encoder in text_encoders: + text_encoder.to(weight_dtype) + if args.full_bf16: + for text_encoder in text_encoders: + text_encoder.to(weight_dtype) + + # resumeする + train_util.resume_from_local_or_hf_if_specified(accelerator, args) + + # epoch数を計算する + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): + args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 + + # 学習する + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + accelerator.print("running training / 学習開始") + accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") + accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") + accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + accelerator.print(f" num epochs / epoch数: {num_train_epochs}") + accelerator.print(f" batch size per device / バッチサイズ: {args.train_batch_size}") + accelerator.print( + f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}" + ) + accelerator.print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") + accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") - print("model saved.") + progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") + global_step = 0 + noise_scheduler = DDPMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False + ) + prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) + + if accelerator.is_main_process: + accelerator.init_trackers("textual_inversion" if args.log_tracker_name is None else args.log_tracker_name) + + # function for saving/removing + def save_model(ckpt_name, embs_list, steps, epoch_no, force_sync_upload=False): + os.makedirs(args.output_dir, exist_ok=True) + ckpt_file = os.path.join(args.output_dir, ckpt_name) + + accelerator.print(f"\nsaving checkpoint: {ckpt_file}") + self.save_weights(ckpt_file, embs_list, save_dtype) + if args.huggingface_repo_id is not None: + huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload) + + def remove_model(old_ckpt_name): + old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) + if os.path.exists(old_ckpt_file): + accelerator.print(f"removing old checkpoint: {old_ckpt_file}") + os.remove(old_ckpt_file) + + # training loop + for epoch in range(num_train_epochs): + accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") + current_epoch.value = epoch + 1 + + for text_encoder in text_encoders: + text_encoder.train() + + loss_total = 0 + + for step, batch in enumerate(train_dataloader): + current_step.value = global_step + with accelerator.accumulate(text_encoders[0]): + with torch.no_grad(): + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device) + else: + # latentに変換 + latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() + latents = latents * self.vae_scale_factor + + # Get the text embedding for conditioning + text_encoder_conds = self.get_text_cond(args, accelerator, batch, tokenizers, text_encoders, weight_dtype) + + # Sample noise, sample a random timestep for each image, and add noise to the latents, + # with noise offset and/or multires noise if specified + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps( + args, noise_scheduler, latents + ) -def save_weights(file, updated_embs, save_dtype): - state_dict = {"emb_params": updated_embs} + # Predict the noise residual + with accelerator.autocast(): + noise_pred = self.call_unet( + args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype + ) - if save_dtype is not None: - for key in list(state_dict.keys()): - v = state_dict[key] - v = v.detach().clone().to("cpu").to(save_dtype) - state_dict[key] = v + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise + + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = loss.mean([1, 2, 3]) + + loss_weights = batch["loss_weights"] # 各sampleごとのweight + loss = loss * loss_weights + + if args.min_snr_gamma: + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + if args.scale_v_pred_loss_like_noise_pred: + loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) + + loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし + + accelerator.backward(loss) + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + params_to_clip = text_encoder.get_input_embeddings().parameters() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + + # Let's make sure we don't update any embedding weights besides the newly added token + with torch.no_grad(): + for text_encoder, orig_embeds_params, index_no_updates in zip( + text_encoders, orig_embeds_params_list, index_no_updates_list + ): + accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[ + index_no_updates + ] = orig_embeds_params[index_no_updates] + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + self.sample_images( + accelerator, + args, + None, + global_step, + accelerator.device, + vae, + tokenizer_or_list, + text_encoder_or_list, + unet, + prompt_replacement, + ) + + # 指定ステップごとにモデルを保存 + if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: + accelerator.wait_for_everyone() + if accelerator.is_main_process: + updated_embs_list = [] + for text_encoder, token_ids in zip(text_encoders, token_ids_list): + updated_embs = ( + accelerator.unwrap_model(text_encoder) + .get_input_embeddings() + .weight[token_ids] + .data.detach() + .clone() + ) + updated_embs_list.append(updated_embs) + + ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step) + save_model(ckpt_name, updated_embs_list, global_step, epoch) + + if args.save_state: + train_util.save_and_remove_state_stepwise(args, accelerator, global_step) + + remove_step_no = train_util.get_remove_step_no(args, global_step) + if remove_step_no is not None: + remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no) + remove_model(remove_ckpt_name) + + current_loss = loss.detach().item() + if args.logging_dir is not None: + logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} + if ( + args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower() + ): # tracking d*lr value + logs["lr/d*lr"] = ( + lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"] + ) + accelerator.log(logs, step=global_step) - if os.path.splitext(file)[1] == ".safetensors": - from safetensors.torch import save_file + loss_total += current_loss + avr_loss = loss_total / (step + 1) + logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) - save_file(state_dict, file) - else: - torch.save(state_dict, file) # can be loaded in Web UI + if global_step >= args.max_train_steps: + break + + if args.logging_dir is not None: + logs = {"loss/epoch": loss_total / len(train_dataloader)} + accelerator.log(logs, step=epoch + 1) + + accelerator.wait_for_everyone() + + updated_embs_list = [] + for text_encoder, token_ids in zip(text_encoders, token_ids_list): + updated_embs = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone() + updated_embs_list.append(updated_embs) + + if args.save_every_n_epochs is not None: + saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs + if accelerator.is_main_process and saving: + ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1) + save_model(ckpt_name, updated_embs_list, epoch + 1, global_step) + + remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1) + if remove_epoch_no is not None: + remove_ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, remove_epoch_no) + remove_model(remove_ckpt_name) + + if args.save_state: + train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1) + + self.sample_images( + accelerator, + args, + epoch + 1, + global_step, + accelerator.device, + vae, + tokenizer_or_list, + text_encoder_or_list, + unet, + prompt_replacement, + ) + # end of epoch -def load_weights(file): - if os.path.splitext(file)[1] == ".safetensors": - from safetensors.torch import load_file + is_main_process = accelerator.is_main_process + if is_main_process: + text_encoder = accelerator.unwrap_model(text_encoder) - data = load_file(file) - else: - # compatible to Web UI's file format - data = torch.load(file, map_location="cpu") - if type(data) != dict: - raise ValueError(f"weight file is not dict / 重みファイルがdict形式ではありません: {file}") + accelerator.end_training() - if "string_to_param" in data: # textual inversion embeddings - data = data["string_to_param"] - if hasattr(data, "_parameters"): # support old PyTorch? - data = getattr(data, "_parameters") + if args.save_state and is_main_process: + train_util.save_state_on_train_end(args, accelerator) - emb = next(iter(data.values())) - if type(emb) != torch.Tensor: - raise ValueError(f"weight file does not contains Tensor / 重みファイルのデータがTensorではありません: {file}") + updated_embs = text_encoder.get_input_embeddings().weight[token_ids].data.detach().clone() - if len(emb.size()) == 1: - emb = emb.unsqueeze(0) + if is_main_process: + ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) + save_model(ckpt_name, updated_embs_list, global_step, num_train_epochs, force_sync_upload=True) - return emb + print("model saved.") def setup_parser() -> argparse.ArgumentParser: @@ -626,4 +754,5 @@ def setup_parser() -> argparse.ArgumentParser: args = parser.parse_args() args = train_util.read_config_from_file(args, parser) - train(args) + trainer = TextualInversionTrainer() + trainer.train(args) From 68ca0ea995434a331209fbc34938e1bbe7ccb083 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 10 Jul 2023 22:28:26 +0900 Subject: [PATCH 15/20] Fix to show template type --- train_textual_inversion.py | 4 ++-- train_textual_inversion_XTI.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 09294048f..1f085643b 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -94,7 +94,7 @@ def load_tokenizer(self, args): def assert_token_string(self, token_string, tokenizers): pass - + def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype): with torch.enable_grad(): input_ids = batch["input_ids"].to(accelerator.device) @@ -311,7 +311,7 @@ def train(self, args): # make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装 if use_template: - accelerator.print("use template for training captions. is object: {args.use_object_template}") + accelerator.print(f"use template for training captions. is object: {args.use_object_template}") templates = imagenet_templates_small if args.use_object_template else imagenet_style_templates_small replace_to = " ".join(token_strings) captions = [] diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index a08c3a824..0e91c71c3 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -234,7 +234,7 @@ def train(args): # make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装 if use_template: - print("use template for training captions. is object: {args.use_object_template}") + print(f"use template for training captions. is object: {args.use_object_template}") templates = imagenet_templates_small if args.use_object_template else imagenet_style_templates_small replace_to = " ".join(token_strings) captions = [] From 2e67d74df46df88a9e9b3cae59373b7053b7a40c Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 11 Jul 2023 22:19:14 +0900 Subject: [PATCH 16/20] add no_half_vae option --- train_textual_inversion.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 1f085643b..cbfd48ce7 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -173,6 +173,7 @@ def train(self, args): # mixed precisionに対応した型を用意しておき適宜castする weight_dtype, save_dtype = train_util.prepare_dtype(args) + vae_dtype = torch.float32 if args.no_half_vae else weight_dtype # モデルを読み込む model_version, text_encoder_or_list, vae, unet = self.load_target_model(args, weight_dtype, accelerator) @@ -351,7 +352,7 @@ def train(self, args): # 学習を準備する if cache_latents: - vae.to(accelerator.device, dtype=weight_dtype) + vae.to(accelerator.device, dtype=vae_dtype) vae.requires_grad_(False) vae.eval() with torch.no_grad(): @@ -447,10 +448,10 @@ def train(self, args): else: unet.eval() - if not cache_latents: + if not cache_latents: # キャッシュしない場合はVAEを使うのでVAEを準備する vae.requires_grad_(False) vae.eval() - vae.to(accelerator.device, dtype=weight_dtype) + vae.to(accelerator.device, dtype=vae_dtype) # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする if args.full_fp16: @@ -529,7 +530,7 @@ def remove_model(old_ckpt_name): latents = batch["latents"].to(accelerator.device) else: # latentに変換 - latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() + latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample() latents = latents * self.vae_scale_factor # Get the text embedding for conditioning @@ -744,6 +745,11 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="ignore caption and use default templates for stype / キャプションは使わずデフォルトのスタイル用テンプレートで学習する", ) + parser.add_argument( + "--no_half_vae", + action="store_true", + help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", + ) return parser From 814996b14f39fb7144c37089a20e05c120dceeaa Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 11 Jul 2023 23:18:35 +0900 Subject: [PATCH 17/20] fix NaN in sampling image --- library/sdxl_lpw_stable_diffusion.py | 2 +- library/train_util.py | 232 +++++++++++++-------------- 2 files changed, 117 insertions(+), 117 deletions(-) diff --git a/library/sdxl_lpw_stable_diffusion.py b/library/sdxl_lpw_stable_diffusion.py index d44b3cf8c..99b0bc8d5 100644 --- a/library/sdxl_lpw_stable_diffusion.py +++ b/library/sdxl_lpw_stable_diffusion.py @@ -922,7 +922,7 @@ def __call__( if up1 is not None: uncond_pool = up1 - dtype = text_embeddings_list[0].dtype + dtype = self.unet.dtype # 4. Preprocess image and mask if isinstance(image, PIL.Image.Image): diff --git a/library/train_util.py b/library/train_util.py index 809f0af03..9438a1895 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3874,127 +3874,127 @@ def sample_images_common( cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None with torch.no_grad(): - with accelerator.autocast(): - for i, prompt in enumerate(prompts): - if not accelerator.is_main_process: - continue + # with accelerator.autocast(): + for i, prompt in enumerate(prompts): + if not accelerator.is_main_process: + continue - if isinstance(prompt, dict): - negative_prompt = prompt.get("negative_prompt") - sample_steps = prompt.get("sample_steps", 30) - width = prompt.get("width", 512) - height = prompt.get("height", 512) - scale = prompt.get("scale", 7.5) - seed = prompt.get("seed") - controlnet_image = prompt.get("controlnet_image") - prompt = prompt.get("prompt") - else: - # prompt = prompt.strip() - # if len(prompt) == 0 or prompt[0] == "#": - # continue - - # subset of gen_img_diffusers - prompt_args = prompt.split(" --") - prompt = prompt_args[0] - negative_prompt = None - sample_steps = 30 - width = height = 512 - scale = 7.5 - seed = None - controlnet_image = None - for parg in prompt_args: - try: - m = re.match(r"w (\d+)", parg, re.IGNORECASE) - if m: - width = int(m.group(1)) - continue - - m = re.match(r"h (\d+)", parg, re.IGNORECASE) - if m: - height = int(m.group(1)) - continue - - m = re.match(r"d (\d+)", parg, re.IGNORECASE) - if m: - seed = int(m.group(1)) - continue - - m = re.match(r"s (\d+)", parg, re.IGNORECASE) - if m: # steps - sample_steps = max(1, min(1000, int(m.group(1)))) - continue - - m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE) - if m: # scale - scale = float(m.group(1)) - continue - - m = re.match(r"n (.+)", parg, re.IGNORECASE) - if m: # negative prompt - negative_prompt = m.group(1) - continue - - m = re.match(r"cn (.+)", parg, re.IGNORECASE) - if m: # negative prompt - controlnet_image = m.group(1) - continue - - except ValueError as ex: - print(f"Exception in parsing / 解析エラー: {parg}") - print(ex) - - if seed is not None: - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - - if prompt_replacement is not None: - prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1]) - if negative_prompt is not None: - negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1]) - - if controlnet_image is not None: - controlnet_image = Image.open(controlnet_image).convert("RGB") - controlnet_image = controlnet_image.resize((width, height), Image.LANCZOS) - - height = max(64, height - height % 8) # round to divisible by 8 - width = max(64, width - width % 8) # round to divisible by 8 - print(f"prompt: {prompt}") - print(f"negative_prompt: {negative_prompt}") - print(f"height: {height}") - print(f"width: {width}") - print(f"sample_steps: {sample_steps}") - print(f"scale: {scale}") - image = pipeline( - prompt=prompt, - height=height, - width=width, - num_inference_steps=sample_steps, - guidance_scale=scale, - negative_prompt=negative_prompt, - controlnet=controlnet, - controlnet_image=controlnet_image, - ).images[0] - - ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) - num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}" - seed_suffix = "" if seed is None else f"_{seed}" - img_filename = ( - f"{'' if args.output_name is None else args.output_name + '_'}{ts_str}_{num_suffix}_{i:02d}{seed_suffix}.png" - ) + if isinstance(prompt, dict): + negative_prompt = prompt.get("negative_prompt") + sample_steps = prompt.get("sample_steps", 30) + width = prompt.get("width", 512) + height = prompt.get("height", 512) + scale = prompt.get("scale", 7.5) + seed = prompt.get("seed") + controlnet_image = prompt.get("controlnet_image") + prompt = prompt.get("prompt") + else: + # prompt = prompt.strip() + # if len(prompt) == 0 or prompt[0] == "#": + # continue + + # subset of gen_img_diffusers + prompt_args = prompt.split(" --") + prompt = prompt_args[0] + negative_prompt = None + sample_steps = 30 + width = height = 512 + scale = 7.5 + seed = None + controlnet_image = None + for parg in prompt_args: + try: + m = re.match(r"w (\d+)", parg, re.IGNORECASE) + if m: + width = int(m.group(1)) + continue + + m = re.match(r"h (\d+)", parg, re.IGNORECASE) + if m: + height = int(m.group(1)) + continue + + m = re.match(r"d (\d+)", parg, re.IGNORECASE) + if m: + seed = int(m.group(1)) + continue + + m = re.match(r"s (\d+)", parg, re.IGNORECASE) + if m: # steps + sample_steps = max(1, min(1000, int(m.group(1)))) + continue + + m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE) + if m: # scale + scale = float(m.group(1)) + continue + + m = re.match(r"n (.+)", parg, re.IGNORECASE) + if m: # negative prompt + negative_prompt = m.group(1) + continue + + m = re.match(r"cn (.+)", parg, re.IGNORECASE) + if m: # negative prompt + controlnet_image = m.group(1) + continue + + except ValueError as ex: + print(f"Exception in parsing / 解析エラー: {parg}") + print(ex) + + if seed is not None: + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + if prompt_replacement is not None: + prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1]) + if negative_prompt is not None: + negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1]) + + if controlnet_image is not None: + controlnet_image = Image.open(controlnet_image).convert("RGB") + controlnet_image = controlnet_image.resize((width, height), Image.LANCZOS) + + height = max(64, height - height % 8) # round to divisible by 8 + width = max(64, width - width % 8) # round to divisible by 8 + print(f"prompt: {prompt}") + print(f"negative_prompt: {negative_prompt}") + print(f"height: {height}") + print(f"width: {width}") + print(f"sample_steps: {sample_steps}") + print(f"scale: {scale}") + image = pipeline( + prompt=prompt, + height=height, + width=width, + num_inference_steps=sample_steps, + guidance_scale=scale, + negative_prompt=negative_prompt, + controlnet=controlnet, + controlnet_image=controlnet_image, + ).images[0] + + ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) + num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}" + seed_suffix = "" if seed is None else f"_{seed}" + img_filename = ( + f"{'' if args.output_name is None else args.output_name + '_'}{ts_str}_{num_suffix}_{i:02d}{seed_suffix}.png" + ) - image.save(os.path.join(save_dir, img_filename)) + image.save(os.path.join(save_dir, img_filename)) - # wandb有効時のみログを送信 + # wandb有効時のみログを送信 + try: + wandb_tracker = accelerator.get_tracker("wandb") try: - wandb_tracker = accelerator.get_tracker("wandb") - try: - import wandb - except ImportError: # 事前に一度確認するのでここはエラー出ないはず - raise ImportError("No wandb / wandb がインストールされていないようです") + import wandb + except ImportError: # 事前に一度確認するのでここはエラー出ないはず + raise ImportError("No wandb / wandb がインストールされていないようです") - wandb_tracker.log({f"sample_{i}": wandb.Image(image)}) - except: # wandb 無効時 - pass + wandb_tracker.log({f"sample_{i}": wandb.Image(image)}) + except: # wandb 無効時 + pass # clear pipeline and cache to reduce vram usage del pipeline From 689721cba5f0e4b421bcaf08669657b27f74297f Mon Sep 17 00:00:00 2001 From: bmaltais Date: Tue, 11 Jul 2023 11:40:42 -0400 Subject: [PATCH 18/20] Updates --- .release | 2 +- README.md | 6 +- dreambooth_gui.py | 326 +++++++++++----------- finetune_gui.py | 9 + kohya_gui.py | 14 +- library/class_lora_tab.py | 42 +++ library/class_sdxl_parameters.py | 6 +- library/class_source_model.py | 4 +- library/svd_merge_lora_gui.py | 7 +- lora_gui.py | 25 +- textual_inversion_gui.py | 456 ++++++++++++++++--------------- 11 files changed, 471 insertions(+), 426 deletions(-) create mode 100644 library/class_lora_tab.py diff --git a/.release b/.release index 042b7490f..3f7713958 100644 --- a/.release +++ b/.release @@ -1 +1 @@ -v21.8.1 \ No newline at end of file +v21.8.2 \ No newline at end of file diff --git a/README.md b/README.md index 33e38beeb..96b2968bb 100644 --- a/README.md +++ b/README.md @@ -462,4 +462,8 @@ If you come across a `FileNotFoundError`, it is likely due to an installation is * 2023/07/10 (v21.8.1) - Let Tensorboard works in docker #1137 - - Fix for accelerate issue \ No newline at end of file + - Fix for accelerate issue + - Add SDXL TI training support + - Rework gui for common layout + - More LoRA tools to class + - Add no_half_vae option to TI \ No newline at end of file diff --git a/dreambooth_gui.py b/dreambooth_gui.py index ef6d11bb0..dd8fa145a 100644 --- a/dreambooth_gui.py +++ b/dreambooth_gui.py @@ -659,184 +659,186 @@ def dreambooth_tab( dummy_db_true = gr.Label(value=True, visible=False) dummy_db_false = gr.Label(value=False, visible=False) dummy_headless = gr.Label(value=headless, visible=False) - gr.Markdown('Train a custom model using kohya dreambooth python code...') - # Setup Configuration Files Gradio - config = ConfigurationFile(headless) - - source_model = SourceModel(headless=headless) - - with gr.Tab('Folders'): - folders = Folders(headless=headless) - with gr.Tab('Parameters'): - basic_training = BasicTraining( - learning_rate_value='1e-5', - lr_scheduler_value='cosine', - lr_warmup_value='10', - ) - with gr.Accordion('Advanced Configuration', open=False): - advanced_training = AdvancedTraining(headless=headless) - advanced_training.color_aug.change( - color_aug_changed, - inputs=[advanced_training.color_aug], - outputs=[basic_training.cache_latents], + with gr.Tab('Training'): + gr.Markdown('Train a custom model using kohya dreambooth python code...') + + # Setup Configuration Files Gradio + config = ConfigurationFile(headless) + + source_model = SourceModel(headless=headless) + + with gr.Tab('Folders'): + folders = Folders(headless=headless) + with gr.Tab('Parameters'): + basic_training = BasicTraining( + learning_rate_value='1e-5', + lr_scheduler_value='cosine', + lr_warmup_value='10', ) + with gr.Accordion('Advanced Configuration', open=False): + advanced_training = AdvancedTraining(headless=headless) + advanced_training.color_aug.change( + color_aug_changed, + inputs=[advanced_training.color_aug], + outputs=[basic_training.cache_latents], + ) - sample = SampleImages() - - with gr.Tab('Tools'): - gr.Markdown( - 'This section provide Dreambooth tools to help setup your dataset...' - ) - gradio_dreambooth_folder_creation_tab( - train_data_dir_input=folders.train_data_dir, - reg_data_dir_input=folders.reg_data_dir, - output_dir_input=folders.output_dir, - logging_dir_input=folders.logging_dir, - headless=headless, - ) + sample = SampleImages() - button_run = gr.Button('Train model', variant='primary') + with gr.Tab('Tools'): + gr.Markdown( + 'This section provide Dreambooth tools to help setup your dataset...' + ) + gradio_dreambooth_folder_creation_tab( + train_data_dir_input=folders.train_data_dir, + reg_data_dir_input=folders.reg_data_dir, + output_dir_input=folders.output_dir, + logging_dir_input=folders.logging_dir, + headless=headless, + ) - button_print = gr.Button('Print training command') + button_run = gr.Button('Train model', variant='primary') - # Setup gradio tensorboard buttons - button_start_tensorboard, button_stop_tensorboard = gradio_tensorboard() + button_print = gr.Button('Print training command') - button_start_tensorboard.click( - start_tensorboard, - inputs=folders.logging_dir, - show_progress=False, - ) + # Setup gradio tensorboard buttons + button_start_tensorboard, button_stop_tensorboard = gradio_tensorboard() - button_stop_tensorboard.click( - stop_tensorboard, - show_progress=False, - ) + button_start_tensorboard.click( + start_tensorboard, + inputs=folders.logging_dir, + show_progress=False, + ) - settings_list = [ - source_model.pretrained_model_name_or_path, - source_model.v2, - source_model.v_parameterization, - source_model.sdxl_checkbox, - folders.logging_dir, - folders.train_data_dir, - folders.reg_data_dir, - folders.output_dir, - basic_training.max_resolution, - basic_training.learning_rate, - basic_training.lr_scheduler, - basic_training.lr_warmup, - basic_training.train_batch_size, - basic_training.epoch, - basic_training.save_every_n_epochs, - basic_training.mixed_precision, - basic_training.save_precision, - basic_training.seed, - basic_training.num_cpu_threads_per_process, - basic_training.cache_latents, - basic_training.cache_latents_to_disk, - basic_training.caption_extension, - basic_training.enable_bucket, - advanced_training.gradient_checkpointing, - advanced_training.full_fp16, - advanced_training.no_token_padding, - basic_training.stop_text_encoder_training, - advanced_training.xformers, - source_model.save_model_as, - advanced_training.shuffle_caption, - advanced_training.save_state, - advanced_training.resume, - advanced_training.prior_loss_weight, - advanced_training.color_aug, - advanced_training.flip_aug, - advanced_training.clip_skip, - advanced_training.vae, - folders.output_name, - advanced_training.max_token_length, - advanced_training.max_train_epochs, - advanced_training.max_data_loader_n_workers, - advanced_training.mem_eff_attn, - advanced_training.gradient_accumulation_steps, - source_model.model_list, - advanced_training.keep_tokens, - advanced_training.persistent_data_loader_workers, - advanced_training.bucket_no_upscale, - advanced_training.random_crop, - advanced_training.bucket_reso_steps, - advanced_training.caption_dropout_every_n_epochs, - advanced_training.caption_dropout_rate, - basic_training.optimizer, - basic_training.optimizer_args, - advanced_training.noise_offset_type, - advanced_training.noise_offset, - advanced_training.adaptive_noise_scale, - advanced_training.multires_noise_iterations, - advanced_training.multires_noise_discount, - sample.sample_every_n_steps, - sample.sample_every_n_epochs, - sample.sample_sampler, - sample.sample_prompts, - advanced_training.additional_parameters, - advanced_training.vae_batch_size, - advanced_training.min_snr_gamma, - advanced_training.weighted_captions, - advanced_training.save_every_n_steps, - advanced_training.save_last_n_steps, - advanced_training.save_last_n_steps_state, - advanced_training.use_wandb, - advanced_training.wandb_api_key, - advanced_training.scale_v_pred_loss_like_noise_pred, - advanced_training.min_timestep, - advanced_training.max_timestep, - ] + button_stop_tensorboard.click( + stop_tensorboard, + show_progress=False, + ) - config.button_open_config.click( - open_configuration, - inputs=[dummy_db_true, config.config_file_name] + settings_list, - outputs=[config.config_file_name] + settings_list, - show_progress=False, - ) + settings_list = [ + source_model.pretrained_model_name_or_path, + source_model.v2, + source_model.v_parameterization, + source_model.sdxl_checkbox, + folders.logging_dir, + folders.train_data_dir, + folders.reg_data_dir, + folders.output_dir, + basic_training.max_resolution, + basic_training.learning_rate, + basic_training.lr_scheduler, + basic_training.lr_warmup, + basic_training.train_batch_size, + basic_training.epoch, + basic_training.save_every_n_epochs, + basic_training.mixed_precision, + basic_training.save_precision, + basic_training.seed, + basic_training.num_cpu_threads_per_process, + basic_training.cache_latents, + basic_training.cache_latents_to_disk, + basic_training.caption_extension, + basic_training.enable_bucket, + advanced_training.gradient_checkpointing, + advanced_training.full_fp16, + advanced_training.no_token_padding, + basic_training.stop_text_encoder_training, + advanced_training.xformers, + source_model.save_model_as, + advanced_training.shuffle_caption, + advanced_training.save_state, + advanced_training.resume, + advanced_training.prior_loss_weight, + advanced_training.color_aug, + advanced_training.flip_aug, + advanced_training.clip_skip, + advanced_training.vae, + folders.output_name, + advanced_training.max_token_length, + advanced_training.max_train_epochs, + advanced_training.max_data_loader_n_workers, + advanced_training.mem_eff_attn, + advanced_training.gradient_accumulation_steps, + source_model.model_list, + advanced_training.keep_tokens, + advanced_training.persistent_data_loader_workers, + advanced_training.bucket_no_upscale, + advanced_training.random_crop, + advanced_training.bucket_reso_steps, + advanced_training.caption_dropout_every_n_epochs, + advanced_training.caption_dropout_rate, + basic_training.optimizer, + basic_training.optimizer_args, + advanced_training.noise_offset_type, + advanced_training.noise_offset, + advanced_training.adaptive_noise_scale, + advanced_training.multires_noise_iterations, + advanced_training.multires_noise_discount, + sample.sample_every_n_steps, + sample.sample_every_n_epochs, + sample.sample_sampler, + sample.sample_prompts, + advanced_training.additional_parameters, + advanced_training.vae_batch_size, + advanced_training.min_snr_gamma, + advanced_training.weighted_captions, + advanced_training.save_every_n_steps, + advanced_training.save_last_n_steps, + advanced_training.save_last_n_steps_state, + advanced_training.use_wandb, + advanced_training.wandb_api_key, + advanced_training.scale_v_pred_loss_like_noise_pred, + advanced_training.min_timestep, + advanced_training.max_timestep, + ] + + config.button_open_config.click( + open_configuration, + inputs=[dummy_db_true, config.config_file_name] + settings_list, + outputs=[config.config_file_name] + settings_list, + show_progress=False, + ) - config.button_load_config.click( - open_configuration, - inputs=[dummy_db_false, config.config_file_name] + settings_list, - outputs=[config.config_file_name] + settings_list, - show_progress=False, - ) + config.button_load_config.click( + open_configuration, + inputs=[dummy_db_false, config.config_file_name] + settings_list, + outputs=[config.config_file_name] + settings_list, + show_progress=False, + ) - config.button_save_config.click( - save_configuration, - inputs=[dummy_db_false, config.config_file_name] + settings_list, - outputs=[config.config_file_name], - show_progress=False, - ) + config.button_save_config.click( + save_configuration, + inputs=[dummy_db_false, config.config_file_name] + settings_list, + outputs=[config.config_file_name], + show_progress=False, + ) - config.button_save_as_config.click( - save_configuration, - inputs=[dummy_db_true, config.config_file_name] + settings_list, - outputs=[config.config_file_name], - show_progress=False, - ) + config.button_save_as_config.click( + save_configuration, + inputs=[dummy_db_true, config.config_file_name] + settings_list, + outputs=[config.config_file_name], + show_progress=False, + ) - button_run.click( - train_model, - inputs=[dummy_headless] + [dummy_db_false] + settings_list, - show_progress=False, - ) + button_run.click( + train_model, + inputs=[dummy_headless] + [dummy_db_false] + settings_list, + show_progress=False, + ) - button_print.click( - train_model, - inputs=[dummy_headless] + [dummy_db_true] + settings_list, - show_progress=False, - ) + button_print.click( + train_model, + inputs=[dummy_headless] + [dummy_db_true] + settings_list, + show_progress=False, + ) - return ( - folders.train_data_dir, - folders.reg_data_dir, - folders.output_dir, - folders.logging_dir, - ) + return ( + folders.train_data_dir, + folders.reg_data_dir, + folders.output_dir, + folders.logging_dir, + ) def UI(**kwargs): diff --git a/finetune_gui.py b/finetune_gui.py index ca1e51f0f..5253ac07c 100644 --- a/finetune_gui.py +++ b/finetune_gui.py @@ -78,6 +78,7 @@ def save_configuration( seed, num_cpu_threads_per_process, train_text_encoder, + full_bf16, create_caption, create_buckets, save_model_as, @@ -197,6 +198,7 @@ def open_configuration( seed, num_cpu_threads_per_process, train_text_encoder, + full_bf16, create_caption, create_buckets, save_model_as, @@ -313,6 +315,7 @@ def train_model( seed, num_cpu_threads_per_process, train_text_encoder, + full_bf16, generate_caption_database, generate_image_buckets, save_model_as, @@ -495,6 +498,8 @@ def train_model( run_cmd += ' --v_parameterization' if train_text_encoder: run_cmd += ' --train_text_encoder' + if full_bf16: + run_cmd += ' --full_bf16' if weighted_captions: run_cmd += ' --weighted_captions' run_cmd += ( @@ -788,6 +793,9 @@ def finetune_tab(headless=False): train_text_encoder = gr.Checkbox( label='Train text encoder', value=True ) + full_bf16 = gr.Checkbox( + label='Full bf16', value = False + ) with gr.Accordion('Advanced parameters', open=False): with gr.Row(): gradient_accumulation_steps = gr.Number( @@ -848,6 +856,7 @@ def finetune_tab(headless=False): basic_training.seed, basic_training.num_cpu_threads_per_process, train_text_encoder, + full_bf16, create_caption, create_buckets, source_model.save_model_as, diff --git a/kohya_gui.py b/kohya_gui.py index 0ac0e15b3..a6043e4b9 100644 --- a/kohya_gui.py +++ b/kohya_gui.py @@ -5,13 +5,8 @@ from finetune_gui import finetune_tab from textual_inversion_gui import ti_tab from library.utilities import utilities_tab -from library.extract_lora_gui import gradio_extract_lora_tab -from library.extract_lycoris_locon_gui import gradio_extract_lycoris_locon_tab -from library.merge_lora_gui import gradio_merge_lora_tab -from library.resize_lora_gui import gradio_resize_lora_tab -from library.extract_lora_from_dylora_gui import gradio_extract_dylora_tab -from library.merge_lycoris_gui import gradio_merge_lycoris_tab from lora_gui import lora_tab +from library.class_lora_tab import LoRATools import os from library.custom_logging import setup_logging @@ -67,12 +62,7 @@ def UI(**kwargs): headless=headless, ) with gr.Tab('LoRA'): - gradio_extract_dylora_tab(headless=headless) - gradio_extract_lora_tab(headless=headless) - gradio_extract_lycoris_locon_tab(headless=headless) - gradio_merge_lora_tab(headless=headless) - gradio_merge_lycoris_tab(headless=headless) - gradio_resize_lora_tab(headless=headless) + _ = LoRATools(headless=headless) with gr.Tab('About'): gr.Markdown(f'kohya_ss GUI release {release}') with gr.Tab('README'): diff --git a/library/class_lora_tab.py b/library/class_lora_tab.py new file mode 100644 index 000000000..a19f34a9f --- /dev/null +++ b/library/class_lora_tab.py @@ -0,0 +1,42 @@ +import gradio as gr +from library.merge_lora_gui import gradio_merge_lora_tab +from library.svd_merge_lora_gui import gradio_svd_merge_lora_tab +from library.verify_lora_gui import gradio_verify_lora_tab +from library.resize_lora_gui import gradio_resize_lora_tab +from library.extract_lora_gui import gradio_extract_lora_tab +from library.extract_lycoris_locon_gui import gradio_extract_lycoris_locon_tab +from library.extract_lora_from_dylora_gui import gradio_extract_dylora_tab +from library.merge_lycoris_gui import gradio_merge_lycoris_tab + +# Deprecated code +from library.dataset_balancing_gui import gradio_dataset_balancing_tab +from library.dreambooth_folder_creation_gui import ( + gradio_dreambooth_folder_creation_tab, +) + +class LoRATools: + def __init__(self, folders = "", headless:bool = False): + self.headless = headless + self.folders = folders + + gr.Markdown( + 'This section provide LoRA tools to help setup your dataset...' + ) + gradio_extract_dylora_tab(headless=headless) + gradio_extract_lora_tab(headless=headless) + gradio_extract_lycoris_locon_tab(headless=headless) + gradio_merge_lora_tab(headless=headless) + gradio_merge_lycoris_tab(headless=headless) + gradio_svd_merge_lora_tab(headless=headless) + gradio_resize_lora_tab(headless=headless) + gradio_verify_lora_tab(headless=headless) + if folders: + with gr.Tab('Deprecated'): + gradio_dreambooth_folder_creation_tab( + train_data_dir_input=folders.train_data_dir, + reg_data_dir_input=folders.reg_data_dir, + output_dir_input=folders.output_dir, + logging_dir_input=folders.logging_dir, + headless=headless, + ) + gradio_dataset_balancing_tab(headless=headless) \ No newline at end of file diff --git a/library/class_sdxl_parameters.py b/library/class_sdxl_parameters.py index 8f7883e8b..33af8631b 100644 --- a/library/class_sdxl_parameters.py +++ b/library/class_sdxl_parameters.py @@ -2,8 +2,9 @@ ### SDXL Parameters class class SDXLParameters: - def __init__(self, sdxl_checkbox): + def __init__(self, sdxl_checkbox, show_sdxl_cache_text_encoder_outputs:bool = True): self.sdxl_checkbox = sdxl_checkbox + self.show_sdxl_cache_text_encoder_outputs = show_sdxl_cache_text_encoder_outputs with gr.Accordion(visible=False, open=True, label='SDXL Specific Parameters') as self.sdxl_row: with gr.Row(): @@ -11,11 +12,12 @@ def __init__(self, sdxl_checkbox): label='Cache text encoder outputs', info='Cache the outputs of the text encoders. This option is useful to reduce the GPU memory usage. This option cannot be used with options for shuffling or dropping the captions.', value=False, + visible=show_sdxl_cache_text_encoder_outputs ) self.sdxl_no_half_vae = gr.Checkbox( label='No half VAE', info='Disable the half-precision (mixed-precision) VAE. VAE for SDXL seems to produce NaNs in some cases. This option is useful to avoid the NaNs.', - value=False + value=True ) self.sdxl_checkbox.change(lambda sdxl_checkbox: gr.Accordion.update(visible=sdxl_checkbox), inputs=[self.sdxl_checkbox], outputs=[self.sdxl_row]) diff --git a/library/class_source_model.py b/library/class_source_model.py index 4080f0498..509bc41e5 100644 --- a/library/class_source_model.py +++ b/library/class_source_model.py @@ -33,8 +33,8 @@ def __init__( label='Model Quick Pick', choices=[ 'custom', - 'stabilityai/stable-diffusion-xl-base-0.9', - 'stabilityai/stable-diffusion-xl-refiner-0.9', + # 'stabilityai/stable-diffusion-xl-base-0.9', + # 'stabilityai/stable-diffusion-xl-refiner-0.9', 'stabilityai/stable-diffusion-2-1-base/blob/main/v2-1_512-ema-pruned', 'stabilityai/stable-diffusion-2-1-base', 'stabilityai/stable-diffusion-2-base', diff --git a/library/svd_merge_lora_gui.py b/library/svd_merge_lora_gui.py index c42227216..9b5cce95e 100644 --- a/library/svd_merge_lora_gui.py +++ b/library/svd_merge_lora_gui.py @@ -36,6 +36,11 @@ def svd_merge_lora( new_conv_rank, device, ): + # Check if the output file already exists + if os.path.isfile(save_to): + print(f"Output file '{save_to}' already exists. Aborting.") + return + # Check if the ratio total is equal to one. If not mormalise to 1 total_ratio = ratio_a + ratio_b + ratio_c + ratio_d if total_ratio != 1: @@ -78,7 +83,7 @@ def svd_merge_lora( run_cmd_ratios += f' {ratio_d}' run_cmd += run_cmd_models - run_cmd += run_cmd_ratiosacti + run_cmd += run_cmd_ratios run_cmd += f' --device {device}' run_cmd += f' --new_rank "{new_rank}"' run_cmd += f' --new_conv_rank "{new_conv_rank}"' diff --git a/lora_gui.py b/lora_gui.py index 2e9becfbc..0a6fc49d0 100644 --- a/lora_gui.py +++ b/lora_gui.py @@ -32,21 +32,14 @@ from library.class_advanced_training import AdvancedTraining from library.class_sdxl_parameters import SDXLParameters from library.class_folders import Folders -from library.dreambooth_folder_creation_gui import ( - gradio_dreambooth_folder_creation_tab, -) from library.tensorboard_gui import ( gradio_tensorboard, start_tensorboard, stop_tensorboard, ) -from library.dataset_balancing_gui import gradio_dataset_balancing_tab from library.utilities import utilities_tab -from library.merge_lora_gui import gradio_merge_lora_tab -from library.svd_merge_lora_gui import gradio_svd_merge_lora_tab -from library.verify_lora_gui import gradio_verify_lora_tab -from library.resize_lora_gui import gradio_resize_lora_tab from library.class_sample_images import SampleImages, run_cmd_sample +from library.class_lora_tab import LoRATools from library.custom_logging import setup_logging @@ -1576,21 +1569,7 @@ def update_LoRA_settings(LoRA_type): ) with gr.Tab('Tools'): - gr.Markdown( - 'This section provide LoRA tools to help setup your dataset...' - ) - gradio_dreambooth_folder_creation_tab( - train_data_dir_input=folders.train_data_dir, - reg_data_dir_input=folders.reg_data_dir, - output_dir_input=folders.output_dir, - logging_dir_input=folders.logging_dir, - headless=headless, - ) - gradio_dataset_balancing_tab(headless=headless) - gradio_merge_lora_tab(headless=headless) - gradio_svd_merge_lora_tab(headless=headless) - gradio_resize_lora_tab(headless=headless) - gradio_verify_lora_tab(headless=headless) + lora_tools = LoRATools(folders=folders, headless=headless) with gr.Tab('Guides'): gr.Markdown( diff --git a/textual_inversion_gui.py b/textual_inversion_gui.py index f91b8dda9..702241ce6 100644 --- a/textual_inversion_gui.py +++ b/textual_inversion_gui.py @@ -30,6 +30,7 @@ from library.class_basic_training import BasicTraining from library.class_advanced_training import AdvancedTraining from library.class_folders import Folders +from library.class_sdxl_parameters import SDXLParameters from library.tensorboard_gui import ( gradio_tensorboard, start_tensorboard, @@ -129,6 +130,7 @@ def save_configuration( scale_v_pred_loss_like_noise_pred, min_timestep, max_timestep, + sdxl_no_half_vae ): # Get list of function parameters and values parameters = list(locals().items()) @@ -245,6 +247,7 @@ def open_configuration( scale_v_pred_loss_like_noise_pred, min_timestep, max_timestep, + sdxl_no_half_vae ): # Get list of function parameters and values parameters = list(locals().items()) @@ -358,6 +361,7 @@ def train_model( scale_v_pred_loss_like_noise_pred, min_timestep, max_timestep, + sdxl_no_half_vae ): # Get list of function parameters and values parameters = list(locals().items()) @@ -421,13 +425,6 @@ def train_model( ): return - if sdxl: - output_message( - msg='TI training is not compatible with an SDXL model.', - headless=headless_bool, - ) - return - # if float(noise_offset) > 0 and ( # multires_noise_iterations > 0 or multires_noise_discount > 0 # ): @@ -520,7 +517,12 @@ def train_model( lr_warmup_steps = round(float(int(lr_warmup) * int(max_train_steps) / 100)) log.info(f'lr_warmup_steps = {lr_warmup_steps}') - run_cmd = f'accelerate launch --num_cpu_threads_per_process={num_cpu_threads_per_process} "train_textual_inversion.py"' + run_cmd = f'accelerate launch --num_cpu_threads_per_process={num_cpu_threads_per_process}' + if sdxl: + run_cmd += f' "./sdxl_train_textual_inversion.py"' + else: + run_cmd += f' "./train_textual_inversion.py"' + if v2: run_cmd += ' --v2' if v_parameterization: @@ -563,6 +565,9 @@ def train_model( ) if int(gradient_accumulation_steps) > 1: run_cmd += f' --gradient_accumulation_steps={int(gradient_accumulation_steps)}' + + if sdxl_no_half_vae: + run_cmd += f' --no_half_vae' run_cmd += run_cmd_training( learning_rate=learning_rate, @@ -679,237 +684,244 @@ def ti_tab( dummy_db_true = gr.Label(value=True, visible=False) dummy_db_false = gr.Label(value=False, visible=False) dummy_headless = gr.Label(value=headless, visible=False) - gr.Markdown('Train a TI using kohya textual inversion python code...') - # Setup Configuration Files Gradio - config = ConfigurationFile(headless) - - source_model = SourceModel( - save_model_as_choices=[ - 'ckpt', - 'safetensors', - ], - headless=headless, - ) - - with gr.Tab('Folders'): - folders = Folders(headless=headless) - with gr.Tab('Parameters'): - with gr.Row(): - weights = gr.Textbox( - label='Resume TI training', - placeholder='(Optional) Path to existing TI embeding file to keep training', - ) - weights_file_input = gr.Button( - '📂', elem_id='open_folder_small', visible=(not headless) - ) - weights_file_input.click( - get_file_path, - outputs=weights, - show_progress=False, - ) - with gr.Row(): - token_string = gr.Textbox( - label='Token string', - placeholder='eg: cat', - ) - init_word = gr.Textbox( - label='Init word', - value='*', - ) - num_vectors_per_token = gr.Slider( - minimum=1, - maximum=75, - value=1, - step=1, - label='Vectors', - ) - max_train_steps = gr.Textbox( - label='Max train steps', - placeholder='(Optional) Maximum number of steps', - ) - template = gr.Dropdown( - label='Template', - choices=[ - 'caption', - 'object template', - 'style template', - ], - value='caption', - ) - basic_training = BasicTraining( - learning_rate_value='1e-5', - lr_scheduler_value='cosine', - lr_warmup_value='10', + with gr.Tab('Training'): + gr.Markdown('Train a TI using kohya textual inversion python code...') + + # Setup Configuration Files Gradio + config = ConfigurationFile(headless) + + source_model = SourceModel( + save_model_as_choices=[ + 'ckpt', + 'safetensors', + ], + headless=headless, ) - with gr.Accordion('Advanced Configuration', open=False): - advanced_training = AdvancedTraining(headless=headless) - advanced_training.color_aug.change( - color_aug_changed, - inputs=[advanced_training.color_aug], - outputs=[basic_training.cache_latents], - ) - sample = SampleImages() + with gr.Tab('Folders'): + folders = Folders(headless=headless) + with gr.Tab('Parameters'): + with gr.Row(): + weights = gr.Textbox( + label='Resume TI training', + placeholder='(Optional) Path to existing TI embeding file to keep training', + ) + weights_file_input = gr.Button( + '📂', elem_id='open_folder_small', visible=(not headless) + ) + weights_file_input.click( + get_file_path, + outputs=weights, + show_progress=False, + ) + with gr.Row(): + token_string = gr.Textbox( + label='Token string', + placeholder='eg: cat', + ) + init_word = gr.Textbox( + label='Init word', + value='*', + ) + num_vectors_per_token = gr.Slider( + minimum=1, + maximum=75, + value=1, + step=1, + label='Vectors', + ) + max_train_steps = gr.Textbox( + label='Max train steps', + placeholder='(Optional) Maximum number of steps', + ) + template = gr.Dropdown( + label='Template', + choices=[ + 'caption', + 'object template', + 'style template', + ], + value='caption', + ) + basic_training = BasicTraining( + learning_rate_value='1e-5', + lr_scheduler_value='cosine', + lr_warmup_value='10', + ) + + # Add SDXL Parameters + sdxl_params = SDXLParameters(source_model.sdxl_checkbox, show_sdxl_cache_text_encoder_outputs=False) + + with gr.Accordion('Advanced Configuration', open=False): + advanced_training = AdvancedTraining(headless=headless) + advanced_training.color_aug.change( + color_aug_changed, + inputs=[advanced_training.color_aug], + outputs=[basic_training.cache_latents], + ) - with gr.Tab('Tools'): - gr.Markdown( - 'This section provide Dreambooth tools to help setup your dataset...' - ) - gradio_dreambooth_folder_creation_tab( - train_data_dir_input=folders.train_data_dir, - reg_data_dir_input=folders.reg_data_dir, - output_dir_input=folders.output_dir, - logging_dir_input=folders.logging_dir, - headless=headless, - ) + sample = SampleImages() - button_run = gr.Button('Train model', variant='primary') + with gr.Tab('Tools'): + gr.Markdown( + 'This section provide Dreambooth tools to help setup your dataset...' + ) + gradio_dreambooth_folder_creation_tab( + train_data_dir_input=folders.train_data_dir, + reg_data_dir_input=folders.reg_data_dir, + output_dir_input=folders.output_dir, + logging_dir_input=folders.logging_dir, + headless=headless, + ) - button_print = gr.Button('Print training command') + button_run = gr.Button('Train model', variant='primary') - # Setup gradio tensorboard buttons - button_start_tensorboard, button_stop_tensorboard = gradio_tensorboard() + button_print = gr.Button('Print training command') - button_start_tensorboard.click( - start_tensorboard, - inputs=folders.logging_dir, - show_progress=False, - ) + # Setup gradio tensorboard buttons + button_start_tensorboard, button_stop_tensorboard = gradio_tensorboard() - button_stop_tensorboard.click( - stop_tensorboard, - show_progress=False, - ) + button_start_tensorboard.click( + start_tensorboard, + inputs=folders.logging_dir, + show_progress=False, + ) - settings_list = [ - source_model.pretrained_model_name_or_path, - source_model.v2, - source_model.v_parameterization, - source_model.sdxl_checkbox, - folders.logging_dir, - folders.train_data_dir, - folders.reg_data_dir, - folders.output_dir, - basic_training.max_resolution, - basic_training.learning_rate, - basic_training.lr_scheduler, - basic_training.lr_warmup, - basic_training.train_batch_size, - basic_training.epoch, - basic_training.save_every_n_epochs, - basic_training.mixed_precision, - basic_training.save_precision, - basic_training.seed, - basic_training.num_cpu_threads_per_process, - basic_training.cache_latents, - basic_training.cache_latents_to_disk, - basic_training.caption_extension, - basic_training.enable_bucket, - advanced_training.gradient_checkpointing, - advanced_training.full_fp16, - advanced_training.no_token_padding, - basic_training.stop_text_encoder_training, - advanced_training.xformers, - source_model.save_model_as, - advanced_training.shuffle_caption, - advanced_training.save_state, - advanced_training.resume, - advanced_training.prior_loss_weight, - advanced_training.color_aug, - advanced_training.flip_aug, - advanced_training.clip_skip, - advanced_training.vae, - folders.output_name, - advanced_training.max_token_length, - advanced_training.max_train_epochs, - advanced_training.max_data_loader_n_workers, - advanced_training.mem_eff_attn, - advanced_training.gradient_accumulation_steps, - source_model.model_list, - token_string, - init_word, - num_vectors_per_token, - max_train_steps, - weights, - template, - advanced_training.keep_tokens, - advanced_training.persistent_data_loader_workers, - advanced_training.bucket_no_upscale, - advanced_training.random_crop, - advanced_training.bucket_reso_steps, - advanced_training.caption_dropout_every_n_epochs, - advanced_training.caption_dropout_rate, - basic_training.optimizer, - basic_training.optimizer_args, - advanced_training.noise_offset_type, - advanced_training.noise_offset, - advanced_training.adaptive_noise_scale, - advanced_training.multires_noise_iterations, - advanced_training.multires_noise_discount, - sample.sample_every_n_steps, - sample.sample_every_n_epochs, - sample.sample_sampler, - sample.sample_prompts, - advanced_training.additional_parameters, - advanced_training.vae_batch_size, - advanced_training.min_snr_gamma, - advanced_training.save_every_n_steps, - advanced_training.save_last_n_steps, - advanced_training.save_last_n_steps_state, - advanced_training.use_wandb, - advanced_training.wandb_api_key, - advanced_training.scale_v_pred_loss_like_noise_pred, - advanced_training.min_timestep, - advanced_training.max_timestep - ] + button_stop_tensorboard.click( + stop_tensorboard, + show_progress=False, + ) - config.button_open_config.click( - open_configuration, - inputs=[dummy_db_true, config.config_file_name] + settings_list, - outputs=[config.config_file_name] + settings_list, - show_progress=False, - ) + settings_list = [ + source_model.pretrained_model_name_or_path, + source_model.v2, + source_model.v_parameterization, + source_model.sdxl_checkbox, + folders.logging_dir, + folders.train_data_dir, + folders.reg_data_dir, + folders.output_dir, + basic_training.max_resolution, + basic_training.learning_rate, + basic_training.lr_scheduler, + basic_training.lr_warmup, + basic_training.train_batch_size, + basic_training.epoch, + basic_training.save_every_n_epochs, + basic_training.mixed_precision, + basic_training.save_precision, + basic_training.seed, + basic_training.num_cpu_threads_per_process, + basic_training.cache_latents, + basic_training.cache_latents_to_disk, + basic_training.caption_extension, + basic_training.enable_bucket, + advanced_training.gradient_checkpointing, + advanced_training.full_fp16, + advanced_training.no_token_padding, + basic_training.stop_text_encoder_training, + advanced_training.xformers, + source_model.save_model_as, + advanced_training.shuffle_caption, + advanced_training.save_state, + advanced_training.resume, + advanced_training.prior_loss_weight, + advanced_training.color_aug, + advanced_training.flip_aug, + advanced_training.clip_skip, + advanced_training.vae, + folders.output_name, + advanced_training.max_token_length, + advanced_training.max_train_epochs, + advanced_training.max_data_loader_n_workers, + advanced_training.mem_eff_attn, + advanced_training.gradient_accumulation_steps, + source_model.model_list, + token_string, + init_word, + num_vectors_per_token, + max_train_steps, + weights, + template, + advanced_training.keep_tokens, + advanced_training.persistent_data_loader_workers, + advanced_training.bucket_no_upscale, + advanced_training.random_crop, + advanced_training.bucket_reso_steps, + advanced_training.caption_dropout_every_n_epochs, + advanced_training.caption_dropout_rate, + basic_training.optimizer, + basic_training.optimizer_args, + advanced_training.noise_offset_type, + advanced_training.noise_offset, + advanced_training.adaptive_noise_scale, + advanced_training.multires_noise_iterations, + advanced_training.multires_noise_discount, + sample.sample_every_n_steps, + sample.sample_every_n_epochs, + sample.sample_sampler, + sample.sample_prompts, + advanced_training.additional_parameters, + advanced_training.vae_batch_size, + advanced_training.min_snr_gamma, + advanced_training.save_every_n_steps, + advanced_training.save_last_n_steps, + advanced_training.save_last_n_steps_state, + advanced_training.use_wandb, + advanced_training.wandb_api_key, + advanced_training.scale_v_pred_loss_like_noise_pred, + advanced_training.min_timestep, + advanced_training.max_timestep, + sdxl_params.sdxl_no_half_vae, + ] + + config.button_open_config.click( + open_configuration, + inputs=[dummy_db_true, config.config_file_name] + settings_list, + outputs=[config.config_file_name] + settings_list, + show_progress=False, + ) - config.button_load_config.click( - open_configuration, - inputs=[dummy_db_false, config.config_file_name] + settings_list, - outputs=[config.config_file_name] + settings_list, - show_progress=False, - ) + config.button_load_config.click( + open_configuration, + inputs=[dummy_db_false, config.config_file_name] + settings_list, + outputs=[config.config_file_name] + settings_list, + show_progress=False, + ) - config.button_save_config.click( - save_configuration, - inputs=[dummy_db_false, config.config_file_name] + settings_list, - outputs=[config.config_file_name], - show_progress=False, - ) + config.button_save_config.click( + save_configuration, + inputs=[dummy_db_false, config.config_file_name] + settings_list, + outputs=[config.config_file_name], + show_progress=False, + ) - config.button_save_as_config.click( - save_configuration, - inputs=[dummy_db_true, config.config_file_name] + settings_list, - outputs=[config.config_file_name], - show_progress=False, - ) + config.button_save_as_config.click( + save_configuration, + inputs=[dummy_db_true, config.config_file_name] + settings_list, + outputs=[config.config_file_name], + show_progress=False, + ) - button_run.click( - train_model, - inputs=[dummy_headless] + [dummy_db_false] + settings_list, - show_progress=False, - ) + button_run.click( + train_model, + inputs=[dummy_headless] + [dummy_db_false] + settings_list, + show_progress=False, + ) - button_print.click( - train_model, - inputs=[dummy_headless] + [dummy_db_true] + settings_list, - show_progress=False, - ) + button_print.click( + train_model, + inputs=[dummy_headless] + [dummy_db_true] + settings_list, + show_progress=False, + ) - return ( - folders.train_data_dir, - folders.reg_data_dir, - folders.output_dir, - folders.logging_dir, - ) + return ( + folders.train_data_dir, + folders.reg_data_dir, + folders.output_dir, + folders.logging_dir, + ) def UI(**kwargs): From 15c33d945548a777e215cc89b221fb2a96a5971c Mon Sep 17 00:00:00 2001 From: bmaltais Date: Tue, 11 Jul 2023 19:53:26 -0400 Subject: [PATCH 19/20] Update torch choice for windows --- README.md | 2 +- setup/setup_windows.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 96b2968bb..8f01bf7d7 100644 --- a/README.md +++ b/README.md @@ -460,7 +460,7 @@ If you come across a `FileNotFoundError`, it is likely due to an installation is ## Change History -* 2023/07/10 (v21.8.1) +* 2023/07/11 (v21.8.1) - Let Tensorboard works in docker #1137 - Fix for accelerate issue - Add SDXL TI training support diff --git a/setup/setup_windows.py b/setup/setup_windows.py index 179f0f44e..8f2c1720a 100644 --- a/setup/setup_windows.py +++ b/setup/setup_windows.py @@ -165,8 +165,8 @@ def main_menu(): if choice == '1': while True: - print('1. Torch 1') - print('2. Torch 2') + print('1. Torch 1 (legacy)') + print('2. Torch 2 (recommended)') print('3. Cancel') choice_torch = input('\nEnter your choice: ') print('') From b602aa2ec0dd5d8c47b07a4bc239ab802fdb30ca Mon Sep 17 00:00:00 2001 From: bmaltais Date: Tue, 11 Jul 2023 19:54:47 -0400 Subject: [PATCH 20/20] Update version --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 8f01bf7d7..4f7744420 100644 --- a/README.md +++ b/README.md @@ -460,7 +460,7 @@ If you come across a `FileNotFoundError`, it is likely due to an installation is ## Change History -* 2023/07/11 (v21.8.1) +* 2023/07/11 (v21.8.2) - Let Tensorboard works in docker #1137 - Fix for accelerate issue - Add SDXL TI training support