Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Quantization] Add quantization support for bitsandbytes #9213

Open
wants to merge 84 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 45 commits
Commits
Show all changes
84 commits
Select commit Hold shift + click to select a range
e634ff2
quantization config.
sayakpaul Aug 19, 2024
02a6dff
fix-copies
sayakpaul Aug 19, 2024
c385a2b
Merge branch 'main' into quantization-config
sayakpaul Aug 20, 2024
0355875
Merge branch 'main' into quantization-config
sayakpaul Aug 20, 2024
e41b494
Merge branch 'main' into quantization-config
sayakpaul Aug 20, 2024
dfb33eb
Merge branch 'main' into quantization-config
sayakpaul Aug 21, 2024
e492655
Merge branch 'main' into quantization-config
sayakpaul Aug 22, 2024
6e86cc0
fix
sayakpaul Aug 22, 2024
58a3d15
modules_to_not_convert
sayakpaul Aug 22, 2024
1d477f9
Merge branch 'main' into quantization-config
sayakpaul Aug 22, 2024
bd7f46d
Merge branch 'main' into quantization-config
sayakpaul Aug 23, 2024
d5d7bb6
Merge branch 'main' into quantization-config
sayakpaul Aug 28, 2024
44c8a75
Merge branch 'main' into quantization-config
sayakpaul Aug 28, 2024
6a0fcdc
add bitsandbytes utilities.
sayakpaul Aug 28, 2024
e4590fa
make progress.
sayakpaul Aug 28, 2024
77a1438
Merge branch 'main' into quantization-config
sayakpaul Aug 29, 2024
335ab6b
fixes
sayakpaul Aug 29, 2024
d44ef85
quality
sayakpaul Aug 29, 2024
210fa1e
up
sayakpaul Aug 29, 2024
f4feee1
up
sayakpaul Aug 29, 2024
e8c1722
Merge branch 'main' into quantization-config
sayakpaul Aug 29, 2024
7f86a71
Merge branch 'main' into quantization-config
sayakpaul Aug 29, 2024
ba671b6
minor
sayakpaul Aug 30, 2024
c1a9f13
up
sayakpaul Aug 30, 2024
4489c54
Merge branch 'main' into quantization-config
sayakpaul Aug 30, 2024
f2ca5e2
up
sayakpaul Aug 30, 2024
d6b8954
fix
sayakpaul Aug 30, 2024
45029e2
provide credits where due.
sayakpaul Aug 30, 2024
4eb468a
make configurations work.
sayakpaul Aug 30, 2024
939965d
fixes
sayakpaul Aug 30, 2024
8557166
Merge branch 'main' into quantization-config
sayakpaul Aug 30, 2024
d098d07
fix
sayakpaul Aug 30, 2024
c4a0074
update_missing_keys
sayakpaul Aug 30, 2024
ee45612
fix
sayakpaul Aug 30, 2024
b24c0a7
fix
sayakpaul Aug 31, 2024
473505c
make it work.
sayakpaul Aug 31, 2024
c795c82
fix
sayakpaul Aug 31, 2024
c1d5b96
Merge branch 'main' into quantization-config
sayakpaul Aug 31, 2024
af7caca
provide credits to transformers.
sayakpaul Aug 31, 2024
80967f5
empty commit
sayakpaul Sep 1, 2024
3bdf25a
handle to() better.
sayakpaul Sep 2, 2024
27415cc
tests
sayakpaul Sep 2, 2024
51cac09
change to bnb from bitsandbytes
sayakpaul Sep 2, 2024
15f3032
fix tests
sayakpaul Sep 2, 2024
77c9fdb
better safeguard.
sayakpaul Sep 2, 2024
ddc9f29
change merging status
sayakpaul Sep 2, 2024
44c4109
courtesy to transformers.
sayakpaul Sep 2, 2024
27666a8
move upper.
sayakpaul Sep 2, 2024
3464d83
better
sayakpaul Sep 2, 2024
b106124
Merge branch 'main' into quantization-config
sayakpaul Sep 2, 2024
330fa0a
Merge branch 'main' into quantization-config
sayakpaul Sep 2, 2024
abc8607
make the unused kwargs warning friendlier.
sayakpaul Sep 3, 2024
31725aa
harmonize changes with https://github.com/huggingface/transformers/pu…
sayakpaul Sep 3, 2024
e5938a6
style
sayakpaul Sep 3, 2024
444588f
trainin tests
sayakpaul Sep 3, 2024
d3360ce
Merge branch 'main' into quantization-config
sayakpaul Sep 3, 2024
d8b35f4
Merge branch 'main' into quantization-config
sayakpaul Sep 3, 2024
859f2d7
Merge branch 'main' into quantization-config
sayakpaul Sep 4, 2024
3b2d6e1
feedback part i.
sayakpaul Sep 4, 2024
5799954
Add Flux inpainting and Flux Img2Img (#9135)
Gothos Sep 4, 2024
8e4bd08
Revert "Add Flux inpainting and Flux Img2Img (#9135)"
sayakpaul Sep 6, 2024
835d4ad
tests
sayakpaul Sep 6, 2024
27075fe
don
sayakpaul Sep 6, 2024
5c00c1c
Merge branch 'main' into quantization-config
sayakpaul Sep 6, 2024
5d633a0
Merge branch 'main' into quantization-config
sayakpaul Sep 8, 2024
c381fe0
Apply suggestions from code review
sayakpaul Sep 10, 2024
3c92878
Merge branch 'main' into quantization-config
sayakpaul Sep 10, 2024
acdeb25
contribution guide.
sayakpaul Sep 11, 2024
aa295b7
Merge branch 'main' into quantization-config
sayakpaul Sep 11, 2024
7f7c9ce
Merge branch 'main' into quantization-config
sayakpaul Sep 15, 2024
55f96d8
Merge branch 'main' into quantization-config
sayakpaul Sep 15, 2024
b28cc65
changes
sayakpaul Sep 17, 2024
8328e86
Merge branch 'main' into quantization-config
sayakpaul Sep 17, 2024
9758942
empty
sayakpaul Sep 17, 2024
b1a9878
fix tests
sayakpaul Sep 17, 2024
971305b
harmonize with https://github.com/huggingface/transformers/pull/33546.
sayakpaul Sep 18, 2024
f41adf1
numpy_cosine_distance
sayakpaul Sep 19, 2024
0bcb88b
Merge branch 'main' into quantization-config
sayakpaul Sep 19, 2024
55b3696
Merge branch 'main' into quantization-config
sayakpaul Sep 20, 2024
4cb3a6d
Merge branch 'main' into quantization-config
sayakpaul Sep 23, 2024
8a03eae
Merge branch 'main' into quantization-config
sayakpaul Sep 24, 2024
53f0a92
Merge branch 'main' into quantization-config
sayakpaul Sep 26, 2024
6aab47c
Merge branch 'main' into quantization-config
sayakpaul Sep 27, 2024
9b9a610
resolved conflicts,
sayakpaul Sep 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
"loaders": ["FromOriginalModelMixin"],
"models": [],
"pipelines": [],
"quantizers.quantization_config": ["BitsAndBytesConfig"],
"schedulers": [],
"utils": [
"OptionalDependencyNotAvailable",
Expand Down Expand Up @@ -123,7 +124,6 @@
"VQModel",
]
)

