From f3fc52bf6a8acc2452f5300b7f8c788e86428bf7 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Fri, 9 Aug 2024 16:38:14 -0700 Subject: [PATCH] Add AffineQuantizedObserver Summary: In our static_quant flow tutorial we were still using observers from `torch.ao` which we plan to deprecate, this PR adds a more general observer for `AffineQuantizedTensor`, and has shown that we can replace the old observers (min max observer), there could be futhre work to improve perf, add new types of observation, e.g. tracking stats other than just min/max, moving average observer, histogram observer. Test Plan: python test/quantization/test_observer.py python tutorials/calibration_flow/static_quant.py Reviewers: Subscribers: Tasks: Tags: --- test/quantization/test_observer.py | 39 +++++ torchao/quantization/observer.py | 166 +++++++++++++++++++++ torchao/quantization/quant_primitives.py | 84 +++++++++-- tutorials/calibration_flow/static_quant.py | 37 +++-- 4 files changed, 300 insertions(+), 26 deletions(-) create mode 100644 test/quantization/test_observer.py create mode 100644 torchao/quantization/observer.py diff --git a/test/quantization/test_observer.py b/test/quantization/test_observer.py new file mode 100644 index 000000000..0e5076051 --- /dev/null +++ b/test/quantization/test_observer.py @@ -0,0 +1,39 @@ +import torch +from torch.testing._internal.common_utils import TestCase +from torchao.quantization.observer import ( + AffineQuantizedMinMaxObserver, + PerTensor, + PerAxis, +) +from torchao.quantization.quant_primitives import ( + MappingType, +) +import unittest +# NOTE: we can copy paste these here if we decide to deprecate them in torch.ao +from torch.ao.quantization.observer import MinMaxObserver, PerChannelMinMaxObserver + +class TestQuantFlow(TestCase): + def _test_obs_helper(self, obs1, obs2): + example_inputs = [torch.randn(10, 2048), torch.randn(10, 2048), torch.randn(10, 2048)] + for example_input in example_inputs: + obs1(example_input) + obs2(example_input) + + scale1, zero_point1 = obs1.calculate_qparams() + scale2, zero_point2 = obs2.calculate_qparams() + self.assertTrue(torch.allclose(scale1, scale2)) + self.assertTrue(torch.allclose(zero_point1, zero_point2)) + + def test_min_max_per_tensor_affine(self): + obs = AffineQuantizedMinMaxObserver(MappingType.ASYMMETRIC, torch.uint8, granularity_type=PerTensor(), eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float, zero_point_dtype=torch.int) + ref_obs = MinMaxObserver(dtype=torch.uint8, qscheme=torch.per_tensor_affine) + self._test_obs_helper(obs, ref_obs) + + def test_min_max_per_channel_affine(self): + obs = AffineQuantizedMinMaxObserver(MappingType.ASYMMETRIC, torch.uint8, granularity_type=PerAxis(axis=0), eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float, zero_point_dtype=torch.int) + ref_obs = PerChannelMinMaxObserver(dtype=torch.uint8, qscheme=torch.per_channel_affine) + self._test_obs_helper(obs, ref_obs) + + +if __name__ == "__main__": + unittest.main() diff --git a/torchao/quantization/observer.py b/torchao/quantization/observer.py new file mode 100644 index 000000000..a8d10f73f --- /dev/null +++ b/torchao/quantization/observer.py @@ -0,0 +1,166 @@ +import torch +from .quant_primitives import ( + _get_reduction_params, + choose_qparams_affine_with_min_max, + MappingType, + ZeroPointDomain, +) + +from abc import ABCMeta, abstractmethod +from dataclasses import dataclass +from typing import Callable, List, Tuple, Optional, Any +from functools import partial +import logging +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class GranularityType: + pass + +@dataclass(frozen=True) +class PerTensor(GranularityType): + pass + +@dataclass(frozen=True) +class PerAxis(GranularityType): + axis: int + +# borrowed from torch.ao.quantization.observer +class _PartialWrapper: + def __init__(self, p): + self.p = p + + def __call__(self, *args, **keywords): + return self.p(*args, **keywords) + + def __repr__(self): + return self.p.__repr__() + + def with_args(self, *args, **kwargs): + return _with_args(self, *args, **kwargs) + +def _with_args(cls_or_self, *args, **kwargs): + r"""Wrapper that allows creation of class factories. + + This can be useful when there is a need to create classes with the same + constructor arguments, but different instances. + + Example:: + + >>> # xdoctest: +SKIP("Undefined vars") + >>> Foo.with_args = classmethod(_with_args) + >>> foo_builder = Foo.with_args(a=3, b=4).with_args(answer=42) + >>> foo_instance1 = foo_builder() + >>> foo_instance2 = foo_builder() + >>> id(foo_instance1) == id(foo_instance2) + False + """ + r = _PartialWrapper(partial(cls_or_self, *args, **kwargs)) + return r + +def get_block_size(input_shape: Tuple[int, ...], granularity_type: GranularityType) -> Tuple[int, ...]: + if isinstance(granularity_type, PerTensor): + return input_shape + elif isinstance(granularity_type, PerAxis): + block_size = list(input_shape) + block_size[granularity_type.axis] = 1 + return tuple(block_size) + raise ValueError(f"Unsupported GranularityType: {granularity_type}") + +ABC: Any = ABCMeta("ABC", (object,), {}) # compatible with Python 2 *and* 3: + +class AffineQuantizedObserverBase(ABC, torch.nn.Module): + """Observer module for affine quantization (https://github.com/pytorch/ao/tree/main/torchao/quantization#affine-quantization) + + Args: + `granularity_type` and `block_size`: The granularity of the quantization, + must specify at least one, if both are specified `block_size` takes precedence + Current supported granularity type are `PerTensor` and `PerAxis` + other args: please see `:class:torchao.dtypes.AffineQuantizedTensor` + """ + with_args = classmethod(_with_args) + + def __init__(self, + mapping_type: MappingType, + target_dtype: torch.dtype, + block_size: Optional[Tuple[int, ...]] = None, + granularity_type: Optional[GranularityType] = None, + quant_min: Optional[int] = None, + quant_max: Optional[int] = None, + eps: Optional[float] = None, + scale_dtype: Optional[torch.dtype] = None, + zero_point_dtype: Optional[torch.dtype] = None, + preserve_zero: bool = True, + zero_point_domain = ZeroPointDomain.INT, + ): + super().__init__() + assert block_size is not None or granularity_type is not None, "Must specify either block_size or granularity_type" + if block_size is not None and granularity_type is not None: + logger.warning("Both block_size and granularity_type are specified, ignoring granularity_type. block_size: {block_size}, granularity_type: {granularity_type}") + self.mapping_type = mapping_type + self.target_dtype = target_dtype + self.block_size = block_size + self.granularity_type = granularity_type + self.quant_min = quant_min + self.quant_max = quant_max + self.eps = eps + self.scale_dtype = scale_dtype + self.zero_point_dtype = zero_point_dtype + self.preserve_zero = preserve_zero + self.zero_point_domain = zero_point_domain + + @abstractmethod + def forward(self, input: torch.Tensor) -> torch.Tensor: + """ forward function should take the input tensor + and updates internal stats and return the original input Tensor + """ + pass + + @abstractmethod + def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]: + """Calculate quantization parameter based on the stats attached to the observer module + and returns a tuple of scale and zero_point Tensor + """ + pass + +class AffineQuantizedMinMaxObserver(AffineQuantizedObserverBase): + def forward(self, input: torch.Tensor): + if input.numel() == 0: + return input + + input_detached = input.detach() + if self.block_size is None: + self.block_size = get_block_size(input_detached.shape, self.granularity_type) + + shape_for_reduction, reduction_dims = _get_reduction_params(self.block_size, input_detached.size()) + input_detached = input_detached.view(shape_for_reduction) + min_val = torch.amin(input_detached, dim=reduction_dims, keepdim=False) + max_val = torch.amax(input_detached, dim=reduction_dims, keepdim=False) + if not hasattr(self, "min_val") or not hasattr(self, "max_val"): + self.min_val = min_val + self.max_val = max_val + else: + min_val = torch.min(self.min_val, min_val) + max_val = torch.max(self.max_val, max_val) + self.min_val.copy_(min_val) + self.max_val.copy_(max_val) + # returning original input + return input + + def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]: + assert hasattr(self, "min_val") and hasattr(self, "max_val"), "Expecting the observer has min_val and max_val, please run the observer before calling calculate_qparams" + return choose_qparams_affine_with_min_max( + self.min_val, + self.max_val, + self.mapping_type, + self.block_size, + self.target_dtype, + self.quant_min, + self.quant_max, + self.eps, + self.scale_dtype, + self.zero_point_dtype, + self.preserve_zero, + self.zero_point_domain + ) diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index 1d958be84..a37c17403 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -21,6 +21,7 @@ "safe_int_mm", "int_scaled_matmul", "choose_qparams_affine", + "choose_qparams_affine_with_min_max", "quantize_affine", "dequantize_affine", "fake_quantize_affine", @@ -570,9 +571,51 @@ def choose_qparams_affine( zero_point_domain.name ) + +def choose_qparams_affine_with_min_max( + min_val: torch.Tensor, + max_val: torch.Tensor, + mapping_type: MappingType, + block_size: Tuple[int, ...], + target_dtype: torch.dtype, + quant_min: Optional[int] = None, + quant_max: Optional[int] = None, + eps: Optional[float] = None, + scale_dtype: Optional[torch.dtype] = None, + zero_point_dtype: Optional[torch.dtype] = None, + preserve_zero: bool = True, + zero_point_domain = ZeroPointDomain.INT, +) -> Tuple[torch.Tensor, torch.Tensor]: + """A variant of :func:`~torchao.quantization.quant_primitives.choose_qparams_affine` + operator that pass in min_val and max_val directly instead of deriving these from a single input. + This is used for observers in static quantization where min_val and max_val may be obtained through + tracking all the data in calibration data set. + + Args: + Mostly same as :func:`~torchao.quantization.quant_primitives.choose_qparams_affine`. with one + difference: instead of passing in `input` Tensor and use that to calculate min_val/max_val + and then scale/zero_point, we pass in min_val/max_val directly + """ + return _choose_qparams_affine( + None, + mapping_type.name, + block_size, + target_dtype, + quant_min, + quant_max, + eps, + scale_dtype, + zero_point_dtype, + preserve_zero, + zero_point_domain.name, + min_val, + max_val, + ) + + @register_custom_op def _choose_qparams_affine( - input: torch.Tensor, + input: Optional[torch.Tensor], mapping_type: str, block_size: List[int], target_dtype: torch.dtype, @@ -583,23 +626,38 @@ def _choose_qparams_affine( zero_point_dtype: Optional[torch.dtype] = None, preserve_zero: bool = True, zero_point_domain: str = "INT", + min_val: Optional[torch.Tensor] = None, + max_val: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """op definition that has compatible signatures with custom op library """ quant_min, quant_max = _get_and_check_qmin_qmax(target_dtype, quant_min, quant_max) assert mapping_type in [MappingType.SYMMETRIC.name, MappingType.ASYMMETRIC.name], f"Unsupported mapping type: {mapping_type}" - if scale_dtype is None: - scale_dtype = input.dtype - if zero_point_dtype is None: - zero_point_dtype = input.dtype + if input is not None: + if scale_dtype is None: + scale_dtype = input.dtype + if zero_point_dtype is None: + zero_point_dtype = input.dtype + if eps is None: + eps = torch.finfo(input.dtype).eps - assert len(block_size) == input.dim(), f"Got input dim:{input.dim()}, block_size: {block_size}" - shape_for_reduction, reduction_dims = _get_reduction_params(block_size, input.size()) - input = input.view(shape_for_reduction) + assert len(block_size) == input.dim(), f"Got input dim:{input.dim()}, block_size: {block_size}" + shape_for_reduction, reduction_dims = _get_reduction_params(block_size, input.size()) + input = input.view(shape_for_reduction) + + min_val = torch.amin(input, dim=reduction_dims, keepdim=False) + max_val = torch.amax(input, dim=reduction_dims, keepdim=False) + else: + assert min_val is not None and max_val is not None, "Need to provide `min_val` and `max_val` when `input` is None, got: {min_val, max_val}" + assert min_val.dtype == max_val.dtype, "Expecting `min_val` and `max_val` to have the same dtype, got: {min_val.dtype, max_val.dtype}" - min_val = torch.amin(input, dim=reduction_dims, keepdim=False) - max_val = torch.amax(input, dim=reduction_dims, keepdim=False) + if scale_dtype is None: + scale_dtype = min_val.dtype + if zero_point_dtype is None: + zero_point_dtype = min_val.dtype + if eps is None: + eps = torch.finfo(min_val.dtype).eps if preserve_zero: min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) @@ -615,10 +673,12 @@ def _choose_qparams_affine( raise ValueError("preserve_zero == False is not supported for symmetric quantization") if zero_point_domain != ZeroPointDomain.INT.name: raise ValueError("zero_point_domain != ZeroPointDomain.INT is not supported for symmetric quantization") + scale = torch.clamp(scale, min=eps) zero_point = torch.full_like(scale, int((quant_max + quant_min + 1) / 2)) else: assert mapping_type == MappingType.ASYMMETRIC.name scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min) + scale = torch.clamp(scale, min=eps) if preserve_zero: zero_point = quant_min - torch.round(min_val_neg / scale) zero_point = torch.clamp(zero_point, quant_min, quant_max) @@ -627,8 +687,4 @@ def _choose_qparams_affine( mid_point = (quant_max + quant_min + 1) / 2 zero_point = min_val_neg + scale * mid_point - if eps is None: - eps = torch.finfo(input.dtype).eps - scale = torch.clamp(scale, min=eps) - return scale.to(dtype=scale_dtype), zero_point.to(dtype=zero_point_dtype) diff --git a/tutorials/calibration_flow/static_quant.py b/tutorials/calibration_flow/static_quant.py index 7911f645e..8106f7e59 100644 --- a/tutorials/calibration_flow/static_quant.py +++ b/tutorials/calibration_flow/static_quant.py @@ -4,8 +4,6 @@ import torch import copy -# TODO: use the generalized observer for affine qunatization in the future -from torch.ao.quantization.observer import MinMaxObserver, PerChannelMinMaxObserver import torch.nn.functional as F from torch import Tensor from torchao.dtypes import to_affine_quantized_static @@ -13,7 +11,14 @@ from torchao.quantization import quantize_ from torchao.quantization import to_linear_activation_quantized from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter - +from torchao.quantization.observer import ( + AffineQuantizedMinMaxObserver, + PerTensor, + PerAxis, +) +from torchao.quantization.quant_primitives import ( + MappingType, +) class ObservedLinear(torch.nn.Linear): @@ -36,9 +41,12 @@ def from_float(cls, float_linear, act_obs, weight_obs): def insert_observers_(model, act_obs, weight_obs): _is_linear = lambda m, fqn: isinstance(m, torch.nn.Linear) - replacement_fn = lambda m: ObservedLinear.from_float(m, act_obs, weight_obs) - act_obs = copy.deepcopy(act_obs) - weight_obs = copy.deepcopy(weight_obs) + + def replacement_fn(m): + copied_act_obs = copy.deepcopy(act_obs) + copied_weight_obs = copy.deepcopy(weight_obs) + return ObservedLinear.from_float(m, copied_act_obs, copied_weight_obs) + _replace_with_custom_fn_if_matches_filter(model, replacement_fn, _is_linear) # converting observed linear module to linear module with quantzied weights (and quantized activations) @@ -94,8 +102,8 @@ def apply_static_quant2(observed_linear): class ToyLinearModel(torch.nn.Module): def __init__(self, m=64, n=32, k=64): super().__init__() - self.linear1 = torch.nn.Linear(m, n, bias=False) - self.linear2 = torch.nn.Linear(n, k, bias=False) + self.linear1 = torch.nn.Linear(m, k, bias=False) + self.linear2 = torch.nn.Linear(k, n, bias=False) def example_inputs(self, batch_size=1, dtype=torch.float32, device="cpu"): return (torch.randn(batch_size, self.linear1.in_features, dtype=dtype, device=device),) @@ -105,16 +113,21 @@ def forward(self, x): x = self.linear2(x) return x +torch.manual_seed(0) + dtype = torch.bfloat16 -m = ToyLinearModel(1024, 1024, 1024).eval().to(dtype).to("cuda") +m = ToyLinearModel().eval().to(dtype).to("cuda") + +m_for_test = copy.deepcopy(m) + m_bf16 = copy.deepcopy(m) example_inputs = m.example_inputs(dtype=dtype, device="cuda") +print("example inputs shape:", example_inputs[0].shape) m_bf16 = torch.compile(m_bf16, mode='max-autotune') -# TODO: use the generalized observer for affine qunatization in the future -act_obs = MinMaxObserver(dtype=torch.uint8, qscheme=torch.per_tensor_affine).to("cuda") -weight_obs = PerChannelMinMaxObserver(dtype=torch.uint8, qscheme=torch.per_channel_affine).to("cuda") +act_obs = AffineQuantizedMinMaxObserver(MappingType.ASYMMETRIC, torch.uint8, granularity_type=PerTensor(), eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float32, zero_point_dtype=torch.int32) +weight_obs = AffineQuantizedMinMaxObserver(MappingType.ASYMMETRIC, torch.uint8, granularity_type=PerAxis(axis=0), eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float32, zero_point_dtype=torch.int32) before_quant = m(*example_inputs)