Skip to content

Commit

Permalink
enable cache_latents when _to_disk #438
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Apr 24, 2023
1 parent 9bb52ac commit 1890535
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2185,6 +2185,12 @@ def verify_training_args(args: argparse.Namespace):
if args.v2 and args.clip_skip is not None:
print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません")

if args.cache_latents_to_disk and not args.cache_latents:
args.cache_latents = True
print(
"cache_latents_to_disk is enabled, so cache_latents is also enabled / cache_latents_to_diskが有効なため、cache_latentsを有効にします"
)


def add_dataset_arguments(
parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool, support_caption_dropout: bool
Expand Down Expand Up @@ -2963,7 +2969,7 @@ def get_remove_step_no(args: argparse.Namespace, step_no: int):

# last_n_steps前のstep_noから、save_every_n_stepsの倍数のstep_noを計算して削除する
# save_every_n_steps=10, save_last_n_steps=30の場合、50step目には30step分残し、10step目を削除する
remove_step_no = step_no - args.save_last_n_steps - 1
remove_step_no = step_no - args.save_last_n_steps - 1
remove_step_no = remove_step_no - (remove_step_no % args.save_every_n_steps)
if remove_step_no < 0:
return None
Expand Down Expand Up @@ -3005,7 +3011,7 @@ def save_sd_model_on_epoch_end_or_stepwise(
os.makedirs(args.output_dir, exist_ok=True)
if save_stable_diffusion_format:
ext = ".safetensors" if use_safetensors else ".ckpt"

if on_epoch_end:
ckpt_name = get_epoch_ckpt_name(args, ext, epoch_no)
else:
Expand Down

0 comments on commit 1890535

Please sign in to comment.