diff --git a/gptqmodel/models/base.py b/gptqmodel/models/base.py index 306ed57d..92a5f836 100644 --- a/gptqmodel/models/base.py +++ b/gptqmodel/models/base.py @@ -102,9 +102,9 @@ def hf_device_map(self): return getattr(self.model, "hf_device_map", None) def _prepare_dataset_for_quantization( - self, - calibration_dataset: List[Dict[str, Union[List[int], torch.LongTensor]]], - batch_size: int = 1, + self, + calibration_dataset: List[Dict[str, Union[List[int], torch.LongTensor]]], + batch_size: int = 1, ): def _convert_tensor_to_list(tensor): if isinstance(tensor, torch.Tensor): @@ -138,7 +138,7 @@ def _convert_tensor_to_list(tensor): pad_token_id = self.config.eos_token_id new_calibration_dataset = [ - collate_data(new_calibration_dataset[start : start + batch_size], pad_token_id) + collate_data(new_calibration_dataset[start: start + batch_size], pad_token_id) for start in range(0, len(new_calibration_dataset), batch_size) ] for new_example in new_calibration_dataset: @@ -183,7 +183,7 @@ def quantize( if len(calibration_dataset) < MIN_CALIBRATION_DATASET_SIZE: logger.warning(f"Calibration dataset size should be greater than {MIN_CALIBRATION_DATASET_SIZE}. " - f"Current size: {len(calibration_dataset)}.") + f"Current size: {len(calibration_dataset)}.") # Calculate the average length of the average input_ids total_input_ids_length = 0 @@ -194,7 +194,7 @@ def quantize( if avg < MIN_CALIBRATION_DATASET_INPUT_IDS_AVG_LENGTH: logger.warning(f"The average length of input_ids of calibration_dataset should be greater than " - f"{MIN_CALIBRATION_DATASET_INPUT_IDS_AVG_LENGTH}! Current AVG is {avg}.") + f"{MIN_CALIBRATION_DATASET_INPUT_IDS_AVG_LENGTH}! Current AVG is {avg}.") device_map = self.hf_device_map if device_map: @@ -239,10 +239,7 @@ def store_input_hook(_, args, kwargs): if pos_ids is not None: position_ids.append(move_to(pos_ids, data_device)) one_kwargs = {} - for ( - k, - v, - ) in kwargs.items(): # make sure other arguments also be captured + for (k, v) in kwargs.items(): # make sure other arguments also be captured if k not in ["hidden_states", "attention_mask", "position_ids"]: one_kwargs[k] = nested_move_to(v, data_device) layer_input_kwargs.append(one_kwargs) @@ -497,8 +494,8 @@ def save_quantized( if model_base_name is None: model_base_name = ( - self.quantize_config.model_file_base_name or - f"gptq_model-{self.quantize_config.bits}bit-{self.quantize_config.group_size}g" + self.quantize_config.model_file_base_name or + f"gptq_model-{self.quantize_config.bits}bit-{self.quantize_config.group_size}g" ) state_dict = self.model.state_dict() @@ -542,7 +539,25 @@ def save_quantized( if format is None and quantize_config.format == FORMAT.GPTQ: # Model qzeros may be edited in place. # TODO: avoid inplace modification of the weights - model = copy.deepcopy(self.model) + # fix ModelCloud/GPTQModel/issues/47 + # fix gptqmodel_cuda cannot be serialized + # no need to set it back, no calculation below + if quantize_config.bits != 4: + cuda_name_modules = {} + from gptqmodel.nn_modules.qlinear.qlinear_cuda import BaseCudaQuantLinear + for name, module in model.named_modules(): + if isinstance(module, BaseCudaQuantLinear): + cuda_name_modules[name] = module.gptqmodel_cuda + module.gptqmodel_cuda = None + model = copy.deepcopy(self.model) + + for name, modules in model.named_modules(): + if isinstance(module, BaseCudaQuantLinear) and name in cuda_name_modules: + module.gptqmodel_cuda = cuda_name_modules[name] + + del cuda_name_modules + else: + model = copy.deepcopy(self.model) model = convert_gptq_v2_to_v1_format( model, quantize_config=quantize_config, qlinear_kernel=self.qlinear_kernel ) diff --git a/gptqmodel/nn_modules/qlinear/__init__.py b/gptqmodel/nn_modules/qlinear/__init__.py index 7c0eefc0..7ef6cb94 100644 --- a/gptqmodel/nn_modules/qlinear/__init__.py +++ b/gptqmodel/nn_modules/qlinear/__init__.py @@ -4,3 +4,8 @@ class BaseQuantLinear(nn.Module): # override me QUANT_TYPE = "base" + + +class BaseCudaQuantLinear(BaseQuantLinear): + # override me + QUANT_TYPE = "base-cuda" diff --git a/gptqmodel/nn_modules/qlinear/qlinear_cuda.py b/gptqmodel/nn_modules/qlinear/qlinear_cuda.py index d2919406..dbd55ada 100644 --- a/gptqmodel/nn_modules/qlinear/qlinear_cuda.py +++ b/gptqmodel/nn_modules/qlinear/qlinear_cuda.py @@ -7,12 +7,12 @@ import torch import torch.nn as nn import transformers -from gptqmodel.nn_modules.qlinear import BaseQuantLinear +from gptqmodel.nn_modules.qlinear import BaseCudaQuantLinear logger = getLogger(__name__) -class QuantLinear(BaseQuantLinear): +class QuantLinear(BaseCudaQuantLinear): QUANT_TYPE = "cuda" def __init__( diff --git a/gptqmodel/nn_modules/qlinear/qlinear_cuda_old.py b/gptqmodel/nn_modules/qlinear/qlinear_cuda_old.py index 7002dbdd..6a95ae9a 100644 --- a/gptqmodel/nn_modules/qlinear/qlinear_cuda_old.py +++ b/gptqmodel/nn_modules/qlinear/qlinear_cuda_old.py @@ -7,12 +7,12 @@ import torch import torch.nn as nn import transformers -from gptqmodel.nn_modules.qlinear import BaseQuantLinear +from gptqmodel.nn_modules.qlinear import BaseCudaQuantLinear logger = getLogger(__name__) -class QuantLinear(BaseQuantLinear): +class QuantLinear(BaseCudaQuantLinear): QUANT_TYPE = "cuda-old" def __init__( diff --git a/tests/test_quant_formats.py b/tests/test_quant_formats.py index f1b4f6eb..54ef85d5 100644 --- a/tests/test_quant_formats.py +++ b/tests/test_quant_formats.py @@ -13,6 +13,18 @@ class TestQuantization(unittest.TestCase): + + def setUp(self): + self.pretrained_model_dir = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T" + + self.tokenizer = AutoTokenizer.from_pretrained(self.pretrained_model_dir, use_fast=True) + self.calibration_dataset = [ + self.tokenizer( + "auto-gptq is an easy-to-use model quantization library with user-friendly apis, based on GPTQ algorithm." + ), + self.tokenizer("Today I am in Paris and it is a wonderful day."), + ] + @parameterized.expand( [ (False, True, FORMAT.GPTQ_V2), @@ -21,16 +33,6 @@ class TestQuantization(unittest.TestCase): ] ) def test_quantize(self, use_marlin: bool, sym: bool, format: FORMAT): - pretrained_model_dir = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T" - - tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir, use_fast=True) - calibration_dataset = [ - tokenizer( - "auto-gptq is an easy-to-use model quantization library with user-friendly apis, based on GPTQ algorithm." - ), - tokenizer("Today I am in Paris and it is a wonderful day."), - ] - quantize_config = QuantizeConfig( bits=4, group_size=128, @@ -40,17 +42,15 @@ def test_quantize(self, use_marlin: bool, sym: bool, format: FORMAT): ) model = GPTQModel.from_pretrained( - pretrained_model_dir, + self.pretrained_model_dir, quantize_config=quantize_config, use_flash_attention_2=False, ) - model.quantize(calibration_dataset) + model.quantize(self.calibration_dataset) with tempfile.TemporaryDirectory() as tmpdirname: - model.save_pretrained( - tmpdirname, - ) + model.save_quantized(tmpdirname) logging.info(f"Saved config mem: {model.quantize_config}") @@ -117,3 +117,27 @@ def test_quantize(self, use_marlin: bool, sym: bool, format: FORMAT): format=format, ) assert isinstance(model.quantize_config, QuantizeConfig) + + def test_gptq_8bit(self): + quantize_config = QuantizeConfig( + bits=8, + group_size=128, + format=FORMAT.GPTQ, + desc_act=True + ) + + model = GPTQModel.from_pretrained( + self.pretrained_model_dir, + quantize_config=quantize_config, + ) + + model.quantize(self.calibration_dataset) + + with tempfile.TemporaryDirectory() as tmpdirname: + err = None + try: + model.save_quantized(tmpdirname) + except Exception as e: + print(e) + err = e + self.assertTrue(err is None)