From 2c731418add79c303213ea884eb3d66bfe6b19d7 Mon Sep 17 00:00:00 2001 From: Yuta Hayashibe Date: Sun, 29 Oct 2023 22:08:42 +0900 Subject: [PATCH] Added sample_images() for --sample_at_first --- fine_tune.py | 3 +++ train_controlnet.py | 14 ++++++++++++++ train_db.py | 2 ++ train_network.py | 4 +++- train_textual_inversion.py | 14 ++++++++++++++ 5 files changed, 36 insertions(+), 1 deletion(-) diff --git a/fine_tune.py b/fine_tune.py index a86a483a0..597678403 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -303,6 +303,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 + # For --sample_at_first + train_util.sample_images(accelerator, args, epoch, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + for m in training_models: m.train() diff --git a/train_controlnet.py b/train_controlnet.py index bbd915cb3..d054d32eb 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -373,6 +373,20 @@ def remove_model(old_ckpt_name): # training loop for epoch in range(num_train_epochs): + # For --sample_at_first + train_util.sample_images( + accelerator, + args, + epoch, + global_step, + accelerator.device, + vae, + tokenizer, + text_encoder, + unet, + controlnet=controlnet, + ) + if is_main_process: accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 diff --git a/train_db.py b/train_db.py index fd8e466e5..443cc5bfc 100644 --- a/train_db.py +++ b/train_db.py @@ -279,6 +279,8 @@ def train(args): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 + train_util.sample_images(accelerator, args, epoch, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + # 指定したステップ数までText Encoderを学習する:epoch最初の状態 unet.train() # train==True is required to enable gradient_checkpointing diff --git a/train_network.py b/train_network.py index d50916b74..bf6597236 100644 --- a/train_network.py +++ b/train_network.py @@ -749,7 +749,9 @@ def remove_model(old_ckpt_name): current_epoch.value = epoch + 1 metadata["ss_epoch"] = str(epoch + 1) - + + # For --sample_at_first + self.sample_images(accelerator, args, epoch, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) network.on_epoch_start(text_encoder, unet) for step, batch in enumerate(train_dataloader): diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 6b6e7f5a0..2a347afa2 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -534,6 +534,20 @@ def remove_model(old_ckpt_name): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 + # For --sample_at_first + self.sample_images( + accelerator, + args, + epoch, + global_step, + accelerator.device, + vae, + tokenizer_or_list, + text_encoder_or_list, + unet, + prompt_replacement, + ) + for text_encoder in text_encoders: text_encoder.train()