diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index b331e4b13760..d2c2e14b6113 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -146,6 +146,12 @@ title: Reinforcement learning training with DDPO title: Methods title: Training +- sections: + - local: quantization/overview + title: Getting Started + - local: quantization/bitsandbytes + title: bitsandbytes + title: Quantization Methods - sections: - local: optimization/fp16 title: Speed up inference @@ -205,6 +211,8 @@ title: Logging - local: api/outputs title: Outputs + - local: api/quantization + title: Quantization title: Main Classes - isExpanded: false sections: diff --git a/docs/source/en/api/quantization.md b/docs/source/en/api/quantization.md new file mode 100644 index 000000000000..2fbde9e707ea --- /dev/null +++ b/docs/source/en/api/quantization.md @@ -0,0 +1,33 @@ + + +# Quantization + +Quantization techniques reduce memory and computational costs by representing weights and activations with lower-precision data types like 8-bit integers (int8). This enables loading larger models you normally wouldn't be able to fit into memory, and speeding up inference. Diffusers supports 8-bit and 4-bit quantization with [bitsandbytes](https://huggingface.co/docs/bitsandbytes/en/index). + +Quantization techniques that aren't supported in Transformers can be added with the [`DiffusersQuantizer`] class. + + + +Learn how to quantize models in the [Quantization](../quantization/overview) guide. + + + + +## BitsAndBytesConfig + +[[autodoc]] BitsAndBytesConfig + +## DiffusersQuantizer + +[[autodoc]] quantizers.base.DiffusersQuantizer diff --git a/docs/source/en/quantization/bitsandbytes.md b/docs/source/en/quantization/bitsandbytes.md new file mode 100644 index 000000000000..f272346aa2e2 --- /dev/null +++ b/docs/source/en/quantization/bitsandbytes.md @@ -0,0 +1,267 @@ + + +# bitsandbytes + +[bitsandbytes](https://huggingface.co/docs/bitsandbytes/index) is the easiest option for quantizing a model to 8 and 4-bit. 8-bit quantization multiplies outliers in fp16 with non-outliers in int8, converts the non-outlier values back to fp16, and then adds them together to return the weights in fp16. This reduces the degradative effect outlier values have on a model's performance. + +4-bit quantization compresses a model even further, and it is commonly used with [QLoRA](https://hf.co/papers/2305.14314) to finetune quantized LLMs. + + +To use bitsandbytes, make sure you have the following libraries installed: + +```bash +pip install diffusers transformers accelerate bitsandbytes -U +``` + +Now you can quantize a model by passing a [`BitsAndBytesConfig`] to [`~ModelMixin.from_pretrained`]. This works for any model in any modality, as long as it supports loading with [Accelerate](https://hf.co/docs/accelerate/index) and contains `torch.nn.Linear` layers. + + + + +Quantizing a model in 8-bit halves the memory-usage: + +```py +from diffusers import FluxTransformer2DModel, BitsAndBytesConfig + +quantization_config = BitsAndBytesConfig(load_in_8bit=True) + +model_8bit = FluxTransformer2DModel.from_pretrained( + "black-forest-labs/FLUX.1-dev", + subfolder="transformer", + quantization_config=quantization_config +) +``` + +By default, all the other modules such as `torch.nn.LayerNorm` are converted to `torch.float16`. You can change the data type of these modules with the `torch_dtype` parameter if you want: + +```py +from diffusers import FluxTransformer2DModel, BitsAndBytesConfig + +quantization_config = BitsAndBytesConfig(load_in_8bit=True) + +model_8bit = FluxTransformer2DModel.from_pretrained( + "black-forest-labs/FLUX.1-dev", + subfolder="transformer", + quantization_config=quantization_config, + torch_dtype=torch.float32 +) +model_8bit.transformer_blocks.layers[-1].norm2.weight.dtype +``` + +Once a model is quantized, you can push the model to the Hub with the [`~ModelMixin.push_to_hub`] method. The quantization `config.json` file is pushed first, followed by the quantized model weights. + +```py +from diffusers import FluxTransformer2DModel, BitsAndBytesConfig + +quantization_config = BitsAndBytesConfig(load_in_8bit=True) + +model_8bit = FluxTransformer2DModel.from_pretrained( + "black-forest-labs/FLUX.1-dev", + subfolder="transformer", + quantization_config=quantization_config +) +``` + + + + +Quantizing a model in 4-bit reduces your memory-usage by 4x: + +```py +from diffusers import FluxTransformer2DModel, BitsAndBytesConfig + +quantization_config = BitsAndBytesConfig(load_in_4bit=True) + +model_4bit = FluxTransformer2DModel.from_pretrained( + "black-forest-labs/FLUX.1-dev", + subfolder="transformer", + quantization_config=quantization_config +) +``` + +By default, all the other modules such as `torch.nn.LayerNorm` are converted to `torch.float16`. You can change the data type of these modules with the `torch_dtype` parameter if you want: + +```py +from diffusers import FluxTransformer2DModel, BitsAndBytesConfig + +quantization_config = BitsAndBytesConfig(load_in_4bit=True) + +model_4bit = FluxTransformer2DModel.from_pretrained( + "black-forest-labs/FLUX.1-dev", + subfolder="transformer", + quantization_config=quantization_config, + torch_dtype=torch.float32 +) +model_4bit.transformer_blocks.layers[-1].norm2.weight.dtype +``` + +Call [`~ModelMixin.push_to_hub`] after loading it in 4-bit precision. You can also save the serialized 4-bit models locally with [`~ModelMixin.save_pretrained`]. + + + + + + +Training with 8-bit and 4-bit weights are only supported for training *extra* parameters. + + + +Check your memory footprint with the `get_memory_footprint` method: + +```py +print(model.get_memory_footprint()) +``` + +Quantized models can be loaded from the [`~ModelMixin.from_pretrained`] method without needing to specify the `quantization_config` parameters: + +```py +from diffusers import FluxTransformer2DModel, BitsAndBytesConfig + +quantization_config = BitsAndBytesConfig(load_in_4bit=True) + +model_4bit = FluxTransformer2DModel.from_pretrained( + "sayakpaul/flux.1-dev-nf4-pkg", subfolder="transformer" +) +``` + +## 8-bit (LLM.int8() algorithm) + + + +Learn more about the details of 8-bit quantization in this [blog post](https://huggingface.co/blog/hf-bitsandbytes-integration)! + + + +This section explores some of the specific features of 8-bit models, such as outlier thresholds and skipping module conversion. + +### Outlier threshold + +An "outlier" is a hidden state value greater than a certain threshold, and these values are computed in fp16. While the values are usually normally distributed ([-3.5, 3.5]), this distribution can be very different for large models ([-60, 6] or [6, 60]). 8-bit quantization works well for values ~5, but beyond that, there is a significant performance penalty. A good default threshold value is 6, but a lower threshold may be needed for more unstable models (small models or finetuning). + +To find the best threshold for your model, we recommend experimenting with the `llm_int8_threshold` parameter in [`BitsAndBytesConfig`]: + +```py +from diffusers import FluxTransformer2DModel, BitsAndBytesConfig + +quantization_config = BitsAndBytesConfig( + load_in_8bit=True, llm_int8_threshold=10, +) + +model_8bit = FluxTransformer2DModel.from_pretrained( + "black-forest-labs/FLUX.1-dev", + subfolder="transformer", + quantization_config=quantization_config, +) +``` + +### Skip module conversion + +For some models, you don't need to quantize every module to 8-bit which can actually cause instability. For example, for diffusion models like [Stable Diffusion 3](../api/pipelines/stable_diffusion/stable_diffusion_3), the `proj_out` module can be skipped using the `llm_int8_skip_modules` parameter in [`BitsAndBytesConfig`]: + +```py +from diffusers import SD3Transformer2DModel, BitsAndBytesConfig + +quantization_config = BitsAndBytesConfig( + load_in_8bit=True, llm_int8_skip_modules=["proj_out"], +) + +model_8bit = SD3Transformer2DModel.from_pretrained( + "stabilityai/stable-diffusion-3-medium-diffusers", + subfolder="transformer", + quantization_config=quantization_config, +) +``` + + +## 4-bit (QLoRA algorithm) + + + +Learn more about its details in this [blog post](https://huggingface.co/blog/4bit-transformers-bitsandbytes). + + + +This section explores some of the specific features of 4-bit models, such as changing the compute data type, using the Normal Float 4 (NF4) data type, and using nested quantization. + + +### Compute data type + +To speedup computation, you can change the data type from float32 (the default value) to bf16 using the `bnb_4bit_compute_dtype` parameter in [`BitsAndBytesConfig`]: + +```py +import torch +from diffusers import BitsAndBytesConfig + +quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16) +``` + +### Normal Float 4 (NF4) + +NF4 is a 4-bit data type from the [QLoRA](https://hf.co/papers/2305.14314) paper, adapted for weights initialized from a normal distribution. You should use NF4 for training 4-bit base models. This can be configured with the `bnb_4bit_quant_type` parameter in the [`BitsAndBytesConfig`]: + +```py +from diffusers import BitsAndBytesConfig + +nf4_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", +) + +model_nf4 = SD3Transformer2DModel.from_pretrained( + "stabilityai/stable-diffusion-3-medium-diffusers", + subfolder="transformer", + quantization_config=nf4_config, +) +``` + +For inference, the `bnb_4bit_quant_type` does not have a huge impact on performance. However, to remain consistent with the model weights, you should use the `bnb_4bit_compute_dtype` and `torch_dtype` values. + +### Nested quantization + +Nested quantization is a technique that can save additional memory at no additional performance cost. This feature performs a second quantization of the already quantized weights to save an additional 0.4 bits/parameter. + +```py +from diffusers import BitsAndBytesConfig + +double_quant_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_use_double_quant=True, +) + +double_quant_model = SD3Transformer2DModel.from_pretrained( + "stabilityai/stable-diffusion-3-medium-diffusers", + subfolder="transformer", + quantization_config=double_quant_config, +) +``` + +## Dequantizing `bitsandbytes` models + +Once quantized, you can dequantize the model to the original precision but this might result in a small quality loss of the model. Make sure you have enough GPU RAM to fit the dequantized model. + +```python +from diffusers import BitsAndBytesConfig + +double_quant_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_use_double_quant=True, +) + +double_quant_model = SD3Transformer2DModel.from_pretrained( + "stabilityai/stable-diffusion-3-medium-diffusers", + subfolder="transformer", + quantization_config=double_quant_config, +) +model.dequantize() +``` \ No newline at end of file diff --git a/docs/source/en/quantization/overview.md b/docs/source/en/quantization/overview.md new file mode 100644 index 000000000000..d8adbc85a259 --- /dev/null +++ b/docs/source/en/quantization/overview.md @@ -0,0 +1,35 @@ + + +# Quantization + +Quantization techniques focus on representing data with less information while also trying to not lose too much accuracy. This often means converting a data type to represent the same information with fewer bits. For example, if your model weights are stored as 32-bit floating points and they're quantized to 16-bit floating points, this halves the model size which makes it easier to store and reduces memory-usage. Lower precision can also speedup inference because it takes less time to perform calculations with fewer bits. + + + +Interested in adding a new quantization method to Transformers? Refer to the [Contribute new quantization method guide](https://huggingface.co/docs/transformers/main/en/quantization/contribute) to learn more about adding a new quantization method. + + + + + +If you are new to the quantization field, we recommend you to check out these beginner-friendly courses about quantization in collaboration with DeepLearning.AI: + +* [Quantization Fundamentals with Hugging Face](https://www.deeplearning.ai/short-courses/quantization-fundamentals-with-hugging-face/) +* [Quantization in Depth](https://www.deeplearning.ai/short-courses/quantization-in-depth/) + + + +## When to use what? + +This section will be expanded once Diffusers has multiple quantization backends. Currently, we only support `bitsandbytes`. [This resource](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) provides a good overview of the pros and cons of different quantization techniques. \ No newline at end of file diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index dedb6f5c7f14..f55d7566db83 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -31,6 +31,7 @@ "loaders": ["FromOriginalModelMixin"], "models": [], "pipelines": [], + "quantizers.quantization_config": ["BitsAndBytesConfig"], "schedulers": [], "utils": [ "OptionalDependencyNotAvailable", @@ -123,7 +124,6 @@ "VQModel", ] ) - _import_structure["optimization"] = [ "get_constant_schedule", "get_constant_schedule_with_warmup", @@ -155,6 +155,7 @@ "StableDiffusionMixin", ] ) + _import_structure["quantizers"] = ["DiffusersQuantizer"] _import_structure["schedulers"].extend( [ "AmusedScheduler", @@ -533,6 +534,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: from .configuration_utils import ConfigMixin + from .quantizers.quantization_config import BitsAndBytesConfig try: if not is_onnx_available(): @@ -626,6 +628,7 @@ ScoreSdeVePipeline, StableDiffusionMixin, ) + from .quantizers import DiffusersQuantizer from .schedulers import ( AmusedScheduler, CMStochasticIterativeScheduler, diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index 3dccd785cae4..85728f10d560 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -526,7 +526,8 @@ def extract_init_dict(cls, config_dict, **kwargs): init_dict[key] = config_dict.pop(key) # 4. Give nice warning if unexpected values have been passed - if len(config_dict) > 0: + only_quant_config_remaining = len(config_dict) == 1 and "quantization_config" in config_dict + if len(config_dict) > 0 and not only_quant_config_remaining: logger.warning( f"The config attributes {config_dict} were passed to {cls.__name__}, " "but are not expected and will be ignored. Please verify your " @@ -586,10 +587,20 @@ def to_json_saveable(value): value = value.as_posix() return value + # IFWatermarker, for example, doesn't have a `config`. + if "quantization_config" in config_dict: + config_dict["quantization_config"] = ( + config_dict.quantization_config.to_dict() + if not isinstance(config_dict.quantization_config, dict) + else config_dict.quantization_config + ) + config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()} # Don't save "_ignore_files" or "_use_default_values" config_dict.pop("_ignore_files", None) config_dict.pop("_use_default_values", None) + # pop the `_pre_quantization_dtype` as torch.dtypes are not serializable. + _ = config_dict.pop("_pre_quantization_dtype", None) return json.dumps(config_dict, indent=2, sort_keys=True) + "\n" diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index c9eb664443b5..8b95a2780956 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -25,6 +25,7 @@ import torch from huggingface_hub.utils import EntryNotFoundError +from ..quantizers.quantization_config import QuantizationMethod from ..utils import ( SAFE_WEIGHTS_INDEX_NAME, SAFETENSORS_FILE_EXTENSION, @@ -54,11 +55,36 @@ # Adapted from `transformers` (see modeling_utils.py) -def _determine_device_map(model: torch.nn.Module, device_map, max_memory, torch_dtype): +def _determine_device_map( + model: torch.nn.Module, device_map, max_memory, torch_dtype, keep_in_fp32_modules=[], hf_quantizer=None +): if isinstance(device_map, str): + special_dtypes = {} + if hf_quantizer is not None: + special_dtypes.update(hf_quantizer.get_special_dtypes_update(model, torch_dtype)) + special_dtypes.update( + { + name: torch.float32 + for name, _ in model.named_parameters() + if any(m in name for m in keep_in_fp32_modules) + } + ) + + target_dtype = torch_dtype + if hf_quantizer is not None: + target_dtype = hf_quantizer.adjust_target_dtype(target_dtype) + no_split_modules = model._get_no_split_modules(device_map) device_map_kwargs = {"no_split_module_classes": no_split_modules} + if "special_dtypes" in inspect.signature(infer_auto_device_map).parameters: + device_map_kwargs["special_dtypes"] = special_dtypes + elif len(special_dtypes) > 0: + logger.warning( + "This model has some weights that should be kept in higher precision, you need to upgrade " + "`accelerate` to properly deal with them (`pip install --upgrade accelerate`)." + ) + if device_map != "sequential": max_memory = get_balanced_memory( model, @@ -70,8 +96,14 @@ def _determine_device_map(model: torch.nn.Module, device_map, max_memory, torch_ else: max_memory = get_max_memory(max_memory) + if hf_quantizer is not None: + max_memory = hf_quantizer.adjust_max_memory(max_memory) + device_map_kwargs["max_memory"] = max_memory - device_map = infer_auto_device_map(model, dtype=torch_dtype, **device_map_kwargs) + device_map = infer_auto_device_map(model, dtype=target_dtype, **device_map_kwargs) + + if hf_quantizer is not None: + hf_quantizer.validate_environment(device_map=device_map) return device_map @@ -100,6 +132,8 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[ """ Reads a checkpoint file, returning properly formatted errors if they arise. """ + if isinstance(checkpoint_file, dict): + return checkpoint_file try: file_extension = os.path.basename(checkpoint_file).split(".")[-1] if file_extension == SAFETENSORS_FILE_EXTENSION: @@ -137,29 +171,57 @@ def load_model_dict_into_meta( device: Optional[Union[str, torch.device]] = None, dtype: Optional[Union[str, torch.dtype]] = None, model_name_or_path: Optional[str] = None, + hf_quantizer=None, + keep_in_fp32_modules=None, ) -> List[str]: - device = device or torch.device("cpu") + if hf_quantizer is None: + device = device or torch.device("cpu") dtype = dtype or torch.float32 + is_quantized = hf_quantizer is not None + is_quant_method_bnb = getattr(model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys()) - - unexpected_keys = [] empty_state_dict = model.state_dict() + unexpected_keys = [param_name for param_name in state_dict if param_name not in empty_state_dict] + is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn") + for param_name, param in state_dict.items(): if param_name not in empty_state_dict: - unexpected_keys.append(param_name) continue - if empty_state_dict[param_name].shape != param.shape: + # We convert floating dtypes to the `dtype` passed except for float8_e4m3fn type. We also want to keep the buffers/params + # in int/uint/bool and not cast them. + is_param_float8_e4m3fn = is_torch_e4m3fn_available and param.dtype == torch.float8_e4m3fn + if torch.is_floating_point(param) and not is_param_float8_e4m3fn: + if ( + keep_in_fp32_modules is not None + and any( + module_to_keep_in_fp32 in param_name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules + ) + and dtype == torch.float16 + ): + dtype = torch.float32 + param = param.to(dtype) + else: + param = param.to(dtype) + + # bnb params are flattened. + if not is_quant_method_bnb and empty_state_dict[param_name].shape != param.shape: model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else "" raise ValueError( f"Cannot load {model_name_or_path_str}because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example." ) - if accepts_dtype: - set_module_tensor_to_device(model, param_name, device, value=param, dtype=dtype) + if not is_quantized or ( + not hf_quantizer.check_quantized_param(model, param, param_name, state_dict, param_device=device) + ): + if accepts_dtype: + set_module_tensor_to_device(model, param_name, device, value=param, dtype=dtype) + else: + set_module_tensor_to_device(model, param_name, device, value=param) else: - set_module_tensor_to_device(model, param_name, device, value=param) + hf_quantizer.create_quantized_param(model, param, param_name, device, state_dict, unexpected_keys) + return unexpected_keys @@ -231,6 +293,35 @@ def _fetch_index_file( return index_file +# Adapted from +# https://github.com/bghira/SimpleTuner/blob/cea2457ab063f6dedb9e697830ae68a96be90641/helpers/training/save_hooks.py#L64 +def _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata): + weight_map = sharded_metadata.get("weight_map", None) + if weight_map is None: + raise KeyError("'weight_map' key not found in the shard index file.") + + # Collect all unique safetensors files from weight_map + files_to_load = set(weight_map.values()) + is_safetensors = all(f.endswith(".safetensors") for f in files_to_load) + merged_state_dict = {} + + # Load tensors from each unique file + for file_name in files_to_load: + part_file_path = os.path.join(sharded_ckpt_cached_folder, file_name) + if not os.path.exists(part_file_path): + raise FileNotFoundError(f"Part file {file_name} not found.") + + if is_safetensors: + with safetensors.safe_open(part_file_path, framework="pt", device="cpu") as f: + for tensor_key in f.keys(): + if tensor_key in weight_map: + merged_state_dict[tensor_key] = f.get_tensor(tensor_key) + else: + merged_state_dict.update(torch.load(part_file_path, weights_only=True, map_location="cpu")) + + return merged_state_dict + + def _fetch_index_file_legacy( is_local, pretrained_model_name_or_path, diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index ad3433889fca..d98854e39039 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -14,13 +14,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import inspect import itertools import json import os import re from collections import OrderedDict -from functools import partial +from functools import partial, wraps from pathlib import Path from typing import Any, Callable, List, Optional, Tuple, Union @@ -31,6 +32,8 @@ from torch import Tensor, nn from .. import __version__ +from ..quantizers import DiffusersAutoQuantizer, DiffusersQuantizer +from ..quantizers.quantization_config import QuantizationMethod from ..utils import ( CONFIG_NAME, FLAX_WEIGHTS_NAME, @@ -43,6 +46,8 @@ _get_model_file, deprecate, is_accelerate_available, + is_bitsandbytes_available, + is_bitsandbytes_version, is_torch_version, logging, ) @@ -56,6 +61,7 @@ _fetch_index_file, _fetch_index_file_legacy, _load_state_dict_into_model, + _merge_sharded_checkpoints, load_model_dict_into_meta, load_state_dict, ) @@ -125,6 +131,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): _supports_gradient_checkpointing = False _keys_to_ignore_on_load_unexpected = None _no_split_modules = None + _keep_in_fp32_modules = None def __init__(self): super().__init__() @@ -308,6 +315,17 @@ def save_pretrained( logger.error(f"Provided path ({save_directory}) should be a directory, not a file") return + hf_quantizer = getattr(self, "hf_quantizer", None) + quantization_serializable = ( + hf_quantizer is not None and isinstance(hf_quantizer, DiffusersQuantizer) and hf_quantizer.is_serializable + ) + + if hf_quantizer is not None and not quantization_serializable: + raise ValueError( + f"The model is quantized with {hf_quantizer.quantization_config.quant_method} and is not serializable - check out the warnings from" + " the logger on the traceback to understand the reason why the quantized model is not serializable." + ) + weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME weights_name = _add_variant(weights_name, variant) weights_name_pattern = weights_name.replace(".bin", "{suffix}.bin").replace( @@ -402,6 +420,18 @@ def save_pretrained( create_pr=create_pr, ) + def dequantize(self): + """ + Potentially dequantize the model in case it has been quantized by a quantization method that support + dequantization. + """ + hf_quantizer = getattr(self, "hf_quantizer", None) + + if hf_quantizer is None: + raise ValueError("You need to first quantize your model in order to dequantize it") + + return hf_quantizer.dequantize(self) + @classmethod @validate_hf_hub_args def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): @@ -524,6 +554,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) variant = kwargs.pop("variant", None) use_safetensors = kwargs.pop("use_safetensors", None) + quantization_config = kwargs.pop("quantization_config", None) allow_pickle = False if use_safetensors is None: @@ -594,6 +625,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P # The max memory utils require PyTorch >= 1.10 to have torch.cuda.mem_get_info. raise ValueError("`low_cpu_mem_usage` and `device_map` require PyTorch >= 1.10.") + if (low_cpu_mem_usage is None or not low_cpu_mem_usage) and cls._keep_in_fp32_modules is not None: + low_cpu_mem_usage = True + logger.info("Set `low_cpu_mem_usage` to True as `_keep_in_fp32_modules` is not None.") + # Load config if we don't provide a configuration config_path = pretrained_model_name_or_path @@ -618,6 +653,52 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P user_agent=user_agent, **kwargs, ) + # no in-place modification of the original config. + config = copy.deepcopy(config) + + # determine initial quantization config. + ####################################### + pre_quantized = "quantization_config" in config and config["quantization_config"] is not None + if pre_quantized or quantization_config is not None: + if pre_quantized: + config["quantization_config"] = DiffusersAutoQuantizer.merge_quantization_configs( + config["quantization_config"], quantization_config + ) + else: + config["quantization_config"] = quantization_config + hf_quantizer = DiffusersAutoQuantizer.from_config( + config["quantization_config"], pre_quantized=pre_quantized + ) + else: + hf_quantizer = None + + if hf_quantizer is not None: + if device_map is not None: + raise NotImplementedError( + "Currently, `device_map` is automatically inferred for quantized models. Support for providing `device_map` as an input will be added in the future." + ) + hf_quantizer.validate_environment(torch_dtype=torch_dtype, from_flax=from_flax, device_map=device_map) + torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype) + + # In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry` + user_agent["quant"] = hf_quantizer.quantization_config.quant_method.value + + # Force-set to `True` for more mem efficiency + if low_cpu_mem_usage is None: + low_cpu_mem_usage = True + logger.warning("`low_cpu_mem_usage` was None, now set to True since model is quantized.") + + # Check if `_keep_in_fp32_modules` is not None + use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and ( + (torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules") + ) + if use_keep_in_fp32_modules: + keep_in_fp32_modules = cls._keep_in_fp32_modules + if not isinstance(keep_in_fp32_modules, list): + keep_in_fp32_modules = [keep_in_fp32_modules] + else: + keep_in_fp32_modules = [] + ####################################### # Determine if we're loading from a directory of sharded checkpoints. is_sharded = False @@ -684,6 +765,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P revision=revision, subfolder=subfolder or "", ) + if hf_quantizer is not None: + model_file = _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata) + logger.info("Merged sharded checkpoints as `hf_quantizer` is not None.") + is_sharded = False elif use_safetensors and not is_sharded: try: @@ -729,13 +814,36 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P with accelerate.init_empty_weights(): model = cls.from_config(config, **unused_kwargs) + if hf_quantizer is not None: + hf_quantizer.preprocess_model( + model=model, device_map=device_map, keep_in_fp32_modules=keep_in_fp32_modules + ) + + # We store the original dtype for quantized models as we cannot easily retrieve it + # once the weights have been quantized + # Note that once you have loaded a quantized model, you can't change its dtype so this will + # remain a single source of truth + config["_pre_quantization_dtype"] = torch_dtype + # if device_map is None, load the state dict and move the params from meta device to the cpu if device_map is None and not is_sharded: - param_device = "cpu" + # `torch.cuda.current_device()` is fine here when `hf_quantizer` is not None. + # It would error out during the `validate_environment()` call above in the absence of cuda. + is_quant_method_bnb = ( + getattr(model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES + ) + if hf_quantizer is None: + param_device = "cpu" + # TODO (sayakpaul, SunMarc): remove this after model loading refactor + elif is_quant_method_bnb: + param_device = torch.cuda.current_device() state_dict = load_state_dict(model_file, variant=variant) model._convert_deprecated_attention_blocks(state_dict) + # move the params from meta device to cpu missing_keys = set(model.state_dict().keys()) - set(state_dict.keys()) + if hf_quantizer is not None: + missing_keys = hf_quantizer.update_missing_keys(model, missing_keys, prefix="") if len(missing_keys) > 0: raise ValueError( f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are" @@ -750,6 +858,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P device=param_device, dtype=torch_dtype, model_name_or_path=pretrained_model_name_or_path, + hf_quantizer=hf_quantizer, + keep_in_fp32_modules=keep_in_fp32_modules, ) if cls._keys_to_ignore_on_load_unexpected is not None: @@ -765,7 +875,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P # Load weights and dispatch according to the device_map # by default the device_map is None and the weights are loaded on the CPU force_hook = True - device_map = _determine_device_map(model, device_map, max_memory, torch_dtype) + device_map = _determine_device_map( + model, device_map, max_memory, torch_dtype, keep_in_fp32_modules, hf_quantizer + ) if device_map is None and is_sharded: # we load the parameters on the cpu device_map = {"": "cpu"} @@ -843,14 +955,25 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P "error_msgs": error_msgs, } + if hf_quantizer is not None: + hf_quantizer.postprocess_model(model) + model.hf_quantizer = hf_quantizer + if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype): raise ValueError( f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}." ) - elif torch_dtype is not None: + elif torch_dtype is not None and hf_quantizer is None: model = model.to(torch_dtype) - model.register_to_config(_name_or_path=pretrained_model_name_or_path) + if hf_quantizer is not None: + # We need to register the _pre_quantization_dtype separately for bookkeeping purposes. + # directly assigning `config["_pre_quantization_dtype"]` won't reflect `_pre_quantization_dtype` + # in `model.config`. We also make sure to purge `_pre_quantization_dtype` when we serialize + # the model config because `_pre_quantization_dtype` is `torch.dtype`, not JSON serializable. + model.register_to_config(_name_or_path=pretrained_model_name_or_path, _pre_quantization_dtype=torch_dtype) + else: + model.register_to_config(_name_or_path=pretrained_model_name_or_path) # Set model in evaluation mode to deactivate DropOut modules by default model.eval() @@ -859,6 +982,76 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P return model + # Adapted from `transformers`. + @wraps(torch.nn.Module.cuda) + def cuda(self, *args, **kwargs): + # Checks if the model has been loaded in 4-bit or 8-bit with BNB + if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES: + if getattr(self, "is_loaded_in_8bit", False): + raise ValueError( + "Calling `cuda()` is not supported for `8-bit` quantized models. " + " Please use the model as it is, since the model has already been set to the correct devices." + ) + elif is_bitsandbytes_version("<", "0.43.2"): + raise ValueError( + "Calling `cuda()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. " + f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2." + ) + return super().cuda(*args, **kwargs) + + # Adapted from `transformers`. + @wraps(torch.nn.Module.to) + def to(self, *args, **kwargs): + dtype_present_in_args = "dtype" in kwargs + + if not dtype_present_in_args: + for arg in args: + if isinstance(arg, torch.dtype): + dtype_present_in_args = True + break + + # Checks if the model has been loaded in 4-bit or 8-bit with BNB + if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES: + if dtype_present_in_args: + raise ValueError( + "You cannot cast a bitsandbytes model in a new `dtype`. Make sure to load the model using `from_pretrained` using the" + " desired `dtype` by passing the correct `torch_dtype` argument." + ) + + if getattr(self, "is_loaded_in_8bit", False): + raise ValueError( + "`.to` is not supported for `8-bit` bitsandbytes models. Please use the model as it is, since the" + " model has already been set to the correct devices and casted to the correct `dtype`." + ) + elif is_bitsandbytes_version("<", "0.43.2"): + raise ValueError( + "Calling `to()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. " + f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2." + ) + return super().to(*args, **kwargs) + + # Taken from `transformers`. + def half(self, *args): + # Checks if the model is quantized + if getattr(self, "is_quantized", False): + raise ValueError( + "`.half()` is not supported for quantized model. Please use the model as it is, since the" + " model has already been cast to the correct `dtype`." + ) + else: + return super().half(*args) + + # Taken from `transformers`. + def float(self, *args): + # Checks if the model is quantized + if getattr(self, "is_quantized", False): + raise ValueError( + "`.float()` is not supported for quantized model. Please use the model as it is, since the" + " model has already been cast to the correct `dtype`." + ) + else: + return super().float(*args) + @classmethod def _load_pretrained_model( cls, @@ -1041,19 +1234,63 @@ def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool 859520964 ``` """ + is_loaded_in_4bit = getattr(self, "is_loaded_in_4bit", False) + + if is_loaded_in_4bit: + if is_bitsandbytes_available(): + import bitsandbytes as bnb + else: + raise ValueError( + "bitsandbytes is not installed but it seems that the model has been loaded in 4bit precision, something went wrong" + " make sure to install bitsandbytes with `pip install bitsandbytes`. You also need a GPU. " + ) if exclude_embeddings: embedding_param_names = [ - f"{name}.weight" - for name, module_type in self.named_modules() - if isinstance(module_type, torch.nn.Embedding) + f"{name}.weight" for name, module_type in self.named_modules() if isinstance(module_type, nn.Embedding) ] - non_embedding_parameters = [ + total_parameters = [ parameter for name, parameter in self.named_parameters() if name not in embedding_param_names ] - return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable) else: - return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable) + total_parameters = list(self.parameters()) + + total_numel = [] + + for param in total_parameters: + if param.requires_grad or not only_trainable: + # For 4bit models, we need to multiply the number of parameters by 2 as half of the parameters are + # used for the 4bit quantization (uint8 tensors are stored) + if is_loaded_in_4bit and isinstance(param, bnb.nn.Params4bit): + if hasattr(param, "element_size"): + num_bytes = param.element_size() + elif hasattr(param, "quant_storage"): + num_bytes = param.quant_storage.itemsize + else: + num_bytes = 1 + total_numel.append(param.numel() * 2 * num_bytes) + else: + total_numel.append(param.numel()) + + return sum(total_numel) + + def get_memory_footprint(self, return_buffers=True): + r""" + Get the memory footprint of a model. This will return the memory footprint of the current model in bytes. + Useful to benchmark the memory footprint of the current model and design some tests. Solution inspired from the + PyTorch discussions: https://discuss.pytorch.org/t/gpu-memory-that-model-uses/56822/2 + + Arguments: + return_buffers (`bool`, *optional*, defaults to `True`): + Whether to return the size of the buffer tensors in the computation of the memory footprint. Buffers + are tensors that do not require gradients and not registered as parameters. E.g. mean and std in batch + norm layers. Please see: https://discuss.pytorch.org/t/what-pytorch-means-by-buffers/120266/2 + """ + mem = sum([param.nelement() * param.element_size() for param in self.parameters()]) + if return_buffers: + mem_bufs = sum([buf.nelement() * buf.element_size() for buf in self.buffers()]) + mem = mem + mem_bufs + return mem def _convert_deprecated_attention_blocks(self, state_dict: OrderedDict) -> None: deprecated_attention_block_paths = [] diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 6721706b5689..2c66cd43d342 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -44,6 +44,7 @@ from ..models import AutoencoderKL from ..models.attention_processor import FusedAttnProcessor2_0 from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, ModelMixin +from ..quantizers.bitsandbytes.utils import _check_bnb_status from ..schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME from ..utils import ( CONFIG_NAME, @@ -54,6 +55,7 @@ is_accelerate_version, is_torch_npu_available, is_torch_version, + is_transformers_version, logging, numpy_to_pil, ) @@ -407,6 +409,8 @@ def module_is_offloaded(module): pipeline_is_sequentially_offloaded = any( module_is_sequentially_offloaded(module) for _, module in self.components.items() ) + # pipeline_has_8bit_bnb_quant = any(_check_bnb_status(module)[-1] for _, module in self.components.items()) + # not pipeline_has_8bit_bnb_quant if pipeline_is_sequentially_offloaded and device and torch.device(device).type == "cuda": raise ValueError( "It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading." @@ -431,18 +435,23 @@ def module_is_offloaded(module): is_offloaded = pipeline_is_offloaded or pipeline_is_sequentially_offloaded for module in modules: - is_loaded_in_8bit = hasattr(module, "is_loaded_in_8bit") and module.is_loaded_in_8bit + _, is_loaded_in_4bit_bnb, is_loaded_in_8bit_bnb = _check_bnb_status(module) - if is_loaded_in_8bit and dtype is not None: + if (is_loaded_in_4bit_bnb or is_loaded_in_8bit_bnb) and dtype is not None: logger.warning( - f"The module '{module.__class__.__name__}' has been loaded in 8bit and conversion to {dtype} is not yet supported. Module is still in 8bit precision." + f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` {'4bit' if is_loaded_in_4bit_bnb else '8bit'} and conversion to {dtype} is not supported. Module is still in {'4bit' if is_loaded_in_4bit_bnb else '8bit'} precision." ) - if is_loaded_in_8bit and device is not None: + if is_loaded_in_8bit_bnb and device is not None: logger.warning( - f"The module '{module.__class__.__name__}' has been loaded in 8bit and moving it to {dtype} via `.to()` is not yet supported. Module is still on {module.device}." + f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` 8bit and moving it to {device} via `.to()` is not supported. Module is still on {module.device}." ) - else: + + # This can happen for `transformer` models. CPU placement was added in + # https://github.com/huggingface/transformers/pull/33122. So, we guard this accordingly. + if is_loaded_in_4bit_bnb and device is not None and is_transformers_version(">", "4.44.0"): + module.to(device=device) + elif not is_loaded_in_8bit_bnb: module.to(device, dtype) if ( @@ -450,6 +459,7 @@ def module_is_offloaded(module): and str(device) in ["cpu"] and not silence_dtype_warnings and not is_offloaded + and not is_loaded_in_4bit_bnb ): logger.warning( "Pipelines loaded with `dtype=torch.float16` cannot run with `cpu` device. It" @@ -1038,9 +1048,18 @@ def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[t hook = None for model_str in self.model_cpu_offload_seq.split("->"): model = all_model_components.pop(model_str, None) + if not isinstance(model, torch.nn.Module): continue + # This is because the model would already be placed on a CUDA device. + _, _, is_loaded_in_8bit_bnb = _check_bnb_status(model) + if is_loaded_in_8bit_bnb: + logger.info( + f"Skipping the hook placement for the {model.__class__.__name__} as it is loaded in `bitsandbytes` 8bit." + ) + continue + _, hook = cpu_offload_with_hook(model, device, prev_module_hook=hook) self._all_hooks.append(hook) diff --git a/src/diffusers/quantizers/__init__.py b/src/diffusers/quantizers/__init__.py new file mode 100644 index 000000000000..93852d29ef59 --- /dev/null +++ b/src/diffusers/quantizers/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .auto import DiffusersAutoQuantizationConfig, DiffusersAutoQuantizer +from .base import DiffusersQuantizer diff --git a/src/diffusers/quantizers/auto.py b/src/diffusers/quantizers/auto.py new file mode 100644 index 000000000000..f231f279e13a --- /dev/null +++ b/src/diffusers/quantizers/auto.py @@ -0,0 +1,137 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Adapted from +https://github.com/huggingface/transformers/blob/c409cd81777fb27aadc043ed3d8339dbc020fb3b/src/transformers/quantizers/auto.py +""" +import warnings +from typing import Dict, Optional, Union + +from .bitsandbytes import BnB4BitDiffusersQuantizer, BnB8BitDiffusersQuantizer +from .quantization_config import BitsAndBytesConfig, QuantizationConfigMixin, QuantizationMethod + + +AUTO_QUANTIZER_MAPPING = { + "bitsandbytes_4bit": BnB4BitDiffusersQuantizer, + "bitsandbytes_8bit": BnB8BitDiffusersQuantizer, +} + +AUTO_QUANTIZATION_CONFIG_MAPPING = { + "bitsandbytes_4bit": BitsAndBytesConfig, + "bitsandbytes_8bit": BitsAndBytesConfig, +} + + +class DiffusersAutoQuantizationConfig: + """ + The auto diffusers quantization config class that takes care of automatically dispatching to the correct + quantization config given a quantization config stored in a dictionary. + """ + + @classmethod + def from_dict(cls, quantization_config_dict: Dict): + quant_method = quantization_config_dict.get("quant_method", None) + # We need a special care for bnb models to make sure everything is BC .. + if quantization_config_dict.get("load_in_8bit", False) or quantization_config_dict.get("load_in_4bit", False): + suffix = "_4bit" if quantization_config_dict.get("load_in_4bit", False) else "_8bit" + quant_method = QuantizationMethod.BITS_AND_BYTES + suffix + elif quant_method is None: + raise ValueError( + "The model's quantization config from the arguments has no `quant_method` attribute. Make sure that the model has been correctly quantized" + ) + + if quant_method not in AUTO_QUANTIZATION_CONFIG_MAPPING.keys(): + raise ValueError( + f"Unknown quantization type, got {quant_method} - supported types are:" + f" {list(AUTO_QUANTIZER_MAPPING.keys())}" + ) + + target_cls = AUTO_QUANTIZATION_CONFIG_MAPPING[quant_method] + return target_cls.from_dict(quantization_config_dict) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): + model_config = cls.load_config(pretrained_model_name_or_path, **kwargs) + if getattr(model_config, "quantization_config", None) is None: + raise ValueError( + f"Did not found a `quantization_config` in {pretrained_model_name_or_path}. Make sure that the model is correctly quantized." + ) + quantization_config_dict = model_config.quantization_config + quantization_config = cls.from_dict(quantization_config_dict) + # Update with potential kwargs that are passed through from_pretrained. + quantization_config.update(kwargs) + return quantization_config + + +class DiffusersAutoQuantizer: + """ + The auto diffusers quantizer class that takes care of automatically instantiating to the correct + `DiffusersQuantizer` given the `QuantizationConfig`. + """ + + @classmethod + def from_config(cls, quantization_config: Union[QuantizationConfigMixin, Dict], **kwargs): + # Convert it to a QuantizationConfig if the q_config is a dict + if isinstance(quantization_config, dict): + quantization_config = DiffusersAutoQuantizationConfig.from_dict(quantization_config) + + quant_method = quantization_config.quant_method + + # Again, we need a special care for bnb as we have a single quantization config + # class for both 4-bit and 8-bit quantization + if quant_method == QuantizationMethod.BITS_AND_BYTES: + if quantization_config.load_in_8bit: + quant_method += "_8bit" + else: + quant_method += "_4bit" + + if quant_method not in AUTO_QUANTIZER_MAPPING.keys(): + raise ValueError( + f"Unknown quantization type, got {quant_method} - supported types are:" + f" {list(AUTO_QUANTIZER_MAPPING.keys())}" + ) + + target_cls = AUTO_QUANTIZER_MAPPING[quant_method] + return target_cls(quantization_config, **kwargs) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): + quantization_config = DiffusersAutoQuantizationConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) + return cls.from_config(quantization_config) + + @classmethod + def merge_quantization_configs( + cls, + quantization_config: Union[dict, QuantizationConfigMixin], + quantization_config_from_args: Optional[QuantizationConfigMixin], + ): + """ + handles situations where both quantization_config from args and quantization_config from model config are + present. + """ + if quantization_config_from_args is not None: + warning_msg = ( + "You passed `quantization_config` or equivalent parameters to `from_pretrained` but the model you're loading" + " already has a `quantization_config` attribute. The `quantization_config` from the model will be used." + ) + else: + warning_msg = "" + + if isinstance(quantization_config, dict): + quantization_config = DiffusersAutoQuantizationConfig.from_dict(quantization_config) + + if warning_msg != "": + warnings.warn(warning_msg) + + return quantization_config diff --git a/src/diffusers/quantizers/base.py b/src/diffusers/quantizers/base.py new file mode 100644 index 000000000000..017136a98854 --- /dev/null +++ b/src/diffusers/quantizers/base.py @@ -0,0 +1,230 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Adapted from +https://github.com/huggingface/transformers/blob/52cb4034ada381fe1ffe8d428a1076e5411a8026/src/transformers/quantizers/base.py +""" + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union + +from ..utils import is_torch_available +from .quantization_config import QuantizationConfigMixin + + +if TYPE_CHECKING: + from ..models.modeling_utils import ModelMixin + +if is_torch_available(): + import torch + + +class DiffusersQuantizer(ABC): + """ + Abstract class of the HuggingFace quantizer. Supports for now quantizing HF diffusers models for inference and/or + quantization. This class is used only for diffusers.models.modeling_utils.ModelMixin.from_pretrained and cannot be + easily used outside the scope of that method yet. + + Attributes + quantization_config (`diffusers.quantizers.quantization_config.QuantizationConfigMixin`): + The quantization config that defines the quantization parameters of your model that you want to quantize. + modules_to_not_convert (`List[str]`, *optional*): + The list of module names to not convert when quantizing the model. + required_packages (`List[str]`, *optional*): + The list of required pip packages to install prior to using the quantizer + requires_calibration (`bool`): + Whether the quantization method requires to calibrate the model before using it. + """ + + requires_calibration = False + required_packages = None + + def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs): + self.quantization_config = quantization_config + + # -- Handle extra kwargs below -- + self.modules_to_not_convert = kwargs.pop("modules_to_not_convert", []) + self.pre_quantized = kwargs.pop("pre_quantized", True) + + if not self.pre_quantized and self.requires_calibration: + raise ValueError( + f"The quantization method {quantization_config.quant_method} does require the model to be pre-quantized." + f" You explicitly passed `pre_quantized=False` meaning your model weights are not quantized. Make sure to " + f"pass `pre_quantized=True` while knowing what you are doing." + ) + + def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": + """ + Some quantization methods require to explicitly set the dtype of the model to a target dtype. You need to + override this method in case you want to make sure that behavior is preserved + + Args: + torch_dtype (`torch.dtype`): + The input dtype that is passed in `from_pretrained` + """ + return torch_dtype + + def update_device_map(self, device_map: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]: + """ + Override this method if you want to pass a override the existing device map with a new one. E.g. for + bitsandbytes, since `accelerate` is a hard requirement, if no device_map is passed, the device_map is set to + `"auto"`` + + Args: + device_map (`Union[dict, str]`, *optional*): + The device_map that is passed through the `from_pretrained` method. + """ + return device_map + + def adjust_target_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": + """ + Override this method if you want to adjust the `target_dtype` variable used in `from_pretrained` to compute the + device_map in case the device_map is a `str`. E.g. for bitsandbytes we force-set `target_dtype` to `torch.int8` + and for 4-bit we pass a custom enum `accelerate.CustomDtype.int4`. + + Args: + torch_dtype (`torch.dtype`, *optional*): + The torch_dtype that is used to compute the device_map. + """ + return torch_dtype + + def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]: + """ + Override this method if you want to adjust the `missing_keys`. + + Args: + missing_keys (`List[str]`, *optional*): + The list of missing keys in the checkpoint compared to the state dict of the model + """ + return missing_keys + + def get_special_dtypes_update(self, model, torch_dtype: "torch.dtype") -> Dict[str, "torch.dtype"]: + """ + returns dtypes for modules that are not quantized - used for the computation of the device_map in case one + passes a str as a device_map. The method will use the `modules_to_not_convert` that is modified in + `_process_model_before_weight_loading`. `diffusers` models don't have any `modules_to_not_convert` attributes + yet but this can change soon in the future. + + Args: + model (`~diffusers.models.modeling_utils.ModelMixin`): + The model to quantize + torch_dtype (`torch.dtype`): + The dtype passed in `from_pretrained` method. + """ + + return { + name: torch_dtype + for name, _ in model.named_parameters() + if any(m in name for m in self.modules_to_not_convert) + } + + def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]: + """adjust max_memory argument for infer_auto_device_map() if extra memory is needed for quantization""" + return max_memory + + def check_quantized_param( + self, + model: "ModelMixin", + param_value: "torch.Tensor", + param_name: str, + state_dict: Dict[str, Any], + **kwargs, + ) -> bool: + """ + checks if a loaded state_dict component is part of quantized param + some validation; only defined for + quantization methods that require to create a new parameters for quantization. + """ + return False + + def create_quantized_param(self, *args, **kwargs) -> "torch.nn.Parameter": + """ + takes needed components from state_dict and creates quantized param. + """ + if not hasattr(self, "check_quantized_param"): + raise AttributeError( + f"`.create_quantized_param()` method is not supported by quantizer class {self.__class__.__name__}." + ) + + def validate_environment(self, *args, **kwargs): + """ + This method is used to potentially check for potential conflicts with arguments that are passed in + `from_pretrained`. You need to define it for all future quantizers that are integrated with diffusers. If no + explicit check are needed, simply return nothing. + """ + return + + def preprocess_model(self, model: "ModelMixin", **kwargs): + """ + Setting model attributes and/or converting model before weights loading. At this point the model should be + initialized on the meta device so you can freely manipulate the skeleton of the model in order to replace + modules in-place. Make sure to override the abstract method `_process_model_before_weight_loading`. + + Args: + model (`~diffusers.models.modeling_utils.ModelMixin`): + The model to quantize + kwargs (`dict`, *optional*): + The keyword arguments that are passed along `_process_model_before_weight_loading`. + """ + model.is_quantized = True + model.quantization_method = self.quantization_config.quant_method + return self._process_model_before_weight_loading(model, **kwargs) + + def postprocess_model(self, model: "ModelMixin", **kwargs): + """ + Post-process the model post weights loading. Make sure to override the abstract method + `_process_model_after_weight_loading`. + + Args: + model (`~diffusers.models.modeling_utils.ModelMixin`): + The model to quantize + kwargs (`dict`, *optional*): + The keyword arguments that are passed along `_process_model_after_weight_loading`. + """ + return self._process_model_after_weight_loading(model, **kwargs) + + def dequantize(self, model): + """ + Potentially dequantize the model to retrive the original model, with some loss in accuracy / performance. Note + not all quantization schemes support this. + """ + model = self._dequantize(model) + + # Delete quantizer and quantization config + del model.hf_quantizer + + return model + + def _dequantize(self, model): + raise NotImplementedError( + f"{self.quantization_config.quant_method} has no implementation of `dequantize`, please raise an issue on GitHub." + ) + + @abstractmethod + def _process_model_before_weight_loading(self, model, **kwargs): + ... + + @abstractmethod + def _process_model_after_weight_loading(self, model, **kwargs): + ... + + @property + @abstractmethod + def is_serializable(self): + ... + + @property + @abstractmethod + def is_trainable(self): + ... diff --git a/src/diffusers/quantizers/bitsandbytes/__init__.py b/src/diffusers/quantizers/bitsandbytes/__init__.py new file mode 100644 index 000000000000..9e745bc810fa --- /dev/null +++ b/src/diffusers/quantizers/bitsandbytes/__init__.py @@ -0,0 +1,2 @@ +from .bnb_quantizer import BnB4BitDiffusersQuantizer, BnB8BitDiffusersQuantizer +from .utils import dequantize_and_replace, dequantize_bnb_weight, replace_with_bnb_linear diff --git a/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py b/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py new file mode 100644 index 000000000000..e3041aba60ae --- /dev/null +++ b/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py @@ -0,0 +1,549 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Adapted from +https://github.com/huggingface/transformers/blob/c409cd81777fb27aadc043ed3d8339dbc020fb3b/src/transformers/quantizers/quantizer_bnb_4bit.py +""" + +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union + +from ...utils import get_module_from_name +from ..base import DiffusersQuantizer + + +if TYPE_CHECKING: + from ...models.modeling_utils import ModelMixin + +from ...utils import ( + is_accelerate_available, + is_accelerate_version, + is_bitsandbytes_available, + is_bitsandbytes_version, + is_torch_available, + logging, +) + + +if is_torch_available(): + import torch + +logger = logging.get_logger(__name__) + + +class BnB4BitDiffusersQuantizer(DiffusersQuantizer): + """ + 4-bit quantization from bitsandbytes.py quantization method: + before loading: converts transformer layers into Linear4bit during loading: load 16bit weight and pass to the + layer object after: quantizes individual weights in Linear4bit into 4bit at the first .cuda() call saving: + from state dict, as usual; saves weights and `quant_state` components + loading: + need to locate `quant_state` components and pass to Param4bit constructor + """ + + use_keep_in_fp32_modules = True + requires_calibration = False + + def __init__(self, quantization_config, **kwargs): + super().__init__(quantization_config, **kwargs) + + if self.quantization_config.llm_int8_skip_modules is not None: + self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules + + def validate_environment(self, *args, **kwargs): + if not torch.cuda.is_available(): + raise RuntimeError("No GPU found. A GPU is needed for quantization.") + if not is_accelerate_available() or is_accelerate_version("<", "0.26.0"): + raise ImportError( + "Using `bitsandbytes` 4-bit quantization requires Accelerate: `pip install 'accelerate>=0.26.0'`" + ) + if not is_bitsandbytes_available() or is_bitsandbytes_version("<", "0.43.3"): + raise ImportError( + "Using `bitsandbytes` 4-bit quantization requires the latest version of bitsandbytes: `pip install -U bitsandbytes`" + ) + + if kwargs.get("from_flax", False): + raise ValueError( + "Converting into 4-bit weights from flax weights is currently not supported, please make" + " sure the weights are in PyTorch format." + ) + + device_map = kwargs.get("device_map", None) + if ( + device_map is not None + and isinstance(device_map, dict) + and not self.quantization_config.llm_int8_enable_fp32_cpu_offload + ): + device_map_without_no_convert = { + key: device_map[key] for key in device_map.keys() if key not in self.modules_to_not_convert + } + if "cpu" in device_map_without_no_convert.values() or "disk" in device_map_without_no_convert.values(): + raise ValueError( + "Some modules are dispatched on the CPU or the disk. Make sure you have enough GPU RAM to fit the " + "quantized model. If you want to dispatch the model on the CPU or the disk while keeping these modules " + "in 32-bit, you need to set `load_in_8bit_fp32_cpu_offload=True` and pass a custom `device_map` to " + "`from_pretrained`. Check " + "https://huggingface.co/docs/transformers/main/en/main_classes/quantization#offload-between-cpu-and-gpu " + "for more details. " + ) + + def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype": + if target_dtype != torch.int8: + from accelerate.utils import CustomDtype + + logger.info("target_dtype {target_dtype} is replaced by `CustomDtype.INT4` for 4-bit BnB quantization") + return CustomDtype.INT4 + else: + raise ValueError(f"Wrong `target_dtype` ({target_dtype}) provided.") + + def check_quantized_param( + self, + model: "ModelMixin", + param_value: "torch.Tensor", + param_name: str, + state_dict: Dict[str, Any], + **kwargs, + ) -> bool: + import bitsandbytes as bnb + + module, tensor_name = get_module_from_name(model, param_name) + if isinstance(module._parameters.get(tensor_name, None), bnb.nn.Params4bit): + # Add here check for loaded components' dtypes once serialization is implemented + return True + elif isinstance(module, bnb.nn.Linear4bit) and tensor_name == "bias": + # bias could be loaded by regular set_module_tensor_to_device() from accelerate, + # but it would wrongly use uninitialized weight there. + return True + else: + return False + + def create_quantized_param( + self, + model: "ModelMixin", + param_value: "torch.Tensor", + param_name: str, + target_device: "torch.device", + state_dict: Dict[str, Any], + unexpected_keys: Optional[List[str]] = None, + ): + import bitsandbytes as bnb + + module, tensor_name = get_module_from_name(model, param_name) + + if tensor_name not in module._parameters: + raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.") + + old_value = getattr(module, tensor_name) + + if tensor_name == "bias": + if param_value is None: + new_value = old_value.to(target_device) + else: + new_value = param_value.to(target_device) + + new_value = torch.nn.Parameter(new_value, requires_grad=old_value.requires_grad) + module._parameters[tensor_name] = new_value + return + + if not isinstance(module._parameters[tensor_name], bnb.nn.Params4bit): + raise ValueError("this function only loads `Linear4bit components`") + if ( + old_value.device == torch.device("meta") + and target_device not in ["meta", torch.device("meta")] + and param_value is None + ): + raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {target_device}.") + + # construct `new_value` for the module._parameters[tensor_name]: + if self.pre_quantized: + # 4bit loading. Collecting components for restoring quantized weight + # This can be expanded to make a universal call for any quantized weight loading + + if not self.is_serializable: + raise ValueError( + "Detected int4 weights but the version of bitsandbytes is not compatible with int4 serialization. " + "Make sure to download the latest `bitsandbytes` version. `pip install --upgrade bitsandbytes`." + ) + + if (param_name + ".quant_state.bitsandbytes__fp4" not in state_dict) and ( + param_name + ".quant_state.bitsandbytes__nf4" not in state_dict + ): + raise ValueError( + f"Supplied state dict for {param_name} does not contain `bitsandbytes__*` and possibly other `quantized_stats` components." + ) + + quantized_stats = {} + for k, v in state_dict.items(): + # `startswith` to counter for edge cases where `param_name` + # substring can be present in multiple places in the `state_dict` + if param_name + "." in k and k.startswith(param_name): + quantized_stats[k] = v + if unexpected_keys is not None and k in unexpected_keys: + unexpected_keys.remove(k) + + new_value = bnb.nn.Params4bit.from_prequantized( + data=param_value, + quantized_stats=quantized_stats, + requires_grad=False, + device=target_device, + ) + else: + new_value = param_value.to("cpu") + kwargs = old_value.__dict__ + new_value = bnb.nn.Params4bit(new_value, requires_grad=False, **kwargs).to(target_device) + + module._parameters[tensor_name] = new_value + + def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]: + # need more space for buffers that are created during quantization + max_memory = {key: val * 0.90 for key, val in max_memory.items()} + return max_memory + + def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": + if torch_dtype is None: + # We force the `dtype` to be float16, this is a requirement from `bitsandbytes` + logger.info( + "Overriding torch_dtype=%s with `torch_dtype=torch.float16` due to " + "requirements of `bitsandbytes` to enable model loading in 8-bit or 4-bit. " + "Pass your own torch_dtype to specify the dtype of the remaining non-linear layers or pass" + " torch_dtype=torch.float16 to remove this warning.", + torch_dtype, + ) + torch_dtype = torch.float16 + return torch_dtype + + # (sayakpaul): I think it could be better to disable custom `device_map`s + # for the first phase of the integration in the interest of simplicity. + # Commenting this for discussions on the PR. + # def update_device_map(self, device_map): + # if device_map is None: + # device_map = {"": torch.cuda.current_device()} + # logger.info( + # "The device_map was not initialized. " + # "Setting device_map to {'':torch.cuda.current_device()}. " + # "If you want to use the model for inference, please set device_map ='auto' " + # ) + # return device_map + + def _process_model_before_weight_loading( + self, + model: "ModelMixin", + device_map, + keep_in_fp32_modules: List[str] = [], + **kwargs, + ): + from .utils import replace_with_bnb_linear + + load_in_8bit_fp32_cpu_offload = self.quantization_config.llm_int8_enable_fp32_cpu_offload + + # We may keep some modules such as the `proj_out` in their original dtype for numerical stability reasons + self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules + + if not isinstance(self.modules_to_not_convert, list): + self.modules_to_not_convert = [self.modules_to_not_convert] + + self.modules_to_not_convert.extend(keep_in_fp32_modules) + + # Extend `self.modules_to_not_convert` to keys that are supposed to be offloaded to `cpu` or `disk` + if isinstance(device_map, dict) and len(device_map.keys()) > 1: + keys_on_cpu = [key for key, value in device_map.items() if value in ["disk", "cpu"]] + + if len(keys_on_cpu) > 0 and not load_in_8bit_fp32_cpu_offload: + raise ValueError( + "If you want to offload some keys to `cpu` or `disk`, you need to set " + "`llm_int8_enable_fp32_cpu_offload=True`. Note that these modules will not be " + " converted to 8-bit but kept in 32-bit." + ) + self.modules_to_not_convert.extend(keys_on_cpu) + + # Purge `None`. + # Unlike `transformers`, we don't know if we should always keep certain modules in FP32 + # in case of diffusion transformer models. For language models and others alike, `lm_head` + # and tied modules are usually kept in FP32. + self.modules_to_not_convert = [module for module in self.modules_to_not_convert if module is not None] + + model = replace_with_bnb_linear( + model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config + ) + model.config.quantization_config = self.quantization_config + + def _process_model_after_weight_loading(self, model: "ModelMixin", **kwargs): + model.is_loaded_in_4bit = True + model.is_4bit_serializable = self.is_serializable + return model + + @property + def is_serializable(self): + # Because we're mandating `bitsandbytes` 0.43.3. + return True + + @property + def is_trainable(self) -> bool: + # Because we're mandating `bitsandbytes` 0.43.3. + return True + + def _dequantize(self, model): + from .utils import dequantize_and_replace + + is_model_on_cpu = model.device.type == "cpu" + if is_model_on_cpu: + logger.info( + "Model was found to be on CPU (could happen as a result of `enable_model_cpu_offload()`). So, moving it to GPU. After dequantization, will move the model back to CPU again to preserve the previous device." + ) + model.to(torch.cuda.current_device()) + + model = dequantize_and_replace( + model, self.modules_to_not_convert, quantization_config=self.quantization_config + ) + if is_model_on_cpu: + model.to("cpu") + return model + + +class BnB8BitDiffusersQuantizer(DiffusersQuantizer): + """ + 8-bit quantization from bitsandbytes quantization method: + before loading: converts transformer layers into Linear8bitLt during loading: load 16bit weight and pass to the + layer object after: quantizes individual weights in Linear8bitLt into 8bit at fitst .cuda() call + saving: + from state dict, as usual; saves weights and 'SCB' component + loading: + need to locate SCB component and pass to the Linear8bitLt object + """ + + use_keep_in_fp32_modules = True + requires_calibration = False + + def __init__(self, quantization_config, **kwargs): + super().__init__(quantization_config, **kwargs) + + if self.quantization_config.llm_int8_skip_modules is not None: + self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules + + # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.validate_environment with 4-bit->8-bit + def validate_environment(self, *args, **kwargs): + if not torch.cuda.is_available(): + raise RuntimeError("No GPU found. A GPU is needed for quantization.") + if not is_accelerate_available() or is_accelerate_version("<", "0.26.0"): + raise ImportError( + "Using `bitsandbytes` 8-bit quantization requires Accelerate: `pip install 'accelerate>=0.26.0'`" + ) + if not is_bitsandbytes_available() or is_bitsandbytes_version("<", "0.43.3"): + raise ImportError( + "Using `bitsandbytes` 8-bit quantization requires the latest version of bitsandbytes: `pip install -U bitsandbytes`" + ) + + if kwargs.get("from_flax", False): + raise ValueError( + "Converting into 8-bit weights from flax weights is currently not supported, please make" + " sure the weights are in PyTorch format." + ) + + device_map = kwargs.get("device_map", None) + if ( + device_map is not None + and isinstance(device_map, dict) + and not self.quantization_config.llm_int8_enable_fp32_cpu_offload + ): + device_map_without_no_convert = { + key: device_map[key] for key in device_map.keys() if key not in self.modules_to_not_convert + } + if "cpu" in device_map_without_no_convert.values() or "disk" in device_map_without_no_convert.values(): + raise ValueError( + "Some modules are dispatched on the CPU or the disk. Make sure you have enough GPU RAM to fit the " + "quantized model. If you want to dispatch the model on the CPU or the disk while keeping these modules " + "in 32-bit, you need to set `load_in_8bit_fp32_cpu_offload=True` and pass a custom `device_map` to " + "`from_pretrained`. Check " + "https://huggingface.co/docs/transformers/main/en/main_classes/quantization#offload-between-cpu-and-gpu " + "for more details. " + ) + + # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.adjust_max_memory + def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]: + # need more space for buffers that are created during quantization + max_memory = {key: val * 0.90 for key, val in max_memory.items()} + return max_memory + + # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.update_torch_dtype + def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": + if torch_dtype is None: + # We force the `dtype` to be float16, this is a requirement from `bitsandbytes` + logger.info( + "Overriding torch_dtype=%s with `torch_dtype=torch.float16` due to " + "requirements of `bitsandbytes` to enable model loading in 8-bit or 4-bit. " + "Pass your own torch_dtype to specify the dtype of the remaining non-linear layers or pass" + " torch_dtype=torch.float16 to remove this warning.", + torch_dtype, + ) + torch_dtype = torch.float16 + return torch_dtype + + # # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.update_device_map + # def update_device_map(self, device_map): + # if device_map is None: + # device_map = {"": torch.cuda.current_device()} + # logger.info( + # "The device_map was not initialized. " + # "Setting device_map to {'':torch.cuda.current_device()}. " + # "If you want to use the model for inference, please set device_map ='auto' " + # ) + # return device_map + + def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype": + if target_dtype != torch.int8: + logger.info("target_dtype {target_dtype} is replaced by `torch.int8` for 8-bit BnB quantization") + return torch.int8 + + def check_quantized_param( + self, + model: "ModelMixin", + param_value: "torch.Tensor", + param_name: str, + state_dict: Dict[str, Any], + **kwargs, + ): + import bitsandbytes as bnb + + module, tensor_name = get_module_from_name(model, param_name) + if isinstance(module._parameters.get(tensor_name, None), bnb.nn.Int8Params): + if self.pre_quantized: + if param_name.replace("weight", "SCB") not in state_dict.keys(): + raise ValueError("Missing quantization component `SCB`") + if param_value.dtype != torch.int8: + raise ValueError( + f"Incompatible dtype `{param_value.dtype}` when loading 8-bit prequantized weight. Expected `torch.int8`." + ) + return True + return False + + def create_quantized_param( + self, + model: "ModelMixin", + param_value: "torch.Tensor", + param_name: str, + target_device: "torch.device", + state_dict: Dict[str, Any], + unexpected_keys: Optional[List[str]] = None, + ): + import bitsandbytes as bnb + + fp16_statistics_key = param_name.replace("weight", "SCB") + fp16_weights_format_key = param_name.replace("weight", "weight_format") + + fp16_statistics = state_dict.get(fp16_statistics_key, None) + fp16_weights_format = state_dict.get(fp16_weights_format_key, None) + + module, tensor_name = get_module_from_name(model, param_name) + if tensor_name not in module._parameters: + raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.") + + old_value = getattr(module, tensor_name) + + if not isinstance(module._parameters[tensor_name], bnb.nn.Int8Params): + raise ValueError(f"Parameter `{tensor_name}` should only be a `bnb.nn.Int8Params` instance.") + if ( + old_value.device == torch.device("meta") + and target_device not in ["meta", torch.device("meta")] + and param_value is None + ): + raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {target_device}.") + + new_value = param_value.to("cpu") + if self.pre_quantized and not self.is_serializable: + raise ValueError( + "Detected int8 weights but the version of bitsandbytes is not compatible with int8 serialization. " + "Make sure to download the latest `bitsandbytes` version. `pip install --upgrade bitsandbytes`." + ) + + kwargs = old_value.__dict__ + new_value = bnb.nn.Int8Params(new_value, requires_grad=False, **kwargs).to(target_device) + + module._parameters[tensor_name] = new_value + if fp16_statistics is not None: + setattr(module.weight, "SCB", fp16_statistics.to(target_device)) + if unexpected_keys is not None: + unexpected_keys.remove(fp16_statistics_key) + + # We just need to pop the `weight_format` keys from the state dict to remove unneeded + # messages. The correct format is correctly retrieved during the first forward pass. + if fp16_weights_format is not None and unexpected_keys is not None: + unexpected_keys.remove(fp16_weights_format_key) + + # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer._process_model_after_weight_loading with 4bit->8bit + def _process_model_after_weight_loading(self, model: "ModelMixin", **kwargs): + model.is_loaded_in_8bit = True + model.is_8bit_serializable = self.is_serializable + return model + + # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer._process_model_before_weight_loading + def _process_model_before_weight_loading( + self, + model: "ModelMixin", + device_map, + keep_in_fp32_modules: List[str] = [], + **kwargs, + ): + from .utils import replace_with_bnb_linear + + load_in_8bit_fp32_cpu_offload = self.quantization_config.llm_int8_enable_fp32_cpu_offload + + # We may keep some modules such as the `proj_out` in their original dtype for numerical stability reasons + self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules + + if not isinstance(self.modules_to_not_convert, list): + self.modules_to_not_convert = [self.modules_to_not_convert] + + self.modules_to_not_convert.extend(keep_in_fp32_modules) + + # Extend `self.modules_to_not_convert` to keys that are supposed to be offloaded to `cpu` or `disk` + if isinstance(device_map, dict) and len(device_map.keys()) > 1: + keys_on_cpu = [key for key, value in device_map.items() if value in ["disk", "cpu"]] + + if len(keys_on_cpu) > 0 and not load_in_8bit_fp32_cpu_offload: + raise ValueError( + "If you want to offload some keys to `cpu` or `disk`, you need to set " + "`llm_int8_enable_fp32_cpu_offload=True`. Note that these modules will not be " + " converted to 8-bit but kept in 32-bit." + ) + self.modules_to_not_convert.extend(keys_on_cpu) + + # Purge `None`. + # Unlike `transformers`, we don't know if we should always keep certain modules in FP32 + # in case of diffusion transformer models. For language models and others alike, `lm_head` + # and tied modules are usually kept in FP32. + self.modules_to_not_convert = [module for module in self.modules_to_not_convert if module is not None] + + model = replace_with_bnb_linear( + model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config + ) + model.config.quantization_config = self.quantization_config + + @property + # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.is_serializable + def is_serializable(self): + # Because we're mandating `bitsandbytes` 0.43.3. + return True + + @property + # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.is_serializable + def is_trainable(self) -> bool: + # Because we're mandating `bitsandbytes` 0.43.3. + return True + + def _dequantize(self, model): + from .utils import dequantize_and_replace + + model = dequantize_and_replace( + model, self.modules_to_not_convert, quantization_config=self.quantization_config + ) + return model diff --git a/src/diffusers/quantizers/bitsandbytes/utils.py b/src/diffusers/quantizers/bitsandbytes/utils.py new file mode 100644 index 000000000000..03755db3d1ec --- /dev/null +++ b/src/diffusers/quantizers/bitsandbytes/utils.py @@ -0,0 +1,306 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Adapted from +https://github.com/huggingface/transformers/blob/c409cd81777fb27aadc043ed3d8339dbc020fb3b/src/transformers/integrations/bitsandbytes.py +""" + +import inspect +from inspect import signature +from typing import Union + +from ...utils import is_accelerate_available, is_bitsandbytes_available, is_torch_available, logging +from ..quantization_config import QuantizationMethod + + +if is_torch_available(): + import torch + import torch.nn as nn + +if is_bitsandbytes_available(): + import bitsandbytes as bnb + +if is_accelerate_available(): + import accelerate + from accelerate import init_empty_weights + from accelerate.hooks import add_hook_to_module, remove_hook_from_module + +logger = logging.get_logger(__name__) + + +def _replace_with_bnb_linear( + model, + modules_to_not_convert=None, + current_key_name=None, + quantization_config=None, + has_been_replaced=False, +): + """ + Private method that wraps the recursion for module replacement. + + Returns the converted model and a boolean that indicates if the conversion has been successfull or not. + """ + for name, module in model.named_children(): + if current_key_name is None: + current_key_name = [] + current_key_name.append(name) + + if isinstance(module, nn.Linear) and name not in modules_to_not_convert: + # Check if the current key is not in the `modules_to_not_convert` + current_key_name_str = ".".join(current_key_name) + if not any( + (key + "." in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert + ): + with init_empty_weights(): + in_features = module.in_features + out_features = module.out_features + + if quantization_config.quantization_method() == "llm_int8": + model._modules[name] = bnb.nn.Linear8bitLt( + in_features, + out_features, + module.bias is not None, + has_fp16_weights=quantization_config.llm_int8_has_fp16_weight, + threshold=quantization_config.llm_int8_threshold, + ) + has_been_replaced = True + else: + if ( + quantization_config.llm_int8_skip_modules is not None + and name in quantization_config.llm_int8_skip_modules + ): + pass + else: + extra_kwargs = ( + {"quant_storage": quantization_config.bnb_4bit_quant_storage} + if "quant_storage" in list(signature(bnb.nn.Linear4bit).parameters) + else {} + ) + model._modules[name] = bnb.nn.Linear4bit( + in_features, + out_features, + module.bias is not None, + quantization_config.bnb_4bit_compute_dtype, + compress_statistics=quantization_config.bnb_4bit_use_double_quant, + quant_type=quantization_config.bnb_4bit_quant_type, + **extra_kwargs, + ) + has_been_replaced = True + # Store the module class in case we need to transpose the weight later + model._modules[name].source_cls = type(module) + # Force requires grad to False to avoid unexpected errors + model._modules[name].requires_grad_(False) + if len(list(module.children())) > 0: + _, has_been_replaced = _replace_with_bnb_linear( + module, + modules_to_not_convert, + current_key_name, + quantization_config, + has_been_replaced=has_been_replaced, + ) + # Remove the last key for recursion + current_key_name.pop(-1) + return model, has_been_replaced + + +def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name=None, quantization_config=None): + """ + Helper function to replace the `nn.Linear` layers within `model` with either `bnb.nn.Linear8bit` or + `bnb.nn.Linear4bit` using the `bitsandbytes` library. + + References: + * `bnb.nn.Linear8bit`: [LLM.int8(): 8-bit Matrix Multiplication for Transformers at + Scale](https://arxiv.org/abs/2208.07339) + * `bnb.nn.Linear4bit`: [QLoRA: Efficient Finetuning of Quantized LLMs](https://arxiv.org/abs/2305.14314) + + Parameters: + model (`torch.nn.Module`): + Input model or `torch.nn.Module` as the function is run recursively. + modules_to_not_convert (`List[`str`]`, *optional*, defaults to `[]`): + Names of the modules to not convert in `Linear8bitLt`. In practice we keep the `modules_to_not_convert` in + full precision for numerical stability reasons. + current_key_name (`List[`str`]`, *optional*): + An array to track the current key of the recursion. This is used to check whether the current key (part of + it) is not in the list of modules to not convert (for instances modules that are offloaded to `cpu` or + `disk`). + quantization_config ('transformers.utils.quantization_config.BitsAndBytesConfig'): + To configure and manage settings related to quantization, a technique used to compress neural network + models by reducing the precision of the weights and activations, thus making models more efficient in terms + of both storage and computation. + """ + model, has_been_replaced = _replace_with_bnb_linear( + model, modules_to_not_convert, current_key_name, quantization_config + ) + + if not has_been_replaced: + logger.warning( + "You are loading your model in 8bit or 4bit but no linear modules were found in your model." + " Please double check your model architecture, or submit an issue on github if you think this is" + " a bug." + ) + + return model + + +# Copied from PEFT: https://github.com/huggingface/peft/blob/47b3712898539569c02ec5b3ed4a6c36811331a1/src/peft/utils/integrations.py#L41 +def dequantize_bnb_weight(weight: "torch.nn.Parameter", state=None): + """ + Helper function to dequantize 4bit or 8bit bnb weights. + + If the weight is not a bnb quantized weight, it will be returned as is. + """ + if not isinstance(weight, torch.nn.Parameter): + raise TypeError(f"Input weight should be of type nn.Parameter, got {type(weight)} instead") + + cls_name = weight.__class__.__name__ + if cls_name not in ("Params4bit", "Int8Params"): + return weight + + if cls_name == "Params4bit": + output_tensor = bnb.functional.dequantize_4bit(weight.data, weight.quant_state) + logger.warning_once( + f"The model is going to be dequantized in {output_tensor.dtype} - if you want to upcast it to another dtype, make sure to pass the desired dtype when quantizing the model through `bnb_4bit_quant_type` argument of `BitsAndBytesConfig`" + ) + return output_tensor + + if state.SCB is None: + state.SCB = weight.SCB + + im = torch.eye(weight.data.shape[-1]).contiguous().half().to(weight.device) + im, imt, SCim, SCimt, coo_tensorim = bnb.functional.double_quant(im) + im, Sim = bnb.functional.transform(im, "col32") + if state.CxB is None: + state.CxB, state.SB = bnb.functional.transform(weight.data, to_order=state.formatB) + out32, Sout32 = bnb.functional.igemmlt(im, state.CxB, Sim, state.SB) + return bnb.functional.mm_dequant(out32, Sout32, SCim, state.SCB, bias=None).t() + + +def _create_accelerate_new_hook(old_hook): + r""" + Creates a new hook based on the old hook. Use it only if you know what you are doing ! This method is a copy of: + https://github.com/huggingface/peft/blob/748f7968f3a31ec06a1c2b0328993319ad9a150a/src/peft/utils/other.py#L245 with + some changes + """ + old_hook_cls = getattr(accelerate.hooks, old_hook.__class__.__name__) + old_hook_attr = old_hook.__dict__ + filtered_old_hook_attr = {} + old_hook_init_signature = inspect.signature(old_hook_cls.__init__) + for k in old_hook_attr.keys(): + if k in old_hook_init_signature.parameters: + filtered_old_hook_attr[k] = old_hook_attr[k] + new_hook = old_hook_cls(**filtered_old_hook_attr) + return new_hook + + +def _dequantize_and_replace( + model, + modules_to_not_convert=None, + current_key_name=None, + quantization_config=None, + has_been_replaced=False, +): + """ + Converts a quantized model into its dequantized original version. The newly converted model will have some + performance drop compared to the original model before quantization - use it only for specific usecases such as + QLoRA adapters merging. + + Returns the converted model and a boolean that indicates if the conversion has been successfull or not. + """ + quant_method = quantization_config.quantization_method() + + target_cls = bnb.nn.Linear8bitLt if quant_method == "llm_int8" else bnb.nn.Linear4bit + + for name, module in model.named_children(): + if current_key_name is None: + current_key_name = [] + current_key_name.append(name) + + if isinstance(module, target_cls) and name not in modules_to_not_convert: + # Check if the current key is not in the `modules_to_not_convert` + current_key_name_str = ".".join(current_key_name) + + if not any( + (key + "." in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert + ): + bias = getattr(module, "bias", None) + + device = module.weight.device + with init_empty_weights(): + new_module = torch.nn.Linear(module.in_features, module.out_features, bias=bias is not None) + + if quant_method == "llm_int8": + state = module.state + else: + state = None + + new_module.weight = torch.nn.Parameter(dequantize_bnb_weight(module.weight, state)) + + if bias is not None: + new_module.bias = bias + + # Create a new hook and attach it in case we use accelerate + if hasattr(module, "_hf_hook"): + old_hook = module._hf_hook + new_hook = _create_accelerate_new_hook(old_hook) + + remove_hook_from_module(module) + add_hook_to_module(new_module, new_hook) + + new_module.to(device) + model._modules[name] = new_module + has_been_replaced = True + if len(list(module.children())) > 0: + _, has_been_replaced = _dequantize_and_replace( + module, + modules_to_not_convert, + current_key_name, + quantization_config, + has_been_replaced=has_been_replaced, + ) + # Remove the last key for recursion + current_key_name.pop(-1) + return model, has_been_replaced + + +def dequantize_and_replace( + model, + modules_to_not_convert=None, + quantization_config=None, +): + model, has_been_replaced = _dequantize_and_replace( + model, + modules_to_not_convert=modules_to_not_convert, + quantization_config=quantization_config, + ) + + if not has_been_replaced: + logger.warning( + "For some reason the model has not been properly dequantized. You might see unexpected behavior." + ) + + return model + + +def _check_bnb_status(module) -> Union[bool, bool]: + is_loaded_in_4bit_bnb = ( + hasattr(module, "is_loaded_in_4bit") + and module.is_loaded_in_4bit + and getattr(module, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES + ) + is_loaded_in_8bit_bnb = ( + hasattr(module, "is_loaded_in_8bit") + and module.is_loaded_in_8bit + and getattr(module, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES + ) + return is_loaded_in_4bit_bnb or is_loaded_in_8bit_bnb, is_loaded_in_4bit_bnb, is_loaded_in_8bit_bnb diff --git a/src/diffusers/quantizers/quantization_config.py b/src/diffusers/quantizers/quantization_config.py new file mode 100644 index 000000000000..f521c5d717d6 --- /dev/null +++ b/src/diffusers/quantizers/quantization_config.py @@ -0,0 +1,391 @@ +#!/usr/bin/env python +# coding=utf-8 + +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Adapted from +https://github.com/huggingface/transformers/blob/52cb4034ada381fe1ffe8d428a1076e5411a8026/src/transformers/utils/quantization_config.py +""" + +import copy +import importlib.metadata +import json +import os +from dataclasses import dataclass +from enum import Enum +from typing import Any, Dict, Union + +from packaging import version + +from ..utils import is_torch_available, logging + + +if is_torch_available(): + import torch + +logger = logging.get_logger(__name__) + + +class QuantizationMethod(str, Enum): + BITS_AND_BYTES = "bitsandbytes" + + +@dataclass +class QuantizationConfigMixin: + """ + Mixin class for quantization config + """ + + quant_method: QuantizationMethod + _exclude_attributes_at_init = [] + + @classmethod + def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs): + """ + Instantiates a [`QuantizationConfigMixin`] from a Python dictionary of parameters. + + Args: + config_dict (`Dict[str, Any]`): + Dictionary that will be used to instantiate the configuration object. + return_unused_kwargs (`bool`,*optional*, defaults to `False`): + Whether or not to return a list of unused keyword arguments. Used for `from_pretrained` method in + `PreTrainedModel`. + kwargs (`Dict[str, Any]`): + Additional parameters from which to initialize the configuration object. + + Returns: + [`QuantizationConfigMixin`]: The configuration object instantiated from those parameters. + """ + + config = cls(**config_dict) + + to_remove = [] + for key, value in kwargs.items(): + if hasattr(config, key): + setattr(config, key, value) + to_remove.append(key) + for key in to_remove: + kwargs.pop(key, None) + + if return_unused_kwargs: + return config, kwargs + else: + return config + + def to_json_file(self, json_file_path: Union[str, os.PathLike]): + """ + Save this instance to a JSON file. + + Args: + json_file_path (`str` or `os.PathLike`): + Path to the JSON file in which this configuration instance's parameters will be saved. + use_diff (`bool`, *optional*, defaults to `True`): + If set to `True`, only the difference between the config instance and the default + `QuantizationConfig()` is serialized to JSON file. + """ + with open(json_file_path, "w", encoding="utf-8") as writer: + config_dict = self.to_dict() + json_string = json.dumps(config_dict, indent=2, sort_keys=True) + "\n" + + writer.write(json_string) + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes this instance to a Python dictionary. Returns: + `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance. + """ + return copy.deepcopy(self.__dict__) + + def __iter__(self): + """allows `dict(obj)` for situations where obj may be a dict or QuantizationConfigMixin""" + for attr, value in copy.deepcopy(self.__dict__).items(): + yield attr, value + + def __repr__(self): + return f"{self.__class__.__name__} {self.to_json_string()}" + + def to_json_string(self, use_diff: bool = True) -> str: + """ + Serializes this instance to a JSON string. + + Args: + use_diff (`bool`, *optional*, defaults to `True`): + If set to `True`, only the difference between the config instance and the default `PretrainedConfig()` + is serialized to JSON string. + + Returns: + `str`: String containing all the attributes that make up this configuration instance in JSON format. + """ + if use_diff is True: + config_dict = self.to_diff_dict() + else: + config_dict = self.to_dict() + return json.dumps(config_dict, indent=2, sort_keys=True) + "\n" + + def update(self, **kwargs): + """ + Updates attributes of this class instance with attributes from `kwargs` if they match existing attributes, + returning all the unused kwargs. + + Args: + kwargs (`Dict[str, Any]`): + Dictionary of attributes to tentatively update this class. + + Returns: + `Dict[str, Any]`: Dictionary containing all the key-value pairs that were not used to update the instance. + """ + to_remove = [] + for key, value in kwargs.items(): + if hasattr(self, key): + setattr(self, key, value) + to_remove.append(key) + + # Remove all the attributes that were updated, without modifying the input dict + unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove} + return unused_kwargs + + +@dataclass +class BitsAndBytesConfig(QuantizationConfigMixin): + """ + This is a wrapper class about all possible attributes and features that you can play with a model that has been + loaded using `bitsandbytes`. + + This replaces `load_in_8bit` or `load_in_4bit`therefore both options are mutually exclusive. + + Currently only supports `LLM.int8()`, `FP4`, and `NF4` quantization. If more methods are added to `bitsandbytes`, + then more arguments will be added to this class. + + Args: + load_in_8bit (`bool`, *optional*, defaults to `False`): + This flag is used to enable 8-bit quantization with LLM.int8(). + load_in_4bit (`bool`, *optional*, defaults to `False`): + This flag is used to enable 4-bit quantization by replacing the Linear layers with FP4/NF4 layers from + `bitsandbytes`. + llm_int8_threshold (`float`, *optional*, defaults to 6.0): + This corresponds to the outlier threshold for outlier detection as described in `LLM.int8() : 8-bit Matrix + Multiplication for Transformers at Scale` paper: https://arxiv.org/abs/2208.07339 Any hidden states value + that is above this threshold will be considered an outlier and the operation on those values will be done + in fp16. Values are usually normally distributed, that is, most values are in the range [-3.5, 3.5], but + there are some exceptional systematic outliers that are very differently distributed for large models. + These outliers are often in the interval [-60, -6] or [6, 60]. Int8 quantization works well for values of + magnitude ~5, but beyond that, there is a significant performance penalty. A good default threshold is 6, + but a lower threshold might be needed for more unstable models (small models, fine-tuning). + llm_int8_skip_modules (`List[str]`, *optional*): + An explicit list of the modules that we do not want to convert in 8-bit. This is useful for models such as + Jukebox that has several heads in different places and not necessarily at the last position. For example + for `CausalLM` models, the last `lm_head` is typically kept in its original `dtype`. + llm_int8_enable_fp32_cpu_offload (`bool`, *optional*, defaults to `False`): + This flag is used for advanced use cases and users that are aware of this feature. If you want to split + your model in different parts and run some parts in int8 on GPU and some parts in fp32 on CPU, you can use + this flag. This is useful for offloading large models such as `google/flan-t5-xxl`. Note that the int8 + operations will not be run on CPU. + llm_int8_has_fp16_weight (`bool`, *optional*, defaults to `False`): + This flag runs LLM.int8() with 16-bit main weights. This is useful for fine-tuning as the weights do not + have to be converted back and forth for the backward pass. + bnb_4bit_compute_dtype (`torch.dtype` or str, *optional*, defaults to `torch.float32`): + This sets the computational type which might be different than the input type. For example, inputs might be + fp32, but computation can be set to bf16 for speedups. + bnb_4bit_quant_type (`str`, *optional*, defaults to `"fp4"`): + This sets the quantization data type in the bnb.nn.Linear4Bit layers. Options are FP4 and NF4 data types + which are specified by `fp4` or `nf4`. + bnb_4bit_use_double_quant (`bool`, *optional*, defaults to `False`): + This flag is used for nested quantization where the quantization constants from the first quantization are + quantized again. + bnb_4bit_quant_storage (`torch.dtype` or str, *optional*, defaults to `torch.uint8`): + This sets the storage type to pack the quanitzed 4-bit prarams. + kwargs (`Dict[str, Any]`, *optional*): + Additional parameters from which to initialize the configuration object. + """ + + _exclude_attributes_at_init = ["_load_in_4bit", "_load_in_8bit", "quant_method"] + + def __init__( + self, + load_in_8bit=False, + load_in_4bit=False, + llm_int8_threshold=6.0, + llm_int8_skip_modules=None, + llm_int8_enable_fp32_cpu_offload=False, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=None, + bnb_4bit_quant_type="fp4", + bnb_4bit_use_double_quant=False, + bnb_4bit_quant_storage=None, + **kwargs, + ): + self.quant_method = QuantizationMethod.BITS_AND_BYTES + + if load_in_4bit and load_in_8bit: + raise ValueError("load_in_4bit and load_in_8bit are both True, but only one can be used at the same time") + + self._load_in_8bit = load_in_8bit + self._load_in_4bit = load_in_4bit + self.llm_int8_threshold = llm_int8_threshold + self.llm_int8_skip_modules = llm_int8_skip_modules + self.llm_int8_enable_fp32_cpu_offload = llm_int8_enable_fp32_cpu_offload + self.llm_int8_has_fp16_weight = llm_int8_has_fp16_weight + self.bnb_4bit_quant_type = bnb_4bit_quant_type + self.bnb_4bit_use_double_quant = bnb_4bit_use_double_quant + + if bnb_4bit_compute_dtype is None: + self.bnb_4bit_compute_dtype = torch.float32 + elif isinstance(bnb_4bit_compute_dtype, str): + self.bnb_4bit_compute_dtype = getattr(torch, bnb_4bit_compute_dtype) + elif isinstance(bnb_4bit_compute_dtype, torch.dtype): + self.bnb_4bit_compute_dtype = bnb_4bit_compute_dtype + else: + raise ValueError("bnb_4bit_compute_dtype must be a string or a torch.dtype") + + if bnb_4bit_quant_storage is None: + self.bnb_4bit_quant_storage = torch.uint8 + elif isinstance(bnb_4bit_quant_storage, str): + if bnb_4bit_quant_storage not in ["float16", "float32", "int8", "uint8", "float64", "bfloat16"]: + raise ValueError( + "`bnb_4bit_quant_storage` must be a valid string (one of 'float16', 'float32', 'int8', 'uint8', 'float64', 'bfloat16') " + ) + self.bnb_4bit_quant_storage = getattr(torch, bnb_4bit_quant_storage) + elif isinstance(bnb_4bit_quant_storage, torch.dtype): + self.bnb_4bit_quant_storage = bnb_4bit_quant_storage + else: + raise ValueError("bnb_4bit_quant_storage must be a string or a torch.dtype") + + if kwargs and not all(k in self._exclude_attributes_at_init for k in kwargs): + logger.warning(f"Unused kwargs: {list(kwargs.keys())}. These kwargs are not used in {self.__class__}.") + + self.post_init() + + @property + def load_in_4bit(self): + return self._load_in_4bit + + @load_in_4bit.setter + def load_in_4bit(self, value: bool): + if not isinstance(value, bool): + raise TypeError("load_in_4bit must be a boolean") + + if self.load_in_8bit and value: + raise ValueError("load_in_4bit and load_in_8bit are both True, but only one can be used at the same time") + self._load_in_4bit = value + + @property + def load_in_8bit(self): + return self._load_in_8bit + + @load_in_8bit.setter + def load_in_8bit(self, value: bool): + if not isinstance(value, bool): + raise TypeError("load_in_8bit must be a boolean") + + if self.load_in_4bit and value: + raise ValueError("load_in_4bit and load_in_8bit are both True, but only one can be used at the same time") + self._load_in_8bit = value + + def post_init(self): + r""" + Safety checker that arguments are correct - also replaces some NoneType arguments with their default values. + """ + if not isinstance(self.load_in_4bit, bool): + raise TypeError("load_in_4bit must be a boolean") + + if not isinstance(self.load_in_8bit, bool): + raise TypeError("load_in_8bit must be a boolean") + + if not isinstance(self.llm_int8_threshold, float): + raise TypeError("llm_int8_threshold must be a float") + + if self.llm_int8_skip_modules is not None and not isinstance(self.llm_int8_skip_modules, list): + raise TypeError("llm_int8_skip_modules must be a list of strings") + if not isinstance(self.llm_int8_enable_fp32_cpu_offload, bool): + raise TypeError("llm_int8_enable_fp32_cpu_offload must be a boolean") + + if not isinstance(self.llm_int8_has_fp16_weight, bool): + raise TypeError("llm_int8_has_fp16_weight must be a boolean") + + if self.bnb_4bit_compute_dtype is not None and not isinstance(self.bnb_4bit_compute_dtype, torch.dtype): + raise TypeError("bnb_4bit_compute_dtype must be torch.dtype") + + if not isinstance(self.bnb_4bit_quant_type, str): + raise TypeError("bnb_4bit_quant_type must be a string") + + if not isinstance(self.bnb_4bit_use_double_quant, bool): + raise TypeError("bnb_4bit_use_double_quant must be a boolean") + + if self.load_in_4bit and not version.parse(importlib.metadata.version("bitsandbytes")) >= version.parse( + "0.39.0" + ): + raise ValueError( + "4 bit quantization requires bitsandbytes>=0.39.0 - please upgrade your bitsandbytes version" + ) + + def is_quantizable(self): + r""" + Returns `True` if the model is quantizable, `False` otherwise. + """ + return self.load_in_8bit or self.load_in_4bit + + def quantization_method(self): + r""" + This method returns the quantization method used for the model. If the model is not quantizable, it returns + `None`. + """ + if self.load_in_8bit: + return "llm_int8" + elif self.load_in_4bit and self.bnb_4bit_quant_type == "fp4": + return "fp4" + elif self.load_in_4bit and self.bnb_4bit_quant_type == "nf4": + return "nf4" + else: + return None + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes this instance to a Python dictionary. Returns: + `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance. + """ + output = copy.deepcopy(self.__dict__) + output["bnb_4bit_compute_dtype"] = str(output["bnb_4bit_compute_dtype"]).split(".")[1] + output["bnb_4bit_quant_storage"] = str(output["bnb_4bit_quant_storage"]).split(".")[1] + output["load_in_4bit"] = self.load_in_4bit + output["load_in_8bit"] = self.load_in_8bit + + return output + + def __repr__(self): + config_dict = self.to_dict() + return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n" + + def to_diff_dict(self) -> Dict[str, Any]: + """ + Removes all attributes from config which correspond to the default config attributes for better readability and + serializes to a Python dictionary. + + Returns: + `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance, + """ + config_dict = self.to_dict() + + # get the default config dict + default_config_dict = BitsAndBytesConfig().to_dict() + + serializable_config_dict = {} + + # only serialize values that differ from the default config + for key, value in config_dict.items(): + if value != default_config_dict[key]: + serializable_config_dict[key] = value + + return serializable_config_dict diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index c7ea2bcc5b7f..c8f64adf3e8a 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -62,6 +62,7 @@ is_accelerate_available, is_accelerate_version, is_bitsandbytes_available, + is_bitsandbytes_version, is_bs4_available, is_flax_available, is_ftfy_available, @@ -94,7 +95,7 @@ is_xformers_available, requires_backends, ) -from .loading_utils import load_image, load_video +from .loading_utils import get_module_from_name, load_image, load_video from .logging import get_logger from .outputs import BaseOutput from .peft_utils import ( diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 1ab946ce7257..3b8cab24b8b7 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -1005,6 +1005,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class DiffusersQuantizer(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class AmusedScheduler(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index 34cc5fcc8605..8b81b19b8a52 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -740,6 +740,20 @@ def is_peft_version(operation: str, version: str): return compare_versions(parse(_peft_version), operation, version) +def is_bitsandbytes_version(operation: str, version: str): + """ + Args: + Compares the current bitsandbytes version to a given reference with an operation. + operation (`str`): + A string representation of an operator, such as `">"` or `"<="` + version (`str`): + A version string + """ + if not _bitsandbytes_version: + return False + return compare_versions(parse(_bitsandbytes_version), operation, version) + + def is_k_diffusion_version(operation: str, version: str): """ Args: diff --git a/src/diffusers/utils/loading_utils.py b/src/diffusers/utils/loading_utils.py index b36664cb81ff..bac24fa23e63 100644 --- a/src/diffusers/utils/loading_utils.py +++ b/src/diffusers/utils/loading_utils.py @@ -1,6 +1,6 @@ import os import tempfile -from typing import Callable, List, Optional, Union +from typing import Any, Callable, List, Optional, Tuple, Union from urllib.parse import unquote, urlparse import PIL.Image @@ -135,3 +135,16 @@ def load_video( pil_images = convert_method(pil_images) return pil_images + + +# Taken from `transformers`. +def get_module_from_name(module, tensor_name: str) -> Tuple[Any, str]: + if "." in tensor_name: + splits = tensor_name.split(".") + for split in splits[:-1]: + new_module = getattr(module, split) + if new_module is None: + raise ValueError(f"{module} has no attribute {split}.") + module = new_module + tensor_name = splits[-1] + return module, tensor_name diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index be3e9983c80f..1eb35a9c392e 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -1,5 +1,6 @@ import functools import importlib +import importlib.metadata import inspect import io import logging @@ -27,6 +28,8 @@ from .import_utils import ( BACKENDS_MAPPING, + is_accelerate_available, + is_bitsandbytes_available, is_compel_available, is_flax_available, is_note_seq_available, @@ -359,6 +362,20 @@ def require_timm(test_case): return unittest.skipUnless(is_timm_available(), "test requires timm")(test_case) +def require_bitsandbytes(test_case): + """ + Decorator marking a test that requires bitsandbytes. These tests are skipped when bitsandbytes isn't installed. + """ + return unittest.skipUnless(is_bitsandbytes_available(), "test requires bitsandbytes")(test_case) + + +def require_accelerate(test_case): + """ + Decorator marking a test that requires accelerate. These tests are skipped when accelerate isn't installed. + """ + return unittest.skipUnless(is_accelerate_available(), "test requires accelerate")(test_case) + + def require_peft_version_greater(peft_version): """ Decorator marking a test that requires PEFT backend with a specific version, this would require some specific @@ -388,6 +405,31 @@ def decorator(test_case): return decorator +def require_bitsandbytes_version_greater(bnb_version): + def decorator(test_case): + correct_bnb_version = is_bitsandbytes_available() and version.parse( + version.parse(importlib.metadata.version("bitsandbytes")).base_version + ) > version.parse(bnb_version) + return unittest.skipUnless( + correct_bnb_version, f"Test requires bitsandbytes with the version greater than {bnb_version}." + )(test_case) + + return decorator + + +def require_transformers_version_greater(transformers_version): + def decorator(test_case): + correct_transformers_version = is_transformers_available() and version.parse( + version.parse(importlib.metadata.version("transformers")).base_version + ) > version.parse(transformers_version) + return unittest.skipUnless( + correct_transformers_version, + f"test requires transformers backend with the version greater than {transformers_version}", + )(test_case) + + return decorator + + def deprecate_after_peft_backend(test_case): """ Decorator marking a test that will be skipped after PEFT backend diff --git a/tests/quantization/bnb/README.md b/tests/quantization/bnb/README.md new file mode 100644 index 000000000000..f1585581597d --- /dev/null +++ b/tests/quantization/bnb/README.md @@ -0,0 +1,44 @@ +The tests here are adapted from [`transformers` tests](https://github.com/huggingface/transformers/tree/409fcfdfccde77a14b7cc36972b774cabc371ae1/tests/quantization/bnb). + +They were conducted on the `audace` machine, using a single RTX 4090. Below is `nvidia-smi`: + +```bash ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 550.90.07 Driver Version: 550.90.07 CUDA Version: 12.4 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA GeForce RTX 4090 Off | 00000000:01:00.0 Off | Off | +| 30% 55C P0 61W / 450W | 1MiB / 24564MiB | 2% Default | +| | | N/A | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA GeForce RTX 4090 Off | 00000000:13:00.0 Off | Off | +| 30% 51C P0 60W / 450W | 1MiB / 24564MiB | 0% Default | +| | | N/A | ++-----------------------------------------+------------------------+----------------------+ +``` + +`diffusers-cli`: + +```bash +- 🤗 Diffusers version: 0.31.0.dev0 +- Platform: Linux-5.15.0-117-generic-x86_64-with-glibc2.35 +- Running on Google Colab?: No +- Python version: 3.10.12 +- PyTorch version (GPU?): 2.5.0.dev20240818+cu124 (True) +- Flax version (CPU?/GPU?/TPU?): not installed (NA) +- Jax version: not installed +- JaxLib version: not installed +- Huggingface_hub version: 0.24.5 +- Transformers version: 4.44.2 +- Accelerate version: 0.34.0.dev0 +- PEFT version: 0.12.0 +- Bitsandbytes version: 0.43.3 +- Safetensors version: 0.4.4 +- xFormers version: not installed +- Accelerator: NVIDIA GeForce RTX 4090, 24564 MiB +NVIDIA GeForce RTX 4090, 24564 MiB +- Using GPU in script?: Yes +``` \ No newline at end of file diff --git a/tests/quantization/bnb/__init__.py b/tests/quantization/bnb/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py new file mode 100644 index 000000000000..96da29b00923 --- /dev/null +++ b/tests/quantization/bnb/test_4bit.py @@ -0,0 +1,558 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a clone of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import gc +import tempfile +import unittest + +import numpy as np + +from diffusers import BitsAndBytesConfig, DiffusionPipeline, FluxTransformer2DModel, SD3Transformer2DModel +from diffusers.utils.testing_utils import ( + is_bitsandbytes_available, + is_torch_available, + is_transformers_available, + load_pt, + numpy_cosine_similarity_distance, + require_accelerate, + require_bitsandbytes_version_greater, + require_torch, + require_torch_gpu, + require_transformers_version_greater, + slow, + torch_device, +) + + +def get_some_linear_layer(model): + if model.__class__.__name__ in ["SD3Transformer2DModel", "FluxTransformer2DModel"]: + return model.transformer_blocks[0].attn.to_q + else: + return NotImplementedError("Don't know what layer to retrieve here.") + + +if is_transformers_available(): + from transformers import T5EncoderModel + +if is_torch_available(): + import torch + import torch.nn as nn + + class LoRALayer(nn.Module): + """Wraps a linear layer with LoRA-like adapter - Used for testing purposes only + + Taken from + https://github.com/huggingface/transformers/blob/566302686a71de14125717dea9a6a45b24d42b37/tests/quantization/bnb/test_4bit.py#L62C5-L78C77 + """ + + def __init__(self, module: nn.Module, rank: int): + super().__init__() + self.module = module + self.adapter = nn.Sequential( + nn.Linear(module.in_features, rank, bias=False), + nn.Linear(rank, module.out_features, bias=False), + ) + small_std = (2.0 / (5 * min(module.in_features, module.out_features))) ** 0.5 + nn.init.normal_(self.adapter[0].weight, std=small_std) + nn.init.zeros_(self.adapter[1].weight) + self.adapter.to(module.weight.device) + + def forward(self, input, *args, **kwargs): + return self.module(input, *args, **kwargs) + self.adapter(input) + + +if is_bitsandbytes_available(): + import bitsandbytes as bnb + + +@require_bitsandbytes_version_greater("0.43.2") +@require_accelerate +@require_torch +@require_torch_gpu +@slow +class Base4bitTests(unittest.TestCase): + # We need to test on relatively large models (aka >1b parameters otherwise the quantiztion may not work as expected) + # Therefore here we use only SD3 to test our module + model_name = "stabilityai/stable-diffusion-3-medium-diffusers" + + # This was obtained on audace so the number might slightly change + expected_rel_difference = 3.69 + + prompt = "a beautiful sunset amidst the mountains." + num_inference_steps = 10 + seed = 0 + + def get_dummy_inputs(self): + prompt_embeds = load_pt( + "https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/prompt_embeds.pt" + ) + pooled_prompt_embeds = load_pt( + "https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/pooled_prompt_embeds.pt" + ) + latent_model_input = load_pt( + "https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/latent_model_input.pt" + ) + + input_dict_for_transformer = { + "hidden_states": latent_model_input, + "encoder_hidden_states": prompt_embeds, + "pooled_projections": pooled_prompt_embeds, + "timestep": torch.Tensor([1.0]), + "return_dict": False, + } + return input_dict_for_transformer + + +class BnB4BitBasicTests(Base4bitTests): + def setUp(self): + # Models + self.model_fp16 = SD3Transformer2DModel.from_pretrained( + self.model_name, subfolder="transformer", torch_dtype=torch.float16 + ) + nf4_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.float16, + ) + self.model_4bit = SD3Transformer2DModel.from_pretrained( + self.model_name, subfolder="transformer", quantization_config=nf4_config + ) + + def tearDown(self): + del self.model_fp16 + del self.model_4bit + + gc.collect() + torch.cuda.empty_cache() + + def test_quantization_num_parameters(self): + r""" + Test if the number of returned parameters is correct + """ + num_params_4bit = self.model_4bit.num_parameters() + num_params_fp16 = self.model_fp16.num_parameters() + + self.assertEqual(num_params_4bit, num_params_fp16) + + def test_quantization_config_json_serialization(self): + r""" + A simple test to check if the quantization config is correctly serialized and deserialized + """ + config = self.model_4bit.config + + self.assertTrue("quantization_config" in config) + + _ = config["quantization_config"].to_dict() + _ = config["quantization_config"].to_diff_dict() + + _ = config["quantization_config"].to_json_string() + + def test_memory_footprint(self): + r""" + A simple test to check if the model conversion has been done correctly by checking on the + memory footprint of the converted model and the class type of the linear layers of the converted models + """ + mem_fp16 = self.model_fp16.get_memory_footprint() + mem_4bit = self.model_4bit.get_memory_footprint() + + self.assertAlmostEqual(mem_fp16 / mem_4bit, self.expected_rel_difference, delta=1e-2) + linear = get_some_linear_layer(self.model_4bit) + self.assertTrue(linear.weight.__class__ == bnb.nn.Params4bit) + + def test_original_dtype(self): + r""" + A simple test to check if the model succesfully stores the original dtype + """ + self.assertTrue("_pre_quantization_dtype" in self.model_4bit.config) + self.assertFalse("_pre_quantization_dtype" in self.model_fp16.config) + self.assertTrue(self.model_4bit.config["_pre_quantization_dtype"] == torch.float16) + + def test_keep_modules_in_fp32(self): + r""" + A simple tests to check if the modules under `_keep_in_fp32_modules` are kept in fp32. + Also ensures if inference works. + """ + fp32_modules = SD3Transformer2DModel._keep_in_fp32_modules + SD3Transformer2DModel._keep_in_fp32_modules = ["proj_out"] + + nf4_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.float16, + ) + model = SD3Transformer2DModel.from_pretrained( + self.model_name, subfolder="transformer", quantization_config=nf4_config + ) + + for name, module in model.named_modules(): + if isinstance(module, torch.nn.Linear): + if name in model._keep_in_fp32_modules: + self.assertTrue(module.weight.dtype == torch.float32) + else: + # 4-bit parameters are packed in uint8 variables + self.assertTrue(module.weight.dtype == torch.uint8) + + # test if inference works. + with torch.no_grad() and torch.amp.autocast("cuda", dtype=torch.float16): + input_dict_for_transformer = self.get_dummy_inputs() + model_inputs = { + k: v.to(device=torch_device) for k, v in input_dict_for_transformer.items() if not isinstance(v, bool) + } + model_inputs.update({k: v for k, v in input_dict_for_transformer.items() if k not in model_inputs}) + _ = model(**model_inputs) + + SD3Transformer2DModel._keep_in_fp32_modules = fp32_modules + + def test_linear_are_4bit(self): + r""" + A simple test to check if the model conversion has been done correctly by checking on the + memory footprint of the converted model and the class type of the linear layers of the converted models + """ + self.model_fp16.get_memory_footprint() + self.model_4bit.get_memory_footprint() + + for name, module in self.model_4bit.named_modules(): + if isinstance(module, torch.nn.Linear): + if name not in ["proj_out"]: + # 4-bit parameters are packed in uint8 variables + self.assertTrue(module.weight.dtype == torch.uint8) + + def test_config_from_pretrained(self): + transformer_4bit = FluxTransformer2DModel.from_pretrained( + "sayakpaul/flux.1-dev-nf4-pkg", subfolder="transformer" + ) + linear = get_some_linear_layer(transformer_4bit) + self.assertTrue(linear.weight.__class__ == bnb.nn.Params4bit) + self.assertTrue(hasattr(linear.weight, "quant_state")) + self.assertTrue(linear.weight.quant_state.__class__ == bnb.functional.QuantState) + + def test_device_assignment(self): + mem_before = self.model_4bit.get_memory_footprint() + + # Move to CPU + self.model_4bit.to("cpu") + self.assertEqual(self.model_4bit.device.type, "cpu") + self.assertAlmostEqual(self.model_4bit.get_memory_footprint(), mem_before) + + # Move back to CUDA device + for device in [0, "cuda", "cuda:0", "call()"]: + if device == "call()": + self.model_4bit.cuda(0) + else: + self.model_4bit.to(device) + self.assertEqual(self.model_4bit.device, torch.device(0)) + self.assertAlmostEqual(self.model_4bit.get_memory_footprint(), mem_before) + self.model_4bit.to("cpu") + + def test_device_and_dtype_assignment(self): + r""" + Test whether trying to cast (or assigning a device to) a model after converting it in 4-bit will throw an error. + Checks also if other models are casted correctly. + """ + with self.assertRaises(ValueError): + # Tries with a `dtype` + self.model_4bit.to(torch.float16) + + with self.assertRaises(ValueError): + # Tries with a `device` and `dtype` + self.model_4bit.to(device="cuda:0", dtype=torch.float16) + + with self.assertRaises(ValueError): + # Tries with a cast + self.model_4bit.float() + + with self.assertRaises(ValueError): + # Tries with a cast + self.model_4bit.half() + + # Test if we did not break anything + self.model_fp16 = self.model_fp16.to(dtype=torch.float32, device=torch_device) + input_dict_for_transformer = self.get_dummy_inputs() + model_inputs = { + k: v.to(dtype=torch.float32, device=torch_device) + for k, v in input_dict_for_transformer.items() + if not isinstance(v, bool) + } + model_inputs.update({k: v for k, v in input_dict_for_transformer.items() if k not in model_inputs}) + with torch.no_grad(): + _ = self.model_fp16(**model_inputs) + + # Check this does not throw an error + _ = self.model_fp16.to("cpu") + + # Check this does not throw an error + _ = self.model_fp16.half() + + # Check this does not throw an error + _ = self.model_fp16.float() + + # Check that this does not throw an error + _ = self.model_fp16.cuda() + + def test_bnb_4bit_wrong_config(self): + r""" + Test whether creating a bnb config with unsupported values leads to errors. + """ + with self.assertRaises(ValueError): + _ = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_storage="add") + + +class BnB4BitTrainingTests(Base4bitTests): + def setUp(self): + nf4_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.float16, + ) + self.model_4bit = SD3Transformer2DModel.from_pretrained( + self.model_name, subfolder="transformer", quantization_config=nf4_config + ) + + def test_training(self): + # Step 1: freeze all parameters + for param in self.model_4bit.parameters(): + param.requires_grad = False # freeze the model - train adapters later + if param.ndim == 1: + # cast the small parameters (e.g. layernorm) to fp32 for stability + param.data = param.data.to(torch.float32) + + # Step 2: add adapters + for _, module in self.model_4bit.named_modules(): + if "Attention" in repr(type(module)): + module.to_k = LoRALayer(module.to_k, rank=4) + module.to_q = LoRALayer(module.to_q, rank=4) + module.to_v = LoRALayer(module.to_v, rank=4) + + # Step 3: dummy batch + input_dict_for_transformer = self.get_dummy_inputs() + model_inputs = { + k: v.to(device=torch_device) for k, v in input_dict_for_transformer.items() if not isinstance(v, bool) + } + model_inputs.update({k: v for k, v in input_dict_for_transformer.items() if k not in model_inputs}) + + # Step 4: Check if the gradient is not None + with torch.amp.autocast("cuda", dtype=torch.float16): + out = self.model_4bit(**model_inputs)[0] + out.norm().backward() + + for module in self.model_4bit.modules(): + if isinstance(module, LoRALayer): + self.assertTrue(module.adapter[1].weight.grad is not None) + self.assertTrue(module.adapter[1].weight.grad.norm().item() > 0) + + +@require_transformers_version_greater("4.44.0") +class SlowBnb4BitTests(Base4bitTests): + def setUp(self) -> None: + nf4_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.float16, + ) + model_4bit = SD3Transformer2DModel.from_pretrained( + self.model_name, subfolder="transformer", quantization_config=nf4_config + ) + self.pipeline_4bit = DiffusionPipeline.from_pretrained( + self.model_name, transformer=model_4bit, torch_dtype=torch.float16 + ) + self.pipeline_4bit.enable_model_cpu_offload() + + def tearDown(self): + del self.pipeline_4bit + + gc.collect() + torch.cuda.empty_cache() + + def test_quality(self): + output = self.pipeline_4bit( + prompt=self.prompt, + num_inference_steps=self.num_inference_steps, + generator=torch.manual_seed(self.seed), + output_type="np", + ).images + + out_slice = output[0, -3:, -3:, -1].flatten() + expected_slice = np.array([0.1123, 0.1296, 0.1609, 0.1042, 0.1230, 0.1274, 0.0928, 0.1165, 0.1216]) + + max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice) + print(f"{max_diff=}") + self.assertTrue(max_diff < 1e-2) + + def test_generate_quality_dequantize(self): + r""" + Test that loading the model and unquantize it produce correct results. + """ + self.pipeline_4bit.transformer.dequantize() + output = self.pipeline_4bit( + prompt=self.prompt, + num_inference_steps=self.num_inference_steps, + generator=torch.manual_seed(self.seed), + output_type="np", + ).images + + out_slice = output[0, -3:, -3:, -1].flatten() + expected_slice = np.array([0.1216, 0.1387, 0.1584, 0.1152, 0.1318, 0.1282, 0.1062, 0.1226, 0.1228]) + max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice) + self.assertTrue(max_diff < 1e-3) + + # Since we offloaded the `pipeline_4bit.transformer` to CPU (result of `enable_model_cpu_offload()), check + # the following. + self.assertTrue(self.pipeline_4bit.transformer.device.type == "cpu") + # calling it again shouldn't be a problem + _ = self.pipeline_4bit( + prompt=self.prompt, + num_inference_steps=2, + generator=torch.manual_seed(self.seed), + output_type="np", + ).images + + +@require_transformers_version_greater("4.44.0") +class SlowBnb4BitFluxTests(Base4bitTests): + def setUp(self) -> None: + # TODO: Copy sayakpaul/flux.1-dev-nf4-pkg to testing repo. + model_id = "sayakpaul/flux.1-dev-nf4-pkg" + t5_4bit = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder_2") + transformer_4bit = FluxTransformer2DModel.from_pretrained(model_id, subfolder="transformer") + self.pipeline_4bit = DiffusionPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", + text_encoder_2=t5_4bit, + transformer=transformer_4bit, + torch_dtype=torch.float16, + ) + self.pipeline_4bit.enable_model_cpu_offload() + + def tearDown(self): + del self.pipeline_4bit + + gc.collect() + torch.cuda.empty_cache() + + def test_quality(self): + # keep the resolution and max tokens to a lower number for faster execution. + output = self.pipeline_4bit( + prompt=self.prompt, + num_inference_steps=self.num_inference_steps, + generator=torch.manual_seed(self.seed), + height=256, + width=256, + max_sequence_length=64, + output_type="np", + ).images + + out_slice = output[0, -3:, -3:, -1].flatten() + expected_slice = np.array([0.0583, 0.0586, 0.0632, 0.0815, 0.0813, 0.0947, 0.1040, 0.1145, 0.1265]) + + max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice) + self.assertTrue(max_diff < 1e-3) + + +@slow +class BaseBnb4BitSerializationTests(Base4bitTests): + def tearDown(self): + gc.collect() + torch.cuda.empty_cache() + + def test_serialization(self, quant_type="nf4", double_quant=True, safe_serialization=True): + r""" + Test whether it is possible to serialize a model in 4-bit. Uses most typical params as default. + See ExtendedSerializationTest class for more params combinations. + """ + + self.quantization_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type=quant_type, + bnb_4bit_use_double_quant=double_quant, + bnb_4bit_compute_dtype=torch.bfloat16, + ) + model_0 = SD3Transformer2DModel.from_pretrained( + self.model_name, subfolder="transformer", quantization_config=self.quantization_config + ) + self.assertTrue("_pre_quantization_dtype" in model_0.config) + with tempfile.TemporaryDirectory() as tmpdirname: + model_0.save_pretrained(tmpdirname, safe_serialization=safe_serialization) + + config = SD3Transformer2DModel.load_config(tmpdirname) + self.assertTrue("quantization_config" in config) + self.assertTrue("_pre_quantization_dtype" not in config) + + model_1 = SD3Transformer2DModel.from_pretrained(tmpdirname) + + # checking quantized linear module weight + linear = get_some_linear_layer(model_1) + self.assertTrue(linear.weight.__class__ == bnb.nn.Params4bit) + self.assertTrue(hasattr(linear.weight, "quant_state")) + self.assertTrue(linear.weight.quant_state.__class__ == bnb.functional.QuantState) + + # checking memory footpring + self.assertAlmostEqual(model_0.get_memory_footprint() / model_1.get_memory_footprint(), 1, places=2) + + # Matching all parameters and their quant_state items: + d0 = dict(model_0.named_parameters()) + d1 = dict(model_1.named_parameters()) + self.assertTrue(d0.keys() == d1.keys()) + + for k in d0.keys(): + self.assertTrue(d0[k].shape == d1[k].shape) + self.assertTrue(d0[k].device.type == d1[k].device.type) + self.assertTrue(d0[k].device == d1[k].device) + self.assertTrue(d0[k].dtype == d1[k].dtype) + self.assertTrue(torch.equal(d0[k], d1[k].to(d0[k].device))) + + if isinstance(d0[k], bnb.nn.modules.Params4bit): + for v0, v1 in zip( + d0[k].quant_state.as_dict().values(), + d1[k].quant_state.as_dict().values(), + ): + if isinstance(v0, torch.Tensor): + self.assertTrue(torch.equal(v0, v1.to(v0.device))) + else: + self.assertTrue(v0 == v1) + + # comparing forward() outputs + dummy_inputs = self.get_dummy_inputs() + inputs = {k: v.to(torch_device) for k, v in dummy_inputs.items() if isinstance(v, torch.Tensor)} + inputs.update({k: v for k, v in dummy_inputs.items() if k not in inputs}) + out_0 = model_0(**inputs)[0] + out_1 = model_1(**inputs)[0] + self.assertTrue(torch.equal(out_0, out_1)) + + +class ExtendedSerializationTest(BaseBnb4BitSerializationTests): + """ + tests more combinations of parameters + """ + + def test_nf4_single_unsafe(self): + self.test_serialization(quant_type="nf4", double_quant=False, safe_serialization=False) + + def test_nf4_single_safe(self): + self.test_serialization(quant_type="nf4", double_quant=False, safe_serialization=True) + + def test_nf4_double_unsafe(self): + self.test_serialization(quant_type="nf4", double_quant=True, safe_serialization=False) + + # nf4 double safetensors quantization is tested in test_serialization() method from the parent class + + def test_fp4_single_unsafe(self): + self.test_serialization(quant_type="fp4", double_quant=False, safe_serialization=False) + + def test_fp4_single_safe(self): + self.test_serialization(quant_type="fp4", double_quant=False, safe_serialization=True) + + def test_fp4_double_unsafe(self): + self.test_serialization(quant_type="fp4", double_quant=True, safe_serialization=False) + + def test_fp4_double_safe(self): + self.test_serialization(quant_type="fp4", double_quant=True, safe_serialization=True) diff --git a/tests/quantization/bnb/test_mixed_int8.py b/tests/quantization/bnb/test_mixed_int8.py new file mode 100644 index 000000000000..7da7cd4de410 --- /dev/null +++ b/tests/quantization/bnb/test_mixed_int8.py @@ -0,0 +1,526 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a clone of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import gc +import tempfile +import unittest + +import numpy as np + +from diffusers import BitsAndBytesConfig, DiffusionPipeline, FluxTransformer2DModel, SD3Transformer2DModel, logging +from diffusers.utils.testing_utils import ( + CaptureLogger, + is_bitsandbytes_available, + is_torch_available, + is_transformers_available, + load_pt, + numpy_cosine_similarity_distance, + require_accelerate, + require_bitsandbytes_version_greater, + require_torch, + require_torch_gpu, + require_transformers_version_greater, + slow, + torch_device, +) + + +def get_some_linear_layer(model): + if model.__class__.__name__ in ["SD3Transformer2DModel", "FluxTransformer2DModel"]: + return model.transformer_blocks[0].attn.to_q + else: + return NotImplementedError("Don't know what layer to retrieve here.") + + +if is_transformers_available(): + from transformers import T5EncoderModel + +if is_torch_available(): + import torch + import torch.nn as nn + + class LoRALayer(nn.Module): + """Wraps a linear layer with LoRA-like adapter - Used for testing purposes only + + Taken from + https://github.com/huggingface/transformers/blob/566302686a71de14125717dea9a6a45b24d42b37/tests/quantization/bnb/test_8bit.py#L62C5-L78C77 + """ + + def __init__(self, module: nn.Module, rank: int): + super().__init__() + self.module = module + self.adapter = nn.Sequential( + nn.Linear(module.in_features, rank, bias=False), + nn.Linear(rank, module.out_features, bias=False), + ) + small_std = (2.0 / (5 * min(module.in_features, module.out_features))) ** 0.5 + nn.init.normal_(self.adapter[0].weight, std=small_std) + nn.init.zeros_(self.adapter[1].weight) + self.adapter.to(module.weight.device) + + def forward(self, input, *args, **kwargs): + return self.module(input, *args, **kwargs) + self.adapter(input) + + +if is_bitsandbytes_available(): + import bitsandbytes as bnb + + +@require_bitsandbytes_version_greater("0.43.2") +@require_accelerate +@require_torch +@require_torch_gpu +@slow +class Base8bitTests(unittest.TestCase): + # We need to test on relatively large models (aka >1b parameters otherwise the quantiztion may not work as expected) + # Therefore here we use only SD3 to test our module + model_name = "stabilityai/stable-diffusion-3-medium-diffusers" + + # This was obtained on audace so the number might slightly change + expected_rel_difference = 1.94 + + prompt = "a beautiful sunset amidst the mountains." + num_inference_steps = 10 + seed = 0 + + def get_dummy_inputs(self): + prompt_embeds = load_pt( + "https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/prompt_embeds.pt" + ) + pooled_prompt_embeds = load_pt( + "https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/pooled_prompt_embeds.pt" + ) + latent_model_input = load_pt( + "https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/latent_model_input.pt" + ) + + input_dict_for_transformer = { + "hidden_states": latent_model_input, + "encoder_hidden_states": prompt_embeds, + "pooled_projections": pooled_prompt_embeds, + "timestep": torch.Tensor([1.0]), + "return_dict": False, + } + return input_dict_for_transformer + + +class BnB8bitBasicTests(Base8bitTests): + def setUp(self): + # Models + self.model_fp16 = SD3Transformer2DModel.from_pretrained( + self.model_name, subfolder="transformer", torch_dtype=torch.float16 + ) + mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True) + self.model_8bit = SD3Transformer2DModel.from_pretrained( + self.model_name, subfolder="transformer", quantization_config=mixed_int8_config + ) + + def tearDown(self): + del self.model_fp16 + del self.model_8bit + + gc.collect() + torch.cuda.empty_cache() + + def test_quantization_num_parameters(self): + r""" + Test if the number of returned parameters is correct + """ + num_params_8bit = self.model_8bit.num_parameters() + num_params_fp16 = self.model_fp16.num_parameters() + + self.assertEqual(num_params_8bit, num_params_fp16) + + def test_quantization_config_json_serialization(self): + r""" + A simple test to check if the quantization config is correctly serialized and deserialized + """ + config = self.model_8bit.config + + self.assertTrue("quantization_config" in config) + + _ = config["quantization_config"].to_dict() + _ = config["quantization_config"].to_diff_dict() + + _ = config["quantization_config"].to_json_string() + + def test_memory_footprint(self): + r""" + A simple test to check if the model conversion has been done correctly by checking on the + memory footprint of the converted model and the class type of the linear layers of the converted models + """ + mem_fp16 = self.model_fp16.get_memory_footprint() + mem_8bit = self.model_8bit.get_memory_footprint() + + self.assertAlmostEqual(mem_fp16 / mem_8bit, self.expected_rel_difference, delta=1e-2) + linear = get_some_linear_layer(self.model_8bit) + self.assertTrue(linear.weight.__class__ == bnb.nn.Int8Params) + + def test_original_dtype(self): + r""" + A simple test to check if the model succesfully stores the original dtype + """ + self.assertTrue("_pre_quantization_dtype" in self.model_8bit.config) + self.assertFalse("_pre_quantization_dtype" in self.model_fp16.config) + self.assertTrue(self.model_8bit.config["_pre_quantization_dtype"] == torch.float16) + + def test_keep_modules_in_fp32(self): + r""" + A simple tests to check if the modules under `_keep_in_fp32_modules` are kept in fp32. + Also ensures if inference works. + """ + fp32_modules = SD3Transformer2DModel._keep_in_fp32_modules + SD3Transformer2DModel._keep_in_fp32_modules = ["proj_out"] + + mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True) + model = SD3Transformer2DModel.from_pretrained( + self.model_name, subfolder="transformer", quantization_config=mixed_int8_config + ) + + for name, module in model.named_modules(): + if isinstance(module, torch.nn.Linear): + if name in model._keep_in_fp32_modules: + self.assertTrue(module.weight.dtype == torch.float32) + else: + # 8-bit parameters are packed in int8 variables + self.assertTrue(module.weight.dtype == torch.int8) + + # test if inference works. + with torch.no_grad() and torch.amp.autocast("cuda", dtype=torch.float16): + input_dict_for_transformer = self.get_dummy_inputs() + model_inputs = { + k: v.to(device=torch_device) for k, v in input_dict_for_transformer.items() if not isinstance(v, bool) + } + model_inputs.update({k: v for k, v in input_dict_for_transformer.items() if k not in model_inputs}) + _ = model(**model_inputs) + + SD3Transformer2DModel._keep_in_fp32_modules = fp32_modules + + def test_linear_are_8bit(self): + r""" + A simple test to check if the model conversion has been done correctly by checking on the + memory footprint of the converted model and the class type of the linear layers of the converted models + """ + self.model_fp16.get_memory_footprint() + self.model_8bit.get_memory_footprint() + + for name, module in self.model_8bit.named_modules(): + if isinstance(module, torch.nn.Linear): + if name not in ["proj_out"]: + # 8-bit parameters are packed in int8 variables + self.assertTrue(module.weight.dtype == torch.int8) + + def test_llm_skip(self): + r""" + A simple test to check if `llm_int8_skip_modules` works as expected + """ + config = BitsAndBytesConfig(load_in_8bit=True, llm_int8_skip_modules=["proj_out"]) + model_8bit = SD3Transformer2DModel.from_pretrained( + self.model_name, subfolder="transformer", quantization_config=config + ) + linear = get_some_linear_layer(model_8bit) + self.assertTrue(linear.weight.dtype == torch.int8) + self.assertTrue(isinstance(linear, bnb.nn.Linear8bitLt)) + + self.assertTrue(isinstance(model_8bit.proj_out, nn.Linear)) + self.assertTrue(model_8bit.proj_out.weight.dtype != torch.int8) + + def test_config_from_pretrained(self): + transformer_8bit = FluxTransformer2DModel.from_pretrained( + "sayakpaul/flux.1-dev-int8-pkg", subfolder="transformer" + ) + linear = get_some_linear_layer(transformer_8bit) + self.assertTrue(linear.weight.__class__ == bnb.nn.Int8Params) + self.assertTrue(hasattr(linear.weight, "SCB")) + + def test_device_and_dtype_assignment(self): + r""" + Test whether trying to cast (or assigning a device to) a model after converting it in 8-bit will throw an error. + Checks also if other models are casted correctly. + """ + with self.assertRaises(ValueError): + # Tries with `str` + self.model_8bit.to("cpu") + + with self.assertRaises(ValueError): + # Tries with a `dtype`` + self.model_8bit.to(torch.float16) + + with self.assertRaises(ValueError): + # Tries with a `device` + self.model_8bit.to(torch.device("cuda:0")) + + with self.assertRaises(ValueError): + # Tries with a `device` + self.model_8bit.float() + + with self.assertRaises(ValueError): + # Tries with a `device` + self.model_8bit.half() + + # Test if we did not break anything + self.model_fp16 = self.model_fp16.to(dtype=torch.float32, device=torch_device) + input_dict_for_transformer = self.get_dummy_inputs() + model_inputs = { + k: v.to(dtype=torch.float32, device=torch_device) + for k, v in input_dict_for_transformer.items() + if not isinstance(v, bool) + } + model_inputs.update({k: v for k, v in input_dict_for_transformer.items() if k not in model_inputs}) + with torch.no_grad(): + _ = self.model_fp16(**model_inputs) + + # Check this does not throw an error + _ = self.model_fp16.to("cpu") + + # Check this does not throw an error + _ = self.model_fp16.half() + + # Check this does not throw an error + _ = self.model_fp16.float() + + # Check that this does not throw an error + _ = self.model_fp16.cuda() + + +class BnB8bitTrainingTests(Base8bitTests): + def setUp(self): + mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True) + self.model_8bit = SD3Transformer2DModel.from_pretrained( + self.model_name, subfolder="transformer", quantization_config=mixed_int8_config + ) + + def test_training(self): + # Step 1: freeze all parameters + for param in self.model_8bit.parameters(): + param.requires_grad = False # freeze the model - train adapters later + if param.ndim == 1: + # cast the small parameters (e.g. layernorm) to fp32 for stability + param.data = param.data.to(torch.float32) + + # Step 2: add adapters + for _, module in self.model_8bit.named_modules(): + if "Attention" in repr(type(module)): + module.to_k = LoRALayer(module.to_k, rank=4) + module.to_q = LoRALayer(module.to_q, rank=4) + module.to_v = LoRALayer(module.to_v, rank=4) + + # Step 3: dummy batch + input_dict_for_transformer = self.get_dummy_inputs() + model_inputs = { + k: v.to(device=torch_device) for k, v in input_dict_for_transformer.items() if not isinstance(v, bool) + } + model_inputs.update({k: v for k, v in input_dict_for_transformer.items() if k not in model_inputs}) + + # Step 4: Check if the gradient is not None + with torch.amp.autocast("cuda", dtype=torch.float16): + out = self.model_8bit(**model_inputs)[0] + out.norm().backward() + + for module in self.model_8bit.modules(): + if isinstance(module, LoRALayer): + self.assertTrue(module.adapter[1].weight.grad is not None) + self.assertTrue(module.adapter[1].weight.grad.norm().item() > 0) + + +@require_transformers_version_greater("4.44.0") +class SlowBnb8bitTests(Base8bitTests): + def setUp(self) -> None: + mixed_int8_config = BitsAndBytesConfig( + load_in_8bit=True, + bnb_8bit_quant_type="nf4", + bnb_8bit_compute_dtype=torch.float16, + ) + model_8bit = SD3Transformer2DModel.from_pretrained( + self.model_name, subfolder="transformer", quantization_config=mixed_int8_config + ) + self.pipeline_8bit = DiffusionPipeline.from_pretrained( + self.model_name, transformer=model_8bit, torch_dtype=torch.float16 + ) + self.pipeline_8bit.enable_model_cpu_offload() + + def tearDown(self): + del self.pipeline_8bit + + gc.collect() + torch.cuda.empty_cache() + + def test_quality(self): + output = self.pipeline_8bit( + prompt=self.prompt, + num_inference_steps=self.num_inference_steps, + generator=torch.manual_seed(self.seed), + output_type="np", + ).images + out_slice = output[0, -3:, -3:, -1].flatten() + expected_slice = np.array([0.0442, 0.0457, 0.0254, 0.0405, 0.0535, 0.0261, 0.0259, 0.04, 0.0452]) + + max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice) + self.assertTrue(max_diff < 1e-2) + + def test_model_cpu_offload_raises_warning(self): + model_8bit = SD3Transformer2DModel.from_pretrained( + self.model_name, subfolder="transformer", quantization_config=BitsAndBytesConfig(load_in_8bit=True) + ) + pipeline_8bit = DiffusionPipeline.from_pretrained( + self.model_name, transformer=model_8bit, torch_dtype=torch.float16 + ) + logger = logging.get_logger("diffusers.pipelines.pipeline_utils") + logger.setLevel(30) + + with CaptureLogger(logger) as cap_logger: + pipeline_8bit.enable_model_cpu_offload() + + assert "has been loaded in `bitsandbytes` 8bit" in cap_logger.out + + def test_generate_quality_dequantize(self): + r""" + Test that loading the model and unquantize it produce correct results. + """ + self.pipeline_8bit.transformer.dequantize() + output = self.pipeline_8bit( + prompt=self.prompt, + num_inference_steps=self.num_inference_steps, + generator=torch.manual_seed(self.seed), + output_type="np", + ).images + + out_slice = output[0, -3:, -3:, -1].flatten() + expected_slice = np.array([0.0266, 0.0264, 0.0271, 0.0110, 0.0310, 0.0098, 0.0078, 0.0256, 0.0208]) + max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice) + self.assertTrue(max_diff < 1e-2) + + # 8bit models cannot be offloaded to CPU. + self.assertTrue(self.pipeline_8bit.transformer.device.type == "cuda") + # calling it again shouldn't be a problem + _ = self.pipeline_8bit( + prompt=self.prompt, + num_inference_steps=2, + generator=torch.manual_seed(self.seed), + output_type="np", + ).images + + +@require_transformers_version_greater("4.44.0") +class SlowBnb8bitFluxTests(Base8bitTests): + def setUp(self) -> None: + # TODO: Copy sayakpaul/flux.1-dev-int8-pkg to testing repo. + model_id = "sayakpaul/flux.1-dev-int8-pkg" + t5_8bit = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder_2") + transformer_8bit = FluxTransformer2DModel.from_pretrained(model_id, subfolder="transformer") + self.pipeline_8bit = DiffusionPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", + text_encoder_2=t5_8bit, + transformer=transformer_8bit, + torch_dtype=torch.float16, + ) + self.pipeline_8bit.enable_model_cpu_offload() + + def tearDown(self): + del self.pipeline_8bit + + gc.collect() + torch.cuda.empty_cache() + + def test_quality(self): + # keep the resolution and max tokens to a lower number for faster execution. + output = self.pipeline_8bit( + prompt=self.prompt, + num_inference_steps=self.num_inference_steps, + generator=torch.manual_seed(self.seed), + height=256, + width=256, + max_sequence_length=64, + output_type="np", + ).images + out_slice = output[0, -3:, -3:, -1].flatten() + expected_slice = np.array([0.0574, 0.0554, 0.0581, 0.0686, 0.0676, 0.0759, 0.0757, 0.0803, 0.0930]) + + max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice) + self.assertTrue(max_diff < 1e-3) + + +@slow +class BaseBnb8bitSerializationTests(Base8bitTests): + def setUp(self): + quantization_config = BitsAndBytesConfig( + load_in_8bit=True, + ) + self.model_0 = SD3Transformer2DModel.from_pretrained( + self.model_name, subfolder="transformer", quantization_config=quantization_config + ) + + def tearDown(self): + del self.model_0 + + gc.collect() + torch.cuda.empty_cache() + + def test_serialization(self): + r""" + Test whether it is possible to serialize a model in 8-bit. Uses most typical params as default. + """ + self.assertTrue("_pre_quantization_dtype" in self.model_0.config) + with tempfile.TemporaryDirectory() as tmpdirname: + self.model_0.save_pretrained(tmpdirname) + + config = SD3Transformer2DModel.load_config(tmpdirname) + self.assertTrue("quantization_config" in config) + self.assertTrue("_pre_quantization_dtype" not in config) + + model_1 = SD3Transformer2DModel.from_pretrained(tmpdirname) + + # checking quantized linear module weight + linear = get_some_linear_layer(model_1) + self.assertTrue(linear.weight.__class__ == bnb.nn.Int8Params) + self.assertTrue(hasattr(linear.weight, "SCB")) + + # checking memory footpring + self.assertAlmostEqual(self.model_0.get_memory_footprint() / model_1.get_memory_footprint(), 1, places=2) + + # Matching all parameters and their quant_state items: + d0 = dict(self.model_0.named_parameters()) + d1 = dict(model_1.named_parameters()) + self.assertTrue(d0.keys() == d1.keys()) + + # comparing forward() outputs + dummy_inputs = self.get_dummy_inputs() + inputs = {k: v.to(torch_device) for k, v in dummy_inputs.items() if isinstance(v, torch.Tensor)} + inputs.update({k: v for k, v in dummy_inputs.items() if k not in inputs}) + out_0 = self.model_0(**inputs)[0] + out_1 = model_1(**inputs)[0] + self.assertTrue(torch.equal(out_0, out_1)) + + def test_serialization_sharded(self): + with tempfile.TemporaryDirectory() as tmpdirname: + self.model_0.save_pretrained(tmpdirname, max_shard_size="200MB") + + config = SD3Transformer2DModel.load_config(tmpdirname) + self.assertTrue("quantization_config" in config) + self.assertTrue("_pre_quantization_dtype" not in config) + + model_1 = SD3Transformer2DModel.from_pretrained(tmpdirname) + + # checking quantized linear module weight + linear = get_some_linear_layer(model_1) + self.assertTrue(linear.weight.__class__ == bnb.nn.Int8Params) + self.assertTrue(hasattr(linear.weight, "SCB")) + + # comparing forward() outputs + dummy_inputs = self.get_dummy_inputs() + inputs = {k: v.to(torch_device) for k, v in dummy_inputs.items() if isinstance(v, torch.Tensor)} + inputs.update({k: v for k, v in dummy_inputs.items() if k not in inputs}) + out_0 = self.model_0(**inputs)[0] + out_1 = model_1(**inputs)[0] + self.assertTrue(torch.equal(out_0, out_1))