Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NF4 Flux params in diffusers #9165

Closed
sayakpaul opened this issue Aug 13, 2024 · 56 comments
Closed

NF4 Flux params in diffusers #9165

sayakpaul opened this issue Aug 13, 2024 · 56 comments

Comments

@sayakpaul
Copy link
Member

sayakpaul commented Aug 13, 2024

@SunMarc

Since the Flux params are quite huge (if we include the text encoder2, autoencoder, and the diffusion model itself) -- it totals to more than 30GB.

https://huggingface.co/lllyasviel/flux1-dev-bnb-nf4 ships a single safetensors file that has the diffusion model in NF4.

Now, I was able to get this converted and load it into our FluxTransformer2DModel, but I am not seeing any size (state dict size) benefits. I am seeing the size benefits (yay!). But loading seems to be not working yet. What am I missing? Will appreciate feedback.

Here is a detailed rundown of what I have done so far.

convert_nf4_flux.py
"""
Utilities adapted from

* https://github.com/huggingface/transformers/blob/main/src/transformers/quantizers/quantizer_bnb_4bit.py
* https://github.com/huggingface/transformers/blob/main/src/transformers/integrations/bitsandbytes.py
"""

import torch
import bitsandbytes as bnb
from transformers.quantizers.quantizers_utils import get_module_from_name
import torch.nn as nn
from accelerate import init_empty_weights


def _replace_with_bnb_linear(
    model,
    method="nf4",
    has_been_replaced=False,
):
    """
    Private method that wraps the recursion for module replacement.

    Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
    """
    for name, module in model.named_children():
        if isinstance(module, nn.Linear):
            with init_empty_weights():
                in_features = module.in_features
                out_features = module.out_features

                if method == "llm_int8":
                    model._modules[name] = bnb.nn.Linear8bitLt(
                        in_features,
                        out_features,
                        module.bias is not None,
                        has_fp16_weights=False,
                        threshold=6.0,
                    )
                    has_been_replaced = True
                else:
                    model._modules[name] = bnb.nn.Linear4bit(
                        in_features,
                        out_features,
                        module.bias is not None,
                        compute_dtype=torch.bfloat16,
                        compress_statistics=False,
                        quant_type="nf4",
                    )
                    has_been_replaced = True
                # Store the module class in case we need to transpose the weight later
                model._modules[name].source_cls = type(module)
                # Force requires grad to False to avoid unexpected errors
                model._modules[name].requires_grad_(False)

        if len(list(module.children())) > 0:
            _, has_been_replaced = _replace_with_bnb_linear(
                module,
                has_been_replaced=has_been_replaced,
            )
        # Remove the last key for recursion
    return model, has_been_replaced


def check_quantized_param(
    model,
    param_name: str,
) -> bool:
    module, tensor_name = get_module_from_name(model, param_name)
    if isinstance(module._parameters.get(tensor_name, None), bnb.nn.Params4bit):
        # Add here check for loaded components' dtypes once serialization is implemented
        return True
    elif isinstance(module, bnb.nn.Linear4bit) and tensor_name == "bias":
        # bias could be loaded by regular set_module_tensor_to_device() from accelerate,
        # but it would wrongly use uninitialized weight there.
        return True
    else:
        return False


def create_quantized_param(
    model,
    param_value: "torch.Tensor",
    param_name: str,
    target_device: "torch.device",
    state_dict=None,
    unexpected_keys=None,
    pre_quantized=False
):
    module, tensor_name = get_module_from_name(model, param_name)

    if tensor_name not in module._parameters:
        raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.")

    old_value = getattr(module, tensor_name)

    if tensor_name == "bias":
        if param_value is None:
            new_value = old_value.to(target_device)
        else:
            new_value = param_value.to(target_device)

        new_value = torch.nn.Parameter(new_value, requires_grad=old_value.requires_grad)
        module._parameters[tensor_name] = new_value
        return

    if not isinstance(module._parameters[tensor_name], bnb.nn.Params4bit):
        raise ValueError("this function only loads `Linear4bit components`")
    if (
        old_value.device == torch.device("meta")
        and target_device not in ["meta", torch.device("meta")]
        and param_value is None
    ):
        raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {target_device}.")

    if pre_quantized:
        if (param_name + ".quant_state.bitsandbytes__fp4" not in state_dict) and (
                param_name + ".quant_state.bitsandbytes__nf4" not in state_dict
            ):
                raise ValueError(
                    f"Supplied state dict for {param_name} does not contain `bitsandbytes__*` and possibly other `quantized_stats` components."
                )

        quantized_stats = {}
        for k, v in state_dict.items():
            # `startswith` to counter for edge cases where `param_name`
            # substring can be present in multiple places in the `state_dict`
            if param_name + "." in k and k.startswith(param_name):
                quantized_stats[k] = v
                if unexpected_keys is not None and k in unexpected_keys:
                    unexpected_keys.remove(k)

        new_value = bnb.nn.Params4bit.from_prequantized(
            data=param_value,
            quantized_stats=quantized_stats,
            requires_grad=False,
            device=target_device,
        )

    else:
        new_value = param_value.to("cpu")
        kwargs = old_value.__dict__
        new_value = bnb.nn.Params4bit(new_value, requires_grad=False, **kwargs).to(target_device)

    module._parameters[tensor_name] = new_value
generate.py
from huggingface_hub import hf_hub_download
from accelerate.utils import set_module_tensor_to_device, compute_module_sizes
from accelerate import init_empty_weights
from diffusers.loaders.single_file_utils import convert_flux_transformer_checkpoint_to_diffusers
from convert_nf4_flux import _replace_with_bnb_linear, create_quantized_param, check_quantized_param
from diffusers import FluxTransformer2DModel, FluxPipeline
import safetensors.torch
import gc
import torch

dtype = torch.bfloat16
ckpt_path = hf_hub_download("black-forest-labs/flux.1-dev", filename="flux1-dev.safetensors")
original_state_dict = safetensors.torch.load_file(ckpt_path)
converted_state_dict = convert_flux_transformer_checkpoint_to_diffusers(original_state_dict)

del original_state_dict
gc.collect()

with init_empty_weights():
    config = FluxTransformer2DModel.load_config("black-forest-labs/flux.1-dev", subfolder="transformer")
    model = FluxTransformer2DModel.from_config(config).to(dtype)

_replace_with_bnb_linear(model, "nf4")
for param_name, param in converted_state_dict.items():
    param = param.to(dtype)
    if not check_quantized_param(model, param_name):
        set_module_tensor_to_device(model, param_name, device=0, value=param)
    else:
        create_quantized_param(model, param, param_name, target_device=0)

