diff --git a/fine_tune.py b/fine_tune.py index 3f3da5b57..319088cce 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -253,9 +253,6 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): else: unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) - # transform DDP after prepare - text_encoder, unet = train_util.transform_if_model_is_DDP(text_encoder, unet) - # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする if args.full_fp16: train_util.patch_accelerator_for_fp16_training(accelerator) diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index f637d9931..5ad748d15 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -51,8 +51,6 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype): torch.cuda.empty_cache() accelerator.wait_for_everyone() - text_encoder1, text_encoder2, unet = train_util.transform_models_if_DDP([text_encoder1, text_encoder2, unet]) - return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info diff --git a/library/train_util.py b/library/train_util.py index ee334840e..2b051e1f1 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3914,17 +3914,6 @@ def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu", une return text_encoder, vae, unet, load_stable_diffusion_format -# TODO remove this function in the future -def transform_if_model_is_DDP(text_encoder, unet, network=None): - # Transform text_encoder, unet and network from DistributedDataParallel - return (model.module if type(model) == DDP else model for model in [text_encoder, unet, network] if model is not None) - - -def transform_models_if_DDP(models): - # Transform text_encoder, unet and network from DistributedDataParallel - return [model.module if type(model) == DDP else model for model in models if model is not None] - - def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projection_in_v2=False): # load models for each process for pi in range(accelerator.state.num_processes): @@ -3948,8 +3937,6 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio torch.cuda.empty_cache() accelerator.wait_for_everyone() - text_encoder, unet = transform_if_model_is_DDP(text_encoder, unet) - return text_encoder, vae, unet, load_stable_diffusion_format diff --git a/sdxl_train.py b/sdxl_train.py index 45e290be6..65e74b9f9 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -397,13 +397,10 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): # acceleratorがなんかよろしくやってくれるらしい if train_unet: unet = accelerator.prepare(unet) - (unet,) = train_util.transform_models_if_DDP([unet]) if train_text_encoder1: text_encoder1 = accelerator.prepare(text_encoder1) - (text_encoder1,) = train_util.transform_models_if_DDP([text_encoder1]) if train_text_encoder2: text_encoder2 = accelerator.prepare(text_encoder2) - (text_encoder2,) = train_util.transform_models_if_DDP([text_encoder2]) optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 44447d1f0..cb97859fa 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -283,9 +283,6 @@ def train(args): # acceleratorがなんかよろしくやってくれるらしい unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) - # transform DDP after prepare (train_network here only) - unet = train_util.transform_models_if_DDP([unet])[0] - if args.gradient_checkpointing: unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる else: diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index 91cbacc6a..87f303018 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -254,9 +254,6 @@ def train(args): ) network: control_net_lllite.ControlNetLLLite - # transform DDP after prepare (train_network here only) - unet, network = train_util.transform_models_if_DDP([unet, network]) - if args.gradient_checkpointing: unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる else: diff --git a/train_db.py b/train_db.py index 4eabed0f4..936cd0bb4 100644 --- a/train_db.py +++ b/train_db.py @@ -112,6 +112,7 @@ def train(args): # mixed precisionに対応した型を用意しておき適宜castする weight_dtype, save_dtype = train_util.prepare_dtype(args) + vae_dtype = torch.float32 if args.no_half_vae else weight_dtype # モデルを読み込む text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype, accelerator) @@ -136,7 +137,7 @@ def train(args): # 学習を準備する if cache_latents: - vae.to(accelerator.device, dtype=weight_dtype) + vae.to(accelerator.device, dtype=vae_dtype) vae.requires_grad_(False) vae.eval() with torch.no_grad(): @@ -225,9 +226,6 @@ def train(args): else: unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) - # transform DDP after prepare - text_encoder, unet = train_util.transform_if_model_is_DDP(text_encoder, unet) - if not train_text_encoder: text_encoder.to(accelerator.device, dtype=weight_dtype) # to avoid 'cpu' vs 'cuda' error @@ -484,6 +482,11 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="steps to stop text encoder training, -1 for no training / Text Encoderの学習を止めるステップ数、-1で最初から学習しない", ) + parser.add_argument( + "--no_half_vae", + action="store_true", + help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", + ) return parser diff --git a/train_network.py b/train_network.py index c9a37fb6a..e570d3f28 100644 --- a/train_network.py +++ b/train_network.py @@ -12,6 +12,7 @@ from tqdm import tqdm import torch +from torch.nn.parallel import DistributedDataParallel as DDP try: import intel_extension_for_pytorch as ipex @@ -127,6 +128,11 @@ def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_cond noise_pred = unet(noisy_latents, timesteps, text_conds).sample return noise_pred + def all_reduce_network(self, accelerator, network): + for param in network.parameters(): + if param.grad is not None: + param.grad = accelerator.reduce(param.grad, reduction="mean") + def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet): train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet) @@ -390,47 +396,23 @@ def train(self, args): # acceleratorがなんかよろしくやってくれるらしい # TODO めちゃくちゃ冗長なのでコードを整理する - if train_unet and train_text_encoder: + if train_unet: + unet = accelerator.prepare(unet) + else: + unet.to(accelerator.device, dtype=weight_dtype) # move to device because unet is not prepared by accelerator + if train_text_encoder: if len(text_encoders) > 1: - unet, t_enc1, t_enc2, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, text_encoders[0], text_encoders[1], network, optimizer, train_dataloader, lr_scheduler - ) - text_encoder = text_encoders = [t_enc1, t_enc2] - del t_enc1, t_enc2 + text_encoder = text_encoders = [accelerator.prepare(t_enc) for t_enc in text_encoders] else: - unet, text_encoder, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, text_encoder, network, optimizer, train_dataloader, lr_scheduler - ) + text_encoder = accelerator.prepare(text_encoder) text_encoders = [text_encoder] - elif train_unet: - unet, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, network, optimizer, train_dataloader, lr_scheduler - ) + else: for t_enc in text_encoders: t_enc.to(accelerator.device, dtype=weight_dtype) - elif train_text_encoder: - if len(text_encoders) > 1: - t_enc1, t_enc2, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - text_encoders[0], text_encoders[1], network, optimizer, train_dataloader, lr_scheduler - ) - text_encoder = text_encoders = [t_enc1, t_enc2] - del t_enc1, t_enc2 - else: - text_encoder, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - text_encoder, network, optimizer, train_dataloader, lr_scheduler - ) - text_encoders = [text_encoder] - - unet.to(accelerator.device, dtype=weight_dtype) # move to device because unet is not prepared by accelerator - else: - network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( network, optimizer, train_dataloader, lr_scheduler ) - # transform DDP after prepare (train_network here only) - text_encoders = train_util.transform_models_if_DDP(text_encoders) - unet, network = train_util.transform_models_if_DDP([unet, network]) - if args.gradient_checkpointing: # according to TI example in Diffusers, train is required unet.train() @@ -451,7 +433,7 @@ def train(self, args): del t_enc - network.prepare_grad_etc(text_encoder, unet) + accelerator.unwrap_model(network).prepare_grad_etc(text_encoder, unet) if not cache_latents: # キャッシュしない場合はVAEを使うのでVAEを準備する vae.requires_grad_(False) @@ -714,8 +696,8 @@ def train(self, args): del train_dataset_group # callback for step start - if hasattr(network, "on_step_start"): - on_step_start = network.on_step_start + if hasattr(accelerator.unwrap_model(network), "on_step_start"): + on_step_start = accelerator.unwrap_model(network).on_step_start else: on_step_start = lambda *args, **kwargs: None @@ -749,10 +731,10 @@ def remove_model(old_ckpt_name): current_epoch.value = epoch + 1 metadata["ss_epoch"] = str(epoch + 1) - + # For --sample_at_first self.sample_images(accelerator, args, epoch, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) - network.on_epoch_start(text_encoder, unet) + accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet) for step, batch in enumerate(train_dataloader): current_step.value = global_step @@ -825,8 +807,9 @@ def remove_model(old_ckpt_name): loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし accelerator.backward(loss) + self.all_reduce_network(accelerator, network) # sync DDP grad manually if accelerator.sync_gradients and args.max_grad_norm != 0.0: - params_to_clip = network.get_trainable_params() + params_to_clip = accelerator.unwrap_model(network).get_trainable_params() accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() @@ -834,7 +817,7 @@ def remove_model(old_ckpt_name): optimizer.zero_grad(set_to_none=True) if args.scale_weight_norms: - keys_scaled, mean_norm, maximum_norm = network.apply_max_norm_regularization( + keys_scaled, mean_norm, maximum_norm = accelerator.unwrap_model(network).apply_max_norm_regularization( args.scale_weight_norms, accelerator.device ) max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm} diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 3b0aec24f..8422edfac 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -415,15 +415,11 @@ def train(self, args): text_encoder_or_list, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( text_encoder_or_list, optimizer, train_dataloader, lr_scheduler ) - # transform DDP after prepare - text_encoder_or_list, unet = train_util.transform_if_model_is_DDP(text_encoder_or_list, unet) elif len(text_encoders) == 2: text_encoder1, text_encoder2, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( text_encoders[0], text_encoders[1], optimizer, train_dataloader, lr_scheduler ) - # transform DDP after prepare - text_encoder1, text_encoder2, unet = train_util.transform_if_model_is_DDP(text_encoder1, text_encoder2, unet) text_encoder_or_list = text_encoders = [text_encoder1, text_encoder2] diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index f77ad2eb2..42d69d2de 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -333,9 +333,6 @@ def train(args): text_encoder, optimizer, train_dataloader, lr_scheduler ) - # transform DDP after prepare - text_encoder, unet = train_util.transform_if_model_is_DDP(text_encoder, unet) - index_no_updates = torch.arange(len(tokenizer)) < token_ids_XTI[0] # print(len(index_no_updates), torch.sum(index_no_updates)) orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone()