Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BACKEND][CPU] Convert tt.get_program_id and tt.print (Hello World) #1

Merged
merged 1 commit into from
May 14, 2024

Conversation

minjang
Copy link
Collaborator

@minjang minjang commented May 14, 2024

Summary: tl.program_id needs to be lowered first for any meaningful example. As of now, we think pid will be provided as additional function arguments to the kernel. The CPU launch (parallel for) will give grid ids. Getting program_id is mapped to reading one of the last three arguments.

I also quickly implemented tl.device_print or print, only for scalar types for a quick "Hello, World!" testing.

Test Plan: Tested with a simple example:

@triton.jit
def add_kernel(...):
    pid = tl.program_id(axis=0)  # We use a 1D launch grid so axis is 0.
    foo = pid + 42
    tl.device_print("Hello, World!", foo, pid)

The resulting .llir is valid:

@printfFormat_1 = internal constant [31 x i8] c"pid (%u, %u, %u) test: %u, %u\0A\00"

declare !dbg !3 i32 @printf(ptr, ...)

define void @add_kernel(ptr addrspace(1) %0, ptr addrspace(1) %1, ptr addrspace(1) %2, i32 %3, i32 %4, i32 %5, i32 %6) !dbg !7 {
  %8 = add i32 %4, 42, !dbg !8
  %9 = call i32 (ptr, ...) @printf(ptr @printfFormat_0, i32 %4, i32 %5, i32 %6, i32 %8, i32 %4)
  ret void, !dbg !9
}

Tried to compile with a fake main function:

> % cat main.c
extern void add_kernel(float*, float*, float*, int, int, int, int);

int main() {
    add_kernel(0, 0, 0, 4, 5, 6, 7);
}

> % llc -filetype=obj add_kernel.llir && clang -o a.out add_kernel.llir.o main.c
> % ./a.out
pid (5, 6, 7) Hello, World!: 47, 5

Summary: As title, `tl.program_id` needs to be supported first. As of now, we think pid will be provided as additional function arguments to the kernel. So, getting program_id is mapped to reading one of the last three arguments.

I also quickly implemented `tl.device_print` or `print`, only for scalar types for a quick "Hello, World!" testing.

Test Plan: Tested with a simple example:

```
@triton.jit
def add_kernel(...):
    pid = tl.program_id(axis=0)  # We use a 1D launch grid so axis is 0.
    foo = pid + 42
    tl.device_print("Hello, World!", foo, pid)
```

The resulting .llir is valid:
```
@printfFormat_1 = internal constant [31 x i8] c"pid (%u, %u, %u) test: %u, %u\0A\00"

declare !dbg !3 i32 @printf(ptr, ...)

define void @add_kernel(ptr addrspace(1) %0, ptr addrspace(1) %1, ptr addrspace(1) %2, i32 %3, i32 %4, i32 %5, i32 %6) !dbg !7 {
  %8 = add i32 %4, 42, !dbg !8
  %9 = call i32 (ptr, ...) @printf(ptr @printfFormat_0, i32 %4, i32 %5, i32 %6, i32 %8, i32 %4)
  ret void, !dbg !9
}
```

Tried to compile with a fake main function:
```
> % cat main.c
extern void add_kernel(float*, float*, float*, int, int, int, int);

int main() {
    add_kernel(0, 0, 0, 4, 5, 6, 7);
}

> % llc -filetype=obj add_kernel.llir && clang -o a.out add_kernel.llir.o main.c
> % ./a.out
pid (5, 6, 7) Hello, World!: 47, 5
```
@minjang minjang requested a review from bertmaher May 14, 2024 03:11
@minjang minjang requested a review from ptillet as a code owner May 14, 2024 03:11
minjang pushed a commit to minjang/triton-cpu that referenced this pull request May 14, 2024
@minjang minjang merged commit 7fb570e into triton-lang:main May 14, 2024
minjang added a commit that referenced this pull request May 15, 2024
Summary: As title, `tl.program_id` needs to be supported first. As of now, we think pid will be provided as additional function arguments to the kernel. So, getting program_id is mapped to reading one of the last three arguments.