del converted_state_dict
gc.collect()

print(compute_module_sizes(model)[""] / 1024 / 1204)

pipe = FluxPipeline.from_pretrained("black-forest-labs/flux.1-dev", transformer=model, torch_dtype=dtype)
pipe.enable_model_cpu_offload()

prompt = "A mystic cat with a sign that says hello world!"
image = pipe(prompt, guidance_scale=3.5, num_inference_steps=50, generator=torch.manual_seed(0)).images[0]
image.save("flux-nf4-dev.png")

model.push_to_hub("sayakpaul/flux.1-dev-nf4")

The image generates just fine. But not sure why we're not seeing any size benefit here.

image

But the loading seems broken (generated image is noise). Advise? I have uploaded the NF4 serialized state dict here: https://huggingface.co/sayakpaul/flux.1-dev-nf4

Loading script is below:

load_from_nf4_and_generate.py
"""
Some bits are from https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py
"""

from huggingface_hub import hf_hub_download
from accelerate.utils import set_module_tensor_to_device, compute_module_sizes
from accelerate import init_empty_weights
from convert_nf4_flux import _replace_with_bnb_linear, create_quantized_param, check_quantized_param
from diffusers import FluxTransformer2DModel, FluxPipeline
import safetensors.torch
import gc
import torch

dtype = torch.bfloat16
is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn")
ckpt_path = hf_hub_download("sayakpaul/flux.1-dev-nf4", filename="diffusion_pytorch_model.safetensors")
original_state_dict = safetensors.torch.load_file(ckpt_path)

with init_empty_weights():
    config = FluxTransformer2DModel.load_config("sayakpaul/flux.1-dev-nf4")
    model = FluxTransformer2DModel.from_config(config).to(dtype)
    expected_state_dict_keys = list(model.state_dict().keys())

_replace_with_bnb_linear(model, "nf4")

for param_name, param in original_state_dict.items():
    if param_name not in expected_state_dict_keys:
        continue
    
    is_param_float8_e4m3fn = is_torch_e4m3fn_available and param.dtype == torch.float8_e4m3fn
    if torch.is_floating_point(param) and not is_param_float8_e4m3fn:
        param = param.to(dtype)
    
    if not check_quantized_param(model, param_name):
        set_module_tensor_to_device(model, param_name, device=0, value=param)
    else:
        create_quantized_param(
            model, param, param_name, target_device=0, state_dict=original_state_dict, pre_quantized=True
        )

del original_state_dict
gc.collect()

print(compute_module_sizes(model)[""] / 1024 / 1204)

pipe = FluxPipeline.from_pretrained("black-forest-labs/flux.1-dev", transformer=model, torch_dtype=dtype)
pipe.enable_model_cpu_offload()

prompt = "A mystic cat with a sign that says hello world!"
image = pipe(prompt, guidance_scale=3.5, num_inference_steps=50, generator=torch.manual_seed(0)).images[0]
image.save("flux-nf4-dev-loaded.png")

NF4 serialization and loading is working fine!

@feffy380
Copy link

feffy380 commented Aug 13, 2024

It looks like expected_state_dict_keys could be the culprit. NF4 adds a few extra keys for each weight, but you're building expected_state_dict_keys before quantization, so these keys are discarded even though they're present in the state_dict.

state_dict comparison from Forge. They use double quantization, which adds even more keys.

@sayakpaul
Copy link
Member Author

@feffy380 thank you but see this bit in convert_nf4_flux.py:

...
    if pre_quantized:
        if (param_name + ".quant_state.bitsandbytes__fp4" not in state_dict) and (
                param_name + ".quant_state.bitsandbytes__nf4" not in state_dict
            ):
                raise ValueError(
                    f"Supplied state dict for {param_name} does not contain `bitsandbytes__*` and possibly other `quantized_stats` components."
                )

        quantized_stats = {}
        for k, v in state_dict.items():
            # `startswith` to counter for edge cases where `param_name`
            # substring can be present in multiple places in the `state_dict`
            if param_name + "." in k and k.startswith(param_name):
                quantized_stats[k] = v
                if unexpected_keys is not None and k in unexpected_keys:
                    unexpected_keys.remove(k)

        new_value = bnb.nn.Params4bit.from_prequantized(
            data=param_value,
            quantized_stats=quantized_stats,
            requires_grad=False,
            device=target_device,
        )

It already packs those quantization stats. But will look into it.

@sayakpaul
Copy link
Member Author

sayakpaul commented Aug 14, 2024

Update: NF4 serialization and loading are working fine. @DN6 let's brainstorm how we can support it more easily? This would help us unlock doing LoRAs on the quantized weights, too (cc: @BenjaminBossan for PEFT). I think this will become evidently critical for larger models.

transformers has a nice reference for us to follow. Additionally, accelerate has: https://huggingface.co/docs/accelerate/en/usage_guides/quantization, but it doesn't support NF4 serialization yet.

Cc: @SunMarc for jamming on this together.

@chuck-ma
Copy link

chuck-ma commented Aug 14, 2024

Update: NF4 serialization and loading are working fine. @DN6 let's brainstorm how we can support it more easily? This would help us unlock doing LoRAs on the quantized weights, too (cc: @BenjaminBossan for PEFT). I think this will become evidently critical for larger models.

transformers has a nice reference for us to follow. Additionally, accelerate has: https://huggingface.co/docs/accelerate/en/usage_guides/quantization, but it doesn't support NF4 serialization yet.

Cc: @SunMarc for jamming on this together.

First of all, I'm very grateful for your efforts. However, I still don't understand where the problem lies in the code using nf4 above. Could you share the final code that can serialize and load nf4?

@sayakpaul
Copy link
Member Author

It is working now. No issues there :) please test and let me know if you run into something weird.

@asomoza
Copy link
Member

asomoza commented Aug 14, 2024

Really cool, I got it working without issues, here's my benchmark with this:

prompt:

high quality photo of a dog sitting beside a tree with a red soccer ball at its side, the background it's a lake with a boat sailing in it and an airplane flying in the cloudy sky.

fp8 nf4
20240813232938_3494736610 20240813232425_3494736610

FP8

Total elapsed time (with model loading): 55.00 seconds
Generation time + vae decoding: 33.35 seconds

NF4

Total elapsed time (with model loading): 40.24 seconds
Generation time + vae decoding: 35.66 seconds

And I see the memory savings.

@sayakpaul
Copy link
Member Author

Thank you, Alvaro!

Of course we can additionally improve things by:

  • Apply the same quantization to T5 (I doubt if CLIP would benefit from this that much. VAEs don't apply here because bnb works best with models composed of mostly nn.Linear layers).
  • Explore the compute dtype on the hardware we're on because the dynamics of BF16 + NF4 and FP16 + NF4 varies from card to card.

@sayakpaul
Copy link
Member Author

sayakpaul commented Aug 14, 2024

In case someone's looking for a way to get the text encoders and the VAE from https://huggingface.co/lllyasviel/flux1-dev-bnb-nf4 converted to the diffusers format:

"""
Main model in bnb-nf4

T5xxl in fp8e4m3fn

CLIP-L in fp16

VAE in bf16
"""

from huggingface_hub import hf_hub_download
from diffusers.loaders.single_file_utils import (
    create_diffusers_clip_model_from_ldm,
    create_diffusers_t5_model_from_checkpoint,
    convert_ldm_vae_checkpoint,
)
from diffusers import AutoencoderKL
import safetensors.torch
from transformers import T5EncoderModel, CLIPTextModel, AutoConfig
import torch

ckpt_id = "lllyasviel/flux1-dev-bnb-nf4"
filename = "flux1-dev-bnb-nf4.safetensors"

ckpt_path = hf_hub_download(repo_id=ckpt_id, filename=filename)
sd = safetensors.torch.load_file(ckpt_path)

# T5
t5 = create_diffusers_t5_model_from_checkpoint(
    cls=T5EncoderModel, checkpoint=sd, config="black-forest-labs/FLUX.1-dev", subfolder="text_encoder_2"
)

# CLIP
clip_ = create_diffusers_clip_model_from_ldm(
    cls=CLIPTextModel, checkpoint=sd, config="black-forest-labs/FLUX.1-dev", subfolder="text_encoder"
)

# VAE
config = AutoencoderKL.load_config("stabilityai/stable-diffusion-3-medium-diffusers", subfolder="vae")
vae = AutoencoderKL.from_config(config, scaling_factor=0.3611, shift_factor=0.1159).to(torch.bfloat16)
vae_sd = {k.replace("vae.", ""): v for k, v in sd.items() if "vae." in k}
vad_sd = convert_ldm_vae_checkpoint(vae_sd, vae.config)
vae.load_state_dict(vad_sd, strict=True)

For the diffusion model as in keys prefixed with model.diffusion_model, we suggest following the saving and loading approach in the OP because we cannot define a clear mechanism to load the quantization stats for the attention modules from those keys and associated tensors.

@DN6
Copy link
Collaborator

DN6 commented Aug 14, 2024

Looks good. Let's get this in 👍🏽

@iwaitu
Copy link

iwaitu commented Aug 14, 2024

In case someone's looking for a way to get the text encoders and the VAE from https://huggingface.co/lllyasviel/flux1-dev-bnb-nf4 converted to the diffusers format:

"""
Main model in bnb-nf4

T5xxl in fp8e4m3fn

CLIP-L in fp16

VAE in bf16
"""

from huggingface_hub import hf_hub_download
from diffusers.loaders.single_file_utils import (
    create_diffusers_clip_model_from_ldm,
    create_diffusers_t5_model_from_checkpoint,
    convert_ldm_vae_checkpoint,
)
from diffusers import AutoencoderKL
import safetensors.torch
from transformers import T5EncoderModel, CLIPTextModel, AutoConfig
import torch

ckpt_id = "lllyasviel/flux1-dev-bnb-nf4"
filename = "flux1-dev-bnb-nf4.safetensors"

ckpt_path = hf_hub_download(repo_id=ckpt_id, filename=filename)
sd = safetensors.torch.load_file(ckpt_path)

# T5
t5 = create_diffusers_t5_model_from_checkpoint(
    cls=T5EncoderModel, checkpoint=sd, config="black-forest-labs/FLUX.1-dev", subfolder="text_encoder_2"
)

# CLIP
clip_ = create_diffusers_clip_model_from_ldm(
    cls=CLIPTextModel, checkpoint=sd, config="black-forest-labs/FLUX.1-dev", subfolder="text_encoder"
)

# VAE
config = AutoencoderKL.load_config("stabilityai/stable-diffusion-3-medium-diffusers", subfolder="vae")
vae = AutoencoderKL.from_config(config, scaling_factor=0.3611, shift_factor=0.1159).to(torch.bfloat16)
vae_sd = {k.replace("vae.", ""): v for k, v in sd.items() if "vae." in k}
vad_sd = convert_ldm_vae_checkpoint(vae_sd, vae.config)
vae.load_state_dict(vad_sd, strict=True)

For the diffusion model as in keys prefixed with model.diffusion_model, we suggest following the saving and loading approach in the OP because we cannot define a clear mechanism to load the quantization stats for the attention modules from those keys and associated tensors.

I try to use this for pipeline failed.

pipe = FluxPipeline.from_pretrained("black-forest-labs/flux.1-dev", transformer=sd, vae=vae,torch_dtype=dtype)
pipe.enable_model_cpu_offload()

@sayakpaul
Copy link
Member Author

Since the purpose of this issue was to show a PoC of NF4 serialization and loading, I will close this.

@Skquark
Copy link

Skquark commented Aug 15, 2024

Wondering if there's a chance this method could be used with Flux ControlNet? I think it needs it...

@sayakpaul
Copy link
Member Author

Should be orthogonal after #9174 is done.

@chuck-ma
Copy link

chuck-ma commented Aug 16, 2024

ValueError                                Traceback (most recent call last)
Cell In[9], line 14
     11 ckpt_path = hf_hub_download("lllyasviel/flux1-dev-bnb-nf4", filename="flux1-dev-bnb-nf4-v2.safetensors")
     13 original_state_dict = safetensors.torch.load_file(ckpt_path)
---> 14 converted_state_dict = convert_flux_transformer_checkpoint_to_diffusers(original_state_dict)
     16 del original_state_dict
     17 gc.collect()

File ~/miniconda3/lib/python3.10/site-packages/diffusers/loaders/single_file_utils.py:1885, in convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs)
   1882 def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
   1883     converted_state_dict = {}
-> 1885     num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "double_blocks." in k))[-1] + 1  # noqa: C401
   1886     num_single_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "single_blocks." in k))[-1] + 1  # noqa: C401
   1887     mlp_ratio = 4.0

File ~/miniconda3/lib/python3.10/site-packages/diffusers/loaders/single_file_utils.py:1885, in <genexpr>(.0)
   1882 def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
   1883     converted_state_dict = {}
