Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[chore] fix: retain memory utility. #9543

Merged
merged 6 commits into from
Sep 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions examples/cogvideo/train_cogvideox_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,7 @@
from diffusers.models.embeddings import get_3d_rotary_pos_embed
from diffusers.optimization import get_scheduler
from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid
from diffusers.training_utils import (
cast_training_params,
clear_objs_and_retain_memory,
)
from diffusers.training_utils import cast_training_params, free_memory
from diffusers.utils import check_min_version, convert_unet_state_dict_to_peft, export_to_video, is_wandb_available
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.torch_utils import is_compiled_module
Expand Down Expand Up @@ -726,7 +723,8 @@ def log_validation(
}
)

clear_objs_and_retain_memory([pipe])
del pipe
free_memory()

return videos

Expand Down
8 changes: 5 additions & 3 deletions examples/controlnet/train_controlnet_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
from diffusers.models.controlnet_flux import FluxControlNetModel
from diffusers.optimization import get_scheduler
from diffusers.pipelines.flux.pipeline_flux_controlnet import FluxControlNetPipeline
from diffusers.training_utils import clear_objs_and_retain_memory, compute_density_for_timestep_sampling
from diffusers.training_utils import compute_density_for_timestep_sampling, free_memory
from diffusers.utils import check_min_version, is_wandb_available, make_image_grid
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.import_utils import is_torch_npu_available, is_xformers_available
Expand Down Expand Up @@ -193,7 +193,8 @@ def log_validation(
else:
logger.warning(f"image logging not implemented for {tracker.name}")

clear_objs_and_retain_memory([pipeline])
del pipeline
free_memory()
return image_logs


Expand Down Expand Up @@ -1103,7 +1104,8 @@ def compute_embeddings(batch, proportion_empty_prompts, flux_controlnet_pipeline
compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint, batch_size=50
)

clear_objs_and_retain_memory([text_encoders, tokenizers])
del text_encoders, tokenizers, text_encoder_one, text_encoder_two, tokenizer_one, tokenizer_two
free_memory()

# Then get the training dataset ready to be passed to the dataloader.
train_dataset = prepare_train_dataset(train_dataset, accelerator)
Expand Down
13 changes: 6 additions & 7 deletions examples/controlnet/train_controlnet_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,7 @@
StableDiffusion3ControlNetPipeline,
)
from diffusers.optimization import get_scheduler
from diffusers.training_utils import (
clear_objs_and_retain_memory,
compute_density_for_timestep_sampling,
compute_loss_weighting_for_sd3,
)
from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, free_memory
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.torch_utils import is_compiled_module
Expand Down Expand Up @@ -174,7 +170,8 @@ def log_validation(controlnet, args, accelerator, weight_dtype, step, is_final_v
else:
logger.warning(f"image logging not implemented for {tracker.name}")

clear_objs_and_retain_memory(pipeline)
del pipeline
free_memory()

if not is_final_validation:
controlnet.to(accelerator.device)
Expand Down Expand Up @@ -1131,7 +1128,9 @@ def compute_text_embeddings(batch, text_encoders, tokenizers):
new_fingerprint = Hasher.hash(args)
train_dataset = train_dataset.map(compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint)

clear_objs_and_retain_memory(text_encoders + tokenizers)
del text_encoder_one, text_encoder_two, text_encoder_three
del tokenizer_one, tokenizer_two, tokenizer_three
free_memory()

train_dataloader = torch.utils.data.DataLoader(
train_dataset,
Expand Down
11 changes: 7 additions & 4 deletions examples/dreambooth/train_dreambooth_lora_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@
from diffusers.training_utils import (
_set_state_dict_into_text_encoder,
cast_training_params,
clear_objs_and_retain_memory,
compute_density_for_timestep_sampling,
compute_loss_weighting_for_sd3,
free_memory,
)
from diffusers.utils import (
check_min_version,
Expand Down Expand Up @@ -1437,7 +1437,8 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):

# Clear the memory here
if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
clear_objs_and_retain_memory([tokenizers, text_encoders, text_encoder_one, text_encoder_two])
del text_encoder_one, text_encoder_two, tokenizer_one, tokenizer_two
free_memory()

# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
# pack the statically computed variables appropriately here. This is so that we don't
Expand Down Expand Up @@ -1480,7 +1481,8 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist)

if args.validation_prompt is None:
clear_objs_and_retain_memory([vae])
del vae
free_memory()

# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
Expand Down Expand Up @@ -1817,7 +1819,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
torch_dtype=weight_dtype,
)
if not args.train_text_encoder:
clear_objs_and_retain_memory([text_encoder_one, text_encoder_two])
del text_encoder_one, text_encoder_two
free_memory()

# Save the lora layers
accelerator.wait_for_everyone()
Expand Down
20 changes: 10 additions & 10 deletions examples/dreambooth/train_dreambooth_lora_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@
from diffusers.training_utils import (
_set_state_dict_into_text_encoder,
cast_training_params,
clear_objs_and_retain_memory,
compute_density_for_timestep_sampling,
compute_loss_weighting_for_sd3,
free_memory,
)
from diffusers.utils import (
check_min_version,
Expand Down Expand Up @@ -211,7 +211,8 @@ def log_validation(
}
)

clear_objs_and_retain_memory(objs=[pipeline])
del pipeline
free_memory()

return images

Expand Down Expand Up @@ -1106,7 +1107,8 @@ def main(args):
image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
image.save(image_filename)

clear_objs_and_retain_memory(objs=[pipeline])
del pipeline
free_memory()

# Handle the repository creation
if accelerator.is_main_process:
Expand Down Expand Up @@ -1453,9 +1455,9 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
# Clear the memory here
if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
# Explicitly delete the objects as well, otherwise only the lists are deleted and the original references remain, preventing garbage collection
clear_objs_and_retain_memory(
objs=[tokenizers, text_encoders, text_encoder_one, text_encoder_two, text_encoder_three]
)
del tokenizers, text_encoders
del text_encoder_one, text_encoder_two, text_encoder_three
free_memory()

# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
# pack the statically computed variables appropriately here. This is so that we don't
Expand Down Expand Up @@ -1791,11 +1793,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
epoch=epoch,
torch_dtype=weight_dtype,
)
objs = []
if not args.train_text_encoder:
objs.extend([text_encoder_one, text_encoder_two, text_encoder_three])

clear_objs_and_retain_memory(objs=objs)
del text_encoder_one, text_encoder_two, text_encoder_three
free_memory()

# Save the lora layers
accelerator.wait_for_everyone()
Expand Down
8 changes: 2 additions & 6 deletions src/diffusers/training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,12 +260,8 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
return weighting


def clear_objs_and_retain_memory(objs: List[Any]):
"""Deletes `objs` and runs garbage collection. Then clears the cache of the available accelerator."""
if len(objs) >= 1:
for obj in objs:
del obj

def free_memory():
"""Runs garbage collection. Then clears the cache of the available accelerator."""
gc.collect()

if torch.cuda.is_available():
Expand Down
Loading