Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Upstream] Decouple the Triton GPU lowering utils. #992

Closed
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 10 additions & 33 deletions include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -1040,21 +1040,7 @@ emitBaseIndexForLayoutImpl(Location loc, RewriterBase &rewriter,
SmallVector<Value> baseIndex;
RewriterBase::InsertionGuard guard(rewriter);
SmallVector<Value> result;
if (auto blockedLayout = mlir::dyn_cast<BlockedEncodingAttr>(layout)) {
result = emitBaseIndexWithinCTAForBlockedLayout(loc, rewriter,
blockedLayout, type);
} else if (auto mmaLayout = mlir::dyn_cast<NvidiaMmaEncodingAttr>(layout)) {
if (mmaLayout.isVolta())
result =
emitBaseIndexWithinCTAForMmaLayoutV1(loc, rewriter, mmaLayout, type);
if (mmaLayout.isAmpere() || mmaLayout.isHopper())
result = emitBaseIndexWithinCTAForMmaLayoutV2V3(loc, rewriter, mmaLayout,
type);
} else if (auto mfmaLayout = mlir::dyn_cast<AMDMfmaEncodingAttr>(layout)) {
result = emitBaseIndexForMfmaLayout(loc, rewriter, mfmaLayout, type);
} else if (auto wmmaLayout = mlir::dyn_cast<AMDWmmaEncodingAttr>(layout)) {
result = emitBaseIndexForWmmaLayout(loc, rewriter, wmmaLayout, type);
} else if (auto sliceLayout = mlir::dyn_cast<SliceEncodingAttr>(layout)) {
if (auto sliceLayout = mlir::dyn_cast<SliceEncodingAttr>(layout)) {
auto parentLayout = sliceLayout.getParent();
auto parentShape = sliceLayout.paddedShape(type.getShape());
RankedTensorType parentTy =
Expand All @@ -1064,8 +1050,13 @@ emitBaseIndexForLayoutImpl(Location loc, RewriterBase &rewriter,
result.erase(result.begin() + sliceLayout.getDim());
// CTAOffset has been added in emitBaseIndexForLayout of parentLayout
return result;
} else if (auto distributedLayout =
mlir::dyn_cast<triton::gpu::DistributedEncodingTrait>(
layout)) {
result =
distributedLayout.emitBaseIndexWithinCTAForLayout(loc, rewriter, type);
} else {
llvm_unreachable("unsupported emitBaseIndexForLayout");
llvm_unreachable("unsupported emitBaseIndexWithinCTAForLayout");
}
if (withCTAOffset) {
auto CTAOffset =
Expand Down Expand Up @@ -1108,24 +1099,10 @@ emitBaseIndexForLayout(Location loc, RewriterBase &rewriter,

inline SmallVector<SmallVector<unsigned>>
emitOffsetForLayout(Attribute layout, RankedTensorType type) {
if (auto blockedLayout = dyn_cast<BlockedEncodingAttr>(layout))
return emitOffsetForBlockedLayout(blockedLayout, type);
if (auto mmaLayout = dyn_cast<NvidiaMmaEncodingAttr>(layout)) {
if (mmaLayout.isVolta())
return emitOffsetForMmaLayoutV1(mmaLayout, type);
if (mmaLayout.isAmpere())
return emitOffsetForMmaLayoutV2(mmaLayout, type);
if (mmaLayout.isHopper())
return emitOffsetForMmaLayoutV3(mmaLayout, type);
}
if (auto mfmaLayout = mlir::dyn_cast<AMDMfmaEncodingAttr>(layout)) {
return emitOffsetForMfmaLayout(mfmaLayout, type);
}
if (auto wmmaLayout = mlir::dyn_cast<AMDWmmaEncodingAttr>(layout)) {
return emitOffsetForWmmaLayout(wmmaLayout, type);
if (auto distributedLayout =
mlir::dyn_cast<triton::gpu::DistributedEncodingTrait>(layout)) {
return distributedLayout.emitOffsetForLayout(type);
}
if (auto sliceLayout = mlir::dyn_cast<SliceEncodingAttr>(layout))
return emitOffsetForSliceLayout(sliceLayout, type);
llvm_unreachable("unsupported emitOffsetForLayout");
}

Expand Down
57 changes: 56 additions & 1 deletion include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,18 @@ For the Threads Per Warp and Values Per Thread level, the linear id distribution
InterfaceMethod<"Gets the number of contiguous elements per thread.",
"SmallVector<unsigned>",
"getContigPerThread">,

InterfaceMethod<"emit base index",
"SmallVector<Value>",
"emitBaseIndexWithinCTAForLayout",
(ins "Location":$loc,
"RewriterBase&":$rewriter,
"RankedTensorType":$type)>,

InterfaceMethod<"emit offset",
"SmallVector<SmallVector<unsigned>>",
"emitOffsetForLayout",
(ins "RankedTensorType":$type)>,
Comment on lines +530 to +540
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sidenote: you can provide a default implementation changing this to:

Suggested change
InterfaceMethod<"emit base index",
"SmallVector<Value>",
"emitBaseIndexWithinCTAForLayout",
(ins "Location":$loc,
"RewriterBase&":$rewriter,
"RankedTensorType":$type)>,
InterfaceMethod<"emit offset",
"SmallVector<SmallVector<unsigned>>",
"emitOffsetForLayout",
(ins "RankedTensorType":$type)>,
InterfaceMethod<"emit base index",
"SmallVector<Value>",
"emitBaseIndexWithinCTAForLayout",
(ins "Location":$loc,
"RewriterBase&":$rewriter,
"RankedTensorType":$type),
/*methodBody=*/[{}],
/*defaultImplementation=*/[{
// IMPL
}]>,
InterfaceMethod<"emit offset",
"SmallVector<SmallVector<unsigned>>",
"emitOffsetForLayout",
(ins "RankedTensorType":$type),
/*methodBody=*/"",
/*defaultImplementation=*/[{
// IMPL
}]>,

Then, in attriburtes implementing this interface, instead of inheriting like:

[InterfaceFoo]

do:

[DeclareAttrInterfaceMethods<InterfaceFoo>]

See for context

];
}

Expand Down Expand Up @@ -745,6 +757,12 @@ for
// Block encoding is dense stride layout. The elements per thread are contiguous.
return getSizePerThread();
};

SmallVector<Value> emitBaseIndexWithinCTAForLayout(Location loc,
RewriterBase &rewriter,
RankedTensorType type) const;

SmallVector<SmallVector<unsigned>> emitOffsetForLayout(RankedTensorType type) const;
}];

let hasCustomAssemblyFormat = 1;
Expand Down Expand Up @@ -904,6 +922,11 @@ V [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,
return contigPerThread;
};

SmallVector<Value> emitBaseIndexWithinCTAForLayout(Location loc,
RewriterBase &rewriter,
RankedTensorType type) const;

SmallVector<SmallVector<unsigned>> emitOffsetForLayout(RankedTensorType type) const;
}];

