From f5804ec74a772432cc2183de5b3675726c213435 Mon Sep 17 00:00:00 2001 From: Thomas Date: Tue, 23 Mar 2021 14:48:39 -0700 Subject: [PATCH] Fix problem bug in MobileBert with vectorization enable. (#5211) --- .../LinalgToSPIRV/KernelDispatchUtils.cpp | 24 +++++++++++-------- .../LinalgTileAndDistributePass.cpp | 19 ++++++++++++++- 2 files changed, 32 insertions(+), 11 deletions(-) diff --git a/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp b/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp index f31b97fb5ac0..5ffd3ac9be4b 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp +++ b/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp @@ -317,6 +317,14 @@ LogicalResult getOpLaunchConfig(linalg::GenericOp op, const SPIRVCodegenOptions &options, TileSizesListType &tileSizes, LaunchConfigInfo &config) { + // Skip vectorization for non-minor identity inputs as it generates + // transfer_read ops with permutation maps that we currently cannot lower. + // TODO: Remove this restriction once the lowering of the permutation map is + // supported in core. + bool vectorize = options.enableVectorization && + llvm::all_of(op.getIndexingMaps(), [](AffineMap &map) { + return map.isMinorIdentity(); + }); int64_t subgroupSize = targetEnv.getResourceLimits().subgroup_size().getValue().getSExtValue(); config.workgroupSize[0] = subgroupSize; @@ -330,9 +338,10 @@ LogicalResult getOpLaunchConfig(linalg::GenericOp op, // avoid a mismatch in the number of workgroup dispatched, we pick a tile size // to have one element per thread. // TODO: Remove this once we switch to linalg on tensor path. - if (options.enableVectorization) { + if (vectorize) { candidateTileSizes.append({4 * subgroupSize, 2 * subgroupSize}); } + candidateTileSizes.push_back(subgroupSize); // Use the first tile size that can divide the shape. If the shape is not // aligned on any of the tile sizes pick the smallest tile of one element per @@ -351,21 +360,16 @@ LogicalResult getOpLaunchConfig(linalg::GenericOp op, // If the shape is not exactly aligned on the tile size skip the second level // of tiling as it expect the number of iteration to be exactly equal to the // number of processors. - if (outputShape.getShape().back() % lowerTs != 0) return success(); - - // Skip vectorization for non-minor identity inputs as it generates - // transfer_read ops with permutation maps that we currently cannot lower. - // TODO: Remove this restriction once the lowering of the permutation map is - // supported in core. - for (unsigned i = 0, e = op.getNumInputs(); i < e; i++) { - if (!op.getInputIndexingMap(i).isMinorIdentity()) return success(); + if (!vectorize || outputShape.getShape().back() % lowerTs != 0) { + config.vectorize = false; + return success(); } tileSizes.emplace_back(); // Subgroup level. ts.back() = lowerTs / subgroupSize; tileSizes.emplace_back(ts); // Thread level. // Vectorize only if we are processing more than one element per thread. - config.vectorize = options.enableVectorization && (ts.back() > 1); + config.vectorize = vectorize && (ts.back() > 1); return success(); } diff --git a/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndDistributePass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndDistributePass.cpp index 3c6d5e35e2bc..05d5432077ab 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndDistributePass.cpp +++ b/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndDistributePass.cpp @@ -120,7 +120,24 @@ class LinalgTileAndDistributePass llvm::dbgs() << "}\n"; } }); - + // Annotate the linalg op with the original types. + for (linalg::LinalgOp op : linalgOps) { + const char inputTypeAttrName[] = "iree.codegen.original_input_types"; + const char outputTypeAttrName[] = "iree.codegen.original_output_types"; + + SmallVector inputTypes; + SmallVector outputTypes; + for (Type type : op.getInputBufferTypes()) inputTypes.push_back(type); + for (Type type : op.getOutputBufferTypes()) outputTypes.push_back(type); + if (!inputTypes.empty()) { + op->setAttr(inputTypeAttrName, + Builder(op).getTypeArrayAttr(inputTypes)); + } + if (!outputTypes.empty()) { + op->setAttr(outputTypeAttrName, + Builder(op).getTypeArrayAttr(outputTypes)); + } + } TileAndFuseOptions tileAndFuseOptions = { getWorkgroupDistributionOptions(), allocateWorkgroupMemory}; if (failed(tileAndFuseLinalgBufferOps(funcOp, linalgOps, dependenceGraph,