-> 1885     num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "double_blocks." in k))[-1] + 1  # noqa: C401
   1886     num_single_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "single_blocks." in k))[-1] + 1  # noqa: C401
   1887     mlp_ratio = 4.0

ValueError: invalid literal for int() with base 10: 'diffusion_model'

When I run the code below, it throws the error above. How could I use nfv v2 from lllyasviel/stable-diffusion-webui-forge#1079

from huggingface_hub import hf_hub_download
from accelerate.utils import set_module_tensor_to_device, compute_module_sizes
from accelerate import init_empty_weights
from diffusers.loaders.single_file_utils import convert_flux_transformer_checkpoint_to_diffusers
from diffusers import FluxTransformer2DModel, FluxPipeline
import safetensors.torch
import gc
import torch

dtype = torch.bfloat16
ckpt_path = hf_hub_download("lllyasviel/flux1-dev-bnb-nf4", filename="flux1-dev-bnb-nf4-v2.safetensors")

original_state_dict = safetensors.torch.load_file(ckpt_path)
converted_state_dict = convert_flux_transformer_checkpoint_to_diffusers(original_state_dict)

del original_state_dict
gc.collect()

@sayakpaul
Copy link
Member Author

See what I told in #9165 (comment)

For the diffusion model as in keys prefixed with model.diffusion_model, we suggest following the saving and loading approach in the OP because we cannot define a clear mechanism to load the quantization stats for the attention modules from those keys and associated tensors.

@chuck-ma
Copy link

chuck-ma commented Aug 16, 2024

See what I told in #9165 (comment)

For the diffusion model as in keys prefixed with mashodel.diffusion_model, we suggest following the saving and loading approach in the OP because we cannot define a clear mechanism to load the quantization stats for the attention modules from those keys and associated tensors.

Wait for a moment. According to the link: lllyasviel/stable-diffusion-webui-forge#1079

V2 is quantized in a better way to turn off the second stage of double quant.

V2 is 0.5 GB larger than the previous version, since the chunk 64 norm is now stored in full precision float32, making it much more precise than the previous version. Also, since V2 does not have second compression stage, it now has less computation overhead for on-the-fly decompression, making the inference a bit faster.

how could we also be able to not have second compression stage and compute faster? It seems that we need to optimize the previous NF4 transformation method of the transformer

@billnye2
Copy link

@sayakpaul Any idea on how to get the FLUX.1-dev-gguf models running in diffusers? Thank you

@asomoza
Copy link
Member

asomoza commented Aug 19, 2024

I've been testing using a diffusers dev gguf model, I got the quantized model easy but still trying to figure out how to load it and use it correctly at inference.

Just a heads up, the gguf in ComfyUI is used just as a storage of the model so you only will see a VRAM saving but not better speed, I think I saw someone saying that Q8_0 is the best in terms of quality and VRAM but that nf4 its still faster.

@billnye2
Copy link

Ah I'm not too familiar with llama.cpp which that comment refers to. What does llama.cpp do to get higher speed with gguf quants vs. just dequantizing on the fly like the comfyui extension does?

@asomoza
Copy link
Member

asomoza commented Aug 19, 2024

I'm not expert at gguf so maybe someone else has a better explanation of this, but first, the author of llama.cpp is the person that created the gguf format so that's why it is relevant.

The difference in speed is that ComfyUI has to upscale back the tensors to bfloat16 for each computation while llama.cpp doesn't do this and works directly with the quantized values, so in addition to the memory savings it is also a lot faster.

@billnye2
Copy link

Wait really, I didn't know llama.cpp operates directly with quantized weights without dequantizing first, or that the model would still work since different weight blocks are quantized individually in gguf. Do you have source for this?

@asomoza
Copy link
Member

asomoza commented Aug 19, 2024

Probably I over simplified it too much, it's not that easy and I don't understand it completely, that why I was telling that probably someone else could explain it better.

You can read about it in this PR but it's more complex than that, some usefull information is:

  • Quantized weights are easily unpacked using a bit shift, AND, and multiplication (and additon in _1 variants).
  • With the K models bits are allocated in a smarter way than in legacy quants.
  • They also lower the quantization error.

There's a lot of math and tables to read and understand and I'm just learning about this as I'm trying to use it with diffusers, so I'm not really the best person to explain this yet.

But you're right and to correct my previous post, they don't work with the quantized values but instead it does a really fast upscale which I suppose the comfyui version doesn't do and that's why it's slower.

@asomoza
Copy link
Member

asomoza commented Aug 19, 2024

Oh an also before I forget, when referring to just as a storage, the same as the issue, what I meant also was the memory optimization and all the other optimization done in llama.cpp so you can actually use the llms with just the cpu, so people don't expect to be able to use Flux with just CPU with diffusers.

When I get the gguf working I can test if it's really slower than nf4 or not.

@al-swaiti
Copy link

https://huggingface.co/docs/transformers/v4.44.0/quantization/bitsandbytes?bnb=4-bit#nested-quantization

from transformers import BitsAndBytesConfig

double_quant_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
)
is this what called V2
https://huggingface.co/lllyasviel/flux1-dev-bnb-nf4/blob/main/flux1-dev-bnb-nf4-v2.safetensors @sayakpaul , also why you didn't use direct diffusers without conservision here
image
, last question is that possible to save it for comfyui and webui before transform to diffusers ,, a big thanks for you

@Ednaordinary
Copy link

