From 4c7fa069848aa32a24e47cae0e5996bc0eeeb70a Mon Sep 17 00:00:00 2001 From: Huy Do Date: Wed, 15 Nov 2023 19:35:28 -0800 Subject: [PATCH] Add some more validation checks for torch.linalg.eigh and torch.compile (#1580) * Add some more validation checks for torch.linalg.eigh and torch.compile * Update test * Also update smoke_test.py * Fix lint --- check_binary.sh | 6 ++++++ test/smoke_test/smoke_test.py | 3 +++ test_example_code/torch_compile_smoke.py | 12 ++++++++++++ 3 files changed, 21 insertions(+) create mode 100644 test_example_code/torch_compile_smoke.py diff --git a/check_binary.sh b/check_binary.sh index 30b44b535..9e7d03a54 100755 --- a/check_binary.sh +++ b/check_binary.sh @@ -404,6 +404,12 @@ if [[ "$DESIRED_CUDA" != 'cpu' && "$DESIRED_CUDA" != 'cpu-cxx11-abi' && "$DESIRE echo "Test that linalg works" python -c "import torch;x=torch.rand(3,3,device='cuda');print(torch.linalg.svd(torch.mm(x.t(), x)))" + echo "Test that linalg.eigh works" + python -c "import torch;x=torch.rand(3,3,device='cuda');print(torch.linalg.eigh(torch.mm(x.t(), x)))" + + echo "Checking that basic torch.compile works" + python ${TEST_CODE_DIR}/torch_compile_smoke.py + popd fi # if libtorch fi # if cuda diff --git a/test/smoke_test/smoke_test.py b/test/smoke_test/smoke_test.py index 3d1b6af64..64efc7601 100644 --- a/test/smoke_test/smoke_test.py +++ b/test/smoke_test/smoke_test.py @@ -193,6 +193,9 @@ def smoke_test_linalg() -> None: A = torch.randn(20, 16, 50, 100, device="cuda").type(dtype) torch.linalg.svd(A) + A = torch.rand(3, 3, device="cuda") + L, Q = torch.linalg.eigh(torch.mm(A.t(), A)) + def smoke_test_compile() -> None: supported_dtypes = [torch.float16, torch.float32, torch.float64] diff --git a/test_example_code/torch_compile_smoke.py b/test_example_code/torch_compile_smoke.py new file mode 100644 index 000000000..7a12a013e --- /dev/null +++ b/test_example_code/torch_compile_smoke.py @@ -0,0 +1,12 @@ +import torch + + +def foo(x: torch.Tensor) -> torch.Tensor: + return torch.sin(x) + torch.cos(x) + + +if __name__ == "__main__": + x = torch.rand(3, 3, device="cuda") + x_eager = foo(x) + x_pt2 = torch.compile(foo)(x) + print(torch.allclose(x_eager, x_pt2))