Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Huber loss support #2221

Merged
merged 1 commit into from
Apr 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .release
Original file line number Diff line number Diff line change
@@ -1 +1 @@
v23.1.0
v23.1.1
7 changes: 6 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ The GUI allows you to set the training parameters and generate and run the requi
- [SDXL training](#sdxl-training)
- [Masked loss](#masked-loss)
- [Change History](#change-history)
- [2024/04/07 (v23.1.1)](#20240407-v2311)
- [2024/04/07 (v23.1.0)](#20240407-v2310)
- [2024/03/21 (v23.0.15)](#20240321-v23015)
- [2024/03/19 (v23.0.14)](#20240319-v23014)
Expand Down Expand Up @@ -403,6 +404,10 @@ ControlNet dataset is used to specify the mask. The mask images should be the RG

## Change History

### 2024/04/07 (v23.1.1)

- Added support for Huber loss under the Parameters / Advanced tab.

### 2024/04/07 (v23.1.0)

- Update sd-scripts to 0.8.7
Expand Down Expand Up @@ -444,7 +449,7 @@ ControlNet dataset is used to specify the mask. The mask images should be the RG
- See [Dataset config](./docs/config_README-en.md) for details.
- The dataset with DreamBooth method supports caching image information (size, caption). PR [#1178](https://github.com/kohya-ss/sd-scripts/pull/1178) and [#1206](https://github.com/kohya-ss/sd-scripts/pull/1206) Thanks to KohakuBlueleaf! See [DreamBooth method specific options](./docs/config_README-en.md#dreambooth-specific-options) for details.

- Image tagging
- Image tagging (not implemented yet in the GUI)
- The support for v3 repositories is added to `tag_image_by_wd14_tagger.py` (`--onnx` option only). PR [#1192](https://github.com/kohya-ss/sd-scripts/pull/1192) Thanks to sdbds!
- Onnx may need to be updated. Onnx is not installed by default, so please install or update it with `pip install onnx==1.15.0 onnxruntime-gpu==1.17.1` etc. Please also check the comments in `requirements.txt`.
- The model is now saved in the subdirectory as `--repo_id` in `tag_image_by_wd14_tagger.py` . This caches multiple repo_id models. Please delete unnecessary files under `--model_dir`.
Expand Down
3 changes: 3 additions & 0 deletions config example.toml
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,14 @@ full_bf16 = false # Full bf16 training (experimental)
full_fp16 = false # Full fp16 training (experimental)
gradient_accumulation_steps = 1 # Gradient accumulation steps
gradient_checkpointing = false # Gradient checkpointing
huber_c = 0.1 # The huber loss parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type
huber_schedule = "snr" # The type of loss to use and whether it's scheduled based on the timestep
ip_noise_gamma = 0 # IP noise gamma
ip_noise_gamma_random_strength = false # IP noise gamma random strength (true, false)
keep_tokens = 0 # Keep tokens
log_tracker_config_dir = "./logs" # Log tracker configs directory
log_tracker_name = "" # Log tracker name
loss_type = "l2" # Loss type (l2, huber, smooth_l1)
masked_loss = false # Masked loss
max_data_loader_n_workers = "0" # Max data loader n workers (string)
max_timestep = 1000 # Max timestep
Expand Down
27 changes: 27 additions & 0 deletions kohya_gui/class_advanced_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,33 @@ def list_vae_files(path):
placeholder='(Optional) Use to provide additional parameters not handled by the GUI. Eg: --some_parameters "value"',
value=self.config.get("advanced.additional_parameters", ""),
)
with gr.Accordion("Scheduled Huber Loss", open=False):
with gr.Row():
self.loss_type = gr.Dropdown(
label="Loss type",
choices=["huber", "smooth_l1", "l2"],
value=self.config.get("advanced.loss_type", "l2"),
info="The type of loss to use and whether it's scheduled based on the timestep",
)
self.huber_schedule = gr.Dropdown(
label="Huber schedule",
choices=[
"constant",
"exponential",
"snr",
],
value=self.config.get("advanced.huber_schedule", "snr"),
info="The type of loss to use and whether it's scheduled based on the timestep",
)
self.huber_c = gr.Number(
label="Huber C",
value=self.config.get("advanced.huber_c", 0.1),
minimum=0.0,
maximum=1.0,
step=0.01,
info="The huber loss parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type",
)

with gr.Row():
self.save_every_n_steps = gr.Number(
label="Save every N steps",
Expand Down
9 changes: 9 additions & 0 deletions kohya_gui/common_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -1098,6 +1098,12 @@ def run_cmd_advanced_training(**kwargs):

if kwargs.get("gradient_checkpointing"):
run_cmd += " --gradient_checkpointing"

if kwargs.get("huber_c"):
run_cmd += fr' --huber_c="{kwargs.get("huber_c")}"'

if kwargs.get("huber_schedule"):
run_cmd += fr' --huber_schedule="{kwargs.get("huber_schedule")}"'

if kwargs.get("ip_noise_gamma"):
if float(kwargs["ip_noise_gamma"]) > 0:
Expand Down Expand Up @@ -1152,6 +1158,9 @@ def run_cmd_advanced_training(**kwargs):
lora_network_weights = kwargs.get("lora_network_weights")
if lora_network_weights:
run_cmd += f' --network_weights="{lora_network_weights}"' # Yes, the parameter is now called network_weights instead of lora_network_weights

if "loss_type" in kwargs:
run_cmd += fr' --loss_type="{kwargs.get("loss_type")}"'

lr_scheduler = kwargs.get("lr_scheduler")
if lr_scheduler:
Expand Down
15 changes: 15 additions & 0 deletions kohya_gui/dreambooth_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,9 @@ def save_configuration(
sample_sampler,
sample_prompts,
additional_parameters,
loss_type,
huber_schedule,
huber_c,
vae_batch_size,
min_snr_gamma,
weighted_captions,
Expand Down Expand Up @@ -275,6 +278,9 @@ def open_configuration(
sample_sampler,
sample_prompts,
additional_parameters,
loss_type,
huber_schedule,
huber_c,
vae_batch_size,
min_snr_gamma,
weighted_captions,
Expand Down Expand Up @@ -410,6 +416,9 @@ def train_model(
sample_sampler,
sample_prompts,
additional_parameters,
loss_type,
huber_schedule,
huber_c,
vae_batch_size,
min_snr_gamma,
weighted_captions,
Expand Down Expand Up @@ -665,6 +674,9 @@ def train_model(
"weighted_captions": weighted_captions,
"xformers": xformers,
"additional_parameters": additional_parameters,
"loss_type": loss_type,
"huber_schedule": huber_schedule,
"huber_c": huber_c,
}

# Conditionally include specific keyword arguments based on sdxl
Expand Down Expand Up @@ -897,6 +909,9 @@ def dreambooth_tab(
sample.sample_sampler,
sample.sample_prompts,
advanced_training.additional_parameters,
advanced_training.loss_type,
advanced_training.huber_schedule,
advanced_training.huber_c,
advanced_training.vae_batch_size,
advanced_training.min_snr_gamma,
advanced_training.weighted_captions,
Expand Down
15 changes: 15 additions & 0 deletions kohya_gui/finetune_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,9 @@ def save_configuration(
sample_sampler,
sample_prompts,
additional_parameters,
loss_type,
huber_schedule,
huber_c,
vae_batch_size,
min_snr_gamma,
weighted_captions,
Expand Down Expand Up @@ -291,6 +294,9 @@ def open_configuration(
sample_sampler,
sample_prompts,
additional_parameters,
loss_type,
huber_schedule,
huber_c,
vae_batch_size,
min_snr_gamma,
weighted_captions,
Expand Down Expand Up @@ -445,6 +451,9 @@ def train_model(
sample_sampler,
sample_prompts,
additional_parameters,
loss_type,
huber_schedule,
huber_c,
vae_batch_size,
min_snr_gamma,
weighted_captions,
Expand Down Expand Up @@ -695,6 +704,9 @@ def train_model(
"weighted_captions": weighted_captions,
"xformers": xformers,
"additional_parameters": additional_parameters,
"loss_type": loss_type,
"huber_schedule": huber_schedule,
"huber_c": huber_c,
}

# Conditionally include specific keyword arguments based on sdxl_checkbox
Expand Down Expand Up @@ -999,6 +1011,9 @@ def list_presets(path):
sample.sample_sampler,
sample.sample_prompts,
advanced_training.additional_parameters,
advanced_training.loss_type,
advanced_training.huber_schedule,
advanced_training.huber_c,
advanced_training.vae_batch_size,
advanced_training.min_snr_gamma,
weighted_captions,
Expand Down
15 changes: 15 additions & 0 deletions kohya_gui/lora_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,9 @@ def save_configuration(
sample_sampler,
sample_prompts,
additional_parameters,
loss_type,
huber_schedule,
huber_c,
vae_batch_size,
min_snr_gamma,
down_lr_weight,
Expand Down Expand Up @@ -378,6 +381,9 @@ def open_configuration(
sample_sampler,
sample_prompts,
additional_parameters,
loss_type,
huber_schedule,
huber_c,
vae_batch_size,
min_snr_gamma,
down_lr_weight,
Expand Down Expand Up @@ -594,6 +600,9 @@ def train_model(
sample_sampler,
sample_prompts,
additional_parameters,
loss_type,
huber_schedule,
huber_c,
vae_batch_size,
min_snr_gamma,
down_lr_weight,
Expand Down Expand Up @@ -1025,6 +1034,9 @@ def train_model(
"weighted_captions": weighted_captions,
"xformers": xformers,
"additional_parameters": additional_parameters,
"loss_type": loss_type,
"huber_schedule": huber_schedule,
"huber_c": huber_c,
}

# Use the ** syntax to unpack the dictionary when calling the function
Expand Down Expand Up @@ -2042,6 +2054,9 @@ def update_LoRA_settings(
sample.sample_sampler,
sample.sample_prompts,
advanced_training.additional_parameters,
advanced_training.loss_type,
advanced_training.huber_schedule,
advanced_training.huber_c,
advanced_training.vae_batch_size,
advanced_training.min_snr_gamma,
down_lr_weight,
Expand Down
15 changes: 15 additions & 0 deletions kohya_gui/textual_inversion_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,9 @@ def save_configuration(
sample_sampler,
sample_prompts,
additional_parameters,
loss_type,
huber_schedule,
huber_c,
vae_batch_size,
min_snr_gamma,
save_every_n_steps,
Expand Down Expand Up @@ -276,6 +279,9 @@ def open_configuration(
sample_sampler,
sample_prompts,
additional_parameters,
loss_type,
huber_schedule,
huber_c,
vae_batch_size,
min_snr_gamma,
save_every_n_steps,
Expand Down Expand Up @@ -410,6 +416,9 @@ def train_model(
sample_sampler,
sample_prompts,
additional_parameters,
loss_type,
huber_schedule,
huber_c,
vae_batch_size,
min_snr_gamma,
save_every_n_steps,
Expand Down Expand Up @@ -643,6 +652,9 @@ def train_model(
wandb_run_name=wandb_run_name,
xformers=xformers,
additional_parameters=additional_parameters,
loss_type=loss_type,
huber_schedule=huber_schedule,
huber_c=huber_c,
)
run_cmd += f' --token_string="{token_string}"'
run_cmd += f' --init_word="{init_word}"'
Expand Down Expand Up @@ -963,6 +975,9 @@ def list_embedding_files(path):
sample.sample_sampler,
sample.sample_prompts,
advanced_training.additional_parameters,
advanced_training.loss_type,
advanced_training.huber_schedule,
advanced_training.huber_c,
advanced_training.vae_batch_size,
advanced_training.min_snr_gamma,
advanced_training.save_every_n_steps,
Expand Down
Loading