-
Notifications
You must be signed in to change notification settings - Fork 5.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
NF4 Flux params in diffusers #9165
Comments
It looks like state_dict comparison from Forge. They use double quantization, which adds even more keys. |
@feffy380 thank you but see this bit in ...
if pre_quantized:
if (param_name + ".quant_state.bitsandbytes__fp4" not in state_dict) and (
param_name + ".quant_state.bitsandbytes__nf4" not in state_dict
):
raise ValueError(
f"Supplied state dict for {param_name} does not contain `bitsandbytes__*` and possibly other `quantized_stats` components."
)
quantized_stats = {}
for k, v in state_dict.items():
# `startswith` to counter for edge cases where `param_name`
# substring can be present in multiple places in the `state_dict`
if param_name + "." in k and k.startswith(param_name):
quantized_stats[k] = v
if unexpected_keys is not None and k in unexpected_keys:
unexpected_keys.remove(k)
new_value = bnb.nn.Params4bit.from_prequantized(
data=param_value,
quantized_stats=quantized_stats,
requires_grad=False,
device=target_device,
) It already packs those quantization stats. But will look into it. |
Update: NF4 serialization and loading are working fine. @DN6 let's brainstorm how we can support it more easily? This would help us unlock doing LoRAs on the quantized weights, too (cc: @BenjaminBossan for PEFT). I think this will become evidently critical for larger models.
Cc: @SunMarc for jamming on this together. |
First of all, I'm very grateful for your efforts. However, I still don't understand where the problem lies in the code using nf4 above. Could you share the final code that can serialize and load nf4? |
It is working now. No issues there :) please test and let me know if you run into something weird. |
Thank you, Alvaro! Of course we can additionally improve things by:
|
In case someone's looking for a way to get the text encoders and the VAE from https://huggingface.co/lllyasviel/flux1-dev-bnb-nf4 converted to the """
Main model in bnb-nf4
T5xxl in fp8e4m3fn
CLIP-L in fp16
VAE in bf16
"""
from huggingface_hub import hf_hub_download
from diffusers.loaders.single_file_utils import (
create_diffusers_clip_model_from_ldm,
create_diffusers_t5_model_from_checkpoint,
convert_ldm_vae_checkpoint,
)
from diffusers import AutoencoderKL
import safetensors.torch
from transformers import T5EncoderModel, CLIPTextModel, AutoConfig
import torch
ckpt_id = "lllyasviel/flux1-dev-bnb-nf4"
filename = "flux1-dev-bnb-nf4.safetensors"
ckpt_path = hf_hub_download(repo_id=ckpt_id, filename=filename)
sd = safetensors.torch.load_file(ckpt_path)
# T5
t5 = create_diffusers_t5_model_from_checkpoint(
cls=T5EncoderModel, checkpoint=sd, config="black-forest-labs/FLUX.1-dev", subfolder="text_encoder_2"
)
# CLIP
clip_ = create_diffusers_clip_model_from_ldm(
cls=CLIPTextModel, checkpoint=sd, config="black-forest-labs/FLUX.1-dev", subfolder="text_encoder"
)
# VAE
config = AutoencoderKL.load_config("stabilityai/stable-diffusion-3-medium-diffusers", subfolder="vae")
vae = AutoencoderKL.from_config(config, scaling_factor=0.3611, shift_factor=0.1159).to(torch.bfloat16)
vae_sd = {k.replace("vae.", ""): v for k, v in sd.items() if "vae." in k}
vad_sd = convert_ldm_vae_checkpoint(vae_sd, vae.config)
vae.load_state_dict(vad_sd, strict=True) For the diffusion model as in keys prefixed with |
Looks good. Let's get this in 👍🏽 |
I try to use this for pipeline failed.
|
Since the purpose of this issue was to show a PoC of NF4 serialization and loading, I will close this. |
Wondering if there's a chance this method could be used with Flux ControlNet? I think it needs it... |
Should be orthogonal after #9174 is done. |
ValueError Traceback (most recent call last)
Cell In[9], line 14
11 ckpt_path = hf_hub_download("lllyasviel/flux1-dev-bnb-nf4", filename="flux1-dev-bnb-nf4-v2.safetensors")
13 original_state_dict = safetensors.torch.load_file(ckpt_path)
---> 14 converted_state_dict = convert_flux_transformer_checkpoint_to_diffusers(original_state_dict)
16 del original_state_dict
17 gc.collect()
File ~/miniconda3/lib/python3.10/site-packages/diffusers/loaders/single_file_utils.py:1885, in convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs)
1882 def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
1883 converted_state_dict = {}
-> 1885 num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "double_blocks." in k))[-1] + 1 # noqa: C401
1886 num_single_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "single_blocks." in k))[-1] + 1 # noqa: C401
1887 mlp_ratio = 4.0
File ~/miniconda3/lib/python3.10/site-packages/diffusers/loaders/single_file_utils.py:1885, in <genexpr>(.0)
1882 def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
1883 converted_state_dict = {}
-> 1885 num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "double_blocks." in k))[-1] + 1 # noqa: C401
1886 num_single_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "single_blocks." in k))[-1] + 1 # noqa: C401
1887 mlp_ratio = 4.0
ValueError: invalid literal for int() with base 10: 'diffusion_model' When I run the code below, it throws the error above. How could I use nfv v2 from lllyasviel/stable-diffusion-webui-forge#1079 from huggingface_hub import hf_hub_download
from accelerate.utils import set_module_tensor_to_device, compute_module_sizes
from accelerate import init_empty_weights
from diffusers.loaders.single_file_utils import convert_flux_transformer_checkpoint_to_diffusers
from diffusers import FluxTransformer2DModel, FluxPipeline
import safetensors.torch
import gc
import torch
dtype = torch.bfloat16
ckpt_path = hf_hub_download("lllyasviel/flux1-dev-bnb-nf4", filename="flux1-dev-bnb-nf4-v2.safetensors")
original_state_dict = safetensors.torch.load_file(ckpt_path)
converted_state_dict = convert_flux_transformer_checkpoint_to_diffusers(original_state_dict)
del original_state_dict
gc.collect() |
See what I told in #9165 (comment)
|
Wait for a moment. According to the link: lllyasviel/stable-diffusion-webui-forge#1079
how could we also be able to not have second compression stage and compute faster? It seems that we need to optimize the previous NF4 transformation method of the transformer |
@sayakpaul Any idea on how to get the FLUX.1-dev-gguf models running in diffusers? Thank you |
I've been testing using a diffusers dev gguf model, I got the quantized model easy but still trying to figure out how to load it and use it correctly at inference. Just a heads up, the gguf in ComfyUI is used just as a storage of the model so you only will see a VRAM saving but not better speed, I think I saw someone saying that Q8_0 is the best in terms of quality and VRAM but that nf4 its still faster. |
Ah I'm not too familiar with llama.cpp which that comment refers to. What does llama.cpp do to get higher speed with gguf quants vs. just dequantizing on the fly like the comfyui extension does? |
I'm not expert at gguf so maybe someone else has a better explanation of this, but first, the author of llama.cpp is the person that created the gguf format so that's why it is relevant. The difference in speed is that ComfyUI has to upscale back the tensors to bfloat16 for each computation while llama.cpp |
Wait really, I didn't know llama.cpp operates directly with quantized weights without dequantizing first, or that the model would still work since different weight blocks are quantized individually in gguf. Do you have source for this? |
Probably I over simplified it too much, it's not that easy and I don't understand it completely, that why I was telling that probably someone else could explain it better. You can read about it in this PR but it's more complex than that, some usefull information is:
There's a lot of math and tables to read and understand and I'm just learning about this as I'm trying to use it with diffusers, so I'm not really the best person to explain this yet. But you're right and to correct my previous post, they don't work with the quantized values but instead it does a really fast upscale which I suppose the comfyui version doesn't do and that's why it's slower. |
Oh an also before I forget, when referring to just as a storage, the same as the issue, what I meant also was the memory optimization and all the other optimization done in llama.cpp so you can actually use the llms with just the cpu, so people don't expect to be able to use Flux with just CPU with diffusers. When I get the gguf working I can test if it's really slower than nf4 or not. |
from transformers import BitsAndBytesConfig double_quant_config = BitsAndBytesConfig( |
Anyone have this working with the recently merged qkv projection for flux? I encounter an error in torch.nn.Linear since the attention weight seem to be in uint8. Casting to a float seems to work, but then gets caught with |
Where exactly do you find the error? Do you have a reproducible snippet? |
Hmm this won't probably be the case yet because the separate Q,K,V projections are not wiped off for bookkeeping purposes. Perhaps we could try to first fuse the QKV projections and then apply the NF4 shenanigans? |
This errors out because there's no weights yet :( |
I don't understand. Why won't you have the weights?
What do you mean here? |
Maybe you could give it another try as soon as #9213 becomes a bit more polished (I plan on working on it this week) and I can let you know once it's in a testable state? |
Sure! I think I misinterpreted the error torch was giving me: What I meant by "from source" is the generate.py script, as in making NF4 weights from the flux repo instead of using your repo in the load_from_nf4_and_generate.py script. |
Maybe you could try this in isolation and report a bug? You could use this minimal test as a reference:
Text encoder? Elaborate a bit more? How are quantizing the text encoder? |
I'll do this, but I think it has something to do with the NF4 weights specifically
The text encoder as in the one in the pipeline: when it's taken a few minutes and I give up and Ctrl+C, the traceback shows it was still on the text_encoder_2 stage. It's probably a product of the "meta" (CPU?) device being so slow. I'll try encoding separately from the pipeline and passing embeddings in as the prompt when I have a bit more time. |
Okay. Maybe would be just better to try it out after our support for bitsandbytes become more robust through the PRs I mentioned above. Because currently, it's all a bit hacky without proper checks and guards in places.
Well, I think it doesn't quite give the full picture, IMO. |
Should be: + pipe = FluxPipeline.from_pretrained("black-forest-labs/flux.1-dev", transformer=model, torch_dtype=dtype)
- pipe = FluxPipeline.from_pretrained("black-forest-labs/flux.1-dev", transformer=model, torch_dtype=dtype).to(device) |
if flux nf4 run everything on the GPU, How much GPU memory required ? |
I am getting same error as @Ednaordinary mentioned above. ' Error in FluxImageGenerator initialization: Cannot copy out of meta tensor; no data! Please use torch.nn.Module.to_empty() instead of torch.nn.Module.to() when moving module from meta to a different device.' I am utilizing your convert_nf4_flux.py script and loaidng it this way you can see code snippit
You can see my console log below. Any ideas on how to solve this? Tried a bunch of stuff but no luck. |
Please note that moving the devices to cuda then doing enable_model_cpu_offload is counterproductive as they will be moved to the gpu then back to cpu. You should only do enable_model_cpu_offload, without the gpu transfer. This won't fix your error as the models will end up being loaded to gpu anyway, but it should help your load times. As per your actual error, make sure you are doing all of this to the transformer before loading to cuda, as per the generate.py script: dtype = torch.bfloat16
ckpt_path = hf_hub_download("black-forest-labs/flux.1-dev", filename="flux1-dev.safetensors")
original_state_dict = safetensors.torch.load_file(ckpt_path)
converted_state_dict = convert_flux_transformer_checkpoint_to_diffusers(original_state_dict)
del original_state_dict
gc.collect()
with init_empty_weights():
config = FluxTransformer2DModel.load_config("black-forest-labs/flux.1-dev", subfolder="transformer")
model = FluxTransformer2DModel.from_config(config).to(dtype)
_replace_with_bnb_linear(model, "nf4")
for param_name, param in converted_state_dict.items():
param = param.to(dtype)
if not check_quantized_param(model, param_name):
set_module_tensor_to_device(model, param_name, device=0, value=param)
else:
create_quantized_param(model, param, param_name, target_device=0)
del converted_state_dict
gc.collect() I cannot tell your exact error, as you didn't include the full transformer loading code in your snippet. If you have qkv fusion enabled, make sure to disable it, as it doesn't currently work with these scripts. |
You could now refer to #9213 for more improved support. |
Tried it (latest commit), though qkv fusion still doesn't work as expected :( Part that matters: import torch
from diffusers import FluxTransformer2DModel, FluxPipeline
model_id = "black-forest-labs/FLUX.1-dev"
nf4_id = "sayakpaul/flux.1-dev-nf4-with-bnb-integration"
model_nf4 = FluxTransformer2DModel.from_pretrained(nf4_id, torch_dtype=torch.bfloat16)
model_nf4.fuse_qkv_projections() I think this is still due to attention blocks being in uint8 Awesome work though, loading is quick and very streamlined |
Thanks for your patience and insights. As a temporary solution, I can point you to the code snippet in #8829 (comment). Maybe that could be a reasonable reference for now and then once #9213 is merged, we can polish our QKV fusion compatibility? |
Sure! I see the main solution now is fusing qkv before quantization, so I'll likely wait for better post-quant support since saving and loading fused qkv would be confusing. Great work, regardless |
Thanks for this, I was able to get Pipeline loading and inference succesful now! |
|
sure, here it is, but if you don't save it, it takes a more time to load. Also I didn't quantize the text encoder because we can run it like this with 24GB VRAM. |
Thanks a lot @asomoza . From you result, fp8 inference is a bit faster than nf4. Can you think of why this could happen? And also, for GPU with enough memory (like 40 Gig), will quantized models in general help with latency? From my experiments, bfloat16 inference is 1-2 seconds faster than nf4. Could this because the compute is still happening in bf16 (even with nf4 model)? |
in two step 😅 |
Is there any way to use this method with the from huggingface_hub import hf_hub_download
from accelerate.utils import set_module_tensor_to_device
from accelerate import init_empty_weights
from convert_nf4_flux import _replace_with_bnb_linear, create_quantized_param, check_quantized_param
from diffusers import FluxTransformer2DModel, FluxControlNetModel, FluxControlNetPipeline
import safetensors.torch
import torch
from diffusers.loaders.single_file_utils import convert_flux_transformer_checkpoint_to_diffusers
from diffusers.utils import load_image
dtype = torch.bfloat16
ckpt_path = hf_hub_download("black-forest-labs/flux.1-schnell", filename="flux1-schnell.safetensors")
original_state_dict = safetensors.torch.load_file(ckpt_path)
converted_state_dict = convert_flux_transformer_checkpoint_to_diffusers(original_state_dict)
del original_state_dict
with init_empty_weights():
config = FluxTransformer2DModel.load_config("black-forest-labs/flux.1-schnell", subfolder="transformer")
model = FluxTransformer2DModel.from_config(config).to(dtype)
_replace_with_bnb_linear(model, "nf4")
for param_name, param in converted_state_dict.items():
param = param.to(dtype)
if not check_quantized_param(model, param_name):
set_module_tensor_to_device(model, param_name, device=0, value=param)
else:
create_quantized_param(model, param, param_name, target_device=0)
del converted_state_dict
controlnet_model = 'InstantX/FLUX.1-dev-Controlnet-Canny'
controlnet = FluxControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16, guidance_embeds = False)
pipe = FluxControlNetPipeline.from_pretrained("black-forest-labs/flux.1-schnell", transformer=model, controlnet=controlnet, torch_dtype=torch.bfloat16)
pipe.to('cuda')
control_image = load_image("https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Canny/resolve/main/canny.jpg")
prompt = "A girl in city, 25 years old, cool, futuristic"
image = pipe(
prompt,
control_image=control_image,
controlnet_conditioning_scale=0.6,
num_inference_steps=3,
guidance_scale=1,
).images[0]
image.save("image.jpg") |
Thank you very much for your great work. I would like to ask if there is any way to speed up the image generation. When I use num_inference_steps 50 on T4 GPU 16GB VRAM, it takes me more than 30 minutes to finish generating an image. Thank you very much again. |
Did you get gguf working? Sorry to resurrect an old thread—just very interested in using flux in gguf format in diffusers. |
I converted and loaded the diffusers model but then I stopped because I'm waiting for the final quantization PR to be merged and then see what steps we should take and if it's worth to do it. |
I'll keep an eye on it then. Thank you very much 🫡 |
hi, did you solved this problem? |
check this black-forest-labs/flux#185 |
@SunMarc
Since the Flux params are quite huge (if we include the text encoder2, autoencoder, and the diffusion model itself) -- it totals to more than 30GB.
https://huggingface.co/lllyasviel/flux1-dev-bnb-nf4 ships a single safetensors file that has the diffusion model in NF4.
Now, I was able to get this converted and load it into our
FluxTransformer2DModel
,but I am not seeing any size (state dict size) benefits. I am seeing the size benefits (yay!).But loading seems to be not working yet. What am I missing? Will appreciate feedback.Here is a detailed rundown of what I have done so far.
convert_nf4_flux.py
generate.py
The image generates just fine.
But not sure why we're not seeing any size benefit here.But the loading seems broken (generated image is noise). Advise?I have uploaded the NF4 serialized state dict here: https://huggingface.co/sayakpaul/flux.1-dev-nf4Loading script is below:
load_from_nf4_and_generate.py
NF4 serialization and loading is working fine!
The text was updated successfully, but these errors were encountered: