Skip to content

Commit

Permalink
Add static quantization as an example for calibration flow (#487)
Browse files Browse the repository at this point in the history
Summary:
So far quantization flow API that we provided (`quantize_`) does not require calibration (calibrate a model with sample data), this PR added a static quantization
example that serves as an example for calibration flow

* 1. first prepare the model for calibration
* 2. calibrate the prepared model with sample data
* 3. convert the calibrated model to quantized model

Test Plan:
python torchao/prototype/calibration_flow/static_quant.py

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 authored Jul 17, 2024
1 parent e9e6671 commit 6dd82d8
Show file tree
Hide file tree
Showing 6 changed files with 202 additions and 13 deletions.
6 changes: 4 additions & 2 deletions test/dtypes/test_affine_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@ class TestAffineQuantized(TestCase):
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
def test_tensor_core_layout_transpose(self):
t = torch.rand(128, 256, dtype=torch.bfloat16, device="cuda")
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
t = l.weight
shape = t.shape
apply_int4_weight_only_quant = int4_weight_only(group_size=32)
aqt = apply_int4_weight_only_quant(t)
ql = apply_int4_weight_only_quant(l)
aqt = ql.weight
aqt_shape = aqt.shape
self.assertEqual(aqt_shape, shape)

Expand Down
2 changes: 2 additions & 0 deletions torchao/dtypes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .affine_quantized_tensor import (
AffineQuantizedTensor,
to_affine_quantized,
to_affine_quantized_static,
LayoutType,
PlainLayoutType,
TensorCoreTiledLayoutType,
Expand All @@ -15,6 +16,7 @@
"UInt4Tensor"
"AffineQuantizedTensor",
"to_affine_quantized",
"to_affine_quantized_static",
"LayoutType",
"PlainLayoutType",
"TensorCoreTiledLayoutType",
Expand Down
36 changes: 35 additions & 1 deletion torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ def from_float(

scale, zero_point = choose_qparams_affine(input_float, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, scale_dtype, zero_point_dtype, preserve_zero, zero_point_domain)
int_data = quantize_affine(input_float, block_size, scale, zero_point, target_dtype, quant_min, quant_max, zero_point_domain)

int_data = layout_type.post_process(int_data)

layout_tensor_ctr = get_layout_tensor_constructor(type(layout_type))
Expand All @@ -246,8 +247,40 @@ def from_float(
dtype=input_float.dtype
)

@classmethod
def from_float_static(
cls,
input_float: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
block_size: Tuple[int, ...],
target_dtype: torch.dtype,
quant_min: Optional[int] = None,
quant_max: Optional[int] = None,
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
layout_type: LayoutType = PlainLayoutType(),
):
original_shape = input_float.shape
input_float = layout_type.pre_process(input_float)

int_data = quantize_affine(input_float, block_size, scale, zero_point, target_dtype, quant_min, quant_max, zero_point_domain)

int_data = layout_type.post_process(int_data)

layout_tensor_ctr = get_layout_tensor_constructor(type(layout_type))
layout_tensor = layout_tensor_ctr(int_data, scale, zero_point, layout_type)
return cls(
layout_tensor,
block_size,
original_shape,
quant_min,
quant_max,
zero_point_domain,
dtype=input_float.dtype,
)

@property
def layout_type(self) -> str:
def layout_type(self) -> LayoutType:
return self.layout_tensor.layout_type

@classmethod
Expand Down Expand Up @@ -809,3 +842,4 @@ def t(func, *args, **kwargs):
return return_and_correct_aliasing(func, args, kwargs, new)

to_affine_quantized = AffineQuantizedTensor.from_float
to_affine_quantized_static = AffineQuantizedTensor.from_float_static
5 changes: 3 additions & 2 deletions torchao/prototype/quant_llm/quant_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torchao.prototype.custom_fp_utils import _f32_to_fpx_unpacked, _fpx_unpacked_to_f32, _n_ones
from torchao.ops import quant_llm_linear
from torchao.dtypes.utils import _implements, _ATEN_OP_OR_TORCH_FN_TABLE
from torchao.quantization.quant_api import _get_linear_subclass_inserter


_ONES_TABLE = [_n_ones(i) for i in range(8)]
Expand Down Expand Up @@ -456,8 +457,8 @@ def apply_quant_llm(weight: Tensor) -> Tensor:
if (in_dim % 64 != 0) or (out_dim % 256 != 0):
return weight
return QuantLlmLinearWeight.from_float(weight, ebits, mbits)
return apply_quant_llm
return _get_linear_subclass_inserter(apply_quant_llm)


def fp6_llm_weight_only():
return quant_llm_fpx_weight_only(3, 2)
return _get_linear_subclass_inserter(quant_llm_fpx_weight_only(3, 2))
21 changes: 13 additions & 8 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,12 +259,12 @@ def insert_subclass(lin):

return insert_subclass

def quantize_(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.Tensor], torch.Tensor], filter_fn: Optional[Callable[[torch.nn.Module, str], bool]]=None, set_inductor_config: bool=True):
def quantize_(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.nn.Module], torch.nn.Module], filter_fn: Optional[Callable[[torch.nn.Module, str], bool]]=None, set_inductor_config: bool=True):
"""Convert the weight of linear modules in the model with `apply_tensor_subclass`, model is modified inplace
Args:
model (torch.nn.Module): input model
apply_tensor_subclass (Callable[[torch.Tensor], torch.Tensor]): function that convert a floating point Tensor to a (quantized) tensor subclass instance (e.g. affine quantized tensor instance)
apply_tensor_subclass (Callable[[torch.nn.Module], torch.nn.Module]): function that applies tensor subclass conversion to the weight of a module and return the module (e.g. convert the weight tensor of linear to affine quantized tensor)
filter_fn (Optional[Callable[[torch.nn.Module, str], bool]]): function that takes a nn.Module instance and fully qualified name of the module, returns True if we want to run `apply_tensor_subclass` on
the weight of the module
set_inductor_config (bool, optional): Whether to automatically use recommended inductor config settings (defaults to True)
Expand Down Expand Up @@ -300,19 +300,24 @@ def quantize_(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.Ten
x, "asymmetric", (1, groupsize), torch.int32, 0, 15, 1e-6,
zero_point_dtype=torch.bfloat16, preserve_zero=False, zero_point_domain="float")
def apply_weight_quant_to_linear(linear):
linear.weight = torch.nn.Parameter(apply_weight_quant(linear.weight), requires_grad=False)
return linear
# apply to modules under block0 submodule
def filter_fn(module: nn.Module, fqn: str) -> bool:
return isinstance(module, nn.Linear)
m = nn.Sequential(nn.Linear(32, 1024), nn.Linear(1024, 32))
quantize_(m, apply_weight_quant, filter_fn)
quantize_(m, apply_weight_quant_to_linear, filter_fn)
"""
if set_inductor_config:
torchao.quantization.utils.recommended_inductor_config_setter()

