Skip to content

Commit

Permalink
Consolidate Backend (ModelCloud#68)
Browse files Browse the repository at this point in the history
* Consolidate Backend

* change Backend.TRITON_V2 to Backend.TRITON

* According to quantize_config.format, determine when the Backend is packing the model.

* Auto choose the fastest one Backend based on quant model compatibility

* fix issue: Automatically select Backend, returns incorrect qlinear.

* cleanup

* cleanup
  • Loading branch information
ZX-ModelCloud authored Jun 27, 2024
1 parent 7212ecc commit c88e026
Show file tree
Hide file tree
Showing 36 changed files with 278 additions and 274 deletions.
21 changes: 7 additions & 14 deletions examples/benchmark/generation_speed.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@

import torch
from datasets import Dataset, load_dataset
from gptqmodel import GPTQModel, QuantizeConfig
from tqdm import tqdm
from transformers import AutoTokenizer, GenerationConfig
from transformers.generation.logits_process import LogitsProcessor

from gptqmodel import Backend, GPTQModel, QuantizeConfig, get_backend

logger = logging.getLogger(__name__)

random.seed(0)
Expand Down Expand Up @@ -143,17 +144,15 @@ def tokenize(examples):

def load_model_tokenizer(
model_name_or_path: str,
backend: Backend,
tokenizer_name_or_path: Optional[str] = None,
from_pretrained: bool = False,
max_memory: Optional[dict] = None,
model_basename: Optional[str] = None,
quantize_config: Optional[str] = None,
trust_remote_code: bool = False,
use_triton: bool = False,
use_bitblas: bool = False,
use_safetensors: bool = True,
use_fast_tokenizer: bool = False,
disable_exllama: bool = False,
):
tokenizer = AutoTokenizer.from_pretrained(
pretrained_model_name_or_path=tokenizer_name_or_path or model_name_or_path,
Expand All @@ -174,15 +173,13 @@ def load_model_tokenizer(
model = GPTQModel.from_quantized(
model_name_or_path,
max_memory=max_memory,
use_triton=use_triton,
use_bitblas=use_bitblas,
use_cuda_fp16=True,
quantize_config=quantize_config,
model_basename=model_basename,
use_safetensors=use_safetensors,
trust_remote_code=trust_remote_code,
warmup_triton=False,
disable_exllama=disable_exllama,
backend=backend,
)

return model, tokenizer
Expand Down Expand Up @@ -234,11 +231,9 @@ def main():
parser.add_argument("--model_basename", type=str, default=None)
parser.add_argument("--quantize_config_save_dir", type=str, default=None)
parser.add_argument("--trust_remote_code", action="store_true")
parser.add_argument("--use_triton", action="store_true")
parser.add_argument("--use_bitblas", action="store_true")
parser.add_argument("--backend", choices=['AUTO', 'CUDA_OLD', 'CUDA', 'TRITON', 'EXLLAMA', 'EXLLAMA_V2', 'MARLIN', 'BITBLAS'])
parser.add_argument("--use_safetensors", action="store_true")
parser.add_argument("--use_fast_tokenizer", action="store_true")
parser.add_argument("--disable_exllama", action="store_true")
parser.add_argument("--num_samples", type=int, default=10)
parser.add_argument("--per_gpu_max_memory", type=int, default=None)
parser.add_argument("--cpu_max_memory", type=int, default=None)
Expand Down Expand Up @@ -277,19 +272,17 @@ def main():
model_basename=args.model_basename,
quantize_config=quantize_config,
trust_remote_code=args.trust_remote_code,
use_triton=args.use_triton,
use_bitblas=args.use_bitblas,
use_safetensors=True,
use_fast_tokenizer=args.use_fast_tokenizer,
disable_exllama=args.disable_exllama,
backend=get_backend(args.backend),
)
end = time.time()
logger.info(f"model and tokenizer loading time: {end - start:.4f}s")
logger.info(f"model quantized: {model.quantized}")
logger.info(f"quantize config: {model.quantize_config.to_dict()}")
logger.info(f"model device map: {model.hf_device_map}")

if args.use_triton:
if args.backend == Backend.TRITON:
logger.info("warmup triton, this may take a while.")
model.warmup_triton()

Expand Down
13 changes: 5 additions & 8 deletions examples/benchmark/perplexity.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
import os

import torch
from gptqmodel.utils import Perplexity
from transformers import AutoTokenizer

from gptqmodel.utils import Perplexity, get_backend

if __name__ == "__main__":
"""
Example usage.
Expand Down Expand Up @@ -42,7 +43,7 @@
default=None,
help="Max memory used in each GPU.",
)
parser.add_argument("--cpu_max_memory", type=int, default=None, help="Mx memory used in CPU.")
parser.add_argument("--cpu_max_memory", type=int, default=None, help="Max memory used in CPU.")
parser.add_argument("--is_quantized", action="store_true", help="Is the model GPTQ quantized?")
parser.add_argument(
"--use_safetensors",
Expand All @@ -51,11 +52,7 @@
)
parser.add_argument("--use_fast_tokenizer", action="store_true", help="Wheter to use fast tokenizer")
parser.add_argument("--trust_remote_code", action="store_true", help="Whether to use remote code")
parser.add_argument(
"--disable_exllama",
action="store_true",
help="Whether to use disable exllama kernel",
)
parser.add_argument("--backend", choices=['AUTO', 'CUDA_OLD', 'CUDA', 'TRITON', 'EXLLAMA', 'EXLLAMA_V2', 'MARLIN', 'BITBLAS'], help="Whether to use Backend format")
args = parser.parse_args()

os.environ["TOKENIZERS_PARALLELISM"] = "false"
Expand Down Expand Up @@ -88,7 +85,7 @@
model_basename=args.model_basename,
use_safetensors=True,
trust_remote_code=args.trust_remote_code,
disable_exllama=args.disable_exllama,
backend=get_backend(args.backend),
)
else:
from transformers import AutoModelForCausalLM
Expand Down
9 changes: 5 additions & 4 deletions examples/evaluation/run_language_modeling_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

import datasets
import torch
from gptqmodel import GPTQModel, QuantizeConfig
from gptqmodel.eval_tasks import LanguageModelingTask
from transformers import AutoTokenizer

from gptqmodel import GPTQModel, QuantizeConfig, get_backend
from gptqmodel.eval_tasks import LanguageModelingTask

DATASET = "tatsu-lab/alpaca"
WITH_INPUT_TEMPLATE = "Instruction:\n{instruction}\n\nInput:\n{input}\n\nOutput:\n"
WITHOUT_INPUT_TEMPLATE = "Instruction:\n{instruction}\n\nOutput:\n"
Expand Down Expand Up @@ -40,7 +41,7 @@ def main():
)
parser.add_argument("--sample_max_len", type=int, default=1024, help="max tokens for each sample")
parser.add_argument("--block_max_len", type=int, default=2048, help="max tokens for each data block")
parser.add_argument("--use_triton", action="store_true")
parser.add_argument("--backend", choices=['AUTO', 'CUDA_OLD', 'CUDA', 'TRITON', 'EXLLAMA', 'EXLLAMA_V2', 'MARLIN', 'BITBLAS'])
args = parser.parse_args()

tokenizer = AutoTokenizer.from_pretrained(args.base_model_dir)
Expand Down Expand Up @@ -70,7 +71,7 @@ def main():
del model
torch.cuda.empty_cache()

model = GPTQModel.from_quantized(args.quantized_model_dir, device="cuda:0", use_triton=args.use_triton)
model = GPTQModel.from_quantized(args.quantized_model_dir, device="cuda:0", backend=get_backend(args.backend))
task.model = model
task.device = model.device
print(f"eval result for quantized model: {task.run()}")
Expand Down
9 changes: 5 additions & 4 deletions examples/evaluation/run_sequence_classification_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@

import datasets
import torch
from gptqmodel import GPTQModel, QuantizeConfig
from gptqmodel.eval_tasks import SequenceClassificationTask
from transformers import AutoTokenizer

from gptqmodel import GPTQModel, QuantizeConfig, get_backend
from gptqmodel.eval_tasks import SequenceClassificationTask

DATASET = "cardiffnlp/tweet_sentiment_multilingual"
TEMPLATE = "Question:What's the sentiment of the given text? Choices are {labels}.\nText: {text}\nAnswer:"
ID2LABEL = {0: "negative", 1: "neutral", 2: "positive"}
Expand Down Expand Up @@ -38,7 +39,7 @@ def main():
)
parser.add_argument("--sample_max_len", type=int, default=1024, help="max tokens for each sample")
parser.add_argument("--block_max_len", type=int, default=2048, help="max tokens for each data block")
parser.add_argument("--use_triton", action="store_true")
parser.add_argument("--backend", choices=['AUTO', 'CUDA_OLD', 'CUDA', 'TRITON', 'EXLLAMA', 'EXLLAMA_V2', 'MARLIN', 'BITBLAS'])
args = parser.parse_args()

tokenizer = AutoTokenizer.from_pretrained(args.base_model_dir)
Expand Down Expand Up @@ -69,7 +70,7 @@ def main():
del model
torch.cuda.empty_cache()

model = GPTQModel.from_quantized(args.quantized_model_dir, device="cuda:0", use_triton=args.use_triton)
model = GPTQModel.from_quantized(args.quantized_model_dir, device="cuda:0", backend=get_backend(args.backend))
task.model = model
task.device = model.device
print(f"eval result for quantized model: {task.run()}")
Expand Down
9 changes: 5 additions & 4 deletions examples/evaluation/run_text_summarization_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@

import datasets
import torch
from gptqmodel import GPTQModel, QuantizeConfig
from gptqmodel.eval_tasks import TextSummarizationTask
from transformers import AutoTokenizer, GenerationConfig

from gptqmodel import GPTQModel, QuantizeConfig, get_backend
from gptqmodel.eval_tasks import TextSummarizationTask

os.system("pip install py7zr")


Expand Down Expand Up @@ -37,7 +38,7 @@ def main():
)
parser.add_argument("--sample_max_len", type=int, default=1024, help="max tokens for each sample")
parser.add_argument("--block_max_len", type=int, default=2048, help="max tokens for each data block")
parser.add_argument("--use_triton", action="store_true")
parser.add_argument("--backend", choices=['AUTO', 'CUDA_OLD', 'CUDA', 'TRITON', 'EXLLAMA', 'EXLLAMA_V2', 'MARLIN', 'BITBLAS'])
args = parser.parse_args()

tokenizer = AutoTokenizer.from_pretrained(args.base_model_dir)
Expand Down Expand Up @@ -67,7 +68,7 @@ def main():
del model
torch.cuda.empty_cache()

model = GPTQModel.from_quantized(args.quantized_model_dir, device="cuda:0", use_triton=args.use_triton)
model = GPTQModel.from_quantized(args.quantized_model_dir, device="cuda:0", backend=get_backend(args.backend))
task.model = model
task.device = model.device
print(f"eval result for quantized model: {task.run(generation_config=GenerationConfig(max_new_tokens=32))}")
Expand Down
5 changes: 3 additions & 2 deletions examples/quantization/basic_usage.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from gptqmodel import GPTQModel, QuantizeConfig
from transformers import AutoTokenizer, TextGenerationPipeline

from gptqmodel import GPTQModel, QuantizeConfig

pretrained_model_dir = "facebook/opt-125m"
quantized_model_dir = "opt-125m-4bit-128g"

Expand Down Expand Up @@ -48,7 +49,7 @@ def main():
model = GPTQModel.from_quantized(quantized_model_dir, device="cuda:0")

# download quantized model from Hugging Face Hub and load to the first GPU
# model = GPTQModel.from_quantized(repo_id, device="cuda:0", use_safetensors=True, use_triton=False)
# model = GPTQModel.from_quantized(repo_id, device="cuda:0", use_safetensors=True,)

# inference with model.generate
print(tokenizer.decode(model.generate(**tokenizer("gptqmodel is", return_tensors="pt").to(model.device))[0]))
Expand Down
13 changes: 7 additions & 6 deletions examples/quantization/basic_usage_bitblas.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import torch
from gptqmodel import GPTQModel
from gptqmodel.quantization import QuantizeConfig
from transformers import AutoTokenizer, TextGenerationPipeline

use_bitblas = True
from gptqmodel import Backend, GPTQModel
from gptqmodel.quantization import QuantizeConfig

backend = Backend.BITBLAS
pretrained_model_dir = "facebook/opt-125m"
quantized_model_dir = "./facebook/opt-125m-4bit-128g"

if use_bitblas:
if backend == Backend.BITBLAS:
quantized_model_dir += "-bitblas"

def main():
Expand Down Expand Up @@ -52,11 +53,11 @@ def main():

# load quantized model to the first GPU
model = GPTQModel.from_quantized(
quantized_model_dir, device="cuda:0", use_bitblas=use_bitblas
quantized_model_dir, device="cuda:0", backend=backend,
)

# download quantized model from Hugging Face Hub and load to the first GPU
# model = GPTQModel.from_quantized(repo_id, device="cuda:0", use_safetensors=True, use_triton=False)
# model = GPTQModel.from_quantized(repo_id, device="cuda:0", use_safetensors=True)

# -- simple token evaluate --
input_ids = torch.ones((1, 1), dtype=torch.long, device="cuda:0")
Expand Down
7 changes: 4 additions & 3 deletions examples/quantization/basic_usage_gpt_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
import numpy as np
import torch
from datasets import load_dataset
from gptqmodel import GPTQModel, QuantizeConfig
from transformers import TextGenerationPipeline

from gptqmodel import GPTQModel, QuantizeConfig

pretrained_model_dir = "gpt2-xl"
quantized_model_dir = "gpt2-large-4bit-128g"

Expand Down Expand Up @@ -65,7 +66,7 @@ def main():

# quantize model, the calibration_dataset should be list of dict whose keys contains "input_ids" and "attention_mask"
# with value under torch.LongTensor type.
model.quantize(traindataset, use_triton=False)
model.quantize(traindataset)

# save quantized model
model.save_quantized(quantized_model_dir)
Expand All @@ -74,7 +75,7 @@ def main():
model.save_quantized(quantized_model_dir, use_safetensors=True)

# load quantized model, currently only support cpu or single gpu
model = GPTQModel.from_quantized(quantized_model_dir, device="cuda:0", use_triton=False)
model = GPTQModel.from_quantized(quantized_model_dir, device="cuda:0")

# inference with model.generate
print(tokenizer.decode(model.generate(**tokenizer("test is", return_tensors="pt").to("cuda:0"))[0]))
Expand Down
5 changes: 3 additions & 2 deletions examples/quantization/basic_usage_wikitext2.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import torch
import torch.nn as nn

from gptqmodel import GPTQModel, QuantizeConfig

pretrained_model_dir = "facebook/opt-125m"
Expand Down Expand Up @@ -145,7 +146,7 @@ def main():

# quantize model, the calibration_dataset should be list of dict whose keys can only be "input_ids" and "attention_mask"
# with value under torch.LongTensor type.
model.quantize(traindataset, use_triton=False)
model.quantize(traindataset)

# save quantized model
model.save_quantized(quantized_model_dir)
Expand All @@ -154,7 +155,7 @@ def main():
model.save_quantized(quantized_model_dir, use_safetensors=True)

# load quantized model, currently only support cpu or single gpu
model = GPTQModel.from_quantized(quantized_model_dir, device="cuda:0", use_triton=False)
model = GPTQModel.from_quantized(quantized_model_dir, device="cuda:0")

opt_eval(model.model, testenc, "cuda:0")

Expand Down
1 change: 1 addition & 0 deletions gptqmodel/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .models import GPTQModel
from .quantization import BaseQuantizeConfig, QuantizeConfig
from .utils import Backend, get_backend
from .utils.exllama import exllama_set_max_input_length
from .version import __version__
20 changes: 3 additions & 17 deletions gptqmodel/models/auto.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Dict, List, Optional, Union

from ..utils import Backend
from ..utils.model import check_and_get_model_type
from .baichuan import BaiChuanGPTQ
from .base import BaseGPTQModel, QuantizeConfig
Expand Down Expand Up @@ -108,30 +109,19 @@ def from_quantized(
device_map: Optional[Union[str, Dict[str, Union[str, int]]]] = None,
max_memory: Optional[dict] = None,
device: Optional[Union[str, int]] = None,
use_triton: bool = False,
backend: Backend = Backend.AUTO,
use_cuda_fp16: bool = True,
quantize_config: Optional[QuantizeConfig | Dict] = None,
model_basename: Optional[str] = None,
use_safetensors: bool = True,
trust_remote_code: bool = False,
warmup_triton: bool = False,
disable_exllama: Optional[bool] = None,
disable_exllamav2: bool = False,
use_marlin: bool = False,
use_bitblas: bool = False,
# verify weight files matches predefined hash during loading
# usage: hash_format:hash_value, example: md5:ugkdh232
# supports all hashlib hash methods
verify_hash: Optional[Union[str, List[str]]] = None,
**kwargs,
) -> BaseGPTQModel:
# If disable_exllamav2 is True, we want to fall back on the exllama kernel and not the cuda/cuda_old ones.
if disable_exllama is None:
if disable_exllamav2:
disable_exllama = False
else:
disable_exllama = True

model_type = check_and_get_model_type(model_name_or_path, trust_remote_code)
quant_func = MODEL_MAP[model_type].from_quantized

Expand All @@ -140,17 +130,13 @@ def from_quantized(
device_map=device_map,
max_memory=max_memory,
device=device,
use_triton=use_triton,
backend=backend,
use_cuda_fp16=use_cuda_fp16,
quantize_config=quantize_config,
model_basename=model_basename,
use_safetensors=use_safetensors,
trust_remote_code=trust_remote_code,
warmup_triton=warmup_triton,
disable_exllama=disable_exllama,
disable_exllamav2=disable_exllamav2,
use_marlin=use_marlin,
use_bitblas=use_bitblas,
verify_hash=verify_hash,
**kwargs,
)
Expand Down
Loading

0 comments on commit c88e026

Please sign in to comment.