Skip to content

Commit

Permalink
Add AffineQuantizedObserver
Browse files Browse the repository at this point in the history
Summary:
In our static_quant flow tutorial we were still using observers from `torch.ao` which we plan to deprecate, this PR adds a more general observer for `AffineQuantizedTensor`, and has shown that we can
replace the old observers (min max observer), there could be futhre work to improve perf, add new types
of observation, e.g. tracking stats other than just min/max, moving average observer, histogram observer.

Test Plan:
python test/quantization/test_observer.py
python tutorials/calibration_flow/static_quant.py

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 committed Aug 13, 2024
1 parent 88a263a commit f3fc52b
Show file tree
Hide file tree
Showing 4 changed files with 300 additions and 26 deletions.
39 changes: 39 additions & 0 deletions test/quantization/test_observer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import torch
from torch.testing._internal.common_utils import TestCase
from torchao.quantization.observer import (
AffineQuantizedMinMaxObserver,
PerTensor,
PerAxis,
)
from torchao.quantization.quant_primitives import (
MappingType,
)
import unittest
# NOTE: we can copy paste these here if we decide to deprecate them in torch.ao
from torch.ao.quantization.observer import MinMaxObserver, PerChannelMinMaxObserver

class TestQuantFlow(TestCase):
def _test_obs_helper(self, obs1, obs2):
example_inputs = [torch.randn(10, 2048), torch.randn(10, 2048), torch.randn(10, 2048)]
for example_input in example_inputs:
obs1(example_input)
obs2(example_input)

scale1, zero_point1 = obs1.calculate_qparams()
scale2, zero_point2 = obs2.calculate_qparams()
self.assertTrue(torch.allclose(scale1, scale2))
self.assertTrue(torch.allclose(zero_point1, zero_point2))

def test_min_max_per_tensor_affine(self):
obs = AffineQuantizedMinMaxObserver(MappingType.ASYMMETRIC, torch.uint8, granularity_type=PerTensor(), eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float, zero_point_dtype=torch.int)
ref_obs = MinMaxObserver(dtype=torch.uint8, qscheme=torch.per_tensor_affine)
self._test_obs_helper(obs, ref_obs)

def test_min_max_per_channel_affine(self):
obs = AffineQuantizedMinMaxObserver(MappingType.ASYMMETRIC, torch.uint8, granularity_type=PerAxis(axis=0), eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float, zero_point_dtype=torch.int)
ref_obs = PerChannelMinMaxObserver(dtype=torch.uint8, qscheme=torch.per_channel_affine)
self._test_obs_helper(obs, ref_obs)


if __name__ == "__main__":
unittest.main()
166 changes: 166 additions & 0 deletions torchao/quantization/observer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
import torch
from .quant_primitives import (
_get_reduction_params,
choose_qparams_affine_with_min_max,
MappingType,
ZeroPointDomain,
)

from abc import ABCMeta, abstractmethod
from dataclasses import dataclass
from typing import Callable, List, Tuple, Optional, Any
from functools import partial
import logging
logger = logging.getLogger(__name__)


@dataclass(frozen=True)
class GranularityType:
pass

@dataclass(frozen=True)
class PerTensor(GranularityType):
pass

@dataclass(frozen=True)
class PerAxis(GranularityType):
axis: int

# borrowed from torch.ao.quantization.observer
class _PartialWrapper:
def __init__(self, p):
self.p = p

def __call__(self, *args, **keywords):
return self.p(*args, **keywords)

def __repr__(self):
return self.p.__repr__()

def with_args(self, *args, **kwargs):
return _with_args(self, *args, **kwargs)

def _with_args(cls_or_self, *args, **kwargs):
r"""Wrapper that allows creation of class factories.
This can be useful when there is a need to create classes with the same
constructor arguments, but different instances.
Example::
>>> # xdoctest: +SKIP("Undefined vars")
>>> Foo.with_args = classmethod(_with_args)
>>> foo_builder = Foo.with_args(a=3, b=4).with_args(answer=42)
>>> foo_instance1 = foo_builder()
>>> foo_instance2 = foo_builder()
>>> id(foo_instance1) == id(foo_instance2)
False
"""
r = _PartialWrapper(partial(cls_or_self, *args, **kwargs))
return r

def get_block_size(input_shape: Tuple[int, ...], granularity_type: GranularityType) -> Tuple[int, ...]:
if isinstance(granularity_type, PerTensor):
return input_shape
elif isinstance(granularity_type, PerAxis):
block_size = list(input_shape)
block_size[granularity_type.axis] = 1
return tuple(block_size)
raise ValueError(f"Unsupported GranularityType: {granularity_type}")