Anyone have this working with the recently merged qkv projection for flux? I encounter an error in torch.nn.Linear since the attention weight seem to be in uint8. Casting to a float seems to work, but then gets caught with size of tensor a (14155776) must match tensor b (9216) at non-singleton dimension 0 which is where I got stuck :(

@sayakpaul
Copy link
Member Author

Where exactly do you find the error? Do you have a reproducible snippet?

@sayakpaul
Copy link
Member Author

possibly a vram boost

Hmm this won't probably be the case yet because the separate Q,K,V projections are not wiped off for bookkeeping purposes.

Perhaps we could try to first fuse the QKV projections and then apply the NF4 shenanigans?

@Ednaordinary
Copy link

This errors out because there's no weights yet :(
Same thing in generating the weights from source, though I'm not sure why that happens

@sayakpaul
Copy link
Member Author

This errors out because there's no weights yet :(

I don't understand. Why won't you have the weights?

Same thing in generating the weights from source,

What do you mean here?

@sayakpaul
Copy link
Member Author

Maybe you could give it another try as soon as #9213 becomes a bit more polished (I plan on working on it this week) and I can let you know once it's in a testable state?

@Ednaordinary
Copy link

Sure!

I think I misinterpreted the error torch was giving me: NotImplementedError: Cannot copy out of meta tensor; no data! Please use torch.nn.Module.to_empty() instead of torch.nn.Module.to() when moving module from meta to a different device.
This doesn't happen straight after qkv fusion, but actually when trying to transfer the model to cuda. Any sort of transfer post-fusion seems to give this error, probably fusion somehow left behind empty weights? I got as far as getting it to run on the meta device, but gets stuck at the text encoder for some reason.

What I meant by "from source" is the generate.py script, as in making NF4 weights from the flux repo instead of using your repo in the load_from_nf4_and_generate.py script.

@sayakpaul
Copy link
Member Author

This doesn't happen straight after qkv fusion, but actually when trying to transfer the model to cuda. Any sort of transfer post-fusion seems to give this error, probably fusion somehow left behind empty weights?

Maybe you could try this in isolation and report a bug? You could use this minimal test as a reference:

def test_fused_qkv_projections(self):

I got as far as getting it to run on the meta device, but gets stuck at the text encoder for some reason.

Text encoder? Elaborate a bit more? How are quantizing the text encoder? transformers already supports direct loading with bitsandbytes.

@Ednaordinary
Copy link

Maybe you could try this in isolation and report a bug? You could use this minimal test as a reference:

I'll do this, but I think it has something to do with the NF4 weights specifically

Text encoder? Elaborate a bit more? How are quantizing the text encoder? transformers already supports direct loading with bitsandbytes.

The text encoder as in the one in the pipeline: when it's taken a few minutes and I give up and Ctrl+C, the traceback shows it was still on the text_encoder_2 stage. It's probably a product of the "meta" (CPU?) device being so slow. I'll try encoding separately from the pipeline and passing embeddings in as the prompt when I have a bit more time.

@sayakpaul
Copy link
Member Author

Okay. Maybe would be just better to try it out after our support for bitsandbytes become more robust through the PRs I mentioned above. Because currently, it's all a bit hacky without proper checks and guards in places.

The text encoder as in the one in the pipeline: when it's taken a few minutes and I give up and Ctrl+C, the traceback shows it was still on the text_encoder_2 stage. It's probably a product of the "meta" (CPU?) device being so slow. I'll try encoding separately from the pipeline and passing embeddings in as the prompt when I have a bit more time.

Well, I think it doesn't quite give the full picture, IMO.

@lonngxiang
Copy link

T4,OutOfMemoryError: CUDA out of memory
image

@sayakpaul
Copy link
Member Author

Should be:

+ pipe = FluxPipeline.from_pretrained("black-forest-labs/flux.1-dev", transformer=model, torch_dtype=dtype)
- pipe = FluxPipeline.from_pretrained("black-forest-labs/flux.1-dev", transformer=model, torch_dtype=dtype).to(device)

@lonngxiang
Copy link

Should be:

+ pipe = FluxPipeline.from_pretrained("black-forest-labs/flux.1-dev", transformer=model, torch_dtype=dtype)
- pipe = FluxPipeline.from_pretrained("black-forest-labs/flux.1-dev", transformer=model, torch_dtype=dtype).to(device)

if flux nf4 run everything on the GPU, How much GPU memory required ?

@dylanisreversing
Copy link

I am getting same error as @Ednaordinary mentioned above. ' Error in FluxImageGenerator initialization: Cannot copy out of meta tensor; no data! Please use torch.nn.Module.to_empty() instead of torch.nn.Module.to() when moving module from meta to a different device.'

I am utilizing your convert_nf4_flux.py script and loaidng it this way you can see code snippit

           # Move models to the correct device
            logger.info(f"Moving models to device: {self.device}")
            model = model.to(self.device)
            clip_ = clip_.to(self.device)
            t5 = t5.to(self.device)
            vae = vae.to(self.device)

            # Create FLUX pipeline
            logger.info("Creating FLUX pipeline")
            self.pipe = FluxPipeline.from_pretrained(
                bfl_repo,
                transformer=model,
                text_encoder=clip_,
                text_encoder_2=t5,
                vae=vae,
                tokenizer=tokenizer,
                tokenizer_2=tokenizer_2,
                torch_dtype=dtype,
                local_files_only=True
            )

            # Enable CPU offload and VAE slicing
            logger.info("Enabling CPU offload and VAE slicing")
            self.pipe.enable_model_cpu_offload()
            self.pipe.enable_vae_slicing()

            # Print debug information
            logger.info(f"Model device: {next(self.pipe.transformer.parameters()).device}")
            logger.info(f"Model dtype: {next(self.pipe.transformer.parameters()).dtype}")
            for name, module in self.pipe.transformer.named_modules():
                if isinstance(module, bnb.nn.Linear4bit):
                    logger.info(f"Module {name} quant_state: {module.weight.quant_state}")

            logger.info("FLUX pipeline loaded and optimized successfully.")

        except Exception as e:
            logger.error(f"Error in FluxImageGenerator initialization: {str(e)}")
            logger.error(traceback.format_exc())
            raise

You can see my console log below.

Any ideas on how to solve this? Tried a bunch of stuff but no luck.
image

@Ednaordinary
Copy link

Ednaordinary commented Sep 2, 2024

@lonngxiang

Please note that moving the devices to cuda then doing enable_model_cpu_offload is counterproductive as they will be moved to the gpu then back to cpu. You should only do enable_model_cpu_offload, without the gpu transfer. This won't fix your error as the models will end up being loaded to gpu anyway, but it should help your load times.

As per your actual error, make sure you are doing all of this to the transformer before loading to cuda, as per the generate.py script:

dtype = torch.bfloat16
ckpt_path = hf_hub_download("black-forest-labs/flux.1-dev", filename="flux1-dev.safetensors")
original_state_dict = safetensors.torch.load_file(ckpt_path)
converted_state_dict = convert_flux_transformer_checkpoint_to_diffusers(original_state_dict)

del original_state_dict
gc.collect()

with init_empty_weights():
    config = FluxTransformer2DModel.load_config("black-forest-labs/flux.1-dev", subfolder="transformer")
    model = FluxTransformer2DModel.from_config(config).to(dtype)

_replace_with_bnb_linear(model, "nf4")
for param_name, param in converted_state_dict.items():
    param = param.to(dtype)
    if not check_quantized_param(model, param_name):
        set_module_tensor_to_device(model, param_name, device=0, value=param)
    else:
        create_quantized_param(model, param, param_name, target_device=0)

del converted_state_dict
gc.collect()

I cannot tell your exact error, as you didn't include the full transformer loading code in your snippet. If you have qkv fusion enabled, make sure to disable it, as it doesn't currently work with these scripts.

@sayakpaul
Copy link
Member Author

You could now refer to #9213 for more improved support.

@Ednaordinary
Copy link

Ednaordinary commented Sep 2, 2024

Tried it (latest commit), though qkv fusion still doesn't work as expected :(

image

Part that matters:

import torch
from diffusers import FluxTransformer2DModel, FluxPipeline

model_id = "black-forest-labs/FLUX.1-dev"
nf4_id = "sayakpaul/flux.1-dev-nf4-with-bnb-integration"
model_nf4 = FluxTransformer2DModel.from_pretrained(nf4_id, torch_dtype=torch.bfloat16)
model_nf4.fuse_qkv_projections()

I think this is still due to attention blocks being in uint8

Awesome work though, loading is quick and very streamlined

@sayakpaul
Copy link
Member Author

Thanks for your patience and insights. As a temporary solution, I can point you to the code snippet in #8829 (comment). Maybe that could be a reasonable reference for now and then once #9213 is merged, we can polish our QKV fusion compatibility?

@Ednaordinary
Copy link

Sure! I see the main solution now is fusing qkv before quantization, so I'll likely wait for better post-quant support since saving and loading fused qkv would be confusing. Great work, regardless

@dylanisreversing
Copy link

You could now refer to #9213 for more improved support.

Thanks for this, I was able to get Pipeline loading and inference succesful now!

@badhri-suresh
Copy link

FP8
asomoza I got the nf4 script working without issues. Could you please provide the working inference script for fp8 precision? Thanks

@asomoza
Copy link
Member

asomoza commented Sep 4, 2024

sure, here it is, but if you don't save it, it takes a more time to load. Also I didn't quantize the text encoder because we can run it like this with 24GB VRAM.

@badhri-suresh
Copy link

Thanks a lot @asomoza . From you result, fp8 inference is a bit faster than nf4. Can you think of why this could happen? And also, for GPU with enough memory (like 40 Gig), will quantized models in general help with latency? From my experiments, bfloat16 inference is 1-2 seconds faster than nf4. Could this because the compute is still happening in bf16 (even with nf4 model)?

@al-swaiti
Copy link

al-swaiti commented Sep 6, 2024

flux-nf4-dev in two step 😅
Uploading ....
https://huggingface.co/ABDALLALSWAITI/Maxwell/

@bradley-pearson6597
Copy link

Is there any way to use this method with the FluxControlNetPipeline? I have tried the below but I am receiving this error Input type (unsigned char) and bias type (c10::BFloat16) should be the same

from huggingface_hub import hf_hub_download
from accelerate.utils import set_module_tensor_to_device
from accelerate import init_empty_weights
from convert_nf4_flux import _replace_with_bnb_linear, create_quantized_param, check_quantized_param
from diffusers import FluxTransformer2DModel, FluxControlNetModel, FluxControlNetPipeline
import safetensors.torch
import torch
from diffusers.loaders.single_file_utils import convert_flux_transformer_checkpoint_to_diffusers
from diffusers.utils import load_image

dtype = torch.bfloat16
ckpt_path = hf_hub_download("black-forest-labs/flux.1-schnell", filename="flux1-schnell.safetensors")
original_state_dict = safetensors.torch.load_file(ckpt_path)
converted_state_dict = convert_flux_transformer_checkpoint_to_diffusers(original_state_dict)

del original_state_dict

with init_empty_weights():
    config = FluxTransformer2DModel.load_config("black-forest-labs/flux.1-schnell", subfolder="transformer")
    model = FluxTransformer2DModel.from_config(config).to(dtype)

_replace_with_bnb_linear(model, "nf4")
for param_name, param in converted_state_dict.items():
    param = param.to(dtype)
    if not check_quantized_param(model, param_name):
        set_module_tensor_to_device(model, param_name, device=0, value=param)
    else:
        create_quantized_param(model, param, param_name, target_device=0)

del converted_state_dict
controlnet_model = 'InstantX/FLUX.1-dev-Controlnet-Canny'
controlnet = FluxControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16, guidance_embeds = False)
pipe = FluxControlNetPipeline.from_pretrained("black-forest-labs/flux.1-schnell", transformer=model, controlnet=controlnet, torch_dtype=torch.bfloat16)
pipe.to('cuda')

