Skip to content

Commit

Permalink
[LLVMCPU] Drop unit dims on memory transfers (#13340)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jerry Wu authored Jun 23, 2023
1 parent 6c016ca commit 88d92bf
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 21 deletions.
2 changes: 1 addition & 1 deletion compiler/src/iree/compiler/Codegen/Common/CommonPasses.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ std::unique_ptr<OperationPass<func::FuncOp>> createMemrefCopyToLinalgPass();

/// Pass to optimize vector transfer_read and transfer_write.
std::unique_ptr<OperationPass<func::FuncOp>> createOptimizeVectorTransferPass(
bool flatten = false);
bool flatten = false, bool dropUnitDims = true);

/// Pad dynamic alloc op to convert them into static one.
std::unique_ptr<OperationPass<func::FuncOp>> createPadDynamicAlloc();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ static void loopInvariantCodeMotion(func::FuncOp funcOp) {

struct OptimizeVectorTransferPass
: public OptimizeVectorTransferBase<OptimizeVectorTransferPass> {
OptimizeVectorTransferPass(bool flatten) : flatten(flatten) {}
OptimizeVectorTransferPass(bool flatten, bool dropUnitDims)
: flatten(flatten), dropUnitDims(dropUnitDims) {}
void runOnOperation() override {
func::FuncOp funcOp = getOperation();
// Generate vector.shape_cast for dropping leading one dimensions in vector
Expand Down Expand Up @@ -125,10 +126,20 @@ struct OptimizeVectorTransferPass
}
}

// TODO(#14191): SPIR-V can't handle the vector.shape_cast created for
// dropping unit dims so this option is disabled in SPIR-V pipeline.
// This option should go away after all backend issues have been resolved.
if (dropUnitDims) {
RewritePatternSet patterns(&getContext());
mlir::vector::populateVectorTransferDropUnitDimsPatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
return signalPassFailure();
}
}

// Second stage of patterns to flatten transfer ops.
if (flatten) {
RewritePatternSet patterns(&getContext());
mlir::vector::populateVectorTransferDropUnitDimsPatterns(patterns);
mlir::vector::populateFlattenVectorTransferPatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
return signalPassFailure();
Expand All @@ -151,13 +162,14 @@ struct OptimizeVectorTransferPass

private:
bool flatten;
bool dropUnitDims;
};

} // namespace

std::unique_ptr<OperationPass<func::FuncOp>> createOptimizeVectorTransferPass(
bool flatten) {
return std::make_unique<OptimizeVectorTransferPass>(flatten);
bool flatten, bool dropUnitDims) {
return std::make_unique<OptimizeVectorTransferPass>(flatten, dropUnitDims);
}

} // namespace iree_compiler
Expand Down
8 changes: 6 additions & 2 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -552,12 +552,16 @@ void addConvTileAndDecomposeExpertPassPipeline(OpPassManager &passManager,
nestedModulePM.addNestedPass<func::FuncOp>(createCSEPass());
}

nestedModulePM.addNestedPass<func::FuncOp>(createCSEPass());
nestedModulePM.addNestedPass<func::FuncOp>(createCanonicalizerPass());
// Eliminate redundant transfer_read/write to avoid stack allocations.
nestedModulePM.addNestedPass<func::FuncOp>(
createOptimizeVectorTransferPass(/*flatten=*/true));

addBufferizePasses(nestedModulePM);

// Perform memref-based transfer_read/write optimizations.
nestedModulePM.addNestedPass<func::FuncOp>(
createOptimizeVectorTransferPass(/*flatten=*/false));

