Skip to content

Commit

Permalink
Faster generation using AWQ + Fused modules (#27411)
Browse files Browse the repository at this point in the history
* v1 fusing modules

* add fused mlp support

* up

* fix CI

* block save_pretrained

* fixup

* small fix

* add new condition

* add v1 docs

* add some comments

* style

* fix nit

* adapt from suggestion

* add check

* change arg names

* change variables name

* Update src/transformers/integrations/awq.py

Co-authored-by: amyeroberts <[email protected]>

* style

* split up into 3 different private methods

* more conditions

* more checks

* add fused tests for custom models

* fix

* fix tests

* final update docs

* final fixes

* fix importlib metadata

* Update src/transformers/utils/quantization_config.py

Co-authored-by: amyeroberts <[email protected]>

* change it to `do_fuse`

* nit

* Update src/transformers/utils/quantization_config.py

Co-authored-by: Marc Sun <[email protected]>

* Update src/transformers/utils/quantization_config.py

Co-authored-by: Marc Sun <[email protected]>

* Update src/transformers/utils/quantization_config.py

Co-authored-by: Marc Sun <[email protected]>

* few fixes

* revert

* fix test

* fix copies

* raise error if model is not quantized

* add test

* use quantization_config.config when fusing

* Update src/transformers/modeling_utils.py

---------

Co-authored-by: amyeroberts <[email protected]>
Co-authored-by: Marc Sun <[email protected]>
  • Loading branch information
3 people authored Dec 5, 2023
1 parent df40edf commit fdb85be
Show file tree
Hide file tree
Showing 7 changed files with 623 additions and 33 deletions.
2 changes: 1 addition & 1 deletion docker/transformers-all-latest-gpu/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ RUN python3 -m pip install --no-cache-dir auto-gptq --extra-index-url https://hu
RUN python3 -m pip install --no-cache-dir einops

# Add autoawq for quantization testing
RUN python3 -m pip install --no-cache-dir https://github.com/casper-hansen/AutoAWQ/releases/download/v0.1.6/autoawq-0.1.6+cu118-cp38-cp38-linux_x86_64.whl
RUN python3 -m pip install --no-cache-dir https://github.com/casper-hansen/AutoAWQ/releases/download/v0.1.7/autoawq-0.1.7+cu118-cp38-cp38-linux_x86_64.whl

# For bettertransformer + gptq
RUN python3 -m pip install --no-cache-dir git+https://github.com/huggingface/optimum@main#egg=optimum
Expand Down
141 changes: 141 additions & 0 deletions docs/source/en/quantization.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,147 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("TheBloke/zephyr-7B-alpha-AWQ", use_flash_attention_2=True, device_map="cuda:0")
```


### Benchmarks

We performed some speed, throughput and latency benchmarks using [`optimum-benchmark`](https://github.com/huggingface/optimum-benchmark) library.

Note at that time of writing this documentation section, the available quantization methods were: `awq`, `gptq` and `bitsandbytes`.

The benchmark was run on a NVIDIA-A100 instance and the model used was [`TheBloke/Mistral-7B-v0.1-AWQ`](https://huggingface.co/TheBloke/Mistral-7B-v0.1-AWQ) for the AWQ model, [`TheBloke/Mistral-7B-v0.1-GPTQ`](https://huggingface.co/TheBloke/Mistral-7B-v0.1-GPTQ) for the GPTQ model. We also benchmarked it against `bitsandbytes` quantization methods and native `float16` model. Some results are shown below:

<div style="text-align: center">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/quantization/forward_memory_plot.png">
</div>

<div style="text-align: center">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/quantization/generate_memory_plot.png">
</div>

<div style="text-align: center">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/quantization/generate_throughput_plot.png">
</div>

<div style="text-align: center">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/quantization/forward_latency_plot.png">
</div>

You can find the full results together with packages versions in [this link](https://github.com/huggingface/optimum-benchmark/tree/main/examples/running-mistrals).

From the results it appears that AWQ quantization method is the fastest quantization method for inference, text generation and among the lowest peak memory for text generation. However, AWQ seems to have the largest forward latency per batch size.


### Make use of fused modules

You can benefit from fused modules by passing an `AwqConfig` with `fuse_modules=True` and your expected maximum sequence length for generation to `fuse_max_seq_len`. For architectures that do not support `do_fuse=True`, you can still fuse the modules, however you need to pass a custom `fusing_mapping` to `AwqConfig()`. Let's dive into these specific usecases.

Note that you cannot combine fusing modules and other optimization techniques such as Flash Attention 2.

#### Fusing modules for supported architectures

Currently we support out of the box AWQ module fusing for `llama` and `mistral`.

To enable this feature for supported architectures simply create an `AwqConfig` and pass the arguments `fuse_max_seq_len` and `do_fuse=True`.

For example to enable module fusing for the model `TheBloke/Mistral-7B-OpenOrca-AWQ`, run:

```python
import torch
from transformers import AwqConfig, AutoModelForCausalLM

model_id = "TheBloke/Mistral-7B-OpenOrca-AWQ"

quantization_config = AwqConfig(
bits=4,
fuse_max_seq_len=512,
do_fuse=True,
)

model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=quantization_config).to(0)
```

Note that you need to define `fuse_max_seq_len` to `AwqConfig`. That total sequence length should include the context length and the expected generation length. You can set it to a large value to be on the safe zone.

You can also apply module fusing for other architectures that are not supported.

#### Fusing modules for unsupported architectures

For architectures that do not support out of the box module fusing, you can pass a custom fusing mapping; simply pass a dictionnary `modules_to_fuse` to `AwqConfig`, let's take an example with the Yi model:


```python
import torch
from transformers import AwqConfig, AutoModelForCausalLM

