forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 55
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Nvfuser moving python tests and files under nvfuser (pytorch#95155)
1. Moving `test_jit_cuda_fuser.py` `test_nvfuser_dynamo.py` `test_nvfuser_frontend.py` under `third_party/nvfuser/python_tests/`. 2. Moving `nvfuser/__init__.py` to `third_party/nvfuser/python/`. 3. Leaving dummy test scripts under `./test/` for CI. 4. Patching `torch/_prims/nvfuser_prims.py` for view/reshape renaming in nvfuser 5. Installing `third_party/nvfuser/python` and `third_party/nvfuser/python_tests` to pytorch root/test directy. Pull Request resolved: pytorch#95155 Approved by: https://github.com/davidberard98
- Loading branch information
1 parent
5d2eb6d
commit 7b403c8
Showing
9 changed files
with
5,857 additions
and
5,809 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,148 +1,11 @@ | ||
# Owner(s): ["module: nvfuser"] | ||
|
||
import unittest | ||
import warnings | ||
from functools import partial | ||
|
||
import torch | ||
import torch._dynamo as torchdynamo | ||
from torch.testing import make_tensor | ||
from torch.testing._internal.common_utils import ( | ||
IS_WINDOWS, | ||
run_tests, | ||
skipIfTorchDynamo, | ||
TEST_WITH_ROCM, | ||
TestCase, | ||
) | ||
from torch.testing._internal.jit_utils import RUN_CUDA | ||
|
||
RUN_NVFUSER = RUN_CUDA and not TEST_WITH_ROCM | ||
|
||
|
||
def is_pre_volta(): | ||
if not RUN_NVFUSER: | ||
return False | ||
prop = torch.cuda.get_device_properties(torch.cuda.current_device()) | ||
return prop.major < 7 | ||
|
||
|
||
def is_networkx_available(): | ||
try: | ||
import networkx # noqa: F401 | ||
|
||
return True | ||
except ImportError: | ||
return False | ||
|
||
|
||
@skipIfTorchDynamo("Not a suitable test for TorchDynamo") | ||
@unittest.skipIf(IS_WINDOWS, "TorchDynamo is not supported on Windows") | ||
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA") | ||
@unittest.skipIf(is_pre_volta(), "Only supported on Volta and newer devices.") | ||
class TestNvFuserDynamo(TestCase): | ||
def test_basic(self): | ||
input1 = make_tensor((2, 4, 8), device="cuda", dtype=torch.float32) | ||
input2 = make_tensor((2, 4, 8), device="cuda", dtype=torch.float32) | ||
|
||
@torchdynamo.optimize("nvprims_nvfuser") | ||
def func(a, b): | ||
return a.sin() + b.cos() | ||
|
||
# No warnings and no errors | ||
with warnings.catch_warnings(record=True) as w: | ||
nvfuser_result = func(input1, input2) | ||
self.assertEqual(len(w), 0) | ||
eager_result = func.__wrapped__(input1, input2) | ||
self.assertEqual(eager_result, nvfuser_result) | ||
|
||
@unittest.skipIf(not is_networkx_available(), "networkx not available") | ||
def test_min_cut(self): | ||
from functorch.compile import default_partition | ||
from torch._dynamo.backends.nvfuser import nvprims_fw_bw_partition_fn | ||
|
||
def get_fw_bw_graph(f, inps, partitioner): | ||
from functorch.compile import aot_function | ||
|
||
# Helper functions are taken from functorch/test_aotdispatch.py | ||
def extract_graph(fx_g, _, graph_cell): | ||
graph_cell[0] = fx_g | ||
return fx_g | ||
|
||
fw_graph_cell = [None] | ||
bw_graph_cell = [None] | ||
aot_function( | ||
f, | ||
fw_compiler=partial(extract_graph, graph_cell=fw_graph_cell), | ||
bw_compiler=partial(extract_graph, graph_cell=bw_graph_cell), | ||
partition_fn=partitioner, | ||
)(*inps).sum().backward() | ||
return (fw_graph_cell[0], bw_graph_cell[0]) | ||
|
||
def get_ins_outs(fx_g): | ||
ins = [] | ||
outs = [] | ||
for n in fx_g.graph.nodes: | ||
if n.op == "placeholder": | ||
ins.append(n) | ||
elif n.op == "output": | ||
outs = tuple(n.args[0]) | ||
return ins, outs | ||
|
||
def get_num_ins_outs(fx_g): | ||
return tuple(len(i) for i in get_ins_outs(fx_g)) | ||
|
||
def func(x): | ||
return x * x * x | ||
|
||
input1 = make_tensor( | ||
(3,), device="cpu", dtype=torch.float32, requires_grad=True | ||
) | ||
fw_graph, bw_graph = get_fw_bw_graph(func, [input1], default_partition) | ||
self.assertEqual(get_num_ins_outs(fw_graph), (1, 3)) | ||
self.assertEqual(get_num_ins_outs(bw_graph), (3, 1)) | ||
|
||
input1 = make_tensor( | ||
(3,), device="cpu", dtype=torch.float32, requires_grad=True | ||
) | ||
fw_graph, bw_graph = get_fw_bw_graph(func, [input1], nvprims_fw_bw_partition_fn) | ||
self.assertEqual(get_num_ins_outs(fw_graph), (1, 2)) | ||
self.assertEqual(get_num_ins_outs(bw_graph), (2, 1)) | ||
|
||
def test_batch_norm_implicit_dtype_promotion(self): | ||
input1 = make_tensor((2, 3, 4, 5), device="cuda", dtype=torch.float32) | ||
input2 = make_tensor((5, 5), device="cuda", dtype=torch.float32) | ||
w = make_tensor((3), device="cuda", dtype=torch.float32) | ||
b = make_tensor((3), device="cuda", dtype=torch.float32) | ||
|
||
@torchdynamo.optimize("nvprims_nvfuser") | ||
def func(mat1, mat2, w, b): | ||
o = torch.matmul(mat1, mat2) | ||
return torch.batch_norm(o, w, b, None, None, True, 1e-2, 1e-5, True) | ||
|
||
# No warnings and no errors | ||
with torch.cuda.amp.autocast(): | ||
with warnings.catch_warnings(record=True) as warning: | ||
nvfuser_result = func(input1, input2, w, b) | ||
self.assertEqual(len(warning), 0) | ||
eager_result = func.__wrapped__(input1, input2, w, b) | ||
self.assertEqual(eager_result, nvfuser_result) | ||
|
||
def test_dtype_correctness(self): | ||
input1 = make_tensor((2, 4, 8), device="cuda", dtype=torch.float16) | ||
|
||
@torchdynamo.optimize("nvprims_nvfuser") | ||
def func(a): | ||
tmp = a + 1.0 | ||
# nvfuser would promote output to fp32 in math, FusionDefinition should cast output dtype back | ||
return torch.where(tmp > 0, tmp, 0.0) | ||
|
||
# No warnings and no errors | ||
with warnings.catch_warnings(record=True) as w: | ||
nvfuser_result = func(input1) | ||
self.assertEqual(len(w), 0) | ||
eager_result = func.__wrapped__(input1) | ||
self.assertEqual(eager_result, nvfuser_result) | ||
|
||
|
||
if __name__ == "__main__": | ||
try: | ||
from _nvfuser.test_dynamo import run_tests # noqa: F403 | ||
except ImportError: | ||
def run_tests(): | ||
return | ||
pass | ||
|
||
if __name__ == '__main__': | ||
run_tests() |
Oops, something went wrong.