let hasCustomAssemblyFormat = 1;
Expand Down Expand Up @@ -968,6 +991,13 @@ Suppose we have a tensor with shape [32, 48], `warpsPerCTA` set to [2, 3].
SmallVector<unsigned> contigPerThread(rank, 1);
return contigPerThread;
};

SmallVector<Value> emitBaseIndexWithinCTAForLayout(Location loc,
RewriterBase &rewriter,
RankedTensorType type) const;

SmallVector<SmallVector<unsigned>> emitOffsetForLayout(RankedTensorType type) const;

}];
}

Expand Down Expand Up @@ -1169,6 +1199,12 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
return contigPerThread;
};

SmallVector<Value> emitBaseIndexWithinCTAForLayout(Location loc,
RewriterBase &rewriter,
RankedTensorType type) const;

SmallVector<SmallVector<unsigned>> emitOffsetForLayout(RankedTensorType type) const;

}];

let hasCustomAssemblyFormat = 1;
Expand Down Expand Up @@ -1210,6 +1246,15 @@ def SliceEncodingAttr : DistributedEncoding<"SliceEncoding", "slice_encoding"> {
auto parentLayout = mlir::cast<DistributedEncodingTrait>(getParent());
return parentLayout.getContigPerThread();
};


SmallVector<Value> emitBaseIndexWithinCTAForLayout(Location loc,
RewriterBase &rewriter,
RankedTensorType type) const {
llvm_unreachable("SliceEncodingAttr doesn't support emitBaseIndexWithinCTAForLayout interface");
};

SmallVector<SmallVector<unsigned>> emitOffsetForLayout(RankedTensorType type) const;
}];

let hasCustomAssemblyFormat = 1;
Expand Down Expand Up @@ -1256,9 +1301,19 @@ section 9.7.13.4.1 for more details.

let hasCustomAssemblyFormat = 1;
let extraClassDeclaration = extraDistributedDeclaration # [{
SmallVector<unsigned> getContigPerThread() {
SmallVector<unsigned> getContigPerThread() const {
return getSizePerThread();
};

SmallVector<Value> emitBaseIndexWithinCTAForLayout(Location loc,
RewriterBase &rewriter,
RankedTensorType type) const {
llvm_unreachable("DotOperandEncodingAttr doesn't support emitBaseIndexWithinCTAForLayout interface");
};

SmallVector<SmallVector<unsigned>> emitOffsetForLayout(RankedTensorType type) const {
llvm_unreachable("DotOperandEncodingAttr doesn't support emitOffsetForLayout interface");
};
}];
}

