-
Notifications
You must be signed in to change notification settings - Fork 5.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Quantization] Add quantization support for bitsandbytes
#9213
base: main
Are you sure you want to change the base?
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this ! I see that you used a lot of things from transformers. Do you think it is possible to import these (or inherit) from transformers ? This will help reducing the maintenance. I'm fine also doing that since there are not too many follow-up PR after a quantizer has been added. About the HfQuantizer
class, there are a lot of methods that were created to fit transformers structure. I'm not sure we will need eveyone of these methods in diffusers. Ofc, we can still do a follow-up PR to clean up.
@SunMarc I am guilty as charged but we don’t have transformers as a hard dependency for loading models in Diffusers. Pinging @DN6 to seek his opinion. Update: Chatted with @DN6 as well. We think it's better to redefine inside |
@SunMarc I think this PR is ready for another review. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this @sayakpaul !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think it makes sense to have this as a separate PR to add a base class because it's hard to understand what methods are needed - we should only introduce a minimum base class and gradually add functionalities as needed
can we have a PR with a minimum example working?
Okay, so, do you want me to add everything needed for bitsandbytes integration in this PR? But do note that this won’t be very different from what we have in transformers. |
@sayakpaul
|
sometimes we can make a feature branch where a bunch of PRs can be merged into before one big honkin' PR is pushed to main at the end. and the pieces are all individually reviewed and can be tested. is this a viable approach for including quantisation? |
Okay I will update this branch. @yiyixuxu |
cc @MekkCyber for visibility |
Just a few considerations for the quantization design. I would say the initial design should start loading/inference at just the model level and then proceed to add functionality (pipeline level loading etc). The feature needs to perform the following functions
At the moment, the most common ask seems to be the ability to load models into GPU using the FP8 dtype and run inference in a supported dtype by dynamically upcasting the necessary layers. NF4 is another format that's gaining attention. So perhaps we should focus on this first. This mostly applies to the DiT models but large models like CogVideo might also benefit with this approach. Some example quantized versions of models that have been doing the rounds
To cover these initial cases, we can rely on Quanto (FP8) and BitsandBytes (NF4). Example API: from diffusers import FluxPipeline, FluxTransformer2DModel, DiffusersQuantoConfig
# Load model in FP8 with Quanto and perform compute in configured dtype.
quantization_config = DiffusersQuantoConfig(weights="float8", compute_dtype=torch.bfloat16)
FluxTransformer2DModel.from_pretrained("<either diffusers format or quanto format weights>", quantization_config=quantization_config)
pipe = FluxPipeline.from_pretrained("...", transformer=transformer) The quantization config should probably take the following arguments
I think initially we can rely on the dynamic upcasting operations performed by Quanto and BnB under the hood to start and then expand on them if needed. Some other considerations
|
This PR will be at the model-level itself. And we should not add multiple backends in a single PR. This PR aims to add Concretely, I would like to stick to the outline of the changes laid out in #9174 (along with anything related) for this PR.
I won't advocate doing all of that in a single PR because it makes things very hard to review. We would rather want to move faster with something more minimal, confirming their effectiveness.
Well, note that if the underlying LoRA wasn't trained with the base quantization precision, it might not perform as expected.
Please note that |
Sounds good to me. For this PR lets do
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, this looks really good! 🔥
Co-authored-by: Steven Liu <[email protected]>
@@ -526,7 +526,8 @@ def extract_init_dict(cls, config_dict, **kwargs): | |||
init_dict[key] = config_dict.pop(key) | |||
|
|||
# 4. Give nice warning if unexpected values have been passed | |||
if len(config_dict) > 0: | |||
only_quant_config_remaining = len(config_dict) == 1 and "quantization_config" in config_dict |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it is better to not add to cofig_dict if it is not going into __init__
, i.e. at line 511
# remove private attributes
config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")}
# remove quantization_config
config_dict = {k: v for k, v in config_dict.items() if k != "quantization_config")}
if hf_quantizer is not None and not _hf_peft_config_loaded and not quantization_serializable: | ||
raise ValueError( | ||
f"The model is quantized with {hf_quantizer.quantization_config.quant_method} and is not serializable - check out the warnings from" | ||
" the logger on the traceback to understand the reason why the quantized model is not serializable." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
but we raised a ValueError here, they are not going to get traceback, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would still throw the warnings on the console, hence.
@@ -99,6 +131,8 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[ | |||
""" | |||
Reads a checkpoint file, returning properly formatted errors if they arise. | |||
""" | |||
if isinstance(checkpoint_file, dict): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why are we making this change? when will checkpoint_file
passed as a dict?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We merge the sharded checkpoints (as stated in the PR description and mutually agreed upon internally) in case we're doing quantization:
model_file = _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata) |
^ model_file
becomes a state dict which is loaded by load_state_dict
:
state_dict = load_state_dict(model_file, variant=variant) |
and hence this change.
for k, v in state_dict.items(): | ||
# `startswith` to counter for edge cases where `param_name` | ||
# substring can be present in multiple places in the `state_dict` | ||
if param_name + "." in k and k.startswith(param_name): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
k.split('.')[0] == param_name
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean if param_name + "." in k and k.split('.')[0] == param_name:
?
# Unlike `transformers`, we don't know if we should always keep certain modules in FP32 | ||
# in case of diffusion transformer models. For language models and others alike, `lm_head` | ||
# and tied modules are usually kept in FP32. | ||
self.modules_to_not_convert = list(filter(None.__ne__, self.modules_to_not_convert)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you provide examples when this list would contain None
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is configured via llm_int8_skip_modules
within the BitsandBytesConfig
object. It is defaulted to None
in our case because we don't know if there's a requirement of a default unlike language models.
@yiyixuxu thanks for your reviews. I think they were very nice and helpful. I have gone ahead and re-run the tests on I have addressed your comments and made changes. PTAL. |
Hi, looks like everything is great. Don't know why approving review is still processing. |
What does this PR do?
Come back later.
bitsandbytes
)bitsandbytes
)bitsandbytes
from_pretrained()
at theModelMixin
level and related changessave_pretrained()
Notes
QuantizationLoaderMixin
in [Quantization] bring quantization to diffusers core #9174, I realized that is not an approach we can take because loading and saving a quantized model is very much baked into the arguments ofModelMixin.save_pretrained()
andModelMixin.from_pretrained()
. It is deeply entangled.device_map
, because for a pipeline, multiple device_maps can get ugly. This will be dealt with in a follow-up PR by @SunMarc and myself.No-frills code snippets
Serialization
Serialized checkpoint: https://huggingface.co/sayakpaul/flux.1-dev-nf4-with-bnb-integration.
NF4 checkpoints of Flux transformer and T5: https://huggingface.co/sayakpaul/flux.1-dev-nf4-pkg (has Colab Notebooks, too).
Inference