diff --git a/gptqmodel/nn_modules/qlinear/qlinear_exllama.py b/gptqmodel/nn_modules/qlinear/qlinear_exllama.py index f6396155..40e20dac 100644 --- a/gptqmodel/nn_modules/qlinear/qlinear_exllama.py +++ b/gptqmodel/nn_modules/qlinear/qlinear_exllama.py @@ -6,6 +6,7 @@ import numpy as np import torch import torch.nn as nn +import torch.nn.functional as F import transformers from gptqmodel.nn_modules.qlinear import BaseQuantLinear from gptqmodel_exllama_kernels import make_q4, q4_matmul @@ -14,12 +15,12 @@ # Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension -none_tensor = torch.empty((1, 1), device="meta") +NON_TENSOR = torch.empty((1, 1), device="meta") def ext_make_q4(qweight, qzeros, scales, g_idx, device): """Construct Q4Matrix, return handle""" - return make_q4(qweight, qzeros, scales, g_idx if g_idx is not None else none_tensor, device) + return make_q4(qweight, qzeros, scales, g_idx if g_idx is not None else NON_TENSOR, device) def ext_q4_matmul(x, q4, q4_width): @@ -44,29 +45,32 @@ def __init__(self, bits: int, group_size: int , sym:bool, desc_act: bool, infeat super().__init__() self.validate(bits=bits, group_size=group_size, sym=sym, desc_act=desc_act) - self.padding = -outfeatures % 32 - self.outfeatures = outfeatures + self.padding - outfeatures = self.outfeatures - - self.infeatures = infeatures self.bits = bits self.group_size = group_size if group_size != -1 else infeatures + + # auto pad + self.outfeatures = outfeatures + (-outfeatures % 32) + self.infeatures = infeatures + (-infeatures % self.group_size) + + # backup original values + self.original_outfeatures = outfeatures + self.original_infeatures = infeatures + self.maxq = 2**self.bits - 1 - assert infeatures % 32 == 0 - assert infeatures % self.group_size == 0 - assert outfeatures % 32 == 0 + assert self.infeatures % 32 == 0 + assert self.outfeatures % 32 == 0 self.register_buffer( "qweight", - torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32), + torch.zeros((self.original_infeatures // 32 * self.bits, self.original_outfeatures), dtype=torch.int32), ) self.register_buffer( "qzeros", torch.zeros( ( - math.ceil(infeatures / self.group_size), - outfeatures // 32 * self.bits, + math.ceil(self.original_infeatures / self.group_size), + self.original_outfeatures // 32 * self.bits, ), dtype=torch.int32, ), @@ -74,17 +78,17 @@ def __init__(self, bits: int, group_size: int , sym:bool, desc_act: bool, infeat self.register_buffer( "scales", torch.zeros( - (math.ceil(infeatures / self.group_size), outfeatures), + (math.ceil(self.original_infeatures / self.group_size), self.original_outfeatures), dtype=torch.float16, ), ) self.register_buffer( "g_idx", - torch.tensor([i // self.group_size for i in range(infeatures)], dtype=torch.int32), + torch.tensor([i // self.group_size for i in range(self.original_infeatures)], dtype=torch.int32), ) if bias: - self.register_buffer("bias", torch.zeros((outfeatures), dtype=torch.float16)) + self.register_buffer("bias", torch.zeros(self.original_outfeatures, dtype=torch.float16)) else: self.bias = None @@ -92,6 +96,19 @@ def post_init(self): assert self.qweight.device.type == "cuda" assert self.qweight.device.index is not None + # resize due to padding after model weights have been loaded + if self.outfeatures != self.original_outfeatures or self.infeatures != self.original_infeatures: + self.qweight.resize_(self.infeatures // 32 * self.bits, self.outfeatures) + self.qzeros.resize_( + math.ceil(self.infeatures / self.group_size), + self.outfeatures // 32 * self.bits + ) + self.scales.resize_((math.ceil(self.infeatures / self.group_size), self.outfeatures),) + self.g_idx = torch.tensor([i // self.group_size for i in range(self.infeatures)], dtype=torch.int32, device=self.g_idx.device) + if self.bias is not None: + self.bias.resize_(self.outfeatures) + + self.width = self.qweight.shape[1] # make_q4 segfaults if g_idx is not on cpu in the act-order case. In the non act-order case, None needs to be passed for g_idx. @@ -120,7 +137,7 @@ def pack(self, linear, scales, zeros, g_idx=None): self.bias = linear.bias.clone().half() intweight = [] - for idx in range(self.infeatures): + for idx in range(self.original_infeatures): intweight.append( torch.round((W[:, idx] + scale_zeros[self.g_idx[idx]]) / self.scales[self.g_idx[idx]]).to(torch.int)[ :, None @@ -164,6 +181,11 @@ def forward(self, x): x = x.half() + # TODO: need to run checks to make sure there is no performance regression padding with F.pad + # if infeatures is padded, we need to pad the input as well + if x.size(-1) != self.infeatures and self.infeatures > self.original_infeatures: + x = F.pad(x, (0, self.infeatures - self.original_infeatures)) + out = ext_q4_matmul(x, self.q4, self.width) if self.bias is not None: diff --git a/gptqmodel/nn_modules/qlinear/qlinear_exllamav2.py b/gptqmodel/nn_modules/qlinear/qlinear_exllamav2.py index cad775b6..d3d26ba2 100644 --- a/gptqmodel/nn_modules/qlinear/qlinear_exllamav2.py +++ b/gptqmodel/nn_modules/qlinear/qlinear_exllamav2.py @@ -4,6 +4,7 @@ from logging import getLogger import torch +import torch.nn.functional as F from gptqmodel.nn_modules.qlinear import BaseQuantLinear from gptqmodel_exllamav2_kernels import gemm_half_q_half, make_q_matrix @@ -12,7 +13,7 @@ # Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension -none_tensor = torch.empty((1, 1), device="meta") +NONE_TENSOR = torch.empty((1, 1), device="meta") def _torch_device(idx): @@ -47,9 +48,9 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None): w["q_scale"], w["q_scale_max"], w["q_groups"], - none_tensor, - none_tensor, - none_tensor, + NONE_TENSOR, + NONE_TENSOR, + NONE_TENSOR, temp_dq, ) # GPTQ @@ -70,9 +71,9 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None): w["qweight"], w["q_perm"], w["q_invperm"], - none_tensor, - none_tensor, - none_tensor, + NONE_TENSOR, + NONE_TENSOR, + NONE_TENSOR, w["qzeros"], w["scales"], w["g_idx"].cpu(), @@ -82,14 +83,14 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None): else: return make_q_matrix( w["qweight"], - none_tensor, - none_tensor, - none_tensor, - none_tensor, - none_tensor, + NONE_TENSOR, + NONE_TENSOR, + NONE_TENSOR, + NONE_TENSOR, + NONE_TENSOR, w["qzeros"], w["scales"], - none_tensor, + NONE_TENSOR, temp_dq, ) @@ -108,30 +109,32 @@ def __init__(self, bits: int, group_size: int, sym: bool, desc_act: bool, infeat self.q_handle = None self.q_tensors = None - self.padding = -outfeatures % 32 - self.outfeatures = outfeatures + self.padding - outfeatures = self.outfeatures - - self.infeatures = infeatures self.bits = bits self.group_size = group_size if group_size != -1 else infeatures + + # auto pad + self.outfeatures = outfeatures + (-outfeatures % 32) + self.infeatures = infeatures + (-infeatures % self.group_size) + + # backup original values + self.original_outfeatures = outfeatures + self.original_infeatures = infeatures self.maxq = 2**self.bits - 1 - assert infeatures % 32 == 0 - assert infeatures % self.group_size == 0 - assert outfeatures % 32 == 0 + assert self.infeatures % 32 == 0 + assert self.outfeatures % 32 == 0 # I need to register the tensors, otherwise, we won't be able to load them easily using transformers ... self.register_buffer( "qweight", - torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32), + torch.zeros((self.original_infeatures // 32 * self.bits, self.original_outfeatures), dtype=torch.int32), ) self.register_buffer( "qzeros", torch.zeros( ( - math.ceil(infeatures / self.group_size), - outfeatures // 32 * self.bits, + math.ceil(self.original_infeatures / self.group_size), + self.original_outfeatures // 32 * self.bits, ), dtype=torch.int32, ), @@ -139,23 +142,36 @@ def __init__(self, bits: int, group_size: int, sym: bool, desc_act: bool, infeat self.register_buffer( "scales", torch.zeros( - (math.ceil(infeatures / self.group_size), outfeatures), + (math.ceil(self.original_infeatures / self.group_size), self.original_outfeatures), dtype=torch.float16, ), ) self.register_buffer( "g_idx", - torch.tensor([i // self.group_size for i in range(infeatures)], dtype=torch.int32), + torch.tensor([i // self.group_size for i in range(self.original_infeatures)], dtype=torch.int32), ) if bias: - self.register_buffer("bias", torch.zeros((outfeatures), dtype=torch.float16)) + self.register_buffer("bias", torch.zeros((self.original_outfeatures), dtype=torch.float16)) else: self.bias = None def post_init(self, temp_dq): assert self.qweight.device.type == "cuda" assert self.qweight.device.index is not None + + # resize due to padding after model weights have been loaded + if self.outfeatures != self.original_outfeatures or self.infeatures != self.original_infeatures: + self.qweight.resize_(self.infeatures // 32 * self.bits, self.outfeatures) + self.qzeros.resize_( + math.ceil(self.infeatures / self.group_size), + self.outfeatures // 32 * self.bits + ) + self.scales.resize_(math.ceil(self.infeatures / self.group_size), self.outfeatures) + self.g_idx = torch.tensor([i // self.group_size for i in range(self.infeatures)], dtype=torch.int32, device=self.g_idx.device) + if self.bias is not None: + self.bias.resize_(self.outfeatures) + self.q_tensors = { "qweight": self.qweight, "qzeros": self.qzeros, @@ -173,6 +189,11 @@ def forward(self, x, force_cuda=False): x = x.half() + # TODO: need to run checks to make sure there is no performance regression padding with F.pad + # if infeatures is padded, we need to pad the input as well + if x.size(-1) != self.infeatures and self.infeatures > self.original_infeatures: + x = F.pad(x, (0, self.infeatures - self.original_infeatures)) + output = ext_gemm_half_q_half(x, self.q_handle, self.outfeatures, force_cuda) if self.bias is not None: diff --git a/gptqmodel/nn_modules/qlinear/qlinear_marlin.py b/gptqmodel/nn_modules/qlinear/qlinear_marlin.py index 5e01f22b..657ca573 100644 --- a/gptqmodel/nn_modules/qlinear/qlinear_marlin.py +++ b/gptqmodel/nn_modules/qlinear/qlinear_marlin.py @@ -82,6 +82,7 @@ def __init__(self, bits: int, group_size: int, sym: bool, desc_act: bool, infeat raise ValueError("`infeatures` must be divisible by 128 and `outfeatures` by 256.") if group_size not in [-1, 128] and group_size != infeatures: raise ValueError("Only group_size -1 and 128 are supported.") + # Marlin groups infeatures according to group_size, so infeatures must be an integer multiple of group_size. if infeatures % group_size != 0: raise ValueError("`infeatures` must be divisible by `group_size`.") diff --git a/gptqmodel/utils/bitblas.py b/gptqmodel/utils/bitblas.py index 9fed21f1..23dc88c7 100644 --- a/gptqmodel/utils/bitblas.py +++ b/gptqmodel/utils/bitblas.py @@ -2,6 +2,7 @@ from logging import getLogger import accelerate +import threadpoolctl as tctl import torch from accelerate.utils import find_tied_parameters from tqdm import tqdm @@ -75,7 +76,7 @@ def prepare_model_for_bitblas_load( @torch.no_grad() -def convert_to_bitblas(model, model_quantlinear, quantization_config: QuantizeConfig, sym: bool, desc_act: bool, repack: bool, +def convert_to_bitblas(model, model_quantlinear, quant_config: QuantizeConfig, sym: bool, desc_act: bool, repack: bool, strict: bool = False): """ Converts GPTQ-packed weights to the Bitblas format. @@ -90,40 +91,42 @@ def convert_to_bitblas(model, model_quantlinear, quantization_config: QuantizeCo # TODO: load directly BitBLAS QuantLinear. message = "Overriding QuantLinear layers to use BitBLAS's QuantLinear..." - for name, module in tqdm(model.named_modules(), desc=message, total=len(list(model.named_modules()))): - if not isinstance(module, model_quantlinear): - continue - - parent_name = ".".join(name.split(".")[:-1]) - layer_name = name[len(parent_name) + 1:] - - # We could use `torch.count_nonzero(module.bias) > 0` here to discard zero bias, but this has issues when loading weights - # from checkpoints holding zero bias. - with torch.device("meta"): - bitblas_module = BitBLASQuantLinear( - bits=quantization_config.bits, - group_size=quantization_config.group_size, - sym=sym, - desc_act=desc_act, - infeatures=module.infeatures, - outfeatures=module.outfeatures, - bias=module.bias is not None, - enable_tuning=True - ) - - # Dequantize the weight. - if repack: - bitblas_module.repack_from_gptq(module) - - # Save to parent. - parent_module = model.get_submodule(parent_name) - setattr(parent_module, layer_name, bitblas_module) - - # Free cuda memory. - del module - gc.collect() + # TODO: need to benchmark to see multiple threads help with bitblas/tvm compilation and runtime + with tctl.threadpool_limits(limits=1): + for name, module in tqdm(model.named_modules(), desc=message, total=len(list(model.named_modules()))): + if not isinstance(module, model_quantlinear): + continue + + parent_name = ".".join(name.split(".")[:-1]) + layer_name = name[len(parent_name) + 1:] + + # We could use `torch.count_nonzero(module.bias) > 0` here to discard zero bias, but this has issues when loading weights + # from checkpoints holding zero bias. + with torch.device("meta"): + bitblas_module = BitBLASQuantLinear( + bits=quant_config.bits, + group_size=quant_config.group_size, + sym=sym, + desc_act=desc_act, + infeatures=module.infeatures, + outfeatures=module.outfeatures, + bias=module.bias is not None, + enable_tuning=True + ) + + # Dequantize the weight. + if repack: + bitblas_module.repack_from_gptq(module) + + # Save to parent. + parent_module = model.get_submodule(parent_name) + setattr(parent_module, layer_name, bitblas_module) + + # Free cuda memory. + del module + gc.collect() # Set quantization config to be BitBLAS. - quantization_config.format = FORMAT.BITBLAS + quant_config.format = FORMAT.BITBLAS return model