Skip to content

Commit

Permalink
Remove input_quant_func from AffineQuantizedTensor subclass (#243)
Browse files Browse the repository at this point in the history
* Remove input_quant_func from AffineQuantizedTensor subclass

Summary:
Currently we have a input_quant_func in the AffineQuantizedTensor, which is a bit convoluted, we want to use a
separate LinearActAffineQuantizedTensor subclass for activation quantization (dynamic quantization) instead

Test Plan:
python test/quantization/test_quant_api.py -k test_quantized_tensor_subclass_8da4w

Reviewers:

Subscribers:

Tasks:

Tags:

* Add dispatch for dynamic quantization in `AffineQuantizedTensor`

Summary:
This PR added dispatch for int8act-int8 weight dynamic quantization that's calling `int_scaled_matmul` kernel in the end

Test Plan:
python test/quantization/test_quant_api.py -k test_quantized_tensor_subclass_int8_dyn_quant

Reviewers:

Subscribers:

Tasks:

Tags:

* Fix test
  • Loading branch information
jerryzh168 authored May 16, 2024
1 parent cae3d82 commit cda787c
Show file tree
Hide file tree
Showing 2 changed files with 286 additions and 57 deletions.
86 changes: 73 additions & 13 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,10 @@ def test_eval_wrapper(self):
# TODO: move to a separate test file
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
def test_quantized_tensor_subclass_8da4w(self):
from torchao.quantization.subclass import AffineQuantizedTensor
from torchao.quantization.subclass import (
AffineQuantizedTensor,
LinearActQuantizedTensor,
)
from torchao.quantization.quant_primitives import MappingType
import copy

Expand All @@ -409,6 +412,7 @@ def test_quantized_tensor_subclass_8da4w(self):
quant_max = 7

# TODO: make a general helper function?
# input settings
def get_per_token_block_size(x):
block_size = []
for i in range(len(x.shape)-1):
Expand All @@ -421,13 +425,18 @@ def get_per_token_block_size(x):
input_target_dtype = torch.int8
input_quant_func = lambda x: AffineQuantizedTensor.from_float(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype)

def dynamic_quant(linear):
# note: order is important
linear.weight = torch.nn.Parameter(AffineQuantizedTensor.from_float(linear.weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps), requires_grad=False)
linear.weight = torch.nn.Parameter(LinearActQuantizedTensor.from_float(linear.weight, input_quant_func), requires_grad=False)

m = ToyLinearModel().eval()
m_copy = copy.deepcopy(m)
example_inputs = m.example_inputs()
m.linear1.weight = torch.nn.Parameter(AffineQuantizedTensor.from_float(m.linear1.weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, input_quant_func=input_quant_func), requires_grad=False)
m.linear2.weight = torch.nn.Parameter(AffineQuantizedTensor.from_float(m.linear2.weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, input_quant_func=input_quant_func), requires_grad=False)
assert isinstance(m.linear1.weight, AffineQuantizedTensor)
assert isinstance(m.linear2.weight, AffineQuantizedTensor)
dynamic_quant(m.linear1)
dynamic_quant(m.linear2)
assert isinstance(m.linear1.weight, LinearActQuantizedTensor)
assert isinstance(m.linear2.weight, LinearActQuantizedTensor)

# reference
from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer
Expand Down Expand Up @@ -461,9 +470,6 @@ def test_quantized_tensor_subclass_int4(self):
preserve_zero = False
zero_point_dtype = torch.bfloat16

# weight only quantization
input_quant_func = None

# use 1024 so that we don't need padding
m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda")
m_copy = copy.deepcopy(m)
Expand All @@ -475,7 +481,6 @@ def to_quantized(weight):
zero_point_dtype=zero_point_dtype,
preserve_zero=preserve_zero,
zero_point_domain=ZeroPointDomain.FLOAT,
input_quant_func=input_quant_func,
)

m.linear1.weight = torch.nn.Parameter(to_quantized(m.linear1.weight), requires_grad=False)
Expand Down Expand Up @@ -506,16 +511,13 @@ def test_quantized_tensor_subclass_int8(self):
eps = torch.finfo(torch.float32).eps
zero_point_dtype = torch.int64

# weight only quantization
input_quant_func = None

m = ToyLinearModel().eval().to(torch.bfloat16)
m_copy = copy.deepcopy(m)
example_inputs = tuple(map(lambda x: x.to(torch.bfloat16), m.example_inputs()))

def to_quantized(weight):
block_size = (1, weight.shape[1])
return AffineQuantizedTensor.from_float(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype, input_quant_func=input_quant_func)
return AffineQuantizedTensor.from_float(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype)

m.linear1.weight = torch.nn.Parameter(to_quantized(m.linear1.weight), requires_grad=False)
m.linear2.weight = torch.nn.Parameter(to_quantized(m.linear2.weight), requires_grad=False)
Expand All @@ -532,5 +534,63 @@ def to_quantized(weight):
torch.testing.assert_close(res, ref, rtol=0.00001, atol=1e-2)


@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_quantized_tensor_subclass_int8_dyn_quant(self):
from torchao.quantization.subclass import AffineQuantizedTensor
from torchao.quantization.subclass import LinearActQuantizedTensor
from torchao.quantization.quant_primitives import MappingType
from torchao.quantization.quant_primitives import ZeroPointDomain
import copy

# weight settings
mapping_type = MappingType.SYMMETRIC
def get_weight_block_size(x):
return (1, x.shape[1])
target_dtype = torch.int8
eps = torch.finfo(torch.float32).eps
zero_point_dtype = torch.int64

# input settings
def get_per_token_block_size(x):
block_size = list(x.shape)
for i in range(len(block_size)-1):
block_size[i] = 1
return block_size

input_mapping_type = MappingType.SYMMETRIC
input_target_dtype = torch.int8
input_eps = 1e-5
input_quant_min = -127
input_quant_max = 127
input_quant_func = lambda x: AffineQuantizedTensor.from_float(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype, eps=input_eps, quant_min=input_quant_min, quant_max=input_quant_max, scale_dtype=torch.float)

# use 1024 so that we don't need padding
m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda")
m_copy = copy.deepcopy(m)
example_inputs = tuple(map(lambda x: x.to(torch.bfloat16).to("cuda"), m.example_inputs()))

def dynamic_quant(linear):
# note: order is important
linear.weight = torch.nn.Parameter(AffineQuantizedTensor.from_float(linear.weight, mapping_type, get_weight_block_size(linear.weight), target_dtype, eps=eps, zero_point_dtype=zero_point_dtype), requires_grad=False)
linear.weight = torch.nn.Parameter(LinearActQuantizedTensor.from_float(linear.weight, input_quant_func), requires_grad=False)

dynamic_quant(m.linear1)
dynamic_quant(m.linear2)
assert isinstance(m.linear1.weight, LinearActQuantizedTensor)
assert isinstance(m.linear2.weight, LinearActQuantizedTensor)
assert isinstance(m.linear1.weight.original_weight_tensor, AffineQuantizedTensor)
assert isinstance(m.linear2.weight.original_weight_tensor, AffineQuantizedTensor)

# reference
from torchao.quantization.quant_api import change_linear_weights_to_int8_dqtensors
change_linear_weights_to_int8_dqtensors(m_copy)

res = m(*example_inputs)
ref = m_copy(*example_inputs)

self.assertTrue(torch.equal(res, ref))


if __name__ == "__main__":
unittest.main()
Loading

0 comments on commit cda787c

Please sign in to comment.