Skip to content

Commit

Permalink
merge multithread pack
Browse files Browse the repository at this point in the history
  • Loading branch information
DeJoker authored and xx committed Jul 19, 2024
1 parent bd68ac5 commit 1141275
Show file tree
Hide file tree
Showing 3 changed files with 1,666 additions and 22 deletions.
19 changes: 18 additions & 1 deletion gptqmodel/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ def __init__(
qlinear_kernel: nn.Module = None,
):
super().__init__()
logger.info(f"start base")


self.model = model
self.model_type = self.model.config.model_type
Expand Down Expand Up @@ -151,11 +153,24 @@ def _convert_tensor_to_list(tensor):
return new_calibration_dataset_batched

def quantize(
self,
calibration_dataset: List[Dict[str, Union[List[int], torch.LongTensor]]],
batch_size: int = 1,
calibration_enable_gpu_cache: bool = True,
):
if isinstance(self.quantize_config, AutoRoundQuantizeConfig):
return self._quantize(calibration_dataset, batch_size, calibration_enable_gpu_cache)
else:
with torch.inference_mode():
return self._quantize(calibration_dataset, batch_size, calibration_enable_gpu_cache)

def _quantize(
self,
calibration_dataset: List[Dict[str, Union[List[int], torch.LongTensor]]],
batch_size: int = 1,
calibration_enable_gpu_cache: bool = True,
):
logger.info(f"start quant")
if self.quantized:
raise EnvironmentError("quantize() is called a model that is already quantized")

Expand Down Expand Up @@ -496,11 +511,13 @@ def tmp(_, inp, out):
[],
) # TODO: is it really OK to cache only the first positional argument?
torch.cuda.empty_cache()

logger.info(f"Quantization summary:\n{quant_log}")
for module_log in quant_log:
logger.info(module_log)

return quant_log, quantizers, force_layer_back_to_cpu, device_map, forward_pass_use_cache

def pack(self, quant_log, quantizers, force_layer_back_to_cpu, device_map, forward_pass_use_cache):
self.qlinear_kernel = pack_model(
model=self.model,
quantizers=quantizers,
Expand Down
53 changes: 32 additions & 21 deletions gptqmodel/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from tqdm import tqdm
from transformers import AutoConfig, PretrainedConfig
from transformers.utils.hub import cached_file
from concurrent.futures import ThreadPoolExecutor, as_completed

from ..models._const import CPU, EXLLAMA_DEFAULT_MAX_INPUT_LENGTH, EXPERT_INDEX_PLACEHOLDER, SUPPORTED_MODELS
from ..nn_modules.qlinear import BaseQuantLinear
Expand All @@ -32,7 +33,7 @@
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.setLevel(logging.INFO)

logger.info(f"start base")

def recurse_getattr(obj, attr: str):
"""
Expand Down Expand Up @@ -290,31 +291,41 @@ def pack_model(
qlayers = find_layers(model, [QuantLinear])

# Limit pack() thread usage to avoid auto-parallizataion regression
with tctl.threadpool_limits(limits=1):
pbar = tqdm(qlayers.keys(), leave=True)
for name in pbar:
pbar.set_description(f"Packing {name}")

quantizers[name], scale, zero, g_idx = quantizers[name]
# so far can only pack layer on CPU
layer_device = qlayers[name].device
qlayers[name].to(CPU)
layers[name], scale, zero, g_idx = (
layers[name].to(CPU),
scale.to(CPU),
zero.to(CPU),
g_idx.to(CPU),
)
if QuantLinear is MarlinQuantLinear:
qlayers[name].pack(layers[name], scale)
else:
qlayers[name].pack(layers[name], scale, zero, g_idx)
qlayers[name].to(layer_device)
with ThreadPoolExecutor(max_workers=1) as executor:
# future_to_item = {executor.submit(pack_layer, quantizers[name], qlayers[name], layers[name], QuantLinear): name for name in qlayers.keys()}
# for future in tqdm(as_completed(future_to_item), total=len(qlayers.keys())):
# item = future_to_item[future]
# result = future.result()
# print(f"Processed {item} to {result}")
futures = [(executor.submit(pack_layer, quantizers[name], qlayers[name], layers[name], QuantLinear), name) for name in qlayers.keys()]
for item in futures:
future, name = item
future.result()
print(f"{name} Processed")

logger.info("Model packed.")

return QuantLinear

def pack_layer(quantizer, qlayer, layer, QuantLinear):
with tctl.threadpool_limits(limits=1):
_quantizer, scale, zero, g_idx = quantizer
# so far can only pack layer on CPU
layer_device = qlayer.device
qlayer.to(CPU)
layer, scale, zero, g_idx = (
layer.to(CPU),
scale.to(CPU),
zero.to(CPU),
g_idx.to(CPU),
)
if QuantLinear is MarlinQuantLinear:
qlayer.pack(layer, scale)
else:
qlayer.pack(layer, scale, zero, g_idx)
qlayer.to(layer_device)


def verify_model_hash(file_path: str, verify_hash: str):
if not isinstance(verify_hash, str):
raise ValueError("model verify_hash must be a string")
Expand Down
Loading

0 comments on commit 1141275

Please sign in to comment.