Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

INTERNAL ASSERT FAILED (Vectorized accesses cannot be inline with computation) #3292

Closed
t-vi opened this issue Oct 28, 2024 · 1 comment · Fixed by #3301
Closed

INTERNAL ASSERT FAILED (Vectorized accesses cannot be inline with computation) #3292

t-vi opened this issue Oct 28, 2024 · 1 comment · Fixed by #3301
Assignees
Labels

Comments

@t-vi
Copy link
Contributor

t-vi commented Oct 28, 2024

This is from the PyTorch CI, but I can also reproduce with a self-built NVFuser:

Internal assert:

Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/nvfuser/__init__.py", line 181, in execute
    results = self._execute(
              ^^^^^^^^^^^^^^
RuntimeError:  INTERNAL ASSERT FAILED at "/home/tv/data/firma/grid/thunder/Fuser/csrc/device_lower/validation.cpp":799, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. Vectorized accesses cannot be inline with computation: T70_l_float[ iblockIdx.x595{( ceilDiv(( ceilDiv(( 5 * 64 ), 4) ), blockDim.x) )}, iblockIdx.y564{( ceilDiv(( 5 * 7 ), 1) )}, iUS565{1}, iV560{4}, ithreadIdx.x562{blockDim.x} ] ca_pos( 3 )
   = pad( T16_g_float[ iS594{( ceilDiv(( ceilDiv(( 5 * 64 ), 4) ), blockDim.x) )}, iS572{( ceilDiv(( 5 * 7 ), 1) )}, iS573{1}, iS568{4}, iS570{blockDim.x} ], {0, 0, 0, 0, 0, 0, 0, i660} )

Exception raised from validateAndCollectVectorizeInfo at /home/tv/data/firma/grid/thunder/Fuser/csrc/device_lower/validation.cpp:799 (most recent call first):
frame #0: nvfuser::nvfCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0xca (0x7f267251572c in /usr/local/lib/python3.11/dist-packages/nvfuser/_C.cpython-311-x86_64-linux-gnu.so)
frame #1: nvfuser::nvfErrorFail(char const*, char const*, unsigned int, char const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0x3e (0x7f267289b31e in /usr/local/lib/python3.11/dist-packages/nvfuser/_C.cpython-311-x86_64-linux-gnu.so)
frame #2: <unknown function> + 0x446331 (0x7f2672846331 in /usr/local/lib/python3.11/dist-packages/nvfuser/_C.cpython-311-x86_64-linux-gnu.so)
frame #3: <unknown function> + 0x38aa49 (0x7f267278aa49 in /usr/local/lib/python3.11/dist-packages/nvfuser/_C.cpython-311-x86_64-linux-gnu.so)
frame #4: nvfuser::GpuLower::GpuLower(nvfuser::Fusion*, nvfuser::CompileParams const&) + 0xe3a (0x7f267278c66a in /usr/local/lib/python3.11/dist-packages/nvfuser/_C.cpython-311-x86_64-linux-gnu.so)
frame #5: nvfuser::FusionExecutor::compileFusion(nvfuser::Fusion*, nvfuser::KernelArgumentHolder const&, nvfuser::LaunchParams const&, nvfuser::CompileParams, nvfuser::SchedulerType, long, long, long, long) + 0xc62 (0x7f2672b73ab2 in /usr/local/lib/python3.11/dist-packages/nvfuser/_C.cpython-311-x86_64-linux-gnu.so)
frame #6: <unknown function> + 0x7ae86d (0x7f2672bae86d in /usr/local/lib/python3.11/dist-packages/nvfuser/_C.cpython-311-x86_64-linux-gnu.so)
frame #7: nvfuser::FusionKernelRuntime::compileFusionParallel(nvfuser::KernelArgumentHolder) + 0x3d0 (0x7f2672bb00b0 in /usr/local/lib/python3.11/dist-packages/nvfuser/_C.cpython-311-x86_64-linux-gnu.so)
frame #8: nvfuser::FusionExecutorCache::runFusionWithInputs(c10::ArrayRef<c10::IValue> const&, std::optional<nvfuser::PrimDataType>, std::optional<signed char>) + 0x1dc (0x7f2672ba987c in /usr/local/lib/python3.11/dist-packages/nvfuser/_C.cpython-311-x86_64-linux-gnu.so)
frame #9: nvfuser::python_frontend::FusionDefinition::execute(c10::ArrayRef<c10::IValue> const&, std::optional<signed char>, bool, bool, bool) const + 0x150 (0x7f2672d47f60 in /usr/local/lib/python3.11/dist-packages/nvfuser/_C.cpython-311-x86_64-linux-gnu.so)
frame #10: <unknown function> + 0x22ff3b (0x7f267262ff3b in /usr/local/lib/python3.11/dist-packages/nvfuser/_C.cpython-311-x86_64-linux-gnu.so)
frame #11: <unknown function> + 0x2801d6 (0x7f26726801d6 in /usr/local/lib/python3.11/dist-packages/nvfuser/_C.cpython-311-x86_64-linux-gnu.so)
frame #12: python3() [0x4d599b]
<omitting python frames>
frame #16: python3() [0x5cf631]
frame #19: python3() [0x5f29b0]
frame #21: <unknown function> + 0x27c8a (0x7f274f457c8a in /lib/x86_64-linux-gnu/libc.so.6)
frame #22: __libc_start_main + 0x85 (0x7f274f457d45 in /lib/x86_64-linux-gnu/libc.so.6)

Repro:

# CUDA devices:
#  0: NVIDIA GeForce RTX 3090
#  1: NVIDIA GeForce RTX 3090
# torch version: 2.6.0a0+git4d9b5a8
# cuda version: 12.1
# nvfuser version: 0.2.21+git628a47e
import torch
from nvfuser import FusionDefinition, DataType

def nvfuser_fusion_id1522(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(shape=[128, 64], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
    T1 = fd.define_tensor(shape=[128, 64], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
    T2 = fd.define_tensor(shape=[5, 5, 576], contiguity=[True, True, True], dtype=DataType.Float, is_cpu=False, stride_order=[2, 1, 0])
    T3 = fd.define_tensor(shape=[5, 5, 1792], contiguity=[True, True, True], dtype=DataType.Float, is_cpu=False, stride_order=[2, 1, 0])
    T13 = fd.ops.slice(T0, start_indices=[0, 0], end_indices=[5, 64], strides=[1, 1], manual_normalization=0)
    T23 = fd.ops.slice(T1, start_indices=[0, 0], end_indices=[5, 64], strides=[1, 1], manual_normalization=0)
    T30 = fd.ops.reshape(T2, new_shape=[5, 5, 1, 9, 64])
    T31 = fd.ops.permute(T30, dims=[0, 2, 3, 1, 4])
    T50 = fd.ops.slice(T31, start_indices=[0, 0, 0, 0, 0], end_indices=[5, 1, 7, 5, 64], strides=[1, 1, 1, 1, 1], manual_normalization=0)
    T69 = fd.ops.slice(T31, start_indices=[0, 0, 7, 0, 0], end_indices=[5, 1, 8, 5, 64], strides=[1, 1, 1, 1, 1], manual_normalization=0)
    T88 = fd.ops.slice(T31, start_indices=[0, 0, 8, 0, 0], end_indices=[5, 1, 9, 5, 64], strides=[1, 1, 1, 1, 1], manual_normalization=0)
    S89 = fd.define_scalar(5, dtype=DataType.Int)
    S90 = fd.define_scalar(1, dtype=DataType.Int)
    S91 = fd.define_scalar(7, dtype=DataType.Int)
    S92 = fd.define_scalar(5, dtype=DataType.Int)
    S93 = fd.define_scalar(64, dtype=DataType.Int)
    T95 = fd.ops.broadcast_in_dim(T69, shape=[S89, S90, S91, S92, S93], broadcast_dims=[0, 1, 2, 3, 4])
    S96 = fd.define_scalar(5, dtype=DataType.Int)
    S97 = fd.define_scalar(1, dtype=DataType.Int)
    S98 = fd.define_scalar(7, dtype=DataType.Int)
    S99 = fd.define_scalar(5, dtype=DataType.Int)
    S100 = fd.define_scalar(64, dtype=DataType.Int)
    T102 = fd.ops.broadcast_in_dim(T88, shape=[S96, S97, S98, S99, S100], broadcast_dims=[0, 1, 2, 3, 4])
    T108 = fd.ops.reshape(T50, new_shape=[5, 7, 5, 64])
    T114 = fd.ops.reshape(T95, new_shape=[5, 7, 5, 64])
    T120 = fd.ops.reshape(T102, new_shape=[5, 7, 5, 64])
    T136 = fd.ops.slice(T108, start_indices=[0, 0, 0, 0], end_indices=[5, 7, 5, 32], strides=[1, 1, 1, 1], manual_normalization=0)
    T152 = fd.ops.slice(T108, start_indices=[0, 0, 0, 32], end_indices=[5, 7, 5, 64], strides=[1, 1, 1, 1], manual_normalization=0)
    T153 = fd.ops.neg(T152)
    T154 = fd.ops.cat([T153, T136], dim=-1, manual_padding=0)
    S155 = fd.define_scalar(5, dtype=DataType.Int)
    S156 = fd.define_scalar(7, dtype=DataType.Int)
    S157 = fd.define_scalar(5, dtype=DataType.Int)
    S158 = fd.define_scalar(64, dtype=DataType.Int)
    T160 = fd.ops.broadcast_in_dim(T13, shape=[S155, S156, S157, S158], broadcast_dims=[2, 3])
    T161 = fd.ops.mul(T108, T160)
    S162 = fd.define_scalar(5, dtype=DataType.Int)
    S163 = fd.define_scalar(7, dtype=DataType.Int)
    S164 = fd.define_scalar(5, dtype=DataType.Int)
    S165 = fd.define_scalar(64, dtype=DataType.Int)
    T167 = fd.ops.broadcast_in_dim(T23, shape=[S162, S163, S164, S165], broadcast_dims=[2, 3])
    T168 = fd.ops.mul(T154, T167)
    T169 = fd.ops.add(T161, T168)
    T185 = fd.ops.slice(T114, start_indices=[0, 0, 0, 0], end_indices=[5, 7, 5, 32], strides=[1, 1, 1, 1], manual_normalization=0)
    T201 = fd.ops.slice(T114, start_indices=[0, 0, 0, 32], end_indices=[5, 7, 5, 64], strides=[1, 1, 1, 1], manual_normalization=0)
    T202 = fd.ops.neg(T201)
    T203 = fd.ops.cat([T202, T185], dim=-1, manual_padding=0)
    T204 = fd.ops.mul(T114, T160)
    T205 = fd.ops.mul(T203, T167)
    T206 = fd.ops.add(T204, T205)
    T222 = fd.ops.slice(T108, start_indices=[0, 0, 0, 0], end_indices=[5, 7, 5, 0], strides=[1, 1, 1, 1], manual_normalization=0)
    T223 = fd.ops.cat([T169, T222], dim=-1, manual_padding=0)
    T239 = fd.ops.slice(T114, start_indices=[0, 0, 0, 0], end_indices=[5, 7, 5, 0], strides=[1, 1, 1, 1], manual_normalization=0)
    T240 = fd.ops.cat([T206, T239], dim=-1, manual_padding=0)
    S241 = fd.define_scalar(0.353553, dtype=DataType.Double)
    T242 = fd.ops.mul(T223, S241)
    T243 = fd.ops.permute(T240, dims=[0, 1, 3, 2])
    S244 = fd.define_scalar(0.353553, dtype=DataType.Double)
    T245 = fd.ops.mul(T243, S244)
    S246 = fd.define_scalar(1.41421, dtype=DataType.Double)
    S247 = fd.ops.reciprocal(S246)
    T248 = fd.ops.mul(T3, S247)
    T249 = fd.ops.erf(T248)
    S250 = fd.define_scalar(0.500000, dtype=DataType.Double)
    T251 = fd.ops.mul(S250, T249)
    S252 = fd.define_scalar(0.500000, dtype=DataType.Double)
    T253 = fd.ops.add(S252, T251)
    T254 = fd.ops.mul(T3, T253)
    fd.add_output(T120)
    fd.add_output(T160)
    fd.add_output(T167)
    fd.add_output(T242)
    fd.add_output(T245)
    fd.add_output(T254)

with FusionDefinition() as fd:
    nvfuser_fusion_id1522(fd)

inputs = [
    torch.testing.make_tensor((128, 64), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((128, 64), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((5, 5, 576), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((5, 5, 1792), dtype=torch.float32, device='cuda:0'),
]
fd.execute(inputs)
@naoyam
Copy link
Collaborator

naoyam commented Oct 28, 2024

Hmm, it's a bit strange that the pad op was vectorized. I thought it should never be vectorized.

I could not reproduce the error on an RTX 6000. I'll see if I can grab a 3090.

jacobhinkle added a commit that referenced this issue Oct 30, 2024
For dynamic fusions, we detect empty tensors and set their extents to
immediate constant 0. Later, in the remove_empty preseg pass, we do a
shallow check that extents are empty so that we can simplify the fusion.
When the fusion is not dynamic there is no concretization step where we
would do this extent replacement, so we might have constant 0 extents
that are compound scalars. This caused us to miss some empty tensors in
#3292, particularly one of the inputs to a `cat`.

This PR:
- Uses a deep evaluation of each `getMaybeExpandedExtent()` to determine
if an axis is empty
- Adds an ExpressionEvaluator field to `EmptyTensorRemover` to avoid
repeating the deep evaluation when possible. This won't help prevent
repeated evaluation of symbolic extents; we could track those in an
`unordered_set` potentially instead.

Fixes #3292

---------

Co-authored-by: Naoya Maruyama <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants