Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
msaroufim committed Aug 15, 2024
1 parent dc48906 commit 7d0a9e7
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@ def test_qat_4w_primitives(self):
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
# TODO: remove once we fix int4 error: https://github.com/pytorch/ao/pull/517
@unittest.skip(TORCH_VERSION_AFTER_2_4, "assert input.dtype == torch.float32" )
@unittest.skipIf(TORCH_VERSION_AT_LEAST_2_4, "assert input.dtype == torch.float32" )
@unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 doesn't work for 2.5+ right now")
def test_qat_4w_linear(self):
from torchao.quantization.prototype.qat.api import Int4WeightOnlyQATLinear
Expand Down Expand Up @@ -454,7 +454,7 @@ def test_qat_4w_linear(self):
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
# TODO: remove once we fix int4 error: https://github.com/pytorch/ao/pull/517
@unittest.skip(TORCH_VERSION_AFTER_2_4, "assert input.dtype == torch.float32" )
@unittest.skipIf(TORCH_VERSION_AT_LEAST_2_4, "assert input.dtype == torch.float32" )
@unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 doesn't work for 2.5+ right now")
def test_qat_4w_quantizer(self):
from torchao.quantization.prototype.qat import Int4WeightOnlyQATQuantizer
Expand Down

0 comments on commit 7d0a9e7

Please sign in to comment.