Skip to content

Commit

Permalink
Fix CI after quantize op change in PyTorch core
Browse files Browse the repository at this point in the history
Summary: pytorch/pytorch#125781 recently
changed the numerics of the quantize op subtly. This commit fixes
the numerics mismatch caused by this PR by making our quantize
ops consistent with the ones in core.

Test Plan:
python test/quantization/test_quant_primitives.py -k test_quantize_dequantize_group_sym
python test/quantization/test_quant_api.py TestQuantFlow.test_quantized_tensor_subclass_8da4w

Reviewers: jerryzh168, cpuhrsch

Subscribers: jerryzh168, cpuhrsch, supriyar
  • Loading branch information
andrewor14 committed May 15, 2024
1 parent 10da375 commit f972443
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 11 deletions.
16 changes: 8 additions & 8 deletions test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
fake_quantize_per_token,
)
from torchao.quantization.quant_primitives import get_group_qparams_symmetric
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_3
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4


# TODO: put this in a common test utils file
Expand Down Expand Up @@ -58,7 +58,7 @@ def _get_qmin_qmax(self, n_bit: int):
qmax = 2 ** (n_bit - 1) - 1
return (qmin, qmax)

@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower")
def test_fake_quantize_per_channel_group(self):
n_bit = 4
(qmin, qmax) = self._get_qmin_qmax(n_bit)
Expand All @@ -84,7 +84,7 @@ def test_fake_quantize_per_channel_group(self):
)
torch.testing.assert_close(out, out_ptq, atol=0, rtol=0)

@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower")
def test_fake_quantize_per_token(self):
(qmin, qmax) = self._get_qmin_qmax(8)

Expand Down Expand Up @@ -130,7 +130,7 @@ def _set_ptq_weight(
ptq_linear.scales = s
ptq_linear.zeros = zp

@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower")
def test_qat_8da4w_linear(self):
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATLinear
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
Expand All @@ -155,7 +155,7 @@ def test_qat_8da4w_linear(self):
ptq_out = ptq_linear(x2)
torch.testing.assert_close(ptq_out, qat_out, atol=0, rtol=0)

@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower")
def test_qat_8da4w_quantizer(self):
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
from torchao.quantization.GPTQ import Int8DynActInt4WeightQuantizer
Expand Down Expand Up @@ -189,7 +189,7 @@ def test_qat_8da4w_quantizer(self):
for k in ptq_state_dict.keys():
torch.testing.assert_close(ptq_state_dict[k], converted_state_dict[k], atol=0, rtol=0)

@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower")
def test_qat_8da4w_quantizer_meta_weights(self):
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer

Expand All @@ -201,7 +201,7 @@ def test_qat_8da4w_quantizer_meta_weights(self):
qat_model = qat_quantizer.prepare(m)
self.assertTrue(all(v.is_meta for v in qat_model.state_dict().values()))

@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower")
def test_qat_8da4w_quantizer_disable_fake_quant(self):
"""
Test that 8da4w QAT with disabled fake quant matches nn.Linear in forward.
Expand Down Expand Up @@ -254,7 +254,7 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self):
qat_out2 = qat_model2(*x2)
torch.testing.assert_close(qat_out, qat_out2, atol=0, rtol=0)

@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower")
def test_qat_8da4w_quantizer_disable_fake_quant_backward(self):
"""
Test that 8da4w QAT with disabled fake quant matches nn.Linear in backward.
Expand Down
2 changes: 1 addition & 1 deletion test/quantization/test_quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def test_quantize_activation_per_token_abs_max_zero_input(self):
quantized_ref, scale_ref = quantize_activation_per_token_absmax(input)


@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower")
def test_quantize_dequantize_group_sym(self):
input = torch.randn(10, 10)
mapping_type = MappingType.SYMMETRIC
Expand Down
2 changes: 1 addition & 1 deletion torchao/quantization/prototype/qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def forward(ctx, input, scales, zero_points, quant_min, quant_max):
# which rounds first before adding the zero points. However, this
# is what `quantize_per_channel_group` and `quantize_per_token`
# do and here we try to match that behavior as closely as possible.
q = input.div(scales).add(zero_points).round()
q = input.mul(1.0 / scales).add(zero_points).round()
dq = q.clamp(quant_min, quant_max).sub(zero_points).mul(scales)
# TODO: do we need this mask?
mask = torch.logical_and((q >= quant_min), (q <= quant_max))
Expand Down
2 changes: 1 addition & 1 deletion torchao/quantization/quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def quantize_affine(

if zero_point_domain == ZeroPointDomain.INT:
quant = torch.clamp(
torch.round(input / scale) + zero_point, quant_min, quant_max
torch.round(input * (1.0 / scale)) + zero_point, quant_min, quant_max
).to(output_dtype)
else:
assert zero_point_domain == ZeroPointDomain.FLOAT
Expand Down

0 comments on commit f972443

Please sign in to comment.