ABC: Any = ABCMeta("ABC", (object,), {}) # compatible with Python 2 *and* 3:

class AffineQuantizedObserverBase(ABC, torch.nn.Module):
"""Observer module for affine quantization (https://github.com/pytorch/ao/tree/main/torchao/quantization#affine-quantization)
Args:
`granularity_type` and `block_size`: The granularity of the quantization,
must specify at least one, if both are specified `block_size` takes precedence
Current supported granularity type are `PerTensor` and `PerAxis`
other args: please see `:class:torchao.dtypes.AffineQuantizedTensor`
"""
with_args = classmethod(_with_args)

def __init__(self,
mapping_type: MappingType,
target_dtype: torch.dtype,
block_size: Optional[Tuple[int, ...]] = None,
granularity_type: Optional[GranularityType] = None,
quant_min: Optional[int] = None,
quant_max: Optional[int] = None,
eps: Optional[float] = None,
scale_dtype: Optional[torch.dtype] = None,
zero_point_dtype: Optional[torch.dtype] = None,
preserve_zero: bool = True,
zero_point_domain = ZeroPointDomain.INT,
):
super().__init__()
assert block_size is not None or granularity_type is not None, "Must specify either block_size or granularity_type"
if block_size is not None and granularity_type is not None:
logger.warning("Both block_size and granularity_type are specified, ignoring granularity_type. block_size: {block_size}, granularity_type: {granularity_type}")
self.mapping_type = mapping_type
self.target_dtype = target_dtype
self.block_size = block_size
self.granularity_type = granularity_type
self.quant_min = quant_min
self.quant_max = quant_max
self.eps = eps
self.scale_dtype = scale_dtype
self.zero_point_dtype = zero_point_dtype
self.preserve_zero = preserve_zero
self.zero_point_domain = zero_point_domain

@abstractmethod
def forward(self, input: torch.Tensor) -> torch.Tensor:
""" forward function should take the input tensor
and updates internal stats and return the original input Tensor
"""
pass

@abstractmethod
def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]:
"""Calculate quantization parameter based on the stats attached to the observer module
and returns a tuple of scale and zero_point Tensor
"""
pass

class AffineQuantizedMinMaxObserver(AffineQuantizedObserverBase):
def forward(self, input: torch.Tensor):
if input.numel() == 0:
return input

input_detached = input.detach()
if self.block_size is None:
self.block_size = get_block_size(input_detached.shape, self.granularity_type)

shape_for_reduction, reduction_dims = _get_reduction_params(self.block_size, input_detached.size())
input_detached = input_detached.view(shape_for_reduction)
min_val = torch.amin(input_detached, dim=reduction_dims, keepdim=False)
max_val = torch.amax(input_detached, dim=reduction_dims, keepdim=False)
if not hasattr(self, "min_val") or not hasattr(self, "max_val"):
self.min_val = min_val
self.max_val = max_val
else:
min_val = torch.min(self.min_val, min_val)
max_val = torch.max(self.max_val, max_val)
self.min_val.copy_(min_val)
self.max_val.copy_(max_val)
# returning original input
return input

def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]:
assert hasattr(self, "min_val") and hasattr(self, "max_val"), "Expecting the observer has min_val and max_val, please run the observer before calling calculate_qparams"
return choose_qparams_affine_with_min_max(
self.min_val,
self.max_val,
self.mapping_type,
self.block_size,
self.target_dtype,
self.quant_min,
self.quant_max,
self.eps,
self.scale_dtype,
self.zero_point_dtype,
self.preserve_zero,
self.zero_point_domain
)
84 changes: 70 additions & 14 deletions torchao/quantization/quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
"safe_int_mm",
"int_scaled_matmul",
"choose_qparams_affine",
"choose_qparams_affine_with_min_max",
"quantize_affine",
"dequantize_affine",
"fake_quantize_affine",
Expand Down Expand Up @@ -570,9 +571,51 @@ def choose_qparams_affine(
zero_point_domain.name
)


