From 8b69e12c12ce964e2b1ecb1164d80a9866d97d19 Mon Sep 17 00:00:00 2001 From: thomasraoux Date: Wed, 31 Mar 2021 13:16:09 -0700 Subject: [PATCH] [CUDA codegen] add vectorization infrastructure Enable vectorization for element-wise ops --- iree/compiler/Conversion/LinalgToNVVM/BUILD | 1 + .../Conversion/LinalgToNVVM/CMakeLists.txt | 1 + .../Conversion/LinalgToNVVM/ConvertToNVVM.cpp | 11 ++ .../Conversion/LinalgToNVVM/KernelConfig.cpp | 47 +----- .../Conversion/LinalgToNVVM/Passes.cpp | 9 +- .../compiler/Conversion/LinalgToNVVM/Passes.h | 3 + .../LinalgToNVVM/TileAndDistribute.cpp | 33 +++-- .../LinalgToNVVM/VectorizationPass.cpp | 135 ++++++++++++++++++ .../Conversion/LinalgToNVVM/test/BUILD | 1 + .../LinalgToNVVM/test/CMakeLists.txt | 1 + .../LinalgToNVVM/test/vectorization.mlir | 56 ++++++++ 11 files changed, 247 insertions(+), 51 deletions(-) create mode 100644 iree/compiler/Conversion/LinalgToNVVM/VectorizationPass.cpp create mode 100644 iree/compiler/Conversion/LinalgToNVVM/test/vectorization.mlir diff --git a/iree/compiler/Conversion/LinalgToNVVM/BUILD b/iree/compiler/Conversion/LinalgToNVVM/BUILD index fcf643fbf66b8..a3291e6e36e2c 100644 --- a/iree/compiler/Conversion/LinalgToNVVM/BUILD +++ b/iree/compiler/Conversion/LinalgToNVVM/BUILD @@ -25,6 +25,7 @@ cc_library( "KernelConfig.cpp", "Passes.cpp", "TileAndDistribute.cpp", + "VectorizationPass.cpp", ], hdrs = [ "KernelConfig.h", diff --git a/iree/compiler/Conversion/LinalgToNVVM/CMakeLists.txt b/iree/compiler/Conversion/LinalgToNVVM/CMakeLists.txt index 4d90c13fa2af9..229063473a69b 100644 --- a/iree/compiler/Conversion/LinalgToNVVM/CMakeLists.txt +++ b/iree/compiler/Conversion/LinalgToNVVM/CMakeLists.txt @@ -21,6 +21,7 @@ iree_cc_library( "KernelConfig.cpp" "Passes.cpp" "TileAndDistribute.cpp" + "VectorizationPass.cpp" DEPS MLIRGPU MLIRGPUToNVVMTransforms diff --git a/iree/compiler/Conversion/LinalgToNVVM/ConvertToNVVM.cpp b/iree/compiler/Conversion/LinalgToNVVM/ConvertToNVVM.cpp index 75708aef28876..179df7c4849a4 100644 --- a/iree/compiler/Conversion/LinalgToNVVM/ConvertToNVVM.cpp +++ b/iree/compiler/Conversion/LinalgToNVVM/ConvertToNVVM.cpp @@ -17,9 +17,11 @@ #include "iree/compiler/Dialect/IREE/IR/IREEOps.h" #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" +#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" #include "mlir/Dialect/GPU/Passes.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/Passes.h" @@ -183,6 +185,14 @@ struct ConvertToNVVMPass // Apply in-dialect lowering first. In-dialect lowering will replace ops // which need to be lowered further, which is not supported by a single // conversion pass. + // Run Vector -> Vector transformations ahead of conversion to LLVM. + { + OwningRewritePatternList patterns(&getContext()); + vector::populateVectorToVectorCanonicalizationPatterns(patterns); + vector::populateVectorSlicesLoweringPatterns(patterns); + vector::populateVectorContractLoweringPatterns(patterns); + (void)applyPatternsAndFoldGreedily(m, std::move(patterns)); + } { OwningRewritePatternList patterns(&getContext()); populateGpuRewritePatterns(patterns); @@ -203,6 +213,7 @@ struct ConvertToNVVMPass IREE::HAL::InterfaceWorkgroupSizeOp, NVVM::BlockDimXOp, NVVM::BlockDimYOp, NVVM::BlockDimZOp>>(m.getContext()); populateStdToLLVMConversionPatterns(converter, llvmPatterns); + populateVectorToLLVMConversionPatterns(converter, llvmPatterns); populateGpuToNVVMConversionPatterns(converter, llvmPatterns); LLVMConversionTarget target(getContext()); populateStdToLLVMFuncOpConversionPattern(converter, llvmPatterns); diff --git a/iree/compiler/Conversion/LinalgToNVVM/KernelConfig.cpp b/iree/compiler/Conversion/LinalgToNVVM/KernelConfig.cpp index 313c77a16750c..61fe98f19bedc 100644 --- a/iree/compiler/Conversion/LinalgToNVVM/KernelConfig.cpp +++ b/iree/compiler/Conversion/LinalgToNVVM/KernelConfig.cpp @@ -24,57 +24,22 @@ using namespace mlir::iree_compiler; static constexpr unsigned cudaWarpSize = 32; -/// Fills `inputTypes` and `outputTypes` with the original input/output types -/// for all tiles for `op`. -/// Copied from iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp -/// This should be moved to a common location if still needed in the future. -static void getInputOutputTypes(linalg::LinalgOp op, - SmallVectorImpl &inputTypes, - SmallVectorImpl &outputTypes) { - // NOTE: Special treatment to let the flow.dispatch.workgroups path to be able - // to query launch configurations. This should be cleaned up after the - // flow.dispatch.workgroups become the default path. - auto inputTypeAttr = - op->getAttrOfType("iree.codegen.original_input_types"); - auto outputTypeAttr = - op->getAttrOfType("iree.codegen.original_output_types"); - if (outputTypeAttr && inputTypeAttr) { - for (Type type : inputTypeAttr.getAsValueRange()) - inputTypes.push_back(type.cast()); - for (Type type : outputTypeAttr.getAsValueRange()) - outputTypes.push_back(type.cast()); - } else { - for (Type type : op.getInputBufferTypes()) - inputTypes.push_back(type.cast()); - for (Type type : op.getOutputBufferTypes()) - outputTypes.push_back(type.cast()); - } -} - static LaunchConfig getOpLaunchConfig(linalg::GenericOp op) { LaunchConfig config; size_t numLoops = getNumOuterParallelLoops(op); if (numLoops == 0) return config; - SmallVector inputTypes, outputTypes; - getInputOutputTypes(op, inputTypes, outputTypes); - config.setWorkgroupSize({cudaWarpSize, 1, 1}); - SmallVector candidateTileSizes; - candidateTileSizes.append({4 * cudaWarpSize, 2 * cudaWarpSize, cudaWarpSize}); - // Use the first tile size that can divide the shape. If the shape is not - // aligned on any of the tile sizes pick the smallest tile of one element per - // thread. - int64_t lowerTs = cudaWarpSize; - for (int64_t size : candidateTileSizes) { - if (outputTypes[0].getShape().back() % size != 0) continue; - lowerTs = size; - break; - } + // Pick a fixed tile size independent of the original shape. + // TODO(thomasraoux): Currently the original shape information is lost during + // tiling at the flow level. We need way to access it to be able to make a + // better choice of tile size. + int64_t lowerTs = 4 * cudaWarpSize; SmallVector ts; ts.resize(numLoops, 1); ts.back() = lowerTs; config.setTileSizes(op, ts, 0); // Workgroup level. + config.setTileSizes(op, {}, 1); // Subgroup level. ts.back() = lowerTs / cudaWarpSize; config.setTileSizes(op, ts, 2); // Thread level. return config; diff --git a/iree/compiler/Conversion/LinalgToNVVM/Passes.cpp b/iree/compiler/Conversion/LinalgToNVVM/Passes.cpp index 9dc6d9f05c5c7..c873b35e34167 100644 --- a/iree/compiler/Conversion/LinalgToNVVM/Passes.cpp +++ b/iree/compiler/Conversion/LinalgToNVVM/Passes.cpp @@ -19,6 +19,7 @@ #include "iree/compiler/Conversion/HLOToLinalg/Passes.h" #include "iree/compiler/Dialect/Shape/Transforms/Passes.h" #include "mlir/Conversion/SCFToStandard/SCFToStandard.h" +#include "mlir/Conversion/StandardToSPIRV/StandardToSPIRVPass.h" #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Pass/PassManager.h" #include "mlir/Pass/PassOptions.h" @@ -37,7 +38,13 @@ static void addLinalgToNVVMPasses(OpPassManager &pm) { // Distribute linalg onto threads within the workgroup. pm.addPass(createTileAndDistributeToThreads()); - // TODO: Linalg -> vector + pm.addNestedPass(createCanonicalizerPass()); + pm.addNestedPass(createCSEPass()); + + // Linalg -> vector + pm.nest().addNestedPass(createVectorizationPass()); + pm.nest().addNestedPass(createCanonicalizerPass()); + pm.nest().addNestedPass(createCSEPass()); pm.addNestedPass(createLowerAffinePass()); pm.addNestedPass(createCanonicalizerPass()); diff --git a/iree/compiler/Conversion/LinalgToNVVM/Passes.h b/iree/compiler/Conversion/LinalgToNVVM/Passes.h index 1d1f6eb1fb964..75f5a5693f728 100644 --- a/iree/compiler/Conversion/LinalgToNVVM/Passes.h +++ b/iree/compiler/Conversion/LinalgToNVVM/Passes.h @@ -24,6 +24,9 @@ namespace iree_compiler { /// Performs the final conversion to NNVM+LLVM dialect. std::unique_ptr> createConvertToNVVMPass(); +/// Convert Linalg ops to Vector. +std::unique_ptr> createVectorizationPass(); + /// Perform tiling and distribution to threads. std::unique_ptr> createTileAndDistributeToThreads(); diff --git a/iree/compiler/Conversion/LinalgToNVVM/TileAndDistribute.cpp b/iree/compiler/Conversion/LinalgToNVVM/TileAndDistribute.cpp index e362cff08c07d..f7abf41c22982 100644 --- a/iree/compiler/Conversion/LinalgToNVVM/TileAndDistribute.cpp +++ b/iree/compiler/Conversion/LinalgToNVVM/TileAndDistribute.cpp @@ -29,15 +29,16 @@ namespace mlir { namespace iree_compiler { +static constexpr int32_t kNumGPUDims = 3; + static SmallVector getGPUThreadIdsAndCounts( OpBuilder &builder, Location loc, unsigned numDims) { - static constexpr int32_t kNumGPUDims = 3; + assert(numDims <= kNumGPUDims); SmallVector procInfo(numDims); std::array dimAttr{"x", "y", "z"}; Type indexType = builder.getIndexType(); for (unsigned i = 0; i < numDims; ++i) { - StringAttr attr = - builder.getStringAttr(dimAttr[std::min(i, kNumGPUDims)]); + StringAttr attr = builder.getStringAttr(dimAttr[i]); procInfo[numDims - 1 - i] = { builder.create(loc, indexType, attr), builder.create(loc, indexType, attr)}; @@ -47,16 +48,20 @@ static SmallVector getGPUThreadIdsAndCounts( /// Patterns for thread level tiling. static void populateTilingToInvocationPatterns( - MLIRContext *context, OwningRewritePatternList &patterns) { + MLIRContext *context, OwningRewritePatternList &patterns, + ArrayRef tileSizes) { linalg::TileSizeComputationFunction getInnerTileSizeFn = - [](OpBuilder &builder, Operation *operation) { - ArrayRef tileSizes = {4}; + [tileSizes](OpBuilder &builder, Operation *operation) { if (tileSizes.empty()) return SmallVector(); SmallVector tileSizesVal; tileSizesVal.reserve(tileSizes.size()); - for (auto val : tileSizes) { + for (auto val : llvm::enumerate(tileSizes)) { + // Only tile the last 3 dimensions. Use tile size of 0 for any higher + // dimension as we only support distributing on 3 dimensions. + int64_t t = + (tileSizes.size() - val.index()) <= kNumGPUDims ? val.value() : 0; tileSizesVal.push_back( - builder.create(operation->getLoc(), val)); + builder.create(operation->getLoc(), t)); } return tileSizesVal; }; @@ -174,13 +179,23 @@ struct TileAndDistributeToThreads } { + SmallVector threadTileSize = + llvm::to_vector<4>(config->getTileSizes(rootOp, 2)); // Apply last level of tiling and distribute to threads. OwningRewritePatternList threadLevelTilingPatterns(context); - populateTilingToInvocationPatterns(context, threadLevelTilingPatterns); + populateTilingToInvocationPatterns(context, threadLevelTilingPatterns, + threadTileSize); (void)applyPatternsAndFoldGreedily( funcOp, std::move(threadLevelTilingPatterns)); applyCanonicalizationPatternsForTiling(context, funcOp); } + { + OwningRewritePatternList patterns(context); + // Apply canonicalization patterns. + linalg::populateLinalgTilingCanonicalizationPatterns(patterns); + populateAffineMinSCFCanonicalizationPattern(patterns); + (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); + } } } }; diff --git a/iree/compiler/Conversion/LinalgToNVVM/VectorizationPass.cpp b/iree/compiler/Conversion/LinalgToNVVM/VectorizationPass.cpp new file mode 100644 index 0000000000000..2d0c194831b79 --- /dev/null +++ b/iree/compiler/Conversion/LinalgToNVVM/VectorizationPass.cpp @@ -0,0 +1,135 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "iree/compiler/Conversion/CodegenUtils/FunctionUtils.h" +#include "iree/compiler/Conversion/CodegenUtils/MarkerUtils.h" +#include "iree/compiler/Conversion/CodegenUtils/TransformUtils.h" +#include "iree/compiler/Conversion/Common/Transforms.h" +#include "iree/compiler/Conversion/LinalgToNVVM/Passes.h" +#include "mlir/Conversion/StandardToSPIRV/StandardToSPIRV.h" +#include "mlir/Dialect/Linalg/Transforms/Hoisting.h" +#include "mlir/Dialect/Vector/VectorOps.h" +#include "mlir/Dialect/Vector/VectorTransforms.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" + +namespace mlir { +namespace iree_compiler { + +//====---------------------------------------------------------------------===// +// Patterns for vectorization +//====---------------------------------------------------------------------===// + +static void populateVectorizationPatterns(OwningRewritePatternList &patterns) { + linalg::insertVectorizationPatterns( + patterns, linalg::LinalgVectorizationOptions(), + linalg::LinalgTransformationFilter( + Identifier::get(getVectorizeMarker(), patterns.getContext()))); +} + +static Optional> getNativeVectorSize(Operation *op) { + if ((OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1)) { + if (auto vecType = op->getResultTypes()[0].dyn_cast()) { + // Map elementwise ops to vec4. + SmallVector nativeSize(vecType.getRank(), 1); + nativeSize.back() = 4; + return nativeSize; + } + } else if (auto vt = dyn_cast(op)) { + auto rank = vt.getVectorType().getRank(); + SmallVector nativeSize(rank, 1); + nativeSize.back() = 4; + return nativeSize; + } + return llvm::None; +} + +static void populateVectorUnrollPatterns(OwningRewritePatternList &patterns) { + patterns.add( + patterns.getContext(), + vector::UnrollVectorOptions().setNativeShapeFn(getNativeVectorSize)); +} + +namespace { +struct VectorizationPass + : public PassWrapper> { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + void runOnOperation() override { + auto funcOp = getOperation(); + MLIRContext *context = &getContext(); + + { + // Step 1. Vectorize + OwningRewritePatternList vectorizationPatterns(context); + populateVectorizationPatterns(vectorizationPatterns); + (void)applyPatternsAndFoldGreedily(funcOp, + std::move(vectorizationPatterns)); + } + // TODO: This should be a folding of Add into Contract in core but while + // they live in different dialects, it is not possible without unnatural + // dependencies. + funcOp.walk([&](Operation *op) { + if (auto contract = canonicalizeContractionAdd(op)) + op->replaceAllUsesWith(contract); + }); + + { + // Step 2. Unroll the vetors to native size and canonicalize. + OwningRewritePatternList vectorUnrollPatterns(context); + populateVectorUnrollPatterns(vectorUnrollPatterns); + (void)applyPatternsAndFoldGreedily(funcOp, + std::move(vectorUnrollPatterns)); + + OwningRewritePatternList canonicalizationPatterns1(funcOp.getContext()); + vector::populateVectorToVectorCanonicalizationPatterns( + canonicalizationPatterns1); + vector::populateVectorToVectorTransformationPatterns( + canonicalizationPatterns1); + (void)applyPatternsAndFoldGreedily(funcOp, + std::move(canonicalizationPatterns1)); + + OwningRewritePatternList canonicalizationPatterns2(funcOp.getContext()); + vector::populateVectorSlicesLoweringPatterns(canonicalizationPatterns2); + (void)applyPatternsAndFoldGreedily(funcOp, + std::move(canonicalizationPatterns2)); + + linalg::hoistRedundantVectorTransfers(funcOp); + } + { + // Step 3. Canonicalize the transfer ops generated. + RewritePatternSet vectorToLoopsPatterns(context); + VectorTransferToSCFOptions vectorToSCFOptions; + vectorToSCFOptions.setUnroll(true); + populateVectorToSCFConversionPatterns(vectorToLoopsPatterns, + vectorToSCFOptions); + populateStdLegalizationPatternsForSPIRVLowering(vectorToLoopsPatterns); + (void)applyPatternsAndFoldGreedily(funcOp, + std::move(vectorToLoopsPatterns)); + } + } +}; +} // namespace + +std::unique_ptr> createVectorizationPass() { + return std::make_unique(); +} + +static PassRegistration pass( + "iree-codegen-cuda-vectorization", "Pass to convert linalg into Vector."); + +} // namespace iree_compiler +} // namespace mlir diff --git a/iree/compiler/Conversion/LinalgToNVVM/test/BUILD b/iree/compiler/Conversion/LinalgToNVVM/test/BUILD index 620ce8e35c803..d68b579abf263 100644 --- a/iree/compiler/Conversion/LinalgToNVVM/test/BUILD +++ b/iree/compiler/Conversion/LinalgToNVVM/test/BUILD @@ -30,6 +30,7 @@ iree_lit_test_suite( "convert_to_nvvm.mlir", "distribute_to_thread.mlir", "pipeline_test.mlir", + "vectorization.mlir", ], include = ["*.mlir"], ), diff --git a/iree/compiler/Conversion/LinalgToNVVM/test/CMakeLists.txt b/iree/compiler/Conversion/LinalgToNVVM/test/CMakeLists.txt index 0097d0e372585..8b44f283e54b3 100644 --- a/iree/compiler/Conversion/LinalgToNVVM/test/CMakeLists.txt +++ b/iree/compiler/Conversion/LinalgToNVVM/test/CMakeLists.txt @@ -17,6 +17,7 @@ iree_lit_test_suite( "convert_to_nvvm.mlir" "distribute_to_thread.mlir" "pipeline_test.mlir" + "vectorization.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Conversion/LinalgToNVVM/test/vectorization.mlir b/iree/compiler/Conversion/LinalgToNVVM/test/vectorization.mlir new file mode 100644 index 0000000000000..544d7c621bc2f --- /dev/null +++ b/iree/compiler/Conversion/LinalgToNVVM/test/vectorization.mlir @@ -0,0 +1,56 @@ +// RUN: iree-opt -iree-codegen-cuda-vectorization %s | IreeFileCheck %s + +func @add_dispatch_0() attributes {cuda_workgroup_size = dense<[32, 1, 1]> : vector<3xi64>} { + %c128 = constant 128 : index + %c1 = constant 1 : index + %c0 = constant 0 : index + %c1024 = constant 1024 : index + %0 = hal.interface.binding.subspan @legacy_io::@ro0[%c0] : memref<1024x1024x1024xf32> + %1 = hal.interface.binding.subspan @legacy_io::@ro1[%c0] : memref<1024x1024x1024xf32> + %2 = hal.interface.binding.subspan @legacy_io::@wo2[%c0] : memref<1024x1024x1024xf32> + %workgroup_id_x = hal.interface.workgroup.id[0] : index + %workgroup_count_x = hal.interface.workgroup.count[0] : index + %workgroup_id_y = hal.interface.workgroup.id[1] : index + %workgroup_count_y = hal.interface.workgroup.count[1] : index + %workgroup_id_z = hal.interface.workgroup.id[2] : index + %workgroup_count_z = hal.interface.workgroup.count[2] : index + scf.for %arg0 = %workgroup_id_z to %c1024 step %workgroup_count_z { + scf.for %arg1 = %workgroup_id_y to %c1024 step %workgroup_count_y { + %3 = affine.apply affine_map<()[s0] -> (s0 * 128)>()[%workgroup_id_x] + %4 = affine.apply affine_map<()[s0] -> (s0 * 128)>()[%workgroup_count_x] + scf.for %arg2 = %3 to %c1024 step %4 { + %5 = memref.subview %0[%arg0, %arg1, %arg2] [1, 1, 128] [1, 1, 1] : memref<1024x1024x1024xf32> to memref<1x1x128xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 1048576 + s0 + d1 * 1024 + d2)>> + %6 = memref.subview %1[%arg0, %arg1, %arg2] [1, 1, 128] [1, 1, 1] : memref<1024x1024x1024xf32> to memref<1x1x128xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 1048576 + s0 + d1 * 1024 + d2)>> + %7 = memref.subview %2[%arg0, %arg1, %arg2] [1, 1, 128] [1, 1, 1] : memref<1024x1024x1024xf32> to memref<1x1x128xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 1048576 + s0 + d1 * 1024 + d2)>> + %8 = "gpu.thread_id"() {dimension = "x"} : () -> index + %9 = "gpu.block_dim"() {dimension = "x"} : () -> index + %10 = "gpu.thread_id"() {dimension = "y"} : () -> index + %11 = "gpu.block_dim"() {dimension = "y"} : () -> index + %12 = "gpu.thread_id"() {dimension = "z"} : () -> index + %13 = "gpu.block_dim"() {dimension = "z"} : () -> index + scf.for %arg3 = %12 to %c1 step %13 { + scf.for %arg4 = %10 to %c1 step %11 { + %14 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%8] + %15 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%9] + scf.for %arg5 = %14 to %c128 step %15 { + %16 = memref.subview %5[%arg3, %arg4, %arg5] [1, 1, 4] [1, 1, 1] : memref<1x1x128xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 1048576 + s0 + d1 * 1024 + d2)>> to memref<1x1x4xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 1048576 + s0 + d1 * 1024 + d2)>> + %17 = memref.subview %6[%arg3, %arg4, %arg5] [1, 1, 4] [1, 1, 1] : memref<1x1x128xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 1048576 + s0 + d1 * 1024 + d2)>> to memref<1x1x4xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 1048576 + s0 + d1 * 1024 + d2)>> + %18 = memref.subview %7[%arg3, %arg4, %arg5] [1, 1, 4] [1, 1, 1] : memref<1x1x128xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 1048576 + s0 + d1 * 1024 + d2)>> to memref<1x1x4xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 1048576 + s0 + d1 * 1024 + d2)>> + linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%16, %17 : memref<1x1x4xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 1048576 + s0 + d1 * 1024 + d2)>>, memref<1x1x4xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 1048576 + s0 + d1 * 1024 + d2)>>) outs(%18 : memref<1x1x4xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 1048576 + s0 + d1 * 1024 + d2)>>) attrs = {__internal_linalg_transform__ = "vectorize", is_root_op, launch_info_key = "__op_num_0__"} { + ^bb0(%arg6: f32, %arg7: f32, %arg8: f32): // no predecessors + %19 = addf %arg6, %arg7 : f32 + linalg.yield %19 : f32 + } + } + } + } + } + } + } + return +} +// CHECK-LABEL: func @add_dispatch_0() +// CHECK: vector.transfer_read {{.*}} : memref<1024x1024x1024xf32>, vector<4xf32> +// CHECK: vector.transfer_read {{.*}} : memref<1024x1024x1024xf32>, vector<4xf32> +// CHECK: addf %{{.*}}, %{{.*}} : vector<1x1x4xf32> +// CHECK: vector.transfer_write {{.*}} : vector<4xf32>, memref<1024x1024x1024xf32>