Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

[GPTQ Enhence] Add GPTQ int8 weight unpack function #184

Merged
merged 6 commits into from
Mar 22, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions neural_speed/convert/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,8 @@ def unpack_weight(qweight, scales, qzeros, q_config):
return unpack_gptq_weight_4bits(qweight, scales, qzeros, q_config)
elif qbits == 3:
return unpack_gptq_weight_3bits(qweight, scales, qzeros, q_config)
elif qbits == 8:
return unpack_gptq_weight_8bits(qweight, scales, qzeros, q_config)

return ValueError(f"Unsupported q_config[bits]: {qbits}")

Expand All @@ -215,6 +217,51 @@ def unpack_weight(qweight, scales, qzeros, q_config):
raise ValueError(f"Unsupported quant_method: {quant_method}")


def unpack_gptq_weight_8bits(qweight, scales, qzeros, q_config):
sym = q_config['sym']
group_size = q_config['group_size']
bits = q_config['bits']
s32_bits = 32

assert bits == 8
# Int32 can store 8 * 4bits data. This is the offset for each data.
wf = torch.tensor(list(range(0, s32_bits, bits)), dtype=torch.int32).unsqueeze(0)
zeros = torch.bitwise_right_shift(torch.unsqueeze(qzeros, 2).expand(-1, -1, 32 // bits),
wf.unsqueeze(0)).to(torch.int16 if bits == 8 else torch.int8)
torch.bitwise_and(zeros, (2**bits) - 1, out=zeros)

if bits == 8:
zeros = zeros.to(torch.int8 if sym else torch.uint8)

zeros = zeros + 1
try:
zeros = zeros.reshape(scales.shape)
except:
# zeros and scales have different iteam numbers.
Zhenzhong1 marked this conversation as resolved.
Show resolved Hide resolved
# remove 1 (due to 0 + 1 in line 68)
zeros = zeros[zeros !=1]
zeros = zeros.reshape(scales.shape)

if not sym and bits == 8:
zeros = (zeros.to(torch.int32) - 128).to(torch.int8)

weight = torch.bitwise_right_shift(torch.unsqueeze(qweight, 1).expand(-1, 32 // bits, -1),
wf.unsqueeze(-1)).to(torch.int16 if bits == 8 else torch.int8)
torch.bitwise_and(weight, (2**bits) - 1, out=weight)

if bits == 8:
# due to INC add shift bias for sym
if sym:
shift_bias = 2**(bits - 1)
weight -= shift_bias
weight = weight.to(torch.int8 if sym else torch.uint8)
# due to INC asym return torch.uint8 but backend request int8,
a32543254 marked this conversation as resolved.
Show resolved Hide resolved
# change it to int8 with offset 128
if not sym:
weight = (weight.to(torch.int32) - 128).to(torch.int8)
return weight, scales, zeros


def unpack_gptq_weight_4bits(qweight, scales, qzeros, q_config):
group_size = q_config['group_size']
bits = q_config['bits']
Expand Down
3 changes: 2 additions & 1 deletion neural_speed/convert/convert_quantized_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ def main(args_in: Optional[List[str]] = None) -> None:
out_path = args.outfile.as_posix()
model_path = args.model.as_posix()

model, config, quantize_config = load_quantized_model(model_path)
#model, config, quantize_config = load_quantized_model(model_path)
Zhenzhong1 marked this conversation as resolved.
Show resolved Hide resolved
model, config, quantize_config = load_quantized_safetensors(model_path)
f = open(out_path, "wb")

# 1. write hparams
Expand Down
Loading