Skip to content

Commit

Permalink
[CUDA codegen] add vectorization infrastructure
Browse files Browse the repository at this point in the history
Enable vectorization for element-wise ops
  • Loading branch information
ThomasRaoux committed Apr 1, 2021
1 parent 1bdc3a4 commit 8b69e12
Show file tree
Hide file tree
Showing 11 changed files with 247 additions and 51 deletions.
1 change: 1 addition & 0 deletions iree/compiler/Conversion/LinalgToNVVM/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ cc_library(
"KernelConfig.cpp",
"Passes.cpp",
"TileAndDistribute.cpp",
"VectorizationPass.cpp",
],
hdrs = [
"KernelConfig.h",
Expand Down
1 change: 1 addition & 0 deletions iree/compiler/Conversion/LinalgToNVVM/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ iree_cc_library(
"KernelConfig.cpp"
"Passes.cpp"
"TileAndDistribute.cpp"
"VectorizationPass.cpp"
DEPS
MLIRGPU
MLIRGPUToNVVMTransforms
Expand Down
11 changes: 11 additions & 0 deletions iree/compiler/Conversion/LinalgToNVVM/ConvertToNVVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@
#include "iree/compiler/Dialect/IREE/IR/IREEOps.h"
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Dialect/GPU/Passes.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"

Expand Down Expand Up @@ -183,6 +185,14 @@ struct ConvertToNVVMPass
// Apply in-dialect lowering first. In-dialect lowering will replace ops
// which need to be lowered further, which is not supported by a single
// conversion pass.
// Run Vector -> Vector transformations ahead of conversion to LLVM.
{
OwningRewritePatternList patterns(&getContext());
vector::populateVectorToVectorCanonicalizationPatterns(patterns);
vector::populateVectorSlicesLoweringPatterns(patterns);
vector::populateVectorContractLoweringPatterns(patterns);
(void)applyPatternsAndFoldGreedily(m, std::move(patterns));
}
{
OwningRewritePatternList patterns(&getContext());
populateGpuRewritePatterns(patterns);
Expand All @@ -203,6 +213,7 @@ struct ConvertToNVVMPass
IREE::HAL::InterfaceWorkgroupSizeOp, NVVM::BlockDimXOp,
NVVM::BlockDimYOp, NVVM::BlockDimZOp>>(m.getContext());
populateStdToLLVMConversionPatterns(converter, llvmPatterns);
populateVectorToLLVMConversionPatterns(converter, llvmPatterns);
populateGpuToNVVMConversionPatterns(converter, llvmPatterns);
LLVMConversionTarget target(getContext());
populateStdToLLVMFuncOpConversionPattern(converter, llvmPatterns);
Expand Down
47 changes: 6 additions & 41 deletions iree/compiler/Conversion/LinalgToNVVM/KernelConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,57 +24,22 @@ using namespace mlir::iree_compiler;

static constexpr unsigned cudaWarpSize = 32;

/// Fills `inputTypes` and `outputTypes` with the original input/output types
/// for all tiles for `op`.
/// Copied from iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp
/// This should be moved to a common location if still needed in the future.
static void getInputOutputTypes(linalg::LinalgOp op,
SmallVectorImpl<ShapedType> &inputTypes,
SmallVectorImpl<ShapedType> &outputTypes) {
// NOTE: Special treatment to let the flow.dispatch.workgroups path to be able
// to query launch configurations. This should be cleaned up after the
// flow.dispatch.workgroups become the default path.
auto inputTypeAttr =
op->getAttrOfType<ArrayAttr>("iree.codegen.original_input_types");
auto outputTypeAttr =
op->getAttrOfType<ArrayAttr>("iree.codegen.original_output_types");
if (outputTypeAttr && inputTypeAttr) {
for (Type type : inputTypeAttr.getAsValueRange<TypeAttr>())
inputTypes.push_back(type.cast<ShapedType>());
for (Type type : outputTypeAttr.getAsValueRange<TypeAttr>())
outputTypes.push_back(type.cast<ShapedType>());
} else {
for (Type type : op.getInputBufferTypes())
inputTypes.push_back(type.cast<ShapedType>());
for (Type type : op.getOutputBufferTypes())
outputTypes.push_back(type.cast<ShapedType>());
}
}

static LaunchConfig getOpLaunchConfig(linalg::GenericOp op) {
LaunchConfig config;
size_t numLoops = getNumOuterParallelLoops(op);
if (numLoops == 0) return config;

SmallVector<ShapedType, 4> inputTypes, outputTypes;
getInputOutputTypes(op, inputTypes, outputTypes);

config.setWorkgroupSize({cudaWarpSize, 1, 1});
SmallVector<int64_t, 4> candidateTileSizes;
candidateTileSizes.append({4 * cudaWarpSize, 2 * cudaWarpSize, cudaWarpSize});
// Use the first tile size that can divide the shape. If the shape is not
// aligned on any of the tile sizes pick the smallest tile of one element per
// thread.
int64_t lowerTs = cudaWarpSize;
for (int64_t size : candidateTileSizes) {
if (outputTypes[0].getShape().back() % size != 0) continue;
lowerTs = size;
break;
}
// Pick a fixed tile size independent of the original shape.
// TODO(thomasraoux): Currently the original shape information is lost during
// tiling at the flow level. We need way to access it to be able to make a
// better choice of tile size.
int64_t lowerTs = 4 * cudaWarpSize;
SmallVector<int64_t, 4> ts;
ts.resize(numLoops, 1);
ts.back() = lowerTs;
config.setTileSizes(op, ts, 0); // Workgroup level.
config.setTileSizes(op, {}, 1); // Subgroup level.
ts.back() = lowerTs / cudaWarpSize;
config.setTileSizes(op, ts, 2); // Thread level.
return config;
Expand Down
9 changes: 8 additions & 1 deletion iree/compiler/Conversion/LinalgToNVVM/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "iree/compiler/Conversion/HLOToLinalg/Passes.h"
#include "iree/compiler/Dialect/Shape/Transforms/Passes.h"
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
#include "mlir/Conversion/StandardToSPIRV/StandardToSPIRVPass.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Pass/PassOptions.h"
Expand All @@ -37,7 +38,13 @@ static void addLinalgToNVVMPasses(OpPassManager &pm) {

// Distribute linalg onto threads within the workgroup.
pm.addPass(createTileAndDistributeToThreads());
// TODO: Linalg -> vector
pm.addNestedPass<ModuleOp>(createCanonicalizerPass());
pm.addNestedPass<ModuleOp>(createCSEPass());

// Linalg -> vector
pm.nest<ModuleOp>().addNestedPass<FuncOp>(createVectorizationPass());
pm.nest<ModuleOp>().addNestedPass<FuncOp>(createCanonicalizerPass());
pm.nest<ModuleOp>().addNestedPass<FuncOp>(createCSEPass());

pm.addNestedPass<ModuleOp>(createLowerAffinePass());
pm.addNestedPass<ModuleOp>(createCanonicalizerPass());
Expand Down
3 changes: 3 additions & 0 deletions iree/compiler/Conversion/LinalgToNVVM/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ namespace iree_compiler {
/// Performs the final conversion to NNVM+LLVM dialect.
std::unique_ptr<OperationPass<ModuleOp>> createConvertToNVVMPass();

/// Convert Linalg ops to Vector.
std::unique_ptr<OperationPass<FuncOp>> createVectorizationPass();

/// Perform tiling and distribution to threads.
std::unique_ptr<OperationPass<IREE::HAL::ExecutableTargetOp>>
createTileAndDistributeToThreads();
Expand Down
33 changes: 24 additions & 9 deletions iree/compiler/Conversion/LinalgToNVVM/TileAndDistribute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,16 @@
namespace mlir {
namespace iree_compiler {

static constexpr int32_t kNumGPUDims = 3;

static SmallVector<linalg::ProcInfo, 2> getGPUThreadIdsAndCounts(
OpBuilder &builder, Location loc, unsigned numDims) {
static constexpr int32_t kNumGPUDims = 3;
assert(numDims <= kNumGPUDims);
SmallVector<linalg::ProcInfo, 2> procInfo(numDims);
std::array<StringRef, kNumGPUDims> dimAttr{"x", "y", "z"};
Type indexType = builder.getIndexType();
for (unsigned i = 0; i < numDims; ++i) {
StringAttr attr =
builder.getStringAttr(dimAttr[std::min<unsigned>(i, kNumGPUDims)]);
StringAttr attr = builder.getStringAttr(dimAttr[i]);
procInfo[numDims - 1 - i] = {
builder.create<gpu::ThreadIdOp>(loc, indexType, attr),
builder.create<gpu::BlockDimOp>(loc, indexType, attr)};
Expand All @@ -47,16 +48,20 @@ static SmallVector<linalg::ProcInfo, 2> getGPUThreadIdsAndCounts(

/// Patterns for thread level tiling.
static void populateTilingToInvocationPatterns(
MLIRContext *context, OwningRewritePatternList &patterns) {
MLIRContext *context, OwningRewritePatternList &patterns,
ArrayRef<int64_t> tileSizes) {
linalg::TileSizeComputationFunction getInnerTileSizeFn =
[](OpBuilder &builder, Operation *operation) {
ArrayRef<int64_t> tileSizes = {4};
[tileSizes](OpBuilder &builder, Operation *operation) {
if (tileSizes.empty()) return SmallVector<Value, 4>();
SmallVector<Value, 4> tileSizesVal;
tileSizesVal.reserve(tileSizes.size());
for (auto val : tileSizes) {
for (auto val : llvm::enumerate(tileSizes)) {
// Only tile the last 3 dimensions. Use tile size of 0 for any higher
// dimension as we only support distributing on 3 dimensions.
int64_t t =
(tileSizes.size() - val.index()) <= kNumGPUDims ? val.value() : 0;
tileSizesVal.push_back(
builder.create<ConstantIndexOp>(operation->getLoc(), val));
builder.create<ConstantIndexOp>(operation->getLoc(), t));
}
return tileSizesVal;
};
Expand Down Expand Up @@ -174,13 +179,23 @@ struct TileAndDistributeToThreads
}

{
SmallVector<int64_t, 4> threadTileSize =
llvm::to_vector<4>(config->getTileSizes(rootOp, 2));
// Apply last level of tiling and distribute to threads.
OwningRewritePatternList threadLevelTilingPatterns(context);
populateTilingToInvocationPatterns(context, threadLevelTilingPatterns);
populateTilingToInvocationPatterns(context, threadLevelTilingPatterns,
threadTileSize);
(void)applyPatternsAndFoldGreedily(
funcOp, std::move(threadLevelTilingPatterns));
applyCanonicalizationPatternsForTiling(context, funcOp);
}
{
OwningRewritePatternList patterns(context);
// Apply canonicalization patterns.
linalg::populateLinalgTilingCanonicalizationPatterns(patterns);
populateAffineMinSCFCanonicalizationPattern(patterns);
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}
}
}
};
Expand Down
135 changes: 135 additions & 0 deletions iree/compiler/Conversion/LinalgToNVVM/VectorizationPass.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
// Copyright 2021 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "iree/compiler/Conversion/CodegenUtils/FunctionUtils.h"
#include "iree/compiler/Conversion/CodegenUtils/MarkerUtils.h"
#include "iree/compiler/Conversion/CodegenUtils/TransformUtils.h"
#include "iree/compiler/Conversion/Common/Transforms.h"
#include "iree/compiler/Conversion/LinalgToNVVM/Passes.h"
#include "mlir/Conversion/StandardToSPIRV/StandardToSPIRV.h"
#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/Dialect/Vector/VectorTransforms.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"

namespace mlir {
namespace iree_compiler {

//====---------------------------------------------------------------------===//
// Patterns for vectorization
//====---------------------------------------------------------------------===//

static void populateVectorizationPatterns(OwningRewritePatternList &patterns) {
linalg::insertVectorizationPatterns<linalg::FillOp, linalg::GenericOp,
linalg::ContractionOpInterface>(
patterns, linalg::LinalgVectorizationOptions(),
linalg::LinalgTransformationFilter(
Identifier::get(getVectorizeMarker(), patterns.getContext())));
}

static Optional<SmallVector<int64_t, 4>> getNativeVectorSize(Operation *op) {
if ((OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1)) {
if (auto vecType = op->getResultTypes()[0].dyn_cast<VectorType>()) {
// Map elementwise ops to vec4.
SmallVector<int64_t, 4> nativeSize(vecType.getRank(), 1);
nativeSize.back() = 4;
return nativeSize;
}
} else if (auto vt = dyn_cast<VectorTransferOpInterface>(op)) {
auto rank = vt.getVectorType().getRank();
SmallVector<int64_t, 4> nativeSize(rank, 1);
nativeSize.back() = 4;
return nativeSize;
}
return llvm::None;
}

static void populateVectorUnrollPatterns(OwningRewritePatternList &patterns) {
patterns.add<vector::UnrollVectorPattern>(
patterns.getContext(),
vector::UnrollVectorOptions().setNativeShapeFn(getNativeVectorSize));
}

namespace {
struct VectorizationPass
: public PassWrapper<VectorizationPass, OperationPass<FuncOp>> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<vector::VectorDialect>();
}
void runOnOperation() override {
auto funcOp = getOperation();
MLIRContext *context = &getContext();

{
// Step 1. Vectorize
OwningRewritePatternList vectorizationPatterns(context);
populateVectorizationPatterns(vectorizationPatterns);
(void)applyPatternsAndFoldGreedily(funcOp,
std::move(vectorizationPatterns));
}
// TODO: This should be a folding of Add into Contract in core but while
// they live in different dialects, it is not possible without unnatural
// dependencies.
funcOp.walk([&](Operation *op) {
if (auto contract = canonicalizeContractionAdd(op))
op->replaceAllUsesWith(contract);
});

{
// Step 2. Unroll the vetors to native size and canonicalize.
OwningRewritePatternList vectorUnrollPatterns(context);
populateVectorUnrollPatterns(vectorUnrollPatterns);
(void)applyPatternsAndFoldGreedily(funcOp,
std::move(vectorUnrollPatterns));

OwningRewritePatternList canonicalizationPatterns1(funcOp.getContext());
vector::populateVectorToVectorCanonicalizationPatterns(
canonicalizationPatterns1);
vector::populateVectorToVectorTransformationPatterns(
canonicalizationPatterns1);
(void)applyPatternsAndFoldGreedily(funcOp,
std::move(canonicalizationPatterns1));

OwningRewritePatternList canonicalizationPatterns2(funcOp.getContext());
vector::populateVectorSlicesLoweringPatterns(canonicalizationPatterns2);
(void)applyPatternsAndFoldGreedily(funcOp,
std::move(canonicalizationPatterns2));

linalg::hoistRedundantVectorTransfers(funcOp);
}
{
// Step 3. Canonicalize the transfer ops generated.
RewritePatternSet vectorToLoopsPatterns(context);
VectorTransferToSCFOptions vectorToSCFOptions;
vectorToSCFOptions.setUnroll(true);
populateVectorToSCFConversionPatterns(vectorToLoopsPatterns,
vectorToSCFOptions);
populateStdLegalizationPatternsForSPIRVLowering(vectorToLoopsPatterns);
(void)applyPatternsAndFoldGreedily(funcOp,
std::move(vectorToLoopsPatterns));
}
}
};
} // namespace

std::unique_ptr<OperationPass<FuncOp>> createVectorizationPass() {
return std::make_unique<VectorizationPass>();
}

static PassRegistration<VectorizationPass> pass(
"iree-codegen-cuda-vectorization", "Pass to convert linalg into Vector.");

} // namespace iree_compiler
} // namespace mlir
1 change: 1 addition & 0 deletions iree/compiler/Conversion/LinalgToNVVM/test/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ iree_lit_test_suite(
"convert_to_nvvm.mlir",
"distribute_to_thread.mlir",
"pipeline_test.mlir",
"vectorization.mlir",
],
include = ["*.mlir"],
),
Expand Down
1 change: 1 addition & 0 deletions iree/compiler/Conversion/LinalgToNVVM/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ iree_lit_test_suite(
"convert_to_nvvm.mlir"
"distribute_to_thread.mlir"
"pipeline_test.mlir"
"vectorization.mlir"
DATA
iree::tools::IreeFileCheck
iree::tools::iree-opt
Expand Down
Loading

0 comments on commit 8b69e12

Please sign in to comment.