Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable distributed sample image generation on multi-GPU enviroment #1061

Merged

Conversation

DKnight54
Copy link
Contributor

Modified the sample_images_common function to make use of Accelerator PartialState to feed the list of sample images prompt to all available GPUs.

Tested working fine in single GPU Google Colab enviroment and dual GPU Kaggle enviroment.

Possible side effect of using mutliple GPUs to generate sample images is that the file creation time may not sync with the order of the prompt from the original prompt file. Attempted some mitigation by spliting prompts to passed to each GPU process in the in the order that the GPU process is called. However, if the sample image prompts have different samplers and/or number of steps, this would likely break the workaround as generation times would be out of sync.

Might be able to artificially force syncronization by making the sample image process wait for all other processes to complete the image generation step before continuing to the next sample image by using accelerator.wait_for_everyone() but I imagine efficient use of GPU time would be more important than perfectly sorted sample images based on image creation time.

@kohya-ss
Copy link
Owner

Thanks for the PR. However, do we really need to distribute the sample image generation across multiple GPUs? I don't think it would take that long. I am worried about increasing the complexity of the code.

@DKnight54
Copy link
Contributor Author

Thanks for the PR. However, do we really need to distribute the sample image generation across multiple GPUs? I don't think it would take that long. I am worried about increasing the complexity of the code.

Honestly? Most use cases where there is only one or two sample images, probably not much effect. However, in scenarios where we are training multiple concepts or are OCD like me and have multiple sample prompts, the ability to spread the load to the idle GPUs does speed up the down time during sample generation. Especially considering that SDXL (and possibly future more complex and heavy models) does require more steps to generate images.

For context, using samplers like k_dpm_2_a and k_dpm_2 doubles the number of steps automatically, and when some SDXL models recommend sampling for 30~60 steps, that means that a single image can, depending on the GPU and image resolution, take around 5 mins to render.

Additionally, for users like me running on the available free resources like colab and kaggle, every minute we can save means one more minute we can put to training LoRAs and models.

@kohya-ss
Copy link
Owner

I am still skeptical about distributed sample generation. If we are doing a large-scale training, we would probably have separate resources to evaluate the saved models, and with Colab and Kaggle, we would probably want to do more steps of training instead of sample output...

However, I think the code is much simpler now.

May I merge this PR?

However, there are a few points of concern now, and I would rewrite the code after merging, even if it is redundant, for the sake of clarity. I would appreciate your understanding and would ask you to test the changes again.

@DKnight54
Copy link
Contributor Author

Sure! Of course you are free to modify the code as you like! It's your code that I modified in the first place.

@kohya-ss kohya-ss changed the base branch from main to dev_multi_gpu_sample_gen February 3, 2024 12:45
@kohya-ss kohya-ss merged commit 1567ce1 into kohya-ss:dev_multi_gpu_sample_gen Feb 3, 2024
1 check failed
@kohya-ss
Copy link
Owner

kohya-ss commented Feb 3, 2024

I merged this to the new branch dev_multi_gpu_sample_gen and updated a bit to simplify the code. I would be happy if you could test it in multiple GPU training.

Thank you again for this PR!

@DKnight54
Copy link
Contributor Author

Hey, Did a test run in Kaggle enviroment on textual inversion, and at the sample at first run, I encountered VRAM OOM when it when to the latents to image step.

As mentioned in #1019 the workaround I came up with was to insert a call to torch.cuda.empty_cache() after the latents have been generated, and before the latents are converted into images. like below:

    with accelerator.autocast():
        latents = pipeline(
            prompt=prompt,
            height=height,
            width=width,
            num_inference_steps=sample_steps,
            guidance_scale=scale,
            negative_prompt=negative_prompt,
            controlnet=controlnet,
            controlnet_image=controlnet_image,
        )
    with torch.cuda.device(torch.cuda.current_device()):
        torch.cuda.empty_cache()
    image = pipeline.latents_to_image(latents)[0]
