Skip to content

Commit

Permalink
retry version guard fix (#679)
Browse files Browse the repository at this point in the history
* retry version guard fix

* push

* push

* push

* push

* push
  • Loading branch information
msaroufim authored Aug 14, 2024
1 parent 582b6d4 commit 1acd710
Show file tree
Hide file tree
Showing 47 changed files with 257 additions and 227 deletions.
4 changes: 2 additions & 2 deletions benchmarks/benchmark_aq.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
Int4WeightOnlyQuantizedLinearWeight,
)
from torchao.utils import (
TORCH_VERSION_AFTER_2_4,
TORCH_VERSION_AT_LEAST_2_4,
)
from torchao.quantization.quant_api import (
_replace_with_custom_fn_if_matches_filter,
Expand Down Expand Up @@ -105,7 +105,7 @@ def _bench_quantized_tensor_subclass_perf(api, ref_api, kwargs=None):
print(f"elapsed time: {elapsed_time}, ref elapsed time: {ref_elapsed_time}")
assert elapsed_time < 1.05 * ref_elapsed_time

if __name__ == "__main__" and TORCH_VERSION_AFTER_2_4 and torch.cuda.is_available():
if __name__ == "__main__" and TORCH_VERSION_AT_LEAST_2_4 and torch.cuda.is_available():
from torchao.quantization.quant_api import change_linear_weights_to_int8_dqtensors
_bench_quantized_tensor_subclass_perf(change_linear_weights_to_int8_dqtensors, _ref_change_linear_weights_to_int8_dqtensors)

Expand Down
2 changes: 1 addition & 1 deletion benchmarks/intmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pathlib

import torch
from torchao.utils import TORCH_VERSION_AFTER_2_4, TORCH_VERSION_AFTER_2_2
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_2


# Check if CUDA is available, if not, exit the script
Expand Down
4 changes: 2 additions & 2 deletions test/dtypes/test_affine_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import unittest
import tempfile
from torchao.utils import (
TORCH_VERSION_AFTER_2_5,
TORCH_VERSION_AT_LEAST_2_5,
)


Expand Down Expand Up @@ -46,7 +46,7 @@ def test_weights_only(self):
torch.save(ql.state_dict(), f)
f.seek(0)
# `weights_only=True` is enabled for torch 2.5+
if TORCH_VERSION_AFTER_2_5:
if TORCH_VERSION_AT_LEAST_2_5:
_ = torch.load(f, weights_only=True)
else:
_ = torch.load(f, weights_only=False)
Expand Down
4 changes: 2 additions & 2 deletions test/dtypes/test_bitnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
from torchao.prototype.dtypes import BitnetTensor
from torchao.prototype.dtypes.uint2 import unpack_uint2
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter
from torchao.utils import TORCH_VERSION_AFTER_2_4
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4

if not TORCH_VERSION_AFTER_2_4:
if not TORCH_VERSION_AT_LEAST_2_4:
pytest.skip("Unsupported PyTorch version", allow_module_level=True)

@pytest.fixture(autouse=True)
Expand Down
4 changes: 2 additions & 2 deletions test/dtypes/test_uint2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
import torch.nn as nn
from torchao.prototype.dtypes import UInt2Tensor
from torchao.prototype.dtypes.uint2 import unpack_uint2
from torchao.utils import TORCH_VERSION_AFTER_2_4
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4

if not TORCH_VERSION_AFTER_2_4:
if not TORCH_VERSION_AT_LEAST_2_4:
pytest.skip("Unsupported PyTorch version", allow_module_level=True)

@pytest.fixture
Expand Down
6 changes: 3 additions & 3 deletions test/dtypes/test_uintx.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from torchao.dtypes.uintx.Uintx import to_uintx
from torchao.quantization.quant_api import quantize_, uintx_weight_only
from torchao.utils import TORCH_VERSION_AFTER_2_5
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5

from torchao.quantization.quant_primitives import (
MappingType,
Expand Down Expand Up @@ -40,7 +40,7 @@ def forward(self, x):
@pytest.mark.parametrize("group_size", group_sizes)
@pytest.mark.parametrize("device", devices)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not TORCH_VERSION_AFTER_2_5, reason="only works with fix in the nightly build")
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="only works with fix in the nightly build")
def test_uintx_weight_only_model_quant(bit_width, group_size, device):
scale = 512
fp16 = Linear16(scale, device)
Expand All @@ -54,7 +54,7 @@ def test_uintx_weight_only_model_quant(bit_width, group_size, device):
@pytest.mark.parametrize("group_size", group_sizes)
@pytest.mark.parametrize("device", devices)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not TORCH_VERSION_AFTER_2_5, reason="only works with fix in the nightly build")
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="only works with fix in the nightly build")
def test_uintx_weight_only_quant(bit_width, group_size, device):
input_float = torch.randn((1, 256), dtype=torch.float16, device = device)
mapping_type = MappingType.SYMMETRIC
Expand Down
4 changes: 2 additions & 2 deletions test/float8/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
import torch
import torch.nn as nn

