Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Problem with Flux Schnell bfloat16 multiGPU #9195

Closed
OlegRuban-ai opened this issue Aug 16, 2024 · 26 comments
Closed

Problem with Flux Schnell bfloat16 multiGPU #9195

OlegRuban-ai opened this issue Aug 16, 2024 · 26 comments
Labels
bug Something isn't working

Comments

@OlegRuban-ai
Copy link

OlegRuban-ai commented Aug 16, 2024

Describe the bug

Hello! I set device_map='balanced' and get images generated in 2.5 minutes (expected in 12-20 seconds), while in pipe.hf_device_map it shows that the devices are distributed like this:

{
 "transformer": "cuda:0",
 "text_encoder_2": "cuda:2",
 "text_encoder": "cuda:0",
 "vae": "cuda:1"
                }

I have 3 video cards 3090 Ti 24 GB and I can’t run it on them.

I also tried this way:
pipe.transformer.to('cuda:2')
pipe.text_encoder.to('cuda:2')
pipe.text_encoder_2.to('cuda:1')
pipe.vae.to('cuda:0')

What is the best way to launch it so that generation occurs on the GPU and quickly?

Reproduction

            pipe = FluxPipeline.from_pretrained(
                path_chkpt,
                torch_dtype=torch.bfloat16,
                device_map='balanced',
            )

Logs

No response

System Info

ubuntu 22.04 3 GPU: 3090 TI 24 GB

accelerate==0.30.1
addict==2.4.0
apscheduler==3.9.1
autocorrect==2.5.0
chardet==4.0.0
cryptography==37.0.2
curl_cffi
diffusers==0.30.0
beautifulsoup4==4.11.2
einops
facexlib>=0.2.5
fastapi==0.92.0
hidiffusion==0.1.6
invisible-watermark>=0.2.0
numpy==1.24.3
opencv-python==4.8.0.74
pandas==2.0.3
pycocotools==2.0.6
pymystem3==0.2.0
pyyaml==6.0
pyjwt==2.6.0
python-multipart==0.0.5
pytrends==4.9.1
psycopg2-binary
realesrgan==0.3.0
redis==4.5.1
sacremoses==0.0.53
selenium==4.2.0
sentencepiece==0.1.97
scipy==1.10.1
scikit-learn==0.24.1
supervision==0.16.0
tb-nightly==2.14.0a20230629
tensorboard>=2.13.0
tomesd
transformers==4.40.1
timm==0.9.16
yapf==0.32.0
uvicorn==0.20.0

spacy==3.7.2
nest_asyncio==1.5.8
httpx==0.25.0

torchvision==0.15.2

insightface==0.7.3
psutil==5.9.6
tk==0.1.0
customtkinter==5.2.1
tensorflow==2.13.0
opennsfw2==0.10.2
protobuf==4.24.4
gfpgan==1.3.8

Who can help?

No response

@OlegRuban-ai OlegRuban-ai added the bug Something isn't working label Aug 16, 2024
@OlegRuban-ai
Copy link
Author

Now device_map can only be balanced and does not support loading the dictionary with distribution over the GPU, but when I set it to balanced, the transformer is sent to the CPU, which may be the reason for the severe slowdown. It doesn’t fit into one GPU, and fp8 doesn’t yet support smaller formats. Please help solve the problem.

@asomoza
Copy link
Member

asomoza commented Aug 16, 2024

in both examples your using the transformer which is the bigger model with a text encoder in the same GPU, the transformer can fit (alone) in a 24GB GPU if you're not using a GUI so the best solution here is to change the distribution so that both text encoders are in the same GPU and leave the transformer alone in another.

Also with this PR you can use a sharded transformer Flux model.

@OlegRuban-ai
Copy link
Author

in both examples your using the transformer which is the bigger model with a text encoder in the same GPU, the transformer can fit (alone) in a 24GB GPU if you're not using a GUI so the best solution here is to change the distribution so that both text encoders are in the same GPU and leave the transformer alone in another.

Also with this PR you can use a sharded transformer Flux model.

The problem is that when I put it like this:

        pipe.transformer.to('cuda:0')
        pipe.text_encoder.to('cuda:1')
        pipe.text_encoder_2.to('cuda:1')
        pipe.vae.to('cuda:2')

