Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FIX] Padding infeatures/outfeatures for exllama, exllama v2, and marlin #98

Merged
merged 19 commits into from
Jun 29, 2024
Merged
56 changes: 39 additions & 17 deletions gptqmodel/nn_modules/qlinear/qlinear_exllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import transformers
from gptqmodel.nn_modules.qlinear import BaseQuantLinear
from gptqmodel_exllama_kernels import make_q4, q4_matmul
Expand All @@ -14,12 +15,12 @@


# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
none_tensor = torch.empty((1, 1), device="meta")
NON_TENSOR = torch.empty((1, 1), device="meta")


def ext_make_q4(qweight, qzeros, scales, g_idx, device):
"""Construct Q4Matrix, return handle"""
return make_q4(qweight, qzeros, scales, g_idx if g_idx is not None else none_tensor, device)
return make_q4(qweight, qzeros, scales, g_idx if g_idx is not None else NON_TENSOR, device)


def ext_q4_matmul(x, q4, q4_width):
Expand All @@ -44,54 +45,70 @@ def __init__(self, bits: int, group_size: int , sym:bool, desc_act: bool, infeat
super().__init__()
self.validate(bits=bits, group_size=group_size, sym=sym, desc_act=desc_act)

self.padding = -outfeatures % 32
self.outfeatures = outfeatures + self.padding
outfeatures = self.outfeatures

self.infeatures = infeatures
self.bits = bits
self.group_size = group_size if group_size != -1 else infeatures

# auto pad
self.outfeatures = outfeatures + (-outfeatures % 32)
self.infeatures = infeatures + (-infeatures % self.group_size)

# backup original values
self.original_outfeatures = outfeatures
self.original_infeatures = infeatures

self.maxq = 2**self.bits - 1

assert infeatures % 32 == 0
assert infeatures % self.group_size == 0
assert outfeatures % 32 == 0
assert self.infeatures % 32 == 0
assert self.outfeatures % 32 == 0

self.register_buffer(
"qweight",
torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32),
torch.zeros((self.original_infeatures // 32 * self.bits, self.original_outfeatures), dtype=torch.int32),
)
self.register_buffer(
"qzeros",
torch.zeros(
(
math.ceil(infeatures / self.group_size),
outfeatures // 32 * self.bits,
math.ceil(self.original_infeatures / self.group_size),
self.original_outfeatures // 32 * self.bits,
),
dtype=torch.int32,
),
)
self.register_buffer(
"scales",
torch.zeros(
(math.ceil(infeatures / self.group_size), outfeatures),
(math.ceil(self.original_infeatures / self.group_size), self.original_outfeatures),
dtype=torch.float16,
),
)
self.register_buffer(
"g_idx",
torch.tensor([i // self.group_size for i in range(infeatures)], dtype=torch.int32),
torch.tensor([i // self.group_size for i in range(self.original_infeatures)], dtype=torch.int32),
)

if bias:
self.register_buffer("bias", torch.zeros((outfeatures), dtype=torch.float16))
self.register_buffer("bias", torch.zeros(self.original_outfeatures, dtype=torch.float16))
else:
self.bias = None

def post_init(self):
assert self.qweight.device.type == "cuda"
assert self.qweight.device.index is not None

# resize due to padding after model weights have been loaded
if self.outfeatures != self.original_outfeatures or self.infeatures != self.original_infeatures:
self.qweight.resize_(self.infeatures // 32 * self.bits, self.outfeatures)
self.qzeros.resize_(
math.ceil(self.infeatures / self.group_size),
self.outfeatures // 32 * self.bits
)
self.scales.resize_((math.ceil(self.infeatures / self.group_size), self.outfeatures),)
self.g_idx = torch.tensor([i // self.group_size for i in range(self.infeatures)], dtype=torch.int32, device=self.g_idx.device)
if self.bias is not None:
self.bias.resize_(self.outfeatures)


self.width = self.qweight.shape[1]

# make_q4 segfaults if g_idx is not on cpu in the act-order case. In the non act-order case, None needs to be passed for g_idx.
Expand Down Expand Up @@ -120,7 +137,7 @@ def pack(self, linear, scales, zeros, g_idx=None):
self.bias = linear.bias.clone().half()

intweight = []
for idx in range(self.infeatures):
for idx in range(self.original_infeatures):
intweight.append(
torch.round((W[:, idx] + scale_zeros[self.g_idx[idx]]) / self.scales[self.g_idx[idx]]).to(torch.int)[
:, None
Expand Down Expand Up @@ -164,6 +181,11 @@ def forward(self, x):

x = x.half()

# TODO: need to run checks to make sure there is no performance regression padding with F.pad
# if infeatures is padded, we need to pad the input as well
if x.size(-1) != self.infeatures and self.infeatures > self.original_infeatures:
x = F.pad(x, (0, self.infeatures - self.original_infeatures))

out = ext_q4_matmul(x, self.q4, self.width)

if self.bias is not None:
Expand Down
75 changes: 48 additions & 27 deletions gptqmodel/nn_modules/qlinear/qlinear_exllamav2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from logging import getLogger

import torch
import torch.nn.functional as F
from gptqmodel.nn_modules.qlinear import BaseQuantLinear
from gptqmodel_exllamav2_kernels import gemm_half_q_half, make_q_matrix

Expand All @@ -12,7 +13,7 @@


# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
none_tensor = torch.empty((1, 1), device="meta")
NONE_TENSOR = torch.empty((1, 1), device="meta")


def _torch_device(idx):
Expand Down Expand Up @@ -47,9 +48,9 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
w["q_scale"],
w["q_scale_max"],
w["q_groups"],
none_tensor,
none_tensor,
none_tensor,
NONE_TENSOR,
NONE_TENSOR,
NONE_TENSOR,
temp_dq,
)
# GPTQ
Expand All @@ -70,9 +71,9 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
w["qweight"],
w["q_perm"],
w["q_invperm"],
none_tensor,
none_tensor,
none_tensor,
NONE_TENSOR,
NONE_TENSOR,
NONE_TENSOR,
w["qzeros"],
w["scales"],
w["g_idx"].cpu(),
Expand All @@ -82,14 +83,14 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
else:
return make_q_matrix(
w["qweight"],
none_tensor,
none_tensor,
none_tensor,
none_tensor,
none_tensor,
NONE_TENSOR,
NONE_TENSOR,
NONE_TENSOR,
NONE_TENSOR,
NONE_TENSOR,
w["qzeros"],
w["scales"],
none_tensor,
NONE_TENSOR,
temp_dq,
)

Expand All @@ -108,54 +109,69 @@ def __init__(self, bits: int, group_size: int, sym: bool, desc_act: bool, infeat
self.q_handle = None
self.q_tensors = None

self.padding = -outfeatures % 32
self.outfeatures = outfeatures + self.padding
outfeatures = self.outfeatures

self.infeatures = infeatures
self.bits = bits
self.group_size = group_size if group_size != -1 else infeatures

# auto pad
self.outfeatures = outfeatures + (-outfeatures % 32)
self.infeatures = infeatures + (-infeatures % self.group_size)

# backup original values
self.original_outfeatures = outfeatures
self.original_infeatures = infeatures
self.maxq = 2**self.bits - 1

assert infeatures % 32 == 0
assert infeatures % self.group_size == 0
assert outfeatures % 32 == 0
assert self.infeatures % 32 == 0
assert self.outfeatures % 32 == 0

# I need to register the tensors, otherwise, we won't be able to load them easily using transformers ...
self.register_buffer(
"qweight",
torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32),
torch.zeros((self.original_infeatures // 32 * self.bits, self.original_outfeatures), dtype=torch.int32),
)
self.register_buffer(
"qzeros",
torch.zeros(
(
math.ceil(infeatures / self.group_size),
outfeatures // 32 * self.bits,
math.ceil(self.original_infeatures / self.group_size),
self.original_outfeatures // 32 * self.bits,
),
dtype=torch.int32,
),
)
self.register_buffer(
"scales",
torch.zeros(
(math.ceil(infeatures / self.group_size), outfeatures),
(math.ceil(self.original_infeatures / self.group_size), self.original_outfeatures),
dtype=torch.float16,
),
)
self.register_buffer(
"g_idx",
torch.tensor([i // self.group_size for i in range(infeatures)], dtype=torch.int32),
torch.tensor([i // self.group_size for i in range(self.original_infeatures)], dtype=torch.int32),
)

if bias:
self.register_buffer("bias", torch.zeros((outfeatures), dtype=torch.float16))
self.register_buffer("bias", torch.zeros((self.original_outfeatures), dtype=torch.float16))
else:
self.bias = None

def post_init(self, temp_dq):
assert self.qweight.device.type == "cuda"
assert self.qweight.device.index is not None

# resize due to padding after model weights have been loaded
if self.outfeatures != self.original_outfeatures or self.infeatures != self.original_infeatures:
self.qweight.resize_(self.infeatures // 32 * self.bits, self.outfeatures)
self.qzeros.resize_(
math.ceil(self.infeatures / self.group_size),
self.outfeatures // 32 * self.bits
)
self.scales.resize_(math.ceil(self.infeatures / self.group_size), self.outfeatures)
self.g_idx = torch.tensor([i // self.group_size for i in range(self.infeatures)], dtype=torch.int32, device=self.g_idx.device)
if self.bias is not None:
self.bias.resize_(self.outfeatures)

self.q_tensors = {
"qweight": self.qweight,
"qzeros": self.qzeros,
Expand All @@ -173,6 +189,11 @@ def forward(self, x, force_cuda=False):

x = x.half()

# TODO: need to run checks to make sure there is no performance regression padding with F.pad
# if infeatures is padded, we need to pad the input as well
if x.size(-1) != self.infeatures and self.infeatures > self.original_infeatures:
x = F.pad(x, (0, self.infeatures - self.original_infeatures))

output = ext_gemm_half_q_half(x, self.q_handle, self.outfeatures, force_cuda)

if self.bias is not None:
Expand Down
1 change: 1 addition & 0 deletions gptqmodel/nn_modules/qlinear/qlinear_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def __init__(self, bits: int, group_size: int, sym: bool, desc_act: bool, infeat
raise ValueError("`infeatures` must be divisible by 128 and `outfeatures` by 256.")
if group_size not in [-1, 128] and group_size != infeatures:
raise ValueError("Only group_size -1 and 128 are supported.")
# Marlin groups infeatures according to group_size, so infeatures must be an integer multiple of group_size.
if infeatures % group_size != 0:
raise ValueError("`infeatures` must be divisible by `group_size`.")

Expand Down
71 changes: 37 additions & 34 deletions gptqmodel/utils/bitblas.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from logging import getLogger

import accelerate
import threadpoolctl as tctl
import torch
from accelerate.utils import find_tied_parameters
from tqdm import tqdm
Expand Down Expand Up @@ -75,7 +76,7 @@ def prepare_model_for_bitblas_load(


@torch.no_grad()
def convert_to_bitblas(model, model_quantlinear, quantization_config: QuantizeConfig, sym: bool, desc_act: bool, repack: bool,
def convert_to_bitblas(model, model_quantlinear, quant_config: QuantizeConfig, sym: bool, desc_act: bool, repack: bool,
strict: bool = False):
"""
Converts GPTQ-packed weights to the Bitblas format.
Expand All @@ -90,40 +91,42 @@ def convert_to_bitblas(model, model_quantlinear, quantization_config: QuantizeCo
# TODO: load directly BitBLAS QuantLinear.
message = "Overriding QuantLinear layers to use BitBLAS's QuantLinear..."

for name, module in tqdm(model.named_modules(), desc=message, total=len(list(model.named_modules()))):
if not isinstance(module, model_quantlinear):
continue

parent_name = ".".join(name.split(".")[:-1])
layer_name = name[len(parent_name) + 1:]

# We could use `torch.count_nonzero(module.bias) > 0` here to discard zero bias, but this has issues when loading weights
# from checkpoints holding zero bias.
with torch.device("meta"):
bitblas_module = BitBLASQuantLinear(
bits=quantization_config.bits,
group_size=quantization_config.group_size,
sym=sym,
desc_act=desc_act,
infeatures=module.infeatures,
outfeatures=module.outfeatures,
bias=module.bias is not None,
enable_tuning=True
)

# Dequantize the weight.
if repack:
bitblas_module.repack_from_gptq(module)

# Save to parent.
parent_module = model.get_submodule(parent_name)
setattr(parent_module, layer_name, bitblas_module)

# Free cuda memory.
del module
gc.collect()
# TODO: need to benchmark to see multiple threads help with bitblas/tvm compilation and runtime
with tctl.threadpool_limits(limits=1):
for name, module in tqdm(model.named_modules(), desc=message, total=len(list(model.named_modules()))):
if not isinstance(module, model_quantlinear):
continue

parent_name = ".".join(name.split(".")[:-1])
layer_name = name[len(parent_name) + 1:]

# We could use `torch.count_nonzero(module.bias) > 0` here to discard zero bias, but this has issues when loading weights
# from checkpoints holding zero bias.
with torch.device("meta"):
bitblas_module = BitBLASQuantLinear(
bits=quant_config.bits,
group_size=quant_config.group_size,
sym=sym,
desc_act=desc_act,
infeatures=module.infeatures,
outfeatures=module.outfeatures,
bias=module.bias is not None,
enable_tuning=True
)

# Dequantize the weight.
if repack:
bitblas_module.repack_from_gptq(module)

# Save to parent.
parent_module = model.get_submodule(parent_name)
setattr(parent_module, layer_name, bitblas_module)

# Free cuda memory.
del module
gc.collect()

# Set quantization config to be BitBLAS.
quantization_config.format = FORMAT.BITBLAS
quant_config.format = FORMAT.BITBLAS

return model
Loading