diff --git a/iree/compiler/Conversion/CodegenUtils/FunctionUtils.cpp b/iree/compiler/Conversion/CodegenUtils/FunctionUtils.cpp index 107894de4c19..9150ab3519a6 100644 --- a/iree/compiler/Conversion/CodegenUtils/FunctionUtils.cpp +++ b/iree/compiler/Conversion/CodegenUtils/FunctionUtils.cpp @@ -42,5 +42,12 @@ IREE::HAL::ExecutableEntryPointOp getEntryPoint(FuncOp funcOp) { return nullptr; } +Value getViewSource(Value view) { + while (auto viewOp = view.getDefiningOp()) { + view = viewOp.getViewSource(); + } + return view; +} + } // namespace iree_compiler } // namespace mlir diff --git a/iree/compiler/Conversion/CodegenUtils/FunctionUtils.h b/iree/compiler/Conversion/CodegenUtils/FunctionUtils.h index 2807df32b7d4..2c65971dbee7 100644 --- a/iree/compiler/Conversion/CodegenUtils/FunctionUtils.h +++ b/iree/compiler/Conversion/CodegenUtils/FunctionUtils.h @@ -41,6 +41,10 @@ unsigned getNumOuterParallelLoops(linalg::LinalgOp op); /// Returns the entry point op for the `funcOp`. Returns `nullptr` on failure. IREE::HAL::ExecutableEntryPointOp getEntryPoint(FuncOp funcOp); +/// Gets the source type of ops that implement ViewOpInterface recursively. Can +/// be used to get the untiled operands from a tiled operation. +Value getViewSource(Value view); + } // namespace iree_compiler } // namespace mlir diff --git a/iree/compiler/Conversion/LinalgToLLVM/MaterializeCPULaunchConfigurationPass.cpp b/iree/compiler/Conversion/LinalgToLLVM/MaterializeCPULaunchConfigurationPass.cpp index 8e757f017654..c8734b7d515e 100644 --- a/iree/compiler/Conversion/LinalgToLLVM/MaterializeCPULaunchConfigurationPass.cpp +++ b/iree/compiler/Conversion/LinalgToLLVM/MaterializeCPULaunchConfigurationPass.cpp @@ -61,14 +61,16 @@ void MaterializeCPULaunchConfigurationPass::runOnOperation() { SmallVector linalgOps; SmallVector tiledLoops; if (failed(getLinalgOps(funcOp, linalgOps, tiledLoops))) { - return signalPassFailure(); + // Nothing to do here. Continue. + continue; } linalg::Aliases aliases; linalg::LinalgDependenceGraph dependenceGraph(aliases, linalgOps); Optional launchConfigOpt = initCPULaunchConfig(context, dependenceGraph, linalgOps); if (!launchConfigOpt) { - return; + // Nothing to do here. Continue. + continue; } LaunchConfig &launchConfig = *launchConfigOpt; diff --git a/iree/compiler/Conversion/LinalgToSPIRV/CodeGenOptionUtils.h b/iree/compiler/Conversion/LinalgToSPIRV/CodeGenOptionUtils.h index 7cce13bf35a3..41d17b8624c2 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/CodeGenOptionUtils.h +++ b/iree/compiler/Conversion/LinalgToSPIRV/CodeGenOptionUtils.h @@ -26,6 +26,7 @@ #define IREE_COMPILER_CONVERSION_LINALGTOSPIRV_CODEGENOPTIONUTILS_H_ #include "llvm/ADT/SmallVector.h" +#include "llvm/Support/raw_ostream.h" namespace mlir { namespace iree_compiler { diff --git a/iree/compiler/Conversion/LinalgToSPIRV/ConcretizeTileAmongWorkgroupsPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/ConcretizeTileAmongWorkgroupsPass.cpp index 02f853beceac..ec5658e78c19 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/ConcretizeTileAmongWorkgroupsPass.cpp +++ b/iree/compiler/Conversion/LinalgToSPIRV/ConcretizeTileAmongWorkgroupsPass.cpp @@ -40,6 +40,7 @@ // //===----------------------------------------------------------------------===// +#include "iree/compiler/Conversion/CodegenUtils/FunctionUtils.h" #include "iree/compiler/Conversion/Common/LaunchConfig.h" #include "iree/compiler/Conversion/Common/Transforms.h" #include "iree/compiler/Conversion/LinalgToSPIRV/CodeGenOptionUtils.h" @@ -73,29 +74,17 @@ namespace mlir { namespace iree_compiler { -namespace { - -constexpr unsigned kWorkgroupDimCount = 3; +static constexpr unsigned kMaxWorkgroupDimCount = 3; -int64_t ceilDiv(int64_t a, int64_t b) { return (a + b - 1) / b; } - -static size_t getNumOuterParallelDims(linalg::LinalgOp op) { - ArrayRef iterators = op.iterator_types().getValue(); - auto parallels = iterators.take_while( - [](Attribute attr) { return linalg::isParallelIteratorType(attr); }); - return parallels.size(); -} +static int64_t ceilDiv(int64_t a, int64_t b) { return (a + b - 1) / b; } /// Returns the root Linalg op that dictates tiling and distribution policy. -linalg::LinalgOp getRootLinalgOp(FuncOp funcOp) { +static linalg::LinalgOp getRootLinalgOp(FuncOp funcOp, + const SPIRVCodegenOptions &options) { SmallVector linalgOps; SmallVector tiledLoops; if (failed(getLinalgOps(funcOp, linalgOps, tiledLoops))) return {}; - SPIRVCodegenOptions options; - options.enableVectorization = true; - options.usingLinalgOnTensors = true; - linalg::Aliases aliases; linalg::LinalgDependenceGraph dependenceGraph(aliases, linalgOps); Optional launchConfigOpt = initGPULaunchConfig( @@ -125,26 +114,29 @@ linalg::LinalgOp getRootLinalgOp(FuncOp funcOp) { // TODO(antiagainst): This is quite fragile. We need a better way to pass the // information down from the upper layer, which readily has it. Probably via // linalg.tile op. -LogicalResult getInputOutputTypesForAllTiles( - linalg::LinalgOp rootOp, SmallVectorImpl &inputTypes, - SmallVectorImpl &outputTypes) { +static std::tuple, SmallVector> getInputOutputTypes( + linalg::LinalgOp rootOp) { + SmallVector inputTypes, outputTypes; for (Value inputBuffer : rootOp.getInputBuffers()) { if (auto subviewOp = inputBuffer.getDefiningOp()) { inputTypes.push_back(subviewOp.getViewSource().getType()); } else if (auto allocOp = inputBuffer.getDefiningOp()) { inputTypes.push_back(allocOp.getType()); } else { - return failure(); + inputTypes.clear(); + break; } } for (Value outputBuffer : rootOp.getOutputBuffers()) { auto subviewOp = outputBuffer.getDefiningOp(); - if (!subviewOp) return failure(); + if (!subviewOp) { + outputTypes.clear(); + break; + } outputTypes.push_back(subviewOp.getViewSource().getType()); } - - return success(); + return std::make_tuple(std::move(inputTypes), std::move(outputTypes)); } /// Assuming the given `rootOp` is the tiled root Linalg op, returns the @@ -153,10 +145,10 @@ LogicalResult getInputOutputTypesForAllTiles( /// /// TODO(antiagainst): This pass can be shared between CPU and GPU. But the /// following query scopes it to GPU for now. -llvm::Optional< - std::pair, llvm::SmallVector>> -getTileSizeAndWorkgroupSize(Operation *rootOp, ArrayRef inputTypes, - ArrayRef outputTypes) { +static LogicalResult getTileSizeAndWorkgroupSize( + Operation *rootOp, ArrayRef inputTypes, ArrayRef outputTypes, + SmallVector &tileSize, SmallVector &workgroupSize, + const SPIRVCodegenOptions &options) { // Build necesary structures to query the tile sizes for distributing to // workgroups. linalg::Aliases aliases; @@ -165,10 +157,6 @@ getTileSizeAndWorkgroupSize(Operation *rootOp, ArrayRef inputTypes, linalgOps.assign(ops.begin(), ops.end()); linalg::LinalgDependenceGraph dependenceGraph(aliases, linalgOps); - SPIRVCodegenOptions options; - options.enableVectorization = true; - options.usingLinalgOnTensors = true; - // NOTE: Launch configuration expects the original input/output type to decide // the configuration. But we have already tiled the Linalg ops here. Use an // attribute to send it over for now. @@ -186,85 +174,44 @@ getTileSizeAndWorkgroupSize(Operation *rootOp, ArrayRef inputTypes, Optional launchConfig = initGPULaunchConfig( rootOp->getContext(), dependenceGraph, options, linalgOps); if (!launchConfig) { - rootOp->emitError("unable to find launch configuration"); - return llvm::None; + return rootOp->emitError("unable to find launch configuration"); } - ArrayRef tileSize = launchConfig->getTileSizes(rootOp, 0); - ArrayRef workgroupSize = launchConfig->getWorkgroupSize(); + tileSize = llvm::to_vector<4>(launchConfig->getTileSizes(rootOp, 0)); + workgroupSize = llvm::to_vector<4>(launchConfig->getWorkgroupSize()); // Clean up internal markers that are set during launch configuration // preparation. launchConfig->finalize(rootOp->getParentOfType()); - return std::make_pair(llvm::to_vector<4>(tileSize), - llvm::to_vector<4>(workgroupSize)); + return success(); } -/// Replaces hal.interface.workgroup.size op with the constant value chosen -/// from tiling scheme. -class ConcretizeWorkgroupSizeOp final - : public OpRewritePattern { - public: - ConcretizeWorkgroupSizeOp(MLIRContext *context, - SmallVector workloadSize, - SmallVector tileSize, - PatternBenefit benefit = 1) - : OpRewritePattern(context, benefit), - workloadSize(std::move(workloadSize)), - tileSize(std::move(tileSize)) {} - - LogicalResult matchAndRewrite(IREE::HAL::InterfaceWorkgroupSizeOp op, - PatternRewriter &rewriter) const override { - unsigned dimIndex = op.dimension().getZExtValue(); - - if (dimIndex < kWorkgroupDimCount && tileSize[dimIndex] != 0) { - rewriter.replaceOpWithNewOp( - op, rewriter.getIndexAttr(tileSize[dimIndex])); - return success(); - } - - return failure(); - } - - private: - SmallVector workloadSize; - SmallVector tileSize; -}; - +namespace { /// Replaces hal.interface.workgroup.count op with the constant value chosen /// from tiling scheme. class ConcretizeWorkgroupCountOp final : public OpRewritePattern { public: ConcretizeWorkgroupCountOp(MLIRContext *context, - SmallVector workloadSize, - SmallVector tileSize, + ArrayRef numWorkgroups, PatternBenefit benefit = 1) : OpRewritePattern(context, benefit), - workloadSize(std::move(workloadSize)), - tileSize(std::move(tileSize)) {} + numWorkgroups(numWorkgroups.begin(), numWorkgroups.end()) {} LogicalResult matchAndRewrite(IREE::HAL::InterfaceWorkgroupCountOp op, PatternRewriter &rewriter) const override { unsigned dimIndex = op.dimension().getZExtValue(); - if (dimIndex >= kWorkgroupDimCount) return failure(); - - int64_t dimSize = workloadSize[dimIndex]; - int64_t dimTile = tileSize[dimIndex]; - - if (dimSize == ShapedType::kDynamicSize || dimTile == 0) return failure(); - - int64_t count = ceilDiv(dimSize, dimTile); - rewriter.replaceOpWithNewOp(op, rewriter.getIndexAttr(count)); + if (dimIndex >= numWorkgroups.size()) return failure(); + rewriter.replaceOpWithNewOp( + op, rewriter.getIndexAttr(numWorkgroups[dimIndex])); return success(); } private: - SmallVector workloadSize; - SmallVector tileSize; + SmallVector numWorkgroups; }; // Canonicalizes away a trip-one scf.for loop by inlining its body and removing @@ -282,12 +229,11 @@ class ConcretizeWorkgroupCountOp final // Such scf.for loops can be inlined if %lb is smaller than upper bound. class RemoveTripOneLoop final : public OpRewritePattern { public: - RemoveTripOneLoop(MLIRContext *context, SmallVector workloadSize, - SmallVector tileSize, - PatternBenefit benefit = 1) + RemoveTripOneLoop(MLIRContext *context, ArrayRef workloadSize, + ArrayRef tileSize, PatternBenefit benefit = 1) : OpRewritePattern(context, benefit), - workloadSize(std::move(workloadSize)), - tileSize(std::move(tileSize)) {} + workloadSize(workloadSize.begin(), workloadSize.end()), + tileSize(tileSize.begin(), tileSize.end()) {} LogicalResult matchAndRewrite(scf::ForOp op, PatternRewriter &rewriter) const override { @@ -358,6 +304,72 @@ class RemoveTripOneLoop final : public OpRewritePattern { SmallVector tileSize; }; +static void removeOneTripTiledLoops(MLIRContext *context, FuncOp funcOp, + linalg::LinalgOp rootLinalgOp, + ArrayRef halWorkgroupSize) { + if (rootLinalgOp.getNumOutputs() != 1) return; + unsigned numParallelDims = getNumOuterParallelLoops(rootLinalgOp); + unsigned numTiledDims = + std::min(numParallelDims, kMaxWorkgroupDimCount); + + Value untiledOutputOperand = getViewSource(rootLinalgOp.getOutput(0)); + ArrayRef outputShape = + untiledOutputOperand.getType().cast().getShape(); + if (outputShape.size() < numParallelDims) return; + + // TODO(ravishankarm, antiagainst): Its pure co-incidence that the + // workload is derivable from the output shape. There is no requirement + // for this but is the case for all operations we are interested in. + auto workloadSize = llvm::to_vector<4>(llvm::reverse( + outputShape.take_front(numParallelDims).take_back(numTiledDims))); + if (llvm::any_of(workloadSize, [](int64_t dim) { + return dim == ShapedType::kDynamicSize; + })) { + return; + } + LLVM_DEBUG({ + llvm::dbgs() << "Queried workload size: "; + llvm::interleaveComma(workloadSize, llvm::dbgs()); + llvm::dbgs() << "\n"; + }); + SmallVector numWorkgroups; + assert(halWorkgroupSize.size() == workloadSize.size()); + for (auto pair : llvm::zip(workloadSize, halWorkgroupSize)) { + auto workload = std::get<0>(pair); + auto size = std::get<1>(pair); + numWorkgroups.push_back(ceilDiv(workload, size)); + } + numWorkgroups.resize(kMaxWorkgroupDimCount, 1); + WorkgroupCountRegionBuilder regionBuilder = [&](OpBuilder &b, Location loc, + std::array) { + std::array returnValues; + for (unsigned i = 0; i < kMaxWorkgroupDimCount; ++i) { + returnValues[i] = b.create(loc, numWorkgroups[i]); + } + return returnValues; + }; + + OpBuilder builder(context); + if (failed(defineWorkgroupCountRegion(builder, funcOp, regionBuilder))) { + return; + } + + { + OwningRewritePatternList workgroupCountPatterns(context); + workgroupCountPatterns.insert(context, + numWorkgroups); + (void)applyPatternsAndFoldGreedily(funcOp, + std::move(workgroupCountPatterns)); + } + { + OwningRewritePatternList removeTripOneLoopPatterns(context); + removeTripOneLoopPatterns.insert(context, workloadSize, + halWorkgroupSize); + (void)applyPatternsAndFoldGreedily(funcOp, + std::move(removeTripOneLoopPatterns)); + } +} + /// Concretizes hal.interface.workgroup.* ops with constants from the chosen /// tiling sheme when possible and perform loop canonicalization afterwards. class ConcretizeTileAmongWorkgroupsPass @@ -375,10 +387,9 @@ class ConcretizeTileAmongWorkgroupsPass void runOnOperation() override { IREE::HAL::ExecutableTargetOp targetOp = getOperation(); ModuleOp module = targetOp.getInnerModule(); - for (FuncOp funcOp : module.getOps()) { if (!funcOp.isPublic()) continue; - if (failed(runOnFunction(funcOp))) return signalPassFailure(); + (void)runOnFunction(funcOp); } } @@ -386,158 +397,84 @@ class ConcretizeTileAmongWorkgroupsPass LogicalResult runOnFunction(FuncOp funcOp) { MLIRContext &context = getContext(); - // 1. Get the root op first. We need it to figure out the original problem - // size, which then affects the tiling and distribution policy. - - linalg::LinalgOp rootOp = getRootLinalgOp(funcOp); - if (!rootOp) { - LLVM_DEBUG(llvm::dbgs() << "unable to find root Linalg op\n"); - // It can happen for ops that are not abstractly tiled during dispatch - // region formation. So don't trigger pass failure. + // 1. Get the linalg operations within the function. The callee here + // successed only for functions with single basic block. + SmallVector linalgOps; + SmallVector tiledLoops; + if (failed(getLinalgOps(funcOp, linalgOps, tiledLoops))) { + return failure(); + } + // If there are no Linalg ops. Nothing to do. Return. + if (linalgOps.empty()) return success(); + + // 2. Get the launch configuration to use for the function. + linalg::Aliases aliases; + linalg::LinalgDependenceGraph dependenceGraph(aliases, linalgOps); + Optional launchConfig = initGPULaunchConfig( + funcOp.getContext(), dependenceGraph, options, linalgOps); + if (!launchConfig) { + // Having no config implies that there is nothing to do here. Return return success(); } - LLVM_DEBUG(llvm::dbgs() << "Root op: " << rootOp << "\n"); - - size_t numTilableDims = getNumOuterParallelDims(rootOp); - - // 2. Figure out the original problem size. - SmallVector inputTypes, outputTypes; - SmallVector workloadSize; - if (succeeded( - getInputOutputTypesForAllTiles(rootOp, inputTypes, outputTypes))) { - if (outputTypes.size() != 1) { - return rootOp.emitError("only support ops with one result right now"); - } - - // Flow/HAL processor id/size/count ops' indices follow the reverse order - // of the shape dimensions. - workloadSize = llvm::to_vector<4>(llvm::reverse( - outputTypes.front().cast().getShape().take_front( - numTilableDims))); - } else { - // This can happen for dynamic shapes. - LLVM_DEBUG(llvm::dbgs() - << "unable to find input/output type for all tiles"); - - inputTypes.clear(); - outputTypes.clear(); - - workloadSize.assign(numTilableDims, ShapedType::kDynamicSize); + // 3. The root operation determines the tile size to use. This has already + // been computed by the launch configuration. + // TODO(ravishankarm): The configuration actually makes sure that all tile + // sizes for the parallel loops are consistent, but get the root operation + // for now. + Operation *rootOp = + launchConfig->getRootOperation(llvm::to_vector<4>(llvm::map_range( + linalgOps, [](linalg::LinalgOp op) { return op.getOperation(); }))); + + unsigned numParallelDims = getNumOuterParallelLoops(rootOp); + unsigned numTiledDims = + std::min(numParallelDims, kMaxWorkgroupDimCount); + ArrayRef tileSizes = launchConfig->getTileSizes(rootOp, 0); + if (tileSizes.size() < numParallelDims) { + return rootOp->emitError( + "invalid tile size configuration, expected at least as many " + "as the number of tiled loops : ") + << numParallelDims; } - LLVM_DEBUG({ - llvm::dbgs() << "Queried workload size: "; - llvm::interleaveComma(workloadSize, llvm::dbgs()); - llvm::dbgs() << "\n"; - }); - - // 3. Query the scheme for tiling among workgroups. - - SmallVector tileSize; - SmallVector workgroupSize; - - // Try to use configuration from the command-line first for testing. - tileSize.assign(options.tileSizes.begin(), options.tileSizes.end()); - tileSize.resize(numTilableDims, 0); - workgroupSize.assign(options.workgroupSize.begin(), - options.workgroupSize.end()); - if (tileSize.empty() || workgroupSize.empty()) { - auto sizes = getTileSizeAndWorkgroupSize(rootOp, inputTypes, outputTypes); - if (sizes) { - // The tile sizes are specified against the original dimension order of - // the workload shape. But Flow/HAL processor id/size/count ops' are - // created using the reverse order. - tileSize = sizes->first; - tileSize.resize(numTilableDims); - tileSize = llvm::to_vector<4>(llvm::reverse(tileSize)); - workgroupSize = sizes->second; - } else { - return funcOp.emitError("failed to query tile size and workgroup size"); - } + // TODO(ravishankarm): The flow tiling only tiles the inner parallel loops + // by default. Using the same approach here. This spooky distant shake hand + // needs to be resolved. Potentially can be made cleaner with use of + // `linalg.tile` operation. + tileSizes = tileSizes.take_front(numParallelDims).take_back(numTiledDims); + if (llvm::any_of(tileSizes, [](int64_t ts) { return ts == 0; })) { + return rootOp->emitError( + "unhandled tile size setting of 0 for a loop that was tiled"); } + // 4. The hal.workgroup.size is a representation of the tile size. Note that + // this is not the actual workgroup size used eventually. That is computed + // by the launch configuration and is set below. + auto halWorkgroupSize = llvm::to_vector<4>(llvm::reverse(tileSizes)); + LLVM_DEBUG({ llvm::dbgs() << "Queried tile size: "; - llvm::interleaveComma(tileSize, llvm::dbgs()); + llvm::interleaveComma(tileSizes, llvm::dbgs()); + llvm::dbgs() << ", HAL workgroup size: "; + llvm::interleaveComma(halWorkgroupSize, llvm::dbgs()); llvm::dbgs() << "\n"; }); - - // 4. Replace hal.interface.workgroup symbolic ops with constant values. - - { - OwningRewritePatternList patterns(&getContext()); - patterns.insert( - &context, workloadSize, tileSize); - - (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); + // 4. Materialize the constant values for the hal.workgroup.size along + // different dimensions. + if (failed(materializeStaticLaunchInformation(funcOp, halWorkgroupSize))) { + return funcOp.emitOpError( + "failed to materialize static launch information"); } - LLVM_DEBUG({ - llvm::dbgs() - << "--- After concretizing hal.interface.workgroup ops ---\n"; - funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope()); - llvm::dbgs() << "\n\n"; - }); - - // 5. Set the entry point region for computing the number of workgroups - // to dispatch. The region has symbolic arguments representing the workload. - // So two modes here (see comments at the begining of this file). - - { - SmallVector numWorkgroups; - for (auto pair : llvm::zip(workloadSize, tileSize)) { - auto workload = std::get<0>(pair); - auto tile = std::get<1>(pair); - if (workload == ShapedType::kDynamicSize || tile == 0) { - numWorkgroups.push_back(ShapedType::kDynamicSize); - } else { - numWorkgroups.push_back(ceilDiv(workload, tile)); - } - } - - numWorkgroups.resize(kWorkgroupDimCount, 1); - - // If all dimensions are known constant, then we can set the number of - // workgroups directly. Otherwise, we need to generate the IR for - // computing it using symbolic values. - if (llvm::none_of(numWorkgroups, [](int64_t dim) { - return dim == ShapedType::kDynamicSize; - })) { - OpBuilder builder(&context); - WorkgroupCountRegionBuilder regionBuilder = - [&](OpBuilder &builder, Location loc, std::array) { - std::array returnValues; - for (unsigned i = 0; i < kWorkgroupDimCount; ++i) { - returnValues[i] = - builder.create(loc, numWorkgroups[i]); - } - return returnValues; - }; - if (failed( - defineWorkgroupCountRegion(builder, funcOp, regionBuilder))) { - return funcOp.emitError( - "failed to set entry point region for number of workgroups"); - } - } else { - if (failed(materializeStaticLaunchInformation(funcOp, tileSize))) { - return funcOp.emitOpError( - "failed to materialize static launch information"); - } - } - } - - if (failed(updateWorkGroupSize(funcOp, workgroupSize))) { + // 5. Update the actual workgroup size to use based on launch configuraiton. + if (failed(updateWorkGroupSize(funcOp, launchConfig->getWorkgroupSize()))) { return funcOp.emitOpError("failed to set workgroup size on function"); } - - // 6. Canonicalization and clean up. + launchConfig->finalize(funcOp); if (inlineTripOneLoops) { - OwningRewritePatternList patterns(&getContext()); - patterns.insert(&context, workloadSize, tileSize); - - (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); + removeOneTripTiledLoops(&context, funcOp, cast(rootOp), + halWorkgroupSize); } return success(); diff --git a/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp b/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp index 67ef07c61d94..0c2431398716 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp +++ b/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp @@ -66,27 +66,19 @@ static int64_t getMinIfShapeStatic(int64_t shape, int64_t tileSize) { /// Fills `inputTypes` and `outputTypes` with the original input/output types /// for all tiles for `op`. -static void getInputOutputTypes(linalg::LinalgOp op, - SmallVectorImpl &inputTypes, - SmallVectorImpl &outputTypes) { - // NOTE: Special treatment to let the flow.dispatch.workgroups path to be able - // to query launch configurations. This should be cleaned up after the - // flow.dispatch.workgroups become the default path. - auto inputTypeAttr = - op->getAttrOfType("iree.codegen.original_input_types"); - auto outputTypeAttr = - op->getAttrOfType("iree.codegen.original_output_types"); - if (outputTypeAttr && inputTypeAttr) { - for (Type type : inputTypeAttr.getAsValueRange()) - inputTypes.push_back(type.cast()); - for (Type type : outputTypeAttr.getAsValueRange()) - outputTypes.push_back(type.cast()); - } else { - for (Type type : op.getInputBufferTypes()) - inputTypes.push_back(type.cast()); - for (Type type : op.getOutputBufferTypes()) - outputTypes.push_back(type.cast()); - } +static std::tuple, SmallVector> +getInputOutputTypes(linalg::LinalgOp op) { + SmallVector inputTypes(op.getNumInputs()), + outputTypes(op.getNumOutputs()); + for (auto operand : enumerate(op.getInputOpOperands())) { + Value source = getViewSource(operand.value().get()); + inputTypes[operand.index()] = source.getType().dyn_cast(); + } + for (auto operand : enumerate(op.getOutputOpOperands())) { + Value source = getViewSource(operand.value().get()); + outputTypes[operand.index()] = source.getType().dyn_cast(); + } + return std::make_tuple(std::move(inputTypes), std::move(outputTypes)); } namespace { @@ -148,13 +140,13 @@ static LogicalResult getMaliSpecificConfig( std::array &numSubgroups) { if (targetEnv.getVendorID() != spirv::Vendor::ARM) return failure(); - SmallVector inputTypes, outputTypes; - getInputOutputTypes(op, inputTypes, outputTypes); + SmallVector inputTypes, outputTypes; + std::tie(inputTypes, outputTypes) = getInputOutputTypes(op); ShapedType lhsType = inputTypes[0], rhsType = inputTypes[1]; - assert(lhsType.getElementType() == rhsType.getElementType()); - - if (!lhsType.hasStaticShape() || !rhsType.hasStaticShape()) return failure(); + if (!lhsType || !rhsType || !lhsType.hasStaticShape() || + !rhsType.hasStaticShape()) + return failure(); // Get a vector of best tile size ordered from best to worst. SmallVector workgroupLevelTs; int64_t dstSize = @@ -313,18 +305,17 @@ static LogicalResult getConfigForCooperativeMatmul( } /// Launch config for element-wise linalg.generic. -template <> -LogicalResult getOpLaunchConfig(linalg::GenericOp op, - const spirv::TargetEnv &targetEnv, - const SPIRVCodegenOptions &options, - TileSizesListType &tileSizes, - LaunchConfigInfo &config) { +LogicalResult getGenericOpLaunchConfig(linalg::LinalgOp linalgOp, + const spirv::TargetEnv &targetEnv, + const SPIRVCodegenOptions &options, + TileSizesListType &tileSizes, + LaunchConfigInfo &config) { // Skip vectorization for non-minor identity inputs as it generates // transfer_read ops with permutation maps that we currently cannot lower. // TODO: Remove this restriction once the lowering of the permutation map is // supported in core. bool vectorize = options.enableVectorization && - llvm::all_of(op.getIndexingMaps(), [](AffineMap &map) { + llvm::all_of(linalgOp.getIndexingMaps(), [](AffineMap &map) { return map.isMinorIdentity(); }); int64_t subgroupSize = @@ -332,7 +323,7 @@ LogicalResult getOpLaunchConfig(linalg::GenericOp op, config.workgroupSize[0] = subgroupSize; config.workgroupSize[1] = 1; config.workgroupSize[2] = 1; - ShapedType outputShape = op.getOutputShapedType(0); + ShapedType outputShape = linalgOp.getOutputShapedType(0); SmallVector candidateTileSizes; // When Vectororization is not enabled we skil the second level of tiling and @@ -354,8 +345,8 @@ LogicalResult getOpLaunchConfig(linalg::GenericOp op, lowerTs = size; break; } + unsigned numLoops = getNumOuterParallelLoops(linalgOp); SmallVector ts; - size_t numLoops = getNumOuterParallelLoops(op); ts.resize(numLoops, 1); ts.back() = lowerTs; tileSizes.emplace_back(ts); // Workgroup level. @@ -367,7 +358,7 @@ LogicalResult getOpLaunchConfig(linalg::GenericOp op, return success(); } - tileSizes.emplace_back(); // Subgroup level. + tileSizes.emplace_back(); // Subgroup level. ts.back() = lowerTs / subgroupSize; tileSizes.emplace_back(ts); // Thread level. // Vectorize only if we are processing more than one element per thread. @@ -375,6 +366,21 @@ LogicalResult getOpLaunchConfig(linalg::GenericOp op, return success(); } +#define GET_GENERIC_OP_LAUNCH_CONFIG(opType) \ + template <> \ + LogicalResult getOpLaunchConfig( \ + opType op, const spirv::TargetEnv &targetEnv, \ + const SPIRVCodegenOptions &options, TileSizesListType &tileSizes, \ + LaunchConfigInfo &config) { \ + return getGenericOpLaunchConfig(op, targetEnv, options, tileSizes, \ + config); \ + } + +GET_GENERIC_OP_LAUNCH_CONFIG(linalg::GenericOp) +GET_GENERIC_OP_LAUNCH_CONFIG(linalg::IndexedGenericOp) + +#undef GET_GENERIC_OP_LAUNCH_CONFIG + /// Launch configuration for different known GPU configuration. static LogicalResult getTargetSpecificConfig( linalg::MatmulOp op, const spirv::TargetEnv &targetEnv, @@ -383,14 +389,15 @@ static LogicalResult getTargetSpecificConfig( std::array &numSubgroups) { if (targetEnv.getVendorID() != spirv::Vendor::ARM) return failure(); - SmallVector inputTypes, outputTypes; - getInputOutputTypes(op, inputTypes, outputTypes); + SmallVector inputTypes, outputTypes; + std::tie(inputTypes, outputTypes) = getInputOutputTypes(op); ShapedType lhsType = inputTypes[0], rhsType = inputTypes[1]; - assert(lhsType.getElementType() == rhsType.getElementType()); - // If the shape size is unknonw fall back to none vectorized path. - if (!lhsType.hasStaticShape() || !rhsType.hasStaticShape()) return failure(); + if (!lhsType || !rhsType || !lhsType.hasStaticShape() || + !rhsType.hasStaticShape()) + return failure(); + // Pick ideal tile size based on the type. SmallVector workgroupLevelTs; int64_t dstSize = lhsType.getDimSize(0) * rhsType.getDimSize(1); @@ -472,11 +479,12 @@ static LogicalResult getMaliSpecificConfig(ConvOpTy op, Operation *operation = op.getOperation(); if (!isa(operation)) return failure(); - SmallVector inputTypes, outputTypes; - getInputOutputTypes(op, inputTypes, outputTypes); + SmallVector inputTypes, outputTypes; + std::tie(inputTypes, outputTypes) = getInputOutputTypes(op); ShapedType inputType = inputTypes[0], outputType = outputTypes[0]; - if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) + if (!inputType || !outputType || !inputType.hasStaticShape() || + !outputType.hasStaticShape()) return failure(); bool isInputTilable = @@ -579,11 +587,12 @@ GET_CONV_LAUNCH_CONFIG(linalg::ConvInputNDHWCFilterDHWCFOp) static LogicalResult getMaliSpecificConfig( linalg::DepthwiseConvInputNHWCFilterHWCOp op, TileSizesListType &tileSizes, LaunchConfigInfo &config) { - SmallVector inputTypes, outputTypes; - getInputOutputTypes(op, inputTypes, outputTypes); + SmallVector inputTypes, outputTypes; + std::tie(inputTypes, outputTypes) = getInputOutputTypes(op); ShapedType inputType = inputTypes[0], outputType = outputTypes[0]; - if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) + if (!inputType || !outputType || !inputType.hasStaticShape() || + !outputType.hasStaticShape()) return failure(); // A list of preferred tile sizes and workgroup sizes. This is for Mali @@ -739,7 +748,6 @@ Optional initGPULaunchConfig( SmallVector workgroupSize(options.workgroupSize.begin(), options.workgroupSize.end()); launchConfig.setWorkgroupSize(workgroupSize); - return launchConfig; } if (linalgOps.empty()) return launchConfig; @@ -748,7 +756,6 @@ Optional initGPULaunchConfig( Optional rootOperation = {}; LaunchConfigInfo config; - for (linalg::LinalgOp linalgOp : linalgOps) { #define DISPATCH(opName) \ if (auto op = dyn_cast(linalgOp.getOperation())) { \ if (rootOperation) { \ @@ -756,16 +763,17 @@ Optional initGPULaunchConfig( return llvm::None; \ } \ rootOperation = linalgOp; \ + if (launchConfig.hasTileSizes(linalgOp.getOperation())) continue; \ TileSizesListType tileSizesInfo; \ if (failed(getOpLaunchConfig(op, targetEnv, options, tileSizesInfo, \ config))) { \ return llvm::None; \ } \ launchConfig.setTileSizes(op, tileSizesInfo); \ - launchConfig.setRootOperation(op); \ continue; \ } + for (linalg::LinalgOp linalgOp : linalgOps) { DISPATCH(linalg::BatchMatmulOp) DISPATCH(linalg::DepthwiseConvInputNHWCFilterHWCOp) DISPATCH(linalg::DepthwiseConvInputNHWCFilterHWCFOp) @@ -776,40 +784,38 @@ Optional initGPULaunchConfig( DISPATCH(linalg::PoolingNHWCMaxOp) DISPATCH(linalg::PoolingNHWCMinOp) DISPATCH(linalg::PoolingNHWCSumOp) - -#undef DISPATCH } + // Any generic/indexed_generic operations found are made the root if no other + // op is the root if (!rootOperation) { - for (linalg::LinalgOp linalgOp : linalgOps) { - if (auto op = dyn_cast(linalgOp.getOperation())) { - if (getNumOuterParallelLoops(linalgOp) == 0 || - llvm::any_of(linalgOp.getIndexingMaps(), [](AffineMap &map) { - return !map.isProjectedPermutation(); - })) { - continue; - } - TileSizesListType tileSizesInfo; - if (failed(getOpLaunchConfig(op, targetEnv, options, tileSizesInfo, - config))) { - continue; - } - launchConfig.setTileSizes(op, tileSizesInfo); - launchConfig.setRootOperation(op); - break; + for (linalg::LinalgOp linalgOp : reverse(linalgOps)) { + size_t numLoops = getNumOuterParallelLoops(linalgOp); + if (numLoops == 0 || + llvm::any_of(linalgOp.getIndexingMaps(), [](AffineMap &map) { + return !map.isProjectedPermutation(); + })) { + return llvm::None; } + + DISPATCH(linalg::GenericOp) + DISPATCH(linalg::IndexedGenericOp) } } - launchConfig.setWorkgroupSize(config.workgroupSize); - launchConfig.setNumSubgroups(config.numSubgroups); - launchConfig.setVectorize(config.vectorize); +#undef DISPATCH if (!rootOperation) { - // No root operations found. Dont need to do anything. - return launchConfig; + return llvm::None; } + launchConfig.setRootOperation(*rootOperation); + if (options.workgroupSize.empty()) { + launchConfig.setWorkgroupSize(config.workgroupSize); + } + launchConfig.setNumSubgroups(config.numSubgroups); + launchConfig.setVectorize(config.vectorize); + if (failed(propogateRootOperationLaunchConfig(launchConfig, *rootOperation, dependenceGraph))) return llvm::None; diff --git a/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndDistributePass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndDistributePass.cpp index 05d5432077ab..ba1333b7cf44 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndDistributePass.cpp +++ b/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndDistributePass.cpp @@ -86,7 +86,8 @@ class LinalgTileAndDistributePass SmallVector tiledLoops; if (failed(getLinalgOps(funcOp, linalgOps, tiledLoops))) { - return signalPassFailure(); + // If there are no linalg ops, nothing to do here. + continue; } linalg::Aliases aliases; @@ -94,8 +95,8 @@ class LinalgTileAndDistributePass Optional launchConfigOpt = initGPULaunchConfig(context, dependenceGraph, options, linalgOps); if (!launchConfigOpt) { - funcOp.emitError("unable to find launch configuration"); - return signalPassFailure(); + // Having no launch configuration also means nothing to do here. + continue; } LaunchConfig &launchConfig = *launchConfigOpt; @@ -120,24 +121,6 @@ class LinalgTileAndDistributePass llvm::dbgs() << "}\n"; } }); - // Annotate the linalg op with the original types. - for (linalg::LinalgOp op : linalgOps) { - const char inputTypeAttrName[] = "iree.codegen.original_input_types"; - const char outputTypeAttrName[] = "iree.codegen.original_output_types"; - - SmallVector inputTypes; - SmallVector outputTypes; - for (Type type : op.getInputBufferTypes()) inputTypes.push_back(type); - for (Type type : op.getOutputBufferTypes()) outputTypes.push_back(type); - if (!inputTypes.empty()) { - op->setAttr(inputTypeAttrName, - Builder(op).getTypeArrayAttr(inputTypes)); - } - if (!outputTypes.empty()) { - op->setAttr(outputTypeAttrName, - Builder(op).getTypeArrayAttr(outputTypes)); - } - } TileAndFuseOptions tileAndFuseOptions = { getWorkgroupDistributionOptions(), allocateWorkgroupMemory}; if (failed(tileAndFuseLinalgBufferOps(funcOp, linalgOps, dependenceGraph, diff --git a/iree/compiler/Conversion/LinalgToSPIRV/TileAndVectorizeInOneWorkgroupPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/TileAndVectorizeInOneWorkgroupPass.cpp index c5460f280811..742d9f4007a2 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/TileAndVectorizeInOneWorkgroupPass.cpp +++ b/iree/compiler/Conversion/LinalgToSPIRV/TileAndVectorizeInOneWorkgroupPass.cpp @@ -412,7 +412,8 @@ void TileAndVectorizeInOneWorkgroupPass::runOnOperation() { SmallVector tiledLoops; if (failed(getLinalgOps(funcOp, linalgOps, tiledLoops))) { - return signalPassFailure(); + // Nothing to do here. + continue; } linalg::Aliases aliases; @@ -420,8 +421,8 @@ void TileAndVectorizeInOneWorkgroupPass::runOnOperation() { Optional launchConfigOpt = initGPULaunchConfig(context, dependenceGraph, options, linalgOps); if (!launchConfigOpt) { - funcOp.emitError("unable to find launch configuration"); - return signalPassFailure(); + // No configuration to tile and vectorize. Nothing to do here. + continue; } LaunchConfig &launchConfig = *launchConfigOpt; diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/BUILD b/iree/compiler/Conversion/LinalgToSPIRV/test/BUILD index b4ea7c1ffad7..27990c52c306 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/test/BUILD +++ b/iree/compiler/Conversion/LinalgToSPIRV/test/BUILD @@ -29,6 +29,7 @@ iree_lit_test_suite( [ "batch_matmul_vectorization.mlir", "concretize_tile_among_workgroups.mlir", + "concretize_tile_among_workgroups_dynamic.mlir", "convert_to_gpu.mlir", "convert_to_spirv.mlir", "dead_alloc.mlir", diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/CMakeLists.txt b/iree/compiler/Conversion/LinalgToSPIRV/test/CMakeLists.txt index 964c72322c48..d41713ab95c5 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/test/CMakeLists.txt +++ b/iree/compiler/Conversion/LinalgToSPIRV/test/CMakeLists.txt @@ -16,6 +16,7 @@ iree_lit_test_suite( SRCS "batch_matmul_vectorization.mlir" "concretize_tile_among_workgroups.mlir" + "concretize_tile_among_workgroups_dynamic.mlir" "convert_to_gpu.mlir" "convert_to_spirv.mlir" "dead_alloc.mlir" diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/concretize_tile_among_workgroups.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/concretize_tile_among_workgroups.mlir index 9e174cb709f9..febdf3ddbda9 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/test/concretize_tile_among_workgroups.mlir +++ b/iree/compiler/Conversion/LinalgToSPIRV/test/concretize_tile_among_workgroups.mlir @@ -1,4 +1,4 @@ -// RUN: iree-opt -split-input-file -pass-pipeline="hal.executable(hal.executable.target(iree-spirv-concretize-tile-among-workgroups))" -iree-spirv-tile-size=16,4,4 -iree-spirv-workgroup-size=4,4,1 %s | IreeFileCheck %s +// RUN: iree-opt -split-input-file -iree-spirv-tile-size=0,4,4,16 -iree-spirv-workgroup-size=4,4,1 -pass-pipeline="hal.executable(hal.executable.target(iree-spirv-concretize-tile-among-workgroups))" -canonicalize -cse %s | IreeFileCheck %s hal.executable @conv2d_static_shape attributes {sym_visibility = "private"} { hal.interface @legacy_io { @@ -68,158 +68,55 @@ hal.executable @conv2d_static_shape attributes {sym_visibility = "private"} { // 2) Replace hal.interface.workgroup.{size|count} ops with constants, // 3) Canonicalize loops and memref.subview ops. -// CHECK: #[[MULMAP:.+]] = affine_map<()[s0, s1] -> (s0 * s1)> -// CHECK: #[[MAP0:.+]] = affine_map<(d0) -> (d0 * 2)> -// CHECK: #[[MAP1:.+]] = affine_map<(d0)[s0] -> (9, d0 * -2 + 225)> -// CHECK: #[[MAP2:.+]] = affine_map<(d0)[s0] -> (16, -d0 + 32)> -// CHECK: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (4, -d0 + 112)> +// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 * 4)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 16)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0] -> (s0 * 8)> +// CHECK-DAG: #[[MAP3:.+]] = affine_map<()[s0] -> (9, s0 * -8 + 225)> +// CHECK-DAG: #[[MAP5:.+]] = affine_map<()[s0] -> (16, s0 * -16 + 32)> +// CHECK-DAG: #[[MAP7:.+]] = affine_map<()[s0] -> (4, s0 * -4 + 112)> -// CHECK: hal.executable.entry_point @conv2d_static_shape -// CHECK: %[[C2:.+]] = constant 2 : index -// CHECK: %[[C28_0:.+]] = constant 28 : index -// CHECK: %[[C28_1:.+]] = constant 28 : index -// CHECK: hal.return %[[C2]], %[[C28_0]], %[[C28_1]] : index, index, index +// CHECK: hal.executable.entry_point @conv2d_static_shape +// CHECK: %[[C2:.+]] = constant 2 : index +// CHECK: %[[C28:.+]] = constant 28 : index +// CHECK: hal.return %[[C2]], %[[C28]], %[[C28]] : index, index, index -// CHECK: func @conv2d_static_shape() +// CHECK: func @conv2d_static_shape() // CHECK-SAME: spv.entry_point_abi = {local_size = dense<[4, 4, 1]> : vector<3xi32>} -// CHECK: %[[ID_X:.+]] = hal.interface.workgroup.id[0] : index -// CHECK: %[[ID_Y:.+]] = hal.interface.workgroup.id[1] : index -// CHECK: %[[ID_Z:.+]] = hal.interface.workgroup.id[2] : index - -// CHECK: %[[Z_MUL_4:.+]] = affine.apply #[[MULMAP]]()[%[[ID_Z]], %c4] -// CHECK: %[[Y_MUL_4:.+]] = affine.apply #[[MULMAP]]()[%[[ID_Y]], %c4] -// CHECK: %[[X_MUL_16:.+]] = affine.apply #[[MULMAP]]()[%[[ID_X]], %c16] - -// CHECK: %[[Z_OFFSET:.+]] = affine.apply #[[MAP0]](%[[Z_MUL_4]]) -// CHECK: %[[Z_SIZE:.+]] = affine.min #[[MAP1]](%[[Z_MUL_4]])[%c4] -// CHECK: %[[Y_OFFSET:.+]] = affine.apply #[[MAP0]](%[[Y_MUL_4]]) -// CHECK: %[[Y_SIZE:.+]] = affine.min #[[MAP1]](%[[Y_MUL_4]])[%c4] - -// CHECK: %[[INPUT:.+]] = memref.subview %{{.+}}[0, %[[Z_OFFSET]], %[[Y_OFFSET]], 0] [1, %[[Z_SIZE]], %[[Y_SIZE]], 16] [1, 1, 1, 1] : memref<1x225x225x16xf32> to memref<1x?x?x16xf32, {{.+}}> - -// CHECK: %[[X_SIZE:.+]] = affine.min #[[MAP2]](%[[X_MUL_16]])[%c16] - -// CHECK: %[[FILTER:.+]] = memref.subview %{{.+}}[0, 0, 0, %[[X_MUL_16]]] [3, 3, 16, %[[X_SIZE]]] [1, 1, 1, 1] : memref<3x3x16x32xf32> to memref<3x3x16x?xf32, {{.+}}> - -// CHECK: %[[Z_SIZE:.+]] = affine.min #[[MAP3]](%[[Z_MUL_4]])[%c4] -// CHECK: %[[Y_SIZE:.+]] = affine.min #[[MAP3]](%[[Y_MUL_4]])[%c4] -// CHECK: %[[OUTPUT:.+]] = memref.subview %{{.+}}[0, %[[Z_MUL_4]], %[[Y_MUL_4]], %[[X_MUL_16]]] [1, %[[Z_SIZE]], %[[Y_SIZE]], %[[X_SIZE]]] [1, 1, 1, 1] : memref<1x112x112x32xf32> to memref<1x?x?x?xf32, {{.+}}> - -// CHECK: linalg.fill(%[[OUTPUT]], %{{.+}}) -// CHECK: linalg.conv_2d_input_nhwc_filter_hwcf {dilations = dense<1> : tensor<2xi64>, is_root_op, strides = dense<2> : tensor<2xi64>} ins(%[[INPUT]], %[[FILTER]] : memref<1x?x?x16xf32, {{.+}}>, memref<3x3x16x?xf32, {{.+}}>) outs(%[[OUTPUT]] : memref<1x?x?x?xf32, {{.+}}>) - -// ----- - -hal.executable @matmul_dynamic_shape attributes {sym_visibility = "private"} { - hal.interface @legacy_io { - hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read" - hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read" - hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard" - } - hal.executable.target @vulkan_spirv, filter="vulkan*" { - hal.executable.entry_point @matmul_dynamic_shape attributes { - interface = @legacy_io, ordinal = 0 : index, - signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, !flow.dispatch.tensor) -> ()} - module attributes {spv.target_env = #spv.target_env<#spv.vce, ARM:IntegratedGPU, {}>} { - func @matmul_dynamic_shape() { - %cst = constant 0.000000e+00 : f32 - %c0 = constant 0 : index - %0 = hal.interface.load.constant offset = 0 : index - %1 = hal.interface.load.constant offset = 1 : index - %2 = hal.interface.binding.subspan @legacy_io::@arg0[%c0] : memref - %3 = hal.interface.binding.subspan @legacy_io::@arg1[%c0] : memref - %4 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : memref - %5 = hal.interface.load.constant offset = 2 : index - %6 = hal.interface.load.constant offset = 3 : index - %7 = hal.interface.load.constant offset = 4 : index - %8 = hal.interface.load.constant offset = 5 : index - %9 = hal.interface.load.constant offset = 6 : index - %10 = hal.interface.load.constant offset = 7 : index - %11 = shapex.make_ranked_shape %5, %6 : (index, index) -> !shapex.ranked_shape<[?,?]> - %12 = shapex.tie_shape %2, %11 : memref, !shapex.ranked_shape<[?,?]> - %13 = shapex.make_ranked_shape %7, %8 : (index, index) -> !shapex.ranked_shape<[?,?]> - %14 = shapex.tie_shape %3, %13 : memref, !shapex.ranked_shape<[?,?]> - %15 = shapex.make_ranked_shape %9, %10 : (index, index) -> !shapex.ranked_shape<[?,?]> - %16 = shapex.tie_shape %4, %15 : memref, !shapex.ranked_shape<[?,?]> - %workgroup_size_x = hal.interface.workgroup.size[0] : index - %workgroup_size_y = hal.interface.workgroup.size[1] : index - %workgroup_id_x = hal.interface.workgroup.id[0] : index - %workgroup_count_x = hal.interface.workgroup.count[0] : index - %workgroup_id_y = hal.interface.workgroup.id[1] : index - %workgroup_count_y = hal.interface.workgroup.count[1] : index - %17 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_y, %workgroup_size_y] - %18 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_y, %workgroup_size_y] - scf.for %arg0 = %17 to %5 step %18 { - %19 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_x, %workgroup_size_x] - %20 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_x, %workgroup_size_x] - scf.for %arg1 = %19 to %8 step %20 { - %21 = affine.min affine_map<(d0)[s0, s1] -> (s1, -d0 + s0)>(%arg0)[%5, %workgroup_size_y] - %22 = memref.subview %12[%arg0, 0] [%21, %6] [1, 1] : memref to memref (d0 * s1 + s0 + d1)>> - %23 = affine.min affine_map<(d0)[s0, s1] -> (s1, -d0 + s0)>(%arg1)[%8, %workgroup_size_x] - %24 = memref.subview %14[0, %arg1] [%7, %23] [1, 1] : memref to memref (d0 * s1 + s0 + d1)>> - %25 = affine.min affine_map<(d0)[s0, s1] -> (s1, -d0 + s0)>(%arg0)[%0, %workgroup_size_y] - %26 = affine.min affine_map<(d0)[s0, s1] -> (s1, -d0 + s0)>(%arg1)[%1, %workgroup_size_x] - %27 = memref.subview %16[%arg0, %arg1] [%25, %26] [1, 1] : memref to memref (d0 * s1 + s0 + d1)>> - linalg.fill(%27, %cst) {__internal_linalg_transform__ = "workgroup"} : memref (d0 * s1 + s0 + d1)>>, f32 - linalg.matmul {__internal_linalg_transform__ = "workgroup"} ins(%22, %24 : memref (d0 * s1 + s0 + d1)>>, memref (d0 * s1 + s0 + d1)>>) outs(%27 : memref (d0 * s1 + s0 + d1)>>) - } - } - return - } - hal.interface @legacy_io attributes {sym_visibility = "private"} { - hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read" - hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read" - hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard" - } - } - } -} - -// Check that for a fully dynamic shaped dispatch region, we can: -// 1) Generate symbolic workgroup counts, -// 2) Replace hal.interface.workgroup.size (but not .count) ops with constants. - -// CHECK: #[[DIV16MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 16)> -// CHECK: #[[DIV4MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 4)> -// CHECK: #[[MULMAP:.+]] = affine_map<()[s0, s1] -> (s0 * s1)> -// CHECK: #[[YBOUNDMAP:.+]] = affine_map<(d0)[s0, s1] -> (4, -d0 + s0)> -// CHECK: #[[XBOUNDMAP:.+]] = affine_map<(d0)[s0, s1] -> (16, -d0 + s0)> - -// CHECK: hal.executable.entry_point @matmul_dynamic_shape -// CHECK: ^{{.+}}(%[[BBARG0:.+]]: index, %[[BBARG1:.+]]: index, %{{.+}}: index): -// CHECK: %c1 = constant 1 : index -// CHECK: %[[SIZE0:.+]] = affine.apply #[[DIV16MAP]]()[%[[BBARG0]]] -// CHECK: %[[SIZE1:.+]] = affine.apply #[[DIV4MAP]]()[%[[BBARG1]]] -// CHECK: hal.return %[[SIZE0]], %[[SIZE1]], %c1 - -// CHECK: func @matmul_dynamic_shape() -// CHECK-SAME: spv.entry_point_abi = {local_size = dense<[4, 4, 1]> : vector<3xi32>} - -// CHECK: %[[C_DIM0:.+]] = hal.interface.load.constant offset = 0 : index -// CHECK: %[[C_DIM1:.+]] = hal.interface.load.constant offset = 1 : index -// CHECK: %[[A_DIM0:.+]] = hal.interface.load.constant offset = 2 : index -// CHECK: %[[A_DIM1:.+]] = hal.interface.load.constant offset = 3 : index -// CHECK: %[[B_DIM0:.+]] = hal.interface.load.constant offset = 4 : index -// CHECK: %[[B_DIM1:.+]] = hal.interface.load.constant offset = 5 : index - -// CHECK: %[[ID_X:.+]] = hal.interface.workgroup.id[0] : index -// CHECK: %[[COUNT_X:.+]] = hal.interface.workgroup.count[0] : index -// CHECK: %[[ID_Y:.+]] = hal.interface.workgroup.id[1] : index -// CHECK: %[[COUNT_Y:.+]] = hal.interface.workgroup.count[1] : index - -// CHECK: %[[Y_LB:.+]] = affine.apply #[[MULMAP]]()[%[[ID_Y]], %c4] -// CHECK: %[[Y_STEP:.+]] = affine.apply #[[MULMAP]]()[%[[COUNT_Y]], %c4] -// CHECK: scf.for %[[IV_Y:.+]] = %[[Y_LB]] to %[[A_DIM0]] step %[[Y_STEP]] -// CHECK: %[[X_LB:.+]] = affine.apply #[[MULMAP]]()[%[[ID_X]], %c16] -// CHECK: %[[X_STEP:.+]] = affine.apply #[[MULMAP]]()[%[[COUNT_X]], %c16] -// CHECK: scf.for %[[IV_X:.+]] = %[[X_LB]] to %[[B_DIM1]] step %[[X_STEP]] -// CHECK: %[[Y_SIZE:.+]] = affine.min #[[YBOUNDMAP]](%[[IV_Y]])[%[[A_DIM0]], %c4] -// CHECK: %[[A_TILE:.+]] = memref.subview %{{.+}}[%[[IV_Y]], 0] [%[[Y_SIZE]], %[[A_DIM1]]] [1, 1] : memref to memref -// CHECK: %[[X_SIZE:.+]] = affine.min #[[XBOUNDMAP]](%[[IV_X]])[%[[B_DIM1]], %c16] -// CHECK: %[[B_TILE:.+]] = memref.subview %{{.+}}[0, %[[IV_X]]] [%[[B_DIM0]], %[[X_SIZE]]] [1, 1] : memref to memref -// CHECK: %[[Y_SIZE:.+]] = affine.min #[[YBOUNDMAP]](%[[IV_Y]])[%[[C_DIM0]], %c4] -// CHECK: %[[X_SIZE:.+]] = affine.min #[[XBOUNDMAP]](%[[IV_X]])[%[[C_DIM1]], %c16] -// CHECK: %[[C_TILE:.+]] = memref.subview %{{.+}}[%[[IV_Y]], %[[IV_X]]] [%[[Y_SIZE]], %[[X_SIZE]]] [1, 1] : memref to memref -// CHECK: linalg.fill(%[[C_TILE]], %cst) {__internal_linalg_transform__ = "workgroup"} : memref, f32 -// CHECK: linalg.matmul {__internal_linalg_transform__ = "workgroup", is_root_op} ins(%[[A_TILE]], %[[B_TILE]] : memref, memref) outs(%[[C_TILE]] : memref) +// CHECK-DAG: %[[INPUT:.+]] = hal.interface.binding.subspan @legacy_io::@arg0 +// CHECK-DAG: %[[FILTER:.+]] = hal.interface.binding.subspan @legacy_io::@arg1 +// CHECK-DAG: %[[OUTPUT:.+]] = hal.interface.binding.subspan @legacy_io::@ret0 + +// CHECK-DAG: %[[ID_X:.+]] = hal.interface.workgroup.id[0] : index +// CHECK-DAG: %[[ID_Y:.+]] = hal.interface.workgroup.id[1] : index +// CHECK-DAG: %[[ID_Z:.+]] = hal.interface.workgroup.id[2] : index + +// CHECK-DAG: %[[OUTPUT_OFFSET_Z:.+]] = affine.apply #[[MAP0]]()[%[[ID_Z]]] +// CHECK-DAG: %[[OUTPUT_OFFSET_Y:.+]] = affine.apply #[[MAP0]]()[%[[ID_Y]]] +// CHECK-DAG: %[[OUTPUT_OFFSET_X:.+]] = affine.apply #[[MAP1]]()[%[[ID_X]]] +// CHECK-DAG: %[[INPUT_OFFSET_Z:.+]] = affine.apply #[[MAP2]]()[%[[ID_Z]]] +// CHECK-DAG: %[[INPUT_SIZE_Z:.+]] = affine.min #[[MAP3]]()[%[[ID_Z]]] +// CHECK-DAG: %[[INPUT_OFFSET_Y:.+]] = affine.apply #[[MAP2]]()[%[[ID_Y]]] +// CHECK-DAG: %[[INPUT_SIZE_Y:.+]] = affine.min #[[MAP3]]()[%[[ID_Y]]] + +// CHECK: %[[INPUT_VIEW:.+]] = memref.subview %[[INPUT]] +// CHECK-SAME: [0, %[[INPUT_OFFSET_Z]], %[[INPUT_OFFSET_Y]], 0] +// CHECK-SAME: [1, %[[INPUT_SIZE_Z]], %[[INPUT_SIZE_Y]], 16] [1, 1, 1, 1] +// CHECK-SAME: memref<1x225x225x16xf32> to memref<1x?x?x16xf32, {{.+}}> + +// CHECK: %[[OUTPUT_SIZE_X:.+]] = affine.min #[[MAP5]]()[%[[ID_X]]] +// CHECK: %[[FILTER_VIEW:.+]] = memref.subview %[[FILTER]] +// CHECK-SAME: [0, 0, 0, %[[OUTPUT_OFFSET_X]]] [3, 3, 16, %[[OUTPUT_SIZE_X]]] +// CHECK-SAME: memref<3x3x16x32xf32> to memref<3x3x16x?xf32, {{.+}}> + +// CHECK-DAG: %[[OUTPUT_SIZE_Z:.+]] = affine.min #[[MAP7]]()[%[[ID_Z]]] +// CHECK-DAG: %[[OUTPUT_SIZE_Y:.+]] = affine.min #[[MAP7]]()[%[[ID_Y]]] +// CHECK: %[[OUTPUT_VIEW:.+]] = memref.subview %[[OUTPUT]] +// CHECK-SAME: [0, %[[OUTPUT_OFFSET_Z]], %[[OUTPUT_OFFSET_Y]], %[[OUTPUT_OFFSET_X]]] +// CHECK-SAME: [1, %[[OUTPUT_SIZE_Z]], %[[OUTPUT_SIZE_Y]], %[[OUTPUT_SIZE_X]]] +// CHECK-SAME: memref<1x112x112x32xf32> to memref<1x?x?x?xf32, {{.+}}> + +// CHECK: linalg.fill(%[[OUTPUT_VIEW]], %{{.+}}) +// CHECK: linalg.conv_2d_input_nhwc_filter_hwcf +// CHECK-SAME: ins(%[[INPUT_VIEW]], %[[FILTER_VIEW]] : memref<1x?x?x16xf32, #map{{[0-9]+}}>, memref<3x3x16x?xf32, #map{{[0-9]+}}>) +// CHECK-SAME: outs(%[[OUTPUT_VIEW]] : memref<1x?x?x?xf32, #map{{[0-9]+}}>) diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/concretize_tile_among_workgroups_dynamic.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/concretize_tile_among_workgroups_dynamic.mlir new file mode 100644 index 000000000000..d2070e5d05b2 --- /dev/null +++ b/iree/compiler/Conversion/LinalgToSPIRV/test/concretize_tile_among_workgroups_dynamic.mlir @@ -0,0 +1,117 @@ +// RUN: iree-opt -split-input-file -iree-spirv-tile-size=4,16 -iree-spirv-workgroup-size=4,4,1 -pass-pipeline="hal.executable(hal.executable.target(iree-spirv-concretize-tile-among-workgroups))" -canonicalize -cse %s | IreeFileCheck %s + +hal.executable @matmul_dynamic_shape attributes {sym_visibility = "private"} { + hal.interface @legacy_io { + hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read" + hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read" + hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard" + } + hal.executable.target @vulkan_spirv, filter="vulkan*" { + hal.executable.entry_point @matmul_dynamic_shape attributes { + interface = @legacy_io, ordinal = 0 : index, + signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, !flow.dispatch.tensor) -> ()} + module attributes {spv.target_env = #spv.target_env<#spv.vce, ARM:IntegratedGPU, {}>} { + func @matmul_dynamic_shape() { + %cst = constant 0.000000e+00 : f32 + %c0 = constant 0 : index + %0 = hal.interface.load.constant offset = 0 : index + %1 = hal.interface.load.constant offset = 1 : index + %2 = hal.interface.binding.subspan @legacy_io::@arg0[%c0] : memref + %3 = hal.interface.binding.subspan @legacy_io::@arg1[%c0] : memref + %4 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : memref + %5 = hal.interface.load.constant offset = 2 : index + %6 = hal.interface.load.constant offset = 3 : index + %7 = hal.interface.load.constant offset = 4 : index + %8 = hal.interface.load.constant offset = 5 : index + %9 = hal.interface.load.constant offset = 6 : index + %10 = hal.interface.load.constant offset = 7 : index + %11 = shapex.make_ranked_shape %5, %6 : (index, index) -> !shapex.ranked_shape<[?,?]> + %12 = shapex.tie_shape %2, %11 : memref, !shapex.ranked_shape<[?,?]> + %13 = shapex.make_ranked_shape %7, %8 : (index, index) -> !shapex.ranked_shape<[?,?]> + %14 = shapex.tie_shape %3, %13 : memref, !shapex.ranked_shape<[?,?]> + %15 = shapex.make_ranked_shape %9, %10 : (index, index) -> !shapex.ranked_shape<[?,?]> + %16 = shapex.tie_shape %4, %15 : memref, !shapex.ranked_shape<[?,?]> + %workgroup_size_x = hal.interface.workgroup.size[0] : index + %workgroup_size_y = hal.interface.workgroup.size[1] : index + %workgroup_id_x = hal.interface.workgroup.id[0] : index + %workgroup_count_x = hal.interface.workgroup.count[0] : index + %workgroup_id_y = hal.interface.workgroup.id[1] : index + %workgroup_count_y = hal.interface.workgroup.count[1] : index + %17 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_y, %workgroup_size_y] + %18 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_y, %workgroup_size_y] + scf.for %arg0 = %17 to %5 step %18 { + %19 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_x, %workgroup_size_x] + %20 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_x, %workgroup_size_x] + scf.for %arg1 = %19 to %8 step %20 { + %21 = affine.min affine_map<(d0)[s0, s1] -> (s1, -d0 + s0)>(%arg0)[%5, %workgroup_size_y] + %22 = memref.subview %12[%arg0, 0] [%21, %6] [1, 1] : memref to memref (d0 * s1 + s0 + d1)>> + %23 = affine.min affine_map<(d0)[s0, s1] -> (s1, -d0 + s0)>(%arg1)[%8, %workgroup_size_x] + %24 = memref.subview %14[0, %arg1] [%7, %23] [1, 1] : memref to memref (d0 * s1 + s0 + d1)>> + %25 = affine.min affine_map<(d0)[s0, s1] -> (s1, -d0 + s0)>(%arg0)[%0, %workgroup_size_y] + %26 = affine.min affine_map<(d0)[s0, s1] -> (s1, -d0 + s0)>(%arg1)[%1, %workgroup_size_x] + %27 = memref.subview %16[%arg0, %arg1] [%25, %26] [1, 1] : memref to memref (d0 * s1 + s0 + d1)>> + linalg.fill(%27, %cst) {__internal_linalg_transform__ = "workgroup"} : memref (d0 * s1 + s0 + d1)>>, f32 + linalg.matmul {__internal_linalg_transform__ = "workgroup"} ins(%22, %24 : memref (d0 * s1 + s0 + d1)>>, memref (d0 * s1 + s0 + d1)>>) outs(%27 : memref (d0 * s1 + s0 + d1)>>) + } + } + return + } + hal.interface @legacy_io attributes {sym_visibility = "private"} { + hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read" + hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read" + hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard" + } + } + } +} + +// Check that for a fully dynamic shaped dispatch region, we can: +// 1) Generate symbolic workgroup counts, +// 2) Replace hal.interface.workgroup.size (but not .count) ops with constants. + +// CHECK-DAG: #[[DIV16MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 16)> +// CHECK-DAG: #[[DIV4MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 4)> +// CHECK-DAG: #[[MUL16MAP:.+]] = affine_map<()[s0] -> (s0 * 16)> +// CHECK-DAG: #[[MUL4MAP:.+]] = affine_map<()[s0] -> (s0 * 4)> +// CHECK-DAG: #[[YBOUNDMAP:.+]] = affine_map<(d0)[s0] -> (4, -d0 + s0)> +// CHECK-DAG: #[[XBOUNDMAP:.+]] = affine_map<(d0)[s0] -> (16, -d0 + s0)> + +// CHECK: hal.executable.entry_point @matmul_dynamic_shape +// CHECK: ^{{.+}}(%[[BBARG0:.+]]: index, %[[BBARG1:.+]]: index, %{{.+}}: index): +// CHECK: %c1 = constant 1 : index +// CHECK: %[[SIZE0:.+]] = affine.apply #[[DIV16MAP]]()[%[[BBARG0]]] +// CHECK: %[[SIZE1:.+]] = affine.apply #[[DIV4MAP]]()[%[[BBARG1]]] +// CHECK: hal.return %[[SIZE0]], %[[SIZE1]], %c1 + +// CHECK: func @matmul_dynamic_shape() +// CHECK-SAME: spv.entry_point_abi = {local_size = dense<[4, 4, 1]> : vector<3xi32>} + +// CHECK: %[[C_DIM0:.+]] = hal.interface.load.constant offset = 0 : index +// CHECK: %[[C_DIM1:.+]] = hal.interface.load.constant offset = 1 : index +// CHECK: %[[A_DIM0:.+]] = hal.interface.load.constant offset = 2 : index +// CHECK: %[[A_DIM1:.+]] = hal.interface.load.constant offset = 3 : index +// CHECK: %[[B_DIM0:.+]] = hal.interface.load.constant offset = 4 : index +// CHECK: %[[B_DIM1:.+]] = hal.interface.load.constant offset = 5 : index + +// CHECK: %[[ID_X:.+]] = hal.interface.workgroup.id[0] : index +// CHECK: %[[COUNT_X:.+]] = hal.interface.workgroup.count[0] : index +// CHECK: %[[ID_Y:.+]] = hal.interface.workgroup.id[1] : index +// CHECK: %[[COUNT_Y:.+]] = hal.interface.workgroup.count[1] : index + +// CHECK: %[[Y_LB:.+]] = affine.apply #[[MUL4MAP]]()[%[[ID_Y]]] +// CHECK: %[[Y_STEP:.+]] = affine.apply #[[MUL4MAP]]()[%[[COUNT_Y]]] +// CHECK: scf.for %[[IV_Y:.+]] = %[[Y_LB]] to %[[A_DIM0]] step %[[Y_STEP]] +// CHECK: %[[X_LB:.+]] = affine.apply #[[MUL16MAP]]()[%[[ID_X]]] +// CHECK: %[[X_STEP:.+]] = affine.apply #[[MUL16MAP]]()[%[[COUNT_X]]] +// CHECK: scf.for %[[IV_X:.+]] = %[[X_LB]] to %[[B_DIM1]] step %[[X_STEP]] +// CHECK: %[[Y_SIZE:.+]] = affine.min #[[YBOUNDMAP]](%[[IV_Y]])[%[[A_DIM0]]] +// CHECK: %[[A_TILE:.+]] = memref.subview %{{.+}}[%[[IV_Y]], 0] [%[[Y_SIZE]], %[[A_DIM1]]] [1, 1] : memref to memref +// CHECK: %[[X_SIZE:.+]] = affine.min #[[XBOUNDMAP]](%[[IV_X]])[%[[B_DIM1]]] +// CHECK: %[[B_TILE:.+]] = memref.subview %{{.+}}[0, %[[IV_X]]] [%[[B_DIM0]], %[[X_SIZE]]] [1, 1] : memref to memref +// CHECK: %[[Y_SIZE:.+]] = affine.min #[[YBOUNDMAP]](%[[IV_Y]])[%[[C_DIM0]]] +// CHECK: %[[X_SIZE:.+]] = affine.min #[[XBOUNDMAP]](%[[IV_X]])[%[[C_DIM1]]] +// CHECK: %[[C_TILE:.+]] = memref.subview %{{.+}}[%[[IV_Y]], %[[IV_X]]] [%[[Y_SIZE]], %[[X_SIZE]]] [1, 1] : memref to memref +// CHECK: linalg.fill(%[[C_TILE]], %cst) +// CHECK: linalg.matmul +// CHECK-SAME: ins(%[[A_TILE]], %[[B_TILE]] +// CHECK-SAME: outs(%[[C_TILE]]