Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AMD] Always swap operands of mfma and use mfma.transposed layout #4767

Merged
merged 6 commits into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion include/triton/Dialect/TritonGPU/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,32 @@ getThreadsPerWarpWithUniqueData(Attribute layout,
SmallVector<unsigned>
getWarpsPerCTAWithUniqueData(Attribute layout, ArrayRef<int64_t> tensorShape);

// Returns the dimensions of the tensor from minor (fast-varying) to
// major (slow-varying). For blocked, mma, and dotOperand layouts,
// though the elements are in registers, the order refers to memory
// layout of the original tensor in global memory.
// For shared Layout, the order refers to which dimension of the original tensor
// is contiguous in shared memory.
SmallVector<unsigned> getOrder(Attribute layout);

// Returns the dimensions along which warpId's are distributed.
// warpsPerCTA only tells the warp layout in the CTA, e.g. warpsPerCTA = [2, 4]
// tells there are 2 warps along dim0 and 4 warps along dim1.
// warpOrder tells the specific order when distributing warp IDs.
// E.g. warpOrder = [0, 1] means the warp IDs are distributed as follows
// [warp0 warp2 warp4 warp6]
// [warp1 warp3 warp5 warp7]
// Note that in most cases, getWarpOrder and getOrder return the same results.
// But this is not guaranteed.
SmallVector<unsigned> getWarpOrder(Attribute layout);

SmallVector<unsigned> getOrder(Attribute layout);
// Returns the dimensions along which threadId's are distributed.
// Similar to warpOrder, threadOrder is necessary to tell the specific thread
// distribution in the warp.
// Note that, in most cases, getThreadOrder and getOrder return the same
// results. But this is not guaranteed. One exception is mfma.transposed layout,
// in which getOrder returns [1, 0] but getThreadOrder returns [0, 1].
SmallVector<unsigned> getThreadOrder(Attribute layout);
antiagainst marked this conversation as resolved.
Show resolved Hide resolved

CTALayoutAttr getCTALayout(Attribute layout);

Expand Down
4 changes: 2 additions & 2 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ SmallVector<unsigned> getParentOrder(Attribute layout) {
if (auto sliceEncoding = mlir::dyn_cast<SliceEncodingAttr>(layout)) {
return getParentOrder(sliceEncoding.getParent());
}
return getOrder(layout);
return getThreadOrder(layout);
}

} // namespace
Expand Down Expand Up @@ -75,7 +75,7 @@ unsigned ReduceOpHelper::getThreadOffsetOnReductionAxis() {
threadOffset = threadsPerWarp[sliceLayout.getDim()];
} else {
auto threadsPerWarp = getThreadsPerWarp(srcLayout);
auto order = getOrder(srcLayout);
auto order = getThreadOrder(srcLayout);
for (unsigned i = 0; i < order.size(); i++) {
if (order[i] == axis)
break;
Expand Down
3 changes: 2 additions & 1 deletion lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ using namespace mlir::triton;
using ::mlir::LLVM::delinearize;
using ::mlir::LLVM::linearize;
using ::mlir::triton::gpu::getOrder;
using ::mlir::triton::gpu::getThreadOrder;
using ::mlir::triton::gpu::getTotalElemsPerThread;

namespace {
Expand Down Expand Up @@ -271,7 +272,7 @@ struct ReduceOpConversion

auto threadsPerWarp =
triton::gpu::getThreadsPerWarpWithUniqueData(srcLayout, srcShape);
auto order = getOrder(srcLayout);
auto order = getThreadOrder(srcLayout);
SmallVector<Value> multiDimLaneId =
delinearize(rewriter, loc, laneId, threadsPerWarp, order);
Value laneIdAxis = multiDimLaneId[axis];
Expand Down
21 changes: 12 additions & 9 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -256,14 +256,6 @@ SmallVector<unsigned> getOrder(Attribute layout) {
auto rank = distributedLayout.getWarpsPerCTA().size();
SmallVector<unsigned> order(rank);
std::iota(order.rbegin(), order.rend(), 0);
auto mfmaLayout = dyn_cast<AMDMfmaEncodingAttr>(layout);
if (!mfmaLayout)
return order;
// For transposed MFMA layouts, we swap M and N dimensions, which is
// always the first two in order; as we can have an optional batch
// dimension following them.
if (mfmaLayout.getIsTransposed())
std::swap(order[0], order[1]);
return order;
}
if (auto dotLayout = dyn_cast<DotOperandEncodingAttr>(layout)) {
Expand All @@ -290,6 +282,14 @@ SmallVector<unsigned> getOrder(Attribute layout) {
return {};
};

SmallVector<unsigned> getThreadOrder(Attribute layout) {
if (auto distributedLayout = mlir::dyn_cast<DistributedEncodingTrait>(layout))
return distributedLayout.getThreadOrder();
else
llvm::report_fatal_error("Unimplemented usage of getThreadOrder");
return {};
};

CTALayoutAttr getCTALayout(Attribute layout) {
if (auto distributedLayout =
mlir::dyn_cast<DistributedEncodingTrait>(layout)) {
Expand Down Expand Up @@ -1536,7 +1536,10 @@ SmallVector<unsigned> AMDMfmaEncodingAttr::getWarpOrder() const {
return ::getWarpOrder(*this);
}
SmallVector<unsigned> AMDMfmaEncodingAttr::getThreadOrder() const {
return ::getOrder(*this);
auto order = ::getOrder(*this);
if (getIsTransposed())
std::swap(order[0], order[1]);
return order;
}
SmallVector<unsigned> AMDMfmaEncodingAttr::getThreadsPerWarp() const {
unsigned rows, cols;
Expand Down
14 changes: 14 additions & 0 deletions lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,13 @@ AMDMfmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
{{kRegister, {{0, 1}, {0, 2}, {0, 8}, /*gap*/ {0, 16}}},
{kLane, {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {16, 0}, /*gap*/ {0, 4}}}},
{outDimNames[order[0]], outDimNames[order[1]]});
// For mfma.transposed layout, the element ownership among threads are
// "transposed" within each warp.
if (getIsTransposed())
tileLayout = LinearLayout(
{{kRegister, {{1, 0}, {2, 0}, {8, 0}, /*gap*/ {16, 0}}},
{kLane, {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, /*gap*/ {4, 0}}}},
{outDimNames[order[0]], outDimNames[order[1]]});
} else {
assert(getMDim() == 16);
// For mfma with 16x16 output, each of the 64 threads holds 4 elements.
Expand All @@ -521,6 +528,13 @@ AMDMfmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
{{kRegister, {{0, 1}, {0, 2}}},
{kLane, {{1, 0}, {2, 0}, {4, 0}, {8, 0}, /*gap*/ {0, 4}, {0, 8}}}},
{outDimNames[order[0]], outDimNames[order[1]]});
// For mfma.transposed layout, the element ownership among threads are
// "transposed" within each warp.
if (getIsTransposed())
tileLayout = LinearLayout(
{{kRegister, {{1, 0}, {2, 0}}},
{kLane, {{0, 1}, {0, 2}, {0, 4}, {0, 8}, /*gap*/ {4, 0}, {8, 0}}}},
{outDimNames[order[0]], outDimNames[order[1]]});
}
if (hasBatchDim) {
assert(order[2] == 0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -269,23 +269,6 @@ class BlockedToMFMA : public RewritePattern {
: RewritePattern(tt::DotOp::getOperationName(), 2, context),
mfmaVersion(mfmaVersion), enforcedNonKDim(nonKDim), kPack(kPack) {}

bool isChainDot(tt::DotOp &dotOp) const {
auto filter = [&dotOp](Operation *op) {
return op->getParentRegion() == dotOp->getParentRegion();
};
ForwardSliceOptions fwdOpt;
fwdOpt.filter = filter;
BackwardSliceOptions bwdOpt;
bwdOpt.omitBlockArguments = true;
bwdOpt.filter = filter;
auto slices = getSlice(dotOp, bwdOpt, fwdOpt);
for (Operation *op : slices) {
if (isa<tt::DotOp>(op) && (op != dotOp))
return true;
}
return false;
}

bool isSecondDot(tt::DotOp &dotOp) const {
auto filter = [&dotOp](Operation *op) {
return op->getParentRegion() == dotOp->getParentRegion();
Expand Down Expand Up @@ -400,11 +383,12 @@ class BlockedToMFMA : public RewritePattern {
auto warpsPerTile =
warpsPerTileMFMA(dotOp, retShape, numWarps, {mDim, nDim});

bool isTransposed = isChainDot(dotOp);
// Always use transposed mfma layout. This enables larger vectorization
// for global store instructions
mfmaEnc = ttg::AMDMfmaEncodingAttr::get(
oldRetType.getContext(),
/*versionMajor*/ mfmaVersion, /*versionMinor*/ 0, warpsPerTile,
/*instrShape*/ mDim, nDim, isTransposed, CTALayout);
/*instrShape*/ mDim, nDim, /*isTransposed*/ true, CTALayout);
antiagainst marked this conversation as resolved.
Show resolved Hide resolved

Type mfmaAccType;
if (oldRetType.getElementType().isIntOrIndex())
Expand Down
8 changes: 4 additions & 4 deletions unittest/Dialect/TritonGPU/DialectTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -559,15 +559,15 @@ TEST_F(AMDMfmaLayoutTest, mfma32) {

auto tmfma2d = createTransposedMFMA(32, 32, {2, 4});
ASSERT_THAT(tmfma2d.getThreadOrder(), testing::ElementsAre(0u, 1u));
ASSERT_THAT(tmfma2d.getWarpOrder(), testing::ElementsAre(0u, 1u));
ASSERT_THAT(tmfma2d.getWarpOrder(), testing::ElementsAre(1u, 0u));

auto mfma3d = createMFMA(32, 32, {2, 4, 1});
ASSERT_THAT(mfma3d.getThreadOrder(), testing::ElementsAre(2u, 1u, 0u));
ASSERT_THAT(mfma3d.getWarpOrder(), testing::ElementsAre(2u, 1u, 0u));

auto tmfma3d = createTransposedMFMA(32, 32, {2, 4, 1});
ASSERT_THAT(tmfma3d.getThreadOrder(), testing::ElementsAre(1u, 2u, 0u));
ASSERT_THAT(tmfma3d.getWarpOrder(), testing::ElementsAre(1u, 2u, 0u));
ASSERT_THAT(tmfma3d.getWarpOrder(), testing::ElementsAre(2u, 1u, 0u));
}

TEST_F(AMDMfmaLayoutTest, mfma16) {
Expand All @@ -577,15 +577,15 @@ TEST_F(AMDMfmaLayoutTest, mfma16) {

auto tmfma2d = createTransposedMFMA(16, 16, {2, 4});
ASSERT_THAT(tmfma2d.getThreadOrder(), testing::ElementsAre(0u, 1u));
ASSERT_THAT(tmfma2d.getWarpOrder(), testing::ElementsAre(0u, 1u));
ASSERT_THAT(tmfma2d.getWarpOrder(), testing::ElementsAre(1u, 0u));

auto mfma3d = createMFMA(16, 16, {2, 4, 1});
ASSERT_THAT(mfma3d.getThreadOrder(), testing::ElementsAre(2u, 1u, 0u));
ASSERT_THAT(mfma3d.getWarpOrder(), testing::ElementsAre(2u, 1u, 0u));

auto tmfma3d = createTransposedMFMA(16, 16, {2, 4, 1});
ASSERT_THAT(tmfma3d.getThreadOrder(), testing::ElementsAre(1u, 2u, 0u));
ASSERT_THAT(tmfma3d.getWarpOrder(), testing::ElementsAre(1u, 2u, 0u));
ASSERT_THAT(tmfma3d.getWarpOrder(), testing::ElementsAre(2u, 1u, 0u));
}

} // anonymous namespace
Expand Down
4 changes: 2 additions & 2 deletions unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -529,14 +529,14 @@ TEST_F(LinearLayoutConversionsTest, MFMA32_2x4Warps) {
LinearLayout(
{{S("register"), {{0, 1}, {0, 2}, {0, 8}, {0, 16}}},
{S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {16, 0}, {0, 4}}},
{S("warp"), {{32, 0}, {0, 0}, {0, 0}}},
{S("warp"), {{0, 0}, {0, 0}, {32, 0}}},
{S("block"), {}}},
{S("dim0"), S("dim1")}));
EXPECT_EQ(toLinearLayout({128, 128}, mfmaT),
LinearLayout(
{{S("register"), {{0, 1}, {0, 2}, {0, 8}, {0, 16}, {64, 0}}},
{S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {16, 0}, {0, 4}}},
{S("warp"), {{32, 0}, {0, 32}, {0, 64}}},
{S("warp"), {{0, 32}, {0, 64}, {32, 0}}},
{S("block"), {}}},
{S("dim0"), S("dim1")}));
}
Expand Down
Loading