From fea810b437e0b4ea448e0ffa7d5933437bac6cae Mon Sep 17 00:00:00 2001 From: Yuta Hayashibe Date: Sun, 29 Oct 2023 21:44:57 +0900 Subject: [PATCH 1/3] Added --sample_at_first to generate sample images before training --- library/train_util.py | 19 +++++++++++++------ sdxl_train.py | 13 +++++++++++++ 2 files changed, 26 insertions(+), 6 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 0f5033413..926e956ca 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2968,6 +2968,9 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: parser.add_argument( "--sample_every_n_steps", type=int, default=None, help="generate sample images every N steps / 学習中のモデルで指定ステップごとにサンプル出力する" ) + parser.add_argument( + "--sample_at_first", action='store_true', help="generate sample images before training / 学習前にサンプル出力する" + ) parser.add_argument( "--sample_every_n_epochs", type=int, @@ -4429,15 +4432,19 @@ def sample_images_common( """ StableDiffusionLongPromptWeightingPipelineの改造版を使うようにしたので、clip skipおよびプロンプトの重みづけに対応した """ - if args.sample_every_n_steps is None and args.sample_every_n_epochs is None: - return - if args.sample_every_n_epochs is not None: - # sample_every_n_steps は無視する - if epoch is None or epoch % args.sample_every_n_epochs != 0: + if steps == 0: + if not args.sample_at_first: return else: - if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch + if args.sample_every_n_steps is None and args.sample_every_n_epochs is None: return + if args.sample_every_n_epochs is not None: + # sample_every_n_steps は無視する + if epoch is None or epoch % args.sample_every_n_epochs != 0: + return + else: + if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch + return print(f"\ngenerating sample images at step / サンプル画像生成 ステップ: {steps}") if not os.path.isfile(args.sample_prompts): diff --git a/sdxl_train.py b/sdxl_train.py index 47bc6a420..a25da42d1 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -477,6 +477,19 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 + # For --sample_at_first + sdxl_train_util.sample_images( + accelerator, + args, + epoch, + global_step, + accelerator.device, + vae, + [tokenizer1, tokenizer2], + [text_encoder1, text_encoder2], + unet, + ) + for m in training_models: m.train() From 5c150675bf1a4a0153b7fa404515ebe76f3e1698 Mon Sep 17 00:00:00 2001 From: Yuta Hayashibe Date: Sun, 29 Oct 2023 21:46:47 +0900 Subject: [PATCH 2/3] Added --sample_at_first description --- docs/train_README-ja.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/train_README-ja.md b/docs/train_README-ja.md index c871f0769..d186bf243 100644 --- a/docs/train_README-ja.md +++ b/docs/train_README-ja.md @@ -374,6 +374,10 @@ classがひとつで対象が複数の場合、正則化画像フォルダはひ サンプル出力するステップ数またはエポック数を指定します。この数ごとにサンプル出力します。両方指定するとエポック数が優先されます。 +- `--sample_at_first` + + 学習開始前にサンプル出力します。学習前との比較ができます。 + - `--sample_prompts` サンプル出力用プロンプトのファイルを指定します。 From 2c731418add79c303213ea884eb3d66bfe6b19d7 Mon Sep 17 00:00:00 2001 From: Yuta Hayashibe Date: Sun, 29 Oct 2023 22:08:42 +0900 Subject: [PATCH 3/3] Added sample_images() for --sample_at_first --- fine_tune.py | 3 +++ train_controlnet.py | 14 ++++++++++++++ train_db.py | 2 ++ train_network.py | 4 +++- train_textual_inversion.py | 14 ++++++++++++++ 5 files changed, 36 insertions(+), 1 deletion(-) diff --git a/fine_tune.py b/fine_tune.py index a86a483a0..597678403 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -303,6 +303,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 + # For --sample_at_first + train_util.sample_images(accelerator, args, epoch, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + for m in training_models: m.train() diff --git a/train_controlnet.py b/train_controlnet.py index bbd915cb3..d054d32eb 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -373,6 +373,20 @@ def remove_model(old_ckpt_name): # training loop for epoch in range(num_train_epochs): + # For --sample_at_first + train_util.sample_images( + accelerator, + args, + epoch, + global_step, + accelerator.device, + vae, + tokenizer, + text_encoder, + unet, + controlnet=controlnet, + ) + if is_main_process: accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 diff --git a/train_db.py b/train_db.py index fd8e466e5..443cc5bfc 100644 --- a/train_db.py +++ b/train_db.py @@ -279,6 +279,8 @@ def train(args): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 + train_util.sample_images(accelerator, args, epoch, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + # 指定したステップ数までText Encoderを学習する:epoch最初の状態 unet.train() # train==True is required to enable gradient_checkpointing diff --git a/train_network.py b/train_network.py index d50916b74..bf6597236 100644 --- a/train_network.py +++ b/train_network.py @@ -749,7 +749,9 @@ def remove_model(old_ckpt_name): current_epoch.value = epoch + 1 metadata["ss_epoch"] = str(epoch + 1) - + + # For --sample_at_first + self.sample_images(accelerator, args, epoch, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) network.on_epoch_start(text_encoder, unet) for step, batch in enumerate(train_dataloader): diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 6b6e7f5a0..2a347afa2 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -534,6 +534,20 @@ def remove_model(old_ckpt_name): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 + # For --sample_at_first + self.sample_images( + accelerator, + args, + epoch, + global_step, + accelerator.device, + vae, + tokenizer_or_list, + text_encoder_or_list, + unet, + prompt_replacement, + ) + for text_encoder in text_encoders: text_encoder.train()