From f9724430a36537e1bac9b970b697c153b0e65b9a Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Wed, 15 May 2024 11:41:28 -0700 Subject: [PATCH] Fix CI after quantize op change in PyTorch core Summary: https://github.com/pytorch/pytorch/pull/125781 recently changed the numerics of the quantize op subtly. This commit fixes the numerics mismatch caused by this PR by making our quantize ops consistent with the ones in core. Test Plan: python test/quantization/test_quant_primitives.py -k test_quantize_dequantize_group_sym python test/quantization/test_quant_api.py TestQuantFlow.test_quantized_tensor_subclass_8da4w Reviewers: jerryzh168, cpuhrsch Subscribers: jerryzh168, cpuhrsch, supriyar --- test/quantization/test_qat.py | 16 ++++++++-------- test/quantization/test_quant_primitives.py | 2 +- torchao/quantization/prototype/qat.py | 2 +- torchao/quantization/quant_primitives.py | 2 +- 4 files changed, 11 insertions(+), 11 deletions(-) diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index a0587d3ff..fe2db8066 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -18,7 +18,7 @@ fake_quantize_per_token, ) from torchao.quantization.quant_primitives import get_group_qparams_symmetric -from torchao.quantization.utils import TORCH_VERSION_AFTER_2_3 +from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4 # TODO: put this in a common test utils file @@ -58,7 +58,7 @@ def _get_qmin_qmax(self, n_bit: int): qmax = 2 ** (n_bit - 1) - 1 return (qmin, qmax) - @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower") + @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower") def test_fake_quantize_per_channel_group(self): n_bit = 4 (qmin, qmax) = self._get_qmin_qmax(n_bit) @@ -84,7 +84,7 @@ def test_fake_quantize_per_channel_group(self): ) torch.testing.assert_close(out, out_ptq, atol=0, rtol=0) - @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower") + @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower") def test_fake_quantize_per_token(self): (qmin, qmax) = self._get_qmin_qmax(8) @@ -130,7 +130,7 @@ def _set_ptq_weight( ptq_linear.scales = s ptq_linear.zeros = zp - @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower") + @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower") def test_qat_8da4w_linear(self): from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATLinear from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear @@ -155,7 +155,7 @@ def test_qat_8da4w_linear(self): ptq_out = ptq_linear(x2) torch.testing.assert_close(ptq_out, qat_out, atol=0, rtol=0) - @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower") + @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower") def test_qat_8da4w_quantizer(self): from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer from torchao.quantization.GPTQ import Int8DynActInt4WeightQuantizer @@ -189,7 +189,7 @@ def test_qat_8da4w_quantizer(self): for k in ptq_state_dict.keys(): torch.testing.assert_close(ptq_state_dict[k], converted_state_dict[k], atol=0, rtol=0) - @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower") + @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower") def test_qat_8da4w_quantizer_meta_weights(self): from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer @@ -201,7 +201,7 @@ def test_qat_8da4w_quantizer_meta_weights(self): qat_model = qat_quantizer.prepare(m) self.assertTrue(all(v.is_meta for v in qat_model.state_dict().values())) - @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower") + @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower") def test_qat_8da4w_quantizer_disable_fake_quant(self): """ Test that 8da4w QAT with disabled fake quant matches nn.Linear in forward. @@ -254,7 +254,7 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self): qat_out2 = qat_model2(*x2) torch.testing.assert_close(qat_out, qat_out2, atol=0, rtol=0) - @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower") + @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower") def test_qat_8da4w_quantizer_disable_fake_quant_backward(self): """ Test that 8da4w QAT with disabled fake quant matches nn.Linear in backward. diff --git a/test/quantization/test_quant_primitives.py b/test/quantization/test_quant_primitives.py index a64439a25..0fb48d761 100644 --- a/test/quantization/test_quant_primitives.py +++ b/test/quantization/test_quant_primitives.py @@ -156,7 +156,7 @@ def test_quantize_activation_per_token_abs_max_zero_input(self): quantized_ref, scale_ref = quantize_activation_per_token_absmax(input) - @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower") + @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower") def test_quantize_dequantize_group_sym(self): input = torch.randn(10, 10) mapping_type = MappingType.SYMMETRIC diff --git a/torchao/quantization/prototype/qat.py b/torchao/quantization/prototype/qat.py index d15e841d7..314543bb8 100644 --- a/torchao/quantization/prototype/qat.py +++ b/torchao/quantization/prototype/qat.py @@ -209,7 +209,7 @@ def forward(ctx, input, scales, zero_points, quant_min, quant_max): # which rounds first before adding the zero points. However, this # is what `quantize_per_channel_group` and `quantize_per_token` # do and here we try to match that behavior as closely as possible. - q = input.div(scales).add(zero_points).round() + q = input.mul(1.0 / scales).add(zero_points).round() dq = q.clamp(quant_min, quant_max).sub(zero_points).mul(scales) # TODO: do we need this mask? mask = torch.logical_and((q >= quant_min), (q <= quant_max)) diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index 4f39a6055..30c685448 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -201,7 +201,7 @@ def quantize_affine( if zero_point_domain == ZeroPointDomain.INT: quant = torch.clamp( - torch.round(input / scale) + zero_point, quant_min, quant_max + torch.round(input * (1.0 / scale)) + zero_point, quant_min, quant_max ).to(output_dtype) else: assert zero_point_domain == ZeroPointDomain.FLOAT