From 81a2792d578cc882171adb223612cfadc334b936 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 27 Sep 2024 20:28:02 +0530 Subject: [PATCH 1/4] fix: retain memory utility. --- examples/cogvideo/train_cogvideox_lora.py | 8 +++----- examples/controlnet/train_controlnet_flux.py | 8 +++++--- examples/controlnet/train_controlnet_sd3.py | 9 ++++++--- .../dreambooth/train_dreambooth_lora_flux.py | 11 +++++++---- .../dreambooth/train_dreambooth_lora_sd3.py | 18 ++++++++---------- src/diffusers/training_utils.py | 8 ++------ 6 files changed, 31 insertions(+), 31 deletions(-) diff --git a/examples/cogvideo/train_cogvideox_lora.py b/examples/cogvideo/train_cogvideox_lora.py index 137f3222f6d9..90aa688b6e86 100644 --- a/examples/cogvideo/train_cogvideox_lora.py +++ b/examples/cogvideo/train_cogvideox_lora.py @@ -38,10 +38,7 @@ from diffusers.models.embeddings import get_3d_rotary_pos_embed from diffusers.optimization import get_scheduler from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid -from diffusers.training_utils import ( - cast_training_params, - clear_objs_and_retain_memory, -) +from diffusers.training_utils import cast_training_params, retain_memory from diffusers.utils import check_min_version, convert_unet_state_dict_to_peft, export_to_video, is_wandb_available from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card from diffusers.utils.torch_utils import is_compiled_module @@ -726,7 +723,8 @@ def log_validation( } ) - clear_objs_and_retain_memory([pipe]) + del pipe + retain_memory() return videos diff --git a/examples/controlnet/train_controlnet_flux.py b/examples/controlnet/train_controlnet_flux.py index e344a9b1e2a5..61bd36ced193 100644 --- a/examples/controlnet/train_controlnet_flux.py +++ b/examples/controlnet/train_controlnet_flux.py @@ -54,7 +54,7 @@ from diffusers.models.controlnet_flux import FluxControlNetModel from diffusers.optimization import get_scheduler from diffusers.pipelines.flux.pipeline_flux_controlnet import FluxControlNetPipeline -from diffusers.training_utils import clear_objs_and_retain_memory, compute_density_for_timestep_sampling +from diffusers.training_utils import compute_density_for_timestep_sampling, retain_memory from diffusers.utils import check_min_version, is_wandb_available, make_image_grid from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card from diffusers.utils.import_utils import is_torch_npu_available, is_xformers_available @@ -193,7 +193,8 @@ def log_validation( else: logger.warning(f"image logging not implemented for {tracker.name}") - clear_objs_and_retain_memory([pipeline]) + del pipeline + retain_memory() return image_logs @@ -1103,7 +1104,8 @@ def compute_embeddings(batch, proportion_empty_prompts, flux_controlnet_pipeline compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint, batch_size=50 ) - clear_objs_and_retain_memory([text_encoders, tokenizers]) + del text_encoders, tokenizers, text_encoder_one, text_encoder_two, tokenizer_one, tokenizer_two + retain_memory() # Then get the training dataset ready to be passed to the dataloader. train_dataset = prepare_train_dataset(train_dataset, accelerator) diff --git a/examples/controlnet/train_controlnet_sd3.py b/examples/controlnet/train_controlnet_sd3.py index 4b255c501d99..f06d812016ca 100644 --- a/examples/controlnet/train_controlnet_sd3.py +++ b/examples/controlnet/train_controlnet_sd3.py @@ -50,9 +50,9 @@ ) from diffusers.optimization import get_scheduler from diffusers.training_utils import ( - clear_objs_and_retain_memory, compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, + retain_memory, ) from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card @@ -174,7 +174,8 @@ def log_validation(controlnet, args, accelerator, weight_dtype, step, is_final_v else: logger.warning(f"image logging not implemented for {tracker.name}") - clear_objs_and_retain_memory(pipeline) + del pipeline + retain_memory() if not is_final_validation: controlnet.to(accelerator.device) @@ -1131,7 +1132,9 @@ def compute_text_embeddings(batch, text_encoders, tokenizers): new_fingerprint = Hasher.hash(args) train_dataset = train_dataset.map(compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint) - clear_objs_and_retain_memory(text_encoders + tokenizers) + del text_encoder_one, text_encoder_two, text_encoder_three + del tokenizer_one, tokenizer_two, tokenizer_three + retain_memory() train_dataloader = torch.utils.data.DataLoader( train_dataset, diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index 6091622719ee..594a60b05e73 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -55,9 +55,9 @@ from diffusers.training_utils import ( _set_state_dict_into_text_encoder, cast_training_params, - clear_objs_and_retain_memory, compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, + retain_memory, ) from diffusers.utils import ( check_min_version, @@ -1437,7 +1437,8 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # Clear the memory here if not args.train_text_encoder and not train_dataset.custom_instance_prompts: - clear_objs_and_retain_memory([tokenizers, text_encoders, text_encoder_one, text_encoder_two]) + del text_encoder_one, text_encoder_two, tokenizer_one, tokenizer_two + retain_memory() # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), # pack the statically computed variables appropriately here. This is so that we don't @@ -1480,7 +1481,8 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist) if args.validation_prompt is None: - clear_objs_and_retain_memory([vae]) + del vae + retain_memory() # Scheduler and math around the number of training steps. overrode_max_train_steps = False @@ -1817,7 +1819,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): torch_dtype=weight_dtype, ) if not args.train_text_encoder: - clear_objs_and_retain_memory([text_encoder_one, text_encoder_two]) + del text_encoder_one, text_encoder_two + retain_memory() # Save the lora layers accelerator.wait_for_everyone() diff --git a/examples/dreambooth/train_dreambooth_lora_sd3.py b/examples/dreambooth/train_dreambooth_lora_sd3.py index 3060813bbbdc..b9fe45efda4f 100644 --- a/examples/dreambooth/train_dreambooth_lora_sd3.py +++ b/examples/dreambooth/train_dreambooth_lora_sd3.py @@ -55,9 +55,9 @@ from diffusers.training_utils import ( _set_state_dict_into_text_encoder, cast_training_params, - clear_objs_and_retain_memory, compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, + retain_memory, ) from diffusers.utils import ( check_min_version, @@ -211,7 +211,8 @@ def log_validation( } ) - clear_objs_and_retain_memory(objs=[pipeline]) + del pipeline + retain_memory() return images @@ -1106,7 +1107,8 @@ def main(args): image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" image.save(image_filename) - clear_objs_and_retain_memory(objs=[pipeline]) + del pipeline + retain_memory() # Handle the repository creation if accelerator.is_main_process: @@ -1453,9 +1455,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # Clear the memory here if not args.train_text_encoder and not train_dataset.custom_instance_prompts: # Explicitly delete the objects as well, otherwise only the lists are deleted and the original references remain, preventing garbage collection - clear_objs_and_retain_memory( - objs=[tokenizers, text_encoders, text_encoder_one, text_encoder_two, text_encoder_three] - ) + retain_memory(objs=[tokenizers, text_encoders, text_encoder_one, text_encoder_two, text_encoder_three]) # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), # pack the statically computed variables appropriately here. This is so that we don't @@ -1791,11 +1791,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): epoch=epoch, torch_dtype=weight_dtype, ) - objs = [] - if not args.train_text_encoder: - objs.extend([text_encoder_one, text_encoder_two, text_encoder_three]) - clear_objs_and_retain_memory(objs=objs) + del text_encoder_one, text_encoder_two, text_encoder_three + retain_memory() # Save the lora layers accelerator.wait_for_everyone() diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 26d4a2a504c6..969df048f2c8 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -260,12 +260,8 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): return weighting -def clear_objs_and_retain_memory(objs: List[Any]): - """Deletes `objs` and runs garbage collection. Then clears the cache of the available accelerator.""" - if len(objs) >= 1: - for obj in objs: - del obj - +def retain_memory(): + """Runs garbage collection. Then clears the cache of the available accelerator.""" gc.collect() if torch.cuda.is_available(): From 33624f91a67239b9b408f04d6f6cb6d6ab87cf1b Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 28 Sep 2024 08:01:20 +0530 Subject: [PATCH 2/4] fix --- examples/dreambooth/train_dreambooth_lora_sd3.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sd3.py b/examples/dreambooth/train_dreambooth_lora_sd3.py index b9fe45efda4f..9e7995cc9774 100644 --- a/examples/dreambooth/train_dreambooth_lora_sd3.py +++ b/examples/dreambooth/train_dreambooth_lora_sd3.py @@ -1455,7 +1455,9 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # Clear the memory here if not args.train_text_encoder and not train_dataset.custom_instance_prompts: # Explicitly delete the objects as well, otherwise only the lists are deleted and the original references remain, preventing garbage collection - retain_memory(objs=[tokenizers, text_encoders, text_encoder_one, text_encoder_two, text_encoder_three]) + del tokenizers, text_encoders + del text_encoder_one, text_encoder_two, text_encoder_three + retain_memory() # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), # pack the statically computed variables appropriately here. This is so that we don't From 747dc42d144e3580de00f728e3906b7933955653 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 28 Sep 2024 08:03:52 +0530 Subject: [PATCH 3/4] quality --- examples/dreambooth/train_dreambooth_lora_sd3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sd3.py b/examples/dreambooth/train_dreambooth_lora_sd3.py index 9e7995cc9774..174550e755df 100644 --- a/examples/dreambooth/train_dreambooth_lora_sd3.py +++ b/examples/dreambooth/train_dreambooth_lora_sd3.py @@ -1455,7 +1455,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # Clear the memory here if not args.train_text_encoder and not train_dataset.custom_instance_prompts: # Explicitly delete the objects as well, otherwise only the lists are deleted and the original references remain, preventing garbage collection - del tokenizers, text_encoders + del tokenizers, text_encoders del text_encoder_one, text_encoder_two, text_encoder_three retain_memory() From debb716b32a4fcbe7eaab429f03e17fd64abff15 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 28 Sep 2024 20:58:55 +0530 Subject: [PATCH 4/4] free_memory. --- examples/cogvideo/train_cogvideox_lora.py | 4 ++-- examples/controlnet/train_controlnet_flux.py | 6 +++--- examples/controlnet/train_controlnet_sd3.py | 10 +++------- examples/dreambooth/train_dreambooth_lora_flux.py | 8 ++++---- examples/dreambooth/train_dreambooth_lora_sd3.py | 10 +++++----- src/diffusers/training_utils.py | 2 +- 6 files changed, 18 insertions(+), 22 deletions(-) diff --git a/examples/cogvideo/train_cogvideox_lora.py b/examples/cogvideo/train_cogvideox_lora.py index 90aa688b6e86..6787c37f93a8 100644 --- a/examples/cogvideo/train_cogvideox_lora.py +++ b/examples/cogvideo/train_cogvideox_lora.py @@ -38,7 +38,7 @@ from diffusers.models.embeddings import get_3d_rotary_pos_embed from diffusers.optimization import get_scheduler from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid -from diffusers.training_utils import cast_training_params, retain_memory +from diffusers.training_utils import cast_training_params, free_memory from diffusers.utils import check_min_version, convert_unet_state_dict_to_peft, export_to_video, is_wandb_available from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card from diffusers.utils.torch_utils import is_compiled_module @@ -724,7 +724,7 @@ def log_validation( ) del pipe - retain_memory() + free_memory() return videos diff --git a/examples/controlnet/train_controlnet_flux.py b/examples/controlnet/train_controlnet_flux.py index 61bd36ced193..5969218f3c3e 100644 --- a/examples/controlnet/train_controlnet_flux.py +++ b/examples/controlnet/train_controlnet_flux.py @@ -54,7 +54,7 @@ from diffusers.models.controlnet_flux import FluxControlNetModel from diffusers.optimization import get_scheduler from diffusers.pipelines.flux.pipeline_flux_controlnet import FluxControlNetPipeline -from diffusers.training_utils import compute_density_for_timestep_sampling, retain_memory +from diffusers.training_utils import compute_density_for_timestep_sampling, free_memory from diffusers.utils import check_min_version, is_wandb_available, make_image_grid from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card from diffusers.utils.import_utils import is_torch_npu_available, is_xformers_available @@ -194,7 +194,7 @@ def log_validation( logger.warning(f"image logging not implemented for {tracker.name}") del pipeline - retain_memory() + free_memory() return image_logs @@ -1105,7 +1105,7 @@ def compute_embeddings(batch, proportion_empty_prompts, flux_controlnet_pipeline ) del text_encoders, tokenizers, text_encoder_one, text_encoder_two, tokenizer_one, tokenizer_two - retain_memory() + free_memory() # Then get the training dataset ready to be passed to the dataloader. train_dataset = prepare_train_dataset(train_dataset, accelerator) diff --git a/examples/controlnet/train_controlnet_sd3.py b/examples/controlnet/train_controlnet_sd3.py index f06d812016ca..9ea78370f5e0 100644 --- a/examples/controlnet/train_controlnet_sd3.py +++ b/examples/controlnet/train_controlnet_sd3.py @@ -49,11 +49,7 @@ StableDiffusion3ControlNetPipeline, ) from diffusers.optimization import get_scheduler -from diffusers.training_utils import ( - compute_density_for_timestep_sampling, - compute_loss_weighting_for_sd3, - retain_memory, -) +from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, free_memory from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card from diffusers.utils.torch_utils import is_compiled_module @@ -175,7 +171,7 @@ def log_validation(controlnet, args, accelerator, weight_dtype, step, is_final_v logger.warning(f"image logging not implemented for {tracker.name}") del pipeline - retain_memory() + free_memory() if not is_final_validation: controlnet.to(accelerator.device) @@ -1134,7 +1130,7 @@ def compute_text_embeddings(batch, text_encoders, tokenizers): del text_encoder_one, text_encoder_two, text_encoder_three del tokenizer_one, tokenizer_two, tokenizer_three - retain_memory() + free_memory() train_dataloader = torch.utils.data.DataLoader( train_dataset, diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index 594a60b05e73..fcc11386abcf 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -57,7 +57,7 @@ cast_training_params, compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, - retain_memory, + free_memory, ) from diffusers.utils import ( check_min_version, @@ -1438,7 +1438,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # Clear the memory here if not args.train_text_encoder and not train_dataset.custom_instance_prompts: del text_encoder_one, text_encoder_two, tokenizer_one, tokenizer_two - retain_memory() + free_memory() # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), # pack the statically computed variables appropriately here. This is so that we don't @@ -1482,7 +1482,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): if args.validation_prompt is None: del vae - retain_memory() + free_memory() # Scheduler and math around the number of training steps. overrode_max_train_steps = False @@ -1820,7 +1820,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): ) if not args.train_text_encoder: del text_encoder_one, text_encoder_two - retain_memory() + free_memory() # Save the lora layers accelerator.wait_for_everyone() diff --git a/examples/dreambooth/train_dreambooth_lora_sd3.py b/examples/dreambooth/train_dreambooth_lora_sd3.py index 174550e755df..02f5a7ee0f7a 100644 --- a/examples/dreambooth/train_dreambooth_lora_sd3.py +++ b/examples/dreambooth/train_dreambooth_lora_sd3.py @@ -57,7 +57,7 @@ cast_training_params, compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, - retain_memory, + free_memory, ) from diffusers.utils import ( check_min_version, @@ -212,7 +212,7 @@ def log_validation( ) del pipeline - retain_memory() + free_memory() return images @@ -1108,7 +1108,7 @@ def main(args): image.save(image_filename) del pipeline - retain_memory() + free_memory() # Handle the repository creation if accelerator.is_main_process: @@ -1457,7 +1457,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # Explicitly delete the objects as well, otherwise only the lists are deleted and the original references remain, preventing garbage collection del tokenizers, text_encoders del text_encoder_one, text_encoder_two, text_encoder_three - retain_memory() + free_memory() # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), # pack the statically computed variables appropriately here. This is so that we don't @@ -1795,7 +1795,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): ) del text_encoder_one, text_encoder_two, text_encoder_three - retain_memory() + free_memory() # Save the lora layers accelerator.wait_for_everyone() diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 969df048f2c8..57bd9074870c 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -260,7 +260,7 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): return weighting -def retain_memory(): +def free_memory(): """Runs garbage collection. Then clears the cache of the available accelerator.""" gc.collect()