From c88e0262255ffe251cfc3ba5cfdf7ea0fb5994bd Mon Sep 17 00:00:00 2001 From: ZX-ModelCloud <165115237+ZX-ModelCloud@users.noreply.github.com> Date: Thu, 27 Jun 2024 13:28:11 +0800 Subject: [PATCH] Consolidate Backend (#68) * Consolidate Backend * change Backend.TRITON_V2 to Backend.TRITON * According to quantize_config.format, determine when the Backend is packing the model. * Auto choose the fastest one Backend based on quant model compatibility * fix issue: Automatically select Backend, returns incorrect qlinear. * cleanup * cleanup --- examples/benchmark/generation_speed.py | 21 ++--- examples/benchmark/perplexity.py | 13 ++- .../evaluation/run_language_modeling_task.py | 9 +- .../run_sequence_classification_task.py | 9 +- .../evaluation/run_text_summarization_task.py | 9 +- examples/quantization/basic_usage.py | 5 +- examples/quantization/basic_usage_bitblas.py | 13 +-- examples/quantization/basic_usage_gpt_xl.py | 7 +- .../quantization/basic_usage_wikitext2.py | 5 +- gptqmodel/__init__.py | 1 + gptqmodel/models/auto.py | 20 +---- gptqmodel/models/base.py | 83 +++++-------------- gptqmodel/nn_modules/qlinear/__init__.py | 34 +++++--- .../nn_modules/qlinear/qlinear_bitblas.py | 2 + gptqmodel/nn_modules/qlinear/qlinear_cuda.py | 5 +- .../nn_modules/qlinear/qlinear_cuda_old.py | 5 +- .../nn_modules/qlinear/qlinear_exllama.py | 3 +- .../nn_modules/qlinear/qlinear_exllamav2.py | 3 +- .../nn_modules/qlinear/qlinear_marlin.py | 5 +- gptqmodel/utils/__init__.py | 1 + gptqmodel/utils/backend.py | 19 +++++ gptqmodel/utils/importer.py | 82 +++++++++++++----- gptqmodel/utils/model.py | 61 +++++--------- tests/test_lm_head.py | 3 +- tests/test_perplexity.py | 5 +- tests/test_q4_bitblas.py | 8 +- tests/test_q4_cuda.py | 16 ++-- tests/test_q4_exallama.py | 25 +++--- tests/test_q4_exallama_v2.py | 18 ++-- tests/test_q4_marlin.py | 10 +-- tests/test_q4_triton.py | 15 ++-- tests/test_quant_formats.py | 13 +-- tests/test_serialization.py | 10 +-- tests/test_sharded.py | 3 +- tests/test_triton.py | 5 +- tests/test_verify_hash.py | 6 +- 36 files changed, 278 insertions(+), 274 deletions(-) create mode 100644 gptqmodel/utils/backend.py diff --git a/examples/benchmark/generation_speed.py b/examples/benchmark/generation_speed.py index 16e42d75..ec0d9b1f 100644 --- a/examples/benchmark/generation_speed.py +++ b/examples/benchmark/generation_speed.py @@ -7,11 +7,12 @@ import torch from datasets import Dataset, load_dataset -from gptqmodel import GPTQModel, QuantizeConfig from tqdm import tqdm from transformers import AutoTokenizer, GenerationConfig from transformers.generation.logits_process import LogitsProcessor +from gptqmodel import Backend, GPTQModel, QuantizeConfig, get_backend + logger = logging.getLogger(__name__) random.seed(0) @@ -143,17 +144,15 @@ def tokenize(examples): def load_model_tokenizer( model_name_or_path: str, + backend: Backend, tokenizer_name_or_path: Optional[str] = None, from_pretrained: bool = False, max_memory: Optional[dict] = None, model_basename: Optional[str] = None, quantize_config: Optional[str] = None, trust_remote_code: bool = False, - use_triton: bool = False, - use_bitblas: bool = False, use_safetensors: bool = True, use_fast_tokenizer: bool = False, - disable_exllama: bool = False, ): tokenizer = AutoTokenizer.from_pretrained( pretrained_model_name_or_path=tokenizer_name_or_path or model_name_or_path, @@ -174,15 +173,13 @@ def load_model_tokenizer( model = GPTQModel.from_quantized( model_name_or_path, max_memory=max_memory, - use_triton=use_triton, - use_bitblas=use_bitblas, use_cuda_fp16=True, quantize_config=quantize_config, model_basename=model_basename, use_safetensors=use_safetensors, trust_remote_code=trust_remote_code, warmup_triton=False, - disable_exllama=disable_exllama, + backend=backend, ) return model, tokenizer @@ -234,11 +231,9 @@ def main(): parser.add_argument("--model_basename", type=str, default=None) parser.add_argument("--quantize_config_save_dir", type=str, default=None) parser.add_argument("--trust_remote_code", action="store_true") - parser.add_argument("--use_triton", action="store_true") - parser.add_argument("--use_bitblas", action="store_true") + parser.add_argument("--backend", choices=['AUTO', 'CUDA_OLD', 'CUDA', 'TRITON', 'EXLLAMA', 'EXLLAMA_V2', 'MARLIN', 'BITBLAS']) parser.add_argument("--use_safetensors", action="store_true") parser.add_argument("--use_fast_tokenizer", action="store_true") - parser.add_argument("--disable_exllama", action="store_true") parser.add_argument("--num_samples", type=int, default=10) parser.add_argument("--per_gpu_max_memory", type=int, default=None) parser.add_argument("--cpu_max_memory", type=int, default=None) @@ -277,11 +272,9 @@ def main(): model_basename=args.model_basename, quantize_config=quantize_config, trust_remote_code=args.trust_remote_code, - use_triton=args.use_triton, - use_bitblas=args.use_bitblas, use_safetensors=True, use_fast_tokenizer=args.use_fast_tokenizer, - disable_exllama=args.disable_exllama, + backend=get_backend(args.backend), ) end = time.time() logger.info(f"model and tokenizer loading time: {end - start:.4f}s") @@ -289,7 +282,7 @@ def main(): logger.info(f"quantize config: {model.quantize_config.to_dict()}") logger.info(f"model device map: {model.hf_device_map}") - if args.use_triton: + if args.backend == Backend.TRITON: logger.info("warmup triton, this may take a while.") model.warmup_triton() diff --git a/examples/benchmark/perplexity.py b/examples/benchmark/perplexity.py index 0b5543f4..955d0043 100644 --- a/examples/benchmark/perplexity.py +++ b/examples/benchmark/perplexity.py @@ -2,9 +2,10 @@ import os import torch -from gptqmodel.utils import Perplexity from transformers import AutoTokenizer +from gptqmodel.utils import Perplexity, get_backend + if __name__ == "__main__": """ Example usage. @@ -42,7 +43,7 @@ default=None, help="Max memory used in each GPU.", ) - parser.add_argument("--cpu_max_memory", type=int, default=None, help="Mx memory used in CPU.") + parser.add_argument("--cpu_max_memory", type=int, default=None, help="Max memory used in CPU.") parser.add_argument("--is_quantized", action="store_true", help="Is the model GPTQ quantized?") parser.add_argument( "--use_safetensors", @@ -51,11 +52,7 @@ ) parser.add_argument("--use_fast_tokenizer", action="store_true", help="Wheter to use fast tokenizer") parser.add_argument("--trust_remote_code", action="store_true", help="Whether to use remote code") - parser.add_argument( - "--disable_exllama", - action="store_true", - help="Whether to use disable exllama kernel", - ) + parser.add_argument("--backend", choices=['AUTO', 'CUDA_OLD', 'CUDA', 'TRITON', 'EXLLAMA', 'EXLLAMA_V2', 'MARLIN', 'BITBLAS'], help="Whether to use Backend format") args = parser.parse_args() os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -88,7 +85,7 @@ model_basename=args.model_basename, use_safetensors=True, trust_remote_code=args.trust_remote_code, - disable_exllama=args.disable_exllama, + backend=get_backend(args.backend), ) else: from transformers import AutoModelForCausalLM diff --git a/examples/evaluation/run_language_modeling_task.py b/examples/evaluation/run_language_modeling_task.py index 867c4009..83e4a5e7 100644 --- a/examples/evaluation/run_language_modeling_task.py +++ b/examples/evaluation/run_language_modeling_task.py @@ -2,10 +2,11 @@ import datasets import torch -from gptqmodel import GPTQModel, QuantizeConfig -from gptqmodel.eval_tasks import LanguageModelingTask from transformers import AutoTokenizer +from gptqmodel import GPTQModel, QuantizeConfig, get_backend +from gptqmodel.eval_tasks import LanguageModelingTask + DATASET = "tatsu-lab/alpaca" WITH_INPUT_TEMPLATE = "Instruction:\n{instruction}\n\nInput:\n{input}\n\nOutput:\n" WITHOUT_INPUT_TEMPLATE = "Instruction:\n{instruction}\n\nOutput:\n" @@ -40,7 +41,7 @@ def main(): ) parser.add_argument("--sample_max_len", type=int, default=1024, help="max tokens for each sample") parser.add_argument("--block_max_len", type=int, default=2048, help="max tokens for each data block") - parser.add_argument("--use_triton", action="store_true") + parser.add_argument("--backend", choices=['AUTO', 'CUDA_OLD', 'CUDA', 'TRITON', 'EXLLAMA', 'EXLLAMA_V2', 'MARLIN', 'BITBLAS']) args = parser.parse_args() tokenizer = AutoTokenizer.from_pretrained(args.base_model_dir) @@ -70,7 +71,7 @@ def main(): del model torch.cuda.empty_cache() - model = GPTQModel.from_quantized(args.quantized_model_dir, device="cuda:0", use_triton=args.use_triton) + model = GPTQModel.from_quantized(args.quantized_model_dir, device="cuda:0", backend=get_backend(args.backend)) task.model = model task.device = model.device print(f"eval result for quantized model: {task.run()}") diff --git a/examples/evaluation/run_sequence_classification_task.py b/examples/evaluation/run_sequence_classification_task.py index 77869897..f7be2b58 100644 --- a/examples/evaluation/run_sequence_classification_task.py +++ b/examples/evaluation/run_sequence_classification_task.py @@ -3,10 +3,11 @@ import datasets import torch -from gptqmodel import GPTQModel, QuantizeConfig -from gptqmodel.eval_tasks import SequenceClassificationTask from transformers import AutoTokenizer +from gptqmodel import GPTQModel, QuantizeConfig, get_backend +from gptqmodel.eval_tasks import SequenceClassificationTask + DATASET = "cardiffnlp/tweet_sentiment_multilingual" TEMPLATE = "Question:What's the sentiment of the given text? Choices are {labels}.\nText: {text}\nAnswer:" ID2LABEL = {0: "negative", 1: "neutral", 2: "positive"} @@ -38,7 +39,7 @@ def main(): ) parser.add_argument("--sample_max_len", type=int, default=1024, help="max tokens for each sample") parser.add_argument("--block_max_len", type=int, default=2048, help="max tokens for each data block") - parser.add_argument("--use_triton", action="store_true") + parser.add_argument("--backend", choices=['AUTO', 'CUDA_OLD', 'CUDA', 'TRITON', 'EXLLAMA', 'EXLLAMA_V2', 'MARLIN', 'BITBLAS']) args = parser.parse_args() tokenizer = AutoTokenizer.from_pretrained(args.base_model_dir) @@ -69,7 +70,7 @@ def main(): del model torch.cuda.empty_cache() - model = GPTQModel.from_quantized(args.quantized_model_dir, device="cuda:0", use_triton=args.use_triton) + model = GPTQModel.from_quantized(args.quantized_model_dir, device="cuda:0", backend=get_backend(args.backend)) task.model = model task.device = model.device print(f"eval result for quantized model: {task.run()}") diff --git a/examples/evaluation/run_text_summarization_task.py b/examples/evaluation/run_text_summarization_task.py index a010ee83..349dd0a8 100644 --- a/examples/evaluation/run_text_summarization_task.py +++ b/examples/evaluation/run_text_summarization_task.py @@ -3,10 +3,11 @@ import datasets import torch -from gptqmodel import GPTQModel, QuantizeConfig -from gptqmodel.eval_tasks import TextSummarizationTask from transformers import AutoTokenizer, GenerationConfig +from gptqmodel import GPTQModel, QuantizeConfig, get_backend +from gptqmodel.eval_tasks import TextSummarizationTask + os.system("pip install py7zr") @@ -37,7 +38,7 @@ def main(): ) parser.add_argument("--sample_max_len", type=int, default=1024, help="max tokens for each sample") parser.add_argument("--block_max_len", type=int, default=2048, help="max tokens for each data block") - parser.add_argument("--use_triton", action="store_true") + parser.add_argument("--backend", choices=['AUTO', 'CUDA_OLD', 'CUDA', 'TRITON', 'EXLLAMA', 'EXLLAMA_V2', 'MARLIN', 'BITBLAS']) args = parser.parse_args() tokenizer = AutoTokenizer.from_pretrained(args.base_model_dir) @@ -67,7 +68,7 @@ def main(): del model torch.cuda.empty_cache() - model = GPTQModel.from_quantized(args.quantized_model_dir, device="cuda:0", use_triton=args.use_triton) + model = GPTQModel.from_quantized(args.quantized_model_dir, device="cuda:0", backend=get_backend(args.backend)) task.model = model task.device = model.device print(f"eval result for quantized model: {task.run(generation_config=GenerationConfig(max_new_tokens=32))}") diff --git a/examples/quantization/basic_usage.py b/examples/quantization/basic_usage.py index 1f2115f5..9ef3aa8c 100644 --- a/examples/quantization/basic_usage.py +++ b/examples/quantization/basic_usage.py @@ -1,6 +1,7 @@ -from gptqmodel import GPTQModel, QuantizeConfig from transformers import AutoTokenizer, TextGenerationPipeline +from gptqmodel import GPTQModel, QuantizeConfig + pretrained_model_dir = "facebook/opt-125m" quantized_model_dir = "opt-125m-4bit-128g" @@ -48,7 +49,7 @@ def main(): model = GPTQModel.from_quantized(quantized_model_dir, device="cuda:0") # download quantized model from Hugging Face Hub and load to the first GPU - # model = GPTQModel.from_quantized(repo_id, device="cuda:0", use_safetensors=True, use_triton=False) + # model = GPTQModel.from_quantized(repo_id, device="cuda:0", use_safetensors=True,) # inference with model.generate print(tokenizer.decode(model.generate(**tokenizer("gptqmodel is", return_tensors="pt").to(model.device))[0])) diff --git a/examples/quantization/basic_usage_bitblas.py b/examples/quantization/basic_usage_bitblas.py index a9339f48..0eb249f9 100644 --- a/examples/quantization/basic_usage_bitblas.py +++ b/examples/quantization/basic_usage_bitblas.py @@ -1,13 +1,14 @@ import torch -from gptqmodel import GPTQModel -from gptqmodel.quantization import QuantizeConfig from transformers import AutoTokenizer, TextGenerationPipeline -use_bitblas = True +from gptqmodel import Backend, GPTQModel +from gptqmodel.quantization import QuantizeConfig + +backend = Backend.BITBLAS pretrained_model_dir = "facebook/opt-125m" quantized_model_dir = "./facebook/opt-125m-4bit-128g" -if use_bitblas: +if backend == Backend.BITBLAS: quantized_model_dir += "-bitblas" def main(): @@ -52,11 +53,11 @@ def main(): # load quantized model to the first GPU model = GPTQModel.from_quantized( - quantized_model_dir, device="cuda:0", use_bitblas=use_bitblas + quantized_model_dir, device="cuda:0", backend=backend, ) # download quantized model from Hugging Face Hub and load to the first GPU - # model = GPTQModel.from_quantized(repo_id, device="cuda:0", use_safetensors=True, use_triton=False) + # model = GPTQModel.from_quantized(repo_id, device="cuda:0", use_safetensors=True) # -- simple token evaluate -- input_ids = torch.ones((1, 1), dtype=torch.long, device="cuda:0") diff --git a/examples/quantization/basic_usage_gpt_xl.py b/examples/quantization/basic_usage_gpt_xl.py index 1056ee10..90546c96 100644 --- a/examples/quantization/basic_usage_gpt_xl.py +++ b/examples/quantization/basic_usage_gpt_xl.py @@ -3,9 +3,10 @@ import numpy as np import torch from datasets import load_dataset -from gptqmodel import GPTQModel, QuantizeConfig from transformers import TextGenerationPipeline +from gptqmodel import GPTQModel, QuantizeConfig + pretrained_model_dir = "gpt2-xl" quantized_model_dir = "gpt2-large-4bit-128g" @@ -65,7 +66,7 @@ def main(): # quantize model, the calibration_dataset should be list of dict whose keys contains "input_ids" and "attention_mask" # with value under torch.LongTensor type. - model.quantize(traindataset, use_triton=False) + model.quantize(traindataset) # save quantized model model.save_quantized(quantized_model_dir) @@ -74,7 +75,7 @@ def main(): model.save_quantized(quantized_model_dir, use_safetensors=True) # load quantized model, currently only support cpu or single gpu - model = GPTQModel.from_quantized(quantized_model_dir, device="cuda:0", use_triton=False) + model = GPTQModel.from_quantized(quantized_model_dir, device="cuda:0") # inference with model.generate print(tokenizer.decode(model.generate(**tokenizer("test is", return_tensors="pt").to("cuda:0"))[0])) diff --git a/examples/quantization/basic_usage_wikitext2.py b/examples/quantization/basic_usage_wikitext2.py index 40f5ff8c..cde1ed6e 100644 --- a/examples/quantization/basic_usage_wikitext2.py +++ b/examples/quantization/basic_usage_wikitext2.py @@ -1,6 +1,7 @@ import numpy as np import torch import torch.nn as nn + from gptqmodel import GPTQModel, QuantizeConfig pretrained_model_dir = "facebook/opt-125m" @@ -145,7 +146,7 @@ def main(): # quantize model, the calibration_dataset should be list of dict whose keys can only be "input_ids" and "attention_mask" # with value under torch.LongTensor type. - model.quantize(traindataset, use_triton=False) + model.quantize(traindataset) # save quantized model model.save_quantized(quantized_model_dir) @@ -154,7 +155,7 @@ def main(): model.save_quantized(quantized_model_dir, use_safetensors=True) # load quantized model, currently only support cpu or single gpu - model = GPTQModel.from_quantized(quantized_model_dir, device="cuda:0", use_triton=False) + model = GPTQModel.from_quantized(quantized_model_dir, device="cuda:0") opt_eval(model.model, testenc, "cuda:0") diff --git a/gptqmodel/__init__.py b/gptqmodel/__init__.py index 50f7fd9e..39c525e3 100644 --- a/gptqmodel/__init__.py +++ b/gptqmodel/__init__.py @@ -1,4 +1,5 @@ from .models import GPTQModel from .quantization import BaseQuantizeConfig, QuantizeConfig +from .utils import Backend, get_backend from .utils.exllama import exllama_set_max_input_length from .version import __version__ diff --git a/gptqmodel/models/auto.py b/gptqmodel/models/auto.py index 72feac8c..f717bccf 100644 --- a/gptqmodel/models/auto.py +++ b/gptqmodel/models/auto.py @@ -1,5 +1,6 @@ from typing import Dict, List, Optional, Union +from ..utils import Backend from ..utils.model import check_and_get_model_type from .baichuan import BaiChuanGPTQ from .base import BaseGPTQModel, QuantizeConfig @@ -108,30 +109,19 @@ def from_quantized( device_map: Optional[Union[str, Dict[str, Union[str, int]]]] = None, max_memory: Optional[dict] = None, device: Optional[Union[str, int]] = None, - use_triton: bool = False, + backend: Backend = Backend.AUTO, use_cuda_fp16: bool = True, quantize_config: Optional[QuantizeConfig | Dict] = None, model_basename: Optional[str] = None, use_safetensors: bool = True, trust_remote_code: bool = False, warmup_triton: bool = False, - disable_exllama: Optional[bool] = None, - disable_exllamav2: bool = False, - use_marlin: bool = False, - use_bitblas: bool = False, # verify weight files matches predefined hash during loading # usage: hash_format:hash_value, example: md5:ugkdh232 # supports all hashlib hash methods verify_hash: Optional[Union[str, List[str]]] = None, **kwargs, ) -> BaseGPTQModel: - # If disable_exllamav2 is True, we want to fall back on the exllama kernel and not the cuda/cuda_old ones. - if disable_exllama is None: - if disable_exllamav2: - disable_exllama = False - else: - disable_exllama = True - model_type = check_and_get_model_type(model_name_or_path, trust_remote_code) quant_func = MODEL_MAP[model_type].from_quantized @@ -140,17 +130,13 @@ def from_quantized( device_map=device_map, max_memory=max_memory, device=device, - use_triton=use_triton, + backend=backend, use_cuda_fp16=use_cuda_fp16, quantize_config=quantize_config, model_basename=model_basename, use_safetensors=use_safetensors, trust_remote_code=trust_remote_code, warmup_triton=warmup_triton, - disable_exllama=disable_exllama, - disable_exllamav2=disable_exllamav2, - use_marlin=use_marlin, - use_bitblas=use_bitblas, verify_hash=verify_hash, **kwargs, ) diff --git a/gptqmodel/models/base.py b/gptqmodel/models/base.py index e40de30e..44726625 100644 --- a/gptqmodel/models/base.py +++ b/gptqmodel/models/base.py @@ -20,6 +20,7 @@ from ..quantization import GPTQ, QuantizeConfig from ..quantization.config import (FORMAT, FORMAT_FIELD_JSON, META_FIELD_QUANTIZER, META_QUANTIZER_GPTQMODEL, MIN_VERSION_WITH_V2, QUANTIZE_BLACK_LIST) +from ..utils.backend import Backend from ..utils.bitblas import convert_to_bitblas, prepare_model_for_bitblas_load from ..utils.data import collate_data from ..utils.importer import select_quant_linear @@ -76,8 +77,6 @@ def __init__( model: PreTrainedModel, quantized: bool, quantize_config: QuantizeConfig, - # TODO: remove is_triton_backend arg..why? doesn't pass smell test @ZX-ModelCloud - is_triton_backend: bool = False, qlinear_kernel: nn.Module = None, ): super().__init__() @@ -88,8 +87,6 @@ def __init__( self.quantize_config = quantize_config self.config = self.model.config - self.is_triton_backend = is_triton_backend - # compat: state to assist in checkpoint_format gptq(v1) to gptq_v2 conversion self.qlinear_kernel = qlinear_kernel @@ -156,8 +153,8 @@ def quantize( calibration_dataset: List[Dict[str, Union[List[int], torch.LongTensor]]], batch_size: int = 1, - # TODO: remove use_triton and use_cuda_fp16 arg..why? doesn't pass smell test @ZX-ModelCloud - use_triton: bool = False, + backend: Backend = Backend.AUTO, + # TODO: remove use_cuda_fp16 arg..why? doesn't pass smell test @ZX-ModelCloud use_cuda_fp16: bool = True, autotune_warmup_after_quantized: bool = False, @@ -427,13 +424,12 @@ def tmp(_, inp, out): quantizers=quantizers, bits=self.quantize_config.bits, group_size=self.quantize_config.group_size, - use_triton=use_triton, + backend=backend, use_cuda_fp16=use_cuda_fp16, desc_act=self.quantize_config.desc_act, warmup_triton=autotune_warmup_after_quantized, force_layer_back_to_cpu=force_layer_back_to_cpu, - use_marlin=self.quantize_config.format == FORMAT.MARLIN, - use_bitblas=self.quantize_config.format == FORMAT.BITBLAS, + format=self.quantize_config.format, ) if device_map: self.model = remove_hook_from_module(self.model, recurse=True) @@ -810,14 +806,7 @@ def from_quantized( max_memory: Optional[dict] = None, device: Optional[Union[str, int]] = None, - # TODO: refract this bewildering amount of ugly args @ZX-ModelCloud - # combine into Backend.ENUM class of Backend.AUTO, Backend.TRITON, Backend.MARLIN - # single arp of backend: Backend = Backend.AUTO (default to auto) - use_triton: bool = True, - use_marlin: bool = True, - use_bitblas: bool = False, - disable_exllama: bool = False, - disable_exllamav2: bool = False, + backend: Backend = Backend.AUTO, torch_dtype: [str | torch.dtype] = "auto", use_cuda_fp16: bool = True, @@ -832,13 +821,6 @@ def from_quantized( **kwargs, ): """load quantized model from local disk""" - # If disable_exllamav2 is True, we want to fall back on the exllama kernel and not the cuda/cuda_old ones. - if disable_exllama is None: - if disable_exllamav2: - disable_exllama = False - else: - disable_exllama = True - if cls.require_trust_remote_code and not trust_remote_code: raise ValueError( f"{model_name_or_path} requires trust_remote_code=True. Please set trust_remote_code=True to load this model." @@ -868,12 +850,6 @@ def from_quantized( "_commit_hash": commit_hash, } - if not disable_exllamav2 and not disable_exllama: - logger.warning( - "You have activated both exllama and exllamav2 kernel. Setting disable_exllama to True and keeping disable_exllamav2 to False" - ) - disable_exllama = True - # == step1: prepare configs and file names == # config: PretrainedConfig = AutoConfig.from_pretrained( model_name_or_path, @@ -899,20 +875,20 @@ def from_quantized( if quantize_config.format == FORMAT.MARLIN: # format marlin requires marlin kernel - use_marlin = True + backend = Backend.MARLIN marlin_compatible = _validate_marlin_device_support() - if not use_marlin: + if backend != Backend.MARLIN: unsupported = _validate_marlin_compatibility(quantize_config) if unsupported is None and marlin_compatible: logger.info( - "You passed a model that is compatible with the Marlin int4*fp16 GPTQ kernel but use_marlin is False. We recommend using `use_marlin=True` to use the optimized Marlin kernels for inference. Example: `model = GPTQModel.from_quantized(..., use_marlin=True)`." + "You passed a model that is compatible with the Marlin int4*fp16 GPTQ kernel but backend is not Backend.MARLIN. We recommend using `backend=Backend.MARLIN` to use the optimized Marlin kernels for inference. Example: `model = GPTQModel.from_quantized(..., backend=Backend.MARLIN)`." ) if quantize_config.format == FORMAT.BITBLAS: # format bitblas requires bitblas kernel - use_bitblas = True + backend = Backend.BITBLAS if model_basename is None: if quantize_config.model_file_base_name: @@ -1012,13 +988,10 @@ def skip(*args, **kwargs): layers, quantize_config.bits, quantize_config.group_size, - use_triton=use_triton, - disable_exllama=disable_exllama, - disable_exllamav2=disable_exllamav2, + backend=Backend.AUTO, + format=quantize_config.format, use_cuda_fp16=use_cuda_fp16, desc_act=quantize_config.desc_act, - use_marlin=quantize_config.format == FORMAT.MARLIN, - use_bitblas=quantize_config.format == FORMAT.BITBLAS, ) model.tie_weights() @@ -1056,14 +1029,14 @@ def skip(*args, **kwargs): no_split_module_classes=[cls.layer_type], ) - if use_marlin: + if backend == Backend.MARLIN: if is_sharded: raise ValueError( "The loading of sharded checkpoints with Marlin is currently not supported." ) if not _validate_marlin_device_support(): raise ValueError( - f'Marlin kernel does not support this gpu with compute capability of `{torch.cuda.get_device_capability()}`. Please do not use `use_marlin=True`.' + f'Marlin kernel does not support this gpu with compute capability of `{torch.cuda.get_device_capability()}`. Please do not use `back=Backend.MARLIN`.' ) # Validate the model can run in Marlin. @@ -1079,11 +1052,8 @@ def skip(*args, **kwargs): group_size=quantize_config.group_size, desc_act=quantize_config.desc_act, sym=quantize_config.sym, - use_triton=use_triton, - disable_exllama=disable_exllama, - disable_exllamav2=disable_exllamav2, - use_marlin=False, - use_bitblas=False, + backend=Backend.AUTO, + format=quantize_config.format, ) # Prepare model for marlin load. @@ -1099,7 +1069,7 @@ def skip(*args, **kwargs): sym=quantize_config.sym, ) - if use_bitblas: + if backend == Backend.BITBLAS: if is_sharded: raise ValueError( "The loading of sharded checkpoints with BitBLAS is currently not supported. Please raise an issue in GPTQModel repository.") @@ -1111,11 +1081,8 @@ def skip(*args, **kwargs): group_size=quantize_config.group_size, desc_act=quantize_config.desc_act, sym=quantize_config.sym, - use_triton=use_triton, - disable_exllama=disable_exllama, - disable_exllamav2=disable_exllamav2, - use_marlin=False, - use_bitblas=False, + backend=Backend.AUTO, + format=quantize_config.format, ) # Prepare model for bitblas load. @@ -1133,7 +1100,7 @@ def skip(*args, **kwargs): # If we use marlin or bitblas to load the quantized model, the model is already a converted model, # and we no longer need to call load_checkpoint_in_model() - if not (use_marlin or use_bitblas): + if backend != Backend.MARLIN and backend != Backend.BITBLAS: accelerate.utils.modeling.load_checkpoint_in_model( model, dtype=torch_dtype, # This is very hacky but works due to https://github.com/huggingface/accelerate/blob/bd72a5f1a80d5146554458823f8aeda0a9db5297/src/accelerate/utils/modeling.py#L292 @@ -1151,11 +1118,8 @@ def skip(*args, **kwargs): group_size=quantize_config.group_size, desc_act=quantize_config.desc_act, sym=quantize_config.sym, - use_triton=use_triton, - disable_exllama=disable_exllama, - disable_exllamav2=disable_exllamav2, - use_marlin=use_marlin, - use_bitblas=use_bitblas, + backend=backend, + format=quantize_config.format, ) # compat: runtime convert checkpoint gptq(v1) to gptq_v2 format @@ -1194,7 +1158,7 @@ def skip(*args, **kwargs): model.eval() # == step6: (optional) warmup triton == # - if use_triton and warmup_triton: + if backend != Backend.TRITON and warmup_triton: from ..nn_modules.qlinear.qlinear_tritonv2 import QuantLinear QuantLinear.warmup(model, seqlen=model.seqlen) @@ -1203,7 +1167,6 @@ def skip(*args, **kwargs): model, quantized=True, quantize_config=quantize_config, - is_triton_backend=use_triton, qlinear_kernel=qlinear_kernel, ) diff --git a/gptqmodel/nn_modules/qlinear/__init__.py b/gptqmodel/nn_modules/qlinear/__init__.py index a8d81ed9..ff590d3c 100644 --- a/gptqmodel/nn_modules/qlinear/__init__.py +++ b/gptqmodel/nn_modules/qlinear/__init__.py @@ -11,19 +11,27 @@ class BaseQuantLinear(nn.Module): SUPPORTED_SYM = [True, False] SUPPORTED_SHARDS: bool = True - def validate(self, bits: int, group_size: int, desc_act: bool, sym: bool): - if self.SUPPORTED_BITS and bits not in self.SUPPORTED_BITS: - raise NotImplementedError(f"{self.QUANT_TYPE} only supports `{self.SUPPORTED_BITS}` bits: actual bits = `{bits}`") - - if self.SUPPORTED_GROUP_SIZE and group_size not in self.SUPPORTED_GROUP_SIZE: - raise NotImplementedError( - f"{self.QUANT_TYPE} only supports `{self.SUPPORTED_GROUP_SIZE}` group_size: actual group_size = `{group_size}`") - - if self.SUPPORTED_SYM and sym not in self.SUPPORTED_SYM: - raise NotImplementedError(f"{self.QUANT_TYPE} only supports `{self.SUPPORTED_SYM}` bits: actual sym = `{sym}`") - - if self.SUPPORTED_DESC_ACT and desc_act not in self.SUPPORTED_DESC_ACT: - raise NotImplementedError(f"{self.QUANT_TYPE} only supports `{self.SUPPORTED_DESC_ACT}` bits: actual desc_act = `{desc_act}`") + @classmethod + def validate(cls, bits: int, group_size: int, desc_act: bool, sym: bool, raise_error: bool = True) -> bool: + validate = True + err = "" + if cls.SUPPORTED_BITS and bits not in cls.SUPPORTED_BITS: + validate = False + err = f"{cls.QUANT_TYPE} only supports `{cls.SUPPORTED_BITS}` bits: actual bits = `{bits}`" + elif cls.SUPPORTED_GROUP_SIZE and group_size not in cls.SUPPORTED_GROUP_SIZE: + validate = False + err = f"{cls.QUANT_TYPE} only supports `{cls.SUPPORTED_GROUP_SIZE}` group_size: actual group_size = `{group_size}`" + elif cls.SUPPORTED_SYM and sym not in cls.SUPPORTED_SYM: + validate = False + err = f"{cls.QUANT_TYPE} only supports `{cls.SUPPORTED_SYM}` bits: actual sym = `{sym}`" + elif cls.SUPPORTED_DESC_ACT and desc_act not in cls.SUPPORTED_DESC_ACT: + validate = False + err = f"{cls.QUANT_TYPE} only supports `{cls.SUPPORTED_DESC_ACT}` bits: actual desc_act = `{desc_act}`" + + if not validate and raise_error: + raise NotImplementedError(err) + + return validate # override me def post_init(self): diff --git a/gptqmodel/nn_modules/qlinear/qlinear_bitblas.py b/gptqmodel/nn_modules/qlinear/qlinear_bitblas.py index d4e56ce0..c366680e 100644 --- a/gptqmodel/nn_modules/qlinear/qlinear_bitblas.py +++ b/gptqmodel/nn_modules/qlinear/qlinear_bitblas.py @@ -7,6 +7,7 @@ import torch import torch.nn as nn + from gptqmodel.nn_modules.qlinear import BaseQuantLinear from .qlinear_cuda_old import QuantLinear as QuantLinearOld @@ -36,6 +37,7 @@ def import_bitblas(): bitblas.set_log_level("INFO") from bitblas.cache import get_database_path + from .bitblas_target_detector import patched_auto_detect_nvidia_target BITBLAS_TARGET = patched_auto_detect_nvidia_target(int(os.environ.get("CUDA_VISIBLE_DEVICES", "0"))) diff --git a/gptqmodel/nn_modules/qlinear/qlinear_cuda.py b/gptqmodel/nn_modules/qlinear/qlinear_cuda.py index 6e345f9a..1e30d91b 100644 --- a/gptqmodel/nn_modules/qlinear/qlinear_cuda.py +++ b/gptqmodel/nn_modules/qlinear/qlinear_cuda.py @@ -1,12 +1,13 @@ import math from logging import getLogger -import gptqmodel_cuda_64 -import gptqmodel_cuda_256 import numpy as np import torch import torch.nn as nn import transformers + +import gptqmodel_cuda_64 +import gptqmodel_cuda_256 from gptqmodel.nn_modules.qlinear import BaseCudaQuantLinear logger = getLogger(__name__) diff --git a/gptqmodel/nn_modules/qlinear/qlinear_cuda_old.py b/gptqmodel/nn_modules/qlinear/qlinear_cuda_old.py index b1b2360d..1248eb5f 100644 --- a/gptqmodel/nn_modules/qlinear/qlinear_cuda_old.py +++ b/gptqmodel/nn_modules/qlinear/qlinear_cuda_old.py @@ -1,12 +1,13 @@ import math from logging import getLogger -import gptqmodel_cuda_64 -import gptqmodel_cuda_256 import numpy as np import torch import torch.nn as nn import transformers + +import gptqmodel_cuda_64 +import gptqmodel_cuda_256 from gptqmodel.nn_modules.qlinear import BaseCudaQuantLinear logger = getLogger(__name__) diff --git a/gptqmodel/nn_modules/qlinear/qlinear_exllama.py b/gptqmodel/nn_modules/qlinear/qlinear_exllama.py index f6396155..dfd78f13 100644 --- a/gptqmodel/nn_modules/qlinear/qlinear_exllama.py +++ b/gptqmodel/nn_modules/qlinear/qlinear_exllama.py @@ -7,9 +7,10 @@ import torch import torch.nn as nn import transformers -from gptqmodel.nn_modules.qlinear import BaseQuantLinear from gptqmodel_exllama_kernels import make_q4, q4_matmul +from gptqmodel.nn_modules.qlinear import BaseQuantLinear + logger = getLogger(__name__) diff --git a/gptqmodel/nn_modules/qlinear/qlinear_exllamav2.py b/gptqmodel/nn_modules/qlinear/qlinear_exllamav2.py index cad775b6..093c6153 100644 --- a/gptqmodel/nn_modules/qlinear/qlinear_exllamav2.py +++ b/gptqmodel/nn_modules/qlinear/qlinear_exllamav2.py @@ -4,9 +4,10 @@ from logging import getLogger import torch -from gptqmodel.nn_modules.qlinear import BaseQuantLinear from gptqmodel_exllamav2_kernels import gemm_half_q_half, make_q_matrix +from gptqmodel.nn_modules.qlinear import BaseQuantLinear + logger = getLogger(__name__) diff --git a/gptqmodel/nn_modules/qlinear/qlinear_marlin.py b/gptqmodel/nn_modules/qlinear/qlinear_marlin.py index 6cbba536..768c3639 100644 --- a/gptqmodel/nn_modules/qlinear/qlinear_marlin.py +++ b/gptqmodel/nn_modules/qlinear/qlinear_marlin.py @@ -3,9 +3,10 @@ from logging import getLogger -import gptqmodel_marlin_cuda import numpy as np import torch + +import gptqmodel_marlin_cuda from gptqmodel.nn_modules.qlinear import BaseQuantLinear logger = getLogger(__name__) @@ -75,7 +76,7 @@ def __init__(self, bits: int, group_size: int, sym: bool, desc_act: bool, infeat if not torch.cuda.get_device_capability()[0] >= 8: raise ValueError( - f'Can not use Marlin int4*fp16 kernel with a device of compute capability {torch.cuda.get_device_capability()}, the minimum compute capability is 8.0 for Marlin kernel. Please do not use `use_marlin=True`, or please upgrade your GPU ("The more you buy, the more you save." - Taiwanese proverb).' + f'Can not use Marlin int4*fp16 kernel with a device of compute capability {torch.cuda.get_device_capability()}, the minimum compute capability is 8.0 for Marlin kernel. Please do not use `backend=Backend.MARLIN`, or please upgrade your GPU ("The more you buy, the more you save." - Taiwanese proverb).' ) if infeatures % 128 != 0 or outfeatures % 256 != 0: diff --git a/gptqmodel/utils/__init__.py b/gptqmodel/utils/__init__.py index bf1d68dc..a123dc04 100644 --- a/gptqmodel/utils/__init__.py +++ b/gptqmodel/utils/__init__.py @@ -1 +1,2 @@ +from .backend import Backend, get_backend from .perplexity import Perplexity diff --git a/gptqmodel/utils/backend.py b/gptqmodel/utils/backend.py new file mode 100644 index 00000000..16fc882c --- /dev/null +++ b/gptqmodel/utils/backend.py @@ -0,0 +1,19 @@ +from enum import Enum + + +class Backend(Enum): + AUTO = 0 # choose the fastest one based on quant model compatibility + CUDA_OLD = 1 + CUDA = 2 + TRITON = 3 + EXLLAMA = 4 + EXLLAMA_V2 = 5 + MARLIN = 6 + BITBLAS = 7 + + +def get_backend(backend: str): + try: + return Backend[backend] + except KeyError: + raise ValueError(f"Invalid Backend str: {backend}") diff --git a/gptqmodel/utils/importer.py b/gptqmodel/utils/importer.py index 76f177bd..f8dc19f7 100644 --- a/gptqmodel/utils/importer.py +++ b/gptqmodel/utils/importer.py @@ -1,35 +1,73 @@ +from collections import OrderedDict from logging import getLogger +from ..quantization import FORMAT +from .backend import Backend + logger = getLogger(__name__) +from ..nn_modules.qlinear.qlinear_bitblas import QuantLinear as BitBLASQuantLinear +from ..nn_modules.qlinear.qlinear_cuda import QuantLinear as CudaQuantLinear +from ..nn_modules.qlinear.qlinear_cuda_old import QuantLinear as CudaOldQuantLinear +from ..nn_modules.qlinear.qlinear_exllama import QuantLinear as ExllamaQuantLinear +from ..nn_modules.qlinear.qlinear_exllamav2 import QuantLinear as ExllamaV2QuantLinear +from ..nn_modules.qlinear.qlinear_marlin import QuantLinear as MarlinQuantLinear +from ..nn_modules.qlinear.qlinear_tritonv2 import QuantLinear as TritonV2QuantLinear + +backend_dict = OrderedDict({ + Backend.MARLIN: MarlinQuantLinear, + Backend.EXLLAMA_V2: ExllamaV2QuantLinear, + Backend.EXLLAMA: ExllamaQuantLinear, + Backend.TRITON: TritonV2QuantLinear, + Backend.CUDA_OLD: CudaOldQuantLinear, + Backend.CUDA: CudaQuantLinear, + Backend.BITBLAS: BitBLASQuantLinear, +}) + +format_dict = { + FORMAT.GPTQ: [Backend.EXLLAMA_V2, Backend.EXLLAMA, Backend.CUDA_OLD, Backend.CUDA], + FORMAT.GPTQ_V2: [Backend.EXLLAMA_V2, Backend.EXLLAMA, Backend.CUDA_OLD, Backend.CUDA], + FORMAT.MARLIN: [Backend.MARLIN], + FORMAT.BITBLAS: [Backend.BITBLAS], +} + # auto select the correct/optimal QuantLinear class def select_quant_linear( - bits: int, - group_size: int, - desc_act: bool, - sym: bool, - use_triton: bool, - disable_exllama: bool = False, - disable_exllamav2: bool = False, - use_marlin: bool = False, - use_bitblas: bool = False, + bits: int, + group_size: int, + desc_act: bool, + sym: bool, + backend: Backend, + format: str, + pack: bool = False, ): - if use_triton: + # Handle the case where backend is AUTO. + if backend == Backend.AUTO: + allow_backends = format_dict[format] + for k, v in backend_dict.items(): + in_allow_backends = k in allow_backends + validate = v.validate(bits, group_size, desc_act, sym, raise_error=False) + check_pack_func = hasattr(v, "pack") if pack else True + if in_allow_backends and validate and check_pack_func: + logger.info(f"Auto choose the fastest one based on quant model compatibility: {v}") + return v + + # Handle the case where backend is not AUTO. + if backend == Backend.TRITON: logger.info("Using tritonv2 for GPTQ") from ..nn_modules.qlinear.qlinear_tritonv2 import QuantLinear + elif backend == Backend.BITBLAS: + from ..nn_modules.qlinear.qlinear_bitblas import QuantLinear + elif bits == 4 and sym and not desc_act and backend == Backend.MARLIN: + from ..nn_modules.qlinear.qlinear_marlin import QuantLinear + elif bits == 4 and backend == Backend.EXLLAMA_V2: + from ..nn_modules.qlinear.qlinear_exllamav2 import QuantLinear + elif bits == 4 and backend == Backend.EXLLAMA: + from ..nn_modules.qlinear.qlinear_exllama import QuantLinear + elif not desc_act or group_size == -1: + from ..nn_modules.qlinear.qlinear_cuda_old import QuantLinear else: - if use_bitblas: - from ..nn_modules.qlinear.qlinear_bitblas import QuantLinear - elif bits == 4 and sym and not desc_act and use_marlin: - from ..nn_modules.qlinear.qlinear_marlin import QuantLinear - elif bits == 4 and not disable_exllamav2: - from ..nn_modules.qlinear.qlinear_exllamav2 import QuantLinear - elif bits == 4 and not disable_exllama: - from ..nn_modules.qlinear.qlinear_exllama import QuantLinear - elif not desc_act or group_size == -1: - from ..nn_modules.qlinear.qlinear_cuda_old import QuantLinear - else: - from ..nn_modules.qlinear.qlinear_cuda import QuantLinear + from ..nn_modules.qlinear.qlinear_cuda import QuantLinear return QuantLinear diff --git a/gptqmodel/utils/model.py b/gptqmodel/utils/model.py index 0053651f..d24e6243 100644 --- a/gptqmodel/utils/model.py +++ b/gptqmodel/utils/model.py @@ -18,6 +18,7 @@ from ..models._const import CPU, CUDA_0, EXLLAMA_DEFAULT_MAX_INPUT_LENGTH, EXPERT_INDEX_PLACEHOLDER, SUPPORTED_MODELS from ..nn_modules.qlinear import BaseQuantLinear from ..quantization import QuantizeConfig +from .backend import Backend from .importer import select_quant_linear logger = getLogger(__name__) @@ -104,26 +105,21 @@ def make_quant( names, bits: int, group_size: int, + backend: Backend, + format: str, desc_act: bool = False, sym: bool = True, - use_triton: bool = False, - use_marlin: bool = False, - use_bitblas: bool = False, - disable_exllama: bool = False, - disable_exllamav2: bool = False, use_cuda_fp16: bool = True, - + pack: bool = False, ): QuantLinear = select_quant_linear_with_pack( bits=bits, group_size=group_size, desc_act=desc_act, sym=sym, - use_triton=use_triton, - use_marlin=use_marlin, - use_bitblas=use_bitblas, - disable_exllama=disable_exllama, - disable_exllamav2=disable_exllamav2, + backend=backend, + format=format, + pack=pack, ) if isinstance(module, QuantLinear): @@ -146,7 +142,7 @@ def make_quant( raise NotImplementedError(f"Unsupported module {submodule}") bias = submodule.bias is not None - if (not (desc_act) or group_size == -1) and not use_triton: + if (not (desc_act) or group_size == -1) and backend != Backend.TRITON: new_layer = QuantLinear( bits=bits, group_size=group_size, @@ -243,25 +239,19 @@ def select_quant_linear_with_pack(bits: int, group_size: int, desc_act: bool, sym: bool, - use_triton: bool, - disable_exllama: bool, - disable_exllamav2: bool , - use_marlin: bool , - use_bitblas: bool,): + backend: Backend, format: str, pack: bool): # If Format is BitBLAS, BitBLASQuantLinear is not used during packing, # and the format is converted to BitBLAS in save_quantized(). - if use_bitblas: - use_bitblas = False + if backend == Backend.BITBLAS: + backend = Backend.AUTO QuantLinear = select_quant_linear( bits=bits, group_size=group_size, desc_act=desc_act, sym=sym, - use_triton=use_triton, - disable_exllama=disable_exllama, - disable_exllamav2=disable_exllamav2, - use_marlin=use_marlin, - use_bitblas=use_bitblas, + backend=backend, + format=format, + pack=pack, ) return QuantLinear @@ -270,25 +260,22 @@ def pack_model( quantizers, bits, group_size, + backend: Backend, + format: str, desc_act=False, sym: bool = True, - use_triton=False, use_cuda_fp16=True, warmup_triton: bool = False, force_layer_back_to_cpu: bool = False, - use_marlin: bool = False, - use_bitblas: bool = False, ): QuantLinear = select_quant_linear_with_pack( bits=bits, group_size=group_size, desc_act=desc_act, sym=sym, - use_triton=use_triton, - disable_exllama=False, - disable_exllamav2=True, - use_marlin=use_marlin, - use_bitblas=use_bitblas, + backend=backend, + format=format, + pack=True, ) if force_layer_back_to_cpu: @@ -302,13 +289,11 @@ def pack_model( quantizers, bits, group_size, - use_triton=use_triton, + backend=backend, + format=format, use_cuda_fp16=use_cuda_fp16, desc_act=desc_act, - disable_exllama=False, - disable_exllamav2=True, - use_marlin=use_marlin, - use_bitblas=use_bitblas, + pack=True, ) qlayers = find_layers(model, [QuantLinear]) @@ -336,7 +321,7 @@ def pack_model( logger.info("Model packed.") - if use_triton and warmup_triton: + if backend != Backend.TRITON and warmup_triton: logger.warning( "using autotune_warmup will move model to GPU, make sure you have enough VRAM to load the whole model." ) diff --git a/tests/test_lm_head.py b/tests/test_lm_head.py index bb20734b..ed13b2a0 100644 --- a/tests/test_lm_head.py +++ b/tests/test_lm_head.py @@ -9,9 +9,10 @@ import numpy # noqa: E402 import torch # noqa: E402 -from gptqmodel import GPTQModel # noqa: E402 from transformers import AutoTokenizer # noqa: E402 +from gptqmodel import GPTQModel # noqa: E402 + class TestLmHead(unittest.TestCase): MODEL_ID = "LnL-AI/TinyLlama-1.1B-intermediate-step-1341k-3T-autoround-lm_head-symFalse" diff --git a/tests/test_perplexity.py b/tests/test_perplexity.py index 09d66d16..1582e9aa 100644 --- a/tests/test_perplexity.py +++ b/tests/test_perplexity.py @@ -7,11 +7,12 @@ import tempfile # noqa: E402 import unittest # noqa: E402 +from parameterized import parameterized # noqa: E402 +from transformers import AutoTokenizer # noqa: E402 + from gptqmodel import GPTQModel # noqa: E402 from gptqmodel.quantization import FORMAT, QuantizeConfig # noqa: E402 from gptqmodel.utils import Perplexity # noqa: E402 -from parameterized import parameterized # noqa: E402 -from transformers import AutoTokenizer # noqa: E402 class TestPerplexity(unittest.TestCase): diff --git a/tests/test_q4_bitblas.py b/tests/test_q4_bitblas.py index 44611688..ef508f48 100644 --- a/tests/test_q4_bitblas.py +++ b/tests/test_q4_bitblas.py @@ -7,6 +7,7 @@ import unittest # noqa: E402 import torch # noqa: E402 + from gptqmodel.nn_modules.qlinear.qlinear_bitblas import QuantLinear as BitBLASQuantLinear # noqa: E402 try: @@ -14,9 +15,10 @@ except ImportError as e: print(f"[WARNING] Could not load gptqmodel_exllama_kernels: {e}") -from gptqmodel import GPTQModel # noqa: E402 from transformers import AutoTokenizer # noqa: E402 +from gptqmodel import Backend, GPTQModel # noqa: E402 + class TestQ4BitBLAS(unittest.TestCase): def test_generation(self): @@ -29,7 +31,7 @@ def test_generation(self): model_id = "TheBloke/Llama-2-7B-Chat-GPTQ" try: - model_q = GPTQModel.from_quantized(model_id, device="cuda:0", use_bitblas=True) + model_q = GPTQModel.from_quantized(model_id, device="cuda:0", backend=Backend.BITBLAS) except ValueError as e: raise e @@ -54,7 +56,7 @@ def test_bias(self): # TheBloke/Llama-2-7B-Chat-GPTQ has bias, but they are all zeros, use a checkpoint which really uses bias. model_id = "s3nh/starcoderbase-1b-GPTQ" try: - model_q = GPTQModel.from_quantized(model_id, device="cuda:0", use_bitblas=True) + model_q = GPTQModel.from_quantized(model_id, device="cuda:0", backend=Backend.BITBLAS) except ValueError as e: raise e diff --git a/tests/test_q4_cuda.py b/tests/test_q4_cuda.py index 1bb5a43c..def42b23 100644 --- a/tests/test_q4_cuda.py +++ b/tests/test_q4_cuda.py @@ -1,23 +1,27 @@ # -- do not touch import os +from gptqmodel import Backend + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # -- end do not touch import unittest # noqa: E402 import torch # noqa: E402 -from gptqmodel.nn_modules.qlinear.qlinear_cuda_old import QuantLinear as QuantLinearCudaOld # noqa: E402 from parameterized import parameterized # noqa: E402 +from gptqmodel.nn_modules.qlinear.qlinear_cuda_old import QuantLinear as QuantLinearCudaOld # noqa: E402 + try: from gptqmodel_exllama_kernels import prepare_buffers, set_tuning_params # noqa: F401 except ImportError as e: print(f"[WARNING] Could not load gptqmodel_exllama_kernels: {e}") -from gptqmodel import GPTQModel # noqa: E402 from transformers import AutoTokenizer # noqa: E402 +from gptqmodel import GPTQModel # noqa: E402 + GENERATE_EVAL_SIZE = 100 def get_diff(a, ref): @@ -619,9 +623,7 @@ def test_generation_desc_act_true(self, torch_dtype, device): model_id, revision=revision, device=device, - use_triton=False, - disable_exllama=True, - disable_exllamav2=True, + backend=Backend.CUDA, torch_dtype=torch_dtype, ) @@ -664,9 +666,7 @@ def test_generation_desc_act_false(self, torch_dtype, device): model_q = GPTQModel.from_quantized( model_id, device=device, - use_triton=False, - disable_exllama=True, - disable_exllamav2=True, + backend=Backend.CUDA, torch_dtype=torch_dtype, ) tokenizer = AutoTokenizer.from_pretrained(model_id) diff --git a/tests/test_q4_exallama.py b/tests/test_q4_exallama.py index 27b53aaa..e4fde0c8 100644 --- a/tests/test_q4_exallama.py +++ b/tests/test_q4_exallama.py @@ -1,20 +1,23 @@ # -- do not touch import os +from gptqmodel import Backend + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # -- end do not touch import unittest # noqa: E402 import torch # noqa: E402 +from gptqmodel_exllama_kernels import prepare_buffers, set_tuning_params # noqa: E402 +from test_q4_cuda import get_diff # noqa: E402 +from transformers import AutoTokenizer # noqa: E402 + from gptqmodel import GPTQModel, exllama_set_max_input_length # noqa: E402 from gptqmodel.models._const import EXLLAMA_DEFAULT_MAX_INPUT_LENGTH # noqa: E402 from gptqmodel.nn_modules.qlinear.qlinear_exllama import QuantLinear # noqa: E402 from gptqmodel.utils.importer import select_quant_linear # noqa: E402 from gptqmodel.utils.model import gptqmodel_post_init # noqa: E402 -from gptqmodel_exllama_kernels import prepare_buffers, set_tuning_params # noqa: E402 -from test_q4_cuda import get_diff # noqa: E402 -from transformers import AutoTokenizer # noqa: E402 CUDA_OLD_REFERENCE = torch.Tensor( [ @@ -1062,9 +1065,7 @@ def test_exllama(self): group_size=group_size, desc_act=False, sym=True, - use_triton=False, - disable_exllama=False, - disable_exllamav2=True, + backend=Backend.EXLLAMA, ) linear = linear_class( @@ -1128,9 +1129,7 @@ def test_exllama_buffer_size(self): model_id, revision=revision, device="cuda:0", - use_triton=False, - disable_exllama=False, - disable_exllamav2=True, + backend=Backend.EXLLAMA, ) tokenizer = AutoTokenizer.from_pretrained(model_id) @@ -1165,9 +1164,7 @@ def test_generation_desc_act_false(self): model_q = GPTQModel.from_quantized( model_id, device="cuda:0", - use_triton=False, - disable_exllama=False, - disable_exllamav2=True, + backend=Backend.EXLLAMA, ) tokenizer = AutoTokenizer.from_pretrained(model_id) @@ -1193,9 +1190,7 @@ def test_generation_desc_act_true(self): model_id, revision=revision, device="cuda:0", - use_triton=False, - disable_exllama=False, - disable_exllamav2=True, + backend=Backend.EXLLAMA, ) tokenizer = AutoTokenizer.from_pretrained(model_id) diff --git a/tests/test_q4_exallama_v2.py b/tests/test_q4_exallama_v2.py index ea85b21e..44928f1d 100644 --- a/tests/test_q4_exallama_v2.py +++ b/tests/test_q4_exallama_v2.py @@ -7,15 +7,15 @@ import unittest # noqa: E402 import torch # noqa: E402 -from gptqmodel import GPTQModel # noqa: E402 -from gptqmodel.nn_modules.qlinear.qlinear_exllamav2 import QuantLinear # noqa: E402 -from gptqmodel.utils.importer import select_quant_linear # noqa: E402 -from gptqmodel.utils.model import gptqmodel_post_init # noqa: E402 -from gptqmodel_exllama_kernels import prepare_buffers, set_tuning_params # noqa: F401 from test_q4_cuda import get_diff # noqa: E402 from test_q4_exallama import CUDA_OLD_REFERENCE # noqa: E402 from transformers import AutoTokenizer # noqa: E402 +from gptqmodel import Backend, GPTQModel # noqa: E402 +from gptqmodel.nn_modules.qlinear.qlinear_exllamav2 import QuantLinear # noqa: E402 +from gptqmodel.utils.importer import select_quant_linear # noqa: E402 +from gptqmodel.utils.model import gptqmodel_post_init # noqa: E402 + GENERATE_EVAL_SIZE = 100 class TestsQ4ExllamaV2(unittest.TestCase): @@ -32,7 +32,7 @@ def test_exllamav2(self): group_size=group_size, desc_act=False, sym=True, - use_triton=False, + backend=Backend.EXLLAMA_V2, ) linear = linear_class( @@ -79,7 +79,7 @@ def test_generation_desc_act_false(self): model_id = "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit" - model_q = GPTQModel.from_quantized(model_id, device="cuda:0", use_triton=False) + model_q = GPTQModel.from_quantized(model_id, device="cuda:0") tokenizer = AutoTokenizer.from_pretrained(model_id) inp = tokenizer(prompt, return_tensors="pt").to(device) @@ -104,7 +104,7 @@ def test_generation_desc_act_true(self): model_id, rivision=revision, device="cuda:0", - use_triton=False, + backend=Backend.EXLLAMA_V2, ) tokenizer = AutoTokenizer.from_pretrained(model_id) @@ -128,7 +128,7 @@ def test_exllama_v2_buffer_size(self): model_id, revision=revision, device="cuda:0", - use_triton=False, + backend=Backend.EXLLAMA_V2, ) tokenizer = AutoTokenizer.from_pretrained(model_id) diff --git a/tests/test_q4_marlin.py b/tests/test_q4_marlin.py index 51ac3b47..84894ec1 100644 --- a/tests/test_q4_marlin.py +++ b/tests/test_q4_marlin.py @@ -7,11 +7,11 @@ import unittest # noqa: E402 import torch # noqa: E402 -from gptqmodel import GPTQModel # noqa: E402 -from gptqmodel.nn_modules.qlinear.qlinear_marlin import QuantLinear as MarlinQuantLinear # noqa: E402 -from gptqmodel_exllama_kernels import prepare_buffers, set_tuning_params # noqa: F401 from transformers import AutoTokenizer # noqa: E402 +from gptqmodel import Backend, GPTQModel # noqa: E402 +from gptqmodel.nn_modules.qlinear.qlinear_marlin import QuantLinear as MarlinQuantLinear # noqa: E402 + class TestQ4Marlin(unittest.TestCase): def test_generation(self): @@ -24,7 +24,7 @@ def test_generation(self): model_id = "TheBloke/Llama-2-7B-Chat-GPTQ" try: - model_q = GPTQModel.from_quantized(model_id, device="cuda:0", use_marlin=True) + model_q = GPTQModel.from_quantized(model_id, device="cuda:0", backend=Backend.MARLIN) except ValueError as e: raise e @@ -49,7 +49,7 @@ def test_bias(self): # TheBloke/Llama-2-7B-Chat-GPTQ has bias, but they are all zeros, use a checkpoint which really uses bias. model_id = "s3nh/starcoderbase-1b-GPTQ" try: - model_q = GPTQModel.from_quantized(model_id, device="cuda:0", use_marlin=True) + model_q = GPTQModel.from_quantized(model_id, device="cuda:0", backend=Backend.MARLIN) except ValueError as e: raise e diff --git a/tests/test_q4_triton.py b/tests/test_q4_triton.py index 80409a73..d0e1c160 100644 --- a/tests/test_q4_triton.py +++ b/tests/test_q4_triton.py @@ -7,11 +7,11 @@ import unittest # noqa: E402 import torch # noqa: E402 -from gptqmodel import GPTQModel # noqa: E402 -from gptqmodel.nn_modules.qlinear.qlinear_tritonv2 import QuantLinear as TritonV2QuantLinear # noqa: E402 -from gptqmodel_exllama_kernels import prepare_buffers, set_tuning_params # noqa: F401 from transformers import AutoTokenizer # noqa: E402 +from gptqmodel import Backend, GPTQModel # noqa: E402 +from gptqmodel.nn_modules.qlinear.qlinear_tritonv2 import QuantLinear as TritonV2QuantLinear # noqa: E402 + GENERATE_EVAL_SIZE = 100 class TestsQ4Triton(unittest.TestCase): @@ -26,9 +26,7 @@ def test_generation_desc_act_false(self): model_q = GPTQModel.from_quantized( model_id, device="cuda:0", - use_triton=True, - disable_exllama=True, - disable_exllamav2=True, + backend=Backend.TRITON, torch_dtype=torch.float16, ) for _, submodule in model_q.named_modules(): @@ -66,10 +64,9 @@ def test_generation_desc_act_true(self): model_q = GPTQModel.from_quantized( model_id, device="cuda:0", + backend=Backend.TRITON, revision=revision, - use_triton=True, - disable_exllama=True, - disable_exllamav2=True, + ) for _, submodule in model_q.named_modules(): if isinstance(submodule, TritonV2QuantLinear): diff --git a/tests/test_quant_formats.py b/tests/test_quant_formats.py index 0ba45972..8a485b88 100644 --- a/tests/test_quant_formats.py +++ b/tests/test_quant_formats.py @@ -11,12 +11,13 @@ import unittest # noqa: E402 import torch.cuda # noqa: E402 -from gptqmodel import GPTQModel, __version__ # noqa: E402 -from gptqmodel.quantization import FORMAT, QUANT_CONFIG_FILENAME, QuantizeConfig # noqa: E402 -from gptqmodel.quantization.config import META_FIELD_QUANTIZER, META_QUANTIZER_GPTQMODEL # noqa: E402 from parameterized import parameterized # noqa: E402 from transformers import AutoTokenizer # noqa: E402 +from gptqmodel import Backend, GPTQModel, __version__ # noqa: E402 +from gptqmodel.quantization import FORMAT, QUANT_CONFIG_FILENAME, QuantizeConfig # noqa: E402 +from gptqmodel.quantization.config import META_FIELD_QUANTIZER, META_QUANTIZER_GPTQMODEL # noqa: E402 + class TestQuantization(unittest.TestCase): @@ -38,7 +39,7 @@ def setUp(self): (True, True, FORMAT.MARLIN), ] ) - def test_quantize(self, use_marlin: bool, sym: bool, format: FORMAT): + def test_quantize(self, backend: Backend, sym: bool, format: FORMAT): quantize_config = QuantizeConfig( bits=4, group_size=128, @@ -73,7 +74,7 @@ def test_quantize(self, use_marlin: bool, sym: bool, format: FORMAT): model = GPTQModel.from_quantized( tmpdirname, device="cuda:0", - use_marlin=use_marlin, + backend=Backend, ) logging.info(f"Loaded config: {model.quantize_config}") @@ -94,7 +95,7 @@ def test_quantize(self, use_marlin: bool, sym: bool, format: FORMAT): "group_size": 128, "sym": sym, "desc_act": False if format == FORMAT.MARLIN else True, - "is_marlin_format": use_marlin, + "is_marlin_format": backend == Backend.MARLIN, } model = GPTQModel.from_quantized( diff --git a/tests/test_serialization.py b/tests/test_serialization.py index 08ee203f..f33c0ddf 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -9,7 +9,7 @@ import tempfile # noqa: E402 import unittest # noqa: E402 -from gptqmodel import GPTQModel # noqa: E402 +from gptqmodel import Backend, GPTQModel # noqa: E402 from gptqmodel.quantization import FORMAT, FORMAT_FIELD_JSON, QUANT_CONFIG_FILENAME # noqa: E402 @@ -17,7 +17,7 @@ class TestSerialization(unittest.TestCase): MODEL_ID = "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit" def test_marlin_local_serialization(self): - model = GPTQModel.from_quantized(self.MODEL_ID, device="cuda:0", use_marlin=True) + model = GPTQModel.from_quantized(self.MODEL_ID, device="cuda:0", backend=Backend.MARLIN) with tempfile.TemporaryDirectory() as tmpdir: model.save_pretrained(tmpdir) @@ -29,13 +29,13 @@ def test_marlin_local_serialization(self): self.assertTrue(config[FORMAT_FIELD_JSON] == FORMAT.MARLIN) - model = GPTQModel.from_quantized(tmpdir, device="cuda:0", use_marlin=True) + model = GPTQModel.from_quantized(tmpdir, device="cuda:0", backend=Backend.MARLIN) def test_marlin_hf_cache_serialization(self): - model = GPTQModel.from_quantized(self.MODEL_ID, device="cuda:0", use_marlin=True) + model = GPTQModel.from_quantized(self.MODEL_ID, device="cuda:0", backend=Backend.MARLIN) self.assertTrue(model.quantize_config.format == FORMAT.MARLIN) - model = GPTQModel.from_quantized(self.MODEL_ID, device="cuda:0", use_marlin=True) + model = GPTQModel.from_quantized(self.MODEL_ID, device="cuda:0", backend=Backend.MARLIN) self.assertTrue(model.quantize_config.format == FORMAT.MARLIN) def test_gptq_v1_to_v2_runtime_convert(self): diff --git a/tests/test_sharded.py b/tests/test_sharded.py index cd660d97..3c93b944 100644 --- a/tests/test_sharded.py +++ b/tests/test_sharded.py @@ -8,10 +8,11 @@ import tempfile # noqa: E402 import unittest # noqa: E402 +from transformers import AutoTokenizer # noqa: E402 + from gptqmodel import GPTQModel # noqa: E402 from gptqmodel.quantization import QuantizeConfig # noqa: E402 from gptqmodel.quantization.config import FORMAT # noqa: E402 -from transformers import AutoTokenizer # noqa: E402 class TestSharded(unittest.TestCase): diff --git a/tests/test_triton.py b/tests/test_triton.py index a0abe7c2..5bba43e4 100644 --- a/tests/test_triton.py +++ b/tests/test_triton.py @@ -9,9 +9,10 @@ import torch # noqa: E402 import torch.utils.benchmark as benchmark # noqa: E402 -from gptqmodel import GPTQModel # noqa: E402 from transformers import AutoTokenizer # noqa: E402 +from gptqmodel import Backend, GPTQModel # noqa: E402 + MODEL_ID = "TheBloke/Llama-7B-GPTQ" DATASET_ID = "timdettmers/openassistant-guanaco" LEARNING_RATE = 3e-5 @@ -79,7 +80,7 @@ class TestTriton(unittest.TestCase): def test_triton_qlinear(self): ref_model, _ = get_model_and_tokenizer( model_id=MODEL_ID, - use_triton=True, + backend=Backend.TRITON, ) hidden_size = ref_model.model.model.embed_tokens.weight.shape[1] diff --git a/tests/test_verify_hash.py b/tests/test_verify_hash.py index 8b534a5d..2020bb6f 100644 --- a/tests/test_verify_hash.py +++ b/tests/test_verify_hash.py @@ -1,6 +1,6 @@ import unittest -from gptqmodel import GPTQModel +from gptqmodel import Backend, GPTQModel class TestVerifyHashFunction(unittest.TestCase): @@ -11,13 +11,13 @@ class TestVerifyHashFunction(unittest.TestCase): def test_verify_md5_hash_function(self): # Load the model with MD5 verify_hash parameter - model = GPTQModel.from_quantized(self.MODEL_ID, device="cuda:0", use_marlin=True, + model = GPTQModel.from_quantized(self.MODEL_ID, device="cuda:0", backend=Backend.MARLIN, verify_hash=self.EXPECTED_MD5_HASH) self.assertIsNotNone(model) def test_verify_sha256_hash_function(self): # Load the model with SHA-256 verify_hash parameter - model = GPTQModel.from_quantized(self.MODEL_ID, device="cuda:0", use_marlin=True, + model = GPTQModel.from_quantized(self.MODEL_ID, device="cuda:0", backend=Backend.MARLIN, verify_hash=self.EXPECTED_SHA256_HASH) # Add additional checks to ensure the model is loaded correctly self.assertIsNotNone(model)