From 33c0c1cc2f9b67adc2c4d32282b67ec11632b0a3 Mon Sep 17 00:00:00 2001 From: Alexander Efimov Date: Thu, 3 Oct 2024 18:08:41 +0200 Subject: [PATCH 1/7] [AMD] Fix shared layout order for batch dimension in pipeline passes (#4796) Batch dimension should be slowest one, other cases are not supported by MFMA/WMMA/MMA pipeline. --- test/TritonGPU/loop-pipeline-hip.mlir | 35 +++++++++++++++++++ .../TritonAMDGPUTransforms/StreamPipeline.cpp | 21 +++++++++-- .../StreamPipelineV2.cpp | 16 ++++++++- 3 files changed, 68 insertions(+), 4 deletions(-) diff --git a/test/TritonGPU/loop-pipeline-hip.mlir b/test/TritonGPU/loop-pipeline-hip.mlir index 55e1b65fa..28c815feb 100644 --- a/test/TritonGPU/loop-pipeline-hip.mlir +++ b/test/TritonGPU/loop-pipeline-hip.mlir @@ -198,3 +198,38 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : tt.return } } // end module + +// ----- + +// CHECK-NOT: #triton_gpu.shared<{{.*}} order = [2, 0, 1] +// CHECK: #triton_gpu.shared<{{.*}} order = [2, 1, 0] +// CHECK-NOT: #triton_gpu.shared<{{.*}} order = [2, 0, 1] + +// CHECK-LABEL: tt.func public @slowest_dim_is_batch +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1, 2], threadsPerWarp = [4, 1, 16], warpsPerCTA = [4, 1, 1], order = [2, 1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1, 8], threadsPerWarp = [16, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 64], warpsPerCTA = [1, 4], order = [1, 0]}> +#blocked5 = #triton_gpu.blocked<{sizePerThread = [1, 1, 2], threadsPerWarp = [16, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx90a", "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func public @slowest_dim_is_batch(%arg0: tensor<1x512x!tt.ptr, #blocked2>, %arg1: tensor<64x8x32x!tt.ptr, #blocked1>, %arg2: tensor<64x1x32x!tt.ptr, #blocked>) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<64x1x32xf32, #blocked> + %cst_0 = arith.constant dense<512> : tensor<1x512xi32, #blocked2> + %cst_1 = arith.constant dense<128> : tensor<64x8x32xi32, #blocked1> + %c1_i32 = arith.constant 1 : i32 + %c5_i32 = arith.constant 2 : i32 + %c0_i32 = arith.constant 0 : i32 + %33:3 = scf.for %arg7 = %c0_i32 to %c5_i32 step %c1_i32 iter_args(%arg8 = %cst, %arg9 = %arg0, %arg10 = %arg1) -> (tensor<64x1x32xf32, #blocked>, tensor<1x512x!tt.ptr, #blocked2>, tensor<64x8x32x!tt.ptr, #blocked1>) : i32 { + %39 = tt.load %arg9 : tensor<1x512x!tt.ptr, #blocked2> + %40 = tt.load %arg10 : tensor<64x8x32x!tt.ptr, #blocked1> + %41 = tt.reshape %39 {allow_reorder = true} : tensor<1x512xf32, #blocked2> -> tensor<64x1x8xf32, #blocked5> + %43 = triton_gpu.convert_layout %41 : tensor<64x1x8xf32, #blocked5> -> tensor<64x1x8xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> + %44 = triton_gpu.convert_layout %40 : tensor<64x8x32xf32, #blocked1> -> tensor<64x8x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> + %45 = tt.dot %43, %44, %arg8 : tensor<64x1x8xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x8x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<64x1x32xf32, #blocked> + %46 = tt.addptr %arg9, %cst_0 : tensor<1x512x!tt.ptr, #blocked2>, tensor<1x512xi32, #blocked2> + %47 = tt.addptr %arg10, %cst_1 : tensor<64x8x32x!tt.ptr, #blocked1>, tensor<64x8x32xi32, #blocked1> + scf.yield %45, %46, %47 : tensor<64x1x32xf32, #blocked>, tensor<1x512x!tt.ptr, #blocked2>, tensor<64x8x32x!tt.ptr, #blocked1> + } + tt.store %arg2, %33#0 : tensor<64x1x32x!tt.ptr, #blocked> + tt.return + } +} diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp index 224df9028..784ce52e1 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp @@ -403,9 +403,24 @@ void LoopPipeliner::createBufferTypes() { // unsigned bitWidth = dotOpEnc.getMMAv2kWidth() // ? 32 / dotOpEnc.getMMAv2kWidth() // : ty.getElementType().getIntOrFloatBitWidth(); - auto sharedEnc = ttg::SharedEncodingAttr::get( - ty.getContext(), dotOpEnc, ty.getShape(), - ttg::getOrder(ty.getEncoding()), CTALayout, eType); + auto srcOrder = ttg::getOrder(ty.getEncoding()); + SmallVector sharedOrder; + int rank = srcOrder.size(); + // TODO rework this when shared -> dotOp conversions support arbitrary + // shared memory ordering + if (rank == 3) { + // Move the batch dimension (dim #0) to be the last so that it will be the + // slowest varying dimension. + for (unsigned i = 0; i < rank; ++i) + if (srcOrder[i] != 0) + sharedOrder.emplace_back(srcOrder[i]); + sharedOrder.emplace_back(0); + } else { + sharedOrder = srcOrder; + } + auto sharedEnc = + ttg::SharedEncodingAttr::get(ty.getContext(), dotOpEnc, ty.getShape(), + sharedOrder, CTALayout, eType); loadsBufferType[loadOp] = triton::MemDescType::get( bufferShape, eType, sharedEnc, triton::gpu::SharedMemorySpaceAttr::get(ty.getContext()), diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipelineV2.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipelineV2.cpp index f1d04b727..027f06652 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipelineV2.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipelineV2.cpp @@ -207,8 +207,22 @@ getSharedEncIfAllUsersAreDotEnc(Value val) { auto CTALayout = ttg::getCTALayout(srcTy.getEncoding()); auto order = ttg::getOrder(srcTy.getEncoding()); unsigned bitWidth = srcTy.getElementType().getIntOrFloatBitWidth(); + SmallVector sharedOrder; + int rank = order.size(); + // TODO rework this when shared -> dotOp conversions support arbitrary + // shared memory ordering + if (rank == 3) { + // Move the batch dimension (dim #0) to be the last so that it will be + // the slowest varying dimension. + for (unsigned i = 0; i < rank; ++i) + if (order[i] != 0) + sharedOrder.emplace_back(order[i]); + sharedOrder.emplace_back(0); + } else { + sharedOrder = order; + } tempAttr = ttg::SharedEncodingAttr::get( - val.getContext(), dotOpEnc, srcTy.getShape(), order, CTALayout, + val.getContext(), dotOpEnc, srcTy.getShape(), sharedOrder, CTALayout, bitWidth, /*needTrans=*/false); } // Check that the shared encodings needed by the users are compatible. From 14951165c4f34813185db40e09598bc6942d0925 Mon Sep 17 00:00:00 2001 From: Alexander Weinrauch Date: Thu, 3 Oct 2024 17:15:52 +0100 Subject: [PATCH 2/7] [AMD] Add missing i16 for wmma and disable some tests (#4843) --- python/test/unit/language/test_core.py | 3 ++- third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp | 2 ++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 039f7ac1a..fbc1eb31b 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -3330,7 +3330,8 @@ def test_dot3d(B, num_warps, M, N, K, BLOCK_M, BLOCK_N, in_dtype_str, out_dtype_ if is_hip(): # hip does not support tf32 precision, so use ieee for all tests input_precision = "ieee" - if "gfx11" in triton.runtime.driver.active.get_current_target().arch: + arch = triton.runtime.driver.active.get_current_target().arch + if "gfx11" in arch or "gfx12" in arch: if in_dtype_str == "float32": pytest.skip(f"{in_dtype_str} is not supported in WMMA dot, FMA does not support dot3d") if out_dtype_str == "float16": diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp index 2c93d6f0e..936844325 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp @@ -162,6 +162,8 @@ std::string getTypeStr(Type ty) { scalarName = "bf16"; } else if (ty.isInteger(32)) { scalarName = "i32"; + } else if (ty.isInteger(16)) { + scalarName = "i16"; } else if (ty.isInteger(8)) { scalarName = "iu8"; } else if (ty.isInteger(4)) { From b8d8ce988479cc283b013cf55a13e8da97507277 Mon Sep 17 00:00:00 2001 From: Alexander Efimov Date: Thu, 3 Oct 2024 23:53:28 +0200 Subject: [PATCH 3/7] [Backend] Bypass conversion for suitable blocked to dotOperand layout (#4538) This PR extends shared memory bypass for blocked->dotOperand conversions and adds bypass check in DecomposeUnsupportedConversions and ReduceDataDuplication. This commit is a preparation step towards improving CodeGen and efficiency of skinny dot cases. --- include/triton/Analysis/Utility.h | 2 + lib/Analysis/Utility.cpp | 76 ++++++++++++- .../TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp | 32 ++++++ .../DecomposeUnsupportedConversions.cpp | 2 + .../Transforms/ReduceDataDuplication.cpp | 18 +-- .../decompose-unsupported-conversions.mlir | 104 +++++++++++++++--- .../tritongpu_to_llvm_block_dot_shortcut.mlir | 47 ++++++++ test/TritonGPU/reduce-data-duplication.mlir | 34 +++++- 8 files changed, 277 insertions(+), 38 deletions(-) create mode 100644 test/Conversion/tritongpu_to_llvm_block_dot_shortcut.mlir diff --git a/include/triton/Analysis/Utility.h b/include/triton/Analysis/Utility.h index 9e3eff155..ae05e2049 100644 --- a/include/triton/Analysis/Utility.h +++ b/include/triton/Analysis/Utility.h @@ -203,6 +203,8 @@ bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy); bool atomicNeedsSharedMemory(Value result); +bool isBlockedToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstT); + bool isMfmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy); bool isMmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy); diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index b9468aa3e..030dd6710 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -536,6 +536,75 @@ bool supportMMA(Value value, int version) { (elemTy.isInteger(8) && version >= 2); } +bool isBlockedToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) { + auto blockedLayout = dyn_cast(srcTy.getEncoding()); + auto dotOperandLayout = dyn_cast(dstTy.getEncoding()); + if (blockedLayout == nullptr || dotOperandLayout == nullptr) + return false; + auto parentLayout = + dyn_cast(dotOperandLayout.getParent()); + if (parentLayout == nullptr) + return false; + auto opShape = srcTy.getShape(); + auto rank = opShape.size(); + + int kDim = dotOperandLayout.getOpIdx() == 0 ? rank - 1 : rank - 2; + int nonKDim = dotOperandLayout.getOpIdx() == 0 ? rank - 2 : rank - 1; + auto ctaLayout = blockedLayout.getCTALayout(); + + // The following logic checks that a source blocked layout matches a + // destination dot operand layout. This means that given tensor in source + // layout could be converted into destination layout without any data movement + // between registers or threads. + // + // It is considered a match if + // 1) Each thread in source layout holds a whole copy of all elements along + // the K dimension of a tensor + // 2) Distribution of data along all other non-K dimensions(Batch/M/N) + // matches between source and destination parent layouts. + // + // First condition comes from the property of dot operand layout with Blocked + // parent: size per threads along K dimension equals size of the tensor along + // K. Second condition comes from other property: dot operand layout + // inherits non-K dimensions from it's parent layout. + // + // clang-format off + // + // For example, following conversion is a no op: + // tensor<128x32xf16, #blocked<{sizePerThread = [2, 32], threadsPerWarp = [32, 1]}>> + // -> + // tensor<128x32xf16, #dot_op<{opIdx=0, parent=#blocked<{sizePerThread = [2, 8], threadsPerWarp = [32, 1]}>>> + // + // clang-format on + bool ctaLayoutCompatible = + ctaLayout.getCTASplitNum()[kDim] == 1 && + blockedLayout.getCTALayout() == parentLayout.getCTALayout(); + bool threadHoldsWholeKDim = + blockedLayout.getSizePerThread()[kDim] == opShape[kDim]; + bool nonKDimCompatible = + blockedLayout.getOrder() == parentLayout.getOrder() && + blockedLayout.getSizePerThread()[nonKDim] == + parentLayout.getSizePerThread()[nonKDim] && + blockedLayout.getThreadsPerWarp()[nonKDim] == + parentLayout.getThreadsPerWarp()[nonKDim] && + blockedLayout.getWarpsPerCTA()[nonKDim] == + parentLayout.getWarpsPerCTA()[nonKDim]; + bool matrixDimsCompatible = + ctaLayoutCompatible && threadHoldsWholeKDim && nonKDimCompatible; + if (rank == 2) + return matrixDimsCompatible; + + // additional check for batch dimension if it is present + assert(rank == 3); + bool bDimCompatible = + blockedLayout.getSizePerThread()[0] == + parentLayout.getSizePerThread()[0] && + blockedLayout.getThreadsPerWarp()[0] == + parentLayout.getThreadsPerWarp()[0] && + blockedLayout.getWarpsPerCTA()[0] == parentLayout.getWarpsPerCTA()[0]; + return matrixDimsCompatible && bDimCompatible; +} + bool isMfmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) { auto mfmaLayout = dyn_cast(srcTy.getEncoding()); auto dotOperandLayout = dyn_cast(dstTy.getEncoding()); @@ -625,12 +694,13 @@ bool cvtNeedsWarpShuffle(RankedTensorType srcTy, RankedTensorType dstTy) { } bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy) { - // TODO(jlebar): Remove these special cases (`isMmaToDotShortcut` and - // `isMfmaToDotShortcut`) once they're fully subsumed by the linear-layout - // checks. + // TODO(jlebar): Remove these special cases (`isMmaToDotShortcut`, + // `isBlockedToDotShortcut` and `isMfmaToDotShortcut`) once they're fully + // subsumed by the linear-layout checks. // TODO(Keren): We didn't check `cvtNeedsWarpShuffle` here because it's not // supported yet in Triton's backend. return !cvtReordersRegisters(srcTy, dstTy) && + !isBlockedToDotShortcut(srcTy, dstTy) && !isMmaToDotShortcut(srcTy, dstTy) && !isMfmaToDotShortcut(srcTy, dstTy); } diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 30cb79276..893afc659 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -232,6 +232,36 @@ struct ConvertLayoutOpConversion const TargetInfoBase &targetInfo; }; +struct ConvertLayoutOpBlockedToDotOpShortcutConversion + : public ConvertOpToLLVMPattern { + const TargetInfoBase &targetInfo; + explicit ConvertLayoutOpBlockedToDotOpShortcutConversion( + LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) { + } + + LogicalResult + matchAndRewrite(ConvertLayoutOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + MLIRContext *ctx = op.getContext(); + + const auto &shape = op.getType().getShape(); + auto srcTy = op.getSrc().getType(); + auto dstTy = op.getType(); + auto dstDotEncoding = dyn_cast(dstTy.getEncoding()); + if (!dstDotEncoding) + return failure(); + if (!isa(srcTy.getEncoding()) || + !isa(dstDotEncoding.getParent())) + return failure(); + if (cvtNeedsSharedMemory(srcTy, dstTy)) + return failure(); + rewriter.replaceOp(op, adaptor.getSrc()); + return success(); + } +}; + struct ConvertLayoutOpUsingLinearLayoutsConversion : public ConvertOpToLLVMPattern { const TargetInfoBase &targetInfo; @@ -657,5 +687,7 @@ void mlir::triton::populateConvertLayoutOpToLLVMPatterns( // one left. mlir::triton::populateConvertLayoutOpUsingLinearLayoutsToLLVMPattern( typeConverter, targetInfo, patterns, benefit.getBenefit() + 1); + patterns.add( + typeConverter, targetInfo, benefit); patterns.add(typeConverter, targetInfo, benefit); } diff --git a/lib/Conversion/TritonGPUToLLVM/DecomposeUnsupportedConversions.cpp b/lib/Conversion/TritonGPUToLLVM/DecomposeUnsupportedConversions.cpp index a477db216..1346cc143 100644 --- a/lib/Conversion/TritonGPUToLLVM/DecomposeUnsupportedConversions.cpp +++ b/lib/Conversion/TritonGPUToLLVM/DecomposeUnsupportedConversions.cpp @@ -83,6 +83,8 @@ void decomposeBlockedToDotLayoutConversion(ModuleOp module) { OpBuilder builder(cvtOp); auto srcType = cast(cvtOp.getSrc().getType()); auto dstType = cast(cvtOp.getType()); + if (!cvtNeedsSharedMemory(srcType, dstType)) + return; auto srcBlocked = dyn_cast(srcType.getEncoding()); auto dstDotOp = diff --git a/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp b/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp index 8c1f18e45..b1e296c1b 100644 --- a/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp +++ b/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp @@ -42,22 +42,8 @@ class TritonGPUReduceDataDuplicationPass dyn_cast(dstType.getEncoding()); if (!dstDotOp) return; - if (auto srcMmaEncoding = - dyn_cast(srcEncoding)) { - - if (srcMmaEncoding.getVersionMajor() != 2 || - (srcMmaEncoding.getWarpsPerCTA()[1] == 1 && - dstDotOp.getParent() == srcMmaEncoding)) - return; - } - if (auto srcMfmaEncoding = - dyn_cast(srcEncoding)) { - - if (srcMfmaEncoding.getWarpsPerCTA()[1] == 1 && - srcMfmaEncoding.getIsTransposed() && - dstDotOp.getParent() == srcMfmaEncoding) - return; - } + if (!cvtNeedsSharedMemory(srcType, dstType)) + return; auto srcOrder = triton::gpu::getOrder(srcEncoding); auto rank = srcOrder.size(); SmallVector sharedOrder; diff --git a/test/Conversion/amd/decompose-unsupported-conversions.mlir b/test/Conversion/amd/decompose-unsupported-conversions.mlir index 0d6220c80..1bd288449 100644 --- a/test/Conversion/amd/decompose-unsupported-conversions.mlir +++ b/test/Conversion/amd/decompose-unsupported-conversions.mlir @@ -1,15 +1,15 @@ -// RUN: triton-opt %s --split-input-file --decompose-unsupported-amd-conversions=arch=gfx1130 | FileCheck %s +// RUN: triton-opt %s --split-input-file --decompose-unsupported-amd-conversions | FileCheck %s -// CHECK: #[[BLOCKED:.+]] = #triton_gpu.blocked<{{.*}}> -// CHECK: #[[WMMA:.+]] = #triton_gpu.amd_wmma<{{.*}}> -// CHECK: #[[SHARED:.+]] = #triton_gpu.shared<{{.*}}> -// CHECK: wmma_to_wmma_dot_op +// CHECK: #[[$BLOCKED:.+]] = #triton_gpu.blocked<{{.*}}> +// CHECK: #[[$WMMA:.+]] = #triton_gpu.amd_wmma<{{.*}}> +// CHECK: #[[$SHARED:.+]] = #triton_gpu.shared<{{.*}}> +// CHECK-LABEL: wmma_to_wmma_dot_op #mma = #triton_gpu.amd_wmma<{version = 1, warpsPerCTA = [2, 2]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx1130", "triton_gpu.threads-per-warp" = 32 : i32} { tt.func @wmma_to_wmma_dot_op(%arg0: tensor<16x16xf16, #mma>) { - // CHECK: %[[SRC_BLOCKED:.+]] = triton_gpu.convert_layout %{{.*}} : tensor<16x16xf16, #[[WMMA]]> -> tensor<16x16xf16, #[[BLOCKED]]> - // CHECK-NEXT: %[[INT_SHARED:.+]] = triton_gpu.local_alloc %[[SRC_BLOCKED]] : {{.*}} -> !tt.memdesc<16x16xf16, #[[SHARED]], #triton_gpu.shared_memory> - // CHECK-NEXT: %[[DST_DOT_OP:.+]] = triton_gpu.local_load %[[INT_SHARED]] : {{.*}} -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[WMMA]], kWidth = 16}>> + // CHECK: %[[SRC_BLOCKED:.+]] = triton_gpu.convert_layout %{{.*}} : tensor<16x16xf16, #[[$WMMA]]> -> tensor<16x16xf16, #[[$BLOCKED]]> + // CHECK-NEXT: %[[INT_SHARED:.+]] = triton_gpu.local_alloc %[[SRC_BLOCKED]] : {{.*}} -> !tt.memdesc<16x16xf16, #[[$SHARED]], #triton_gpu.shared_memory> + // CHECK-NEXT: %[[DST_DOT_OP:.+]] = triton_gpu.local_load %[[INT_SHARED]] : {{.*}} -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[$WMMA]], kWidth = 16}>> %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf16, #mma> -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> tt.return } @@ -17,17 +17,89 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -// CHECK: #[[BLOCKED:.+]] = #triton_gpu.blocked<{{.*}}> -// CHECK: #[[WMMA:.+]] = #triton_gpu.amd_wmma<{{.*}}> -// CHECK: #[[SHARED:.+]] = #triton_gpu.shared<{{.*}}> -// CHECK: wmma_to_wmma_dot3d_op +// CHECK: #[[$BLOCKED:.+]] = #triton_gpu.blocked<{{.*}}> +// CHECK: #[[$WMMA:.+]] = #triton_gpu.amd_wmma<{{.*}}> +// CHECK: #[[$SHARED:.+]] = #triton_gpu.shared<{{.*}}> +// CHECK-LABEL: wmma_to_wmma_dot3d_op #mma = #triton_gpu.amd_wmma<{version = 1, warpsPerCTA = [2, 2, 2]}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { tt.func @wmma_to_wmma_dot3d_op(%arg0: tensor<2x16x16xf16, #mma>) { - // CHECK: %[[SRC_BLOCKED:.+]] = triton_gpu.convert_layout %{{.*}} : tensor<2x16x16xf16, #[[WMMA]]> -> tensor<2x16x16xf16, #[[BLOCKED]]> - // CHECK-NEXT: %[[INT_SHARED:.+]] = triton_gpu.local_alloc %[[SRC_BLOCKED]] : {{.*}} -> !tt.memdesc<2x16x16xf16, #[[SHARED]], #triton_gpu.shared_memory> - // CHECK-NEXT: %[[DST_DOT_OP:.+]] = triton_gpu.local_load %[[INT_SHARED]] : {{.*}} -> tensor<2x16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[WMMA]], kWidth = 16}>> + // CHECK: %[[SRC_BLOCKED:.+]] = triton_gpu.convert_layout %{{.*}} : tensor<2x16x16xf16, #[[$WMMA]]> -> tensor<2x16x16xf16, #[[$BLOCKED]]> + // CHECK-NEXT: %[[INT_SHARED:.+]] = triton_gpu.local_alloc %[[SRC_BLOCKED]] : {{.*}} -> !tt.memdesc<2x16x16xf16, #[[$SHARED]], #triton_gpu.shared_memory> + // CHECK-NEXT: %[[DST_DOT_OP:.+]] = triton_gpu.local_load %[[INT_SHARED]] : {{.*}} -> tensor<2x16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[$WMMA]], kWidth = 16}>> %0 = triton_gpu.convert_layout %arg0 : tensor<2x16x16xf16, #mma> -> tensor<2x16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> tt.return } } + +// ----- + +// CHECK-LABEL: blocked_to_dot_op_shortcut_gfx1130 +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 1], warpsPerCTA = [2, 2], order = [1, 0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx1130", "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func @blocked_to_dot_op_shortcut_gfx1130(%arg0: tensor<32x32xf16, #blocked>) { + // CHECK-NOT: triton_gpu.local_alloc + // CHECK: triton_gpu.convert_layout + // CHECK-NOT: triton_gpu.local_alloc + %0 = triton_gpu.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> + tt.return + } +} + +// ----- + +// CHECK-LABEL: blocked_to_dot_op_shortcut_gfx940 +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 2], warpsPerCTA = [2, 2], order = [1, 0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx940", "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func @blocked_to_dot_op_shortcut_gfx940(%arg0: tensor<32x32xf16, #blocked>) { + // CHECK-NOT: triton_gpu.local_alloc + // CHECK: triton_gpu.convert_layout + // CHECK-NOT: triton_gpu.local_alloc + %0 = triton_gpu.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> + tt.return + } +} + +// ----- + +// CHECK-LABEL: neg_blocked_to_dot_op_incompatible_elems_gfx940 +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [32, 2], warpsPerCTA = [2, 2], order = [1, 0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx940", "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func @neg_blocked_to_dot_op_incompatible_elems_gfx940(%arg0: tensor<32x32xf16, #blocked>) { + // CHECK-NOT: triton_gpu.convert_layout + // CHECK: triton_gpu.local_alloc + // CHECK: triton_gpu.local_load + %0 = triton_gpu.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> + tt.return + } +} + +// ----- + +// CHECK-LABEL: neg_blocked_to_dot_op_incompatible_threads_gfx940 +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 2], warpsPerCTA = [2, 2], order = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 32], threadsPerWarp = [16, 4], warpsPerCTA = [2, 2], order = [1, 0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx940", "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func @neg_blocked_to_dot_op_incompatible_threads_gfx940(%arg0: tensor<32x32xf16, #blocked>) { + // CHECK-NOT: triton_gpu.convert_layout + // CHECK: triton_gpu.local_alloc + // CHECK: triton_gpu.local_load + %0 = triton_gpu.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked1}>> + tt.return + } +} + +// ----- + +// CHECK-LABEL: neg_blocked_to_dot_op_incompatible_warp_gfx940 +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 2], warpsPerCTA = [2, 2], order = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 2], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx940", "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func @neg_blocked_to_dot_op_incompatible_warp_gfx940(%arg0: tensor<32x32xf16, #blocked>) { + // CHECK-NOT: triton_gpu.convert_layout + // CHECK: triton_gpu.local_alloc + // CHECK: triton_gpu.local_load + %0 = triton_gpu.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked1}>> + tt.return + } +} diff --git a/test/Conversion/tritongpu_to_llvm_block_dot_shortcut.mlir b/test/Conversion/tritongpu_to_llvm_block_dot_shortcut.mlir new file mode 100644 index 000000000..49128064a --- /dev/null +++ b/test/Conversion/tritongpu_to_llvm_block_dot_shortcut.mlir @@ -0,0 +1,47 @@ +// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-gpu-to-llvm | FileCheck %s + +// CHECK-LABEL: blocked_to_dot_op_shortcut_warp32 +#blocked = #triton_gpu.blocked<{sizePerThread = [32, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [0, 1]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:80", "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func @blocked_to_dot_op_shortcut_warp32(%arg0: tensor<32x32xf16, #blocked>, %arg1: tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>) { + %0 = triton_gpu.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> + // CHECK-NOT: load + tt.return + } +} + +// ----- + +// CHECK-LABEL: blocked_to_dot_op_shortcut_warp64 +#blocked = #triton_gpu.blocked<{sizePerThread = [32, 1], threadsPerWarp = [2, 32], warpsPerCTA = [2, 2], order = [1, 0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx940", "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func @blocked_to_dot_op_shortcut_warp64(%arg0: tensor<32x32xf16, #blocked>) { + %0 = triton_gpu.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> + // CHECK-NOT: load + tt.return + } +} + +// ----- + +// CHECK-LABEL: blocked_to_dot3d_op_shortcut_warp32 +#blocked = #triton_gpu.blocked<{sizePerThread = [2, 32, 1], threadsPerWarp = [1, 1, 32], warpsPerCTA = [2, 1, 2], order = [1, 2, 0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:80", "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func @blocked_to_dot3d_op_shortcut_warp32(%arg0: tensor<8x32x32xf16, #blocked>) { + %0 = triton_gpu.convert_layout %arg0 : tensor<8x32x32xf16, #blocked> -> tensor<8x32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> + // CHECK-NOT: load + tt.return + } +} + +// ----- + +// CHECK-LABEL: blocked_to_dot3d_op_shortcut_warp64 +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 32, 1], threadsPerWarp = [1, 2, 32], warpsPerCTA = [2, 2, 1], order = [2, 1, 0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx940", "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func @blocked_to_dot3d_op_shortcut_warp64(%arg0: tensor<8x32x32xf16, #blocked>) { + %0 = triton_gpu.convert_layout %arg0 : tensor<8x32x32xf16, #blocked> -> tensor<8x32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> + // CHECK-NOT: load + tt.return + } +} diff --git a/test/TritonGPU/reduce-data-duplication.mlir b/test/TritonGPU/reduce-data-duplication.mlir index e98a6108d..9fca92c9b 100644 --- a/test/TritonGPU/reduce-data-duplication.mlir +++ b/test/TritonGPU/reduce-data-duplication.mlir @@ -1,8 +1,8 @@ // RUN: triton-opt %s -split-input-file -tritongpu-reduce-data-duplication | FileCheck %s -// CHECK: #[[SHARED:.*]] = #triton_gpu.shared<{vec = 8, perPhase = 8, maxPhase = 2, order = [0, 1], hasLeadingOffset = false} -// CHECK: apply_swizzle -// CHECK: %{{.*}} = triton_gpu.local_alloc %{{.*}} : (tensor<16x256xf16, #{{.*}}>) -> !tt.memdesc<16x256xf16, #[[SHARED]], #triton_gpu.shared_memory> +// CHECK: #[[$SHARED:.*]] = #triton_gpu.shared<{vec = 8, perPhase = 8, maxPhase = 2, order = [0, 1], hasLeadingOffset = false} +// CHECK-LABEL: apply_swizzle +// CHECK: %{{.*}} = triton_gpu.local_alloc %{{.*}} : (tensor<16x256xf16, #{{.*}}>) -> !tt.memdesc<16x256xf16, #[[$SHARED]], #triton_gpu.shared_memory> #blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 4], order = [0, 1]}> #mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = [16, 8]}> @@ -12,3 +12,31 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : tt.return } } + +// ----- + +// CHECK-LABEL: conversion_shortcut_blocked_dotop_warp32 +// CHECK-NOT: triton_gpu.local_alloc +// CHECK: triton_gpu.convert_layout +// CHECK-NOT: triton_gpu.local_alloc +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 64], threadsPerWarp = [16, 2], warpsPerCTA = [2, 2], order = [0, 1]}> +module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func @conversion_shortcut_blocked_dotop_warp32(%arg0: tensor<64x64xf16, #blocked>) { + %0 = triton_gpu.convert_layout %arg0 : tensor<64x64xf16, #blocked> -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> + tt.return + } +} + +// ----- + +// CHECK-LABEL: conversion_shortcut_blocked_dotop_warp64 +// CHECK-NOT: triton_gpu.local_alloc +// CHECK: triton_gpu.convert_layout +// CHECK-NOT: triton_gpu.local_alloc +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 64], threadsPerWarp = [32, 2], warpsPerCTA = [2, 2], order = [0, 1]}> +module attributes {"triton_gpu.target" = "hip:gfx940", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func @conversion_shortcut_blocked_dotop_warp64(%arg0: tensor<64x64xf16, #blocked>) { + %0 = triton_gpu.convert_layout %arg0 : tensor<64x64xf16, #blocked> -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> + tt.return + } +} From 219c177de4a051bb05059e4aa133eea450411a53 Mon Sep 17 00:00:00 2001 From: Corbin Robeck Date: Thu, 3 Oct 2024 21:24:31 -0400 Subject: [PATCH 4/7] [PROTON] Add metric percentage features (#4836) Add Proton feature to print percentage of total model for non-exclusive metrics. --- third_party/proton/proton/viewer.py | 13 ++++++++++++- third_party/proton/test/test_viewer.py | 7 +++++-- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/third_party/proton/proton/viewer.py b/third_party/proton/proton/viewer.py index 9fe0e7e67..ae689589c 100644 --- a/third_party/proton/proton/viewer.py +++ b/third_party/proton/proton/viewer.py @@ -105,6 +105,7 @@ def get_min_time_bytes(df, device_info): def derive_metrics(gf, metrics, raw_metrics, device_info): derived_metrics = [] original_metrics = [] + exclusive_metrics = ["util"] + list(derivable_metrics.keys()) + list(avg_time_factor_dict.factor.keys()) internal_frame_indices = gf.dataframe["device_id"].isna() def get_time_seconds(df): @@ -133,6 +134,7 @@ def get_time_seconds(df): gf.dataframe[f"{metric} (inc)"] = (get_time_seconds(gf.dataframe) / time_factor_dict.factor[metric_time_unit]) derived_metrics.append(f"{metric} (inc)") + metric_name = match_available_metrics([time_factor_dict.name], raw_metrics)[0] elif metric in avg_time_factor_dict.factor: metric_time_unit = avg_time_factor_dict.name + "/" + metric.split("/")[1] gf.dataframe[f"{metric} (inc)"] = (get_time_seconds(gf.dataframe) / gf.dataframe['count'] / @@ -141,7 +143,12 @@ def get_time_seconds(df): derived_metrics.append(f"{metric} (inc)") else: original_metrics.append(metric) - + if metric not in exclusive_metrics: + single_frame = gf.dataframe[metric_name] + total = gf.dataframe[metric_name].iloc[0] + metric = metric.split("/")[0] + gf.dataframe[f"{metric}/% (inc)"] = (single_frame / total) * 100.0 + derived_metrics.append(f"{metric}/% (inc)") if original_metrics: original_metrics = match_available_metrics(original_metrics, raw_metrics) return derived_metrics + original_metrics @@ -227,6 +234,10 @@ def main(): - flop[<8/16/32/64>]/s, gflop[<8/16/32/64>]/s, tflop[<8/16/32/64>]/s: flops / time - byte/s, gbyte/s, tbyte/s: bytes / time - util: max(sum(flops) / peak_flops_time, sum(bytes) / peak_bandwidth_time) + +For inclusive metrics (e.g. time) an additional column is printed showing the percentage +each frame is of the full model. + """, ) argparser.add_argument( diff --git a/third_party/proton/test/test_viewer.py b/third_party/proton/test/test_viewer.py index 998825bbc..b2d4d39f9 100644 --- a/third_party/proton/test/test_viewer.py +++ b/third_party/proton/test/test_viewer.py @@ -118,8 +118,11 @@ def test_util(): def test_time_derivation(): derivation_metrics_test( metrics=["time/s", "time/ms", "time/us", "time/ns"], expected_data={ - 'time/s (inc)': [0.0004096, 0.0002048, 0.0002048], 'time/ms (inc)': [0.4096, 0.2048, 0.2048], - 'time/us (inc)': [409.6, 204.8, 204.8], 'time/ns (inc)': [409600.0, 204800.0, 204800.0] + 'time/s (inc)': [0.0004096, 0.0002048, 0.0002048], + 'time/ms (inc)': [0.4096, 0.2048, 0.2048], + 'time/us (inc)': [409.6, 204.8, 204.8], + 'time/ns (inc)': [409600.0, 204800.0, 204800.0], + 'time/% (inc)': [100.0, 50.0, 50.0], }, sample_file=cuda_example_file) From 5f9bb95d657268b3c33f9e7e5dbbde4510d9f704 Mon Sep 17 00:00:00 2001 From: SJW <48454132+sjw36@users.noreply.github.com> Date: Thu, 3 Oct 2024 23:54:27 -0500 Subject: [PATCH 5/7] [Backend] Copy attributes to new loop in RewriteTensorPointer (#4848) This fixes tl.num_stages lost in translation. --- lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp | 1 + test/Triton/rewrite-tensor-pointer.mlir | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp b/lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp index 4d40e0f31..bb22489ea 100644 --- a/lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp +++ b/lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp @@ -413,6 +413,7 @@ class RewriteTensorPointerPass auto newForOp = builder.create(op.getLoc(), op.getLowerBound(), op.getUpperBound(), op.getStep(), newIterOperands); + newForOp->setAttrs(op->getAttrs()); // Create value mapping. Note that for tensor pointers, we use identity // mapping. It may refer to a value in the old loop, but we will rewrite it diff --git a/test/Triton/rewrite-tensor-pointer.mlir b/test/Triton/rewrite-tensor-pointer.mlir index 1d659fa03..26625c3a0 100644 --- a/test/Triton/rewrite-tensor-pointer.mlir +++ b/test/Triton/rewrite-tensor-pointer.mlir @@ -111,7 +111,7 @@ tt.func public @rewrite_for(%arg0: !tt.ptr, %arg1: !tt.ptr) { %4 = arith.addf %arg3, %3 : tensor<128x32xf16> %5 = tt.advance %arg4, [%c32_i32, %c0_i32] : !tt.ptr> scf.yield %4, %5 : tensor<128x32xf16>, !tt.ptr> - } + } {tt.num_stages = 3 : i32} %2 = tt.splat %arg1 : !tt.ptr -> tensor<128x32x!tt.ptr> tt.store %2, %1#0 : tensor<128x32x!tt.ptr> tt.return @@ -138,6 +138,7 @@ tt.func public @rewrite_for(%arg0: !tt.ptr, %arg1: !tt.ptr) { // CHECK: %[[EXTSI3:.*]] = arith.extsi %[[C0_I32]] : i32 to i64 // CHECK: %[[ADDI1:.*]] = arith.addi %[[ARG5]], %[[EXTSI3]] : i64 // CHECK: scf.yield %{{.*}}, %[[ADDI0]], %[[ADDI1]] : tensor<128x32xf16>, i64, i64 +// CHECK: tt.num_stages = 3 // ----- tt.func public @rewrite_if(%arg0: !tt.ptr, %arg1: i1) -> tensor<128x32xf16> { From c2570b7b3efe066ef32602e437378bcb4db1e9b3 Mon Sep 17 00:00:00 2001 From: Whitney Tsang Date: Mon, 7 Oct 2024 21:38:37 -0400 Subject: [PATCH 6/7] Add dotOperandDpasToLinearLayout (Tensor B) (#2422) `dotOperandMfmaToLinearLayout` for tensor B is added in upstream commit https://github.com/triton-lang/triton/commit/ff02a46bd51efd405786dbabd9258c41ccfe7efe recently. CI: https://github.com/intel/intel-xpu-backend-for-triton/actions/runs/11170777152 Signed-off-by: Whitney Tsang --- .../TritonGPU/IR/LinearLayoutConversions.cpp | 4 + test/Conversion/intel/dot_layout_offset.mlir | 312 +++++++++--------- .../IR/LinearLayoutConversions.h | 4 + .../IR/LinearLayoutConversions.cpp | 9 + 4 files changed, 168 insertions(+), 161 deletions(-) diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index c35b186fb..7d508f234 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -1,5 +1,6 @@ #include +#include "intel/include/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.h" #include "triton/Dialect/Triton/IR/Utility.h" #include "triton/Dialect/TritonGPU/IR/Attributes.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" @@ -827,6 +828,9 @@ DotOperandEncodingAttr::toLinearLayout(ArrayRef shape) const { if (auto mfmaLayout = llvm::dyn_cast(getParent())) { return dotOperandMfmaToLinearLayout(*this, shape); } + if (auto dpasLayout = llvm::dyn_cast(getParent())) { + return dotOperandDpasToLinearLayout(*this, shape); + } return std::nullopt; } diff --git a/test/Conversion/intel/dot_layout_offset.mlir b/test/Conversion/intel/dot_layout_offset.mlir index 26e9d4d60..92129848d 100644 --- a/test/Conversion/intel/dot_layout_offset.mlir +++ b/test/Conversion/intel/dot_layout_offset.mlir @@ -344,317 +344,307 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : // CHECK: %[[THREAD_ID_I64:.*]] = llvm.call spir_funccc @_Z12get_local_idj(%[[VAL_142]]) // CHECK: %[[THREAD_ID_I32:.*]] = llvm.trunc %[[THREAD_ID_I64]] : i64 to i32 // CHECK: %[[VAL_145:.*]] = llvm.mlir.constant(16 : i32) : i32 + // CHECK: %[[LANE_ID:.*]] = llvm.urem %[[THREAD_ID_I32]], %[[VAL_145]] : i32 // CHECK: %[[WARP_ID:.*]] = llvm.udiv %[[THREAD_ID_I32]], %[[VAL_145]] : i32 - // CHECK: %[[VAL_147:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: %[[LANE_ID:.*]] = llvm.urem %[[THREAD_ID_I32]], %[[VAL_147]] : i32 + // CHECK-COUNT-3: %[[CST_0:.*]] = llvm.mlir.constant(0 : i32) : i32 // CHECK: %[[VAL_149:.*]] = llvm.mlir.constant(1 : i32) : i32 - // CHECK: %[[WARP_ID_N:.*]] = llvm.urem %[[WARP_ID]], %[[VAL_149]] : i32 - // CHECK: %[[VAL_151:.*]] = llvm.udiv %[[WARP_ID]], %[[VAL_149]] : i32 + // CHECK: %[[VAL_150:.*]] = llvm.and %[[LANE_ID]], %[[VAL_149]] : i32 + // CHECK: %[[VAL_151:.*]] = llvm.icmp "eq" %[[VAL_150]], %[[CST_0]] : i32 // CHECK: %[[VAL_152:.*]] = llvm.mlir.constant(1 : i32) : i32 - // CHECK: %[[WARP_ID_M:.*]] = llvm.urem %[[VAL_151]], %[[VAL_152]] : i32 - // CHECK: %[[VAL_154:.*]] = llvm.udiv %[[VAL_151]], %[[VAL_152]] : i32 + // CHECK: %[[VAL_153:.*]] = llvm.select %[[VAL_151]], %[[CST_0]], %[[VAL_152]] : i1, i32 + // CHECK: %[[VAL_154:.*]] = llvm.xor %[[CST_0]], %[[VAL_153]] : i32 // CHECK: %[[VAL_155:.*]] = llvm.mlir.constant(2 : i32) : i32 - // CHECK: %[[ROUNDED_WARP_ID_N:.*]] = llvm.urem %[[WARP_ID_N]], %[[VAL_155]] : i32 - // CHECK: %[[warpShape_N:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: %[[warpOffset:.*]] = llvm.mul %[[ROUNDED_WARP_ID_N]], %[[warpShape_N]] : i32 - // CHECK: %[[VAL_159:.*]] = llvm.mlir.constant(8 : i32) : i32 - // CHECK: %[[VAL_160:.*]] = llvm.udiv %[[LANE_ID]], %[[VAL_159]] : i32 - // CHECK: %[[VAL_161:.*]] = llvm.mlir.constant(2 : i32) : i32 - // CHECK: %[[laneRowIndex:.*]] = llvm.mul %[[VAL_160]], %[[VAL_161]] : i32 - // CHECK: %[[VAL_163:.*]] = llvm.mlir.constant(8 : i32) : i32 - // CHECK: %[[laneColIndex:.*]] = llvm.urem %[[LANE_ID]], %[[VAL_163]] : i32 - // CHECK: %[[multiDimBase_N:.*]] = llvm.add %[[laneColIndex]], %[[warpOffset]] : i32 - // CHECK: %[[VAL_166:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[VAL_167:.*]] = llvm.mlir.constant(1 : i32) : i32 - // CHECK: %[[VAL_168:.*]] = llvm.urem %[[VAL_166]], %[[VAL_167]] : i32 - // CHECK: %[[VAL_169:.*]] = llvm.udiv %[[VAL_166]], %[[VAL_167]] : i32 - // CHECK: %[[VAL_170:.*]] = llvm.mlir.constant(1 : i32) : i32 - // CHECK: %[[VAL_171:.*]] = llvm.urem %[[VAL_169]], %[[VAL_170]] : i32 - // CHECK: %[[VAL_172:.*]] = llvm.udiv %[[VAL_169]], %[[VAL_170]] : i32 - // CHECK: %[[VAL_173:.*]] = llvm.mlir.constant(1 : i32) : i32 - // CHECK: %[[VAL_174:.*]] = llvm.urem %[[VAL_171]], %[[VAL_173]] : i32 - // CHECK: %[[VAL_175:.*]] = llvm.mlir.constant(1 : i32) : i32 - // CHECK: %[[VAL_176:.*]] = llvm.urem %[[VAL_168]], %[[VAL_175]] : i32 - // CHECK: %[[VAL_177:.*]] = llvm.mlir.constant(32 : i32) : i32 - // CHECK: %[[CTAOffset_M:.*]] = llvm.mul %[[VAL_174]], %[[VAL_177]] : i32 - // CHECK: %[[VAL_179:.*]] = llvm.mlir.constant(32 : i32) : i32 - // CHECK: %[[CTAOffset_N:.*]] = llvm.mul %[[VAL_176]], %[[VAL_179]] : i32 - // CHECK: %[[VAL_181:.*]] = llvm.add %[[laneRowIndex]], %[[CTAOffset_M]] : i32 - // CHECK: %[[VAL_182:.*]] = llvm.add %[[multiDimBase_N]], %[[CTAOffset_N]] : i32 + // CHECK: %[[VAL_156:.*]] = llvm.and %[[LANE_ID]], %[[VAL_155]] : i32 + // CHECK: %[[VAL_157:.*]] = llvm.icmp "eq" %[[VAL_156]], %[[CST_0]] : i32 + // CHECK: %[[VAL_158:.*]] = llvm.mlir.constant(2 : i32) : i32 + // CHECK: %[[VAL_159:.*]] = llvm.select %[[VAL_157]], %[[CST_0]], %[[VAL_158]] : i1, i32 + // CHECK: %[[VAL_160:.*]] = llvm.xor %[[VAL_154]], %[[VAL_159]] : i32 + // CHECK: %[[VAL_161:.*]] = llvm.mlir.constant(4 : i32) : i32 + // CHECK: %[[VAL_162:.*]] = llvm.and %[[LANE_ID]], %[[VAL_161]] : i32 + // CHECK: %[[VAL_163:.*]] = llvm.icmp "eq" %[[VAL_162]], %[[CST_0]] : i32 + // CHECK: %[[VAL_164:.*]] = llvm.mlir.constant(4 : i32) : i32 + // CHECK: %[[VAL_165:.*]] = llvm.select %[[VAL_163]], %[[CST_0]], %[[VAL_164]] : i1, i32 + // CHECK: %[[VAL_182:.*]] = llvm.xor %[[VAL_160]], %[[VAL_165]] : i32 + // CHECK: %[[VAL_167:.*]] = llvm.mlir.constant(8 : i32) : i32 + // CHECK: %[[VAL_168:.*]] = llvm.and %[[LANE_ID]], %[[VAL_167]] : i32 + // CHECK: %[[VAL_169:.*]] = llvm.icmp "eq" %[[VAL_168]], %[[CST_0]] : i32 + // CHECK: %[[VAL_170:.*]] = llvm.mlir.constant(2 : i32) : i32 + // CHECK: %[[VAL_171:.*]] = llvm.select %[[VAL_169]], %[[CST_0]], %[[VAL_170]] : i1, i32 + // CHECK: %[[VAL_181:.*]] = llvm.xor %[[CST_0]], %[[VAL_171]] : i32 // COM: There are total [2, 4] repetitions of tensor shape [32, 32] per warp of B. // COM: The repetitions are clustered as [1, 2] for B operand. The repetitions orders are [0, 0], [0, 1], [1, 0], [1, 1], [0, 2], [0, 3], [1, 2], [1, 3] // COM: Offsets of rep [0, 0]. // CHECK: %[[VAL_183:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[VAL_184:.*]] = llvm.add %[[VAL_181]], %[[VAL_183]] : i32 + // CHECK: %[[VAL_184:.*]] = llvm.xor %[[VAL_181]], %[[VAL_183]] : i32 // CHECK: %[[VAL_185:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[VAL_186:.*]] = llvm.add %[[VAL_182]], %[[VAL_185]] : i32 + // CHECK: %[[VAL_186:.*]] = llvm.xor %[[VAL_182]], %[[VAL_185]] : i32 // CHECK: %[[VAL_187:.*]] = llvm.mlir.constant(1 : i32) : i32 - // CHECK: %[[VAL_188:.*]] = llvm.add %[[VAL_181]], %[[VAL_187]] : i32 + // CHECK: %[[VAL_188:.*]] = llvm.xor %[[VAL_181]], %[[VAL_187]] : i32 // CHECK: %[[VAL_189:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[VAL_190:.*]] = llvm.add %[[VAL_182]], %[[VAL_189]] : i32 + // CHECK: %[[VAL_190:.*]] = llvm.xor %[[VAL_182]], %[[VAL_189]] : i32 // CHECK: %[[VAL_191:.*]] = llvm.mlir.constant(4 : i32) : i32 - // CHECK: %[[VAL_192:.*]] = llvm.add %[[VAL_181]], %[[VAL_191]] : i32 + // CHECK: %[[VAL_192:.*]] = llvm.xor %[[VAL_181]], %[[VAL_191]] : i32 // CHECK: %[[VAL_193:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[VAL_194:.*]] = llvm.add %[[VAL_182]], %[[VAL_193]] : i32 + // CHECK: %[[VAL_194:.*]] = llvm.xor %[[VAL_182]], %[[VAL_193]] : i32 // CHECK: %[[VAL_195:.*]] = llvm.mlir.constant(5 : i32) : i32 - // CHECK: %[[VAL_196:.*]] = llvm.add %[[VAL_181]], %[[VAL_195]] : i32 + // CHECK: %[[VAL_196:.*]] = llvm.xor %[[VAL_181]], %[[VAL_195]] : i32 // CHECK: %[[VAL_197:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[VAL_198:.*]] = llvm.add %[[VAL_182]], %[[VAL_197]] : i32 + // CHECK: %[[VAL_198:.*]] = llvm.xor %[[VAL_182]], %[[VAL_197]] : i32 // CHECK: %[[VAL_199:.*]] = llvm.mlir.constant(8 : i32) : i32 - // CHECK: %[[VAL_200:.*]] = llvm.add %[[VAL_181]], %[[VAL_199]] : i32 + // CHECK: %[[VAL_200:.*]] = llvm.xor %[[VAL_181]], %[[VAL_199]] : i32 // CHECK: %[[VAL_201:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[VAL_202:.*]] = llvm.add %[[VAL_182]], %[[VAL_201]] : i32 + // CHECK: %[[VAL_202:.*]] = llvm.xor %[[VAL_182]], %[[VAL_201]] : i32 // CHECK: %[[VAL_203:.*]] = llvm.mlir.constant(9 : i32) : i32 - // CHECK: %[[VAL_204:.*]] = llvm.add %[[VAL_181]], %[[VAL_203]] : i32 + // CHECK: %[[VAL_204:.*]] = llvm.xor %[[VAL_181]], %[[VAL_203]] : i32 // CHECK: %[[VAL_205:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[VAL_206:.*]] = llvm.add %[[VAL_182]], %[[VAL_205]] : i32 + // CHECK: %[[VAL_206:.*]] = llvm.xor %[[VAL_182]], %[[VAL_205]] : i32 // CHECK: %[[VAL_207:.*]] = llvm.mlir.constant(12 : i32) : i32 - // CHECK: %[[VAL_208:.*]] = llvm.add %[[VAL_181]], %[[VAL_207]] : i32 + // CHECK: %[[VAL_208:.*]] = llvm.xor %[[VAL_181]], %[[VAL_207]] : i32 // CHECK: %[[VAL_209:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[VAL_210:.*]] = llvm.add %[[VAL_182]], %[[VAL_209]] : i32 + // CHECK: %[[VAL_210:.*]] = llvm.xor %[[VAL_182]], %[[VAL_209]] : i32 // CHECK: %[[VAL_211:.*]] = llvm.mlir.constant(13 : i32) : i32 - // CHECK: %[[VAL_212:.*]] = llvm.add %[[VAL_181]], %[[VAL_211]] : i32 + // CHECK: %[[VAL_212:.*]] = llvm.xor %[[VAL_181]], %[[VAL_211]] : i32 // CHECK: %[[VAL_213:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[VAL_214:.*]] = llvm.add %[[VAL_182]], %[[VAL_213]] : i32 + // CHECK: %[[VAL_214:.*]] = llvm.xor %[[VAL_182]], %[[VAL_213]] : i32 // COM: Offsets of rep [0, 1]. // CHECK: %[[VAL_215:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[VAL_216:.*]] = llvm.add %[[VAL_181]], %[[VAL_215]] : i32 + // CHECK: %[[VAL_216:.*]] = llvm.xor %[[VAL_181]], %[[VAL_215]] : i32 // CHECK: %[[VAL_217:.*]] = llvm.mlir.constant(8 : i32) : i32 - // CHECK: %[[VAL_218:.*]] = llvm.add %[[VAL_182]], %[[VAL_217]] : i32 + // CHECK: %[[VAL_218:.*]] = llvm.xor %[[VAL_182]], %[[VAL_217]] : i32 // CHECK: %[[VAL_219:.*]] = llvm.mlir.constant(1 : i32) : i32 - // CHECK: %[[VAL_220:.*]] = llvm.add %[[VAL_181]], %[[VAL_219]] : i32 + // CHECK: %[[VAL_220:.*]] = llvm.xor %[[VAL_181]], %[[VAL_219]] : i32 // CHECK: %[[VAL_221:.*]] = llvm.mlir.constant(8 : i32) : i32 - // CHECK: %[[VAL_222:.*]] = llvm.add %[[VAL_182]], %[[VAL_221]] : i32 + // CHECK: %[[VAL_222:.*]] = llvm.xor %[[VAL_182]], %[[VAL_221]] : i32 // CHECK: %[[VAL_223:.*]] = llvm.mlir.constant(4 : i32) : i32 - // CHECK: %[[VAL_224:.*]] = llvm.add %[[VAL_181]], %[[VAL_223]] : i32 + // CHECK: %[[VAL_224:.*]] = llvm.xor %[[VAL_181]], %[[VAL_223]] : i32 // CHECK: %[[VAL_225:.*]] = llvm.mlir.constant(8 : i32) : i32 - // CHECK: %[[VAL_226:.*]] = llvm.add %[[VAL_182]], %[[VAL_225]] : i32 + // CHECK: %[[VAL_226:.*]] = llvm.xor %[[VAL_182]], %[[VAL_225]] : i32 // CHECK: %[[VAL_227:.*]] = llvm.mlir.constant(5 : i32) : i32 - // CHECK: %[[VAL_228:.*]] = llvm.add %[[VAL_181]], %[[VAL_227]] : i32 + // CHECK: %[[VAL_228:.*]] = llvm.xor %[[VAL_181]], %[[VAL_227]] : i32 // CHECK: %[[VAL_229:.*]] = llvm.mlir.constant(8 : i32) : i32 - // CHECK: %[[VAL_230:.*]] = llvm.add %[[VAL_182]], %[[VAL_229]] : i32 + // CHECK: %[[VAL_230:.*]] = llvm.xor %[[VAL_182]], %[[VAL_229]] : i32 // CHECK: %[[VAL_231:.*]] = llvm.mlir.constant(8 : i32) : i32 - // CHECK: %[[VAL_232:.*]] = llvm.add %[[VAL_181]], %[[VAL_231]] : i32 + // CHECK: %[[VAL_232:.*]] = llvm.xor %[[VAL_181]], %[[VAL_231]] : i32 // CHECK: %[[VAL_233:.*]] = llvm.mlir.constant(8 : i32) : i32 - // CHECK: %[[VAL_234:.*]] = llvm.add %[[VAL_182]], %[[VAL_233]] : i32 + // CHECK: %[[VAL_234:.*]] = llvm.xor %[[VAL_182]], %[[VAL_233]] : i32 // CHECK: %[[VAL_235:.*]] = llvm.mlir.constant(9 : i32) : i32 - // CHECK: %[[VAL_236:.*]] = llvm.add %[[VAL_181]], %[[VAL_235]] : i32 + // CHECK: %[[VAL_236:.*]] = llvm.xor %[[VAL_181]], %[[VAL_235]] : i32 // CHECK: %[[VAL_237:.*]] = llvm.mlir.constant(8 : i32) : i32 - // CHECK: %[[VAL_238:.*]] = llvm.add %[[VAL_182]], %[[VAL_237]] : i32 + // CHECK: %[[VAL_238:.*]] = llvm.xor %[[VAL_182]], %[[VAL_237]] : i32 // CHECK: %[[VAL_239:.*]] = llvm.mlir.constant(12 : i32) : i32 - // CHECK: %[[VAL_240:.*]] = llvm.add %[[VAL_181]], %[[VAL_239]] : i32 + // CHECK: %[[VAL_240:.*]] = llvm.xor %[[VAL_181]], %[[VAL_239]] : i32 // CHECK: %[[VAL_241:.*]] = llvm.mlir.constant(8 : i32) : i32 - // CHECK: %[[VAL_242:.*]] = llvm.add %[[VAL_182]], %[[VAL_241]] : i32 + // CHECK: %[[VAL_242:.*]] = llvm.xor %[[VAL_182]], %[[VAL_241]] : i32 // CHECK: %[[VAL_243:.*]] = llvm.mlir.constant(13 : i32) : i32 - // CHECK: %[[VAL_244:.*]] = llvm.add %[[VAL_181]], %[[VAL_243]] : i32 + // CHECK: %[[VAL_244:.*]] = llvm.xor %[[VAL_181]], %[[VAL_243]] : i32 // CHECK: %[[VAL_245:.*]] = llvm.mlir.constant(8 : i32) : i32 - // CHECK: %[[VAL_246:.*]] = llvm.add %[[VAL_182]], %[[VAL_245]] : i32 + // CHECK: %[[VAL_246:.*]] = llvm.xor %[[VAL_182]], %[[VAL_245]] : i32 // COM: Offsets of rep [1, 0]. // CHECK: %[[VAL_247:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: %[[VAL_248:.*]] = llvm.add %[[VAL_181]], %[[VAL_247]] : i32 + // CHECK: %[[VAL_248:.*]] = llvm.xor %[[VAL_181]], %[[VAL_247]] : i32 // CHECK: %[[VAL_249:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[VAL_250:.*]] = llvm.add %[[VAL_182]], %[[VAL_249]] : i32 + // CHECK: %[[VAL_250:.*]] = llvm.xor %[[VAL_182]], %[[VAL_249]] : i32 // CHECK: %[[VAL_251:.*]] = llvm.mlir.constant(17 : i32) : i32 - // CHECK: %[[VAL_252:.*]] = llvm.add %[[VAL_181]], %[[VAL_251]] : i32 + // CHECK: %[[VAL_252:.*]] = llvm.xor %[[VAL_181]], %[[VAL_251]] : i32 // CHECK: %[[VAL_253:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[VAL_254:.*]] = llvm.add %[[VAL_182]], %[[VAL_253]] : i32 + // CHECK: %[[VAL_254:.*]] = llvm.xor %[[VAL_182]], %[[VAL_253]] : i32 // CHECK: %[[VAL_255:.*]] = llvm.mlir.constant(20 : i32) : i32 - // CHECK: %[[VAL_256:.*]] = llvm.add %[[VAL_181]], %[[VAL_255]] : i32 + // CHECK: %[[VAL_256:.*]] = llvm.xor %[[VAL_181]], %[[VAL_255]] : i32 // CHECK: %[[VAL_257:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[VAL_258:.*]] = llvm.add %[[VAL_182]], %[[VAL_257]] : i32 + // CHECK: %[[VAL_258:.*]] = llvm.xor %[[VAL_182]], %[[VAL_257]] : i32 // CHECK: %[[VAL_259:.*]] = llvm.mlir.constant(21 : i32) : i32 - // CHECK: %[[VAL_260:.*]] = llvm.add %[[VAL_181]], %[[VAL_259]] : i32 + // CHECK: %[[VAL_260:.*]] = llvm.xor %[[VAL_181]], %[[VAL_259]] : i32 // CHECK: %[[VAL_261:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[VAL_262:.*]] = llvm.add %[[VAL_182]], %[[VAL_261]] : i32 + // CHECK: %[[VAL_262:.*]] = llvm.xor %[[VAL_182]], %[[VAL_261]] : i32 // CHECK: %[[VAL_263:.*]] = llvm.mlir.constant(24 : i32) : i32 - // CHECK: %[[VAL_264:.*]] = llvm.add %[[VAL_181]], %[[VAL_263]] : i32 + // CHECK: %[[VAL_264:.*]] = llvm.xor %[[VAL_181]], %[[VAL_263]] : i32 // CHECK: %[[VAL_265:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[VAL_266:.*]] = llvm.add %[[VAL_182]], %[[VAL_265]] : i32 + // CHECK: %[[VAL_266:.*]] = llvm.xor %[[VAL_182]], %[[VAL_265]] : i32 // CHECK: %[[VAL_267:.*]] = llvm.mlir.constant(25 : i32) : i32 - // CHECK: %[[VAL_268:.*]] = llvm.add %[[VAL_181]], %[[VAL_267]] : i32 + // CHECK: %[[VAL_268:.*]] = llvm.xor %[[VAL_181]], %[[VAL_267]] : i32 // CHECK: %[[VAL_269:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[VAL_270:.*]] = llvm.add %[[VAL_182]], %[[VAL_269]] : i32 + // CHECK: %[[VAL_270:.*]] = llvm.xor %[[VAL_182]], %[[VAL_269]] : i32 // CHECK: %[[VAL_271:.*]] = llvm.mlir.constant(28 : i32) : i32 - // CHECK: %[[VAL_272:.*]] = llvm.add %[[VAL_181]], %[[VAL_271]] : i32 + // CHECK: %[[VAL_272:.*]] = llvm.xor %[[VAL_181]], %[[VAL_271]] : i32 // CHECK: %[[VAL_273:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[VAL_274:.*]] = llvm.add %[[VAL_182]], %[[VAL_273]] : i32 + // CHECK: %[[VAL_274:.*]] = llvm.xor %[[VAL_182]], %[[VAL_273]] : i32 // CHECK: %[[VAL_275:.*]] = llvm.mlir.constant(29 : i32) : i32 - // CHECK: %[[VAL_276:.*]] = llvm.add %[[VAL_181]], %[[VAL_275]] : i32 + // CHECK: %[[VAL_276:.*]] = llvm.xor %[[VAL_181]], %[[VAL_275]] : i32 // CHECK: %[[VAL_277:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[VAL_278:.*]] = llvm.add %[[VAL_182]], %[[VAL_277]] : i32 + // CHECK: %[[VAL_278:.*]] = llvm.xor %[[VAL_182]], %[[VAL_277]] : i32 // COM: Offsets of rep [1, 1]. // CHECK: %[[VAL_279:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: %[[VAL_280:.*]] = llvm.add %[[VAL_181]], %[[VAL_279]] : i32 + // CHECK: %[[VAL_280:.*]] = llvm.xor %[[VAL_181]], %[[VAL_279]] : i32 // CHECK: %[[VAL_281:.*]] = llvm.mlir.constant(8 : i32) : i32 - // CHECK: %[[VAL_282:.*]] = llvm.add %[[VAL_182]], %[[VAL_281]] : i32 + // CHECK: %[[VAL_282:.*]] = llvm.xor %[[VAL_182]], %[[VAL_281]] : i32 // CHECK: %[[VAL_283:.*]] = llvm.mlir.constant(17 : i32) : i32 - // CHECK: %[[VAL_284:.*]] = llvm.add %[[VAL_181]], %[[VAL_283]] : i32 + // CHECK: %[[VAL_284:.*]] = llvm.xor %[[VAL_181]], %[[VAL_283]] : i32 // CHECK: %[[VAL_285:.*]] = llvm.mlir.constant(8 : i32) : i32 - // CHECK: %[[VAL_286:.*]] = llvm.add %[[VAL_182]], %[[VAL_285]] : i32 + // CHECK: %[[VAL_286:.*]] = llvm.xor %[[VAL_182]], %[[VAL_285]] : i32 // CHECK: %[[VAL_287:.*]] = llvm.mlir.constant(20 : i32) : i32 - // CHECK: %[[VAL_288:.*]] = llvm.add %[[VAL_181]], %[[VAL_287]] : i32 + // CHECK: %[[VAL_288:.*]] = llvm.xor %[[VAL_181]], %[[VAL_287]] : i32 // CHECK: %[[VAL_289:.*]] = llvm.mlir.constant(8 : i32) : i32 - // CHECK: %[[VAL_290:.*]] = llvm.add %[[VAL_182]], %[[VAL_289]] : i32 + // CHECK: %[[VAL_290:.*]] = llvm.xor %[[VAL_182]], %[[VAL_289]] : i32 // CHECK: %[[VAL_291:.*]] = llvm.mlir.constant(21 : i32) : i32 - // CHECK: %[[VAL_292:.*]] = llvm.add %[[VAL_181]], %[[VAL_291]] : i32 + // CHECK: %[[VAL_292:.*]] = llvm.xor %[[VAL_181]], %[[VAL_291]] : i32 // CHECK: %[[VAL_293:.*]] = llvm.mlir.constant(8 : i32) : i32 - // CHECK: %[[VAL_294:.*]] = llvm.add %[[VAL_182]], %[[VAL_293]] : i32 + // CHECK: %[[VAL_294:.*]] = llvm.xor %[[VAL_182]], %[[VAL_293]] : i32 // CHECK: %[[VAL_295:.*]] = llvm.mlir.constant(24 : i32) : i32 - // CHECK: %[[VAL_296:.*]] = llvm.add %[[VAL_181]], %[[VAL_295]] : i32 + // CHECK: %[[VAL_296:.*]] = llvm.xor %[[VAL_181]], %[[VAL_295]] : i32 // CHECK: %[[VAL_297:.*]] = llvm.mlir.constant(8 : i32) : i32 - // CHECK: %[[VAL_298:.*]] = llvm.add %[[VAL_182]], %[[VAL_297]] : i32 + // CHECK: %[[VAL_298:.*]] = llvm.xor %[[VAL_182]], %[[VAL_297]] : i32 // CHECK: %[[VAL_299:.*]] = llvm.mlir.constant(25 : i32) : i32 - // CHECK: %[[VAL_300:.*]] = llvm.add %[[VAL_181]], %[[VAL_299]] : i32 + // CHECK: %[[VAL_300:.*]] = llvm.xor %[[VAL_181]], %[[VAL_299]] : i32 // CHECK: %[[VAL_301:.*]] = llvm.mlir.constant(8 : i32) : i32 - // CHECK: %[[VAL_302:.*]] = llvm.add %[[VAL_182]], %[[VAL_301]] : i32 + // CHECK: %[[VAL_302:.*]] = llvm.xor %[[VAL_182]], %[[VAL_301]] : i32 // CHECK: %[[VAL_303:.*]] = llvm.mlir.constant(28 : i32) : i32 - // CHECK: %[[VAL_304:.*]] = llvm.add %[[VAL_181]], %[[VAL_303]] : i32 + // CHECK: %[[VAL_304:.*]] = llvm.xor %[[VAL_181]], %[[VAL_303]] : i32 // CHECK: %[[VAL_305:.*]] = llvm.mlir.constant(8 : i32) : i32 - // CHECK: %[[VAL_306:.*]] = llvm.add %[[VAL_182]], %[[VAL_305]] : i32 + // CHECK: %[[VAL_306:.*]] = llvm.xor %[[VAL_182]], %[[VAL_305]] : i32 // CHECK: %[[VAL_307:.*]] = llvm.mlir.constant(29 : i32) : i32 - // CHECK: %[[VAL_308:.*]] = llvm.add %[[VAL_181]], %[[VAL_307]] : i32 + // CHECK: %[[VAL_308:.*]] = llvm.xor %[[VAL_181]], %[[VAL_307]] : i32 // CHECK: %[[VAL_309:.*]] = llvm.mlir.constant(8 : i32) : i32 - // CHECK: %[[VAL_310:.*]] = llvm.add %[[VAL_182]], %[[VAL_309]] : i32 + // CHECK: %[[VAL_310:.*]] = llvm.xor %[[VAL_182]], %[[VAL_309]] : i32 // COM: Offsets of rep [0, 2]. // CHECK: %[[VAL_311:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[VAL_312:.*]] = llvm.add %[[VAL_181]], %[[VAL_311]] : i32 + // CHECK: %[[VAL_312:.*]] = llvm.xor %[[VAL_181]], %[[VAL_311]] : i32 // CHECK: %[[VAL_313:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: %[[VAL_314:.*]] = llvm.add %[[VAL_182]], %[[VAL_313]] : i32 + // CHECK: %[[VAL_314:.*]] = llvm.xor %[[VAL_182]], %[[VAL_313]] : i32 // CHECK: %[[VAL_315:.*]] = llvm.mlir.constant(1 : i32) : i32 - // CHECK: %[[VAL_316:.*]] = llvm.add %[[VAL_181]], %[[VAL_315]] : i32 + // CHECK: %[[VAL_316:.*]] = llvm.xor %[[VAL_181]], %[[VAL_315]] : i32 // CHECK: %[[VAL_317:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: %[[VAL_318:.*]] = llvm.add %[[VAL_182]], %[[VAL_317]] : i32 + // CHECK: %[[VAL_318:.*]] = llvm.xor %[[VAL_182]], %[[VAL_317]] : i32 // CHECK: %[[VAL_319:.*]] = llvm.mlir.constant(4 : i32) : i32 - // CHECK: %[[VAL_320:.*]] = llvm.add %[[VAL_181]], %[[VAL_319]] : i32 + // CHECK: %[[VAL_320:.*]] = llvm.xor %[[VAL_181]], %[[VAL_319]] : i32 // CHECK: %[[VAL_321:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: %[[VAL_322:.*]] = llvm.add %[[VAL_182]], %[[VAL_321]] : i32 + // CHECK: %[[VAL_322:.*]] = llvm.xor %[[VAL_182]], %[[VAL_321]] : i32 // CHECK: %[[VAL_323:.*]] = llvm.mlir.constant(5 : i32) : i32 - // CHECK: %[[VAL_324:.*]] = llvm.add %[[VAL_181]], %[[VAL_323]] : i32 + // CHECK: %[[VAL_324:.*]] = llvm.xor %[[VAL_181]], %[[VAL_323]] : i32 // CHECK: %[[VAL_325:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: %[[VAL_326:.*]] = llvm.add %[[VAL_182]], %[[VAL_325]] : i32 + // CHECK: %[[VAL_326:.*]] = llvm.xor %[[VAL_182]], %[[VAL_325]] : i32 // CHECK: %[[VAL_327:.*]] = llvm.mlir.constant(8 : i32) : i32 - // CHECK: %[[VAL_328:.*]] = llvm.add %[[VAL_181]], %[[VAL_327]] : i32 + // CHECK: %[[VAL_328:.*]] = llvm.xor %[[VAL_181]], %[[VAL_327]] : i32 // CHECK: %[[VAL_329:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: %[[VAL_330:.*]] = llvm.add %[[VAL_182]], %[[VAL_329]] : i32 + // CHECK: %[[VAL_330:.*]] = llvm.xor %[[VAL_182]], %[[VAL_329]] : i32 // CHECK: %[[VAL_331:.*]] = llvm.mlir.constant(9 : i32) : i32 - // CHECK: %[[VAL_332:.*]] = llvm.add %[[VAL_181]], %[[VAL_331]] : i32 + // CHECK: %[[VAL_332:.*]] = llvm.xor %[[VAL_181]], %[[VAL_331]] : i32 // CHECK: %[[VAL_333:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: %[[VAL_334:.*]] = llvm.add %[[VAL_182]], %[[VAL_333]] : i32 + // CHECK: %[[VAL_334:.*]] = llvm.xor %[[VAL_182]], %[[VAL_333]] : i32 // CHECK: %[[VAL_335:.*]] = llvm.mlir.constant(12 : i32) : i32 - // CHECK: %[[VAL_336:.*]] = llvm.add %[[VAL_181]], %[[VAL_335]] : i32 + // CHECK: %[[VAL_336:.*]] = llvm.xor %[[VAL_181]], %[[VAL_335]] : i32 // CHECK: %[[VAL_337:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: %[[VAL_338:.*]] = llvm.add %[[VAL_182]], %[[VAL_337]] : i32 + // CHECK: %[[VAL_338:.*]] = llvm.xor %[[VAL_182]], %[[VAL_337]] : i32 // CHECK: %[[VAL_339:.*]] = llvm.mlir.constant(13 : i32) : i32 - // CHECK: %[[VAL_340:.*]] = llvm.add %[[VAL_181]], %[[VAL_339]] : i32 + // CHECK: %[[VAL_340:.*]] = llvm.xor %[[VAL_181]], %[[VAL_339]] : i32 // CHECK: %[[VAL_341:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: %[[VAL_342:.*]] = llvm.add %[[VAL_182]], %[[VAL_341]] : i32 + // CHECK: %[[VAL_342:.*]] = llvm.xor %[[VAL_182]], %[[VAL_341]] : i32 // COM: Offsets of rep [0, 3]. // CHECK: %[[VAL_343:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[VAL_344:.*]] = llvm.add %[[VAL_181]], %[[VAL_343]] : i32 + // CHECK: %[[VAL_344:.*]] = llvm.xor %[[VAL_181]], %[[VAL_343]] : i32 // CHECK: %[[VAL_345:.*]] = llvm.mlir.constant(24 : i32) : i32 - // CHECK: %[[VAL_346:.*]] = llvm.add %[[VAL_182]], %[[VAL_345]] : i32 + // CHECK: %[[VAL_346:.*]] = llvm.xor %[[VAL_182]], %[[VAL_345]] : i32 // CHECK: %[[VAL_347:.*]] = llvm.mlir.constant(1 : i32) : i32 - // CHECK: %[[VAL_348:.*]] = llvm.add %[[VAL_181]], %[[VAL_347]] : i32 + // CHECK: %[[VAL_348:.*]] = llvm.xor %[[VAL_181]], %[[VAL_347]] : i32 // CHECK: %[[VAL_349:.*]] = llvm.mlir.constant(24 : i32) : i32 - // CHECK: %[[VAL_350:.*]] = llvm.add %[[VAL_182]], %[[VAL_349]] : i32 + // CHECK: %[[VAL_350:.*]] = llvm.xor %[[VAL_182]], %[[VAL_349]] : i32 // CHECK: %[[VAL_351:.*]] = llvm.mlir.constant(4 : i32) : i32 - // CHECK: %[[VAL_352:.*]] = llvm.add %[[VAL_181]], %[[VAL_351]] : i32 + // CHECK: %[[VAL_352:.*]] = llvm.xor %[[VAL_181]], %[[VAL_351]] : i32 // CHECK: %[[VAL_353:.*]] = llvm.mlir.constant(24 : i32) : i32 - // CHECK: %[[VAL_354:.*]] = llvm.add %[[VAL_182]], %[[VAL_353]] : i32 + // CHECK: %[[VAL_354:.*]] = llvm.xor %[[VAL_182]], %[[VAL_353]] : i32 // CHECK: %[[VAL_355:.*]] = llvm.mlir.constant(5 : i32) : i32 - // CHECK: %[[VAL_356:.*]] = llvm.add %[[VAL_181]], %[[VAL_355]] : i32 + // CHECK: %[[VAL_356:.*]] = llvm.xor %[[VAL_181]], %[[VAL_355]] : i32 // CHECK: %[[VAL_357:.*]] = llvm.mlir.constant(24 : i32) : i32 - // CHECK: %[[VAL_358:.*]] = llvm.add %[[VAL_182]], %[[VAL_357]] : i32 + // CHECK: %[[VAL_358:.*]] = llvm.xor %[[VAL_182]], %[[VAL_357]] : i32 // CHECK: %[[VAL_359:.*]] = llvm.mlir.constant(8 : i32) : i32 - // CHECK: %[[VAL_360:.*]] = llvm.add %[[VAL_181]], %[[VAL_359]] : i32 + // CHECK: %[[VAL_360:.*]] = llvm.xor %[[VAL_181]], %[[VAL_359]] : i32 // CHECK: %[[VAL_361:.*]] = llvm.mlir.constant(24 : i32) : i32 - // CHECK: %[[VAL_362:.*]] = llvm.add %[[VAL_182]], %[[VAL_361]] : i32 + // CHECK: %[[VAL_362:.*]] = llvm.xor %[[VAL_182]], %[[VAL_361]] : i32 // CHECK: %[[VAL_363:.*]] = llvm.mlir.constant(9 : i32) : i32 - // CHECK: %[[VAL_364:.*]] = llvm.add %[[VAL_181]], %[[VAL_363]] : i32 + // CHECK: %[[VAL_364:.*]] = llvm.xor %[[VAL_181]], %[[VAL_363]] : i32 // CHECK: %[[VAL_365:.*]] = llvm.mlir.constant(24 : i32) : i32 - // CHECK: %[[VAL_366:.*]] = llvm.add %[[VAL_182]], %[[VAL_365]] : i32 + // CHECK: %[[VAL_366:.*]] = llvm.xor %[[VAL_182]], %[[VAL_365]] : i32 // CHECK: %[[VAL_367:.*]] = llvm.mlir.constant(12 : i32) : i32 - // CHECK: %[[VAL_368:.*]] = llvm.add %[[VAL_181]], %[[VAL_367]] : i32 + // CHECK: %[[VAL_368:.*]] = llvm.xor %[[VAL_181]], %[[VAL_367]] : i32 // CHECK: %[[VAL_369:.*]] = llvm.mlir.constant(24 : i32) : i32 - // CHECK: %[[VAL_370:.*]] = llvm.add %[[VAL_182]], %[[VAL_369]] : i32 + // CHECK: %[[VAL_370:.*]] = llvm.xor %[[VAL_182]], %[[VAL_369]] : i32 // CHECK: %[[VAL_371:.*]] = llvm.mlir.constant(13 : i32) : i32 - // CHECK: %[[VAL_372:.*]] = llvm.add %[[VAL_181]], %[[VAL_371]] : i32 + // CHECK: %[[VAL_372:.*]] = llvm.xor %[[VAL_181]], %[[VAL_371]] : i32 // CHECK: %[[VAL_373:.*]] = llvm.mlir.constant(24 : i32) : i32 - // CHECK: %[[VAL_374:.*]] = llvm.add %[[VAL_182]], %[[VAL_373]] : i32 + // CHECK: %[[VAL_374:.*]] = llvm.xor %[[VAL_182]], %[[VAL_373]] : i32 // COM: Offsets of rep [1, 2]. // CHECK: %[[VAL_375:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: %[[VAL_376:.*]] = llvm.add %[[VAL_181]], %[[VAL_375]] : i32 + // CHECK: %[[VAL_376:.*]] = llvm.xor %[[VAL_181]], %[[VAL_375]] : i32 // CHECK: %[[VAL_377:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: %[[VAL_378:.*]] = llvm.add %[[VAL_182]], %[[VAL_377]] : i32 + // CHECK: %[[VAL_378:.*]] = llvm.xor %[[VAL_182]], %[[VAL_377]] : i32 // CHECK: %[[VAL_379:.*]] = llvm.mlir.constant(17 : i32) : i32 - // CHECK: %[[VAL_380:.*]] = llvm.add %[[VAL_181]], %[[VAL_379]] : i32 + // CHECK: %[[VAL_380:.*]] = llvm.xor %[[VAL_181]], %[[VAL_379]] : i32 // CHECK: %[[VAL_381:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: %[[VAL_382:.*]] = llvm.add %[[VAL_182]], %[[VAL_381]] : i32 + // CHECK: %[[VAL_382:.*]] = llvm.xor %[[VAL_182]], %[[VAL_381]] : i32 // CHECK: %[[VAL_383:.*]] = llvm.mlir.constant(20 : i32) : i32 - // CHECK: %[[VAL_384:.*]] = llvm.add %[[VAL_181]], %[[VAL_383]] : i32 + // CHECK: %[[VAL_384:.*]] = llvm.xor %[[VAL_181]], %[[VAL_383]] : i32 // CHECK: %[[VAL_385:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: %[[VAL_386:.*]] = llvm.add %[[VAL_182]], %[[VAL_385]] : i32 + // CHECK: %[[VAL_386:.*]] = llvm.xor %[[VAL_182]], %[[VAL_385]] : i32 // CHECK: %[[VAL_387:.*]] = llvm.mlir.constant(21 : i32) : i32 - // CHECK: %[[VAL_388:.*]] = llvm.add %[[VAL_181]], %[[VAL_387]] : i32 + // CHECK: %[[VAL_388:.*]] = llvm.xor %[[VAL_181]], %[[VAL_387]] : i32 // CHECK: %[[VAL_389:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: %[[VAL_390:.*]] = llvm.add %[[VAL_182]], %[[VAL_389]] : i32 + // CHECK: %[[VAL_390:.*]] = llvm.xor %[[VAL_182]], %[[VAL_389]] : i32 // CHECK: %[[VAL_391:.*]] = llvm.mlir.constant(24 : i32) : i32 - // CHECK: %[[VAL_392:.*]] = llvm.add %[[VAL_181]], %[[VAL_391]] : i32 + // CHECK: %[[VAL_392:.*]] = llvm.xor %[[VAL_181]], %[[VAL_391]] : i32 // CHECK: %[[VAL_393:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: %[[VAL_394:.*]] = llvm.add %[[VAL_182]], %[[VAL_393]] : i32 + // CHECK: %[[VAL_394:.*]] = llvm.xor %[[VAL_182]], %[[VAL_393]] : i32 // CHECK: %[[VAL_395:.*]] = llvm.mlir.constant(25 : i32) : i32 - // CHECK: %[[VAL_396:.*]] = llvm.add %[[VAL_181]], %[[VAL_395]] : i32 + // CHECK: %[[VAL_396:.*]] = llvm.xor %[[VAL_181]], %[[VAL_395]] : i32 // CHECK: %[[VAL_397:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: %[[VAL_398:.*]] = llvm.add %[[VAL_182]], %[[VAL_397]] : i32 + // CHECK: %[[VAL_398:.*]] = llvm.xor %[[VAL_182]], %[[VAL_397]] : i32 // CHECK: %[[VAL_399:.*]] = llvm.mlir.constant(28 : i32) : i32 - // CHECK: %[[VAL_400:.*]] = llvm.add %[[VAL_181]], %[[VAL_399]] : i32 + // CHECK: %[[VAL_400:.*]] = llvm.xor %[[VAL_181]], %[[VAL_399]] : i32 // CHECK: %[[VAL_401:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: %[[VAL_402:.*]] = llvm.add %[[VAL_182]], %[[VAL_401]] : i32 + // CHECK: %[[VAL_402:.*]] = llvm.xor %[[VAL_182]], %[[VAL_401]] : i32 // CHECK: %[[VAL_403:.*]] = llvm.mlir.constant(29 : i32) : i32 - // CHECK: %[[VAL_404:.*]] = llvm.add %[[VAL_181]], %[[VAL_403]] : i32 + // CHECK: %[[VAL_404:.*]] = llvm.xor %[[VAL_181]], %[[VAL_403]] : i32 // CHECK: %[[VAL_405:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: %[[VAL_406:.*]] = llvm.add %[[VAL_182]], %[[VAL_405]] : i32 + // CHECK: %[[VAL_406:.*]] = llvm.xor %[[VAL_182]], %[[VAL_405]] : i32 // COM: Offsets of rep [1, 3]. // CHECK: %[[VAL_407:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: %[[VAL_408:.*]] = llvm.add %[[VAL_181]], %[[VAL_407]] : i32 + // CHECK: %[[VAL_408:.*]] = llvm.xor %[[VAL_181]], %[[VAL_407]] : i32 // CHECK: %[[VAL_409:.*]] = llvm.mlir.constant(24 : i32) : i32 - // CHECK: %[[VAL_410:.*]] = llvm.add %[[VAL_182]], %[[VAL_409]] : i32 + // CHECK: %[[VAL_410:.*]] = llvm.xor %[[VAL_182]], %[[VAL_409]] : i32 // CHECK: %[[VAL_411:.*]] = llvm.mlir.constant(17 : i32) : i32 - // CHECK: %[[VAL_412:.*]] = llvm.add %[[VAL_181]], %[[VAL_411]] : i32 + // CHECK: %[[VAL_412:.*]] = llvm.xor %[[VAL_181]], %[[VAL_411]] : i32 // CHECK: %[[VAL_413:.*]] = llvm.mlir.constant(24 : i32) : i32 - // CHECK: %[[VAL_414:.*]] = llvm.add %[[VAL_182]], %[[VAL_413]] : i32 + // CHECK: %[[VAL_414:.*]] = llvm.xor %[[VAL_182]], %[[VAL_413]] : i32 // CHECK: %[[VAL_415:.*]] = llvm.mlir.constant(20 : i32) : i32 - // CHECK: %[[VAL_416:.*]] = llvm.add %[[VAL_181]], %[[VAL_415]] : i32 + // CHECK: %[[VAL_416:.*]] = llvm.xor %[[VAL_181]], %[[VAL_415]] : i32 // CHECK: %[[VAL_417:.*]] = llvm.mlir.constant(24 : i32) : i32 - // CHECK: %[[VAL_418:.*]] = llvm.add %[[VAL_182]], %[[VAL_417]] : i32 + // CHECK: %[[VAL_418:.*]] = llvm.xor %[[VAL_182]], %[[VAL_417]] : i32 // CHECK: %[[VAL_419:.*]] = llvm.mlir.constant(21 : i32) : i32 - // CHECK: %[[VAL_420:.*]] = llvm.add %[[VAL_181]], %[[VAL_419]] : i32 + // CHECK: %[[VAL_420:.*]] = llvm.xor %[[VAL_181]], %[[VAL_419]] : i32 // CHECK: %[[VAL_421:.*]] = llvm.mlir.constant(24 : i32) : i32 - // CHECK: %[[VAL_422:.*]] = llvm.add %[[VAL_182]], %[[VAL_421]] : i32 + // CHECK: %[[VAL_422:.*]] = llvm.xor %[[VAL_182]], %[[VAL_421]] : i32 // CHECK: %[[VAL_423:.*]] = llvm.mlir.constant(24 : i32) : i32 - // CHECK: %[[VAL_424:.*]] = llvm.add %[[VAL_181]], %[[VAL_423]] : i32 + // CHECK: %[[VAL_424:.*]] = llvm.xor %[[VAL_181]], %[[VAL_423]] : i32 // CHECK: %[[VAL_425:.*]] = llvm.mlir.constant(24 : i32) : i32 - // CHECK: %[[VAL_426:.*]] = llvm.add %[[VAL_182]], %[[VAL_425]] : i32 + // CHECK: %[[VAL_426:.*]] = llvm.xor %[[VAL_182]], %[[VAL_425]] : i32 // CHECK: %[[VAL_427:.*]] = llvm.mlir.constant(25 : i32) : i32 - // CHECK: %[[VAL_428:.*]] = llvm.add %[[VAL_181]], %[[VAL_427]] : i32 + // CHECK: %[[VAL_428:.*]] = llvm.xor %[[VAL_181]], %[[VAL_427]] : i32 // CHECK: %[[VAL_429:.*]] = llvm.mlir.constant(24 : i32) : i32 - // CHECK: %[[VAL_430:.*]] = llvm.add %[[VAL_182]], %[[VAL_429]] : i32 + // CHECK: %[[VAL_430:.*]] = llvm.xor %[[VAL_182]], %[[VAL_429]] : i32 // CHECK: %[[VAL_431:.*]] = llvm.mlir.constant(28 : i32) : i32 - // CHECK: %[[VAL_432:.*]] = llvm.add %[[VAL_181]], %[[VAL_431]] : i32 + // CHECK: %[[VAL_432:.*]] = llvm.xor %[[VAL_181]], %[[VAL_431]] : i32 // CHECK: %[[VAL_433:.*]] = llvm.mlir.constant(24 : i32) : i32 - // CHECK: %[[VAL_434:.*]] = llvm.add %[[VAL_182]], %[[VAL_433]] : i32 + // CHECK: %[[VAL_434:.*]] = llvm.xor %[[VAL_182]], %[[VAL_433]] : i32 // CHECK: %[[VAL_435:.*]] = llvm.mlir.constant(29 : i32) : i32 - // CHECK: %[[VAL_436:.*]] = llvm.add %[[VAL_181]], %[[VAL_435]] : i32 + // CHECK: %[[VAL_436:.*]] = llvm.xor %[[VAL_181]], %[[VAL_435]] : i32 // CHECK: %[[VAL_437:.*]] = llvm.mlir.constant(24 : i32) : i32 - // CHECK: %[[VAL_438:.*]] = llvm.add %[[VAL_182]], %[[VAL_437]] : i32 + // CHECK: %[[VAL_438:.*]] = llvm.xor %[[VAL_182]], %[[VAL_437]] : i32 tt.print " x: " {hex = false, isSigned = array} : %cst : tensor<32x32xf16, #dot_operand_b> tt.return } diff --git a/third_party/intel/include/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.h b/third_party/intel/include/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.h index 2758e6341..249849520 100644 --- a/third_party/intel/include/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.h +++ b/third_party/intel/include/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.h @@ -18,6 +18,10 @@ namespace mlir::triton::gpu { LinearLayout DPAStoLinearLayout(ArrayRef shape, Attribute layout, unsigned opIdx = 2); +std::optional +dotOperandDpasToLinearLayout(DotOperandEncodingAttr dotDpasLayout, + ArrayRef shape); + } // namespace mlir::triton::gpu #endif // TRITON_DIALECT_TRITONINTELGPU_IR_LINEARLAYOUTCONVERSIONS_H diff --git a/third_party/intel/lib/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.cpp b/third_party/intel/lib/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.cpp index d056fb229..90e950bd0 100644 --- a/third_party/intel/lib/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.cpp +++ b/third_party/intel/lib/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.cpp @@ -582,4 +582,13 @@ LinearLayout DPAStoLinearLayout(ArrayRef shape, Attribute layout, CTALayoutAttr::getDefault(ctx, rank), shape); } +std::optional +dotOperandDpasToLinearLayout(DotOperandEncodingAttr dotDpasLayout, + ArrayRef shape) { + auto dpasLayout = cast(dotDpasLayout.getParent()); + if (dotDpasLayout.getOpIdx() == 0) + return std::nullopt; + return DPAStoLinearLayout(shape, dpasLayout, dotDpasLayout.getOpIdx()); +} + } // namespace mlir::triton::gpu From 2202ca754c53de28c4fb14e9e7cebe7ec0e5d22f Mon Sep 17 00:00:00 2001 From: Whitney Tsang Date: Tue, 8 Oct 2024 03:38:30 -0400 Subject: [PATCH 7/7] [FA][ScheduleLoad] Fix bug exposed by causal=true (#2433) Cannot move ops that are used by other ops in another region. Signed-off-by: Whitney Tsang --- test/TritonIntelGPU/schedule-load.mlir | 24 +++++++++++++++++++ .../TritonIntelGPUTransforms/ScheduleLoad.cpp | 9 ++++++- 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/test/TritonIntelGPU/schedule-load.mlir b/test/TritonIntelGPU/schedule-load.mlir index c984cfabd..6352b6d03 100644 --- a/test/TritonIntelGPU/schedule-load.mlir +++ b/test/TritonIntelGPU/schedule-load.mlir @@ -303,3 +303,27 @@ module attributes {"triton_gpu.num-warps" = 32 : i32, "triton_gpu.threads-per-wa tt.return } } + +// ----- + +tt.func public @test(%arg0: !tt.ptr>, %arg1: !tt.ptr>) { + %lb = arith.constant 0 : i32 + %ub = tt.get_program_id x : i32 + %st = arith.constant 32 : i32 + %zero = arith.constant dense<0.000000e+00> : tensor<8x16xf32> + %common = tt.load %arg1 {DotIdx = 0 : i32} : !tt.ptr> + // COM: Check %common is not moved in the loop. + // CHECK: tt.load %arg1 + // CHECK-COUNT-2: scf.for + scf.for %iv0 = %lb to %ub step %st : i32 { + %load1 = tt.load %arg0 {DotIdx = 1 : i32} : !tt.ptr> + %extract1 = triton_intel_gpu.extract %common[0] : tensor<8x32xf16> -> tensor<8x16xf16> + %dot1 = tt.dot %extract1, %load1, %zero, inputPrecision = tf32 {"schedule-group" = 0 : i32} : tensor<8x16xf16> * tensor<16x16xf16> -> tensor<8x16xf32> + } + scf.for %iv1 = %lb to %ub step %st : i32 { + %load2 = tt.load %arg0 {DotIdx = 1 : i32} : !tt.ptr> + %extract2 = triton_intel_gpu.extract %common[0] : tensor<8x32xf16> -> tensor<8x16xf16> + %dot2 = tt.dot %extract2, %load2, %zero, inputPrecision = tf32 {"schedule-group" = 0 : i32} : tensor<8x16xf16> * tensor<16x16xf16> -> tensor<8x16xf32> + } + tt.return +} diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/ScheduleLoad.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/ScheduleLoad.cpp index 07c9e0610..41e975f9f 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/ScheduleLoad.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/ScheduleLoad.cpp @@ -71,8 +71,15 @@ class ScheduleLoadPass for (SmallVector &dots : dotsGroup) { SmallVector notVisited = getNotVisitedUses(dots); for (Value val : notVisited) { - if (Operation *op = val.getDefiningOp()) + if (Operation *op = val.getDefiningOp()) { + // Cannot move op that used by other ops in another region. + Region *rgn = dots.begin()->getOperation()->getParentRegion(); + if (any_of(val.getUsers(), [&](Operation *user) { + return user->getParentRegion() != rgn; + })) + continue; op->moveBefore(dots.begin()->getOperation()); + } } } });