Skip to content

Commit

Permalink
Fix issue with lora_network_weights not being loaded (#2357)
Browse files Browse the repository at this point in the history
  • Loading branch information
bmaltais committed Apr 20, 2024
1 parent bcd20a0 commit 58e57a3
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 29 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,7 @@ The `gui.bat` and `gui.sh` scripts now include the `--do_not_use_shell` argument
#### Miscellaneous
- Made various other minor improvements and bug fixes to enhance overall functionality and user experience.
- Fixed an issue with existing LoRA network weights were not properly loaded prior to training
### 2024/04/10 (v23.1.5)
Expand Down
10 changes: 9 additions & 1 deletion kohya_gui/common_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,15 @@ def update_my_data(my_data):
pass

my_data.pop(key, None)



# Replace the lora_network_weights key with network_weights keeping the original value
for key in ["lora_network_weights"]:
value = my_data.get(key) # Get original value
if value is not None: # Check if the key exists in the dictionary
my_data["network_weights"] = value
my_data.pop(key, None)

return my_data


Expand Down
14 changes: 7 additions & 7 deletions kohya_gui/dreambooth_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,8 +669,8 @@ def train_model(
# def save_huggingface_to_toml(self, toml_file_path: str):
config_toml_data = {
# Update the values in the TOML data
"async_upload": async_upload,
"adaptive_noise_scale": adaptive_noise_scale if not 0 else None,
"async_upload": async_upload,
"bucket_no_upscale": bucket_no_upscale,
"bucket_reso_steps": bucket_reso_steps,
"cache_latents": cache_latents,
Expand All @@ -686,18 +686,17 @@ def train_model(
"enable_bucket": enable_bucket,
"epoch": int(epoch),
"flip_aug": flip_aug,
"masked_loss": masked_loss,
"full_bf16": full_bf16,
"full_fp16": full_fp16,
"gradient_accumulation_steps": int(gradient_accumulation_steps),
"gradient_checkpointing": gradient_checkpointing,
"huber_c": huber_c,
"huber_schedule": huber_schedule,
"huggingface_path_in_repo": huggingface_path_in_repo,
"huggingface_repo_id": huggingface_repo_id,
"huggingface_token": huggingface_token,
"huggingface_repo_type": huggingface_repo_type,
"huggingface_repo_visibility": huggingface_repo_visibility,
"huggingface_path_in_repo": huggingface_path_in_repo,
"huggingface_token": huggingface_token,
"ip_noise_gamma": ip_noise_gamma if ip_noise_gamma != 0 else None,
"ip_noise_gamma_random_strength": ip_noise_gamma_random_strength,
"keep_tokens": int(keep_tokens),
Expand All @@ -712,8 +711,9 @@ def train_model(
learning_rate_te2 if sdxl and not 0 else None
), # only for sdxl and not 0
"logging_dir": logging_dir,
"log_tracker_name": log_tracker_name,
"log_tracker_config": log_tracker_config,
"log_tracker_name": log_tracker_name,
"log_with": log_with,
"loss_type": loss_type,
"lr_scheduler": lr_scheduler,
"lr_scheduler_args": str(lr_scheduler_args).replace('"', "").split(),
Expand All @@ -722,6 +722,7 @@ def train_model(
),
"lr_scheduler_power": lr_scheduler_power,
"lr_warmup_steps": lr_warmup_steps,
"masked_loss": masked_loss,
"max_bucket_reso": max_bucket_reso,
"max_timestep": max_timestep if max_timestep != 0 else None,
"max_token_length": int(max_token_length),
Expand All @@ -743,12 +744,12 @@ def train_model(
"noise_offset": noise_offset if not 0 else None,
"noise_offset_random_strength": noise_offset_random_strength,
"noise_offset_type": noise_offset_type,
"optimizer_type": optimizer,
"optimizer_args": (
str(optimizer_args).replace('"', "").split()
if optimizer_args != ""
else None
),
"optimizer_type": optimizer,
"output_dir": output_dir,
"output_name": output_name,
"persistent_data_loader_workers": persistent_data_loader_workers,
Expand Down Expand Up @@ -789,7 +790,6 @@ def train_model(
),
"train_batch_size": train_batch_size,
"train_data_dir": train_data_dir,
"log_with": log_with,
"v2": v2,
"v_parameterization": v_parameterization,
"v_pred_like_loss": v_pred_like_loss if v_pred_like_loss != 0 else None,
Expand Down
4 changes: 2 additions & 2 deletions kohya_gui/finetune_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,10 +746,10 @@ def train_model(

config_toml_data = {
# Update the values in the TOML data
"async_upload": async_upload,
"adaptive_noise_scale": (
adaptive_noise_scale if adaptive_noise_scale != 0 else None
),
"async_upload": async_upload,
"block_lr": block_lr,
"bucket_no_upscale": bucket_no_upscale,
"bucket_reso_steps": bucket_reso_steps,
Expand All @@ -767,7 +767,6 @@ def train_model(
"dynamo_backend": dynamo_backend,
"enable_bucket": True,
"flip_aug": flip_aug,
"masked_loss": masked_loss,
"full_bf16": full_bf16,
"full_fp16": full_fp16,
"gradient_accumulation_steps": int(gradient_accumulation_steps),
Expand Down Expand Up @@ -800,6 +799,7 @@ def train_model(
"lr_scheduler": lr_scheduler,
"lr_scheduler_args": str(lr_scheduler_args).replace('"', "").split(),
"lr_warmup_steps": lr_warmup_steps,
"masked_loss": masked_loss,
"max_bucket_reso": int(max_bucket_reso),
"max_timestep": max_timestep if max_timestep != 0 else None,
"max_token_length": int(max_token_length),
Expand Down
36 changes: 18 additions & 18 deletions kohya_gui/lora_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def save_configuration(
text_encoder_lr,
unet_lr,
network_dim,
lora_network_weights,
network_weights,
dim_from_weights,
color_aug,
flip_aug,
Expand Down Expand Up @@ -310,7 +310,7 @@ def open_configuration(
text_encoder_lr,
unet_lr,
network_dim,
lora_network_weights,
network_weights,
dim_from_weights,
color_aug,
flip_aug,
Expand Down Expand Up @@ -545,7 +545,7 @@ def train_model(
text_encoder_lr,
unet_lr,
network_dim,
lora_network_weights,
network_weights,
dim_from_weights,
color_aug,
flip_aug,
Expand Down Expand Up @@ -689,7 +689,7 @@ def train_model(
log_tracker_config=log_tracker_config,
resume=resume,
vae=vae,
lora_network_weights=lora_network_weights,
network_weights=network_weights,
dataset_config=dataset_config,
):
return TRAIN_BUTTON_VISIBLE
Expand Down Expand Up @@ -1024,10 +1024,10 @@ def train_model(
network_train_unet_only = text_encoder_lr_float == 0 and unet_lr_float != 0

config_toml_data = {
"async_upload": async_upload,
"adaptive_noise_scale": (
adaptive_noise_scale if adaptive_noise_scale != 0 else None
),
"async_upload": async_upload,
"bucket_no_upscale": bucket_no_upscale,
"bucket_reso_steps": bucket_reso_steps,
"cache_latents": cache_latents,
Expand All @@ -1047,7 +1047,6 @@ def train_model(
"enable_bucket": enable_bucket,
"epoch": int(epoch),
"flip_aug": flip_aug,
"masked_loss": masked_loss,
"fp8_base": fp8_base,
"full_bf16": full_bf16,
"full_fp16": full_fp16,
Expand All @@ -1067,7 +1066,6 @@ def train_model(
"logging_dir": logging_dir,
"log_tracker_name": log_tracker_name,
"log_tracker_config": log_tracker_config,
"lora_network_weights": lora_network_weights,
"loss_type": loss_type,
"lr_scheduler": lr_scheduler,
"lr_scheduler_args": str(lr_scheduler_args).replace('"', "").split(),
Expand All @@ -1076,6 +1074,7 @@ def train_model(
),
"lr_scheduler_power": lr_scheduler_power,
"lr_warmup_steps": lr_warmup_steps,
"masked_loss": masked_loss,
"max_bucket_reso": max_bucket_reso,
"max_grad_norm": max_grad_norm,
"max_timestep": max_timestep if max_timestep != 0 else None,
Expand Down Expand Up @@ -1103,6 +1102,7 @@ def train_model(
"network_module": network_module,
"network_train_unet_only": network_train_unet_only,
"network_train_text_encoder_only": network_train_text_encoder_only,
"network_weights": network_weights,
"no_half_vae": True if sdxl and sdxl_no_half_vae else None,
"noise_offset": noise_offset if noise_offset != 0 else None,
"noise_offset_random_strength": noise_offset_random_strength,
Expand Down Expand Up @@ -1360,21 +1360,21 @@ def list_presets(path):
)
with gr.Group():
with gr.Row():
lora_network_weights = gr.Textbox(
label="LoRA network weights",
network_weights = gr.Textbox(
label="Network weights",
placeholder="(Optional)",
info="Path to an existing LoRA network weights to resume training from",
)
lora_network_weights_file = gr.Button(
network_weights_file = gr.Button(
document_symbol,
elem_id="open_folder_small",
elem_classes=["tool"],
visible=(not headless),
)
lora_network_weights_file.click(
network_weights_file.click(
get_any_file_path,
inputs=[lora_network_weights],
outputs=lora_network_weights,
inputs=[network_weights],
outputs=network_weights,
show_progress=False,
)
dim_from_weights = gr.Checkbox(
Expand Down Expand Up @@ -1627,7 +1627,7 @@ def update_LoRA_settings(
},
},
},
"lora_network_weights": {
"network_weights": {
"gr_type": gr.Textbox,
"update_params": {
"visible": LoRA_type
Expand All @@ -1647,7 +1647,7 @@ def update_LoRA_settings(
},
},
},
"lora_network_weights_file": {
"network_weights_file": {
"gr_type": gr.Button,
"update_params": {
"visible": LoRA_type
Expand Down Expand Up @@ -2061,8 +2061,8 @@ def update_LoRA_settings(
network_row,
convolution_row,
kohya_advanced_lora,
lora_network_weights,
lora_network_weights_file,
network_weights,
network_weights_file,
dim_from_weights,
factor,
conv_dim,
Expand Down Expand Up @@ -2140,7 +2140,7 @@ def update_LoRA_settings(
text_encoder_lr,
unet_lr,
network_dim,
lora_network_weights,
network_weights,
dim_from_weights,
advanced_training.color_aug,
advanced_training.flip_aug,
Expand Down
2 changes: 1 addition & 1 deletion kohya_gui/textual_inversion_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,10 +695,10 @@ def train_model(
# def save_huggingface_to_toml(self, toml_file_path: str):
config_toml_data = {
# Update the values in the TOML data
"async_upload": async_upload,
"adaptive_noise_scale": (
adaptive_noise_scale if adaptive_noise_scale != 0 else None
),
"async_upload": async_upload,
"bucket_no_upscale": bucket_no_upscale,
"bucket_reso_steps": bucket_reso_steps,
"cache_latents": cache_latents,
Expand Down

0 comments on commit 58e57a3

Please sign in to comment.