Skip to content

Latest commit

 

History

History
200 lines (156 loc) · 9.42 KB

README.md

File metadata and controls

200 lines (156 loc) · 9.42 KB

TensorRT-LLM Quantization Toolkit Installation Guide

Introduction

This document introduces:

  • The steps to install the TensorRT-LLM quantization toolkit.
  • The Python APIs to quantize the models.

The detailed LLM quantization recipe is distributed to the README.md of the corresponding model examples.

Installation

The NVIDIA TensorRT Model Optimizer quantization toolkit is installed automatically as a dependency of TensorRT-LLM.

# Install the additional requirements
cd examples/quantization
pip install -r requirements.txt

Usage

# FP8 quantization.
python quantize.py --model_dir $MODEL_PATH --qformat fp8 --kv_cache_dtype fp8 --output_dir $OUTPUT_PATH

# INT4_AWQ tp4 quantization.
python quantize.py --model_dir $MODEL_PATH --qformat int4_awq --awq_block_size 64 --tp_size 4 --output_dir $OUTPUT_PATH

# INT8 SQ with INT8 kv cache.
python quantize.py --model_dir $MODEL_PATH --qformat int8_sq --kv_cache_dtype int8 --output_dir $OUTPUT_PATH

# Auto quantization(e.g. fp8 + int4_awq + w4a8_awq) using average weights bits 5
python quantize.py --model_dir $MODEL_PATH  --autoq_format fp8,int4_awq,w4a8_awq  --output_dir $OUTPUT_PATH --auto_quantize_bits 5 --tp_size 2

# FP8 quantization for NeMo model.
python quantize.py --nemo_ckpt_path nemotron-3-8b-base-4k/Nemotron-3-8B-Base-4k.nemo \
                   --dtype bfloat16 \
                   --batch_size 64 \
                   --qformat fp8 \
                   --output_dir nemotron-3-8b/trt_ckpt/fp8/1-gpu

# FP8 quantization for Medusa model.
python quantize.py --model_dir $MODEL_PATH\
                   --dtype float16 \
                   --qformat fp8 \
                   --kv_cache_dtype fp8 \
                   --output_dir $OUTPUT_PATH \
                   --calib_size 512 \
                   --tp_size 1 \
                   --medusa_model_dir /path/to/medusa_head/ \
                   --num_medusa_heads 4

Checkpoint saved in output_dir can be directly passed to trtllm-build.

Quantization Arguments:

  • model_dir: Hugging Face model path.
  • qformat: Specify the quantization algorithm applied to the checkpoint.
    • fp8: Weights are quantized to FP8 tensor wise. Activation ranges are calibrated tensor wise.
    • int8_sq: Weights are smoothed and quantized to INT8 channel wise. Activation ranges are calibrated tensor wise.
    • int4_awq: Weights are re-scaled and block-wise quantized to INT4. Block size is specified by awq_block_size.
    • w4a8_awq: Weights are re-scaled and block-wise quantized to INT4. Block size is specified by awq_block_size. Activation ranges are calibrated tensor wise.
    • int8_wo: Actually nothing is applied to weights. Weights are quantized to INT8 channel wise when TRTLLM building the engine.
    • int4_wo: Same as int8_wo but in INT4.
    • full_prec: No quantization.
  • autoq_format: Specific quantization algorithms will be searched in auto quantization. The algorithm must in ['fp8', 'int4_awq', 'w4a8_awq', 'int8_sq'] and you can use ',' to separate more than one quantization algorithms(e.g. --autoq_format fp8,int4_awq,w4a8_awq).
  • auto_quantize_bits: Effective bits constraint for auto quantization. If not set, regular quantization without auto quantization search will be applied. Note: it must be set within correct range otherwise it will be set by lowest value if possible.
  • output_dir: Path to save the quantized checkpoint.
  • dtype: Specify data type of model when loading from Hugging Face.
  • kv_cache_dtype: Specify kv cache data type.
    • int8: Use int8 kv cache.
    • fp8: Use FP8 kv cache.
    • None (default): Use kv cache as model dtype.
  • batch_size: Batch size for calibration. Default is 1.
  • calib_size: Number of samples. Default is 512.
  • calib_max_seq_length: Max sequence length of calibration samples. Default is 512.
  • tp_size: Checkpoint is tensor paralleled by tp_size. Default is 1.
  • pp_size: Checkpoint is pipeline paralleled by pp_size. Default is 1.
  • awq_block_size: AWQ algorithm specific parameter. Indicate the block size when quantizing weights. 64 and 128 are supported by TRTLLM.

NeMo model specific arguments:

  • nemo_ckpt_path: NeMo checkpoint path.
  • calib_tp_size: TP size for NeMo checkpoint calibration.
  • calib_pp_size: PP size for NeMo checkpoint calibration.

Medusa specific arguments:

  • medusa_model_dir: Model path of medusa.
  • quant_medusa_head: Whether to quantize the weights of medusa heads.
  • num_medusa_heads: Number of medusa heads.
  • num_medusa_layers: Number of medusa layers.
  • max_draft_len: Max length of draft.
  • medusa_hidden_act: Activation function of medusa.

Building Arguments:

There are several arguments for building stage which related to quantizaion.

  • use_fp8_context_fmha: This is Hopper-only feature. Use FP8 Gemm to calculate the attention operation.
qkv scale = 1.0
FP_O = quantize(softmax(FP8_Q * FP8_K), scale=1.0) * FP8_V
FP_O * output_scale = FP8_O

Checkpoint Conversion Arguments (not supported by all models)

  • FP8
    • use_fp8_rowwise: Enable FP8 per-token per-channel quantization for linear layer. (FP8 from quantize.py is per-tensor).
  • INT8
    • smoothquant: Enable INT8 quantization for linear layer. Set the α parameter (see https://arxiv.org/pdf/2211.10438.pdf) to Smoothquant the model, and output int8 weights. A good first try is 0.5. Must be in [0, 1].
    • per_channel: Using per-channel quantization for weight when smoothquant is enabled.
    • per_token: Using per-token quantization for activation when smoothquant is enabled.
  • Weight-Only
    • use_weight_only: Weights are quantized to INT4 or INT8 channel wise.
    • weight_only_precision: Indicate int4 or int8 when use_weight_only is enabled. Or int4_gptq when quant_ckpt_path is provided which means checkpoint is for GPTQ.
    • quant_ckpt_path: Path of a GPTQ quantized model checkpoint in .safetensors format.
    • group_size: Group size used in GPTQ quantization.
    • per_group: Should be enabled when load from GPTQ.
  • KV Cache
    • int8_kv_cache: By default, we use dtype for KV cache. int8_kv_cache chooses int8 quantization for KV cache.
    • fp8_kv_cache: By default, we use dtype for KV cache. fp8_kv_cache chooses fp8 quantization for KV cache.

APIs

quantize.py uses the quantization toolkit to calibrate the PyTorch models and export TensorRT-LLM checkpoints. Each TensorRT-LLM checkpoint contains a config file (in .json format) and one or several rank weight files (in .safetensors format). It will produce one another quantization config for per-layer's information when setting auto quantization. The checkpoints can be directly used by trtllm-build command to build TensorRT-LLM engines. See this doc for more details on the TensorRT-LLM checkpoint format.

This quantization step may take a long time to finish and requires large GPU memory. Please use a server grade GPU if a GPU out-of-memory error occurs

If the model is trained with multi-GPU with tensor parallelism, the PTQ calibration process requires the same amount of GPUs as the training time too.

PTQ (Post Training Quantization)

PTQ can be achieved with simple calibration on a small set of training or evaluation data (typically 128-512 samples) after converting a regular PyTorch model to a quantized model.

import torch
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM
import modelopt.torch.quantization as mtq
import modelopt.torch.utils.dataset_utils as dataset_utils

model = AutoModelForCausalLM.from_pretrained(...)

# Select the quantization config, for example, FP8
config = mtq.FP8_DEFAULT_CFG

# Prepare the calibration set and define a forward loop
calib_dataloader = DataLoader(...)
calibrate_loop = dataset_utils.create_forward_loop(
    calib_dataloader, dataloader=calib_dataloader
)

# PTQ with in-place replacement to quantized modules
with torch.no_grad():
    mtq.quantize(model, config, forward_loop=calibrate_loop)

# or PTQ with auto quantization
with torch.no_grad():
    model, search_history = mtq.auto_quantize(
        model,
        data_loader=calib_dataloader,
        loss_func=lambda output, batch: output.loss,
        constraints={"effective_bits": auto_quantize_bits}, # The average bits of quantized weights
        forward_step=lambda model, batch: model(**batch),
        quantization_formats=[quant_algo1, quant_algo2,...] + [None],
        num_score_steps=min(
        num_calib_steps=len(calib_dataloader),
            len(calib_dataloader), 128 // batch_size
        ),  # Limit the number of score steps to avoid long calibration time
        verbose=True,
    )

Export Quantized Model

After the model is quantized, it can be exported to a TensorRT-LLM checkpoint, which includes

  • One json file recording the model structure and metadata, and
  • One or several rank weight files storing quantized model weights and scaling factors.

The export API is

from modelopt.torch.export import export_tensorrt_llm_checkpoint

with torch.inference_mode():
    export_tensorrt_llm_checkpoint(
        model,  # The quantized model.
        decoder_type,  # The type of the model as str, e.g gptj, llama or gptnext.
        dtype,  # The exported weights data type as torch.dtype.
        export_dir,  # The directory where the exported files will be stored.
        inference_tensor_parallel=tp_size,  # The tensor parallelism size for inference.
        inference_pipeline_parallel=pp_size,  # The pipeline parallelism size for inference.
    )