I also quickly implemented `tl.device_print` or `print`, only for scalar types for a quick "Hello, World!" testing.

Test Plan: Tested with a simple example:

```
@triton.jit
def add_kernel(...):
    pid = tl.program_id(axis=0)  # We use a 1D launch grid so axis is 0.
    foo = pid + 42
    tl.device_print("Hello, World!", foo, pid)
```

The resulting .llir is valid:
```
@printfFormat_1 = internal constant [31 x i8] c"pid (%u, %u, %u) test: %u, %u\0A\00"

declare !dbg !3 i32 @printf(ptr, ...)

define void @add_kernel(ptr addrspace(1) %0, ptr addrspace(1) %1, ptr addrspace(1) %2, i32 %3, i32 %4, i32 %5, i32 %6) !dbg !7 {
  %8 = add i32 %4, 42, !dbg !8
  %9 = call i32 (ptr, ...) @printf(ptr @printfFormat_0, i32 %4, i32 %5, i32 %6, i32 %8, i32 %4)
  ret void, !dbg !9
}
```

Tried to compile with a fake main function:
```
> % cat main.c
extern void add_kernel(float*, float*, float*, int, int, int, int);

int main() {
    add_kernel(0, 0, 0, 4, 5, 6, 7);
}

> % llc -filetype=obj add_kernel.llir && clang -o a.out add_kernel.llir.o main.c
> % ./a.out
pid (5, 6, 7) Hello, World!: 47, 5
```
ienkovich pushed a commit to ienkovich/triton-cpu that referenced this pull request May 15, 2024
@minjang minjang deleted the more_lowering branch May 28, 2024 17:59
minjang pushed a commit that referenced this pull request Jun 24, 2024
When running
[convert_blocked1d_to_slice0](https://github.com/triton-lang/triton/blob/0ba5f0c3cd029d5c3d1f01b9bf29dac32c27345e/test/Conversion/tritongpu_to_llvm.mlir#L924)
Triton ends up computing a rank of a matrix with 0 columns during linear
layout lowering, which trips up f2reduce, and causes undefined behavior,
detectable through
[UBSAN](https://clang.llvm.org/docs/UndefinedBehaviorSanitizer.html).

Fix this by returning the rank (0) early in these cases, without calling
f2reduce.

<details><summary>Stack trace</summary>
<p>

```
third_party/triton/third_party/f2reduce/f2reduce.cpp:421:30: runtime error: shift exponent 18446744073709551615 is too large for 64-bit type 'unsigned long long'
    #0 0x556ee2fea3be in inplace_rref_small third_party/triton/third_party/f2reduce/f2reduce.cpp:421:30
    #1 0x556ee2fea3be in f2reduce::inplace_rref_strided(unsigned long*, unsigned long, unsigned long, unsigned long) third_party/triton/third_party/f2reduce/f2reduce.cpp:470:9
    #2 0x556ee2ea70da in getMatrixRank third_party/triton/lib/Tools/LinearLayout.cpp:125:3
    #3 0x556ee2ea70da in mlir::triton::LinearLayout::checkInvariants(bool) third_party/triton/lib/Tools/LinearLayout.cpp:299:7
    #4 0x556ee2ea656d in mlir::triton::LinearLayout::tryCreate(llvm::MapVector<mlir::StringAttr, std::__u::vector<std::__u::vector<int, std::__u::allocator<int>>, std::__u::allocator<std::__u::vector<int, std::__u::allocator<int>>>>, llvm::DenseMap<mlir::StringAttr, unsigned int, llvm::DenseMapInfo<mlir::StringAttr, void>, llvm::detail::DenseMapPair<mlir::StringAttr, unsigned int>>, llvm::SmallVector<std::__u::pair<mlir::StringAttr, std::__u::vector<std::__u::vector<int, std::__u::allocator<int>>, std::__u::allocator<std::__u::vector<int, std::__u::allocator<int>>>>>, 0u>>, llvm::ArrayRef<std::__u::pair<mlir::StringAttr, int>>, bool) third_party/triton/lib/Tools/LinearLayout.cpp:190:41
    #5 0x556ee2eb2150 in mlir::triton::LinearLayout::divideRight(mlir::triton::LinearLayout const&) third_party/triton/lib/Tools/LinearLayout.cpp:654:51
    #6 0x556ee2ee1c39 in mlir::cvtNeedsSharedMemory(mlir::RankedTensorType, mlir::RankedTensorType) third_party/triton/lib/Analysis/Utility.cpp:652:14
    #7 0x556ee2cf38fd in mlir::triton::getRepShapeForCvtLayout(mlir::triton::gpu::ConvertLayoutOp) third_party/triton/lib/Analysis/Allocation.cpp:66:8
    #8 0x556ee2cf3efa in mlir::triton::getScratchConfigForCvtLayout(mlir::triton::gpu::ConvertLayoutOp, unsigned int&, unsigned int&) third_party/triton/lib/Analysis/Allocation.cpp:95:19
    #9 0x556ee2cf6057 in mlir::triton::AllocationAnalysis::getScratchValueSize(mlir::Operation*) third_party/triton/lib/Analysis/Allocation.cpp:272:24
    #10 0x556ee2cf5499 in operator() third_party/triton/lib/Analysis/Allocation.cpp:343:7
    #11 0x556ee2cf5499 in void llvm::function_ref<void (mlir::Operation*)>::callback_fn<mlir::triton::AllocationAnalysis::getValuesAndSizes()::'lambda'(mlir::Operation*)>(long, mlir::Operation*) third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:45:12
    #12 0x556edeeee7a9 in operator() third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:68:12
    #13 0x556edeeee7a9 in void mlir::detail::walk<mlir::ForwardIterator>(mlir::Operation*, llvm::function_ref<void (mlir::Operation*)>, mlir::WalkOrder) third_party/llvm/llvm-project/mlir/include/mlir/IR/Visitors.h:174:5
    #14 0x556edeeee87c in void mlir::detail::walk<mlir::ForwardIterator>(mlir::Operation*, llvm::function_ref<void (mlir::Operation*)>, mlir::WalkOrder) third_party/llvm/llvm-project/mlir/include/mlir/IR/Visitors.h:182:9
    #15 0x556ee2cf49e7 in walk<(mlir::WalkOrder)0, mlir::ForwardIterator, (lambda at third_party/triton/lib/Analysis/Allocation.cpp:341:42), mlir::Operation *, void> third_party/llvm/llvm-project/mlir/include/mlir/IR/Visitors.h:313:10
    #16 0x556ee2cf49e7 in walk<(mlir::WalkOrder)0, mlir::ForwardIterator, (lambda at third_party/triton/lib/Analysis/Allocation.cpp:341:42), void> third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h:794:12
    #17 0x556ee2cf49e7 in mlir::triton::AllocationAnalysis::getValuesAndSizes() third_party/triton/lib/Analysis/Allocation.cpp:341:16
    #18 0x556ee2cf4852 in run third_party/triton/lib/Analysis/Allocation.cpp:182:5
    #19 0x556ee2cf4852 in AllocationAnalysis third_party/triton/lib/Analysis/Allocation.cpp:169:5
    #20 0x556ee2cf4852 in mlir::Allocation::run(llvm::DenseMap<mlir::FunctionOpInterface, mlir::Allocation, llvm::DenseMapInfo<mlir::FunctionOpInterface, void>, llvm::detail::DenseMapPair<mlir::FunctionOpInterface, mlir::Allocation>>&) third_party/triton/lib/Analysis/Allocation.cpp:627:3
    #21 0x556ee1677402 in operator() third_party/triton/include/triton/Analysis/Allocation.h:227:26
    #22 0x556ee1677402 in void mlir::CallGraph<mlir::Allocation>::doWalk<(mlir::WalkOrder)0, (mlir::WalkOrder)1, mlir::ModuleAllocation::ModuleAllocation(mlir::ModuleOp)::'lambda'(mlir::CallOpInterface, mlir::FunctionOpInterface), mlir::ModuleAllocation::ModuleAllocation(mlir::ModuleOp)::'lambda'(mlir::FunctionOpInterface)>(mlir::FunctionOpInterface, llvm::DenseSet<mlir::FunctionOpInterface, llvm::DenseMapInfo<mlir::FunctionOpInterface, void>>&, mlir::ModuleAllocation::ModuleAllocation(mlir::ModuleOp)::'lambda'(mlir::CallOpInterface, mlir::FunctionOpInterface), mlir::ModuleAllocation::ModuleAllocation(mlir::ModuleOp)::'lambda'(mlir::FunctionOpInterface)) third_party/triton/include/triton/Analysis/Utility.h:350:7
    #23 0x556ee16756b3 in walk<(mlir::WalkOrder)0, (mlir::WalkOrder)1, (lambda at third_party/triton/include/triton/Analysis/Allocation.h:222:9), (lambda at third_party/triton/include/triton/Analysis/Allocation.h:224:9)> third_party/triton/include/triton/Analysis/Utility.h:242:7
    #24 0x556ee16756b3 in mlir::ModuleAllocation::ModuleAllocation(mlir::ModuleOp) third_party/triton/include/triton/Analysis/Allocation.h:220:5
    #25 0x556ee2c2bf18 in (anonymous namespace)::AllocateSharedMemory::runOnOperation() third_party/triton/lib/Conversion/TritonGPUToLLVM/AllocateSharedMemory.cpp:26:22
...
UndefinedBehaviorSanitizer: invalid-shift-exponent third_party/triton/third_party/f2reduce/f2reduce.cpp:421:30 
```
</p>
</details>
minjang added a commit that referenced this pull request Jun 24, 2024
Summary: As title, `tl.program_id` needs to be supported first. As of now, we think pid will be provided as additional function arguments to the kernel. So, getting program_id is mapped to reading one of the last three arguments.

I also quickly implemented `tl.device_print` or `print`, only for scalar types for a quick "Hello, World!" testing.

Test Plan: Tested with a simple example:

```
@triton.jit
def add_kernel(...):
    pid = tl.program_id(axis=0)  # We use a 1D launch grid so axis is 0.
    foo = pid + 42
    tl.device_print("Hello, World!", foo, pid)
```

The resulting .llir is valid:
```
@printfFormat_1 = internal constant [31 x i8] c"pid (%u, %u, %u) test: %u, %u\0A\00"

declare !dbg !3 i32 @printf(ptr, ...)

define void @add_kernel(ptr addrspace(1) %0, ptr addrspace(1) %1, ptr addrspace(1) %2, i32 %3, i32 %4, i32 %5, i32 %6) !dbg !7 {
  %8 = add i32 %4, 42, !dbg !8
  %9 = call i32 (ptr, ...) @printf(ptr @printfFormat_0, i32 %4, i32 %5, i32 %6, i32 %8, i32 %4)
  ret void, !dbg !9
}
```

Tried to compile with a fake main function:
```
> % cat main.c
extern void add_kernel(float*, float*, float*, int, int, int, int);

int main() {
    add_kernel(0, 0, 0, 4, 5, 6, 7);
}

> % llc -filetype=obj add_kernel.llir && clang -o a.out add_kernel.llir.o main.c
> % ./a.out
pid (5, 6, 7) Hello, World!: 47, 5
```
Devjiu pushed a commit to Devjiu/triton-cpu that referenced this pull request Aug 13, 2024
Devjiu pushed a commit to Devjiu/triton-cpu that referenced this pull request Aug 13, 2024
…riton-lang#1)

Summary: As title, `tl.program_id` needs to be supported first. As of now, we think pid will be provided as additional function arguments to the kernel. So, getting program_id is mapped to reading one of the last three arguments.

I also quickly implemented `tl.device_print` or `print`, only for scalar types for a quick "Hello, World!" testing.

Test Plan: Tested with a simple example:

```
@triton.jit
def add_kernel(...):
    pid = tl.program_id(axis=0)  # We use a 1D launch grid so axis is 0.
    foo = pid + 42
    tl.device_print("Hello, World!", foo, pid)
```

The resulting .llir is valid:
```
@printfFormat_1 = internal constant [31 x i8] c"pid (%u, %u, %u) test: %u, %u\0A\00"

declare !dbg !3 i32 @printf(ptr, ...)

define void @add_kernel(ptr addrspace(1) %0, ptr addrspace(1) %1, ptr addrspace(1) %2, i32 %3, i32 %4, i32 %5, i32 %6) !dbg !7 {
  %8 = add i32 %4, 42, !dbg !8
  %9 = call i32 (ptr, ...) @printf(ptr @printfFormat_0, i32 %4, i32 %5, i32 %6, i32 %8, i32 %4)
  ret void, !dbg !9
}
```

Tried to compile with a fake main function:
```
> % cat main.c
extern void add_kernel(float*, float*, float*, int, int, int, int);

int main() {
    add_kernel(0, 0, 0, 4, 5, 6, 7);
}

> % llc -filetype=obj add_kernel.llir && clang -o a.out add_kernel.llir.o main.c
> % ./a.out
pid (5, 6, 7) Hello, World!: 47, 5
```
int3 pushed a commit that referenced this pull request Aug 29, 2024
Summary: As title, `tl.program_id` needs to be supported first. As of now, we think pid will be provided as additional function arguments to the kernel. So, getting program_id is mapped to reading one of the last three arguments.

I also quickly implemented `tl.device_print` or `print`, only for scalar types for a quick "Hello, World!" testing.

Test Plan: Tested with a simple example:

```
@triton.jit
def add_kernel(...):
    pid = tl.program_id(axis=0)  # We use a 1D launch grid so axis is 0.
    foo = pid + 42
    tl.device_print("Hello, World!", foo, pid)
```

The resulting .llir is valid:
```
@printfFormat_1 = internal constant [31 x i8] c"pid (%u, %u, %u) test: %u, %u\0A\00"

declare !dbg !3 i32 @printf(ptr, ...)

define void @add_kernel(ptr addrspace(1) %0, ptr addrspace(1) %1, ptr addrspace(1) %2, i32 %3, i32 %4, i32 %5, i32 %6) !dbg !7 {
  %8 = add i32 %4, 42, !dbg !8
  %9 = call i32 (ptr, ...) @printf(ptr @printfFormat_0, i32 %4, i32 %5, i32 %6, i32 %8, i32 %4)
  ret void, !dbg !9
}
```

Tried to compile with a fake main function:
```
> % cat main.c
extern void add_kernel(float*, float*, float*, int, int, int, int);

int main() {
    add_kernel(0, 0, 0, 4, 5, 6, 7);
}

> % llc -filetype=obj add_kernel.llir && clang -o a.out add_kernel.llir.o main.c
> % ./a.out
pid (5, 6, 7) Hello, World!: 47, 5
```
minjang added a commit that referenced this pull request Sep 22, 2024
Summary: As title, `tl.program_id` needs to be supported first. As of now, we think pid will be provided as additional function arguments to the kernel. So, getting program_id is mapped to reading one of the last three arguments.

I also quickly implemented `tl.device_print` or `print`, only for scalar types for a quick "Hello, World!" testing.

Test Plan: Tested with a simple example:

```
@triton.jit
def add_kernel(...):
    pid = tl.program_id(axis=0)  # We use a 1D launch grid so axis is 0.
    foo = pid + 42
    tl.device_print("Hello, World!", foo, pid)
```

The resulting .llir is valid:
```
@printfFormat_1 = internal constant [31 x i8] c"pid (%u, %u, %u) test: %u, %u\0A\00"

declare !dbg !3 i32 @printf(ptr, ...)

define void @add_kernel(ptr addrspace(1) %0, ptr addrspace(1) %1, ptr addrspace(1) %2, i32 %3, i32 %4, i32 %5, i32 %6) !dbg !7 {
  %8 = add i32 %4, 42, !dbg !8
  %9 = call i32 (ptr, ...) @printf(ptr @printfFormat_0, i32 %4, i32 %5, i32 %6, i32 %8, i32 %4)
  ret void, !dbg !9
}
```

Tried to compile with a fake main function:
```
> % cat main.c
extern void add_kernel(float*, float*, float*, int, int, int, int);

int main() {
    add_kernel(0, 0, 0, 4, 5, 6, 7);
}

> % llc -filetype=obj add_kernel.llir && clang -o a.out add_kernel.llir.o main.c
> % ./a.out
pid (5, 6, 7) Hello, World!: 47, 5
```
minjang added a commit that referenced this pull request Oct 22, 2024
Summary: As title, `tl.program_id` needs to be supported first. As of now, we think pid will be provided as additional function arguments to the kernel. So, getting program_id is mapped to reading one of the last three arguments.

I also quickly implemented `tl.device_print` or `print`, only for scalar types for a quick "Hello, World!" testing.

Test Plan: Tested with a simple example:

```
@triton.jit
def add_kernel(...):
    pid = tl.program_id(axis=0)  # We use a 1D launch grid so axis is 0.
    foo = pid + 42
    tl.device_print("Hello, World!", foo, pid)
```

The resulting .llir is valid:
```
@printfFormat_1 = internal constant [31 x i8] c"pid (%u, %u, %u) test: %u, %u\0A\00"

declare !dbg !3 i32 @printf(ptr, ...)

define void @add_kernel(ptr addrspace(1) %0, ptr addrspace(1) %1, ptr addrspace(1) %2, i32 %3, i32 %4, i32 %5, i32 %6) !dbg !7 {
  %8 = add i32 %4, 42, !dbg !8
  %9 = call i32 (ptr, ...) @printf(ptr @printfFormat_0, i32 %4, i32 %5, i32 %6, i32 %8, i32 %4)
  ret void, !dbg !9
}
```

Tried to compile with a fake main function:
```
> % cat main.c
extern void add_kernel(float*, float*, float*, int, int, int, int);

int main() {
    add_kernel(0, 0, 0, 4, 5, 6, 7);
}

> % llc -filetype=obj add_kernel.llir && clang -o a.out add_kernel.llir.o main.c
> % ./a.out
pid (5, 6, 7) Hello, World!: 47, 5
```
minjang added a commit that referenced this pull request Oct 24, 2024
Summary: As title, `tl.program_id` needs to be supported first. As of now, we think pid will be provided as additional function arguments to the kernel. So, getting program_id is mapped to reading one of the last three arguments.

I also quickly implemented `tl.device_print` or `print`, only for scalar types for a quick "Hello, World!" testing.

Test Plan: Tested with a simple example:

```
@triton.jit
def add_kernel(...):
    pid = tl.program_id(axis=0)  # We use a 1D launch grid so axis is 0.
    foo = pid + 42
    tl.device_print("Hello, World!", foo, pid)
```

The resulting .llir is valid:
```
@printfFormat_1 = internal constant [31 x i8] c"pid (%u, %u, %u) test: %u, %u\0A\00"

declare !dbg !3 i32 @printf(ptr, ...)

define void @add_kernel(ptr addrspace(1) %0, ptr addrspace(1) %1, ptr addrspace(1) %2, i32 %3, i32 %4, i32 %5, i32 %6) !dbg !7 {
  %8 = add i32 %4, 42, !dbg !8
  %9 = call i32 (ptr, ...) @printf(ptr @printfFormat_0, i32 %4, i32 %5, i32 %6, i32 %8, i32 %4)
  ret void, !dbg !9
}
```

Tried to compile with a fake main function:
```
> % cat main.c
extern void add_kernel(float*, float*, float*, int, int, int, int);

int main() {
    add_kernel(0, 0, 0, 4, 5, 6, 7);
}

> % llc -filetype=obj add_kernel.llir && clang -o a.out add_kernel.llir.o main.c
> % ./a.out
pid (5, 6, 7) Hello, World!: 47, 5
```
Devjiu pushed a commit to Devjiu/triton-cpu that referenced this pull request Nov 13, 2024
Adds a `USE_BLOCK_POINTER` flag to the matmul_kernel so we can get IR for pointers-to-tensors instead of tensors-of-pointers.
Devjiu pushed a commit to Devjiu/triton-cpu that referenced this pull request Nov 13, 2024
Adds a `USE_BLOCK_POINTER` flag to the matmul_kernel so we can get IR for pointers-to-tensors instead of tensors-of-pointers.
maryamtahhan pushed a commit to maryamtahhan/triton-cpu that referenced this pull request Nov 14, 2024
This will fix the following problem:
```bash
python: /home/runner/work/triton/triton/llvm-project/llvm/include/llvm/ADT/ilist_iterator.h:168: llvm::ilist_iterator::reference llvm::ilist_iterator<llvm::ilist_detail::node_options<mlir::Operation, true, false, void, false, void>, false, false>::operator*() const [OptionsT = llvm::ilist_detail::node_options<mlir::Operation, true, false, void, false, void>, IsReverse = false, IsConst = false]: Assertion `!NodePtr->isKnownSentinel()' failed.
Aborted (core dumped)
```

The problem was found when using PyTorch on Intel gpu:

<details>

<summary> Simplified reproducer triton-lang#1:</summary>

```python
from torch._inductor.async_compile import AsyncCompile
async_compile = AsyncCompile()

triton_per_fused_add_embedding_native_layer_norm_0 = async_compile.triton('triton_per_fused_add_embedding_native_layer_norm_0', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.persistent_reduction(
    size_hints=[512, 128],
    reduction_hint=ReductionHint.INNER,
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*fp32', 'in_ptr2': '*fp32', 'in_ptr3': '*fp32', 'in_ptr4': '*fp32', 'in_ptr5': '*fp32', 'out_ptr2': '*fp32', 'xnumel': 'i32', 'rnumel': 'i32'}, 'device': DeviceProperties(type='xpu', index=0, cc={'driver_version': '1.3.30049', 'gpu_eu_count': 448, 'gpu_subslice_count': 56, 'has_atomic64': True, 'has_bfloat16_conversions': True, 'has_fp16': True, 'has_fp64': True, 'has_subgroup_2d_block_io': True, 'has_subgroup_matrix_multiply_accumulate': True, 'has_subgroup_matrix_multiply_accumulate_tensor_float32': False, 'max_compute_units': 448, 'max_num_sub_groups': 64, 'max_work_group_size': 1024, 'name': 'Intel(R) Data Center GPU Max 1100', 'platform_name': 'Intel(R) Level-Zero', 'sub_group_sizes': [16, 32], 'total_memory': 51539607552, 'type': 'gpu', 'vendor': 'Intel(R) Corporation', 'version': '1.3'}, major=None, regs_per_multiprocessor=None, max_threads_per_multi_processor=None, multi_processor_count=None, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4, 5, 6, 7, 8), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused_add_embedding_native_layer_norm_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 5, 'num_reduction': 4, 'backend_hash': 'D82C2E8E2C9203D653D1A2B8A0511701E4F7567A195A5128E03B9AA7218348AA', 'are_deterministic_algorithms_enabled': True, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_per_fused_add_embedding_native_layer_norm_0(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr):
    xnumel = 512
    rnumel = 128
    RBLOCK: tl.constexpr = 128
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    rindex = tl.arange(0, RBLOCK)[None, :]
    roffset = 0
    rmask = tl.full([XBLOCK, RBLOCK], True, tl.int1)
    x0 = xindex
    r1 = rindex
    tmp0 = tl.load(in_ptr0 + (x0), xmask, eviction_policy='evict_last')
    tmp7 = tl.load(in_ptr2 + (r1 + (128*x0)), xmask, other=0.0)
    tmp9 = tl.load(in_ptr3 + (r1 + (128*x0)), xmask, other=0.0)
    tmp34 = tl.load(in_ptr4 + (r1), None, eviction_policy='evict_last')
    tmp36 = tl.load(in_ptr5 + (r1), None, eviction_policy='evict_last')
    tmp1 = tl.full([XBLOCK, RBLOCK], 30000, tl.int32)
    tmp2 = tmp0 + tmp1
    tmp3 = tmp0 < 0
    tmp4 = tl.where(tmp3, tmp2, tmp0)
    tl.device_assert(((0 <= tmp4) & (tmp4 < 30000)) | ~(xmask), "index out of bounds: 0 <= tmp4 < 30000")
''', device_str='xpu')

```
</details>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant