Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Quantization] Add quantization support for bitsandbytes #9213

Open
wants to merge 84 commits into
base: main
Choose a base branch
from

Conversation

sayakpaul
Copy link
Member

@sayakpaul sayakpaul commented Aug 19, 2024

What does this PR do?

Come back later.

  • Quantization config class (base and bitsandbytes)
  • Quantizer class (base and bitsandbytes)
  • Utilities related to bitsandbytes
  • from_pretrained() at the ModelMixin level and related changes
  • save_pretrained()
  • NF4 tests
  • INT8 (llm.int8()) tests
  • Docs

Notes

  • Even though I alluded to having a separate QuantizationLoaderMixin in [Quantization] bring quantization to diffusers core #9174, I realized that is not an approach we can take because loading and saving a quantized model is very much baked into the arguments of ModelMixin.save_pretrained() and ModelMixin.from_pretrained(). It is deeply entangled.
  • For the initial quantization support, I think it's okay to not allow passing device_map, because for a pipeline, multiple device_maps can get ugly. This will be dealt with in a follow-up PR by @SunMarc and myself.
  • For the point above, for checkpoints that are found to be sharded (Flux, for example), I have decided to merge them on CPU to simplify the implementation. This will be dealt with in a follow-up PR by @SunMarc.
  • The PR has an extensive testing suite covering training, too. However, I have decided not to add it to our CI yet. We should first let this feature flow into the community and then add the tests to our nightly CI.

No-frills code snippets

Serialization
import torch 
from diffusers import BitsAndBytesConfig, FluxTransformer2DModel, FluxPipeline
from accelerate.utils import compute_module_sizes

model_id = "black-forest-labs/FLUX.1-dev"

nf4_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)
model_nf4 = FluxTransformer2DModel.from_pretrained(
    model_id, subfolder="transformer", quantization_config=nf4_config, torch_dtype=torch.bfloat16
)
assert model_nf4.dtype == torch.uint8, model_nf4.dtype
print(model_nf4.dtype)
print(model_nf4.config.quantization_config)
print(compute_module_sizes(model_nf4)[""] / 1024 / 1024)

push_id = "sayakpaul/flux.1-dev-nf4-with-bnb-integration"
model_nf4.push_to_hub(push_id)

Serialized checkpoint: https://huggingface.co/sayakpaul/flux.1-dev-nf4-with-bnb-integration.

NF4 checkpoints of Flux transformer and T5: https://huggingface.co/sayakpaul/flux.1-dev-nf4-pkg (has Colab Notebooks, too).

Inference
import torch
from diffusers import FluxTransformer2DModel, FluxPipeline

model_id = "black-forest-labs/FLUX.1-dev"
nf4_id = "sayakpaul/flux.1-dev-nf4-with-bnb-integration"
model_nf4 = FluxTransformer2DModel.from_pretrained(nf4_id, torch_dtype=torch.bfloat16)
print(model_nf4.dtype)
print(model_nf4.config.quantization_config)

