From 8031585c48ff3adc0fd702cc6ff8a287ebfc2ab7 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Fri, 28 Jun 2024 06:39:31 +0000 Subject: [PATCH 01/17] fix padding --- gptqmodel/nn_modules/qlinear/qlinear_exllama.py | 10 +++++----- gptqmodel/nn_modules/qlinear/qlinear_exllamav2.py | 10 +++++----- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/gptqmodel/nn_modules/qlinear/qlinear_exllama.py b/gptqmodel/nn_modules/qlinear/qlinear_exllama.py index f6396155..2e907c96 100644 --- a/gptqmodel/nn_modules/qlinear/qlinear_exllama.py +++ b/gptqmodel/nn_modules/qlinear/qlinear_exllama.py @@ -44,13 +44,13 @@ def __init__(self, bits: int, group_size: int , sym:bool, desc_act: bool, infeat super().__init__() self.validate(bits=bits, group_size=group_size, sym=sym, desc_act=desc_act) - self.padding = -outfeatures % 32 - self.outfeatures = outfeatures + self.padding - outfeatures = self.outfeatures - - self.infeatures = infeatures self.bits = bits self.group_size = group_size if group_size != -1 else infeatures + + # auto pad + self.outfeatures = outfeatures + (-outfeatures % 32) + self.infeatures = infeatures + (-infeatures % self.group_size) + self.maxq = 2**self.bits - 1 assert infeatures % 32 == 0 diff --git a/gptqmodel/nn_modules/qlinear/qlinear_exllamav2.py b/gptqmodel/nn_modules/qlinear/qlinear_exllamav2.py index cad775b6..0782cb9c 100644 --- a/gptqmodel/nn_modules/qlinear/qlinear_exllamav2.py +++ b/gptqmodel/nn_modules/qlinear/qlinear_exllamav2.py @@ -108,13 +108,13 @@ def __init__(self, bits: int, group_size: int, sym: bool, desc_act: bool, infeat self.q_handle = None self.q_tensors = None - self.padding = -outfeatures % 32 - self.outfeatures = outfeatures + self.padding - outfeatures = self.outfeatures - - self.infeatures = infeatures self.bits = bits self.group_size = group_size if group_size != -1 else infeatures + + # auto pad + self.outfeatures = outfeatures + (-outfeatures % 32) + self.infeatures = infeatures + (-infeatures % self.group_size) + self.maxq = 2**self.bits - 1 assert infeatures % 32 == 0 From 44b97a66762caf57a07f8f169ec4f28c7c77d8b4 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Fri, 28 Jun 2024 06:47:13 +0000 Subject: [PATCH 02/17] fix padding --- gptqmodel/nn_modules/qlinear/qlinear_marlin.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/gptqmodel/nn_modules/qlinear/qlinear_marlin.py b/gptqmodel/nn_modules/qlinear/qlinear_marlin.py index 5e01f22b..a844087e 100644 --- a/gptqmodel/nn_modules/qlinear/qlinear_marlin.py +++ b/gptqmodel/nn_modules/qlinear/qlinear_marlin.py @@ -78,16 +78,21 @@ def __init__(self, bits: int, group_size: int, sym: bool, desc_act: bool, infeat f'Can not use Marlin int4*fp16 kernel with a device of compute capability {torch.cuda.get_device_capability()}, the minimum compute capability is 8.0 for Marlin kernel. Please do not use `backend=Backend.MARLIN`, or please upgrade your GPU ("The more you buy, the more you save." - Taiwanese proverb).' ) + self.group_size = group_size if group_size != -1 else infeatures + if self.group_size not in [-1, 128]: + raise ValueError("Only group_size -1 and 128 are supported.") + + self.infeatures = infeatures + (-outfeatures % self.group_size) + self.outfeatures = outfeatures + (-outfeatures % 256) + if infeatures % 128 != 0 or outfeatures % 256 != 0: raise ValueError("`infeatures` must be divisible by 128 and `outfeatures` by 256.") - if group_size not in [-1, 128] and group_size != infeatures: - raise ValueError("Only group_size -1 and 128 are supported.") + + #TODO: why is this condition here? => and group_size != infeatures if infeatures % group_size != 0: - raise ValueError("`infeatures` must be divisible by `group_size`.") + raise ValueError("`infeatures` must be divisible by `group_size`.") + - self.infeatures = infeatures - self.outfeatures = outfeatures - self.group_size = group_size if group_size != -1 else infeatures self.register_buffer( "B", torch.empty((self.infeatures // 16, self.outfeatures * 16 // 8), dtype=torch.int), From 671799133320d77276d5701e87e3ddcf6afb168d Mon Sep 17 00:00:00 2001 From: Qubitium Date: Fri, 28 Jun 2024 07:10:31 +0000 Subject: [PATCH 03/17] store original in/out features --- gptqmodel/nn_modules/qlinear/qlinear_exllama.py | 4 ++++ gptqmodel/nn_modules/qlinear/qlinear_exllamav2.py | 4 ++++ gptqmodel/nn_modules/qlinear/qlinear_marlin.py | 5 +++++ 3 files changed, 13 insertions(+) diff --git a/gptqmodel/nn_modules/qlinear/qlinear_exllama.py b/gptqmodel/nn_modules/qlinear/qlinear_exllama.py index 2e907c96..ab214e15 100644 --- a/gptqmodel/nn_modules/qlinear/qlinear_exllama.py +++ b/gptqmodel/nn_modules/qlinear/qlinear_exllama.py @@ -51,6 +51,10 @@ def __init__(self, bits: int, group_size: int , sym:bool, desc_act: bool, infeat self.outfeatures = outfeatures + (-outfeatures % 32) self.infeatures = infeatures + (-infeatures % self.group_size) + # backup original values + self.original_outfeatures = outfeatures + self.original_infeatures = infeatures + self.maxq = 2**self.bits - 1 assert infeatures % 32 == 0 diff --git a/gptqmodel/nn_modules/qlinear/qlinear_exllamav2.py b/gptqmodel/nn_modules/qlinear/qlinear_exllamav2.py index 0782cb9c..6ca479a1 100644 --- a/gptqmodel/nn_modules/qlinear/qlinear_exllamav2.py +++ b/gptqmodel/nn_modules/qlinear/qlinear_exllamav2.py @@ -115,6 +115,10 @@ def __init__(self, bits: int, group_size: int, sym: bool, desc_act: bool, infeat self.outfeatures = outfeatures + (-outfeatures % 32) self.infeatures = infeatures + (-infeatures % self.group_size) + # backup original values + self.original_outfeatures = outfeatures + self.original_infeatures = infeatures + self.maxq = 2**self.bits - 1 assert infeatures % 32 == 0 diff --git a/gptqmodel/nn_modules/qlinear/qlinear_marlin.py b/gptqmodel/nn_modules/qlinear/qlinear_marlin.py index a844087e..d0a2aba1 100644 --- a/gptqmodel/nn_modules/qlinear/qlinear_marlin.py +++ b/gptqmodel/nn_modules/qlinear/qlinear_marlin.py @@ -82,9 +82,14 @@ def __init__(self, bits: int, group_size: int, sym: bool, desc_act: bool, infeat if self.group_size not in [-1, 128]: raise ValueError("Only group_size -1 and 128 are supported.") + # auto pad self.infeatures = infeatures + (-outfeatures % self.group_size) self.outfeatures = outfeatures + (-outfeatures % 256) + # backup original values + self.original_outfeatures = outfeatures + self.original_infeatures = infeatures + if infeatures % 128 != 0 or outfeatures % 256 != 0: raise ValueError("`infeatures` must be divisible by 128 and `outfeatures` by 256.") From 690908029b5dfd9573bc3780fa70b12bb0b82363 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Fri, 28 Jun 2024 09:33:00 +0000 Subject: [PATCH 04/17] fix bad var reference --- .../nn_modules/qlinear/qlinear_exllama.py | 23 +++++++++++-------- .../nn_modules/qlinear/qlinear_exllamav2.py | 23 +++++++++++-------- .../nn_modules/qlinear/qlinear_marlin.py | 13 +++++++---- 3 files changed, 37 insertions(+), 22 deletions(-) diff --git a/gptqmodel/nn_modules/qlinear/qlinear_exllama.py b/gptqmodel/nn_modules/qlinear/qlinear_exllama.py index ab214e15..237367ef 100644 --- a/gptqmodel/nn_modules/qlinear/qlinear_exllama.py +++ b/gptqmodel/nn_modules/qlinear/qlinear_exllama.py @@ -55,22 +55,27 @@ def __init__(self, bits: int, group_size: int , sym:bool, desc_act: bool, infeat self.original_outfeatures = outfeatures self.original_infeatures = infeatures + # code bug prevention + del infeatures + del outfeatures + del group_size + self.maxq = 2**self.bits - 1 - assert infeatures % 32 == 0 - assert infeatures % self.group_size == 0 - assert outfeatures % 32 == 0 + assert self.infeatures % 32 == 0 + assert self.infeatures % self.group_size == 0 + assert self.outfeatures % 32 == 0 self.register_buffer( "qweight", - torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32), + torch.zeros((self.infeatures // 32 * self.bits, self.outfeatures), dtype=torch.int32), ) self.register_buffer( "qzeros", torch.zeros( ( - math.ceil(infeatures / self.group_size), - outfeatures // 32 * self.bits, + math.ceil(self.infeatures / self.group_size), + self.outfeatures // 32 * self.bits, ), dtype=torch.int32, ), @@ -78,17 +83,17 @@ def __init__(self, bits: int, group_size: int , sym:bool, desc_act: bool, infeat self.register_buffer( "scales", torch.zeros( - (math.ceil(infeatures / self.group_size), outfeatures), + (math.ceil(self.infeatures / self.group_size), self.outfeatures), dtype=torch.float16, ), ) self.register_buffer( "g_idx", - torch.tensor([i // self.group_size for i in range(infeatures)], dtype=torch.int32), + torch.tensor([i // self.group_size for i in range(self.infeatures)], dtype=torch.int32), ) if bias: - self.register_buffer("bias", torch.zeros((outfeatures), dtype=torch.float16)) + self.register_buffer("bias", torch.zeros(self.outfeatures, dtype=torch.float16)) else: self.bias = None diff --git a/gptqmodel/nn_modules/qlinear/qlinear_exllamav2.py b/gptqmodel/nn_modules/qlinear/qlinear_exllamav2.py index 6ca479a1..c387df8c 100644 --- a/gptqmodel/nn_modules/qlinear/qlinear_exllamav2.py +++ b/gptqmodel/nn_modules/qlinear/qlinear_exllamav2.py @@ -119,23 +119,28 @@ def __init__(self, bits: int, group_size: int, sym: bool, desc_act: bool, infeat self.original_outfeatures = outfeatures self.original_infeatures = infeatures + # code bug prevention + del infeatures + del outfeatures + del group_size + self.maxq = 2**self.bits - 1 - assert infeatures % 32 == 0 - assert infeatures % self.group_size == 0 - assert outfeatures % 32 == 0 + assert self.infeatures % 32 == 0 + assert self.infeatures % self.group_size == 0 + assert self.outfeatures % 32 == 0 # I need to register the tensors, otherwise, we won't be able to load them easily using transformers ... self.register_buffer( "qweight", - torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32), + torch.zeros((self.infeatures // 32 * self.bits, self.outfeatures), dtype=torch.int32), ) self.register_buffer( "qzeros", torch.zeros( ( - math.ceil(infeatures / self.group_size), - outfeatures // 32 * self.bits, + math.ceil(self.infeatures / self.group_size), + self.outfeatures // 32 * self.bits, ), dtype=torch.int32, ), @@ -143,17 +148,17 @@ def __init__(self, bits: int, group_size: int, sym: bool, desc_act: bool, infeat self.register_buffer( "scales", torch.zeros( - (math.ceil(infeatures / self.group_size), outfeatures), + (math.ceil(self.infeatures / self.group_size), self.outfeatures), dtype=torch.float16, ), ) self.register_buffer( "g_idx", - torch.tensor([i // self.group_size for i in range(infeatures)], dtype=torch.int32), + torch.tensor([i // self.group_size for i in range(self.infeatures)], dtype=torch.int32), ) if bias: - self.register_buffer("bias", torch.zeros((outfeatures), dtype=torch.float16)) + self.register_buffer("bias", torch.zeros((self.outfeatures), dtype=torch.float16)) else: self.bias = None diff --git a/gptqmodel/nn_modules/qlinear/qlinear_marlin.py b/gptqmodel/nn_modules/qlinear/qlinear_marlin.py index d0a2aba1..c215071d 100644 --- a/gptqmodel/nn_modules/qlinear/qlinear_marlin.py +++ b/gptqmodel/nn_modules/qlinear/qlinear_marlin.py @@ -90,11 +90,16 @@ def __init__(self, bits: int, group_size: int, sym: bool, desc_act: bool, infeat self.original_outfeatures = outfeatures self.original_infeatures = infeatures - if infeatures % 128 != 0 or outfeatures % 256 != 0: + # code bug prevention + del infeatures + del outfeatures + del group_size + + if self.infeatures % 128 != 0 or self.outfeatures % 256 != 0: raise ValueError("`infeatures` must be divisible by 128 and `outfeatures` by 256.") #TODO: why is this condition here? => and group_size != infeatures - if infeatures % group_size != 0: + if self.infeatures % self.group_size != 0: raise ValueError("`infeatures` must be divisible by `group_size`.") @@ -104,7 +109,7 @@ def __init__(self, bits: int, group_size: int, sym: bool, desc_act: bool, infeat ) self.register_buffer( "s", - torch.empty((self.infeatures // group_size, self.outfeatures), dtype=torch.half), + torch.empty((self.infeatures // self.group_size, self.outfeatures), dtype=torch.half), ) # 128 is currently the minimum `tile_n`, hence it gives the maximum workspace size; 16 is the default `max_par` self.register_buffer( @@ -113,7 +118,7 @@ def __init__(self, bits: int, group_size: int, sym: bool, desc_act: bool, infeat persistent=False, ) if bias: - self.register_buffer("bias", torch.zeros((outfeatures), dtype=torch.half)) + self.register_buffer("bias", torch.zeros((self.outfeatures), dtype=torch.half)) else: self.bias = None From 32897439562c2859ac3f5dda0fa4bcc50d4835bd Mon Sep 17 00:00:00 2001 From: Qubitium Date: Fri, 28 Jun 2024 12:34:07 +0000 Subject: [PATCH 05/17] shorter var name --- gptqmodel/utils/bitblas.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/gptqmodel/utils/bitblas.py b/gptqmodel/utils/bitblas.py index 9fed21f1..2aebbee7 100644 --- a/gptqmodel/utils/bitblas.py +++ b/gptqmodel/utils/bitblas.py @@ -75,7 +75,7 @@ def prepare_model_for_bitblas_load( @torch.no_grad() -def convert_to_bitblas(model, model_quantlinear, quantization_config: QuantizeConfig, sym: bool, desc_act: bool, repack: bool, +def convert_to_bitblas(model, model_quantlinear, quant_config: QuantizeConfig, sym: bool, desc_act: bool, repack: bool, strict: bool = False): """ Converts GPTQ-packed weights to the Bitblas format. @@ -101,8 +101,8 @@ def convert_to_bitblas(model, model_quantlinear, quantization_config: QuantizeCo # from checkpoints holding zero bias. with torch.device("meta"): bitblas_module = BitBLASQuantLinear( - bits=quantization_config.bits, - group_size=quantization_config.group_size, + bits=quant_config.bits, + group_size=quant_config.group_size, sym=sym, desc_act=desc_act, infeatures=module.infeatures, @@ -124,6 +124,6 @@ def convert_to_bitblas(model, model_quantlinear, quantization_config: QuantizeCo gc.collect() # Set quantization config to be BitBLAS. - quantization_config.format = FORMAT.BITBLAS + quant_config.format = FORMAT.BITBLAS return model From 1c064fd8f74d5992d2f2d5f7a83c8a89a7a7918f Mon Sep 17 00:00:00 2001 From: Qubitium Date: Fri, 28 Jun 2024 12:37:09 +0000 Subject: [PATCH 06/17] limit bitblas convert to use 1 thread --- gptqmodel/utils/bitblas.py | 68 ++++++++++++++++++++------------------ 1 file changed, 36 insertions(+), 32 deletions(-) diff --git a/gptqmodel/utils/bitblas.py b/gptqmodel/utils/bitblas.py index 2aebbee7..3271a4e3 100644 --- a/gptqmodel/utils/bitblas.py +++ b/gptqmodel/utils/bitblas.py @@ -5,6 +5,8 @@ import torch from accelerate.utils import find_tied_parameters from tqdm import tqdm +import threadpoolctl as tctl + from ..nn_modules.qlinear.qlinear_bitblas import QuantLinear as BitBLASQuantLinear from ..quantization import FORMAT, QuantizeConfig @@ -90,38 +92,40 @@ def convert_to_bitblas(model, model_quantlinear, quant_config: QuantizeConfig, s # TODO: load directly BitBLAS QuantLinear. message = "Overriding QuantLinear layers to use BitBLAS's QuantLinear..." - for name, module in tqdm(model.named_modules(), desc=message, total=len(list(model.named_modules()))): - if not isinstance(module, model_quantlinear): - continue - - parent_name = ".".join(name.split(".")[:-1]) - layer_name = name[len(parent_name) + 1:] - - # We could use `torch.count_nonzero(module.bias) > 0` here to discard zero bias, but this has issues when loading weights - # from checkpoints holding zero bias. - with torch.device("meta"): - bitblas_module = BitBLASQuantLinear( - bits=quant_config.bits, - group_size=quant_config.group_size, - sym=sym, - desc_act=desc_act, - infeatures=module.infeatures, - outfeatures=module.outfeatures, - bias=module.bias is not None, - enable_tuning=True - ) - - # Dequantize the weight. - if repack: - bitblas_module.repack_from_gptq(module) - - # Save to parent. - parent_module = model.get_submodule(parent_name) - setattr(parent_module, layer_name, bitblas_module) - - # Free cuda memory. - del module - gc.collect() + # TODO: need to benchmark to see multiple threads help with bitblas/tvm compilation and runtime + with tctl.threadpool_limits(limits=1): + for name, module in tqdm(model.named_modules(), desc=message, total=len(list(model.named_modules()))): + if not isinstance(module, model_quantlinear): + continue + + parent_name = ".".join(name.split(".")[:-1]) + layer_name = name[len(parent_name) + 1:] + + # We could use `torch.count_nonzero(module.bias) > 0` here to discard zero bias, but this has issues when loading weights + # from checkpoints holding zero bias. + with torch.device("meta"): + bitblas_module = BitBLASQuantLinear( + bits=quant_config.bits, + group_size=quant_config.group_size, + sym=sym, + desc_act=desc_act, + infeatures=module.infeatures, + outfeatures=module.outfeatures, + bias=module.bias is not None, + enable_tuning=True + ) + + # Dequantize the weight. + if repack: + bitblas_module.repack_from_gptq(module) + + # Save to parent. + parent_module = model.get_submodule(parent_name) + setattr(parent_module, layer_name, bitblas_module) + + # Free cuda memory. + del module + gc.collect() # Set quantization config to be BitBLAS. quant_config.format = FORMAT.BITBLAS From 12552e4e063176d704011d752774af4389d6f70c Mon Sep 17 00:00:00 2001 From: Qubitium Date: Fri, 28 Jun 2024 12:37:36 +0000 Subject: [PATCH 07/17] ruff --- gptqmodel/utils/bitblas.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/gptqmodel/utils/bitblas.py b/gptqmodel/utils/bitblas.py index 3271a4e3..23dc88c7 100644 --- a/gptqmodel/utils/bitblas.py +++ b/gptqmodel/utils/bitblas.py @@ -2,11 +2,10 @@ from logging import getLogger import accelerate +import threadpoolctl as tctl import torch from accelerate.utils import find_tied_parameters from tqdm import tqdm -import threadpoolctl as tctl - from ..nn_modules.qlinear.qlinear_bitblas import QuantLinear as BitBLASQuantLinear from ..quantization import FORMAT, QuantizeConfig From c4b9ea202a8b8fc55d1138abb82aa615a3416667 Mon Sep 17 00:00:00 2001 From: LRL-ModelCloud Date: Sat, 29 Jun 2024 00:00:40 +0800 Subject: [PATCH 08/17] fix qlinear_exllama pack --- gptqmodel/nn_modules/qlinear/qlinear_exllama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gptqmodel/nn_modules/qlinear/qlinear_exllama.py b/gptqmodel/nn_modules/qlinear/qlinear_exllama.py index 237367ef..40341b26 100644 --- a/gptqmodel/nn_modules/qlinear/qlinear_exllama.py +++ b/gptqmodel/nn_modules/qlinear/qlinear_exllama.py @@ -129,7 +129,7 @@ def pack(self, linear, scales, zeros, g_idx=None): self.bias = linear.bias.clone().half() intweight = [] - for idx in range(self.infeatures): + for idx in range(self.original_infeatures): intweight.append( torch.round((W[:, idx] + scale_zeros[self.g_idx[idx]]) / self.scales[self.g_idx[idx]]).to(torch.int)[ :, None From 8670dc4ca8e14fd79afe218cc92479cde589448c Mon Sep 17 00:00:00 2001 From: LRL-ModelCloud Date: Sat, 29 Jun 2024 01:49:50 +0800 Subject: [PATCH 09/17] revert qliner_marlin change --- .../nn_modules/qlinear/qlinear_marlin.py | 36 ++++++------------- 1 file changed, 11 insertions(+), 25 deletions(-) diff --git a/gptqmodel/nn_modules/qlinear/qlinear_marlin.py b/gptqmodel/nn_modules/qlinear/qlinear_marlin.py index c215071d..657ca573 100644 --- a/gptqmodel/nn_modules/qlinear/qlinear_marlin.py +++ b/gptqmodel/nn_modules/qlinear/qlinear_marlin.py @@ -78,38 +78,24 @@ def __init__(self, bits: int, group_size: int, sym: bool, desc_act: bool, infeat f'Can not use Marlin int4*fp16 kernel with a device of compute capability {torch.cuda.get_device_capability()}, the minimum compute capability is 8.0 for Marlin kernel. Please do not use `backend=Backend.MARLIN`, or please upgrade your GPU ("The more you buy, the more you save." - Taiwanese proverb).' ) - self.group_size = group_size if group_size != -1 else infeatures - if self.group_size not in [-1, 128]: - raise ValueError("Only group_size -1 and 128 are supported.") - - # auto pad - self.infeatures = infeatures + (-outfeatures % self.group_size) - self.outfeatures = outfeatures + (-outfeatures % 256) - - # backup original values - self.original_outfeatures = outfeatures - self.original_infeatures = infeatures - - # code bug prevention - del infeatures - del outfeatures - del group_size - - if self.infeatures % 128 != 0 or self.outfeatures % 256 != 0: + if infeatures % 128 != 0 or outfeatures % 256 != 0: raise ValueError("`infeatures` must be divisible by 128 and `outfeatures` by 256.") + if group_size not in [-1, 128] and group_size != infeatures: + raise ValueError("Only group_size -1 and 128 are supported.") + # Marlin groups infeatures according to group_size, so infeatures must be an integer multiple of group_size. + if infeatures % group_size != 0: + raise ValueError("`infeatures` must be divisible by `group_size`.") - #TODO: why is this condition here? => and group_size != infeatures - if self.infeatures % self.group_size != 0: - raise ValueError("`infeatures` must be divisible by `group_size`.") - - + self.infeatures = infeatures + self.outfeatures = outfeatures + self.group_size = group_size if group_size != -1 else infeatures self.register_buffer( "B", torch.empty((self.infeatures // 16, self.outfeatures * 16 // 8), dtype=torch.int), ) self.register_buffer( "s", - torch.empty((self.infeatures // self.group_size, self.outfeatures), dtype=torch.half), + torch.empty((self.infeatures // group_size, self.outfeatures), dtype=torch.half), ) # 128 is currently the minimum `tile_n`, hence it gives the maximum workspace size; 16 is the default `max_par` self.register_buffer( @@ -118,7 +104,7 @@ def __init__(self, bits: int, group_size: int, sym: bool, desc_act: bool, infeat persistent=False, ) if bias: - self.register_buffer("bias", torch.zeros((self.outfeatures), dtype=torch.half)) + self.register_buffer("bias", torch.zeros((outfeatures), dtype=torch.half)) else: self.bias = None From f37b4ae27e25fef631728a265d51c3dbed127a01 Mon Sep 17 00:00:00 2001 From: LRL-ModelCloud Date: Sat, 29 Jun 2024 02:01:23 +0800 Subject: [PATCH 10/17] cleanup code --- gptqmodel/nn_modules/qlinear/qlinear_exllama.py | 6 ------ gptqmodel/nn_modules/qlinear/qlinear_exllamav2.py | 7 ------- 2 files changed, 13 deletions(-) diff --git a/gptqmodel/nn_modules/qlinear/qlinear_exllama.py b/gptqmodel/nn_modules/qlinear/qlinear_exllama.py index 40341b26..bb4e4cde 100644 --- a/gptqmodel/nn_modules/qlinear/qlinear_exllama.py +++ b/gptqmodel/nn_modules/qlinear/qlinear_exllama.py @@ -55,15 +55,9 @@ def __init__(self, bits: int, group_size: int , sym:bool, desc_act: bool, infeat self.original_outfeatures = outfeatures self.original_infeatures = infeatures - # code bug prevention - del infeatures - del outfeatures - del group_size - self.maxq = 2**self.bits - 1 assert self.infeatures % 32 == 0 - assert self.infeatures % self.group_size == 0 assert self.outfeatures % 32 == 0 self.register_buffer( diff --git a/gptqmodel/nn_modules/qlinear/qlinear_exllamav2.py b/gptqmodel/nn_modules/qlinear/qlinear_exllamav2.py index c387df8c..50506f9a 100644 --- a/gptqmodel/nn_modules/qlinear/qlinear_exllamav2.py +++ b/gptqmodel/nn_modules/qlinear/qlinear_exllamav2.py @@ -118,16 +118,9 @@ def __init__(self, bits: int, group_size: int, sym: bool, desc_act: bool, infeat # backup original values self.original_outfeatures = outfeatures self.original_infeatures = infeatures - - # code bug prevention - del infeatures - del outfeatures - del group_size - self.maxq = 2**self.bits - 1 assert self.infeatures % 32 == 0 - assert self.infeatures % self.group_size == 0 assert self.outfeatures % 32 == 0 # I need to register the tensors, otherwise, we won't be able to load them easily using transformers ... From d78d75ff45334afb2ba99fde37afa781dbaecb6f Mon Sep 17 00:00:00 2001 From: Qubitium Date: Fri, 28 Jun 2024 22:13:04 +0000 Subject: [PATCH 11/17] plan b: init with original shape, then model load, then do padding/resize in post_init --- .../nn_modules/qlinear/qlinear_exllama.py | 25 ++++++++++++++----- .../nn_modules/qlinear/qlinear_exllamav2.py | 25 ++++++++++++++----- 2 files changed, 38 insertions(+), 12 deletions(-) diff --git a/gptqmodel/nn_modules/qlinear/qlinear_exllama.py b/gptqmodel/nn_modules/qlinear/qlinear_exllama.py index bb4e4cde..3af23972 100644 --- a/gptqmodel/nn_modules/qlinear/qlinear_exllama.py +++ b/gptqmodel/nn_modules/qlinear/qlinear_exllama.py @@ -62,14 +62,14 @@ def __init__(self, bits: int, group_size: int , sym:bool, desc_act: bool, infeat self.register_buffer( "qweight", - torch.zeros((self.infeatures // 32 * self.bits, self.outfeatures), dtype=torch.int32), + torch.zeros((self.original_infeatures // 32 * self.bits, self.original_outfeatures), dtype=torch.int32), ) self.register_buffer( "qzeros", torch.zeros( ( - math.ceil(self.infeatures / self.group_size), - self.outfeatures // 32 * self.bits, + math.ceil(self.original_infeatures / self.group_size), + self.original_outfeatures // 32 * self.bits, ), dtype=torch.int32, ), @@ -77,17 +77,17 @@ def __init__(self, bits: int, group_size: int , sym:bool, desc_act: bool, infeat self.register_buffer( "scales", torch.zeros( - (math.ceil(self.infeatures / self.group_size), self.outfeatures), + (math.ceil(self.original_infeatures / self.group_size), self.original_outfeatures), dtype=torch.float16, ), ) self.register_buffer( "g_idx", - torch.tensor([i // self.group_size for i in range(self.infeatures)], dtype=torch.int32), + torch.tensor([i // self.group_size for i in range(self.original_infeatures)], dtype=torch.int32), ) if bias: - self.register_buffer("bias", torch.zeros(self.outfeatures, dtype=torch.float16)) + self.register_buffer("bias", torch.zeros(self.original_outfeatures, dtype=torch.float16)) else: self.bias = None @@ -95,6 +95,19 @@ def post_init(self): assert self.qweight.device.type == "cuda" assert self.qweight.device.index is not None + # resize due to padding after model weights have been loaded + if self.outfeatures != self.original_outfeatures or self.infeatures != self.original_infeatures: + self.qweight.resize_(self.infeatures // 32 * self.bits, self.outfeatures) + self.qzeros.resize_( + math.ceil(self.infeatures / self.group_size), + self.outfeatures // 32 * self.bits + ) + self.scales.resize_((math.ceil(self.infeatures / self.group_size), self.outfeatures),) + self.g_idx.resize_(i // self.group_size for i in range(self.infeatures)) + if self.bias is not None: + self.bias.resize_(self.outfeatures) + + self.width = self.qweight.shape[1] # make_q4 segfaults if g_idx is not on cpu in the act-order case. In the non act-order case, None needs to be passed for g_idx. diff --git a/gptqmodel/nn_modules/qlinear/qlinear_exllamav2.py b/gptqmodel/nn_modules/qlinear/qlinear_exllamav2.py index 50506f9a..deb61da1 100644 --- a/gptqmodel/nn_modules/qlinear/qlinear_exllamav2.py +++ b/gptqmodel/nn_modules/qlinear/qlinear_exllamav2.py @@ -126,14 +126,14 @@ def __init__(self, bits: int, group_size: int, sym: bool, desc_act: bool, infeat # I need to register the tensors, otherwise, we won't be able to load them easily using transformers ... self.register_buffer( "qweight", - torch.zeros((self.infeatures // 32 * self.bits, self.outfeatures), dtype=torch.int32), + torch.zeros((self.original_infeatures // 32 * self.bits, self.original_outfeatures), dtype=torch.int32), ) self.register_buffer( "qzeros", torch.zeros( ( - math.ceil(self.infeatures / self.group_size), - self.outfeatures // 32 * self.bits, + math.ceil(self.original_infeatures / self.group_size), + self.original_outfeatures // 32 * self.bits, ), dtype=torch.int32, ), @@ -141,23 +141,36 @@ def __init__(self, bits: int, group_size: int, sym: bool, desc_act: bool, infeat self.register_buffer( "scales", torch.zeros( - (math.ceil(self.infeatures / self.group_size), self.outfeatures), + (math.ceil(self.original_infeatures / self.group_size), self.original_outfeatures), dtype=torch.float16, ), ) self.register_buffer( "g_idx", - torch.tensor([i // self.group_size for i in range(self.infeatures)], dtype=torch.int32), + torch.tensor([i // self.group_size for i in range(self.original_infeatures)], dtype=torch.int32), ) if bias: - self.register_buffer("bias", torch.zeros((self.outfeatures), dtype=torch.float16)) + self.register_buffer("bias", torch.zeros((self.original_outfeatures), dtype=torch.float16)) else: self.bias = None def post_init(self, temp_dq): assert self.qweight.device.type == "cuda" assert self.qweight.device.index is not None + + # resize due to padding after model weights have been loaded + if self.outfeatures != self.original_outfeatures or self.infeatures != self.original_infeatures: + self.qweight.resize_(self.infeatures // 32 * self.bits, self.outfeatures) + self.qzeros.resize_( + math.ceil(self.infeatures / self.group_size), + self.outfeatures // 32 * self.bits + ) + self.scales.resize_(math.ceil(self.infeatures / self.group_size), self.outfeatures) + self.g_idx.resize_(i // self.group_size for i in range(self.infeatures)) + if self.bias is not None: + self.bias.resize_(self.outfeatures) + self.q_tensors = { "qweight": self.qweight, "qzeros": self.qzeros, From 105b6a9e71be6dbbda29676f0f8b10959d8ccbc8 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Sat, 29 Jun 2024 09:19:44 +0000 Subject: [PATCH 12/17] fix g_idx post_init --- gptqmodel/nn_modules/qlinear/qlinear_exllama.py | 2 +- gptqmodel/nn_modules/qlinear/qlinear_exllamav2.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/gptqmodel/nn_modules/qlinear/qlinear_exllama.py b/gptqmodel/nn_modules/qlinear/qlinear_exllama.py index 3af23972..6c50d76e 100644 --- a/gptqmodel/nn_modules/qlinear/qlinear_exllama.py +++ b/gptqmodel/nn_modules/qlinear/qlinear_exllama.py @@ -103,7 +103,7 @@ def post_init(self): self.outfeatures // 32 * self.bits ) self.scales.resize_((math.ceil(self.infeatures / self.group_size), self.outfeatures),) - self.g_idx.resize_(i // self.group_size for i in range(self.infeatures)) + self.g_idx = torch.tensor((i // self.group_size for i in range(self.infeatures)), dtype=torch.int32, device=self.g_idx.device) if self.bias is not None: self.bias.resize_(self.outfeatures) diff --git a/gptqmodel/nn_modules/qlinear/qlinear_exllamav2.py b/gptqmodel/nn_modules/qlinear/qlinear_exllamav2.py index deb61da1..40780d49 100644 --- a/gptqmodel/nn_modules/qlinear/qlinear_exllamav2.py +++ b/gptqmodel/nn_modules/qlinear/qlinear_exllamav2.py @@ -167,7 +167,7 @@ def post_init(self, temp_dq): self.outfeatures // 32 * self.bits ) self.scales.resize_(math.ceil(self.infeatures / self.group_size), self.outfeatures) - self.g_idx.resize_(i // self.group_size for i in range(self.infeatures)) + self.g_idx = torch.tensor((i // self.group_size for i in range(self.infeatures)), dtype=torch.int32, device=self.g_idx.device) if self.bias is not None: self.bias.resize_(self.outfeatures) From a96ef4dc8e6896d8e698fdcd3c8856423091571c Mon Sep 17 00:00:00 2001 From: Qubitium Date: Sat, 29 Jun 2024 09:24:37 +0000 Subject: [PATCH 13/17] const var reformat to all caps --- .../nn_modules/qlinear/qlinear_exllama.py | 4 +-- .../nn_modules/qlinear/qlinear_exllamav2.py | 26 +++++++++---------- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/gptqmodel/nn_modules/qlinear/qlinear_exllama.py b/gptqmodel/nn_modules/qlinear/qlinear_exllama.py index 6c50d76e..12cb67de 100644 --- a/gptqmodel/nn_modules/qlinear/qlinear_exllama.py +++ b/gptqmodel/nn_modules/qlinear/qlinear_exllama.py @@ -14,12 +14,12 @@ # Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension -none_tensor = torch.empty((1, 1), device="meta") +NON_TENSOR = torch.empty((1, 1), device="meta") def ext_make_q4(qweight, qzeros, scales, g_idx, device): """Construct Q4Matrix, return handle""" - return make_q4(qweight, qzeros, scales, g_idx if g_idx is not None else none_tensor, device) + return make_q4(qweight, qzeros, scales, g_idx if g_idx is not None else NON_TENSOR, device) def ext_q4_matmul(x, q4, q4_width): diff --git a/gptqmodel/nn_modules/qlinear/qlinear_exllamav2.py b/gptqmodel/nn_modules/qlinear/qlinear_exllamav2.py index 40780d49..e0b21c07 100644 --- a/gptqmodel/nn_modules/qlinear/qlinear_exllamav2.py +++ b/gptqmodel/nn_modules/qlinear/qlinear_exllamav2.py @@ -12,7 +12,7 @@ # Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension -none_tensor = torch.empty((1, 1), device="meta") +NONE_TENSOR = torch.empty((1, 1), device="meta") def _torch_device(idx): @@ -47,9 +47,9 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None): w["q_scale"], w["q_scale_max"], w["q_groups"], - none_tensor, - none_tensor, - none_tensor, + NONE_TENSOR, + NONE_TENSOR, + NONE_TENSOR, temp_dq, ) # GPTQ @@ -70,9 +70,9 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None): w["qweight"], w["q_perm"], w["q_invperm"], - none_tensor, - none_tensor, - none_tensor, + NONE_TENSOR, + NONE_TENSOR, + NONE_TENSOR, w["qzeros"], w["scales"], w["g_idx"].cpu(), @@ -82,14 +82,14 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None): else: return make_q_matrix( w["qweight"], - none_tensor, - none_tensor, - none_tensor, - none_tensor, - none_tensor, + NONE_TENSOR, + NONE_TENSOR, + NONE_TENSOR, + NONE_TENSOR, + NONE_TENSOR, w["qzeros"], w["scales"], - none_tensor, + NONE_TENSOR, temp_dq, ) From 90f3f544a7633d4cd60d08e075c8f85b18b73c00 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Sat, 29 Jun 2024 09:29:08 +0000 Subject: [PATCH 14/17] fix ( -> [ --- gptqmodel/nn_modules/qlinear/qlinear_exllama.py | 2 +- gptqmodel/nn_modules/qlinear/qlinear_exllamav2.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/gptqmodel/nn_modules/qlinear/qlinear_exllama.py b/gptqmodel/nn_modules/qlinear/qlinear_exllama.py index 12cb67de..c5afba30 100644 --- a/gptqmodel/nn_modules/qlinear/qlinear_exllama.py +++ b/gptqmodel/nn_modules/qlinear/qlinear_exllama.py @@ -103,7 +103,7 @@ def post_init(self): self.outfeatures // 32 * self.bits ) self.scales.resize_((math.ceil(self.infeatures / self.group_size), self.outfeatures),) - self.g_idx = torch.tensor((i // self.group_size for i in range(self.infeatures)), dtype=torch.int32, device=self.g_idx.device) + self.g_idx = torch.tensor([i // self.group_size for i in range(self.infeatures)], dtype=torch.int32, device=self.g_idx.device) if self.bias is not None: self.bias.resize_(self.outfeatures) diff --git a/gptqmodel/nn_modules/qlinear/qlinear_exllamav2.py b/gptqmodel/nn_modules/qlinear/qlinear_exllamav2.py index e0b21c07..7390202b 100644 --- a/gptqmodel/nn_modules/qlinear/qlinear_exllamav2.py +++ b/gptqmodel/nn_modules/qlinear/qlinear_exllamav2.py @@ -167,7 +167,7 @@ def post_init(self, temp_dq): self.outfeatures // 32 * self.bits ) self.scales.resize_(math.ceil(self.infeatures / self.group_size), self.outfeatures) - self.g_idx = torch.tensor((i // self.group_size for i in range(self.infeatures)), dtype=torch.int32, device=self.g_idx.device) + self.g_idx = torch.tensor([i // self.group_size for i in range(self.infeatures)], dtype=torch.int32, device=self.g_idx.device) if self.bias is not None: self.bias.resize_(self.outfeatures) From 4189817436602f5147f9c4b1753acdc44cfad65f Mon Sep 17 00:00:00 2001 From: LRL-ModelCloud Date: Sat, 29 Jun 2024 18:51:17 +0800 Subject: [PATCH 15/17] padding the x that passes in forward --- gptqmodel/nn_modules/qlinear/qlinear_exllama.py | 5 +++++ gptqmodel/nn_modules/qlinear/qlinear_exllamav2.py | 5 +++++ 2 files changed, 10 insertions(+) diff --git a/gptqmodel/nn_modules/qlinear/qlinear_exllama.py b/gptqmodel/nn_modules/qlinear/qlinear_exllama.py index c5afba30..e5be24e2 100644 --- a/gptqmodel/nn_modules/qlinear/qlinear_exllama.py +++ b/gptqmodel/nn_modules/qlinear/qlinear_exllama.py @@ -6,6 +6,7 @@ import numpy as np import torch import torch.nn as nn +import torch.nn.functional as F import transformers from gptqmodel.nn_modules.qlinear import BaseQuantLinear from gptqmodel_exllama_kernels import make_q4, q4_matmul @@ -180,6 +181,10 @@ def forward(self, x): x = x.half() + # If we padding infeatures, we also need to padding the x that passes in forward + if x.size(-1) != self.infeatures and self.infeatures > self.original_infeatures: + x = F.pad(x, (0, self.infeatures - self.original_infeatures)) + out = ext_q4_matmul(x, self.q4, self.width) if self.bias is not None: diff --git a/gptqmodel/nn_modules/qlinear/qlinear_exllamav2.py b/gptqmodel/nn_modules/qlinear/qlinear_exllamav2.py index 7390202b..25eb9d22 100644 --- a/gptqmodel/nn_modules/qlinear/qlinear_exllamav2.py +++ b/gptqmodel/nn_modules/qlinear/qlinear_exllamav2.py @@ -4,6 +4,7 @@ from logging import getLogger import torch +import torch.nn.functional as F from gptqmodel.nn_modules.qlinear import BaseQuantLinear from gptqmodel_exllamav2_kernels import gemm_half_q_half, make_q_matrix @@ -188,6 +189,10 @@ def forward(self, x, force_cuda=False): x = x.half() + # If we padding infeatures, we also need to padding the x that passes in forward + if x.size(-1) != self.infeatures and self.infeatures > self.original_infeatures: + x = F.pad(x, (0, self.infeatures - self.original_infeatures)) + output = ext_gemm_half_q_half(x, self.q_handle, self.outfeatures, force_cuda) if self.bias is not None: From 32dec09c9f4bf59329c565b63b9be177c7e9c860 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Sat, 29 Jun 2024 10:58:40 +0000 Subject: [PATCH 16/17] comments/todo --- gptqmodel/nn_modules/qlinear/qlinear_exllama.py | 3 ++- gptqmodel/nn_modules/qlinear/qlinear_exllamav2.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/gptqmodel/nn_modules/qlinear/qlinear_exllama.py b/gptqmodel/nn_modules/qlinear/qlinear_exllama.py index e5be24e2..3e5deb3d 100644 --- a/gptqmodel/nn_modules/qlinear/qlinear_exllama.py +++ b/gptqmodel/nn_modules/qlinear/qlinear_exllama.py @@ -181,7 +181,8 @@ def forward(self, x): x = x.half() - # If we padding infeatures, we also need to padding the x that passes in forward + # TODO: need to run checks to make sure there is no performance regression padding with F.pad + # if infeatures is padded, we need to pad the the input as well if x.size(-1) != self.infeatures and self.infeatures > self.original_infeatures: x = F.pad(x, (0, self.infeatures - self.original_infeatures)) diff --git a/gptqmodel/nn_modules/qlinear/qlinear_exllamav2.py b/gptqmodel/nn_modules/qlinear/qlinear_exllamav2.py index 25eb9d22..e1bfc0e4 100644 --- a/gptqmodel/nn_modules/qlinear/qlinear_exllamav2.py +++ b/gptqmodel/nn_modules/qlinear/qlinear_exllamav2.py @@ -189,7 +189,8 @@ def forward(self, x, force_cuda=False): x = x.half() - # If we padding infeatures, we also need to padding the x that passes in forward + # TODO: need to run checks to make sure there is no performance regression padding with F.pad + # if infeatures is padded, we need to pad the the input as well if x.size(-1) != self.infeatures and self.infeatures > self.original_infeatures: x = F.pad(x, (0, self.infeatures - self.original_infeatures)) From 3c2ba94acabe4e3c05523efc2dd8b071a49ecc31 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Sat, 29 Jun 2024 11:00:00 +0000 Subject: [PATCH 17/17] comments --- gptqmodel/nn_modules/qlinear/qlinear_exllama.py | 2 +- gptqmodel/nn_modules/qlinear/qlinear_exllamav2.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/gptqmodel/nn_modules/qlinear/qlinear_exllama.py b/gptqmodel/nn_modules/qlinear/qlinear_exllama.py index 3e5deb3d..40e20dac 100644 --- a/gptqmodel/nn_modules/qlinear/qlinear_exllama.py +++ b/gptqmodel/nn_modules/qlinear/qlinear_exllama.py @@ -182,7 +182,7 @@ def forward(self, x): x = x.half() # TODO: need to run checks to make sure there is no performance regression padding with F.pad - # if infeatures is padded, we need to pad the the input as well + # if infeatures is padded, we need to pad the input as well if x.size(-1) != self.infeatures and self.infeatures > self.original_infeatures: x = F.pad(x, (0, self.infeatures - self.original_infeatures)) diff --git a/gptqmodel/nn_modules/qlinear/qlinear_exllamav2.py b/gptqmodel/nn_modules/qlinear/qlinear_exllamav2.py index e1bfc0e4..d3d26ba2 100644 --- a/gptqmodel/nn_modules/qlinear/qlinear_exllamav2.py +++ b/gptqmodel/nn_modules/qlinear/qlinear_exllamav2.py @@ -190,7 +190,7 @@ def forward(self, x, force_cuda=False): x = x.half() # TODO: need to run checks to make sure there is no performance regression padding with F.pad - # if infeatures is padded, we need to pad the the input as well + # if infeatures is padded, we need to pad the input as well if x.size(-1) != self.infeatures and self.infeatures > self.original_infeatures: x = F.pad(x, (0, self.infeatures - self.original_infeatures))