model_id = "TheBloke/Yi-34B-AWQ"

quantization_config = AwqConfig(
bits=4,
fuse_max_seq_len=512,
modules_to_fuse={
"attention": ["q_proj", "k_proj", "v_proj", "o_proj"],
"layernorm": ["ln1", "ln2", "norm"],
"mlp": ["gate_proj", "up_proj", "down_proj"],
"use_alibi": False,
"num_attention_heads": 56,
"num_key_value_heads": 8,
"hidden_size": 7168
}
)

model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=quantization_config).to(0)
```

The parameter `modules_to_fuse` needs to have the following respective fields:

- `"attention"`: The names of the attention layers to fuse - in the order: query, key, value and output projection layer. In case you don't want to fuse the attention layers you can pass an empty list.
- `"layernorm"`: The names of all the layernorm layers you want to replace with a custom fused layer norm. In case you don't want to fuse these layers you can also pass an empty list.
- `"mlp"`: The names of the MLP layers you want to fuse into a single MLP layer in the order: (gate (dense layer post-attention) / up / down layers).
- `"use_alibi"`: If you model uses alibi positional embedding
- `"num_attention_heads"`: The number of attention heads
- `"num_key_value_heads"`: This is the number of key value heads that should be used to implement Grouped Query Attention. If num_key_value_heads=num_attention_heads, the model will use Multi Head Attention (MHA), if num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used.
- `"hidden_size"`: Dimension of the hidden representations.


#### Benchmarks

We benchmarked the model with and without fused modules first using only `batch_size=1` on the `TheBloke/Mistral-7B-OpenOrca-AWQ` model and below are the results:

*unfused case*

| Batch Size | Prefill Length | Decode Length | Prefill tokens/s | Decode tokens/s | Memory (VRAM) |
|-------------:|-----------------:|----------------:|-------------------:|------------------:|:----------------|
| 1 | 32 | 32 | 60.0984 | 38.4537 | 4.50 GB (5.68%) |
| 1 | 64 | 64 | 1333.67 | 31.6604 | 4.50 GB (5.68%) |
| 1 | 128 | 128 | 2434.06 | 31.6272 | 4.50 GB (5.68%) |
| 1 | 256 | 256 | 3072.26 | 38.1731 | 4.50 GB (5.68%) |
| 1 | 512 | 512 | 3184.74 | 31.6819 | 4.59 GB (5.80%) |
| 1 | 1024 | 1024 | 3148.18 | 36.8031 | 4.81 GB (6.07%) |
| 1 | 2048 | 2048 | 2927.33 | 35.2676 | 5.73 GB (7.23%) |

*fused case*

| Batch Size | Prefill Length | Decode Length | Prefill tokens/s | Decode tokens/s | Memory (VRAM) |
|-------------:|-----------------:|----------------:|-------------------:|------------------:|:----------------|
| 1 | 32 | 32 | 81.4899 | 80.2569 | 4.00 GB (5.05%) |
| 1 | 64 | 64 | 1756.1 | 106.26 | 4.00 GB (5.05%) |
| 1 | 128 | 128 | 2479.32 | 105.631 | 4.00 GB (5.06%) |
| 1 | 256 | 256 | 1813.6 | 85.7485 | 4.01 GB (5.06%) |
| 1 | 512 | 512 | 2848.9 | 97.701 | 4.11 GB (5.19%) |
| 1 | 1024 | 1024 | 3044.35 | 87.7323 | 4.41 GB (5.57%) |
| 1 | 2048 | 2048 | 2715.11 | 89.4709 | 5.57 GB (7.04%) |

We also performed benchmarks with [`optimum-benchmark`](https://github.com/huggingface/optimum-benchmark) library. And below are the results:

<div style="text-align: center">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/quantization/fused_forward_memory_plot.png">
</div>

<div style="text-align: center">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/quantization/fused_generate_throughput.png">
</div>


## AutoGPTQ

<Tip>
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/integrations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@


_import_structure = {
"awq": ["replace_with_awq_linear"],
"awq": ["fuse_awq_modules", "replace_with_awq_linear"],
"bitsandbytes": [
"get_keys_to_not_convert",
"replace_8bit_linear",
Expand Down Expand Up @@ -80,7 +80,7 @@
}

if TYPE_CHECKING:
from .awq import replace_with_awq_linear
from .awq import fuse_awq_modules, replace_with_awq_linear
from .bitsandbytes import (
get_keys_to_not_convert,
replace_8bit_linear,
Expand Down
Loading

0 comments on commit fdb85be

Please sign in to comment.