From b1d82d4c1d86ba399ccfa68a61c0af591b1eb1c3 Mon Sep 17 00:00:00 2001 From: Kunwar Grover Date: Fri, 2 Aug 2024 12:52:51 +0000 Subject: [PATCH] [VectorExt] Teach to_layout vectorization --- .../src/iree/compiler/Codegen/BUILD.bazel | 1 + .../src/iree/compiler/Codegen/CMakeLists.txt | 1 + .../Dialect/VectorExt/IR/VectorExtOps.td | 6 +- .../Dialect/VectorExt/Transforms/BUILD.bazel | 55 ++++++++++++ .../VectorExt/Transforms/CMakeLists.txt | 48 +++++++++++ .../Dialect/VectorExt/Transforms/Passes.cpp | 22 +++++ .../Dialect/VectorExt/Transforms/Passes.h | 23 +++++ .../Dialect/VectorExt/Transforms/Passes.td | 22 +++++ .../Transforms/VectorizeIREEVectorExtOps.cpp | 86 +++++++++++++++++++ .../test/vectorize_vector_ext_ops.mlir | 33 +++++++ compiler/src/iree/compiler/Codegen/Passes.cpp | 2 + 11 files changed, 298 insertions(+), 1 deletion(-) create mode 100644 compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/BUILD.bazel create mode 100644 compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/CMakeLists.txt create mode 100644 compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/Passes.cpp create mode 100644 compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/Passes.h create mode 100644 compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/Passes.td create mode 100644 compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/VectorizeIREEVectorExtOps.cpp create mode 100644 compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/test/vectorize_vector_ext_ops.mlir diff --git a/compiler/src/iree/compiler/Codegen/BUILD.bazel b/compiler/src/iree/compiler/Codegen/BUILD.bazel index a07556c695fc9..f2ead8784caba 100644 --- a/compiler/src/iree/compiler/Codegen/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/BUILD.bazel @@ -26,6 +26,7 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Codegen/Common/GPU:CommonGPUPasses", "//compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR:IREECodegenDialect", "//compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms:GPUTransforms", + "//compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms:VectorExtTransforms", "//compiler/src/iree/compiler/Codegen/LLVMCPU", "//compiler/src/iree/compiler/Codegen/LLVMGPU", "//compiler/src/iree/compiler/Codegen/SPIRV", diff --git a/compiler/src/iree/compiler/Codegen/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/CMakeLists.txt index bf8407b1119c2..625c999a0b676 100644 --- a/compiler/src/iree/compiler/Codegen/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/CMakeLists.txt @@ -24,6 +24,7 @@ iree_cc_library( iree::compiler::Codegen::Common::GPU::CommonGPUPasses iree::compiler::Codegen::Dialect::Codegen::IR::IREECodegenDialect iree::compiler::Codegen::Dialect::GPU::Transforms::GPUTransforms + iree::compiler::Codegen::Dialect::VectorExt::Transforms::VectorExtTransforms iree::compiler::Codegen::LLVMCPU iree::compiler::Codegen::LLVMGPU iree::compiler::Codegen::SPIRV diff --git a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtOps.td b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtOps.td index af95bc2294621..bb5c634248678 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtOps.td +++ b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtOps.td @@ -41,7 +41,11 @@ def IREEVectorExt_ToLayoutOp : IREEVectorExt_PureOp<"to_layout", [ let results = (outs AnyShaped:$output ); - let extraClassDeclaration = [{}]; + let extraClassDeclaration = [{ + bool hasTensorSemantics() { + return isa(getOutput().getType()); + } + }]; let assemblyFormat = "$input `to` $layout attr-dict `:` type($input)"; let hasVerifier = 1; } diff --git a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/BUILD.bazel new file mode 100644 index 0000000000000..c732a03cd2a58 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/BUILD.bazel @@ -0,0 +1,55 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +load("//build_tools/bazel:build_defs.oss.bzl", "iree_compiler_cc_library", "iree_gentbl_cc_library") + +package( + default_visibility = ["//visibility:public"], + features = ["layering_check"], + licenses = ["notice"], # Apache 2.0 +) + +iree_gentbl_cc_library( + name = "PassesIncGen", + tbl_outs = [ + ( + ["--gen-pass-decls"], + "Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "Passes.td", + deps = [ + "@llvm-project//mlir:PassBaseTdFiles", + ], +) + +iree_compiler_cc_library( + name = "VectorExtTransforms", + srcs = [ + "Passes.cpp", + "VectorizeIREEVectorExtOps.cpp", + ], + hdrs = [ + "Passes.h", + "Passes.h.inc", + ], + deps = [ + ":PassesIncGen", + "//compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR:IREEVectorExtDialect", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:FunctionInterfaces", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:Transforms", + "@llvm-project//mlir:VectorDialect", + "@llvm-project//mlir:VectorTransforms", + "@llvm-project//mlir:VectorUtils", + ], +) diff --git a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/CMakeLists.txt new file mode 100644 index 0000000000000..e8851ae0de05e --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/CMakeLists.txt @@ -0,0 +1,48 @@ +################################################################################ +# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from # +# compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/BUILD.bazel # +# # +# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary # +# CMake-only content. # +# # +# To disable autogeneration for this file entirely, delete this header. # +################################################################################ + +iree_add_all_subdirs() + +iree_tablegen_library( + NAME + PassesIncGen + TD_FILE + "Passes.td" + OUTS + --gen-pass-decls Passes.h.inc +) + +iree_cc_library( + NAME + VectorExtTransforms + HDRS + "Passes.h" + "Passes.h.inc" + SRCS + "Passes.cpp" + "VectorizeIREEVectorExtOps.cpp" + DEPS + ::PassesIncGen + LLVMSupport + MLIRArithDialect + MLIRFunctionInterfaces + MLIRIR + MLIRPass + MLIRSupport + MLIRTensorDialect + MLIRTransforms + MLIRVectorDialect + MLIRVectorTransforms + MLIRVectorUtils + iree::compiler::Codegen::Dialect::VectorExt::IR::IREEVectorExtDialect + PUBLIC +) + +### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ### diff --git a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/Passes.cpp b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/Passes.cpp new file mode 100644 index 0000000000000..a81ca0b78ad45 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/Passes.cpp @@ -0,0 +1,22 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Codegen/Dialect/VectorExt/Transforms/Passes.h" + +namespace mlir::iree_compiler { + +namespace IREE::VectorExt { +namespace { +#define GEN_PASS_REGISTRATION +#include "iree/compiler/Codegen/Dialect/VectorExt/Transforms/Passes.h.inc" +} // namespace +} // namespace IREE::VectorExt + +void registerIREEVectorExtPasses() { + // Generated. + IREE::VectorExt::registerPasses(); +} +} // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/Passes.h b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/Passes.h new file mode 100644 index 0000000000000..17967f61dca67 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/Passes.h @@ -0,0 +1,23 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_COMPILER_CODEGEN_DIALECT_VECTOR_EXT_TRANSFORMS_PASSES_H_ +#define IREE_COMPILER_CODEGEN_DIALECT_VECTOR_EXT_TRANSFORMS_PASSES_H_ + +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "mlir/Pass/Pass.h" + +namespace mlir::iree_compiler::IREE::VectorExt { +#define GEN_PASS_DECL +#include "iree/compiler/Codegen/Dialect/VectorExt/Transforms/Passes.h.inc" // IWYU pragma: keep +} // namespace mlir::iree_compiler::IREE::VectorExt + +namespace mlir::iree_compiler { +/// Register VectorExt passes. +void registerIREEVectorExtPasses(); +} // namespace mlir::iree_compiler + +#endif // IREE_COMPILER_CODEGEN_DIALECT_VECTOR_EXT_TRANSFORMS_PASSES_H_ diff --git a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/Passes.td b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/Passes.td new file mode 100644 index 0000000000000..222ea4fd0e73a --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/Passes.td @@ -0,0 +1,22 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_CODEGEN_DIALECT_VECTOR_EXT_TRANSFORMS_PASSES +#define IREE_CODEGEN_DIALECT_VECTOR_EXT_TRANSFORMS_PASSES + +include "mlir/Pass/PassBase.td" + +def VectorizeIREEVectorExtOpsPass : + Pass<"iree-vector-ext-vectorize-ops", ""> { + let summary = "Vectorizes then lowers a few iree_vector_ext ops before vectorization."; + let dependentDialects = [ + "::mlir::vector::VectorDialect", + "::mlir::arith::ArithDialect", + "::mlir::iree_compiler::IREE::VectorExt::IREEVectorExtDialect" + ]; +} + +#endif // IREE_CODEGEN_DIALECT_VECTOR_EXT_TRANSFORMS_PASSES diff --git a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/VectorizeIREEVectorExtOps.cpp b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/VectorizeIREEVectorExtOps.cpp new file mode 100644 index 0000000000000..9e7a5e76c5b97 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/VectorizeIREEVectorExtOps.cpp @@ -0,0 +1,86 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtDialect.h" +#include "iree/compiler/Codegen/Dialect/VectorExt/Transforms/Passes.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Utils/VectorUtils.h" +#include "mlir/IR/Builders.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir::iree_compiler::IREE::VectorExt { + +#define GEN_PASS_DEF_VECTORIZEIREEVECTOREXTOPSPASS +#include "iree/compiler/Codegen/Dialect/VectorExt/Transforms/Passes.h.inc" + +namespace { + +struct VectorizeToLayoutOpPattern final + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(IREE::VectorExt::ToLayoutOp toLayoutOp, + PatternRewriter &rewriter) const override { + if (!toLayoutOp.hasTensorSemantics()) { + return failure(); + } + if (!toLayoutOp.getType().hasStaticShape()) { + return rewriter.notifyMatchFailure(toLayoutOp, + "non-static shape for vectorization"); + } + + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(toLayoutOp); + + Location loc = toLayoutOp.getLoc(); + ShapedType inputTy = toLayoutOp.getType(); + + // Construct the (never used) zero padding value for input. + auto padValue = rewriter.create( + loc, rewriter.getZeroAttr(inputTy.getElementType())); + + auto newInput = vector::createReadOrMaskedRead( + rewriter, loc, toLayoutOp.getInput(), inputTy.getShape(), padValue, + /*useInBoundsInsteadOfMasking=*/true); + + // Create the toLayout operation but with vector types instead. + auto newLayoutOp = rewriter.create( + loc, newInput.getType(), newInput, toLayoutOp.getLayout()); + + // Create the write back to a tensor. + int64_t rank = inputTy.getRank(); + auto zero = rewriter.create(loc, 0); + auto empty = rewriter.create(loc, inputTy, ValueRange()); + rewriter.replaceOpWithNewOp( + toLayoutOp, + /*vector=*/newLayoutOp, + /*source=*/empty, + /*indices=*/SmallVector(rank, zero), + /*inBounds=*/SmallVector(rank, true)); + return success(); + } +}; + +} // namespace + +namespace { +struct VectorizeIREEVectorExtOpsPass final + : impl::VectorizeIREEVectorExtOpsPassBase { + void runOnOperation() override; +}; +} // namespace + +void VectorizeIREEVectorExtOpsPass::runOnOperation() { + MLIRContext *ctx = &getContext(); + RewritePatternSet patterns(ctx); + patterns.add(ctx); + if (failed( + applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) { + return signalPassFailure(); + } +} + +} // namespace mlir::iree_compiler::IREE::VectorExt diff --git a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/test/vectorize_vector_ext_ops.mlir b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/test/vectorize_vector_ext_ops.mlir new file mode 100644 index 0000000000000..581d7d7bca416 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/test/vectorize_vector_ext_ops.mlir @@ -0,0 +1,33 @@ +// RUN: iree-opt %s -pass-pipeline='builtin.module(func.func(iree-vector-ext-vectorize-ops, iree-codegen-generic-vectorization))' | FileCheck %s + +#layout = #iree_vector_ext.nested_layout< + subgroups_per_workgroup = [1, 1], + batches_per_subgroup = [1, 1], + outers_per_batch = [1, 1], + threads_per_outer = [1, 1], + elements_per_thread = [64, 64], + + subgroup_strides = [0, 0], + thread_strides = [0, 0] +> + +// CHECK-LABEL: func.func @vectorize_matmul_layout +func.func @vectorize_matmul_layout(%A: tensor<64x64xf32>, + %B: tensor<64x64xf32>, + %C: tensor<64x64xf32>) + -> tensor<64x64xf32> { + %AL = iree_vector_ext.to_layout %A to #layout : tensor<64x64xf32> + %BL = iree_vector_ext.to_layout %B to #layout : tensor<64x64xf32> + %CL = iree_vector_ext.to_layout %C to #layout : tensor<64x64xf32> + // CHECK: %[[A:.+]] = iree_vector_ext.to_layout + // CHECK-SAME: vector<64x64xf32> + // CHECK: %[[B:.+]] = iree_vector_ext.to_layout + // CHECK-SAME: vector<64x64xf32> + // CHECK: %[[C:.+]] = iree_vector_ext.to_layout + // CHECK-SAME: vector<64x64xf32> + %matmul = linalg.matmul ins(%AL, %BL : tensor<64x64xf32>, tensor<64x64xf32>) + outs(%CL: tensor<64x64xf32>) -> tensor<64x64xf32> + // CHECK: vector.contract + // CHECK-SAME: %[[A]], %[[B]], %[[C]] + return %matmul : tensor<64x64xf32> +} diff --git a/compiler/src/iree/compiler/Codegen/Passes.cpp b/compiler/src/iree/compiler/Codegen/Passes.cpp index 236071717d737..f465ec8067c77 100644 --- a/compiler/src/iree/compiler/Codegen/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/Passes.cpp @@ -14,6 +14,7 @@ #include "iree/compiler/Codegen/Common/GPU/Passes.h" #include "iree/compiler/Codegen/Common/Passes.h" #include "iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.h" +#include "iree/compiler/Codegen/Dialect/VectorExt/Transforms/Passes.h" #include "iree/compiler/Codegen/LLVMCPU/Passes.h" #include "iree/compiler/Codegen/LLVMGPU/Passes.h" #include "iree/compiler/Codegen/LLVMGPU/ROCDLPasses.h" @@ -35,6 +36,7 @@ void registerCodegenPasses() { registerCodegenVMVXPasses(); registerCodegenWGSLPasses(); registerIREEGPUPasses(); + registerIREEVectorExtPasses(); } } // namespace mlir::iree_compiler