Skip to content

Commit

Permalink
Added sample_images() for --sample_at_first
Browse files Browse the repository at this point in the history
  • Loading branch information
shirayu committed Oct 29, 2023
1 parent 5c15067 commit 2c73141
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 1 deletion.
3 changes: 3 additions & 0 deletions fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1

# For --sample_at_first
train_util.sample_images(accelerator, args, epoch, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)

for m in training_models:
m.train()

Expand Down
14 changes: 14 additions & 0 deletions train_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,20 @@ def remove_model(old_ckpt_name):

# training loop
for epoch in range(num_train_epochs):
# For --sample_at_first
train_util.sample_images(
accelerator,
args,
epoch,
global_step,
accelerator.device,
vae,
tokenizer,
text_encoder,
unet,
controlnet=controlnet,
)

if is_main_process:
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1
Expand Down
2 changes: 2 additions & 0 deletions train_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,8 @@ def train(args):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1

train_util.sample_images(accelerator, args, epoch, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)

# 指定したステップ数までText Encoderを学習する:epoch最初の状態
unet.train()
# train==True is required to enable gradient_checkpointing
Expand Down
4 changes: 3 additions & 1 deletion train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,7 +749,9 @@ def remove_model(old_ckpt_name):
current_epoch.value = epoch + 1

metadata["ss_epoch"] = str(epoch + 1)


# For --sample_at_first
self.sample_images(accelerator, args, epoch, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
network.on_epoch_start(text_encoder, unet)

for step, batch in enumerate(train_dataloader):
Expand Down
14 changes: 14 additions & 0 deletions train_textual_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,20 @@ def remove_model(old_ckpt_name):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1

# For --sample_at_first
self.sample_images(
accelerator,
args,
epoch,
global_step,
accelerator.device,
vae,
tokenizer_or_list,
text_encoder_or_list,
unet,
prompt_replacement,
)

for text_encoder in text_encoders:
text_encoder.train()

Expand Down

0 comments on commit 2c73141

Please sign in to comment.