Expand Down
57 changes: 57 additions & 0 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#include "mlir/Support/LLVM.h"
#include "triton/Analysis/Utility.h"
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
#include "triton/Dialect/Triton/IR/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.cpp.inc"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
Expand Down Expand Up @@ -642,6 +643,16 @@ BlockedEncodingAttr::getShapePerCTATile(ArrayRef<int64_t> tensorShape) const {
return shape;
}

SmallVector<Value> BlockedEncodingAttr::emitBaseIndexWithinCTAForLayout(
Location loc, RewriterBase &rewriter, RankedTensorType type) const {
return emitBaseIndexWithinCTAForBlockedLayout(loc, rewriter, *this, type);
}

SmallVector<SmallVector<unsigned>>
BlockedEncodingAttr::emitOffsetForLayout(RankedTensorType type) const {
return emitOffsetForBlockedLayout(*this, type);
}

template <class T>
SmallVector<T> SliceEncodingAttr::paddedShape(ArrayRef<T> shape) const {
size_t rank = shape.size();
Expand Down Expand Up @@ -749,6 +760,11 @@ SliceEncodingAttr::getShapePerCTATile(ArrayRef<int64_t> tensorShape) const {
return shape;
}

SmallVector<SmallVector<unsigned>>
SliceEncodingAttr::emitOffsetForLayout(RankedTensorType type) const {
return emitOffsetForSliceLayout(*this, type);
};

//

SmallVector<unsigned>
Expand Down Expand Up @@ -787,6 +803,16 @@ unsigned AMDMfmaEncodingAttr::getTotalElemsPerThread(ArrayRef<int64_t> shape,
return product<unsigned>(getElemsPerThread(shape, eltTy));
}

SmallVector<Value> AMDMfmaEncodingAttr::emitBaseIndexWithinCTAForLayout(
Location loc, RewriterBase &rewriter, RankedTensorType type) const {
return emitBaseIndexForMfmaLayout(loc, rewriter, *this, type);
}

SmallVector<SmallVector<unsigned>>
AMDMfmaEncodingAttr::emitOffsetForLayout(RankedTensorType type) const {
return emitOffsetForMfmaLayout(*this, type);
}

//

SmallVector<unsigned>
Expand Down Expand Up @@ -899,6 +925,27 @@ unsigned NvidiaMmaEncodingAttr::getTotalElemsPerThread(ArrayRef<int64_t> shape,
return product<unsigned>(getElemsPerThread(shape, eltTy));
}

SmallVector<Value> NvidiaMmaEncodingAttr::emitBaseIndexWithinCTAForLayout(
Location loc, RewriterBase &rewriter, RankedTensorType type) const {
SmallVector<Value> result;
if (isVolta())
result = emitBaseIndexWithinCTAForMmaLayoutV1(loc, rewriter, *this, type);
if (isAmpere() || isHopper())
result = emitBaseIndexWithinCTAForMmaLayoutV2V3(loc, rewriter, *this, type);
return result;
}

SmallVector<SmallVector<unsigned>>
NvidiaMmaEncodingAttr::emitOffsetForLayout(RankedTensorType type) const {
if (isVolta())
return emitOffsetForMmaLayoutV1(*this, type);
if (isAmpere())
return emitOffsetForMmaLayoutV2(*this, type);
if (isHopper())
return emitOffsetForMmaLayoutV3(*this, type);
llvm_unreachable("unsupported emitOffsetForLayout");
}

//

SmallVector<unsigned>
Expand Down Expand Up @@ -1672,6 +1719,16 @@ SmallVector<unsigned> AMDWmmaEncodingAttr::getMNKDimPerWMMAInstr() {
return {16, 16, 16};
}

SmallVector<Value> AMDWmmaEncodingAttr::emitBaseIndexWithinCTAForLayout(
Location loc, RewriterBase &rewriter, RankedTensorType type) const {
return emitBaseIndexForWmmaLayout(loc, rewriter, *this, type);
}

SmallVector<SmallVector<unsigned>>
AMDWmmaEncodingAttr::emitOffsetForLayout(RankedTensorType type) const {
return emitOffsetForWmmaLayout(*this, type);
}

//===----------------------------------------------------------------------===//
// Mma encoding
//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,12 @@ along the row (resp. col) dimension.
}

SmallVector<unsigned> getContigPerThread();

SmallVector<Value> emitBaseIndexWithinCTAForLayout(Location loc,
RewriterBase &rewriter,
RankedTensorType type) const;

SmallVector<SmallVector<unsigned>> emitOffsetForLayout(RankedTensorType type) const;
}];

let hasCustomAssemblyFormat = 1;
Expand Down
Loading