Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Quantization] Add quantization support for bitsandbytes #9213

Merged
merged 119 commits into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
119 commits
Select commit Hold shift + click to select a range
e634ff2
quantization config.
sayakpaul Aug 19, 2024
02a6dff
fix-copies
sayakpaul Aug 19, 2024
c385a2b
Merge branch 'main' into quantization-config
sayakpaul Aug 20, 2024
0355875
Merge branch 'main' into quantization-config
sayakpaul Aug 20, 2024
e41b494
Merge branch 'main' into quantization-config
sayakpaul Aug 20, 2024
dfb33eb
Merge branch 'main' into quantization-config
sayakpaul Aug 21, 2024
e492655
Merge branch 'main' into quantization-config
sayakpaul Aug 22, 2024
6e86cc0
fix
sayakpaul Aug 22, 2024
58a3d15
modules_to_not_convert
sayakpaul Aug 22, 2024
1d477f9
Merge branch 'main' into quantization-config
sayakpaul Aug 22, 2024
bd7f46d
Merge branch 'main' into quantization-config
sayakpaul Aug 23, 2024
d5d7bb6
Merge branch 'main' into quantization-config
sayakpaul Aug 28, 2024
44c8a75
Merge branch 'main' into quantization-config
sayakpaul Aug 28, 2024
6a0fcdc
add bitsandbytes utilities.
sayakpaul Aug 28, 2024
e4590fa
make progress.
sayakpaul Aug 28, 2024
77a1438
Merge branch 'main' into quantization-config
sayakpaul Aug 29, 2024
335ab6b
fixes
sayakpaul Aug 29, 2024
d44ef85
quality
sayakpaul Aug 29, 2024
210fa1e
up
sayakpaul Aug 29, 2024
f4feee1
up
sayakpaul Aug 29, 2024
e8c1722
Merge branch 'main' into quantization-config
sayakpaul Aug 29, 2024
7f86a71
Merge branch 'main' into quantization-config
sayakpaul Aug 29, 2024
ba671b6
minor
sayakpaul Aug 30, 2024
c1a9f13
up
sayakpaul Aug 30, 2024
4489c54
Merge branch 'main' into quantization-config
sayakpaul Aug 30, 2024
f2ca5e2
up
sayakpaul Aug 30, 2024
d6b8954
fix
sayakpaul Aug 30, 2024
45029e2
provide credits where due.
sayakpaul Aug 30, 2024
4eb468a
make configurations work.
sayakpaul Aug 30, 2024
939965d
fixes
sayakpaul Aug 30, 2024
8557166
Merge branch 'main' into quantization-config
sayakpaul Aug 30, 2024
d098d07
fix
sayakpaul Aug 30, 2024
c4a0074
update_missing_keys
sayakpaul Aug 30, 2024
ee45612
fix
sayakpaul Aug 30, 2024
b24c0a7
fix
sayakpaul Aug 31, 2024
473505c
make it work.
sayakpaul Aug 31, 2024
c795c82
fix
sayakpaul Aug 31, 2024
c1d5b96
Merge branch 'main' into quantization-config
sayakpaul Aug 31, 2024
af7caca
provide credits to transformers.
sayakpaul Aug 31, 2024
80967f5
empty commit
sayakpaul Sep 1, 2024
3bdf25a
handle to() better.
sayakpaul Sep 2, 2024
27415cc
tests
sayakpaul Sep 2, 2024
51cac09
change to bnb from bitsandbytes
sayakpaul Sep 2, 2024
15f3032
fix tests
sayakpaul Sep 2, 2024
77c9fdb
better safeguard.
sayakpaul Sep 2, 2024
ddc9f29
change merging status
sayakpaul Sep 2, 2024
44c4109
courtesy to transformers.
sayakpaul Sep 2, 2024
27666a8
move upper.
sayakpaul Sep 2, 2024
3464d83
better
sayakpaul Sep 2, 2024
b106124
Merge branch 'main' into quantization-config
sayakpaul Sep 2, 2024
330fa0a
Merge branch 'main' into quantization-config
sayakpaul Sep 2, 2024
abc8607
make the unused kwargs warning friendlier.
sayakpaul Sep 3, 2024
31725aa
harmonize changes with https://github.com/huggingface/transformers/pu…
sayakpaul Sep 3, 2024
e5938a6
style
sayakpaul Sep 3, 2024
444588f
trainin tests
sayakpaul Sep 3, 2024
d3360ce
Merge branch 'main' into quantization-config
sayakpaul Sep 3, 2024
d8b35f4
Merge branch 'main' into quantization-config
sayakpaul Sep 3, 2024
859f2d7
Merge branch 'main' into quantization-config
sayakpaul Sep 4, 2024
3b2d6e1
feedback part i.
sayakpaul Sep 4, 2024
5799954
Add Flux inpainting and Flux Img2Img (#9135)
Gothos Sep 4, 2024
8e4bd08
Revert "Add Flux inpainting and Flux Img2Img (#9135)"
sayakpaul Sep 6, 2024
835d4ad
tests
sayakpaul Sep 6, 2024
27075fe
don
sayakpaul Sep 6, 2024
5c00c1c
Merge branch 'main' into quantization-config
sayakpaul Sep 6, 2024
5d633a0
Merge branch 'main' into quantization-config
sayakpaul Sep 8, 2024
c381fe0
Apply suggestions from code review
sayakpaul Sep 10, 2024
3c92878
Merge branch 'main' into quantization-config
sayakpaul Sep 10, 2024
acdeb25
contribution guide.
sayakpaul Sep 11, 2024
aa295b7
Merge branch 'main' into quantization-config
sayakpaul Sep 11, 2024
7f7c9ce
Merge branch 'main' into quantization-config
sayakpaul Sep 15, 2024
55f96d8
Merge branch 'main' into quantization-config
sayakpaul Sep 15, 2024
b28cc65
changes
sayakpaul Sep 17, 2024
8328e86
Merge branch 'main' into quantization-config
sayakpaul Sep 17, 2024
9758942
empty
sayakpaul Sep 17, 2024
b1a9878
fix tests
sayakpaul Sep 17, 2024
971305b
harmonize with https://github.com/huggingface/transformers/pull/33546.
sayakpaul Sep 18, 2024
f41adf1
numpy_cosine_distance
sayakpaul Sep 19, 2024
0bcb88b
Merge branch 'main' into quantization-config
sayakpaul Sep 19, 2024
55b3696
Merge branch 'main' into quantization-config
sayakpaul Sep 20, 2024
4cb3a6d
Merge branch 'main' into quantization-config
sayakpaul Sep 23, 2024
8a03eae
Merge branch 'main' into quantization-config
sayakpaul Sep 24, 2024
53f0a92
Merge branch 'main' into quantization-config
sayakpaul Sep 26, 2024
6aab47c
Merge branch 'main' into quantization-config
sayakpaul Sep 27, 2024
9b9a610
resolved conflicts,
sayakpaul Sep 29, 2024
510d57a
Merge branch 'main' into quantization-config
sayakpaul Oct 10, 2024
555a5ae
config_dict modification.
sayakpaul Oct 10, 2024
da10365
remove if config comment.
sayakpaul Oct 10, 2024
71316a6
note for load_state_dict changes.
sayakpaul Oct 10, 2024
12f5c59
float8 check.
sayakpaul Oct 10, 2024
5e722cd
quantizer.
sayakpaul Oct 10, 2024
c78dd0c
raise an error for non-True low_cpu_mem_usage values when using quant.
sayakpaul Oct 10, 2024
af3ecea
low_cpu_mem_usage shenanigans when using fp32 modules.
sayakpaul Oct 10, 2024
a473d28
don't re-assign _pre_quantization_type.
sayakpaul Oct 10, 2024
870d74f
make comments clear.
sayakpaul Oct 10, 2024
3e6cfeb
remove comments.
sayakpaul Oct 10, 2024
673993c
handle mixed types better when moving to cpu.
sayakpaul Oct 10, 2024
0d5f2f7
add tests to check if we're throwing warning rightly.
sayakpaul Oct 10, 2024
3cb20fe
better check.
sayakpaul Oct 10, 2024
10940a9
fix 8bit test_quality.
sayakpaul Oct 10, 2024
c0a88ae
Merge branch 'main' into quantization-config
sayakpaul Oct 10, 2024
dcc5bc5
Merge branch 'main' into quantization-config
sayakpaul Oct 12, 2024
5e0b4eb
Merge branch 'main' into quantization-config
sayakpaul Oct 12, 2024
569dd96
Merge branch 'main' into quantization-config
sayakpaul Oct 13, 2024
8bdc846
Merge branch 'main' into quantization-config
sayakpaul Oct 15, 2024
ff8ddef
handle dtype more robustly.
sayakpaul Oct 15, 2024
de6394a
better message when keep_in_fp32_modules.
sayakpaul Oct 15, 2024
81bb48a
handle dtype casting.
sayakpaul Oct 15, 2024
c5e62ae
Merge branch 'main' into quantization-config
sayakpaul Oct 15, 2024
d023b40
Merge branch 'main' into quantization-config
sayakpaul Oct 16, 2024
a3d2655
Merge branch 'main' into quantization-config
sayakpaul Oct 16, 2024
700b0f3
Merge branch 'main' into quantization-config
sayakpaul Oct 18, 2024
0ae70fe
fix dtype checks in pipeline.
sayakpaul Oct 18, 2024
ecdf1d0
fix warning message.
sayakpaul Oct 18, 2024
aea3398
Update src/diffusers/models/modeling_utils.py
sayakpaul Oct 18, 2024
3a91974
Merge branch 'main' into quantization-config
sayakpaul Oct 18, 2024
5d8e844
Merge branch 'main' into quantization-config
sayakpaul Oct 19, 2024
501a6ba
mitigate the confusing cpu warning
sayakpaul Oct 19, 2024
1a931cb
Merge branch 'main' into quantization-config
sayakpaul Oct 21, 2024
2fa8fb9
Merge branch 'main' into quantization-config
sayakpaul Oct 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
"loaders": ["FromOriginalModelMixin"],
"models": [],
"pipelines": [],
"quantizers": [],
"schedulers": [],
"utils": [
"OptionalDependencyNotAvailable",
Expand Down Expand Up @@ -122,7 +123,6 @@
"VQModel",
]
)

