Skip to content

Commit

Permalink
[VMVX] Switch to new pass generation tablegen definitions (#18149)
Browse files Browse the repository at this point in the history
This should be NFC. The additional changes:

- Removes an unused using-declaration from
VMVXLowerExecutableTargetPass.
- Rename LowerLinalgMicrokernels.cpp to match the pass name (i.e., it
becomes VMVXLowerLinalgMicrokernels.cpp).

---------

Signed-off-by: hanhanW <[email protected]>
  • Loading branch information
hanhanW committed Aug 8, 2024
1 parent 050a449 commit 4bea50e
Show file tree
Hide file tree
Showing 11 changed files with 38 additions and 111 deletions.
3 changes: 1 addition & 2 deletions compiler/src/iree/compiler/Codegen/VMVX/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ iree_gentbl_cc_library(
iree_compiler_cc_library(
name = "PassHeaders",
hdrs = [
"PassDetail.h",
"Passes.h",
"Passes.h.inc",
],
Expand All @@ -49,11 +48,11 @@ iree_compiler_cc_library(
name = "VMVX",
srcs = [
"KernelDispatch.cpp",
"LowerLinalgMicrokernels.cpp",
"Passes.cpp",
"VMVXAssignConstantOrdinals.cpp",
"VMVXLinkExecutables.cpp",
"VMVXLowerExecutableTargetPass.cpp",
"VMVXLowerLinalgMicrokernels.cpp",
"VMVXSelectLoweringStrategy.cpp",
],
hdrs = [
Expand Down
3 changes: 1 addition & 2 deletions compiler/src/iree/compiler/Codegen/VMVX/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ iree_cc_library(
NAME
PassHeaders
HDRS
"PassDetail.h"
"Passes.h"
"Passes.h.inc"
DEPS
Expand All @@ -47,11 +46,11 @@ iree_cc_library(
"Passes.h"
SRCS
"KernelDispatch.cpp"
"LowerLinalgMicrokernels.cpp"
"Passes.cpp"
"VMVXAssignConstantOrdinals.cpp"
"VMVXLinkExecutables.cpp"
"VMVXLowerExecutableTargetPass.cpp"
"VMVXLowerLinalgMicrokernels.cpp"
"VMVXSelectLoweringStrategy.cpp"
DEPS
::PassHeaders
Expand Down
23 changes: 0 additions & 23 deletions compiler/src/iree/compiler/Codegen/VMVX/PassDetail.h

This file was deleted.

1 change: 0 additions & 1 deletion compiler/src/iree/compiler/Codegen/VMVX/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

#include "iree/compiler/Codegen/Common/CPU/Passes.h"
#include "iree/compiler/Codegen/Common/Passes.h"
#include "iree/compiler/Codegen/VMVX/PassDetail.h"
#include "iree/compiler/Codegen/VMVX/Passes.h"
#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
Expand Down
29 changes: 4 additions & 25 deletions compiler/src/iree/compiler/Codegen/VMVX/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,26 +19,9 @@
namespace mlir::iree_compiler {

//------------------------------------------------------------------------------
// VMVX passes
// VMVX Pass Pipelines
//------------------------------------------------------------------------------

// Lowers high level library calls from named ops and generics. This operates
// at the bufferized linalg level.
std::unique_ptr<Pass> createVMVXLowerLinalgMicrokernelsPass();

/// Materialize the encoding of operations. The layout to use for the encoded
/// operations are VMVX specific.
std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createVMVXMaterializeEncodingPass();

/// Pass to select a lowering strategy for a hal.executable.variant operation.
std::unique_ptr<OperationPass<ModuleOp>> createVMVXSelectLoweringStrategyPass();

/// Pass to lower the module an hal.executable.variant operation to external
/// dialect.
std::unique_ptr<InterfacePass<FunctionOpInterface>>
createVMVXLowerExecutableTargetPass();

/// Populates the passes to lower to tiled/distributed/bufferized ops,
/// suitable for library call dispatch and lowering to loops.
void addVMVXDefaultPassPipeline(OpPassManager &funcPassManager,
Expand All @@ -48,20 +31,16 @@ void addVMVXDefaultPassPipeline(OpPassManager &funcPassManager,
// VMVX Linking Passes and Pipelines
//----------------------------------------------------------------------------//

/// Assigns executable constant ordinals across all VMVX variants.
std::unique_ptr<OperationPass<IREE::HAL::ExecutableVariantOp>>
createVMVXAssignConstantOrdinalsPass();

/// Links VMVX HAL executables within the top-level program module.
std::unique_ptr<OperationPass<mlir::ModuleOp>> createVMVXLinkExecutablesPass();

/// Populates passes needed to link HAL executables across VMVX targets.
void buildVMVXLinkingPassPipeline(OpPassManager &variantPassManager);

//----------------------------------------------------------------------------//
// Register VMVX Passes
//----------------------------------------------------------------------------//

#define GEN_PASS_DECL
#include "iree/compiler/Codegen/VMVX/Passes.h.inc" // IWYU pragma: keep

void registerCodegenVMVXPasses();

} // namespace mlir::iree_compiler
Expand Down
17 changes: 5 additions & 12 deletions compiler/src/iree/compiler/Codegen/VMVX/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -13,39 +13,32 @@ include "mlir/Pass/PassBase.td"
// VMVX passes (keep alphabetical)
//===---------------------------------------------------------------------===//

def VMVXAssignConstantOrdinals :
def VMVXAssignConstantOrdinalsPass :
Pass<"iree-vmvx-assign-constant-ordinals", "IREE::HAL::ExecutableVariantOp"> {
let summary = "Assigns executable constant ordinals across all VMVX variants.";
let constructor = "mlir::iree_compiler::createVMVXAssignConstantOrdinalsPass()";
}

def VMVXSelectLoweringStrategy :
def VMVXSelectLoweringStrategyPass :
Pass<"iree-vmvx-select-lowering-strategy", "ModuleOp"> {
let summary =
"Select a IREE::HAL::DispatchLoweringPassPipeline for lowering the variant";
let constructor =
"mlir::iree_compiler::createVMVXSelectLoweringStrategyPass()";
}

def VMVXLinkExecutables :
def VMVXLinkExecutablesPass :
Pass<"iree-vmvx-link-executables", "mlir::ModuleOp"> {
let summary = "Links VMVX HAL executables within the top-level program module.";
let constructor = "mlir::iree_compiler::createVMVXLinkExecutablesPass()";
}

def VMVXLowerExecutableTarget :
def VMVXLowerExecutableTargetPass :
InterfacePass<"iree-vmvx-lower-executable-target", "mlir::FunctionOpInterface"> {
let summary =
"Lower executable target using an IREE::HAL::DispatchLoweringPassPipeline";
let constructor =
"mlir::iree_compiler::createVMVXLowerExecutableTargetPass()";
}

def VMVXLowerLinalgMicrokernels :
def VMVXLowerLinalgMicrokernelsPass :
Pass<"iree-vmvx-lower-linalg-microkernels", ""> {
let summary =
"Lowers linalg ops to the VMVX microkernel library";
let constructor = "mlir::iree_compiler::createVMVXLowerLinalgMicrokernelsPass()";
let options = [
Option<"warnOnUnconverted", "warn-on-unconverted", "bool",
/*default=*/"false",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,20 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Codegen/VMVX/PassDetail.h"
#include "iree/compiler/Codegen/VMVX/Passes.h"
#include "iree/compiler/Dialect/VM/IR/VMOps.h"
#include "mlir/Pass/Pass.h"

namespace mlir::iree_compiler {

#define GEN_PASS_DEF_VMVXASSIGNCONSTANTORDINALSPASS
#include "iree/compiler/Codegen/VMVX/Passes.h.inc"

namespace {

struct VMVXAssignConstantOrdinalsPass
: public VMVXAssignConstantOrdinalsBase<VMVXAssignConstantOrdinalsPass> {
VMVXAssignConstantOrdinalsPass() = default;
: public impl::VMVXAssignConstantOrdinalsPassBase<
VMVXAssignConstantOrdinalsPass> {
void runOnOperation() override {
auto variantOp = getOperation();

Expand Down Expand Up @@ -56,10 +58,4 @@ struct VMVXAssignConstantOrdinalsPass
};

} // namespace

std::unique_ptr<OperationPass<IREE::HAL::ExecutableVariantOp>>
createVMVXAssignConstantOrdinalsPass() {
return std::make_unique<VMVXAssignConstantOrdinalsPass>();
}

} // namespace mlir::iree_compiler
12 changes: 4 additions & 8 deletions compiler/src/iree/compiler/Codegen/VMVX/VMVXLinkExecutables.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Codegen/Utils/LinkingUtils.h"
#include "iree/compiler/Codegen/VMVX/PassDetail.h"
#include "iree/compiler/Codegen/VMVX/Passes.h"
#include "iree/compiler/Dialect/VM/IR/VMOps.h"
#include "iree/compiler/Utils/ModuleUtils.h"
Expand All @@ -14,11 +13,13 @@

namespace mlir::iree_compiler {

#define GEN_PASS_DEF_VMVXLINKEXECUTABLESPASS
#include "iree/compiler/Codegen/VMVX/Passes.h.inc"

namespace {

struct VMVXLinkExecutablesPass
: public VMVXLinkExecutablesBase<VMVXLinkExecutablesPass> {
VMVXLinkExecutablesPass() = default;
: public impl::VMVXLinkExecutablesPassBase<VMVXLinkExecutablesPass> {
void runOnOperation() override {
auto moduleOp = getOperation();
auto moduleBuilder = OpBuilder::atBlockBegin(moduleOp.getBody());
Expand Down Expand Up @@ -79,9 +80,4 @@ struct VMVXLinkExecutablesPass
};

} // namespace

std::unique_ptr<OperationPass<mlir::ModuleOp>> createVMVXLinkExecutablesPass() {
return std::make_unique<VMVXLinkExecutablesPass>();
}

} // namespace mlir::iree_compiler
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h"
#include "iree/compiler/Codegen/Utils/Utils.h"
#include "iree/compiler/Codegen/VMVX/PassDetail.h"
#include "iree/compiler/Codegen/VMVX/Passes.h"
#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
Expand All @@ -23,18 +22,18 @@

#define DEBUG_TYPE "iree-vmvx-lower-executable-target"

using mlir::iree_compiler::IREE::Codegen::LoweringConfigAttr;

namespace mlir::iree_compiler {

#define GEN_PASS_DEF_VMVXLOWEREXECUTABLETARGETPASS
#include "iree/compiler/Codegen/VMVX/Passes.h.inc"

namespace {

/// Lowers an hal.executable.variant operation to scalar/native-vector code.
class VMVXLowerExecutableTargetPass
: public VMVXLowerExecutableTargetBase<VMVXLowerExecutableTargetPass> {
: public impl::VMVXLowerExecutableTargetPassBase<
VMVXLowerExecutableTargetPass> {
public:
VMVXLowerExecutableTargetPass() = default;
VMVXLowerExecutableTargetPass(const VMVXLowerExecutableTargetPass &pass) {}
void getDependentDialects(DialectRegistry &registry) const override {
// clang-format off
registry.insert<IREE::HAL::HALDialect,
Expand Down Expand Up @@ -89,10 +88,4 @@ void VMVXLowerExecutableTargetPass::runOnOperation() {
return signalPassFailure();
}
}

std::unique_ptr<InterfacePass<FunctionOpInterface>>
createVMVXLowerExecutableTargetPass() {
return std::make_unique<VMVXLowerExecutableTargetPass>();
}

} // namespace mlir::iree_compiler
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Codegen/VMVX/PassDetail.h"
#include "iree/compiler/Codegen/VMVX/Passes.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
Expand All @@ -24,6 +23,9 @@

namespace mlir::iree_compiler {

#define GEN_PASS_DEF_VMVXLOWERLINALGMICROKERNELSPASS
#include "iree/compiler/Codegen/VMVX/Passes.h.inc"

namespace {

// Permutes raw strides against a projected permutation map returning a
Expand Down Expand Up @@ -918,7 +920,11 @@ struct LinalgFillConversion : public OpRewritePattern<linalg::FillOp> {
} // namespace

class VMVXLowerLinalgMicrokernelsPass
: public VMVXLowerLinalgMicrokernelsBase<VMVXLowerLinalgMicrokernelsPass> {
: public impl::VMVXLowerLinalgMicrokernelsPassBase<
VMVXLowerLinalgMicrokernelsPass> {
using impl::VMVXLowerLinalgMicrokernelsPassBase<
VMVXLowerLinalgMicrokernelsPass>::VMVXLowerLinalgMicrokernelsPassBase;

void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<IREE::Util::UtilDialect, IREE::VMVX::VMVXDialect,
memref::MemRefDialect>();
Expand Down Expand Up @@ -948,9 +954,4 @@ class VMVXLowerLinalgMicrokernelsPass
}
}
};

std::unique_ptr<Pass> createVMVXLowerLinalgMicrokernelsPass() {
return std::make_unique<VMVXLowerLinalgMicrokernelsPass>();
}

} // namespace mlir::iree_compiler
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h"
#include "iree/compiler/Codegen/VMVX/KernelDispatch.h"
#include "iree/compiler/Codegen/VMVX/PassDetail.h"
#include "iree/compiler/Codegen/VMVX/Passes.h"
#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
Expand All @@ -25,13 +24,15 @@ using mlir::iree_compiler::IREE::Codegen::LoweringConfigAttr;

namespace mlir::iree_compiler {

#define GEN_PASS_DEF_VMVXSELECTLOWERINGSTRATEGYPASS
#include "iree/compiler/Codegen/VMVX/Passes.h.inc"

namespace {
/// Selects the lowering strategy for a hal.executable.variant operation.
class VMVXSelectLoweringStrategyPass
: public VMVXSelectLoweringStrategyBase<VMVXSelectLoweringStrategyPass> {
: public impl::VMVXSelectLoweringStrategyPassBase<
VMVXSelectLoweringStrategyPass> {
public:
VMVXSelectLoweringStrategyPass() = default;
VMVXSelectLoweringStrategyPass(const VMVXSelectLoweringStrategyPass &pass) {}
void getDependentDialects(DialectRegistry &registry) const override {
// TODO(qedawkins): Once TransformStrategies is deprecated, drop the
// unnecessary dialect registrations.
Expand Down Expand Up @@ -61,10 +62,4 @@ void VMVXSelectLoweringStrategyPass::runOnOperation() {
}
}
}

std::unique_ptr<OperationPass<ModuleOp>>
createVMVXSelectLoweringStrategyPass() {
return std::make_unique<VMVXSelectLoweringStrategyPass>();
}

} // namespace mlir::iree_compiler

0 comments on commit 4bea50e

Please sign in to comment.