control_image = load_image("https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Canny/resolve/main/canny.jpg")
prompt = "A girl in city, 25 years old, cool, futuristic"
image = pipe(
    prompt, 
    control_image=control_image,
    controlnet_conditioning_scale=0.6,
    num_inference_steps=3, 
    guidance_scale=1,
).images[0]
image.save("image.jpg")

@VTaPo
Copy link

VTaPo commented Sep 15, 2024

@SunMarc

Since the Flux params are quite huge (if we include the text encoder2, autoencoder, and the diffusion model itself) -- it totals to more than 30GB.

https://huggingface.co/lllyasviel/flux1-dev-bnb-nf4 ships a single safetensors file that has the diffusion model in NF4.

Now, I was able to get this converted and load it into our FluxTransformer2DModel, but I am not seeing any size (state dict size) benefits. I am seeing the size benefits (yay!). But loading seems to be not working yet. What am I missing? Will appreciate feedback.

Here is a detailed rundown of what I have done so far.

convert_nf4_flux.py

"""
Utilities adapted from

* https://github.com/huggingface/transformers/blob/main/src/transformers/quantizers/quantizer_bnb_4bit.py
* https://github.com/huggingface/transformers/blob/main/src/transformers/integrations/bitsandbytes.py
"""

import torch
import bitsandbytes as bnb
from transformers.quantizers.quantizers_utils import get_module_from_name
import torch.nn as nn
from accelerate import init_empty_weights