pipe = FluxPipeline.from_pretrained(model_id, transformer=model_nf4, torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload()

prompt = "A mystic cat with a sign that says hello world!"
image = pipe(prompt, guidance_scale=3.5, num_inference_steps=50, generator=torch.manual_seed(0)).images[0]
image.save("flux-nf4-dev-loaded.png")

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this ! I see that you used a lot of things from transformers. Do you think it is possible to import these (or inherit) from transformers ? This will help reducing the maintenance. I'm fine also doing that since there are not too many follow-up PR after a quantizer has been added. About the HfQuantizer class, there are a lot of methods that were created to fit transformers structure. I'm not sure we will need eveyone of these methods in diffusers. Ofc, we can still do a follow-up PR to clean up.

src/diffusers/quantizers/base.py Outdated Show resolved Hide resolved
@sayakpaul
Copy link
Member Author

sayakpaul commented Aug 20, 2024

@SunMarc I am guilty as charged but we don’t have transformers as a hard dependency for loading models in Diffusers. Pinging @DN6 to seek his opinion.

Update: Chatted with @DN6 as well. We think it's better to redefine inside diffusers without the transformers specific bits which we can clean in this PR.

@sayakpaul
Copy link
Member Author

@SunMarc I think this PR is ready for another review.

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this @sayakpaul !

src/diffusers/quantizers/base.py Show resolved Hide resolved
Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it makes sense to have this as a separate PR to add a base class because it's hard to understand what methods are needed - we should only introduce a minimum base class and gradually add functionalities as needed

can we have a PR with a minimum example working?

@sayakpaul
Copy link
Member Author

sayakpaul commented Aug 22, 2024

Okay, so, do you want me to add everything needed for bitsandbytes integration in this PR? But do note that this won’t be very different from what we have in transformers.

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Aug 22, 2024

@sayakpaul
I think so because:

  1. it is better to review that way
  2. we don't need this class in diffusers on its own because it cannot be used yet, no?

@bghira
Copy link
Contributor

bghira commented Aug 22, 2024

sometimes we can make a feature branch where a bunch of PRs can be merged into before one big honkin' PR is pushed to main at the end. and the pieces are all individually reviewed and can be tested. is this a viable approach for including quantisation?

@sayakpaul
Copy link
Member Author

Okay I will update this branch. @yiyixuxu

@SunMarc
Copy link
Member

SunMarc commented Aug 23, 2024

cc @MekkCyber for visibility

@DN6
Copy link
Collaborator

DN6 commented Aug 28, 2024

Just a few considerations for the quantization design.

I would say the initial design should start loading/inference at just the model level and then proceed to add functionality (pipeline level loading etc).

The feature needs to perform the following functions

  1. Perform on the fly quantization of large models so that they can be loaded in a low-memory dtype
    1. with from_pretrained
    2. with from_single_file
  2. Dynamically upcast to the appropriate compute dtype when running inference
  3. Save/Load already quantized versions of these large models (FP8, NF4)
  4. Allow loading/inference with LoRAs in these quantized models. (This we have to figure out in more detail)

At the moment, the most common ask seems to be the ability to load models into GPU using the FP8 dtype and run inference in a supported dtype by dynamically upcasting the necessary layers. NF4 is another format that's gaining attention.

So perhaps we should focus on this first. This mostly applies to the DiT models but large models like CogVideo might also benefit with this approach.

Some example quantized versions of models that have been doing the rounds

To cover these initial cases, we can rely on Quanto (FP8) and BitsandBytes (NF4).

Example API:

from diffusers import FluxPipeline, FluxTransformer2DModel, DiffusersQuantoConfig

# Load model in FP8 with Quanto and perform compute in configured dtype. 

quantization_config = DiffusersQuantoConfig(weights="float8", compute_dtype=torch.bfloat16)

FluxTransformer2DModel.from_pretrained("<either diffusers format or quanto format weights>", quantization_config=quantization_config)

pipe = FluxPipeline.from_pretrained("...", transformer=transformer)

The quantization config should probably take the following arguments

DiffusersQuantoConfig(
	weights_dtype="", # dtype to store weights
	compute_dtype="", # dtype to perform inference
	skip_quantize_modules=["ResBlock"]
)

I think initially we can rely on the dynamic upcasting operations performed by Quanto and BnB under the hood to start and then expand on them if needed.

Some other considerations

  1. Since we have transformers models in diffusers that can also benefit from quantized loading, we might want to consider adding a Diffusers prefix to the quantization configs. e.g DiffusersQuantoConfig so that when we import quantization configs from transformers there aren't any conflicts.
  2. For saving and loading models we can start with models saved in Quanto/BnB format.
  3. One possible challenge with Pipeline level quantized loading is that we have a mix of transformers/diffusers models. So a single config to quantize/load both types might not be possible.
  4. Single file loading has it's own set of issues, such as dealing with checkpoints that have been naively quantized. This applies to some of the Flux single file checkpoints. e.g. safetensors.torch.save_file(model.to(torch.float8_e4m3fn), "model-fp8.safetensors) and loading full pipeline single file checkpoints. But we can address these later.

@sayakpaul
Copy link
Member Author

sayakpaul commented Aug 28, 2024

This PR will be at the model-level itself. And we should not add multiple backends in a single PR. This PR aims to add bitsandbytes. We can do other backends taking this PR as a reference. I would like us to mutually agree on this before I start making progress on this PR.

Concretely, I would like to stick to the outline of the changes laid out in #9174 (along with anything related) for this PR.

The feature needs to perform the following functions

I won't advocate doing all of that in a single PR because it makes things very hard to review. We would rather want to move faster with something more minimal, confirming their effectiveness.

Allow loading/inference with LoRAs in these quantized models. (This we have to figure out in more detail)

Well, note that if the underlying LoRA wasn't trained with the base quantization precision, it might not perform as expected.

So perhaps we should focus on this first. This mostly applies to the DiT models but large models like CogVideo might also benefit with this approach.

Please note that bitsandbytes related quantization mostly applies to nn.linear whereas quanto is broader in their scopes (i.e, quanto can be applied to an nn.Conv2D as well).

@DN6
Copy link
Collaborator

DN6 commented Aug 28, 2024

This PR will be at the model-level itself. And we should not add multiple backends in a single PR. This PR aims to add bitsandbytes. We can do other backends taking this PR as a reference. I would like us to mutually agree on this before I start making progress on this PR.

Sounds good to me.

For this PR lets do

  1. from_pretrained only
  2. bnb quantization.

Copy link
Member

@stevhliu stevhliu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, this looks really good! 🔥

docs/source/en/api/quantization.md Outdated Show resolved Hide resolved
docs/source/en/quantization/bitsandbytes.md Outdated Show resolved Hide resolved
docs/source/en/quantization/bitsandbytes.md Outdated Show resolved Hide resolved
docs/source/en/quantization/bitsandbytes.md Outdated Show resolved Hide resolved
docs/source/en/quantization/bitsandbytes.md Outdated Show resolved Hide resolved
docs/source/en/quantization/bitsandbytes.md Outdated Show resolved Hide resolved
docs/source/en/quantization/bitsandbytes.md Outdated Show resolved Hide resolved
docs/source/en/quantization/bitsandbytes.md Outdated Show resolved Hide resolved
docs/source/en/quantization/bitsandbytes.md Outdated Show resolved Hide resolved
docs/source/en/quantization/overview.md Outdated Show resolved Hide resolved
@@ -526,7 +526,8 @@ def extract_init_dict(cls, config_dict, **kwargs):
init_dict[key] = config_dict.pop(key)

# 4. Give nice warning if unexpected values have been passed
if len(config_dict) > 0:
only_quant_config_remaining = len(config_dict) == 1 and "quantization_config" in config_dict
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is better to not add to cofig_dict if it is not going into __init__, i.e. at line 511

 # remove private attributes
 config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")}
# remove quantization_config
 config_dict = {k: v for k, v in config_dict.items() if k != "quantization_config")}