_replace_with_custom_fn_if_matches_filter(
model,
_get_linear_subclass_inserter(apply_tensor_subclass),
apply_tensor_subclass,
_is_linear if filter_fn is None else filter_fn,
)

Expand Down Expand Up @@ -356,7 +361,7 @@ def get_per_token_block_size(x):
weight = to_linear_act_quantized(weight, input_quant_func)
return weight

return apply_int8_dynamic_activation_int4_weight_quant
return _get_linear_subclass_inserter(apply_int8_dynamic_activation_int4_weight_quant)


def int4_weight_only(group_size=128, inner_k_tiles=8):
Expand Down Expand Up @@ -394,7 +399,7 @@ def apply_int4_weight_only_quant(weight):
layout_type = TensorCoreTiledLayoutType(inner_k_tiles=inner_k_tiles)
return to_affine_quantized(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain, layout_type=layout_type)

return apply_int4_weight_only_quant
return _get_linear_subclass_inserter(apply_int4_weight_only_quant)


def int8_weight_only():
Expand All @@ -412,7 +417,7 @@ def apply_int8wo_quant(weight):
block_size = (1, weight.shape[1])
return to_affine_quantized(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype)

return apply_int8wo_quant
return _get_linear_subclass_inserter(apply_int8wo_quant)