Traceback (most recent call last):
  File "/kaggle/temp/sd-scripts/sdxl_train_textual_inversion.py", line 137, in <module>
    trainer.train(args)
  File "/kaggle/temp/sd-scripts/train_textual_inversion.py", line 530, in train
    self.sample_images(
  File "/kaggle/temp/sd-scripts/sdxl_train_textual_inversion.py", line 85, in sample_images
    sdxl_train_util.sample_images(
  File "/kaggle/temp/sd-scripts/library/sdxl_train_util.py", line 367, in sample_images
    return train_util.sample_images_common(SdxlStableDiffusionLongPromptWeightingPipeline, *args, **kwargs)
  File "/kaggle/temp/sd-scripts/library/train_util.py", line 4753, in sample_images_common
    sample_image_inference(accelerator, args, pipeline, save_dir, prompt_dict, epoch, steps, prompt_replacement, controlnet=controlnet)
  File "/kaggle/temp/sd-scripts/library/train_util.py", line 4823, in sample_image_inference
    image = pipeline.latents_to_image(latents)[0]
  File "/kaggle/temp/sd-scripts/library/sdxl_lpw_stable_diffusion.py", line 1040, in latents_to_image
    image = self.decode_latents(latents.to(self.vae.dtype))
  File "/kaggle/temp/sd-scripts/library/sdxl_lpw_stable_diffusion.py", line 714, in decode_latents
    image = self.vae.decode(latents.to(self.vae.dtype)).sample
  File "/kaggle/temp/venv/lib/python3.10/site-packages/diffusers/utils/accelerate_utils.py", line 46, in wrapper
    return method(self, *args, **kwargs)
  File "/kaggle/temp/venv/lib/python3.10/site-packages/diffusers/models/autoencoders/autoencoder_kl.py", line 304, in decode
    decoded = self._decode(z).sample
  File "/kaggle/temp/venv/lib/python3.10/site-packages/diffusers/models/autoencoders/autoencoder_kl.py", line 275, in _decode
    dec = self.decoder(z)
  File "/kaggle/temp/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/kaggle/temp/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/kaggle/temp/venv/lib/python3.10/site-packages/diffusers/models/autoencoders/vae.py", line 338, in forward
    sample = up_block(sample, latent_embeds)
  File "/kaggle/temp/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/kaggle/temp/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/kaggle/temp/venv/lib/python3.10/site-packages/diffusers/models/unet_2d_blocks.py", line 2535, in forward
    hidden_states = upsampler(hidden_states)
  File "/kaggle/temp/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/kaggle/temp/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/kaggle/temp/venv/lib/python3.10/site-packages/diffusers/models/upsampling.py", line 184, in forward
    hidden_states = self.conv(hidden_states, scale)
  File "/kaggle/temp/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/kaggle/temp/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/kaggle/temp/venv/lib/python3.10/site-packages/diffusers/models/lora.py", line 358, in forward
    return F.conv2d(
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 1024.00 MiB. GPU 0 has a total capacity of 14.75 GiB of which 831.06 MiB is free. Process 6503 has 13.93 GiB memory in use. Of the allocated memory 11.87 GiB is allocated by PyTorch, and 1.86 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
steps:   0%|          | 0/360 [01:57<?, ?it/s]
Traceback (most recent call last):
  File "/kaggle/temp/sd-scripts/sdxl_train_textual_inversion.py", line 137, in <module>
    trainer.train(args)
  File "/kaggle/temp/sd-scripts/train_textual_inversion.py", line 530, in train
    self.sample_images(
  File "/kaggle/temp/sd-scripts/sdxl_train_textual_inversion.py", line 85, in sample_images
    sdxl_train_util.sample_images(
  File "/kaggle/temp/sd-scripts/library/sdxl_train_util.py", line 367, in sample_images
    return train_util.sample_images_common(SdxlStableDiffusionLongPromptWeightingPipeline, *args, **kwargs)
  File "/kaggle/temp/sd-scripts/library/train_util.py", line 4753, in sample_images_common
    sample_image_inference(accelerator, args, pipeline, save_dir, prompt_dict, epoch, steps, prompt_replacement, controlnet=controlnet)
  File "/kaggle/temp/sd-scripts/library/train_util.py", line 4823, in sample_image_inference
    image = pipeline.latents_to_image(latents)[0]
  File "/kaggle/temp/sd-scripts/library/sdxl_lpw_stable_diffusion.py", line 1040, in latents_to_image
    image = self.decode_latents(latents.to(self.vae.dtype))
  File "/kaggle/temp/sd-scripts/library/sdxl_lpw_stable_diffusion.py", line 714, in decode_latents
    image = self.vae.decode(latents.to(self.vae.dtype)).sample
  File "/kaggle/temp/venv/lib/python3.10/site-packages/diffusers/utils/accelerate_utils.py", line 46, in wrapper
    return method(self, *args, **kwargs)
  File "/kaggle/temp/venv/lib/python3.10/site-packages/diffusers/models/autoencoders/autoencoder_kl.py", line 304, in decode
    decoded = self._decode(z).sample
  File "/kaggle/temp/venv/lib/python3.10/site-packages/diffusers/models/autoencoders/autoencoder_kl.py", line 275, in _decode
    dec = self.decoder(z)
  File "/kaggle/temp/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/kaggle/temp/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/kaggle/temp/venv/lib/python3.10/site-packages/diffusers/models/autoencoders/vae.py", line 338, in forward
    sample = up_block(sample, latent_embeds)
  File "/kaggle/temp/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/kaggle/temp/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/kaggle/temp/venv/lib/python3.10/site-packages/diffusers/models/unet_2d_blocks.py", line 2535, in forward
    hidden_states = upsampler(hidden_states)
  File "/kaggle/temp/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/kaggle/temp/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/kaggle/temp/venv/lib/python3.10/site-packages/diffusers/models/upsampling.py", line 184, in forward
    hidden_states = self.conv(hidden_states, scale)
  File "/kaggle/temp/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/kaggle/temp/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/kaggle/temp/venv/lib/python3.10/site-packages/diffusers/models/lora.py", line 358, in forward
    return F.conv2d(
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 1024.00 MiB. GPU 1 has a total capacity of 14.75 GiB of which 831.06 MiB is free. Process 6504 has 13.93 GiB memory in use. Of the allocated memory 11.87 GiB is allocated by PyTorch, and 1.86 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
[2024-02-03 17:28:39,889] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 699 closing signal SIGTERM
[2024-02-03 17:28:40,254] torch.distributed.elastic.multiprocessing.api: [ERROR] failed (exitcode: 1) local_rank: 0 (pid: 698) of binary: /kaggle/temp/venv/bin/python
Traceback (most recent call last):
  File "/kaggle/temp/venv/bin/accelerate", line 8, in <module>
    sys.exit(main())
  File "/kaggle/temp/venv/lib/python3.10/site-packages/accelerate/commands/accelerate_cli.py", line 47, in main
    args.func(args)
  File "/kaggle/temp/venv/lib/python3.10/site-packages/accelerate/commands/launch.py", line 1008, in launch_command
    multi_gpu_launcher(args)
  File "/kaggle/temp/venv/lib/python3.10/site-packages/accelerate/commands/launch.py", line 666, in multi_gpu_launcher
    distrib_run.run(args)
  File "/kaggle/temp/venv/lib/python3.10/site-packages/torch/distributed/run.py", line 803, in run
    elastic_launch(
  File "/kaggle/temp/venv/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 135, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/kaggle/temp/venv/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 268, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
============================================================
sdxl_train_textual_inversion.py FAILED
------------------------------------------------------------
Failures:
  <NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2024-02-03_17:28:39
  host      : f6fb502e6534
  rank      : 0 (local_rank: 0)
  exitcode  : 1 (pid: 698)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
============================================================
Name: torch
Version: 2.2.0+cu118
Summary: Tensors and Dynamic neural networks in Python with strong GPU acceleration
Home-page: https://pytorch.org/
Author: PyTorch Team
Author-email: [email protected]
License: BSD-3
Location: /kaggle/temp/venv/lib/python3.10/site-packages
Requires: filelock, fsspec, jinja2, networkx, nvidia-cublas-cu11, nvidia-cuda-cupti-cu11, nvidia-cuda-nvrtc-cu11, nvidia-cuda-runtime-cu11, nvidia-cudnn-cu11, nvidia-cufft-cu11, nvidia-curand-cu11, nvidia-cusolver-cu11, nvidia-cusparse-cu11, nvidia-nccl-cu11, nvidia-nvtx-cu11, sympy, triton, typing-extensions
Required-by: accelerate, open-clip-torch, pytorch-lightning, timm, torchaudio, torchmetrics, torchvision

@DKnight54
Copy link
Contributor Author

Testing on Kaggle when training LoRAs works fine though

@kohya-ss
Copy link
Owner

kohya-ss commented Feb 4, 2024

Thank you for testing! I've added cuda.empty_cache at that position.

@FurkanGozukara
Copy link

yesterday i helped one of my patreon supporter and he had dual rtx 4090 on linux

1 gpu training speed was 1.2 it / s

when 2 gpus used the training speed dropped to 2 second / it

literally became slower than single card cumulatively

wkpark pushed a commit to wkpark/sd-scripts that referenced this pull request Feb 27, 2024
Revert bitsandbytes-windows update
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants