Skip to content

Commit

Permalink
fix duplicated sample gen for every epoch ref #907
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Dec 7, 2023
1 parent db84530 commit 912dca8
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 53 deletions.
6 changes: 3 additions & 3 deletions fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,14 +295,14 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)

# For --sample_at_first
train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)

loss_recorder = train_util.LossRecorder()
for epoch in range(num_train_epochs):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1

# For --sample_at_first
train_util.sample_images(accelerator, args, epoch, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)

for m in training_models:
m.train()

Expand Down
18 changes: 5 additions & 13 deletions sdxl_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,24 +458,16 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)

# For --sample_at_first
sdxl_train_util.sample_images(
accelerator, args, 0, global_step, accelerator.device, vae, [tokenizer1, tokenizer2], [text_encoder1, text_encoder2], unet
)

loss_recorder = train_util.LossRecorder()
for epoch in range(num_train_epochs):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1

# For --sample_at_first
sdxl_train_util.sample_images(
accelerator,
args,
epoch,
global_step,
accelerator.device,
vae,
[tokenizer1, tokenizer2],
[text_encoder1, text_encoder2],
unet,
)

for m in training_models:
m.train()

Expand Down
26 changes: 11 additions & 15 deletions train_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,13 @@

from tqdm import tqdm
import torch

try:
import intel_extension_for_pytorch as ipex

if torch.xpu.is_available():
from library.ipex import ipex_init

ipex_init()
except Exception:
pass
Expand Down Expand Up @@ -335,7 +338,9 @@ def train(args):
init_kwargs = {}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers("controlnet_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
accelerator.init_trackers(
"controlnet_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
)

loss_recorder = train_util.LossRecorder()
del train_dataset_group
Expand Down Expand Up @@ -371,22 +376,13 @@ def remove_model(old_ckpt_name):
accelerator.print(f"removing old checkpoint: {old_ckpt_file}")
os.remove(old_ckpt_file)

# For --sample_at_first
train_util.sample_images(
accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, controlnet=controlnet
)

# training loop
for epoch in range(num_train_epochs):
# For --sample_at_first
train_util.sample_images(
accelerator,
args,
epoch,
global_step,
accelerator.device,
vae,
tokenizer,
text_encoder,
unet,
controlnet=controlnet,
)

if is_main_process:
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1
Expand Down
5 changes: 3 additions & 2 deletions train_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,13 +272,14 @@ def train(args):
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)

# For --sample_at_first
train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)

loss_recorder = train_util.LossRecorder()
for epoch in range(num_train_epochs):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1

train_util.sample_images(accelerator, args, epoch, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)

# 指定したステップ数までText Encoderを学習する:epoch最初の状態
unet.train()
# train==True is required to enable gradient_checkpointing
Expand Down
11 changes: 5 additions & 6 deletions train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,9 +409,7 @@ def train(self, args):
else:
for t_enc in text_encoders:
t_enc.to(accelerator.device, dtype=weight_dtype)
network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
network, optimizer, train_dataloader, lr_scheduler
)
network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(network, optimizer, train_dataloader, lr_scheduler)

if args.gradient_checkpointing:
# according to TI example in Diffusers, train is required
Expand Down Expand Up @@ -725,15 +723,16 @@ def remove_model(old_ckpt_name):
accelerator.print(f"removing old checkpoint: {old_ckpt_file}")
os.remove(old_ckpt_file)

# For --sample_at_first
self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)

# training loop
for epoch in range(num_train_epochs):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1

metadata["ss_epoch"] = str(epoch + 1)

# For --sample_at_first
self.sample_images(accelerator, args, epoch, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet)

for step, batch in enumerate(train_dataloader):
Expand Down Expand Up @@ -807,7 +806,7 @@ def remove_model(old_ckpt_name):
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし

accelerator.backward(loss)
self.all_reduce_network(accelerator, network) # sync DDP grad manually
self.all_reduce_network(accelerator, network) # sync DDP grad manually
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
params_to_clip = accelerator.unwrap_model(network).get_trainable_params()
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
Expand Down
31 changes: 17 additions & 14 deletions train_textual_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@

from tqdm import tqdm
import torch

try:
import intel_extension_for_pytorch as ipex

if torch.xpu.is_available():
from library.ipex import ipex_init

ipex_init()
except Exception:
pass
Expand Down Expand Up @@ -525,25 +528,25 @@ def remove_model(old_ckpt_name):
accelerator.print(f"removing old checkpoint: {old_ckpt_file}")
os.remove(old_ckpt_file)

# For --sample_at_first
self.sample_images(
accelerator,
args,
0,
global_step,
accelerator.device,
vae,
tokenizer_or_list,
text_encoder_or_list,
unet,
prompt_replacement,
)

# training loop
for epoch in range(num_train_epochs):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1

# For --sample_at_first
self.sample_images(
accelerator,
args,
epoch,
global_step,
accelerator.device,
vae,
tokenizer_or_list,
text_encoder_or_list,
unet,
prompt_replacement,
)

for text_encoder in text_encoders:
text_encoder.train()

Expand Down

0 comments on commit 912dca8

Please sign in to comment.