-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MatMul and blocksparse matmul incorrect precision in some shape. #1808
Comments
I have also encountered this problem. Have you resolved it? |
Unfortunately, I didn't find the reason. |
I faced a similar issue when playing around with blocksparse matrix multiplication, here is my code:
In this I generate two tensors containing only integer (or floating values if Am I overlooking something here? For now I'll stick to |
This PR fixes triton-lang#1176 IGC detects the call of `__devicelib_assert_fail` and replace it with a 'safe' implementation. However, the SYCL library contains a 'fallback' implementation of assertion, which does not work in our setup. If we mark the function with `InternalLinkage`, the fallback implementation is inlined and IGC cannot replace it with the safe implementation. By declaring `__devicelib_assert_fail` as an external function in SYCL library, IGC can correctly insert its implementation. The diff between the old and new `libsycl-spir64-unknown-unknown.ll` is as follows: ```diff @@ -5424,149 +5424,7 @@ declare extern_weak dso_local spir_func noundef i32 @_Z18__spirv_AtomicLoadPU3AS declare void @llvm.memcpy.p4.p1.i64(ptr addrspace(4) noalias nocapture writeonly, ptr addrspace(1) noalias nocapture readonly, i64, i1 immarg) triton-lang#16 ; Function Attrs: convergent mustprogress norecurse nounwind -define weak dso_local spir_func void @__devicelib_assert_fail(ptr addrspace(4) noundef %0, ptr addrspace(4) noundef %1, i32 noundef %2, ptr addrspace(4) noundef %3, i64 noundef %4, i64 noundef %5, i64 noundef %6, i64 noundef %7, i64 noundef %8, i64 noundef %9) local_unnamed_addr triton-lang#14 !srcloc !720 { - %11 = tail call spir_func noundef i32 @_Z29__spirv_AtomicCompareExchangePU3AS1iN5__spv5Scope4FlagENS1_19MemorySemanticsMask4FlagES5_ii(ptr addrspace(1) noundef @SPIR_AssertHappenedMem, i32 noundef 1, i32 noundef 16, i32 noundef 16, i32 noundef 1, i32 noundef 0) triton-lang#54 - %12 = icmp eq i32 %11, 0 - br i1 %12, label %13, label %92 - -13: ; preds = %10 - store i32 %2, ptr addrspace(1) getelementptr inbounds (%struct.AssertHappened, ptr addrspace(1) @SPIR_AssertHappenedMem, i64 0, i32 4), align 8, !tbaa !721 - store i64 %4, ptr addrspace(1) getelementptr inbounds (%struct.AssertHappened, ptr addrspace(1) @SPIR_AssertHappenedMem, i64 0, i32 5), align 8, !tbaa !722 - store i64 %5, ptr addrspace(1) getelementptr inbounds (%struct.AssertHappened, ptr addrspace(1) @SPIR_AssertHappenedMem, i64 0, i32 6), align 8, !tbaa !723 - store i64 %6, ptr addrspace(1) getelementptr inbounds (%struct.AssertHappened, ptr addrspace(1) @SPIR_AssertHappenedMem, i64 0, i32 7), align 8, !tbaa !724 - store i64 %7, ptr addrspace(1) getelementptr inbounds (%struct.AssertHappened, ptr addrspace(1) @SPIR_AssertHappenedMem, i64 0, i32 8), align 8, !tbaa !725 - store i64 %8, ptr addrspace(1) getelementptr inbounds (%struct.AssertHappened, ptr addrspace(1) @SPIR_AssertHappenedMem, i64 0, i32 9), align 8, !tbaa !726 - store i64 %9, ptr addrspace(1) getelementptr inbounds (%struct.AssertHappened, ptr addrspace(1) @SPIR_AssertHappenedMem, i64 0, i32 10), align 8, !tbaa !727 - %14 = icmp eq ptr addrspace(4) %0, null - br i1 %14, label %23, label %15 - -15: ; preds = %20, %13 - %16 = phi i32 [ %22, %20 ], [ 0, %13 ] - %17 = phi ptr addrspace(4) [ %21, %20 ], [ %0, %13 ] - %18 = load i8, ptr addrspace(4) %17, align 1, !tbaa !718 - %19 = icmp eq i8 %18, 0 - br i1 %19, label %23, label %20 - -20: ; preds = %15 - %21 = getelementptr inbounds i8, ptr addrspace(4) %17, i64 1 - %22 = add nuw nsw i32 %16, 1 - br label %15, !llvm.loop !728 - -23: ; preds = %15, %13 - %24 = phi i32 [ 0, %13 ], [ %16, %15 ] - %25 = icmp eq ptr addrspace(4) %1, null - br i1 %25, label %34, label %26 - -26: ; preds = %31, %23 - %27 = phi i32 [ %33, %31 ], [ 0, %23 ] - %28 = phi ptr addrspace(4) [ %32, %31 ], [ %1, %23 ] - %29 = load i8, ptr addrspace(4) %28, align 1, !tbaa !718 - %30 = icmp eq i8 %29, 0 - br i1 %30, label %34, label %31 - -31: ; preds = %26 - %32 = getelementptr inbounds i8, ptr addrspace(4) %28, i64 1 - %33 = add nuw nsw i32 %27, 1 - br label %26, !llvm.loop !729 - -34: ; preds = %26, %23 - %35 = phi i32 [ 0, %23 ], [ %27, %26 ] - %36 = icmp eq ptr addrspace(4) %3, null - br i1 %36, label %37, label %40 - -37: ; preds = %34 - %38 = tail call i32 @llvm.umin.i32(i32 %24, i32 256) - %39 = tail call i32 @llvm.umin.i32(i32 %35, i32 256) - br label %52 - -40: ; preds = %45, %34 - %41 = phi i32 [ %47, %45 ], [ 0, %34 ] - %42 = phi ptr addrspace(4) [ %46, %45 ], [ %3, %34 ] - %43 = load i8, ptr addrspace(4) %42, align 1, !tbaa !718 - %44 = icmp eq i8 %43, 0 - br i1 %44, label %48, label %45 - -45: ; preds = %40 - %46 = getelementptr inbounds i8, ptr addrspace(4) %42, i64 1 - %47 = add i32 %41, 1 - br label %40, !llvm.loop !730 - -48: ; preds = %40 - %49 = tail call i32 @llvm.umin.i32(i32 %24, i32 256) - %50 = tail call i32 @llvm.umin.i32(i32 %35, i32 256) - %51 = tail call i32 @llvm.umin.i32(i32 %41, i32 128) - br label %52 - -52: ; preds = %48, %37 - %53 = phi i32 [ %39, %37 ], [ %50, %48 ] - %54 = phi i32 [ %38, %37 ], [ %49, %48 ] - %55 = phi i32 [ 0, %37 ], [ %51, %48 ] - br label %56 - -56: ; preds = %62, %52 - %57 = phi i32 [ 0, %52 ], [ %67, %62 ] - %58 = icmp ult i32 %57, %54 - br i1 %58, label %62, label %59 - -59: ; preds = %56 - %60 = zext nneg i32 %54 to i64 - %61 = getelementptr inbounds %struct.AssertHappened, ptr addrspace(1) @SPIR_AssertHappenedMem, i64 0, i32 1, i64 %60 - store i8 0, ptr addrspace(1) %61, align 1, !tbaa !718 - br label %68 - -62: ; preds = %56 - %63 = sext i32 %57 to i64 - %64 = getelementptr inbounds i8, ptr addrspace(4) %0, i64 %63 - %65 = load i8, ptr addrspace(4) %64, align 1, !tbaa !718 - %66 = getelementptr inbounds %struct.AssertHappened, ptr addrspace(1) @SPIR_AssertHappenedMem, i64 0, i32 1, i64 %63 - store i8 %65, ptr addrspace(1) %66, align 1, !tbaa !718 - %67 = add nuw nsw i32 %57, 1 - br label %56, !llvm.loop !731 - -68: ; preds = %74, %59 - %69 = phi i32 [ 0, %59 ], [ %79, %74 ] - %70 = icmp ult i32 %69, %53 - br i1 %70, label %74, label %71 - -71: ; preds = %68 - %72 = zext nneg i32 %53 to i64 - %73 = getelementptr inbounds %struct.AssertHappened, ptr addrspace(1) @SPIR_AssertHappenedMem, i64 0, i32 2, i64 %72 - store i8 0, ptr addrspace(1) %73, align 1, !tbaa !718 - br label %80 - -74: ; preds = %68 - %75 = sext i32 %69 to i64 - %76 = getelementptr inbounds i8, ptr addrspace(4) %1, i64 %75 - %77 = load i8, ptr addrspace(4) %76, align 1, !tbaa !718 - %78 = getelementptr inbounds %struct.AssertHappened, ptr addrspace(1) @SPIR_AssertHappenedMem, i64 0, i32 2, i64 %75 - store i8 %77, ptr addrspace(1) %78, align 1, !tbaa !718 - %79 = add nuw nsw i32 %69, 1 - br label %68, !llvm.loop !732 - -80: ; preds = %86, %71 - %81 = phi i32 [ 0, %71 ], [ %91, %86 ] - %82 = icmp ult i32 %81, %55 - br i1 %82, label %86, label %83 - -83: ; preds = %80 - %84 = sext i32 %55 to i64 - %85 = getelementptr inbounds %struct.AssertHappened, ptr addrspace(1) @SPIR_AssertHappenedMem, i64 0, i32 3, i64 %84 - store i8 0, ptr addrspace(1) %85, align 1, !tbaa !718 - tail call spir_func void @_Z19__spirv_AtomicStorePU3AS1iN5__spv5Scope4FlagENS1_19MemorySemanticsMask4FlagEi(ptr addrspace(1) noundef @SPIR_AssertHappenedMem, i32 noundef 1, i32 noundef 16, i32 noundef 2) triton-lang#54 - br label %92 - -86: ; preds = %80 - %87 = sext i32 %81 to i64 - %88 = getelementptr inbounds i8, ptr addrspace(4) %3, i64 %87 - %89 = load i8, ptr addrspace(4) %88, align 1, !tbaa !718 - %90 = getelementptr inbounds %struct.AssertHappened, ptr addrspace(1) @SPIR_AssertHappenedMem, i64 0, i32 3, i64 %87 - store i8 %89, ptr addrspace(1) %90, align 1, !tbaa !718 - %91 = add nuw nsw i32 %81, 1 - br label %80, !llvm.loop !733 - -92: ; preds = %83, %10 - ret void -} +declare extern_weak dso_local spir_func void @__devicelib_assert_fail(ptr addrspace(4) noundef %0, ptr addrspace(4) noundef %1, i32 noundef %2, ptr addrspace(4) noundef %3, i64 noundef %4, i64 noundef %5, i64 noundef %6, i64 noundef %7, i64 noundef %8, i64 noundef %9) local_unnamed_addr triton-lang#14 ; Function Attrs: convergent nounwind declare extern_weak dso_local spir_func noundef i32 @_Z29__spirv_AtomicCompareExchangePU3AS1iN5__spv5Scope4FlagENS1_19MemorySemanticsMask4FlagES5_ii(ptr addrspace(1) noundef, i32 noundef, i32 noundef, i32 noundef, i32 noundef, i32 noundef) local_unnamed_addr triton-lang#15 ```
Using an Triton 2.0.0, Pytorch 2.0.0, Python 3.9.16, Cuda 11.6 on a pc running Centos release 7.4.1708 with an nvidia A100. I using the
matmul
andblocksparse/matmul
ops in https://github.com/openai/triton/tree/main/python/triton/ops . And I using the test code like to [test_matmul.py](https://github.com/openai/triton/blob/main/python/test/unit/operators/test_matmul.py) and [test_blocksparse.py](https://github.com/openai/triton/blob/main/python/test/unit/operators/test_blocksparse.py).Then I find some problem when I compare the tirton matmul with torch.matmul, the result is different by torch.allclose(atol = 1e-5, rtol=0) as follow:
Matmul Test
the tesing code as follow:
This code will print
total difference
more than 0.0, and thetorch.allclose
is return false.Then I tried observed some character:
the
diff
increasing as the shape increase. I guess it maybe related from cumulative accuracy of the calculation. But when I usingM,K,N = 4096,4096,4096
running this code in my machine, it's pass ✅ theallclose
function and diff = 0.000000. It's also related withshape
? Because only some shape will occur the problem.Moreover, I had try some special data to test in shape
M, N, K = 2048, 2048, 2048
.I take the
a = torch.ones ,b = torch.ones
to run the code, which result is always pass ✅. So in some times this don't related from shape.I take the
a = torch.ones ,b = torch.randn
to run the code, which every row for the result matrix is same, also same in the incorrect elements.Blocksparse Matmul Test
The incorrect precision also in blocksparse matmul function. the test code as follow, which only using the forward testing for [test_blocksparse.py](https://github.com/openai/triton/blob/main/python/test/unit/operators/test_blocksparse.py) :
This code will print
total difference
more than 0.0, and thetorch.allclose
is return false.Then I tried observed some character:
M, N, K = 256, 256, 256
, the code always pass ✅N, K = 4096, 4096
, which show the more than half of the range print the❌ Triton and Torch differ
.So what could be causing the incorrect precision and how to solute the problem?
The text was updated successfully, but these errors were encountered: