Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add exllamav2 better #27111

Merged
merged 38 commits into from
Nov 1, 2023
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
ccf4017
add_ xllamav2 arg
SunMarc Sep 27, 2023
dc9ad60
add test
SunMarc Sep 27, 2023
bcd44ea
style
SunMarc Sep 27, 2023
3fa339d
add check
SunMarc Sep 27, 2023
cd6c0a9
add doc
SunMarc Sep 27, 2023
a1f53a1
replace by use_exllama_v2
SunMarc Oct 23, 2023
8a4d493
fix tests
SunMarc Oct 23, 2023
e43d892
fix doc
SunMarc Oct 23, 2023
97f8531
style
SunMarc Oct 23, 2023
73c482a
better condition
SunMarc Oct 23, 2023
b68d207
fix logic
SunMarc Oct 24, 2023
4e2ff0b
Merge remote-tracking branch 'upstream/main' into add_exllamav2_arg
SunMarc Oct 24, 2023
788412f
add deprecate msg
SunMarc Oct 26, 2023
7250a72
deprecate exllama
SunMarc Oct 27, 2023
d2d6c2d
remove disable_exllama from the linter
SunMarc Oct 27, 2023
831c1e0
remove
SunMarc Oct 27, 2023
b26d837
fix warning
SunMarc Oct 27, 2023
be19406
Revert the commits deprecating exllama
SunMarc Oct 27, 2023
0c4ae07
deprecate disable_exllama for use_exllama
SunMarc Oct 27, 2023
098b0da
fix
SunMarc Oct 27, 2023
864e193
fix loading attribute
SunMarc Oct 27, 2023
39d87ab
better handling of args
SunMarc Oct 27, 2023
3a33b37
Merge branch 'main' into add_exllamav2_better
SunMarc Oct 27, 2023
269852d
remove disable_exllama from init and linter
SunMarc Oct 27, 2023
c8b6beb
Apply suggestions from code review
SunMarc Oct 30, 2023
72ca82b
better arg
SunMarc Oct 30, 2023
ecb6512
Merge branch 'add_exllamav2_better' of https://github.com/SunMarc/tra…
SunMarc Oct 30, 2023
e1f3c48
fix warning
SunMarc Oct 30, 2023
48ebca6
Merge remote-tracking branch 'upstream/main' into add_exllamav2_better
SunMarc Oct 31, 2023
4fd403a
Apply suggestions from code review
SunMarc Oct 31, 2023
a8c81e0
switch to dict
SunMarc Oct 31, 2023
2004896
Apply suggestions from code review
SunMarc Nov 1, 2023
2d37325
style
SunMarc Nov 1, 2023
4a6d702
nits
SunMarc Nov 1, 2023
b761e48
Merge remote-tracking branch 'upstream/main' into add_exllamav2_better
SunMarc Nov 1, 2023
d577f9e
style
SunMarc Nov 1, 2023
0955658
better tests
SunMarc Nov 1, 2023
71f81b7
style
SunMarc Nov 1, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions docs/source/en/main_classes/quantization.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,18 +122,20 @@ from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("{your_username}/opt-125m-gptq", device_map="auto")
```

### Exllama kernels for faster inference
### Exllamav2 kernels for faster inference

For 4-bit model, you can use the exllama kernels in order to a faster inference speed. It is activated by default. You can change that behavior by passing `disable_exllama` in [`GPTQConfig`]. This will overwrite the quantization config stored in the config. Note that you will only be able to overwrite the attributes related to the kernels. Furthermore, you need to have the entire model on gpus if you want to use exllama kernels.
For 4-bit model, you can use the exllamav2 kernels in order to a faster inference speed compared to exllama kernels. You just need to
pass `use_exllama_v2=True` in [`GPTQConfig`]. This will overwrite the quantization config stored in the config. Note that you will only be able to overwrite the attributes related to the kernels. Furthermore, you need to have the entire model on gpus if you want to use exllamav2 kernels.

```py
import torch
gptq_config = GPTQConfig(bits=4, disable_exllama=False)
gptq_config = GPTQConfig(bits=4, use_exllama_v2=True)
model = AutoModelForCausalLM.from_pretrained("{your_username}/opt-125m-gptq", device_map="auto", quantization_config = gptq_config)
```

Note that only 4-bit models are supported for now. Furthermore, it is recommended to deactivate the exllama kernels if you are finetuning a quantized model with peft.

You can find the benchmark of these kernels [here](https://github.com/huggingface/optimum/tree/main/tests/benchmark#gptq-benchmark)
#### Fine-tune a quantized model

With the official support of adapters in the Hugging Face ecosystem, you can fine-tune models that have been quantized with GPTQ.
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2714,7 +2714,7 @@ def from_pretrained(
logger.warning(
"You passed `quantization_config` to `from_pretrained` but the model you're loading already has a "
"`quantization_config` attribute and has already quantized weights. However, loading attributes"
" (e.g. disable_exllama, use_cuda_fp16, max_input_length) will be overwritten with the one you passed to `from_pretrained`. The rest will be ignored."
" (e.g. use_cuda_fp16, max_input_length, use_exllama_v2) will be overwritten with the one you passed to `from_pretrained`. The rest will be ignored."
)
if (
quantization_method_from_args == QuantizationMethod.GPTQ
Expand Down
28 changes: 23 additions & 5 deletions src/transformers/utils/quantization_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,11 +344,11 @@ class GPTQConfig(QuantizationConfigMixin):
The batch size used when processing the dataset
pad_token_id (`int`, *optional*):
The pad token id. Needed to prepare the dataset when `batch_size` > 1.
disable_exllama (`bool`, *optional*, defaults to `False`):
Whether to use exllama backend. Only works with `bits` = 4.
max_input_length (`int`, *optional*):
The maximum input length. This is needed to initialize a buffer that depends on the maximum expected input
length. It is specific to the exllama backend with act-order.
use_exllama_v2 (`bool`, *optional*, defaults to `False`):
Whether to use exllamav2 backend. Only works with `bits` = 4.
"""

def __init__(
Expand All @@ -367,8 +367,8 @@ def __init__(
module_name_preceding_first_block: Optional[List[str]] = None,
batch_size: int = 1,
pad_token_id: Optional[int] = None,
disable_exllama: bool = False,
max_input_length: Optional[int] = None,
use_exllama_v2: bool = False,
**kwargs,
):
self.quant_method = QuantizationMethod.GPTQ
Expand All @@ -386,13 +386,16 @@ def __init__(
self.module_name_preceding_first_block = module_name_preceding_first_block
self.batch_size = batch_size
self.pad_token_id = pad_token_id
self.disable_exllama = disable_exllama
self.disable_exllama = kwargs.get("disable_exllama", False)
SunMarc marked this conversation as resolved.
Show resolved Hide resolved
self.max_input_length = max_input_length
self.use_exllama_v2 = use_exllama_v2
# needed for compatibility with optimum gptq config
self.disable_exllamav2 = not use_exllama_v2
self.post_init()

def get_loading_attributes(self):
attibutes_dict = copy.deepcopy(self.__dict__)
loading_attibutes = ["disable_exllama", "use_cuda_fp16", "max_input_length"]
loading_attibutes = ["disable_exllama", "use_exllama_v2", "use_cuda_fp16", "max_input_length"]
SunMarc marked this conversation as resolved.
Show resolved Hide resolved
loading_attibutes_dict = {i: j for i, j in attibutes_dict.items() if i in loading_attibutes}
return loading_attibutes_dict

Expand All @@ -418,3 +421,18 @@ def post_init(self):
f"""dataset needs to be either a list of string or a value in
['wikitext2','c4','c4-new','ptb','ptb-new'], but we found {self.dataset}"""
)
if self.bits == 4:
if self.use_exllama_v2:
optimum_version = version.parse(importlib.metadata.version("optimum"))
autogptq_version = version.parse(importlib.metadata.version("auto_gptq"))
if optimum_version <= version.parse("1.13.2") or autogptq_version <= version.parse("0.4.2"):
raise ValueError(
f"You need optimum > 1.13.2 and auto-gptq > 0.4.2 . Make sure to have that version installed - detected version : optimum {optimum_version} and autogptq {autogptq_version}"
)
self.disable_exllama = True
logger.warning("You have activated exllamav2 kernels. Exllama kernels will be disabled.")
if not self.disable_exllama:
logger.warning(
"You have activated exllama backend. Using `disable_exllama` is deprecated and will be removed in version 5.0 of 🤗 Transformers."
"Use `use_exllama_v2` instead. Note that you can get better inference speed using exllamav2 kernel by setting `use_exllama_v2=True`"
)
60 changes: 58 additions & 2 deletions tests/quantization/gptq/test_gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ def test_quantized_layers_class(self):
group_size=self.group_size,
bits=self.bits,
disable_exllama=self.disable_exllama,
disable_exllamav2=True,
)
self.assertTrue(self.quantized_model.transformer.h[0].mlp.dense_4h_to_h.__class__ == QuantLinear)

Expand Down Expand Up @@ -281,8 +282,7 @@ def setUpClass(cls):
"""
Setup quantized model
"""

cls.quantization_config = GPTQConfig(bits=4, disable_exllama=False, max_input_length=4028)
cls.quantization_config = GPTQConfig(bits=4, max_input_length=4028)
cls.quantized_model = AutoModelForCausalLM.from_pretrained(
cls.model_name,
revision=cls.revision,
Expand Down Expand Up @@ -334,6 +334,62 @@ def test_max_input_length(self):
self.quantized_model.generate(**inp, num_beams=1, min_new_tokens=3, max_new_tokens=3)


@slow
@require_optimum
@require_auto_gptq
@require_torch_gpu
@require_accelerate
class GPTQTestExllamaV2(unittest.TestCase):
"""
Test GPTQ model with exllamav2 kernel and desc_act=True (also known as act-order).
More information on those arguments here:
https://huggingface.co/docs/transformers/main_classes/quantization#transformers.GPTQConfig
"""

EXPECTED_OUTPUTS = set()
EXPECTED_OUTPUTS.add("Hello my name is Katie and I am a 20 year")
model_name = "hf-internal-testing/Llama-2-7B-GPTQ"
revision = "gptq-4bit-128g-actorder_True"
input_text = "Hello my name is"

@classmethod
def setUpClass(cls):
"""
Setup quantized model
"""
cls.quantization_config = GPTQConfig(bits=4, use_exllama_v2=True)
cls.quantized_model = AutoModelForCausalLM.from_pretrained(
cls.model_name,
revision=cls.revision,
torch_dtype=torch.float16,
device_map={"": 0},
quantization_config=cls.quantization_config,
)
cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name, use_fast=True)

def check_inference_correctness(self, model):
"""
Test the generation quality of the quantized model and see that we are matching the expected output.
Given that we are operating on small numbers + the testing model is relatively small, we might not get
the same output across GPUs. So we'll generate few tokens (5-10) and check their output.
"""

# Check that inference pass works on the model
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")

# Check the exactness of the results
output_sequences = model.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10)

# Get the generation
self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)

def test_generate_quality(self):
"""
Simple test to check the quality of the model by comapring the the generated tokens with the expected tokens
"""
self.check_inference_correctness(self.quantized_model)


# fail when run all together
@pytest.mark.skip
@require_accelerate
Expand Down
Loading