src/diffusers/models/modeling_utils.py Show resolved Hide resolved
if hf_quantizer is not None and not _hf_peft_config_loaded and not quantization_serializable:
raise ValueError(
f"The model is quantized with {hf_quantizer.quantization_config.quant_method} and is not serializable - check out the warnings from"
" the logger on the traceback to understand the reason why the quantized model is not serializable."
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but we raised a ValueError here, they are not going to get traceback, no?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would still throw the warnings on the console, hence.

src/diffusers/models/modeling_utils.py Outdated Show resolved Hide resolved
@@ -99,6 +131,8 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[
"""
Reads a checkpoint file, returning properly formatted errors if they arise.
"""
if isinstance(checkpoint_file, dict):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are we making this change? when will checkpoint_file passed as a dict?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We merge the sharded checkpoints (as stated in the PR description and mutually agreed upon internally) in case we're doing quantization:

model_file = _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata)

^ model_file becomes a state dict which is loaded by load_state_dict:

state_dict = load_state_dict(model_file, variant=variant)

and hence this change.

src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py Outdated Show resolved Hide resolved
src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py Outdated Show resolved Hide resolved
src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py Outdated Show resolved Hide resolved
for k, v in state_dict.items():
# `startswith` to counter for edge cases where `param_name`
# substring can be present in multiple places in the `state_dict`
if param_name + "." in k and k.startswith(param_name):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

k.split('.')[0] == param_name ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean if param_name + "." in k and k.split('.')[0] == param_name:?

# Unlike `transformers`, we don't know if we should always keep certain modules in FP32
# in case of diffusion transformer models. For language models and others alike, `lm_head`
# and tied modules are usually kept in FP32.
self.modules_to_not_convert = list(filter(None.__ne__, self.modules_to_not_convert))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you provide examples when this list would contain None?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is configured via llm_int8_skip_modules within the BitsandBytesConfig object. It is defaulted to None in our case because we don't know if there's a requirement of a default unlike language models.

@sayakpaul
Copy link
Member Author

@yiyixuxu thanks for your reviews. I think they were very nice and helpful. I have gone ahead and re-run the tests on audace and everything is green.

I have addressed your comments and made changes. PTAL.

@chuck-ma
Copy link

Hi, looks like everything is great. Don't know why approving review is still processing.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.