diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 87efb3962..0d11093fd 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -1428,6 +1428,7 @@ class TestAOTI(unittest.TestCase): @parameterized.expand( list(itertools.product(TENSOR_SUBCLASS_APIS, COMMON_DEVICES, COMMON_DTYPES)), ) + @run_supported_device_dtype def test_aoti(self, api, test_device, test_dtype): if not TORCH_VERSION_AFTER_2_4: self.skipTest("aoti compatibility requires 2.4+.") @@ -1442,11 +1443,6 @@ def test_aoti(self, api, test_device, test_dtype): if test_dtype != torch.bfloat16: self.skipTest(f"{api} in {test_dtype} is not support for aoti compilation yet") - if test_device == "cuda" and not torch.cuda.is_available(): - self.skipTest(f"Need CUDA available.") - if test_device == "cuda" and torch.cuda.is_available() and test_dtype == torch.bfloat16 and torch.cuda.get_device_capability() < (8, 0): - self.skipTest("Need CUDA and SM80+ available.") - m, k, n = 32, 64, 32 class test_model(nn.Module): @@ -1479,5 +1475,53 @@ def forward(self, x): torch._export.aot_compile(model, example_inputs) +class TestExport(unittest.TestCase): + @parameterized.expand( + list(itertools.product(TENSOR_SUBCLASS_APIS, COMMON_DEVICES, COMMON_DTYPES)), + ) + @run_supported_device_dtype + def test_aoti(self, api, test_device, test_dtype): + if not TORCH_VERSION_AFTER_2_4: + self.skipTest("aoti compatibility requires 2.4+.") + + logger.info(f"TestExport: {api}, {test_device}, {test_dtype}") + + if test_dtype != torch.bfloat16: + self.skipTest(f"{api} in {test_dtype} is not support for aoti compilation yet") + + m, k, n = 32, 64, 32 + + class test_model(nn.Module): + def __init__(self): + super().__init__() + self.lin1 = nn.Linear(k, n) + self.relu = nn.ReLU() + self.lin2 = nn.Linear(n, n) + + def forward(self, x): + x = self.lin1(x) + x = self.relu(x) + x = self.lin2(x) + return x + + x = torch.randn(m, k, dtype=test_dtype, device=test_device) + + # get float reference + model = test_model().to(dtype=test_dtype, device=test_device).eval() + ref_f = model(x) + + kwargs = {"dtype": test_dtype} + api(model, **kwargs) + + # running model + ref = model(x) + + # make sure it compiles + example_inputs = (x,) + model = torch.export.export(model, example_inputs).module() + after_export = model(x) + self.assertTrue(torch.equal(after_export, ref)) + + if __name__ == "__main__": unittest.main()