Skip to content

Commit

Permalink
adding default inductor config settings
Browse files Browse the repository at this point in the history
Summary:

making autoquant and quantize apis call a new
recommended_inductor_config_setter util to set recommended apis

also update groupsize -> groupsize in generate.py

Test Plan:

sh benchmarks.sh

comparison of different config combinations for matmul precision,
mixed_mm and coordinate_descent

tok/s=  9.14, mem/s=  60.55 GB/s, peak_mem= 8.33 GB, model_size= 6.62 GB quant: int8dq, mod: Llama-2-7b-chat-hf,
tok/s=147.02, mem/s= 973.53 GB/s, peak_mem= 8.95 GB, model_size= 6.62 GB quant: int8wo, mod: Llama-2-7b-chat-hf,
tok/s=  9.23, mem/s=  61.11 GB/s, peak_mem= 8.33 GB, model_size= 6.62 GB quant: int8dq, mod: Llama-2-7b-chat-hf,
tok/s=139.59, mem/s= 924.33 GB/s, peak_mem= 8.95 GB, model_size= 6.62 GB quant: int8wo, mod: Llama-2-7b-chat-hf,
tok/s=  9.10, mem/s=  60.26 GB/s, peak_mem= 8.33 GB, model_size= 6.62 GB quant: int8dq, mod: Llama-2-7b-chat-hf,
tok/s=146.98, mem/s= 973.23 GB/s, peak_mem= 8.95 GB, model_size= 6.62 GB quant: int8wo, mod: Llama-2-7b-chat-hf,
tok/s=  9.28, mem/s=  61.48 GB/s, peak_mem= 8.33 GB, model_size= 6.62 GB quant: int8dq, mod: Llama-2-7b-chat-hf,
tok/s=146.90, mem/s= 972.73 GB/s, peak_mem= 8.95 GB, model_size= 6.62 GB quant: int8wo, mod: Llama-2-7b-chat-hf,
tok/s=  9.08, mem/s=  60.09 GB/s, peak_mem= 8.33 GB, model_size= 6.62 GB quant: int8dq, mod: Llama-2-7b-chat-hf,
tok/s=137.58, mem/s= 911.00 GB/s, peak_mem= 8.95 GB, model_size= 6.62 GB quant: int8wo, mod: Llama-2-7b-chat-hf,
tok/s=  9.19, mem/s=  60.87 GB/s, peak_mem= 8.61 GB, model_size= 6.62 GB quant: int8dq, mod: Llama-2-7b-chat-hf,
tok/s=166.02, mem/s=1099.30 GB/s, peak_mem= 8.97 GB, model_size= 6.62 GB quant: int8wo, mod: Llama-2-7b-chat-hf,

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
HDCharles committed Jun 25, 2024
1 parent 9dc2c11 commit bfe2ea2
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 22 deletions.
4 changes: 1 addition & 3 deletions torchao/_models/llama/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,7 @@
import time
from torchao.quantization.GPTQ import Int4WeightOnlyGPTQQuantizer
from torchao._models.llama.model import prepare_inputs_for_model

torch._inductor.config.fx_graph_cache = True
torch._inductor.config.force_fuse_int_mm_with_mul = True
torchao.quantization.utils.recommended_inductor_config_setter()

def run_evaluation(
checkpoint_path: Path,
Expand Down
14 changes: 4 additions & 10 deletions torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import torch._dynamo.config
import torch._inductor.config
from torchao.utils import get_model_size_in_bytes
torchao.quantization.utils.recommended_inductor_config_setter()

def device_sync(device):
if "cuda" in device:
Expand All @@ -22,13 +23,6 @@ def device_sync(device):
else:
print(f"device={device} is not yet suppported")


torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.triton.unique_kernel_names = True
torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
torch._inductor.config.force_fuse_int_mm_with_mul = True
# torch._inductor.config.use_mixed_mm = True

default_device = 'cuda' if torch.cuda.is_available() else 'cpu'

# support running without installing as a package
Expand Down Expand Up @@ -203,7 +197,7 @@ def main(
if "int4wo" in quantization:
groupsize=int(quantization.split("-")[-1])
assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}"
quantize(model, int4_weight_only(groupsize=groupsize))
quantize(model, int4_weight_only(group_size=groupsize))
if "autoquant" == quantization:
model = autoquant(model, manual=True)

Expand Down Expand Up @@ -339,8 +333,8 @@ def callback(x):
parser.add_argument('--max_new_tokens', type=int, default=200, help='Maximum number of new tokens.')
parser.add_argument('--top_k', type=int, default=200, help='Top-k for sampling.')
parser.add_argument('--temperature', type=float, default=0.8, help='Temperature for sampling.')
parser.add_argument('--checkpoint_path', type=Path, default=Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"), help='Model checkpoint path.')
parser.add_argument("--quantization", type=str, help='Which quantization techniques to apply: int8dq, int8wo, int4wo-<groupsize>, autoquant')
parser.add_argument('--checkpoint_path', type=Path, default=Path("../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), help='Model checkpoint path.')
parser.add_argument('-q', '--quantization', type=str, help='Which quantization techniques to apply: int8dq, int8wo, int4wo-<groupsize>, autoquant')
parser.add_argument('--compile', action='store_true', help='Whether to compile the model.')
parser.add_argument('--compile_prefill', action='store_true', help='Whether to compile the prefill (improves prefill perf, but higher compile times)')
parser.add_argument('--profile', type=Path, default=None, help='Profile path.')
Expand Down
10 changes: 3 additions & 7 deletions torchao/quantization/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,6 @@ of the activations that the different linear layers see, it then benchmarks thes
import torch
import torchao

# inductor settings which improve torch.compile performance for quantized modules
torch._inductor.config.force_fuse_int_mm_with_mul = True
torch._inductor.config.use_mixed_mm = True

# Plug in your model and example input
model = torch.nn.Sequential(torch.nn.Linear(32, 64)).cuda().to(torch.bfloat16)
input = torch.randn(32,32, dtype=torch.bfloat16, device='cuda')
Expand Down Expand Up @@ -107,9 +103,6 @@ m_bf16 = torch.compile(m_bf16, mode='max-autotune')
group_size = 32
m = quantize(m, int4_weight_only(group_size=group_size))

torch._inductor.config.force_fuse_int_mm_with_mul = True
torch._inductor.config.use_mixed_mm = True

# temporary workaround for tensor subclass + torch.compile
from torchao.quantization.utils import unwrap_tensor_subclass
m = unwrap_tensor_subclass(m)
Expand Down Expand Up @@ -163,6 +156,9 @@ m = torch.export.export(m_unwrapped, example_inputs).module()
torch._export.aot_compile(m_unwrapped, example_inputs)
```

### Automatic Inductor Configuration
The `quantize` and `autoquant` apis now automatically use our recommended inductor configuration setings. You can mimic the same configuration settings for your own experiments by using the `torchao.quantization.utils.recommended_inductor_config_setter` to replicate our recommended configuration settings. Alternatively if you wish to disable these recommended settings, you can use the key word argument `set_inductor_config` and set it to false in the `quantize` or `autoquant` apis to prevent assignment of those configuration settings. You can also overwrite these configuration settings after they are assigned if you so desire, as long as they are overwritten before passing any inputs to the torch.compiled model. This means that previous flows which referenced a variety of inductor configurations that needed to be set are now outdated, though continuing to manually set those same inductor configurations is unlikely to cause any issues.

### Other Available Quantization Techniques
#### A8W8 Dynamic Quantization

Expand Down
9 changes: 8 additions & 1 deletion torchao/quantization/autoquant.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
import torchao
from .subclass import ( # noqa
Int8DynamicallyQuantizedLinearWeight,
Int8WeightOnlyQuantizedLinearWeight,
Expand Down Expand Up @@ -443,8 +444,10 @@ def autoquant(
model,
example_input=None,
qtensor_class_list=DEFAULT_CLASS_LIST,
filter_fn=None, mode=["interpolate", .85],
filter_fn=None,
mode=["interpolate", .85],
manual=False,
set_inductor_config=True,
**aq_kwargs
):
"""
Expand Down Expand Up @@ -477,6 +480,7 @@ def autoquant(
and the second element is the mode value (e.g., 0.85). Defaults to ["interpolate", .85].
manual (bool, optional): Whether to stop shape calibration and do autoquant after a single run (default, False) or to wait for
the user to call model.finalize_autoquant (True) so inputs with several shapes/dtypes can be logged.
set_inductor_config (bool, optional): Whether to automatically use recommended inductor config settings (defaults to True)
**aq_kwargs: Additional keyword arguments for the autoquantization process.
Returns:
Expand All @@ -493,6 +497,9 @@ def autoquant(
model(*example_input2)
model.finalize_autoquant()
"""
if set_inductor_config:
torchao.quantization.utils.recommended_inductor_config_setter()


# perform initial swap from linear weights
# to AutoQuantizableLinearWeight
Expand Down
6 changes: 5 additions & 1 deletion torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"""

import torch
import torchao
import torch.nn as nn
import torch.nn.functional as F
from typing import Any, Callable, Union, Dict
Expand Down Expand Up @@ -258,13 +259,14 @@ def insert_subclass(lin):

return insert_subclass

def quantize(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.Tensor], torch.Tensor], filter_fn=None) -> torch.nn.Module:
def quantize(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.Tensor], torch.Tensor], filter_fn=None, set_inductor_config=True) -> torch.nn.Module:
"""Convert the weight of linear modules in the model with `apply_tensor_subclass`
Args:
model: input model
apply_tensor_subclass (Callable[[torch.Tensor], torch.Tensor]): function that convert a floating point Tensor to a (quantized) tensor subclass instance
filter_fn: used to filter out the modules that we don't want to apply tenosr subclass
set_inductor_config (bool, optional): Whether to automatically use recommended inductor config settings (defaults to True)
Example::
Expand All @@ -291,6 +293,8 @@ def filter_fn(module, fqn):
m = MyModel(...)
m = quantize(m, apply_weight_quant, filter_fn)
"""
if set_inductor_config:
torchao.quantization.utils.recommended_inductor_config_setter()
if isinstance(apply_tensor_subclass, str):
if apply_tensor_subclass not in _APPLY_TS_TABLE:
raise ValueError(f"{apply_tensor_subclass} not supported: {_APPLY_TS_TABLE.keys()}")
Expand Down
18 changes: 18 additions & 0 deletions torchao/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
"groupwise_affine_dequantize_tensor",
"per_token_dynamic_quant",
"get_group_qparams_symmetric",
"recommended_inductor_config_setter"
]

try:
Expand Down Expand Up @@ -456,3 +457,20 @@ def per_token_dynamic_quant(input: torch.Tensor) -> torch.Tensor:
input, scales, zero_points, quant_min, quant_max, torch.int8, orig_dtype
)
return input.to(orig_dtype)

def recommended_inductor_config_setter():
"""
Set inductor config to use the following optimizations which have been showed to improve performance for quantized models:
coordinate_descent_tuning = True
coordinate_descent_check_all_directions = True
force_fuse_int_mm_with_mul = True
fx_graph_cache = True
triton.unique_kernel_names = True
torch.set_float32_matmul_precision("high")
"""
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.coordinate_descent_check_all_directions = True
torch._inductor.config.force_fuse_int_mm_with_mul = True
torch._inductor.config.fx_graph_cache = True
torch._inductor.config.triton.unique_kernel_names = True
torch.set_float32_matmul_precision("high")

0 comments on commit bfe2ea2

Please sign in to comment.