_import_structure["optimization"] = [
"get_constant_schedule",
"get_constant_schedule_with_warmup",
Expand Down Expand Up @@ -154,6 +154,7 @@
"StableDiffusionMixin",
]
)
_import_structure["quantizers"] = ["HfQuantizer"]
_import_structure["schedulers"].extend(
[
"AmusedScheduler",
Expand Down Expand Up @@ -616,6 +617,7 @@
ScoreSdeVePipeline,
StableDiffusionMixin,
)
from .quantizers import HfQuantizer
from .schedulers import (
AmusedScheduler,
CMStochasticIterativeScheduler,
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/quantizers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .base import HfQuantizer
235 changes: 235 additions & 0 deletions src/diffusers/quantizers/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Adapted from
https://github.com/huggingface/transformers/blob/52cb4034ada381fe1ffe8d428a1076e5411a8026/src/transformers/quantizers/base.py
"""

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

from ..utils import is_torch_available
from .quantization_config import QuantizationConfigMixin


if TYPE_CHECKING:
from ..models.modeling_utils import ModelMixin

if is_torch_available():
import torch


class HfQuantizer(ABC):
"""
Abstract class of the HuggingFace quantizer. Supports for now quantizing HF diffusers models for inference and/or
quantization. This class is used only for diffusers.models.modeling_utils.ModelMixin.from_pretrained and cannot be
easily used outside the scope of that method yet.

Attributes
quantization_config (`diffusers.quantizers.quantization_config.QuantizationConfigMixin`):
The quantization config that defines the quantization parameters of your model that you want to quantize.
modules_to_not_convert (`List[str]`, *optional*):
The list of module names to not convert when quantizing the model.
required_packages (`List[str]`, *optional*):
The list of required pip packages to install prior to using the quantizer
requires_calibration (`bool`):
Whether the quantization method requires to calibrate the model before using it.
requires_parameters_quantization (`bool`):
Whether the quantization method requires to create a new Parameter. For example, for bitsandbytes, it is
required to create a new xxxParameter in order to properly quantize the model.
"""

requires_calibration = False
required_packages = None
requires_parameters_quantization = False
sayakpaul marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs):
self.quantization_config = quantization_config

# -- Handle extra kwargs below --
self.modules_to_not_convert = kwargs.pop("modules_to_not_convert", [])
self.pre_quantized = kwargs.pop("pre_quantized", True)

if not self.pre_quantized and self.requires_calibration:
raise ValueError(
f"The quantization method {quantization_config.quant_method} does require the model to be pre-quantized."
f" You explicitly passed `pre_quantized=False` meaning your model weights are not quantized. Make sure to "
f"pass `pre_quantized=True` while knowing what you are doing."
)

def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
"""
Some quantization methods require to explicitly set the dtype of the model to a target dtype. You need to
override this method in case you want to make sure that behavior is preserved

Args:
torch_dtype (`torch.dtype`):
The input dtype that is passed in `from_pretrained`
"""
return torch_dtype

def update_device_map(self, device_map: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
"""
Override this method if you want to pass a override the existing device map with a new one. E.g. for
bitsandbytes, since `accelerate` is a hard requirement, if no device_map is passed, the device_map is set to
`"auto"``

Args:
device_map (`Union[dict, str]`, *optional*):
The device_map that is passed through the `from_pretrained` method.
"""
return device_map

def adjust_target_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
"""
Override this method if you want to adjust the `target_dtype` variable used in `from_pretrained` to compute the
device_map in case the device_map is a `str`. E.g. for bitsandbytes we force-set `target_dtype` to `torch.int8`
and for 4-bit we pass a custom enum `accelerate.CustomDtype.int4`.

Args:
torch_dtype (`torch.dtype`, *optional*):
The torch_dtype that is used to compute the device_map.
"""
return torch_dtype

def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]:
"""
Override this method if you want to adjust the `missing_keys`.

Args:
missing_keys (`List[str]`, *optional*):
The list of missing keys in the checkpoint compared to the state dict of the model
"""
return missing_keys

def get_special_dtypes_update(self, model, torch_dtype: "torch.dtype") -> Dict[str, "torch.dtype"]:
"""
returns dtypes for modules that are not quantized - used for the computation of the device_map in case one
passes a str as a device_map. The method will use the `modules_to_not_convert` that is modified in
`_process_model_before_weight_loading`.

Args:
model (`~diffusers.models.modeling_utils.ModelMixin`):
The model to quantize
torch_dtype (`torch.dtype`):
The dtype passed in `from_pretrained` method.
"""

return {
name: torch_dtype
for name, _ in model.named_parameters()
if any(m in name for m in self.modules_to_not_convert)
}

def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]:
"""adjust max_memory argument for infer_auto_device_map() if extra memory is needed for quantization"""
return max_memory

def check_quantized_param(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO check_is_quantized_param or check_if_quantized_param more explicitly conveys what this method does.

self,
model: "ModelMixin",
param_value: "torch.Tensor",
param_name: str,
state_dict: Dict[str, Any],
**kwargs,
) -> bool:
"""
checks if a loaded state_dict component is part of quantized param + some validation; only defined if
requires_parameters_quantization == True for quantization methods that require to create a new parameters for
quantization.
"""
return False

def create_quantized_param(self, *args, **kwargs) -> "torch.nn.Parameter":
"""
takes needed components from state_dict and creates quantized param; only applicable if
requires_parameters_quantization == True
"""
if not self.requires_parameters_quantization:
raise AttributeError(
f"`.create_quantized_param()` method is not supported by quantizer class {self.__class__.__name__}."
)

def validate_environment(self, *args, **kwargs):
"""
This method is used to potentially check for potential conflicts with arguments that are passed in
`from_pretrained`. You need to define it for all future quantizers that are integrated with diffusers. If no
explicit check are needed, simply return nothing.
"""
return

def preprocess_model(self, model: "ModelMixin", **kwargs):
"""
Setting model attributes and/or converting model before weights loading. At this point the model should be
initialized on the meta device so you can freely manipulate the skeleton of the model in order to replace
modules in-place. Make sure to override the abstract method `_process_model_before_weight_loading`.

Args:
model (`~diffusers.models.modeling_utils.ModelMixin`):
The model to quantize
kwargs (`dict`, *optional*):
The keyword arguments that are passed along `_process_model_before_weight_loading`.
"""
model.is_quantized = True
model.quantization_method = self.quantization_config.quant_method
sayakpaul marked this conversation as resolved.
Show resolved Hide resolved
return self._process_model_before_weight_loading(model, **kwargs)

def postprocess_model(self, model: "ModelMixin", **kwargs):
"""
Post-process the model post weights loading. Make sure to override the abstract method
`_process_model_after_weight_loading`.

Args:
model (`~diffusers.models.modeling_utils.ModelMixin`):
The model to quantize
kwargs (`dict`, *optional*):
The keyword arguments that are passed along `_process_model_after_weight_loading`.
"""
return self._process_model_after_weight_loading(model, **kwargs)

def dequantize(self, model):
"""
Potentially dequantize the model to retrive the original model, with some loss in accuracy / performance. Note
not all quantization schemes support this.
"""
model = self._dequantize(model)

# Delete quantizer and quantization config
del model.hf_quantizer

return model

def _dequantize(self, model):
raise NotImplementedError(
f"{self.quantization_config.quant_method} has no implementation of `dequantize`, please raise an issue on GitHub."
)

@abstractmethod
def _process_model_before_weight_loading(self, model, **kwargs):
...

@abstractmethod
def _process_model_after_weight_loading(self, model, **kwargs):
...

@property
@abstractmethod
def is_serializable(self):
...

@property
@abstractmethod
def is_trainable(self):
...
Loading
Loading