def choose_qparams_affine_with_min_max(
min_val: torch.Tensor,
max_val: torch.Tensor,
mapping_type: MappingType,
block_size: Tuple[int, ...],
target_dtype: torch.dtype,
quant_min: Optional[int] = None,
quant_max: Optional[int] = None,
eps: Optional[float] = None,
scale_dtype: Optional[torch.dtype] = None,
zero_point_dtype: Optional[torch.dtype] = None,
preserve_zero: bool = True,
zero_point_domain = ZeroPointDomain.INT,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""A variant of :func:`~torchao.quantization.quant_primitives.choose_qparams_affine`
operator that pass in min_val and max_val directly instead of deriving these from a single input.
This is used for observers in static quantization where min_val and max_val may be obtained through
tracking all the data in calibration data set.
Args:
Mostly same as :func:`~torchao.quantization.quant_primitives.choose_qparams_affine`. with one
difference: instead of passing in `input` Tensor and use that to calculate min_val/max_val
and then scale/zero_point, we pass in min_val/max_val directly
"""
return _choose_qparams_affine(
None,
mapping_type.name,
block_size,
target_dtype,
quant_min,
quant_max,
eps,
scale_dtype,
zero_point_dtype,
preserve_zero,
zero_point_domain.name,
min_val,
max_val,
)


@register_custom_op
def _choose_qparams_affine(
input: torch.Tensor,
input: Optional[torch.Tensor],
mapping_type: str,
block_size: List[int],
target_dtype: torch.dtype,
Expand All @@ -583,23 +626,38 @@ def _choose_qparams_affine(
zero_point_dtype: Optional[torch.dtype] = None,
preserve_zero: bool = True,
zero_point_domain: str = "INT",
min_val: Optional[torch.Tensor] = None,
max_val: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""op definition that has compatible signatures with custom op library
"""
quant_min, quant_max = _get_and_check_qmin_qmax(target_dtype, quant_min, quant_max)
assert mapping_type in [MappingType.SYMMETRIC.name, MappingType.ASYMMETRIC.name], f"Unsupported mapping type: {mapping_type}"

if scale_dtype is None:
scale_dtype = input.dtype
if zero_point_dtype is None:
zero_point_dtype = input.dtype
if input is not None:
if scale_dtype is None:
scale_dtype = input.dtype
if zero_point_dtype is None:
zero_point_dtype = input.dtype
if eps is None:
eps = torch.finfo(input.dtype).eps

assert len(block_size) == input.dim(), f"Got input dim:{input.dim()}, block_size: {block_size}"
shape_for_reduction, reduction_dims = _get_reduction_params(block_size, input.size())
input = input.view(shape_for_reduction)
assert len(block_size) == input.dim(), f"Got input dim:{input.dim()}, block_size: {block_size}"
shape_for_reduction, reduction_dims = _get_reduction_params(block_size, input.size())
input = input.view(shape_for_reduction)

min_val = torch.amin(input, dim=reduction_dims, keepdim=False)
max_val = torch.amax(input, dim=reduction_dims, keepdim=False)
else:
assert min_val is not None and max_val is not None, "Need to provide `min_val` and `max_val` when `input` is None, got: {min_val, max_val}"
assert min_val.dtype == max_val.dtype, "Expecting `min_val` and `max_val` to have the same dtype, got: {min_val.dtype, max_val.dtype}"

min_val = torch.amin(input, dim=reduction_dims, keepdim=False)
max_val = torch.amax(input, dim=reduction_dims, keepdim=False)
if scale_dtype is None:
scale_dtype = min_val.dtype
if zero_point_dtype is None:
zero_point_dtype = min_val.dtype
if eps is None:
eps = torch.finfo(min_val.dtype).eps

if preserve_zero:
min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
Expand All @@ -615,10 +673,12 @@ def _choose_qparams_affine(
raise ValueError("preserve_zero == False is not supported for symmetric quantization")
if zero_point_domain != ZeroPointDomain.INT.name:
raise ValueError("zero_point_domain != ZeroPointDomain.INT is not supported for symmetric quantization")
scale = torch.clamp(scale, min=eps)
zero_point = torch.full_like(scale, int((quant_max + quant_min + 1) / 2))
else:
assert mapping_type == MappingType.ASYMMETRIC.name
scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min)
scale = torch.clamp(scale, min=eps)
if preserve_zero:
zero_point = quant_min - torch.round(min_val_neg / scale)
zero_point = torch.clamp(zero_point, quant_min, quant_max)
Expand All @@ -627,8 +687,4 @@ def _choose_qparams_affine(
mid_point = (quant_max + quant_min + 1) / 2
zero_point = min_val_neg + scale * mid_point

if eps is None:
eps = torch.finfo(input.dtype).eps
scale = torch.clamp(scale, min=eps)

return scale.to(dtype=scale_dtype), zero_point.to(dtype=zero_point_dtype)
Loading

0 comments on commit f3fc52b

Please sign in to comment.