Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

v22.3.0 #1745

Merged
merged 20 commits into from
Dec 6, 2023
Merged

v22.3.0 #1745

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .release
Original file line number Diff line number Diff line change
@@ -1 +1 @@
v22.2.2
v22.3.0
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -651,10 +651,24 @@


## Change History
* 2023/12/06 (v22.3.0)
- Merge sd-scripts updates:
- `finetune\tag_images_by_wd14_tagger.py` now supports the separator other than `,` with `--caption_separator` option. Thanks to KohakuBlueleaf! PR [#913](https://github.com/kohya-ss/sd-scripts/pull/913)
- Min SNR Gamma with V-predicition (SD 2.1) is fixed. Thanks to feffy380! PR[#934](https://github.com/kohya-ss/sd-scripts/pull/934)
- See [#673](https://github.com/kohya-ss/sd-scripts/issues/673) for details.
- `--min_diff` and `--clamp_quantile` options are added to `networks/extract_lora_from_models.py`. Thanks to wkpark! PR [#936](https://github.com/kohya-ss/sd-scripts/pull/936)
- The default values are same as the previous version.
- Deep Shrink hires fix is supported in `sdxl_gen_img.py` and `gen_img_diffusers.py`.
- `--ds_timesteps_1` and `--ds_timesteps_2` options denote the timesteps of the Deep Shrink for the first and second stages.
- `--ds_depth_1` and `--ds_depth_2` options denote the depth (block index) of the Deep Shrink for the first and second stages.
- `--ds_ratio` option denotes the ratio of the Deep Shrink. `0.5` means the half of the original latent size for the Deep Shrink.
- `--dst1`, `--dst2`, `--dsd1`, `--dsd2` and `--dsr` prompt options are also available.
- Add GLoRA support

* 2023/12/03 (v22.2.2)
- Update Lycoris module to 2.0.0 (https://github.com/KohakuBlueleaf/LyCORIS/blob/0006e2ffa05a48d8818112d9f70da74c0cd30b99/README.md)
- Update Lycoris merge and extract tools
- Remove anoying warning about local pip modules that is not necessary.

Check warning on line 671 in README.md

View workflow job for this annotation

GitHub Actions / build

"anoying" should be "annoying".

Check warning on line 671 in README.md

View workflow job for this annotation

GitHub Actions / build

"anoying" should be "annoying".
- Adding support for LyCORIS presets
- Adding Support for LyCORIS Native Fine-Tuning
- Adding support for Lycoris Diag-OFT
Expand Down Expand Up @@ -687,5 +701,5 @@
- kohya_ss gui updates:
- Implement GUI support for SDXL finetune TE1 and TE2 training LR parameters and for non SDXL finetune TE training parameter
- Implement GUI support for Dreambooth TE LR parameter
- Implement Debiased Estimation loss at the botom of the Advanced Parameters tab.

Check warning on line 704 in README.md

View workflow job for this annotation

GitHub Actions / build

"botom" should be "bottom".

Check warning on line 704 in README.md

View workflow job for this annotation

GitHub Actions / build

"botom" should be "bottom".

2 changes: 1 addition & 1 deletion fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
loss = loss.mean([1, 2, 3])

if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.debiased_estimation_loss:
Expand Down
24 changes: 16 additions & 8 deletions finetune/tag_images_by_wd14_tagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,9 @@ def main(args):

tag_freq = {}

undesired_tags = set(args.undesired_tags.split(","))
caption_separator = args.caption_separator
stripped_caption_separator = caption_separator.strip()
undesired_tags = set(args.undesired_tags.split(stripped_caption_separator))

def run_batch(path_imgs):
imgs = np.array([im for _, im in path_imgs])
Expand Down Expand Up @@ -194,7 +196,7 @@ def run_batch(path_imgs):

if tag_name not in undesired_tags:
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
general_tag_text += ", " + tag_name
general_tag_text += caption_separator + tag_name
combined_tags.append(tag_name)
elif i >= len(general_tags) and p >= args.character_threshold:
tag_name = character_tags[i - len(general_tags)]
Expand All @@ -203,18 +205,18 @@ def run_batch(path_imgs):

if tag_name not in undesired_tags:
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
character_tag_text += ", " + tag_name
character_tag_text += caption_separator + tag_name
combined_tags.append(tag_name)

# 先頭のカンマを取る
if len(general_tag_text) > 0:
general_tag_text = general_tag_text[2:]
general_tag_text = general_tag_text[len(caption_separator) :]
if len(character_tag_text) > 0:
character_tag_text = character_tag_text[2:]
character_tag_text = character_tag_text[len(caption_separator) :]

caption_file = os.path.splitext(image_path)[0] + args.caption_extension

tag_text = ", ".join(combined_tags)
tag_text = caption_separator.join(combined_tags)

if args.append_tags:
# Check if file exists
Expand All @@ -224,13 +226,13 @@ def run_batch(path_imgs):
existing_content = f.read().strip("\n") # Remove newlines

# Split the content into tags and store them in a list
existing_tags = [tag.strip() for tag in existing_content.split(",") if tag.strip()]
existing_tags = [tag.strip() for tag in existing_content.split(stripped_caption_separator) if tag.strip()]

# Check and remove repeating tags in tag_text
new_tags = [tag for tag in combined_tags if tag not in existing_tags]

# Create new tag_text
tag_text = ", ".join(existing_tags + new_tags)
tag_text = caption_separator.join(existing_tags + new_tags)

with open(caption_file, "wt", encoding="utf-8") as f:
f.write(tag_text + "\n")
Expand Down Expand Up @@ -350,6 +352,12 @@ def setup_parser() -> argparse.ArgumentParser:
parser.add_argument("--frequency_tags", action="store_true", help="Show frequency of tags for images / 画像ごとのタグの出現頻度を表示する")
parser.add_argument("--onnx", action="store_true", help="use onnx model for inference / onnxモデルを推論に使用する")
parser.add_argument("--append_tags", action="store_true", help="Append captions instead of overwriting / 上書きではなくキャプションを追記する")
parser.add_argument(
"--caption_separator",
type=str,
default=", ",
help="Separator for captions, include space if needed / キャプションの区切り文字、必要ならスペースを含めてください",
)

return parser

Expand Down
88 changes: 86 additions & 2 deletions gen_img_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@
from networks.lora import LoRANetwork
import tools.original_control_net as original_control_net
from tools.original_control_net import ControlNetInfo
from library.original_unet import UNet2DConditionModel
from library.original_unet import UNet2DConditionModel, InferUNet2DConditionModel
from library.original_unet import FlashAttentionFunction

from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI
Expand Down Expand Up @@ -378,7 +378,7 @@ def __init__(
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
unet: InferUNet2DConditionModel,
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
clip_skip: int,
clip_model: CLIPModel,
Expand Down Expand Up @@ -2196,6 +2196,7 @@ def main(args):
)
original_unet.load_state_dict(unet.state_dict())
unet = original_unet
unet: InferUNet2DConditionModel = InferUNet2DConditionModel(unet)

# VAEを読み込む
if args.vae is not None:
Expand Down Expand Up @@ -2352,13 +2353,20 @@ def __getattr__(self, item):
vae = sli_vae
del sli_vae
vae.to(dtype).to(device)
vae.eval()

text_encoder.to(dtype).to(device)
unet.to(dtype).to(device)

text_encoder.eval()
unet.eval()

if clip_model is not None:
clip_model.to(dtype).to(device)
clip_model.eval()
if vgg16_model is not None:
vgg16_model.to(dtype).to(device)
vgg16_model.eval()

# networkを組み込む
if args.network_module:
Expand Down Expand Up @@ -2501,6 +2509,10 @@ def __getattr__(self, item):
if args.diffusers_xformers:
pipe.enable_xformers_memory_efficient_attention()

# Deep Shrink
if args.ds_depth_1 is not None:
unet.set_deep_shrink(args.ds_depth_1, args.ds_timesteps_1, args.ds_depth_2, args.ds_timesteps_2, args.ds_ratio)

# Extended Textual Inversion および Textual Inversionを処理する
if args.XTI_embeddings:
diffusers.models.UNet2DConditionModel.forward = unet_forward_XTI
Expand Down Expand Up @@ -3085,6 +3097,13 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False):
clip_prompt = None
network_muls = None

# Deep Shrink
ds_depth_1 = None # means no override
ds_timesteps_1 = args.ds_timesteps_1
ds_depth_2 = args.ds_depth_2
ds_timesteps_2 = args.ds_timesteps_2
ds_ratio = args.ds_ratio

prompt_args = raw_prompt.strip().split(" --")
prompt = prompt_args[0]
print(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}")
Expand Down Expand Up @@ -3156,10 +3175,51 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False):
print(f"network mul: {network_muls}")
continue

# Deep Shrink
m = re.match(r"dsd1 ([\d\.]+)", parg, re.IGNORECASE)
if m: # deep shrink depth 1
ds_depth_1 = int(m.group(1))
print(f"deep shrink depth 1: {ds_depth_1}")
continue

m = re.match(r"dst1 ([\d\.]+)", parg, re.IGNORECASE)
if m: # deep shrink timesteps 1
ds_timesteps_1 = int(m.group(1))
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
print(f"deep shrink timesteps 1: {ds_timesteps_1}")
continue

m = re.match(r"dsd2 ([\d\.]+)", parg, re.IGNORECASE)
if m: # deep shrink depth 2
ds_depth_2 = int(m.group(1))
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
print(f"deep shrink depth 2: {ds_depth_2}")
continue

m = re.match(r"dst2 ([\d\.]+)", parg, re.IGNORECASE)
if m: # deep shrink timesteps 2
ds_timesteps_2 = int(m.group(1))
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
print(f"deep shrink timesteps 2: {ds_timesteps_2}")
continue

m = re.match(r"dsr ([\d\.]+)", parg, re.IGNORECASE)
if m: # deep shrink ratio
ds_ratio = float(m.group(1))
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
print(f"deep shrink ratio: {ds_ratio}")
continue

except ValueError as ex:
print(f"Exception in parsing / 解析エラー: {parg}")
print(ex)

# override Deep Shrink
if ds_depth_1 is not None:
if ds_depth_1 < 0:
ds_depth_1 = args.ds_depth_1 or 3
unet.set_deep_shrink(ds_depth_1, ds_timesteps_1, ds_depth_2, ds_timesteps_2, ds_ratio)

# prepare seed
if seeds is not None: # given in prompt
# 数が足りないなら前のをそのまま使う
Expand Down Expand Up @@ -3509,6 +3569,30 @@ def setup_parser() -> argparse.ArgumentParser:
# "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像"
# )

# Deep Shrink
parser.add_argument(
"--ds_depth_1",
type=int,
default=None,
help="Enable Deep Shrink with this depth 1, valid values are 0 to 3 / Deep Shrinkをこのdepthで有効にする",
)
parser.add_argument(
"--ds_timesteps_1",
type=int,
default=650,
help="Apply Deep Shrink depth 1 until this timesteps / Deep Shrink depth 1を適用するtimesteps",
)
parser.add_argument("--ds_depth_2", type=int, default=None, help="Deep Shrink depth 2 / Deep Shrinkのdepth 2")
parser.add_argument(
"--ds_timesteps_2",
type=int,
default=650,
help="Apply Deep Shrink depth 2 until this timesteps / Deep Shrink depth 2を適用するtimesteps",
)
parser.add_argument(
"--ds_ratio", type=float, default=0.5, help="Deep Shrink ratio for downsampling / Deep Shrinkのdownsampling比率"
)

return parser


Expand Down
1 change: 1 addition & 0 deletions library/config_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class BaseSubsetParams:
image_dir: Optional[str] = None
num_repeats: int = 1
shuffle_caption: bool = False
caption_separator: str = ',',
keep_tokens: int = 0
color_aug: bool = False
flip_aug: bool = False
Expand Down
9 changes: 6 additions & 3 deletions library/custom_train_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,13 @@ def enforce_zero_terminal_snr(betas):
noise_scheduler.alphas_cumprod = alphas_cumprod


def apply_snr_weight(loss, timesteps, noise_scheduler, gamma):
def apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False):
snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])
gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr)
snr_weight = torch.minimum(gamma_over_snr, torch.ones_like(gamma_over_snr)).float().to(loss.device) # from paper
min_snr_gamma = torch.minimum(snr, torch.full_like(snr, gamma))
if v_prediction:
snr_weight = torch.div(min_snr_gamma, snr+1).float().to(loss.device)
else:
snr_weight = torch.div(min_snr_gamma, snr).float().to(loss.device)
loss = loss * snr_weight
return loss

Expand Down
Loading
Loading