diff --git a/fine_tune.py b/fine_tune.py index 319088cce..f72e618b1 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -295,14 +295,14 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs) + # For --sample_at_first + train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + loss_recorder = train_util.LossRecorder() for epoch in range(num_train_epochs): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 - # For --sample_at_first - train_util.sample_images(accelerator, args, epoch, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) - for m in training_models: m.train() diff --git a/sdxl_train.py b/sdxl_train.py index 65e74b9f9..05ad08788 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -458,24 +458,16 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs) + # For --sample_at_first + sdxl_train_util.sample_images( + accelerator, args, 0, global_step, accelerator.device, vae, [tokenizer1, tokenizer2], [text_encoder1, text_encoder2], unet + ) + loss_recorder = train_util.LossRecorder() for epoch in range(num_train_epochs): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 - # For --sample_at_first - sdxl_train_util.sample_images( - accelerator, - args, - epoch, - global_step, - accelerator.device, - vae, - [tokenizer1, tokenizer2], - [text_encoder1, text_encoder2], - unet, - ) - for m in training_models: m.train() diff --git a/train_controlnet.py b/train_controlnet.py index c45083625..1f3dbae30 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -11,10 +11,13 @@ from tqdm import tqdm import torch + try: import intel_extension_for_pytorch as ipex + if torch.xpu.is_available(): from library.ipex import ipex_init + ipex_init() except Exception: pass @@ -335,7 +338,9 @@ def train(args): init_kwargs = {} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) - accelerator.init_trackers("controlnet_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs) + accelerator.init_trackers( + "controlnet_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs + ) loss_recorder = train_util.LossRecorder() del train_dataset_group @@ -371,22 +376,13 @@ def remove_model(old_ckpt_name): accelerator.print(f"removing old checkpoint: {old_ckpt_file}") os.remove(old_ckpt_file) + # For --sample_at_first + train_util.sample_images( + accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, controlnet=controlnet + ) + # training loop for epoch in range(num_train_epochs): - # For --sample_at_first - train_util.sample_images( - accelerator, - args, - epoch, - global_step, - accelerator.device, - vae, - tokenizer, - text_encoder, - unet, - controlnet=controlnet, - ) - if is_main_process: accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 diff --git a/train_db.py b/train_db.py index 936cd0bb4..5518740f1 100644 --- a/train_db.py +++ b/train_db.py @@ -272,13 +272,14 @@ def train(args): init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs) + # For --sample_at_first + train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + loss_recorder = train_util.LossRecorder() for epoch in range(num_train_epochs): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 - train_util.sample_images(accelerator, args, epoch, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) - # 指定したステップ数までText Encoderを学習する:epoch最初の状態 unet.train() # train==True is required to enable gradient_checkpointing diff --git a/train_network.py b/train_network.py index e570d3f28..378a3390a 100644 --- a/train_network.py +++ b/train_network.py @@ -409,9 +409,7 @@ def train(self, args): else: for t_enc in text_encoders: t_enc.to(accelerator.device, dtype=weight_dtype) - network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - network, optimizer, train_dataloader, lr_scheduler - ) + network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(network, optimizer, train_dataloader, lr_scheduler) if args.gradient_checkpointing: # according to TI example in Diffusers, train is required @@ -725,6 +723,9 @@ def remove_model(old_ckpt_name): accelerator.print(f"removing old checkpoint: {old_ckpt_file}") os.remove(old_ckpt_file) + # For --sample_at_first + self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + # training loop for epoch in range(num_train_epochs): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") @@ -732,8 +733,6 @@ def remove_model(old_ckpt_name): metadata["ss_epoch"] = str(epoch + 1) - # For --sample_at_first - self.sample_images(accelerator, args, epoch, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet) for step, batch in enumerate(train_dataloader): @@ -807,7 +806,7 @@ def remove_model(old_ckpt_name): loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし accelerator.backward(loss) - self.all_reduce_network(accelerator, network) # sync DDP grad manually + self.all_reduce_network(accelerator, network) # sync DDP grad manually if accelerator.sync_gradients and args.max_grad_norm != 0.0: params_to_clip = accelerator.unwrap_model(network).get_trainable_params() accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 8422edfac..877ac838e 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -7,10 +7,13 @@ from tqdm import tqdm import torch + try: import intel_extension_for_pytorch as ipex + if torch.xpu.is_available(): from library.ipex import ipex_init + ipex_init() except Exception: pass @@ -525,25 +528,25 @@ def remove_model(old_ckpt_name): accelerator.print(f"removing old checkpoint: {old_ckpt_file}") os.remove(old_ckpt_file) + # For --sample_at_first + self.sample_images( + accelerator, + args, + 0, + global_step, + accelerator.device, + vae, + tokenizer_or_list, + text_encoder_or_list, + unet, + prompt_replacement, + ) + # training loop for epoch in range(num_train_epochs): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 - # For --sample_at_first - self.sample_images( - accelerator, - args, - epoch, - global_step, - accelerator.device, - vae, - tokenizer_or_list, - text_encoder_or_list, - unet, - prompt_replacement, - ) - for text_encoder in text_encoders: text_encoder.train()