Skip to content

Commit

Permalink
Debias Estimation loss (#889)
Browse files Browse the repository at this point in the history
* update for bnb 0.41.1

* fixed generate_controlnet_subsets_config for training

* Revert "update for bnb 0.41.1"

This reverts commit 70bd3612d84778d491fc8006b8b9f9e21c4d2eb8.

* add debiased_estimation_loss

* add train_network

* Revert "add train_network"

This reverts commit 6539363c5c13a3e63fc0e52adf7fc26fb566d491.

* Update train_network.py
  • Loading branch information
sdbds authored Oct 23, 2023
1 parent 681034d commit 202f2c3
Show file tree
Hide file tree
Showing 9 changed files with 38 additions and 2 deletions.
5 changes: 4 additions & 1 deletion fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
get_weighted_text_embeddings,
prepare_scheduler_for_custom_training,
scale_v_prediction_loss_like_noise_prediction,
apply_debiased_estimation,
)


Expand Down Expand Up @@ -339,7 +340,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
else:
target = noise

if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred:
if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.debiased_estimation_loss,:
# do not mean over batch dimension for snr weight or scale v-pred loss
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = loss.mean([1, 2, 3])
Expand All @@ -348,6 +349,8 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.debiased_estimation_loss:
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)

loss = loss.mean() # mean over batch dimension
else:
Expand Down
11 changes: 11 additions & 0 deletions library/custom_train_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,12 @@ def add_v_prediction_like_loss(loss, timesteps, noise_scheduler, v_pred_like_los
loss = loss + loss / scale * v_pred_like_loss
return loss

def apply_debiased_estimation(loss, timesteps, noise_scheduler):
snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
weight = 1/torch.sqrt(snr_t)
loss = weight * loss
return loss

# TODO train_utilと分散しているのでどちらかに寄せる

Expand All @@ -108,6 +114,11 @@ def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted
default=None,
help="add v-prediction like loss multiplied by this value / v-prediction lossをこの値をかけたものをlossに加算する",
)
parser.add_argument(
"--debiased_estimation_loss",
action="store_true",
help="debiased estimation loss / debiased estimation loss",
)
if support_weighted_captions:
parser.add_argument(
"--weighted_captions",
Expand Down
5 changes: 4 additions & 1 deletion sdxl_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
prepare_scheduler_for_custom_training,
scale_v_prediction_loss_like_noise_prediction,
add_v_prediction_like_loss,
apply_debiased_estimation,
)
from library.sdxl_original_unet import SdxlUNet2DConditionModel

Expand Down Expand Up @@ -548,7 +549,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):

target = noise

if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.v_pred_like_loss:
if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.v_pred_like_loss or args.debiased_estimation_loss:
# do not mean over batch dimension for snr weight or scale v-pred loss
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = loss.mean([1, 2, 3])
Expand All @@ -559,6 +560,8 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.v_pred_like_loss:
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
if args.debiased_estimation_loss:
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)

loss = loss.mean() # mean over batch dimension
else:
Expand Down
3 changes: 3 additions & 0 deletions sdxl_train_control_net_lllite.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
pyramid_noise_like,
apply_noise_offset,
scale_v_prediction_loss_like_noise_prediction,
apply_debiased_estimation,
)
import networks.control_net_lllite_for_train as control_net_lllite_for_train

Expand Down Expand Up @@ -465,6 +466,8 @@ def remove_model(old_ckpt_name):
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.v_pred_like_loss:
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
if args.debiased_estimation_loss:
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)

loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし

Expand Down
3 changes: 3 additions & 0 deletions sdxl_train_control_net_lllite_old.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
pyramid_noise_like,
apply_noise_offset,
scale_v_prediction_loss_like_noise_prediction,
apply_debiased_estimation,
)
import networks.control_net_lllite as control_net_lllite

Expand Down Expand Up @@ -435,6 +436,8 @@ def remove_model(old_ckpt_name):
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.v_pred_like_loss:
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
if args.debiased_estimation_loss:
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)

loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし

Expand Down
3 changes: 3 additions & 0 deletions train_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
pyramid_noise_like,
apply_noise_offset,
scale_v_prediction_loss_like_noise_prediction,
apply_debiased_estimation,
)

# perlin_noise,
Expand Down Expand Up @@ -336,6 +337,8 @@ def train(args):
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.debiased_estimation_loss:
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)

loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし

Expand Down
4 changes: 4 additions & 0 deletions train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
prepare_scheduler_for_custom_training,
scale_v_prediction_loss_like_noise_prediction,
add_v_prediction_like_loss,
apply_debiased_estimation,
)


Expand Down Expand Up @@ -528,6 +529,7 @@ def train(self, args):
"ss_min_snr_gamma": args.min_snr_gamma,
"ss_scale_weight_norms": args.scale_weight_norms,
"ss_ip_noise_gamma": args.ip_noise_gamma,
"ss_debiased_estimation": bool(args.debiased_estimation_loss),
}

if use_user_config:
Expand Down Expand Up @@ -811,6 +813,8 @@ def remove_model(old_ckpt_name):
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.v_pred_like_loss:
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
if args.debiased_estimation_loss:
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)

loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし

Expand Down
3 changes: 3 additions & 0 deletions train_textual_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
prepare_scheduler_for_custom_training,
scale_v_prediction_loss_like_noise_prediction,
add_v_prediction_like_loss,
apply_debiased_estimation,
)

imagenet_templates_small = [
Expand Down Expand Up @@ -582,6 +583,8 @@ def remove_model(old_ckpt_name):
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.v_pred_like_loss:
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
if args.debiased_estimation_loss:
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)

loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし

Expand Down
3 changes: 3 additions & 0 deletions train_textual_inversion_XTI.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
pyramid_noise_like,
apply_noise_offset,
scale_v_prediction_loss_like_noise_prediction,
apply_debiased_estimation,
)
import library.original_unet as original_unet
from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI
Expand Down Expand Up @@ -471,6 +472,8 @@ def remove_model(old_ckpt_name):
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.debiased_estimation_loss:
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)

loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし

Expand Down

0 comments on commit 202f2c3

Please sign in to comment.