Skip to content

Commit

Permalink
autoquant using aqt (#609)
Browse files Browse the repository at this point in the history
* autoquant using aqt

Summary:

changing autoquant to use aqt instead of the old subclass subtensors changed aqt to first dispatch to a static _quantized_linear_op which then dispatches to the normal function. This way autoquant has an extention point to modify the kernel functions for various quantization modes without editing the main kernel function of all the classes. linear_activation_quantized_tensor got the same treatment.

there were some transposes found in the aqt kernels not present in the subclass kernels, however they do not seen to affect performance (see benchmark_results.txt for an autoquant perf run)

Test Plan:

sh benchmarks.sh

python test_integration.py

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
HDCharles authored Aug 8, 2024
1 parent 1cfe69e commit 934dead
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 24 deletions.
12 changes: 6 additions & 6 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,6 +672,8 @@ def _test_lin_weight_subclass_impl(
test_dtype=torch.bfloat16,
test_shape=(32, 64, 32),
):
if not "cuda" in test_device:
self.skipTest("test requires cuda")
m, k, n = test_shape
x = torch.randn(m, k, device=test_device, dtype=test_dtype)
lin = torch.nn.Linear(k, n, device=test_device).to(test_dtype)
Expand Down Expand Up @@ -709,30 +711,28 @@ def test_int8_weight_only_quant_subclass(self, device, dtype):
)

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AFTER_2_5, "autoquant+aqt needs newer pytorch")
def test_aq_int8_dynamic_quant_subclass(self, device, dtype):
self._test_lin_weight_subclass_impl(
AQInt8DynamicallyQuantizedLinearWeight.from_float, device, 35, test_dtype=dtype
)

@parameterized.expand(COMMON_DEVICE_DTYPE)
def test_aq_int8_weight_only_quant_subclass(self, device, dtype):
self._test_lin_weight_subclass_impl(
AQInt8DynamicallyQuantizedLinearWeight.from_float, device, 35, test_dtype=dtype
)

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AFTER_2_5, "autoquant+aqt needs newer pytorch")
def test_aq_int8_weight_only_quant_subclass(self, device, dtype):
self._test_lin_weight_subclass_impl(
AQWeightOnlyQuantizedLinearWeight.from_float, device, 35, test_dtype=dtype
)

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AFTER_2_5, "autoquant+aqt needs newer pytorch")
def test_aq_int8_weight_only_quant_2_subclass(self, device, dtype):
self._test_lin_weight_subclass_impl(
AQWeightOnlyQuantizedLinearWeight2.from_float, device, 35, test_dtype=dtype
)

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AFTER_2_5, "autoquant+aqt needs newer pytorch")
def test_aq_int8_weight_only_quant_3_subclass(self, device, dtype):
self._test_lin_weight_subclass_impl(
AQWeightOnlyQuantizedLinearWeight3.from_float, device, 35, test_dtype=dtype
Expand Down
6 changes: 5 additions & 1 deletion torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,10 @@ def dequantize(self, output_dtype=None):
int_data, scale, zero_point = self.layout_tensor.get_plain()
return dequantize_affine(int_data, self.block_size, scale, zero_point, int_data.dtype, self.quant_min, self.quant_max, self.zero_point_domain, output_dtype=output_dtype)

@staticmethod
def _quantized_linear_op(input_tensor, weight_tensor, bias):
return _quantized_linear_op(input_tensor, weight_tensor, bias)

def __tensor_flatten__(self):
return ["layout_tensor"], [self.block_size, self.shape, self.quant_min, self.quant_max, self.zero_point_domain, self.dtype]

Expand Down Expand Up @@ -832,7 +836,7 @@ def _(func, types, args, kwargs):
# is not picked up by any of the dispatch paths in `_quantized_linear_op`, this allows us to
# make the branches easier to understand in `_quantized_linear_op`
try:
return _quantized_linear_op(input_tensor, weight_tensor, bias)
return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias)
except:
if isinstance(input_tensor, AffineQuantizedTensor):
input_tensor = input_tensor.dequantize()
Expand Down
80 changes: 67 additions & 13 deletions torchao/quantization/autoquant.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
import torch
import torchao
from torchao.quantization.quant_primitives import (
MappingType,
ZeroPointDomain,
)
from .subclass import ( # noqa
Int8DynamicallyQuantizedLinearWeight,
Int8WeightOnlyQuantizedLinearWeight,
QuantizedLinearWeightBase,
)
from torchao.dtypes import AffineQuantizedTensor, PlainLayoutType
from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor
from torch.utils._python_dispatch import return_and_correct_aliasing
from .quant_primitives import (
safe_int_mm,
Expand Down Expand Up @@ -252,9 +258,9 @@ class AQMixin():
def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]):
w_qtensor = cls.from_float(weight)
if _is_interpolate_mode(mode):
q_c_op = torch.compile(cls._quantized_op, mode="max-autotune-no-cudagraphs")
q_c_op = torch.compile(cls._quantized_linear_op, mode="max-autotune-no-cudagraphs")
else:
func = lambda a,b,c: F.relu(cls._quantized_op(F.relu(a), b, c))
func = lambda a,b,c: F.relu(cls._quantized_linear_op(F.relu(a), b, c))
q_c_op = torch.compile(func, mode="max-autotune-no-cudagraphs")
res = do_autoquant_bench(q_c_op, act_mat, w_qtensor, bias, warmup=25, rep=100)
if res < best_time*1.1:
Expand All @@ -263,10 +269,48 @@ def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]):
print(f">>time: {res:0.3f}ms for {cls}, to_beat: {best_time:0.3f}ms ")
return res

