Skip to content

Commit

Permalink
Add decorator for custom op and inductor decomp registration
Browse files Browse the repository at this point in the history
Summary:
This PR adds a decorator to register custom op and also an inductor dcomposition.

The goal is for torch.export path to be able to see high level ops like quantize_affine instead of breaking down the op, this is because some backends like xnnpack wants to work with these higher level ops.

This is a redo for #408, difference is we can preserve the enums on the python side in this PR

Test Plan:
regression tests:
python test/quantization/test_quant_api.py
python test/integration/test_integration.py

also need to check performance with python tutorials/quantize_vit/run_vit_b_quant.py

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 committed Jul 2, 2024
1 parent c2cf973 commit 604f69c
Show file tree
Hide file tree
Showing 2 changed files with 168 additions and 21 deletions.
17 changes: 14 additions & 3 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1244,7 +1244,7 @@ def test_autoquant_manual(self, device, dtype):
out3 = mod(example_input)
sqnr2 = SQNR(out, out3)
self.assertTrue(sqnr2 >= 30)


@parameterized.expand(combine_parameters(COMMON_DEVICE_DTYPE,
[
Expand Down Expand Up @@ -1376,7 +1376,7 @@ class TestExport(unittest.TestCase):
list(itertools.product(TENSOR_SUBCLASS_APIS, COMMON_DEVICES, COMMON_DTYPES)),
)
@run_supported_device_dtype
def test_aoti(self, api, test_device, test_dtype):
def test_export(self, api, test_device, test_dtype):
if not TORCH_VERSION_AFTER_2_4:
self.skipTest("aoti compatibility requires 2.4+.")

Expand Down Expand Up @@ -1413,9 +1413,20 @@ def forward(self, x):

# make sure it compiles
example_inputs = (x,)
model = torch.export.export(model, example_inputs).module()
from torch._export import capture_pre_autograd_graph
# TODO: export changes numerics right now, this is because of functionalization according to Zhengxu
# we can re-enable this after non-functional IR is enabled in export
# model = torch.export.export(model, example_inputs).module()
model = capture_pre_autograd_graph(model, example_inputs)
after_export = model(x)
self.assertTrue(torch.equal(after_export, ref))
if api is _int8da_int8w_api:
targets = [n.target for n in model.graph.nodes]
self.assertTrue(torch.ops.quant.choose_qparams_affine.default in targets)
self.assertTrue(torch.ops.quant.quantize_affine.default in targets)




class TestUtils(unittest.TestCase):
@parameterized.expand(COMMON_DEVICE_DTYPE)
Expand Down
172 changes: 154 additions & 18 deletions torchao/quantization/quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,16 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from enum import Enum
from enum import Enum, auto
from typing import List, Optional, Tuple, Dict
import torch

from torchao.kernel.intmm import int_scaled_matmul
from torchao.kernel.intmm import safe_int_mm
from torchao.utils import TORCH_VERSION_AFTER_2_3
from torchao.utils import (
TORCH_VERSION_AFTER_2_3,
TORCH_VERSION_AFTER_2_5,
)


__all__ = [
Expand All @@ -34,17 +37,17 @@ class MappingType(Enum):
based on this mapping
e.g. scale = (10.2 - (-3.5)) / (7 - (-8))
"""
SYMMETRIC = 0
ASYMMETRIC = 1
SYMMETRIC = auto()
ASYMMETRIC = auto()

class ZeroPointDomain(Enum):
"""Enum that indicate whether zero_point is in integer domain or floating point domain
integer domain: quantized_val = (float_val / scale) (integer) + zero_point (integer)
float domain: quantized_val = (float_val - (zero_point (float) - scale * mid_point)) / scale
"""
INT = 0
FLOAT = 1
INT = auto()
FLOAT = auto()

"""
Map from dtype to the bound value of integers
Expand All @@ -69,6 +72,54 @@ class ZeroPointDomain(Enum):
})


quant_lib = torch.library.Library("quant", "FRAGMENT")

def register_custom_op(lib):
"""This decorator is used to preserve some high level operators for torch.export.export
while still allow them to be decomposed for inductor path
requirement: make sure `fn.__name__[1:]` is the operator name you want to register
NOTE: This should be applied at the top, after all other decorators have been applied
NOTE: We haven't tested the case when `fn` accepts tensor subclass instance as input,
e.g. uint4 tensor subclass instance, and we'll probably need to figure out what would make
sense for downstream system (like executorch) to accept as well
Example:
lib = torch.library.Library("my_namespace', "FRAGMENT")
@register_custom_op(lib)
def _the_op_that_needs_to_be_preserved(...)
...
# after this, `_the_op_that_needs_to_be_preserved` will be preserved as
# torch.ops.my_namespace.the_op_that_needs_to_be_preserved operator after
# torch.export.export / torch._export.capture_pre_autograd_graph
"""
from torch._inductor.decomposition import register_decomposition

def decorator(fn):
if TORCH_VERSION_AFTER_2_5:
from torch._library.infer_schema import infer_schema

# expecting fn.__name__ starts with `_` and we want to take the rest
# to be the name of the custom op
assert fn.__name__[0] == "_", f"Expecting function name starts with `_`, got {fn.__name__}"
op_name = fn.__name__[1:]
schema = op_name + infer_schema(fn)
lib.define(schema)
lib.impl(op_name, fn, "CompositeImplicitAutograd")

lib_namespace = lib.ns
op = getattr(getattr(torch.ops, lib_namespace), op_name)
register_decomposition([op])(fn)
return op
else:
return fn

return decorator


# TODO: decide on if we want to allow custom quant_min/quant_max here
def _get_and_check_qmin_qmax(dtype, quant_min, quant_max):
"""Get quant_min and quant_max args based on dtype and also
Expand Down Expand Up @@ -140,7 +191,7 @@ def quantize_affine(
quant_min: Optional[int] = None,
quant_max: Optional[int] = None,
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
):
) -> torch.Tensor:
"""
Args:
input (torch.Tensor): original float32, float16 or bfloat16 Tensor
Expand Down Expand Up @@ -174,6 +225,31 @@ def quantize_affine(
Output:
quantized tensor with requested dtype
"""
return _quantize_affine(
input,
block_size,
scale,
zero_point,
output_dtype,
quant_min,
quant_max,
zero_point_domain.name,
)


@register_custom_op(quant_lib)
def _quantize_affine(
input: torch.Tensor,
block_size: List[int],
scale: torch.Tensor,
zero_point: Optional[torch.Tensor],
output_dtype: torch.dtype,
quant_min: Optional[int] = None,
quant_max: Optional[int] = None,
zero_point_domain: str = "INT",
) -> torch.Tensor:
"""op definition that has compatible signatures with custom op library
"""
# TODO: validations
# TODO: validate scale/zero_point dimensions are compatible with block_size
assert input.dtype in [torch.float32, torch.float16, torch.bfloat16], f"Unsupported input dtype: {input.dtype}"
Expand All @@ -188,12 +264,12 @@ def quantize_affine(
if zero_point is not None:
zero_point = zero_point.view(shape_after_reduction)

if zero_point_domain == ZeroPointDomain.INT:
if zero_point_domain == ZeroPointDomain.INT.name:
quant = torch.clamp(
torch.round(input * (1.0 / scale)) + zero_point, quant_min, quant_max
).to(output_dtype)
else:
assert zero_point_domain == ZeroPointDomain.FLOAT
assert zero_point_domain == ZeroPointDomain.FLOAT.name
mid_point = (quant_max + quant_min + 1) / 2
min_val = zero_point - scale * mid_point
quant = (
Expand All @@ -216,7 +292,7 @@ def dequantize_affine(
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
*,
output_dtype: torch.dtype = torch.float32,
):
) -> torch.Tensor:
"""
Args:
input (torch.Tensor): quantized tensor, should match the dtype `dtype` argument
Expand All @@ -238,6 +314,34 @@ def dequantize_affine(
Output:
dequantized Tensor, with requested dtype or fp32
"""
return _dequantize_affine(
input,
block_size,
scale,
zero_point,
input_dtype,
quant_min,
quant_max,
zero_point_domain.name,
output_dtype=output_dtype,
)


# @register_custom_op(quant_lib, 'dequantize_affine(Tensor input, int[] block_size, Tensor scale, Tensor zero_point, ScalarType input_dtype, int? quant_min=None, int? quant_max=None, str zero_point_domain="INT", ScalarType output_dtype=float) -> Tensor')
@register_custom_op(quant_lib)
def _dequantize_affine(
input: torch.Tensor,
block_size: List[int],
scale: torch.Tensor,
zero_point: Optional[torch.Tensor],
input_dtype: torch.dtype,
quant_min: Optional[int] = None,
quant_max: Optional[int] = None,
zero_point_domain: str = "INT",
output_dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""op definition that has compatible signatures with custom op library
"""

# TODO: validations
# TODO: validate scale/zero_point dimensions are compatible with block_size
Expand All @@ -255,16 +359,16 @@ def dequantize_affine(
if zero_point is not None:
zero_point = zero_point.view(shape_after_reduction)

if zero_point_domain == ZeroPointDomain.INT:
if zero_point_domain == ZeroPointDomain.INT.name:
# Force a copy to avoid input modification due
# to upcoming in-place operations.
dequant = input.to(torch.int32, copy=True)
if zero_point is not None:
dequant -= zero_point.to(torch.int32)
dequant = dequant - zero_point.to(torch.int32)
dequant = dequant.to(output_dtype)
dequant *= scale
dequant = dequant * scale
else:
assert zero_point_domain == ZeroPointDomain.FLOAT, f"Unexpected zero point domain: {zero_point_domain}"
assert zero_point_domain == ZeroPointDomain.FLOAT.name, f"Unexpected zero point domain: {zero_point_domain}"
mid_point = (quant_max + quant_min + 1) / 2
# This should allocate new memory and avoid input modification
dequant = input - mid_point
Expand Down Expand Up @@ -320,8 +424,39 @@ def choose_qparams_affine(
Output:
Tuple of scales and zero_points Tensor with requested dtype
"""
return _choose_qparams_affine(
input,
mapping_type.name,
block_size,
target_dtype,
quant_min,
quant_max,
eps,
scale_dtype,
zero_point_dtype,
preserve_zero,
zero_point_domain.name
)

# @register_custom_op(quant_lib, 'choose_qparams_affine(Tensor input, str mapping_type, int[] block_size, ScalarType target_dtype, int? quant_min=None, int? quant_max=None, float? eps=None, ScalarType? scale_dtype=None, ScalarType? zero_point_dtype=None, bool preserve_zero=True, str zero_point_domain="INT") -> (Tensor, Tensor)')
@register_custom_op(quant_lib)
def _choose_qparams_affine(
input: torch.Tensor,
mapping_type: str,
block_size: List[int],
target_dtype: torch.dtype,
quant_min: Optional[int] = None,
quant_max: Optional[int] = None,
eps: Optional[float] = None,
scale_dtype: Optional[torch.dtype] = None,
zero_point_dtype: Optional[torch.dtype] = None,
preserve_zero: bool = True,
zero_point_domain: str = "INT",
) -> Tuple[torch.Tensor, torch.Tensor]:
"""op definition that has compatible signatures with custom op library
"""
quant_min, quant_max = _get_and_check_qmin_qmax(target_dtype, quant_min, quant_max)
assert mapping_type in [MappingType.SYMMETRIC, MappingType.ASYMMETRIC], f"Unsupported mapping type: {mapping_type}"
assert mapping_type in [MappingType.SYMMETRIC.name, MappingType.ASYMMETRIC.name], f"Unsupported mapping type: {mapping_type}"

if scale_dtype is None:
scale_dtype = input.dtype
Expand All @@ -342,21 +477,22 @@ def choose_qparams_affine(
min_val_neg = min_val
max_val_pos = max_val

if mapping_type == MappingType.SYMMETRIC:
if mapping_type == MappingType.SYMMETRIC.name:
max_val_pos = torch.max(-min_val_neg, max_val_pos)
scale = max_val_pos / (float(quant_max - quant_min) / 2)
if not preserve_zero:
raise ValueError("preserve_zero == False is not supported for symmetric quantization")
if zero_point_domain != ZeroPointDomain.INT:
if zero_point_domain != ZeroPointDomain.INT.name:
raise ValueError("zero_point_domain != ZeroPointDomain.INT is not supported for symmetric quantization")
zero_point = torch.full_like(scale, int((quant_max + quant_min + 1) / 2))
else:
assert mapping_type == MappingType.ASYMMETRIC.name
scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min)
if preserve_zero:
zero_point = quant_min - torch.round(min_val_neg / scale)
zero_point = torch.clamp(zero_point, quant_min, quant_max)
else:
assert zero_point_domain == ZeroPointDomain.FLOAT, "if not preserve_zero, zero_point must be in FLOAT domain"
assert zero_point_domain == ZeroPointDomain.FLOAT.name, "if not preserve_zero, zero_point must be in FLOAT domain"
mid_point = (quant_max + quant_min + 1) / 2
zero_point = min_val_neg + scale * mid_point

Expand Down

0 comments on commit 604f69c

Please sign in to comment.