diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index cef78b32e..232fbef81 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -423,7 +423,7 @@ def test_qat_4w_primitives(self): @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") # TODO: remove once we fix int4 error: https://github.com/pytorch/ao/pull/517 - @unittest.skip(TORCH_VERSION_AFTER_2_4, "assert input.dtype == torch.float32" ) + @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_4, "assert input.dtype == torch.float32" ) @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 doesn't work for 2.5+ right now") def test_qat_4w_linear(self): from torchao.quantization.prototype.qat.api import Int4WeightOnlyQATLinear @@ -454,7 +454,7 @@ def test_qat_4w_linear(self): @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") # TODO: remove once we fix int4 error: https://github.com/pytorch/ao/pull/517 - @unittest.skip(TORCH_VERSION_AFTER_2_4, "assert input.dtype == torch.float32" ) + @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_4, "assert input.dtype == torch.float32" ) @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 doesn't work for 2.5+ right now") def test_qat_4w_quantizer(self): from torchao.quantization.prototype.qat import Int4WeightOnlyQATQuantizer