from torchao.utils import TORCH_VERSION_AFTER_2_4
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4

if not TORCH_VERSION_AFTER_2_4:
if not TORCH_VERSION_AT_LEAST_2_4:
pytest.skip("Unsupported PyTorch version", allow_module_level=True)


Expand Down
4 changes: 2 additions & 2 deletions test/float8/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@

import pytest

from torchao.utils import TORCH_VERSION_AFTER_2_4
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4

if not TORCH_VERSION_AFTER_2_4:
if not TORCH_VERSION_AT_LEAST_2_4:
pytest.skip("Unsupported PyTorch version", allow_module_level=True)

import torch
Expand Down
4 changes: 2 additions & 2 deletions test/float8/test_dtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@

import pytest

from torchao.utils import TORCH_VERSION_AFTER_2_4
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4

if not TORCH_VERSION_AFTER_2_4:
if not TORCH_VERSION_AT_LEAST_2_4:
pytest.skip("Unsupported PyTorch version", allow_module_level=True)

from torchao.float8 import Float8LinearConfig
Expand Down
4 changes: 2 additions & 2 deletions test/float8/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@

import fire

from torchao.utils import TORCH_VERSION_AFTER_2_4
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4

if not TORCH_VERSION_AFTER_2_4:
if not TORCH_VERSION_AT_LEAST_2_4:
pytest.skip("Unsupported PyTorch version", allow_module_level=True)

import torch
Expand Down
4 changes: 2 additions & 2 deletions test/float8/test_fsdp2/test_fsdp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import unittest
from typing import Any, List

from torchao.utils import TORCH_VERSION_AFTER_2_4
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4

if not TORCH_VERSION_AFTER_2_4:
if not TORCH_VERSION_AT_LEAST_2_4:
pytest.skip("Unsupported PyTorch version", allow_module_level=True)


Expand Down
4 changes: 2 additions & 2 deletions test/float8/test_fsdp_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@

import pytest

from torchao.utils import TORCH_VERSION_AFTER_2_4
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4

if not TORCH_VERSION_AFTER_2_4:
if not TORCH_VERSION_AT_LEAST_2_4:
pytest.skip("Unsupported PyTorch version", allow_module_level=True)

import torch
Expand Down
4 changes: 2 additions & 2 deletions test/float8/test_inference_flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@
import pytest
from unittest.mock import patch
from torchao.utils import (
TORCH_VERSION_AFTER_2_4,
TORCH_VERSION_AT_LEAST_2_4,
unwrap_tensor_subclass,
)

if not TORCH_VERSION_AFTER_2_4:
if not TORCH_VERSION_AT_LEAST_2_4:
pytest.skip("Unsupported PyTorch version", allow_module_level=True)

import torch
Expand Down
4 changes: 2 additions & 2 deletions test/float8/test_numerics_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@

import pytest

from torchao.utils import TORCH_VERSION_AFTER_2_4
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4

if not TORCH_VERSION_AFTER_2_4:
if not TORCH_VERSION_AT_LEAST_2_4:
pytest.skip("Unsupported PyTorch version", allow_module_level=True)

import torch
Expand Down
Loading

0 comments on commit 1acd710

Please sign in to comment.