Skip to content

Commit

Permalink
[Codegen][GPU] Add support for all other intrinsics to TileAndFuse (i…
Browse files Browse the repository at this point in the history
…ree-org#18179)

This adds the ConcretizeMmaShapes pass to the LLVMGPUTileAndFuse
pipeline to add support for other intrinsic types, in particular MFMA
and WMMA variants that require reshaping of the accumulator to match
requirements of the layout.

This also reworks the reshaping code to use SingleSubgroupLayout instead
of VectorExt::PerDimLayoutAttr to drop an unneeded dialect dependency
and also simplify the IR for cases where reshaping is not needed. In
particular, when there is a unit `outer` dimension in a layout, no
additional reshaping is needed so we can omit the reshapes in such
cases. There is an option in the future to still do such reshaping so as
to pre-swizzle the data needed for the MMA during the store to shared
memory, but the details for how best to implement that are left as TODO.
  • Loading branch information
qedawkins authored Aug 13, 2024
1 parent 3901e62 commit 7812c77
Show file tree
Hide file tree
Showing 9 changed files with 747 additions and 92 deletions.
77 changes: 36 additions & 41 deletions compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -664,9 +664,7 @@ static LogicalResult populateCanonicalOffsetsSizesAndStrides(
llvm::zip_equal(subgroupLayout.outer, subgroupLayout.thread,
subgroupLayout.element)) {
if (outer != 1) {
// TODO: Support this case. Might need a reshape since this makes the
// slice non-contigious.
return failure();
rankReducedShape.push_back(outer);
}
rankReducedShape.push_back(thread * element);
}
Expand All @@ -690,6 +688,7 @@ static LogicalResult populateCanonicalOffsetsSizesAndStrides(
subgroupLayout.element)) {
if (dimSize == 1) {
vtids.push_back(zero);
continue;
}

// ((tid floordiv stride) mod size) * element.
Expand All @@ -702,7 +701,12 @@ static LogicalResult populateCanonicalOffsetsSizesAndStrides(
}

int64_t idx = 0;
for (int64_t element : subgroupLayout.element) {
for (auto [element, outer] :
llvm::zip_equal(subgroupLayout.element, subgroupLayout.outer)) {
if (outer != 1) {
canonicalSizes.push_back(builder.getIndexAttr(outer));
canonicalOffsets.push_back(zero);
}
canonicalSizes.push_back(builder.getIndexAttr(element));
canonicalOffsets.push_back(vtids[idx++]);
}
Expand All @@ -716,13 +720,6 @@ LogicalResult MMAAttr::populateOperandOffsetsSizesStrides(
Value laneId, ArrayRef<int64_t> permutation,
SmallVector<OpFoldResult> &offsets, SmallVector<OpFoldResult> &sizes,
SmallVector<OpFoldResult> &strides) const {
switch (getIntrinsic().getValue()) {
case MMAIntrinsic::MFMA_F32_16x16x16_F16:
case MMAIntrinsic::MFMA_I32_16x16x32_I8:
break;
default:
return failure();
}

MMAAttr::SingleSubgroupLayout subgroupLayout;
switch (fragment) {
Expand Down Expand Up @@ -758,47 +755,33 @@ LogicalResult MMAAttr::materializeOperandConcreteShape(
std::optional<ArrayRef<int64_t>> permutation,
SmallVector<ReassociationIndices> &reassociations,
RankedTensorType &resultType) const {
OpaqueMmaLayout opaqueLayout =
getOpaqueMFMALayout(operand.getContext(), getIntrinsic().getValue());
// TODO(Max191): The `getConcreteMFMALayout` function creates some
// `PerDimLayoutAttr` that are not used by this function. This means that
// any pass that uses `materializeOperandConcreteShape` needs to be
// dependent on the VectorExt dialect. Ideally, the `getConcreteMFMALayout`
// function should be refactored so we can reuse the shape information of
// the layout without needing to create any `PerDimLayoutAttr`.
ConcreteMmaLayout layout =
getConcreteMFMALayout(operand.getContext(), getIntrinsic().getValue());
SmallVector<ArrayRef<int64_t>> concreteSizes;

SmallVector<int64_t, 2> outerSizes;
SmallVector<int64_t, 2> opaqueSizes;
auto [m, n, k] = getMNKShape();
switch (fragment) {
case IREE::GPU::MMAFragment::Lhs: {
concreteSizes.push_back(layout.aMLayout.getShapes());
concreteSizes.push_back(layout.aKLayout.getShapes());
opaqueSizes.push_back(opaqueLayout.mSize);
opaqueSizes.push_back(opaqueLayout.kSize);
outerSizes = getASingleSubgroupLayout().outer;
opaqueSizes.append({m, k});
break;
}
case IREE::GPU::MMAFragment::Rhs: {
concreteSizes.push_back(layout.bKLayout.getShapes());
concreteSizes.push_back(layout.bNLayout.getShapes());
opaqueSizes.push_back(opaqueLayout.kSize);
opaqueSizes.push_back(opaqueLayout.nSize);
outerSizes = getBSingleSubgroupLayout().outer;
opaqueSizes.append({k, n});
break;
}
case IREE::GPU::MMAFragment::Acc: {
concreteSizes.push_back(layout.cMLayout.getShapes());
concreteSizes.push_back(layout.cNLayout.getShapes());
opaqueSizes.push_back(opaqueLayout.mSize);
opaqueSizes.push_back(opaqueLayout.nSize);
outerSizes = getCSingleSubgroupLayout().outer;
opaqueSizes.append({m, n});
break;
}
}
if (permutation.has_value()) {
if (permutation.value().size() != opaqueSizes.size()) {
if (permutation.value().size() != outerSizes.size()) {
return failure();
}
applyPermutationToVector(concreteSizes, permutation.value());
applyPermutationToVector(opaqueSizes, permutation.value());
applyPermutationToVector(outerSizes, permutation.value());
}

// Inner tile must have sizes matching the opaque layout.
Expand All @@ -819,11 +802,23 @@ LogicalResult MMAAttr::materializeOperandConcreteShape(
return ReassociationIndices({idx});
});
int idx = reInds.size();
for (ArrayRef<int64_t> sizes : concreteSizes) {
resultShape.append(SmallVector<int64_t>(sizes));
reInds.push_back(
llvm::to_vector(llvm::seq<int64_t>(idx, idx + sizes.size())));
idx += sizes.size();
for (auto [outer, native] : llvm::zip_equal(outerSizes, opaqueSizes)) {
// Skip expansion if the outer dim is unit as the SingleSubgroupLayout gives
// a guarantee that the |element| counts are contiguous within the layout,
// and a unit outer implies a single offset and size for that dimension.
if (outer == 1) {
resultShape.push_back(native);
reInds.push_back(ReassociationIndices({idx++}));
continue;
}

// Reshape to [outer, native / outer] == [outer, thread * element]. This
// corresponds to |outer| repetitions of the thread/element sublayout.
resultShape.push_back(outer);
assert(native % outer == 0 && "invalid mma layout");
resultShape.push_back(native / outer);
reInds.push_back(ReassociationIndices{idx, idx + 1});
idx += 2;
}

reassociations = reInds;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,16 +73,7 @@ LogicalResult setMatmulLoweringConfig(IREE::GPU::TargetAttr target,
lhsElemType, rhsElemType, initElemType};

SmallVector<GPUMatmulShapeType> intrinsics;
SmallVector<IREE::GPU::MmaInterfaceAttr> supportedMmas;
for (IREE::GPU::MMAAttr mma : target.getWgp().getMma()) {
IREE::GPU::MMAIntrinsic type = mma.getIntrinsic().getValue();
// TODO: Drop this once all intrinsics are supported.
if (type != IREE::GPU::MMAIntrinsic::MFMA_F32_16x16x16_F16 &&
type != IREE::GPU::MMAIntrinsic::MFMA_I32_16x16x32_I8) {
continue;
}
supportedMmas.push_back(mma);

auto [mSize, nSize, kSize] = mma.getMNKShape();
auto [aType, bType, cType] = mma.getABCElementTypes();
if (mma.getSubgroupSize() != targetSubgroupSize)
Expand Down Expand Up @@ -185,7 +176,8 @@ LogicalResult setMatmulLoweringConfig(IREE::GPU::TargetAttr target,
// Similarly the reduction tile size is just the post-packing tile count.
reductionTileSizes[kDim] = schedule->kTileCount;

IREE::GPU::MmaInterfaceAttr mmaKind = supportedMmas[schedule->index];
IREE::GPU::MmaInterfaceAttr mmaKind =
target.getWgp().getMma()[schedule->index];

// Attach the MMA schedule as an attribute to the entry point export function
// for later access in the pipeline.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.h"
#include "iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.h"
#include "iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h"
#include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtDialect.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

Expand Down Expand Up @@ -66,6 +65,13 @@ struct ConcretizeMmaOperandShape final : OpRewritePattern<MultiMmaOp> {
return failure();
}

// Early exit if the operand is unaffected.
if (llvm::all_of(reassociations, [](ReassociationIndices reassoc) {
return reassoc.size() == 1;
})) {
return failure();
}

// Create the expand_shape.
Location loc = mmaOp->getLoc();
Value concreteOperand = rewriter
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def DistributeMmaToLanesPass :
"::mlir::arith::ArithDialect",
"::mlir::affine::AffineDialect",
"::mlir::scf::SCFDialect",
"::mlir::tensor::TensorDialect",
];
}

Expand All @@ -25,7 +26,6 @@ def ConcretizeMmaShapesPass :
let dependentDialects = [
"::mlir::tensor::TensorDialect",
"::mlir::iree_compiler::IREE::GPU::IREEGPUDialect",
"::mlir::iree_compiler::IREE::VectorExt::IREEVectorExtDialect",
];
let options = [
Option<"concretizeInputs", "concretize-inputs",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,8 @@ convertContractionToMultiMma(RewriterBase &rewriter, linalg::LinalgOp linalgOp,
FailureOr<Operation *> distributeMultiMmaOp(RewriterBase &rewriter,
IREE::GPU::MultiMmaOp mmaOp) {
if (!mmaOp.hasTensorSemantics() || mmaOp.hasThreadSemantics()) {
return failure();
return rewriter.notifyMatchFailure(
mmaOp, "mmaOp must have vector and subgroup for distribution.");
}

OpBuilder::InsertionGuard g(rewriter);
Expand Down Expand Up @@ -508,7 +509,7 @@ FailureOr<Operation *> distributeMultiMmaOp(RewriterBase &rewriter,
if (failed(mmaOp.getKind().populateOperandOffsetsSizesStrides(
rewriter, loc, IREE::GPU::MMAFragment::Lhs, laneId, lhsPermutation,
lhsOffsets, lhsSizes, lhsStrides))) {
return failure();
return mmaOp->emitOpError("failed to populate lhs offsets");
}
// Extract the rank-reduced slice of the lhs based on the expected inner
// vector shape.
Expand All @@ -528,7 +529,7 @@ FailureOr<Operation *> distributeMultiMmaOp(RewriterBase &rewriter,
if (failed(mmaOp.getKind().populateOperandOffsetsSizesStrides(
rewriter, loc, IREE::GPU::MMAFragment::Rhs, laneId, rhsPermutation,
rhsOffsets, rhsSizes, rhsStrides))) {
return failure();
return mmaOp->emitOpError("failed to populate rhs offsets");
}
// Extract the rank-reduced slice of the rhs based on the expected inner
// vector shape.
Expand All @@ -548,7 +549,7 @@ FailureOr<Operation *> distributeMultiMmaOp(RewriterBase &rewriter,
if (failed(mmaOp.getKind().populateOperandOffsetsSizesStrides(
rewriter, loc, IREE::GPU::MMAFragment::Acc, laneId, accPermutation,
accOffsets, accSizes, accStrides))) {
return failure();
return mmaOp->emitOpError("failed to populate acc offsets");
}
// Extract the rank-reduced slice of the accumulator based on the expected
// inner vector shape.
Expand Down
Loading

0 comments on commit 7812c77

Please sign in to comment.