Skip to content

Commit

Permalink
Add dispatch for dynamic quantization in AffineQuantizedTensor
Browse files Browse the repository at this point in the history
Summary:
This PR added dispatch for int8act-int8 weight dynamic quantization that's calling `int_scaled_matmul` kernel in the end

Test Plan:
python test/quantization/test_quant_api.py -k test_quantized_tensor_subclass_int8_dyn_quant

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 committed May 15, 2024
1 parent d71a65a commit 94a058c
Show file tree
Hide file tree
Showing 2 changed files with 227 additions and 134 deletions.
67 changes: 64 additions & 3 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,7 @@ def test_quantized_tensor_subclass_8da4w(self):
quant_max = 7

# TODO: make a general helper function?
# input settings
def get_per_token_block_size(x):
block_size = []
for i in range(len(x.shape)-1):
Expand All @@ -422,18 +423,20 @@ def get_per_token_block_size(x):
# input settings
input_mapping_type = MappingType.ASYMMETRIC
input_target_dtype = torch.int8
input_quant_func = lambda x: AffineQuantizedTensor.from_float(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype)

def dynamic_quant(linear):
linear.weight = torch.nn.Parameter(LinearActAffineQuantizedTensor.from_float(linear.weight, input_mapping_type, get_per_token_block_size, input_target_dtype, quant_min, quant_max, eps), requires_grad=False)
# note: order is important
linear.weight = torch.nn.Parameter(AffineQuantizedTensor.from_float(linear.weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps), requires_grad=False)
linear.weight = torch.nn.Parameter(LinearActAffineQuantizedTensor.from_float(linear.weight, input_quant_func), requires_grad=False)

m = ToyLinearModel().eval()
m_copy = copy.deepcopy(m)
example_inputs = m.example_inputs()
dynamic_quant(m.linear1)
dynamic_quant(m.linear2)
assert isinstance(m.linear1.weight, AffineQuantizedTensor)
assert isinstance(m.linear2.weight, AffineQuantizedTensor)
assert isinstance(m.linear1.weight, LinearActAffineQuantizedTensor)
assert isinstance(m.linear2.weight, LinearActAffineQuantizedTensor)

# reference
from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer
Expand Down Expand Up @@ -531,5 +534,63 @@ def to_quantized(weight):
torch.testing.assert_close(res, ref, rtol=0.00001, atol=1e-2)


@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_quantized_tensor_subclass_int8_dyn_quant(self):
from torchao.quantization.subclass import AffineQuantizedTensor
from torchao.quantization.subclass import LinearActAffineQuantizedTensor
from torchao.quantization.quant_primitives import MappingType
from torchao.quantization.quant_primitives import ZeroPointDomain
import copy

# weight settings
mapping_type = MappingType.SYMMETRIC
def get_weight_block_size(x):
return (1, x.shape[1])
target_dtype = torch.int8
eps = torch.finfo(torch.float32).eps
zero_point_dtype = torch.int64

# input settings
def get_per_token_block_size(x):
block_size = list(x.shape)
for i in range(len(block_size)-1):
block_size[i] = 1
return block_size

input_mapping_type = MappingType.SYMMETRIC
input_target_dtype = torch.int8
input_eps = 1e-5
input_quant_min = -127
input_quant_max = 127
input_quant_func = lambda x: AffineQuantizedTensor.from_float(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype, eps=input_eps, quant_min=input_quant_min, quant_max=input_quant_max, scale_dtype=torch.float)

# use 1024 so that we don't need padding
m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda")
m_copy = copy.deepcopy(m)
example_inputs = tuple(map(lambda x: x.to(torch.bfloat16).to("cuda"), m.example_inputs()))

def dynamic_quant(linear):
# note: order is important
linear.weight = torch.nn.Parameter(AffineQuantizedTensor.from_float(linear.weight, mapping_type, get_weight_block_size(linear.weight), target_dtype, eps=eps, zero_point_dtype=zero_point_dtype), requires_grad=False)
linear.weight = torch.nn.Parameter(LinearActAffineQuantizedTensor.from_float(linear.weight, input_quant_func), requires_grad=False)

dynamic_quant(m.linear1)
dynamic_quant(m.linear2)
assert isinstance(m.linear1.weight, LinearActAffineQuantizedTensor)
assert isinstance(m.linear2.weight, LinearActAffineQuantizedTensor)
assert isinstance(m.linear1.weight.original_weight_tensor, AffineQuantizedTensor)
assert isinstance(m.linear2.weight.original_weight_tensor, AffineQuantizedTensor)

# reference
from torchao.quantization.quant_api import change_linear_weights_to_int8_dqtensors
change_linear_weights_to_int8_dqtensors(m_copy)

res = m(*example_inputs)
ref = m_copy(*example_inputs)

self.assertTrue(torch.equal(res, ref))


if __name__ == "__main__":
unittest.main()
Loading

0 comments on commit 94a058c

Please sign in to comment.