Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Faster generation using AWQ + Fused modules #27411

Merged
merged 51 commits into from
Dec 5, 2023
Merged
Show file tree
Hide file tree
Changes from 50 commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
6c995f9
v1 fusing modules
younesbelkada Nov 9, 2023
85cc9c7
add fused mlp support
younesbelkada Nov 9, 2023
b6cd554
Merge remote-tracking branch 'upstream/main' into awq-fused-modules
younesbelkada Nov 13, 2023
7ffbaa3
up
younesbelkada Nov 13, 2023
05b5f62
fix CI
younesbelkada Nov 13, 2023
8670aa2
block save_pretrained
younesbelkada Nov 13, 2023
9ee6b38
fixup
younesbelkada Nov 13, 2023
f8d4177
Merge remote-tracking branch 'upstream/main' into awq-fused-modules
younesbelkada Nov 13, 2023
1a8c915
small fix
younesbelkada Nov 14, 2023
b541b4d
add new condition
younesbelkada Nov 15, 2023
2ea1f47
Merge remote-tracking branch 'upstream/main' into awq-fused-modules
younesbelkada Nov 16, 2023
024b737
Merge branch 'awq-fused-modules' of https://github.com/younesbelkada/…
younesbelkada Nov 16, 2023
a7d74f8
add v1 docs
younesbelkada Nov 16, 2023
85e1e3b
add some comments
younesbelkada Nov 16, 2023
3e6ba9b
Merge branch 'main' into awq-fused-modules
younesbelkada Nov 17, 2023
26194d0
Merge remote-tracking branch 'upstream/main' into awq-fused-modules
younesbelkada Nov 21, 2023
f160a16
style
younesbelkada Nov 21, 2023
14c820d
fix nit
younesbelkada Nov 21, 2023
03d8dff
adapt from suggestion
younesbelkada Nov 21, 2023
0a08551
add check
younesbelkada Nov 21, 2023
234165f
change arg names
younesbelkada Nov 21, 2023
03980d9
change variables name
younesbelkada Nov 21, 2023
8a68a23
Update src/transformers/integrations/awq.py
younesbelkada Nov 21, 2023
21f6879
style
younesbelkada Nov 21, 2023
cde53ef
split up into 3 different private methods
younesbelkada Nov 21, 2023
8517e32
more conditions
younesbelkada Nov 21, 2023
b187c07
more checks
younesbelkada Nov 21, 2023
c3e32ab
add fused tests for custom models
younesbelkada Nov 22, 2023
d3c7753
fix
younesbelkada Nov 22, 2023
4113c45
fix tests
younesbelkada Nov 22, 2023
0bd1b0c
final update docs
younesbelkada Nov 22, 2023
61db430
final fixes
younesbelkada Nov 22, 2023
cd37d32
fix importlib metadata
younesbelkada Nov 23, 2023
8f381ed
Merge remote-tracking branch 'upstream/main' into awq-fused-modules
younesbelkada Dec 4, 2023
e80ad75
Merge branch 'awq-fused-modules' of https://github.com/younesbelkada/…
younesbelkada Dec 4, 2023
b5c337c
Update src/transformers/utils/quantization_config.py
younesbelkada Dec 4, 2023
3f98913
change it to `do_fuse`
younesbelkada Dec 4, 2023
3bd0446
nit
younesbelkada Dec 4, 2023
e1b3bfa
Update src/transformers/utils/quantization_config.py
younesbelkada Dec 4, 2023
cb31546
Update src/transformers/utils/quantization_config.py
younesbelkada Dec 4, 2023
45875fd
Update src/transformers/utils/quantization_config.py
younesbelkada Dec 4, 2023
faaa255
Merge branch 'awq-fused-modules' of https://github.com/younesbelkada/…
younesbelkada Dec 4, 2023
c1ea9b2
few fixes
younesbelkada Dec 4, 2023
d90eec7
revert
younesbelkada Dec 4, 2023
e65687b
fix test
younesbelkada Dec 4, 2023
da78cf4
fix copies
younesbelkada Dec 4, 2023
2fcc465
Merge remote-tracking branch 'upstream/main' into awq-fused-modules
younesbelkada Dec 5, 2023
0697687
raise error if model is not quantized
younesbelkada Dec 5, 2023
12aff7c
add test
younesbelkada Dec 5, 2023
498fe55
use quantization_config.config when fusing
younesbelkada Dec 5, 2023
196095e
Update src/transformers/modeling_utils.py
younesbelkada Dec 5, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docker/transformers-all-latest-gpu/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ RUN python3 -m pip install --no-cache-dir auto-gptq --extra-index-url https://hu
RUN python3 -m pip install --no-cache-dir einops

