Skip to content

Commit

Permalink
Fix barrier insertion after assert op (triton-lang#5114)
Browse files Browse the repository at this point in the history
This will fix the following problem:
```bash
python: /home/runner/work/triton/triton/llvm-project/llvm/include/llvm/ADT/ilist_iterator.h:168: llvm::ilist_iterator::reference llvm::ilist_iterator<llvm::ilist_detail::node_options<mlir::Operation, true, false, void, false, void>, false, false>::operator*() const [OptionsT = llvm::ilist_detail::node_options<mlir::Operation, true, false, void, false, void>, IsReverse = false, IsConst = false]: Assertion `!NodePtr->isKnownSentinel()' failed.
Aborted (core dumped)
```

The problem was found when using PyTorch on Intel gpu:

<details>

<summary> Simplified reproducer triton-lang#1:</summary>

```python
from torch._inductor.async_compile import AsyncCompile
async_compile = AsyncCompile()

triton_per_fused_add_embedding_native_layer_norm_0 = async_compile.triton('triton_per_fused_add_embedding_native_layer_norm_0', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.persistent_reduction(
    size_hints=[512, 128],
    reduction_hint=ReductionHint.INNER,
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*fp32', 'in_ptr2': '*fp32', 'in_ptr3': '*fp32', 'in_ptr4': '*fp32', 'in_ptr5': '*fp32', 'out_ptr2': '*fp32', 'xnumel': 'i32', 'rnumel': 'i32'}, 'device': DeviceProperties(type='xpu', index=0, cc={'driver_version': '1.3.30049', 'gpu_eu_count': 448, 'gpu_subslice_count': 56, 'has_atomic64': True, 'has_bfloat16_conversions': True, 'has_fp16': True, 'has_fp64': True, 'has_subgroup_2d_block_io': True, 'has_subgroup_matrix_multiply_accumulate': True, 'has_subgroup_matrix_multiply_accumulate_tensor_float32': False, 'max_compute_units': 448, 'max_num_sub_groups': 64, 'max_work_group_size': 1024, 'name': 'Intel(R) Data Center GPU Max 1100', 'platform_name': 'Intel(R) Level-Zero', 'sub_group_sizes': [16, 32], 'total_memory': 51539607552, 'type': 'gpu', 'vendor': 'Intel(R) Corporation', 'version': '1.3'}, major=None, regs_per_multiprocessor=None, max_threads_per_multi_processor=None, multi_processor_count=None, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4, 5, 6, 7, 8), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused_add_embedding_native_layer_norm_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 5, 'num_reduction': 4, 'backend_hash': 'D82C2E8E2C9203D653D1A2B8A0511701E4F7567A195A5128E03B9AA7218348AA', 'are_deterministic_algorithms_enabled': True, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_per_fused_add_embedding_native_layer_norm_0(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr):
    xnumel = 512
    rnumel = 128
    RBLOCK: tl.constexpr = 128
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    rindex = tl.arange(0, RBLOCK)[None, :]
    roffset = 0
    rmask = tl.full([XBLOCK, RBLOCK], True, tl.int1)
    x0 = xindex
    r1 = rindex
    tmp0 = tl.load(in_ptr0 + (x0), xmask, eviction_policy='evict_last')
    tmp7 = tl.load(in_ptr2 + (r1 + (128*x0)), xmask, other=0.0)
    tmp9 = tl.load(in_ptr3 + (r1 + (128*x0)), xmask, other=0.0)
    tmp34 = tl.load(in_ptr4 + (r1), None, eviction_policy='evict_last')
    tmp36 = tl.load(in_ptr5 + (r1), None, eviction_policy='evict_last')
    tmp1 = tl.full([XBLOCK, RBLOCK], 30000, tl.int32)
    tmp2 = tmp0 + tmp1
    tmp3 = tmp0 < 0
    tmp4 = tl.where(tmp3, tmp2, tmp0)
    tl.device_assert(((0 <= tmp4) & (tmp4 < 30000)) | ~(xmask), "index out of bounds: 0 <= tmp4 < 30000")
''', device_str='xpu')

```
</details>
  • Loading branch information
anmyachev authored Nov 12, 2024
1 parent 126d546 commit 3e359b3
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 4 deletions.
2 changes: 1 addition & 1 deletion lib/Conversion/TritonGPUToLLVM/AssertOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ struct AssertOpConversion : public ConvertOpToLLVMPattern<triton::AssertOp> {
// know about the op to split the block.
void llAssert(Operation *op, Value condition, StringRef message,
ConversionPatternRewriter &rewriter) const {
ConversionPatternRewriter::InsertionGuard guard(rewriter);

auto ctx = rewriter.getContext();
auto loc = op->getLoc();
Expand Down Expand Up @@ -87,6 +86,7 @@ struct AssertOpConversion : public ConvertOpToLLVMPattern<triton::AssertOp> {
rewriter.create<cf::BranchOp>(loc, thenBlock);
rewriter.setInsertionPointToEnd(prevBlock);
rewriter.create<cf::CondBranchOp>(loc, condition, ifBlock, thenBlock);
rewriter.setInsertionPointToStart(thenBlock);
}

protected:
Expand Down
19 changes: 16 additions & 3 deletions python/test/unit/test_debug.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
import pytest
import torch
import triton.language as tl
Expand All @@ -10,8 +9,8 @@
@pytest.mark.parametrize('env_var', [True, False])
@pytest.mark.parametrize('jit_flag', [True, False])
@pytest.mark.forked
def test_device_assert(cond, opt_flag, env_var, jit_flag, device):
os.environ['TRITON_DEBUG'] = str(int(env_var))
def test_device_assert(monkeypatch, cond, opt_flag, env_var, jit_flag, device):
monkeypatch.setenv("TRITON_DEBUG", str(int(env_var)))
torch.zeros([1], dtype=torch.int32, device=device)

@triton.jit(debug=jit_flag)
Expand All @@ -34,6 +33,20 @@ def _kernel(COND: tl.constexpr):
getattr(torch, device).synchronize()


def test_device_assert_barrier(monkeypatch, device):
monkeypatch.setenv("TRITON_DEBUG", "1")
tensor = torch.zeros([16], dtype=torch.int32, device=device)

@triton.jit
def _kernel(in_ptr0):
xindex = tl.arange(0, 8)
tmp0 = tl.load(in_ptr0 + xindex)
tl.device_assert(tmp0 < 1)

_kernel[(1, )](tensor)
getattr(torch, device).synchronize()


@pytest.mark.parametrize("cond", [False, True])
def test_static_assert(cond):

Expand Down

0 comments on commit 3e359b3

Please sign in to comment.