From 755077cde95db74da41f3769293f83521699cabe Mon Sep 17 00:00:00 2001 From: Lixun Zhang Date: Sun, 29 Sep 2024 23:07:37 -0500 Subject: [PATCH] [AMD] Always swap operands of mfma and use mfma.transposed layout (#4767) This helps to improve writeout to use `global_store_dwordx2`. Along the way this PR - Fixed the issue with getOrder for mfma layout - Fixed the issue with reduceOp when dealing with mfma.transposed layout In general, getOrder and getThreadOrder can return different values, and this is the case for mfma.transposed layout. Therefore, we shouldn't assume order and threadOrder are always the same. --- include/triton/Dialect/TritonGPU/IR/Dialect.h | 25 ++++++++++++++++++- lib/Analysis/Utility.cpp | 4 +-- .../TritonGPUToLLVM/ReduceOpToLLVM.cpp | 3 ++- lib/Dialect/TritonGPU/IR/Dialect.cpp | 21 +++++++++------- .../TritonGPU/IR/LinearLayoutConversions.cpp | 14 +++++++++++ .../AccelerateAMDMatmul.cpp | 22 +++------------- unittest/Dialect/TritonGPU/DialectTest.cpp | 8 +++--- .../TritonGPU/LinearLayoutConversionsTest.cpp | 4 +-- 8 files changed, 63 insertions(+), 38 deletions(-) diff --git a/include/triton/Dialect/TritonGPU/IR/Dialect.h b/include/triton/Dialect/TritonGPU/IR/Dialect.h index 3b012a630541..74ea99b58891 100644 --- a/include/triton/Dialect/TritonGPU/IR/Dialect.h +++ b/include/triton/Dialect/TritonGPU/IR/Dialect.h @@ -75,9 +75,32 @@ getThreadsPerWarpWithUniqueData(Attribute layout, SmallVector getWarpsPerCTAWithUniqueData(Attribute layout, ArrayRef tensorShape); +// Returns the dimensions of the tensor from minor (fast-varying) to +// major (slow-varying). For blocked, mma, and dotOperand layouts, +// though the elements are in registers, the order refers to memory +// layout of the original tensor in global memory. +// For shared Layout, the order refers to which dimension of the original tensor +// is contiguous in shared memory. +SmallVector getOrder(Attribute layout); + +// Returns the dimensions along which warpId's are distributed. +// warpsPerCTA only tells the warp layout in the CTA, e.g. warpsPerCTA = [2, 4] +// tells there are 2 warps along dim0 and 4 warps along dim1. +// warpOrder tells the specific order when distributing warp IDs. +// E.g. warpOrder = [0, 1] means the warp IDs are distributed as follows +// [warp0 warp2 warp4 warp6] +// [warp1 warp3 warp5 warp7] +// Note that in most cases, getWarpOrder and getOrder return the same results. +// But this is not guaranteed. SmallVector getWarpOrder(Attribute layout); -SmallVector getOrder(Attribute layout); +// Returns the dimensions along which threadId's are distributed. +// Similar to warpOrder, threadOrder is necessary to tell the specific thread +// distribution in the warp. +// Note that, in most cases, getThreadOrder and getOrder return the same +// results. But this is not guaranteed. One exception is mfma.transposed layout, +// in which getOrder returns [1, 0] but getThreadOrder returns [0, 1]. +SmallVector getThreadOrder(Attribute layout); CTALayoutAttr getCTALayout(Attribute layout); diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index 56630c731858..b9468aa3e380 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -36,7 +36,7 @@ SmallVector getParentOrder(Attribute layout) { if (auto sliceEncoding = mlir::dyn_cast(layout)) { return getParentOrder(sliceEncoding.getParent()); } - return getOrder(layout); + return getThreadOrder(layout); } } // namespace @@ -75,7 +75,7 @@ unsigned ReduceOpHelper::getThreadOffsetOnReductionAxis() { threadOffset = threadsPerWarp[sliceLayout.getDim()]; } else { auto threadsPerWarp = getThreadsPerWarp(srcLayout); - auto order = getOrder(srcLayout); + auto order = getThreadOrder(srcLayout); for (unsigned i = 0; i < order.size(); i++) { if (order[i] == axis) break; diff --git a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp index 16c9991a17b0..414328be50cf 100644 --- a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp @@ -9,6 +9,7 @@ using namespace mlir::triton; using ::mlir::LLVM::delinearize; using ::mlir::LLVM::linearize; using ::mlir::triton::gpu::getOrder; +using ::mlir::triton::gpu::getThreadOrder; using ::mlir::triton::gpu::getTotalElemsPerThread; namespace { @@ -271,7 +272,7 @@ struct ReduceOpConversion auto threadsPerWarp = triton::gpu::getThreadsPerWarpWithUniqueData(srcLayout, srcShape); - auto order = getOrder(srcLayout); + auto order = getThreadOrder(srcLayout); SmallVector multiDimLaneId = delinearize(rewriter, loc, laneId, threadsPerWarp, order); Value laneIdAxis = multiDimLaneId[axis]; diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index a454fef56674..48f31bdf2a9d 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -256,14 +256,6 @@ SmallVector getOrder(Attribute layout) { auto rank = distributedLayout.getWarpsPerCTA().size(); SmallVector order(rank); std::iota(order.rbegin(), order.rend(), 0); - auto mfmaLayout = dyn_cast(layout); - if (!mfmaLayout) - return order; - // For transposed MFMA layouts, we swap M and N dimensions, which is - // always the first two in order; as we can have an optional batch - // dimension following them. - if (mfmaLayout.getIsTransposed()) - std::swap(order[0], order[1]); return order; } if (auto dotLayout = dyn_cast(layout)) { @@ -290,6 +282,14 @@ SmallVector getOrder(Attribute layout) { return {}; }; +SmallVector getThreadOrder(Attribute layout) { + if (auto distributedLayout = mlir::dyn_cast(layout)) + return distributedLayout.getThreadOrder(); + else + llvm::report_fatal_error("Unimplemented usage of getThreadOrder"); + return {}; +}; + CTALayoutAttr getCTALayout(Attribute layout) { if (auto distributedLayout = mlir::dyn_cast(layout)) { @@ -1536,7 +1536,10 @@ SmallVector AMDMfmaEncodingAttr::getWarpOrder() const { return ::getWarpOrder(*this); } SmallVector AMDMfmaEncodingAttr::getThreadOrder() const { - return ::getOrder(*this); + auto order = ::getOrder(*this); + if (getIsTransposed()) + std::swap(order[0], order[1]); + return order; } SmallVector AMDMfmaEncodingAttr::getThreadsPerWarp() const { unsigned rows, cols; diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index 286b1eac519c..f576ce215417 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -507,6 +507,13 @@ AMDMfmaEncodingAttr::toLinearLayout(ArrayRef shape) const { {{kRegister, {{0, 1}, {0, 2}, {0, 8}, /*gap*/ {0, 16}}}, {kLane, {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {16, 0}, /*gap*/ {0, 4}}}}, {outDimNames[order[0]], outDimNames[order[1]]}); + // For mfma.transposed layout, the element ownership among threads are + // "transposed" within each warp. + if (getIsTransposed()) + tileLayout = LinearLayout( + {{kRegister, {{1, 0}, {2, 0}, {8, 0}, /*gap*/ {16, 0}}}, + {kLane, {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, /*gap*/ {4, 0}}}}, + {outDimNames[order[0]], outDimNames[order[1]]}); } else { assert(getMDim() == 16); // For mfma with 16x16 output, each of the 64 threads holds 4 elements. @@ -521,6 +528,13 @@ AMDMfmaEncodingAttr::toLinearLayout(ArrayRef shape) const { {{kRegister, {{0, 1}, {0, 2}}}, {kLane, {{1, 0}, {2, 0}, {4, 0}, {8, 0}, /*gap*/ {0, 4}, {0, 8}}}}, {outDimNames[order[0]], outDimNames[order[1]]}); + // For mfma.transposed layout, the element ownership among threads are + // "transposed" within each warp. + if (getIsTransposed()) + tileLayout = LinearLayout( + {{kRegister, {{1, 0}, {2, 0}}}, + {kLane, {{0, 1}, {0, 2}, {0, 4}, {0, 8}, /*gap*/ {4, 0}, {8, 0}}}}, + {outDimNames[order[0]], outDimNames[order[1]]}); } if (hasBatchDim) { assert(order[2] == 0); diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp index bf976a8138dc..21b74ecf99fa 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp @@ -269,23 +269,6 @@ class BlockedToMFMA : public RewritePattern { : RewritePattern(tt::DotOp::getOperationName(), 2, context), mfmaVersion(mfmaVersion), enforcedNonKDim(nonKDim), kPack(kPack) {} - bool isChainDot(tt::DotOp &dotOp) const { - auto filter = [&dotOp](Operation *op) { - return op->getParentRegion() == dotOp->getParentRegion(); - }; - ForwardSliceOptions fwdOpt; - fwdOpt.filter = filter; - BackwardSliceOptions bwdOpt; - bwdOpt.omitBlockArguments = true; - bwdOpt.filter = filter; - auto slices = getSlice(dotOp, bwdOpt, fwdOpt); - for (Operation *op : slices) { - if (isa(op) && (op != dotOp)) - return true; - } - return false; - } - bool isSecondDot(tt::DotOp &dotOp) const { auto filter = [&dotOp](Operation *op) { return op->getParentRegion() == dotOp->getParentRegion(); @@ -400,11 +383,12 @@ class BlockedToMFMA : public RewritePattern { auto warpsPerTile = warpsPerTileMFMA(dotOp, retShape, numWarps, {mDim, nDim}); - bool isTransposed = isChainDot(dotOp); + // Always use transposed mfma layout. This enables larger vectorization + // for global store instructions mfmaEnc = ttg::AMDMfmaEncodingAttr::get( oldRetType.getContext(), /*versionMajor*/ mfmaVersion, /*versionMinor*/ 0, warpsPerTile, - /*instrShape*/ mDim, nDim, isTransposed, CTALayout); + /*instrShape*/ mDim, nDim, /*isTransposed*/ true, CTALayout); Type mfmaAccType; if (oldRetType.getElementType().isIntOrIndex()) diff --git a/unittest/Dialect/TritonGPU/DialectTest.cpp b/unittest/Dialect/TritonGPU/DialectTest.cpp index 67b1d9e9bce0..e3f521f1b3da 100644 --- a/unittest/Dialect/TritonGPU/DialectTest.cpp +++ b/unittest/Dialect/TritonGPU/DialectTest.cpp @@ -559,7 +559,7 @@ TEST_F(AMDMfmaLayoutTest, mfma32) { auto tmfma2d = createTransposedMFMA(32, 32, {2, 4}); ASSERT_THAT(tmfma2d.getThreadOrder(), testing::ElementsAre(0u, 1u)); - ASSERT_THAT(tmfma2d.getWarpOrder(), testing::ElementsAre(0u, 1u)); + ASSERT_THAT(tmfma2d.getWarpOrder(), testing::ElementsAre(1u, 0u)); auto mfma3d = createMFMA(32, 32, {2, 4, 1}); ASSERT_THAT(mfma3d.getThreadOrder(), testing::ElementsAre(2u, 1u, 0u)); @@ -567,7 +567,7 @@ TEST_F(AMDMfmaLayoutTest, mfma32) { auto tmfma3d = createTransposedMFMA(32, 32, {2, 4, 1}); ASSERT_THAT(tmfma3d.getThreadOrder(), testing::ElementsAre(1u, 2u, 0u)); - ASSERT_THAT(tmfma3d.getWarpOrder(), testing::ElementsAre(1u, 2u, 0u)); + ASSERT_THAT(tmfma3d.getWarpOrder(), testing::ElementsAre(2u, 1u, 0u)); } TEST_F(AMDMfmaLayoutTest, mfma16) { @@ -577,7 +577,7 @@ TEST_F(AMDMfmaLayoutTest, mfma16) { auto tmfma2d = createTransposedMFMA(16, 16, {2, 4}); ASSERT_THAT(tmfma2d.getThreadOrder(), testing::ElementsAre(0u, 1u)); - ASSERT_THAT(tmfma2d.getWarpOrder(), testing::ElementsAre(0u, 1u)); + ASSERT_THAT(tmfma2d.getWarpOrder(), testing::ElementsAre(1u, 0u)); auto mfma3d = createMFMA(16, 16, {2, 4, 1}); ASSERT_THAT(mfma3d.getThreadOrder(), testing::ElementsAre(2u, 1u, 0u)); @@ -585,7 +585,7 @@ TEST_F(AMDMfmaLayoutTest, mfma16) { auto tmfma3d = createTransposedMFMA(16, 16, {2, 4, 1}); ASSERT_THAT(tmfma3d.getThreadOrder(), testing::ElementsAre(1u, 2u, 0u)); - ASSERT_THAT(tmfma3d.getWarpOrder(), testing::ElementsAre(1u, 2u, 0u)); + ASSERT_THAT(tmfma3d.getWarpOrder(), testing::ElementsAre(2u, 1u, 0u)); } } // anonymous namespace diff --git a/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp b/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp index 0b7a0f78211d..7d918602a705 100644 --- a/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp +++ b/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp @@ -529,14 +529,14 @@ TEST_F(LinearLayoutConversionsTest, MFMA32_2x4Warps) { LinearLayout( {{S("register"), {{0, 1}, {0, 2}, {0, 8}, {0, 16}}}, {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {16, 0}, {0, 4}}}, - {S("warp"), {{32, 0}, {0, 0}, {0, 0}}}, + {S("warp"), {{0, 0}, {0, 0}, {32, 0}}}, {S("block"), {}}}, {S("dim0"), S("dim1")})); EXPECT_EQ(toLinearLayout({128, 128}, mfmaT), LinearLayout( {{S("register"), {{0, 1}, {0, 2}, {0, 8}, {0, 16}, {64, 0}}}, {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {16, 0}, {0, 4}}}, - {S("warp"), {{32, 0}, {0, 32}, {0, 64}}}, + {S("warp"), {{0, 32}, {0, 64}, {32, 0}}}, {S("block"), {}}}, {S("dim0"), S("dim1")})); }