Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[VectorExt] Teach vectorization to to_layout #18092

Merged
merged 5 commits into from
Aug 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ iree_compiler_cc_library(
"//compiler/src/iree/compiler/Codegen/Common/GPU:CommonGPUPasses",
"//compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR:IREECodegenDialect",
"//compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms:GPUTransforms",
"//compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms:VectorExtTransforms",
"//compiler/src/iree/compiler/Codegen/LLVMCPU",
"//compiler/src/iree/compiler/Codegen/LLVMGPU",
"//compiler/src/iree/compiler/Codegen/SPIRV",
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ iree_cc_library(
iree::compiler::Codegen::Common::GPU::CommonGPUPasses
iree::compiler::Codegen::Dialect::Codegen::IR::IREECodegenDialect
iree::compiler::Codegen::Dialect::GPU::Transforms::GPUTransforms
iree::compiler::Codegen::Dialect::VectorExt::Transforms::VectorExtTransforms
iree::compiler::Codegen::LLVMCPU
iree::compiler::Codegen::LLVMGPU
iree::compiler::Codegen::SPIRV
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -795,8 +795,8 @@ struct DistributeLayoutConflictResolutions final
LogicalResult matchAndRewrite(IREE::VectorExt::ToLayoutOp resolutionOp,
DistributionSignature &signature,
PatternRewriter &rewriter) const override {
VectorValue vector = resolutionOp.getInput();
VectorValue result = resolutionOp.getOutput();
auto vector = cast<VectorValue>(resolutionOp.getInput());
auto result = cast<VectorValue>(resolutionOp.getOutput());
LayoutAttr currentLayout = dyn_cast<LayoutAttr>(signature[vector]);
if (!currentLayout)
return failure();
Expand Down Expand Up @@ -848,8 +848,8 @@ struct DistributeLayoutConflictToSharedMemory final
DistributionSignature &signature,
PatternRewriter &rewriter) const override {
auto loc = resolutionOp.getLoc();
VectorValue vector = resolutionOp.getInput();
VectorValue result = resolutionOp.getOutput();
auto vector = cast<VectorValue>(resolutionOp.getInput());
auto result = cast<VectorValue>(resolutionOp.getOutput());
LayoutAttr currentLayout = dyn_cast<LayoutAttr>(signature[vector]);
if (!currentLayout) {
return rewriter.notifyMatchFailure(resolutionOp,
Expand Down Expand Up @@ -1019,10 +1019,12 @@ struct DistributeTrivialLayoutConversions final
LogicalResult matchAndRewrite(IREE::VectorExt::ToLayoutOp toLayoutOp,
DistributionSignature &signature,
PatternRewriter &rewriter) const override {
auto input = cast<VectorValue>(toLayoutOp.getInput());
auto output = cast<VectorValue>(toLayoutOp.getOutput());
VectorLayoutInterface currentLayout =
dyn_cast<LayoutAttr>(signature[toLayoutOp.getInput()]);
dyn_cast<LayoutAttr>(signature[input]);
VectorLayoutInterface targetLayout =
dyn_cast<LayoutAttr>(signature[toLayoutOp.getResult()]);
dyn_cast<LayoutAttr>(signature[output]);

if (currentLayout != targetLayout) {
return rewriter.notifyMatchFailure(toLayoutOp,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,9 @@ class DistributionLayout : public AnalysisState {
/// should only be used when you know there will be no layout conflicts.
/// Otherwise, the resolve-like functions should be used.
void setInnerLayout(const VectorLayoutInterface &layout) {
assert(layout && layout.isValidLayout(getValue()).succeeded());
assert(layout &&
layout.isValidLayout(getValue().getType(), getValue().getLoc())
.succeeded());
vectorLayout = layout;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -413,8 +413,8 @@ builtin.module attributes { transform.with_named_sequence } {
builtin.module attributes { transform.with_named_sequence } {
func.func @invalid_rank_nested_layout_anchor(%a: vector<16x16xf16>, %b: vector<16x16xf16>) -> vector<16x16xf16> {
%c = arith.addf %a, %b : vector<16x16xf16>
// expected-error @above {{Rank of vector (2) does not match rank of layout (3)}}
%cl = iree_vector_ext.to_layout %c to #layout : vector<16x16xf16>
// expected-error @above {{Rank of vector (2) does not match rank of layout (3)}}
func.return %cl : vector<16x16xf16>
}

Expand Down Expand Up @@ -442,8 +442,8 @@ builtin.module attributes { transform.with_named_sequence } {
builtin.module attributes { transform.with_named_sequence } {
func.func @invalid_size_nested_layout_anchor(%a: vector<16x16xf16>, %b: vector<16x16xf16>) -> vector<16x16xf16> {
%c = arith.addf %a, %b : vector<16x16xf16>
// expected-error @above {{Vector shape: [16, 16] does not match the layout (nested_layout<subgroups_per_workgroup = [1, 1], batches_per_subgroup = [2, 4], outers_per_batch = [1, 1], threads_per_outer = [8, 2], elements_per_thread = [2, 2], subgroup_strides = [0, 0], thread_strides = [1, 8]>) at dim 0. Dimension expected by layout: 32 actual: 16}}
%cl = iree_vector_ext.to_layout %c to #layout2 : vector<16x16xf16>
// expected-error @above {{Vector shape: [16, 16] does not match the layout (nested_layout<subgroups_per_workgroup = [1, 1], batches_per_subgroup = [2, 4], outers_per_batch = [1, 1], threads_per_outer = [8, 2], elements_per_thread = [2, 2], subgroup_strides = [0, 0], thread_strides = [1, 8]>) at dim 0. Dimension expected by layout: 32 actual: 16}}
func.return %cl : vector<16x16xf16>
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,13 @@ std::optional<int64_t> LayoutAttr::getShape(const LayoutDimension &dim) const {

// Get the SIMT Vector shape in the order specified by dims. If no dims are
// specified, then return an empty vector.
LogicalResult LayoutAttr::isValidLayout(VectorValue vector) const {
ArrayRef<int64_t> shape = vector.getType().getShape();
LogicalResult LayoutAttr::isValidLayout(ShapedType shapeTy,
Location loc) const {
ArrayRef<int64_t> shape = shapeTy.getShape();
if (shape.size() != getRank()) {
return emitError(vector.getLoc(), "Rank of vector (" +
std::to_string(shape.size()) +
") does not match rank of layout (" +
std::to_string(getRank()) + ").");
return emitError(loc, "Rank of vector (")
<< shape.size() << ") does not match rank of layout (" << getRank()
<< ").";
}
for (auto [idx, layout] : llvm::enumerate(getLayouts())) {
ArrayRef<int64_t> layoutShape = layout.getShapes();
Expand All @@ -72,13 +72,11 @@ LogicalResult LayoutAttr::isValidLayout(VectorValue vector) const {
std::string layoutStr;
llvm::raw_string_ostream layoutOs(layoutStr);
printStripped(layoutOs);
return emitError(vector.getLoc(),
"Vector shape: [" + shapeStr +
"] does not match the layout (" + layoutStr +
") at dim " + std::to_string(idx) +
". Dimension expected by layout: " +
std::to_string(expectedShape) +
" actual: " + std::to_string(shape[idx]));
return emitError(loc, "Vector shape: [")
<< shapeStr << "] does not match the layout (" << layoutStr
<< ") at dim " << idx
<< ". Dimension expected by layout: " << expectedShape
<< " actual: " << shape[idx];
}
}
return success();
Expand Down Expand Up @@ -321,14 +319,14 @@ int64_t NestedLayoutAttr::getRank() const {
return getBatchesPerSubgroup().size();
}

LogicalResult NestedLayoutAttr::isValidLayout(VectorValue vector) const {
LogicalResult NestedLayoutAttr::isValidLayout(ShapedType shapeTy,
Location loc) const {
int64_t rank = getRank();
ArrayRef<int64_t> shape = vector.getType().getShape();
ArrayRef<int64_t> shape = shapeTy.getShape();
if (shape.size() != rank) {
return emitError(vector.getLoc(), "Rank of vector (" +
std::to_string(shape.size()) +
") does not match rank of layout (" +
std::to_string(rank) + ").");
return emitError(loc, "Rank of vector (")
<< shape.size() << ") does not match rank of layout (" << rank
<< ").";
}
// Multiply all shapes in the layout.
for (int i = 0, e = rank; i < e; ++i) {
Expand All @@ -343,13 +341,11 @@ LogicalResult NestedLayoutAttr::isValidLayout(VectorValue vector) const {
std::string layoutStr;
llvm::raw_string_ostream layoutOs(layoutStr);
printStripped(layoutOs);
return emitError(vector.getLoc(),
"Vector shape: [" + shapeStr +
"] does not match the layout (" + layoutStr +
") at dim " + std::to_string(i) +
". Dimension expected by layout: " +
std::to_string(expectedShape) +
" actual: " + std::to_string(shape[i]));
return emitError(loc, "Vector shape: [")
<< shapeStr << "] does not match the layout ("
<< layoutStr + ") at dim " << i
<< ". Dimension expected by layout: " << expectedShape
<< " actual: " << shape[i];
}
}
return success();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,10 @@ def VectorLayoutInterface : AttrInterface<"VectorLayoutInterface"> {

let methods = [
InterfaceMethod<
/*description=*/"Check if this layout is valid for the given vector type."
"On failure, emits diagnostics to explain the failure.",
/*description=*/"Check if this layout is valid for the given shape.",
/*retTy=*/"LogicalResult",
/*methodName=*/"isValidLayout",
/*args=*/(ins "::mlir::TypedValue<::mlir::VectorType>":$vector)
/*args=*/(ins "::mlir::ShapedType":$shape, "::mlir::Location":$loc)
>,
InterfaceMethod<
/*description=*/"Permutes the given layout.",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtOps.h"
#include <numeric>
#include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtDialect.h"

using namespace mlir;
Expand All @@ -19,7 +18,7 @@ using VectorValue = TypedValue<VectorType>;

// Validate that the layout has the same shape as the input.
LogicalResult ToLayoutOp::verify() {
return getLayout().isValidLayout(getInput());
return getLayout().isValidLayout(getInput().getType(), getLoc());
}

// to_simd -> to_simt
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,17 @@ def IREEVectorExt_ToLayoutOp : IREEVectorExt_PureOp<"to_layout", [
transforms the value to have that layout.
}];
let arguments = (ins
AnyVector:$input,
AnyShaped:$input,
VectorLayoutInterface:$layout
);
let results = (outs
AnyVector:$output
AnyShaped:$output
);
let extraClassDeclaration = [{}];
let extraClassDeclaration = [{
bool hasTensorSemantics() {
return isa<RankedTensorType>(getOutput().getType());
}
}];
let assemblyFormat = "$input `to` $layout attr-dict `:` type($input)";
let hasVerifier = 1;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
func.func @invalid_layout(%lhs: memref<32x32xf16>, %rhs: memref<32x32xf16>) -> vector<32x32xf16> {
%cst_0 = arith.constant 0.0 : f16
%c0 = arith.constant 0 : index
// expected-error @+1 {{Vector shape: [32, 32] does not match the layout (layout<<[ BATCHX, LANEX, VECTORY], [1, 1, 1]>, <[ BATCHY, LANEY, VECTORX], [4, 2, 4]>>) at dim 0. Dimension expected by layout: 1 actual: 32}}
%result = vector.transfer_read %lhs[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<32x32xf16>, vector<32x32xf16>
// expected-error @+1 {{Vector shape: [32, 32] does not match the layout (layout<<[ BATCHX, LANEX, VECTORY], [1, 1, 1]>, <[ BATCHY, LANEY, VECTORX], [4, 2, 4]>>) at dim 0. Dimension expected by layout: 1 actual: 32}}
%2 = iree_vector_ext.to_layout %result to #layout1 : vector<32x32xf16>
return %2 : vector<32x32xf16>
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright 2024 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

load("//build_tools/bazel:build_defs.oss.bzl", "iree_compiler_cc_library", "iree_gentbl_cc_library")

package(
default_visibility = ["//visibility:public"],
features = ["layering_check"],
licenses = ["notice"], # Apache 2.0
)

iree_gentbl_cc_library(
name = "PassesIncGen",
tbl_outs = [
(
["--gen-pass-decls"],
"Passes.h.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "Passes.td",
deps = [
"@llvm-project//mlir:PassBaseTdFiles",
],
)

iree_compiler_cc_library(
name = "VectorExtTransforms",
srcs = [
"Passes.cpp",
"VectorizeIREEVectorExtOps.cpp",
],
hdrs = [
"Passes.h",
"Passes.h.inc",
],
deps = [
":PassesIncGen",
"//compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR:IREEVectorExtDialect",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:FunctionInterfaces",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:TransformUtils",
"@llvm-project//mlir:Transforms",
"@llvm-project//mlir:VectorDialect",
"@llvm-project//mlir:VectorTransforms",
"@llvm-project//mlir:VectorUtils",
],
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
################################################################################
# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
# compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/BUILD.bazel #
# #
# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
# CMake-only content. #
# #
# To disable autogeneration for this file entirely, delete this header. #
################################################################################

iree_add_all_subdirs()

iree_tablegen_library(
NAME
PassesIncGen
TD_FILE
"Passes.td"
OUTS
--gen-pass-decls Passes.h.inc
)

iree_cc_library(
NAME
VectorExtTransforms
HDRS
"Passes.h"
"Passes.h.inc"
SRCS
"Passes.cpp"
"VectorizeIREEVectorExtOps.cpp"
DEPS
::PassesIncGen
LLVMSupport
MLIRArithDialect
MLIRFunctionInterfaces
MLIRIR
MLIRPass
MLIRSupport
MLIRTensorDialect
MLIRTransformUtils
MLIRTransforms
MLIRVectorDialect
MLIRVectorTransforms
MLIRVectorUtils
iree::compiler::Codegen::Dialect::VectorExt::IR::IREEVectorExtDialect
PUBLIC
)

### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Codegen/Dialect/VectorExt/Transforms/Passes.h"

namespace mlir::iree_compiler {

namespace IREE::VectorExt {
namespace {
#define GEN_PASS_REGISTRATION
#include "iree/compiler/Codegen/Dialect/VectorExt/Transforms/Passes.h.inc"
} // namespace
} // namespace IREE::VectorExt

void registerIREEVectorExtPasses() {
// Generated.
IREE::VectorExt::registerPasses();
}
} // namespace mlir::iree_compiler
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#ifndef IREE_COMPILER_CODEGEN_DIALECT_VECTOR_EXT_TRANSFORMS_PASSES_H_
#define IREE_COMPILER_CODEGEN_DIALECT_VECTOR_EXT_TRANSFORMS_PASSES_H_

#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Pass/Pass.h"

namespace mlir::iree_compiler::IREE::VectorExt {
#define GEN_PASS_DECL
#include "iree/compiler/Codegen/Dialect/VectorExt/Transforms/Passes.h.inc" // IWYU pragma: keep
} // namespace mlir::iree_compiler::IREE::VectorExt

namespace mlir::iree_compiler {
/// Register VectorExt passes.
void registerIREEVectorExtPasses();
} // namespace mlir::iree_compiler

#endif // IREE_COMPILER_CODEGEN_DIALECT_VECTOR_EXT_TRANSFORMS_PASSES_H_
Loading
Loading