def _replace_with_bnb_linear(
    model,
    method="nf4",
    has_been_replaced=False,
):
    """
    Private method that wraps the recursion for module replacement.

    Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
    """
    for name, module in model.named_children():
        if isinstance(module, nn.Linear):
            with init_empty_weights():
                in_features = module.in_features
                out_features = module.out_features

                if method == "llm_int8":
                    model._modules[name] = bnb.nn.Linear8bitLt(
                        in_features,
                        out_features,
                        module.bias is not None,
                        has_fp16_weights=False,
                        threshold=6.0,
                    )
                    has_been_replaced = True
                else:
                    model._modules[name] = bnb.nn.Linear4bit(
                        in_features,
                        out_features,
                        module.bias is not None,
                        compute_dtype=torch.bfloat16,
                        compress_statistics=False,
                        quant_type="nf4",
                    )
                    has_been_replaced = True
                # Store the module class in case we need to transpose the weight later
                model._modules[name].source_cls = type(module)
                # Force requires grad to False to avoid unexpected errors
                model._modules[name].requires_grad_(False)

        if len(list(module.children())) > 0:
            _, has_been_replaced = _replace_with_bnb_linear(
                module,
                has_been_replaced=has_been_replaced,
            )
        # Remove the last key for recursion
    return model, has_been_replaced


def check_quantized_param(
    model,
    param_name: str,
) -> bool:
    module, tensor_name = get_module_from_name(model, param_name)
    if isinstance(module._parameters.get(tensor_name, None), bnb.nn.Params4bit):
        # Add here check for loaded components' dtypes once serialization is implemented
        return True
    elif isinstance(module, bnb.nn.Linear4bit) and tensor_name == "bias":
        # bias could be loaded by regular set_module_tensor_to_device() from accelerate,
        # but it would wrongly use uninitialized weight there.
        return True
    else:
        return False


def create_quantized_param(
    model,
    param_value: "torch.Tensor",
    param_name: str,
    target_device: "torch.device",
    state_dict=None,
    unexpected_keys=None,
    pre_quantized=False
):
    module, tensor_name = get_module_from_name(model, param_name)

    if tensor_name not in module._parameters:
        raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.")

    old_value = getattr(module, tensor_name)

    if tensor_name == "bias":
        if param_value is None:
            new_value = old_value.to(target_device)
        else:
            new_value = param_value.to(target_device)

        new_value = torch.nn.Parameter(new_value, requires_grad=old_value.requires_grad)
        module._parameters[tensor_name] = new_value
        return

    if not isinstance(module._parameters[tensor_name], bnb.nn.Params4bit):
        raise ValueError("this function only loads `Linear4bit components`")
    if (
        old_value.device == torch.device("meta")
        and target_device not in ["meta", torch.device("meta")]
        and param_value is None
    ):
        raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {target_device}.")

    if pre_quantized:
        if (param_name + ".quant_state.bitsandbytes__fp4" not in state_dict) and (
                param_name + ".quant_state.bitsandbytes__nf4" not in state_dict
            ):
                raise ValueError(
                    f"Supplied state dict for {param_name} does not contain `bitsandbytes__*` and possibly other `quantized_stats` components."
                )

        quantized_stats = {}
        for k, v in state_dict.items():
            # `startswith` to counter for edge cases where `param_name`
            # substring can be present in multiple places in the `state_dict`
            if param_name + "." in k and k.startswith(param_name):
                quantized_stats[k] = v
                if unexpected_keys is not None and k in unexpected_keys:
                    unexpected_keys.remove(k)

        new_value = bnb.nn.Params4bit.from_prequantized(
            data=param_value,
            quantized_stats=quantized_stats,
            requires_grad=False,
            device=target_device,
        )

    else:
        new_value = param_value.to("cpu")
        kwargs = old_value.__dict__
        new_value = bnb.nn.Params4bit(new_value, requires_grad=False, **kwargs).to(target_device)

    module._parameters[tensor_name] = new_value

generate.py

from huggingface_hub import hf_hub_download
from accelerate.utils import set_module_tensor_to_device, compute_module_sizes
from accelerate import init_empty_weights
from diffusers.loaders.single_file_utils import convert_flux_transformer_checkpoint_to_diffusers
from convert_nf4_flux import _replace_with_bnb_linear, create_quantized_param, check_quantized_param
from diffusers import FluxTransformer2DModel, FluxPipeline
import safetensors.torch
import gc
import torch

dtype = torch.bfloat16
ckpt_path = hf_hub_download("black-forest-labs/flux.1-dev", filename="flux1-dev.safetensors")
original_state_dict = safetensors.torch.load_file(ckpt_path)
converted_state_dict = convert_flux_transformer_checkpoint_to_diffusers(original_state_dict)

del original_state_dict
gc.collect()

with init_empty_weights():
    config = FluxTransformer2DModel.load_config("black-forest-labs/flux.1-dev", subfolder="transformer")
    model = FluxTransformer2DModel.from_config(config).to(dtype)

_replace_with_bnb_linear(model, "nf4")
for param_name, param in converted_state_dict.items():
    param = param.to(dtype)
    if not check_quantized_param(model, param_name):
        set_module_tensor_to_device(model, param_name, device=0, value=param)
    else:
        create_quantized_param(model, param, param_name, target_device=0)

del converted_state_dict
gc.collect()

print(compute_module_sizes(model)[""] / 1024 / 1204)

pipe = FluxPipeline.from_pretrained("black-forest-labs/flux.1-dev", transformer=model, torch_dtype=dtype)
pipe.enable_model_cpu_offload()

prompt = "A mystic cat with a sign that says hello world!"
image = pipe(prompt, guidance_scale=3.5, num_inference_steps=50, generator=torch.manual_seed(0)).images[0]
image.save("flux-nf4-dev.png")

model.push_to_hub("sayakpaul/flux.1-dev-nf4")

The image generates just fine. But not sure why we're not seeing any size benefit here.

image

But the loading seems broken (generated image is noise). Advise? I have uploaded the NF4 serialized state dict here: https://huggingface.co/sayakpaul/flux.1-dev-nf4

Loading script is below:

load_from_nf4_and_generate.py

"""
Some bits are from https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py
"""

from huggingface_hub import hf_hub_download
from accelerate.utils import set_module_tensor_to_device, compute_module_sizes
from accelerate import init_empty_weights
from convert_nf4_flux import _replace_with_bnb_linear, create_quantized_param, check_quantized_param
from diffusers import FluxTransformer2DModel, FluxPipeline
import safetensors.torch
import gc
import torch

dtype = torch.bfloat16
is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn")
ckpt_path = hf_hub_download("sayakpaul/flux.1-dev-nf4", filename="diffusion_pytorch_model.safetensors")
original_state_dict = safetensors.torch.load_file(ckpt_path)

with init_empty_weights():
    config = FluxTransformer2DModel.load_config("sayakpaul/flux.1-dev-nf4")
    model = FluxTransformer2DModel.from_config(config).to(dtype)
    expected_state_dict_keys = list(model.state_dict().keys())

_replace_with_bnb_linear(model, "nf4")

