From fdb85be40fa255c015819e711c15117c2aaa5101 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Tue, 5 Dec 2023 12:14:45 +0100 Subject: [PATCH] Faster generation using AWQ + Fused modules (#27411) * v1 fusing modules * add fused mlp support * up * fix CI * block save_pretrained * fixup * small fix * add new condition * add v1 docs * add some comments * style * fix nit * adapt from suggestion * add check * change arg names * change variables name * Update src/transformers/integrations/awq.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * style * split up into 3 different private methods * more conditions * more checks * add fused tests for custom models * fix * fix tests * final update docs * final fixes * fix importlib metadata * Update src/transformers/utils/quantization_config.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * change it to `do_fuse` * nit * Update src/transformers/utils/quantization_config.py Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * Update src/transformers/utils/quantization_config.py Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * Update src/transformers/utils/quantization_config.py Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * few fixes * revert * fix test * fix copies * raise error if model is not quantized * add test * use quantization_config.config when fusing * Update src/transformers/modeling_utils.py --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> --- docker/transformers-all-latest-gpu/Dockerfile | 2 +- docs/source/en/quantization.md | 141 +++++++++++ src/transformers/integrations/__init__.py | 4 +- src/transformers/integrations/awq.py | 237 +++++++++++++++++- src/transformers/modeling_utils.py | 44 +++- src/transformers/utils/quantization_config.py | 58 ++++- tests/quantization/autoawq/test_awq.py | 170 +++++++++++-- 7 files changed, 623 insertions(+), 33 deletions(-) diff --git a/docker/transformers-all-latest-gpu/Dockerfile b/docker/transformers-all-latest-gpu/Dockerfile index d108ba5ace5805..7ab236a55d5902 100644 --- a/docker/transformers-all-latest-gpu/Dockerfile +++ b/docker/transformers-all-latest-gpu/Dockerfile @@ -56,7 +56,7 @@ RUN python3 -m pip install --no-cache-dir auto-gptq --extra-index-url https://hu RUN python3 -m pip install --no-cache-dir einops # Add autoawq for quantization testing -RUN python3 -m pip install --no-cache-dir https://github.com/casper-hansen/AutoAWQ/releases/download/v0.1.6/autoawq-0.1.6+cu118-cp38-cp38-linux_x86_64.whl +RUN python3 -m pip install --no-cache-dir https://github.com/casper-hansen/AutoAWQ/releases/download/v0.1.7/autoawq-0.1.7+cu118-cp38-cp38-linux_x86_64.whl # For bettertransformer + gptq RUN python3 -m pip install --no-cache-dir git+https://github.com/huggingface/optimum@main#egg=optimum diff --git a/docs/source/en/quantization.md b/docs/source/en/quantization.md index 60903e36ad5968..00fe899e73bcbc 100644 --- a/docs/source/en/quantization.md +++ b/docs/source/en/quantization.md @@ -85,6 +85,147 @@ from transformers import AutoModelForCausalLM, AutoTokenizer model = AutoModelForCausalLM.from_pretrained("TheBloke/zephyr-7B-alpha-AWQ", use_flash_attention_2=True, device_map="cuda:0") ``` + +### Benchmarks + +We performed some speed, throughput and latency benchmarks using [`optimum-benchmark`](https://github.com/huggingface/optimum-benchmark) library. + +Note at that time of writing this documentation section, the available quantization methods were: `awq`, `gptq` and `bitsandbytes`. + +The benchmark was run on a NVIDIA-A100 instance and the model used was [`TheBloke/Mistral-7B-v0.1-AWQ`](https://huggingface.co/TheBloke/Mistral-7B-v0.1-AWQ) for the AWQ model, [`TheBloke/Mistral-7B-v0.1-GPTQ`](https://huggingface.co/TheBloke/Mistral-7B-v0.1-GPTQ) for the GPTQ model. We also benchmarked it against `bitsandbytes` quantization methods and native `float16` model. Some results are shown below: + +
+ +
+ +
+ +
+ +
+ +
+ +
+ +
+ +You can find the full results together with packages versions in [this link](https://github.com/huggingface/optimum-benchmark/tree/main/examples/running-mistrals). + +From the results it appears that AWQ quantization method is the fastest quantization method for inference, text generation and among the lowest peak memory for text generation. However, AWQ seems to have the largest forward latency per batch size. + + +### Make use of fused modules + +You can benefit from fused modules by passing an `AwqConfig` with `fuse_modules=True` and your expected maximum sequence length for generation to `fuse_max_seq_len`. For architectures that do not support `do_fuse=True`, you can still fuse the modules, however you need to pass a custom `fusing_mapping` to `AwqConfig()`. Let's dive into these specific usecases. + +Note that you cannot combine fusing modules and other optimization techniques such as Flash Attention 2. + +#### Fusing modules for supported architectures + +Currently we support out of the box AWQ module fusing for `llama` and `mistral`. + +To enable this feature for supported architectures simply create an `AwqConfig` and pass the arguments `fuse_max_seq_len` and `do_fuse=True`. + +For example to enable module fusing for the model `TheBloke/Mistral-7B-OpenOrca-AWQ`, run: + +```python +import torch +from transformers import AwqConfig, AutoModelForCausalLM + +model_id = "TheBloke/Mistral-7B-OpenOrca-AWQ" + +quantization_config = AwqConfig( + bits=4, + fuse_max_seq_len=512, + do_fuse=True, +) + +model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=quantization_config).to(0) +``` + +Note that you need to define `fuse_max_seq_len` to `AwqConfig`. That total sequence length should include the context length and the expected generation length. You can set it to a large value to be on the safe zone. + +You can also apply module fusing for other architectures that are not supported. + +#### Fusing modules for unsupported architectures + +For architectures that do not support out of the box module fusing, you can pass a custom fusing mapping; simply pass a dictionnary `modules_to_fuse` to `AwqConfig`, let's take an example with the Yi model: + + +```python +import torch +from transformers import AwqConfig, AutoModelForCausalLM + +model_id = "TheBloke/Yi-34B-AWQ" + +quantization_config = AwqConfig( + bits=4, + fuse_max_seq_len=512, + modules_to_fuse={ + "attention": ["q_proj", "k_proj", "v_proj", "o_proj"], + "layernorm": ["ln1", "ln2", "norm"], + "mlp": ["gate_proj", "up_proj", "down_proj"], + "use_alibi": False, + "num_attention_heads": 56, + "num_key_value_heads": 8, + "hidden_size": 7168 + } +) + +model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=quantization_config).to(0) +``` + +The parameter `modules_to_fuse` needs to have the following respective fields: + +- `"attention"`: The names of the attention layers to fuse - in the order: query, key, value and output projection layer. In case you don't want to fuse the attention layers you can pass an empty list. +- `"layernorm"`: The names of all the layernorm layers you want to replace with a custom fused layer norm. In case you don't want to fuse these layers you can also pass an empty list. +- `"mlp"`: The names of the MLP layers you want to fuse into a single MLP layer in the order: (gate (dense layer post-attention) / up / down layers). +- `"use_alibi"`: If you model uses alibi positional embedding +- `"num_attention_heads"`: The number of attention heads +- `"num_key_value_heads"`: This is the number of key value heads that should be used to implement Grouped Query Attention. If num_key_value_heads=num_attention_heads, the model will use Multi Head Attention (MHA), if num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. +- `"hidden_size"`: Dimension of the hidden representations. + + +#### Benchmarks + +We benchmarked the model with and without fused modules first using only `batch_size=1` on the `TheBloke/Mistral-7B-OpenOrca-AWQ` model and below are the results: + +*unfused case* + +| Batch Size | Prefill Length | Decode Length | Prefill tokens/s | Decode tokens/s | Memory (VRAM) | +|-------------:|-----------------:|----------------:|-------------------:|------------------:|:----------------| +| 1 | 32 | 32 | 60.0984 | 38.4537 | 4.50 GB (5.68%) | +| 1 | 64 | 64 | 1333.67 | 31.6604 | 4.50 GB (5.68%) | +| 1 | 128 | 128 | 2434.06 | 31.6272 | 4.50 GB (5.68%) | +| 1 | 256 | 256 | 3072.26 | 38.1731 | 4.50 GB (5.68%) | +| 1 | 512 | 512 | 3184.74 | 31.6819 | 4.59 GB (5.80%) | +| 1 | 1024 | 1024 | 3148.18 | 36.8031 | 4.81 GB (6.07%) | +| 1 | 2048 | 2048 | 2927.33 | 35.2676 | 5.73 GB (7.23%) | + +*fused case* + +| Batch Size | Prefill Length | Decode Length | Prefill tokens/s | Decode tokens/s | Memory (VRAM) | +|-------------:|-----------------:|----------------:|-------------------:|------------------:|:----------------| +| 1 | 32 | 32 | 81.4899 | 80.2569 | 4.00 GB (5.05%) | +| 1 | 64 | 64 | 1756.1 | 106.26 | 4.00 GB (5.05%) | +| 1 | 128 | 128 | 2479.32 | 105.631 | 4.00 GB (5.06%) | +| 1 | 256 | 256 | 1813.6 | 85.7485 | 4.01 GB (5.06%) | +| 1 | 512 | 512 | 2848.9 | 97.701 | 4.11 GB (5.19%) | +| 1 | 1024 | 1024 | 3044.35 | 87.7323 | 4.41 GB (5.57%) | +| 1 | 2048 | 2048 | 2715.11 | 89.4709 | 5.57 GB (7.04%) | + +We also performed benchmarks with [`optimum-benchmark`](https://github.com/huggingface/optimum-benchmark) library. And below are the results: + +
+ +
+ +
+ +
+ + ## AutoGPTQ diff --git a/src/transformers/integrations/__init__.py b/src/transformers/integrations/__init__.py index 427b5e00000feb..3d1e41263eef70 100644 --- a/src/transformers/integrations/__init__.py +++ b/src/transformers/integrations/__init__.py @@ -17,7 +17,7 @@ _import_structure = { - "awq": ["replace_with_awq_linear"], + "awq": ["fuse_awq_modules", "replace_with_awq_linear"], "bitsandbytes": [ "get_keys_to_not_convert", "replace_8bit_linear", @@ -80,7 +80,7 @@ } if TYPE_CHECKING: - from .awq import replace_with_awq_linear + from .awq import fuse_awq_modules, replace_with_awq_linear from .bitsandbytes import ( get_keys_to_not_convert, replace_8bit_linear, diff --git a/src/transformers/integrations/awq.py b/src/transformers/integrations/awq.py index 94d996b0fffd4c..336a216e401461 100644 --- a/src/transformers/integrations/awq.py +++ b/src/transformers/integrations/awq.py @@ -12,14 +12,33 @@ # See the License for the specific language governing permissions and # limitations under the License. "AWQ (Activation aware Weight Quantization) integration file" +from ..activations import ACT2FN +from ..modeling_utils import PreTrainedModel from ..utils import is_auto_awq_available, is_torch_available -from ..utils.quantization_config import AwqBackendPackingMethod, AWQLinearVersion +from ..utils.quantization_config import AwqBackendPackingMethod, AwqConfig, AWQLinearVersion if is_torch_available(): + import torch import torch.nn as nn +AWQ_FUSED_MAPPINGS = { + "mistral": { + "attention": ["q_proj", "k_proj", "v_proj", "o_proj"], + "mlp": ["gate_proj", "up_proj", "down_proj"], + "layernorm": ["input_layernorm", "post_attention_layernorm", "norm"], + "use_alibi": False, + }, + "llama": { + "attention": ["q_proj", "k_proj", "v_proj", "o_proj"], + "mlp": ["gate_proj", "up_proj", "down_proj"], + "layernorm": ["input_layernorm", "post_attention_layernorm", "norm"], + "use_alibi": False, + }, +} + + def replace_with_awq_linear( model, modules_to_not_convert=None, @@ -102,3 +121,219 @@ def replace_with_awq_linear( # Remove the last key for recursion current_key_name.pop(-1) return model, has_been_replaced + + +def get_modules_to_fuse(model, quantization_config): + """ + Returns the fusing mapping given the quantization config and the model + + Args: + model (`~PreTrainedModel`): + The model to fuse - note this model should have been converted into AWQ format beforehand. + quantization_config (`~transformers.quantization_config.AWQConfig`): + The quantization configuration to use. + """ + if not isinstance(model, PreTrainedModel): + raise ValueError(f"The model should be an instance of `PreTrainedModel`, got {model.__class__.__name__}") + + # Always default to `quantization_config.modules_to_fuse` + if quantization_config.modules_to_fuse is not None: + current_fused_mapping = quantization_config.modules_to_fuse + current_fused_mapping["max_seq_len"] = quantization_config.fuse_max_seq_len + elif model.config.model_type in AWQ_FUSED_MAPPINGS: + current_fused_mapping = AWQ_FUSED_MAPPINGS[model.config.model_type] + + # Handle hidden_size, num_attention_heads, num_key_value_heads on our own. + hidden_size = model.config.hidden_size + num_attention_heads = model.config.num_attention_heads + num_key_value_heads = getattr(model.config, "num_key_value_heads", num_attention_heads) + + # Fill `current_fused_mapping` with the expected values + current_fused_mapping["hidden_size"] = hidden_size + current_fused_mapping["num_attention_heads"] = num_attention_heads + current_fused_mapping["num_key_value_heads"] = num_key_value_heads + current_fused_mapping["max_seq_len"] = quantization_config.fuse_max_seq_len + else: + raise ValueError( + "Fusing mapping not found either on the quantization config or the supported `AWQ_FUSED_MAPPINGS`. Please pass a `fused_mapping` argument" + " in the `quantization_config` or raise an issue on transformers https://github.com/huggingface/transformers to add its support." + ) + return current_fused_mapping + + +def fuse_awq_modules(model, quantization_config): + """ + Optionally fuse some modules in the model to speedup inference. + + Args: + model (`~PreTrainedModel`): + The model to fuse - note this model should have been converted into AWQ format beforehand. + quantization_config (`dict`): + The quantization configuration to use. + """ + # We need to convert it from dict in order to get an AwqConfig object + # otherwise the fields `backend` etc. will not be available + # https://github.com/huggingface/transformers/pull/27411#discussion_r1414044495 + awq_config = AwqConfig.from_dict(quantization_config) + backend = awq_config.backend + + modules_to_fuse = get_modules_to_fuse(model, awq_config) + + if backend == AwqBackendPackingMethod.AUTOAWQ: + from awq.modules.fused.attn import QuantAttentionFused + from awq.modules.fused.mlp import QuantFusedMLP + from awq.modules.fused.norm import FasterTransformerRMSNorm + else: + raise ValueError("Fusing is only supported for the AutoAWQ backend") + + for name, module in model.named_modules(): + # Replace layer norms + _fuse_awq_layernorm(modules_to_fuse["layernorm"], module, FasterTransformerRMSNorm) + + # Replace MLP layers + _fuse_awq_mlp(model, name, modules_to_fuse["mlp"], module, QuantFusedMLP) + + # Replace attention layers + _fuse_awq_attention_layers(model, module, modules_to_fuse, name, QuantAttentionFused) + return model + + +def _fuse_awq_layernorm(fuse_module_names, module, target_cls): + """ + Fuse the LayerNorm layers into a target class using autoawq + + Args: + fuse_module_names (`List[str]`): + The list of module names to fuse + module (`nn.Module`): + The pytorch parent module that has layernorm modules to fuse + target_cls (`~autoawq.FasterTransformerRMSNorm`): + The `FasterTransformerRMSNorm` class as it only supports that class + for now. + """ + for module_name in fuse_module_names: + if hasattr(module, module_name): + old_module = getattr(module, module_name) + module._modules[module_name] = target_cls( + old_module.weight, + old_module.variance_epsilon, + ).to(old_module.weight.device) + del old_module + + +def _fuse_awq_mlp(model, current_module_name, fuse_module_names, module, target_cls): + """ + Fuse the MLP layers into a target class using autoawq + + Args: + model (`~PreTrainedModel`): + The input pretrained model + current_module_name (`str`): + The current submodule name + fuse_module_names (`List[str]`): + The list of module names to fuse. For the MLP layers it has to be an array + of length 3 that consists of the 3 MLP layers in the order (gate (dense layer post-attention) / up / down layers) + module (`nn.Module`): + The pytorch parent module that has layernorm modules to fuse + target_cls (`~autoawq.QuantFusedMLP`): + The `QuantFusedMLP` class as it only supports that class + for now. + """ + if len(fuse_module_names) == 0: + return + + if hasattr(module, fuse_module_names[0]): + gate_proj = getattr(module, fuse_module_names[0]) + up_proj = getattr(module, fuse_module_names[1]) + down_proj = getattr(module, fuse_module_names[2]) + + previous_device = gate_proj.qweight.device + activation_fn = ACT2FN[model.config.hidden_act] + new_module = target_cls(gate_proj, down_proj, up_proj, activation_fn) + + parent_name, child_name = current_module_name.rsplit(".", 1) + parent = model.get_submodule(parent_name) + setattr(parent, child_name, new_module.to(previous_device)) + + del gate_proj, up_proj, down_proj + + +def _fuse_awq_attention_layers(model, module, modules_to_fuse, current_module_name, target_cls): + """ + Fuse the Attention layers into a target class using autoawq + + Args: + model (`~PreTrainedModel`): + The input pretrained model + module (`nn.Module`): + The pytorch parent module that has layernorm modules to fuse + modules_to_fuse (`List[str]`): + The module fusing mapping. The dictionary has to contain a field `attention` with attention module names + in the correct order: q, k, v, o layer + current_module_name (`str`): + The current submodule name + target_cls (`~autoawq.QuantAttentionFused`): + The `QuantAttentionFused` class as it only supports that class + for now. + """ + from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV + + if len(modules_to_fuse["attention"]) == 0: + return + + if hasattr(module, modules_to_fuse["attention"][0]): + # First, we pack the QKV layers together + q_proj = getattr(module, modules_to_fuse["attention"][0]) + previous_device = q_proj.qweight.device + + if isinstance(q_proj, WQLinear_GEMV): + linear_target_cls = WQLinear_GEMV + cat_dim = 0 + elif isinstance(q_proj, WQLinear_GEMM): + linear_target_cls = WQLinear_GEMM + cat_dim = 1 + else: + raise ValueError("Unsupported q_proj type: {type(q_proj)}") + + k_proj = getattr(module, modules_to_fuse["attention"][1]) + v_proj = getattr(module, modules_to_fuse["attention"][2]) + o_proj = getattr(module, modules_to_fuse["attention"][3]) + + bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None + + qkv_layer = linear_target_cls( + q_proj.w_bit, + q_proj.group_size, + q_proj.in_features, + q_proj.out_features + k_proj.out_features + v_proj.out_features, + q_proj.bias is not None, + next(iter(module.state_dict().values())).device, + ) + + qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=cat_dim) + qkv_layer.qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=cat_dim) + qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=cat_dim) + + if isinstance(qkv_layer, WQLinear_GEMV): + qkv_layer.split_k_iters = q_proj.split_k_iters + + qkv_layer.bias = bias + + fused_attention_layer = target_cls( + modules_to_fuse["hidden_size"], + modules_to_fuse["num_attention_heads"], + modules_to_fuse["num_key_value_heads"], + qkv_layer, + o_proj, + previous_device, + modules_to_fuse["max_seq_len"], + use_alibi=modules_to_fuse["use_alibi"], + ) + + fused_attention_layer.is_hf_transformers = True + + parent_name, child_name = current_module_name.rsplit(".", 1) + parent = model.get_submodule(parent_name) + setattr(parent, child_name, fused_attention_layer.to(previous_device)) + + del q_proj, k_proj, v_proj, o_proj diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 4d1178bc6862a0..e478a016c8afb8 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2071,6 +2071,9 @@ def save_pretrained( "You are calling `save_pretrained` on a 4-bit converted model. This is currently not supported" ) + if getattr(self, "_awq_is_fused", False): + raise ValueError("You cannot save an AWQ model that uses fused modules!") + if "save_config" in kwargs: warnings.warn( "`save_config` is deprecated and will be removed in v5 of Transformers. Use `is_main_process` instead." @@ -2726,18 +2729,12 @@ def from_pretrained( ) quantization_method_from_args = None + if quantization_config is not None: quantization_method_from_args = getattr( quantization_config, "quant_method", QuantizationMethod.BITS_AND_BYTES ) - if quantization_method_from_args == QuantizationMethod.AWQ: - raise ValueError( - "You cannot pass an `AwqConfig` when loading a model as you can only use AWQ models" - " for inference. To quantize transformers models with AWQ algorithm, please refer to our" - " quantization docs: https://huggingface.co/docs/transformers/main_classes/quantization " - ) - if quantization_config is None and (load_in_8bit or load_in_4bit): quantization_method_from_args = QuantizationMethod.BITS_AND_BYTES quantization_config, kwargs = BitsAndBytesConfig.from_dict( @@ -2830,21 +2827,36 @@ def from_pretrained( quantization_method_from_config = config.quantization_config.get( "quant_method", QuantizationMethod.BITS_AND_BYTES ) + + if ( + quantization_method_from_args is not None + and quantization_method_from_args == QuantizationMethod.AWQ + and quantization_method_from_config is None + ): + raise ValueError( + "You cannot quantize with AWQ a non-quantized model using transformers, please refer to the quantization documentation" + " to read more about how to quantize models with AWQ algorithm https://huggingface.co/docs/transformers/main_classes/quantization" + ) + if quantization_method_from_config is not None and quantization_method_from_args is not None: if quantization_method_from_config != quantization_method_from_args: raise ValueError( f"The model is already quantized with {quantization_method_from_config}. " f"You can't quantize it again with {quantization_method_from_args}" ) - if quantization_method_from_config == QuantizationMethod.GPTQ and quantization_method_from_args is not None: + + if ( + quantization_method_from_config in (QuantizationMethod.GPTQ, QuantizationMethod.AWQ) + and quantization_method_from_args is not None + ): loading_attr_dict = quantization_config.get_loading_attributes() for attr, val in loading_attr_dict.items(): config.quantization_config[attr] = val quantization_method_from_args = None logger.warning( - "You passed `quantization_config` to `from_pretrained` but the model you're loading already has a " - "`quantization_config` attribute and has already quantized weights. However, loading attributes" - " (e.g. use_exllama, exllama_config, use_cuda_fp16, max_input_length) will be overwritten with the one you passed to `from_pretrained`. The rest will be ignored." + f"You passed `quantization_config` to `from_pretrained` but the model you're loading already has a " + f"`quantization_config` attribute and has already quantized weights. However, loading attributes" + f" (e.g. {list(loading_attr_dict.keys())}) will be overwritten with the one you passed to `from_pretrained`. The rest will be ignored." ) if ( quantization_method_from_args == QuantizationMethod.GPTQ @@ -3372,7 +3384,7 @@ def from_pretrained( model = quantizer.convert_model(model) model._is_quantized_training_enabled = True elif quantization_method_from_config == QuantizationMethod.AWQ: - from .integrations import get_keys_to_not_convert, replace_with_awq_linear + from .integrations import fuse_awq_modules, get_keys_to_not_convert, replace_with_awq_linear modules_to_not_convert = get_keys_to_not_convert(model) @@ -3590,6 +3602,14 @@ def from_pretrained( ) pass + if ( + quantization_config is not None + and quantization_config.quant_method == QuantizationMethod.AWQ + and quantization_config.do_fuse + ): + model = fuse_awq_modules(model, config.quantization_config) + model._awq_is_fused = True + # Dispatch model with hooks on all devices if necessary if device_map is not None: device_map_kwargs = { diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index 222ba68a6dc1e4..4f268ab6bc7102 100644 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -24,7 +24,7 @@ from packaging import version -from ..utils import is_torch_available, logging +from ..utils import is_auto_awq_available, is_torch_available, logging if is_torch_available(): @@ -543,6 +543,12 @@ class AwqConfig(QuantizationConfigMixin): backend (`AwqBackendPackingMethod`, *optional*, defaults to `AwqBackendPackingMethod.AUTOAWQ`): The quantization backend. Some models might be quantized using `llm-awq` backend. This is useful for users that quantize their own models using `llm-awq` library. + do_fuse (`bool`, *optional*, defaults to `False`): + Whether to fuse attention and mlp layers together for faster inference + fuse_max_seq_len (`int`, *optional*): + The Maximum sequence length to generate when using fusing. + modules_to_fuse (`dict`, *optional*, default to `None`): + Overwrite the natively supported fusing scheme with the one specified by the users. """ def __init__( @@ -552,6 +558,9 @@ def __init__( zero_point: bool = True, version: AWQLinearVersion = AWQLinearVersion.GEMM, backend: AwqBackendPackingMethod = AwqBackendPackingMethod.AUTOAWQ, + do_fuse: Optional[bool] = None, + fuse_max_seq_len: Optional[int] = None, + modules_to_fuse: Optional[dict] = None, **kwargs, ): self.quant_method = QuantizationMethod.AWQ @@ -561,6 +570,14 @@ def __init__( self.zero_point = zero_point self.version = version self.backend = backend + self.fuse_max_seq_len = fuse_max_seq_len + + self.modules_to_fuse = modules_to_fuse + if do_fuse is None: + self.do_fuse = modules_to_fuse is not None and len(modules_to_fuse) > 0 + else: + self.do_fuse = do_fuse + self.fuse_max_seq_len = fuse_max_seq_len self.post_init() @@ -587,3 +604,42 @@ def post_init(self): major, minor = compute_capability if major < 8: raise ValueError("LLM-AWQ backend is only supported on GPUs with compute capability >= 8.0") + + if self.do_fuse and self.fuse_max_seq_len is None: + raise ValueError( + "You cannot enable fused modules without specifying a `fuse_max_seq_len`, make sure to pass a valid `fuse_max_seq_len` for your usecase" + ) + + if self.do_fuse: + awq_version_supports_fusing = False + MIN_AWQ_VERSION = "0.1.7" + if is_auto_awq_available(): + awq_version_supports_fusing = version.parse(importlib.metadata.version("autoawq")) >= version.parse( + MIN_AWQ_VERSION + ) + + if not awq_version_supports_fusing: + raise ValueError( + f"You current version of `autoawq` does not support module fusing, please upgrade `autoawq` package to at least {MIN_AWQ_VERSION}." + ) + + if self.do_fuse and self.modules_to_fuse is not None: + required_keys = [ + "hidden_size", + "num_attention_heads", + "num_key_value_heads", + "mlp", + "attention", + "layernorm", + "use_alibi", + ] + if not all(key in self.modules_to_fuse for key in required_keys): + raise ValueError( + f"Required fields are missing in the fusing mapping, required fields are {required_keys}" + ) + + def get_loading_attributes(self): + attibutes_dict = copy.deepcopy(self.__dict__) + loading_attibutes = ["do_fuse", "modules_to_fuse", "fuse_max_seq_len"] + loading_attibutes_dict = {i: j for i, j in attibutes_dict.items() if i in loading_attibutes} + return loading_attibutes_dict diff --git a/tests/quantization/autoawq/test_awq.py b/tests/quantization/autoawq/test_awq.py index f0854e42553e16..8f9cbd91aad773 100644 --- a/tests/quantization/autoawq/test_awq.py +++ b/tests/quantization/autoawq/test_awq.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import gc import tempfile import unittest @@ -107,6 +108,11 @@ def setUpClass(cls): device_map=cls.device_map, ) + def tearDown(self): + gc.collect() + torch.cuda.empty_cache() + gc.collect() + def test_quantized_model_conversion(self): """ Simple test that checks if the quantized model has been converted properly @@ -158,6 +164,13 @@ def test_quantized_model(self): output = self.quantized_model.generate(**input_ids, max_new_tokens=40) self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) + def test_raise_if_non_quantized(self): + model_id = "facebook/opt-125m" + quantization_config = AwqConfig(bits=4) + + with self.assertRaises(ValueError): + _ = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=quantization_config) + def test_quantized_model_bf16(self): """ Simple test that checks if the quantized model is working properly with bf16 @@ -195,22 +208,6 @@ def test_save_pretrained(self): output = model.generate(**input_ids, max_new_tokens=40) self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) - def test_raise_quantization(self): - """ - Simple test that checks if one passes a quantization config to quantize a model, it raises an error - """ - quantization_config = AwqConfig(bits=4) - - with self.assertRaises(ValueError) as context: - _ = AutoModelForCausalLM.from_pretrained( - self.dummy_transformers_model_name, quantization_config=quantization_config - ) - - self.assertEqual( - str(context.exception), - "You cannot pass an `AwqConfig` when loading a model as you can only use AWQ models for inference. To quantize transformers models with AWQ algorithm, please refer to our quantization docs: https://huggingface.co/docs/transformers/main_classes/quantization ", - ) - @require_torch_multi_gpu def test_quantized_model_multi_gpu(self): """ @@ -225,3 +222,144 @@ def test_quantized_model_multi_gpu(self): output = quantized_model.generate(**input_ids, max_new_tokens=40) self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) + + +@slow +@require_torch_gpu +@require_auto_awq +@require_accelerate +class AwqFusedTest(unittest.TestCase): + model_name = "TheBloke/Mistral-7B-OpenOrca-AWQ" + model_revision = "7048b2af77d0dd1c81b000b19d73f9cc8950b510" + + custom_mapping_model_id = "TheBloke/Yi-34B-AWQ" + custom_model_revision = "f1b2cd1b7459ceecfdc1fac5bb8725f13707c589" + + prompt = ( + "You're standing on the surface of the Earth. " + "You walk one mile south, one mile west and one mile north. " + "You end up exactly where you started. Where are you?" + ) + + EXPECTED_GENERATION = prompt + "\n\nThis is a classic puzzle that has been around for" + EXPECTED_GENERATION_CUSTOM_MODEL = "HelloWorld.java:11)\r\n\tat org" + + def tearDown(self): + gc.collect() + torch.cuda.empty_cache() + gc.collect() + + def _check_fused_modules(self, model): + has_fused_modules = False + fused_modules_name = ["QuantAttentionFused", "QuantFusedMLP", "FasterTransformerRMSNorm"] + + for _, module in model.named_modules(): + if module.__class__.__name__ in fused_modules_name: + has_fused_modules = True + break + + self.assertTrue(has_fused_modules, "Modules fusing not performed correctly!") + + def test_raise_save_pretrained(self): + """ + Test that `save_pretrained` is effectively blocked for fused models + """ + quantization_config = AwqConfig(bits=4, fuse_max_seq_len=128, do_fuse=True) + + model = AutoModelForCausalLM.from_pretrained( + self.model_name, + quantization_config=quantization_config, + low_cpu_mem_usage=True, + revision=self.model_revision, + ).to(torch_device) + + self._check_fused_modules(model) + + with self.assertRaises(ValueError), tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + + def test_generation_fused(self): + """ + Test generation quality for fused models - single batch case + """ + quantization_config = AwqConfig(bits=4, fuse_max_seq_len=128, do_fuse=True) + + model = AutoModelForCausalLM.from_pretrained( + self.model_name, + quantization_config=quantization_config, + low_cpu_mem_usage=True, + revision=self.model_revision, + ).to(torch_device) + + self._check_fused_modules(model) + + tokenizer = AutoTokenizer.from_pretrained(self.model_name, revision=self.model_revision) + + inputs = tokenizer(self.prompt, return_tensors="pt").to(torch_device) + + outputs = model.generate(**inputs, max_new_tokens=12) + + self.assertEqual(tokenizer.decode(outputs[0], skip_special_tokens=True), self.EXPECTED_GENERATION) + + def test_generation_fused_batched(self): + """ + Test generation quality for fused models - multi batch case + """ + quantization_config = AwqConfig(bits=4, fuse_max_seq_len=128, do_fuse=True) + + model = AutoModelForCausalLM.from_pretrained( + self.model_name, + quantization_config=quantization_config, + low_cpu_mem_usage=True, + revision=self.model_revision, + ).to(torch_device) + + self._check_fused_modules(model) + + tokenizer = AutoTokenizer.from_pretrained(self.model_name, revision=self.model_revision) + + tokenizer.pad_token_id = tokenizer.eos_token_id + inputs = tokenizer([self.prompt, self.prompt], return_tensors="pt", padding=True).to(torch_device) + + outputs = model.generate(**inputs, max_new_tokens=12) + + self.assertEqual(tokenizer.decode(outputs[0], skip_special_tokens=True), self.EXPECTED_GENERATION) + + @require_torch_multi_gpu + def test_generation_custom_model(self): + """ + Test generation quality for fused models using custom fused map. + """ + quantization_config = AwqConfig( + bits=4, + fuse_max_seq_len=512, + modules_to_fuse={ + "attention": ["q_proj", "k_proj", "v_proj", "o_proj"], + "layernorm": ["ln1", "ln2", "norm"], + "mlp": ["gate_proj", "up_proj", "down_proj"], + "use_alibi": False, + "num_attention_heads": 56, + "num_key_value_heads": 8, + "hidden_size": 7168, + }, + ) + + model = AutoModelForCausalLM.from_pretrained( + self.custom_mapping_model_id, + quantization_config=quantization_config, + trust_remote_code=True, + device_map="balanced", + revision=self.custom_model_revision, + ) + + self._check_fused_modules(model) + + tokenizer = AutoTokenizer.from_pretrained( + self.custom_mapping_model_id, revision=self.custom_model_revision, trust_remote_code=True + ) + + prompt = "Hello" + inputs = tokenizer(prompt, return_tensors="pt").to(torch_device) + + outputs = model.generate(**inputs, max_new_tokens=12) + self.assertEqual(tokenizer.decode(outputs[0], skip_special_tokens=True), self.EXPECTED_GENERATION_CUSTOM_MODEL)