class AQInt8DynamicallyQuantizedLinearWeight(AQMixin, Int8DynamicallyQuantizedLinearWeight):
class AQInt8DynamicallyQuantizedLinearWeight(AQMixin, LinearActivationQuantizedTensor):
"""
AutoQuantizable version of Int8DynamicallyQuantizedLinearWeight
"""
@classmethod
def from_float(cls, weight):
# TODO test if this is valid
# in_features = weight.shape[1]
# int8 dynamic quantization only has benefit when in_feature > 16
# if in_features <= 16:
# return weight

# avoid circular dep
from torchao.dtypes import to_affine_quantized
# weight settings
mapping_type = MappingType.SYMMETRIC
def get_weight_block_size(x):
return (1, x.shape[1])
target_dtype = torch.int8
eps = torch.finfo(torch.float32).eps
zero_point_dtype = torch.int64

# input settings
def get_per_token_block_size(x):
block_size = list(x.shape)
for i in range(len(block_size)-1):
block_size[i] = 1
return block_size

input_mapping_type = MappingType.SYMMETRIC
input_target_dtype = torch.int8
input_eps = 1e-5
input_quant_min = -127
input_quant_max = 127
layout_type = PlainLayoutType()
input_quant_func = lambda x: to_affine_quantized(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype, eps=input_eps, quant_min=input_quant_min, quant_max=input_quant_max, scale_dtype=torch.float32 if x.dtype == torch.float16 else None)

block_size = get_weight_block_size(weight)
weight = to_affine_quantized(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype, layout_type=layout_type)
weight = super(AQInt8DynamicallyQuantizedLinearWeight, cls).from_float(weight, input_quant_func)
return weight

