From 58e57a336583078fc6444254fff18c9a31c10445 Mon Sep 17 00:00:00 2001 From: bmaltais Date: Sat, 20 Apr 2024 15:05:35 -0400 Subject: [PATCH] Fix issue with lora_network_weights not being loaded (#2357) --- README.md | 1 + kohya_gui/common_gui.py | 10 ++++++++- kohya_gui/dreambooth_gui.py | 14 ++++++------ kohya_gui/finetune_gui.py | 4 ++-- kohya_gui/lora_gui.py | 36 +++++++++++++++--------------- kohya_gui/textual_inversion_gui.py | 2 +- 6 files changed, 38 insertions(+), 29 deletions(-) diff --git a/README.md b/README.md index 394868a70..e556dc9f0 100644 --- a/README.md +++ b/README.md @@ -480,6 +480,7 @@ The `gui.bat` and `gui.sh` scripts now include the `--do_not_use_shell` argument #### Miscellaneous - Made various other minor improvements and bug fixes to enhance overall functionality and user experience. +- Fixed an issue with existing LoRA network weights were not properly loaded prior to training ### 2024/04/10 (v23.1.5) diff --git a/kohya_gui/common_gui.py b/kohya_gui/common_gui.py index a00d3e681..411501a77 100644 --- a/kohya_gui/common_gui.py +++ b/kohya_gui/common_gui.py @@ -429,7 +429,15 @@ def update_my_data(my_data): pass my_data.pop(key, None) - + + + # Replace the lora_network_weights key with network_weights keeping the original value + for key in ["lora_network_weights"]: + value = my_data.get(key) # Get original value + if value is not None: # Check if the key exists in the dictionary + my_data["network_weights"] = value + my_data.pop(key, None) + return my_data diff --git a/kohya_gui/dreambooth_gui.py b/kohya_gui/dreambooth_gui.py index fbc5d21f2..d2dfd8ba5 100644 --- a/kohya_gui/dreambooth_gui.py +++ b/kohya_gui/dreambooth_gui.py @@ -669,8 +669,8 @@ def train_model( # def save_huggingface_to_toml(self, toml_file_path: str): config_toml_data = { # Update the values in the TOML data - "async_upload": async_upload, "adaptive_noise_scale": adaptive_noise_scale if not 0 else None, + "async_upload": async_upload, "bucket_no_upscale": bucket_no_upscale, "bucket_reso_steps": bucket_reso_steps, "cache_latents": cache_latents, @@ -686,18 +686,17 @@ def train_model( "enable_bucket": enable_bucket, "epoch": int(epoch), "flip_aug": flip_aug, - "masked_loss": masked_loss, "full_bf16": full_bf16, "full_fp16": full_fp16, "gradient_accumulation_steps": int(gradient_accumulation_steps), "gradient_checkpointing": gradient_checkpointing, "huber_c": huber_c, "huber_schedule": huber_schedule, + "huggingface_path_in_repo": huggingface_path_in_repo, "huggingface_repo_id": huggingface_repo_id, - "huggingface_token": huggingface_token, "huggingface_repo_type": huggingface_repo_type, "huggingface_repo_visibility": huggingface_repo_visibility, - "huggingface_path_in_repo": huggingface_path_in_repo, + "huggingface_token": huggingface_token, "ip_noise_gamma": ip_noise_gamma if ip_noise_gamma != 0 else None, "ip_noise_gamma_random_strength": ip_noise_gamma_random_strength, "keep_tokens": int(keep_tokens), @@ -712,8 +711,9 @@ def train_model( learning_rate_te2 if sdxl and not 0 else None ), # only for sdxl and not 0 "logging_dir": logging_dir, - "log_tracker_name": log_tracker_name, "log_tracker_config": log_tracker_config, + "log_tracker_name": log_tracker_name, + "log_with": log_with, "loss_type": loss_type, "lr_scheduler": lr_scheduler, "lr_scheduler_args": str(lr_scheduler_args).replace('"', "").split(), @@ -722,6 +722,7 @@ def train_model( ), "lr_scheduler_power": lr_scheduler_power, "lr_warmup_steps": lr_warmup_steps, + "masked_loss": masked_loss, "max_bucket_reso": max_bucket_reso, "max_timestep": max_timestep if max_timestep != 0 else None, "max_token_length": int(max_token_length), @@ -743,12 +744,12 @@ def train_model( "noise_offset": noise_offset if not 0 else None, "noise_offset_random_strength": noise_offset_random_strength, "noise_offset_type": noise_offset_type, - "optimizer_type": optimizer, "optimizer_args": ( str(optimizer_args).replace('"', "").split() if optimizer_args != "" else None ), + "optimizer_type": optimizer, "output_dir": output_dir, "output_name": output_name, "persistent_data_loader_workers": persistent_data_loader_workers, @@ -789,7 +790,6 @@ def train_model( ), "train_batch_size": train_batch_size, "train_data_dir": train_data_dir, - "log_with": log_with, "v2": v2, "v_parameterization": v_parameterization, "v_pred_like_loss": v_pred_like_loss if v_pred_like_loss != 0 else None, diff --git a/kohya_gui/finetune_gui.py b/kohya_gui/finetune_gui.py index bfb6d0e16..743bb01cd 100644 --- a/kohya_gui/finetune_gui.py +++ b/kohya_gui/finetune_gui.py @@ -746,10 +746,10 @@ def train_model( config_toml_data = { # Update the values in the TOML data - "async_upload": async_upload, "adaptive_noise_scale": ( adaptive_noise_scale if adaptive_noise_scale != 0 else None ), + "async_upload": async_upload, "block_lr": block_lr, "bucket_no_upscale": bucket_no_upscale, "bucket_reso_steps": bucket_reso_steps, @@ -767,7 +767,6 @@ def train_model( "dynamo_backend": dynamo_backend, "enable_bucket": True, "flip_aug": flip_aug, - "masked_loss": masked_loss, "full_bf16": full_bf16, "full_fp16": full_fp16, "gradient_accumulation_steps": int(gradient_accumulation_steps), @@ -800,6 +799,7 @@ def train_model( "lr_scheduler": lr_scheduler, "lr_scheduler_args": str(lr_scheduler_args).replace('"', "").split(), "lr_warmup_steps": lr_warmup_steps, + "masked_loss": masked_loss, "max_bucket_reso": int(max_bucket_reso), "max_timestep": max_timestep if max_timestep != 0 else None, "max_token_length": int(max_token_length), diff --git a/kohya_gui/lora_gui.py b/kohya_gui/lora_gui.py index cc8a918a7..d68df6f18 100644 --- a/kohya_gui/lora_gui.py +++ b/kohya_gui/lora_gui.py @@ -105,7 +105,7 @@ def save_configuration( text_encoder_lr, unet_lr, network_dim, - lora_network_weights, + network_weights, dim_from_weights, color_aug, flip_aug, @@ -310,7 +310,7 @@ def open_configuration( text_encoder_lr, unet_lr, network_dim, - lora_network_weights, + network_weights, dim_from_weights, color_aug, flip_aug, @@ -545,7 +545,7 @@ def train_model( text_encoder_lr, unet_lr, network_dim, - lora_network_weights, + network_weights, dim_from_weights, color_aug, flip_aug, @@ -689,7 +689,7 @@ def train_model( log_tracker_config=log_tracker_config, resume=resume, vae=vae, - lora_network_weights=lora_network_weights, + network_weights=network_weights, dataset_config=dataset_config, ): return TRAIN_BUTTON_VISIBLE @@ -1024,10 +1024,10 @@ def train_model( network_train_unet_only = text_encoder_lr_float == 0 and unet_lr_float != 0 config_toml_data = { - "async_upload": async_upload, "adaptive_noise_scale": ( adaptive_noise_scale if adaptive_noise_scale != 0 else None ), + "async_upload": async_upload, "bucket_no_upscale": bucket_no_upscale, "bucket_reso_steps": bucket_reso_steps, "cache_latents": cache_latents, @@ -1047,7 +1047,6 @@ def train_model( "enable_bucket": enable_bucket, "epoch": int(epoch), "flip_aug": flip_aug, - "masked_loss": masked_loss, "fp8_base": fp8_base, "full_bf16": full_bf16, "full_fp16": full_fp16, @@ -1067,7 +1066,6 @@ def train_model( "logging_dir": logging_dir, "log_tracker_name": log_tracker_name, "log_tracker_config": log_tracker_config, - "lora_network_weights": lora_network_weights, "loss_type": loss_type, "lr_scheduler": lr_scheduler, "lr_scheduler_args": str(lr_scheduler_args).replace('"', "").split(), @@ -1076,6 +1074,7 @@ def train_model( ), "lr_scheduler_power": lr_scheduler_power, "lr_warmup_steps": lr_warmup_steps, + "masked_loss": masked_loss, "max_bucket_reso": max_bucket_reso, "max_grad_norm": max_grad_norm, "max_timestep": max_timestep if max_timestep != 0 else None, @@ -1103,6 +1102,7 @@ def train_model( "network_module": network_module, "network_train_unet_only": network_train_unet_only, "network_train_text_encoder_only": network_train_text_encoder_only, + "network_weights": network_weights, "no_half_vae": True if sdxl and sdxl_no_half_vae else None, "noise_offset": noise_offset if noise_offset != 0 else None, "noise_offset_random_strength": noise_offset_random_strength, @@ -1360,21 +1360,21 @@ def list_presets(path): ) with gr.Group(): with gr.Row(): - lora_network_weights = gr.Textbox( - label="LoRA network weights", + network_weights = gr.Textbox( + label="Network weights", placeholder="(Optional)", info="Path to an existing LoRA network weights to resume training from", ) - lora_network_weights_file = gr.Button( + network_weights_file = gr.Button( document_symbol, elem_id="open_folder_small", elem_classes=["tool"], visible=(not headless), ) - lora_network_weights_file.click( + network_weights_file.click( get_any_file_path, - inputs=[lora_network_weights], - outputs=lora_network_weights, + inputs=[network_weights], + outputs=network_weights, show_progress=False, ) dim_from_weights = gr.Checkbox( @@ -1627,7 +1627,7 @@ def update_LoRA_settings( }, }, }, - "lora_network_weights": { + "network_weights": { "gr_type": gr.Textbox, "update_params": { "visible": LoRA_type @@ -1647,7 +1647,7 @@ def update_LoRA_settings( }, }, }, - "lora_network_weights_file": { + "network_weights_file": { "gr_type": gr.Button, "update_params": { "visible": LoRA_type @@ -2061,8 +2061,8 @@ def update_LoRA_settings( network_row, convolution_row, kohya_advanced_lora, - lora_network_weights, - lora_network_weights_file, + network_weights, + network_weights_file, dim_from_weights, factor, conv_dim, @@ -2140,7 +2140,7 @@ def update_LoRA_settings( text_encoder_lr, unet_lr, network_dim, - lora_network_weights, + network_weights, dim_from_weights, advanced_training.color_aug, advanced_training.flip_aug, diff --git a/kohya_gui/textual_inversion_gui.py b/kohya_gui/textual_inversion_gui.py index d3daa8067..0e46368c8 100644 --- a/kohya_gui/textual_inversion_gui.py +++ b/kohya_gui/textual_inversion_gui.py @@ -695,10 +695,10 @@ def train_model( # def save_huggingface_to_toml(self, toml_file_path: str): config_toml_data = { # Update the values in the TOML data - "async_upload": async_upload, "adaptive_noise_scale": ( adaptive_noise_scale if adaptive_noise_scale != 0 else None ), + "async_upload": async_upload, "bucket_no_upscale": bucket_no_upscale, "bucket_reso_steps": bucket_reso_steps, "cache_latents": cache_latents,