Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add_exllamav2 #1419

Merged
merged 8 commits into from
Oct 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions docs/source/llm_quantization/usage_guides/quantization.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ quantized_model = load_quantized_model(empty_model, save_folder=save_folder, dev

### Exllama kernels for faster inference

For 4-bit model, you can use the exllama kernels in order to a faster inference speed. It is activated by default. If you want to change its value, you just need to pass `disable_exllama` in [`~optimum.gptq.load_quantized_model`]. In order to use these kernels, you need to have the entire model on gpus.
With the release of the exllamav2 kernel, you can get faster inference speed compared to the exllama kernels for 4-bit model. It is activated by default: `disable_exllamav2=False` in [`~optimum.gptq.load_quantized_model`]. In order to use these kernels, you need to have the entire model on gpus.

```py
from optimum.gptq import GPTQQuantizer, load_quantized_model
Expand All @@ -86,10 +86,25 @@ from accelerate import init_empty_weights
with init_empty_weights():
empty_model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
empty_model.tie_weights()
quantized_model = load_quantized_model(empty_model, save_folder=save_folder, device_map="auto", disable_exllama=False)
quantized_model = load_quantized_model(empty_model, save_folder=save_folder, device_map="auto")
```

Note that only 4-bit models are supported with exllama kernels for now. Furthermore, it is recommended to disable the exllama kernel when you are finetuning your model with peft.
If you wish to use exllama kernels, you will have to disable the exllamav2 kernel and activate the exllama kernel:

```py
from optimum.gptq import GPTQQuantizer, load_quantized_model
import torch

from accelerate import init_empty_weights
with init_empty_weights():
empty_model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
empty_model.tie_weights()
quantized_model = load_quantized_model(empty_model, save_folder=save_folder, device_map="auto", disable_exllama=False, disable_exllamav2=True)
```

Note that only 4-bit models are supported with exllama/exllamav2 kernels for now. Furthermore, it is recommended to disable the exllama/exllamav2 kernel when you are finetuning your model with peft.

You can find the benchmark of these kernels [here](https://github.com/huggingface/optimum/tree/main/tests/benchmark#gptq-benchmark)

#### Fine-tune a quantized model

Expand Down
46 changes: 34 additions & 12 deletions optimum/gptq/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ def __init__(
module_name_preceding_first_block: Optional[List[str]] = None,
batch_size: int = 1,
pad_token_id: Optional[int] = None,
disable_exllama: bool = False,
disable_exllama: bool = True,
disable_exllamav2: bool = False,
max_input_length: Optional[int] = None,
*args,
**kwargs,
Expand Down Expand Up @@ -107,8 +108,10 @@ def __init__(
The batch size of the dataset
pad_token_id (`Optional[int]`, defaults to `None`):
The pad token id. Needed to prepare the dataset when `batch_size` > 1.
disable_exllama (`bool`, defaults to `False`):
disable_exllama (`bool`, defaults to `True`):
Whether to use exllama backend. Only works with `bits` = 4.
disable_exllamav2 (`bool`, defaults to `False`):
Whether to use exllamav2 backend. Only works with `bits` = 4.
max_input_length (`Optional[int]`, defaults to `None`):
The maximum input length. This is needed to initialize a buffer that depends on the maximum expected input length.
It is specific to the exllama backend with act-order.
Expand All @@ -128,6 +131,7 @@ def __init__(
self.batch_size = batch_size
self.pad_token_id = pad_token_id
self.disable_exllama = disable_exllama
self.disable_exllamav2 = disable_exllamav2
self.max_input_length = max_input_length
self.quant_method = QuantizationMethod.GPTQ

Expand All @@ -137,6 +141,10 @@ def __init__(
raise ValueError("group_size must be greater than 0 or equal to -1")
if not (0 < self.damp_percent < 1):
raise ValueError("damp_percent must between 0 and 1.")
if not self.disable_exllamav2 and not self.disable_exllama:
raise ValueError(
"disable_exllamav2 and disable_exllama are both set to `False`. Please disable one of the kernels."
)

def to_dict(self):
"""
Expand Down Expand Up @@ -205,6 +213,7 @@ def _replace_by_quant_layers(self, module: nn.Module, names: List[str], name: st
group_size=self.group_size,
bits=self.bits,
disable_exllama=self.disable_exllama,
disable_exllamav2=self.disable_exllamav2,
)
if isinstance(module, QuantLinear):
return
Expand Down Expand Up @@ -440,13 +449,21 @@ def tmp(_, input, output):
layer_inputs, layer_outputs = layer_outputs, []
torch.cuda.empty_cache()

if self.bits == 4 and not self.disable_exllama:
if self.bits == 4:
# device not on gpu
if device == torch.device("cpu") or (has_device_map and any(d in devices for d in ["cpu", "disk"])):
logger.warning(
"Found modules on cpu/disk. Using Exllama backend requires all the modules to be on GPU. Setting `disable_exllama=True`"
)
self.disable_exllama = True
elif self.desc_act:
if not self.disable_exllama:
logger.warning(
"Found modules on cpu/disk. Using Exllama backend requires all the modules to be on GPU. Setting `disable_exllama=True`"
)
self.disable_exllama = True
if not self.disable_exllamav2:
logger.warning(
"Found modules on cpu/disk. Using Exllamav2 backend requires all the modules to be on GPU. Setting `disable_exllamav2=True`"
)
self.disable_exllamav2 = True
# act order and exllama
elif self.desc_act and not self.disable_exllama:
fxmarty marked this conversation as resolved.
Show resolved Hide resolved
logger.warning(
"Using Exllama backend with act_order will reorder the weights offline, thus you will not be able to save the model with the right weights."
"Setting `disable_exllama=True`. You should only use Exllama backend with act_order for inference. "
Expand Down Expand Up @@ -475,13 +492,13 @@ def post_init_model(self, model):
model (`nn.Module`):
The input model
"""
if self.bits == 4 and not self.disable_exllama:
if self.bits == 4 and (not self.disable_exllama or not self.disable_exllamav2):
if get_device(model) == torch.device("cpu") or (
hasattr(model, "hf_device_map") and any(d in model.hf_device_map for d in ["cpu", "disk"])
):
raise ValueError(
"Found modules on cpu/disk. Using Exllama backend requires all the modules to be on GPU."
"You can deactivate exllama backend by setting `disable_exllama=True` in the quantization config object"
"Found modules on cpu/disk. Using Exllama or Exllamav2 backend requires all the modules to be on GPU."
"You can deactivate exllama backend by setting `disable_exllama=True` or `disable_exllamav2=True` in the quantization config object"
)

class StoreAttr(object):
Expand Down Expand Up @@ -514,6 +531,7 @@ def pack_model(
group_size=self.group_size,
bits=self.bits,
disable_exllama=self.disable_exllama,
disable_exllamav2=self.disable_exllamav2,
)
logger.info("Packing model...")
layers = get_layers(model)
Expand Down Expand Up @@ -579,7 +597,8 @@ def load_quantized_model(
offload_folder: Optional[str] = None,
offload_buffers: Optional[str] = None,
offload_state_dict: bool = False,
disable_exllama: bool = False,
disable_exllama: bool = True,
disable_exllamav2: bool = False,
max_input_length: Optional[int] = None,
):
"""
Expand Down Expand Up @@ -615,6 +634,8 @@ def load_quantized_model(
picked contains `"disk"` values.
disable_exllama (`bool`, defaults to `False`):
Whether to use exllama backend. Only works with `bits` = 4.
disable_exllama (`bool`, defaults to `False`):
Whether to use exllamav2 backend. Only works with `bits` = 4.
max_input_length (`Optional[int]`, defaults to `None`):
The maximum input length. This is needed to initialize a buffer that depends on the maximum expected input length.
It is specific to the exllama backend with act-order.
Expand Down Expand Up @@ -648,6 +669,7 @@ def load_quantized_model(
) from err
quantizer = GPTQQuantizer.from_dict(quantize_config_dict)
quantizer.disable_exllama = disable_exllama
quantizer.disable_exllamav2 = disable_exllamav2
quantizer.max_input_length = max_input_length

model = quantizer.convert_model(model)
Expand Down
4 changes: 2 additions & 2 deletions optimum/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,11 @@ def is_timm_available():
def is_auto_gptq_available():
if _auto_gptq_available:
version_autogptq = packaging.version.parse(importlib_metadata.version("auto_gptq"))
if AUTOGPTQ_MINIMUM_VERSION <= version_autogptq:
if AUTOGPTQ_MINIMUM_VERSION < version_autogptq:
return True
else:
raise ImportError(
f"Found an incompatible version of auto-gptq. Found version {version_autogptq}, but only {AUTOGPTQ_MINIMUM_VERSION} and above are supported"
f"Found an incompatible version of auto-gptq. Found version {version_autogptq}, but only version above {AUTOGPTQ_MINIMUM_VERSION} are supported"
)


Expand Down
56 changes: 53 additions & 3 deletions tests/gptq/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class GPTQTest(unittest.TestCase):
group_size = 128
desc_act = False
disable_exllama = True
disable_exllamav2 = True

dataset = [
"auto-gptq is an easy-to-use model quantization library with user-friendly apis, based on GPTQ algorithm."
Expand All @@ -69,6 +70,7 @@ def setUpClass(cls):
group_size=cls.group_size,
desc_act=cls.desc_act,
disable_exllama=cls.disable_exllama,
disable_exllamav2=cls.disable_exllamav2,
)

cls.quantized_model = cls.quantizer.quantize_model(cls.model_fp16, cls.tokenizer)
Expand Down Expand Up @@ -96,6 +98,7 @@ def test_quantized_layers_class(self):
group_size=self.group_size,
bits=self.bits,
disable_exllama=self.disable_exllama,
disable_exllamav2=self.disable_exllamav2,
)
self.assertTrue(self.quantized_model.transformer.h[0].mlp.dense_4h_to_h.__class__ == QuantLinear)

Expand Down Expand Up @@ -133,13 +136,18 @@ def test_serialization(self):
)
empty_model.tie_weights()
quantized_model_from_saved = load_quantized_model(
empty_model, save_folder=tmpdirname, device_map={"": 0}, disable_exllama=self.disable_exllama
empty_model,
save_folder=tmpdirname,
device_map={"": 0},
disable_exllama=self.disable_exllama,
disable_exllamav2=self.disable_exllamav2,
)
self.check_inference_correctness(quantized_model_from_saved)


class GPTQTestExllama(GPTQTest):
disable_exllama = False
disable_exllamav2 = True
EXPECTED_OUTPUTS = set()
EXPECTED_OUTPUTS.add("Hello my name is John, I am a professional photographer and I")
EXPECTED_OUTPUTS.add("Hello my name is jay and i am a student at university.")
Expand All @@ -153,6 +161,7 @@ class GPTQTestActOrder(GPTQTest):
EXPECTED_OUTPUTS.add("Hello my name is nathalie, I am a young girl from")

disable_exllama = True
disable_exllamav2 = True
desc_act = True

def test_generate_quality(self):
Expand All @@ -178,7 +187,7 @@ def test_exllama_serialization(self):
)
empty_model.tie_weights()
quantized_model_from_saved = load_quantized_model(
empty_model, save_folder=tmpdirname, device_map={"": 0}, disable_exllama=False
empty_model, save_folder=tmpdirname, device_map={"": 0}, disable_exllama=False, disable_exllamav2=True
SunMarc marked this conversation as resolved.
Show resolved Hide resolved
)
self.check_inference_correctness(quantized_model_from_saved)

Expand All @@ -197,7 +206,12 @@ def test_exllama_max_input_length(self):
)
empty_model.tie_weights()
quantized_model_from_saved = load_quantized_model(
empty_model, save_folder=tmpdirname, device_map={"": 0}, disable_exllama=False, max_input_length=4028
empty_model,
SunMarc marked this conversation as resolved.
Show resolved Hide resolved
save_folder=tmpdirname,
device_map={"": 0},
disable_exllama=False,
max_input_length=4028,
disable_exllamav2=True,
)

prompt = "I am in Paris and" * 1000
Expand All @@ -213,6 +227,42 @@ def test_exllama_max_input_length(self):
quantized_model_from_saved.generate(**inp, num_beams=1, min_new_tokens=3, max_new_tokens=3)


class GPTQTestExllamav2(GPTQTest):
desc_act = False
disable_exllama = True
disable_exllamav2 = True

def test_generate_quality(self):
# don't need to test
pass

def test_serialization(self):
# don't need to test
pass
Comment on lines +235 to +241
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, we quantize the model using the cuda-old kernel and save the model to later load it with exllamav2 for the test_exllama_serialization test. Since these tests will use the cuda-old kernel, we don't need to test them as we already do so in a previous test.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And how about generate_quality?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is also tested in the GPTQTest class. The wording is confusing but test_exllama_serialization in GPTQTestExllamav2 does two things: test the loading quantized weights with exllamav2 kernels + test the inference correctness.


def test_exllama_serialization(self):
"""
Test the serialization of the model and the loading of the quantized weights with exllamav2 kernel
"""
from accelerate import init_empty_weights

with tempfile.TemporaryDirectory() as tmpdirname:
self.quantizer.save(self.quantized_model, tmpdirname)
self.quantized_model.config.save_pretrained(tmpdirname)
with init_empty_weights():
empty_model = AutoModelForCausalLM.from_config(
AutoConfig.from_pretrained(self.model_name), torch_dtype=torch.float16
)
empty_model.tie_weights()
quantized_model_from_saved = load_quantized_model(
empty_model,
save_folder=tmpdirname,
device_map={"": 0},
disable_exllamav2=False,
)
self.check_inference_correctness(quantized_model_from_saved)


class GPTQUtilsTest(unittest.TestCase):
"""
Test utilities
Expand Down
Loading