Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CORE] Consolidate 6+ kernel boolean toggels args to single Backend arg #68

Merged
merged 8 commits into from
Jun 27, 2024
21 changes: 7 additions & 14 deletions examples/benchmark/generation_speed.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@

import torch
from datasets import Dataset, load_dataset
from gptqmodel import GPTQModel, QuantizeConfig
from tqdm import tqdm
from transformers import AutoTokenizer, GenerationConfig
from transformers.generation.logits_process import LogitsProcessor

from gptqmodel import Backend, GPTQModel, QuantizeConfig, get_backend

logger = logging.getLogger(__name__)

random.seed(0)
Expand Down Expand Up @@ -143,17 +144,15 @@ def tokenize(examples):

def load_model_tokenizer(
model_name_or_path: str,
backend: Backend,
tokenizer_name_or_path: Optional[str] = None,
from_pretrained: bool = False,
max_memory: Optional[dict] = None,
model_basename: Optional[str] = None,
quantize_config: Optional[str] = None,
trust_remote_code: bool = False,
use_triton: bool = False,
use_bitblas: bool = False,
use_safetensors: bool = True,
use_fast_tokenizer: bool = False,
disable_exllama: bool = False,
):
tokenizer = AutoTokenizer.from_pretrained(
pretrained_model_name_or_path=tokenizer_name_or_path or model_name_or_path,
Expand All @@ -174,15 +173,13 @@ def load_model_tokenizer(
model = GPTQModel.from_quantized(
model_name_or_path,
max_memory=max_memory,
use_triton=use_triton,
use_bitblas=use_bitblas,
use_cuda_fp16=True,
quantize_config=quantize_config,
model_basename=model_basename,
use_safetensors=use_safetensors,
trust_remote_code=trust_remote_code,
warmup_triton=False,
disable_exllama=disable_exllama,
backend=backend,
)

return model, tokenizer
Expand Down Expand Up @@ -234,11 +231,9 @@ def main():
parser.add_argument("--model_basename", type=str, default=None)
parser.add_argument("--quantize_config_save_dir", type=str, default=None)
parser.add_argument("--trust_remote_code", action="store_true")
parser.add_argument("--use_triton", action="store_true")
parser.add_argument("--use_bitblas", action="store_true")
parser.add_argument("--backend", choices=['AUTO', 'CUDA_OLD', 'CUDA', 'TRITON', 'EXLLAMA', 'EXLLAMA_V2', 'MARLIN', 'BITBLAS'])
parser.add_argument("--use_safetensors", action="store_true")
parser.add_argument("--use_fast_tokenizer", action="store_true")
parser.add_argument("--disable_exllama", action="store_true")
parser.add_argument("--num_samples", type=int, default=10)
parser.add_argument("--per_gpu_max_memory", type=int, default=None)
parser.add_argument("--cpu_max_memory", type=int, default=None)
Expand Down Expand Up @@ -277,19 +272,17 @@ def main():
model_basename=args.model_basename,
quantize_config=quantize_config,
trust_remote_code=args.trust_remote_code,
use_triton=args.use_triton,
use_bitblas=args.use_bitblas,
use_safetensors=True,
use_fast_tokenizer=args.use_fast_tokenizer,
disable_exllama=args.disable_exllama,
backend=get_backend(args.backend),
)
end = time.time()
logger.info(f"model and tokenizer loading time: {end - start:.4f}s")
logger.info(f"model quantized: {model.quantized}")
logger.info(f"quantize config: {model.quantize_config.to_dict()}")
logger.info(f"model device map: {model.hf_device_map}")

if args.use_triton:
if args.backend == Backend.TRITON:
logger.info("warmup triton, this may take a while.")
model.warmup_triton()

Expand Down
13 changes: 5 additions & 8 deletions examples/benchmark/perplexity.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
import os

import torch
from gptqmodel.utils import Perplexity
from transformers import AutoTokenizer

from gptqmodel.utils import Perplexity, get_backend

if __name__ == "__main__":
"""
Example usage.
Expand Down Expand Up @@ -42,7 +43,7 @@
default=None,
help="Max memory used in each GPU.",
)
parser.add_argument("--cpu_max_memory", type=int, default=None, help="Mx memory used in CPU.")
parser.add_argument("--cpu_max_memory", type=int, default=None, help="Max memory used in CPU.")
parser.add_argument("--is_quantized", action="store_true", help="Is the model GPTQ quantized?")
parser.add_argument(
"--use_safetensors",
Expand All @@ -51,11 +52,7 @@
)
parser.add_argument("--use_fast_tokenizer", action="store_true", help="Wheter to use fast tokenizer")
parser.add_argument("--trust_remote_code", action="store_true", help="Whether to use remote code")
parser.add_argument(
"--disable_exllama",
action="store_true",
help="Whether to use disable exllama kernel",
)
parser.add_argument("--backend", choices=['AUTO', 'CUDA_OLD', 'CUDA', 'TRITON', 'EXLLAMA', 'EXLLAMA_V2', 'MARLIN', 'BITBLAS'], help="Whether to use Backend format")
args = parser.parse_args()

os.environ["TOKENIZERS_PARALLELISM"] = "false"
Expand Down Expand Up @@ -88,7 +85,7 @@
model_basename=args.model_basename,
use_safetensors=True,
trust_remote_code=args.trust_remote_code,
disable_exllama=args.disable_exllama,
backend=get_backend(args.backend),
)
else:
from transformers import AutoModelForCausalLM
Expand Down
9 changes: 5 additions & 4 deletions examples/evaluation/run_language_modeling_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

import datasets
import torch
from gptqmodel import GPTQModel, QuantizeConfig
from gptqmodel.eval_tasks import LanguageModelingTask
from transformers import AutoTokenizer

from gptqmodel import GPTQModel, QuantizeConfig, get_backend
from gptqmodel.eval_tasks import LanguageModelingTask

DATASET = "tatsu-lab/alpaca"
WITH_INPUT_TEMPLATE = "Instruction:\n{instruction}\n\nInput:\n{input}\n\nOutput:\n"
WITHOUT_INPUT_TEMPLATE = "Instruction:\n{instruction}\n\nOutput:\n"
Expand Down Expand Up @@ -40,7 +41,7 @@ def main():
)
parser.add_argument("--sample_max_len", type=int, default=1024, help="max tokens for each sample")
parser.add_argument("--block_max_len", type=int, default=2048, help="max tokens for each data block")
parser.add_argument("--use_triton", action="store_true")
parser.add_argument("--backend", choices=['AUTO', 'CUDA_OLD', 'CUDA', 'TRITON', 'EXLLAMA', 'EXLLAMA_V2', 'MARLIN', 'BITBLAS'])
args = parser.parse_args()

tokenizer = AutoTokenizer.from_pretrained(args.base_model_dir)
Expand Down Expand Up @@ -70,7 +71,7 @@ def main():
del model
torch.cuda.empty_cache()

model = GPTQModel.from_quantized(args.quantized_model_dir, device="cuda:0", use_triton=args.use_triton)
model = GPTQModel.from_quantized(args.quantized_model_dir, device="cuda:0", backend=get_backend(args.backend))
task.model = model
task.device = model.device
print(f"eval result for quantized model: {task.run()}")
Expand Down
9 changes: 5 additions & 4 deletions examples/evaluation/run_sequence_classification_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@

import datasets
import torch
from gptqmodel import GPTQModel, QuantizeConfig
from gptqmodel.eval_tasks import SequenceClassificationTask
from transformers import AutoTokenizer

from gptqmodel import GPTQModel, QuantizeConfig, get_backend
from gptqmodel.eval_tasks import SequenceClassificationTask

DATASET = "cardiffnlp/tweet_sentiment_multilingual"
TEMPLATE = "Question:What's the sentiment of the given text? Choices are {labels}.\nText: {text}\nAnswer:"
ID2LABEL = {0: "negative", 1: "neutral", 2: "positive"}
Expand Down Expand Up @@ -38,7 +39,7 @@ def main():
)
parser.add_argument("--sample_max_len", type=int, default=1024, help="max tokens for each sample")
parser.add_argument("--block_max_len", type=int, default=2048, help="max tokens for each data block")
parser.add_argument("--use_triton", action="store_true")
parser.add_argument("--backend", choices=['AUTO', 'CUDA_OLD', 'CUDA', 'TRITON', 'EXLLAMA', 'EXLLAMA_V2', 'MARLIN', 'BITBLAS'])
args = parser.parse_args()

tokenizer = AutoTokenizer.from_pretrained(args.base_model_dir)
Expand Down Expand Up @@ -69,7 +70,7 @@ def main():
del model
torch.cuda.empty_cache()

model = GPTQModel.from_quantized(args.quantized_model_dir, device="cuda:0", use_triton=args.use_triton)
model = GPTQModel.from_quantized(args.quantized_model_dir, device="cuda:0", backend=get_backend(args.backend))
task.model = model
task.device = model.device
print(f"eval result for quantized model: {task.run()}")
Expand Down
9 changes: 5 additions & 4 deletions examples/evaluation/run_text_summarization_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@

import datasets
import torch
from gptqmodel import GPTQModel, QuantizeConfig
from gptqmodel.eval_tasks import TextSummarizationTask
from transformers import AutoTokenizer, GenerationConfig

from gptqmodel import GPTQModel, QuantizeConfig, get_backend
from gptqmodel.eval_tasks import TextSummarizationTask

os.system("pip install py7zr")


Expand Down Expand Up @@ -37,7 +38,7 @@ def main():
)
parser.add_argument("--sample_max_len", type=int, default=1024, help="max tokens for each sample")
parser.add_argument("--block_max_len", type=int, default=2048, help="max tokens for each data block")
parser.add_argument("--use_triton", action="store_true")
parser.add_argument("--backend", choices=['AUTO', 'CUDA_OLD', 'CUDA', 'TRITON', 'EXLLAMA', 'EXLLAMA_V2', 'MARLIN', 'BITBLAS'])
args = parser.parse_args()

tokenizer = AutoTokenizer.from_pretrained(args.base_model_dir)
Expand Down Expand Up @@ -67,7 +68,7 @@ def main():
del model
torch.cuda.empty_cache()

model = GPTQModel.from_quantized(args.quantized_model_dir, device="cuda:0", use_triton=args.use_triton)
model = GPTQModel.from_quantized(args.quantized_model_dir, device="cuda:0", backend=get_backend(args.backend))
task.model = model
task.device = model.device
print(f"eval result for quantized model: {task.run(generation_config=GenerationConfig(max_new_tokens=32))}")
Expand Down
5 changes: 3 additions & 2 deletions examples/quantization/basic_usage.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from gptqmodel import GPTQModel, QuantizeConfig
from transformers import AutoTokenizer, TextGenerationPipeline

from gptqmodel import GPTQModel, QuantizeConfig

pretrained_model_dir = "facebook/opt-125m"
quantized_model_dir = "opt-125m-4bit-128g"

Expand Down Expand Up @@ -48,7 +49,7 @@ def main():
model = GPTQModel.from_quantized(quantized_model_dir, device="cuda:0")

# download quantized model from Hugging Face Hub and load to the first GPU
# model = GPTQModel.from_quantized(repo_id, device="cuda:0", use_safetensors=True, use_triton=False)
# model = GPTQModel.from_quantized(repo_id, device="cuda:0", use_safetensors=True,)

# inference with model.generate
print(tokenizer.decode(model.generate(**tokenizer("gptqmodel is", return_tensors="pt").to(model.device))[0]))
Expand Down
13 changes: 7 additions & 6 deletions examples/quantization/basic_usage_bitblas.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import torch
from gptqmodel import GPTQModel
from gptqmodel.quantization import QuantizeConfig
from transformers import AutoTokenizer, TextGenerationPipeline

use_bitblas = True
from gptqmodel import Backend, GPTQModel
from gptqmodel.quantization import QuantizeConfig

backend = Backend.BITBLAS
pretrained_model_dir = "facebook/opt-125m"
quantized_model_dir = "./facebook/opt-125m-4bit-128g"

if use_bitblas:
if backend == Backend.BITBLAS:
quantized_model_dir += "-bitblas"

def main():
Expand Down Expand Up @@ -52,11 +53,11 @@ def main():

# load quantized model to the first GPU
model = GPTQModel.from_quantized(
quantized_model_dir, device="cuda:0", use_bitblas=use_bitblas
quantized_model_dir, device="cuda:0", backend=backend,
)

# download quantized model from Hugging Face Hub and load to the first GPU
# model = GPTQModel.from_quantized(repo_id, device="cuda:0", use_safetensors=True, use_triton=False)
# model = GPTQModel.from_quantized(repo_id, device="cuda:0", use_safetensors=True)

# -- simple token evaluate --
input_ids = torch.ones((1, 1), dtype=torch.long, device="cuda:0")
Expand Down
7 changes: 4 additions & 3 deletions examples/quantization/basic_usage_gpt_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
import numpy as np
import torch
from datasets import load_dataset
from gptqmodel import GPTQModel, QuantizeConfig
from transformers import TextGenerationPipeline

from gptqmodel import GPTQModel, QuantizeConfig

pretrained_model_dir = "gpt2-xl"
quantized_model_dir = "gpt2-large-4bit-128g"

Expand Down Expand Up @@ -65,7 +66,7 @@ def main():

# quantize model, the calibration_dataset should be list of dict whose keys contains "input_ids" and "attention_mask"
# with value under torch.LongTensor type.
model.quantize(traindataset, use_triton=False)
model.quantize(traindataset)

# save quantized model
model.save_quantized(quantized_model_dir)
Expand All @@ -74,7 +75,7 @@ def main():
model.save_quantized(quantized_model_dir, use_safetensors=True)

# load quantized model, currently only support cpu or single gpu
model = GPTQModel.from_quantized(quantized_model_dir, device="cuda:0", use_triton=False)
model = GPTQModel.from_quantized(quantized_model_dir, device="cuda:0")

# inference with model.generate
print(tokenizer.decode(model.generate(**tokenizer("test is", return_tensors="pt").to("cuda:0"))[0]))
Expand Down
5 changes: 3 additions & 2 deletions examples/quantization/basic_usage_wikitext2.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import torch
import torch.nn as nn

from gptqmodel import GPTQModel, QuantizeConfig

pretrained_model_dir = "facebook/opt-125m"
Expand Down Expand Up @@ -145,7 +146,7 @@ def main():

# quantize model, the calibration_dataset should be list of dict whose keys can only be "input_ids" and "attention_mask"
# with value under torch.LongTensor type.
model.quantize(traindataset, use_triton=False)
model.quantize(traindataset)

# save quantized model
model.save_quantized(quantized_model_dir)
Expand All @@ -154,7 +155,7 @@ def main():
model.save_quantized(quantized_model_dir, use_safetensors=True)

# load quantized model, currently only support cpu or single gpu
model = GPTQModel.from_quantized(quantized_model_dir, device="cuda:0", use_triton=False)
model = GPTQModel.from_quantized(quantized_model_dir, device="cuda:0")

opt_eval(model.model, testenc, "cuda:0")

Expand Down
1 change: 1 addition & 0 deletions gptqmodel/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .models import GPTQModel
from .quantization import BaseQuantizeConfig, QuantizeConfig
from .utils import Backend, get_backend
from .utils.exllama import exllama_set_max_input_length
from .version import __version__
20 changes: 3 additions & 17 deletions gptqmodel/models/auto.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Dict, List, Optional, Union

from ..utils import Backend
from ..utils.model import check_and_get_model_type
from .baichuan import BaiChuanGPTQ
from .base import BaseGPTQModel, QuantizeConfig
Expand Down Expand Up @@ -108,30 +109,19 @@ def from_quantized(
device_map: Optional[Union[str, Dict[str, Union[str, int]]]] = None,
max_memory: Optional[dict] = None,
device: Optional[Union[str, int]] = None,
use_triton: bool = False,
backend: Backend = Backend.AUTO,
use_cuda_fp16: bool = True,
quantize_config: Optional[QuantizeConfig | Dict] = None,
model_basename: Optional[str] = None,
use_safetensors: bool = True,
trust_remote_code: bool = False,
warmup_triton: bool = False,
disable_exllama: Optional[bool] = None,
disable_exllamav2: bool = False,
use_marlin: bool = False,
use_bitblas: bool = False,
# verify weight files matches predefined hash during loading
# usage: hash_format:hash_value, example: md5:ugkdh232
# supports all hashlib hash methods
verify_hash: Optional[Union[str, List[str]]] = None,
**kwargs,
) -> BaseGPTQModel:
# If disable_exllamav2 is True, we want to fall back on the exllama kernel and not the cuda/cuda_old ones.
if disable_exllama is None:
if disable_exllamav2:
disable_exllama = False
else:
disable_exllama = True

model_type = check_and_get_model_type(model_name_or_path, trust_remote_code)
quant_func = MODEL_MAP[model_type].from_quantized

Expand All @@ -140,17 +130,13 @@ def from_quantized(
device_map=device_map,
max_memory=max_memory,
device=device,
use_triton=use_triton,
backend=backend,
use_cuda_fp16=use_cuda_fp16,
quantize_config=quantize_config,
model_basename=model_basename,
use_safetensors=use_safetensors,
trust_remote_code=trust_remote_code,
warmup_triton=warmup_triton,
disable_exllama=disable_exllama,
disable_exllamav2=disable_exllamav2,
use_marlin=use_marlin,
use_bitblas=use_bitblas,
verify_hash=verify_hash,
**kwargs,
)
Expand Down
Loading
Loading