diff --git a/gptqmodel/nn_modules/qlinear/qlinear_marlin.py b/gptqmodel/nn_modules/qlinear/qlinear_marlin.py index ca867cde..fd827007 100644 --- a/gptqmodel/nn_modules/qlinear/qlinear_marlin.py +++ b/gptqmodel/nn_modules/qlinear/qlinear_marlin.py @@ -91,10 +91,6 @@ def __init__(self, bits: int, group_size: int, desc_act: bool, sym: bool, infeat self.group_size = group_size if group_size != -1 else infeatures - del infeatures - del outfeatures - del group_size - self.register_buffer( "B", torch.empty((self.infeatures // 16, self.outfeatures * 16 // 8), dtype=torch.int), @@ -132,11 +128,11 @@ def pack(self, linear, scales): w = linear.weight.data.t() if self.infeatures != self.original_infeatures or self.outfeatures != self.original_outfeatures: - padded_w = torch.zeros((self.infeatures, self.outfeatures)) + padded_w = torch.zeros((self.infeatures, self.outfeatures), dtype=w.dtype, device=w.device) padded_w[:w.size(0), :w.size(1)] = w w = padded_w - padded_s = torch.zeros((s.size(0), self.outfeatures)) + padded_s = torch.zeros((s.size(0), self.outfeatures), dtype=torch.half, device=s.device) padded_s[:s.size(0), :s.size(1)] = s s = padded_s @@ -184,10 +180,6 @@ def forward(self, A): A = F.pad(A, (0, self.infeatures - self.original_infeatures)) C = torch.empty(A.shape[:-1] + (self.s.shape[1],), dtype=A.dtype, device=A.device) - - if C.size(-1) != self.outfeatures: - C = F.pad(C, (0, self.outfeatures - self.original_outfeatures)) - mul( A.view((-1, A.shape[-1])), self.B, diff --git a/gptqmodel/utils/marlin.py b/gptqmodel/utils/marlin.py index c098a7ce..a2509188 100644 --- a/gptqmodel/utils/marlin.py +++ b/gptqmodel/utils/marlin.py @@ -149,7 +149,13 @@ def convert_to_marlin( if repack: import gptqmodel_marlin_cuda - marlin_repacked_weight = gptqmodel_marlin_cuda.gptq_repack(module.qweight) + qweight = module.qweight + if new_module.infeatures != new_module.original_infeatures or new_module.outfeatures != new_module.original_outfeatures: + padded_qweight = torch.zeros((new_module.infeatures, new_module.outfeatures), dtype=torch.int, device=module.qweight.device) + padded_qweight[:module.qweight.size(0), :module.qweight.size(1)] = qweight + qweight = padded_qweight + + marlin_repacked_weight = gptqmodel_marlin_cuda.gptq_repack(qweight) if strict: dequantized_qzeros = unpack_qzeros(module.qzeros) @@ -163,12 +169,18 @@ def convert_to_marlin( _, _scale_perm, _scale_perm_single = _get_perms() s = module.scales.data.clone() + + if new_module.infeatures != new_module.original_infeatures or new_module.outfeatures != new_module.original_outfeatures: + padded_s = torch.zeros((s.size(0), new_module.outfeatures), dtype=torch.half, device=s.device) + padded_s[:s.size(0), :s.size(1)] = s + s = padded_s + if module.group_size != module.infeatures: s = s.reshape((1, -1)) s = s.reshape((-1, len(_scale_perm)))[:, _scale_perm] else: s = s.reshape((-1, len(_scale_perm_single)))[:, _scale_perm_single] - s = s.reshape((-1, module.outfeatures)).contiguous() + s = s.reshape((-1, new_module.outfeatures)).contiguous() new_module.B = marlin_repacked_weight new_module.s = s