for param_name, param in original_state_dict.items():
    if param_name not in expected_state_dict_keys:
        continue
    
    is_param_float8_e4m3fn = is_torch_e4m3fn_available and param.dtype == torch.float8_e4m3fn
    if torch.is_floating_point(param) and not is_param_float8_e4m3fn:
        param = param.to(dtype)
    
    if not check_quantized_param(model, param_name):
        set_module_tensor_to_device(model, param_name, device=0, value=param)
    else:
        create_quantized_param(
            model, param, param_name, target_device=0, state_dict=original_state_dict, pre_quantized=True
        )

del original_state_dict
gc.collect()

print(compute_module_sizes(model)[""] / 1024 / 1204)

pipe = FluxPipeline.from_pretrained("black-forest-labs/flux.1-dev", transformer=model, torch_dtype=dtype)
pipe.enable_model_cpu_offload()

prompt = "A mystic cat with a sign that says hello world!"
image = pipe(prompt, guidance_scale=3.5, num_inference_steps=50, generator=torch.manual_seed(0)).images[0]
image.save("flux-nf4-dev-loaded.png")

NF4 serialization and loading is working fine!

Thank you very much for your great work. I would like to ask if there is any way to speed up the image generation. When I use num_inference_steps 50 on T4 GPU 16GB VRAM, it takes me more than 30 minutes to finish generating an image. Thank you very much again.

@spejamas
Copy link

When I get the gguf working I can test if it's really slower than nf4 or not.

Did you get gguf working? Sorry to resurrect an old thread—just very interested in using flux in gguf format in diffusers.

@asomoza
Copy link
Member

asomoza commented Sep 29, 2024

I converted and loaded the diffusers model but then I stopped because I'm waiting for the final quantization PR to be merged and then see what steps we should take and if it's worth to do it.

@spejamas
Copy link

I'll keep an eye on it then. Thank you very much 🫡

@Oguzhanercan
Copy link

Is there any way to use this method with the FluxControlNetPipeline? I have tried the below but I am receiving this error Input type (unsigned char) and bias type (c10::BFloat16) should be the same

from huggingface_hub import hf_hub_download
from accelerate.utils import set_module_tensor_to_device
from accelerate import init_empty_weights
from convert_nf4_flux import _replace_with_bnb_linear, create_quantized_param, check_quantized_param
from diffusers import FluxTransformer2DModel, FluxControlNetModel, FluxControlNetPipeline
import safetensors.torch
import torch
from diffusers.loaders.single_file_utils import convert_flux_transformer_checkpoint_to_diffusers
from diffusers.utils import load_image

dtype = torch.bfloat16
ckpt_path = hf_hub_download("black-forest-labs/flux.1-schnell", filename="flux1-schnell.safetensors")
original_state_dict = safetensors.torch.load_file(ckpt_path)
converted_state_dict = convert_flux_transformer_checkpoint_to_diffusers(original_state_dict)

del original_state_dict

with init_empty_weights():
    config = FluxTransformer2DModel.load_config("black-forest-labs/flux.1-schnell", subfolder="transformer")
    model = FluxTransformer2DModel.from_config(config).to(dtype)

_replace_with_bnb_linear(model, "nf4")
for param_name, param in converted_state_dict.items():
    param = param.to(dtype)
    if not check_quantized_param(model, param_name):
        set_module_tensor_to_device(model, param_name, device=0, value=param)
    else:
        create_quantized_param(model, param, param_name, target_device=0)

del converted_state_dict
controlnet_model = 'InstantX/FLUX.1-dev-Controlnet-Canny'
controlnet = FluxControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16, guidance_embeds = False)
pipe = FluxControlNetPipeline.from_pretrained("black-forest-labs/flux.1-schnell", transformer=model, controlnet=controlnet, torch_dtype=torch.bfloat16)
pipe.to('cuda')

control_image = load_image("https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Canny/resolve/main/canny.jpg")
prompt = "A girl in city, 25 years old, cool, futuristic"
image = pipe(
    prompt, 
    control_image=control_image,
    controlnet_conditioning_scale=0.6,
    num_inference_steps=3, 
    guidance_scale=1,
).images[0]
image.save("image.jpg")

hi, did you solved this problem?

@VadimPoliakov
Copy link

Is there any way to use this method with the FluxControlNetPipeline? I have tried the below but I am receiving this error Input type (unsigned char) and bias type (c10::BFloat16) should be the same

from huggingface_hub import hf_hub_download
from accelerate.utils import set_module_tensor_to_device
from accelerate import init_empty_weights
from convert_nf4_flux import _replace_with_bnb_linear, create_quantized_param, check_quantized_param
from diffusers import FluxTransformer2DModel, FluxControlNetModel, FluxControlNetPipeline
import safetensors.torch
import torch
from diffusers.loaders.single_file_utils import convert_flux_transformer_checkpoint_to_diffusers
from diffusers.utils import load_image

dtype = torch.bfloat16
ckpt_path = hf_hub_download("black-forest-labs/flux.1-schnell", filename="flux1-schnell.safetensors")
original_state_dict = safetensors.torch.load_file(ckpt_path)
converted_state_dict = convert_flux_transformer_checkpoint_to_diffusers(original_state_dict)

del original_state_dict

with init_empty_weights():
    config = FluxTransformer2DModel.load_config("black-forest-labs/flux.1-schnell", subfolder="transformer")
    model = FluxTransformer2DModel.from_config(config).to(dtype)

_replace_with_bnb_linear(model, "nf4")
for param_name, param in converted_state_dict.items():
    param = param.to(dtype)
    if not check_quantized_param(model, param_name):
        set_module_tensor_to_device(model, param_name, device=0, value=param)
    else:
        create_quantized_param(model, param, param_name, target_device=0)

del converted_state_dict
controlnet_model = 'InstantX/FLUX.1-dev-Controlnet-Canny'
controlnet = FluxControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16, guidance_embeds = False)
pipe = FluxControlNetPipeline.from_pretrained("black-forest-labs/flux.1-schnell", transformer=model, controlnet=controlnet, torch_dtype=torch.bfloat16)
pipe.to('cuda')

control_image = load_image("https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Canny/resolve/main/canny.jpg")
prompt = "A girl in city, 25 years old, cool, futuristic"
image = pipe(
    prompt, 
    control_image=control_image,
    controlnet_conditioning_scale=0.6,
    num_inference_steps=3, 
    guidance_scale=1,
).images[0]
image.save("image.jpg")

hi, did you solved this problem?

check this black-forest-labs/flux#185

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests