Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option --sample_at_first #907

Merged
merged 3 commits into from
Dec 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/train_README-ja.md
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,10 @@ classがひとつで対象が複数の場合、正則化画像フォルダはひ

サンプル出力するステップ数またはエポック数を指定します。この数ごとにサンプル出力します。両方指定するとエポック数が優先されます。

- `--sample_at_first`

学習開始前にサンプル出力します。学習前との比較ができます。

- `--sample_prompts`

サンプル出力用プロンプトのファイルを指定します。
Expand Down
3 changes: 3 additions & 0 deletions fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1

# For --sample_at_first
train_util.sample_images(accelerator, args, epoch, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)

for m in training_models:
m.train()

Expand Down
19 changes: 13 additions & 6 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2968,6 +2968,9 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
parser.add_argument(
"--sample_every_n_steps", type=int, default=None, help="generate sample images every N steps / 学習中のモデルで指定ステップごとにサンプル出力する"
)
parser.add_argument(
"--sample_at_first", action='store_true', help="generate sample images before training / 学習前にサンプル出力する"
)
parser.add_argument(
"--sample_every_n_epochs",
type=int,
Expand Down Expand Up @@ -4429,15 +4432,19 @@ def sample_images_common(
"""
StableDiffusionLongPromptWeightingPipelineの改造版を使うようにしたので、clip skipおよびプロンプトの重みづけに対応した
"""
if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
return
if args.sample_every_n_epochs is not None:
# sample_every_n_steps は無視する
if epoch is None or epoch % args.sample_every_n_epochs != 0:
if steps == 0:
if not args.sample_at_first:
return
else:
if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch
if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
return
if args.sample_every_n_epochs is not None:
# sample_every_n_steps は無視する
if epoch is None or epoch % args.sample_every_n_epochs != 0:
return
else:
if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch
return

print(f"\ngenerating sample images at step / サンプル画像生成 ステップ: {steps}")
if not os.path.isfile(args.sample_prompts):
Expand Down
13 changes: 13 additions & 0 deletions sdxl_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,19 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1

# For --sample_at_first
sdxl_train_util.sample_images(
accelerator,
args,
epoch,
global_step,
accelerator.device,
vae,
[tokenizer1, tokenizer2],
[text_encoder1, text_encoder2],
unet,
)

for m in training_models:
m.train()

Expand Down
14 changes: 14 additions & 0 deletions train_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,20 @@ def remove_model(old_ckpt_name):

# training loop
for epoch in range(num_train_epochs):
# For --sample_at_first
train_util.sample_images(
accelerator,
args,
epoch,
global_step,
accelerator.device,
vae,
tokenizer,
text_encoder,
unet,
controlnet=controlnet,
)

if is_main_process:
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1
Expand Down
2 changes: 2 additions & 0 deletions train_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,8 @@ def train(args):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1

train_util.sample_images(accelerator, args, epoch, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)

# 指定したステップ数までText Encoderを学習する:epoch最初の状態
unet.train()
# train==True is required to enable gradient_checkpointing
Expand Down
4 changes: 3 additions & 1 deletion train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,7 +749,9 @@ def remove_model(old_ckpt_name):
current_epoch.value = epoch + 1

metadata["ss_epoch"] = str(epoch + 1)


# For --sample_at_first
self.sample_images(accelerator, args, epoch, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
network.on_epoch_start(text_encoder, unet)

for step, batch in enumerate(train_dataloader):
Expand Down
14 changes: 14 additions & 0 deletions train_textual_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,20 @@ def remove_model(old_ckpt_name):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1

# For --sample_at_first
self.sample_images(
accelerator,
args,
epoch,
global_step,
accelerator.device,
vae,
tokenizer_or_list,
text_encoder_or_list,
unet,
prompt_replacement,
)

for text_encoder in text_encoders:
text_encoder.train()

Expand Down