Skip to content

Commit

Permalink
[VectorDistribution] Reuse intrinsic layout in chained gemm (#18505)
Browse files Browse the repository at this point in the history
This patch teaches attention codegen pipeline to reuse the intrinsic
layout of output of the first matmul as the lhs of the second matmul.
This is possible for 16x16x16 and 32x32x8 MFMA intrinsic layouts.
  • Loading branch information
Groverkss committed Sep 20, 2024
1 parent 0f15c8d commit 914858f
Show file tree
Hide file tree
Showing 7 changed files with 364 additions and 46 deletions.
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ iree_compiler_cc_library(
"//compiler/src/iree/compiler/Codegen/Utils:VectorOpUtils",
"//compiler/src/iree/compiler/Dialect/Encoding/IR",
"//compiler/src/iree/compiler/Dialect/HAL/IR",
"//compiler/src/iree/compiler/Utils",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AMDGPUDialect",
"@llvm-project//mlir:AffineDialect",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ iree_cc_library(
iree::compiler::Codegen::Utils::VectorOpUtils
iree::compiler::Dialect::Encoding::IR
iree::compiler::Dialect::HAL::IR
iree::compiler::Utils
PUBLIC
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1020,10 +1020,16 @@ struct DistributeTrivialLayoutConversions final
PatternRewriter &rewriter) const override {
auto input = cast<VectorValue>(toLayoutOp.getInput());
auto output = cast<VectorValue>(toLayoutOp.getOutput());
VectorLayoutInterface currentLayout =
dyn_cast<LayoutAttr>(signature[input]);
VectorLayoutInterface targetLayout =
dyn_cast<LayoutAttr>(signature[output]);
VectorLayoutInterface currentLayout = signature[input];
VectorLayoutInterface targetLayout = signature[output];

if (!currentLayout) {
return rewriter.notifyMatchFailure(toLayoutOp, "No layout set on input");
}

if (!targetLayout) {
return rewriter.notifyMatchFailure(toLayoutOp, "No layout set on output");
}

if (currentLayout != targetLayout) {
return rewriter.notifyMatchFailure(toLayoutOp,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "iree/compiler/Codegen/Common/VectorLayoutAnalysis.h"
#include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtDialect.h"
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
#include "iree/compiler/Utils/Permutation.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/FormatVariadic.h"
Expand Down Expand Up @@ -601,6 +602,93 @@ struct DistributeTranspose final : OpDistributionPattern<vector::TransposeOp> {
}
};

struct DistributeBatchOuterToLayoutConversions final
: OpDistributionPattern<IREE::VectorExt::ToLayoutOp> {
using OpDistributionPattern::OpDistributionPattern;

LogicalResult matchAndRewrite(IREE::VectorExt::ToLayoutOp toLayoutOp,
DistributionSignature &signature,
PatternRewriter &rewriter) const override {
Location loc = toLayoutOp.getLoc();
auto input = cast<VectorValue>(toLayoutOp.getInput());
auto output = cast<VectorValue>(toLayoutOp.getOutput());
auto layoutA = dyn_cast<NestedLayoutAttr>(signature[input]);
auto layoutB = dyn_cast<NestedLayoutAttr>(signature[output]);

if (!layoutA || !layoutB) {
return rewriter.notifyMatchFailure(toLayoutOp, "non-nested layout");
}

// Check if everything other than batch and outer tile matches.
if (layoutA.getSubgroupTile() != layoutB.getSubgroupTile()) {
return failure();
}
if (layoutA.getSubgroupStrides() != layoutB.getSubgroupStrides()) {
return failure();
}
if (layoutA.getThreadTile() != layoutB.getThreadTile()) {
return failure();
}
if (layoutA.getThreadStrides() != layoutB.getThreadStrides()) {
return failure();
}
if (layoutA.getElementTile() != layoutB.getElementTile()) {
return failure();
}

auto batchTileA = SmallVector<int64_t>(layoutA.getBatchTile());
auto outerTileA = SmallVector<int64_t>(layoutA.getOuterTile());
auto batchTileB = SmallVector<int64_t>(layoutB.getBatchTile());
auto outerTileB = SmallVector<int64_t>(layoutB.getOuterTile());

// Check if there is a batch/outer tile mismatch.
if (batchTileA == batchTileB && outerTileA == outerTileB) {
return rewriter.notifyMatchFailure(toLayoutOp,
"trivial layout conversion");
}

SmallVector<int64_t> shapeA = layoutA.getDistributedShape();
SmallVector<int64_t> shapeB = layoutB.getDistributedShape();
int64_t rank = layoutA.getRank();

// Interleave batch and outer dims by transposing.

// Build a permutation for interleaving.
auto interleavePermutation =
llvm::to_vector(llvm::seq<int64_t>(shapeA.size()));
for (int i = 0; i < rank; ++i) {
// Batch tile : [0...rank]
// OuterTile : [rank+1...2*rank]
// Interleave : [batch0, outer0, batch1, outer1,...]
interleavePermutation[2 * i] = i;
interleavePermutation[2 * i + 1] = i + rank;
}

auto interleaved = rewriter.create<vector::TransposeOp>(
loc, getDistributed(rewriter, input, layoutA), interleavePermutation);

// Shape cast to match the new layout.

SmallVector<int64_t> transposedShapeB(shapeB);
applyPermutationToVector(transposedShapeB, interleavePermutation);
Type reshapedType = VectorType::get(
transposedShapeB, interleaved.getResultVectorType().getElementType());

auto reshaped =
rewriter.create<vector::ShapeCastOp>(loc, reshapedType, interleaved);

// Inverse transpose to preserve original order.
SmallVector<int64_t> invertedPermutation =
invertPermutationVector(interleavePermutation);

auto layouted = rewriter.create<vector::TransposeOp>(loc, reshaped,
invertedPermutation);

replaceOpWithDistributedValues(rewriter, toLayoutOp, layouted.getResult());
return success();
}
};

} // namespace

void populateGPUDistributeNestedLayoutAttrPatterns(RewritePatternSet &patterns,
Expand All @@ -612,6 +700,7 @@ void populateGPUDistributeNestedLayoutAttrPatterns(RewritePatternSet &patterns,
patterns.add<DistributeBroadcast, DistributeTranspose>(patterns.getContext());
patterns.add<DistributeMultiReduction>(patterns.getContext(), subgroupSize,
maxBitsPerShuffle);
patterns.add<DistributeBatchOuterToLayoutConversions>(patterns.getContext());
}

}; // namespace mlir::iree_compiler
Loading

0 comments on commit 914858f

Please sign in to comment.