diff --git a/docs/train_README-ja.md b/docs/train_README-ja.md index c871f0769..d186bf243 100644 --- a/docs/train_README-ja.md +++ b/docs/train_README-ja.md @@ -374,6 +374,10 @@ classがひとつで対象が複数の場合、正則化画像フォルダはひ サンプル出力するステップ数またはエポック数を指定します。この数ごとにサンプル出力します。両方指定するとエポック数が優先されます。 +- `--sample_at_first` + + 学習開始前にサンプル出力します。学習前との比較ができます。 + - `--sample_prompts` サンプル出力用プロンプトのファイルを指定します。 diff --git a/fine_tune.py b/fine_tune.py index b07876776..3f3da5b57 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -303,6 +303,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 + # For --sample_at_first + train_util.sample_images(accelerator, args, epoch, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + for m in training_models: m.train() diff --git a/library/train_util.py b/library/train_util.py index 510708245..a94562a33 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2979,6 +2979,9 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: parser.add_argument( "--sample_every_n_steps", type=int, default=None, help="generate sample images every N steps / 学習中のモデルで指定ステップごとにサンプル出力する" ) + parser.add_argument( + "--sample_at_first", action='store_true', help="generate sample images before training / 学習前にサンプル出力する" + ) parser.add_argument( "--sample_every_n_epochs", type=int, @@ -4576,15 +4579,19 @@ def sample_images_common( """ StableDiffusionLongPromptWeightingPipelineの改造版を使うようにしたので、clip skipおよびプロンプトの重みづけに対応した """ - if args.sample_every_n_steps is None and args.sample_every_n_epochs is None: - return - if args.sample_every_n_epochs is not None: - # sample_every_n_steps は無視する - if epoch is None or epoch % args.sample_every_n_epochs != 0: + if steps == 0: + if not args.sample_at_first: return else: - if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch + if args.sample_every_n_steps is None and args.sample_every_n_epochs is None: return + if args.sample_every_n_epochs is not None: + # sample_every_n_steps は無視する + if epoch is None or epoch % args.sample_every_n_epochs != 0: + return + else: + if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch + return print(f"\ngenerating sample images at step / サンプル画像生成 ステップ: {steps}") if not os.path.isfile(args.sample_prompts): diff --git a/sdxl_train.py b/sdxl_train.py index fd775624e..45e290be6 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -466,6 +466,19 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 + # For --sample_at_first + sdxl_train_util.sample_images( + accelerator, + args, + epoch, + global_step, + accelerator.device, + vae, + [tokenizer1, tokenizer2], + [text_encoder1, text_encoder2], + unet, + ) + for m in training_models: m.train() diff --git a/train_controlnet.py b/train_controlnet.py index e0118d1c5..c45083625 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -373,6 +373,20 @@ def remove_model(old_ckpt_name): # training loop for epoch in range(num_train_epochs): + # For --sample_at_first + train_util.sample_images( + accelerator, + args, + epoch, + global_step, + accelerator.device, + vae, + tokenizer, + text_encoder, + unet, + controlnet=controlnet, + ) + if is_main_process: accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 diff --git a/train_db.py b/train_db.py index 966999dfb..4eabed0f4 100644 --- a/train_db.py +++ b/train_db.py @@ -279,6 +279,8 @@ def train(args): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 + train_util.sample_images(accelerator, args, epoch, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + # 指定したステップ数までText Encoderを学習する:epoch最初の状態 unet.train() # train==True is required to enable gradient_checkpointing diff --git a/train_network.py b/train_network.py index 1cbed2e7b..c9a37fb6a 100644 --- a/train_network.py +++ b/train_network.py @@ -749,7 +749,9 @@ def remove_model(old_ckpt_name): current_epoch.value = epoch + 1 metadata["ss_epoch"] = str(epoch + 1) - + + # For --sample_at_first + self.sample_images(accelerator, args, epoch, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) network.on_epoch_start(text_encoder, unet) for step, batch in enumerate(train_dataloader): diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 45a437b91..3b0aec24f 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -534,6 +534,20 @@ def remove_model(old_ckpt_name): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 + # For --sample_at_first + self.sample_images( + accelerator, + args, + epoch, + global_step, + accelerator.device, + vae, + tokenizer_or_list, + text_encoder_or_list, + unet, + prompt_replacement, + ) + for text_encoder in text_encoders: text_encoder.train()