Skip to content

Commit

Permalink
Merge pull request #907 from shirayu/add_option_sample_at_first
Browse files Browse the repository at this point in the history
Add option --sample_at_first
  • Loading branch information
kohya-ss committed Dec 3, 2023
2 parents df59822 + 2c73141 commit 383b4a2
Show file tree
Hide file tree
Showing 8 changed files with 66 additions and 7 deletions.
4 changes: 4 additions & 0 deletions docs/train_README-ja.md
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,10 @@ classがひとつで対象が複数の場合、正則化画像フォルダはひ

サンプル出力するステップ数またはエポック数を指定します。この数ごとにサンプル出力します。両方指定するとエポック数が優先されます。

- `--sample_at_first`

学習開始前にサンプル出力します。学習前との比較ができます。

- `--sample_prompts`

サンプル出力用プロンプトのファイルを指定します。
Expand Down
3 changes: 3 additions & 0 deletions fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1

# For --sample_at_first
train_util.sample_images(accelerator, args, epoch, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)

for m in training_models:
m.train()

Expand Down
19 changes: 13 additions & 6 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2979,6 +2979,9 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
parser.add_argument(
"--sample_every_n_steps", type=int, default=None, help="generate sample images every N steps / 学習中のモデルで指定ステップごとにサンプル出力する"
)
parser.add_argument(
"--sample_at_first", action='store_true', help="generate sample images before training / 学習前にサンプル出力する"
)
parser.add_argument(
"--sample_every_n_epochs",
type=int,
Expand Down Expand Up @@ -4576,15 +4579,19 @@ def sample_images_common(
"""
StableDiffusionLongPromptWeightingPipelineの改造版を使うようにしたので、clip skipおよびプロンプトの重みづけに対応した
"""
if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
return
if args.sample_every_n_epochs is not None:
# sample_every_n_steps は無視する
if epoch is None or epoch % args.sample_every_n_epochs != 0:
if steps == 0:
if not args.sample_at_first:
return
else:
if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch
if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
return
if args.sample_every_n_epochs is not None:
# sample_every_n_steps は無視する
if epoch is None or epoch % args.sample_every_n_epochs != 0:
return
else:
if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch
return

print(f"\ngenerating sample images at step / サンプル画像生成 ステップ: {steps}")
if not os.path.isfile(args.sample_prompts):
Expand Down
13 changes: 13 additions & 0 deletions sdxl_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,19 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1

# For --sample_at_first
sdxl_train_util.sample_images(
accelerator,
args,
epoch,
global_step,
accelerator.device,
vae,
[tokenizer1, tokenizer2],
[text_encoder1, text_encoder2],
unet,
)

for m in training_models:
m.train()

Expand Down
14 changes: 14 additions & 0 deletions train_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,20 @@ def remove_model(old_ckpt_name):

# training loop
for epoch in range(num_train_epochs):
# For --sample_at_first
train_util.sample_images(
accelerator,
args,
epoch,
global_step,
accelerator.device,
vae,
tokenizer,
text_encoder,
unet,
controlnet=controlnet,
)

if is_main_process:
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1
Expand Down
2 changes: 2 additions & 0 deletions train_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,8 @@ def train(args):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1

train_util.sample_images(accelerator, args, epoch, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)

# 指定したステップ数までText Encoderを学習する:epoch最初の状態
unet.train()
# train==True is required to enable gradient_checkpointing
Expand Down
4 changes: 3 additions & 1 deletion train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,7 +749,9 @@ def remove_model(old_ckpt_name):
current_epoch.value = epoch + 1

metadata["ss_epoch"] = str(epoch + 1)


# For --sample_at_first
self.sample_images(accelerator, args, epoch, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
network.on_epoch_start(text_encoder, unet)

for step, batch in enumerate(train_dataloader):
Expand Down
14 changes: 14 additions & 0 deletions train_textual_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,20 @@ def remove_model(old_ckpt_name):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1

# For --sample_at_first
self.sample_images(
accelerator,
args,
epoch,
global_step,
accelerator.device,
vae,
tokenizer_or_list,
text_encoder_or_list,
unet,
prompt_replacement,
)

for text_encoder in text_encoders:
text_encoder.train()

Expand Down

0 comments on commit 383b4a2

Please sign in to comment.