// Run IREE specific passes before vector lowering expert.
nestedModulePM.addNestedPass<func::FuncOp>(
createRemoveSingleIterationLoopPass());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,4 +217,5 @@ hal.executable private @pad_consumer_fusion {
// CHECK: scf.yield
// CHECK: scf.yield
// CHECK: scf.yield
// CHECK-COUNT-7: vector.store %{{.+}}, %[[OUTPUT_SUBVIEW_0]]
// CHECK: %[[OUTPUT_SUBVIEW_1:.+]] = memref.subview %[[OUTPUT_SUBVIEW_0]]
// CHECK-COUNT-7: vector.store %{{.+}}, %[[OUTPUT_SUBVIEW_1]]
4 changes: 3 additions & 1 deletion compiler/src/iree/compiler/Codegen/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,9 @@ def OptimizeVectorTransfer :
let constructor = "mlir::iree_compiler::createOptimizeVectorTransferPass()";
let options = [
Option<"optionFlatten", "flatten", "bool", "false",
"Flatten the vector type of vector transfers where possible (contiguous row-major data).">
"Flatten the vector type of vector transfers where possible (contiguous row-major data).">,
Option<"optionDropUnitDims", "drop-unit-dims", "bool", /*default=*/"true",
"Drop unit dims in vector transfers where possible (might generate vector.shape_cast).">,
];
let dependentDialects = [
"memref::MemRefDialect"
Expand Down
28 changes: 16 additions & 12 deletions compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,14 +208,16 @@ static void addMemRefLoweringPasses(OpPassManager &pm) {
pm.addPass(createSPIRVVectorizeLoadStore());
// Perform various vector-level cross-op optimizations like load-store
// forwarding, shape casting and casting op cancelling.
pm.addNestedPass<func::FuncOp>(createOptimizeVectorTransferPass());
pm.addNestedPass<func::FuncOp>(createOptimizeVectorTransferPass(
/*flatten=*/false, /*dropUnitDims=*/false));
pm.addNestedPass<func::FuncOp>(createSPIRVBreakDownLargeVectorPass());

// Perform optimizations that need to across the scf.for region boundary.
pm.addNestedPass<func::FuncOp>(createForOpCanonicalizationPass());
pm.addPass(createCanonicalizerPass());
pm.addPass(createCSEPass());
pm.addNestedPass<func::FuncOp>(createOptimizeVectorTransferPass());
pm.addNestedPass<func::FuncOp>(createOptimizeVectorTransferPass(
/*flatten=*/false, /*dropUnitDims=*/false));

// Turn multi-dimension memref into one-dimension. This is needed for SPIR-V
// because we don't use upstream memref descriptors.
Expand Down Expand Up @@ -311,8 +313,8 @@ void addSPIRVBaseVectorizePassPipeline(OpPassManager &pm) {

// Perform various vector-level cross-op optimizations like load-store
// forwarding, shape casting and casting op cancelling.
nestedModulePM.addNestedPass<func::FuncOp>(
createOptimizeVectorTransferPass());
nestedModulePM.addNestedPass<func::FuncOp>(createOptimizeVectorTransferPass(
/*flatten=*/false, /*dropUnitDims=*/false));
}

void addSPIRVCooperativeMatrixVectorizePassPipeline(OpPassManager &pm,
Expand Down Expand Up @@ -370,8 +372,8 @@ void addSPIRVCooperativeMatrixVectorizePassPipeline(OpPassManager &pm,

// Perform various vector-level cross-op optimizations like load-store
// forwarding, shape casting and casting op cancelling.
nestedModulePM.addNestedPass<func::FuncOp>(
createOptimizeVectorTransferPass());
nestedModulePM.addNestedPass<func::FuncOp>(createOptimizeVectorTransferPass(
/*flatten=*/false, /*dropUnitDims=*/false));

// Fold subview ops is reqiured for converting vector transfer ops into SPIR-V
// cooperative ops in the next step.
Expand Down Expand Up @@ -445,10 +447,12 @@ void addSPIRVMatmulPromoteVectorizePassPipeline(OpPassManager &topPM,
// to hoisting. Because this is before folding all memref subview ops away, we
// still have subview ops using the same indices, which allows for transfer
// read/write forwarding.
nestedPM.addNestedPass<func::FuncOp>(createOptimizeVectorTransferPass());
nestedPM.addNestedPass<func::FuncOp>(createOptimizeVectorTransferPass(
/*flatten=*/false, /*dropUnitDims=*/false));

nestedPM.addNestedPass<func::FuncOp>(memref::createFoldMemRefAliasOpsPass());
nestedPM.addNestedPass<func::FuncOp>(createOptimizeVectorTransferPass());
nestedPM.addNestedPass<func::FuncOp>(createOptimizeVectorTransferPass(
/*flatten=*/false, /*dropUnitDims=*/false));

// Hoist loop invariant code to avoid pipelining it.
nestedPM.addNestedPass<func::FuncOp>(createLoopInvariantCodeMotionPass());
Expand Down Expand Up @@ -506,8 +510,8 @@ void addSPIRVSubgroupReducePassPipeline(OpPassManager &pm) {

// Perform various vector-level cross-op optimizations like load-store
// forwarding, shape casting and casting op cancelling.
nestedModulePM.addNestedPass<func::FuncOp>(
createOptimizeVectorTransferPass());
nestedModulePM.addNestedPass<func::FuncOp>(createOptimizeVectorTransferPass(
/*flatten=*/false, /*dropUnitDims=*/false));

// Simplify the IR for vector distribution.
nestedModulePM.addNestedPass<func::FuncOp>(
Expand Down Expand Up @@ -567,8 +571,8 @@ void addSPIRVWinogradVectorizePassPipeline(OpPassManager &pm) {

// Perform various vector-level cross-op optimizations like load-store
// forwarding, shape casting and casting op cancelling.
nestedModulePM.addNestedPass<func::FuncOp>(
createOptimizeVectorTransferPass());
nestedModulePM.addNestedPass<func::FuncOp>(createOptimizeVectorTransferPass(
/*flatten=*/false, /*dropUnitDims=*/false));
}

void addSPIRVTransformDialectPassPipeline(OpPassManager &pm) {
Expand Down

0 comments on commit 88d92bf

Please sign in to comment.