# Add autoawq for quantization testing
RUN python3 -m pip install --no-cache-dir https://github.com/casper-hansen/AutoAWQ/releases/download/v0.1.6/autoawq-0.1.6+cu118-cp38-cp38-linux_x86_64.whl
RUN python3 -m pip install --no-cache-dir https://github.com/casper-hansen/AutoAWQ/releases/download/v0.1.7/autoawq-0.1.7+cu118-cp38-cp38-linux_x86_64.whl

# For bettertransformer + gptq
RUN python3 -m pip install --no-cache-dir git+https://github.com/huggingface/optimum@main#egg=optimum
Expand Down
141 changes: 141 additions & 0 deletions docs/source/en/quantization.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,147 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("TheBloke/zephyr-7B-alpha-AWQ", use_flash_attention_2=True, device_map="cuda:0")
```


### Benchmarks

We performed some speed, throughput and latency benchmarks using [`optimum-benchmark`](https://github.com/huggingface/optimum-benchmark) library.

Note at that time of writing this documentation section, the available quantization methods were: `awq`, `gptq` and `bitsandbytes`.

The benchmark was run on a NVIDIA-A100 instance and the model used was [`TheBloke/Mistral-7B-v0.1-AWQ`](https://huggingface.co/TheBloke/Mistral-7B-v0.1-AWQ) for the AWQ model, [`TheBloke/Mistral-7B-v0.1-GPTQ`](https://huggingface.co/TheBloke/Mistral-7B-v0.1-GPTQ) for the GPTQ model. We also benchmarked it against `bitsandbytes` quantization methods and native `float16` model. Some results are shown below:

<div style="text-align: center">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/quantization/forward_memory_plot.png">
</div>

<div style="text-align: center">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/quantization/generate_memory_plot.png">
</div>

<div style="text-align: center">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/quantization/generate_throughput_plot.png">
</div>

<div style="text-align: center">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/quantization/forward_latency_plot.png">
</div>

You can find the full results together with packages versions in [this link](https://github.com/huggingface/optimum-benchmark/tree/main/examples/running-mistrals).

From the results it appears that AWQ quantization method is the fastest quantization method for inference, text generation and among the lowest peak memory for text generation. However, AWQ seems to have the largest forward latency per batch size.


### Make use of fused modules

You can benefit from fused modules by passing an `AwqConfig` with `fuse_modules=True` and your expected maximum sequence length for generation to `fuse_max_seq_len`. For architectures that do not support `do_fuse=True`, you can still fuse the modules, however you need to pass a custom `fusing_mapping` to `AwqConfig()`. Let's dive into these specific usecases.

Note that you cannot combine fusing modules and other optimization techniques such as Flash Attention 2.

#### Fusing modules for supported architectures

Currently we support out of the box AWQ module fusing for `llama` and `mistral`.

To enable this feature for supported architectures simply create an `AwqConfig` and pass the arguments `fuse_max_seq_len` and `do_fuse=True`.

For example to enable module fusing for the model `TheBloke/Mistral-7B-OpenOrca-AWQ`, run:

```python
import torch
from transformers import AwqConfig, AutoModelForCausalLM

model_id = "TheBloke/Mistral-7B-OpenOrca-AWQ"

quantization_config = AwqConfig(
bits=4,
fuse_max_seq_len=512,
do_fuse=True,
)

model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=quantization_config).to(0)
```

Note that you need to define `fuse_max_seq_len` to `AwqConfig`. That total sequence length should include the context length and the expected generation length. You can set it to a large value to be on the safe zone.

You can also apply module fusing for other architectures that are not supported.

#### Fusing modules for unsupported architectures

For architectures that do not support out of the box module fusing, you can pass a custom fusing mapping; simply pass a dictionnary `modules_to_fuse` to `AwqConfig`, let's take an example with the Yi model:


```python
import torch
from transformers import AwqConfig, AutoModelForCausalLM

model_id = "TheBloke/Yi-34B-AWQ"

quantization_config = AwqConfig(
bits=4,
fuse_max_seq_len=512,
modules_to_fuse={
"attention": ["q_proj", "k_proj", "v_proj", "o_proj"],
"layernorm": ["ln1", "ln2", "norm"],
"mlp": ["gate_proj", "up_proj", "down_proj"],
"use_alibi": False,
"num_attention_heads": 56,
"num_key_value_heads": 8,
"hidden_size": 7168
}
)

model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=quantization_config).to(0)
```

The parameter `modules_to_fuse` needs to have the following respective fields:

- `"attention"`: The names of the attention layers to fuse - in the order: query, key, value and output projection layer. In case you don't want to fuse the attention layers you can pass an empty list.
- `"layernorm"`: The names of all the layernorm layers you want to replace with a custom fused layer norm. In case you don't want to fuse these layers you can also pass an empty list.
- `"mlp"`: The names of the MLP layers you want to fuse into a single MLP layer in the order: (gate (dense layer post-attention) / up / down layers).
- `"use_alibi"`: If you model uses alibi positional embedding
- `"num_attention_heads"`: The number of attention heads
- `"num_key_value_heads"`: This is the number of key value heads that should be used to implement Grouped Query Attention. If num_key_value_heads=num_attention_heads, the model will use Multi Head Attention (MHA), if num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used.
- `"hidden_size"`: Dimension of the hidden representations.


#### Benchmarks

We benchmarked the model with and without fused modules first using only `batch_size=1` on the `TheBloke/Mistral-7B-OpenOrca-AWQ` model and below are the results:

*unfused case*

| Batch Size | Prefill Length | Decode Length | Prefill tokens/s | Decode tokens/s | Memory (VRAM) |
|-------------:|-----------------:|----------------:|-------------------:|------------------:|:----------------|
| 1 | 32 | 32 | 60.0984 | 38.4537 | 4.50 GB (5.68%) |
| 1 | 64 | 64 | 1333.67 | 31.6604 | 4.50 GB (5.68%) |
| 1 | 128 | 128 | 2434.06 | 31.6272 | 4.50 GB (5.68%) |
| 1 | 256 | 256 | 3072.26 | 38.1731 | 4.50 GB (5.68%) |
| 1 | 512 | 512 | 3184.74 | 31.6819 | 4.59 GB (5.80%) |
| 1 | 1024 | 1024 | 3148.18 | 36.8031 | 4.81 GB (6.07%) |
| 1 | 2048 | 2048 | 2927.33 | 35.2676 | 5.73 GB (7.23%) |

*fused case*

| Batch Size | Prefill Length | Decode Length | Prefill tokens/s | Decode tokens/s | Memory (VRAM) |
|-------------:|-----------------:|----------------:|-------------------:|------------------:|:----------------|
| 1 | 32 | 32 | 81.4899 | 80.2569 | 4.00 GB (5.05%) |
| 1 | 64 | 64 | 1756.1 | 106.26 | 4.00 GB (5.05%) |
| 1 | 128 | 128 | 2479.32 | 105.631 | 4.00 GB (5.06%) |
| 1 | 256 | 256 | 1813.6 | 85.7485 | 4.01 GB (5.06%) |
| 1 | 512 | 512 | 2848.9 | 97.701 | 4.11 GB (5.19%) |
| 1 | 1024 | 1024 | 3044.35 | 87.7323 | 4.41 GB (5.57%) |
| 1 | 2048 | 2048 | 2715.11 | 89.4709 | 5.57 GB (7.04%) |

We also performed benchmarks with [`optimum-benchmark`](https://github.com/huggingface/optimum-benchmark) library. And below are the results:

<div style="text-align: center">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/quantization/fused_forward_memory_plot.png">
</div>

<div style="text-align: center">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/quantization/fused_generate_throughput.png">
</div>


## AutoGPTQ

<Tip>
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/integrations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@


_import_structure = {
"awq": ["replace_with_awq_linear"],
"awq": ["fuse_awq_modules", "replace_with_awq_linear"],
"bitsandbytes": [
"get_keys_to_not_convert",
"replace_8bit_linear",
Expand Down Expand Up @@ -80,7 +80,7 @@
}

if TYPE_CHECKING:
from .awq import replace_with_awq_linear
from .awq import fuse_awq_modules, replace_with_awq_linear
from .bitsandbytes import (
get_keys_to_not_convert,
replace_8bit_linear,
Expand Down
Loading
Loading