@classmethod
def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]):
"""
Expand Down Expand Up @@ -298,7 +342,8 @@ def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]):
)
q_c_matmul=torch.compile(quantized_matmul, mode="max-autotune-no-cudagraphs")
with torch.no_grad():
res_matmul = do_autoquant_bench(q_c_matmul, x_vals_int8, x_scales.reshape(-1,1), w_qtensor.int_data)
w_vals_int8 = w_qtensor.original_weight_tensor.layout_tensor.int_data.contiguous().t()
res_matmul = do_autoquant_bench(q_c_matmul, x_vals_int8, x_scales.reshape(-1,1), w_vals_int8)
print(f">>time: {res_matmul:0.3f}ms for {cls} matmul, to_beat: {best_time:0.3f}ms")

# if the (much faster) matmul kernel is already beat, don't bother benchmarking full op
Expand All @@ -313,18 +358,27 @@ def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]):
print(f">>time: {res_f:0.3f}ms for {cls} interpolated, breakeven constant: {max_int_const_win:0.2f}")
return res_f

class AQWeightOnlyQuantizedLinearWeight(Int8WeightOnlyQuantizedLinearWeight, AQMixin):
class AQWeightOnlyQuantizedLinearWeight(AffineQuantizedTensor, AQMixin):
"""
AutoQuantizable version of Int8WeightOnlyQuantizedLinearWeight
"""
@classmethod
def from_float(cls, weight):
mapping_type = MappingType.SYMMETRIC
target_dtype = torch.int8
eps = torch.finfo(torch.float32).eps
zero_point_dtype = torch.int64
block_size = (1, weight.shape[1])
return super(AQWeightOnlyQuantizedLinearWeight, cls).from_float(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype)


class AQWeightOnlyQuantizedLinearWeight2(Int8WeightOnlyQuantizedLinearWeight, AQMixin):
class AQWeightOnlyQuantizedLinearWeight2(AQWeightOnlyQuantizedLinearWeight, AQMixin):
"""
AutoQuantizable version of Int8WeightOnlyQuantizedLinearWeight that
uses a different kernel
"""
@staticmethod
def _quantized_op(act_mat, w_qtensor, bias):
def _quantized_linear_op(act_mat, w_qtensor, bias):
"""
Performs the quantized linear operations
Expand All @@ -339,8 +393,8 @@ def _quantized_op(act_mat, w_qtensor, bias):
orig_dtype = act_mat.dtype
orig_shape = act_mat.shape
act_mat = act_mat.reshape(-1, act_mat.shape[-1], 1)
y = (act_mat*w_qtensor.int_data.unsqueeze(0)).sum(dim=-2)
y = y.reshape(*orig_shape[:-1], y.shape[-1]) * w_qtensor.q_scales
y = (act_mat*w_qtensor.layout_tensor.int_data.t().unsqueeze(0)).sum(dim=-2)
y = y.reshape(*orig_shape[:-1], y.shape[-1]) * w_qtensor.layout_tensor.scale
if bias is not None:
y += bias
return y.to(orig_dtype)
Expand All @@ -352,14 +406,14 @@ def _autoquant_test(cls, act_mat, *args):
return torch.inf
return super()._autoquant_test(act_mat, *args)

class AQWeightOnlyQuantizedLinearWeight3(Int8WeightOnlyQuantizedLinearWeight, AQMixin):
class AQWeightOnlyQuantizedLinearWeight3(AQWeightOnlyQuantizedLinearWeight, AQMixin):
"""
AutoQuantizable version of Int8WeightOnlyQuantizedLinearWeight that
uses a different kernel
"""
def _quantized_op(act_mat, w_qtensor, bias):
def _quantized_linear_op(act_mat, w_qtensor, bias):
orig_shape = act_mat.shape
y = torch.mm(act_mat.reshape(-1, orig_shape[-1]), w_qtensor.int_data*w_qtensor.q_scales)
y = torch.mm(act_mat.reshape(-1, orig_shape[-1]), w_qtensor.layout_tensor.int_data.t()*w_qtensor.layout_tensor.scale)
y=y.reshape(*orig_shape[:-1], y.shape[-1])
if bias is not None:
y += bias
Expand All @@ -377,7 +431,7 @@ def __init__(self):
super().__init__()

@staticmethod
def _quantized_op(act_mat, w_qtensor, bias):
def _quantized_linear_op(act_mat, w_qtensor, bias):
return torch.nn.functional.linear(act_mat, w_qtensor, bias)

@classmethod
Expand Down
12 changes: 8 additions & 4 deletions torchao/quantization/linear_activation_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,13 @@ def __tensor_unflatten__(
input_quant_func,
)

@staticmethod
def _quantized_linear_op(input_tensor, weight_tensor, bias):
input_quant_func = weight_tensor.input_quant_func
original_weight_tensor = weight_tensor.original_weight_tensor
aqt = input_quant_func(input_tensor)
return torch.nn.functional.linear(aqt, original_weight_tensor, bias)

@classmethod
def from_float(cls, input_float, input_quant_func):
return cls(input_float, input_quant_func)
Expand Down Expand Up @@ -101,10 +108,7 @@ def _(func, types, args, kwargs):
args[2] if len(args) > 2 else None,
)
if isinstance(weight_tensor, LinearActivationQuantizedTensor):
input_quant_func = weight_tensor.input_quant_func
original_weight_tensor = weight_tensor.original_weight_tensor
aqt = input_quant_func(input_tensor)
return torch.nn.functional.linear(aqt, original_weight_tensor, bias)
return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias)

raise NotImplementedError("LinearActivationQuantizedTensor: No specialized dispatch found for linear op")

Expand Down

0 comments on commit 934dead

Please sign in to comment.