Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Search for libdevice relative to shared library #1176

Merged
merged 4 commits into from
Feb 11, 2023

Conversation

malfet
Copy link
Collaborator

@malfet malfet commented Feb 11, 2023

If Triton is shipped as shared library inside python package, search for libdevice.10.bc relative to this library installation path: Triton package installation looks as follows

triton/
       _C/libtriton.so
       third_party/cuda/lib/libdevice.10.bc

Test plan:

CXX=g++-9 pip install git+https://github.com/malfet/triton@5a0fe18751bc15c7035916de55e0892e63c3db77#subdirectory=python

And run following:

import torch

def foo(x: torch.Tensor) -> torch.Tensor:
   return torch.sin(x) + torch.cos(x)

if __name__=="__main__":
    x = torch.rand(3, 3, device="cuda")
    x_eager = foo(x)
    x_pt2 = torch.compile(foo)(x)
    print(torch.allclose(x_eager, x_pt2))

@malfet malfet requested a review from Jokeren as a code owner February 11, 2023 01:49
@Superjomn Superjomn enabled auto-merge (squash) February 11, 2023 01:50
auto-merge was automatically disabled February 11, 2023 01:56

Head branch was pushed to by a user without write access

@ptillet ptillet enabled auto-merge (squash) February 11, 2023 01:57
@ptillet ptillet merged commit 2d4370b into triton-lang:main Feb 11, 2023
@malfet malfet deleted the relative-search branch February 11, 2023 05:24
dllehr-amd pushed a commit to dllehr-amd/triton that referenced this pull request Feb 19, 2023
dllehr-amd pushed a commit to dllehr-amd/triton that referenced this pull request Feb 19, 2023
dllehr-amd pushed a commit to dllehr-amd/triton that referenced this pull request Feb 19, 2023
dllehr-amd pushed a commit to ROCm/triton that referenced this pull request Feb 19, 2023
pingzhuu pushed a commit to siliconflow/triton that referenced this pull request Apr 2, 2024
ZzEeKkAa pushed a commit to ZzEeKkAa/triton that referenced this pull request Aug 5, 2024
This is related to the issue (triton-lang#1176) that internallinkage breaks
test_subprocess::test_assert.
The bug is caused by wrong __devicelib_assert_fail behavior when it is
optimized by inlining or dead argument elimination.
It is reported to IGC.
Until it is resolved, I changed the workaround specific to
__devicelib_assert_fail, so that other SPIR library functions can be
statically linked.
ZzEeKkAa pushed a commit to ZzEeKkAa/triton that referenced this pull request Aug 16, 2024
This PR fixes triton-lang#1176 
IGC detects the call of `__devicelib_assert_fail` and replace it with a
'safe' implementation.
However, the SYCL library contains a 'fallback' implementation of
assertion, which does not work in our setup.
If we mark the function with `InternalLinkage`, the fallback
implementation is inlined and IGC cannot replace it with the safe
implementation.
By declaring `__devicelib_assert_fail` as an external function in SYCL
library, IGC can correctly insert its implementation.
The diff between the old and new `libsycl-spir64-unknown-unknown.ll` is
as follows:
```diff
@@ -5424,149 +5424,7 @@ declare extern_weak dso_local spir_func noundef i32 @_Z18__spirv_AtomicLoadPU3AS
 declare void @llvm.memcpy.p4.p1.i64(ptr addrspace(4) noalias nocapture writeonly, ptr addrspace(1) noalias nocapture readonly, i64, i1 immarg) triton-lang#16
 
 ; Function Attrs: convergent mustprogress norecurse nounwind
-define weak dso_local spir_func void @__devicelib_assert_fail(ptr addrspace(4) noundef %0, ptr addrspace(4) noundef %1, i32 noundef %2, ptr addrspace(4) noundef %3, i64 noundef %4, i64 noundef %5, i64 noundef %6, i64 noundef %7, i64 noundef %8, i64 noundef %9) local_unnamed_addr triton-lang#14 !srcloc !720 {
-  %11 = tail call spir_func noundef i32 @_Z29__spirv_AtomicCompareExchangePU3AS1iN5__spv5Scope4FlagENS1_19MemorySemanticsMask4FlagES5_ii(ptr addrspace(1) noundef @SPIR_AssertHappenedMem, i32 noundef 1, i32 noundef 16, i32 noundef 16, i32 noundef 1, i32 noundef 0) triton-lang#54
-  %12 = icmp eq i32 %11, 0
-  br i1 %12, label %13, label %92
-
-13:                                               ; preds = %10
-  store i32 %2, ptr addrspace(1) getelementptr inbounds (%struct.AssertHappened, ptr addrspace(1) @SPIR_AssertHappenedMem, i64 0, i32 4), align 8, !tbaa !721
-  store i64 %4, ptr addrspace(1) getelementptr inbounds (%struct.AssertHappened, ptr addrspace(1) @SPIR_AssertHappenedMem, i64 0, i32 5), align 8, !tbaa !722
-  store i64 %5, ptr addrspace(1) getelementptr inbounds (%struct.AssertHappened, ptr addrspace(1) @SPIR_AssertHappenedMem, i64 0, i32 6), align 8, !tbaa !723
-  store i64 %6, ptr addrspace(1) getelementptr inbounds (%struct.AssertHappened, ptr addrspace(1) @SPIR_AssertHappenedMem, i64 0, i32 7), align 8, !tbaa !724
-  store i64 %7, ptr addrspace(1) getelementptr inbounds (%struct.AssertHappened, ptr addrspace(1) @SPIR_AssertHappenedMem, i64 0, i32 8), align 8, !tbaa !725
-  store i64 %8, ptr addrspace(1) getelementptr inbounds (%struct.AssertHappened, ptr addrspace(1) @SPIR_AssertHappenedMem, i64 0, i32 9), align 8, !tbaa !726
-  store i64 %9, ptr addrspace(1) getelementptr inbounds (%struct.AssertHappened, ptr addrspace(1) @SPIR_AssertHappenedMem, i64 0, i32 10), align 8, !tbaa !727
-  %14 = icmp eq ptr addrspace(4) %0, null
-  br i1 %14, label %23, label %15
-
-15:                                               ; preds = %20, %13
-  %16 = phi i32 [ %22, %20 ], [ 0, %13 ]
-  %17 = phi ptr addrspace(4) [ %21, %20 ], [ %0, %13 ]
-  %18 = load i8, ptr addrspace(4) %17, align 1, !tbaa !718
-  %19 = icmp eq i8 %18, 0
-  br i1 %19, label %23, label %20
-
-20:                                               ; preds = %15
-  %21 = getelementptr inbounds i8, ptr addrspace(4) %17, i64 1
-  %22 = add nuw nsw i32 %16, 1
-  br label %15, !llvm.loop !728
-
-23:                                               ; preds = %15, %13
-  %24 = phi i32 [ 0, %13 ], [ %16, %15 ]
-  %25 = icmp eq ptr addrspace(4) %1, null
-  br i1 %25, label %34, label %26
-
-26:                                               ; preds = %31, %23
-  %27 = phi i32 [ %33, %31 ], [ 0, %23 ]
-  %28 = phi ptr addrspace(4) [ %32, %31 ], [ %1, %23 ]
-  %29 = load i8, ptr addrspace(4) %28, align 1, !tbaa !718
-  %30 = icmp eq i8 %29, 0
-  br i1 %30, label %34, label %31
-
-31:                                               ; preds = %26
-  %32 = getelementptr inbounds i8, ptr addrspace(4) %28, i64 1
-  %33 = add nuw nsw i32 %27, 1
-  br label %26, !llvm.loop !729
-
-34:                                               ; preds = %26, %23
-  %35 = phi i32 [ 0, %23 ], [ %27, %26 ]
-  %36 = icmp eq ptr addrspace(4) %3, null
-  br i1 %36, label %37, label %40
-
-37:                                               ; preds = %34
-  %38 = tail call i32 @llvm.umin.i32(i32 %24, i32 256)
-  %39 = tail call i32 @llvm.umin.i32(i32 %35, i32 256)
-  br label %52
-
-40:                                               ; preds = %45, %34
-  %41 = phi i32 [ %47, %45 ], [ 0, %34 ]
-  %42 = phi ptr addrspace(4) [ %46, %45 ], [ %3, %34 ]
-  %43 = load i8, ptr addrspace(4) %42, align 1, !tbaa !718
-  %44 = icmp eq i8 %43, 0
-  br i1 %44, label %48, label %45
-
-45:                                               ; preds = %40
-  %46 = getelementptr inbounds i8, ptr addrspace(4) %42, i64 1
-  %47 = add i32 %41, 1
-  br label %40, !llvm.loop !730
-
-48:                                               ; preds = %40
-  %49 = tail call i32 @llvm.umin.i32(i32 %24, i32 256)
-  %50 = tail call i32 @llvm.umin.i32(i32 %35, i32 256)
-  %51 = tail call i32 @llvm.umin.i32(i32 %41, i32 128)
-  br label %52
-
-52:                                               ; preds = %48, %37
-  %53 = phi i32 [ %39, %37 ], [ %50, %48 ]
-  %54 = phi i32 [ %38, %37 ], [ %49, %48 ]
-  %55 = phi i32 [ 0, %37 ], [ %51, %48 ]
-  br label %56
-
-56:                                               ; preds = %62, %52
-  %57 = phi i32 [ 0, %52 ], [ %67, %62 ]
-  %58 = icmp ult i32 %57, %54
-  br i1 %58, label %62, label %59
-
-59:                                               ; preds = %56
-  %60 = zext nneg i32 %54 to i64
-  %61 = getelementptr inbounds %struct.AssertHappened, ptr addrspace(1) @SPIR_AssertHappenedMem, i64 0, i32 1, i64 %60
-  store i8 0, ptr addrspace(1) %61, align 1, !tbaa !718
-  br label %68
-
-62:                                               ; preds = %56
-  %63 = sext i32 %57 to i64
-  %64 = getelementptr inbounds i8, ptr addrspace(4) %0, i64 %63
-  %65 = load i8, ptr addrspace(4) %64, align 1, !tbaa !718
-  %66 = getelementptr inbounds %struct.AssertHappened, ptr addrspace(1) @SPIR_AssertHappenedMem, i64 0, i32 1, i64 %63
-  store i8 %65, ptr addrspace(1) %66, align 1, !tbaa !718
-  %67 = add nuw nsw i32 %57, 1
-  br label %56, !llvm.loop !731
-
-68:                                               ; preds = %74, %59
-  %69 = phi i32 [ 0, %59 ], [ %79, %74 ]
-  %70 = icmp ult i32 %69, %53
-  br i1 %70, label %74, label %71
-
-71:                                               ; preds = %68
-  %72 = zext nneg i32 %53 to i64
-  %73 = getelementptr inbounds %struct.AssertHappened, ptr addrspace(1) @SPIR_AssertHappenedMem, i64 0, i32 2, i64 %72
-  store i8 0, ptr addrspace(1) %73, align 1, !tbaa !718
-  br label %80
-
-74:                                               ; preds = %68
-  %75 = sext i32 %69 to i64
-  %76 = getelementptr inbounds i8, ptr addrspace(4) %1, i64 %75
-  %77 = load i8, ptr addrspace(4) %76, align 1, !tbaa !718
-  %78 = getelementptr inbounds %struct.AssertHappened, ptr addrspace(1) @SPIR_AssertHappenedMem, i64 0, i32 2, i64 %75
-  store i8 %77, ptr addrspace(1) %78, align 1, !tbaa !718
-  %79 = add nuw nsw i32 %69, 1
-  br label %68, !llvm.loop !732
-
-80:                                               ; preds = %86, %71
-  %81 = phi i32 [ 0, %71 ], [ %91, %86 ]
-  %82 = icmp ult i32 %81, %55
-  br i1 %82, label %86, label %83
-
-83:                                               ; preds = %80
-  %84 = sext i32 %55 to i64
-  %85 = getelementptr inbounds %struct.AssertHappened, ptr addrspace(1) @SPIR_AssertHappenedMem, i64 0, i32 3, i64 %84
-  store i8 0, ptr addrspace(1) %85, align 1, !tbaa !718
-  tail call spir_func void @_Z19__spirv_AtomicStorePU3AS1iN5__spv5Scope4FlagENS1_19MemorySemanticsMask4FlagEi(ptr addrspace(1) noundef @SPIR_AssertHappenedMem, i32 noundef 1, i32 noundef 16, i32 noundef 2) triton-lang#54
-  br label %92
-
-86:                                               ; preds = %80
-  %87 = sext i32 %81 to i64
-  %88 = getelementptr inbounds i8, ptr addrspace(4) %3, i64 %87
-  %89 = load i8, ptr addrspace(4) %88, align 1, !tbaa !718
-  %90 = getelementptr inbounds %struct.AssertHappened, ptr addrspace(1) @SPIR_AssertHappenedMem, i64 0, i32 3, i64 %87
-  store i8 %89, ptr addrspace(1) %90, align 1, !tbaa !718
-  %91 = add nuw nsw i32 %81, 1
-  br label %80, !llvm.loop !733
-
-92:                                               ; preds = %83, %10
-  ret void
-}
+declare extern_weak dso_local spir_func void @__devicelib_assert_fail(ptr addrspace(4) noundef %0, ptr addrspace(4) noundef %1, i32 noundef %2, ptr addrspace(4) noundef %3, i64 noundef %4, i64 noundef %5, i64 noundef %6, i64 noundef %7, i64 noundef %8, i64 noundef %9) local_unnamed_addr triton-lang#14
 
 ; Function Attrs: convergent nounwind
 declare extern_weak dso_local spir_func noundef i32 @_Z29__spirv_AtomicCompareExchangePU3AS1iN5__spv5Scope4FlagENS1_19MemorySemanticsMask4FlagES5_ii(ptr addrspace(1) noundef, i32 noundef, i32 noundef, i32 noundef, i32 noundef, i32 noundef) local_unnamed_addr triton-lang#15

```
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants