Skip to content

Commit

Permalink
[FIX] Marlin runtime conversion padding (ModelCloud#192)
Browse files Browse the repository at this point in the history
* add marlin convert padding

* cleanup

* revert base.py change

* use module.qweight.device

* add infeatures and outfeatures check

* cleanup

* no padding required for C

---------

Co-authored-by: LRL-ModelCloud <[email protected]>
  • Loading branch information
LRL-ModelCloud and LRL-ModelCloud authored Jul 10, 2024
1 parent 40308cd commit 8b97466
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 12 deletions.
12 changes: 2 additions & 10 deletions gptqmodel/nn_modules/qlinear/qlinear_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,6 @@ def __init__(self, bits: int, group_size: int, desc_act: bool, sym: bool, infeat

self.group_size = group_size if group_size != -1 else infeatures

del infeatures
del outfeatures
del group_size

self.register_buffer(
"B",
torch.empty((self.infeatures // 16, self.outfeatures * 16 // 8), dtype=torch.int),
Expand Down Expand Up @@ -132,11 +128,11 @@ def pack(self, linear, scales):
w = linear.weight.data.t()

if self.infeatures != self.original_infeatures or self.outfeatures != self.original_outfeatures:
padded_w = torch.zeros((self.infeatures, self.outfeatures))
padded_w = torch.zeros((self.infeatures, self.outfeatures), dtype=w.dtype, device=w.device)
padded_w[:w.size(0), :w.size(1)] = w
w = padded_w

padded_s = torch.zeros((s.size(0), self.outfeatures))
padded_s = torch.zeros((s.size(0), self.outfeatures), dtype=torch.half, device=s.device)
padded_s[:s.size(0), :s.size(1)] = s
s = padded_s

Expand Down Expand Up @@ -184,10 +180,6 @@ def forward(self, A):
A = F.pad(A, (0, self.infeatures - self.original_infeatures))

C = torch.empty(A.shape[:-1] + (self.s.shape[1],), dtype=A.dtype, device=A.device)

if C.size(-1) != self.outfeatures:
C = F.pad(C, (0, self.outfeatures - self.original_outfeatures))

mul(
A.view((-1, A.shape[-1])),
self.B,
Expand Down
16 changes: 14 additions & 2 deletions gptqmodel/utils/marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,13 @@ def convert_to_marlin(
if repack:
import gptqmodel_marlin_cuda

marlin_repacked_weight = gptqmodel_marlin_cuda.gptq_repack(module.qweight)
qweight = module.qweight
if new_module.infeatures != new_module.original_infeatures or new_module.outfeatures != new_module.original_outfeatures:
padded_qweight = torch.zeros((new_module.infeatures, new_module.outfeatures), dtype=torch.int, device=module.qweight.device)
padded_qweight[:module.qweight.size(0), :module.qweight.size(1)] = qweight
qweight = padded_qweight

marlin_repacked_weight = gptqmodel_marlin_cuda.gptq_repack(qweight)

if strict:
dequantized_qzeros = unpack_qzeros(module.qzeros)
Expand All @@ -163,12 +169,18 @@ def convert_to_marlin(
_, _scale_perm, _scale_perm_single = _get_perms()

s = module.scales.data.clone()

if new_module.infeatures != new_module.original_infeatures or new_module.outfeatures != new_module.original_outfeatures:
padded_s = torch.zeros((s.size(0), new_module.outfeatures), dtype=torch.half, device=s.device)
padded_s[:s.size(0), :s.size(1)] = s
s = padded_s

if module.group_size != module.infeatures:
s = s.reshape((1, -1))
s = s.reshape((-1, len(_scale_perm)))[:, _scale_perm]
else:
s = s.reshape((-1, len(_scale_perm_single)))[:, _scale_perm_single]
s = s.reshape((-1, module.outfeatures)).contiguous()
s = s.reshape((-1, new_module.outfeatures)).contiguous()

new_module.B = marlin_repacked_weight
new_module.s = s
Expand Down

0 comments on commit 8b97466

Please sign in to comment.