_import_structure["optimization"] = [
"get_constant_schedule",
"get_constant_schedule_with_warmup",
Expand Down Expand Up @@ -155,6 +155,7 @@
"StableDiffusionMixin",
]
)
_import_structure["quantizers"] = ["DiffusersQuantizer"]
_import_structure["schedulers"].extend(
[
"AmusedScheduler",
Expand Down Expand Up @@ -526,6 +527,7 @@

if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .configuration_utils import ConfigMixin
from .quantizers.quantization_config import BitsAndBytesConfig

try:
if not is_onnx_available():
Expand Down Expand Up @@ -619,6 +621,7 @@
ScoreSdeVePipeline,
StableDiffusionMixin,
)
from .quantizers import DiffusersQuantizer
from .schedulers import (
AmusedScheduler,
CMStochasticIterativeScheduler,
Expand Down
13 changes: 12 additions & 1 deletion src/diffusers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,8 @@ def extract_init_dict(cls, config_dict, **kwargs):
init_dict[key] = config_dict.pop(key)

# 4. Give nice warning if unexpected values have been passed
if len(config_dict) > 0:
only_quant_config_remaining = len(config_dict) == 1 and "quantization_config" in config_dict
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because quantization_config isn't a part of any model's __init__().

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is better to not add to cofig_dict if it is not going into __init__, i.e. at line 511

 # remove private attributes
 config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")}
# remove quantization_config
 config_dict = {k: v for k, v in config_dict.items() if k != "quantization_config")}

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We cannot remove quantization_config from the config of a model as that would prevent loading of the quantized models via from_pretrained().

quantization_config isn't used for initializing a model, it's used to determine what kind of quantization configuration to inject inside the given model. This is why it's only used in from_pretrained() of ModelMixin.

LMK if you have a better idea to handle it.

if len(config_dict) > 0 and not only_quant_config_remaining:
logger.warning(
f"The config attributes {config_dict} were passed to {cls.__name__}, "
"but are not expected and will be ignored. Please verify your "
Expand Down Expand Up @@ -586,10 +587,20 @@ def to_json_saveable(value):
value = value.as_posix()
return value

# IFWatermarker, for example, doesn't have a `config`.
if hasattr(self, "config") and "quantization_config" in self.config:
config_dict["quantization_config"] = (
self.config.quantization_config.to_dict()
if not isinstance(self.config.quantization_config, dict)
else self.config.quantization_config
)
sayakpaul marked this conversation as resolved.
Show resolved Hide resolved

config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()}
# Don't save "_ignore_files" or "_use_default_values"
config_dict.pop("_ignore_files", None)
config_dict.pop("_use_default_values", None)
# pop the `_pre_quantization_dtype` as torch.dtypes are not serializable.
_ = config_dict.pop("_pre_quantization_dtype", None)

return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"

Expand Down
109 changes: 100 additions & 9 deletions src/diffusers/models/model_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import torch
from huggingface_hub.utils import EntryNotFoundError

from ..quantizers.quantization_config import QuantizationMethod
from ..utils import (
SAFE_WEIGHTS_INDEX_NAME,
SAFETENSORS_FILE_EXTENSION,
Expand Down Expand Up @@ -53,11 +54,36 @@


# Adapted from `transformers` (see modeling_utils.py)
def _determine_device_map(model: torch.nn.Module, device_map, max_memory, torch_dtype):
def _determine_device_map(
model: torch.nn.Module, device_map, max_memory, torch_dtype, keep_in_fp32_modules=[], hf_quantizer=None
):
if isinstance(device_map, str):
special_dtypes = {}
if hf_quantizer is not None:
special_dtypes.update(hf_quantizer.get_special_dtypes_update(model, torch_dtype))
special_dtypes.update(
{
name: torch.float32
for name, _ in model.named_parameters()
if any(m in name for m in keep_in_fp32_modules)
}
)

target_dtype = torch_dtype
if hf_quantizer is not None:
target_dtype = hf_quantizer.adjust_target_dtype(target_dtype)

no_split_modules = model._get_no_split_modules(device_map)
device_map_kwargs = {"no_split_module_classes": no_split_modules}

if "special_dtypes" in inspect.signature(infer_auto_device_map).parameters:
device_map_kwargs["special_dtypes"] = special_dtypes
elif len(special_dtypes) > 0:
logger.warning(
"This model has some weights that should be kept in higher precision, you need to upgrade "
"`accelerate` to properly deal with them (`pip install --upgrade accelerate`)."
)

if device_map != "sequential":
max_memory = get_balanced_memory(
model,
Expand All @@ -69,8 +95,14 @@ def _determine_device_map(model: torch.nn.Module, device_map, max_memory, torch_
else:
max_memory = get_max_memory(max_memory)

if hf_quantizer is not None:
max_memory = hf_quantizer.adjust_max_memory(max_memory)

device_map_kwargs["max_memory"] = max_memory
device_map = infer_auto_device_map(model, dtype=torch_dtype, **device_map_kwargs)
device_map = infer_auto_device_map(model, dtype=target_dtype, **device_map_kwargs)

if hf_quantizer is not None:
hf_quantizer.validate_environment(device_map=device_map)

return device_map

Expand Down Expand Up @@ -99,6 +131,8 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[
"""
Reads a checkpoint file, returning properly formatted errors if they arise.
"""
if isinstance(checkpoint_file, dict):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are we making this change? when will checkpoint_file passed as a dict?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We merge the sharded checkpoints (as stated in the PR description and mutually agreed upon internally) in case we're doing quantization:

model_file = _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata)

^ model_file becomes a state dict which is loaded by load_state_dict:

state_dict = load_state_dict(model_file, variant=variant)

and hence this change.

return checkpoint_file
try:
file_extension = os.path.basename(checkpoint_file).split(".")[-1]
if file_extension == SAFETENSORS_FILE_EXTENSION:
Expand Down Expand Up @@ -136,29 +170,57 @@ def load_model_dict_into_meta(
device: Optional[Union[str, torch.device]] = None,
dtype: Optional[Union[str, torch.dtype]] = None,
model_name_or_path: Optional[str] = None,
hf_quantizer=None,
keep_in_fp32_modules=None,
) -> List[str]:
device = device or torch.device("cpu")
device = device or torch.device("cpu") if hf_quantizer is None else device
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More on this in the later changes.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not specific to this PR but device = device or torch.device("cpu") is a bit dangerous because theoretically, 0 is a valid device but it would be considered falsy. AFAICT it's not problematic for the existing code, but something to keep in mind.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added a comment about it too.

sayakpaul marked this conversation as resolved.
Show resolved Hide resolved
dtype = dtype or torch.float32
is_quantized = hf_quantizer is not None

accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())

unexpected_keys = []
empty_state_dict = model.state_dict()
unexpected_keys = [param_name for param_name in state_dict if param_name not in empty_state_dict]
is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn")

for param_name, param in state_dict.items():
if param_name not in empty_state_dict:
unexpected_keys.append(param_name)
continue

if empty_state_dict[param_name].shape != param.shape:
# We convert floating dtypes to the `dtype` passed except for float8_e4m3fn type. We also want to keep the buffers/params
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so for example:

  1. dtype = torch.float16
  2. is_quantized = False
  3. the module is one of the modules that we included in keep_in_fp32_modules

inside this function, with current function, we would first convert it to torch.float32, then later this line will run, it would be convert back to float16 again because dtype here is still torch.float16 - I don't think it is expected behavior

            if accepts_dtype:
                set_module_tensor_to_device(model, param_name, device, value=param, dtype=dtype)
            else:
                set_module_tensor_to_device(model, param_name, device, value=param)

Copy link
Member Author

@sayakpaul sayakpaul Sep 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point!

Added:

+                dtype = torch.float32
                param = param.to(dtype)

Also added tests (test_keep_modules_in_fp32) to ensure effectiveness.

# in int/uint/bool and not cast them.
is_param_float8_e4m3fn = is_torch_e4m3fn_available and param.dtype == torch.float8_e4m3fn
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't yet support param.dtype == torch.float8_e4m3fn, no?
let's not add this for now then

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we throw an error then?

if dtype is not None and torch.is_floating_point(param) and not is_param_float8_e4m3fn:
sayakpaul marked this conversation as resolved.
Show resolved Hide resolved
if (
keep_in_fp32_modules is not None
and any(
module_to_keep_in_fp32 in param_name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules
)
and dtype == torch.float16
):
param = param.to(torch.float32)
else:
param = param.to(dtype)

is_quant_method_bnb = getattr(model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES
sayakpaul marked this conversation as resolved.
Show resolved Hide resolved
if not is_quantized and not is_quant_method_bnb and empty_state_dict[param_name].shape != param.shape:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because bnb quantized params are usually flattened.

sayakpaul marked this conversation as resolved.
Show resolved Hide resolved
model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else ""
raise ValueError(
f"Cannot load {model_name_or_path_str}because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
)

if accepts_dtype:
set_module_tensor_to_device(model, param_name, device, value=param, dtype=dtype)
if (
not is_quantized
or (not hf_quantizer.requires_parameters_quantization)
sayakpaul marked this conversation as resolved.
Show resolved Hide resolved
or (not hf_quantizer.check_quantized_param(model, param, param_name, state_dict, param_device=device))
):
if accepts_dtype:
set_module_tensor_to_device(model, param_name, device, value=param, dtype=dtype)
else:
set_module_tensor_to_device(model, param_name, device, value=param)
else:
set_module_tensor_to_device(model, param_name, device, value=param)
hf_quantizer.create_quantized_param(model, param, param_name, device, state_dict, unexpected_keys)

return unexpected_keys


Expand Down Expand Up @@ -228,3 +290,32 @@ def _fetch_index_file(
index_file = None

return index_file


# Adapted from
# https://github.com/bghira/SimpleTuner/blob/cea2457ab063f6dedb9e697830ae68a96be90641/helpers/training/save_hooks.py#L64
def _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata):
weight_map = sharded_metadata.get("weight_map", None)
if weight_map is None:
raise KeyError("'weight_map' key not found in the shard index file.")

# Collect all unique safetensors files from weight_map
files_to_load = set(weight_map.values())
is_safetensors = all(f.endswith(".safetensors") for f in files_to_load)
merged_state_dict = {}

# Load tensors from each unique file
for file_name in files_to_load:
part_file_path = os.path.join(sharded_ckpt_cached_folder, file_name)
if not os.path.exists(part_file_path):
raise FileNotFoundError(f"Part file {file_name} not found.")

if is_safetensors:
with safetensors.safe_open(part_file_path, framework="pt", device="cpu") as f:
for tensor_key in f.keys():
if tensor_key in weight_map:
merged_state_dict[tensor_key] = f.get_tensor(tensor_key)
else:
merged_state_dict.update(torch.load(part_file_path, weights_only=True, map_location="cpu"))

return merged_state_dict
Loading
Loading