Skip to content

Commit

Permalink
Support Weight-Only quantization on CPU device with QBits backend (#437)
Browse files Browse the repository at this point in the history
Signed-off-by: Cheng Penghui <[email protected]>
Signed-off-by: Cheng, Penghui <[email protected]>
Co-authored-by: Casper Hansen <[email protected]>
  • Loading branch information
PenghuiCheng and casper-hansen authored Jun 8, 2024
1 parent 6a46ad6 commit 2627364
Show file tree
Hide file tree
Showing 15 changed files with 415 additions and 53 deletions.
23 changes: 23 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,29 @@ GPU: 2x NVIDIA GeForce RTX 4090
| Mixtral | 46.7B | 🔵GEMM | 1 | 2048 | 2048 | 2446.15 | 77.0516 | 27.98 GB (59.15%) |
| Mixtral | 46.7B | 🔵GEMM | 1 | 4096 | 4096 | 1985.78 | 77.5689 | 34.65 GB (73.26%) |

### CPU

- CPU: INTEL(R) XEON(R) PLATINUM 8592+ with 8-channel 4800MT/s memory.
- Command: `python examples/benchmark.py --model_path <hf_model> --batch_size 1`

| Model | Size | Batch Size | Prefill Length | Decode Length | Prefill tokens/s | Decode tokens/s | Memory (RAM) |
|--------:|------:|-----------:|-------------:|-----------------:|----------------:|---------------:|:------------------|
| Mixtral | 7B | 1 | 64 | 64 | 389.24 | 16.01 | 5.59 GB (0.02%) |
| Mixtral | 7B | 1 | 2048 | 2048 | 1412 | 17.76 | 6.29 GB (0.03%) |
| Vicuna | 7B | 1 | 64 | 64 | 346 | 18.13 | 8.18 GB (0.03%) |
| Vicuna | 7B | 1 | 2048 | 2048 | 1023.4 | 18.18 | 8.80 GB (0.04%) |
| LLaMA2 | 13B | 1 | 64 | 64 | 160.24 | 9.87 | 14.65 GB (0.06%) |
| LLaMA2 | 13B | 1 | 2048 | 2048 | 592.35 | 9.93 | 16.87 GB (0.07%) |
| Mosaicml | 7B | 1 | 64 | 64 | 433.17 | 18.79 | 4.60 GB (0.02%) |
| Mosaicml | 7B | 1 | 2048 | 2048 | 404.25 | 19.91 | 4.75 GB (0.02%) |
| Falcon | 7B | 1 | 64 | 64 | 303.16 | 14.41 | 5.18 GB (0.02%) |
| Falcon | 7B | 1 | 2048 | 2048 | 634.57 | 15.55 | 5.80 GB (0.02%) |
| CodeLlama | 34B | 1 | 64 | 64 | 153.73 | 4.23 | 29.00 GB (0.12%) |
| CodeLlama | 34B | 1 | 2048 | 2048 | 274.25 | 4.38 | 35.21 GB (0.15%) |
| Deepseek-coder | 33B | 1 | 64 | 64 | 83.08 | 4.07 | 22.16 GB (0.09%) |
| Deepseek-coder | 33B | 1 | 2048 | 2048 | 296.04 | 4.33 | 37.05 GB |


## Reference

If you find AWQ useful or relevant to your research, you can cite their [paper](https://arxiv.org/abs/2306.00978):
Expand Down
2 changes: 2 additions & 0 deletions awq/models/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def from_quantized(
fuse_layers=True,
use_exllama=False,
use_exllama_v2=False,
use_qbits=False,
batch_size=1,
safetensors=True,
device_map="balanced",
Expand Down Expand Up @@ -109,6 +110,7 @@ def from_quantized(
fuse_layers=fuse_layers,
use_exllama=use_exllama,
use_exllama_v2=use_exllama_v2,
use_qbits=use_qbits,
safetensors=safetensors,
device_map=device_map,
max_memory=max_memory,
Expand Down
54 changes: 43 additions & 11 deletions awq/models/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import gc
import json
import logging
import torch
import transformers
import torch.nn as nn
Expand All @@ -15,19 +16,22 @@
from awq.modules.linear import (
WQLinear_GEMM,
WQLinear_GEMV,
WQLinear_QBits,
WQLinear_Marlin,
WQLinear_Exllama,
WQLinear_ExllamaV2,
WQLinear_GEMVFast,
marlin_post_init,
exllama_post_init,
exllamav2_post_init,
qbits_post_init,
)
from awq.utils.module import (
get_named_linears,
set_op_by_name,
exclude_layers_to_not_quantize,
)
from awq.utils.utils import get_best_device, qbits_available
from transformers import (
AutoConfig,
PreTrainedModel,
Expand All @@ -46,6 +50,10 @@
from awq.quantize.quantizer import AwqQuantizer
from awq.utils.module import get_named_linears, set_op_by_name

if qbits_available:
from intel_extension_for_transformers.qbits import check_isa_supported


# Since we support different `AutoModelForxxx` from transformers
# we need to define a custom mapping dict as below:
TRANSFORMERS_AUTO_MAPPING_DICT = {
Expand Down Expand Up @@ -389,6 +397,9 @@ def from_quantized(
use_exllama_v2: Annotated[
bool, Doc("Whether to map the weights to ExLlamaV2 kernels.")
] = False,
use_qbits: Annotated[
bool, Doc("Whether to map the weights to qbits kernels for CPU device.")
] = False,
device_map: Annotated[
Union[str, Dict],
Doc(
Expand Down Expand Up @@ -439,6 +450,14 @@ def from_quantized(
trust_remote_code=trust_remote_code,
)

use_cpu_qbits = use_qbits or get_best_device() == "cpu"
if use_cpu_qbits:
if not qbits_available:
raise ImportError("Please install intel-extension-for-transformers with "
"`pip install intel-extension-for-transformers` for 'qbits' kernel!")

fuse_layers = False
logging.warn("Unsupport fuse_layers featrue for CPU device with QBits backend!")
# Prepare WQLinear layers, replace nn.Linear
self._load_quantized_modules(
self,
Expand All @@ -447,6 +466,7 @@ def from_quantized(
quant_config.version,
use_exllama=use_exllama,
use_exllama_v2=use_exllama_v2,
use_qbits=use_cpu_qbits,
)

model.tie_weights()
Expand All @@ -467,9 +487,13 @@ def from_quantized(
if fuse_layers:
self.fuse_layers(model)

if quant_config.version == "marlin":
if use_cpu_qbits:
dtype = torch.bfloat16 if check_isa_supported("AMX") else torch.float32
model.to(dtype=dtype, device="cpu")
# repack qweight to match the QBits kernel.
model = qbits_post_init(model)
elif quant_config.version == "marlin":
model = marlin_post_init(model)

elif use_exllama:
# creates q4 handle
model = exllama_post_init(model)
Expand Down Expand Up @@ -507,10 +531,10 @@ def _load_config(
ignore_patterns.extend(["*.pt*", "*.bin*", "consolidated*"])
else:
ignore_patterns.append("*.safetensors*")

if download_kwargs is None:
download_kwargs = {}

if "ignore_patterns" in download_kwargs:
download_kwargs_ignore_patterns = download_kwargs.pop("ignore_patterns")

Expand Down Expand Up @@ -551,11 +575,11 @@ def _load_config(
return model_weights_path, config, quant_config

def _load_quantized_modules(
self, model, quant_config, version, use_exllama, use_exllama_v2
self, model, quant_config, version, use_exllama, use_exllama_v2, use_qbits=False
):
# Real quantization of weights
assert not (
version == "gemv" and (use_exllama or use_exllama_v2)
version == "gemv" and (use_exllama or use_exllama_v2 or use_qbits)
), "Exllama kernels only support GEMM version."

# Get blocks of model
Expand All @@ -577,7 +601,9 @@ def _load_quantized_modules(

# Replace nn.Linear with WQLinear
for name, module in named_linears.items():
if version == "marlin":
if use_qbits:
q_linear_module = WQLinear_QBits
elif version == "marlin":
q_linear_module = WQLinear_Marlin
elif use_exllama:
q_linear_module = WQLinear_Exllama
Expand All @@ -590,13 +616,19 @@ def _load_quantized_modules(
elif version == "gemv_fast":
q_linear_module = WQLinear_GEMVFast

q_linear = q_linear_module.from_linear(
module, quant_config.w_bit, quant_config.q_group_size, True
)
if use_qbits:
q_linear = q_linear_module.from_linear(module,
quant_config.w_bit,
quant_config.q_group_size,
True,
has_zero_points=quant_config.zero_point)
else:
q_linear = q_linear_module.from_linear(module, quant_config.w_bit, quant_config.q_group_size, True)
q_linear.to(next(layer.parameters()).device)
set_op_by_name(layer, name, q_linear)

torch.cuda.empty_cache()
if not use_qbits:
torch.cuda.empty_cache()
gc.collect()

@staticmethod
Expand Down
1 change: 1 addition & 0 deletions awq/modules/linear/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .exllama import WQLinear_Exllama, exllama_post_init
from .exllamav2 import WQLinear_ExllamaV2, exllamav2_post_init
from .gemm import WQLinear_GEMM
from .gemm_qbits import WQLinear_QBits, qbits_post_init
from .gemv import WQLinear_GEMV
from .marlin import WQLinear_Marlin, marlin_post_init
from .gemv_fast import WQLinear_GEMVFast
155 changes: 155 additions & 0 deletions awq/modules/linear/gemm_qbits.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
import torch
import torch.nn as nn
from ...utils.packing_utils import reverse_awq_order, unpack_awq

try:
from intel_extension_for_transformers import qbits # with QBits kernels ()

QBITS_INSTALLED = True
except:
QBITS_INSTALLED = False

BITS_DTYPE_MAPPING = {
4: "int4_clip",
8: "int8",
}


def convert_dtype_torch2str(dtype):
if dtype == torch.int8:
return "int8"
elif dtype == torch.float:
return "fp32"
elif dtype == torch.float16:
return "fp16"
elif dtype == torch.bfloat16:
return "bf16"
elif isinstance(dtype, str) and dtype in ["int8", "fp32", "fp16", "bf16"]:
return dtype
else:
assert False, "Unsupported pytorch dtype {} to str dtype".format(dtype)


class WQLinear_QBits(nn.Module):

def __init__(self, w_bit, group_size, in_features, out_features, bias, zero_point, dev):
super().__init__()
assert QBITS_INSTALLED, \
"Please install ITREX qbits package with `pip install intel-extension-for-transformers`."

self.use_bf16 = qbits.check_isa_supported("AMX")

if w_bit not in [2, 3, 4, 8]:
raise NotImplementedError("Only 2, 3, 4, 8 bits are supported for now.")

self.in_features = in_features
self.out_features = out_features
self.w_bit = w_bit
self.group_size = group_size if group_size != -1 else in_features
self.zero_point = zero_point
self.scale_dtype = torch.float32

# quick sanity check (make sure aligment)
assert self.in_features % self.group_size == 0
assert out_features % (32 // self.w_bit) == 0
self.pack_num = 32 // self.w_bit

self.register_buffer(
"qzeros",
torch.zeros(
(in_features // self.group_size, out_features // self.pack_num),
dtype=torch.int8,
device=dev,
) if self.zero_point else None,
)
self.register_buffer(
"scales",
torch.zeros(
(in_features // self.group_size, out_features),
dtype=torch.bfloat16 if self.use_bf16 else torch.float32,
device=dev,
))
if bias:
self.register_buffer(
"bias",
torch.zeros((out_features), dtype=torch.bfloat16 if self.use_bf16 else torch.float32, device=dev),
)
else:
self.register_buffer(
"bias",
None,
)
qweight = torch.zeros((in_features, out_features // self.pack_num), dtype=torch.int32, device=dev)
self.register_buffer("qweight", qweight)

def post_init(self):
assert self.qweight.device.type == "cpu"

intweight, zeros = unpack_awq(self.qweight, self.qzeros, self.w_bit) # weight: k x n zeros: k / group_size x n
intweight, zeros = reverse_awq_order(intweight, zeros, self.w_bit) # weight: k x n zeros: k / group_size x n
if self.zero_point:
intweight = torch.bitwise_and(intweight, (2**self.w_bit) - 1) - (2**(self.w_bit - 1))
zeros = torch.bitwise_and(zeros, (2**self.w_bit) - 1) - (2**(self.w_bit - 1))
else:
intweight = torch.bitwise_and(intweight, (2**self.w_bit) - 1)
g_idx = torch.empty(0, dtype=torch.int32)

self.qweight = qbits.repack_quantized_weight(intweight, self.scales.float(), zeros, g_idx,
BITS_DTYPE_MAPPING[self.w_bit],
convert_dtype_torch2str(self.scale_dtype),
convert_dtype_torch2str(self.scales.dtype), self.zero_point,
self.group_size)

@classmethod
def from_linear(cls, linear, w_bit, group_size, init_only=False, scales=None, zeros=None, has_zero_points=False):
awq_linear = cls(
w_bit,
group_size,
linear.in_features,
linear.out_features,
linear.bias is not None,
has_zero_points,
linear.weight.device,
)
if init_only: # just prepare for loading sd
return awq_linear

raise NotImplementedError("Only inference is supported for Exllama kernels")

@torch.no_grad()
def forward(self, x):
assert QBITS_INSTALLED, (
"QBits kernels could not be loaded. "
"Please install with `pip install intel-extension-for-transformers` and "
"refer to the detial https://github.com/intel/intel-extension-for-transformers/blob/main/docs/qbits.md")

input_dtype = x.dtype
out_shape = x.shape[:-1] + (self.out_features,)
x = x.view(-1, x.shape[-1]) # convert xd to 2d
out_2d_shape = x.shape[:-1] + (self.out_features,)

outputs = torch.zeros(out_2d_shape, dtype=input_dtype)
bias = self.bias if self.bias is not None else torch.empty(
0, dtype=torch.bfloat16 if self.use_bf16 else torch.float32)

qbits.woq_linear(x, self.qweight, bias, outputs, convert_dtype_torch2str(input_dtype),
BITS_DTYPE_MAPPING[self.w_bit], convert_dtype_torch2str(self.scale_dtype), self.zero_point)

return outputs.view(out_shape)

def extra_repr(self) -> str:
return ("in_features={}, out_features={}, bias={}, w_bit={}, group_size={}".format(
self.in_features,
self.out_features,
self.bias is not None,
self.w_bit,
self.group_size,
))


def qbits_post_init(model):
for _, submodule in model.named_modules():
if isinstance(submodule, WQLinear_QBits):
submodule.post_init()

return model
3 changes: 2 additions & 1 deletion awq/quantize/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,8 @@ def _apply_quant(self, module, named_linears: Dict[str, nn.Linear]):

if self.version == "gemm":
scales = scales.t().contiguous()
zeros = zeros.t().contiguous()
if zeros is not None:
zeros = zeros.t().contiguous()
q_linear_module = WQLinear_GEMM

elif self.version == "gemv":
Expand Down
Loading

0 comments on commit 2627364

Please sign in to comment.