diff --git a/compiler/src/iree/compiler/Codegen/VMVX/BUILD.bazel b/compiler/src/iree/compiler/Codegen/VMVX/BUILD.bazel index 816c3f8b5fec..6ec47bd1d906 100644 --- a/compiler/src/iree/compiler/Codegen/VMVX/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/VMVX/BUILD.bazel @@ -28,7 +28,6 @@ iree_gentbl_cc_library( iree_compiler_cc_library( name = "PassHeaders", hdrs = [ - "PassDetail.h", "Passes.h", "Passes.h.inc", ], @@ -49,11 +48,11 @@ iree_compiler_cc_library( name = "VMVX", srcs = [ "KernelDispatch.cpp", - "LowerLinalgMicrokernels.cpp", "Passes.cpp", "VMVXAssignConstantOrdinals.cpp", "VMVXLinkExecutables.cpp", "VMVXLowerExecutableTargetPass.cpp", + "VMVXLowerLinalgMicrokernels.cpp", "VMVXSelectLoweringStrategy.cpp", ], hdrs = [ diff --git a/compiler/src/iree/compiler/Codegen/VMVX/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/VMVX/CMakeLists.txt index d3cf8e79f3d3..2e6b1195c4f7 100644 --- a/compiler/src/iree/compiler/Codegen/VMVX/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/VMVX/CMakeLists.txt @@ -23,7 +23,6 @@ iree_cc_library( NAME PassHeaders HDRS - "PassDetail.h" "Passes.h" "Passes.h.inc" DEPS @@ -47,11 +46,11 @@ iree_cc_library( "Passes.h" SRCS "KernelDispatch.cpp" - "LowerLinalgMicrokernels.cpp" "Passes.cpp" "VMVXAssignConstantOrdinals.cpp" "VMVXLinkExecutables.cpp" "VMVXLowerExecutableTargetPass.cpp" + "VMVXLowerLinalgMicrokernels.cpp" "VMVXSelectLoweringStrategy.cpp" DEPS ::PassHeaders diff --git a/compiler/src/iree/compiler/Codegen/VMVX/PassDetail.h b/compiler/src/iree/compiler/Codegen/VMVX/PassDetail.h deleted file mode 100644 index 58a2fa1e362b..000000000000 --- a/compiler/src/iree/compiler/Codegen/VMVX/PassDetail.h +++ /dev/null @@ -1,23 +0,0 @@ -// Copyright 2023 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#ifndef IREE_COMPILER_CODEGEN_VMVX_PASS_DETAIL_H_ -#define IREE_COMPILER_CODEGEN_VMVX_PASS_DETAIL_H_ - -#include "iree/compiler/Dialect/HAL/IR/HALOps.h" -#include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Interfaces/FunctionInterfaces.h" -#include "mlir/Pass/Pass.h" - -namespace mlir::iree_compiler { - -#define GEN_PASS_CLASSES -#include "iree/compiler/Codegen/VMVX/Passes.h.inc" - -} // namespace mlir::iree_compiler - -#endif // IREE_COMPILER_CODEGEN_VMVX_PASS_DETAIL_H_ diff --git a/compiler/src/iree/compiler/Codegen/VMVX/Passes.cpp b/compiler/src/iree/compiler/Codegen/VMVX/Passes.cpp index a18598328aa6..ef169fd9ccb4 100644 --- a/compiler/src/iree/compiler/Codegen/VMVX/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/VMVX/Passes.cpp @@ -8,7 +8,6 @@ #include "iree/compiler/Codegen/Common/CPU/Passes.h" #include "iree/compiler/Codegen/Common/Passes.h" -#include "iree/compiler/Codegen/VMVX/PassDetail.h" #include "iree/compiler/Codegen/VMVX/Passes.h" #include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h" #include "mlir/Dialect/Func/IR/FuncOps.h" diff --git a/compiler/src/iree/compiler/Codegen/VMVX/Passes.h b/compiler/src/iree/compiler/Codegen/VMVX/Passes.h index 63aa1db7acc1..98b27f6ab83b 100644 --- a/compiler/src/iree/compiler/Codegen/VMVX/Passes.h +++ b/compiler/src/iree/compiler/Codegen/VMVX/Passes.h @@ -19,26 +19,9 @@ namespace mlir::iree_compiler { //------------------------------------------------------------------------------ -// VMVX passes +// VMVX Pass Pipelines //------------------------------------------------------------------------------ -// Lowers high level library calls from named ops and generics. This operates -// at the bufferized linalg level. -std::unique_ptr createVMVXLowerLinalgMicrokernelsPass(); - -/// Materialize the encoding of operations. The layout to use for the encoded -/// operations are VMVX specific. -std::unique_ptr> -createVMVXMaterializeEncodingPass(); - -/// Pass to select a lowering strategy for a hal.executable.variant operation. -std::unique_ptr> createVMVXSelectLoweringStrategyPass(); - -/// Pass to lower the module an hal.executable.variant operation to external -/// dialect. -std::unique_ptr> -createVMVXLowerExecutableTargetPass(); - /// Populates the passes to lower to tiled/distributed/bufferized ops, /// suitable for library call dispatch and lowering to loops. void addVMVXDefaultPassPipeline(OpPassManager &funcPassManager, @@ -48,13 +31,6 @@ void addVMVXDefaultPassPipeline(OpPassManager &funcPassManager, // VMVX Linking Passes and Pipelines //----------------------------------------------------------------------------// -/// Assigns executable constant ordinals across all VMVX variants. -std::unique_ptr> -createVMVXAssignConstantOrdinalsPass(); - -/// Links VMVX HAL executables within the top-level program module. -std::unique_ptr> createVMVXLinkExecutablesPass(); - /// Populates passes needed to link HAL executables across VMVX targets. void buildVMVXLinkingPassPipeline(OpPassManager &variantPassManager); @@ -62,6 +38,9 @@ void buildVMVXLinkingPassPipeline(OpPassManager &variantPassManager); // Register VMVX Passes //----------------------------------------------------------------------------// +#define GEN_PASS_DECL +#include "iree/compiler/Codegen/VMVX/Passes.h.inc" // IWYU pragma: keep + void registerCodegenVMVXPasses(); } // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/VMVX/Passes.td b/compiler/src/iree/compiler/Codegen/VMVX/Passes.td index b8b703a957d5..19b3c6a8c364 100644 --- a/compiler/src/iree/compiler/Codegen/VMVX/Passes.td +++ b/compiler/src/iree/compiler/Codegen/VMVX/Passes.td @@ -13,39 +13,32 @@ include "mlir/Pass/PassBase.td" // VMVX passes (keep alphabetical) //===---------------------------------------------------------------------===// -def VMVXAssignConstantOrdinals : +def VMVXAssignConstantOrdinalsPass : Pass<"iree-vmvx-assign-constant-ordinals", "IREE::HAL::ExecutableVariantOp"> { let summary = "Assigns executable constant ordinals across all VMVX variants."; - let constructor = "mlir::iree_compiler::createVMVXAssignConstantOrdinalsPass()"; } -def VMVXSelectLoweringStrategy : +def VMVXSelectLoweringStrategyPass : Pass<"iree-vmvx-select-lowering-strategy", "ModuleOp"> { let summary = "Select a IREE::HAL::DispatchLoweringPassPipeline for lowering the variant"; - let constructor = - "mlir::iree_compiler::createVMVXSelectLoweringStrategyPass()"; } -def VMVXLinkExecutables : +def VMVXLinkExecutablesPass : Pass<"iree-vmvx-link-executables", "mlir::ModuleOp"> { let summary = "Links VMVX HAL executables within the top-level program module."; - let constructor = "mlir::iree_compiler::createVMVXLinkExecutablesPass()"; } -def VMVXLowerExecutableTarget : +def VMVXLowerExecutableTargetPass : InterfacePass<"iree-vmvx-lower-executable-target", "mlir::FunctionOpInterface"> { let summary = "Lower executable target using an IREE::HAL::DispatchLoweringPassPipeline"; - let constructor = - "mlir::iree_compiler::createVMVXLowerExecutableTargetPass()"; } -def VMVXLowerLinalgMicrokernels : +def VMVXLowerLinalgMicrokernelsPass : Pass<"iree-vmvx-lower-linalg-microkernels", ""> { let summary = "Lowers linalg ops to the VMVX microkernel library"; - let constructor = "mlir::iree_compiler::createVMVXLowerLinalgMicrokernelsPass()"; let options = [ Option<"warnOnUnconverted", "warn-on-unconverted", "bool", /*default=*/"false", diff --git a/compiler/src/iree/compiler/Codegen/VMVX/VMVXAssignConstantOrdinals.cpp b/compiler/src/iree/compiler/Codegen/VMVX/VMVXAssignConstantOrdinals.cpp index ae7f5000760b..dd9f1751060c 100644 --- a/compiler/src/iree/compiler/Codegen/VMVX/VMVXAssignConstantOrdinals.cpp +++ b/compiler/src/iree/compiler/Codegen/VMVX/VMVXAssignConstantOrdinals.cpp @@ -4,18 +4,20 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -#include "iree/compiler/Codegen/VMVX/PassDetail.h" #include "iree/compiler/Codegen/VMVX/Passes.h" #include "iree/compiler/Dialect/VM/IR/VMOps.h" #include "mlir/Pass/Pass.h" namespace mlir::iree_compiler { +#define GEN_PASS_DEF_VMVXASSIGNCONSTANTORDINALSPASS +#include "iree/compiler/Codegen/VMVX/Passes.h.inc" + namespace { struct VMVXAssignConstantOrdinalsPass - : public VMVXAssignConstantOrdinalsBase { - VMVXAssignConstantOrdinalsPass() = default; + : public impl::VMVXAssignConstantOrdinalsPassBase< + VMVXAssignConstantOrdinalsPass> { void runOnOperation() override { auto variantOp = getOperation(); @@ -56,10 +58,4 @@ struct VMVXAssignConstantOrdinalsPass }; } // namespace - -std::unique_ptr> -createVMVXAssignConstantOrdinalsPass() { - return std::make_unique(); -} - } // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/VMVX/VMVXLinkExecutables.cpp b/compiler/src/iree/compiler/Codegen/VMVX/VMVXLinkExecutables.cpp index fc25fcca5a6e..571378ad1056 100644 --- a/compiler/src/iree/compiler/Codegen/VMVX/VMVXLinkExecutables.cpp +++ b/compiler/src/iree/compiler/Codegen/VMVX/VMVXLinkExecutables.cpp @@ -5,7 +5,6 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception #include "iree/compiler/Codegen/Utils/LinkingUtils.h" -#include "iree/compiler/Codegen/VMVX/PassDetail.h" #include "iree/compiler/Codegen/VMVX/Passes.h" #include "iree/compiler/Dialect/VM/IR/VMOps.h" #include "iree/compiler/Utils/ModuleUtils.h" @@ -14,11 +13,13 @@ namespace mlir::iree_compiler { +#define GEN_PASS_DEF_VMVXLINKEXECUTABLESPASS +#include "iree/compiler/Codegen/VMVX/Passes.h.inc" + namespace { struct VMVXLinkExecutablesPass - : public VMVXLinkExecutablesBase { - VMVXLinkExecutablesPass() = default; + : public impl::VMVXLinkExecutablesPassBase { void runOnOperation() override { auto moduleOp = getOperation(); auto moduleBuilder = OpBuilder::atBlockBegin(moduleOp.getBody()); @@ -79,9 +80,4 @@ struct VMVXLinkExecutablesPass }; } // namespace - -std::unique_ptr> createVMVXLinkExecutablesPass() { - return std::make_unique(); -} - } // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/VMVX/VMVXLowerExecutableTargetPass.cpp b/compiler/src/iree/compiler/Codegen/VMVX/VMVXLowerExecutableTargetPass.cpp index e487e9871b0d..94d6fbb9bfd8 100644 --- a/compiler/src/iree/compiler/Codegen/VMVX/VMVXLowerExecutableTargetPass.cpp +++ b/compiler/src/iree/compiler/Codegen/VMVX/VMVXLowerExecutableTargetPass.cpp @@ -8,7 +8,6 @@ #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h" #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h" #include "iree/compiler/Codegen/Utils/Utils.h" -#include "iree/compiler/Codegen/VMVX/PassDetail.h" #include "iree/compiler/Codegen/VMVX/Passes.h" #include "iree/compiler/Dialect/HAL/IR/HALDialect.h" #include "iree/compiler/Dialect/HAL/IR/HALOps.h" @@ -23,18 +22,18 @@ #define DEBUG_TYPE "iree-vmvx-lower-executable-target" -using mlir::iree_compiler::IREE::Codegen::LoweringConfigAttr; - namespace mlir::iree_compiler { +#define GEN_PASS_DEF_VMVXLOWEREXECUTABLETARGETPASS +#include "iree/compiler/Codegen/VMVX/Passes.h.inc" + namespace { /// Lowers an hal.executable.variant operation to scalar/native-vector code. class VMVXLowerExecutableTargetPass - : public VMVXLowerExecutableTargetBase { + : public impl::VMVXLowerExecutableTargetPassBase< + VMVXLowerExecutableTargetPass> { public: - VMVXLowerExecutableTargetPass() = default; - VMVXLowerExecutableTargetPass(const VMVXLowerExecutableTargetPass &pass) {} void getDependentDialects(DialectRegistry ®istry) const override { // clang-format off registry.insert> -createVMVXLowerExecutableTargetPass() { - return std::make_unique(); -} - } // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/VMVX/LowerLinalgMicrokernels.cpp b/compiler/src/iree/compiler/Codegen/VMVX/VMVXLowerLinalgMicrokernels.cpp similarity index 99% rename from compiler/src/iree/compiler/Codegen/VMVX/LowerLinalgMicrokernels.cpp rename to compiler/src/iree/compiler/Codegen/VMVX/VMVXLowerLinalgMicrokernels.cpp index fc84b24db00f..8d7230ec979d 100644 --- a/compiler/src/iree/compiler/Codegen/VMVX/LowerLinalgMicrokernels.cpp +++ b/compiler/src/iree/compiler/Codegen/VMVX/VMVXLowerLinalgMicrokernels.cpp @@ -4,7 +4,6 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -#include "iree/compiler/Codegen/VMVX/PassDetail.h" #include "iree/compiler/Codegen/VMVX/Passes.h" #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h" #include "iree/compiler/Dialect/Util/IR/UtilDialect.h" @@ -24,6 +23,9 @@ namespace mlir::iree_compiler { +#define GEN_PASS_DEF_VMVXLOWERLINALGMICROKERNELSPASS +#include "iree/compiler/Codegen/VMVX/Passes.h.inc" + namespace { // Permutes raw strides against a projected permutation map returning a @@ -918,7 +920,11 @@ struct LinalgFillConversion : public OpRewritePattern { } // namespace class VMVXLowerLinalgMicrokernelsPass - : public VMVXLowerLinalgMicrokernelsBase { + : public impl::VMVXLowerLinalgMicrokernelsPassBase< + VMVXLowerLinalgMicrokernelsPass> { + using impl::VMVXLowerLinalgMicrokernelsPassBase< + VMVXLowerLinalgMicrokernelsPass>::VMVXLowerLinalgMicrokernelsPassBase; + void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); @@ -948,9 +954,4 @@ class VMVXLowerLinalgMicrokernelsPass } } }; - -std::unique_ptr createVMVXLowerLinalgMicrokernelsPass() { - return std::make_unique(); -} - } // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/VMVX/VMVXSelectLoweringStrategy.cpp b/compiler/src/iree/compiler/Codegen/VMVX/VMVXSelectLoweringStrategy.cpp index 376622e0d1e3..7ab5d266dbdb 100644 --- a/compiler/src/iree/compiler/Codegen/VMVX/VMVXSelectLoweringStrategy.cpp +++ b/compiler/src/iree/compiler/Codegen/VMVX/VMVXSelectLoweringStrategy.cpp @@ -7,7 +7,6 @@ #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h" #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h" #include "iree/compiler/Codegen/VMVX/KernelDispatch.h" -#include "iree/compiler/Codegen/VMVX/PassDetail.h" #include "iree/compiler/Codegen/VMVX/Passes.h" #include "iree/compiler/Dialect/HAL/IR/HALDialect.h" #include "iree/compiler/Dialect/HAL/IR/HALOps.h" @@ -25,13 +24,15 @@ using mlir::iree_compiler::IREE::Codegen::LoweringConfigAttr; namespace mlir::iree_compiler { +#define GEN_PASS_DEF_VMVXSELECTLOWERINGSTRATEGYPASS +#include "iree/compiler/Codegen/VMVX/Passes.h.inc" + namespace { /// Selects the lowering strategy for a hal.executable.variant operation. class VMVXSelectLoweringStrategyPass - : public VMVXSelectLoweringStrategyBase { + : public impl::VMVXSelectLoweringStrategyPassBase< + VMVXSelectLoweringStrategyPass> { public: - VMVXSelectLoweringStrategyPass() = default; - VMVXSelectLoweringStrategyPass(const VMVXSelectLoweringStrategyPass &pass) {} void getDependentDialects(DialectRegistry ®istry) const override { // TODO(qedawkins): Once TransformStrategies is deprecated, drop the // unnecessary dialect registrations. @@ -61,10 +62,4 @@ void VMVXSelectLoweringStrategyPass::runOnOperation() { } } } - -std::unique_ptr> -createVMVXSelectLoweringStrategyPass() { - return std::make_unique(); -} - } // namespace mlir::iree_compiler