Skip to content

Commit

Permalink
[VectorExt] Teach to_layout vectorization
Browse files Browse the repository at this point in the history
  • Loading branch information
Groverkss committed Aug 2, 2024
1 parent 98c9319 commit b1d82d4
Show file tree
Hide file tree
Showing 11 changed files with 298 additions and 1 deletion.
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ iree_compiler_cc_library(
"//compiler/src/iree/compiler/Codegen/Common/GPU:CommonGPUPasses",
"//compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR:IREECodegenDialect",
"//compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms:GPUTransforms",
"//compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms:VectorExtTransforms",
"//compiler/src/iree/compiler/Codegen/LLVMCPU",
"//compiler/src/iree/compiler/Codegen/LLVMGPU",
"//compiler/src/iree/compiler/Codegen/SPIRV",
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ iree_cc_library(
iree::compiler::Codegen::Common::GPU::CommonGPUPasses
iree::compiler::Codegen::Dialect::Codegen::IR::IREECodegenDialect
iree::compiler::Codegen::Dialect::GPU::Transforms::GPUTransforms
iree::compiler::Codegen::Dialect::VectorExt::Transforms::VectorExtTransforms
iree::compiler::Codegen::LLVMCPU
iree::compiler::Codegen::LLVMGPU
iree::compiler::Codegen::SPIRV
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,11 @@ def IREEVectorExt_ToLayoutOp : IREEVectorExt_PureOp<"to_layout", [
let results = (outs
AnyShaped:$output
);
let extraClassDeclaration = [{}];
let extraClassDeclaration = [{
bool hasTensorSemantics() {
return isa<RankedTensorType>(getOutput().getType());
}
}];
let assemblyFormat = "$input `to` $layout attr-dict `:` type($input)";
let hasVerifier = 1;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright 2024 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

load("//build_tools/bazel:build_defs.oss.bzl", "iree_compiler_cc_library", "iree_gentbl_cc_library")

package(
default_visibility = ["//visibility:public"],
features = ["layering_check"],
licenses = ["notice"], # Apache 2.0
)

iree_gentbl_cc_library(
name = "PassesIncGen",
tbl_outs = [
(
["--gen-pass-decls"],
"Passes.h.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "Passes.td",
deps = [
"@llvm-project//mlir:PassBaseTdFiles",
],
)

iree_compiler_cc_library(
name = "VectorExtTransforms",
srcs = [
"Passes.cpp",
"VectorizeIREEVectorExtOps.cpp",
],
hdrs = [
"Passes.h",
"Passes.h.inc",
],
deps = [
":PassesIncGen",
"//compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR:IREEVectorExtDialect",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:FunctionInterfaces",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:Transforms",
"@llvm-project//mlir:VectorDialect",
"@llvm-project//mlir:VectorTransforms",
"@llvm-project//mlir:VectorUtils",
],
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
################################################################################
# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
# compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/BUILD.bazel #
# #
# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
# CMake-only content. #
# #
# To disable autogeneration for this file entirely, delete this header. #
################################################################################

iree_add_all_subdirs()

iree_tablegen_library(
NAME
PassesIncGen
TD_FILE
"Passes.td"
OUTS
--gen-pass-decls Passes.h.inc
)

iree_cc_library(
NAME
VectorExtTransforms
HDRS
"Passes.h"
"Passes.h.inc"
SRCS
"Passes.cpp"
"VectorizeIREEVectorExtOps.cpp"
DEPS
::PassesIncGen
LLVMSupport
MLIRArithDialect
MLIRFunctionInterfaces
MLIRIR
MLIRPass
MLIRSupport
MLIRTensorDialect
MLIRTransforms
MLIRVectorDialect
MLIRVectorTransforms
MLIRVectorUtils
iree::compiler::Codegen::Dialect::VectorExt::IR::IREEVectorExtDialect
PUBLIC
)

### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Codegen/Dialect/VectorExt/Transforms/Passes.h"

namespace mlir::iree_compiler {

namespace IREE::VectorExt {
namespace {
#define GEN_PASS_REGISTRATION
#include "iree/compiler/Codegen/Dialect/VectorExt/Transforms/Passes.h.inc"
} // namespace
} // namespace IREE::VectorExt

void registerIREEVectorExtPasses() {
// Generated.
IREE::VectorExt::registerPasses();
}
} // namespace mlir::iree_compiler
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#ifndef IREE_COMPILER_CODEGEN_DIALECT_VECTOR_EXT_TRANSFORMS_PASSES_H_
#define IREE_COMPILER_CODEGEN_DIALECT_VECTOR_EXT_TRANSFORMS_PASSES_H_

#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Pass/Pass.h"

namespace mlir::iree_compiler::IREE::VectorExt {
#define GEN_PASS_DECL
#include "iree/compiler/Codegen/Dialect/VectorExt/Transforms/Passes.h.inc" // IWYU pragma: keep
} // namespace mlir::iree_compiler::IREE::VectorExt

namespace mlir::iree_compiler {
/// Register VectorExt passes.
void registerIREEVectorExtPasses();
} // namespace mlir::iree_compiler

#endif // IREE_COMPILER_CODEGEN_DIALECT_VECTOR_EXT_TRANSFORMS_PASSES_H_
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#ifndef IREE_CODEGEN_DIALECT_VECTOR_EXT_TRANSFORMS_PASSES
#define IREE_CODEGEN_DIALECT_VECTOR_EXT_TRANSFORMS_PASSES

include "mlir/Pass/PassBase.td"

def VectorizeIREEVectorExtOpsPass :
Pass<"iree-vector-ext-vectorize-ops", ""> {
let summary = "Vectorizes then lowers a few iree_vector_ext ops before vectorization.";
let dependentDialects = [
"::mlir::vector::VectorDialect",
"::mlir::arith::ArithDialect",
"::mlir::iree_compiler::IREE::VectorExt::IREEVectorExtDialect"
];
}

#endif // IREE_CODEGEN_DIALECT_VECTOR_EXT_TRANSFORMS_PASSES
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtDialect.h"
#include "iree/compiler/Codegen/Dialect/VectorExt/Transforms/Passes.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/Builders.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

namespace mlir::iree_compiler::IREE::VectorExt {

#define GEN_PASS_DEF_VECTORIZEIREEVECTOREXTOPSPASS
#include "iree/compiler/Codegen/Dialect/VectorExt/Transforms/Passes.h.inc"

namespace {

struct VectorizeToLayoutOpPattern final
: OpRewritePattern<IREE::VectorExt::ToLayoutOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(IREE::VectorExt::ToLayoutOp toLayoutOp,
PatternRewriter &rewriter) const override {
if (!toLayoutOp.hasTensorSemantics()) {
return failure();
}
if (!toLayoutOp.getType().hasStaticShape()) {
return rewriter.notifyMatchFailure(toLayoutOp,
"non-static shape for vectorization");
}

OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(toLayoutOp);

Location loc = toLayoutOp.getLoc();
ShapedType inputTy = toLayoutOp.getType();

// Construct the (never used) zero padding value for input.
auto padValue = rewriter.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(inputTy.getElementType()));

auto newInput = vector::createReadOrMaskedRead(
rewriter, loc, toLayoutOp.getInput(), inputTy.getShape(), padValue,
/*useInBoundsInsteadOfMasking=*/true);

// Create the toLayout operation but with vector types instead.
auto newLayoutOp = rewriter.create<IREE::VectorExt::ToLayoutOp>(
loc, newInput.getType(), newInput, toLayoutOp.getLayout());

// Create the write back to a tensor.
int64_t rank = inputTy.getRank();
auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
auto empty = rewriter.create<tensor::EmptyOp>(loc, inputTy, ValueRange());
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
toLayoutOp,
/*vector=*/newLayoutOp,
/*source=*/empty,
/*indices=*/SmallVector<Value>(rank, zero),
/*inBounds=*/SmallVector<bool>(rank, true));
return success();
}
};

} // namespace

namespace {
struct VectorizeIREEVectorExtOpsPass final
: impl::VectorizeIREEVectorExtOpsPassBase<VectorizeIREEVectorExtOpsPass> {
void runOnOperation() override;
};
} // namespace

void VectorizeIREEVectorExtOpsPass::runOnOperation() {
MLIRContext *ctx = &getContext();
RewritePatternSet patterns(ctx);
patterns.add<VectorizeToLayoutOpPattern>(ctx);
if (failed(
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) {
return signalPassFailure();
}
}

} // namespace mlir::iree_compiler::IREE::VectorExt
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// RUN: iree-opt %s -pass-pipeline='builtin.module(func.func(iree-vector-ext-vectorize-ops, iree-codegen-generic-vectorization))' | FileCheck %s

#layout = #iree_vector_ext.nested_layout<
subgroups_per_workgroup = [1, 1],
batches_per_subgroup = [1, 1],
outers_per_batch = [1, 1],
threads_per_outer = [1, 1],
elements_per_thread = [64, 64],

subgroup_strides = [0, 0],
thread_strides = [0, 0]
>

// CHECK-LABEL: func.func @vectorize_matmul_layout
func.func @vectorize_matmul_layout(%A: tensor<64x64xf32>,
%B: tensor<64x64xf32>,
%C: tensor<64x64xf32>)
-> tensor<64x64xf32> {
%AL = iree_vector_ext.to_layout %A to #layout : tensor<64x64xf32>
%BL = iree_vector_ext.to_layout %B to #layout : tensor<64x64xf32>
%CL = iree_vector_ext.to_layout %C to #layout : tensor<64x64xf32>
// CHECK: %[[A:.+]] = iree_vector_ext.to_layout
// CHECK-SAME: vector<64x64xf32>
// CHECK: %[[B:.+]] = iree_vector_ext.to_layout
// CHECK-SAME: vector<64x64xf32>
// CHECK: %[[C:.+]] = iree_vector_ext.to_layout
// CHECK-SAME: vector<64x64xf32>
%matmul = linalg.matmul ins(%AL, %BL : tensor<64x64xf32>, tensor<64x64xf32>)
outs(%CL: tensor<64x64xf32>) -> tensor<64x64xf32>
// CHECK: vector.contract
// CHECK-SAME: %[[A]], %[[B]], %[[C]]
return %matmul : tensor<64x64xf32>
}
2 changes: 2 additions & 0 deletions compiler/src/iree/compiler/Codegen/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "iree/compiler/Codegen/Common/GPU/Passes.h"
#include "iree/compiler/Codegen/Common/Passes.h"
#include "iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.h"
#include "iree/compiler/Codegen/Dialect/VectorExt/Transforms/Passes.h"
#include "iree/compiler/Codegen/LLVMCPU/Passes.h"
#include "iree/compiler/Codegen/LLVMGPU/Passes.h"
#include "iree/compiler/Codegen/LLVMGPU/ROCDLPasses.h"
Expand All @@ -35,6 +36,7 @@ void registerCodegenPasses() {
registerCodegenVMVXPasses();
registerCodegenWGSLPasses();
registerIREEGPUPasses();
registerIREEVectorExtPasses();
}

} // namespace mlir::iree_compiler

0 comments on commit b1d82d4

Please sign in to comment.