Skip to content

Commit

Permalink
[VectorExt] Teach vectorization to to_layout (iree-org#18092)
Browse files Browse the repository at this point in the history
This patch adds a pass to vectorize iree_vector_ext.to_layout
operations. This allows us to set layouts at linalg level and distribute
later in the pipeline.
  • Loading branch information
Groverkss committed Aug 6, 2024
1 parent e22b78d commit b324f2a
Show file tree
Hide file tree
Showing 18 changed files with 344 additions and 44 deletions.
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ iree_compiler_cc_library(
"//compiler/src/iree/compiler/Codegen/Common/GPU:CommonGPUPasses",
"//compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR:IREECodegenDialect",
"//compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms:GPUTransforms",
"//compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms:VectorExtTransforms",
"//compiler/src/iree/compiler/Codegen/LLVMCPU",
"//compiler/src/iree/compiler/Codegen/LLVMGPU",
"//compiler/src/iree/compiler/Codegen/SPIRV",
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ iree_cc_library(
iree::compiler::Codegen::Common::GPU::CommonGPUPasses
iree::compiler::Codegen::Dialect::Codegen::IR::IREECodegenDialect
iree::compiler::Codegen::Dialect::GPU::Transforms::GPUTransforms
iree::compiler::Codegen::Dialect::VectorExt::Transforms::VectorExtTransforms
iree::compiler::Codegen::LLVMCPU
iree::compiler::Codegen::LLVMGPU
iree::compiler::Codegen::SPIRV
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -795,8 +795,8 @@ struct DistributeLayoutConflictResolutions final
LogicalResult matchAndRewrite(IREE::VectorExt::ToLayoutOp resolutionOp,
DistributionSignature &signature,
PatternRewriter &rewriter) const override {
VectorValue vector = resolutionOp.getInput();
VectorValue result = resolutionOp.getOutput();
auto vector = cast<VectorValue>(resolutionOp.getInput());
auto result = cast<VectorValue>(resolutionOp.getOutput());
LayoutAttr currentLayout = dyn_cast<LayoutAttr>(signature[vector]);
if (!currentLayout)
return failure();
Expand Down Expand Up @@ -848,8 +848,8 @@ struct DistributeLayoutConflictToSharedMemory final
DistributionSignature &signature,
PatternRewriter &rewriter) const override {
auto loc = resolutionOp.getLoc();
VectorValue vector = resolutionOp.getInput();
VectorValue result = resolutionOp.getOutput();
auto vector = cast<VectorValue>(resolutionOp.getInput());
auto result = cast<VectorValue>(resolutionOp.getOutput());
LayoutAttr currentLayout = dyn_cast<LayoutAttr>(signature[vector]);
if (!currentLayout) {
return rewriter.notifyMatchFailure(resolutionOp,
Expand Down Expand Up @@ -1019,10 +1019,12 @@ struct DistributeTrivialLayoutConversions final
LogicalResult matchAndRewrite(IREE::VectorExt::ToLayoutOp toLayoutOp,
DistributionSignature &signature,
PatternRewriter &rewriter) const override {
auto input = cast<VectorValue>(toLayoutOp.getInput());
auto output = cast<VectorValue>(toLayoutOp.getOutput());
VectorLayoutInterface currentLayout =
dyn_cast<LayoutAttr>(signature[toLayoutOp.getInput()]);
dyn_cast<LayoutAttr>(signature[input]);
VectorLayoutInterface targetLayout =
dyn_cast<LayoutAttr>(signature[toLayoutOp.getResult()]);
dyn_cast<LayoutAttr>(signature[output]);

if (currentLayout != targetLayout) {
return rewriter.notifyMatchFailure(toLayoutOp,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,9 @@ class DistributionLayout : public AnalysisState {
/// should only be used when you know there will be no layout conflicts.
/// Otherwise, the resolve-like functions should be used.
void setInnerLayout(const VectorLayoutInterface &layout) {
assert(layout && layout.isValidLayout(getValue()).succeeded());
assert(layout &&
layout.isValidLayout(getValue().getType(), getValue().getLoc())
.succeeded());
vectorLayout = layout;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -413,8 +413,8 @@ builtin.module attributes { transform.with_named_sequence } {
builtin.module attributes { transform.with_named_sequence } {
func.func @invalid_rank_nested_layout_anchor(%a: vector<16x16xf16>, %b: vector<16x16xf16>) -> vector<16x16xf16> {
%c = arith.addf %a, %b : vector<16x16xf16>
// expected-error @above {{Rank of vector (2) does not match rank of layout (3)}}
%cl = iree_vector_ext.to_layout %c to #layout : vector<16x16xf16>
// expected-error @above {{Rank of vector (2) does not match rank of layout (3)}}
func.return %cl : vector<16x16xf16>
}

Expand Down Expand Up @@ -442,8 +442,8 @@ builtin.module attributes { transform.with_named_sequence } {
builtin.module attributes { transform.with_named_sequence } {
func.func @invalid_size_nested_layout_anchor(%a: vector<16x16xf16>, %b: vector<16x16xf16>) -> vector<16x16xf16> {
%c = arith.addf %a, %b : vector<16x16xf16>
// expected-error @above {{Vector shape: [16, 16] does not match the layout (nested_layout<subgroups_per_workgroup = [1, 1], batches_per_subgroup = [2, 4], outers_per_batch = [1, 1], threads_per_outer = [8, 2], elements_per_thread = [2, 2], subgroup_strides = [0, 0], thread_strides = [1, 8]>) at dim 0. Dimension expected by layout: 32 actual: 16}}
%cl = iree_vector_ext.to_layout %c to #layout2 : vector<16x16xf16>
// expected-error @above {{Vector shape: [16, 16] does not match the layout (nested_layout<subgroups_per_workgroup = [1, 1], batches_per_subgroup = [2, 4], outers_per_batch = [1, 1], threads_per_outer = [8, 2], elements_per_thread = [2, 2], subgroup_strides = [0, 0], thread_strides = [1, 8]>) at dim 0. Dimension expected by layout: 32 actual: 16}}
func.return %cl : vector<16x16xf16>
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,13 @@ std::optional<int64_t> LayoutAttr::getShape(const LayoutDimension &dim) const {

// Get the SIMT Vector shape in the order specified by dims. If no dims are
// specified, then return an empty vector.
LogicalResult LayoutAttr::isValidLayout(VectorValue vector) const {
ArrayRef<int64_t> shape = vector.getType().getShape();
LogicalResult LayoutAttr::isValidLayout(ShapedType shapeTy,
Location loc) const {
ArrayRef<int64_t> shape = shapeTy.getShape();
if (shape.size() != getRank()) {
return emitError(vector.getLoc(), "Rank of vector (" +
std::to_string(shape.size()) +
") does not match rank of layout (" +
std::to_string(getRank()) + ").");
return emitError(loc, "Rank of vector (")
<< shape.size() << ") does not match rank of layout (" << getRank()
<< ").";
}
for (auto [idx, layout] : llvm::enumerate(getLayouts())) {
ArrayRef<int64_t> layoutShape = layout.getShapes();
Expand All @@ -72,13 +72,11 @@ LogicalResult LayoutAttr::isValidLayout(VectorValue vector) const {
std::string layoutStr;
llvm::raw_string_ostream layoutOs(layoutStr);
printStripped(layoutOs);
return emitError(vector.getLoc(),
"Vector shape: [" + shapeStr +
"] does not match the layout (" + layoutStr +
") at dim " + std::to_string(idx) +
". Dimension expected by layout: " +
std::to_string(expectedShape) +
" actual: " + std::to_string(shape[idx]));
return emitError(loc, "Vector shape: [")
<< shapeStr << "] does not match the layout (" << layoutStr
<< ") at dim " << idx
<< ". Dimension expected by layout: " << expectedShape
<< " actual: " << shape[idx];
}
}
return success();
Expand Down Expand Up @@ -321,14 +319,14 @@ int64_t NestedLayoutAttr::getRank() const {
return getBatchesPerSubgroup().size();
}

LogicalResult NestedLayoutAttr::isValidLayout(VectorValue vector) const {
LogicalResult NestedLayoutAttr::isValidLayout(ShapedType shapeTy,
Location loc) const {
int64_t rank = getRank();
ArrayRef<int64_t> shape = vector.getType().getShape();
ArrayRef<int64_t> shape = shapeTy.getShape();
if (shape.size() != rank) {
return emitError(vector.getLoc(), "Rank of vector (" +
std::to_string(shape.size()) +
") does not match rank of layout (" +
std::to_string(rank) + ").");
return emitError(loc, "Rank of vector (")
<< shape.size() << ") does not match rank of layout (" << rank
<< ").";
}
// Multiply all shapes in the layout.
for (int i = 0, e = rank; i < e; ++i) {
Expand All @@ -343,13 +341,11 @@ LogicalResult NestedLayoutAttr::isValidLayout(VectorValue vector) const {
std::string layoutStr;
llvm::raw_string_ostream layoutOs(layoutStr);
printStripped(layoutOs);
return emitError(vector.getLoc(),
"Vector shape: [" + shapeStr +
"] does not match the layout (" + layoutStr +
") at dim " + std::to_string(i) +
". Dimension expected by layout: " +
std::to_string(expectedShape) +
" actual: " + std::to_string(shape[i]));
return emitError(loc, "Vector shape: [")
<< shapeStr << "] does not match the layout ("
<< layoutStr + ") at dim " << i
<< ". Dimension expected by layout: " << expectedShape
<< " actual: " << shape[i];
}
}
return success();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,10 @@ def VectorLayoutInterface : AttrInterface<"VectorLayoutInterface"> {

let methods = [
InterfaceMethod<
/*description=*/"Check if this layout is valid for the given vector type."
"On failure, emits diagnostics to explain the failure.",
/*description=*/"Check if this layout is valid for the given shape.",
/*retTy=*/"LogicalResult",
/*methodName=*/"isValidLayout",
/*args=*/(ins "::mlir::TypedValue<::mlir::VectorType>":$vector)
/*args=*/(ins "::mlir::ShapedType":$shape, "::mlir::Location":$loc)
>,
InterfaceMethod<
/*description=*/"Permutes the given layout.",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtOps.h"
#include <numeric>
#include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtDialect.h"

using namespace mlir;
Expand All @@ -19,7 +18,7 @@ using VectorValue = TypedValue<VectorType>;

// Validate that the layout has the same shape as the input.
LogicalResult ToLayoutOp::verify() {
return getLayout().isValidLayout(getInput());
return getLayout().isValidLayout(getInput().getType(), getLoc());
}

// to_simd -> to_simt
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,17 @@ def IREEVectorExt_ToLayoutOp : IREEVectorExt_PureOp<"to_layout", [
transforms the value to have that layout.
}];
let arguments = (ins
AnyVector:$input,
AnyShaped:$input,
VectorLayoutInterface:$layout
);
let results = (outs
AnyVector:$output
AnyShaped:$output
);
let extraClassDeclaration = [{}];
let extraClassDeclaration = [{
bool hasTensorSemantics() {
return isa<RankedTensorType>(getOutput().getType());
}
}];
let assemblyFormat = "$input `to` $layout attr-dict `:` type($input)";
let hasVerifier = 1;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
func.func @invalid_layout(%lhs: memref<32x32xf16>, %rhs: memref<32x32xf16>) -> vector<32x32xf16> {
%cst_0 = arith.constant 0.0 : f16
%c0 = arith.constant 0 : index
// expected-error @+1 {{Vector shape: [32, 32] does not match the layout (layout<<[ BATCHX, LANEX, VECTORY], [1, 1, 1]>, <[ BATCHY, LANEY, VECTORX], [4, 2, 4]>>) at dim 0. Dimension expected by layout: 1 actual: 32}}
%result = vector.transfer_read %lhs[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<32x32xf16>, vector<32x32xf16>
// expected-error @+1 {{Vector shape: [32, 32] does not match the layout (layout<<[ BATCHX, LANEX, VECTORY], [1, 1, 1]>, <[ BATCHY, LANEY, VECTORX], [4, 2, 4]>>) at dim 0. Dimension expected by layout: 1 actual: 32}}
%2 = iree_vector_ext.to_layout %result to #layout1 : vector<32x32xf16>
return %2 : vector<32x32xf16>
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright 2024 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

load("//build_tools/bazel:build_defs.oss.bzl", "iree_compiler_cc_library", "iree_gentbl_cc_library")

package(
default_visibility = ["//visibility:public"],
features = ["layering_check"],
licenses = ["notice"], # Apache 2.0
)

iree_gentbl_cc_library(
name = "PassesIncGen",
tbl_outs = [
(
["--gen-pass-decls"],
"Passes.h.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "Passes.td",
deps = [
"@llvm-project//mlir:PassBaseTdFiles",
],
)

iree_compiler_cc_library(
name = "VectorExtTransforms",
srcs = [
"Passes.cpp",
"VectorizeIREEVectorExtOps.cpp",
],
hdrs = [
"Passes.h",
"Passes.h.inc",
],
deps = [
":PassesIncGen",
"//compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR:IREEVectorExtDialect",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:FunctionInterfaces",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:TransformUtils",
"@llvm-project//mlir:Transforms",
"@llvm-project//mlir:VectorDialect",
"@llvm-project//mlir:VectorTransforms",
"@llvm-project//mlir:VectorUtils",
],
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
################################################################################
# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
# compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/BUILD.bazel #
# #
# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
# CMake-only content. #
# #
# To disable autogeneration for this file entirely, delete this header. #
################################################################################

iree_add_all_subdirs()

iree_tablegen_library(
NAME
PassesIncGen
TD_FILE
"Passes.td"
OUTS
--gen-pass-decls Passes.h.inc
)

iree_cc_library(
NAME
VectorExtTransforms
HDRS
"Passes.h"
"Passes.h.inc"
SRCS
"Passes.cpp"
"VectorizeIREEVectorExtOps.cpp"
DEPS
::PassesIncGen
LLVMSupport
MLIRArithDialect
MLIRFunctionInterfaces
MLIRIR
MLIRPass
MLIRSupport
MLIRTensorDialect
MLIRTransformUtils
MLIRTransforms
MLIRVectorDialect
MLIRVectorTransforms
MLIRVectorUtils
iree::compiler::Codegen::Dialect::VectorExt::IR::IREEVectorExtDialect
PUBLIC
)

### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Codegen/Dialect/VectorExt/Transforms/Passes.h"

namespace mlir::iree_compiler {

namespace IREE::VectorExt {
namespace {
#define GEN_PASS_REGISTRATION
#include "iree/compiler/Codegen/Dialect/VectorExt/Transforms/Passes.h.inc"
} // namespace
} // namespace IREE::VectorExt

void registerIREEVectorExtPasses() {
// Generated.
IREE::VectorExt::registerPasses();
}
} // namespace mlir::iree_compiler
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#ifndef IREE_COMPILER_CODEGEN_DIALECT_VECTOR_EXT_TRANSFORMS_PASSES_H_
#define IREE_COMPILER_CODEGEN_DIALECT_VECTOR_EXT_TRANSFORMS_PASSES_H_

#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Pass/Pass.h"

namespace mlir::iree_compiler::IREE::VectorExt {
#define GEN_PASS_DECL
#include "iree/compiler/Codegen/Dialect/VectorExt/Transforms/Passes.h.inc" // IWYU pragma: keep
} // namespace mlir::iree_compiler::IREE::VectorExt

namespace mlir::iree_compiler {
/// Register VectorExt passes.
void registerIREEVectorExtPasses();
} // namespace mlir::iree_compiler

#endif // IREE_COMPILER_CODEGEN_DIALECT_VECTOR_EXT_TRANSFORMS_PASSES_H_
Loading

0 comments on commit b324f2a

Please sign in to comment.