Skip to content

Commit

Permalink
[AMD] Always swap operands of mfma and use mfma.transposed layout (tr…
Browse files Browse the repository at this point in the history
…iton-lang#4767)

This helps to improve writeout to use `global_store_dwordx2`.

Along the way this PR
- Fixed the issue with getOrder for mfma layout
- Fixed the issue with reduceOp when dealing with mfma.transposed layout

In general, getOrder and getThreadOrder can return different values, and
this is the case for mfma.transposed layout. Therefore, we shouldn't
assume order and threadOrder are always the same.
  • Loading branch information
zhanglx13 authored and Luosuu committed Nov 13, 2024
1 parent 30ec95d commit 47b2626
Show file tree
Hide file tree
Showing 8 changed files with 63 additions and 38 deletions.
25 changes: 24 additions & 1 deletion include/triton/Dialect/TritonGPU/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,32 @@ getThreadsPerWarpWithUniqueData(Attribute layout,
SmallVector<unsigned>
getWarpsPerCTAWithUniqueData(Attribute layout, ArrayRef<int64_t> tensorShape);

// Returns the dimensions of the tensor from minor (fast-varying) to
// major (slow-varying). For blocked, mma, and dotOperand layouts,
// though the elements are in registers, the order refers to memory
// layout of the original tensor in global memory.
// For shared Layout, the order refers to which dimension of the original tensor
// is contiguous in shared memory.
SmallVector<unsigned> getOrder(Attribute layout);

// Returns the dimensions along which warpId's are distributed.
// warpsPerCTA only tells the warp layout in the CTA, e.g. warpsPerCTA = [2, 4]
// tells there are 2 warps along dim0 and 4 warps along dim1.
// warpOrder tells the specific order when distributing warp IDs.
// E.g. warpOrder = [0, 1] means the warp IDs are distributed as follows
// [warp0 warp2 warp4 warp6]
// [warp1 warp3 warp5 warp7]
// Note that in most cases, getWarpOrder and getOrder return the same results.
// But this is not guaranteed.
SmallVector<unsigned> getWarpOrder(Attribute layout);

SmallVector<unsigned> getOrder(Attribute layout);
// Returns the dimensions along which threadId's are distributed.
// Similar to warpOrder, threadOrder is necessary to tell the specific thread
// distribution in the warp.
// Note that, in most cases, getThreadOrder and getOrder return the same
// results. But this is not guaranteed. One exception is mfma.transposed layout,
// in which getOrder returns [1, 0] but getThreadOrder returns [0, 1].
SmallVector<unsigned> getThreadOrder(Attribute layout);

CTALayoutAttr getCTALayout(Attribute layout);

Expand Down
4 changes: 2 additions & 2 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ SmallVector<unsigned> getParentOrder(Attribute layout) {
if (auto sliceEncoding = mlir::dyn_cast<SliceEncodingAttr>(layout)) {
return getParentOrder(sliceEncoding.getParent());
}
return getOrder(layout);
return getThreadOrder(layout);
}

} // namespace
Expand Down Expand Up @@ -75,7 +75,7 @@ unsigned ReduceOpHelper::getThreadOffsetOnReductionAxis() {
threadOffset = threadsPerWarp[sliceLayout.getDim()];
} else {
auto threadsPerWarp = getThreadsPerWarp(srcLayout);
auto order = getOrder(srcLayout);
auto order = getThreadOrder(srcLayout);
for (unsigned i = 0; i < order.size(); i++) {
if (order[i] == axis)
break;
Expand Down
3 changes: 2 additions & 1 deletion lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ using namespace mlir::triton;
using ::mlir::LLVM::delinearize;
using ::mlir::LLVM::linearize;
using ::mlir::triton::gpu::getOrder;
using ::mlir::triton::gpu::getThreadOrder;
using ::mlir::triton::gpu::getTotalElemsPerThread;

namespace {
Expand Down Expand Up @@ -271,7 +272,7 @@ struct ReduceOpConversion

auto threadsPerWarp =
triton::gpu::getThreadsPerWarpWithUniqueData(srcLayout, srcShape);
auto order = getOrder(srcLayout);
auto order = getThreadOrder(srcLayout);
SmallVector<Value> multiDimLaneId =
delinearize(rewriter, loc, laneId, threadsPerWarp, order);
Value laneIdAxis = multiDimLaneId[axis];
Expand Down
21 changes: 12 additions & 9 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -256,14 +256,6 @@ SmallVector<unsigned> getOrder(Attribute layout) {
auto rank = distributedLayout.getWarpsPerCTA().size();
SmallVector<unsigned> order(rank);
std::iota(order.rbegin(), order.rend(), 0);
auto mfmaLayout = dyn_cast<AMDMfmaEncodingAttr>(layout);
if (!mfmaLayout)
return order;
// For transposed MFMA layouts, we swap M and N dimensions, which is
// always the first two in order; as we can have an optional batch
// dimension following them.
if (mfmaLayout.getIsTransposed())
std::swap(order[0], order[1]);
return order;
}
if (auto dotLayout = dyn_cast<DotOperandEncodingAttr>(layout)) {
Expand All @@ -290,6 +282,14 @@ SmallVector<unsigned> getOrder(Attribute layout) {
return {};
};

SmallVector<unsigned> getThreadOrder(Attribute layout) {
if (auto distributedLayout = mlir::dyn_cast<DistributedEncodingTrait>(layout))
return distributedLayout.getThreadOrder();
else
llvm::report_fatal_error("Unimplemented usage of getThreadOrder");
return {};
};

CTALayoutAttr getCTALayout(Attribute layout) {
if (auto distributedLayout =
mlir::dyn_cast<DistributedEncodingTrait>(layout)) {
Expand Down Expand Up @@ -1536,7 +1536,10 @@ SmallVector<unsigned> AMDMfmaEncodingAttr::getWarpOrder() const {
return ::getWarpOrder(*this);
}
SmallVector<unsigned> AMDMfmaEncodingAttr::getThreadOrder() const {
return ::getOrder(*this);
auto order = ::getOrder(*this);
if (getIsTransposed())
std::swap(order[0], order[1]);
return order;
}
SmallVector<unsigned> AMDMfmaEncodingAttr::getThreadsPerWarp() const {
unsigned rows, cols;
Expand Down
14 changes: 14 additions & 0 deletions lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,13 @@ AMDMfmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
{{kRegister, {{0, 1}, {0, 2}, {0, 8}, /*gap*/ {0, 16}}},
{kLane, {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {16, 0}, /*gap*/ {0, 4}}}},
{outDimNames[order[0]], outDimNames[order[1]]});
// For mfma.transposed layout, the element ownership among threads are
// "transposed" within each warp.
if (getIsTransposed())
tileLayout = LinearLayout(
{{kRegister, {{1, 0}, {2, 0}, {8, 0}, /*gap*/ {16, 0}}},
{kLane, {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, /*gap*/ {4, 0}}}},
{outDimNames[order[0]], outDimNames[order[1]]});
} else {
assert(getMDim() == 16);
// For mfma with 16x16 output, each of the 64 threads holds 4 elements.
Expand All @@ -521,6 +528,13 @@ AMDMfmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
{{kRegister, {{0, 1}, {0, 2}}},
{kLane, {{1, 0}, {2, 0}, {4, 0}, {8, 0}, /*gap*/ {0, 4}, {0, 8}}}},
{outDimNames[order[0]], outDimNames[order[1]]});
// For mfma.transposed layout, the element ownership among threads are
// "transposed" within each warp.
if (getIsTransposed())
tileLayout = LinearLayout(
{{kRegister, {{1, 0}, {2, 0}}},
{kLane, {{0, 1}, {0, 2}, {0, 4}, {0, 8}, /*gap*/ {4, 0}, {8, 0}}}},
{outDimNames[order[0]], outDimNames[order[1]]});
}
if (hasBatchDim) {
assert(order[2] == 0);
Expand Down
22 changes: 3 additions & 19 deletions third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -269,23 +269,6 @@ class BlockedToMFMA : public RewritePattern {
: RewritePattern(tt::DotOp::getOperationName(), 2, context),
mfmaVersion(mfmaVersion), enforcedNonKDim(nonKDim), kPack(kPack) {}

bool isChainDot(tt::DotOp &dotOp) const {
auto filter = [&dotOp](Operation *op) {
return op->getParentRegion() == dotOp->getParentRegion();
};
ForwardSliceOptions fwdOpt;
fwdOpt.filter = filter;
BackwardSliceOptions bwdOpt;
bwdOpt.omitBlockArguments = true;
bwdOpt.filter = filter;
auto slices = getSlice(dotOp, bwdOpt, fwdOpt);
for (Operation *op : slices) {
if (isa<tt::DotOp>(op) && (op != dotOp))
return true;
}
return false;
}

bool isSecondDot(tt::DotOp &dotOp) const {
auto filter = [&dotOp](Operation *op) {
return op->getParentRegion() == dotOp->getParentRegion();
Expand Down Expand Up @@ -400,11 +383,12 @@ class BlockedToMFMA : public RewritePattern {
auto warpsPerTile =
warpsPerTileMFMA(dotOp, retShape, numWarps, {mDim, nDim});

bool isTransposed = isChainDot(dotOp);
// Always use transposed mfma layout. This enables larger vectorization
// for global store instructions
mfmaEnc = ttg::AMDMfmaEncodingAttr::get(
oldRetType.getContext(),
/*versionMajor*/ mfmaVersion, /*versionMinor*/ 0, warpsPerTile,
/*instrShape*/ mDim, nDim, isTransposed, CTALayout);
/*instrShape*/ mDim, nDim, /*isTransposed*/ true, CTALayout);

Type mfmaAccType;
if (oldRetType.getElementType().isIntOrIndex())
Expand Down
8 changes: 4 additions & 4 deletions unittest/Dialect/TritonGPU/DialectTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -559,15 +559,15 @@ TEST_F(AMDMfmaLayoutTest, mfma32) {

auto tmfma2d = createTransposedMFMA(32, 32, {2, 4});
ASSERT_THAT(tmfma2d.getThreadOrder(), testing::ElementsAre(0u, 1u));
ASSERT_THAT(tmfma2d.getWarpOrder(), testing::ElementsAre(0u, 1u));
ASSERT_THAT(tmfma2d.getWarpOrder(), testing::ElementsAre(1u, 0u));

auto mfma3d = createMFMA(32, 32, {2, 4, 1});
ASSERT_THAT(mfma3d.getThreadOrder(), testing::ElementsAre(2u, 1u, 0u));
ASSERT_THAT(mfma3d.getWarpOrder(), testing::ElementsAre(2u, 1u, 0u));

auto tmfma3d = createTransposedMFMA(32, 32, {2, 4, 1});
ASSERT_THAT(tmfma3d.getThreadOrder(), testing::ElementsAre(1u, 2u, 0u));
ASSERT_THAT(tmfma3d.getWarpOrder(), testing::ElementsAre(1u, 2u, 0u));
ASSERT_THAT(tmfma3d.getWarpOrder(), testing::ElementsAre(2u, 1u, 0u));
}

TEST_F(AMDMfmaLayoutTest, mfma16) {
Expand All @@ -577,15 +577,15 @@ TEST_F(AMDMfmaLayoutTest, mfma16) {

auto tmfma2d = createTransposedMFMA(16, 16, {2, 4});
ASSERT_THAT(tmfma2d.getThreadOrder(), testing::ElementsAre(0u, 1u));
ASSERT_THAT(tmfma2d.getWarpOrder(), testing::ElementsAre(0u, 1u));
ASSERT_THAT(tmfma2d.getWarpOrder(), testing::ElementsAre(1u, 0u));

auto mfma3d = createMFMA(16, 16, {2, 4, 1});
ASSERT_THAT(mfma3d.getThreadOrder(), testing::ElementsAre(2u, 1u, 0u));
ASSERT_THAT(mfma3d.getWarpOrder(), testing::ElementsAre(2u, 1u, 0u));

auto tmfma3d = createTransposedMFMA(16, 16, {2, 4, 1});
ASSERT_THAT(tmfma3d.getThreadOrder(), testing::ElementsAre(1u, 2u, 0u));
ASSERT_THAT(tmfma3d.getWarpOrder(), testing::ElementsAre(1u, 2u, 0u));
ASSERT_THAT(tmfma3d.getWarpOrder(), testing::ElementsAre(2u, 1u, 0u));
}

} // anonymous namespace
Expand Down
4 changes: 2 additions & 2 deletions unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -529,14 +529,14 @@ TEST_F(LinearLayoutConversionsTest, MFMA32_2x4Warps) {
LinearLayout(
{{S("register"), {{0, 1}, {0, 2}, {0, 8}, {0, 16}}},
{S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {16, 0}, {0, 4}}},
{S("warp"), {{32, 0}, {0, 0}, {0, 0}}},
{S("warp"), {{0, 0}, {0, 0}, {32, 0}}},
{S("block"), {}}},
{S("dim0"), S("dim1")}));
EXPECT_EQ(toLinearLayout({128, 128}, mfmaT),
LinearLayout(
{{S("register"), {{0, 1}, {0, 2}, {0, 8}, {0, 16}, {64, 0}}},
{S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {16, 0}, {0, 4}}},
{S("warp"), {{32, 0}, {0, 32}, {0, 64}}},
{S("warp"), {{0, 32}, {0, 64}, {32, 0}}},
{S("block"), {}}},
{S("dim0"), S("dim1")}));
}
Expand Down

0 comments on commit 47b2626

Please sign in to comment.