def int8_dynamic_activation_int8_weight():
"""
Expand Down Expand Up @@ -454,4 +459,4 @@ def get_per_token_block_size(x):
weight = to_linear_act_quantized(weight, input_quant_func)
return weight

return apply_int8_dynamic_activation_int8_weight_quant
return _get_linear_subclass_inserter(apply_int8_dynamic_activation_int8_weight_quant)
145 changes: 145 additions & 0 deletions tutorials/calibration_flow/static_quant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
"""
Demo for static quantization flow
"""
import torch
import copy

# TODO: use the generalized observer for affine qunatization in the future
from torch.ao.quantization.observer import MinMaxObserver, PerChannelMinMaxObserver
import torch.nn.functional as F
from torch import Tensor
from torchao.dtypes import to_affine_quantized_static
from torchao.quantization.utils import compute_error
from torchao.quantization import quantize_
from torchao.quantization.subclass import to_linear_act_quantized
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter



class ObservedLinear(torch.nn.Linear):
def __init__(self, in_features: int, out_features: int, act_obs: torch.nn.Module, weight_obs: torch.nn.Module, bias: bool = True, device=None, dtype=None):
super().__init__(in_features, out_features, bias, device, dtype)
self.act_obs = act_obs
self.weight_obs = weight_obs

def forward(self, input: Tensor):
observed_input = self.act_obs(input)
observed_weight = self.weight_obs(self.weight)
return F.linear(observed_input, observed_weight, self.bias)

@classmethod
def from_float(cls, float_linear, act_obs, weight_obs):
observed_linear = cls(float_linear.in_features, float_linear.out_features, act_obs, weight_obs, False, device=float_linear.weight.device, dtype=float_linear.weight.dtype)
observed_linear.weight = float_linear.weight
observed_linear.bias = float_linear.bias
return observed_linear

def insert_observers_(model, act_obs, weight_obs):
_is_linear = lambda m, fqn: isinstance(m, torch.nn.Linear)
replacement_fn = lambda m: ObservedLinear.from_float(m, act_obs, weight_obs)
act_obs = copy.deepcopy(act_obs)
weight_obs = copy.deepcopy(weight_obs)
_replace_with_custom_fn_if_matches_filter(model, replacement_fn, _is_linear)

# converting observed linear module to linear module with quantzied weights (and quantized activations)
# with tensor subclasses
def apply_static_quant(observed_linear):
target_dtype = torch.uint8

# weight quantization
weight_scale, weight_zero_point = observed_linear.weight_obs.calculate_qparams()
def weight_quant_func(weight):
block_size = (1, weight.shape[1])
return to_affine_quantized_static(weight, weight_scale, weight_zero_point, block_size, target_dtype)
linear = torch.nn.Linear(observed_linear.in_features, observed_linear.out_features, False, device=observed_linear.weight.device, dtype=observed_linear.weight.dtype)
linear.weight = observed_linear.weight
linear.bias = observed_linear.bias

linear.weight = torch.nn.Parameter(weight_quant_func(linear.weight), requires_grad=False)

# activation quantization
act_scale, act_zero_point = observed_linear.act_obs.calculate_qparams()
input_quant_func = lambda x: to_affine_quantized_static(x, act_scale, act_zero_point, x.shape, target_dtype)
linear.weight = torch.nn.Parameter(to_linear_act_quantized(linear.weight, input_quant_func), requires_grad=False)

return linear


# alternative for converting observed linear module to quantized linear module
class QuantizedLinear(torch.nn.Module):
def __init__(self, in_features: int, out_features: int, act_obs: torch.nn.Module, weight_obs: torch.nn.Module, weight: torch.Tensor, bias: torch.Tensor):
super().__init__()
self.act_scale, self.act_zero_point = act_obs.calculate_qparams()
weight_scale, weight_zero_point = weight_obs.calculate_qparams()
assert weight.dim() == 2
block_size = (1, weight.shape[1])
target_dtype = torch.uint8
self.qweight = to_affine_quantized_static(weight, weight_scale, weight_zero_point, block_size, target_dtype)
self.bias = bias

def forward(self, input: Tensor):
block_size = input.shape
target_dtype = torch.uint8
qinput = to_affine_quantized_static(input, self.act_scale, self.act_zero_point, block_size, target_dtype)
return F.linear(qinput, self.qweight, self.bias)

@classmethod
def from_observed(cls, observed_linear):
quantized_linear = cls(observed_linear.in_features, observed_linear.out_features, observed_linear.act_obs, observed_linear.weight_obs, observed_linear.weight, observed_linear.bias)
return quantized_linear

def apply_static_quant2(observed_linear):
return QuantizedLinear.from_observed(observed_linear)

class ToyLinearModel(torch.nn.Module):
def __init__(self, m=64, n=32, k=64):
super().__init__()
self.linear1 = torch.nn.Linear(m, n, bias=False)
self.linear2 = torch.nn.Linear(n, k, bias=False)

def example_inputs(self, batch_size=1, dtype=torch.float32, device="cpu"):
return (torch.randn(batch_size, self.linear1.in_features, dtype=dtype, device=device),)

def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
return x

dtype = torch.bfloat16
m = ToyLinearModel(1024, 1024, 1024).eval().to(dtype).to("cuda")
m_bf16 = copy.deepcopy(m)
example_inputs = m.example_inputs(dtype=dtype, device="cuda")

m_bf16 = torch.compile(m_bf16, mode='max-autotune')

# TODO: use the generalized observer for affine qunatization in the future
act_obs = MinMaxObserver(dtype=torch.uint8, qscheme=torch.per_tensor_affine).to("cuda")
weight_obs = PerChannelMinMaxObserver(dtype=torch.uint8, qscheme=torch.per_channel_affine).to("cuda")

before_quant = m(*example_inputs)

insert_observers_(m, act_obs, weight_obs)
# calibrating / training
for _ in range(10):
m(*example_inputs)

after_obs = m(*example_inputs)

m2 = copy.deepcopy(m)

is_observed_linear = lambda m, fqn: isinstance(m, ObservedLinear)

# quantized linear represented as an nn.Linear with modified tensor subclass weights
# for both activation and weight quantization
quantize_(m, apply_static_quant, is_observed_linear)
print("quantized model (applying tensor subclass to weight):", m)
after_quant = m(*example_inputs)
assert compute_error(before_quant, after_quant) > 30
print("test passed")

# quantized linear as a standalone module
quantize_(m2, apply_static_quant2, is_observed_linear)
print("quantized model (quantized module):", m2)
after_quant = m2(*example_inputs)
assert compute_error(before_quant, after_quant) > 30
print("test passed")

0 comments on commit 6dd82d8

Please sign in to comment.