From 2695fe99e9d655ef396834bec7ee51eccb74a454 Mon Sep 17 00:00:00 2001 From: Han-Chung Wang Date: Thu, 8 Aug 2024 14:31:48 -0700 Subject: [PATCH] [GlobalOpt] Switch to new pass generation tablegen definitions. (#18163) This is mostly an NFC change. The revision applies a little cleanups: - Remove `enableQuantizedMatmulReassociation` option from FuseDequantizationMatmulPass. It should be controled by pipeline. - Move testing options to tablegen definitions for PropagateLinalgTransposePass - Switch a couple of passes to follow `create.*Pass` naming convention. - Switch namespaces to the new single-line syntax for FuseSiluHorizontalMatmulPass --------- Signed-off-by: hanhanW --- .../compiler/GlobalOptimization/BUILD.bazel | 1 - .../GlobalOptimization/CMakeLists.txt | 1 - .../CleanupNumericNarrowing.cpp | 12 +-- .../Convert1X1FilterConv2DToMatmul.cpp | 11 +-- .../DataLayoutPropagation.cpp | 13 +-- .../GlobalOptimization/DecomposeConcat.cpp | 15 ++- .../DemoteContractionInputsToBF16.cpp | 11 ++- .../DetachElementwiseFromNamedOps.cpp | 11 +-- .../EraseUnusedLinalgOperands.cpp | 13 +-- .../GlobalOptimization/ExpandTensorShapes.cpp | 14 +-- .../FuseDequantizationMatmul.cpp | 92 ++++++++----------- .../FuseHorizontalContractions.cpp | 18 ++-- .../FuseSiluHorizontalMatmul.cpp | 25 ++--- .../GeneralizeLinalgNamedOps.cpp | 13 +-- .../GlobalLoopInvariantCodeMotion.cpp | 11 +-- .../InferNumericNarrowing.cpp | 11 +-- .../MaterializeHomogeneousEncodings.cpp | 16 ++-- .../GlobalOptimization/OptimizeNumerics.cpp | 12 +-- .../compiler/GlobalOptimization/PassDetail.h | 22 ----- .../compiler/GlobalOptimization/Passes.cpp | 12 +-- .../iree/compiler/GlobalOptimization/Passes.h | 90 +++--------------- .../compiler/GlobalOptimization/Passes.td | 70 +++++--------- .../PropagateLinalgTranspose.cpp | 25 ++--- .../GlobalOptimization/RaiseSpecialOps.cpp | 12 +-- .../RemoveZeroExtentTensors.cpp | 11 +-- .../GlobalOptimization/SimplifyPackUnpack.cpp | 10 +- .../test/fuse_dequantization_matmul.mlir | 2 +- 27 files changed, 183 insertions(+), 371 deletions(-) delete mode 100644 compiler/src/iree/compiler/GlobalOptimization/PassDetail.h diff --git a/compiler/src/iree/compiler/GlobalOptimization/BUILD.bazel b/compiler/src/iree/compiler/GlobalOptimization/BUILD.bazel index efd77583de67..bf0aef892fb6 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/BUILD.bazel +++ b/compiler/src/iree/compiler/GlobalOptimization/BUILD.bazel @@ -30,7 +30,6 @@ iree_gentbl_cc_library( iree_compiler_cc_library( name = "PassHeaders", hdrs = [ - "PassDetail.h", "Passes.h", "Passes.h.inc", ], diff --git a/compiler/src/iree/compiler/GlobalOptimization/CMakeLists.txt b/compiler/src/iree/compiler/GlobalOptimization/CMakeLists.txt index 7d6435a1cf5f..51a1c5d22f98 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/CMakeLists.txt +++ b/compiler/src/iree/compiler/GlobalOptimization/CMakeLists.txt @@ -23,7 +23,6 @@ iree_cc_library( NAME PassHeaders HDRS - "PassDetail.h" "Passes.h" "Passes.h.inc" DEPS diff --git a/compiler/src/iree/compiler/GlobalOptimization/CleanupNumericNarrowing.cpp b/compiler/src/iree/compiler/GlobalOptimization/CleanupNumericNarrowing.cpp index b55dd29b86ff..d06d1bdeeaef 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/CleanupNumericNarrowing.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/CleanupNumericNarrowing.cpp @@ -5,15 +5,18 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception #include "iree/compiler/Dialect/Util/IR/UtilOps.h" -#include "iree/compiler/GlobalOptimization/PassDetail.h" #include "iree/compiler/GlobalOptimization/Passes.h" namespace mlir::iree_compiler::GlobalOptimization { +#define GEN_PASS_DEF_CLEANUPNUMERICNARROWINGPASS +#include "iree/compiler/GlobalOptimization/Passes.h.inc" + namespace { class CleanupNumericNarrowingPass - : public CleanupNumericNarrowingBase { + : public impl::CleanupNumericNarrowingPassBase< + CleanupNumericNarrowingPass> { void runOnOperation() override { getOperation()->walk([](IREE::Util::NumericOptionalNarrowOp op) { op.getResult().replaceAllUsesWith(op.getOperand()); @@ -23,9 +26,4 @@ class CleanupNumericNarrowingPass }; } // namespace - -std::unique_ptr createCleanupNumericNarrowingPass() { - return std::make_unique(); -} - } // namespace mlir::iree_compiler::GlobalOptimization diff --git a/compiler/src/iree/compiler/GlobalOptimization/Convert1X1FilterConv2DToMatmul.cpp b/compiler/src/iree/compiler/GlobalOptimization/Convert1X1FilterConv2DToMatmul.cpp index 1d7b3a0ea11b..a8b4becfff2b 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/Convert1X1FilterConv2DToMatmul.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/Convert1X1FilterConv2DToMatmul.cpp @@ -4,7 +4,6 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -#include "iree/compiler/GlobalOptimization/PassDetail.h" #include "iree/compiler/GlobalOptimization/Passes.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -14,6 +13,9 @@ namespace mlir::iree_compiler::GlobalOptimization { +#define GEN_PASS_DEF_CONVERT1X1FILTERCONV2DTOMATMULPASS +#include "iree/compiler/GlobalOptimization/Passes.h.inc" + namespace { // Converts linalg.conv_2d_input_nhwc_filter_nhwc op to linalg.matmul @@ -157,7 +159,7 @@ class Convert1x1FilterConvToMatmul : public OpRewritePattern { }; struct Convert1X1FilterConv2DToMatmulPass - : public Convert1X1FilterConv2DToMatmulBase< + : public impl::Convert1X1FilterConv2DToMatmulPassBase< Convert1X1FilterConv2DToMatmulPass> { void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); @@ -176,9 +178,4 @@ struct Convert1X1FilterConv2DToMatmulPass } }; } // namespace - -std::unique_ptr createConvert1X1FilterConv2DToMatmulPass() { - return std::make_unique(); -} - } // namespace mlir::iree_compiler::GlobalOptimization diff --git a/compiler/src/iree/compiler/GlobalOptimization/DataLayoutPropagation.cpp b/compiler/src/iree/compiler/GlobalOptimization/DataLayoutPropagation.cpp index c56e7087c713..44172d3355c3 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/DataLayoutPropagation.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/DataLayoutPropagation.cpp @@ -4,20 +4,20 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -#include "iree/compiler/GlobalOptimization/PassDetail.h" #include "iree/compiler/GlobalOptimization/Passes.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -using namespace mlir; - namespace mlir::iree_compiler::GlobalOptimization { +#define GEN_PASS_DEF_DATALAYOUTPROPAGATIONPASS +#include "iree/compiler/GlobalOptimization/Passes.h.inc" + namespace { struct DataLayoutPropagationPass - : public DataLayoutPropagationBase { + : public impl::DataLayoutPropagationPassBase { void runOnOperation() override { MLIRContext *context = &getContext(); FunctionOpInterface funcOp = getOperation(); @@ -43,9 +43,4 @@ struct DataLayoutPropagationPass }; } // namespace - -std::unique_ptr> -createDataLayoutPropagationPass() { - return std::make_unique(); -} } // namespace mlir::iree_compiler::GlobalOptimization diff --git a/compiler/src/iree/compiler/GlobalOptimization/DecomposeConcat.cpp b/compiler/src/iree/compiler/GlobalOptimization/DecomposeConcat.cpp index 445ea4f0df19..eae755f841c8 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/DecomposeConcat.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/DecomposeConcat.cpp @@ -4,7 +4,6 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -#include "iree/compiler/GlobalOptimization/PassDetail.h" #include "iree/compiler/GlobalOptimization/Passes.h" #include "llvm/ADT/STLExtras.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" @@ -16,6 +15,9 @@ namespace mlir::iree_compiler::GlobalOptimization { +#define GEN_PASS_DEF_DECOMPOSECONCATPASS +#include "iree/compiler/GlobalOptimization/Passes.h.inc" + namespace { static Value createTranspose(OpBuilder &builder, Value source, @@ -78,13 +80,16 @@ struct TransposeInnerConcatenation : public OpRewritePattern { } }; -struct DecomposeConcatPass : public DecomposeConcatBase { +struct DecomposeConcatPass + : public impl::DecomposeConcatPassBase { + using impl::DecomposeConcatPassBase< + DecomposeConcatPass>::DecomposeConcatPassBase; + explicit DecomposeConcatPass(bool enableConcatTransposition) { + this->enableConcatTransposition = enableConcatTransposition; + } void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); } - DecomposeConcatPass(bool enableConcatTransposition) { - this->enableConcatTransposition = enableConcatTransposition; - } DecomposeConcatPass(const DecomposeConcatPass &pass) : DecomposeConcatPass(pass.enableConcatTransposition) {} diff --git a/compiler/src/iree/compiler/GlobalOptimization/DemoteContractionInputsToBF16.cpp b/compiler/src/iree/compiler/GlobalOptimization/DemoteContractionInputsToBF16.cpp index 9df9578140d5..5f72ff5763d9 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/DemoteContractionInputsToBF16.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/DemoteContractionInputsToBF16.cpp @@ -5,7 +5,6 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception #include "iree/compiler/Dialect/Util/IR/UtilOps.h" -#include "iree/compiler/GlobalOptimization/PassDetail.h" #include "iree/compiler/GlobalOptimization/Passes.h" #include "llvm/Support/Debug.h" #include "mlir/Dialect/Arith/IR/Arith.h" @@ -19,6 +18,9 @@ namespace mlir::iree_compiler::GlobalOptimization { +#define GEN_PASS_DEF_DEMOTECONTRACTIONINPUTSTOBF16PASS +#include "iree/compiler/GlobalOptimization/Passes.h.inc" + namespace { // For narrowable inputs, selects @@ -133,12 +135,13 @@ struct DemoteContractionInputsToBF16Pattern }; class DemoteContractionInputsToBF16Pass - : public DemoteContractionInputsToBF16Base< + : public impl::DemoteContractionInputsToBF16PassBase< DemoteContractionInputsToBF16Pass> { - public: + using impl::DemoteContractionInputsToBF16PassBase< + DemoteContractionInputsToBF16Pass>::DemoteContractionInputsToBF16PassBase; explicit DemoteContractionInputsToBF16Pass(const DemotionOption &option) { - this->demoteOnly.setValue(option); + this->demoteOnly = option; } void runOnOperation() override { MLIRContext *context = &getContext(); diff --git a/compiler/src/iree/compiler/GlobalOptimization/DetachElementwiseFromNamedOps.cpp b/compiler/src/iree/compiler/GlobalOptimization/DetachElementwiseFromNamedOps.cpp index c90a0f8e865f..524f111e271c 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/DetachElementwiseFromNamedOps.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/DetachElementwiseFromNamedOps.cpp @@ -12,7 +12,6 @@ //===----------------------------------------------------------------------===// #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h" -#include "iree/compiler/GlobalOptimization/PassDetail.h" #include "iree/compiler/GlobalOptimization/Passes.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" @@ -27,6 +26,9 @@ namespace mlir::iree_compiler::GlobalOptimization { +#define GEN_PASS_DEF_DETACHELEMENTWISEFROMNAMEDOPSPASS +#include "iree/compiler/GlobalOptimization/Passes.h.inc" + namespace { struct DetachElementwisePattern @@ -185,7 +187,7 @@ struct DetachSplatConstantOutsOperands }; struct DetachElementwiseFromNamedOpsPass - : public DetachElementwiseFromNamedOpsBase< + : public impl::DetachElementwiseFromNamedOpsPassBase< DetachElementwiseFromNamedOpsPass> { void getDependentDialects(DialectRegistry ®istry) const override { registry.insert createDetachElementwiseFromNamedOpsPass() { - return std::make_unique(); -} - } // namespace mlir::iree_compiler::GlobalOptimization diff --git a/compiler/src/iree/compiler/GlobalOptimization/EraseUnusedLinalgOperands.cpp b/compiler/src/iree/compiler/GlobalOptimization/EraseUnusedLinalgOperands.cpp index a99538290674..3396e4622eb8 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/EraseUnusedLinalgOperands.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/EraseUnusedLinalgOperands.cpp @@ -4,7 +4,6 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -#include "iree/compiler/GlobalOptimization/PassDetail.h" #include "iree/compiler/GlobalOptimization/Passes.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" @@ -14,9 +13,13 @@ namespace mlir::iree_compiler::GlobalOptimization { +#define GEN_PASS_DEF_ERASEUNUSEDLINALGOPERANDSPASS +#include "iree/compiler/GlobalOptimization/Passes.h.inc" + namespace { struct EraseUnusedLinalgOperandsPass - : public EraseUnusedLinalgOperandsBase { + : public impl::EraseUnusedLinalgOperandsPassBase< + EraseUnusedLinalgOperandsPass> { void runOnOperation() override { MLIRContext *context = &getContext(); RewritePatternSet patterns(context); @@ -28,10 +31,4 @@ struct EraseUnusedLinalgOperandsPass } }; } // namespace - -std::unique_ptr> -createEraseUnusedLinalgOperands() { - return std::make_unique(); -} - } // namespace mlir::iree_compiler::GlobalOptimization diff --git a/compiler/src/iree/compiler/GlobalOptimization/ExpandTensorShapes.cpp b/compiler/src/iree/compiler/GlobalOptimization/ExpandTensorShapes.cpp index 51209999c296..c9974d1d68a8 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/ExpandTensorShapes.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/ExpandTensorShapes.cpp @@ -11,7 +11,6 @@ #include "iree/compiler/Dialect/Util/IR/UtilDialect.h" #include "iree/compiler/Dialect/Util/IR/UtilOps.h" #include "iree/compiler/Dialect/Util/Transforms/Patterns.h" -#include "iree/compiler/GlobalOptimization/PassDetail.h" #include "iree/compiler/GlobalOptimization/Passes.h" #include "iree/compiler/Utils/IntegerSet.h" #include "llvm/ADT/BreadthFirstIterator.h" @@ -29,6 +28,10 @@ #define DEBUG_TYPE "iree-global-opt-expand-tensor-shapes" namespace mlir::iree_compiler::GlobalOptimization { + +#define GEN_PASS_DEF_EXPANDTENSORSHAPESPASS +#include "iree/compiler/GlobalOptimization/Passes.h.inc" + namespace { // TODO(benvanik): factor out into a generic util pass base that lets us share @@ -624,10 +627,8 @@ static void expandTensorDims(Operation *op, SymbolTable &symbolTable, // results are always wrapped in a flow.tensor.tie_shape, with the // elision/deduplication/etc left until cleanup. class ExpandTensorShapesPass - : public ExpandTensorShapesBase { + : public impl::ExpandTensorShapesPassBase { public: - ExpandTensorShapesPass() = default; - void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); registry.insert(); @@ -661,9 +662,4 @@ class ExpandTensorShapesPass }; } // namespace - -std::unique_ptr> createExpandTensorShapesPass() { - return std::make_unique(); -} - } // namespace mlir::iree_compiler::GlobalOptimization diff --git a/compiler/src/iree/compiler/GlobalOptimization/FuseDequantizationMatmul.cpp b/compiler/src/iree/compiler/GlobalOptimization/FuseDequantizationMatmul.cpp index 832d9fa6b6cd..af3cff98bdbe 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/FuseDequantizationMatmul.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/FuseDequantizationMatmul.cpp @@ -6,7 +6,6 @@ #include "iree/compiler/Dialect/Flow/IR/FlowOps.h" #include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h" -#include "iree/compiler/GlobalOptimization/PassDetail.h" #include "iree/compiler/GlobalOptimization/Passes.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" @@ -24,6 +23,9 @@ namespace mlir::iree_compiler::GlobalOptimization { +#define GEN_PASS_DEF_FUSEDEQUANTIZATIONMATMULPASS +#include "iree/compiler/GlobalOptimization/Passes.h.inc" + namespace { //----------------------------------------------------------------------------// @@ -767,19 +769,12 @@ static LogicalResult reassociateDequantMatmul(RewriterBase &rewriter, } struct FuseDequantizationMatmulPass - : public FuseDequantizationMatmulBase { - + : public impl::FuseDequantizationMatmulPassBase< + FuseDequantizationMatmulPass> { void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); } - FuseDequantizationMatmulPass(bool enableQuantizedMatmulReassociation) { - this->enableQuantizedMatmulReassociation = - enableQuantizedMatmulReassociation; - } - FuseDequantizationMatmulPass(const FuseDequantizationMatmulPass &pass) - : FuseDequantizationMatmulPass(pass.enableQuantizedMatmulReassociation) {} - void runOnOperation() override; }; @@ -789,56 +784,45 @@ void FuseDequantizationMatmulPass::runOnOperation() { MLIRContext *context = &getContext(); auto funcOp = getOperation(); - // Perform reassociation if enabled - if (this->enableQuantizedMatmulReassociation) { - int quantizeBitWidth = 16; - SmallVector> candidates; - for (auto genericOp : - funcOp.getFunctionBody().getOps()) { - if (failed(isContractionWithTwoReductions(genericOp))) { - continue; - } + int quantizeBitWidth = 16; + SmallVector> candidates; + for (auto genericOp : funcOp.getFunctionBody().getOps()) { + if (failed(isContractionWithTwoReductions(genericOp))) { + continue; + } - OpOperand *lhs = genericOp.getDpsInputOperand(0); - OpOperand *rhs = genericOp.getDpsInputOperand(1); - auto lhsOp = lhs->get().getDefiningOp(); - auto rhsOp = rhs->get().getDefiningOp(); - if (!llvm::cast(genericOp.getInputs()[0].getType()) - .hasStaticShape() || - !llvm::cast(genericOp.getInputs()[1].getType()) - .hasStaticShape() || - !llvm::cast(genericOp.getResults()[0].getType()) - .hasStaticShape()) { - // Codegen can't handle the dynamic case yet. + OpOperand *lhs = genericOp.getDpsInputOperand(0); + OpOperand *rhs = genericOp.getDpsInputOperand(1); + auto lhsOp = lhs->get().getDefiningOp(); + auto rhsOp = rhs->get().getDefiningOp(); + if (!llvm::cast(genericOp.getInputs()[0].getType()) + .hasStaticShape() || + !llvm::cast(genericOp.getInputs()[1].getType()) + .hasStaticShape() || + !llvm::cast(genericOp.getResults()[0].getType()) + .hasStaticShape()) { + // Codegen can't handle the dynamic case yet. + continue; + } + if (lhsOp) { + if (!failed(isGroupedDequantizationOp(lhsOp))) { + candidates.push_back(std::make_pair(lhsOp, genericOp)); continue; } - if (lhsOp) { - if (!failed(isGroupedDequantizationOp(lhsOp))) { - candidates.push_back(std::make_pair(lhsOp, genericOp)); - continue; - } - } - if (rhsOp) { - if (!failed(isGroupedDequantizationOp(rhsOp))) { - candidates.push_back(std::make_pair(rhsOp, genericOp)); - } - } } - IRRewriter rewriter(context); - for (auto candidate : candidates) { - rewriter.setInsertionPointAfter(candidate.second); - if (failed(reassociateDequantMatmul( - rewriter, candidate.first, candidate.second, quantizeBitWidth))) { - return signalPassFailure(); + if (rhsOp) { + if (!failed(isGroupedDequantizationOp(rhsOp))) { + candidates.push_back(std::make_pair(rhsOp, genericOp)); } } } + IRRewriter rewriter(context); + for (auto candidate : candidates) { + rewriter.setInsertionPointAfter(candidate.second); + if (failed(reassociateDequantMatmul(rewriter, candidate.first, + candidate.second, quantizeBitWidth))) { + return signalPassFailure(); + } + } } - -std::unique_ptr> -createFuseDequantizationMatmulPass(bool enableQuantizedMatmulReassociation) { - return std::make_unique( - enableQuantizedMatmulReassociation); -} - } // namespace mlir::iree_compiler::GlobalOptimization diff --git a/compiler/src/iree/compiler/GlobalOptimization/FuseHorizontalContractions.cpp b/compiler/src/iree/compiler/GlobalOptimization/FuseHorizontalContractions.cpp index ca9a4901388b..b71912bb0a83 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/FuseHorizontalContractions.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/FuseHorizontalContractions.cpp @@ -6,11 +6,8 @@ #include "iree/compiler/Dialect/Flow/IR/FlowOps.h" #include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h" -#include "iree/compiler/GlobalOptimization/PassDetail.h" #include "iree/compiler/GlobalOptimization/Passes.h" #include "iree/compiler/GlobalOptimization/Utils.h" -#include "mlir/IR/Dominance.h" - #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" @@ -21,6 +18,7 @@ #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dominance.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" @@ -28,14 +26,18 @@ #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#define DEBUG_TYPE "iree-global-opt-fuse-horizontal-contraction" +#define DEBUG_TYPE "iree-global-opt-fuse-horizontal-contractions" namespace mlir::iree_compiler::GlobalOptimization { +#define GEN_PASS_DEF_FUSEHORIZONTALCONTRACTIONSPASS +#include "iree/compiler/GlobalOptimization/Passes.h.inc" + namespace { struct FuseHorizontalContractionsPass - : public FuseHorizontalContractionsBase { + : public impl::FuseHorizontalContractionsPassBase< + FuseHorizontalContractionsPass> { void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); @@ -449,10 +451,4 @@ void FuseHorizontalContractionsPass::runOnOperation() { return signalPassFailure(); } } - -std::unique_ptr> -createFuseHorizontalContractionsPass() { - return std::make_unique(); -} - } // namespace mlir::iree_compiler::GlobalOptimization diff --git a/compiler/src/iree/compiler/GlobalOptimization/FuseSiluHorizontalMatmul.cpp b/compiler/src/iree/compiler/GlobalOptimization/FuseSiluHorizontalMatmul.cpp index 3da1d320edb3..b4a2711c3f97 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/FuseSiluHorizontalMatmul.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/FuseSiluHorizontalMatmul.cpp @@ -6,7 +6,6 @@ #include "iree/compiler/Dialect/Flow/IR/FlowOps.h" #include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h" -#include "iree/compiler/GlobalOptimization/PassDetail.h" #include "iree/compiler/GlobalOptimization/Passes.h" #include "iree/compiler/GlobalOptimization/Utils.h" #include "mlir/Dialect/Arith/IR/Arith.h" @@ -23,9 +22,10 @@ #define DEBUG_TYPE "iree-global-opt-fuse-dequantization-matmul" #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") -namespace mlir { -namespace iree_compiler { -namespace GlobalOptimization { +namespace mlir::iree_compiler::GlobalOptimization { + +#define GEN_PASS_DEF_FUSESILUHORIZONTALMATMULPASS +#include "iree/compiler/GlobalOptimization/Passes.h.inc" namespace { @@ -169,16 +169,12 @@ class FuseSiluHorizontalMatmulPattern final }; struct FuseSiluHorizontalMatmulPass - : public FuseSiluHorizontalMatmulBase { - + : public impl::FuseSiluHorizontalMatmulPassBase< + FuseSiluHorizontalMatmulPass> { void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); } - FuseSiluHorizontalMatmulPass() {} - FuseSiluHorizontalMatmulPass(const FuseSiluHorizontalMatmulPass &pass) - : FuseSiluHorizontalMatmulPass() {} - void runOnOperation() override; }; @@ -195,11 +191,4 @@ void FuseSiluHorizontalMatmulPass::runOnOperation() { } } -std::unique_ptr> -createFuseSiluHorizontalMatmulPass() { - return std::make_unique(); -} - -} // namespace GlobalOptimization -} // namespace iree_compiler -} // namespace mlir +} // namespace mlir::iree_compiler::GlobalOptimization diff --git a/compiler/src/iree/compiler/GlobalOptimization/GeneralizeLinalgNamedOps.cpp b/compiler/src/iree/compiler/GlobalOptimization/GeneralizeLinalgNamedOps.cpp index 5f30902c7e2f..92293bc156ba 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/GeneralizeLinalgNamedOps.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/GeneralizeLinalgNamedOps.cpp @@ -12,7 +12,6 @@ //===----------------------------------------------------------------------===// #include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h" -#include "iree/compiler/GlobalOptimization/PassDetail.h" #include "iree/compiler/GlobalOptimization/Passes.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" @@ -20,10 +19,13 @@ namespace mlir::iree_compiler::GlobalOptimization { +#define GEN_PASS_DEF_GENERALIZELINALGNAMEDOPSPASS +#include "iree/compiler/GlobalOptimization/Passes.h.inc" + namespace { struct GeneralizeLinalgNamedOpsPass - : public GeneralizeLinalgNamedOpsBase { - + : public impl::GeneralizeLinalgNamedOpsPassBase< + GeneralizeLinalgNamedOpsPass> { void runOnOperation() override; }; } // namespace @@ -59,9 +61,4 @@ void GeneralizeLinalgNamedOpsPass::runOnOperation() { } } -std::unique_ptr> -createGeneralizeLinalgNamedOpsPass() { - return std::make_unique(); -} - } // namespace mlir::iree_compiler::GlobalOptimization diff --git a/compiler/src/iree/compiler/GlobalOptimization/GlobalLoopInvariantCodeMotion.cpp b/compiler/src/iree/compiler/GlobalOptimization/GlobalLoopInvariantCodeMotion.cpp index 9a30e21cf2c7..0489448d9b0a 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/GlobalLoopInvariantCodeMotion.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/GlobalLoopInvariantCodeMotion.cpp @@ -4,7 +4,6 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -#include "iree/compiler/GlobalOptimization/PassDetail.h" #include "iree/compiler/GlobalOptimization/Passes.h" #include "llvm/ADT/TypeSwitch.h" #include "mlir/Dialect/SCF/IR/SCF.h" @@ -101,10 +100,13 @@ static LogicalResult hoistLoopInvariants(LoopLikeOpInterface loopOp, namespace mlir::iree_compiler::GlobalOptimization { +#define GEN_PASS_DEF_GLOBALLOOPINVARIANTCODEMOTIONPASS +#include "iree/compiler/GlobalOptimization/Passes.h.inc" + namespace { struct GlobalLoopInvariantCodeMotionPass - : public GlobalLoopInvariantCodeMotionBase< + : public impl::GlobalLoopInvariantCodeMotionPassBase< GlobalLoopInvariantCodeMotionPass> { void runOnOperation() override { MLIRContext *context = &getContext(); @@ -129,9 +131,4 @@ struct GlobalLoopInvariantCodeMotionPass }; } // namespace - -std::unique_ptr> -createGlobalLoopInvariantCodeMotionPass() { - return std::make_unique(); -} } // namespace mlir::iree_compiler::GlobalOptimization diff --git a/compiler/src/iree/compiler/GlobalOptimization/InferNumericNarrowing.cpp b/compiler/src/iree/compiler/GlobalOptimization/InferNumericNarrowing.cpp index 0edb016eee9b..c4d48aa3b90d 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/InferNumericNarrowing.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/InferNumericNarrowing.cpp @@ -10,7 +10,6 @@ #include "iree/compiler/Dialect/Util/Analysis/Explorer.h" #include "iree/compiler/Dialect/Util/IR/UtilDialect.h" #include "iree/compiler/Dialect/Util/IR/UtilOps.h" -#include "iree/compiler/GlobalOptimization/PassDetail.h" #include "iree/compiler/GlobalOptimization/Passes.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Support/Debug.h" @@ -20,6 +19,9 @@ using llvm::SmallPtrSet; namespace mlir::iree_compiler::GlobalOptimization { +#define GEN_PASS_DEF_INFERNUMERICNARROWINGPASS +#include "iree/compiler/GlobalOptimization/Passes.h.inc" + namespace { IntegerType deriveIntegerTypeFromRange(MLIRContext *context, int64_t minValue, @@ -47,7 +49,7 @@ IntegerType deriveIntegerTypeFromRange(MLIRContext *context, int64_t minValue, } class InferNumericNarrowingPass - : public InferNumericNarrowingBase { + : public impl::InferNumericNarrowingPassBase { void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); } @@ -128,9 +130,4 @@ class InferNumericNarrowingPass }; } // namespace - -std::unique_ptr createInferNumericNarrowingPass() { - return std::make_unique(); -} - } // namespace mlir::iree_compiler::GlobalOptimization diff --git a/compiler/src/iree/compiler/GlobalOptimization/MaterializeHomogeneousEncodings.cpp b/compiler/src/iree/compiler/GlobalOptimization/MaterializeHomogeneousEncodings.cpp index 5c264befab8b..750ae788f63b 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/MaterializeHomogeneousEncodings.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/MaterializeHomogeneousEncodings.cpp @@ -9,7 +9,6 @@ #include "iree/compiler/Dialect/HAL/Analysis/DeviceAnalysis.h" #include "iree/compiler/Dialect/HAL/IR/HALDialect.h" #include "iree/compiler/Dialect/HAL/IR/HALOps.h" -#include "iree/compiler/GlobalOptimization/PassDetail.h" #include "iree/compiler/GlobalOptimization/Passes.h" #include "iree/compiler/Utils/PassUtils.h" #include "llvm/ADT/STLExtras.h" @@ -24,15 +23,17 @@ namespace mlir::iree_compiler::GlobalOptimization { +#define GEN_PASS_DEF_MATERIALIZEHOMOGENEOUSENCODINGSPASS +#include "iree/compiler/GlobalOptimization/Passes.h.inc" + using FunctionLikeNest = MultiOpNest; +namespace { class MaterializeHomogeneousEncodingsPass - : public MaterializeHomogeneousEncodingsBase< + : public impl::MaterializeHomogeneousEncodingsPassBase< MaterializeHomogeneousEncodingsPass> { public: - MaterializeHomogeneousEncodingsPass() = default; - void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); } @@ -78,10 +79,5 @@ class MaterializeHomogeneousEncodingsPass } } }; - -std::unique_ptr> -createMaterializeHomogeneousEncodingsPass() { - return std::make_unique(); -} - +} // namespace } // namespace mlir::iree_compiler::GlobalOptimization diff --git a/compiler/src/iree/compiler/GlobalOptimization/OptimizeNumerics.cpp b/compiler/src/iree/compiler/GlobalOptimization/OptimizeNumerics.cpp index 329354e2e84b..38b87498e440 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/OptimizeNumerics.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/OptimizeNumerics.cpp @@ -5,7 +5,6 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception #include "iree/compiler/Dialect/Util/IR/UtilOps.h" -#include "iree/compiler/GlobalOptimization/PassDetail.h" #include "iree/compiler/GlobalOptimization/Passes.h" #include "llvm/Support/Debug.h" #include "mlir/Dialect/Arith/IR/Arith.h" @@ -16,6 +15,9 @@ namespace mlir::iree_compiler::GlobalOptimization { +#define GEN_PASS_DEF_OPTIMIZENUMERICSPASS +#include "iree/compiler/GlobalOptimization/Passes.h.inc" + namespace { int getNextPotBitWidth(int bitWidth, int minBitWidth = 8) { @@ -254,7 +256,8 @@ struct LinalgFpMatmulToLowP : public OpRewritePattern { } }; -class OptimizeNumericsPass : public OptimizeNumericsBase { +class OptimizeNumericsPass + : public impl::OptimizeNumericsPassBase { void runOnOperation() override { MLIRContext *context = &getContext(); RewritePatternSet patterns(context); @@ -274,9 +277,4 @@ class OptimizeNumericsPass : public OptimizeNumericsBase { }; } // namespace - -std::unique_ptr createOptimizeNumericsPass() { - return std::make_unique(); -} - } // namespace mlir::iree_compiler::GlobalOptimization diff --git a/compiler/src/iree/compiler/GlobalOptimization/PassDetail.h b/compiler/src/iree/compiler/GlobalOptimization/PassDetail.h deleted file mode 100644 index b0b79e726b75..000000000000 --- a/compiler/src/iree/compiler/GlobalOptimization/PassDetail.h +++ /dev/null @@ -1,22 +0,0 @@ -// Copyright 2023 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#ifndef IREE_COMPILER_GLOBALOPTIMIZATION_PASSDETAIL_H_ -#define IREE_COMPILER_GLOBALOPTIMIZATION_PASSDETAIL_H_ - -#include "iree/compiler/GlobalOptimization/Passes.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/Interfaces/FunctionInterfaces.h" -#include "mlir/Pass/Pass.h" - -namespace mlir::iree_compiler::GlobalOptimization { - -#define GEN_PASS_CLASSES -#include "iree/compiler/GlobalOptimization/Passes.h.inc" - -} // namespace mlir::iree_compiler::GlobalOptimization - -#endif // IREE_COMPILER_GLOBALOPTIMIZATION_PASSDETAIL_H_ diff --git a/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp b/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp index e095d788cb44..877f9761a7a8 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp @@ -101,7 +101,7 @@ void buildGlobalOptimizationPassPipeline( .addPass(createDetachElementwiseFromNamedOpsPass) .addPass(mlir::createLinalgNamedOpConversionPass) .addPass(createConvert1X1FilterConv2DToMatmulPass); - mainPassManager.addPass(createEraseUnusedLinalgOperands()); + mainPassManager.addPass(createEraseUnusedLinalgOperandsPass()); // Expand tensor shapes into SSA values and optimize the whole program. // The more we are able to equate shape dimensions at this level the @@ -116,7 +116,7 @@ void buildGlobalOptimizationPassPipeline( // RaiseSpecialOps, by virtue of implementing various peephole // optimizations, is sensitive to surrounding IR structure. Thus we run // this pass both before unit dim folding + consteval, as well as after. - .addPass(createRaiseSpecialOps) + .addPass(createRaiseSpecialOpsPass) // We decompose and transpose concatenations immediately before folding // unit extent dims because this allows decoupling unit dims in the // concatenation from the transposes that are introduced. @@ -138,10 +138,8 @@ void buildGlobalOptimizationPassPipeline( return createDemoteContractionInputsToBF16Pass( clDemoteContractionInputsToBF16Strategy); }) - .addPass([&]() { - return createFuseDequantizationMatmulPass( - clEnableQuantizedMatmulReassociation); - }) + .addPredicatedPass(clEnableQuantizedMatmulReassociation, + createFuseDequantizationMatmulPass) .addPass(IREE::Flow::createCanonicalizerPass) .addPass(mlir::createCSEPass) // Propagate transposes immediately before set encoding/data tiling @@ -226,7 +224,7 @@ void buildGlobalOptimizationPassPipeline( FunctionLikeNest(mainPassManager) // After running const-eval to a fixed point and folding unit extent dims, // try any new raising opportunities. - .addPass(createRaiseSpecialOps) + .addPass(createRaiseSpecialOpsPass) // Strip std.assert & co after we perform optimizations; prior to this we // may use the assertions to derive information during analysis. .addPredicatedPass(transformOptions.options.stripAssertions, diff --git a/compiler/src/iree/compiler/GlobalOptimization/Passes.h b/compiler/src/iree/compiler/GlobalOptimization/Passes.h index 7643737cda74..62252247135e 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/Passes.h +++ b/compiler/src/iree/compiler/GlobalOptimization/Passes.h @@ -36,93 +36,27 @@ struct TransformOptions : public PassPipelineOptions { void buildGlobalOptimizationPassPipeline( OpPassManager &mainPassManager, const TransformOptions &transformOptions); -//===----------------------------------------------------------------------===// -// Input canonicalization and legalization -//===----------------------------------------------------------------------===// +//------------------------------------------------------------------------------ +// Wrappers that not use tablegen options. +//------------------------------------------------------------------------------ -/// Cleans up any numeric narrowing ops inserted by -/// iree-global-opt-infer-numeric-narrowing. -std::unique_ptr createCleanupNumericNarrowingPass(); - -/// Converts linalg convolution ops with 1x1 kernels into linalg.matmul. -std::unique_ptr createConvert1X1FilterConv2DToMatmulPass(); - -/// Fuses dequantization and matmul linalg.generic ops -std::unique_ptr -createDecomposeConcatPass(bool enableConcatTransposition = false); +std::unique_ptr createDecomposeConcatPass(bool enableConcatTransposition); // Used by the demoteContractionInputsToBF16 pass to determine which op inputs // to demote. enum class DemotionOption { All, Conv, Matmul, None }; +std::unique_ptr +createDemoteContractionInputsToBF16Pass(DemotionOption option); -/// Demotes inputs (LHS, RHS) of linalg matmul-like ops from f32 to bf16. -std::unique_ptr createDemoteContractionInputsToBF16Pass( - DemotionOption option = DemotionOption::None); - -/// Detaches elementwise ops from named Linalg ops. -std::unique_ptr createDetachElementwiseFromNamedOpsPass(); - -/// Applies patterns to erase unused linalg operands and remove dead code -/// associated. -std::unique_ptr> -createEraseUnusedLinalgOperands(); - -/// Expands tensor shape dimensions into SSA values across the program. -std::unique_ptr> createExpandTensorShapesPass(); - -/// Fuses dequantization and matmul linalg.generic ops -std::unique_ptr> -createFuseDequantizationMatmulPass( - bool enableQuantizedMatmulReassociation = false); - -/// Horizontally fuses multiple contraction ops. -std::unique_ptr> -createFuseHorizontalContractionsPass(); - -/// Fuses two matmul ops and a linalg.generic Silu op -std::unique_ptr> -createFuseSiluHorizontalMatmulPass(); - -/// Generalizes some named Linalg ops into `linalg.generic` operations since the -/// compiler can handle that better. -std::unique_ptr> -createGeneralizeLinalgNamedOpsPass(); - -/// Infers and inserts util.numeric.optional_narrow ops at points that may be -/// beneficial. -std::unique_ptr createInferNumericNarrowingPass(); - -/// Materializes logical encodings to physical encodings if there is a single -/// device target. -std::unique_ptr> -createMaterializeHomogeneousEncodingsPass(); - -/// Optimizes numerics given annotations added via -/// iree-global-opt-infer-numeric-narrowing. -std::unique_ptr createOptimizeNumericsPass(); - -/// Propagates linalg.transpose ops to a restricted set of operations. -std::unique_ptr> -createPropagateLinalgTransposePass(bool enableAggressivePropagation = false); - -/// Performs specialized raisings of various sequences of ops to a -/// representation easier for the compiler to handle. -std::unique_ptr createRaiseSpecialOps(); - -/// Removes tensors that have 0-extents. std::unique_ptr> -createRemoveZeroExtentTensorsPass(); +createPropagateLinalgTransposePass(bool enableAggressivePropagation); -/// Simplifies tensor pack/unpack ops to reshape ops. -std::unique_ptr createSimplifyPackUnpackPass(); +//----------------------------------------------------------------------------// +// Register GlobalOptimization Passes +//----------------------------------------------------------------------------// -/// Hoist loop invariants out of loops with zero-trip-check. -std::unique_ptr> -createGlobalLoopInvariantCodeMotionPass(); - -/// Propagate pack/unpack ops across other ops to improve fusion. -std::unique_ptr> -createDataLayoutPropagationPass(); +#define GEN_PASS_DECL +#include "iree/compiler/GlobalOptimization/Passes.h.inc" // IWYU pragma: keep void registerGlobalOptimizationPipeline(); diff --git a/compiler/src/iree/compiler/GlobalOptimization/Passes.td b/compiler/src/iree/compiler/GlobalOptimization/Passes.td index 141fcf394188..c91916146d2f 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/Passes.td +++ b/compiler/src/iree/compiler/GlobalOptimization/Passes.td @@ -9,22 +9,19 @@ include "mlir/Pass/PassBase.td" -def CleanupNumericNarrowing : +def CleanupNumericNarrowingPass : Pass<"iree-global-opt-cleanup-numeric-narrowing", ""> { let summary = "Cleans up any numeric narrowing ops inserted by iree-global-opt-infer-numeric-narrowing."; - let constructor = "mlir::iree_compiler::GlobalOptimization::createCleanupNumericNarrowingPass()"; } -def Convert1X1FilterConv2DToMatmul: +def Convert1X1FilterConv2DToMatmulPass: Pass<"iree-global-opt-convert-1x1-filter-conv2d-to-matmul", ""> { let summary = "Convert linalg convolution ops with 1x1 kernels into linalg matrix multiplication ops."; - let constructor = "mlir::iree_compiler::GlobalOptimization::createConvert1X1FilterConv2DToMatmulPass()"; } -def DecomposeConcat : +def DecomposeConcatPass : Pass<"iree-global-opt-decompose-concat", ""> { let summary = "Decomposes concatenations into a destination and a sequence of slice inserts."; - let constructor = "mlir::iree_compiler::GlobalOptimization::createDecomposeConcatPass()"; let options = [ Option<"enableConcatTransposition", "enable-concat-transposition", "bool", /*default=*/"false", "Allows transposing concatenations such that " @@ -32,12 +29,10 @@ def DecomposeConcat : ]; } -def DemoteContractionInputsToBF16 +def DemoteContractionInputsToBF16Pass : Pass<"iree-global-opt-demote-contraction-inputs-to-bf16", ""> { let summary = "Demotes inputs (LHS, RHS) of linalg matmul-like ops from f32 to bf16."; - let constructor = "mlir::iree_compiler::GlobalOptimization::" - "createDemoteContractionInputsToBF16Pass()"; let options = [Option<"demoteOnly", "demote-only", "mlir::iree_compiler::GlobalOptimization::DemotionOption", @@ -61,106 +56,89 @@ def DemoteContractionInputsToBF16 ]; } -def DetachElementwiseFromNamedOps : +def DetachElementwiseFromNamedOpsPass : Pass<"iree-global-opt-detach-elementwise-from-named-ops", ""> { let summary = "Detaches elementwise ops from named Linalg ops."; - let constructor = "mlir::iree_compiler::GlobalOptimization::createDetachElementwiseFromNamedOpsPass()"; } -def EraseUnusedLinalgOperands : +def EraseUnusedLinalgOperandsPass : Pass<"iree-global-opt-erase-unused-linalg-operands", "mlir::ModuleOp"> { let summary = "Erases unused linalg operand and remove dead code."; - let constructor = "mlir::iree_compiler::GlobalOptimization::createEraseUnusedLinalgOperands()"; } -def ExpandTensorShapes : +def ExpandTensorShapesPass : Pass<"iree-global-opt-expand-tensor-shapes", "mlir::ModuleOp"> { let summary = "Expands tensor shape dimensions into SSA values across the program."; - let constructor = "mlir::iree_compiler::GlobalOptimization::createExpandTensorShapesPass()"; } -def FuseDequantizationMatmul: +def FuseDequantizationMatmulPass: InterfacePass<"iree-global-opt-fuse-dequantization-matmul", "mlir::FunctionOpInterface"> { let summary = "Fuses dequantization and matmul linalg.generic ops."; - let constructor = "mlir::iree_compiler::GlobalOptimization::createFuseDequantizationMatmulPass()"; - let options = [ - Option<"enableQuantizedMatmulReassociation", "enable-quantized-matmul-reassociation", "bool", - /*default=*/"false", "Allow reassociation of quantized matmuls (experimental).">, - ]; } -def FuseHorizontalContractions: +def FuseHorizontalContractionsPass: InterfacePass<"iree-global-opt-fuse-horizontal-contractions", "mlir::FunctionOpInterface"> { let summary = "Fuses horizontal contraction ops without fusions"; - let constructor = "mlir::iree_compiler::GlobalOptimization::createFuseHorizontalContractionsPass()"; } -def FuseSiluHorizontalMatmul: +def FuseSiluHorizontalMatmulPass: InterfacePass<"iree-global-opt-fuse-silu-horizontal-matmul", "mlir::FunctionOpInterface"> { let summary = "Fuses matmul ops and silu linalg.generic op."; - let constructor = "mlir::iree_compiler::GlobalOptimization::createFuseSiluHorizontalMatmulPass()"; } -def GeneralizeLinalgNamedOps : +def GeneralizeLinalgNamedOpsPass : InterfacePass<"iree-global-opt-generalize-linalg-named-ops", "mlir::FunctionOpInterface"> { let summary = "Convert some Linalg named ops into linalg.generics."; - let constructor = "mlir::iree_compiler::GlobalOptimization::createGeneralizeLinalgNamedOpsPass()"; } -def InferNumericNarrowing : +def InferNumericNarrowingPass : Pass<"iree-global-opt-infer-numeric-narrowing", ""> { let summary = "Infers and inserts util.numeric.optional_narrow ops at points that may be beneficial."; - let constructor = "mlir::iree_compiler::GlobalOptimization::createInferNumericNarrowingPass()"; } -def MaterializeHomogeneousEncodings : +def MaterializeHomogeneousEncodingsPass : Pass<"iree-global-opt-materialize-homogeneous-encodings", "mlir::ModuleOp"> { let summary = "Materializes logical encodings to physical encodings if there is a single device target."; - let constructor = - "mlir::iree_compiler::GlobalOptimization::createMaterializeHomogeneousEncodingsPass()"; } -def OptimizeNumerics : +def OptimizeNumericsPass : Pass<"iree-global-opt-optimize-numerics", ""> { let summary = "Optimizes numerics given annotations added via iree-global-opt-infer-numeric-narrowing."; - let constructor = "mlir::iree_compiler::GlobalOptimization::createOptimizeNumericsPass()"; } -def PropagateLinalgTranspose : +def PropagateLinalgTransposePass : InterfacePass<"iree-global-opt-propagate-linalg-transpose", "mlir::FunctionOpInterface"> { let summary = "Propagates linalg.transpose through a restricted set of ops."; - let constructor = "mlir::iree_compiler::GlobalOptimization::createPropagateLinalgTransposePass()"; let options = [ Option<"enableAggressivePropagation", "enable-aggressive-propagation", "bool", /*default=*/"false", "Enable aggressive propagation to named ops.">, + Option<"testSinkingOnly", "test-sinking-only", "bool", /*default=*/"false", + "Flag used for lit-testing sinking patterns only. Not for general usage">, + Option<"testBubblingOnly", "test-bubbling-only", "bool", /*default=*/"false", + "Flag used for lit-testing bubbling patterns only. Not for general usage">, ]; } -def RaiseSpecialOps : +def RaiseSpecialOpsPass : Pass<"iree-global-opt-raise-special-ops", ""> { let summary = "Raises special ops like softmax to the high level linalg.ext representation."; - let constructor = "mlir::iree_compiler::GlobalOptimization::createRaiseSpecialOps()"; } -def RemoveZeroExtentTensors : +def RemoveZeroExtentTensorsPass : InterfacePass<"iree-global-opt-remove-zero-extent-tensors", "mlir::FunctionOpInterface"> { let summary = "Removes tensors that have 0-extents."; - let constructor = "mlir::iree_compiler::GlobalOptimization::createRemoveZeroExtentTensorsPass()"; } -def SimplifyPackUnpack : Pass<"iree-global-opt-simplify-pack-unpack", ""> { +def SimplifyPackUnpackPass : Pass<"iree-global-opt-simplify-pack-unpack", ""> { let summary = "Simplifies tensor pack and unpack ops."; - let constructor = "mlir::iree_compiler::GlobalOptimization::createSimplifyPackUnpackPass()"; } -def GlobalLoopInvariantCodeMotion : InterfacePass<"iree-global-opt-loop-invariant-code-motion", "mlir::FunctionOpInterface"> { +def GlobalLoopInvariantCodeMotionPass : InterfacePass<"iree-global-opt-loop-invariant-code-motion", "mlir::FunctionOpInterface"> { let summary = "Hoist loop invariants out of loops with zero-trip-check."; - let constructor = "mlir::iree_compiler::GlobalOptimization::createGlobalLoopInvariantCodeMotionPass()"; } -def DataLayoutPropagation : InterfacePass<"iree-global-opt-data-layout-propagation", "mlir::FunctionOpInterface"> { +def DataLayoutPropagationPass : InterfacePass<"iree-global-opt-data-layout-propagation", "mlir::FunctionOpInterface"> { let summary = "Propagate pack/unpack ops across other ops to improve fusion"; - let constructor = "mlir::iree_compiler::GlobalOptimization::createDataLayoutPropagationPass()"; } #endif // IREE_COMPILER_GLOBALOPTIMIZATION_PASSES diff --git a/compiler/src/iree/compiler/GlobalOptimization/PropagateLinalgTranspose.cpp b/compiler/src/iree/compiler/GlobalOptimization/PropagateLinalgTranspose.cpp index ac2c35627235..2dea4ad2ae4a 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/PropagateLinalgTranspose.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/PropagateLinalgTranspose.cpp @@ -14,7 +14,6 @@ #include "iree/compiler/Dialect/Flow/Conversion/TensorToFlow/Utils.h" #include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h" #include "iree/compiler/Dialect/Util/IR/UtilOps.h" -#include "iree/compiler/GlobalOptimization/PassDetail.h" #include "iree/compiler/GlobalOptimization/Passes.h" #include "llvm/Support/Debug.h" #include "mlir/Dialect/Arith/IR/Arith.h" @@ -34,6 +33,9 @@ namespace mlir::iree_compiler::GlobalOptimization { +#define GEN_PASS_DEF_PROPAGATELINALGTRANSPOSEPASS +#include "iree/compiler/GlobalOptimization/Passes.h.inc" + //===----------------------------------------------------------------------===// // Transpose permutation helpers //===----------------------------------------------------------------------===// @@ -801,29 +803,18 @@ class NamedOpConversion : public OpRewritePattern { namespace { struct PropagateLinalgTransposePass - : public PropagateLinalgTransposeBase { + : public impl::PropagateLinalgTransposePassBase< + PropagateLinalgTransposePass> { + using impl::PropagateLinalgTransposePassBase< + PropagateLinalgTransposePass>::PropagateLinalgTransposePassBase; void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); } - PropagateLinalgTransposePass(bool enableAggressivePropagation) { + explicit PropagateLinalgTransposePass(bool enableAggressivePropagation) { this->enableAggressivePropagation = enableAggressivePropagation; } - PropagateLinalgTransposePass(const PropagateLinalgTransposePass &pass) - : PropagateLinalgTransposePass(pass.enableAggressivePropagation) {} void runOnOperation() override; - -private: - Option testSinkingOnly{ - *this, "test-sinking-only", - llvm::cl::desc("Flag used for lit-testing sinking patterns only. " - "Not for general usage"), - llvm::cl::init(false)}; - Option testBubblingOnly{ - *this, "test-bubbling-only", - llvm::cl::desc("Flag used for lit-testing bubbling patterns only. " - "Not for general usage"), - llvm::cl::init(false)}; }; } // namespace diff --git a/compiler/src/iree/compiler/GlobalOptimization/RaiseSpecialOps.cpp b/compiler/src/iree/compiler/GlobalOptimization/RaiseSpecialOps.cpp index 8c267e204b41..b7008690b6ab 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/RaiseSpecialOps.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/RaiseSpecialOps.cpp @@ -9,7 +9,6 @@ #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h" #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h" #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h" -#include "iree/compiler/GlobalOptimization/PassDetail.h" #include "iree/compiler/GlobalOptimization/Passes.h" #include "iree/compiler/GlobalOptimization/Utils.h" #include "llvm/ADT/STLExtras.h" @@ -32,6 +31,9 @@ using transform_ext::StructuredOpMatcher; namespace mlir::iree_compiler::GlobalOptimization { +#define GEN_PASS_DEF_RAISESPECIALOPSPASS +#include "iree/compiler/GlobalOptimization/Passes.h.inc" + namespace { //===----------------------------------------------------------------------===// @@ -998,7 +1000,8 @@ class ConcatenateNegateAndSlicePattern // Pass Implementation //===----------------------------------------------------------------------===// -struct RaiseSpecialOpsPass : public RaiseSpecialOpsBase { +struct RaiseSpecialOpsPass + : public impl::RaiseSpecialOpsPassBase { void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); } @@ -1054,9 +1057,4 @@ struct RaiseSpecialOpsPass : public RaiseSpecialOpsBase { }; } // namespace - -std::unique_ptr createRaiseSpecialOps() { - return std::make_unique(); -} - } // namespace mlir::iree_compiler::GlobalOptimization diff --git a/compiler/src/iree/compiler/GlobalOptimization/RemoveZeroExtentTensors.cpp b/compiler/src/iree/compiler/GlobalOptimization/RemoveZeroExtentTensors.cpp index 9cb5929db000..b6d82c71d522 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/RemoveZeroExtentTensors.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/RemoveZeroExtentTensors.cpp @@ -4,7 +4,6 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -#include "iree/compiler/GlobalOptimization/PassDetail.h" #include "iree/compiler/GlobalOptimization/Passes.h" #include "mlir/Dialect/MemRef/Transforms/Transforms.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -14,6 +13,9 @@ namespace mlir::iree_compiler::GlobalOptimization { +#define GEN_PASS_DEF_REMOVEZEROEXTENTTENSORSPASS +#include "iree/compiler/GlobalOptimization/Passes.h.inc" + /// Check if a `t` is a `tensor` with zero extents. static std::optional isZeroExtent(Type t) { auto operandType = dyn_cast(t); @@ -77,7 +79,7 @@ struct FoldZeroExtentInserts : public OpRewritePattern { namespace { struct RemoveZeroExtentTensorsPass - : RemoveZeroExtentTensorsBase { + : impl::RemoveZeroExtentTensorsPassBase { void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); } @@ -101,9 +103,4 @@ void RemoveZeroExtentTensorsPass::runOnOperation() { } } -std::unique_ptr> -createRemoveZeroExtentTensorsPass() { - return std::make_unique(); -} - } // namespace mlir::iree_compiler::GlobalOptimization diff --git a/compiler/src/iree/compiler/GlobalOptimization/SimplifyPackUnpack.cpp b/compiler/src/iree/compiler/GlobalOptimization/SimplifyPackUnpack.cpp index 9333f9c712a9..3bd113fd4dd8 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/SimplifyPackUnpack.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/SimplifyPackUnpack.cpp @@ -4,7 +4,6 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -#include "iree/compiler/GlobalOptimization/PassDetail.h" #include "iree/compiler/GlobalOptimization/Passes.h" #include "mlir/Dialect/Tensor/Transforms/Transforms.h" #include "mlir/Pass/Pass.h" @@ -12,9 +11,12 @@ namespace mlir::iree_compiler::GlobalOptimization { +#define GEN_PASS_DEF_SIMPLIFYPACKUNPACKPASS +#include "iree/compiler/GlobalOptimization/Passes.h.inc" + namespace { struct SimplifyPackUnpackPass - : public SimplifyPackUnpackBase { + : public impl::SimplifyPackUnpackPassBase { void runOnOperation() override; }; @@ -30,8 +32,4 @@ void SimplifyPackUnpackPass::runOnOperation() { } } -std::unique_ptr createSimplifyPackUnpackPass() { - return std::make_unique(); -} - } // namespace mlir::iree_compiler::GlobalOptimization diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/fuse_dequantization_matmul.mlir b/compiler/src/iree/compiler/GlobalOptimization/test/fuse_dequantization_matmul.mlir index 403729597504..3ad78055fe8d 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/test/fuse_dequantization_matmul.mlir +++ b/compiler/src/iree/compiler/GlobalOptimization/test/fuse_dequantization_matmul.mlir @@ -1,4 +1,4 @@ -// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(util.func(iree-global-opt-fuse-dequantization-matmul{enable-quantized-matmul-reassociation=true},iree-flow-canonicalize))" %s | FileCheck %s +// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(util.func(iree-global-opt-fuse-dequantization-matmul,iree-flow-canonicalize))" %s | FileCheck %s util.func public @grouped_quantized_matmul_reassociate(%arg0: tensor<11008x32x128xi4>, %arg1: tensor<32x128xf32>, %arg2: tensor<11008x32xf32>, %arg3: tensor<11008x32xf32>) -> tensor<11008xf32> { %cst = arith.constant 0.000000e+00 : f32