diff --git a/SUBMODULE_VERSIONS.txt b/SUBMODULE_VERSIONS.txt index 643b229576e8..33e4fba8e8e3 100644 --- a/SUBMODULE_VERSIONS.txt +++ b/SUBMODULE_VERSIONS.txt @@ -4,16 +4,16 @@ 4fb0ff7069bd88ee85902f4d0bb62794e5f6d021 third_party/flatcc b1fbd33c06cdb0024c67733c6fdec2009d17b384 third_party/googletest 88b845dee001723c4a0db1fe5477de735b6d3bb0 third_party/liburing -013b829185fee6d8eaa515a7e36ec468a2a02600 third_party/llvm-bazel -cd442157cff4aad209ae532cbf031abbe10bc1df third_party/llvm-project +189e771009a640214e08e855830ae6f15a83c655 third_party/llvm-bazel +1f6a57c1a0fad922e04a2b1f414b092d4b0cd8b0 third_party/llvm-project 68547d08daca039467df49c7cc50c3a0061787f3 third_party/mlir-emitc -431be0e9b235e1b98adf0367f3beb440aa672875 third_party/mlir-hlo +cbef26c6a8f1e4be3f4cfb902db992c45e93b7a6 third_party/mlir-hlo 2b2bd45bbf9be04fd22ece5cc1f54679202e9257 third_party/pffft d8c7ee00a687ac369e62e2032514a93a9b413502 third_party/pybind11 2887692065c38ef6617f423feafc6b69dd0a0681 third_party/ruy 685f86471e9d26b3eb7676695a2e2cefb4551ae9 third_party/spirv_cross f8bf11a0253a32375c32cad92c841237b96696c0 third_party/spirv_headers -aa3bd9f6de5a76c4c226548a48e448d211978e92 third_party/tensorflow +da3da1e8a81a9866d98bcfe54eb21ec27cab7000 third_party/tensorflow 8732f0e94e4e41049a43029202bda94d7b4e85da third_party/tracy 9bd3f561bcee3f01d22912de10bb07ce4e23d378 third_party/vulkan_headers 3528e2aed3e8808f33e1e7d63eeb1560456a605a third_party/vulkan_memory_allocator diff --git a/experimental/ModelBuilder/ModelRunner.cpp b/experimental/ModelBuilder/ModelRunner.cpp index b7c01e2bb210..39741d36c8b0 100644 --- a/experimental/ModelBuilder/ModelRunner.cpp +++ b/experimental/ModelBuilder/ModelRunner.cpp @@ -59,12 +59,10 @@ void mlir::ModelRunner::compile( if (target == Target::CPUTarget) { // Lower vector operations progressively into more elementary // vector operations before running the regular compiler passes. - mlir::OwningRewritePatternList patterns; - mlir::vector::populateVectorSlicesLoweringPatterns(patterns, - module->getContext()); + mlir::OwningRewritePatternList patterns(module->getContext()); + mlir::vector::populateVectorSlicesLoweringPatterns(patterns); mlir::vector::populateVectorContractLoweringPatterns( - patterns, module->getContext(), - compilationOptions.vectorTransformsOptions); + patterns, compilationOptions.vectorTransformsOptions); (void)mlir::applyPatternsAndFoldGreedily(*module, std::move(patterns)); } runLoweringPass(compilationOptions.loweringPasses diff --git a/integrations/tensorflow/iree_tf_compiler/TF/ConvertToMHLO.cpp b/integrations/tensorflow/iree_tf_compiler/TF/ConvertToMHLO.cpp index e6b8de3d0160..40c1c1d59d8d 100644 --- a/integrations/tensorflow/iree_tf_compiler/TF/ConvertToMHLO.cpp +++ b/integrations/tensorflow/iree_tf_compiler/TF/ConvertToMHLO.cpp @@ -58,15 +58,15 @@ class ConvertToMHLOPass : public PassWrapper { // Lower TF Patterns must be separate from canonocalization patterns as // they are sometimes inversions of eachother. - OwningRewritePatternList lowerTfPatterns; + OwningRewritePatternList lowerTfPatterns(&getContext()); mlir::TF::PopulateLoweringTFPatterns(context, &lowerTfPatterns); - OwningRewritePatternList canonicalizePatterns; + OwningRewritePatternList canonicalizePatterns(&getContext()); for (auto *op : context->getRegisteredOperations()) { op->getCanonicalizationPatterns(canonicalizePatterns, context); } - OwningRewritePatternList patterns; + OwningRewritePatternList patterns(&getContext()); // Note that the `OperationConverter` orders patterns lexicographically by: // 1) Ascending legalization depth (i.e., minimum number of patterns // necessary to arrive at conversion target). @@ -98,10 +98,10 @@ class ConvertToMHLOPass : public PassWrapper { DenseSet prevUnconvertedOps; DenseSet unconvertedOps; - FrozenRewritePatternList frozenPatterns(std::move(patterns)); - FrozenRewritePatternList frozenCanonicalizePatterns( + FrozenRewritePatternSet frozenPatterns(std::move(patterns)); + FrozenRewritePatternSet frozenCanonicalizePatterns( std::move(canonicalizePatterns)); - FrozenRewritePatternList frozenTfPatterns(std::move(lowerTfPatterns)); + FrozenRewritePatternSet frozenTfPatterns(std::move(lowerTfPatterns)); while (true) { if (failed( applyPatternsAndFoldGreedily(op, frozenCanonicalizePatterns))) { diff --git a/integrations/tensorflow/iree_tf_compiler/dialect/tf_strings/conversion/convert_tf_to_tf_strings.cc b/integrations/tensorflow/iree_tf_compiler/dialect/tf_strings/conversion/convert_tf_to_tf_strings.cc index c2460b8e3227..55c399f128a8 100644 --- a/integrations/tensorflow/iree_tf_compiler/dialect/tf_strings/conversion/convert_tf_to_tf_strings.cc +++ b/integrations/tensorflow/iree_tf_compiler/dialect/tf_strings/conversion/convert_tf_to_tf_strings.cc @@ -146,7 +146,7 @@ class ConvertTFToTFStringsPass void populateTFToTFStringsPatterns(MLIRContext *ctx, OwningRewritePatternList &patterns) { - populateWithGenerated(ctx, patterns); + populateWithGenerated(patterns); patterns.insert(ctx); patterns.insert(ctx); } diff --git a/integrations/tensorflow/iree_tf_compiler/dialect/tf_tensorlist/conversion/convert_tf_to_tf_tensorlist.cc b/integrations/tensorflow/iree_tf_compiler/dialect/tf_tensorlist/conversion/convert_tf_to_tf_tensorlist.cc index 3e8aa175e395..1a83f35e5be5 100644 --- a/integrations/tensorflow/iree_tf_compiler/dialect/tf_tensorlist/conversion/convert_tf_to_tf_tensorlist.cc +++ b/integrations/tensorflow/iree_tf_compiler/dialect/tf_tensorlist/conversion/convert_tf_to_tf_tensorlist.cc @@ -98,8 +98,8 @@ void ConvertTFToTFTensorListPass::runOnOperation() { // The MLIR type conversion infrastructure doesn't handle this situation well. // It only knows how to handle blindly convert one type to another type. - OwningRewritePatternList patterns; - populateWithGenerated(&getContext(), patterns); + OwningRewritePatternList patterns(&getContext()); + populateWithGenerated(patterns); patterns.insert(&getContext()); ConversionTarget target(getContext()); diff --git a/integrations/tensorflow/iree_tf_compiler/dialect/utils/conversion_utils.h b/integrations/tensorflow/iree_tf_compiler/dialect/utils/conversion_utils.h index 107205f7b032..37942b7e2c06 100644 --- a/integrations/tensorflow/iree_tf_compiler/dialect/utils/conversion_utils.h +++ b/integrations/tensorflow/iree_tf_compiler/dialect/utils/conversion_utils.h @@ -55,7 +55,7 @@ class ConversionPass : public PassWrapper> { LogicalResult run() { auto module = this->getOperation(); - OwningRewritePatternList patterns; + OwningRewritePatternList patterns(&this->getContext()); Converter typeConverter; // Lower to the standard string operations. @@ -82,10 +82,8 @@ class ConversionPass : public PassWrapper> { llvm::all_of(op.getResultTypes(), func); }); - populateFuncOpTypeConversionPattern(patterns, &this->getContext(), - typeConverter); - populateCallOpTypeConversionPattern(patterns, &this->getContext(), - typeConverter); + populateFuncOpTypeConversionPattern(patterns, typeConverter); + populateCallOpTypeConversionPattern(patterns, typeConverter); auto result = applyPartialConversion(module.getOperation(), target, std::move(patterns)); diff --git a/iree/compiler/Conversion/CodegenUtils/ForOpCanonicalization.cpp b/iree/compiler/Conversion/CodegenUtils/ForOpCanonicalization.cpp index df1eb41c86bc..985bdb1ed2ac 100644 --- a/iree/compiler/Conversion/CodegenUtils/ForOpCanonicalization.cpp +++ b/iree/compiler/Conversion/CodegenUtils/ForOpCanonicalization.cpp @@ -217,7 +217,7 @@ struct ForOpCanonicalizationPass : PassWrapper { void runOnFunction() override { FuncOp fn = getFunction(); - OwningRewritePatternList patterns; + OwningRewritePatternList patterns(&getContext()); patterns.insert(fn.getContext()); (void)applyPatternsAndFoldGreedily(fn, std::move(patterns)); diff --git a/iree/compiler/Conversion/Common/BufferAllocViewCleanUpPass.cpp b/iree/compiler/Conversion/Common/BufferAllocViewCleanUpPass.cpp index 63213097cc82..c66b4e434005 100644 --- a/iree/compiler/Conversion/Common/BufferAllocViewCleanUpPass.cpp +++ b/iree/compiler/Conversion/Common/BufferAllocViewCleanUpPass.cpp @@ -108,7 +108,7 @@ struct RemoveDeadMemAllocs : RewritePattern { struct BufferAllocViewCleanUpPass : public PassWrapper { void runOnFunction() override { - OwningRewritePatternList patterns; + OwningRewritePatternList patterns(&getContext()); patterns.insert(&getContext()); patterns.insert(); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); diff --git a/iree/compiler/Conversion/Common/LinalgRewriteDestructiveUpdatesPass.cpp b/iree/compiler/Conversion/Common/LinalgRewriteDestructiveUpdatesPass.cpp index 774305657ba0..a6a0f2d03e7e 100644 --- a/iree/compiler/Conversion/Common/LinalgRewriteDestructiveUpdatesPass.cpp +++ b/iree/compiler/Conversion/Common/LinalgRewriteDestructiveUpdatesPass.cpp @@ -532,7 +532,7 @@ void LinalgRewriteDestructiveUpdates::runOnFunction() { // Non-default canonicalization patterns. // TODO: add Linalg tiling canonicalization patterns, affineminscf and others // as needed. - OwningRewritePatternList canonicalizationPatterns; + OwningRewritePatternList canonicalizationPatterns(&getContext()); scf::ForOp::getCanonicalizationPatterns(canonicalizationPatterns, context); (void)applyPatternsAndFoldGreedily(funcOp, std::move(canonicalizationPatterns)); diff --git a/iree/compiler/Conversion/Common/Transforms.cpp b/iree/compiler/Conversion/Common/Transforms.cpp index 5d7fa02a1b78..d8da92f6a02b 100644 --- a/iree/compiler/Conversion/Common/Transforms.cpp +++ b/iree/compiler/Conversion/Common/Transforms.cpp @@ -45,7 +45,7 @@ namespace iree_compiler { /// easier. void applyCanonicalizationPatternsForTiling(MLIRContext *context, Operation *op) { - OwningRewritePatternList canonicalizationPatterns; + OwningRewritePatternList canonicalizationPatterns(context); canonicalizationPatterns.insert(context); scf::ForOp::getCanonicalizationPatterns(canonicalizationPatterns, context); AffineApplyOp::getCanonicalizationPatterns(canonicalizationPatterns, context); @@ -344,7 +344,7 @@ LogicalResult defineWorkgroupCountRegion( LogicalResult materializeStaticLaunchInformation( FuncOp funcOp, ArrayRef workloadPerWorkgroup) { - OwningRewritePatternList patterns; + OwningRewritePatternList patterns(funcOp.getContext()); patterns.insert(funcOp.getContext(), workloadPerWorkgroup); if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { diff --git a/iree/compiler/Conversion/Common/VectorTransferOptimization.cpp b/iree/compiler/Conversion/Common/VectorTransferOptimization.cpp index c0ad44b6bde6..9904c49a1d8c 100644 --- a/iree/compiler/Conversion/Common/VectorTransferOptimization.cpp +++ b/iree/compiler/Conversion/Common/VectorTransferOptimization.cpp @@ -64,9 +64,8 @@ struct VectorTransferOptimizationPass // Generate vector.shape_cast for dropping leading one dimensions in vector // ops. This increases the chance that we can forward more transfer writes // to transfer reads. - OwningRewritePatternList patterns; - mlir::vector::populateCastAwayVectorLeadingOneDimPatterns( - patterns, funcOp.getContext()); + OwningRewritePatternList patterns(&getContext()); + mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(patterns); (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); vector::transferOpflowOpt(funcOp); diff --git a/iree/compiler/Conversion/HLOToHLO/Convert1x1ConvToDot.cpp b/iree/compiler/Conversion/HLOToHLO/Convert1x1ConvToDot.cpp index 4bfe8ec192aa..2a142f038067 100644 --- a/iree/compiler/Conversion/HLOToHLO/Convert1x1ConvToDot.cpp +++ b/iree/compiler/Conversion/HLOToHLO/Convert1x1ConvToDot.cpp @@ -130,7 +130,7 @@ struct Convert1x1ConvToDotPass void runOnFunction() override { MLIRContext *context = &getContext(); - OwningRewritePatternList patterns; + OwningRewritePatternList patterns(&getContext()); patterns.insert(context); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } diff --git a/iree/compiler/Conversion/HLOToHLO/DecomposeHLOClamp.cpp b/iree/compiler/Conversion/HLOToHLO/DecomposeHLOClamp.cpp index 0adbd571d0d6..d294d3e05484 100644 --- a/iree/compiler/Conversion/HLOToHLO/DecomposeHLOClamp.cpp +++ b/iree/compiler/Conversion/HLOToHLO/DecomposeHLOClamp.cpp @@ -60,7 +60,7 @@ struct DecomposeHLOClampPass void runOnFunction() override { MLIRContext *context = &getContext(); - OwningRewritePatternList patterns; + OwningRewritePatternList patterns(&getContext()); patterns.insert(context); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } diff --git a/iree/compiler/Conversion/HLOToHLO/DemoteF32ToF16.cpp b/iree/compiler/Conversion/HLOToHLO/DemoteF32ToF16.cpp index c66e2807981b..92bdc62bfa73 100644 --- a/iree/compiler/Conversion/HLOToHLO/DemoteF32ToF16.cpp +++ b/iree/compiler/Conversion/HLOToHLO/DemoteF32ToF16.cpp @@ -172,9 +172,9 @@ void ConvertF32ToF16Pass::runOnOperation() { ModuleOp moduleOp = getOperation(); FloatTypeConverter converter; - OwningRewritePatternList patterns; + OwningRewritePatternList patterns(&getContext()); patterns.insert(context, converter); - populateFuncOpTypeConversionPattern(patterns, context, converter); + populateFuncOpTypeConversionPattern(patterns, converter); F32ToF16ConversionTarget target(*context); target.markUnknownOpDynamicallyLegal(); if (failed(applyFullConversion(moduleOp, target, std::move(patterns)))) diff --git a/iree/compiler/Conversion/HLOToLinalg/FusionOfTensorOps.cpp b/iree/compiler/Conversion/HLOToLinalg/FusionOfTensorOps.cpp index 7dc89dc28980..bc5819c32e23 100644 --- a/iree/compiler/Conversion/HLOToLinalg/FusionOfTensorOps.cpp +++ b/iree/compiler/Conversion/HLOToLinalg/FusionOfTensorOps.cpp @@ -73,18 +73,19 @@ struct FusionOfTensorOpsPass } void runOnOperation() override { - OwningRewritePatternList fusionPatterns, interfacePatterns; + OwningRewritePatternList fusionPatterns(&getContext()); + OwningRewritePatternList interfacePatterns(&getContext()); Operation *op = getOperation(); MLIRContext *context = op->getContext(); interfacePatterns.insert(context); - FrozenRewritePatternList frozenInterfacePatterns( + FrozenRewritePatternSet frozenInterfacePatterns( std::move(interfacePatterns)); (void)applyPatternsAndFoldGreedily(op->getRegions(), frozenInterfacePatterns); - populateLinalgTensorOpsFusionPatterns(context, fusionPatterns); + populateLinalgTensorOpsFusionPatterns(fusionPatterns); (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(fusionPatterns)); diff --git a/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp b/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp index 68b7eac97d01..38acef360621 100644 --- a/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp +++ b/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp @@ -825,6 +825,8 @@ static LogicalResult createAndPropagateBufferUsedForResultTensors( // Canonicalization patterns. //===----------------------------------------------------------------------===// +// TODO(hanchung): Revisit the pattern, this seems no longer needed because the +// reshape ops are folded in tensors world. // Folds linalg.reshape op that directly reshaping an iree.placeholder op into // the iree.placeholder op itself. class FoldReshapeIntoPlaceholder final @@ -900,7 +902,7 @@ void ConvertHLOToLinalgOnBuffersPass::runOnFunction() { return signalPassFailure(); } - OwningRewritePatternList patterns; + OwningRewritePatternList patterns(&getContext()); populateHLOToLinalgOnBuffersConversionPatterns(context, patterns, resultTensorToBufferMap); patterns.insert( @@ -940,7 +942,7 @@ void ConvertHLOToLinalgOnBuffersPass::runOnFunction() { // Perform additional canonicalizations. { - OwningRewritePatternList foldingPatterns; + OwningRewritePatternList foldingPatterns(&getContext()); foldingPatterns.insert(context); (void)applyPatternsAndFoldGreedily(funcOp, std::move(foldingPatterns)); } diff --git a/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnTensors.cpp b/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnTensors.cpp index aecec545b36e..cfbc1ae586f1 100644 --- a/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnTensors.cpp +++ b/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnTensors.cpp @@ -194,7 +194,7 @@ struct ConvertHLOToLinalgOnTensorsPass } void runOnFunction() override { - OwningRewritePatternList patterns; + OwningRewritePatternList patterns(&getContext()); MLIRContext *context = &getContext(); populateHLOToLinalgOnTensorsConversionPatterns(context, patterns); if (useLinalgOnTensorsPath) { diff --git a/iree/compiler/Conversion/HLOToLinalg/ResolveShapeOps.cpp b/iree/compiler/Conversion/HLOToLinalg/ResolveShapeOps.cpp index 02d34eca08d1..4f9107d3077f 100644 --- a/iree/compiler/Conversion/HLOToLinalg/ResolveShapeOps.cpp +++ b/iree/compiler/Conversion/HLOToLinalg/ResolveShapeOps.cpp @@ -98,7 +98,7 @@ struct ResolveShapeOpsPass void ResolveShapeOpsPass::runOnFunction() { MLIRContext *context = &getContext(); - OwningRewritePatternList dimPatterns; + OwningRewritePatternList dimPatterns(&getContext()); dimPatterns.insert(context); // Set up a target to convert all std.dim ops. We need a conversion target @@ -111,7 +111,7 @@ void ResolveShapeOpsPass::runOnFunction() { return signalPassFailure(); } - OwningRewritePatternList shapePatterns; + OwningRewritePatternList shapePatterns(&getContext()); shapePatterns.insert(context); Shape::RankedDimOp::getCanonicalizationPatterns(shapePatterns, context); diff --git a/iree/compiler/Conversion/HLOToLinalg/test/fusion.mlir b/iree/compiler/Conversion/HLOToLinalg/test/fusion.mlir index e964a08ea089..a832e1955254 100644 --- a/iree/compiler/Conversion/HLOToLinalg/test/fusion.mlir +++ b/iree/compiler/Conversion/HLOToLinalg/test/fusion.mlir @@ -32,10 +32,9 @@ module { // ----- module { - func @fuse_store_reshape() { + func @fuse_store_reshape(%arg0: tensor<100xi32>) { %c0 = constant 0 : index - %c42 = constant dense<42> : tensor<100xi32> - %0 = linalg.tensor_reshape %c42 [affine_map<(d0, d1) -> (d0, d1)>] : tensor<100xi32> into tensor<4x25xi32> + %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1) -> (d0, d1)>] : tensor<100xi32> into tensor<4x25xi32> hal.interface.store.tensor %0, @legacy_io::@ret0, offset = %c0 : tensor<4x25xi32> return } @@ -45,8 +44,8 @@ module { } // CHECK-LABEL: func @fuse_store_reshape -// CHECK: %[[C42:.+]] = constant dense<{{.+}}> : tensor<100xi32> -// CHECK: hal.interface.store.tensor %[[C42]] +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: tensor<100xi32> +// CHECK: hal.interface.store.tensor %[[ARG0]] // ----- diff --git a/iree/compiler/Conversion/HLOToLinalg/test/linalg_tensor_to_buffer.mlir b/iree/compiler/Conversion/HLOToLinalg/test/linalg_tensor_to_buffer.mlir index d39791f1efff..c3e921a42818 100644 --- a/iree/compiler/Conversion/HLOToLinalg/test/linalg_tensor_to_buffer.mlir +++ b/iree/compiler/Conversion/HLOToLinalg/test/linalg_tensor_to_buffer.mlir @@ -320,66 +320,6 @@ module { // ----- -#map0 = affine_map<(d0, d1) -> (d0, d1)> -#map1 = affine_map<(d0, d1, d2) -> (d0, d1)> -#map2 = affine_map<(d0, d1, d2) -> (d2)> - -module { - func @store_reshape_src_and_result_2() { - %c0 = constant 0 : index - %shape = linalg.init_tensor[2, 4] : tensor<2x4xf32> - %0 = hal.interface.load.tensor @legacy_io::@arg0, offset = %c0 - {operand_result_index = 0 : i32} : tensor<2x4xf32> - %1 = linalg.generic { - indexing_maps = [#map0, #map0], - iterator_types = ["parallel", "parallel"]} - ins(%0 : tensor<2x4xf32>) - outs(%shape : tensor<2x4xf32>) { - ^bb0(%arg0: f32, %s: f32): // no predecessors - %2 = math.tanh %arg0 : f32 - linalg.yield %2 : f32 - } -> tensor<2x4xf32> - %3 = linalg.tensor_reshape %1 [#map1, #map2] - : tensor<2x4xf32> into tensor<1x2x4xf32> - %4 = linalg.tensor_reshape %1 [#map1, #map2] - : tensor<2x4xf32> into tensor<1x2x4xf32> - %5 = linalg.tensor_reshape %1 [#map1, #map2] - : tensor<2x4xf32> into tensor<1x2x4xf32> - hal.interface.store.tensor %3, @legacy_io::@ret0, offset = %c0 - {operand_result_index = 1 : i32} : tensor<1x2x4xf32> - hal.interface.store.tensor %4, @legacy_io::@ret1, offset = %c0 - {operand_result_index = 2 : i32} : tensor<1x2x4xf32> - hal.interface.store.tensor %5, @legacy_io::@ret2, offset = %c0 - {operand_result_index = 3 : i32} : tensor<1x2x4xf32> - return - } - hal.interface @legacy_io attributes {sym_visibility = "private"} { - hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", - access="Read" - hal.interface.binding @ret0, set=0, binding=1, type="StorageBuffer", - access="Write|Discard" - hal.interface.binding @ret1, set=0, binding=2, type="StorageBuffer", - access="Write|Discard" - hal.interface.binding @ret2, set=0, binding=3, type="StorageBuffer", - access="Write|Discard" - } -} - -// CHECK-LABEL: func @store_reshape_src_and_result_2 -// CHECK-DAG: %[[T0:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret2, operand_result_index = 3 : i32} : memref<1x2x4xf32> -// CHECK-DAG: %[[T1:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret2, operand_result_index = 3 : i32} : memref<2x4xf32> -// CHECK-DAG: %[[T2:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret1, operand_result_index = 2 : i32} : memref<1x2x4xf32> -// CHECK-DAG: %[[T3:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0, operand_result_index = 1 : i32} : memref<1x2x4xf32> -// CHECK-DAG: %[[T4:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0, operand_result_index = 0 : i32} : memref<2x4xf32> -// CHECK: linalg.generic -// CHECK-SAME: ins(%[[T4]] : -// CHECK-SAME: outs(%[[T1]] : -// CHECK: linalg.copy(%[[T0]], %[[T3]]) -// CHECK: linalg.copy(%[[T0]], %[[T2]]) -// CHECK: return - -// ----- - #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1)> #map1 = affine_map<(d0, d1, d2, d3) -> (d2, d3)> #map2 = affine_map<(d0, d1) -> (d0, d1)> diff --git a/iree/compiler/Conversion/LinalgToLLVM/ConvImg2ColMatmulConversion.cpp b/iree/compiler/Conversion/LinalgToLLVM/ConvImg2ColMatmulConversion.cpp index 783662d4a51e..26559860d048 100644 --- a/iree/compiler/Conversion/LinalgToLLVM/ConvImg2ColMatmulConversion.cpp +++ b/iree/compiler/Conversion/LinalgToLLVM/ConvImg2ColMatmulConversion.cpp @@ -200,7 +200,7 @@ void populateConvImg2ColMatmulConversionPatterns( void ConvImg2ColMatmulConversionPass::runOnFunction() { auto funcOp = getOperation(); auto context = funcOp.getContext(); - OwningRewritePatternList patterns; + OwningRewritePatternList patterns(&getContext()); populateConvImg2ColMatmulConversionPatterns(context, patterns); (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); } diff --git a/iree/compiler/Conversion/LinalgToLLVM/ConvertToLLVM.cpp b/iree/compiler/Conversion/LinalgToLLVM/ConvertToLLVM.cpp index fc1500f8481b..8d2883fa0c0d 100644 --- a/iree/compiler/Conversion/LinalgToLLVM/ConvertToLLVM.cpp +++ b/iree/compiler/Conversion/LinalgToLLVM/ConvertToLLVM.cpp @@ -644,26 +644,24 @@ class ConvertToLLVMPass void ConvertToLLVMPass::runOnOperation() { // Run Vector -> Vector transformations ahead of conversion to LLVM. { - OwningRewritePatternList patterns; - vector::populateVectorToVectorCanonicalizationPatterns(patterns, - &getContext()); - vector::populateVectorSlicesLoweringPatterns(patterns, &getContext()); - vector::populateVectorContractLoweringPatterns(patterns, &getContext()); + OwningRewritePatternList patterns(&getContext()); + vector::populateVectorToVectorCanonicalizationPatterns(patterns); + vector::populateVectorSlicesLoweringPatterns(patterns); + vector::populateVectorContractLoweringPatterns(patterns); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } { - OwningRewritePatternList vectorToLoopsPatterns; + OwningRewritePatternList vectorToLoopsPatterns(&getContext()); populateVectorToSCFConversionPatterns( - vectorToLoopsPatterns, &getContext(), - VectorTransferToSCFOptions().setUnroll(true)); + vectorToLoopsPatterns, VectorTransferToSCFOptions().setUnroll(true)); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(vectorToLoopsPatterns)); } // math dialect elementry functions -> polynomial form. { - OwningRewritePatternList mathPatterns; - populateMathPolynomialApproximationPatterns(mathPatterns, &getContext()); + OwningRewritePatternList mathPatterns(&getContext()); + populateMathPolynomialApproximationPatterns(mathPatterns); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(mathPatterns)); } @@ -674,12 +672,12 @@ void ConvertToLLVMPass::runOnOperation() { return success(); }); - OwningRewritePatternList patterns; - populateAffineToStdConversionPatterns(patterns, &getContext()); - populateLoopToStdConversionPatterns(patterns, &getContext()); - populateExpandTanhPattern(patterns, &getContext()); + OwningRewritePatternList patterns(&getContext()); + populateAffineToStdConversionPatterns(patterns); + populateLoopToStdConversionPatterns(patterns); + populateExpandTanhPattern(patterns); populateStdToLLVMConversionPatterns(converter, patterns); - populateVectorToSCFConversionPatterns(patterns, &getContext()); + populateVectorToSCFConversionPatterns(patterns); populateVectorToLLVMMatrixConversionPatterns(converter, patterns); populateVectorToLLVMConversionPatterns(converter, patterns); populateLinalgToLLVMConversionPatterns(converter, patterns); @@ -732,7 +730,7 @@ void ConvertToLLVMPass::runOnOperation() { // Post conversion patterns. { - OwningRewritePatternList postPatterns; + OwningRewritePatternList postPatterns(&getContext()); if (options_.unfuseFMAOps) { populateUnfusedFMAOpsPassPatterns(&getContext(), postPatterns); (void)applyPatternsAndFoldGreedily(module, std::move(postPatterns)); diff --git a/iree/compiler/Conversion/LinalgToLLVM/FoldTensorExtractOpPass.cpp b/iree/compiler/Conversion/LinalgToLLVM/FoldTensorExtractOpPass.cpp index 53e078ba14b6..026dd95cba4c 100644 --- a/iree/compiler/Conversion/LinalgToLLVM/FoldTensorExtractOpPass.cpp +++ b/iree/compiler/Conversion/LinalgToLLVM/FoldTensorExtractOpPass.cpp @@ -62,9 +62,8 @@ class FoldTensorExtractOpPass } // namespace void FoldTensorExtractOpPass::runOnOperation() { - MLIRContext *context = &getContext(); - OwningRewritePatternList patterns; - populateWithGenerated(context, patterns); + OwningRewritePatternList patterns(&getContext()); + populateWithGenerated(patterns); if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) signalPassFailure(); } diff --git a/iree/compiler/Conversion/LinalgToLLVM/LinalgTileAndVectorizePass.cpp b/iree/compiler/Conversion/LinalgToLLVM/LinalgTileAndVectorizePass.cpp index 5ca30867d466..441d9e7950b1 100644 --- a/iree/compiler/Conversion/LinalgToLLVM/LinalgTileAndVectorizePass.cpp +++ b/iree/compiler/Conversion/LinalgToLLVM/LinalgTileAndVectorizePass.cpp @@ -136,7 +136,7 @@ void TileAndVectorizeWorkgroups::runOnFunction() { // Promotes workgroups subviews to a full-tile allocated on the stack. if (clEnablePromoteWorkgroupToFullTiles) { - OwningRewritePatternList promotionPatterns; + OwningRewritePatternList promotionPatterns(&getContext()); promotionPatterns.insert( context, linalg::LinalgPromotionOptions().setAllocationDeallocationFns( @@ -151,7 +151,7 @@ void TileAndVectorizeWorkgroups::runOnFunction() { // Workgroup first level of tiling. { // First level of tiling patterns. (workgroups memory) - OwningRewritePatternList l1patterns; + OwningRewritePatternList l1patterns(&getContext()); l1patterns.insert( linalg::LinalgTilingOptions().setTileSizeComputationFunction( [](OpBuilder &builder, @@ -173,7 +173,7 @@ void TileAndVectorizeWorkgroups::runOnFunction() { // Second level of tiling. (workgroups memory -> vectors) { - OwningRewritePatternList l2patterns; + OwningRewritePatternList l2patterns(&getContext()); l2patterns.insert( linalg::LinalgTilingOptions().setTileSizeComputationFunction( [](OpBuilder &builder, @@ -192,7 +192,7 @@ void TileAndVectorizeWorkgroups::runOnFunction() { // Apply canonicalization. { - OwningRewritePatternList canonicalizationPatterns; + OwningRewritePatternList canonicalizationPatterns(&getContext()); canonicalizationPatterns.insert(context); AffineApplyOp::getCanonicalizationPatterns(canonicalizationPatterns, context); @@ -207,10 +207,10 @@ void TileAndVectorizeWorkgroups::runOnFunction() { // Apply vectorization patterns. { - OwningRewritePatternList vectorizationPatterns; + OwningRewritePatternList vectorizationPatterns(&getContext()); linalg::insertVectorizationPatterns( - vectorizationPatterns, context, linalg::LinalgVectorizationOptions(), + vectorizationPatterns, linalg::LinalgVectorizationOptions(), linalg::LinalgTransformationFilter( Identifier::get(getVectorizeMarker(), context))); if (failed(applyPatternsAndFoldGreedily( @@ -232,7 +232,7 @@ void TileAndVectorizeWorkgroups::runOnFunction() { vector::VectorTransformsOptions vectorTransformsOptions = vector::VectorTransformsOptions().setVectorTransformsOptions( vector::VectorContractLowering::OuterProduct); - OwningRewritePatternList vectorContractLoweringPatterns; + OwningRewritePatternList vectorContractLoweringPatterns(&getContext()); vectorContractLoweringPatterns .insert( @@ -247,16 +247,15 @@ void TileAndVectorizeWorkgroups::runOnFunction() { { VectorTransferToSCFOptions vectorToSCFOptions = VectorTransferToSCFOptions().setUnroll(true); - OwningRewritePatternList vectorToLoopsPatterns; - populateVectorToSCFConversionPatterns(vectorToLoopsPatterns, context, + OwningRewritePatternList vectorToLoopsPatterns(&getContext()); + populateVectorToSCFConversionPatterns(vectorToLoopsPatterns, vectorToSCFOptions); // Hosit hierarchical tiling indexing and other loop invariant transfer // ops computation. linalg::hoistRedundantVectorTransfers(funcOp); // TODO(ataei): Move this to common vector dialect patterns. - populateStdLegalizationPatternsForSPIRVLowering(context, - vectorToLoopsPatterns); + populateStdLegalizationPatternsForSPIRVLowering(vectorToLoopsPatterns); if (failed(applyPatternsAndFoldGreedily( funcOp, std::move(vectorToLoopsPatterns)))) { return signalPassFailure(); diff --git a/iree/compiler/Conversion/LinalgToLLVM/LinalgVectorizePass.cpp b/iree/compiler/Conversion/LinalgToLLVM/LinalgVectorizePass.cpp index ada02740234d..5c7ac44f90a2 100644 --- a/iree/compiler/Conversion/LinalgToLLVM/LinalgVectorizePass.cpp +++ b/iree/compiler/Conversion/LinalgToLLVM/LinalgVectorizePass.cpp @@ -58,10 +58,10 @@ void LinalgVectorizationPass::runOnFunction() { MLIRContext *context = &getContext(); // Apply vectorization patterns. { - OwningRewritePatternList vectorizationPatterns; + OwningRewritePatternList vectorizationPatterns(&getContext()); linalg::insertVectorizationPatterns( - vectorizationPatterns, context, linalg::LinalgVectorizationOptions(), + vectorizationPatterns, linalg::LinalgVectorizationOptions(), linalg::LinalgTransformationFilter(ArrayRef( Identifier::get(getWorkgroupMarker(), context)))); (void)applyPatternsAndFoldGreedily(funcOp, @@ -84,22 +84,21 @@ void LinalgVectorizationPass::runOnFunction() { // Apply unrolling patterns. { - OwningRewritePatternList vectorUnrollPatterns; + OwningRewritePatternList vectorUnrollPatterns(&getContext()); vectorUnrollPatterns.insert( context, vector::UnrollVectorOptions().setNativeShapeFn(getShape)); (void)applyPatternsAndFoldGreedily(funcOp, std::move(vectorUnrollPatterns)); - OwningRewritePatternList canonicalizationPatterns1; + OwningRewritePatternList canonicalizationPatterns1(&getContext()); vector::populateVectorToVectorCanonicalizationPatterns( - canonicalizationPatterns1, funcOp.getContext()); + canonicalizationPatterns1); vector::populateVectorToVectorTransformationPatterns( - canonicalizationPatterns1, funcOp.getContext()); + canonicalizationPatterns1); (void)applyPatternsAndFoldGreedily(funcOp, std::move(canonicalizationPatterns1)); - OwningRewritePatternList canonicalizationPatterns2; - vector::populateVectorSlicesLoweringPatterns(canonicalizationPatterns2, - funcOp.getContext()); + OwningRewritePatternList canonicalizationPatterns2(&getContext()); + vector::populateVectorSlicesLoweringPatterns(canonicalizationPatterns2); (void)applyPatternsAndFoldGreedily(funcOp, std::move(canonicalizationPatterns2)); diff --git a/iree/compiler/Conversion/LinalgToLLVM/PlanConvLoopOrder.cpp b/iree/compiler/Conversion/LinalgToLLVM/PlanConvLoopOrder.cpp index b6a596b99c3d..3f87a4a39359 100644 --- a/iree/compiler/Conversion/LinalgToLLVM/PlanConvLoopOrder.cpp +++ b/iree/compiler/Conversion/LinalgToLLVM/PlanConvLoopOrder.cpp @@ -55,9 +55,8 @@ void PlanConvLoopOrderPass::runOnFunction() { /*output_channel=*/3, }; - OwningRewritePatternList patterns; - linalg::populateLinalgConvGeneralizationPatterns(context, patterns, - firstStepMarker); + OwningRewritePatternList patterns(&getContext()); + linalg::populateLinalgConvGeneralizationPatterns(patterns, firstStepMarker); patterns.insert>( context, loopOrder, secondStepMarker); diff --git a/iree/compiler/Conversion/LinalgToLLVM/UnfuseFMAOps.cpp b/iree/compiler/Conversion/LinalgToLLVM/UnfuseFMAOps.cpp index 9890cf716b8a..d2b0243c3ae8 100644 --- a/iree/compiler/Conversion/LinalgToLLVM/UnfuseFMAOps.cpp +++ b/iree/compiler/Conversion/LinalgToLLVM/UnfuseFMAOps.cpp @@ -58,7 +58,7 @@ void populateUnfusedFMAOpsPassPatterns(MLIRContext *context, void UnfusedFMAOpsPass::runOnFunction() { auto funcOp = getOperation(); auto context = funcOp.getContext(); - OwningRewritePatternList patterns; + OwningRewritePatternList patterns(&getContext()); populateUnfusedFMAOpsPassPatterns(context, patterns); (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); } diff --git a/iree/compiler/Conversion/LinalgToLLVM/test/matmul_vectorization.mlir b/iree/compiler/Conversion/LinalgToLLVM/test/matmul_vectorization.mlir index b64ef0a4939d..98e9489e67f9 100644 --- a/iree/compiler/Conversion/LinalgToLLVM/test/matmul_vectorization.mlir +++ b/iree/compiler/Conversion/LinalgToLLVM/test/matmul_vectorization.mlir @@ -52,14 +52,14 @@ hal.executable @dynamic_matmul attributes {sym_visibility = "private"} { // CHECK-PROMOTED: #[[MAP1:.+]] = affine_map<(d0, d1)[s0] -> (d0 * 128 + s0 + d1)> // CHECK-PROMOTED: func @matmul_128x128x128 // CHECK-PROMOTED: (%[[ARG0:.+]]: memref<128x128xf32>, %[[ARG1:.+]]: memref<128x128xf32>, %[[ARG2:.+]]: memref<128x128xf32>) { -// CHECK-PROMOTED: %[[KDIM_SIZE:.+]] = constant 128 : index -// CHECK-PROMOTED: %[[WORGKROUP_SIZE:.+]] = constant 64 : index -// CHECK-PROMOTED: %[[VECTOR_SIZE:.+]] = constant 4 : index -// CHECK-PROMOTED: %[[L1_SIZE:.+]] = constant 32 : index -// CHECK-PROMOTED: %[[START:.+]] = constant 0 : index -// CHECK-PROMOTED: %[[C1:.+]] = constant 1 : index -// CHECK-PROMOTED: %[[C1:.+]] = constant 2 : index -// CHECK-PROMOTED: %[[C1:.+]] = constant 3 : index +// CHECK-PROMOTED-DAG: %[[KDIM_SIZE:.+]] = constant 128 : index +// CHECK-PROMOTED-DAG: %[[WORGKROUP_SIZE:.+]] = constant 64 : index +// CHECK-PROMOTED-DAG: %[[VECTOR_SIZE:.+]] = constant 4 : index +// CHECK-PROMOTED-DAG: %[[L1_SIZE:.+]] = constant 32 : index +// CHECK-PROMOTED-DAG: %[[START:.+]] = constant 0 : index +// CHECK-PROMOTED-DAG: %[[C1:.+]] = constant 1 : index +// CHECK-PROMOTED-DAG: %[[C1:.+]] = constant 2 : index +// CHECK-PROMOTED-DAG: %[[C1:.+]] = constant 3 : index // CHECK-PROMOTED: %[[A_PROMOTED_TILE:.+]] = memref.alloca() : memref<64x64xf32> // CHECK-PROMOTED: %[[B_PROMOTED_TILE:.+]] = memref.alloca() : memref<128x64xf32> // CHECK-PROMOTED: %[[C_PROMOTED_TILE:.+]] = memref.alloca() : memref<64x128xf32> diff --git a/iree/compiler/Conversion/LinalgToNVVM/ConvertToNVVM.cpp b/iree/compiler/Conversion/LinalgToNVVM/ConvertToNVVM.cpp index 786fa3156a31..75708aef2887 100644 --- a/iree/compiler/Conversion/LinalgToNVVM/ConvertToNVVM.cpp +++ b/iree/compiler/Conversion/LinalgToNVVM/ConvertToNVVM.cpp @@ -184,12 +184,12 @@ struct ConvertToNVVMPass // which need to be lowered further, which is not supported by a single // conversion pass. { - OwningRewritePatternList patterns; - populateGpuRewritePatterns(m.getContext(), patterns); + OwningRewritePatternList patterns(&getContext()); + populateGpuRewritePatterns(patterns); (void)applyPatternsAndFoldGreedily(m, std::move(patterns)); } { - OwningRewritePatternList llvmPatterns; + OwningRewritePatternList llvmPatterns(&getContext()); llvmPatterns.insert(m.getContext(), converter); llvmPatterns diff --git a/iree/compiler/Conversion/LinalgToSPIRV/ConcretizeTileAmongWorkgroupsPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/ConcretizeTileAmongWorkgroupsPass.cpp index bda0a1372ccd..02f853beceac 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/ConcretizeTileAmongWorkgroupsPass.cpp +++ b/iree/compiler/Conversion/LinalgToSPIRV/ConcretizeTileAmongWorkgroupsPass.cpp @@ -466,7 +466,7 @@ class ConcretizeTileAmongWorkgroupsPass // 4. Replace hal.interface.workgroup symbolic ops with constant values. { - OwningRewritePatternList patterns; + OwningRewritePatternList patterns(&getContext()); patterns.insert( &context, workloadSize, tileSize); @@ -534,7 +534,7 @@ class ConcretizeTileAmongWorkgroupsPass // 6. Canonicalization and clean up. if (inlineTripOneLoops) { - OwningRewritePatternList patterns; + OwningRewritePatternList patterns(&getContext()); patterns.insert(&context, workloadSize, tileSize); (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); diff --git a/iree/compiler/Conversion/LinalgToSPIRV/ConvertToGPUPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/ConvertToGPUPass.cpp index bee5dd7f79b5..172e9a95d1f0 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/ConvertToGPUPass.cpp +++ b/iree/compiler/Conversion/LinalgToSPIRV/ConvertToGPUPass.cpp @@ -824,7 +824,7 @@ void ConvertToGPUPass::runOnOperation() { // Let the rest fall through. target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); - OwningRewritePatternList patterns; + OwningRewritePatternList patterns(&getContext()); patterns.insert< MapLinalgOpToGlobalInvocationId, @@ -845,7 +845,7 @@ void ConvertToGPUPass::runOnOperation() { MapLinalgOpToLocalInvocationId, RemoveLinalgRange, SerializeParallelLoopPattern>( context, options.usingLinalgOnTensors); - FrozenRewritePatternList frozenPatterns(std::move(patterns)); + FrozenRewritePatternSet frozenPatterns(std::move(patterns)); for (FuncOp funcOp : getOperation().getInnerModule().getOps()) { if (!isEntryPoint(funcOp)) continue; diff --git a/iree/compiler/Conversion/LinalgToSPIRV/ConvertToSPIRVPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/ConvertToSPIRVPass.cpp index e4808a3dcff3..7b55c8e8657f 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/ConvertToSPIRVPass.cpp +++ b/iree/compiler/Conversion/LinalgToSPIRV/ConvertToSPIRVPass.cpp @@ -539,27 +539,25 @@ void ConvertToSPIRVPass::runOnOperation() { SPIRVTypeConverter typeConverter(targetAttr); ScfToSPIRVContext scfToSPIRVContext; - OwningRewritePatternList patterns; + OwningRewritePatternList patterns(&getContext()); // Pull in GPU patterns to convert processor ID ops and loop ops. - populateGPUToSPIRVPatterns(context, typeConverter, patterns); + populateGPUToSPIRVPatterns(typeConverter, patterns); // Pull in SCF patterns to convert control flow ops. - populateSCFToSPIRVPatterns(context, typeConverter, scfToSPIRVContext, - patterns); + populateSCFToSPIRVPatterns(typeConverter, scfToSPIRVContext, patterns); // Pull in standard patterns to convert arithmetic ops and others. - populateStandardToSPIRVPatterns(context, typeConverter, patterns); + populateStandardToSPIRVPatterns(typeConverter, patterns); // Pull in standard patterns to convert tensor operations to SPIR-V. These are // primarily used to handle tensor-type constants and contain a // threshold. Only those constants that are below the threshold are converted // to SPIR-V. In IREE we want to control this threshold at Flow level. So set // this value arbitrarily high to make sure that everything within a dispatch // region is converted. - mlir::populateTensorToSPIRVPatterns(context, typeConverter, - std::numeric_limits::max() / 8, - patterns); + mlir::populateTensorToSPIRVPatterns( + typeConverter, std::numeric_limits::max() / 8, patterns); // Pull in vector patterns to convert vector ops. - mlir::populateVectorToSPIRVPatterns(context, typeConverter, patterns); + mlir::populateVectorToSPIRVPatterns(typeConverter, patterns); // Pull in builtin func to spv.func conversion. - populateBuiltinFuncToSPIRVPatterns(context, typeConverter, patterns); + populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns); auto &cooperativeMatrixAnalysis = getAnalysis(); populateVectorToSPIRVPatterns(context, typeConverter, patterns, cooperativeMatrixAnalysis); @@ -593,7 +591,7 @@ void ConvertToSPIRVPass::runOnOperation() { functions.push_back(fn); } - FrozenRewritePatternList frozenPatterns(std::move(patterns)); + FrozenRewritePatternSet frozenPatterns(std::move(patterns)); for (FuncOp fn : functions) if (failed(applyFullConversion(fn, *target, frozenPatterns))) return signalPassFailure(); diff --git a/iree/compiler/Conversion/LinalgToSPIRV/FoldGPUProcessorIDUses.cpp b/iree/compiler/Conversion/LinalgToSPIRV/FoldGPUProcessorIDUses.cpp index 323832d9e707..1e35b1388fdf 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/FoldGPUProcessorIDUses.cpp +++ b/iree/compiler/Conversion/LinalgToSPIRV/FoldGPUProcessorIDUses.cpp @@ -275,7 +275,7 @@ struct FoldGPUProcessIDUsesPass void runOnOperation() override { MLIRContext *context = &getContext(); - OwningRewritePatternList patterns; + OwningRewritePatternList patterns(&getContext()); populateFoldGPUProcessorIDUsesPatterns(context, patterns); (void)applyPatternsAndFoldGreedily(getOperation().getInnerModule(), std::move(patterns)); diff --git a/iree/compiler/Conversion/LinalgToSPIRV/TileAndVectorizeInOneWorkgroupPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/TileAndVectorizeInOneWorkgroupPass.cpp index 431883273003..c5460f280811 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/TileAndVectorizeInOneWorkgroupPass.cpp +++ b/iree/compiler/Conversion/LinalgToSPIRV/TileAndVectorizeInOneWorkgroupPass.cpp @@ -308,7 +308,7 @@ static void populateVectorizationPatterns(MLIRContext *context, OwningRewritePatternList &patterns) { linalg::insertVectorizationPatterns( - patterns, context, linalg::LinalgVectorizationOptions(), + patterns, linalg::LinalgVectorizationOptions(), linalg::LinalgTransformationFilter( Identifier::get(getVectorizeMarker(), context))); } @@ -330,23 +330,21 @@ static void populateVectorUnrollPatterns(MLIRContext *context, static void applyVectorTransformation(FuncOp funcOp) { { - OwningRewritePatternList vectorUnrollPatterns; + OwningRewritePatternList vectorUnrollPatterns(funcOp.getContext()); populateVectorUnrollPatterns(funcOp.getContext(), vectorUnrollPatterns); (void)applyPatternsAndFoldGreedily(funcOp, std::move(vectorUnrollPatterns)); - OwningRewritePatternList canonicalizationPatterns1; + OwningRewritePatternList canonicalizationPatterns1(funcOp.getContext()); vector::populateVectorToVectorCanonicalizationPatterns( - canonicalizationPatterns1, funcOp.getContext()); + canonicalizationPatterns1); vector::populateVectorToVectorTransformationPatterns( - canonicalizationPatterns1, funcOp.getContext()); - vector::populateSplitVectorTransferPatterns(canonicalizationPatterns1, - funcOp.getContext()); + canonicalizationPatterns1); + vector::populateSplitVectorTransferPatterns(canonicalizationPatterns1); (void)applyPatternsAndFoldGreedily(funcOp, std::move(canonicalizationPatterns1)); - OwningRewritePatternList canonicalizationPatterns2; - vector::populateVectorSlicesLoweringPatterns(canonicalizationPatterns2, - funcOp.getContext()); + OwningRewritePatternList canonicalizationPatterns2(funcOp.getContext()); + vector::populateVectorSlicesLoweringPatterns(canonicalizationPatterns2); (void)applyPatternsAndFoldGreedily(funcOp, std::move(canonicalizationPatterns2)); LLVM_DEBUG({ @@ -451,7 +449,7 @@ void TileAndVectorizeInOneWorkgroupPass::runOnOperation() { // The promotion patterns are put separate from the tiling patterns to // make sure that the allocated scratchspace memory is constant sizes // which requires some folding to trigger. - OwningRewritePatternList promotionPatterns; + OwningRewritePatternList promotionPatterns(&getContext()); populatePromotionPatterns(context, promotionPatterns); (void)applyPatternsAndFoldGreedily(funcOp, std::move(promotionPatterns)); applyCanonicalizationPatternsForTiling(context, funcOp); @@ -465,7 +463,7 @@ void TileAndVectorizeInOneWorkgroupPass::runOnOperation() { if (launchConfig.useVectorize()) { { - OwningRewritePatternList secondLevelTilingPatterns; + OwningRewritePatternList secondLevelTilingPatterns(&getContext()); populateTilingToSubgroupPatterns(context, launchConfig, secondLevelTilingPatterns); (void)applyPatternsAndFoldGreedily( @@ -481,7 +479,7 @@ void TileAndVectorizeInOneWorkgroupPass::runOnOperation() { } { - OwningRewritePatternList thirdLevelTilingPatterns; + OwningRewritePatternList thirdLevelTilingPatterns(&getContext()); populateTilingToInvocationPatterns(context, launchConfig, thirdLevelTilingPatterns); (void)applyPatternsAndFoldGreedily(funcOp, @@ -497,7 +495,7 @@ void TileAndVectorizeInOneWorkgroupPass::runOnOperation() { } { - OwningRewritePatternList tilingPatterns; + OwningRewritePatternList tilingPatterns(&getContext()); auto marker = getLinalgMatchAndReplaceMarker( getConvFilterTileMarker(), getVectorizeMarker(), context); populateTilingConvFilterPatterns(context, tilingPatterns, launchConfig, @@ -516,7 +514,7 @@ void TileAndVectorizeInOneWorkgroupPass::runOnOperation() { } { - OwningRewritePatternList vectorizationPatterns; + OwningRewritePatternList vectorizationPatterns(&getContext()); populateVectorizationPatterns(context, launchConfig, vectorizationPatterns); populateVectorizeLinalgConvPatterns(context, vectorizationPatterns); @@ -556,9 +554,8 @@ void TileAndVectorizeInOneWorkgroupPass::runOnOperation() { linalg::DepthwiseConvInputNHWCFilterHWCOp>(op)); }); - OwningRewritePatternList patterns; - linalg::populateLinalgNamedOpsGeneralizationPatterns(context, patterns, - marker); + OwningRewritePatternList patterns(&getContext()); + linalg::populateLinalgNamedOpsGeneralizationPatterns(patterns, marker); (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); diff --git a/iree/compiler/Conversion/LinalgToSPIRV/VectorToGPUPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/VectorToGPUPass.cpp index 79d64cd2c5a4..e07559164d49 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/VectorToGPUPass.cpp +++ b/iree/compiler/Conversion/LinalgToSPIRV/VectorToGPUPass.cpp @@ -180,7 +180,7 @@ void ConvertVectorToGPUPass::tileAndVectorizeLinalgCopy(FuncOp funcOp, return !(hasMarker(copy, getCopyToWorkgroupMemoryMarker())); }); target->markUnknownOpDynamicallyLegal([](Operation *) { return true; }); - OwningRewritePatternList tileAndDistributePattern; + OwningRewritePatternList tileAndDistributePattern(&getContext()); populateLinalgTileAndDistributePatterns(context, tileAndDistributePattern); if (failed(applyPartialConversion(funcOp, *target, std::move(tileAndDistributePattern)))) { @@ -196,9 +196,9 @@ void ConvertVectorToGPUPass::tileAndVectorizeLinalgCopy(FuncOp funcOp, (void)applyPatternsAndFoldGreedily(funcOp, std::move(canonicalizePatterns)); // 3. Vectorize the tiled linalg to be able to map it to load/store vector. - OwningRewritePatternList vectorizationPatterns; + OwningRewritePatternList vectorizationPatterns(&getContext()); linalg::insertVectorizationPatterns( - vectorizationPatterns, context, linalg::LinalgVectorizationOptions(), + vectorizationPatterns, linalg::LinalgVectorizationOptions(), linalg::LinalgTransformationFilter( Identifier::get(getVectorizeMarker(), context), {})); (void)applyPatternsAndFoldGreedily(funcOp, std::move(vectorizationPatterns)); @@ -366,7 +366,7 @@ class ExtractStridedLowering // Lower vector ops to instructions that can be later converted to SPIR-V. void ConvertVectorToGPUPass::lowerVectorOps(FuncOp funcOp, MLIRContext *context) { - OwningRewritePatternList patterns; + OwningRewritePatternList patterns(&getContext()); patterns.insert(context); @@ -381,7 +381,7 @@ void ConvertVectorToGPUPass::runOnOperation() { lowerVectorOps(funcOp, context); auto &cooperativeMatrixAnalysis = getAnalysis(); - OwningRewritePatternList patterns; + OwningRewritePatternList patterns(&getContext()); patterns.insert, VectorTransferReadConversion, VectorTransferWriteConversion>(context, cooperativeMatrixAnalysis); diff --git a/iree/compiler/Conversion/LinalgToSPIRV/VectorizeMemref.cpp b/iree/compiler/Conversion/LinalgToSPIRV/VectorizeMemref.cpp index 8a6fbe0a6322..c968c9aec776 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/VectorizeMemref.cpp +++ b/iree/compiler/Conversion/LinalgToSPIRV/VectorizeMemref.cpp @@ -440,7 +440,7 @@ void VectorizeMemRefPass::runOnOperation() { memrefUsageAnalysis = &getAnalysis(); - OwningRewritePatternList patterns; + OwningRewritePatternList patterns(&getContext()); patterns.insert( context, *memrefUsageAnalysis); diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/materialize_launch_configuration.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/materialize_launch_configuration.mlir index 5fc2981330d4..07cec3eab73c 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/test/materialize_launch_configuration.mlir +++ b/iree/compiler/Conversion/LinalgToSPIRV/test/materialize_launch_configuration.mlir @@ -64,10 +64,10 @@ hal.executable @matmul_tensors attributes {sym_visibility = "private"} { // CHECK-DAG: %[[WGY:.+]] = affine.apply #[[MAP1]]()[%[[ARG1]]] // CHECK: hal.return %[[WGX]], %[[WGY]], %[[C1]] // CHECK-NOT: hal.interface.workgroup.size -// CHECK-DAG: %[[C0:.+]] = constant 0 -// CHECK-DAG: %[[C1:.+]] = constant 1 -// CHECK-DAG: %[[C16:.+]] = constant 16 -// CHECK-DAG: %[[C8:.+]] = constant 8 +// CHECK-DAG: %[[C0:.+]] = constant 0 : index +// CHECK-DAG: %[[C1:.+]] = constant 1 : index +// CHECK-DAG: %[[C16:.+]] = constant 16 : index +// CHECK-DAG: %[[C8:.+]] = constant 8 : index // CHECK-DAG: %[[LHS:.+]] = hal.interface.binding.subspan @legacy_io::@arg0 // CHECK-DAG: %[[RHS:.+]] = hal.interface.binding.subspan @legacy_io::@arg1 // CHECK-DAG: %[[INIT:.+]] = hal.interface.binding.subspan @legacy_io::@arg2 diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/vector_to_gpu.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/vector_to_gpu.mlir index 46f667dda78c..6c9ad186e039 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/test/vector_to_gpu.mlir +++ b/iree/compiler/Conversion/LinalgToSPIRV/test/vector_to_gpu.mlir @@ -33,9 +33,9 @@ module attributes {gpu.container_module, spv.target_env = #spv.target_env<#spv.v } // CHECK: #[[MAP1:.+]] = affine_map<(d0) -> (d0 * 4)> - // CHECK: %[[C1024:.+]] = constant 1024 : index - // CHECK: %[[C8:.+]] = constant 8 : index - // CHECK: %[[C0:.+]] = constant 0 : index + // CHECK-DAG: %[[C1024:.+]] = constant 1024 : index + // CHECK-DAG: %[[C8:.+]] = constant 8 : index + // CHECK-DAG: %[[C0:.+]] = constant 0 : index // CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<128x32xf32, 3> // CHECK: %[[DST:.+]] = memref.subview %{{.+}}[0, 0] [128, 32] [1, 1] : memref<4096x4096xf32> to memref<128x32xf32, #map0> // CHECK: %[[TIDx:.+]] = "gpu.thread_id"() {dimension = "x"} : () -> index diff --git a/iree/compiler/Conversion/LinalgToVector/LoadStoreVectorization.cpp b/iree/compiler/Conversion/LinalgToVector/LoadStoreVectorization.cpp index 0476a21e1e62..221d9711a128 100644 --- a/iree/compiler/Conversion/LinalgToVector/LoadStoreVectorization.cpp +++ b/iree/compiler/Conversion/LinalgToVector/LoadStoreVectorization.cpp @@ -230,7 +230,7 @@ struct LoadStoreVectorizationPass void runOnOperation() override { MLIRContext *context = &getContext(); - OwningRewritePatternList patterns; + OwningRewritePatternList patterns(&getContext()); // clang-format off patterns.insert< VectorizeGenericOp, diff --git a/iree/compiler/Conversion/LinalgToVector/VectorizeConv.cpp b/iree/compiler/Conversion/LinalgToVector/VectorizeConv.cpp index 9e0ee62abb26..14ad4bad5f66 100644 --- a/iree/compiler/Conversion/LinalgToVector/VectorizeConv.cpp +++ b/iree/compiler/Conversion/LinalgToVector/VectorizeConv.cpp @@ -347,7 +347,7 @@ struct VectorizeLinalgConvPass void runOnOperation() override { MLIRContext *context = &getContext(); - OwningRewritePatternList patterns; + OwningRewritePatternList patterns(&getContext()); patterns.insert(context); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } diff --git a/iree/compiler/Conversion/LinalgToVector/test/vectorize_linalg_conv.mlir b/iree/compiler/Conversion/LinalgToVector/test/vectorize_linalg_conv.mlir index e92ec490df82..69e79db29b61 100644 --- a/iree/compiler/Conversion/LinalgToVector/test/vectorize_linalg_conv.mlir +++ b/iree/compiler/Conversion/LinalgToVector/test/vectorize_linalg_conv.mlir @@ -1,9 +1,10 @@ // RUN: iree-opt -split-input-file -iree-codegen-vectorize-linalg-conv -canonicalize -cse %s | IreeFileCheck %s -func @vectorize_conv(%filter: memref<1x1x3x4xf32>, %input: memref<1x2x2x3xf32>, %output: memref<1x2x2x4xf32>) { - %0 = memref.subview %filter[0, 0, 0, 0] [1, 1, 3, 4] [1, 1, 1, 1] : memref<1x1x3x4xf32> to memref<1x1x3x4xf32> - %1 = memref.subview %input[0, 0, 0, 0] [1, 2, 2, 3] [1, 1, 1, 1] : memref<1x2x2x3xf32> to memref<1x2x2x3xf32> - %2 = memref.subview %output[0, 0, 0, 0] [1, 2, 2, 4] [1, 1, 1, 1] : memref<1x2x2x4xf32> to memref<1x2x2x4xf32> +// Passing bigger buffers to avoid memref.subview fold awawy. +func @vectorize_conv(%filter: memref<2x1x3x4xf32>, %input: memref<2x2x2x3xf32>, %output: memref<2x2x2x4xf32>) { + %0 = memref.subview %filter[0, 0, 0, 0] [1, 1, 3, 4] [1, 1, 1, 1] : memref<2x1x3x4xf32> to memref<1x1x3x4xf32> + %1 = memref.subview %input[0, 0, 0, 0] [1, 2, 2, 3] [1, 1, 1, 1] : memref<2x2x2x3xf32> to memref<1x2x2x3xf32> + %2 = memref.subview %output[0, 0, 0, 0] [1, 2, 2, 4] [1, 1, 1, 1] : memref<2x2x2x4xf32> to memref<1x2x2x4xf32> linalg.conv_2d_input_nhwc_filter_hwcf {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins (%1, %0: memref<1x2x2x3xf32>, memref<1x1x3x4xf32>) outs (%2: memref<1x2x2x4xf32>) @@ -15,69 +16,74 @@ func @vectorize_conv(%filter: memref<1x1x3x4xf32>, %input: memref<1x2x2x3xf32>, // CHECK: #map2 = affine_map<(d0, d1, d2) -> (d0, d1)> // CHECK: func @vectorize_conv -// CHECK-SAME: %[[FILTER_ARG:.+]]: memref<1x1x3x4xf32>, -// CHECK-SAME: %[[INPUT_ARG:.+]]: memref<1x2x2x3xf32>, -// CHECK-SAME: %[[OUTPUT_ARG:.+]]: memref<1x2x2x4xf32> +// CHECK-SAME: %[[FILTER_ARG:.+]]: memref<2x1x3x4xf32>, +// CHECK-SAME: %[[INPUT_ARG:.+]]: memref<2x2x2x3xf32>, +// CHECK-SAME: %[[OUTPUT_ARG:.+]]: memref<2x2x2x4xf32> // CHECK: %[[FLOAT_ZERO:.+]] = constant 0.000000e+00 : f32 +// CHECK-DAG: %[[FILTER_SUBVIEW:.+]] = memref.subview %[[FILTER_ARG]]{{.*}} to memref<1x1x3x4xf32> +// CHECK-DAG: %[[INPUT_SUBVIEW:.+]] = memref.subview %[[INPUT_ARG]]{{.*}} to memref<1x2x2x3xf32> +// CHECK-DAG: %[[OUTPUT_SUBVIEW:.+]] = memref.subview %[[OUTPUT_ARG]]{{.*}} to memref<1x2x2x4xf32> + // Read in the filter and get slices -// CHECK: %[[FILTER_VECTOR:.+]] = vector.transfer_read %[[FILTER_ARG]][%c0, %c0, %c0, %c0], %cst {masked = [false, false]} : memref<1x1x3x4xf32>, vector<3x4xf32> +// CHECK: %[[FILTER_VECTOR:.+]] = vector.transfer_read %[[FILTER_SUBVIEW]][%c0, %c0, %c0, %c0], %cst {masked = [false, false]} : memref<1x1x3x4xf32>, vector<3x4xf32> // CHECK: %[[FILTER_0:.+]] = vector.extract_strided_slice %[[FILTER_VECTOR]] {offsets = [0, 0], sizes = [1, 4], strides = [1, 1]} : vector<3x4xf32> to vector<1x4xf32> // CHECK: %[[FILTER_1:.+]] = vector.extract_strided_slice %[[FILTER_VECTOR]] {offsets = [1, 0], sizes = [1, 4], strides = [1, 1]} : vector<3x4xf32> to vector<1x4xf32> // CHECK: %[[FILTER_2:.+]] = vector.extract_strided_slice %[[FILTER_VECTOR]] {offsets = [2, 0], sizes = [1, 4], strides = [1, 1]} : vector<3x4xf32> to vector<1x4xf32> // Handle batch #0 -// CHECK: %[[INPUT_0:.+]] = vector.transfer_read %[[INPUT_ARG]][%c0, %c0, %c0, %c0], %[[FLOAT_ZERO]] {masked = [false, false]} : memref<1x2x2x3xf32>, vector<1x3xf32> -// CHECK: %[[OUTPUT_0:.+]] = vector.transfer_read %[[OUTPUT_ARG]][%c0, %c0, %c0, %c0], %[[FLOAT_ZERO]] {masked = [false, false]} : memref<1x2x2x4xf32>, vector<1x4xf32> +// CHECK: %[[INPUT_0:.+]] = vector.transfer_read %[[INPUT_SUBVIEW]][%c0, %c0, %c0, %c0], %[[FLOAT_ZERO]] {masked = [false, false]} : memref<1x2x2x3xf32>, vector<1x3xf32> +// CHECK: %[[OUTPUT_0:.+]] = vector.transfer_read %[[OUTPUT_SUBVIEW]][%c0, %c0, %c0, %c0], %[[FLOAT_ZERO]] {masked = [false, false]} : memref<1x2x2x4xf32>, vector<1x4xf32> // CHECK: %[[INPUT_0_0:.+]] = vector.extract_strided_slice %[[INPUT_0]] {offsets = [0, 0], sizes = [1, 1], strides = [1, 1]} : vector<1x3xf32> to vector<1x1xf32> // CHECK: %[[DOT_0:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[INPUT_0_0]], %[[FILTER_0]], %[[OUTPUT_0]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32> // CHECK: %[[INPUT_0_1:.+]] = vector.extract_strided_slice %[[INPUT_0]] {offsets = [0, 1], sizes = [1, 1], strides = [1, 1]} : vector<1x3xf32> to vector<1x1xf32> // CHECK: %[[DOT_1:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[INPUT_0_1]], %[[FILTER_1]], %[[DOT_0]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32> // CHECK: %[[INPUT_0_2:.+]] = vector.extract_strided_slice %[[INPUT_0]] {offsets = [0, 2], sizes = [1, 1], strides = [1, 1]} : vector<1x3xf32> to vector<1x1xf32> // CHECK: %[[DOT_2:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[INPUT_0_2]], %[[FILTER_2]], %[[DOT_1]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32> -// CHECK: vector.transfer_write %[[DOT_2]], %[[OUTPUT_ARG]][%c0, %c0, %c0, %c0] {masked = [false, false]} : vector<1x4xf32>, memref<1x2x2x4xf32> +// CHECK: vector.transfer_write %[[DOT_2]], %[[OUTPUT_SUBVIEW]][%c0, %c0, %c0, %c0] {masked = [false, false]} : vector<1x4xf32>, memref<1x2x2x4xf32> // Handle batch #1 -// CHECK: %[[INPUT_1:.+]] = vector.transfer_read %[[INPUT_ARG]][%c0, %c0, %c1, %c0], %[[FLOAT_ZERO]] {masked = [false, false]} : memref<1x2x2x3xf32>, vector<1x3xf32> -// CHECK: %[[OUTPUT_1:.+]] = vector.transfer_read %[[OUTPUT_ARG]][%c0, %c0, %c1, %c0], %[[FLOAT_ZERO]] {masked = [false, false]} : memref<1x2x2x4xf32>, vector<1x4xf32> +// CHECK: %[[INPUT_1:.+]] = vector.transfer_read %[[INPUT_SUBVIEW]][%c0, %c0, %c1, %c0], %[[FLOAT_ZERO]] {masked = [false, false]} : memref<1x2x2x3xf32>, vector<1x3xf32> +// CHECK: %[[OUTPUT_1:.+]] = vector.transfer_read %[[OUTPUT_SUBVIEW]][%c0, %c0, %c1, %c0], %[[FLOAT_ZERO]] {masked = [false, false]} : memref<1x2x2x4xf32>, vector<1x4xf32> // CHECK: %[[INPUT_1_0:.+]] = vector.extract_strided_slice %[[INPUT_1]] {offsets = [0, 0], sizes = [1, 1], strides = [1, 1]} : vector<1x3xf32> to vector<1x1xf32> // CHECK: %[[DOT_0:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[INPUT_1_0]], %[[FILTER_0]], %[[OUTPUT_1]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32> // CHECK: %[[INPUT_1_1:.+]] = vector.extract_strided_slice %[[INPUT_1]] {offsets = [0, 1], sizes = [1, 1], strides = [1, 1]} : vector<1x3xf32> to vector<1x1xf32> // CHECK: %[[DOT_1:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[INPUT_1_1]], %[[FILTER_1]], %[[DOT_0]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32> // CHECK: %[[INPUT_1_2:.+]] = vector.extract_strided_slice %[[INPUT_1]] {offsets = [0, 2], sizes = [1, 1], strides = [1, 1]} : vector<1x3xf32> to vector<1x1xf32> // CHECK: %[[DOT_2:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[INPUT_1_2]], %[[FILTER_2]], %[[DOT_1]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32> -// CHECK: vector.transfer_write %[[DOT_2]], %[[OUTPUT_ARG]][%c0, %c0, %c1, %c0] {masked = [false, false]} : vector<1x4xf32>, memref<1x2x2x4xf32> +// CHECK: vector.transfer_write %[[DOT_2]], %[[OUTPUT_SUBVIEW]][%c0, %c0, %c1, %c0] {masked = [false, false]} : vector<1x4xf32>, memref<1x2x2x4xf32> // Handle batch #2 -// CHECK: %[[INPUT_2:.+]] = vector.transfer_read %[[INPUT_ARG]][%c0, %c1, %c0, %c0], %[[FLOAT_ZERO]] {masked = [false, false]} : memref<1x2x2x3xf32>, vector<1x3xf32> -// CHECK: %[[OUTPUT_2:.+]] = vector.transfer_read %[[OUTPUT_ARG]][%c0, %c1, %c0, %c0], %[[FLOAT_ZERO]] {masked = [false, false]} : memref<1x2x2x4xf32>, vector<1x4xf32> +// CHECK: %[[INPUT_2:.+]] = vector.transfer_read %[[INPUT_SUBVIEW]][%c0, %c1, %c0, %c0], %[[FLOAT_ZERO]] {masked = [false, false]} : memref<1x2x2x3xf32>, vector<1x3xf32> +// CHECK: %[[OUTPUT_2:.+]] = vector.transfer_read %[[OUTPUT_SUBVIEW]][%c0, %c1, %c0, %c0], %[[FLOAT_ZERO]] {masked = [false, false]} : memref<1x2x2x4xf32>, vector<1x4xf32> // CHECK: %[[INPUT_2_0:.+]] = vector.extract_strided_slice %[[INPUT_2]] {offsets = [0, 0], sizes = [1, 1], strides = [1, 1]} : vector<1x3xf32> to vector<1x1xf32> // CHECK: %[[DOT_0:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[INPUT_2_0]], %[[FILTER_0]], %[[OUTPUT_2]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32> // CHECK: %[[INPUT_2_1:.+]] = vector.extract_strided_slice %[[INPUT_2]] {offsets = [0, 1], sizes = [1, 1], strides = [1, 1]} : vector<1x3xf32> to vector<1x1xf32> // CHECK: %[[DOT_1:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[INPUT_2_1]], %[[FILTER_1]], %[[DOT_0]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32> // CHECK: %[[INPUT_2_2:.+]] = vector.extract_strided_slice %[[INPUT_2]] {offsets = [0, 2], sizes = [1, 1], strides = [1, 1]} : vector<1x3xf32> to vector<1x1xf32> // CHECK: %[[DOT_2:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[INPUT_2_2]], %[[FILTER_2]], %[[DOT_1]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32> -// CHECK: vector.transfer_write %[[DOT_2]], %[[OUTPUT_ARG]][%c0, %c1, %c0, %c0] {masked = [false, false]} : vector<1x4xf32>, memref<1x2x2x4xf32> +// CHECK: vector.transfer_write %[[DOT_2]], %[[OUTPUT_SUBVIEW]][%c0, %c1, %c0, %c0] {masked = [false, false]} : vector<1x4xf32>, memref<1x2x2x4xf32> // Handle batch #3 -// CHECK: %[[INPUT_3:.+]] = vector.transfer_read %[[INPUT_ARG]][%c0, %c1, %c1, %c0], %[[FLOAT_ZERO]] {masked = [false, false]} : memref<1x2x2x3xf32>, vector<1x3xf32> -// CHECK: %[[OUTPUT_3:.+]] = vector.transfer_read %[[OUTPUT_ARG]][%c0, %c1, %c1, %c0], %[[FLOAT_ZERO]] {masked = [false, false]} : memref<1x2x2x4xf32>, vector<1x4xf32> +// CHECK: %[[INPUT_3:.+]] = vector.transfer_read %[[INPUT_SUBVIEW]][%c0, %c1, %c1, %c0], %[[FLOAT_ZERO]] {masked = [false, false]} : memref<1x2x2x3xf32>, vector<1x3xf32> +// CHECK: %[[OUTPUT_3:.+]] = vector.transfer_read %[[OUTPUT_SUBVIEW]][%c0, %c1, %c1, %c0], %[[FLOAT_ZERO]] {masked = [false, false]} : memref<1x2x2x4xf32>, vector<1x4xf32> // CHECK: %[[INPUT_3_0:.+]] = vector.extract_strided_slice %[[INPUT_3]] {offsets = [0, 0], sizes = [1, 1], strides = [1, 1]} : vector<1x3xf32> to vector<1x1xf32> // CHECK: %[[DOT_0:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[INPUT_3_0]], %[[FILTER_0]], %[[OUTPUT_3]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32> // CHECK: %[[INPUT_3_1:.+]] = vector.extract_strided_slice %[[INPUT_3]] {offsets = [0, 1], sizes = [1, 1], strides = [1, 1]} : vector<1x3xf32> to vector<1x1xf32> // CHECK: %[[DOT_1:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[INPUT_3_1]], %[[FILTER_1]], %[[DOT_0]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32> // CHECK: %[[INPUT_3_2:.+]] = vector.extract_strided_slice %[[INPUT_3]] {offsets = [0, 2], sizes = [1, 1], strides = [1, 1]} : vector<1x3xf32> to vector<1x1xf32> // CHECK: %[[DOT_2:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[INPUT_3_2]], %[[FILTER_2]], %[[DOT_1]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32> -// CHECK: vector.transfer_write %[[DOT_2]], %[[OUTPUT_ARG]][%c0, %c1, %c1, %c0] {masked = [false, false]} : vector<1x4xf32>, memref<1x2x2x4xf32> +// CHECK: vector.transfer_write %[[DOT_2]], %[[OUTPUT_SUBVIEW]][%c0, %c1, %c1, %c0] {masked = [false, false]} : vector<1x4xf32>, memref<1x2x2x4xf32> // ----- // CHECK-LABEL: func @do_not_vectorize_conv_with_non_1_batch -func @do_not_vectorize_conv_with_non_1_batch(%filter: memref<1x1x4x4xf32>, %input: memref<2x1x7x4xf32>, %output: memref<2x1x4x4xf32>) { - %0 = memref.subview %filter[0, 0, 0, 0] [1, 1, 4, 4] [1, 1, 1, 1] : memref<1x1x4x4xf32> to memref<1x1x4x4xf32> - %1 = memref.subview %input[0, 0, 0, 0] [2, 1, 7, 4] [1, 1, 1, 1] : memref<2x1x7x4xf32> to memref<2x1x7x4xf32> - %2 = memref.subview %output[0, 0, 0, 0] [2, 1, 4, 4] [1, 1, 1, 1] : memref<2x1x4x4xf32> to memref<2x1x4x4xf32> +// Passing bigger buffers to avoid memref.subview fold awawy. +func @do_not_vectorize_conv_with_non_1_batch(%filter: memref<2x1x4x4xf32>, %input: memref<3x1x7x4xf32>, %output: memref<3x1x4x4xf32>) { + %0 = memref.subview %filter[0, 0, 0, 0] [1, 1, 4, 4] [1, 1, 1, 1] : memref<2x1x4x4xf32> to memref<1x1x4x4xf32> + %1 = memref.subview %input[0, 0, 0, 0] [2, 1, 7, 4] [1, 1, 1, 1] : memref<3x1x7x4xf32> to memref<2x1x7x4xf32> + %2 = memref.subview %output[0, 0, 0, 0] [2, 1, 4, 4] [1, 1, 1, 1] : memref<3x1x4x4xf32> to memref<2x1x4x4xf32> // CHECK: linalg.conv_2d_input_nhwc_filter_hwcf linalg.conv_2d_input_nhwc_filter_hwcf {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>} ins (%1, %0: memref<2x1x7x4xf32>, memref<1x1x4x4xf32>) @@ -88,10 +94,11 @@ func @do_not_vectorize_conv_with_non_1_batch(%filter: memref<1x1x4x4xf32>, %inpu // ----- // CHECK-LABEL: func @do_not_vectorize_conv_with_non_1_filter_height -func @do_not_vectorize_conv_with_non_1_filter_height(%filter: memref<2x1x4x4xf32>, %input: memref<1x2x7x4xf32>, %output: memref<1x1x4x4xf32>) { - %0 = memref.subview %filter[0, 0, 0, 0] [2, 1, 4, 4] [1, 1, 1, 1] : memref<2x1x4x4xf32> to memref<2x1x4x4xf32> - %1 = memref.subview %input[0, 0, 0, 0] [1, 2, 7, 4] [1, 1, 1, 1] : memref<1x2x7x4xf32> to memref<1x2x7x4xf32> - %2 = memref.subview %output[0, 0, 0, 0] [1, 1, 4, 4] [1, 1, 1, 1] : memref<1x1x4x4xf32> to memref<1x1x4x4xf32> +// Passing bigger buffers to avoid memref.subview fold awawy. +func @do_not_vectorize_conv_with_non_1_filter_height(%filter: memref<3x1x4x4xf32>, %input: memref<2x2x7x4xf32>, %output: memref<2x1x4x4xf32>) { + %0 = memref.subview %filter[0, 0, 0, 0] [2, 1, 4, 4] [1, 1, 1, 1] : memref<3x1x4x4xf32> to memref<2x1x4x4xf32> + %1 = memref.subview %input[0, 0, 0, 0] [1, 2, 7, 4] [1, 1, 1, 1] : memref<2x2x7x4xf32> to memref<1x2x7x4xf32> + %2 = memref.subview %output[0, 0, 0, 0] [1, 1, 4, 4] [1, 1, 1, 1] : memref<2x1x4x4xf32> to memref<1x1x4x4xf32> // CHECK: linalg.conv_2d_input_nhwc_filter_hwcf linalg.conv_2d_input_nhwc_filter_hwcf {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>} ins (%1, %0: memref<1x2x7x4xf32>, memref<2x1x4x4xf32>) @@ -102,10 +109,11 @@ func @do_not_vectorize_conv_with_non_1_filter_height(%filter: memref<2x1x4x4xf32 // ----- // CHECK-LABEL: func @do_not_vectorize_conv_with_non_1_filter_width -func @do_not_vectorize_conv_with_non_1_filter_width(%filter: memref<1x2x4x4xf32>, %input: memref<1x1x8x4xf32>, %output: memref<1x1x4x4xf32>) { - %0 = memref.subview %filter[0, 0, 0, 0] [1, 2, 4, 4] [1, 1, 1, 1] : memref<1x2x4x4xf32> to memref<1x2x4x4xf32> - %1 = memref.subview %input[0, 0, 0, 0] [1, 1, 8, 4] [1, 1, 1, 1] : memref<1x1x8x4xf32> to memref<1x1x8x4xf32> - %2 = memref.subview %output[0, 0, 0, 0] [1, 1, 4, 4] [1, 1, 1, 1] : memref<1x1x4x4xf32> to memref<1x1x4x4xf32> +// Passing bigger buffers to avoid memref.subview fold awawy. +func @do_not_vectorize_conv_with_non_1_filter_width(%filter: memref<2x2x4x4xf32>, %input: memref<2x1x8x4xf32>, %output: memref<2x1x4x4xf32>) { + %0 = memref.subview %filter[0, 0, 0, 0] [1, 2, 4, 4] [1, 1, 1, 1] : memref<2x2x4x4xf32> to memref<1x2x4x4xf32> + %1 = memref.subview %input[0, 0, 0, 0] [1, 1, 8, 4] [1, 1, 1, 1] : memref<2x1x8x4xf32> to memref<1x1x8x4xf32> + %2 = memref.subview %output[0, 0, 0, 0] [1, 1, 4, 4] [1, 1, 1, 1] : memref<2x1x4x4xf32> to memref<1x1x4x4xf32> // CHECK: linalg.conv_2d_input_nhwc_filter_hwcf linalg.conv_2d_input_nhwc_filter_hwcf {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>} ins (%1, %0: memref<1x1x8x4xf32>, memref<1x2x4x4xf32>) @@ -116,10 +124,11 @@ func @do_not_vectorize_conv_with_non_1_filter_width(%filter: memref<1x2x4x4xf32> // ----- // CHECK-LABEL: func @do_not_vectorize_conv_with_non_1_dilation -func @do_not_vectorize_conv_with_non_1_dilation(%filter: memref<1x1x4x4xf32>, %input: memref<1x1x7x4xf32>, %output: memref<1x1x4x4xf32>) { - %0 = memref.subview %filter[0, 0, 0, 0] [1, 1, 4, 4] [1, 1, 1, 1] : memref<1x1x4x4xf32> to memref<1x1x4x4xf32> - %1 = memref.subview %input[0, 0, 0, 0] [1, 1, 7, 4] [1, 1, 1, 1] : memref<1x1x7x4xf32> to memref<1x1x7x4xf32> - %2 = memref.subview %output[0, 0, 0, 0] [1, 1, 4, 4] [1, 1, 1, 1] : memref<1x1x4x4xf32> to memref<1x1x4x4xf32> +// Passing bigger buffers to avoid memref.subview fold awawy. +func @do_not_vectorize_conv_with_non_1_dilation(%filter: memref<2x1x4x4xf32>, %input: memref<2x1x7x4xf32>, %output: memref<2x1x4x4xf32>) { + %0 = memref.subview %filter[0, 0, 0, 0] [1, 1, 4, 4] [1, 1, 1, 1] : memref<2x1x4x4xf32> to memref<1x1x4x4xf32> + %1 = memref.subview %input[0, 0, 0, 0] [1, 1, 7, 4] [1, 1, 1, 1] : memref<2x1x7x4xf32> to memref<1x1x7x4xf32> + %2 = memref.subview %output[0, 0, 0, 0] [1, 1, 4, 4] [1, 1, 1, 1] : memref<2x1x4x4xf32> to memref<1x1x4x4xf32> // CHECK: linalg.conv_2d_input_nhwc_filter_hwcf linalg.conv_2d_input_nhwc_filter_hwcf {dilations = dense<[2, 1]> : vector<2xi64>, strides = dense<2> : vector<2xi64>} ins (%1, %0: memref<1x1x7x4xf32>, memref<1x1x4x4xf32>) @@ -129,76 +138,82 @@ func @do_not_vectorize_conv_with_non_1_dilation(%filter: memref<1x1x4x4xf32>, %i // ----- -func @vectorize_depthwise_conv(%input: memref<1x3x3x8xf32>, %filter: memref<1x1x8xf32>, %output: memref<1x2x2x8xf32>) { - %0 = memref.subview %input[0, 0, 0, 0] [1, 3, 3, 8] [1, 1, 1, 1] : memref<1x3x3x8xf32> to memref<1x3x3x8xf32> - %1 = memref.subview %filter[0, 0, 0] [1, 1, 8] [1, 1, 1] : memref<1x1x8xf32> to memref<1x1x8xf32> - %2 = memref.subview %output[0, 0, 0, 0] [1, 2, 2, 8] [1, 1, 1, 1] : memref<1x2x2x8xf32> to memref<1x2x2x8xf32> +// Passing bigger buffers to avoid memref.subview fold awawy. +func @vectorize_depthwise_conv(%input: memref<2x3x3x8xf32>, %filter: memref<2x1x8xf32>, %output: memref<2x2x2x8xf32>) { + %0 = memref.subview %input[0, 0, 0, 0] [1, 3, 3, 8] [1, 1, 1, 1] : memref<2x3x3x8xf32> to memref<1x3x3x8xf32> + %1 = memref.subview %filter[0, 0, 0] [1, 1, 8] [1, 1, 1] : memref<2x1x8xf32> to memref<1x1x8xf32> + %2 = memref.subview %output[0, 0, 0, 0] [1, 2, 2, 8] [1, 1, 1, 1] : memref<2x2x2x8xf32> to memref<1x2x2x8xf32> linalg.depthwise_conv_2d_input_nhwc_filter_hwc {strides = dense<2> : tensor<2xi64>} ins(%0, %1 : memref<1x3x3x8xf32>, memref<1x1x8xf32>) outs(%2 : memref<1x2x2x8xf32>) return } // CHECK-LABEL: func @vectorize_depthwise_conv -// CHECK-SAME: %[[INPUT_ARG:.+]]: memref<1x3x3x8xf32>, -// CHECK-SAME: %[[FILTER_ARG:.+]]: memref<1x1x8xf32>, -// CHECK-SAME: %[[OUTPUT_ARG:.+]]: memref<1x2x2x8xf32> +// CHECK-SAME: %[[INPUT_ARG:.+]]: memref<2x3x3x8xf32>, +// CHECK-SAME: %[[FILTER_ARG:.+]]: memref<2x1x8xf32>, +// CHECK-SAME: %[[OUTPUT_ARG:.+]]: memref<2x2x2x8xf32> // CHECK: %[[FLOAT_ZERO:.+]] = constant 0.000000e+00 : f32 -// CHECK: %[[FILTER_VECTOR:.+]] = vector.transfer_read %[[FILTER_ARG]][%c0, %c0, %c0], %cst {masked = [false]} : memref<1x1x8xf32>, vector<8xf32> +// CHECK-DAG: %[[INPUT_SUBVIEW:.+]] = memref.subview %[[INPUT_ARG]]{{.*}} to memref<1x3x3x8xf32> +// CHECK-DAG: %[[FILTER_SUBVIEW:.+]] = memref.subview %[[FILTER_ARG]]{{.*}} to memref<1x1x8xf32> +// CHECK-DAG: %[[OUTPUT_SUBVIEW:.+]] = memref.subview %[[OUTPUT_ARG]]{{.*}} to memref<1x2x2x8xf32> + +// CHECK: %[[FILTER_VECTOR:.+]] = vector.transfer_read %[[FILTER_SUBVIEW]][%c0, %c0, %c0], %cst {masked = [false]} : memref<1x1x8xf32>, vector<8xf32> // Common filter #0 // CHECK: %[[FILTER_0:.+]] = vector.extract_strided_slice %[[FILTER_VECTOR]] {offsets = [0], sizes = [4], strides = [1]} : vector<8xf32> to vector<4xf32> -// CHECK: %[[OUTPUT_0_0:.+]] = vector.transfer_read %[[OUTPUT_ARG]][%c0, %c0, %c0, %c0], %cst {masked = [false]} : memref<1x2x2x8xf32>, vector<4xf32> -// CHECK: %[[INPUT_0_0:.+]] = vector.transfer_read %[[INPUT_ARG]][%c0, %c0, %c0, %c0], %cst {masked = [false]} : memref<1x3x3x8xf32>, vector<4xf32> +// CHECK: %[[OUTPUT_0_0:.+]] = vector.transfer_read %[[OUTPUT_SUBVIEW]][%c0, %c0, %c0, %c0], %cst {masked = [false]} : memref<1x2x2x8xf32>, vector<4xf32> +// CHECK: %[[INPUT_0_0:.+]] = vector.transfer_read %[[INPUT_SUBVIEW]][%c0, %c0, %c0, %c0], %cst {masked = [false]} : memref<1x3x3x8xf32>, vector<4xf32> // CHECK: %[[FMA_0_0:.+]] = vector.fma %[[INPUT_0_0]], %[[FILTER_0]], %[[OUTPUT_0_0]] : vector<4xf32> -// CHECK: vector.transfer_write %[[FMA_0_0]], %[[OUTPUT_ARG]][%c0, %c0, %c0, %c0] {masked = [false]} : vector<4xf32>, memref<1x2x2x8xf32> +// CHECK: vector.transfer_write %[[FMA_0_0]], %[[OUTPUT_SUBVIEW]][%c0, %c0, %c0, %c0] {masked = [false]} : vector<4xf32>, memref<1x2x2x8xf32> -// CHECK: %[[OUTPUT_0_1:.+]] = vector.transfer_read %[[OUTPUT_ARG]][%c0, %c0, %c1, %c0], %cst {masked = [false]} : memref<1x2x2x8xf32>, vector<4xf32> -// CHECK: %[[INPUT_0_1:.+]] = vector.transfer_read %[[INPUT_ARG]][%c0, %c0, %c2, %c0], %cst {masked = [false]} : memref<1x3x3x8xf32>, vector<4xf32> +// CHECK: %[[OUTPUT_0_1:.+]] = vector.transfer_read %[[OUTPUT_SUBVIEW]][%c0, %c0, %c1, %c0], %cst {masked = [false]} : memref<1x2x2x8xf32>, vector<4xf32> +// CHECK: %[[INPUT_0_1:.+]] = vector.transfer_read %[[INPUT_SUBVIEW]][%c0, %c0, %c2, %c0], %cst {masked = [false]} : memref<1x3x3x8xf32>, vector<4xf32> // CHECK: %[[FMA_0_1:.+]] = vector.fma %[[INPUT_0_1]], %[[FILTER_0]], %[[OUTPUT_0_1]] : vector<4xf32> -// CHECK: vector.transfer_write %[[FMA_0_1]], %[[OUTPUT_ARG]][%c0, %c0, %c1, %c0] {masked = [false]} : vector<4xf32>, memref<1x2x2x8xf32> +// CHECK: vector.transfer_write %[[FMA_0_1]], %[[OUTPUT_SUBVIEW]][%c0, %c0, %c1, %c0] {masked = [false]} : vector<4xf32>, memref<1x2x2x8xf32> -// CHECK: %[[OUTPUT_1_0:.+]] = vector.transfer_read %[[OUTPUT_ARG]][%c0, %c1, %c0, %c0], %cst {masked = [false]} : memref<1x2x2x8xf32>, vector<4xf32> -// CHECK: %[[INPUT_1_0:.+]] = vector.transfer_read %[[INPUT_ARG]][%c0, %c2, %c0, %c0], %cst {masked = [false]} : memref<1x3x3x8xf32>, vector<4xf32> +// CHECK: %[[OUTPUT_1_0:.+]] = vector.transfer_read %[[OUTPUT_SUBVIEW]][%c0, %c1, %c0, %c0], %cst {masked = [false]} : memref<1x2x2x8xf32>, vector<4xf32> +// CHECK: %[[INPUT_1_0:.+]] = vector.transfer_read %[[INPUT_SUBVIEW]][%c0, %c2, %c0, %c0], %cst {masked = [false]} : memref<1x3x3x8xf32>, vector<4xf32> // CHECK: %[[FMA_1_0:.+]] = vector.fma %[[INPUT_1_0]], %[[FILTER_0]], %[[OUTPUT_1_0]] : vector<4xf32> -// CHECK: vector.transfer_write %[[FMA_1_0]], %[[OUTPUT_ARG]][%c0, %c1, %c0, %c0] {masked = [false]} : vector<4xf32>, memref<1x2x2x8xf32> +// CHECK: vector.transfer_write %[[FMA_1_0]], %[[OUTPUT_SUBVIEW]][%c0, %c1, %c0, %c0] {masked = [false]} : vector<4xf32>, memref<1x2x2x8xf32> -// CHECK: %[[OUTPUT_1_1:.+]] = vector.transfer_read %[[OUTPUT_ARG]][%c0, %c1, %c1, %c0], %cst {masked = [false]} : memref<1x2x2x8xf32>, vector<4xf32> -// CHECK: %[[INPUT_1_1:.+]] = vector.transfer_read %[[INPUT_ARG]][%c0, %c2, %c2, %c0], %cst {masked = [false]} : memref<1x3x3x8xf32>, vector<4xf32> +// CHECK: %[[OUTPUT_1_1:.+]] = vector.transfer_read %[[OUTPUT_SUBVIEW]][%c0, %c1, %c1, %c0], %cst {masked = [false]} : memref<1x2x2x8xf32>, vector<4xf32> +// CHECK: %[[INPUT_1_1:.+]] = vector.transfer_read %[[INPUT_SUBVIEW]][%c0, %c2, %c2, %c0], %cst {masked = [false]} : memref<1x3x3x8xf32>, vector<4xf32> // CHECK: %[[FMA_1_1:.+]] = vector.fma %[[INPUT_1_1]], %[[FILTER_0]], %[[OUTPUT_1_1]] : vector<4xf32> -// CHECK: vector.transfer_write %[[FMA_1_1]], %[[OUTPUT_ARG]][%c0, %c1, %c1, %c0] {masked = [false]} : vector<4xf32>, memref<1x2x2x8xf32> +// CHECK: vector.transfer_write %[[FMA_1_1]], %[[OUTPUT_SUBVIEW]][%c0, %c1, %c1, %c0] {masked = [false]} : vector<4xf32>, memref<1x2x2x8xf32> // Common filter #1 // CHECK: %[[FILTER_1:.+]] = vector.extract_strided_slice %[[FILTER_VECTOR]] {offsets = [4], sizes = [4], strides = [1]} : vector<8xf32> to vector<4xf32> -// CHECK: %[[OUTPUT_0_0:.+]] = vector.transfer_read %[[OUTPUT_ARG]][%c0, %c0, %c0, %c4], %cst {masked = [false]} : memref<1x2x2x8xf32>, vector<4xf32> -// CHECK: %[[INPUT_0_0:.+]] = vector.transfer_read %[[INPUT_ARG]][%c0, %c0, %c0, %c4], %cst {masked = [false]} : memref<1x3x3x8xf32>, vector<4xf32> +// CHECK: %[[OUTPUT_0_0:.+]] = vector.transfer_read %[[OUTPUT_SUBVIEW]][%c0, %c0, %c0, %c4], %cst {masked = [false]} : memref<1x2x2x8xf32>, vector<4xf32> +// CHECK: %[[INPUT_0_0:.+]] = vector.transfer_read %[[INPUT_SUBVIEW]][%c0, %c0, %c0, %c4], %cst {masked = [false]} : memref<1x3x3x8xf32>, vector<4xf32> // CHECK: %[[FMA_0_0:.+]] = vector.fma %[[INPUT_0_0]], %[[FILTER_1]], %[[OUTPUT_0_0]] : vector<4xf32> -// CHECK: vector.transfer_write %[[FMA_0_0]], %[[OUTPUT_ARG]][%c0, %c0, %c0, %c4] {masked = [false]} : vector<4xf32>, memref<1x2x2x8xf32> +// CHECK: vector.transfer_write %[[FMA_0_0]], %[[OUTPUT_SUBVIEW]][%c0, %c0, %c0, %c4] {masked = [false]} : vector<4xf32>, memref<1x2x2x8xf32> -// CHECK: %[[OUTPUT_0_1:.+]] = vector.transfer_read %[[OUTPUT_ARG]][%c0, %c0, %c1, %c4], %cst {masked = [false]} : memref<1x2x2x8xf32>, vector<4xf32> -// CHECK: %[[INPUT_0_1:.+]] = vector.transfer_read %[[INPUT_ARG]][%c0, %c0, %c2, %c4], %cst {masked = [false]} : memref<1x3x3x8xf32>, vector<4xf32> +// CHECK: %[[OUTPUT_0_1:.+]] = vector.transfer_read %[[OUTPUT_SUBVIEW]][%c0, %c0, %c1, %c4], %cst {masked = [false]} : memref<1x2x2x8xf32>, vector<4xf32> +// CHECK: %[[INPUT_0_1:.+]] = vector.transfer_read %[[INPUT_SUBVIEW]][%c0, %c0, %c2, %c4], %cst {masked = [false]} : memref<1x3x3x8xf32>, vector<4xf32> // CHECK: %[[FMA_0_1:.+]] = vector.fma %[[INPUT_0_1]], %[[FILTER_1]], %[[OUTPUT_0_1]] : vector<4xf32> -// CHECK: vector.transfer_write %[[FMA_0_1]], %[[OUTPUT_ARG]][%c0, %c0, %c1, %c4] {masked = [false]} : vector<4xf32>, memref<1x2x2x8xf32> +// CHECK: vector.transfer_write %[[FMA_0_1]], %[[OUTPUT_SUBVIEW]][%c0, %c0, %c1, %c4] {masked = [false]} : vector<4xf32>, memref<1x2x2x8xf32> -// CHECK: %[[OUTPUT_1_0:.+]] = vector.transfer_read %[[OUTPUT_ARG]][%c0, %c1, %c0, %c4], %cst {masked = [false]} : memref<1x2x2x8xf32>, vector<4xf32> -// CHECK: %[[INPUT_1_0:.+]] = vector.transfer_read %[[INPUT_ARG]][%c0, %c2, %c0, %c4], %cst {masked = [false]} : memref<1x3x3x8xf32>, vector<4xf32> +// CHECK: %[[OUTPUT_1_0:.+]] = vector.transfer_read %[[OUTPUT_SUBVIEW]][%c0, %c1, %c0, %c4], %cst {masked = [false]} : memref<1x2x2x8xf32>, vector<4xf32> +// CHECK: %[[INPUT_1_0:.+]] = vector.transfer_read %[[INPUT_SUBVIEW]][%c0, %c2, %c0, %c4], %cst {masked = [false]} : memref<1x3x3x8xf32>, vector<4xf32> // CHECK: %[[FMA_1_0:.+]] = vector.fma %[[INPUT_1_0]], %[[FILTER_1]], %[[OUTPUT_1_0]] : vector<4xf32> -// CHECK: vector.transfer_write %[[FMA_1_0]], %[[OUTPUT_ARG]][%c0, %c1, %c0, %c4] {masked = [false]} : vector<4xf32>, memref<1x2x2x8xf32> +// CHECK: vector.transfer_write %[[FMA_1_0]], %[[OUTPUT_SUBVIEW]][%c0, %c1, %c0, %c4] {masked = [false]} : vector<4xf32>, memref<1x2x2x8xf32> -// CHECK: %[[OUTPUT_1_1:.+]] = vector.transfer_read %[[OUTPUT_ARG]][%c0, %c1, %c1, %c4], %cst {masked = [false]} : memref<1x2x2x8xf32>, vector<4xf32> -// CHECK: %[[INPUT_1_1:.+]] = vector.transfer_read %[[INPUT_ARG]][%c0, %c2, %c2, %c4], %cst {masked = [false]} : memref<1x3x3x8xf32>, vector<4xf32> +// CHECK: %[[OUTPUT_1_1:.+]] = vector.transfer_read %[[OUTPUT_SUBVIEW]][%c0, %c1, %c1, %c4], %cst {masked = [false]} : memref<1x2x2x8xf32>, vector<4xf32> +// CHECK: %[[INPUT_1_1:.+]] = vector.transfer_read %[[INPUT_SUBVIEW]][%c0, %c2, %c2, %c4], %cst {masked = [false]} : memref<1x3x3x8xf32>, vector<4xf32> // CHECK: %[[FMA_1_1:.+]] = vector.fma %[[INPUT_1_1]], %[[FILTER_1]], %[[OUTPUT_1_1]] : vector<4xf32> -// CHECK: vector.transfer_write %[[FMA_1_1]], %[[OUTPUT_ARG]][%c0, %c1, %c1, %c4] {masked = [false]} : vector<4xf32>, memref<1x2x2x8xf32> +// CHECK: vector.transfer_write %[[FMA_1_1]], %[[OUTPUT_SUBVIEW]][%c0, %c1, %c1, %c4] {masked = [false]} : vector<4xf32>, memref<1x2x2x8xf32> // ----- // CHECK-LABEL: func @do_not_vectorize_depthwise_conv_with_non_1_filter_height -func @do_not_vectorize_depthwise_conv_with_non_1_filter_height(%input: memref<1x2x3x4xf32>, %filter: memref<2x1x4xf32>, %output: memref<1x1x2x4xf32>) { - %0 = memref.subview %input[0, 0, 0, 0] [1, 2, 3, 4] [1, 1, 1, 1] : memref<1x2x3x4xf32> to memref<1x2x3x4xf32> - %1 = memref.subview %filter[0, 0, 0] [2, 1, 4] [1, 1, 1] : memref<2x1x4xf32> to memref<2x1x4xf32> - %2 = memref.subview %output[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : memref<1x1x2x4xf32> to memref<1x1x2x4xf32> +// Passing bigger buffers to avoid memref.subview fold awawy. +func @do_not_vectorize_depthwise_conv_with_non_1_filter_height(%input: memref<2x2x3x4xf32>, %filter: memref<3x1x4xf32>, %output: memref<2x1x2x4xf32>) { + %0 = memref.subview %input[0, 0, 0, 0] [1, 2, 3, 4] [1, 1, 1, 1] : memref<2x2x3x4xf32> to memref<1x2x3x4xf32> + %1 = memref.subview %filter[0, 0, 0] [2, 1, 4] [1, 1, 1] : memref<3x1x4xf32> to memref<2x1x4xf32> + %2 = memref.subview %output[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : memref<2x1x2x4xf32> to memref<1x1x2x4xf32> // CHECK: linalg.depthwise_conv_2d_input_nhwc_filter_hwc linalg.depthwise_conv_2d_input_nhwc_filter_hwc {strides = dense<2> : tensor<2xi64>} ins(%0, %1 : memref<1x2x3x4xf32>, memref<2x1x4xf32>) outs(%2 : memref<1x1x2x4xf32>) return @@ -207,10 +222,11 @@ func @do_not_vectorize_depthwise_conv_with_non_1_filter_height(%input: memref<1x // ----- // CHECK-LABEL: func @do_not_vectorize_depthwise_conv_with_non_1_filter_width -func @do_not_vectorize_depthwise_conv_with_non_1_filter_width(%input: memref<1x1x4x4xf32>, %filter: memref<1x2x4xf32>, %output: memref<1x1x2x4xf32>) { - %0 = memref.subview %input[0, 0, 0, 0] [1, 1, 4, 4] [1, 1, 1, 1] : memref<1x1x4x4xf32> to memref<1x1x4x4xf32> - %1 = memref.subview %filter[0, 0, 0] [1, 2, 4] [1, 1, 1] : memref<1x2x4xf32> to memref<1x2x4xf32> - %2 = memref.subview %output[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : memref<1x1x2x4xf32> to memref<1x1x2x4xf32> +// Passing bigger buffers to avoid memref.subview fold awawy. +func @do_not_vectorize_depthwise_conv_with_non_1_filter_width(%input: memref<2x1x4x4xf32>, %filter: memref<2x2x4xf32>, %output: memref<2x1x2x4xf32>) { + %0 = memref.subview %input[0, 0, 0, 0] [1, 1, 4, 4] [1, 1, 1, 1] : memref<2x1x4x4xf32> to memref<1x1x4x4xf32> + %1 = memref.subview %filter[0, 0, 0] [1, 2, 4] [1, 1, 1] : memref<2x2x4xf32> to memref<1x2x4xf32> + %2 = memref.subview %output[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : memref<2x1x2x4xf32> to memref<1x1x2x4xf32> // CHECK: linalg.depthwise_conv_2d_input_nhwc_filter_hwc linalg.depthwise_conv_2d_input_nhwc_filter_hwc {strides = dense<2> : tensor<2xi64>} ins(%0, %1 : memref<1x1x4x4xf32>, memref<1x2x4xf32>) outs(%2 : memref<1x1x2x4xf32>) return diff --git a/iree/compiler/Dialect/Flow/IR/test/stream_folding.mlir b/iree/compiler/Dialect/Flow/IR/test/stream_folding.mlir index 81ca6be3ebca..12685567cc28 100644 --- a/iree/compiler/Dialect/Flow/IR/test/stream_folding.mlir +++ b/iree/compiler/Dialect/Flow/IR/test/stream_folding.mlir @@ -95,7 +95,7 @@ func @dynamicUpdateSliceImmutability( %start1 = constant 1 : index %workload = constant 8 : index // CHECK: %[[TARGET_CLONE:.+]] = flow.tensor.clone %[[TARGET]] : tensor<2x4xi32> - // CHECK-NEXT: %[[UPDATED:.+]] = flow.tensor.update %[[UPDATE]], %[[TARGET]] + // CHECK: %[[UPDATED:.+]] = flow.tensor.update %[[UPDATE]], %[[TARGET]] %t0 = flow.tensor.update %stream_update, %stream_target[%start0, %start1] : tensor<1x1xi32> -> tensor<2x4xi32> // CHECK-NEXT: %[[RETURN:.+]] = flow.dispatch @ex::@entry[%c8](%[[TARGET_CLONE]], %[[UPDATED]]) %t1 = flow.dispatch @ex::@entry[%workload](%stream_target, %t0) : (tensor<2x4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32> diff --git a/iree/compiler/Dialect/Flow/Transforms/ConvertToFlowTensorOps.cpp b/iree/compiler/Dialect/Flow/Transforms/ConvertToFlowTensorOps.cpp index f92a2b441460..865ce1bd6610 100644 --- a/iree/compiler/Dialect/Flow/Transforms/ConvertToFlowTensorOps.cpp +++ b/iree/compiler/Dialect/Flow/Transforms/ConvertToFlowTensorOps.cpp @@ -96,7 +96,7 @@ struct ConvertToFlowTensorOpsPass FuncOp funcOp = getOperation(); MLIRContext *context = funcOp->getContext(); context->allowUnregisteredDialects(true); - OwningRewritePatternList patterns; + OwningRewritePatternList patterns(&getContext()); patterns.insert(context); if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { return signalPassFailure(); diff --git a/iree/compiler/Dialect/Flow/Transforms/DestructiveUpdateUtils.cpp b/iree/compiler/Dialect/Flow/Transforms/DestructiveUpdateUtils.cpp index 9661159e0615..2e0627aa30a6 100644 --- a/iree/compiler/Dialect/Flow/Transforms/DestructiveUpdateUtils.cpp +++ b/iree/compiler/Dialect/Flow/Transforms/DestructiveUpdateUtils.cpp @@ -486,7 +486,7 @@ LogicalResult rewriteLinalgDestructiveUpdates( // Non-default canonicalization patterns. // TODO(nicolasvasilache): add Linalg tiling canonicalization patterns, // affineminscf and others as needed. - OwningRewritePatternList canonicalizationPatterns; + OwningRewritePatternList canonicalizationPatterns(context); scf::ForOp::getCanonicalizationPatterns(canonicalizationPatterns, context); (void)applyPatternsAndFoldGreedily(dispatchOp, std::move(canonicalizationPatterns)); diff --git a/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp b/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp index 177703cda9cd..4cf9fefe2cc0 100644 --- a/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp +++ b/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp @@ -933,7 +933,7 @@ void DispatchLinalgOnTensorsPass::runOnOperation() { // Use the workgroup size as a proxy for tile size here. At the flow level // this represents the "workload" per processors and is not necessarily tied // to the workgroup size specified by the backend. - OwningRewritePatternList patterns; + OwningRewritePatternList patterns(&getContext()); auto linalgTilingOptions = linalg::LinalgTilingOptions() .setDistributionOptions(workgroupDistributionOptions) @@ -948,7 +948,7 @@ void DispatchLinalgOnTensorsPass::runOnOperation() { ArrayRef(), Identifier::get("workgroup", context))); // Add canonicalization patterns. - linalg::populateLinalgTilingCanonicalizationPatterns(patterns, context); + linalg::populateLinalgTilingCanonicalizationPatterns(patterns); patterns.insert(context); (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); } @@ -965,7 +965,7 @@ void DispatchLinalgOnTensorsPass::runOnOperation() { // Move other operations into their own dispatch regions. { - OwningRewritePatternList patterns; + OwningRewritePatternList patterns(&getContext()); patterns.insert(); (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); } @@ -982,7 +982,7 @@ void DispatchLinalgOnTensorsPass::runOnOperation() { // Run necessary canonicalization patterns before destructive updates. { - OwningRewritePatternList patterns; + OwningRewritePatternList patterns(&getContext()); // This is needed because tiling and distribution may create // subtensor_insert ops whose source operands come from tensor.cast ops. // Those tensor.cast ops cast tensors into a more dynamic shape, in order diff --git a/iree/compiler/Dialect/Flow/Transforms/HLOToHLOPreprocessing.cpp b/iree/compiler/Dialect/Flow/Transforms/HLOToHLOPreprocessing.cpp index db16081c17bd..4c3c7b47a596 100644 --- a/iree/compiler/Dialect/Flow/Transforms/HLOToHLOPreprocessing.cpp +++ b/iree/compiler/Dialect/Flow/Transforms/HLOToHLOPreprocessing.cpp @@ -796,7 +796,7 @@ struct HLOToHLOPreprocessing void runOnFunction() override { MLIRContext *context = &getContext(); ConversionTarget conversionTarget(*context); - OwningRewritePatternList conversionPatterns; + OwningRewritePatternList conversionPatterns(&getContext()); // Note that various input modalities may do their own legalization of // CHLO. Converting here allows IREE to accept CHLO dialect regardless of // whether it was legalized away at a higher level. @@ -810,7 +810,7 @@ struct HLOToHLOPreprocessing return signalPassFailure(); } - OwningRewritePatternList patterns; + OwningRewritePatternList patterns(&getContext()); mhlo::PopulateUnfuseBatchNormPatterns(context, &patterns); mhlo::PopulateComplexLoweringPatterns(context, &patterns); mhlo::PopulateGatherToTorchIndexSelectPatterns(context, &patterns); diff --git a/iree/compiler/Dialect/Flow/Transforms/PrePostPartitioningConversion.cpp b/iree/compiler/Dialect/Flow/Transforms/PrePostPartitioningConversion.cpp index 802e4340eeca..621afac0f00f 100644 --- a/iree/compiler/Dialect/Flow/Transforms/PrePostPartitioningConversion.cpp +++ b/iree/compiler/Dialect/Flow/Transforms/PrePostPartitioningConversion.cpp @@ -41,7 +41,7 @@ class PrePartitioningConversionPass void runOnFunction() override { auto *context = &getContext(); ConversionTarget conversionTarget(*context); - OwningRewritePatternList conversionPatterns; + OwningRewritePatternList conversionPatterns(&getContext()); conversionTarget.addLegalDialect(); diff --git a/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors_dynamic.mlir b/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors_dynamic.mlir index ee1f3fea0f18..9b93d1b05f56 100644 --- a/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors_dynamic.mlir +++ b/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors_dynamic.mlir @@ -194,7 +194,8 @@ func @two_dispatches(%A : tensor, %B : tensor) -> tensor {iree.reflection = {}}, %arg1: !shapex.ranked_shape<[?]> {iree.reflection = {}}) -> (tensor {iree.reflection = {}}) attributes {iree.module.export} { diff --git a/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertHALToVM.cpp b/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertHALToVM.cpp index 7f951297318b..b567f05ccc01 100644 --- a/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertHALToVM.cpp +++ b/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertHALToVM.cpp @@ -127,7 +127,7 @@ class ConvertHALToVMPass StringRef(hal_imports_create()->data, hal_imports_create()->size), innerModuleOp); - OwningRewritePatternList conversionPatterns; + OwningRewritePatternList conversionPatterns(&getContext()); populateStandardToVMPatterns(context, typeConverter, conversionPatterns); SymbolTable importSymbols(innerModuleOp); diff --git a/iree/compiler/Dialect/HAL/Transforms/ConvertToHAL.cpp b/iree/compiler/Dialect/HAL/Transforms/ConvertToHAL.cpp index 64a90c2ed8c6..007478077547 100644 --- a/iree/compiler/Dialect/HAL/Transforms/ConvertToHAL.cpp +++ b/iree/compiler/Dialect/HAL/Transforms/ConvertToHAL.cpp @@ -71,7 +71,7 @@ class ConvertToHALPass HALTypeConverter typeConverter(conversionInterfaces); HALConversionTarget conversionTarget(context, typeConverter); - OwningRewritePatternList patterns; + OwningRewritePatternList patterns(&getContext()); setupIREEToHALLegality(context, conversionTarget); populateIREEToHALPatterns(context, patterns); diff --git a/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp b/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp index 1f2a3a3bb8cc..75f20a079366 100644 --- a/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp +++ b/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp @@ -527,7 +527,7 @@ class MaterializeInterfacesPass } // Convert interface-related flow.dispatch.* ops to their hal.* versions. - OwningRewritePatternList patterns; + OwningRewritePatternList patterns(&getContext()); patterns.insert, diff --git a/iree/compiler/Dialect/HAL/Transforms/ResolveEntryPointOrdinals.cpp b/iree/compiler/Dialect/HAL/Transforms/ResolveEntryPointOrdinals.cpp index 54c19202878d..bd6a32dbc43e 100644 --- a/iree/compiler/Dialect/HAL/Transforms/ResolveEntryPointOrdinals.cpp +++ b/iree/compiler/Dialect/HAL/Transforms/ResolveEntryPointOrdinals.cpp @@ -84,7 +84,7 @@ class ResolveEntryPointOrdinalsPass public: void runOnOperation() override { MLIRContext *context = &getContext(); - OwningRewritePatternList patterns; + OwningRewritePatternList patterns(&getContext()); patterns.insert(context); patterns.insert(context); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); diff --git a/iree/compiler/Dialect/Shape/Conversion/ConvertShapeToShapex.cpp b/iree/compiler/Dialect/Shape/Conversion/ConvertShapeToShapex.cpp index d7b189258055..1aaa9a202bdb 100644 --- a/iree/compiler/Dialect/Shape/Conversion/ConvertShapeToShapex.cpp +++ b/iree/compiler/Dialect/Shape/Conversion/ConvertShapeToShapex.cpp @@ -284,7 +284,7 @@ class ConvertShapeToShapex conversionTarget.addLegalDialect(); // Patterns. - OwningRewritePatternList patterns; + OwningRewritePatternList patterns(&getContext()); patterns.insert(context); patterns.insert(context); patterns.insert(context); diff --git a/iree/compiler/Dialect/Shape/IR/test/canonicalize.mlir b/iree/compiler/Dialect/Shape/IR/test/canonicalize.mlir index 73070967dc46..c92301a9129e 100644 --- a/iree/compiler/Dialect/Shape/IR/test/canonicalize.mlir +++ b/iree/compiler/Dialect/Shape/IR/test/canonicalize.mlir @@ -45,8 +45,8 @@ func @foldStaticRankedDim(%arg0: !shapex.ranked_shape<[1,?,2,?]>) -> (i32, i32) func @foldFullyStaticRankedShape(%arg0: tensor<1x2xf32>) -> (i32, i32) { // CHECK-NOT: shapex.get_ranked_shape // CHECK-NOT: shapex.ranked_dim - // CHECK: constant 1 - // CHECK: constant 2 + // CHECK-DAG: constant 1 + // CHECK-DAG: constant 2 %0 = shapex.get_ranked_shape %arg0 : tensor<1x2xf32> -> !shapex.ranked_shape<[1,2]> %1 = shapex.ranked_dim %0[0] : !shapex.ranked_shape<[1,2]> -> i32 %2 = shapex.ranked_dim %0[1] : !shapex.ranked_shape<[1,2]> -> i32 @@ -74,8 +74,8 @@ func @foldFullyStaticRankedShapeDims(%arg0: tensor<1x2xf32>) -> (i32, i32) { // CHECK-NOT: shapex.get_ranked_shape // CHECK-NOT: shapex.ranked_dims // CHECK-NOT: shapex.ranked_dim - // CHECK: constant 1 - // CHECK: constant 2 + // CHECK-DAG: constant 1 + // CHECK-DAG: constant 2 %0 = shapex.get_ranked_shape %arg0 : tensor<1x2xf32> -> !shapex.ranked_shape<[1,2]> %1:2 = shapex.ranked_dims %0 : !shapex.ranked_shape<[1,2]> -> i32, i32 return %1#0, %1#1 : i32, i32 diff --git a/iree/compiler/Dialect/Shape/Transforms/CleanupPlaceholdersPass.cpp b/iree/compiler/Dialect/Shape/Transforms/CleanupPlaceholdersPass.cpp index e4f11c9eaf8b..762cea8e113d 100644 --- a/iree/compiler/Dialect/Shape/Transforms/CleanupPlaceholdersPass.cpp +++ b/iree/compiler/Dialect/Shape/Transforms/CleanupPlaceholdersPass.cpp @@ -38,7 +38,7 @@ class CleanupTieShapePattern : public OpRewritePattern { class CleanupShapePlaceholdersPass : public PassWrapper { void runOnFunction() override { - OwningRewritePatternList patterns; + OwningRewritePatternList patterns(&getContext()); patterns.insert(&getContext()); (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); } diff --git a/iree/compiler/Dialect/Shape/Transforms/ConvertHLOToShapeDialectPass.cpp b/iree/compiler/Dialect/Shape/Transforms/ConvertHLOToShapeDialectPass.cpp index c1af25b2381f..0b23d90bb928 100644 --- a/iree/compiler/Dialect/Shape/Transforms/ConvertHLOToShapeDialectPass.cpp +++ b/iree/compiler/Dialect/Shape/Transforms/ConvertHLOToShapeDialectPass.cpp @@ -72,7 +72,7 @@ class ConvertHLOToShapePass void runOnFunction() override { ConversionTarget conversionTarget(getContext()); - OwningRewritePatternList conversionPatterns; + OwningRewritePatternList conversionPatterns(&getContext()); conversionTarget.addLegalDialect(); conversionTarget.addLegalDialect(); diff --git a/iree/compiler/Dialect/Shape/Transforms/MaterializeShapeCalculationsPass.cpp b/iree/compiler/Dialect/Shape/Transforms/MaterializeShapeCalculationsPass.cpp index 4ef0e2fa9b17..ff93f8c1767f 100644 --- a/iree/compiler/Dialect/Shape/Transforms/MaterializeShapeCalculationsPass.cpp +++ b/iree/compiler/Dialect/Shape/Transforms/MaterializeShapeCalculationsPass.cpp @@ -57,7 +57,7 @@ class MaterializeShapeCalculationsPass target.addLegalDialect(); setupMaterializeShapeCalculationsLegality(target); - OwningRewritePatternList conversionPatterns; + OwningRewritePatternList conversionPatterns(&getContext()); populateMaterializeShapeCalculationsConversionPatterns(conversionPatterns, context); if (failed(applyPartialConversion(getOperation(), target, @@ -69,7 +69,7 @@ class MaterializeShapeCalculationsPass // And then canonicalize shape ops. // TODO(laurenzo): I would prefer to get the list of ops in the dialect // versus doing this, but I don't know that is possible. - OwningRewritePatternList patterns; + OwningRewritePatternList patterns(&getContext()); CastCompatibleShapeOp::getCanonicalizationPatterns(patterns, context); GetRankedShapeOp::getCanonicalizationPatterns(patterns, context); MakeRankedShapeOp::getCanonicalizationPatterns(patterns, context); diff --git a/iree/compiler/Dialect/VM/Conversion/StandardToVM/ConvertStandardToVMTest.cpp b/iree/compiler/Dialect/VM/Conversion/StandardToVM/ConvertStandardToVMTest.cpp index 15b9ed3a746b..416341e12401 100644 --- a/iree/compiler/Dialect/VM/Conversion/StandardToVM/ConvertStandardToVMTest.cpp +++ b/iree/compiler/Dialect/VM/Conversion/StandardToVM/ConvertStandardToVMTest.cpp @@ -41,7 +41,7 @@ class ConvertStandardToVMTestPass IREE::VM::TypeConverter typeConverter( IREE::VM::getTargetOptionsFromFlags()); - OwningRewritePatternList patterns; + OwningRewritePatternList patterns(&getContext()); populateStandardToVMPatterns(&getContext(), typeConverter, patterns); // NOTE: we allow other dialects besides just VM during this pass as we are diff --git a/iree/compiler/Dialect/VM/Conversion/StandardToVM/test/assignment_ops.mlir b/iree/compiler/Dialect/VM/Conversion/StandardToVM/test/assignment_ops.mlir index b3d6de31a377..b12a8b233dd1 100644 --- a/iree/compiler/Dialect/VM/Conversion/StandardToVM/test/assignment_ops.mlir +++ b/iree/compiler/Dialect/VM/Conversion/StandardToVM/test/assignment_ops.mlir @@ -8,13 +8,15 @@ module @my_module { // CHECK: func @my_fn // CHECK-SAME: %[[ARG0:[a-zA-Z0-9$._-]+]] // CHECK-SAME: %[[ARG1:[a-zA-Z0-9$._-]+]] - func @my_fn(%arg0 : i32, %arg1 : i32) -> (i32) { + // CHECK-SAME: %[[ARG2:[a-zA-Z0-9$._-]+]] + // CHECK-SAME: %[[ARG3:[a-zA-Z0-9$._-]+]] + func @my_fn(%arg0 : i32, %arg1 : i32, %arg2 : i32, %arg3 : i32) -> (i32) { // Note that in std, cmp returns an i1 and this relies on the dialect // conversion framework promoting that to i32. // CHECK: %[[CMP:[a-zA-Z0-9$._-]+]] = vm.cmp.eq.i32 %1 = cmpi eq, %arg0, %arg1 : i32 - // CHECK: vm.select.i32 %[[CMP]], %[[ARG0]], %[[ARG1]] : i32 - %2 = select %1, %arg0, %arg1 : i32 + // CHECK: vm.select.i32 %[[CMP]], %[[ARG2]], %[[ARG3]] : i32 + %2 = select %1, %arg2, %arg3 : i32 return %2 : i32 } } @@ -29,13 +31,15 @@ module @my_module { // CHECK: func @my_fn // CHECK-SAME: %[[ARG0:[a-zA-Z0-9$._-]+]] // CHECK-SAME: %[[ARG1:[a-zA-Z0-9$._-]+]] - func @my_fn(%arg0 : index, %arg1 : index) -> (index) { + // CHECK-SAME: %[[ARG2:[a-zA-Z0-9$._-]+]] + // CHECK-SAME: %[[ARG3:[a-zA-Z0-9$._-]+]] + func @my_fn(%arg0 : index, %arg1 : index, %arg2 : index, %arg3 : index) -> (index) { // Note that in std, cmp returns an i1 and this relies on the dialect // conversion framework promoting that to i32. // CHECK: %[[CMP:[a-zA-Z0-9$._-]+]] = vm.cmp.eq.i32 %1 = cmpi eq, %arg0, %arg1 : index - // CHECK: vm.select.i32 %[[CMP]], %[[ARG0]], %[[ARG1]] : i32 - %2 = select %1, %arg0, %arg1 : index + // CHECK: vm.select.i32 %[[CMP]], %[[ARG2]], %[[ARG3]] : i32 + %2 = select %1, %arg2, %arg3 : index return %2 : index } } diff --git a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp index 7a4ddbbac843..48d9784d4ab9 100644 --- a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp +++ b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp @@ -348,7 +348,7 @@ class ConvertVMToEmitCPass void runOnOperation() override { ConversionTarget target(getContext()); - OwningRewritePatternList patterns; + OwningRewritePatternList patterns(&getContext()); populateVMToCPatterns(&getContext(), patterns); target.addLegalDialect(); diff --git a/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp b/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp index b002f4fd3c14..246467b3ec76 100644 --- a/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp +++ b/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp @@ -284,10 +284,38 @@ void GlobalStoreIndirectRefOp::getCanonicalizationPatterns( // Constants //===----------------------------------------------------------------------===// +namespace { + +template +struct FoldZeroConstInteger final : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(GeneralOp constOp, + PatternRewriter &rewriter) const override { + if (matchPattern(constOp.result(), m_Zero())) { + rewriter.replaceOpWithNewOp(constOp); + return success(); + } + return failure(); + } +}; + +} // namespace + OpFoldResult ConstI32Op::fold(ArrayRef operands) { return value(); } +void ConstI32Op::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert>(context); +} + OpFoldResult ConstI64Op::fold(ArrayRef operands) { return value(); } +void ConstI64Op::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert>(context); +} + OpFoldResult ConstI32ZeroOp::fold(ArrayRef operands) { return IntegerAttr::get(getResult().getType(), 0); } diff --git a/iree/compiler/Dialect/VM/IR/VMOps.td b/iree/compiler/Dialect/VM/IR/VMOps.td index e624738d9ded..38fca4d27c47 100644 --- a/iree/compiler/Dialect/VM/IR/VMOps.td +++ b/iree/compiler/Dialect/VM/IR/VMOps.td @@ -661,6 +661,7 @@ def VM_ConstI32Op : VM_ConstIntegerOp { let summary = [{32-bit integer constant operation}]; let hasFolder = 1; + let hasCanonicalizer = 1; } def VM_ConstI64Op : @@ -668,6 +669,7 @@ def VM_ConstI64Op : [VM_ExtI64]> { let summary = [{64-bit integer constant operation}]; let hasFolder = 1; + let hasCanonicalizer = 1; } class VM_ConstIntegerZeroOp BytecodeEncoder::encodeFunction( } for (auto &op : block.getOperations()) { - auto *serializableOp = - op.getAbstractOperation()->getInterface(); + auto serializableOp = dyn_cast(op); if (!serializableOp) { op.emitOpError() << "is not serializable"; return llvm::None; } if (failed(encoder.beginOp(&op)) || - failed(serializableOp->encode(&op, symbolTable, encoder)) || + failed(serializableOp.encode(symbolTable, encoder)) || failed(encoder.endOp(&op))) { op.emitOpError() << "failed to encode"; return llvm::None; diff --git a/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.cpp b/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.cpp index 5c5aef17d087..821ba43ec517 100644 --- a/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.cpp +++ b/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.cpp @@ -108,7 +108,7 @@ static std::vector buildTypeTable(IREE::VM::ModuleOp moduleOp) { // required transformations (such as debug op stripping). static LogicalResult canonicalizeModule(BytecodeTargetOptions targetOptions, IREE::VM::ModuleOp moduleOp) { - OwningRewritePatternList patterns; + OwningRewritePatternList patterns(moduleOp.getContext()); ConversionTarget target(*moduleOp.getContext()); target.addLegalDialect(); target.addLegalOp(); diff --git a/iree/compiler/Dialect/VM/Target/C/CModuleTarget.cpp b/iree/compiler/Dialect/VM/Target/C/CModuleTarget.cpp index 58fead54eeac..7aec29665635 100644 --- a/iree/compiler/Dialect/VM/Target/C/CModuleTarget.cpp +++ b/iree/compiler/Dialect/VM/Target/C/CModuleTarget.cpp @@ -568,7 +568,7 @@ static LogicalResult buildModuleDescriptors(IREE::VM::ModuleOp &moduleOp, // Adapted from BytecodeModuleTarget and extended by C specific passes static LogicalResult canonicalizeModule( IREE::VM::ModuleOp moduleOp, IREE::VM::CTargetOptions targetOptions) { - OwningRewritePatternList patterns; + OwningRewritePatternList patterns(&getContext()); ConversionTarget target(*moduleOp.getContext()); target.addLegalDialect(); target.addLegalOp(); diff --git a/iree/compiler/Dialect/VM/Transforms/Conversion.cpp b/iree/compiler/Dialect/VM/Transforms/Conversion.cpp index b010a6bb3d3a..c74864f86ff1 100644 --- a/iree/compiler/Dialect/VM/Transforms/Conversion.cpp +++ b/iree/compiler/Dialect/VM/Transforms/Conversion.cpp @@ -120,7 +120,7 @@ class ConversionPass } } - OwningRewritePatternList conversionPatterns; + OwningRewritePatternList conversionPatterns(&getContext()); populateIREEToVMPatterns(context, typeConverter, conversionPatterns); populateStandardToVMPatterns(context, typeConverter, conversionPatterns); conversionPatterns.insert(context); diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/fft.mlir b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/fft.mlir index 027bfcd28538..0991c792b24f 100644 --- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/fft.mlir +++ b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/fft.mlir @@ -1,9 +1,9 @@ // RUN: iree-opt -split-input-file -iree-vmla-pre-conversion-lowering -iree-vmla-conversion -canonicalize %s | IreeFileCheck %s func private @fft(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) -> (tensor<8xf32>, tensor<8xf32>) { - // CHECK: [[RS:%.+]] = shapex.const_ranked_shape : !shapex.ranked_shape<[8]> - // CHECK-NEXT: [[C32:%.+]] = constant 32 : index - // CHECK-NEXT: [[OUTBUF1:%.+]] = vmla.buffer.alloc byte_length = [[C32]] : !vmla.buffer + // CHECK-DAG: [[RS:%.+]] = shapex.const_ranked_shape : !shapex.ranked_shape<[8]> + // CHECK-DAG: [[C32:%.+]] = constant 32 : index + // CHECK: [[OUTBUF1:%.+]] = vmla.buffer.alloc byte_length = [[C32]] : !vmla.buffer // CHECK-NEXT: [[OUTBUF2:%.+]] = vmla.buffer.alloc byte_length = [[C32]] : !vmla.buffer // CHECK-NEXT: vmla.fft %arg0([[RS]] : !shapex.ranked_shape<[8]>), %arg1([[RS]] : !shapex.ranked_shape<[8]>), out [[OUTBUF1]], [[OUTBUF2]] : f32 %real, %imag = "vmla.fft.pseudo"(%arg0, %arg1) : (tensor<8xf32>, tensor<8xf32>) -> (tensor<8xf32>, tensor<8xf32>) @@ -11,9 +11,9 @@ func private @fft(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) -> (tensor<8xf32> } func private @ifft(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) -> (tensor<8xf32>, tensor<8xf32>) { - // CHECK: [[RS:%.+]] = shapex.const_ranked_shape : !shapex.ranked_shape<[8]> - // CHECK-NEXT: [[C32:%.+]] = constant 32 : index - // CHECK-NEXT: [[OUTBUF1:%.+]] = vmla.buffer.alloc byte_length = [[C32]] : !vmla.buffer + // CHECK-DAG: [[RS:%.+]] = shapex.const_ranked_shape : !shapex.ranked_shape<[8]> + // CHECK-DAG: [[C32:%.+]] = constant 32 : index + // CHECK: [[OUTBUF1:%.+]] = vmla.buffer.alloc byte_length = [[C32]] : !vmla.buffer // CHECK-NEXT: [[OUTBUF2:%.+]] = vmla.buffer.alloc byte_length = [[C32]] : !vmla.buffer // CHECK-NEXT: vmla.ifft %arg0([[RS]] : !shapex.ranked_shape<[8]>), %arg1([[RS]] : !shapex.ranked_shape<[8]>), out [[OUTBUF1]], [[OUTBUF2]] : f32 %real, %imag = "vmla.ifft.pseudo"(%arg0, %arg1) : (tensor<8xf32>, tensor<8xf32>) -> (tensor<8xf32>, tensor<8xf32>) @@ -21,9 +21,9 @@ func private @ifft(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) -> (tensor<8xf32 } func private @rfft(%arg0: tensor<8xf32>) -> (tensor<5xf32>, tensor<5xf32>) { - // CHECK: [[RS:%.+]] = shapex.const_ranked_shape : !shapex.ranked_shape<[8]> - // CHECK-NEXT: [[C20:%.+]] = constant 20 : index - // CHECK-NEXT: [[OUTBUF1:%.+]] = vmla.buffer.alloc byte_length = [[C20]] : !vmla.buffer + // CHECK-DAG: [[RS:%.+]] = shapex.const_ranked_shape : !shapex.ranked_shape<[8]> + // CHECK-DAG: [[C20:%.+]] = constant 20 : index + // CHECK: [[OUTBUF1:%.+]] = vmla.buffer.alloc byte_length = [[C20]] : !vmla.buffer // CHECK-NEXT: [[OUTBUF2:%.+]] = vmla.buffer.alloc byte_length = [[C20]] : !vmla.buffer // CHECK-NEXT: vmla.rfft %arg0([[RS]] : !shapex.ranked_shape<[8]>), out [[OUTBUF1]], [[OUTBUF2]] : f32 %real, %imag = "vmla.rfft.pseudo"(%arg0) : (tensor<8xf32>) -> (tensor<5xf32>, tensor<5xf32>) @@ -31,8 +31,8 @@ func private @rfft(%arg0: tensor<8xf32>) -> (tensor<5xf32>, tensor<5xf32>) { } func private @irfft(%arg0: tensor<5xf32>, %arg1: tensor<5xf32>) -> tensor<8xf32> { - // CHECK: [[RS:%.+]] = shapex.const_ranked_shape : !shapex.ranked_shape<[5]> - // CHECK-NEXT: [[C32:%.+]] = constant 32 : index + // CHECK-DAG: [[RS:%.+]] = shapex.const_ranked_shape : !shapex.ranked_shape<[5]> + // CHECK-DAG: [[C32:%.+]] = constant 32 : index // CHECK-NEXT: [[OUTBUF1:%.+]] = vmla.buffer.alloc byte_length = [[C32]] : !vmla.buffer // CHECK-NEXT: vmla.irfft %arg0([[RS]] : !shapex.ranked_shape<[5]>), %arg1([[RS]] : !shapex.ranked_shape<[5]>), out [[OUTBUF1]] : f32 %real = "vmla.irfft.pseudo"(%arg0, %arg1) : (tensor<5xf32>, tensor<5xf32>) -> (tensor<8xf32>) diff --git a/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/ConvertVMLAToVM.cpp b/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/ConvertVMLAToVM.cpp index ad011a8b0b21..b49c6e3fdb1f 100644 --- a/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/ConvertVMLAToVM.cpp +++ b/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/ConvertVMLAToVM.cpp @@ -402,7 +402,7 @@ class ConvertVMLAToVMPass StringRef(vmla_imports_create()->data, vmla_imports_create()->size), innerModuleOp); - OwningRewritePatternList conversionPatterns; + OwningRewritePatternList conversionPatterns(&getContext()); populateStandardToVMPatterns(context, typeConverter, conversionPatterns); SymbolTable importSymbols(innerModuleOp); diff --git a/iree/compiler/Dialect/VMLA/Transforms/Conversion.cpp b/iree/compiler/Dialect/VMLA/Transforms/Conversion.cpp index be6a9073eea5..cfdab50ae968 100644 --- a/iree/compiler/Dialect/VMLA/Transforms/Conversion.cpp +++ b/iree/compiler/Dialect/VMLA/Transforms/Conversion.cpp @@ -85,14 +85,13 @@ class ConversionPass conversionTarget.addIllegalDialect(); conversionTarget.addIllegalDialect(); - OwningRewritePatternList conversionPatterns; + OwningRewritePatternList conversionPatterns(&getContext()); populateStandardToVMLAPatterns(context, conversionPatterns, typeConverter); populateHLOToVMLAPatterns(context, conversionPatterns, typeConverter); populateHALToVMLAPatterns(context, conversionPatterns, typeConverter); // Ensure FuncOp signatures are updated. - populateFuncOpTypeConversionPattern(conversionPatterns, context, - typeConverter); + populateFuncOpTypeConversionPattern(conversionPatterns, typeConverter); // We allow the shape dialect to persist, making specific dim queries // illegal (which allows them to fold away). These patterns allow dimension diff --git a/iree/compiler/Dialect/VMLA/Transforms/PreConversionLowering.cpp b/iree/compiler/Dialect/VMLA/Transforms/PreConversionLowering.cpp index 4b22d31b9d4d..1afdbd20761c 100644 --- a/iree/compiler/Dialect/VMLA/Transforms/PreConversionLowering.cpp +++ b/iree/compiler/Dialect/VMLA/Transforms/PreConversionLowering.cpp @@ -470,14 +470,14 @@ class PreConversionLoweringPass // These patterns should be run greedily as they are not dialect // conversions. - OwningRewritePatternList greedyPatterns; + OwningRewritePatternList greedyPatterns(&getContext()); mhlo::PopulateComplexLoweringPatterns(context, &greedyPatterns); if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(greedyPatterns)))) { return signalPassFailure(); } - OwningRewritePatternList patterns; + OwningRewritePatternList patterns(&getContext()); ConversionTarget target(*context); target.addLegalDialect(); target.addLegalDialect(); @@ -503,7 +503,7 @@ class PreConversionLoweringPass } { - OwningRewritePatternList greedyPatterns; + OwningRewritePatternList greedyPatterns(&getContext()); greedyPatterns.insert(context); if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(greedyPatterns)))) { diff --git a/iree/compiler/Dialect/VMLA/Transforms/test/transformation.mlir b/iree/compiler/Dialect/VMLA/Transforms/test/transformation.mlir index 493aa88aabca..9e7b73c5c579 100644 --- a/iree/compiler/Dialect/VMLA/Transforms/test/transformation.mlir +++ b/iree/compiler/Dialect/VMLA/Transforms/test/transformation.mlir @@ -17,8 +17,8 @@ hal.interface @legacy_io attributes {sym_visibility = "private"} { } // CHECK: func @simpleMath_rgn_dispatch_0(%arg0: !vmla.interface, %arg1: index, %arg2: index, %arg3: index) { -// CHECK-NEXT: %c0 = constant 0 : index -// CHECK-NEXT: %c16 = constant 16 : index +// CHECK-DAG: %c0 = constant 0 : index +// CHECK-DAG: %c16 = constant 16 : index // CHECK-NEXT: %0 = vmla.interface.binding %arg0 {binding = 0 : i32, set = 0 : i32} : !vmla.buffer // CHECK-NEXT: %1 = vmla.buffer.view %0[%c0], byte_length = %c16 : !vmla.buffer // CHECK-NEXT: %2 = vmla.buffer.alloc byte_length = %c16 : !vmla.buffer diff --git a/iree/test/e2e/tosa_ops/BUILD b/iree/test/e2e/tosa_ops/BUILD index cb3a05392252..06a3544b2e1c 100644 --- a/iree/test/e2e/tosa_ops/BUILD +++ b/iree/test/e2e/tosa_ops/BUILD @@ -47,7 +47,6 @@ ALL_SRCS = enforce_glob( "logical_right_shift.mlir", "maximum.mlir", "minimum.mlir", - "mul.mlir", "negate.mlir", "reluN.mlir", "reshape.mlir", @@ -59,6 +58,9 @@ ALL_SRCS = enforce_glob( "while.mlir", ], include = ["*.mlir"], + exclude = [ + "mul.mlir", # TODO(suderman): Re-enable once apply_scale lowering lands. + ], ) iree_check_single_backend_test_suite( diff --git a/iree/test/e2e/tosa_ops/CMakeLists.txt b/iree/test/e2e/tosa_ops/CMakeLists.txt index 3948a95a51dd..0fc880bf3425 100644 --- a/iree/test/e2e/tosa_ops/CMakeLists.txt +++ b/iree/test/e2e/tosa_ops/CMakeLists.txt @@ -32,7 +32,6 @@ iree_check_single_backend_test_suite( "logical_right_shift.mlir" "maximum.mlir" "minimum.mlir" - "mul.mlir" "negate.mlir" "reluN.mlir" "reshape.mlir" @@ -70,7 +69,6 @@ iree_check_single_backend_test_suite( "logical_right_shift.mlir" "maximum.mlir" "minimum.mlir" - "mul.mlir" "negate.mlir" "reluN.mlir" "reshape.mlir" diff --git a/third_party/llvm-bazel b/third_party/llvm-bazel index 013b829185fe..189e771009a6 160000 --- a/third_party/llvm-bazel +++ b/third_party/llvm-bazel @@ -1 +1 @@ -Subproject commit 013b829185fee6d8eaa515a7e36ec468a2a02600 +Subproject commit 189e771009a640214e08e855830ae6f15a83c655 diff --git a/third_party/llvm-project b/third_party/llvm-project index cd442157cff4..1f6a57c1a0fa 160000 --- a/third_party/llvm-project +++ b/third_party/llvm-project @@ -1 +1 @@ -Subproject commit cd442157cff4aad209ae532cbf031abbe10bc1df +Subproject commit 1f6a57c1a0fad922e04a2b1f414b092d4b0cd8b0 diff --git a/third_party/mlir-hlo b/third_party/mlir-hlo index 431be0e9b235..cbef26c6a8f1 160000 --- a/third_party/mlir-hlo +++ b/third_party/mlir-hlo @@ -1 +1 @@ -Subproject commit 431be0e9b235e1b98adf0367f3beb440aa672875 +Subproject commit cbef26c6a8f1e4be3f4cfb902db992c45e93b7a6 diff --git a/third_party/tensorflow b/third_party/tensorflow index aa3bd9f6de5a..da3da1e8a81a 160000 --- a/third_party/tensorflow +++ b/third_party/tensorflow @@ -1 +1 @@ -Subproject commit aa3bd9f6de5a76c4c226548a48e448d211978e92 +Subproject commit da3da1e8a81a9866d98bcfe54eb21ec27cab7000