-
Notifications
You must be signed in to change notification settings - Fork 5.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Quantization] Add quantization support for bitsandbytes
#9213
base: main
Are you sure you want to change the base?
Changes from 45 commits
e634ff2
02a6dff
c385a2b
0355875
e41b494
dfb33eb
e492655
6e86cc0
58a3d15
1d477f9
bd7f46d
d5d7bb6
44c8a75
6a0fcdc
e4590fa
77a1438
335ab6b
d44ef85
210fa1e
f4feee1
e8c1722
7f86a71
ba671b6
c1a9f13
4489c54
f2ca5e2
d6b8954
45029e2
4eb468a
939965d
8557166
d098d07
c4a0074
ee45612
b24c0a7
473505c
c795c82
c1d5b96
af7caca
80967f5
3bdf25a
27415cc
51cac09
15f3032
77c9fdb
ddc9f29
44c4109
27666a8
3464d83
b106124
330fa0a
abc8607
31725aa
e5938a6
444588f
d3360ce
d8b35f4
859f2d7
3b2d6e1
5799954
8e4bd08
835d4ad
27075fe
5c00c1c
5d633a0
c381fe0
3c92878
acdeb25
aa295b7
7f7c9ce
55f96d8
b28cc65
8328e86
9758942
b1a9878
971305b
f41adf1
0bcb88b
55b3696
4cb3a6d
8a03eae
53f0a92
6aab47c
9b9a610
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -25,6 +25,7 @@ | |||||
import torch | ||||||
from huggingface_hub.utils import EntryNotFoundError | ||||||
|
||||||
from ..quantizers.quantization_config import QuantizationMethod | ||||||
from ..utils import ( | ||||||
SAFE_WEIGHTS_INDEX_NAME, | ||||||
SAFETENSORS_FILE_EXTENSION, | ||||||
|
@@ -53,11 +54,36 @@ | |||||
|
||||||
|
||||||
# Adapted from `transformers` (see modeling_utils.py) | ||||||
def _determine_device_map(model: torch.nn.Module, device_map, max_memory, torch_dtype): | ||||||
def _determine_device_map( | ||||||
model: torch.nn.Module, device_map, max_memory, torch_dtype, keep_in_fp32_modules=[], hf_quantizer=None | ||||||
): | ||||||
if isinstance(device_map, str): | ||||||
special_dtypes = {} | ||||||
if hf_quantizer is not None: | ||||||
special_dtypes.update(hf_quantizer.get_special_dtypes_update(model, torch_dtype)) | ||||||
special_dtypes.update( | ||||||
{ | ||||||
name: torch.float32 | ||||||
for name, _ in model.named_parameters() | ||||||
if any(m in name for m in keep_in_fp32_modules) | ||||||
} | ||||||
) | ||||||
|
||||||
target_dtype = torch_dtype | ||||||
if hf_quantizer is not None: | ||||||
target_dtype = hf_quantizer.adjust_target_dtype(target_dtype) | ||||||
|
||||||
no_split_modules = model._get_no_split_modules(device_map) | ||||||
device_map_kwargs = {"no_split_module_classes": no_split_modules} | ||||||
|
||||||
if "special_dtypes" in inspect.signature(infer_auto_device_map).parameters: | ||||||
device_map_kwargs["special_dtypes"] = special_dtypes | ||||||
elif len(special_dtypes) > 0: | ||||||
logger.warning( | ||||||
"This model has some weights that should be kept in higher precision, you need to upgrade " | ||||||
"`accelerate` to properly deal with them (`pip install --upgrade accelerate`)." | ||||||
) | ||||||
|
||||||
if device_map != "sequential": | ||||||
max_memory = get_balanced_memory( | ||||||
model, | ||||||
|
@@ -69,8 +95,14 @@ def _determine_device_map(model: torch.nn.Module, device_map, max_memory, torch_ | |||||
else: | ||||||
max_memory = get_max_memory(max_memory) | ||||||
|
||||||
if hf_quantizer is not None: | ||||||
max_memory = hf_quantizer.adjust_max_memory(max_memory) | ||||||
|
||||||
device_map_kwargs["max_memory"] = max_memory | ||||||
device_map = infer_auto_device_map(model, dtype=torch_dtype, **device_map_kwargs) | ||||||
device_map = infer_auto_device_map(model, dtype=target_dtype, **device_map_kwargs) | ||||||
|
||||||
if hf_quantizer is not None: | ||||||
hf_quantizer.validate_environment(device_map=device_map) | ||||||
|
||||||
return device_map | ||||||
|
||||||
|
@@ -99,6 +131,8 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[ | |||||
""" | ||||||
Reads a checkpoint file, returning properly formatted errors if they arise. | ||||||
""" | ||||||
if isinstance(checkpoint_file, dict): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why are we making this change? when will There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We merge the sharded checkpoints (as stated in the PR description and mutually agreed upon internally) in case we're doing quantization:
^
and hence this change. |
||||||
return checkpoint_file | ||||||
try: | ||||||
file_extension = os.path.basename(checkpoint_file).split(".")[-1] | ||||||
if file_extension == SAFETENSORS_FILE_EXTENSION: | ||||||
|
@@ -136,29 +170,57 @@ def load_model_dict_into_meta( | |||||
device: Optional[Union[str, torch.device]] = None, | ||||||
dtype: Optional[Union[str, torch.dtype]] = None, | ||||||
model_name_or_path: Optional[str] = None, | ||||||
hf_quantizer=None, | ||||||
keep_in_fp32_modules=None, | ||||||
) -> List[str]: | ||||||
device = device or torch.device("cpu") | ||||||
device = device or torch.device("cpu") if hf_quantizer is None else device | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. More on this in the later changes. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not specific to this PR but There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Indeed. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have added a comment about it too.
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
dtype = dtype or torch.float32 | ||||||
is_quantized = hf_quantizer is not None | ||||||
|
||||||
accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys()) | ||||||
|
||||||
unexpected_keys = [] | ||||||
empty_state_dict = model.state_dict() | ||||||
unexpected_keys = [param_name for param_name in state_dict if param_name not in empty_state_dict] | ||||||
is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn") | ||||||
|
||||||
for param_name, param in state_dict.items(): | ||||||
if param_name not in empty_state_dict: | ||||||
unexpected_keys.append(param_name) | ||||||
continue | ||||||
|
||||||
if empty_state_dict[param_name].shape != param.shape: | ||||||
# We convert floating dtypes to the `dtype` passed except for float8_e4m3fn type. We also want to keep the buffers/params | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so for example:
inside this function, with current function, we would first convert it to torch.float32, then later this line will run, it would be convert back to float16 again because
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point! Added: + dtype = torch.float32
param = param.to(dtype) Also added tests ( |
||||||
# in int/uint/bool and not cast them. | ||||||
is_param_float8_e4m3fn = is_torch_e4m3fn_available and param.dtype == torch.float8_e4m3fn | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we don't yet support There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we throw an error then? |
||||||
if dtype is not None and torch.is_floating_point(param) and not is_param_float8_e4m3fn: | ||||||
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
if ( | ||||||
keep_in_fp32_modules is not None | ||||||
and any( | ||||||
module_to_keep_in_fp32 in param_name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules | ||||||
) | ||||||
and dtype == torch.float16 | ||||||
): | ||||||
param = param.to(torch.float32) | ||||||
else: | ||||||
param = param.to(dtype) | ||||||
|
||||||
is_quant_method_bnb = getattr(model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES | ||||||
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
if not is_quantized and not is_quant_method_bnb and empty_state_dict[param_name].shape != param.shape: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Because bnb quantized params are usually flattened.
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else "" | ||||||
raise ValueError( | ||||||
f"Cannot load {model_name_or_path_str}because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example." | ||||||
) | ||||||
|
||||||
if accepts_dtype: | ||||||
set_module_tensor_to_device(model, param_name, device, value=param, dtype=dtype) | ||||||
if ( | ||||||
not is_quantized | ||||||
or (not hf_quantizer.requires_parameters_quantization) | ||||||
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
or (not hf_quantizer.check_quantized_param(model, param, param_name, state_dict, param_device=device)) | ||||||
): | ||||||
if accepts_dtype: | ||||||
set_module_tensor_to_device(model, param_name, device, value=param, dtype=dtype) | ||||||
else: | ||||||
set_module_tensor_to_device(model, param_name, device, value=param) | ||||||
else: | ||||||
set_module_tensor_to_device(model, param_name, device, value=param) | ||||||
hf_quantizer.create_quantized_param(model, param, param_name, device, state_dict, unexpected_keys) | ||||||
|
||||||
return unexpected_keys | ||||||
|
||||||
|
||||||
|
@@ -228,3 +290,32 @@ def _fetch_index_file( | |||||
index_file = None | ||||||
|
||||||
return index_file | ||||||
|
||||||
|
||||||
# Adapted from | ||||||
# https://github.com/bghira/SimpleTuner/blob/cea2457ab063f6dedb9e697830ae68a96be90641/helpers/training/save_hooks.py#L64 | ||||||
def _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata): | ||||||
weight_map = sharded_metadata.get("weight_map", None) | ||||||
if weight_map is None: | ||||||
raise KeyError("'weight_map' key not found in the shard index file.") | ||||||
|
||||||
# Collect all unique safetensors files from weight_map | ||||||
files_to_load = set(weight_map.values()) | ||||||
is_safetensors = all(f.endswith(".safetensors") for f in files_to_load) | ||||||
merged_state_dict = {} | ||||||
|
||||||
# Load tensors from each unique file | ||||||
for file_name in files_to_load: | ||||||
part_file_path = os.path.join(sharded_ckpt_cached_folder, file_name) | ||||||
if not os.path.exists(part_file_path): | ||||||
raise FileNotFoundError(f"Part file {file_name} not found.") | ||||||
|
||||||
if is_safetensors: | ||||||
with safetensors.safe_open(part_file_path, framework="pt", device="cpu") as f: | ||||||
for tensor_key in f.keys(): | ||||||
if tensor_key in weight_map: | ||||||
merged_state_dict[tensor_key] = f.get_tensor(tensor_key) | ||||||
else: | ||||||
merged_state_dict.update(torch.load(part_file_path, weights_only=True, map_location="cpu")) | ||||||
|
||||||
return merged_state_dict |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because
quantization_config
isn't a part of any model's__init__()
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it is better to not add to cofig_dict if it is not going into
__init__
, i.e. at line 511There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We cannot remove
quantization_config
from the config of a model as that would prevent loading of the quantized models viafrom_pretrained()
.quantization_config
isn't used for initializing a model, it's used to determine what kind of quantization configuration to inject inside the given model. This is why it's only used infrom_pretrained()
ofModelMixin
.LMK if you have a better idea to handle it.