diff --git a/iree/compiler/Conversion/LinalgToLLVM/LinalgTileAndVectorizePass.cpp b/iree/compiler/Conversion/LinalgToLLVM/LinalgTileAndVectorizePass.cpp index 969c6cbb6168..18c5e89c4318 100644 --- a/iree/compiler/Conversion/LinalgToLLVM/LinalgTileAndVectorizePass.cpp +++ b/iree/compiler/Conversion/LinalgToLLVM/LinalgTileAndVectorizePass.cpp @@ -14,6 +14,7 @@ #include "iree/compiler/Conversion/CodegenUtils/MarkerUtils.h" #include "iree/compiler/Conversion/CodegenUtils/TransformUtils.h" +#include "iree/compiler/Conversion/Common/Transforms.h" #include "iree/compiler/Conversion/LinalgToLLVM/KernelDispatch.h" #include "mlir/Conversion/StandardToSPIRV/StandardToSPIRV.h" #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" @@ -134,6 +135,13 @@ void TileAndVectorizeWorkgroups::runOnFunction() { auto funcOp = getOperation(); MLIRContext *context = &getContext(); + // Apply prior vectorization canonicalization passes. + { + OwningRewritePatternList canonicalization(&getContext()); + populateAffineMinSCFCanonicalizationPattern(canonicalization); + (void)applyPatternsAndFoldGreedily(funcOp, std::move(canonicalization)); + } + // Promotes workgroups subviews to a full-tile allocated on the stack. if (clEnablePromoteWorkgroupToFullTiles) { OwningRewritePatternList promotionPatterns(&getContext()); diff --git a/iree/compiler/Conversion/LinalgToLLVM/Passes.cpp b/iree/compiler/Conversion/LinalgToLLVM/Passes.cpp index aed81df07c0a..d025827be287 100644 --- a/iree/compiler/Conversion/LinalgToLLVM/Passes.cpp +++ b/iree/compiler/Conversion/LinalgToLLVM/Passes.cpp @@ -32,6 +32,7 @@ void addLinalgToLLVMPasses(OpPassManager &passManager, LLVMCodegenOptions options) { // Distribute linalg op among a 3d grid of parallel threads. Tile each // workgroup thread memory then vectorize the linalg op. + if (options.usingLinalgOnTensors) { passManager.addPass(createMaterializeCPULaunchConfigurationPass()); } else { @@ -39,6 +40,8 @@ void addLinalgToLLVMPasses(OpPassManager &passManager, } OpPassManager &nestedModulePM = passManager.nest(); + nestedModulePM.addNestedPass(createCanonicalizerPass()); + if (options.useConvImg2Col) { // linalg::ConvInputNHWCFilterHWCFOp -> (Img2Col packing + matmul). // After convolution is tiled and distributed among workgroups its converted @@ -80,7 +83,7 @@ void buildLLVMTransformPassPipeline(OpPassManager &passManager, // HLO -> Linalg on buffers. if (options.usingLinalgOnTensors) { - nestedModulePM.addPass(createLinalgVectorizePass()); + nestedModulePM.addNestedPass(createLinalgVectorizePass()); // Use stack allocation on CPU side. WorkgroupMemoryAllocationFn allocationFn = [](OpBuilder &builder, Location loc, ArrayRef staticShape,