Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix RAM leak when loading SDXL model in lowram device #676

Merged
merged 9 commits into from
Jul 30, 2023

Conversation

Isotr0py
Copy link
Contributor

For low-ram device like Colab, the usual workflow to load U-net may blow the RAM since it costs about 12.7Gib RAM to load the checkpoint and initialize the SdxlUNet2DConditionModel().

Using init_empty_weights() context manager can initialize a model without using any RAM. It can solve the RAM leak on low RAM device.

However, since we didn't use load_state_dict() when loading U-Net from checkpoint, _IncompatibleKeys will not be warned for corrupted checkpoint with missing_keys and unexpected_keys...

@Isotr0py Isotr0py changed the title Fix RAM leak when loading model in lowram device Fix RAM leak when loading SDXL model in lowram device Jul 23, 2023
@kohya-ss
Copy link
Owner

Thank you for this! Sorry for the late reply. This is very interesting and useful.

However, when I checked this PR in my environment, I found that applying this PR causes OOM in cases where memory is on the edge without PR.

Do you have any ideas?

@Isotr0py
Copy link
Contributor Author

Isotr0py commented Jul 27, 2023

from accelerate import init_empty_weights, load_checkpoint_and_dispatch
from accelerate.utils.modeling import set_module_tensor_to_device

from diffusers import StableDiffusionXLPipeline
import torch
import gc

from library import sdxl_model_util, sdxl_original_unet


def load_target_model(device):
    pipe = StableDiffusionXLPipeline.from_pretrained(
        "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
    )

    text_encoder1 = pipe.text_encoder
    text_encoder2 = pipe.text_encoder_2
    vae = pipe.vae
    unet = pipe.unet
    del pipe
    gc.collect()

    print("convert U-Net to original U-Net")

    # original u-net
    original_unet = sdxl_original_unet.SdxlUNet2DConditionModel()
    state_dict = sdxl_model_util.convert_diffusers_unet_state_dict_to_sdxl(unet.state_dict())
    original_unet.load_state_dict(state_dict)
    unet = original_unet
    
    # pr's u-net
    state_dict = sdxl_model_util.convert_diffusers_unet_state_dict_to_sdxl(unet.state_dict())
    with init_empty_weights():
        unet = sdxl_original_unet.SdxlUNet2DConditionModel()
    for k in list(state_dict.keys()):
        set_module_tensor_to_device(unet, k, device, value=state_dict.pop(k))

    print("U-Net converted to original U-Net")
    return text_encoder1, text_encoder2, unet

text_encoder1, text_encoder2, unet = load_target_model("cuda")
unet.to("cuda")
text_encoder1.to("cuda")
text_encoder2.to("cuda")

I tested with the code upon. PR's and original method both cost about 12.2G VRAM (10.6G for U-net only) to load unet and text_encoders.

It seems that VRAM usage didn't change in model loading. I wonder OOM occurs in which step to figure out what is going wrong.

@kohya-ss
Copy link
Owner

Thank you for the detailed explanation. In my case, the OOM occurs when the training starts.

however, I ran the same configuration as yesterday and for some reason it worked today without going OOM.
I am very confused. I think that maybe even if I specify the same seed, the order of the batch changes because the weights are not initialized, and the order of the batch is causing the OOM.

Sorry to bother you.

Since the OOM doesn't seem to be related to your PR, I would like to merge it with a little more confirmation just to be sure.

@Isotr0py
Copy link
Contributor Author

BTW, for set_module_tensor_to_device, if we pass dtype=torch.float16 or dtype=torch.bfloat16, the model will directly load to fp16/bf16 instead of fp32 than fp16/bf16.

Maybe we can integrate full_fp16/full_bf16 into model loading to reduce VRAM peak.🤔

@Isotr0py
Copy link
Contributor Author

Isotr0py commented Jul 28, 2023

Well, _load_state_dict would work like model.load_state_dict() with strict=True to find out wrong keys and raise an error before loading.

RuntimeError: Error(s) in loading state_dict for SdxlUNet2DConditionModel:
        Missing key(s) in state_dict: "input_blocks.0.0.bias", "time_embed.2.weight", "time_embed.0.weight", 
"time_embed.0.bias", "input_blocks.0.0.weight". 
        Unexpected key(s) in state_dict: "extra_keys". 

It seems that it did work for a corrupted SDXL-1.0 checkpoint created by adding/deleting some keys from original ones.

@kohya-ss
Copy link
Owner

Thanks for the update! It has become quite complicated, but I understand that it is unavoidable.

After merging, I may refactor a little, but I would appreciate your understanding.

@kohya-ss kohya-ss changed the base branch from sdxl to dev July 30, 2023 03:45
@kohya-ss kohya-ss merged commit e20b6ac into kohya-ss:dev Jul 30, 2023
1 check passed
@kohya-ss
Copy link
Owner

Thank you for this! I've merged to dev branch and will test with other PRs, and merge into sdxl branch.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants