Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce peak RAM usage #332

Merged
merged 1 commit into from
Mar 30, 2023
Merged

Reduce peak RAM usage #332

merged 1 commit into from
Mar 30, 2023

Conversation

guaneec
Copy link
Contributor

@guaneec guaneec commented Mar 27, 2023

These are the changes I had to make to have LoRA training work in the Kaggle T4 x 2 environment.

What this PR does:

  • Sequentially loads the SD model for each GPU so the peak usage doesn't grow linearly with the number of GPUs
  • Directly load state dicts into the correct device

Haven't tested if it breaks other stuff.

@kohya-ss kohya-ss changed the base branch from main to dev March 30, 2023 12:36
@kohya-ss kohya-ss merged commit 6c28dfb into kohya-ss:dev Mar 30, 2023
@kohya-ss
Copy link
Owner

I've merged. Thank you for this!

@feffy380
Copy link
Contributor

feffy380 commented Mar 30, 2023

This seems to cause out of memory errors on gpus with 8GB VRAM (3060 ti):
image
(Screenshot not mine)
Reverting to an earlier commit (b996f5a) fixed the user's issue and this is the most recent change affecting model loading

@kohya-ss
Copy link
Owner

kohya-ss commented Mar 30, 2023

Thank you for reporting. I will fix sooner.

@guaneec
I will fix like this. The device option with safetensors.torch.load_file may cause an error in some environment, so the option is set when args.lowram is specified. Please let me know if you notice anything.

EDIT: load_file seems to raise an error always with 'device="cuda"', so I remove the argument from load_file, and if with lowram, move models to cuda after loading.

   for pi in range(accelerator.state.num_processes):
        # TODO: modify other training scripts as well
        if pi == accelerator.state.local_process_index:
            print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")

            text_encoder, vae, unet, _ = train_util.load_target_model(
                args, weight_dtype, accelerator.device if args.lowram else "cpu"
            )

            # work on low-ram device
            if args.lowram:
                text_encoder.to(accelerator.device)
                unet.to(accelerator.device)
                vae.to(accelerator.device)

            gc.collect()
            torch.cuda.empty_cache()
        accelerator.wait_for_everyone()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants