From 2a8dc5d5835a84097a399154e7fa34fbf89a1bc3 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Tue, 7 May 2024 13:38:44 -0700 Subject: [PATCH] Enable dispatch to tinygemm int4 and int8 kernels for unified quantized tensor Summary: This adds some dispatch to the tinygemm kernels for cuda, although need to resolve implementation mismatch problem for tinygemm first Test Plan: python test/quantization/test_quant_api.py -k test_quantized_tensor_subclass_int4 python test/quantization/test_quant_api.py -k test_quantized_tensor_subclass_int8 Reviewers: Subscribers: Tasks: Tags: --- test/quantization/test_quant_api.py | 105 +++++++++++++++++++-- test/quantization/test_quant_primitives.py | 11 +-- torchao/quantization/autoquant.py | 1 + torchao/quantization/quant_primitives.py | 85 +++++++++++++---- torchao/quantization/subclass.py | 84 +++++++++++++++-- 5 files changed, 243 insertions(+), 43 deletions(-) diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index 10d36f0c1..cea659e61 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -9,7 +9,6 @@ import unittest import torch import os -from torch._export import capture_pre_autograd_graph from torch.ao.quantization.quantize_pt2e import ( prepare_pt2e, convert_pt2e, @@ -36,7 +35,7 @@ def dynamic_quant(model, example_inputs): - m = capture_pre_autograd_graph(model, example_inputs) + m = torch.export.export(model, example_inputs).module() quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config(is_dynamic=True)) m = prepare_pt2e(m, quantizer) m = convert_pt2e(m) @@ -50,14 +49,14 @@ def _apply_dynamic_quant(model): """ _replace_with_custom_fn_if_matches_filter( model, - lambda linear_mod: dynamic_quant(linear_mod, (torch.randn(1, linear_mod.in_features))), + lambda linear_mod: dynamic_quant(linear_mod, (torch.randn(1, linear_mod.in_features),)), lambda mod, fqn: isinstance(mod, torch.nn.Linear), ) return model def capture_and_prepare(model, example_inputs): - m = capture_pre_autograd_graph(model, example_inputs) + m = torch.export.export(model, example_inputs) quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config(is_dynamic=True)) m = prepare_pt2e(m, quantizer) # TODO: we can run the weight observer in convert_pt2e so that user don't need to run this @@ -88,13 +87,13 @@ def quantize(self, model: torch.nn.Module) -> torch.nn.Module: return model class ToyLinearModel(torch.nn.Module): - def __init__(self): + def __init__(self, m=64, n=32, k=64): super().__init__() - self.linear1 = torch.nn.Linear(64, 32, bias=False).to(torch.float) - self.linear2 = torch.nn.Linear(32, 64, bias=False).to(torch.float) + self.linear1 = torch.nn.Linear(m, n, bias=False).to(torch.float) + self.linear2 = torch.nn.Linear(n, k, bias=False).to(torch.float) def example_inputs(self): - return (torch.randn(1, 64).to(torch.float),) + return (torch.randn(1, self.linear1.in_features).to(torch.float),) def forward(self, x): x = self.linear1(x) @@ -104,8 +103,9 @@ def forward(self, x): class TestQuantFlow(unittest.TestCase): def test_dynamic_quant_gpu_singleline(self): m = ToyLinearModel().eval() + example_inputs = m.example_inputs() m = _apply_dynamic_quant(m) - quantized = m(*m.example_inputs()) + quantized = m(*example_inputs) # AssertionError: Expecting input to have dtype torch.float32, but got dtype: torch.float64 # While executing %choose_qparams_tensor_1 : [num_users=2] = call_function[target=torch.ops.quantized_decomposed.choose_qparams.tensor](args = (%arg0_3, -128, 127, 0.000244140625, torch.int8), kwargs = {}) # m = torch.compile(m, mode="max-autotune") @@ -442,7 +442,94 @@ def get_per_token_block_size(x): ref = m_copy(*example_inputs) self.assertTrue(torch.equal(res, ref)) + @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+") + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + def test_quantized_tensor_subclass_int4(self): + from torchao.quantization.subclass import AffineQuantizedTensor + from torchao.quantization.quant_primitives import MappingType + from torchao.quantization.quant_primitives import ZeroPointDomain + import copy + + # weight settings + groupsize = 32 + mapping_type = MappingType.ASYMMETRIC + block_size = (1, groupsize) + target_dtype = torch.int32 + quant_min = 0 + quant_max = 15 + eps = 1e-6 + preserve_zero = False + zero_point_dtype = torch.bfloat16 + + # weight only quantization + input_quant_func = None + + # use 1024 so that we don't need padding + m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda") + m_copy = copy.deepcopy(m) + example_inputs = tuple(map(lambda x: x.to(torch.bfloat16).to("cuda"), m.example_inputs())) + + def to_quantized(weight): + return AffineQuantizedTensor.from_float( + weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, + zero_point_dtype=zero_point_dtype, + preserve_zero=preserve_zero, + zero_point_domain=ZeroPointDomain.FLOAT, + input_quant_func=input_quant_func, + ) + + m.linear1.weight = torch.nn.Parameter(to_quantized(m.linear1.weight), requires_grad=False) + m.linear2.weight = torch.nn.Parameter(to_quantized(m.linear2.weight), requires_grad=False) + assert isinstance(m.linear1.weight, AffineQuantizedTensor) + assert isinstance(m.linear2.weight, AffineQuantizedTensor) + + # reference + from torchao.quantization.quant_api import change_linear_weights_to_int4_woqtensors + change_linear_weights_to_int4_woqtensors(m_copy, groupsize=groupsize) + + res = m(*example_inputs) + ref = m_copy(*example_inputs) + + self.assertTrue(torch.equal(res, ref)) + + + @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+") + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + def test_quantized_tensor_subclass_int8(self): + from torchao.quantization.subclass import AffineQuantizedTensor + from torchao.quantization.quant_primitives import MappingType + import copy + + # weight settings + mapping_type = MappingType.SYMMETRIC + target_dtype = torch.int8 + eps = torch.finfo(torch.float32).eps + zero_point_dtype = torch.int64 + + # weight only quantization + input_quant_func = None + + m = ToyLinearModel().eval().to(torch.bfloat16) + m_copy = copy.deepcopy(m) + example_inputs = tuple(map(lambda x: x.to(torch.bfloat16), m.example_inputs())) + + def to_quantized(weight): + block_size = (1, weight.shape[1]) + return AffineQuantizedTensor.from_float(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype, input_quant_func=input_quant_func) + + m.linear1.weight = torch.nn.Parameter(to_quantized(m.linear1.weight), requires_grad=False) + m.linear2.weight = torch.nn.Parameter(to_quantized(m.linear2.weight), requires_grad=False) + assert isinstance(m.linear1.weight, AffineQuantizedTensor) + assert isinstance(m.linear2.weight, AffineQuantizedTensor) + + # reference + from torchao.quantization.quant_api import change_linear_weights_to_int8_woqtensors + change_linear_weights_to_int8_woqtensors(m_copy) + + res = m(*example_inputs) + ref = m_copy(*example_inputs) + torch.testing.assert_close(res, ref, rtol=0.00001, atol=1e-2) if __name__ == "__main__": diff --git a/test/quantization/test_quant_primitives.py b/test/quantization/test_quant_primitives.py index 291039e42..a64439a25 100644 --- a/test/quantization/test_quant_primitives.py +++ b/test/quantization/test_quant_primitives.py @@ -327,6 +327,8 @@ def test_not_preserve_zero_not_supported(self): def test_tinygemm_get_groupwise_affine_qparams(self): + from torchao.quantization.quant_primitives import ZeroPointDomain + input = torch.randn(10, 256) n_bit = 4 scale_ref, zero_point_ref = get_groupwise_affine_qparams(input, n_bit=n_bit, groupsize=128, dtype=torch.bfloat16) @@ -351,16 +353,11 @@ def test_tinygemm_get_groupwise_affine_qparams(self): scale_dtype=scale_dtype, zero_point_dtype=zero_point_dtype, preserve_zero=False, + zero_point_domain=ZeroPointDomain.FLOAT, ) - def int_zero_point_to_float(zero_point, scale, qaunt_min, mid_point): - return (quant_min - zero_point + mid_point) * scale - - mid_point = 2 ** (n_bit - 1) - zero_point_float = int_zero_point_to_float(zero_point, scale, quant_min, mid_point) - self.assertTrue(torch.equal(scale, scale_ref)) - torch.testing.assert_close(zero_point_float, zero_point_ref, rtol=0.00001, atol=torch.max(scale)*0.03) + self.assertTrue(torch.equal(zero_point, zero_point_ref)) if __name__ == "__main__": diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index 4331d9b04..4c0ae53ce 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -9,6 +9,7 @@ quantize_activation_per_token_absmax, safe_int_mm, ) +from .utils import TORCH_VERSION_AFTER_2_4 import torch.nn.functional as F try: from torch._inductor.utils import do_bench diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index 3975284b6..4f39a6055 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -72,6 +72,14 @@ def guard_dtype_size(tensor_arg, arg_name, dtype=None, size=None): torch.uint7: (0, 2**7-1), }) +class MappingType(Enum): + SYMMETRIC = 0 + ASYMMETRIC = 1 + +class ZeroPointDomain(Enum): + INT = 0 + FLOAT = 1 + # TODO: decide on if we want to allow custom quant_min/quant_max here def _get_and_check_qmin_qmax(dtype, quant_min, quant_max): """Get quant_min and quant_max args based on dtype and also @@ -141,7 +149,8 @@ def quantize_affine( zero_point: Optional[torch.Tensor], output_dtype: torch.dtype, quant_min: Optional[int] = None, - quant_max: Optional[int] = None + quant_max: Optional[int] = None, + zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, ): """ Args: @@ -153,6 +162,12 @@ def quantize_affine( output_dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor quant_min (Optional[int]): minimum quantized value for output Tensor, if not specified, it will be derived from dtype quant_max (Optional[int]): maximum quantized value for output Tensor, if not specified, it will be derived from dtype + zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be eitehr integer or float + if zero_point is in integer domain, zero point is added to the quantized integer value during + quantization + if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized) + value during quantization + default is ZeroPointDomain.INT Note: How can block_size represent different granularities? @@ -184,9 +199,19 @@ def quantize_affine( if zero_point is not None: zero_point = zero_point.view(shape_after_reduction) - quant = torch.clamp( - torch.round(input / scale) + zero_point, quant_min, quant_max - ).to(output_dtype) + if zero_point_domain == ZeroPointDomain.INT: + quant = torch.clamp( + torch.round(input / scale) + zero_point, quant_min, quant_max + ).to(output_dtype) + else: + assert zero_point_domain == ZeroPointDomain.FLOAT + mid_point = (quant_max + quant_min + 1) / 2 + min_val = zero_point - scale * mid_point + quant = ( + torch.clamp( + torch.round((input - min_val) / scale), + quant_min, quant_max) + ).to(output_dtype) quant = quant.view(original_shape) return quant @@ -199,6 +224,7 @@ def dequantize_affine( input_dtype: torch.dtype, quant_min: Optional[int] = None, quant_max: Optional[int] = None, + zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, *, output_dtype: torch.dtype = torch.float32, ): @@ -213,6 +239,12 @@ def dequantize_affine( quant_min (Optional[int]): minimum quantized value for input Tensor quant_max (Optional[int]): maximum quantized value for input Tensor output_dtype (torch.dtype): dtype for output Tensor, default is fp32 + zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be eitehr integer or float + if zero_point is in integer domain, zero point is added to the quantized integer value during + quantization + if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized) + value during quantization + default is ZeroPointDomain.INT Output: dequantized Tensor, with requested dtype or fp32 @@ -233,18 +265,22 @@ def dequantize_affine( if zero_point is not None: zero_point = zero_point.view(shape_after_reduction) - dequant = input.to(torch.int32) - if zero_point is not None: - dequant -= zero_point.to(torch.int32) - dequant = dequant.to(output_dtype) - dequant *= scale - dequant = dequant.view(original_shape) - return dequant.to(output_dtype) + if zero_point_domain == ZeroPointDomain.INT: + dequant = input.to(torch.int32) + if zero_point is not None: + dequant -= zero_point.to(torch.int32) + dequant = dequant.to(output_dtype) + dequant *= scale + else: + assert zero_point_domain == ZeroPointDomain.FLOAT, f"Unexpected zero point domain: {zero_point_domain}" + mid_point = (quant_max + quant_min + 1) / 2 + dequant = input - mid_point + dequant = dequant.to(output_dtype) + dequant *= scale + if zero_point is not None: + dequant += zero_point - -class MappingType(Enum): - SYMMETRIC = 0 - ASYMMETRIC = 1 + return dequant.view(original_shape).to(output_dtype) def choose_qparams_affine( input: torch.Tensor, @@ -256,7 +292,8 @@ def choose_qparams_affine( eps: Optional[float] = None, scale_dtype: Optional[torch.dtype] = None, zero_point_dtype: Optional[torch.dtype] = None, - preserve_zero = True, + preserve_zero: bool = True, + zero_point_domain = ZeroPointDomain.INT, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: @@ -280,6 +317,13 @@ def choose_qparams_affine( If we don't need zero to be exactly representable, we won't do rounding and clamping for zero_point + zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be eitehr integer or float + if zero_point is in integer domain, zero point is added to the quantized integer value during + quantization + if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized) + value during quantization + default is ZeroPointDomain.INT + Output: Tuple of scales and zero_points Tensor with requested dtype """ @@ -310,15 +354,18 @@ def choose_qparams_affine( scale = max_val_pos / (float(quant_max - quant_min) / 2) if not preserve_zero: raise ValueError("preserve_zero == False is not supported for symmetric quantization") - zero_point = torch.full_like(scale, int((quant_min + quant_max + 1) / 2)) + if zero_point_domain != ZeroPointDomain.INT: + raise ValueError("zero_point_domain != ZeroPointDomain.INT is not supported for symmetric quantization") + zero_point = torch.full_like(scale, int((quant_max + quant_min + 1) / 2)) else: scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min) if preserve_zero: zero_point = quant_min - torch.round(min_val_neg / scale) zero_point = torch.clamp(zero_point, quant_min, quant_max) else: - zero_point = quant_min - min_val_neg / scale - + assert zero_point_domain == ZeroPointDomain.FLOAT, "if not preserve_zero, zero_point must be in FLOAT domain" + mid_point = (quant_max + quant_min + 1) / 2 + zero_point = min_val_neg + scale * mid_point if eps is None: eps = torch.finfo(input.dtype).eps diff --git a/torchao/quantization/subclass.py b/torchao/quantization/subclass.py index 6128720d4..607cb7776 100644 --- a/torchao/quantization/subclass.py +++ b/torchao/quantization/subclass.py @@ -14,10 +14,13 @@ dynamically_quantize_per_channel, groupwise_affine_quantize_tensor, quant_int8_dynamic_per_token_linear, + pack_tinygemm_scales_and_zeros, unpack_tinygemm_scales_and_zeros, + groupwise_affine_quantize_tensor_from_qparams, choose_qparams_affine, quantize_affine, dequantize_affine, + ZeroPointDomain, ) from .utils import find_multiple from typing import Tuple, Optional, Callable @@ -619,7 +622,13 @@ class AffineQuantizedTensor(torch.Tensor): shape (torch.Size): the shape for the Tensor quant_min (Optional[int]): minimum quantized value for the Tensor, if not specified, it will be derived from dtype of `int_data` quant_max (Optional[int]): maximum quantized value for the Tensor, if not specified, it will be derived from dtype of `int_data` - input_quant_func (Optional[Callable]): function for quantizing the input float Tensor to a quantized tensor subclass object, that takes input Tensor as input and outputs an AffineQuantizedTensor object + zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be eitehr integer or float + if zero_point is in integer domain, zero point is added to the quantized integer value during + quantization + if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized) + value during quantization + default is ZeroPointDomain.INT + input_quant_func (Optional[Callable]): function for quantizing the input float Tensor to a quantized tensor subclass object, that takes float Tensor as input and outputs an AffineQuantizedTensor object dtype: dtype for external representation of the tensor, e.g. torch.float32 """ @@ -633,8 +642,10 @@ def __new__( shape: torch.Size, quant_min: Optional[int] = None, quant_max: Optional[int] = None, + zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, input_quant_func: Optional[Callable] = None, dtype=None, + # TODO: remove args and kwargs *args, **kwargs ): @@ -658,6 +669,7 @@ def __init__( shape: torch.Size, quant_min: Optional[int] = None, quant_max: Optional[int] = None, + zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, input_quant_func: Optional[Callable] = None, dtype=None, *args, @@ -669,6 +681,7 @@ def __init__( self.block_size = block_size self.quant_min = quant_min self.quant_max = quant_max + self.zero_point_domain = zero_point_domain self.input_quant_func = input_quant_func def __repr__(self): @@ -677,18 +690,20 @@ def __repr__(self): f"device={self.device}, dtype={self.dtype}, input_quant_func={self.input_quant_func}, requires_grad={self.requires_grad})" ) - def dequantize(self, output_dtype=torch.float32): - return dequantize_affine(self.int_data, self.block_size, self.scale, self.zero_point, self.int_data.dtype, self.quant_min, self.quant_max, output_dtype=output_dtype) + def dequantize(self, output_dtype=None): + if output_dtype is None: + output_dtype = self.dtype + return dequantize_affine(self.int_data, self.block_size, self.scale, self.zero_point, self.int_data.dtype, self.quant_min, self.quant_max, self.zero_point_domain, output_dtype=output_dtype) def __tensor_flatten__(self): - return ["int_data", "scales", "zero_point"], [self.block_size, self.shape, self.quant_min, self.quant_max, self.input_quant_func, self.dtype] + return ["int_data", "scales", "zero_point"], [self.block_size, self.shape, self.quant_min, self.quant_max, self.zero_point_domain, self.input_quant_func, self.dtype] @classmethod def __tensor_unflatten__( cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride ): int_data, scale, zero_point = tensor_data_dict["int_data"], tensor_data_dict["scale"], tensor_data_dict["zero_point"] - block_size, shape, quant_min, quant_max, input_quant_func, dtype = tensor_attributes + block_size, shape, quant_min, quant_max, zero_point_domain, input_quant_func, dtype = tensor_attributes return cls( int_data, scale, @@ -697,6 +712,7 @@ def __tensor_unflatten__( shape if outer_size is None else outer_size, quant_min, quant_max, + zero_point_domain, input_quant_func=input_quant_func, dtype=dtype, strides=outer_stride, @@ -715,9 +731,11 @@ def from_float( scale_dtype = None, zero_point_dtype = None, input_quant_func = None, + preserve_zero = True, + zero_point_domain = ZeroPointDomain.INT, ): - scale, zero_point = choose_qparams_affine(input_float, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, scale_dtype, zero_point_dtype) - int_data = quantize_affine(input_float, block_size, scale, zero_point, target_dtype, quant_min, quant_max) + scale, zero_point = choose_qparams_affine(input_float, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, scale_dtype, zero_point_dtype, preserve_zero, zero_point_domain) + int_data = quantize_affine(input_float, block_size, scale, zero_point, target_dtype, quant_min, quant_max, zero_point_domain) return cls( int_data, scale, @@ -726,6 +744,7 @@ def from_float( input_float.shape, quant_min, quant_max, + zero_point_domain, input_quant_func=input_quant_func, dtype=input_float.dtype ) @@ -740,7 +759,54 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): args[1], args[2] if len(args) > 2 else None, ) - if weight_qtensor.input_quant_func is not None: + if weight_qtensor.input_quant_func is None: + is_cuda = args[0].is_cuda + is_cpu = args[0].device == torch.device("cpu") + # weight only quantization + is_int8 = ( + weight_qtensor.int_data.dtype == torch.int8 and + weight_qtensor.quant_min is None or weight_qtensor.quant_min == -128 and + weight_qtensor.quant_max is None or weight_qtensor.quant_max == 127 + ) + is_uint4 = ( + weight_qtensor.int_data.dtype == torch.int32 and + weight_qtensor.quant_min == 0 and + weight_qtensor.quant_max == 15 + ) + + # TODO: enable cpu and mps path as well + # TODO: make sure weight dimension matches the expectation of the int4mm kernel + # TODO: move this to TinygemmAffineQuantizedTensor + if ( + is_cuda and + is_uint4 and + weight_qtensor.dtype == torch.bfloat16 and + len(weight_qtensor.shape) == 2 and + weight_qtensor.block_size[0] == 1 and + weight_qtensor.zero_point_domain == ZeroPointDomain.FLOAT + ): + # groupwise int4 quantization + # TODO: currently doing packing on the fly, we'll need to figure out + # the API to do packing before hand + # TODO: expose the arg + innerKTiles = 8 + packed_weight = torch.ops.aten._convert_weight_to_int4pack(weight_qtensor.int_data.to(torch.int32), innerKTiles) + scales_and_zeros = pack_tinygemm_scales_and_zeros(weight_qtensor.scale, weight_qtensor.zero_point) + groupsize = weight_qtensor.block_size[-1] + return torch.ops.aten._weight_int4pack_mm(input_tensor.contiguous(), packed_weight, groupsize, scales_and_zeros) + elif ( + is_cpu and + is_int8 and + len(weight_qtensor.shape) == 2 and + len(weight_qtensor.block_size) == 2 and + weight_qtensor.block_size[0] == 1 and + weight_qtensor.block_size[1] == weight_qtensor.shape[1] + ): + # TODO: enable mps path as well + # per channel int8 weight only quantizated mm + return torch.ops.aten._weight_int8pack_mm(input_tensor.contiguous(), weight_qtensor.int_data, weight_qtensor.scale) + else: + # dynamic quantization input_tensor = weight_qtensor.input_quant_func(input_tensor) input_tensor = input_tensor.dequantize() weight_tensor = weight_qtensor.dequantize() @@ -777,6 +843,7 @@ def to(self, *args, **kwargs): self.shape, self.quant_min, self.quant_max, + self.zero_point_domain, self.input_quant_func, **kwargs, ) @@ -790,6 +857,7 @@ def _apply_fn_to_data(self, fn): self.shape, self.quant_min, self.quant_max, + self.zero_point_domain, self.input_quant_func, dtype=self.dtype, )