Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Are there any way to reduce system RAM usage? I am trying to make a tutorial for free colab SDXL LoRA training #788

Closed
FurkanGozukara opened this issue Aug 25, 2023 · 16 comments

Comments

@FurkanGozukara
Copy link

Hello. I am trying to make a tutorial for free colab

It has good GPU but the problem is system ram is only 13 GB

We need some options or way to reduce RAM usage

@FurkanGozukara
Copy link
Author

Kaggle provides 2 GPU

can we load unet and text encoder into one of the GPUs with --lowram option?

image

@kohya-ss

@FurkanGozukara
Copy link
Author

loading U-Net from checkpoint
IOStream.flush timed out
╭───────────────────── Traceback (most recent call last) ──────────────────────╮
│ /opt/conda/bin/accelerate:8 in <module>                                      │
│                                                                              │
│   5 from accelerate.commands.accelerate_cli import main                      │
│   6 if __name__ == '__main__':                                               │
│   7 │   sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])     │
│ ❱ 8 │   sys.exit(main())                                                     │
│   9                                                                          │
│                                                                              │
│ /opt/conda/lib/python3.10/site-packages/accelerate/commands/accelerate_cli.p │
│ y:45 in main                                                                 │
│                                                                              │
│   42 │   │   exit(1)                                                         │
│   43 │                                                                       │
│   44 │   # Run                                                               │
│ ❱ 45 │   args.func(args)                                                     │
│   46                                                                         │
│   47                                                                         │
│   48 if __name__ == "__main__":                                              │
│                                                                              │
│ /opt/conda/lib/python3.10/site-packages/accelerate/commands/launch.py:909 in │
│ launch_command                                                               │
│                                                                              │
│   906 │   elif args.use_megatron_lm and not args.cpu:                        │
│   907 │   │   multi_gpu_launcher(args)                                       │
│   908 │   elif args.multi_gpu and not args.cpu:                              │
│ ❱ 909 │   │   multi_gpu_launcher(args)                                       │
│   910 │   elif args.tpu and not args.cpu:                                    │
│   911 │   │   if args.tpu_use_cluster:                                       │
│   912 │   │   │   tpu_pod_launcher(args)                                     │
│                                                                              │
│ /opt/conda/lib/python3.10/site-packages/accelerate/commands/launch.py:604 in │
│ multi_gpu_launcher                                                           │
│                                                                              │
│   601 │   )                                                                  │
│   602 │   with patch_environment(**current_env):                             │
│   603 │   │   try:                                                           │
│ ❱ 604 │   │   │   distrib_run.run(args)                                      │
│   605 │   │   except Exception:                                              │
│   606 │   │   │   if is_rich_available() and debug:                          │
│   607 │   │   │   │   console = get_console()                                │
│                                                                              │
│ /opt/conda/lib/python3.10/site-packages/torch/distributed/run.py:785 in run  │
│                                                                              │
│   782 │   │   )                                                              │
│   783 │                                                                      │
│   784 │   config, cmd, cmd_args = config_from_args(args)                     │
│ ❱ 785 │   elastic_launch(                                                    │
│   786 │   │   config=config,                                                 │
│   787 │   │   entrypoint=cmd,                                                │
│   788 │   )(*cmd_args)                                                       │
│                                                                              │
│ /opt/conda/lib/python3.10/site-packages/torch/distributed/launcher/api.py:13 │
│ 4 in __call__                                                                │
│                                                                              │
│   131 │   │   self._entrypoint = entrypoint                                  │
│   132 │                                                                      │
│   133 │   def __call__(self, *args):                                         │
│ ❱ 134 │   │   return launch_agent(self._config, self._entrypoint, list(args) │
│   135                                                                        │
│   136                                                                        │
│   137 def _get_entrypoint_name(                                              │
│                                                                              │
│ /opt/conda/lib/python3.10/site-packages/torch/distributed/launcher/api.py:25 │
│ 0 in launch_agent                                                            │
│                                                                              │
│   247 │   │   │   # if the error files for the failed children exist         │
│   248 │   │   │   # @record will copy the first error (root cause)           │
│   249 │   │   │   # to the error file of the launcher process.               │
│ ❱ 250 │   │   │   raise ChildFailedError(                                    │
│   251 │   │   │   │   name=entrypoint_name,                                  │
│   252 │   │   │   │   failures=result.failures,                              │
│   253 │   │   │   )                                                          │
╰──────────────────────────────────────────────────────────────────────────────╯
ChildFailedError: 
====================================================
./sdxl_train_network.py FAILED
----------------------------------------------------
Failures:
  <NO_OTHER_FAILURES>
----------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2023-08-25_21:56:09
  host      : a33a9a9c87bb
  rank      : 0 (local_rank: 0)
  exitcode  : -9 (pid: 705)
  error_file: <N/A>
  traceback : Signal 9 (SIGKILL) received by PID 705

@Isotr0py
Copy link
Contributor

Isotr0py commented Aug 26, 2023

@FurkanGozukara If you use kaggle to train SDXL LoRA, I recommend to use converted pretrained checkpoint from HF repo instead of safetensors format file. Because loading from safetensors will require transfer text_encoders and vae in addition.

pretrained_model_name_or_path="stabilityai/stable-diffusion-xl-base-1.0"

Due to the limited RAM, it's a little bit too 'extreme' for kaggle's kernel to initialize models on 2 GPU.
If you use the safetensors format file, RAM usage will increase during text_encoders and vae conversion (which may kill the kernel).

QQ截图20230826150901

@Isotr0py
Copy link
Contributor

Isotr0py commented Aug 26, 2023

Here is the test code that simulates model loading from HF repo in scripts.

Without operations cost RAM like cache latents, it cost about 7.2GB RAM to load model on the first GPU and 11.1GB on the second GPU.

from accelerate import init_empty_weights, load_checkpoint_and_dispatch
from accelerate.utils.modeling import set_module_tensor_to_device

from diffusers import StableDiffusionXLPipeline
import torch
import gc

from library import sdxl_model_util, sdxl_original_unet


def load_target_model(device):
    pipe = StableDiffusionXLPipeline.from_pretrained(
        "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
    )

    text_encoder1 = pipe.text_encoder
    text_encoder2 = pipe.text_encoder_2
    vae = pipe.vae
    unet = pipe.unet
    del pipe
    gc.collect()

    print("convert U-Net to original U-Net")

    state_dict = sdxl_model_util.convert_diffusers_unet_state_dict_to_sdxl(unet.state_dict())
    with init_empty_weights():
        unet = sdxl_original_unet.SdxlUNet2DConditionModel()
    for k in list(state_dict.keys()):
        set_module_tensor_to_device(unet, k, device, value=state_dict.pop(k))

    print("U-Net converted to original U-Net")
    return text_encoder1, text_encoder2, unet

for device in ["cuda:0","cuda:1"]:
    text_encoder1, text_encoder2, unet = load_target_model(device)
    text_encoder1.to(device)
    text_encoder2.to(device)

@FurkanGozukara
Copy link
Author

Here is the test code that simulates model loading from HF repo in scripts.

Without operations cost RAM like cache latents, it cost about 7.2GB RAM to load model on the first GPU and 11.1GB on the second GPU.

from accelerate import init_empty_weights, load_checkpoint_and_dispatch
from accelerate.utils.modeling import set_module_tensor_to_device

from diffusers import StableDiffusionXLPipeline
import torch
import gc

from library import sdxl_model_util, sdxl_original_unet


def load_target_model(device):
    pipe = StableDiffusionXLPipeline.from_pretrained(
        "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
    )

    text_encoder1 = pipe.text_encoder
    text_encoder2 = pipe.text_encoder_2
    vae = pipe.vae
    unet = pipe.unet
    del pipe
    gc.collect()

    print("convert U-Net to original U-Net")

    state_dict = sdxl_model_util.convert_diffusers_unet_state_dict_to_sdxl(unet.state_dict())
    with init_empty_weights():
        unet = sdxl_original_unet.SdxlUNet2DConditionModel()
    for k in list(state_dict.keys()):
        set_module_tensor_to_device(unet, k, device, value=state_dict.pop(k))

    print("U-Net converted to original U-Net")
    return text_encoder1, text_encoder2, unet

for device in ["cuda:0","cuda:1"]:
    text_encoder1, text_encoder2, unet = load_target_model(device)
    text_encoder1.to(device)
    text_encoder2.to(device)

I am trying to make this work for gui verision

So what is the logic of loading diffusers? I just need to give stability ai repo name? Or perhaps any cli arguments to load into 2 GPUs?

@Isotr0py
Copy link
Contributor

Isotr0py commented Aug 26, 2023

It has been merged in #676. It initialized unet in meta device to reduce RAM usage. You can find more details in the PR.

So just need to change pretrained_model_name_or_path from "/kaggle/working/sd_xl_base_1.0.safetensors" to "stabilityai/stable-diffusion-xl-base-1.0"

pretrained_model_name_or_path="stabilityai/stable-diffusion-xl-base-1.0"

(I'm not sure whether loading from HF repo works on bmaltais's gui or not, I seldom used the gui version) (;´ヮ`)7

@FurkanGozukara
Copy link
Author

It has been merged in #676. It initialized unet in meta device to reduce RAM usage. You can find more details in the PR.

So just need to change pretrained_model_name_or_path from "/kaggle/working/sd_xl_base_1.0.safetensors" to "stabilityai/stable-diffusion-xl-base-1.0"

pretrained_model_name_or_path="stabilityai/stable-diffusion-xl-base-1.0"

(I'm not sure whether loading from HF repo works on bmaltais's gui or not, I seldom used the gui version) (;´ヮ`)7

thank you i tested this

it is better than before at least starting to load into vram

but still crashes

tried both with --lowram and without it and with all optimizations

when --lowram used i think this time crash happens due to out of vram error

image

@FurkanGozukara
Copy link
Author

cache latents causing vram error

can you make it before other stuff?

without cache latents this is on kaggle

image

@Isotr0py
Copy link
Contributor

Isotr0py commented Aug 26, 2023

As I presented above, using multi-gpu to train a LoRA with only 47 images on kaggle costs about 12.8GB RAM. Because it has cost about 11.1GB RAM to load models on two GPU (I think it's hard to optimize this part's RAM usage more at least).

So if the dataset is large (normally 1000+ images), the RAM usage will reach 13GB to crash the kernel easily.

However, it seems that the training can work normally on one GPU. So I think we can enable only one GPU on kaggle to make sure the training can work steadily.

QQ图片20230826222131

@FurkanGozukara
Copy link
Author

FurkanGozukara commented Aug 26, 2023

@Isotr0py i made it work with kaggle

1024x1024 latest version of Kohya

hopefully will make a tutorial

by the way couldnt use caching. that part is causing vram error. it has to be fixed @kohya-ss

@Isotr0py
Copy link
Contributor

Isotr0py commented Aug 26, 2023

Maybe you can cache latents to disk for your dataset before starting a training.

Once the .npz latents exists, vae will skip it. So you can prevent vae encoding if all images are pre-encoded.

Anyway, a possible way to reduce RAM and VRAM is intergrating full_fp16 and full_bf16 support to model loading. But it needs a refactor for the training workflow partly.

@FurkanGozukara
Copy link
Author

@Isotr0py so how to cache into disk before starting?

@Isotr0py
Copy link
Contributor

@FurkanGozukara Use tools/cache_latents.py

@Isotr0py
Copy link
Contributor

@FurkanGozukara The PR should reduce RAM and VRAM usage with --lowram and --full_fp16/bf16.

The training should run normally with caching latents on kaggle.

@FurkanGozukara
Copy link
Author

FurkanGozukara commented Aug 27, 2023

@FurkanGozukara The PR should reduce RAM and VRAM usage with --lowram and --full_fp16/bf16.

The training should run normally with caching latents on kaggle.

thank you so much i will test when the @bmaltais pulls it into gradio version

@FurkanGozukara
Copy link
Author

FurkanGozukara commented Sep 3, 2023

@Isotr0py amazing update - it was just merged into kohya GUI today

here i made a speed comparison

https://twitter.com/GozukaraFurkan/status/1698471340032872721

I also updated my tutorial github readme file

How To Do SDXL LoRA Training On RunPod With Kohya SS GUI Trainer & Use LoRAs With Automatic1111 UI

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants