From eb2554228759ca86175cd42592344b74e67c23bc Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Tue, 8 Oct 2024 12:59:06 -0700 Subject: [PATCH] Add generic fake quantized linear for QAT Summary: This commit adds a generic fake quantized linear module to replace the uses of the existing more specific QAT linears. For example, `Int8DynActInt4WeightQATLinear` can be expressed as follows: ``` from torchao.quantization.prototype.qat.api import FakeQuantizeConfig from torchao.quantization.prototype.qat.linear import FakeQuantizedLinear activation_config = FakeQuantizeConfig( bit_width=8, granularity="per_token", symmetric=False, dynamic=True, ) weight_config = FakeQuantizeConfig( bit_width=4, group_size=8, symmetric=True, dynamic=True, ) fq_linear = FakeQuantizedLinear( 16, 32, False, activation_config, weight_config, ) ``` The main motivation is to provide a more flexible way to perform QAT on models with linear layers. Previously, we would have to create a new linear class every time we wish to experiment with different fake quantization settings, e.g. different group size or different bit width. Now we can express this easily using a single linear module. Test Plan: python test/quantization/test_qat.py -k test_fake_quantize_config python test/quantization/test_qat.py -k test_fake_quantized_linear_8da4w python test/quantization/test_qat.py -k test_fake_quantized_linear_4w ghstack-source-id: 2598aa9a704e109b443c299ec6b8497b18e13716 Pull Request resolved: https://github.com/pytorch/ao/pull/1020 --- test/quantization/test_qat.py | 229 +++++++++--- torchao/quantization/prototype/qat/api.py | 67 +++- .../prototype/qat/fake_quantizer.py | 116 +++++++ torchao/quantization/prototype/qat/linear.py | 327 ++++++++++-------- torchao/quantization/prototype/qat/utils.py | 10 +- 5 files changed, 558 insertions(+), 191 deletions(-) create mode 100644 torchao/quantization/prototype/qat/fake_quantizer.py diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index e1e670d5d..67a59965f 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -11,17 +11,27 @@ import unittest import torch +import torch.nn.functional as F from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401 from torchao.dtypes import ( TensorCoreTiledLayoutType, ) from torchao.quantization.prototype.qat.api import ( ComposableQATQuantizer, + FakeQuantizeConfig, + QuantizationGranularity, +) +from torchao.quantization.prototype.qat.fake_quantizer import ( + FakeQuantizer, +) +from torchao.quantization.prototype.qat.linear import ( + FakeQuantizedLinear, ) from torchao.quantization.prototype.qat.utils import ( _choose_qparams_per_token_asymmetric, _fake_quantize_per_channel_group, _fake_quantize_per_token, + _get_qmin_qmax, _GenericFakeQuantize, ) from torchao.quantization.quant_api import ( @@ -92,15 +102,10 @@ def forward(self, x): class TestQAT(unittest.TestCase): SEED = 123 - def _get_qmin_qmax(self, n_bit: int): - qmin = -(2 ** (n_bit - 1)) - qmax = 2 ** (n_bit - 1) - 1 - return (qmin, qmax) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") def test_fake_quantize_per_channel_group(self): n_bit = 4 - (qmin, qmax) = self._get_qmin_qmax(n_bit) + (qmin, qmax) = _get_qmin_qmax(n_bit) group_size = 128 torch.manual_seed(self.SEED) @@ -126,7 +131,7 @@ def test_fake_quantize_per_channel_group(self): @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") def test_fake_quantize_per_token(self): - (qmin, qmax) = self._get_qmin_qmax(8) + (qmin, qmax) = _get_qmin_qmax(8) torch.manual_seed(self.SEED) x = torch.randn(100, 256).requires_grad_() @@ -165,11 +170,11 @@ def _set_ptq_weight( Int4WeightOnlyQATLinear, ) n_bit = 4 - (qmin, qmax) = self._get_qmin_qmax(n_bit) + (qmin, qmax) = _get_qmin_qmax(n_bit) + group_size = qat_linear.weight_fake_quantizer.config.group_size if isinstance(ptq_linear, Int8DynActInt4WeightLinear): assert isinstance(qat_linear, Int8DynActInt4WeightQATLinear) fp32_weight = qat_linear.weight - group_size = qat_linear.groupsize (s, zp) = get_group_qparams_symmetric(fp32_weight, n_bit, group_size) q_weight = torch.ops.quantized_decomposed.quantize_per_channel_group( fp32_weight, s, zp, qmin, qmax, torch.int8, group_size, @@ -180,7 +185,7 @@ def _set_ptq_weight( elif isinstance(ptq_linear, WeightOnlyInt4Linear): assert isinstance(qat_linear, Int4WeightOnlyQATLinear) (q_weight, scales_and_zeros) = groupwise_affine_quantize_tensor( - qat_linear.weight, n_bit, qat_linear.groupsize, + qat_linear.weight, n_bit, group_size, ) q_weight = torch.ops.aten._convert_weight_to_int4pack( q_weight.to("cuda"), qat_linear.inner_k_tiles, @@ -218,31 +223,36 @@ def test_qat_8da4w_linear(self): @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") def test_qat_8da4w_quantizer(self): from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer - from torchao.quantization.prototype.qat.linear import Int8DynActInt4WeightQATQuantizer + from torchao.quantization.GPTQ import Int8DynActInt4WeightQuantizer group_size = 16 torch.manual_seed(self.SEED) m = M() m2 = copy.deepcopy(m) - subclass_quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size) - module_swap_quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size) - subclass_model = subclass_quantizer.prepare(m) - module_swap_model = module_swap_quantizer.prepare(m2) + qat_quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size) + ptq_quantizer = Int8DynActInt4WeightQuantizer(groupsize=group_size) + qat_model = qat_quantizer.prepare(m) + ptq_model = ptq_quantizer.quantize(m2) # Compare model values torch.manual_seed(self.SEED) x = m.example_inputs() x2 = copy.deepcopy(x) - subclass_out = subclass_model(*x) - module_swap_out = module_swap_model(*x2) - torch.testing.assert_close(subclass_out, module_swap_out, atol=0, rtol=0) + qat_out = qat_model(*x) + ptq_out = ptq_model(*x2) + torch.testing.assert_close(ptq_out, qat_out, atol=0, rtol=0) # Convert QAT model and compare model values - subclass_model = subclass_quantizer.convert(subclass_model) - module_swap_model = module_swap_quantizer.convert(module_swap_model) - subclass_out = subclass_model(*x) - module_swap_out = module_swap_model(*x2) - torch.testing.assert_close(subclass_out, module_swap_out, atol=0, rtol=0) + converted_model = qat_quantizer.convert(qat_model) + converted_out = converted_model(*x) + torch.testing.assert_close(ptq_out, converted_out, atol=0, rtol=0) + + # Compare converted state dict + ptq_state_dict = ptq_model.state_dict() + converted_state_dict = converted_model.state_dict() + self.assertEqual(ptq_state_dict.keys(), converted_state_dict.keys()) + for k in ptq_state_dict.keys(): + torch.testing.assert_close(ptq_state_dict[k], converted_state_dict[k], atol=0, rtol=0) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") def test_qat_8da4w_quantizer_meta_weights(self): @@ -275,9 +285,12 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self): quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size) qat_model = quantizer.prepare(m) qat_model.apply(disable_8da4w_fake_quant) - self.assertFalse(qat_model.linear1._fake_quant_enabled) - self.assertFalse(qat_model.linear2._fake_quant_enabled) - self.assertFalse(qat_model.sub.linear._fake_quant_enabled) + self.assertFalse(qat_model.linear1.activation_fake_quantizer.enabled) + self.assertFalse(qat_model.linear1.weight_fake_quantizer.enabled) + self.assertFalse(qat_model.linear2.activation_fake_quantizer.enabled) + self.assertFalse(qat_model.linear2.weight_fake_quantizer.enabled) + self.assertFalse(qat_model.sub.linear.activation_fake_quantizer.enabled) + self.assertFalse(qat_model.sub.linear.weight_fake_quantizer.enabled) # Disabled fake quant is just a normal linear m2.linear1.weight = torch.nn.Parameter(qat_model.linear1.weight) @@ -292,9 +305,12 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self): # Renable fake quant qat_model.apply(enable_8da4w_fake_quant) - self.assertTrue(qat_model.linear1._fake_quant_enabled) - self.assertTrue(qat_model.linear2._fake_quant_enabled) - self.assertTrue(qat_model.sub.linear._fake_quant_enabled) + self.assertTrue(qat_model.linear1.activation_fake_quantizer.enabled) + self.assertTrue(qat_model.linear1.weight_fake_quantizer.enabled) + self.assertTrue(qat_model.linear2.activation_fake_quantizer.enabled) + self.assertTrue(qat_model.linear2.weight_fake_quantizer.enabled) + self.assertTrue(qat_model.sub.linear.activation_fake_quantizer.enabled) + self.assertTrue(qat_model.sub.linear.weight_fake_quantizer.enabled) # Fake quant should be applied as normal quantizer2 = Int8DynActInt4WeightQATQuantizer(groupsize=group_size) @@ -407,7 +423,7 @@ def test_qat_generic_fake_quantize(self): the numerics of existing fake quantize ops in Pytorch in both the forward and the backward passes. """ - (qmin, qmax) = self._get_qmin_qmax(4) + (qmin, qmax) = _get_qmin_qmax(4) py_input = torch.randn(16, 64).float().requires_grad_() py_s = torch.randn(16).float() py_zp = torch.randint(qmax, size=(16,), dtype=torch.int32) @@ -521,7 +537,7 @@ def test_qat_4w_quantizer_gradients(self): @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") def test_qat_4w_quantizer(self): from torchao.quantization.prototype.qat import Int4WeightOnlyQATQuantizer - from torchao.quantization.prototype.qat.linear import Int4WeightOnlyQATQuantizer + from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer group_size = 32 inner_k_tiles = 8 @@ -530,29 +546,34 @@ def test_qat_4w_quantizer(self): torch.manual_seed(self.SEED) m = M().to(device).to(dtype) m2 = copy.deepcopy(m) - subclass_quantizer = Int4WeightOnlyQATQuantizer( + qat_quantizer = Int4WeightOnlyQATQuantizer( groupsize=group_size, inner_k_tiles=inner_k_tiles, ) - module_swap_quantizer = Int4WeightOnlyQATQuantizer( + ptq_quantizer = Int4WeightOnlyQuantizer( groupsize=group_size, inner_k_tiles=inner_k_tiles, ) - subclass_model = subclass_quantizer.prepare(m) - module_swap_model = module_swap_quantizer.prepare(m2) + qat_model = qat_quantizer.prepare(m) + ptq_model = ptq_quantizer.quantize(m2) # Compare model values torch.manual_seed(self.SEED) x = [i.to(device).to(dtype) for i in m.example_inputs()] x2 = copy.deepcopy(x) - subclass_out = subclass_model(*x) - module_swap_out = module_swap_model(*x2) - torch.testing.assert_close(subclass_out, module_swap_out, atol=0, rtol=0) + qat_out = qat_model(*x) + ptq_out = ptq_model(*x2) + self._assert_close_4w(qat_out, ptq_out) # Convert QAT model and compare model values - subclass_model = subclass_quantizer.convert(subclass_model) - module_swap_model = module_swap_quantizer.convert(module_swap_model) - subclass_out = subclass_model(*x) - module_swap_out = module_swap_model(*x2) - torch.testing.assert_close(subclass_out, module_swap_out, atol=0, rtol=0) + converted_model = qat_quantizer.convert(qat_model) + converted_out = converted_model(*x) + torch.testing.assert_close(converted_out, ptq_out, atol=0, rtol=0) + + # Compare converted state dict + ptq_state_dict = ptq_model.state_dict() + converted_state_dict = converted_model.state_dict() + self.assertEqual(ptq_state_dict.keys(), converted_state_dict.keys()) + for k in ptq_state_dict.keys(): + torch.testing.assert_close(ptq_state_dict[k], converted_state_dict[k], atol=0, rtol=0) class _MyQATQuantizer(TwoStepQuantizer): """ @@ -603,5 +624,127 @@ def test_qat_4w_embedding(self): converted = quantizer.convert(model) converted_out = converted(*x) + def test_fake_quantize_config(self): + """ + Test initialization and property setting of `FakeQuantizeConfig`. + """ + # basic configs + per_token_config = FakeQuantizeConfig(8, "per_token") + self.assertEqual(per_token_config.bit_width, 8) + self.assertEqual(per_token_config.granularity, QuantizationGranularity.PER_TOKEN) + self.assertIsNone(per_token_config.group_size) + per_channel_config = FakeQuantizeConfig(4, "per_channel") + self.assertEqual(per_channel_config.bit_width, 4) + self.assertEqual(per_channel_config.granularity, QuantizationGranularity.PER_CHANNEL) + self.assertIsNone(per_channel_config.group_size) + + # initialize per_group config using only group size + per_group_config = FakeQuantizeConfig(4, group_size=32) + self.assertEqual(per_group_config.bit_width, 4) + self.assertEqual(per_group_config.granularity, QuantizationGranularity.PER_GROUP) + self.assertEqual(per_group_config.group_size, 32) + + # set granularity after initialization, should accept str as before + per_group_config.granularity = "per_token" + self.assertEqual(per_token_config.granularity, QuantizationGranularity.PER_TOKEN) + + # set group_size after initialization, should also update granularity + per_group_config.group_size = 16 + self.assertEqual(per_group_config.granularity, QuantizationGranularity.PER_GROUP) + self.assertEqual(per_group_config.group_size, 16) + + # bad config1: no granularity or group size provided + with self.assertRaisesRegex(ValueError, "group_size or granularity must be set"): + FakeQuantizeConfig(8) + + # bad config2: 'per_group' but no group size + with self.assertRaisesRegex(ValueError, "no group_size was set"): + FakeQuantizeConfig(8, "per_group") + + # bad config3: group size was set but granularity was not 'per_group' + with self.assertRaisesRegex(ValueError, "group_size was set"): + FakeQuantizeConfig(8, "per_token", group_size=16) + + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + def test_fake_quantized_linear_8da4w(self): + """ + Test that we can express int8 dynamic activations + int4 weights with `FakeQuantizedLinear`. + """ + group_size = 128 + torch.manual_seed(self.SEED) + fq_linear = FakeQuantizedLinear( + 256, + 688, + bias=False, + activation_config=FakeQuantizeConfig(8, "per_token", symmetric=False), + weight_config=FakeQuantizeConfig(4, group_size=group_size), + ) + + def linear_forward_8da4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: + """ + Baseline for int8 dynamic per token asymmetric + int4 per group symmetric quant. + """ + # activations + (s, zp) = _choose_qparams_per_token_asymmetric(x, torch.float32, torch.int32) + (qmin, qmax) = _get_qmin_qmax(8) + x_fq = _fake_quantize_per_token(x, s, zp, qmin, qmax) + + # weights + (s, zp) = get_group_qparams_symmetric(weight, 4, group_size, torch.float32) + zp = zp.to(torch.int32) + (qmin, qmax) = _get_qmin_qmax(4) + w_fq = _fake_quantize_per_channel_group(weight, s, zp, qmin, qmax, group_size) + return F.linear(x_fq, w_fq) + + # Compare linear values + torch.manual_seed(self.SEED) + x = torch.randn(100, 256) + x2 = copy.deepcopy(x) + fq_out = fq_linear(x) + baseline_out = linear_forward_8da4w(x2, fq_linear.weight) + torch.testing.assert_close(baseline_out, fq_out, atol=0, rtol=0) + + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + def test_fake_quantized_linear_4w(self): + """ + Test that we can express int4 weight only (tinygemm) with `FakeQuantizedLinear`. + """ + group_size = 128 + weight_config = FakeQuantizeConfig( + bit_width=4, + group_size=group_size, + symmetric=False, + zero_point_domain=ZeroPointDomain.FLOAT, + ) + torch.manual_seed(self.SEED) + fq_linear = FakeQuantizedLinear( + 256, + 688, + bias=False, + activation_config=None, + weight_config=weight_config, + ) + + def linear_forward_4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: + """ + Baseline for int4 weight only fake quantization that simulates the tinygemm kernel. + """ + (qmin, qmax) = _get_qmin_qmax(4, symmetric=False) + (s, zp) = get_groupwise_affine_qparams(weight, 4, group_size, torch.float32) + zp = zp.to(torch.int32) + w_fq = _fake_quantize_per_channel_group( + weight, s, zp, qmin, qmax, group_size, zero_point_domain=ZeroPointDomain.FLOAT, + ) + return F.linear(x, w_fq) + + # Compare linear values + torch.manual_seed(self.SEED) + x = torch.randn(100, 256) + x2 = copy.deepcopy(x) + fq_out = fq_linear(x) + baseline_out = linear_forward_4w(x2, fq_linear.weight) + torch.testing.assert_close(baseline_out, fq_out, atol=0, rtol=0) + + if __name__ == "__main__": unittest.main() diff --git a/torchao/quantization/prototype/qat/api.py b/torchao/quantization/prototype/qat/api.py index 93717271b..9c5828add 100644 --- a/torchao/quantization/prototype/qat/api.py +++ b/torchao/quantization/prototype/qat/api.py @@ -4,11 +4,76 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, List +from dataclasses import dataclass +from enum import Enum +from typing import Any, List, Optional import torch from torchao.quantization.unified import TwoStepQuantizer +from torchao.quantization.quant_primitives import ZeroPointDomain + + +# TODO: change this to quant_primitives.Granularity +class QuantizationGranularity(Enum): + PER_CHANNEL = "per_channel" + PER_TOKEN = "per_token" + PER_GROUP = "per_group" + + +@dataclass +class FakeQuantizeConfig: + """ + Config for how to fake quantize weights or activations. + + args: + bit_width: number of bits to simulate during fake quantization + granularity: granularity of scales and zero points, one of: + 'per_token', 'per_channel', or 'per_group' + group_size: size of each group for 'per_group' granularity + symmetric: whether to use symmetric (default) or asymmetric quantization + scale_precision: scale dtype (default torch.fp32) + zero_point_precision: zero point dtype (default torch.int32) + zero_point_domain: whether zero point is in integer (default) or float domain + dynamic: whether to use dynamic (defualt) or static scale and zero points + range_learning: whether to learn scale and zero points during training (coming soon) + """ + bit_width: int + granularity: Optional[QuantizationGranularity] = None + group_size: Optional[int] = None + symmetric: bool = True + scale_precision: torch.dtype = torch.float32 + zero_point_precision: torch.dtype = torch.int32 + zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT + dynamic: bool = True + range_learning: bool = False + + def __post_init__(self): + """ + Verify that `group_size` and `granularity` are consistent. + """ + if self.group_size is None and self.granularity is None: + raise ValueError("At least one of group_size or granularity must be set") + if self.granularity == QuantizationGranularity.PER_GROUP and self.group_size is None: + raise ValueError("Granularity is 'per_group' but no group_size was set") + if self.granularity != QuantizationGranularity.PER_GROUP and self.group_size is not None: + if self.granularity is None: + self.granularity = QuantizationGranularity.PER_GROUP + else: + raise ValueError( + "Granularity is '%s' but group_size was set" % self.granularity.value + ) + self._initialized = True + + def __setattr__(self, name: str, value: Any): + """ + Support setting `granularity` by string and through `group_size`. + """ + if name == "group_size" and getattr(self, "_initialized", False): + super().__setattr__("granularity", QuantizationGranularity.PER_GROUP) + if name == "granularity" and isinstance(value, str): + value = QuantizationGranularity(value) + super().__setattr__(name, value) class ComposableQATQuantizer(TwoStepQuantizer): diff --git a/torchao/quantization/prototype/qat/fake_quantizer.py b/torchao/quantization/prototype/qat/fake_quantizer.py new file mode 100644 index 000000000..7c7e09f8b --- /dev/null +++ b/torchao/quantization/prototype/qat/fake_quantizer.py @@ -0,0 +1,116 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional, Tuple + +import torch + +from torchao.quantization.utils import ( + get_group_qparams_symmetric, + get_groupwise_affine_qparams, +) +from .api import ( + FakeQuantizeConfig, + QuantizationGranularity, +) +from .utils import ( + _choose_qparams_per_token_asymmetric, + _fake_quantize_per_channel_group, + _fake_quantize_per_token, + _get_qmin_qmax, +) + + +class FakeQuantizer(torch.nn.Module): + """ + Generic module for applying fake quantization to a tensor, as specified in the config. + """ + def __init__(self, config: FakeQuantizeConfig): + super().__init__() + self.config = config + self.enabled = True + self.scale: Optional[torch.Tensor] = None + self.zero_point: Optional[torch.Tensor] = None + + # TODO: support range learinng + if self.config.range_learning: + raise NotImplementedError("Range learning is not supported yet") + + def forward(self, x: torch.Tensor): + """ + Apply fake quantization to the tensor based on the bit-width, + granularity, symmetry, and other properties specified in the config. + """ + if not self.enabled: + return x + + if self.config.granularity == QuantizationGranularity.PER_TOKEN: + return self._per_token_forward(x) + elif self.config.granularity in [ + QuantizationGranularity.PER_CHANNEL, + QuantizationGranularity.PER_GROUP, + ]: + return self._per_channel_or_group_forward(x) + else: + raise ValueError("Unknown granularity %s" % self.config.granularity) + + def _per_token_forward(self, x: torch.Tensor): + """ + Perform per token fake quantization on the tensor. + """ + if self.config.symmetric: + raise NotImplementedError("Symmetric per token is not supported yet") + if self._should_compute_qparams(): + (self.scale, self.zero_point) = _choose_qparams_per_token_asymmetric( + x, self.config.scale_precision, self.config.zero_point_precision, + ) + qmin, qmax = _get_qmin_qmax(self.config.bit_width) + return _fake_quantize_per_token(x, self.scale, self.zero_point, qmin, qmax) + + def _per_channel_or_group_forward(self, x: torch.Tensor): + """ + Perform per channel or per group fake quantization on the tensor. + We express per channel using per group where the group size is the size + of the last dimension of the tensor. + """ + bit_width = self.config.bit_width + granularity = self.config.granularity + scale_precision = self.config.scale_precision + zero_point_precision = self.config.zero_point_precision + zero_point_domain = self.config.zero_point_domain + symmetric = self.config.symmetric + + # get group size + if granularity == QuantizationGranularity.PER_CHANNEL: + group_size = x.size()[-1] + elif granularity == QuantizationGranularity.PER_GROUP: + assert self.config.group_size is not None + group_size = self.config.group_size + else: + raise ValueError("Group size not defined for granularity %s" % granularity) + + # get scales and zero points + if self._should_compute_qparams(): + if symmetric: + (self.scale, self.zero_point) = get_group_qparams_symmetric( + x, bit_width, group_size, scale_precision, + ) + else: + (self.scale, self.zero_point) = get_groupwise_affine_qparams( + x, bit_width, group_size, scale_precision, + ) + self.zero_point = self.zero_point.to(zero_point_precision) + + qmin, qmax = _get_qmin_qmax(bit_width, symmetric) + return _fake_quantize_per_channel_group( + x, self.scale, self.zero_point, qmin, qmax, group_size, zero_point_domain, + ) + + def _should_compute_qparams(self) -> bool: + """ + Return whether we need to compute new scales and zero points. + """ + return self.config.dynamic or self.scale is None or self.zero_point is None diff --git a/torchao/quantization/prototype/qat/linear.py b/torchao/quantization/prototype/qat/linear.py index 07276ba84..32f560189 100644 --- a/torchao/quantization/prototype/qat/linear.py +++ b/torchao/quantization/prototype/qat/linear.py @@ -21,6 +21,8 @@ from torchao.quantization.quant_primitives import ZeroPointDomain from torchao.quantization.unified import TwoStepQuantizer from torchao.quantization.utils import get_group_qparams_symmetric +from .api import FakeQuantizeConfig +from .fake_quantizer import FakeQuantizer from .utils import ( _choose_qparams_per_token_asymmetric, _fake_quantize_per_channel_group, @@ -29,6 +31,79 @@ ) +class FakeQuantizedLinear(torch.nn.Linear): + """ + General linear layer with fake quantized weights and activations. + + Specific fake quantization bit widths, granularity, schemes etc. are specified + through separate configs for weights and activations. + + Example usage:: + + activation_config = FakeQuantizeConfig( + bit_width=8, + granularity="per_token", + symmetric=False, + ) + weight_config = FakeQuantizeConfig( + bit_width=4, + group_size=8, + symmetric=True, + ) + fq_linear = FakeQuantizedLinear( + 16, 32, False, activation_config, weight_config, + ) + fq_linear(torch.randn(16)) + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = False, + activation_config: Optional[FakeQuantizeConfig] = None, + weight_config: Optional[FakeQuantizeConfig] = None, + *args, + **kwargs, + ) -> None: + super().__init__( + in_features, + out_features, + bias, + *args, + **kwargs, + ) + if bias: + raise NotImplementedError("bias not supported yet") + + # initialize activation fake quantizer + if activation_config is not None: + self.activation_fake_quantizer = FakeQuantizer(activation_config) + else: + self.activation_fake_quantizer = None + + # initialize weight fake quantizer + if weight_config is not None: + group_size = weight_config.group_size + if group_size is not None and in_features % group_size != 0: + raise ValueError( + "in_features (%s) % group_size (%s) must be == 0" % + (in_features, group_size) + ) + self.weight_fake_quantizer = FakeQuantizer(weight_config) + else: + self.weight_fake_quantizer = None + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.activation_fake_quantizer is not None: + x = self.activation_fake_quantizer(x) + if self.weight_fake_quantizer is not None: + w = self.weight_fake_quantizer(self.weight) + else: + w = self.weight + return F.linear(x, w) + + # ========================================================= # | Linear int8 dynamic activations + int4 weight QAT | # ========================================================= @@ -77,42 +152,42 @@ def convert( *args: Any, **kwargs: Any ) -> torch.nn.Module: - _convert_qat_linear_8da4w(model) + self._convert_qat_linear_8da4w(model) return model - -def _convert_qat_linear_8da4w(module: torch.nn.Module): - """ - Replace all `Int8DynActInt4WeightQATLinear` with `Int8DynActInt4WeightLinear`. - """ - for name, child in module.named_children(): - if isinstance(child, Int8DynActInt4WeightQATLinear): - quantized_linear = Int8DynActInt4WeightLinear( - child.in_features, - child.out_features, - bias=False, - groupsize=child.groupsize, - precision=child.precision, - scales_precision=child.scales_precision, - ) - setattr(module, name, quantized_linear) - - # Load weights and qparams into quantized linear - n_bit = 4 - (qmin, qmax) = _get_qmin_qmax(n_bit) - (s, zp) = get_group_qparams_symmetric(child.weight, n_bit, child.groupsize) - from torchao._executorch_ops import _quantized_decomposed_quantize_per_channel_group_wrapper - q_weight = _quantized_decomposed_quantize_per_channel_group_wrapper( - child.weight, s, zp, qmin, qmax, torch.int8, child.groupsize, - ) - quantized_linear.weight = q_weight - quantized_linear.scales = s - quantized_linear.zeros = zp - else: - _convert_qat_linear_8da4w(child) - - -class Int8DynActInt4WeightQATLinear(torch.nn.Linear): + def _convert_qat_linear_8da4w(self, module: torch.nn.Module): + """ + Replace all `Int8DynActInt4WeightQATLinear` with `Int8DynActInt4WeightLinear`. + """ + for name, child in module.named_children(): + if isinstance(child, Int8DynActInt4WeightQATLinear): + config = child.weight_fake_quantizer.config + quantized_linear = Int8DynActInt4WeightLinear( + child.in_features, + child.out_features, + bias=False, + groupsize=config.group_size, + precision=child.weight.dtype, + scales_precision=config.scale_precision, + ) + setattr(module, name, quantized_linear) + + # Load weights and qparams into quantized linear + n_bit = 4 + (qmin, qmax) = _get_qmin_qmax(n_bit) + (s, zp) = get_group_qparams_symmetric(child.weight, n_bit, config.group_size) + from torchao._executorch_ops import _quantized_decomposed_quantize_per_channel_group_wrapper + q_weight = _quantized_decomposed_quantize_per_channel_group_wrapper( + child.weight, s, zp, qmin, qmax, torch.int8, config.group_size, + ) + quantized_linear.weight = q_weight + quantized_linear.scales = s + quantized_linear.zeros = zp + else: + self._convert_qat_linear_8da4w(child) + + +class Int8DynActInt4WeightQATLinear(FakeQuantizedLinear): """ This module implements a linear layer with int8 dynamic per token fake quantized activations with int4 fake quantized grouped per channel weights. @@ -133,63 +208,39 @@ def __init__( precision: torch.dtype = torch.float32, scales_precision: torch.dtype = torch.float32, ) -> None: + activation_config = FakeQuantizeConfig( + bit_width=8, + granularity="per_token", + symmetric=False, + dynamic=True, + scale_precision=scales_precision, + zero_point_precision=scales_precision, + ) + weight_config = FakeQuantizeConfig( + bit_width=4, + group_size=groupsize, + symmetric=True, + dynamic=True, + scale_precision=scales_precision, + zero_point_precision=scales_precision, + ) super().__init__( in_features, out_features, bias, + activation_config, + weight_config, device=device, dtype=precision, ) - assert ( - in_features % groupsize == 0 - ), f"require in_features:{in_features} % groupsize:{groupsize} == 0" - assert not bias, "require bias=False" - self.groupsize = groupsize - self.precision = precision - self.scales_precision = scales_precision - # TODO: make this configurable? - self.zero_points_precision = torch.int32 - self._fake_quant_enabled = True def enable_fake_quant(self, enabled: bool = True): - self._fake_quant_enabled = enabled + self.activation_fake_quantizer.enabled = enabled + self.weight_fake_quantizer.enabled = enabled def disable_fake_quant(self): self.enable_fake_quant(False) - def forward(self, x: torch.Tensor) -> torch.Tensor: - # activations: int8 dynamic asymmetric quant - if self._fake_quant_enabled: - (act_scales, act_zp) = _choose_qparams_per_token_asymmetric( - x, self.scales_precision, self.zero_points_precision, - ) - (act_qmin, act_qmax) = _get_qmin_qmax(8) - x_fq = _fake_quantize_per_token( - x, act_scales, act_zp, act_qmin, act_qmax, - ) - else: - x_fq = x - - # weights: int4 grouped per channel symmetric quant - if self._fake_quant_enabled: - (weight_scales, weight_zp) = get_group_qparams_symmetric( - self.weight, 4, self.groupsize, self.scales_precision, - ) - # TODO: pass zp dtype to `get_group_qparams_symmetric` instead - weight_zp = weight_zp.to(self.zero_points_precision) - (weight_qmin, weight_qmax) = _get_qmin_qmax(4) - w_fq = _fake_quantize_per_channel_group( - self.weight, - weight_scales, - weight_zp, - weight_qmin, - weight_qmax, - self.groupsize, - ) - else: - w_fq = self.weight - return F.linear(x_fq, w_fq) - def enable_8da4w_fake_quant(mod: torch.nn.Module): """ @@ -257,46 +308,45 @@ def convert( *args: Any, **kwargs: Any ) -> torch.nn.Module: - _convert_qat_linear_4w(model) + self._convert_qat_linear_4w(model) return model - -def _convert_qat_linear_4w(module: torch.nn.Module): - """ - Replace all `Int4WeightOnlyQATLinear` with `WeightOnlyInt4Linear`. - """ - for name, child in module.named_children(): - if isinstance(child, Int4WeightOnlyQATLinear): - in_features = child.in_features - out_features = child.out_features - groupsize = child.groupsize - inner_k_tiles = child.inner_k_tiles - quantized_linear = WeightOnlyInt4Linear( - in_features, - out_features, - bias=False, - groupsize=groupsize, - inner_k_tiles=inner_k_tiles, - precision=child.precision, - scales_precision=child.scales_precision, - ) - setattr(module, name, quantized_linear) - - # Load weights and qparams into quantized linear - n_bit = 4 - (q_weight, scales_and_zeros) = groupwise_affine_quantize_tensor( - child.weight, n_bit, child.groupsize, - ) - q_weight = torch.ops.aten._convert_weight_to_int4pack( - q_weight.to(child.weight.device), child.inner_k_tiles, - ) - quantized_linear.weight = q_weight - quantized_linear.scales_and_zeros = scales_and_zeros - else: - _convert_qat_linear_4w(child) - - -class Int4WeightOnlyQATLinear(torch.nn.Linear): + def _convert_qat_linear_4w(self, module: torch.nn.Module): + """ + Replace all `Int4WeightOnlyQATLinear` with `WeightOnlyInt4Linear`. + """ + for name, child in module.named_children(): + if isinstance(child, Int4WeightOnlyQATLinear): + in_features = child.in_features + out_features = child.out_features + inner_k_tiles = child.inner_k_tiles + config = child.weight_fake_quantizer.config + quantized_linear = WeightOnlyInt4Linear( + in_features, + out_features, + bias=False, + groupsize=config.group_size, + inner_k_tiles=inner_k_tiles, + precision=child.weight.dtype, + scales_precision=config.scale_precision, + ) + setattr(module, name, quantized_linear) + + # Load weights and qparams into quantized linear + n_bit = 4 + (q_weight, scales_and_zeros) = groupwise_affine_quantize_tensor( + child.weight, n_bit, config.group_size, + ) + q_weight = torch.ops.aten._convert_weight_to_int4pack( + q_weight.to(child.weight.device), child.inner_k_tiles, + ) + quantized_linear.weight = q_weight + quantized_linear.scales_and_zeros = scales_and_zeros + else: + self._convert_qat_linear_4w(child) + + +class Int4WeightOnlyQATLinear(FakeQuantizedLinear): """ This module implements a linear layer with int4 fake quantized grouped per channel weights, with forward numerics matching `WeightOnlyInt4Linear`, @@ -319,47 +369,36 @@ def __init__( precision: torch.dtype = torch.bfloat16, scales_precision: torch.dtype = torch.bfloat16, ) -> None: + assert scales_precision == torch.bfloat16, "only bf16 is supported for scales" + if not _check_linear_int4_k(in_features, groupsize, inner_k_tiles): + raise ValueError("Padding for QAT 4w is not supported yet") + self.inner_k_tiles = inner_k_tiles + weight_config = FakeQuantizeConfig( + bit_width=4, + group_size=groupsize, + symmetric=False, + dynamic=True, + scale_precision=scales_precision, + zero_point_precision=scales_precision, + zero_point_domain=ZeroPointDomain.FLOAT, + ) super().__init__( in_features, out_features, bias, + activation_config=None, + weight_config=weight_config, device=device, dtype=precision, ) - assert not bias, "require bias=False" - assert scales_precision == torch.bfloat16, "only bf16 is supported for scales" - if not _check_linear_int4_k(in_features, groupsize, inner_k_tiles): - raise ValueError("Padding for QAT 4w is not supported yet") - self.groupsize = groupsize - self.inner_k_tiles = inner_k_tiles - self.precision = precision - self.scales_precision = scales_precision - self._fake_quant_enabled = True def enable_fake_quant(self, enabled: bool = True): - self._fake_quant_enabled = enabled + self.activation_fake_quantizer.enabled = enabled + self.weight_fake_quantizer.enabled = enabled def disable_fake_quant(self): self.enable_fake_quant(False) - def forward(self, x: torch.Tensor) -> torch.Tensor: - n_bit = 4 - qmin = 0 - qmax = 2 ** n_bit - 1 - scales, zero_points = get_groupwise_affine_qparams( - self.weight, n_bit, self.groupsize, self.scales_precision, - ) - w_fq = _fake_quantize_per_channel_group( - self.weight, - scales, - zero_points, - qmin, - qmax, - self.groupsize, - ZeroPointDomain.FLOAT, - ) - return F.linear(x, w_fq) - def enable_4w_fake_quant(mod: torch.nn.Module): """ diff --git a/torchao/quantization/prototype/qat/utils.py b/torchao/quantization/prototype/qat/utils.py index 354475e65..8f2dd9d13 100644 --- a/torchao/quantization/prototype/qat/utils.py +++ b/torchao/quantization/prototype/qat/utils.py @@ -181,7 +181,11 @@ def _choose_qparams_per_token_asymmetric( return scale.to(scales_precision), zero_point.to(zero_points_precision) -def _get_qmin_qmax(n_bit: int): - qmin = -(2 ** (n_bit - 1)) - qmax = 2 ** (n_bit - 1) - 1 +def _get_qmin_qmax(n_bit: int, symmetric: bool=True): + if symmetric: + qmin = -(2 ** (n_bit - 1)) + qmax = 2 ** (n_bit - 1) - 1 + else: + qmin = 0 + qmax = 2 ** n_bit - 1 return (qmin, qmax)