then I run into the error RuntimeError('Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1! (when checking argument for argument mat1 in method wrapper_CUDA_addmm)')

when I put it like this:
pipe.transformer.to('cuda:0')
pipe.text_encoder.to('cuda:0')
pipe.text_encoder_2.to('cuda:1')
pipe.vae.to('cuda:2')

CUDA out of memory. Tried to allocate 20.00 MiB (GPU 0; 23.69 GiB total capacity; 22.29 GiB already allocated; 8.75 MiB free; 22.30 GiB reserved in total by PyTorch)

@asomoza
Copy link
Member

asomoza commented Aug 16, 2024

oh, this means that in the second configuration it fails before the inference because it OOMs when loading the model, it doesn't mean that it works with the second configuration better than the first one.

One solution is to separate the stages like in the PR (get text embeddings, denoising, decode latents) so the tensors are in the same device.

I'm not sure if this is supposed to be done automatically in the pipeline when using multi GPUs so ccing @sayakpaul for more insights.

Also I'm curious if you're going to get better performance that just using one GPU with cpu offloading, you can just use one 3090 with flux in bfloat16 if you have all the 24 GB available.

@sayakpaul
Copy link
Member

In this case, it's best to use enable_model_cpu_offload() and run it in a single GPU:

pipe = FluxPipeline.from_pretrained(
    path_chkpt,
    torch_dtype=torch.bfloat16,
)
pipe.enable_model_cpu_offload()

Make sure you install accelerate from source. I am also not sure if bfloat16 is a preferred data-type for 3090.

In any case, the slow-down is somewhat expected because of the data movements.

@OlegRuban-ai
Copy link
Author

In this case, it's best to use enable_model_cpu_offload() and run it in a single GPU:

pipe = FluxPipeline.from_pretrained(
    path_chkpt,
    torch_dtype=torch.bfloat16,
)
pipe.enable_model_cpu_offload()

I tried this option, but the generation speed of 80 seconds per image does not suit me. Are there any options to speed it up?

@sayakpaul
Copy link
Member

Did you try with FP16?

@OlegRuban-ai
Copy link
Author

Did you try with FP16?

Yes

@sayakpaul
Copy link
Member

Hmm, then there are a couple of things to keep in mind.

Let's first summarize what we have:

1> device_map="balanced" won't solve the latency problem because of the data movements.
2> We cannot use a single 24GB card with Flux because not all the modules would fit on the VRAM.
3> enable_model_cpu_offload() doesn't meet your inference latency requirements (cc: @asomoza do you observe similar trends in timing too?).

We can perform quantization with optimum.quanto to solve 2. https://gist.github.com/AmericanPresidentJimmyCarter/873985638e1f3541ba8b00137e7dacd9 shows a thorough example. But that might lead to a slowdown in latency because the compute dtype is still FP16/BF16.

Next option would be to try out torchao, more specifically, the int8wo (as in int8 weight-only) quantization. This will help us fit the modules in memory, hopefully.

Now, to solve the latency problem we can make use of torch.compile() as in discussed here.

Would you like to give this a try?

Code (haven't tested):

from torchao.quantization import int8_weight_only, quantize_

... # pipeline loading code

quantize_(pipeline.transformer, int8_weight_only())
pipeline.transformer.to(memory_format=torch.channels_last)
pipeline.transformer = torch.compile(pipeline.transformer, mode="max-autotune", fullgraph=True)

You can potentially apply the same scheme to other modules of the pipeline, too and see if they help. Note that torch.compile() support may not be there yet for the CLIP and T5 model.

You can also try out the other quantization schemes from torchao which you can find from the repo.

We will be integrating NF4 support soon through bitsandbytes. Hopefully, that should solve your issues better. See: #9174

@OlegRuban-ai
Copy link
Author

I used this code:

'''#from torchao.quantization import int8_weight_only, quantize_

pipeline loading code

quantize_(pipe.transformer, int8_weight_only())
pipe.transformer.to(memory_format=torch.channels_last)
pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True)
'''

Unsupported: hasattr ConstDictVariable to

from user code:
File "/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py", line 164, in new_forward
args, kwargs = module._hf_hook.pre_forward(module, *args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py", line 363, in pre_forward
return send_to_device(args, self.execution_device), send_to_device(
File "/usr/local/lib/python3.10/dist-packages/accelerate/utils/operations.py", line 148, in send_to_device
if is_torch_tensor(tensor) or hasattr(tensor, "to"):

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

  1. In fp16, the generation speed is the same low (80 sec/it), but RAM and VRAM consumption is minimal

  2. Doesn’t work via torch.compile()

  3. transformer and text_encoder must be on the same video card, but together they do not fit into 24 GB

  4. Quantization failed. There is hope that it will be possible to implement this model in diffusers: https://huggingface.co/lllyasviel/flux1-dev-bnb-nf4

Thank you for your time.

@sayakpaul
Copy link
Member

For torch.compile() and quantization things, please always use the nightlies.

transformer and text_encoder must be on the same video card, but together they do not fit into 24 GB

What if you first keep them on CPU, quantize, and then move to GPU?

Quantization failed. There is hope that it will be possible to implement this model in diffusers:

See here: #9165

@asomoza
Copy link
Member

asomoza commented Aug 16, 2024

there's something wrong in your env though, I can run Flux dev with bfloat16 (don't have space to test schnell right now) in a single 3090 (not TI) without a GUI and just using enable_model_cpu_offload and with 20 steps it takes 29 seconds to generate

Taking 80 seconds using schnell with 4 steps with a 3090 TI means that something else is wrong.

Since you had OOMs is not that you're using RAM instead of VRAM like when you use Windows so my guess is that you're using really slow RAM or maybe you're swapping to disk?

As a reference, with bfloat16 or float16 and offloading to cpu you'll need more than 32GB of RAM, in my case I use more than 50GB with what I have loaded from before, so take that into consideration.

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Aug 17, 2024

@OlegRuban-ai
can you confirm what's your ourput for this script?

from diffusers import DiffusionPipeline
import torch
path_chkpt = "black-forest-labs/FLUX.1-schnell"
pipe = DiffusionPipeline.from_pretrained(
    path_chkpt,
    torch_dtype=torch.bfloat16,
    device_map='balanced',
)

print(pipe.hf_device_map)

the device map you pasted here #9195 (comment) is not what we would expect, the algorithm should have automatically put vae and text_encoder into the same gpu

{
 "transformer": "cuda:0",
 "text_encoder_2": "cuda:2",
 "text_encoder": "cuda:0",
 "vae": "cuda:1"
                }

@OlegRuban-ai
Copy link
Author

@yiyixuxu
This is what I get in the output. It is impossible to set parameters directly with a dictionary, as was previously possible in diffusers, because it only accepts device_map='balanced'

{'transformer': 'cpu', 'text_encoder_2': 0, 'text_encoder': 1, 'vae': 2}

as you can see, the transformer gets to the CPU, although I have one 3090 video card that remains completely free

@OlegRuban-ai
Copy link
Author

OlegRuban-ai commented Aug 17, 2024

@asomoza

The RAM specifications (Х6 = 96 GB):

Array Handle: 0x0013
Error Information Handle: Not Provided
Total Width: 72 bits
Data Width: 64 bits
Size: 16384 MB
Form Factor: DIMM
Set: None
Locator: PROC 1 DIMM 1
Bank Locator: Not Specified
Type: DDR4
Type Detail: Synchronous Registered (Buffered)
Speed: 2666 MT/s
Manufacturer: HPE
Serial Number: Not Specified
Asset Tag: Not Specified
Part Number: 840757-091
Rank: 1
Configured Memory Speed: 2400 MT/s
Minimum Voltage: 1.2 V
Maximum Voltage: 1.2 V
Configured Voltage: 1.2 V
Memory Technology: DRAM
Memory Operating Mode Capability: Volatile memory
Firmware Version: Not Specified
Module Manufacturer ID: Bank 10, Hex 0x83
Module Product ID: Unknown
Memory Subsystem Controller Manufacturer ID: Unknown
Memory Subsystem Controller Product ID: Unknown
Non-Volatile Size: None
Volatile Size: 16 GB
Cache Size: None
Logical Size: None

Tell me, what does your model launch code look like? Because in Google Colab with A100 or L4 i have also speed low than you said.

@asomoza
Copy link
Member

asomoza commented Aug 17, 2024

@OlegRuban-ai I'm sorry, that was from memory but now I remember that 29s was with FP8, with bfloat16 it takes 42s in the generation loop and 54.06 seconds with the vae decoding. I'm not counting the loading though which is done just the first time.

My code is just simple, I add some lines just so I can save the elapsed time but I can't provide screenshots as usual because I have to kill my GUI to do the tests.

I now have space to test schnell and I got 8s in the denoising loop and 14.58 with vae decoding

Here's the code for schnell

import time

import torch

import diffusers


pipeline = diffusers.DiffusionPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-schnell",
    torch_dtype=torch.bfloat16,
)
pipeline.enable_model_cpu_offload()


prompt = "a photo of a dog with cat-like look"

start_time = time.time()

image = pipeline(
    prompt=prompt,
    guidance_scale=0.0,
    num_inference_steps=4,
    max_sequence_length=256,
).images[0]

elapsed_time = time.time() - start_time

with open("elapsed_time_schnell.txt", "w") as file:
    file.write(f"Total elapsed time: {elapsed_time:.2f} seconds\n")

Also as @sayakpaul mentioned, you'll need to install accelerate from source to be able to run it like this with 24 GB VRAM and as I mentioned, the VRAM has to be 100% free.

This is the image I got from schnell:

image

@yiyixuxu
Copy link
Collaborator

@OlegRuban-ai the fact that transformer was put to the cpu with 1 gpu still free indicates that it does not fit into any of these 3 gpus

can you run this?

from accelerate.utils import get_max_memory
print(get_max_memory())

@OlegRuban-ai
Copy link
Author

@asomoza

I tested it on an A100 and L4 video card in Google Colab with your code and found that the model does not fit in L4 (22.5 Gb VRAM) and throws an out of memory error. It fits in the A100 (40 Gb VRAM), but at startup it takes up to 40 GB of RAM and the peak load on the video card during generation before memory is freed in the CPU is 26 Gb, therefore it will not fit in 24 Gb on the 3090 during generation.

Generation speed on A100:
Total elapsed time: 39.82 seconds

but 4 steps of image generation take 15 seconds, the remaining 25 seconds the code frees VRAM memory and transfers it to RAM, which is quite long
L4_cuda_out_of_memory_flux_schnell
a100_flux_schnell

@OlegRuban-ai
Copy link
Author

On the A100 model, if I place the pipeline in the video card pipeline.to('cuda') and start generating without changing anything else in the above settings, I get the following:
RAM: 3.8 Gb
VRAM: 36.5 Gb
Speed: 4.4 sec

Is it possible to somehow speed up the process of switching to the CPU with pipeline.enable_model_cpu_offload() so that the generation takes at least 20 seconds, not 39?
a100_flux_schnell_cuda

@yiyixuxu
Copy link
Collaborator

@sayakpaul let's look into this?

@yiyixuxu
Copy link
Collaborator

I ran it on a colab A100, I think it takes about 10 seconds to move the all the components to GPU and 24 seconds back to cpu so I would say the slow down with enable_model_cpu_offload is most likely expected, just more noticeable with flux schnell given it is only 4steps and that its transformer is largest we have

@sayakpaul
Copy link
Member

Yeah, I am on the same boat as you as well. I'm not sure if the CPU movement can be minimized. The movement costs just get compounded when the size of the underlying model is just too big. So, using a quantized model could be useful here.

@OlegRuban-ai
Copy link
Author

@yiyixuxu

периодически неправильно распределяет видеокарты. Один раз я загрузил и трансформер с энкодером встали на одну видеокарту, а потом перезапустил и получаю такую картину:

{'text_encoder_2': 1, 'transformer': 0, 'text_encoder': 2, 'vae': 2}

до перезапуска, когда всё работало, было так:
{'text_encoder_2': 1, 'transformer': 2, 'text_encoder': 0, 'vae': 0}

Получаю ошибку:
RuntimeError('Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:2! (when checking argument for argument mat1 in method wrapper_CUDA_addmm)')

Код никак не менялся, нагрузки на видеокарту тоже не изменились

@sayakpaul
Copy link
Member

https://x.com/risingsayak/status/1825500686345207897?s=46 might be an interesting thread for you. Give it a try and let us know :)

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Sep 15, 2024
@sayakpaul
Copy link
Member

@sayakpaul sayakpaul removed the stale Issues that haven't received updates label Sep 15, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants