diff --git a/test/prototype/test_low_bit_optim.py b/test/prototype/test_low_bit_optim.py index e69d0ed6f..28dd37740 100644 --- a/test/prototype/test_low_bit_optim.py +++ b/test/prototype/test_low_bit_optim.py @@ -230,6 +230,7 @@ def world_size(self) -> int: return 2 @pytest.mark.skipif(not TORCH_VERSION_AFTER_2_4, reason="torch >= 2.4 required") + @pytest.mark.skipif(TORCH_VERSION_AFTER_2_5, reason="https://github.com/pytorch/ao/issues/652") @skip_if_lt_x_gpu(2) def test_fsdp2(self): optim_classes = [low_bit_optim.Adam8bit, low_bit_optim.Adam4bit]