Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add decorator for custom op and inductor decomp registration #434

Merged
merged 1 commit into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1244,7 +1244,7 @@ def test_autoquant_manual(self, device, dtype):
out3 = mod(example_input)
sqnr2 = SQNR(out, out3)
self.assertTrue(sqnr2 >= 30)


@parameterized.expand(combine_parameters(COMMON_DEVICE_DTYPE,
[
Expand Down Expand Up @@ -1376,7 +1376,7 @@ class TestExport(unittest.TestCase):
list(itertools.product(TENSOR_SUBCLASS_APIS, COMMON_DEVICES, COMMON_DTYPES)),
)
@run_supported_device_dtype
def test_aoti(self, api, test_device, test_dtype):
def test_export(self, api, test_device, test_dtype):
if not TORCH_VERSION_AFTER_2_4:
self.skipTest("aoti compatibility requires 2.4+.")

Expand Down Expand Up @@ -1413,9 +1413,20 @@ def forward(self, x):

# make sure it compiles
example_inputs = (x,)
model = torch.export.export(model, example_inputs).module()
from torch._export import capture_pre_autograd_graph
# TODO: export changes numerics right now, this is because of functionalization according to Zhengxu
# we can re-enable this after non-functional IR is enabled in export
# model = torch.export.export(model, example_inputs).module()
model = capture_pre_autograd_graph(model, example_inputs)
after_export = model(x)
self.assertTrue(torch.equal(after_export, ref))
if api is _int8da_int8w_api:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this checking?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is because right now we will only see these ops for int8da_int8w quantization, other types of quant (e.g. int4 weight only) will call into the efficient kernels directly

we should probably figure out a path for executorch, I think we could abstract this with "layout", what would be a good name here?

targets = [n.target for n in model.graph.nodes]
self.assertTrue(torch.ops.quant.choose_qparams_affine.default in targets)
self.assertTrue(torch.ops.quant.quantize_affine.default in targets)




class TestUtils(unittest.TestCase):
@parameterized.expand(COMMON_DEVICE_DTYPE)
Expand Down
126 changes: 108 additions & 18 deletions torchao/quantization/quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,17 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from enum import Enum
from enum import Enum, auto
from typing import List, Optional, Tuple, Dict
import torch

from torchao.kernel.intmm import int_scaled_matmul
from torchao.kernel.intmm import safe_int_mm
from torchao.utils import TORCH_VERSION_AFTER_2_3
from torchao.utils import (
TORCH_VERSION_AFTER_2_3,
TORCH_VERSION_AFTER_2_5,
)
from torchao.utils import _register_custom_op


__all__ = [
Expand All @@ -34,17 +38,17 @@ class MappingType(Enum):
based on this mapping
e.g. scale = (10.2 - (-3.5)) / (7 - (-8))
"""
SYMMETRIC = 0
ASYMMETRIC = 1
SYMMETRIC = auto()
ASYMMETRIC = auto()

class ZeroPointDomain(Enum):
"""Enum that indicate whether zero_point is in integer domain or floating point domain

integer domain: quantized_val = (float_val / scale) (integer) + zero_point (integer)
float domain: quantized_val = (float_val - (zero_point (float) - scale * mid_point)) / scale
"""
INT = 0
FLOAT = 1
INT = auto()
FLOAT = auto()

"""
Map from dtype to the bound value of integers
Expand All @@ -69,6 +73,10 @@ class ZeroPointDomain(Enum):
})


quant_lib = torch.library.Library("quant", "FRAGMENT")

register_custom_op = _register_custom_op(quant_lib)

# TODO: decide on if we want to allow custom quant_min/quant_max here
def _get_and_check_qmin_qmax(dtype, quant_min, quant_max):
"""Get quant_min and quant_max args based on dtype and also
Expand Down Expand Up @@ -140,7 +148,7 @@ def quantize_affine(
quant_min: Optional[int] = None,
quant_max: Optional[int] = None,
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
):
) -> torch.Tensor:
"""
Args:
input (torch.Tensor): original float32, float16 or bfloat16 Tensor
Expand Down Expand Up @@ -174,6 +182,31 @@ def quantize_affine(
Output:
quantized tensor with requested dtype
"""
return _quantize_affine(
input,
block_size,
scale,
zero_point,
output_dtype,
quant_min,
quant_max,
zero_point_domain.name,
)


@register_custom_op
def _quantize_affine(
input: torch.Tensor,
block_size: List[int],
scale: torch.Tensor,
zero_point: Optional[torch.Tensor],
output_dtype: torch.dtype,
quant_min: Optional[int] = None,
quant_max: Optional[int] = None,
zero_point_domain: str = "INT",
) -> torch.Tensor:
"""op definition that has compatible signatures with custom op library
"""
# TODO: validations
# TODO: validate scale/zero_point dimensions are compatible with block_size
assert input.dtype in [torch.float32, torch.float16, torch.bfloat16], f"Unsupported input dtype: {input.dtype}"
Expand All @@ -188,12 +221,12 @@ def quantize_affine(
if zero_point is not None:
zero_point = zero_point.view(shape_after_reduction)

if zero_point_domain == ZeroPointDomain.INT:
if zero_point_domain == ZeroPointDomain.INT.name:
quant = torch.clamp(
torch.round(input * (1.0 / scale)) + zero_point, quant_min, quant_max
).to(output_dtype)
else:
assert zero_point_domain == ZeroPointDomain.FLOAT
assert zero_point_domain == ZeroPointDomain.FLOAT.name
mid_point = (quant_max + quant_min + 1) / 2
min_val = zero_point - scale * mid_point
quant = (
Expand All @@ -216,7 +249,7 @@ def dequantize_affine(
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
*,
output_dtype: torch.dtype = torch.float32,
):
) -> torch.Tensor:
"""
Args:
input (torch.Tensor): quantized tensor, should match the dtype `dtype` argument
Expand All @@ -238,6 +271,32 @@ def dequantize_affine(
Output:
dequantized Tensor, with requested dtype or fp32
"""
return _dequantize_affine(
input,
block_size,
scale,
zero_point,
input_dtype,
quant_min,
quant_max,
zero_point_domain.name,
output_dtype=output_dtype,
)

@register_custom_op
def _dequantize_affine(
input: torch.Tensor,
block_size: List[int],
scale: torch.Tensor,
zero_point: Optional[torch.Tensor],
input_dtype: torch.dtype,
quant_min: Optional[int] = None,
quant_max: Optional[int] = None,
zero_point_domain: str = "INT",
output_dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""op definition that has compatible signatures with custom op library
"""

# TODO: validations
# TODO: validate scale/zero_point dimensions are compatible with block_size
Expand All @@ -255,16 +314,16 @@ def dequantize_affine(
if zero_point is not None:
zero_point = zero_point.view(shape_after_reduction)

if zero_point_domain == ZeroPointDomain.INT:
if zero_point_domain == ZeroPointDomain.INT.name:
# Force a copy to avoid input modification due
# to upcoming in-place operations.
dequant = input.to(torch.int32, copy=True)
if zero_point is not None:
dequant -= zero_point.to(torch.int32)
dequant = dequant - zero_point.to(torch.int32)
dequant = dequant.to(output_dtype)
dequant *= scale
dequant = dequant * scale
else:
assert zero_point_domain == ZeroPointDomain.FLOAT, f"Unexpected zero point domain: {zero_point_domain}"
assert zero_point_domain == ZeroPointDomain.FLOAT.name, f"Unexpected zero point domain: {zero_point_domain}"
mid_point = (quant_max + quant_min + 1) / 2
# This should allocate new memory and avoid input modification
dequant = input - mid_point
Expand Down Expand Up @@ -320,8 +379,38 @@ def choose_qparams_affine(
Output:
Tuple of scales and zero_points Tensor with requested dtype
"""
return _choose_qparams_affine(
input,
mapping_type.name,
block_size,
target_dtype,
quant_min,
quant_max,
eps,
scale_dtype,
zero_point_dtype,
preserve_zero,
zero_point_domain.name
)

@register_custom_op
def _choose_qparams_affine(
input: torch.Tensor,
mapping_type: str,
block_size: List[int],
target_dtype: torch.dtype,
quant_min: Optional[int] = None,
quant_max: Optional[int] = None,
eps: Optional[float] = None,
scale_dtype: Optional[torch.dtype] = None,
zero_point_dtype: Optional[torch.dtype] = None,
preserve_zero: bool = True,
zero_point_domain: str = "INT",
) -> Tuple[torch.Tensor, torch.Tensor]:
"""op definition that has compatible signatures with custom op library
"""
quant_min, quant_max = _get_and_check_qmin_qmax(target_dtype, quant_min, quant_max)
assert mapping_type in [MappingType.SYMMETRIC, MappingType.ASYMMETRIC], f"Unsupported mapping type: {mapping_type}"
assert mapping_type in [MappingType.SYMMETRIC.name, MappingType.ASYMMETRIC.name], f"Unsupported mapping type: {mapping_type}"

if scale_dtype is None:
scale_dtype = input.dtype
Expand All @@ -342,21 +431,22 @@ def choose_qparams_affine(
min_val_neg = min_val
max_val_pos = max_val

if mapping_type == MappingType.SYMMETRIC:
if mapping_type == MappingType.SYMMETRIC.name:
max_val_pos = torch.max(-min_val_neg, max_val_pos)
scale = max_val_pos / (float(quant_max - quant_min) / 2)
if not preserve_zero:
raise ValueError("preserve_zero == False is not supported for symmetric quantization")
if zero_point_domain != ZeroPointDomain.INT:
if zero_point_domain != ZeroPointDomain.INT.name:
raise ValueError("zero_point_domain != ZeroPointDomain.INT is not supported for symmetric quantization")
zero_point = torch.full_like(scale, int((quant_max + quant_min + 1) / 2))
else:
assert mapping_type == MappingType.ASYMMETRIC.name
scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min)
if preserve_zero:
zero_point = quant_min - torch.round(min_val_neg / scale)
zero_point = torch.clamp(zero_point, quant_min, quant_max)
else:
assert zero_point_domain == ZeroPointDomain.FLOAT, "if not preserve_zero, zero_point must be in FLOAT domain"
assert zero_point_domain == ZeroPointDomain.FLOAT.name, "if not preserve_zero, zero_point must be in FLOAT domain"
mid_point = (quant_max + quant_min + 1) / 2
zero_point = min_val_neg + scale * mid_point

Expand Down
52 changes: 51 additions & 1 deletion torchao/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
"skip_if_compute_capability_less_than",
"benchmark_torch_function_in_microseconds",
"find_multiple",
"_register_custom_op",
"get_model_size_in_bytes",
"unwrap_tensor_subclass",
"TORCH_VERSION_AFTER_2_2",
Expand Down Expand Up @@ -65,7 +66,7 @@ def wrapper(*args, **kwargs):

def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
import torch.utils.benchmark as benchmark # this avoids importing numpy when torchao module is loaded

# Manual warmup
f(*args, **kwargs)
f(*args, **kwargs)
Expand All @@ -84,6 +85,55 @@ def find_multiple(n: int, *args: Tuple[int]) -> int:
return n
return n + k - (n % k)

def _register_custom_op(lib):
"""This decorator is used to preserve some high level operators for torch.export.export
while still allow them to be decomposed for inductor path

requirement: make sure `fn.__name__[1:]` is the operator name you want to register

NOTE: This should be applied at the top, after all other decorators have been applied
NOTE: We haven't tested the case when `fn` accepts tensor subclass instance as input,
e.g. uint4 tensor subclass instance, and we'll probably need to figure out what would make
sense for downstream system (like executorch) to accept as well

Example:
lib = torch.library.Library("my_namespace', "FRAGMENT")

register_custom_op = _register_custom_op(lib)

@register_custom_op
def _the_op_that_needs_to_be_preserved(...)
...

# after this, `_the_op_that_needs_to_be_preserved` will be preserved as
# torch.ops.my_namespace.the_op_that_needs_to_be_preserved operator after
# torch.export.export / torch._export.capture_pre_autograd_graph

"""
from torch._inductor.decomposition import register_decomposition

def decorator(fn):
if TORCH_VERSION_AFTER_2_5:
from torch._library.infer_schema import infer_schema

# expecting fn.__name__ starts with `_` and we want to take the rest
# to be the name of the custom op
assert fn.__name__[0] == "_", f"Expecting function name starts with `_`, got {fn.__name__}"
assert not any(c in fn.__name__ for c in ".<>"), f"Expecting op to be defined in normal functions, not lambda or local: {fn.__name__}"
op_name = fn.__name__[1:]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you assert there is no "." or "<" or ">" in fn.name? this can happen with lambdas or local functions

schema = op_name + infer_schema(fn)
lib.define(schema)
lib.impl(op_name, fn, "CompositeImplicitAutograd")

lib_namespace = lib.ns
op = getattr(getattr(torch.ops, lib_namespace), op_name)
register_decomposition([op])(fn)
return op
else:
return fn

return decorator

def get_model_size_in_bytes(model, ignore_embeddings=False):
"""
Returns the model size in bytes. The option to ignore embeddings
Expand Down
Loading