From 7072c93e406b361641367fd29f6ecddff01e5a13 Mon Sep 17 00:00:00 2001 From: GMNGeoffrey Date: Mon, 22 Mar 2021 12:17:52 +0000 Subject: [PATCH 1/5] Synchronize submodules with LLVM at llvm/llvm-project@cd442157cff4 --- SUBMODULE_VERSIONS.txt | 4 ++-- third_party/mlir-hlo | 2 +- third_party/tensorflow | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/SUBMODULE_VERSIONS.txt b/SUBMODULE_VERSIONS.txt index 6b0cf5bff932..35c5fd335a75 100644 --- a/SUBMODULE_VERSIONS.txt +++ b/SUBMODULE_VERSIONS.txt @@ -7,13 +7,13 @@ b1fbd33c06cdb0024c67733c6fdec2009d17b384 third_party/googletest 013b829185fee6d8eaa515a7e36ec468a2a02600 third_party/llvm-bazel cd442157cff4aad209ae532cbf031abbe10bc1df third_party/llvm-project 3483f1653fc7cb3bfb3a4d1b463f3a651ecaa676 third_party/mlir-emitc -431be0e9b235e1b98adf0367f3beb440aa672875 third_party/mlir-hlo +98debb127d3a14e0239a3432461e3876d293b409 third_party/mlir-hlo 2b2bd45bbf9be04fd22ece5cc1f54679202e9257 third_party/pffft d8c7ee00a687ac369e62e2032514a93a9b413502 third_party/pybind11 2887692065c38ef6617f423feafc6b69dd0a0681 third_party/ruy 685f86471e9d26b3eb7676695a2e2cefb4551ae9 third_party/spirv_cross f8bf11a0253a32375c32cad92c841237b96696c0 third_party/spirv_headers -aa3bd9f6de5a76c4c226548a48e448d211978e92 third_party/tensorflow +5c483374ac525a388ac9b1b24e468eb874ed0980 third_party/tensorflow 8732f0e94e4e41049a43029202bda94d7b4e85da third_party/tracy 9bd3f561bcee3f01d22912de10bb07ce4e23d378 third_party/vulkan_headers 3528e2aed3e8808f33e1e7d63eeb1560456a605a third_party/vulkan_memory_allocator diff --git a/third_party/mlir-hlo b/third_party/mlir-hlo index 431be0e9b235..98debb127d3a 160000 --- a/third_party/mlir-hlo +++ b/third_party/mlir-hlo @@ -1 +1 @@ -Subproject commit 431be0e9b235e1b98adf0367f3beb440aa672875 +Subproject commit 98debb127d3a14e0239a3432461e3876d293b409 diff --git a/third_party/tensorflow b/third_party/tensorflow index aa3bd9f6de5a..5c483374ac52 160000 --- a/third_party/tensorflow +++ b/third_party/tensorflow @@ -1 +1 @@ -Subproject commit aa3bd9f6de5a76c4c226548a48e448d211978e92 +Subproject commit 5c483374ac525a388ac9b1b24e468eb874ed0980 From c6b2bc74c14226b32ad6e0e90e78eec8e19d77d4 Mon Sep 17 00:00:00 2001 From: iree-copybara-bot Date: Tue, 23 Mar 2021 04:38:28 -0700 Subject: [PATCH 2/5] Integrate LLVM at llvm/llvm-project@e990fa217031 Updates LLVM usage to match [e990fa217031](https://github.com/llvm/llvm-project/commit/e990fa217031) PiperOrigin-RevId: 364529739 --- SUBMODULE_VERSIONS.txt | 2 +- third_party/llvm-project | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/SUBMODULE_VERSIONS.txt b/SUBMODULE_VERSIONS.txt index 35c5fd335a75..618d511d211a 100644 --- a/SUBMODULE_VERSIONS.txt +++ b/SUBMODULE_VERSIONS.txt @@ -5,7 +5,7 @@ daff5fead3fbe22c6fc58310ca3f49caf117f185 third_party/benchmark b1fbd33c06cdb0024c67733c6fdec2009d17b384 third_party/googletest 88b845dee001723c4a0db1fe5477de735b6d3bb0 third_party/liburing 013b829185fee6d8eaa515a7e36ec468a2a02600 third_party/llvm-bazel -cd442157cff4aad209ae532cbf031abbe10bc1df third_party/llvm-project +e990fa2170314b179ec025b68fd00fbe9aab398d third_party/llvm-project 3483f1653fc7cb3bfb3a4d1b463f3a651ecaa676 third_party/mlir-emitc 98debb127d3a14e0239a3432461e3876d293b409 third_party/mlir-hlo 2b2bd45bbf9be04fd22ece5cc1f54679202e9257 third_party/pffft diff --git a/third_party/llvm-project b/third_party/llvm-project index cd442157cff4..e990fa217031 160000 --- a/third_party/llvm-project +++ b/third_party/llvm-project @@ -1 +1 @@ -Subproject commit cd442157cff4aad209ae532cbf031abbe10bc1df +Subproject commit e990fa2170314b179ec025b68fd00fbe9aab398d From 9407792e3d9cafa1c0f5518beb720247205eea07 Mon Sep 17 00:00:00 2001 From: iree-copybara-bot Date: Tue, 23 Mar 2021 06:14:57 -0700 Subject: [PATCH 3/5] Integrate LLVM at llvm/llvm-project@5657f93e788f Updates LLVM usage to match [5657f93e788f](https://github.com/llvm/llvm-project/commit/5657f93e788f) PiperOrigin-RevId: 364541987 --- SUBMODULE_VERSIONS.txt | 2 +- third_party/llvm-project | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/SUBMODULE_VERSIONS.txt b/SUBMODULE_VERSIONS.txt index 618d511d211a..758063c9409e 100644 --- a/SUBMODULE_VERSIONS.txt +++ b/SUBMODULE_VERSIONS.txt @@ -5,7 +5,7 @@ daff5fead3fbe22c6fc58310ca3f49caf117f185 third_party/benchmark b1fbd33c06cdb0024c67733c6fdec2009d17b384 third_party/googletest 88b845dee001723c4a0db1fe5477de735b6d3bb0 third_party/liburing 013b829185fee6d8eaa515a7e36ec468a2a02600 third_party/llvm-bazel -e990fa2170314b179ec025b68fd00fbe9aab398d third_party/llvm-project +5657f93e788f093c70fb448dd6f9398b149df278 third_party/llvm-project 3483f1653fc7cb3bfb3a4d1b463f3a651ecaa676 third_party/mlir-emitc 98debb127d3a14e0239a3432461e3876d293b409 third_party/mlir-hlo 2b2bd45bbf9be04fd22ece5cc1f54679202e9257 third_party/pffft diff --git a/third_party/llvm-project b/third_party/llvm-project index e990fa217031..5657f93e788f 160000 --- a/third_party/llvm-project +++ b/third_party/llvm-project @@ -1 +1 @@ -Subproject commit e990fa2170314b179ec025b68fd00fbe9aab398d +Subproject commit 5657f93e788f093c70fb448dd6f9398b149df278 From 136052f41a2960fef2661c9cb85740e4dd5ccf88 Mon Sep 17 00:00:00 2001 From: iree-copybara-bot Date: Tue, 23 Mar 2021 07:25:17 -0700 Subject: [PATCH 4/5] Integrate LLVM at llvm/llvm-project@0776eca7a4e7 Updates LLVM usage to match [0776eca7a4e7](https://github.com/llvm/llvm-project/commit/0776eca7a4e7) PiperOrigin-RevId: 364552759 --- SUBMODULE_VERSIONS.txt | 2 +- .../StandardToVM/test/assignment_ops.mlir | 16 ++++++++++------ third_party/llvm-project | 2 +- 3 files changed, 12 insertions(+), 8 deletions(-) diff --git a/SUBMODULE_VERSIONS.txt b/SUBMODULE_VERSIONS.txt index 758063c9409e..8277552ca90a 100644 --- a/SUBMODULE_VERSIONS.txt +++ b/SUBMODULE_VERSIONS.txt @@ -5,7 +5,7 @@ daff5fead3fbe22c6fc58310ca3f49caf117f185 third_party/benchmark b1fbd33c06cdb0024c67733c6fdec2009d17b384 third_party/googletest 88b845dee001723c4a0db1fe5477de735b6d3bb0 third_party/liburing 013b829185fee6d8eaa515a7e36ec468a2a02600 third_party/llvm-bazel -5657f93e788f093c70fb448dd6f9398b149df278 third_party/llvm-project +0776eca7a4e76bfadc311f3607be3a4f0c0e989a third_party/llvm-project 3483f1653fc7cb3bfb3a4d1b463f3a651ecaa676 third_party/mlir-emitc 98debb127d3a14e0239a3432461e3876d293b409 third_party/mlir-hlo 2b2bd45bbf9be04fd22ece5cc1f54679202e9257 third_party/pffft diff --git a/iree/compiler/Dialect/VM/Conversion/StandardToVM/test/assignment_ops.mlir b/iree/compiler/Dialect/VM/Conversion/StandardToVM/test/assignment_ops.mlir index b3d6de31a377..b12a8b233dd1 100644 --- a/iree/compiler/Dialect/VM/Conversion/StandardToVM/test/assignment_ops.mlir +++ b/iree/compiler/Dialect/VM/Conversion/StandardToVM/test/assignment_ops.mlir @@ -8,13 +8,15 @@ module @my_module { // CHECK: func @my_fn // CHECK-SAME: %[[ARG0:[a-zA-Z0-9$._-]+]] // CHECK-SAME: %[[ARG1:[a-zA-Z0-9$._-]+]] - func @my_fn(%arg0 : i32, %arg1 : i32) -> (i32) { + // CHECK-SAME: %[[ARG2:[a-zA-Z0-9$._-]+]] + // CHECK-SAME: %[[ARG3:[a-zA-Z0-9$._-]+]] + func @my_fn(%arg0 : i32, %arg1 : i32, %arg2 : i32, %arg3 : i32) -> (i32) { // Note that in std, cmp returns an i1 and this relies on the dialect // conversion framework promoting that to i32. // CHECK: %[[CMP:[a-zA-Z0-9$._-]+]] = vm.cmp.eq.i32 %1 = cmpi eq, %arg0, %arg1 : i32 - // CHECK: vm.select.i32 %[[CMP]], %[[ARG0]], %[[ARG1]] : i32 - %2 = select %1, %arg0, %arg1 : i32 + // CHECK: vm.select.i32 %[[CMP]], %[[ARG2]], %[[ARG3]] : i32 + %2 = select %1, %arg2, %arg3 : i32 return %2 : i32 } } @@ -29,13 +31,15 @@ module @my_module { // CHECK: func @my_fn // CHECK-SAME: %[[ARG0:[a-zA-Z0-9$._-]+]] // CHECK-SAME: %[[ARG1:[a-zA-Z0-9$._-]+]] - func @my_fn(%arg0 : index, %arg1 : index) -> (index) { + // CHECK-SAME: %[[ARG2:[a-zA-Z0-9$._-]+]] + // CHECK-SAME: %[[ARG3:[a-zA-Z0-9$._-]+]] + func @my_fn(%arg0 : index, %arg1 : index, %arg2 : index, %arg3 : index) -> (index) { // Note that in std, cmp returns an i1 and this relies on the dialect // conversion framework promoting that to i32. // CHECK: %[[CMP:[a-zA-Z0-9$._-]+]] = vm.cmp.eq.i32 %1 = cmpi eq, %arg0, %arg1 : index - // CHECK: vm.select.i32 %[[CMP]], %[[ARG0]], %[[ARG1]] : i32 - %2 = select %1, %arg0, %arg1 : index + // CHECK: vm.select.i32 %[[CMP]], %[[ARG2]], %[[ARG3]] : i32 + %2 = select %1, %arg2, %arg3 : index return %2 : index } } diff --git a/third_party/llvm-project b/third_party/llvm-project index 5657f93e788f..0776eca7a4e7 160000 --- a/third_party/llvm-project +++ b/third_party/llvm-project @@ -1 +1 @@ -Subproject commit 5657f93e788f093c70fb448dd6f9398b149df278 +Subproject commit 0776eca7a4e76bfadc311f3607be3a4f0c0e989a From f29d6c80e3272fd1b4e5042ba401fefaf975602e Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Tue, 23 Mar 2021 12:18:57 -0700 Subject: [PATCH 5/5] Integrate LLVM at llvm/llvm-project@b24436ac96bd Updates LLVM usage to match [b24436ac96bd](https://github.com/llvm/llvm-project/commit/b24436ac96bd) PiperOrigin-RevId: 364615807 --- SUBMODULE_VERSIONS.txt | 2 +- experimental/ModelBuilder/ModelRunner.cpp | 8 +- .../iree_tf_compiler/TF/ConvertToMHLO.cpp | 12 +- .../conversion/convert_tf_to_tf_strings.cc | 2 +- .../conversion/convert_tf_to_tf_tensorlist.cc | 4 +- .../dialect/utils/conversion_utils.h | 8 +- .../CodegenUtils/ForOpCanonicalization.cpp | 2 +- .../Common/BufferAllocViewCleanUpPass.cpp | 2 +- .../LinalgRewriteDestructiveUpdatesPass.cpp | 2 +- .../compiler/Conversion/Common/Transforms.cpp | 4 +- .../Common/VectorTransferOptimization.cpp | 5 +- .../HLOToHLO/Convert1x1ConvToDot.cpp | 2 +- .../Conversion/HLOToHLO/DecomposeHLOClamp.cpp | 2 +- .../Conversion/HLOToHLO/DemoteF32ToF16.cpp | 4 +- .../HLOToLinalg/FusionOfTensorOps.cpp | 7 +- .../HLOToLinalg/HLOToLinalgOnBuffers.cpp | 6 +- .../HLOToLinalg/HLOToLinalgOnTensors.cpp | 2 +- .../HLOToLinalg/ResolveShapeOps.cpp | 4 +- .../Conversion/HLOToLinalg/test/fusion.mlir | 9 +- .../test/linalg_tensor_to_buffer.mlir | 60 ------- .../ConvImg2ColMatmulConversion.cpp | 2 +- .../Conversion/LinalgToLLVM/ConvertToLLVM.cpp | 30 ++-- .../LinalgToLLVM/FoldTensorExtractOpPass.cpp | 5 +- .../LinalgTileAndVectorizePass.cpp | 21 ++- .../LinalgToLLVM/LinalgVectorizePass.cpp | 17 +- .../LinalgToLLVM/PlanConvLoopOrder.cpp | 5 +- .../Conversion/LinalgToLLVM/UnfuseFMAOps.cpp | 2 +- .../test/matmul_vectorization.mlir | 16 +- .../Conversion/LinalgToNVVM/ConvertToNVVM.cpp | 6 +- .../ConcretizeTileAmongWorkgroupsPass.cpp | 4 +- .../LinalgToSPIRV/ConvertToGPUPass.cpp | 4 +- .../LinalgToSPIRV/ConvertToSPIRVPass.cpp | 20 +-- .../LinalgToSPIRV/FoldGPUProcessorIDUses.cpp | 2 +- .../TileAndVectorizeInOneWorkgroupPass.cpp | 33 ++-- .../LinalgToSPIRV/VectorToGPUPass.cpp | 10 +- .../LinalgToSPIRV/VectorizeMemref.cpp | 2 +- .../materialize_launch_configuration.mlir | 8 +- .../LinalgToSPIRV/test/vector_to_gpu.mlir | 6 +- .../LinalgToVector/LoadStoreVectorization.cpp | 2 +- .../LinalgToVector/VectorizeConv.cpp | 2 +- .../test/vectorize_linalg_conv.mlir | 168 ++++++++++-------- .../Dialect/Flow/IR/test/stream_folding.mlir | 2 +- .../Transforms/ConvertToFlowTensorOps.cpp | 2 +- .../Transforms/DestructiveUpdateUtils.cpp | 2 +- .../Transforms/DispatchLinalgOnTensors.cpp | 8 +- .../Flow/Transforms/HLOToHLOPreprocessing.cpp | 4 +- .../PrePostPartitioningConversion.cpp | 4 +- .../dispatch_linalg_on_tensors_dynamic.mlir | 3 +- .../Flow/Transforms/test/form_streams.mlir | 2 +- .../HAL/Conversion/HALToVM/ConvertHALToVM.cpp | 2 +- .../Dialect/HAL/Transforms/ConvertToHAL.cpp | 2 +- .../HAL/Transforms/MaterializeInterfaces.cpp | 2 +- .../Transforms/ResolveEntryPointOrdinals.cpp | 2 +- .../Shape/Conversion/ConvertShapeToShapex.cpp | 2 +- .../Dialect/Shape/IR/test/canonicalize.mlir | 8 +- .../Transforms/CleanupPlaceholdersPass.cpp | 2 +- .../ConvertHLOToShapeDialectPass.cpp | 2 +- .../MaterializeShapeCalculationsPass.cpp | 4 +- .../StandardToVM/ConvertStandardToVMTest.cpp | 2 +- .../Conversion/VMToEmitC/ConvertVMToEmitC.cpp | 2 +- iree/compiler/Dialect/VM/IR/VMOpFolders.cpp | 28 +++ iree/compiler/Dialect/VM/IR/VMOps.td | 2 + .../Target/Bytecode/BytecodeModuleTarget.cpp | 2 +- .../Dialect/VM/Target/C/CModuleTarget.cpp | 2 +- .../Dialect/VM/Transforms/Conversion.cpp | 2 +- .../VMLA/Conversion/HLOToVMLA/test/fft.mlir | 22 +-- .../Conversion/VMLAToVM/ConvertVMLAToVM.cpp | 2 +- .../Dialect/VMLA/Transforms/Conversion.cpp | 5 +- .../VMLA/Transforms/PreConversionLowering.cpp | 6 +- .../VMLA/Transforms/test/transformation.mlir | 4 +- iree/test/e2e/models/BUILD | 4 +- iree/test/e2e/models/CMakeLists.txt | 17 -- iree/test/e2e/tosa_ops/BUILD | 4 +- iree/test/e2e/tosa_ops/CMakeLists.txt | 2 - third_party/llvm-project | 2 +- 75 files changed, 313 insertions(+), 360 deletions(-) diff --git a/SUBMODULE_VERSIONS.txt b/SUBMODULE_VERSIONS.txt index 8277552ca90a..34b027863ec3 100644 --- a/SUBMODULE_VERSIONS.txt +++ b/SUBMODULE_VERSIONS.txt @@ -5,7 +5,7 @@ daff5fead3fbe22c6fc58310ca3f49caf117f185 third_party/benchmark b1fbd33c06cdb0024c67733c6fdec2009d17b384 third_party/googletest 88b845dee001723c4a0db1fe5477de735b6d3bb0 third_party/liburing 013b829185fee6d8eaa515a7e36ec468a2a02600 third_party/llvm-bazel -0776eca7a4e76bfadc311f3607be3a4f0c0e989a third_party/llvm-project +b24436ac96bdf3f2c545fc85dc8af239d618c9c4 third_party/llvm-project 3483f1653fc7cb3bfb3a4d1b463f3a651ecaa676 third_party/mlir-emitc 98debb127d3a14e0239a3432461e3876d293b409 third_party/mlir-hlo 2b2bd45bbf9be04fd22ece5cc1f54679202e9257 third_party/pffft diff --git a/experimental/ModelBuilder/ModelRunner.cpp b/experimental/ModelBuilder/ModelRunner.cpp index b7c01e2bb210..39741d36c8b0 100644 --- a/experimental/ModelBuilder/ModelRunner.cpp +++ b/experimental/ModelBuilder/ModelRunner.cpp @@ -59,12 +59,10 @@ void mlir::ModelRunner::compile( if (target == Target::CPUTarget) { // Lower vector operations progressively into more elementary // vector operations before running the regular compiler passes. - mlir::OwningRewritePatternList patterns; - mlir::vector::populateVectorSlicesLoweringPatterns(patterns, - module->getContext()); + mlir::OwningRewritePatternList patterns(module->getContext()); + mlir::vector::populateVectorSlicesLoweringPatterns(patterns); mlir::vector::populateVectorContractLoweringPatterns( - patterns, module->getContext(), - compilationOptions.vectorTransformsOptions); + patterns, compilationOptions.vectorTransformsOptions); (void)mlir::applyPatternsAndFoldGreedily(*module, std::move(patterns)); } runLoweringPass(compilationOptions.loweringPasses diff --git a/integrations/tensorflow/iree_tf_compiler/TF/ConvertToMHLO.cpp b/integrations/tensorflow/iree_tf_compiler/TF/ConvertToMHLO.cpp index e6b8de3d0160..40c1c1d59d8d 100644 --- a/integrations/tensorflow/iree_tf_compiler/TF/ConvertToMHLO.cpp +++ b/integrations/tensorflow/iree_tf_compiler/TF/ConvertToMHLO.cpp @@ -58,15 +58,15 @@ class ConvertToMHLOPass : public PassWrapper { // Lower TF Patterns must be separate from canonocalization patterns as // they are sometimes inversions of eachother. - OwningRewritePatternList lowerTfPatterns; + OwningRewritePatternList lowerTfPatterns(&getContext()); mlir::TF::PopulateLoweringTFPatterns(context, &lowerTfPatterns); - OwningRewritePatternList canonicalizePatterns; + OwningRewritePatternList canonicalizePatterns(&getContext()); for (auto *op : context->getRegisteredOperations()) { op->getCanonicalizationPatterns(canonicalizePatterns, context); } - OwningRewritePatternList patterns; + OwningRewritePatternList patterns(&getContext()); // Note that the `OperationConverter` orders patterns lexicographically by: // 1) Ascending legalization depth (i.e., minimum number of patterns // necessary to arrive at conversion target). @@ -98,10 +98,10 @@ class ConvertToMHLOPass : public PassWrapper { DenseSet prevUnconvertedOps; DenseSet unconvertedOps; - FrozenRewritePatternList frozenPatterns(std::move(patterns)); - FrozenRewritePatternList frozenCanonicalizePatterns( + FrozenRewritePatternSet frozenPatterns(std::move(patterns)); + FrozenRewritePatternSet frozenCanonicalizePatterns( std::move(canonicalizePatterns)); - FrozenRewritePatternList frozenTfPatterns(std::move(lowerTfPatterns)); + FrozenRewritePatternSet frozenTfPatterns(std::move(lowerTfPatterns)); while (true) { if (failed( applyPatternsAndFoldGreedily(op, frozenCanonicalizePatterns))) { diff --git a/integrations/tensorflow/iree_tf_compiler/dialect/tf_strings/conversion/convert_tf_to_tf_strings.cc b/integrations/tensorflow/iree_tf_compiler/dialect/tf_strings/conversion/convert_tf_to_tf_strings.cc index c2460b8e3227..55c399f128a8 100644 --- a/integrations/tensorflow/iree_tf_compiler/dialect/tf_strings/conversion/convert_tf_to_tf_strings.cc +++ b/integrations/tensorflow/iree_tf_compiler/dialect/tf_strings/conversion/convert_tf_to_tf_strings.cc @@ -146,7 +146,7 @@ class ConvertTFToTFStringsPass void populateTFToTFStringsPatterns(MLIRContext *ctx, OwningRewritePatternList &patterns) { - populateWithGenerated(ctx, patterns); + populateWithGenerated(patterns); patterns.insert(ctx); patterns.insert(ctx); } diff --git a/integrations/tensorflow/iree_tf_compiler/dialect/tf_tensorlist/conversion/convert_tf_to_tf_tensorlist.cc b/integrations/tensorflow/iree_tf_compiler/dialect/tf_tensorlist/conversion/convert_tf_to_tf_tensorlist.cc index 3e8aa175e395..1a83f35e5be5 100644 --- a/integrations/tensorflow/iree_tf_compiler/dialect/tf_tensorlist/conversion/convert_tf_to_tf_tensorlist.cc +++ b/integrations/tensorflow/iree_tf_compiler/dialect/tf_tensorlist/conversion/convert_tf_to_tf_tensorlist.cc @@ -98,8 +98,8 @@ void ConvertTFToTFTensorListPass::runOnOperation() { // The MLIR type conversion infrastructure doesn't handle this situation well. // It only knows how to handle blindly convert one type to another type. - OwningRewritePatternList patterns; - populateWithGenerated(&getContext(), patterns); + OwningRewritePatternList patterns(&getContext()); + populateWithGenerated(patterns); patterns.insert(&getContext()); ConversionTarget target(getContext()); diff --git a/integrations/tensorflow/iree_tf_compiler/dialect/utils/conversion_utils.h b/integrations/tensorflow/iree_tf_compiler/dialect/utils/conversion_utils.h index 107205f7b032..37942b7e2c06 100644 --- a/integrations/tensorflow/iree_tf_compiler/dialect/utils/conversion_utils.h +++ b/integrations/tensorflow/iree_tf_compiler/dialect/utils/conversion_utils.h @@ -55,7 +55,7 @@ class ConversionPass : public PassWrapper> { LogicalResult run() { auto module = this->getOperation(); - OwningRewritePatternList patterns; + OwningRewritePatternList patterns(&this->getContext()); Converter typeConverter; // Lower to the standard string operations. @@ -82,10 +82,8 @@ class ConversionPass : public PassWrapper> { llvm::all_of(op.getResultTypes(), func); }); - populateFuncOpTypeConversionPattern(patterns, &this->getContext(), - typeConverter); - populateCallOpTypeConversionPattern(patterns, &this->getContext(), - typeConverter); + populateFuncOpTypeConversionPattern(patterns, typeConverter); + populateCallOpTypeConversionPattern(patterns, typeConverter); auto result = applyPartialConversion(module.getOperation(), target, std::move(patterns)); diff --git a/iree/compiler/Conversion/CodegenUtils/ForOpCanonicalization.cpp b/iree/compiler/Conversion/CodegenUtils/ForOpCanonicalization.cpp index df1eb41c86bc..985bdb1ed2ac 100644 --- a/iree/compiler/Conversion/CodegenUtils/ForOpCanonicalization.cpp +++ b/iree/compiler/Conversion/CodegenUtils/ForOpCanonicalization.cpp @@ -217,7 +217,7 @@ struct ForOpCanonicalizationPass : PassWrapper { void runOnFunction() override { FuncOp fn = getFunction(); - OwningRewritePatternList patterns; + OwningRewritePatternList patterns(&getContext()); patterns.insert(fn.getContext()); (void)applyPatternsAndFoldGreedily(fn, std::move(patterns)); diff --git a/iree/compiler/Conversion/Common/BufferAllocViewCleanUpPass.cpp b/iree/compiler/Conversion/Common/BufferAllocViewCleanUpPass.cpp index 63213097cc82..c66b4e434005 100644 --- a/iree/compiler/Conversion/Common/BufferAllocViewCleanUpPass.cpp +++ b/iree/compiler/Conversion/Common/BufferAllocViewCleanUpPass.cpp @@ -108,7 +108,7 @@ struct RemoveDeadMemAllocs : RewritePattern { struct BufferAllocViewCleanUpPass : public PassWrapper { void runOnFunction() override { - OwningRewritePatternList patterns; + OwningRewritePatternList patterns(&getContext()); patterns.insert(&getContext()); patterns.insert(); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); diff --git a/iree/compiler/Conversion/Common/LinalgRewriteDestructiveUpdatesPass.cpp b/iree/compiler/Conversion/Common/LinalgRewriteDestructiveUpdatesPass.cpp index 774305657ba0..a6a0f2d03e7e 100644 --- a/iree/compiler/Conversion/Common/LinalgRewriteDestructiveUpdatesPass.cpp +++ b/iree/compiler/Conversion/Common/LinalgRewriteDestructiveUpdatesPass.cpp @@ -532,7 +532,7 @@ void LinalgRewriteDestructiveUpdates::runOnFunction() { // Non-default canonicalization patterns. // TODO: add Linalg tiling canonicalization patterns, affineminscf and others // as needed. - OwningRewritePatternList canonicalizationPatterns; + OwningRewritePatternList canonicalizationPatterns(&getContext()); scf::ForOp::getCanonicalizationPatterns(canonicalizationPatterns, context); (void)applyPatternsAndFoldGreedily(funcOp, std::move(canonicalizationPatterns)); diff --git a/iree/compiler/Conversion/Common/Transforms.cpp b/iree/compiler/Conversion/Common/Transforms.cpp index 0c4af0b9d9f0..8fdf1545bade 100644 --- a/iree/compiler/Conversion/Common/Transforms.cpp +++ b/iree/compiler/Conversion/Common/Transforms.cpp @@ -45,7 +45,7 @@ namespace iree_compiler { /// easier. void applyCanonicalizationPatternsForTiling(MLIRContext *context, Operation *op) { - OwningRewritePatternList canonicalizationPatterns; + OwningRewritePatternList canonicalizationPatterns(context); canonicalizationPatterns.insert(context); scf::ForOp::getCanonicalizationPatterns(canonicalizationPatterns, context); AffineApplyOp::getCanonicalizationPatterns(canonicalizationPatterns, context); @@ -345,7 +345,7 @@ LogicalResult defineWorkgroupCountRegion( LogicalResult materializeStaticLaunchInformation( FuncOp funcOp, ArrayRef workloadPerWorkgroup) { - OwningRewritePatternList patterns; + OwningRewritePatternList patterns(funcOp.getContext()); patterns.insert(funcOp.getContext(), workloadPerWorkgroup); if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { diff --git a/iree/compiler/Conversion/Common/VectorTransferOptimization.cpp b/iree/compiler/Conversion/Common/VectorTransferOptimization.cpp index c0ad44b6bde6..9904c49a1d8c 100644 --- a/iree/compiler/Conversion/Common/VectorTransferOptimization.cpp +++ b/iree/compiler/Conversion/Common/VectorTransferOptimization.cpp @@ -64,9 +64,8 @@ struct VectorTransferOptimizationPass // Generate vector.shape_cast for dropping leading one dimensions in vector // ops. This increases the chance that we can forward more transfer writes // to transfer reads. - OwningRewritePatternList patterns; - mlir::vector::populateCastAwayVectorLeadingOneDimPatterns( - patterns, funcOp.getContext()); + OwningRewritePatternList patterns(&getContext()); + mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(patterns); (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); vector::transferOpflowOpt(funcOp); diff --git a/iree/compiler/Conversion/HLOToHLO/Convert1x1ConvToDot.cpp b/iree/compiler/Conversion/HLOToHLO/Convert1x1ConvToDot.cpp index 4bfe8ec192aa..2a142f038067 100644 --- a/iree/compiler/Conversion/HLOToHLO/Convert1x1ConvToDot.cpp +++ b/iree/compiler/Conversion/HLOToHLO/Convert1x1ConvToDot.cpp @@ -130,7 +130,7 @@ struct Convert1x1ConvToDotPass void runOnFunction() override { MLIRContext *context = &getContext(); - OwningRewritePatternList patterns; + OwningRewritePatternList patterns(&getContext()); patterns.insert(context); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } diff --git a/iree/compiler/Conversion/HLOToHLO/DecomposeHLOClamp.cpp b/iree/compiler/Conversion/HLOToHLO/DecomposeHLOClamp.cpp index 0adbd571d0d6..d294d3e05484 100644 --- a/iree/compiler/Conversion/HLOToHLO/DecomposeHLOClamp.cpp +++ b/iree/compiler/Conversion/HLOToHLO/DecomposeHLOClamp.cpp @@ -60,7 +60,7 @@ struct DecomposeHLOClampPass void runOnFunction() override { MLIRContext *context = &getContext(); - OwningRewritePatternList patterns; + OwningRewritePatternList patterns(&getContext()); patterns.insert(context); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } diff --git a/iree/compiler/Conversion/HLOToHLO/DemoteF32ToF16.cpp b/iree/compiler/Conversion/HLOToHLO/DemoteF32ToF16.cpp index c66e2807981b..92bdc62bfa73 100644 --- a/iree/compiler/Conversion/HLOToHLO/DemoteF32ToF16.cpp +++ b/iree/compiler/Conversion/HLOToHLO/DemoteF32ToF16.cpp @@ -172,9 +172,9 @@ void ConvertF32ToF16Pass::runOnOperation() { ModuleOp moduleOp = getOperation(); FloatTypeConverter converter; - OwningRewritePatternList patterns; + OwningRewritePatternList patterns(&getContext()); patterns.insert(context, converter); - populateFuncOpTypeConversionPattern(patterns, context, converter); + populateFuncOpTypeConversionPattern(patterns, converter); F32ToF16ConversionTarget target(*context); target.markUnknownOpDynamicallyLegal(); if (failed(applyFullConversion(moduleOp, target, std::move(patterns)))) diff --git a/iree/compiler/Conversion/HLOToLinalg/FusionOfTensorOps.cpp b/iree/compiler/Conversion/HLOToLinalg/FusionOfTensorOps.cpp index 7dc89dc28980..bc5819c32e23 100644 --- a/iree/compiler/Conversion/HLOToLinalg/FusionOfTensorOps.cpp +++ b/iree/compiler/Conversion/HLOToLinalg/FusionOfTensorOps.cpp @@ -73,18 +73,19 @@ struct FusionOfTensorOpsPass } void runOnOperation() override { - OwningRewritePatternList fusionPatterns, interfacePatterns; + OwningRewritePatternList fusionPatterns(&getContext()); + OwningRewritePatternList interfacePatterns(&getContext()); Operation *op = getOperation(); MLIRContext *context = op->getContext(); interfacePatterns.insert(context); - FrozenRewritePatternList frozenInterfacePatterns( + FrozenRewritePatternSet frozenInterfacePatterns( std::move(interfacePatterns)); (void)applyPatternsAndFoldGreedily(op->getRegions(), frozenInterfacePatterns); - populateLinalgTensorOpsFusionPatterns(context, fusionPatterns); + populateLinalgTensorOpsFusionPatterns(fusionPatterns); (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(fusionPatterns)); diff --git a/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp b/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp index 68b7eac97d01..38acef360621 100644 --- a/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp +++ b/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp @@ -825,6 +825,8 @@ static LogicalResult createAndPropagateBufferUsedForResultTensors( // Canonicalization patterns. //===----------------------------------------------------------------------===// +// TODO(hanchung): Revisit the pattern, this seems no longer needed because the +// reshape ops are folded in tensors world. // Folds linalg.reshape op that directly reshaping an iree.placeholder op into // the iree.placeholder op itself. class FoldReshapeIntoPlaceholder final @@ -900,7 +902,7 @@ void ConvertHLOToLinalgOnBuffersPass::runOnFunction() { return signalPassFailure(); } - OwningRewritePatternList patterns; + OwningRewritePatternList patterns(&getContext()); populateHLOToLinalgOnBuffersConversionPatterns(context, patterns, resultTensorToBufferMap); patterns.insert( @@ -940,7 +942,7 @@ void ConvertHLOToLinalgOnBuffersPass::runOnFunction() { // Perform additional canonicalizations. { - OwningRewritePatternList foldingPatterns; + OwningRewritePatternList foldingPatterns(&getContext()); foldingPatterns.insert(context); (void)applyPatternsAndFoldGreedily(funcOp, std::move(foldingPatterns)); } diff --git a/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnTensors.cpp b/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnTensors.cpp index aecec545b36e..cfbc1ae586f1 100644 --- a/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnTensors.cpp +++ b/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnTensors.cpp @@ -194,7 +194,7 @@ struct ConvertHLOToLinalgOnTensorsPass } void runOnFunction() override { - OwningRewritePatternList patterns; + OwningRewritePatternList patterns(&getContext()); MLIRContext *context = &getContext(); populateHLOToLinalgOnTensorsConversionPatterns(context, patterns); if (useLinalgOnTensorsPath) { diff --git a/iree/compiler/Conversion/HLOToLinalg/ResolveShapeOps.cpp b/iree/compiler/Conversion/HLOToLinalg/ResolveShapeOps.cpp index 02d34eca08d1..4f9107d3077f 100644 --- a/iree/compiler/Conversion/HLOToLinalg/ResolveShapeOps.cpp +++ b/iree/compiler/Conversion/HLOToLinalg/ResolveShapeOps.cpp @@ -98,7 +98,7 @@ struct ResolveShapeOpsPass void ResolveShapeOpsPass::runOnFunction() { MLIRContext *context = &getContext(); - OwningRewritePatternList dimPatterns; + OwningRewritePatternList dimPatterns(&getContext()); dimPatterns.insert(context); // Set up a target to convert all std.dim ops. We need a conversion target @@ -111,7 +111,7 @@ void ResolveShapeOpsPass::runOnFunction() { return signalPassFailure(); } - OwningRewritePatternList shapePatterns; + OwningRewritePatternList shapePatterns(&getContext()); shapePatterns.insert(context); Shape::RankedDimOp::getCanonicalizationPatterns(shapePatterns, context); diff --git a/iree/compiler/Conversion/HLOToLinalg/test/fusion.mlir b/iree/compiler/Conversion/HLOToLinalg/test/fusion.mlir index e964a08ea089..a832e1955254 100644 --- a/iree/compiler/Conversion/HLOToLinalg/test/fusion.mlir +++ b/iree/compiler/Conversion/HLOToLinalg/test/fusion.mlir @@ -32,10 +32,9 @@ module { // ----- module { - func @fuse_store_reshape() { + func @fuse_store_reshape(%arg0: tensor<100xi32>) { %c0 = constant 0 : index - %c42 = constant dense<42> : tensor<100xi32> - %0 = linalg.tensor_reshape %c42 [affine_map<(d0, d1) -> (d0, d1)>] : tensor<100xi32> into tensor<4x25xi32> + %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1) -> (d0, d1)>] : tensor<100xi32> into tensor<4x25xi32> hal.interface.store.tensor %0, @legacy_io::@ret0, offset = %c0 : tensor<4x25xi32> return } @@ -45,8 +44,8 @@ module { } // CHECK-LABEL: func @fuse_store_reshape -// CHECK: %[[C42:.+]] = constant dense<{{.+}}> : tensor<100xi32> -// CHECK: hal.interface.store.tensor %[[C42]] +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: tensor<100xi32> +// CHECK: hal.interface.store.tensor %[[ARG0]] // ----- diff --git a/iree/compiler/Conversion/HLOToLinalg/test/linalg_tensor_to_buffer.mlir b/iree/compiler/Conversion/HLOToLinalg/test/linalg_tensor_to_buffer.mlir index d39791f1efff..c3e921a42818 100644 --- a/iree/compiler/Conversion/HLOToLinalg/test/linalg_tensor_to_buffer.mlir +++ b/iree/compiler/Conversion/HLOToLinalg/test/linalg_tensor_to_buffer.mlir @@ -320,66 +320,6 @@ module { // ----- -#map0 = affine_map<(d0, d1) -> (d0, d1)> -#map1 = affine_map<(d0, d1, d2) -> (d0, d1)> -#map2 = affine_map<(d0, d1, d2) -> (d2)> - -module { - func @store_reshape_src_and_result_2() { - %c0 = constant 0 : index - %shape = linalg.init_tensor[2, 4] : tensor<2x4xf32> - %0 = hal.interface.load.tensor @legacy_io::@arg0, offset = %c0 - {operand_result_index = 0 : i32} : tensor<2x4xf32> - %1 = linalg.generic { - indexing_maps = [#map0, #map0], - iterator_types = ["parallel", "parallel"]} - ins(%0 : tensor<2x4xf32>) - outs(%shape : tensor<2x4xf32>) { - ^bb0(%arg0: f32, %s: f32): // no predecessors - %2 = math.tanh %arg0 : f32 - linalg.yield %2 : f32 - } -> tensor<2x4xf32> - %3 = linalg.tensor_reshape %1 [#map1, #map2] - : tensor<2x4xf32> into tensor<1x2x4xf32> - %4 = linalg.tensor_reshape %1 [#map1, #map2] - : tensor<2x4xf32> into tensor<1x2x4xf32> - %5 = linalg.tensor_reshape %1 [#map1, #map2] - : tensor<2x4xf32> into tensor<1x2x4xf32> - hal.interface.store.tensor %3, @legacy_io::@ret0, offset = %c0 - {operand_result_index = 1 : i32} : tensor<1x2x4xf32> - hal.interface.store.tensor %4, @legacy_io::@ret1, offset = %c0 - {operand_result_index = 2 : i32} : tensor<1x2x4xf32> - hal.interface.store.tensor %5, @legacy_io::@ret2, offset = %c0 - {operand_result_index = 3 : i32} : tensor<1x2x4xf32> - return - } - hal.interface @legacy_io attributes {sym_visibility = "private"} { - hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", - access="Read" - hal.interface.binding @ret0, set=0, binding=1, type="StorageBuffer", - access="Write|Discard" - hal.interface.binding @ret1, set=0, binding=2, type="StorageBuffer", - access="Write|Discard" - hal.interface.binding @ret2, set=0, binding=3, type="StorageBuffer", - access="Write|Discard" - } -} - -// CHECK-LABEL: func @store_reshape_src_and_result_2 -// CHECK-DAG: %[[T0:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret2, operand_result_index = 3 : i32} : memref<1x2x4xf32> -// CHECK-DAG: %[[T1:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret2, operand_result_index = 3 : i32} : memref<2x4xf32> -// CHECK-DAG: %[[T2:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret1, operand_result_index = 2 : i32} : memref<1x2x4xf32> -// CHECK-DAG: %[[T3:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0, operand_result_index = 1 : i32} : memref<1x2x4xf32> -// CHECK-DAG: %[[T4:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0, operand_result_index = 0 : i32} : memref<2x4xf32> -// CHECK: linalg.generic -// CHECK-SAME: ins(%[[T4]] : -// CHECK-SAME: outs(%[[T1]] : -// CHECK: linalg.copy(%[[T0]], %[[T3]]) -// CHECK: linalg.copy(%[[T0]], %[[T2]]) -// CHECK: return - -// ----- - #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1)> #map1 = affine_map<(d0, d1, d2, d3) -> (d2, d3)> #map2 = affine_map<(d0, d1) -> (d0, d1)> diff --git a/iree/compiler/Conversion/LinalgToLLVM/ConvImg2ColMatmulConversion.cpp b/iree/compiler/Conversion/LinalgToLLVM/ConvImg2ColMatmulConversion.cpp index 783662d4a51e..26559860d048 100644 --- a/iree/compiler/Conversion/LinalgToLLVM/ConvImg2ColMatmulConversion.cpp +++ b/iree/compiler/Conversion/LinalgToLLVM/ConvImg2ColMatmulConversion.cpp @@ -200,7 +200,7 @@ void populateConvImg2ColMatmulConversionPatterns( void ConvImg2ColMatmulConversionPass::runOnFunction() { auto funcOp = getOperation(); auto context = funcOp.getContext(); - OwningRewritePatternList patterns; + OwningRewritePatternList patterns(&getContext()); populateConvImg2ColMatmulConversionPatterns(context, patterns); (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); } diff --git a/iree/compiler/Conversion/LinalgToLLVM/ConvertToLLVM.cpp b/iree/compiler/Conversion/LinalgToLLVM/ConvertToLLVM.cpp index 118bf89974c0..1e48e2b45925 100644 --- a/iree/compiler/Conversion/LinalgToLLVM/ConvertToLLVM.cpp +++ b/iree/compiler/Conversion/LinalgToLLVM/ConvertToLLVM.cpp @@ -633,26 +633,24 @@ class ConvertToLLVMPass void ConvertToLLVMPass::runOnOperation() { // Run Vector -> Vector transformations ahead of conversion to LLVM. { - OwningRewritePatternList patterns; - vector::populateVectorToVectorCanonicalizationPatterns(patterns, - &getContext()); - vector::populateVectorSlicesLoweringPatterns(patterns, &getContext()); - vector::populateVectorContractLoweringPatterns(patterns, &getContext()); + OwningRewritePatternList patterns(&getContext()); + vector::populateVectorToVectorCanonicalizationPatterns(patterns); + vector::populateVectorSlicesLoweringPatterns(patterns); + vector::populateVectorContractLoweringPatterns(patterns); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } { - OwningRewritePatternList vectorToLoopsPatterns; + OwningRewritePatternList vectorToLoopsPatterns(&getContext()); populateVectorToSCFConversionPatterns( - vectorToLoopsPatterns, &getContext(), - VectorTransferToSCFOptions().setUnroll(true)); + vectorToLoopsPatterns, VectorTransferToSCFOptions().setUnroll(true)); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(vectorToLoopsPatterns)); } // math dialect elementry functions -> polynomial form. { - OwningRewritePatternList mathPatterns; - populateMathPolynomialApproximationPatterns(mathPatterns, &getContext()); + OwningRewritePatternList mathPatterns(&getContext()); + populateMathPolynomialApproximationPatterns(mathPatterns); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(mathPatterns)); } @@ -663,12 +661,12 @@ void ConvertToLLVMPass::runOnOperation() { return success(); }); - OwningRewritePatternList patterns; - populateAffineToStdConversionPatterns(patterns, &getContext()); - populateLoopToStdConversionPatterns(patterns, &getContext()); - populateExpandTanhPattern(patterns, &getContext()); + OwningRewritePatternList patterns(&getContext()); + populateAffineToStdConversionPatterns(patterns); + populateLoopToStdConversionPatterns(patterns); + populateExpandTanhPattern(patterns); populateStdToLLVMConversionPatterns(converter, patterns); - populateVectorToSCFConversionPatterns(patterns, &getContext()); + populateVectorToSCFConversionPatterns(patterns); populateVectorToLLVMMatrixConversionPatterns(converter, patterns); populateVectorToLLVMConversionPatterns(converter, patterns); populateLinalgToLLVMConversionPatterns(converter, patterns); @@ -721,7 +719,7 @@ void ConvertToLLVMPass::runOnOperation() { // Post conversion patterns. { - OwningRewritePatternList postPatterns; + OwningRewritePatternList postPatterns(&getContext()); if (options_.unfuseFMAOps) { populateUnfusedFMAOpsPassPatterns(&getContext(), postPatterns); (void)applyPatternsAndFoldGreedily(module, std::move(postPatterns)); diff --git a/iree/compiler/Conversion/LinalgToLLVM/FoldTensorExtractOpPass.cpp b/iree/compiler/Conversion/LinalgToLLVM/FoldTensorExtractOpPass.cpp index 53e078ba14b6..026dd95cba4c 100644 --- a/iree/compiler/Conversion/LinalgToLLVM/FoldTensorExtractOpPass.cpp +++ b/iree/compiler/Conversion/LinalgToLLVM/FoldTensorExtractOpPass.cpp @@ -62,9 +62,8 @@ class FoldTensorExtractOpPass } // namespace void FoldTensorExtractOpPass::runOnOperation() { - MLIRContext *context = &getContext(); - OwningRewritePatternList patterns; - populateWithGenerated(context, patterns); + OwningRewritePatternList patterns(&getContext()); + populateWithGenerated(patterns); if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) signalPassFailure(); } diff --git a/iree/compiler/Conversion/LinalgToLLVM/LinalgTileAndVectorizePass.cpp b/iree/compiler/Conversion/LinalgToLLVM/LinalgTileAndVectorizePass.cpp index 5ca30867d466..441d9e7950b1 100644 --- a/iree/compiler/Conversion/LinalgToLLVM/LinalgTileAndVectorizePass.cpp +++ b/iree/compiler/Conversion/LinalgToLLVM/LinalgTileAndVectorizePass.cpp @@ -136,7 +136,7 @@ void TileAndVectorizeWorkgroups::runOnFunction() { // Promotes workgroups subviews to a full-tile allocated on the stack. if (clEnablePromoteWorkgroupToFullTiles) { - OwningRewritePatternList promotionPatterns; + OwningRewritePatternList promotionPatterns(&getContext()); promotionPatterns.insert( context, linalg::LinalgPromotionOptions().setAllocationDeallocationFns( @@ -151,7 +151,7 @@ void TileAndVectorizeWorkgroups::runOnFunction() { // Workgroup first level of tiling. { // First level of tiling patterns. (workgroups memory) - OwningRewritePatternList l1patterns; + OwningRewritePatternList l1patterns(&getContext()); l1patterns.insert( linalg::LinalgTilingOptions().setTileSizeComputationFunction( [](OpBuilder &builder, @@ -173,7 +173,7 @@ void TileAndVectorizeWorkgroups::runOnFunction() { // Second level of tiling. (workgroups memory -> vectors) { - OwningRewritePatternList l2patterns; + OwningRewritePatternList l2patterns(&getContext()); l2patterns.insert( linalg::LinalgTilingOptions().setTileSizeComputationFunction( [](OpBuilder &builder, @@ -192,7 +192,7 @@ void TileAndVectorizeWorkgroups::runOnFunction() { // Apply canonicalization. { - OwningRewritePatternList canonicalizationPatterns; + OwningRewritePatternList canonicalizationPatterns(&getContext()); canonicalizationPatterns.insert(context); AffineApplyOp::getCanonicalizationPatterns(canonicalizationPatterns, context); @@ -207,10 +207,10 @@ void TileAndVectorizeWorkgroups::runOnFunction() { // Apply vectorization patterns. { - OwningRewritePatternList vectorizationPatterns; + OwningRewritePatternList vectorizationPatterns(&getContext()); linalg::insertVectorizationPatterns( - vectorizationPatterns, context, linalg::LinalgVectorizationOptions(), + vectorizationPatterns, linalg::LinalgVectorizationOptions(), linalg::LinalgTransformationFilter( Identifier::get(getVectorizeMarker(), context))); if (failed(applyPatternsAndFoldGreedily( @@ -232,7 +232,7 @@ void TileAndVectorizeWorkgroups::runOnFunction() { vector::VectorTransformsOptions vectorTransformsOptions = vector::VectorTransformsOptions().setVectorTransformsOptions( vector::VectorContractLowering::OuterProduct); - OwningRewritePatternList vectorContractLoweringPatterns; + OwningRewritePatternList vectorContractLoweringPatterns(&getContext()); vectorContractLoweringPatterns .insert( @@ -247,16 +247,15 @@ void TileAndVectorizeWorkgroups::runOnFunction() { { VectorTransferToSCFOptions vectorToSCFOptions = VectorTransferToSCFOptions().setUnroll(true); - OwningRewritePatternList vectorToLoopsPatterns; - populateVectorToSCFConversionPatterns(vectorToLoopsPatterns, context, + OwningRewritePatternList vectorToLoopsPatterns(&getContext()); + populateVectorToSCFConversionPatterns(vectorToLoopsPatterns, vectorToSCFOptions); // Hosit hierarchical tiling indexing and other loop invariant transfer // ops computation. linalg::hoistRedundantVectorTransfers(funcOp); // TODO(ataei): Move this to common vector dialect patterns. - populateStdLegalizationPatternsForSPIRVLowering(context, - vectorToLoopsPatterns); + populateStdLegalizationPatternsForSPIRVLowering(vectorToLoopsPatterns); if (failed(applyPatternsAndFoldGreedily( funcOp, std::move(vectorToLoopsPatterns)))) { return signalPassFailure(); diff --git a/iree/compiler/Conversion/LinalgToLLVM/LinalgVectorizePass.cpp b/iree/compiler/Conversion/LinalgToLLVM/LinalgVectorizePass.cpp index ada02740234d..5c7ac44f90a2 100644 --- a/iree/compiler/Conversion/LinalgToLLVM/LinalgVectorizePass.cpp +++ b/iree/compiler/Conversion/LinalgToLLVM/LinalgVectorizePass.cpp @@ -58,10 +58,10 @@ void LinalgVectorizationPass::runOnFunction() { MLIRContext *context = &getContext(); // Apply vectorization patterns. { - OwningRewritePatternList vectorizationPatterns; + OwningRewritePatternList vectorizationPatterns(&getContext()); linalg::insertVectorizationPatterns( - vectorizationPatterns, context, linalg::LinalgVectorizationOptions(), + vectorizationPatterns, linalg::LinalgVectorizationOptions(), linalg::LinalgTransformationFilter(ArrayRef( Identifier::get(getWorkgroupMarker(), context)))); (void)applyPatternsAndFoldGreedily(funcOp, @@ -84,22 +84,21 @@ void LinalgVectorizationPass::runOnFunction() { // Apply unrolling patterns. { - OwningRewritePatternList vectorUnrollPatterns; + OwningRewritePatternList vectorUnrollPatterns(&getContext()); vectorUnrollPatterns.insert( context, vector::UnrollVectorOptions().setNativeShapeFn(getShape)); (void)applyPatternsAndFoldGreedily(funcOp, std::move(vectorUnrollPatterns)); - OwningRewritePatternList canonicalizationPatterns1; + OwningRewritePatternList canonicalizationPatterns1(&getContext()); vector::populateVectorToVectorCanonicalizationPatterns( - canonicalizationPatterns1, funcOp.getContext()); + canonicalizationPatterns1); vector::populateVectorToVectorTransformationPatterns( - canonicalizationPatterns1, funcOp.getContext()); + canonicalizationPatterns1); (void)applyPatternsAndFoldGreedily(funcOp, std::move(canonicalizationPatterns1)); - OwningRewritePatternList canonicalizationPatterns2; - vector::populateVectorSlicesLoweringPatterns(canonicalizationPatterns2, - funcOp.getContext()); + OwningRewritePatternList canonicalizationPatterns2(&getContext()); + vector::populateVectorSlicesLoweringPatterns(canonicalizationPatterns2); (void)applyPatternsAndFoldGreedily(funcOp, std::move(canonicalizationPatterns2)); diff --git a/iree/compiler/Conversion/LinalgToLLVM/PlanConvLoopOrder.cpp b/iree/compiler/Conversion/LinalgToLLVM/PlanConvLoopOrder.cpp index b6a596b99c3d..3f87a4a39359 100644 --- a/iree/compiler/Conversion/LinalgToLLVM/PlanConvLoopOrder.cpp +++ b/iree/compiler/Conversion/LinalgToLLVM/PlanConvLoopOrder.cpp @@ -55,9 +55,8 @@ void PlanConvLoopOrderPass::runOnFunction() { /*output_channel=*/3, }; - OwningRewritePatternList patterns; - linalg::populateLinalgConvGeneralizationPatterns(context, patterns, - firstStepMarker); + OwningRewritePatternList patterns(&getContext()); + linalg::populateLinalgConvGeneralizationPatterns(patterns, firstStepMarker); patterns.insert>( context, loopOrder, secondStepMarker); diff --git a/iree/compiler/Conversion/LinalgToLLVM/UnfuseFMAOps.cpp b/iree/compiler/Conversion/LinalgToLLVM/UnfuseFMAOps.cpp index 9890cf716b8a..d2b0243c3ae8 100644 --- a/iree/compiler/Conversion/LinalgToLLVM/UnfuseFMAOps.cpp +++ b/iree/compiler/Conversion/LinalgToLLVM/UnfuseFMAOps.cpp @@ -58,7 +58,7 @@ void populateUnfusedFMAOpsPassPatterns(MLIRContext *context, void UnfusedFMAOpsPass::runOnFunction() { auto funcOp = getOperation(); auto context = funcOp.getContext(); - OwningRewritePatternList patterns; + OwningRewritePatternList patterns(&getContext()); populateUnfusedFMAOpsPassPatterns(context, patterns); (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); } diff --git a/iree/compiler/Conversion/LinalgToLLVM/test/matmul_vectorization.mlir b/iree/compiler/Conversion/LinalgToLLVM/test/matmul_vectorization.mlir index b64ef0a4939d..98e9489e67f9 100644 --- a/iree/compiler/Conversion/LinalgToLLVM/test/matmul_vectorization.mlir +++ b/iree/compiler/Conversion/LinalgToLLVM/test/matmul_vectorization.mlir @@ -52,14 +52,14 @@ hal.executable @dynamic_matmul attributes {sym_visibility = "private"} { // CHECK-PROMOTED: #[[MAP1:.+]] = affine_map<(d0, d1)[s0] -> (d0 * 128 + s0 + d1)> // CHECK-PROMOTED: func @matmul_128x128x128 // CHECK-PROMOTED: (%[[ARG0:.+]]: memref<128x128xf32>, %[[ARG1:.+]]: memref<128x128xf32>, %[[ARG2:.+]]: memref<128x128xf32>) { -// CHECK-PROMOTED: %[[KDIM_SIZE:.+]] = constant 128 : index -// CHECK-PROMOTED: %[[WORGKROUP_SIZE:.+]] = constant 64 : index -// CHECK-PROMOTED: %[[VECTOR_SIZE:.+]] = constant 4 : index -// CHECK-PROMOTED: %[[L1_SIZE:.+]] = constant 32 : index -// CHECK-PROMOTED: %[[START:.+]] = constant 0 : index -// CHECK-PROMOTED: %[[C1:.+]] = constant 1 : index -// CHECK-PROMOTED: %[[C1:.+]] = constant 2 : index -// CHECK-PROMOTED: %[[C1:.+]] = constant 3 : index +// CHECK-PROMOTED-DAG: %[[KDIM_SIZE:.+]] = constant 128 : index +// CHECK-PROMOTED-DAG: %[[WORGKROUP_SIZE:.+]] = constant 64 : index +// CHECK-PROMOTED-DAG: %[[VECTOR_SIZE:.+]] = constant 4 : index +// CHECK-PROMOTED-DAG: %[[L1_SIZE:.+]] = constant 32 : index +// CHECK-PROMOTED-DAG: %[[START:.+]] = constant 0 : index +// CHECK-PROMOTED-DAG: %[[C1:.+]] = constant 1 : index +// CHECK-PROMOTED-DAG: %[[C1:.+]] = constant 2 : index +// CHECK-PROMOTED-DAG: %[[C1:.+]] = constant 3 : index // CHECK-PROMOTED: %[[A_PROMOTED_TILE:.+]] = memref.alloca() : memref<64x64xf32> // CHECK-PROMOTED: %[[B_PROMOTED_TILE:.+]] = memref.alloca() : memref<128x64xf32> // CHECK-PROMOTED: %[[C_PROMOTED_TILE:.+]] = memref.alloca() : memref<64x128xf32> diff --git a/iree/compiler/Conversion/LinalgToNVVM/ConvertToNVVM.cpp b/iree/compiler/Conversion/LinalgToNVVM/ConvertToNVVM.cpp index 786fa3156a31..75708aef2887 100644 --- a/iree/compiler/Conversion/LinalgToNVVM/ConvertToNVVM.cpp +++ b/iree/compiler/Conversion/LinalgToNVVM/ConvertToNVVM.cpp @@ -184,12 +184,12 @@ struct ConvertToNVVMPass // which need to be lowered further, which is not supported by a single // conversion pass. { - OwningRewritePatternList patterns; - populateGpuRewritePatterns(m.getContext(), patterns); + OwningRewritePatternList patterns(&getContext()); + populateGpuRewritePatterns(patterns); (void)applyPatternsAndFoldGreedily(m, std::move(patterns)); } { - OwningRewritePatternList llvmPatterns; + OwningRewritePatternList llvmPatterns(&getContext()); llvmPatterns.insert(m.getContext(), converter); llvmPatterns diff --git a/iree/compiler/Conversion/LinalgToSPIRV/ConcretizeTileAmongWorkgroupsPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/ConcretizeTileAmongWorkgroupsPass.cpp index abce25c1f694..934b9d9c9777 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/ConcretizeTileAmongWorkgroupsPass.cpp +++ b/iree/compiler/Conversion/LinalgToSPIRV/ConcretizeTileAmongWorkgroupsPass.cpp @@ -462,7 +462,7 @@ class ConcretizeTileAmongWorkgroupsPass // 4. Replace hal.interface.workgroup symbolic ops with constant values. { - OwningRewritePatternList patterns; + OwningRewritePatternList patterns(&getContext()); patterns.insert( &context, workloadSize, tileSize); @@ -530,7 +530,7 @@ class ConcretizeTileAmongWorkgroupsPass // 6. Canonicalization and clean up. if (inlineTripOneLoops) { - OwningRewritePatternList patterns; + OwningRewritePatternList patterns(&getContext()); patterns.insert(&context, workloadSize, tileSize); (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); diff --git a/iree/compiler/Conversion/LinalgToSPIRV/ConvertToGPUPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/ConvertToGPUPass.cpp index bee5dd7f79b5..172e9a95d1f0 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/ConvertToGPUPass.cpp +++ b/iree/compiler/Conversion/LinalgToSPIRV/ConvertToGPUPass.cpp @@ -824,7 +824,7 @@ void ConvertToGPUPass::runOnOperation() { // Let the rest fall through. target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); - OwningRewritePatternList patterns; + OwningRewritePatternList patterns(&getContext()); patterns.insert< MapLinalgOpToGlobalInvocationId, @@ -845,7 +845,7 @@ void ConvertToGPUPass::runOnOperation() { MapLinalgOpToLocalInvocationId, RemoveLinalgRange, SerializeParallelLoopPattern>( context, options.usingLinalgOnTensors); - FrozenRewritePatternList frozenPatterns(std::move(patterns)); + FrozenRewritePatternSet frozenPatterns(std::move(patterns)); for (FuncOp funcOp : getOperation().getInnerModule().getOps()) { if (!isEntryPoint(funcOp)) continue; diff --git a/iree/compiler/Conversion/LinalgToSPIRV/ConvertToSPIRVPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/ConvertToSPIRVPass.cpp index e4808a3dcff3..7b55c8e8657f 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/ConvertToSPIRVPass.cpp +++ b/iree/compiler/Conversion/LinalgToSPIRV/ConvertToSPIRVPass.cpp @@ -539,27 +539,25 @@ void ConvertToSPIRVPass::runOnOperation() { SPIRVTypeConverter typeConverter(targetAttr); ScfToSPIRVContext scfToSPIRVContext; - OwningRewritePatternList patterns; + OwningRewritePatternList patterns(&getContext()); // Pull in GPU patterns to convert processor ID ops and loop ops. - populateGPUToSPIRVPatterns(context, typeConverter, patterns); + populateGPUToSPIRVPatterns(typeConverter, patterns); // Pull in SCF patterns to convert control flow ops. - populateSCFToSPIRVPatterns(context, typeConverter, scfToSPIRVContext, - patterns); + populateSCFToSPIRVPatterns(typeConverter, scfToSPIRVContext, patterns); // Pull in standard patterns to convert arithmetic ops and others. - populateStandardToSPIRVPatterns(context, typeConverter, patterns); + populateStandardToSPIRVPatterns(typeConverter, patterns); // Pull in standard patterns to convert tensor operations to SPIR-V. These are // primarily used to handle tensor-type constants and contain a // threshold. Only those constants that are below the threshold are converted // to SPIR-V. In IREE we want to control this threshold at Flow level. So set // this value arbitrarily high to make sure that everything within a dispatch // region is converted. - mlir::populateTensorToSPIRVPatterns(context, typeConverter, - std::numeric_limits::max() / 8, - patterns); + mlir::populateTensorToSPIRVPatterns( + typeConverter, std::numeric_limits::max() / 8, patterns); // Pull in vector patterns to convert vector ops. - mlir::populateVectorToSPIRVPatterns(context, typeConverter, patterns); + mlir::populateVectorToSPIRVPatterns(typeConverter, patterns); // Pull in builtin func to spv.func conversion. - populateBuiltinFuncToSPIRVPatterns(context, typeConverter, patterns); + populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns); auto &cooperativeMatrixAnalysis = getAnalysis(); populateVectorToSPIRVPatterns(context, typeConverter, patterns, cooperativeMatrixAnalysis); @@ -593,7 +591,7 @@ void ConvertToSPIRVPass::runOnOperation() { functions.push_back(fn); } - FrozenRewritePatternList frozenPatterns(std::move(patterns)); + FrozenRewritePatternSet frozenPatterns(std::move(patterns)); for (FuncOp fn : functions) if (failed(applyFullConversion(fn, *target, frozenPatterns))) return signalPassFailure(); diff --git a/iree/compiler/Conversion/LinalgToSPIRV/FoldGPUProcessorIDUses.cpp b/iree/compiler/Conversion/LinalgToSPIRV/FoldGPUProcessorIDUses.cpp index 323832d9e707..1e35b1388fdf 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/FoldGPUProcessorIDUses.cpp +++ b/iree/compiler/Conversion/LinalgToSPIRV/FoldGPUProcessorIDUses.cpp @@ -275,7 +275,7 @@ struct FoldGPUProcessIDUsesPass void runOnOperation() override { MLIRContext *context = &getContext(); - OwningRewritePatternList patterns; + OwningRewritePatternList patterns(&getContext()); populateFoldGPUProcessorIDUsesPatterns(context, patterns); (void)applyPatternsAndFoldGreedily(getOperation().getInnerModule(), std::move(patterns)); diff --git a/iree/compiler/Conversion/LinalgToSPIRV/TileAndVectorizeInOneWorkgroupPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/TileAndVectorizeInOneWorkgroupPass.cpp index ee9e20fd3712..b351412686d9 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/TileAndVectorizeInOneWorkgroupPass.cpp +++ b/iree/compiler/Conversion/LinalgToSPIRV/TileAndVectorizeInOneWorkgroupPass.cpp @@ -308,7 +308,7 @@ static void populateVectorizationPatterns(MLIRContext *context, OwningRewritePatternList &patterns) { linalg::insertVectorizationPatterns( - patterns, context, linalg::LinalgVectorizationOptions(), + patterns, linalg::LinalgVectorizationOptions(), linalg::LinalgTransformationFilter( Identifier::get(getVectorizeMarker(), context))); } @@ -330,23 +330,21 @@ static void populateVectorUnrollPatterns(MLIRContext *context, static void applyVectorTransformation(FuncOp funcOp) { { - OwningRewritePatternList vectorUnrollPatterns; + OwningRewritePatternList vectorUnrollPatterns(funcOp.getContext()); populateVectorUnrollPatterns(funcOp.getContext(), vectorUnrollPatterns); (void)applyPatternsAndFoldGreedily(funcOp, std::move(vectorUnrollPatterns)); - OwningRewritePatternList canonicalizationPatterns1; + OwningRewritePatternList canonicalizationPatterns1(funcOp.getContext()); vector::populateVectorToVectorCanonicalizationPatterns( - canonicalizationPatterns1, funcOp.getContext()); + canonicalizationPatterns1); vector::populateVectorToVectorTransformationPatterns( - canonicalizationPatterns1, funcOp.getContext()); - vector::populateSplitVectorTransferPatterns(canonicalizationPatterns1, - funcOp.getContext()); + canonicalizationPatterns1); + vector::populateSplitVectorTransferPatterns(canonicalizationPatterns1); (void)applyPatternsAndFoldGreedily(funcOp, std::move(canonicalizationPatterns1)); - OwningRewritePatternList canonicalizationPatterns2; - vector::populateVectorSlicesLoweringPatterns(canonicalizationPatterns2, - funcOp.getContext()); + OwningRewritePatternList canonicalizationPatterns2(funcOp.getContext()); + vector::populateVectorSlicesLoweringPatterns(canonicalizationPatterns2); (void)applyPatternsAndFoldGreedily(funcOp, std::move(canonicalizationPatterns2)); LLVM_DEBUG({ @@ -450,7 +448,7 @@ void TileAndVectorizeInOneWorkgroupPass::runOnOperation() { // The promotion patterns are put separate from the tiling patterns to // make sure that the allocated scratchspace memory is constant sizes // which requires some folding to trigger. - OwningRewritePatternList promotionPatterns; + OwningRewritePatternList promotionPatterns(&getContext()); populatePromotionPatterns(context, promotionPatterns); (void)applyPatternsAndFoldGreedily(funcOp, std::move(promotionPatterns)); applyCanonicalizationPatternsForTiling(context, funcOp); @@ -464,7 +462,7 @@ void TileAndVectorizeInOneWorkgroupPass::runOnOperation() { if (launchConfig.useVectorize()) { { - OwningRewritePatternList secondLevelTilingPatterns; + OwningRewritePatternList secondLevelTilingPatterns(&getContext()); populateTilingToSubgroupPatterns(context, launchConfig, secondLevelTilingPatterns); (void)applyPatternsAndFoldGreedily( @@ -480,7 +478,7 @@ void TileAndVectorizeInOneWorkgroupPass::runOnOperation() { } { - OwningRewritePatternList thirdLevelTilingPatterns; + OwningRewritePatternList thirdLevelTilingPatterns(&getContext()); populateTilingToInvocationPatterns(context, launchConfig, thirdLevelTilingPatterns); (void)applyPatternsAndFoldGreedily(funcOp, @@ -496,7 +494,7 @@ void TileAndVectorizeInOneWorkgroupPass::runOnOperation() { } { - OwningRewritePatternList tilingPatterns; + OwningRewritePatternList tilingPatterns(&getContext()); auto marker = getLinalgMatchAndReplaceMarker( getConvFilterTileMarker(), getVectorizeMarker(), context); populateTilingConvFilterPatterns(context, tilingPatterns, launchConfig, @@ -515,7 +513,7 @@ void TileAndVectorizeInOneWorkgroupPass::runOnOperation() { } { - OwningRewritePatternList vectorizationPatterns; + OwningRewritePatternList vectorizationPatterns(&getContext()); populateVectorizationPatterns(context, launchConfig, vectorizationPatterns); populateVectorizeLinalgConvPatterns(context, vectorizationPatterns); @@ -555,9 +553,8 @@ void TileAndVectorizeInOneWorkgroupPass::runOnOperation() { linalg::DepthwiseConvInputNHWCFilterHWCOp>(op)); }); - OwningRewritePatternList patterns; - linalg::populateLinalgNamedOpsGeneralizationPatterns(context, patterns, - marker); + OwningRewritePatternList patterns(&getContext()); + linalg::populateLinalgNamedOpsGeneralizationPatterns(patterns, marker); (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); diff --git a/iree/compiler/Conversion/LinalgToSPIRV/VectorToGPUPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/VectorToGPUPass.cpp index 79d64cd2c5a4..e07559164d49 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/VectorToGPUPass.cpp +++ b/iree/compiler/Conversion/LinalgToSPIRV/VectorToGPUPass.cpp @@ -180,7 +180,7 @@ void ConvertVectorToGPUPass::tileAndVectorizeLinalgCopy(FuncOp funcOp, return !(hasMarker(copy, getCopyToWorkgroupMemoryMarker())); }); target->markUnknownOpDynamicallyLegal([](Operation *) { return true; }); - OwningRewritePatternList tileAndDistributePattern; + OwningRewritePatternList tileAndDistributePattern(&getContext()); populateLinalgTileAndDistributePatterns(context, tileAndDistributePattern); if (failed(applyPartialConversion(funcOp, *target, std::move(tileAndDistributePattern)))) { @@ -196,9 +196,9 @@ void ConvertVectorToGPUPass::tileAndVectorizeLinalgCopy(FuncOp funcOp, (void)applyPatternsAndFoldGreedily(funcOp, std::move(canonicalizePatterns)); // 3. Vectorize the tiled linalg to be able to map it to load/store vector. - OwningRewritePatternList vectorizationPatterns; + OwningRewritePatternList vectorizationPatterns(&getContext()); linalg::insertVectorizationPatterns( - vectorizationPatterns, context, linalg::LinalgVectorizationOptions(), + vectorizationPatterns, linalg::LinalgVectorizationOptions(), linalg::LinalgTransformationFilter( Identifier::get(getVectorizeMarker(), context), {})); (void)applyPatternsAndFoldGreedily(funcOp, std::move(vectorizationPatterns)); @@ -366,7 +366,7 @@ class ExtractStridedLowering // Lower vector ops to instructions that can be later converted to SPIR-V. void ConvertVectorToGPUPass::lowerVectorOps(FuncOp funcOp, MLIRContext *context) { - OwningRewritePatternList patterns; + OwningRewritePatternList patterns(&getContext()); patterns.insert(context); @@ -381,7 +381,7 @@ void ConvertVectorToGPUPass::runOnOperation() { lowerVectorOps(funcOp, context); auto &cooperativeMatrixAnalysis = getAnalysis(); - OwningRewritePatternList patterns; + OwningRewritePatternList patterns(&getContext()); patterns.insert, VectorTransferReadConversion, VectorTransferWriteConversion>(context, cooperativeMatrixAnalysis); diff --git a/iree/compiler/Conversion/LinalgToSPIRV/VectorizeMemref.cpp b/iree/compiler/Conversion/LinalgToSPIRV/VectorizeMemref.cpp index 8a6fbe0a6322..c968c9aec776 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/VectorizeMemref.cpp +++ b/iree/compiler/Conversion/LinalgToSPIRV/VectorizeMemref.cpp @@ -440,7 +440,7 @@ void VectorizeMemRefPass::runOnOperation() { memrefUsageAnalysis = &getAnalysis(); - OwningRewritePatternList patterns; + OwningRewritePatternList patterns(&getContext()); patterns.insert( context, *memrefUsageAnalysis); diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/materialize_launch_configuration.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/materialize_launch_configuration.mlir index 5fc2981330d4..07cec3eab73c 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/test/materialize_launch_configuration.mlir +++ b/iree/compiler/Conversion/LinalgToSPIRV/test/materialize_launch_configuration.mlir @@ -64,10 +64,10 @@ hal.executable @matmul_tensors attributes {sym_visibility = "private"} { // CHECK-DAG: %[[WGY:.+]] = affine.apply #[[MAP1]]()[%[[ARG1]]] // CHECK: hal.return %[[WGX]], %[[WGY]], %[[C1]] // CHECK-NOT: hal.interface.workgroup.size -// CHECK-DAG: %[[C0:.+]] = constant 0 -// CHECK-DAG: %[[C1:.+]] = constant 1 -// CHECK-DAG: %[[C16:.+]] = constant 16 -// CHECK-DAG: %[[C8:.+]] = constant 8 +// CHECK-DAG: %[[C0:.+]] = constant 0 : index +// CHECK-DAG: %[[C1:.+]] = constant 1 : index +// CHECK-DAG: %[[C16:.+]] = constant 16 : index +// CHECK-DAG: %[[C8:.+]] = constant 8 : index // CHECK-DAG: %[[LHS:.+]] = hal.interface.binding.subspan @legacy_io::@arg0 // CHECK-DAG: %[[RHS:.+]] = hal.interface.binding.subspan @legacy_io::@arg1 // CHECK-DAG: %[[INIT:.+]] = hal.interface.binding.subspan @legacy_io::@arg2 diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/vector_to_gpu.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/vector_to_gpu.mlir index 46f667dda78c..6c9ad186e039 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/test/vector_to_gpu.mlir +++ b/iree/compiler/Conversion/LinalgToSPIRV/test/vector_to_gpu.mlir @@ -33,9 +33,9 @@ module attributes {gpu.container_module, spv.target_env = #spv.target_env<#spv.v } // CHECK: #[[MAP1:.+]] = affine_map<(d0) -> (d0 * 4)> - // CHECK: %[[C1024:.+]] = constant 1024 : index - // CHECK: %[[C8:.+]] = constant 8 : index - // CHECK: %[[C0:.+]] = constant 0 : index + // CHECK-DAG: %[[C1024:.+]] = constant 1024 : index + // CHECK-DAG: %[[C8:.+]] = constant 8 : index + // CHECK-DAG: %[[C0:.+]] = constant 0 : index // CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<128x32xf32, 3> // CHECK: %[[DST:.+]] = memref.subview %{{.+}}[0, 0] [128, 32] [1, 1] : memref<4096x4096xf32> to memref<128x32xf32, #map0> // CHECK: %[[TIDx:.+]] = "gpu.thread_id"() {dimension = "x"} : () -> index diff --git a/iree/compiler/Conversion/LinalgToVector/LoadStoreVectorization.cpp b/iree/compiler/Conversion/LinalgToVector/LoadStoreVectorization.cpp index 0476a21e1e62..221d9711a128 100644 --- a/iree/compiler/Conversion/LinalgToVector/LoadStoreVectorization.cpp +++ b/iree/compiler/Conversion/LinalgToVector/LoadStoreVectorization.cpp @@ -230,7 +230,7 @@ struct LoadStoreVectorizationPass void runOnOperation() override { MLIRContext *context = &getContext(); - OwningRewritePatternList patterns; + OwningRewritePatternList patterns(&getContext()); // clang-format off patterns.insert< VectorizeGenericOp, diff --git a/iree/compiler/Conversion/LinalgToVector/VectorizeConv.cpp b/iree/compiler/Conversion/LinalgToVector/VectorizeConv.cpp index 9e0ee62abb26..14ad4bad5f66 100644 --- a/iree/compiler/Conversion/LinalgToVector/VectorizeConv.cpp +++ b/iree/compiler/Conversion/LinalgToVector/VectorizeConv.cpp @@ -347,7 +347,7 @@ struct VectorizeLinalgConvPass void runOnOperation() override { MLIRContext *context = &getContext(); - OwningRewritePatternList patterns; + OwningRewritePatternList patterns(&getContext()); patterns.insert(context); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } diff --git a/iree/compiler/Conversion/LinalgToVector/test/vectorize_linalg_conv.mlir b/iree/compiler/Conversion/LinalgToVector/test/vectorize_linalg_conv.mlir index e92ec490df82..69e79db29b61 100644 --- a/iree/compiler/Conversion/LinalgToVector/test/vectorize_linalg_conv.mlir +++ b/iree/compiler/Conversion/LinalgToVector/test/vectorize_linalg_conv.mlir @@ -1,9 +1,10 @@ // RUN: iree-opt -split-input-file -iree-codegen-vectorize-linalg-conv -canonicalize -cse %s | IreeFileCheck %s -func @vectorize_conv(%filter: memref<1x1x3x4xf32>, %input: memref<1x2x2x3xf32>, %output: memref<1x2x2x4xf32>) { - %0 = memref.subview %filter[0, 0, 0, 0] [1, 1, 3, 4] [1, 1, 1, 1] : memref<1x1x3x4xf32> to memref<1x1x3x4xf32> - %1 = memref.subview %input[0, 0, 0, 0] [1, 2, 2, 3] [1, 1, 1, 1] : memref<1x2x2x3xf32> to memref<1x2x2x3xf32> - %2 = memref.subview %output[0, 0, 0, 0] [1, 2, 2, 4] [1, 1, 1, 1] : memref<1x2x2x4xf32> to memref<1x2x2x4xf32> +// Passing bigger buffers to avoid memref.subview fold awawy. +func @vectorize_conv(%filter: memref<2x1x3x4xf32>, %input: memref<2x2x2x3xf32>, %output: memref<2x2x2x4xf32>) { + %0 = memref.subview %filter[0, 0, 0, 0] [1, 1, 3, 4] [1, 1, 1, 1] : memref<2x1x3x4xf32> to memref<1x1x3x4xf32> + %1 = memref.subview %input[0, 0, 0, 0] [1, 2, 2, 3] [1, 1, 1, 1] : memref<2x2x2x3xf32> to memref<1x2x2x3xf32> + %2 = memref.subview %output[0, 0, 0, 0] [1, 2, 2, 4] [1, 1, 1, 1] : memref<2x2x2x4xf32> to memref<1x2x2x4xf32> linalg.conv_2d_input_nhwc_filter_hwcf {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins (%1, %0: memref<1x2x2x3xf32>, memref<1x1x3x4xf32>) outs (%2: memref<1x2x2x4xf32>) @@ -15,69 +16,74 @@ func @vectorize_conv(%filter: memref<1x1x3x4xf32>, %input: memref<1x2x2x3xf32>, // CHECK: #map2 = affine_map<(d0, d1, d2) -> (d0, d1)> // CHECK: func @vectorize_conv -// CHECK-SAME: %[[FILTER_ARG:.+]]: memref<1x1x3x4xf32>, -// CHECK-SAME: %[[INPUT_ARG:.+]]: memref<1x2x2x3xf32>, -// CHECK-SAME: %[[OUTPUT_ARG:.+]]: memref<1x2x2x4xf32> +// CHECK-SAME: %[[FILTER_ARG:.+]]: memref<2x1x3x4xf32>, +// CHECK-SAME: %[[INPUT_ARG:.+]]: memref<2x2x2x3xf32>, +// CHECK-SAME: %[[OUTPUT_ARG:.+]]: memref<2x2x2x4xf32> // CHECK: %[[FLOAT_ZERO:.+]] = constant 0.000000e+00 : f32 +// CHECK-DAG: %[[FILTER_SUBVIEW:.+]] = memref.subview %[[FILTER_ARG]]{{.*}} to memref<1x1x3x4xf32> +// CHECK-DAG: %[[INPUT_SUBVIEW:.+]] = memref.subview %[[INPUT_ARG]]{{.*}} to memref<1x2x2x3xf32> +// CHECK-DAG: %[[OUTPUT_SUBVIEW:.+]] = memref.subview %[[OUTPUT_ARG]]{{.*}} to memref<1x2x2x4xf32> + // Read in the filter and get slices -// CHECK: %[[FILTER_VECTOR:.+]] = vector.transfer_read %[[FILTER_ARG]][%c0, %c0, %c0, %c0], %cst {masked = [false, false]} : memref<1x1x3x4xf32>, vector<3x4xf32> +// CHECK: %[[FILTER_VECTOR:.+]] = vector.transfer_read %[[FILTER_SUBVIEW]][%c0, %c0, %c0, %c0], %cst {masked = [false, false]} : memref<1x1x3x4xf32>, vector<3x4xf32> // CHECK: %[[FILTER_0:.+]] = vector.extract_strided_slice %[[FILTER_VECTOR]] {offsets = [0, 0], sizes = [1, 4], strides = [1, 1]} : vector<3x4xf32> to vector<1x4xf32> // CHECK: %[[FILTER_1:.+]] = vector.extract_strided_slice %[[FILTER_VECTOR]] {offsets = [1, 0], sizes = [1, 4], strides = [1, 1]} : vector<3x4xf32> to vector<1x4xf32> // CHECK: %[[FILTER_2:.+]] = vector.extract_strided_slice %[[FILTER_VECTOR]] {offsets = [2, 0], sizes = [1, 4], strides = [1, 1]} : vector<3x4xf32> to vector<1x4xf32> // Handle batch #0 -// CHECK: %[[INPUT_0:.+]] = vector.transfer_read %[[INPUT_ARG]][%c0, %c0, %c0, %c0], %[[FLOAT_ZERO]] {masked = [false, false]} : memref<1x2x2x3xf32>, vector<1x3xf32> -// CHECK: %[[OUTPUT_0:.+]] = vector.transfer_read %[[OUTPUT_ARG]][%c0, %c0, %c0, %c0], %[[FLOAT_ZERO]] {masked = [false, false]} : memref<1x2x2x4xf32>, vector<1x4xf32> +// CHECK: %[[INPUT_0:.+]] = vector.transfer_read %[[INPUT_SUBVIEW]][%c0, %c0, %c0, %c0], %[[FLOAT_ZERO]] {masked = [false, false]} : memref<1x2x2x3xf32>, vector<1x3xf32> +// CHECK: %[[OUTPUT_0:.+]] = vector.transfer_read %[[OUTPUT_SUBVIEW]][%c0, %c0, %c0, %c0], %[[FLOAT_ZERO]] {masked = [false, false]} : memref<1x2x2x4xf32>, vector<1x4xf32> // CHECK: %[[INPUT_0_0:.+]] = vector.extract_strided_slice %[[INPUT_0]] {offsets = [0, 0], sizes = [1, 1], strides = [1, 1]} : vector<1x3xf32> to vector<1x1xf32> // CHECK: %[[DOT_0:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[INPUT_0_0]], %[[FILTER_0]], %[[OUTPUT_0]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32> // CHECK: %[[INPUT_0_1:.+]] = vector.extract_strided_slice %[[INPUT_0]] {offsets = [0, 1], sizes = [1, 1], strides = [1, 1]} : vector<1x3xf32> to vector<1x1xf32> // CHECK: %[[DOT_1:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[INPUT_0_1]], %[[FILTER_1]], %[[DOT_0]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32> // CHECK: %[[INPUT_0_2:.+]] = vector.extract_strided_slice %[[INPUT_0]] {offsets = [0, 2], sizes = [1, 1], strides = [1, 1]} : vector<1x3xf32> to vector<1x1xf32> // CHECK: %[[DOT_2:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[INPUT_0_2]], %[[FILTER_2]], %[[DOT_1]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32> -// CHECK: vector.transfer_write %[[DOT_2]], %[[OUTPUT_ARG]][%c0, %c0, %c0, %c0] {masked = [false, false]} : vector<1x4xf32>, memref<1x2x2x4xf32> +// CHECK: vector.transfer_write %[[DOT_2]], %[[OUTPUT_SUBVIEW]][%c0, %c0, %c0, %c0] {masked = [false, false]} : vector<1x4xf32>, memref<1x2x2x4xf32> // Handle batch #1 -// CHECK: %[[INPUT_1:.+]] = vector.transfer_read %[[INPUT_ARG]][%c0, %c0, %c1, %c0], %[[FLOAT_ZERO]] {masked = [false, false]} : memref<1x2x2x3xf32>, vector<1x3xf32> -// CHECK: %[[OUTPUT_1:.+]] = vector.transfer_read %[[OUTPUT_ARG]][%c0, %c0, %c1, %c0], %[[FLOAT_ZERO]] {masked = [false, false]} : memref<1x2x2x4xf32>, vector<1x4xf32> +// CHECK: %[[INPUT_1:.+]] = vector.transfer_read %[[INPUT_SUBVIEW]][%c0, %c0, %c1, %c0], %[[FLOAT_ZERO]] {masked = [false, false]} : memref<1x2x2x3xf32>, vector<1x3xf32> +// CHECK: %[[OUTPUT_1:.+]] = vector.transfer_read %[[OUTPUT_SUBVIEW]][%c0, %c0, %c1, %c0], %[[FLOAT_ZERO]] {masked = [false, false]} : memref<1x2x2x4xf32>, vector<1x4xf32> // CHECK: %[[INPUT_1_0:.+]] = vector.extract_strided_slice %[[INPUT_1]] {offsets = [0, 0], sizes = [1, 1], strides = [1, 1]} : vector<1x3xf32> to vector<1x1xf32> // CHECK: %[[DOT_0:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[INPUT_1_0]], %[[FILTER_0]], %[[OUTPUT_1]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32> // CHECK: %[[INPUT_1_1:.+]] = vector.extract_strided_slice %[[INPUT_1]] {offsets = [0, 1], sizes = [1, 1], strides = [1, 1]} : vector<1x3xf32> to vector<1x1xf32> // CHECK: %[[DOT_1:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[INPUT_1_1]], %[[FILTER_1]], %[[DOT_0]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32> // CHECK: %[[INPUT_1_2:.+]] = vector.extract_strided_slice %[[INPUT_1]] {offsets = [0, 2], sizes = [1, 1], strides = [1, 1]} : vector<1x3xf32> to vector<1x1xf32> // CHECK: %[[DOT_2:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[INPUT_1_2]], %[[FILTER_2]], %[[DOT_1]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32> -// CHECK: vector.transfer_write %[[DOT_2]], %[[OUTPUT_ARG]][%c0, %c0, %c1, %c0] {masked = [false, false]} : vector<1x4xf32>, memref<1x2x2x4xf32> +// CHECK: vector.transfer_write %[[DOT_2]], %[[OUTPUT_SUBVIEW]][%c0, %c0, %c1, %c0] {masked = [false, false]} : vector<1x4xf32>, memref<1x2x2x4xf32> // Handle batch #2 -// CHECK: %[[INPUT_2:.+]] = vector.transfer_read %[[INPUT_ARG]][%c0, %c1, %c0, %c0], %[[FLOAT_ZERO]] {masked = [false, false]} : memref<1x2x2x3xf32>, vector<1x3xf32> -// CHECK: %[[OUTPUT_2:.+]] = vector.transfer_read %[[OUTPUT_ARG]][%c0, %c1, %c0, %c0], %[[FLOAT_ZERO]] {masked = [false, false]} : memref<1x2x2x4xf32>, vector<1x4xf32> +// CHECK: %[[INPUT_2:.+]] = vector.transfer_read %[[INPUT_SUBVIEW]][%c0, %c1, %c0, %c0], %[[FLOAT_ZERO]] {masked = [false, false]} : memref<1x2x2x3xf32>, vector<1x3xf32> +// CHECK: %[[OUTPUT_2:.+]] = vector.transfer_read %[[OUTPUT_SUBVIEW]][%c0, %c1, %c0, %c0], %[[FLOAT_ZERO]] {masked = [false, false]} : memref<1x2x2x4xf32>, vector<1x4xf32> // CHECK: %[[INPUT_2_0:.+]] = vector.extract_strided_slice %[[INPUT_2]] {offsets = [0, 0], sizes = [1, 1], strides = [1, 1]} : vector<1x3xf32> to vector<1x1xf32> // CHECK: %[[DOT_0:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[INPUT_2_0]], %[[FILTER_0]], %[[OUTPUT_2]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32> // CHECK: %[[INPUT_2_1:.+]] = vector.extract_strided_slice %[[INPUT_2]] {offsets = [0, 1], sizes = [1, 1], strides = [1, 1]} : vector<1x3xf32> to vector<1x1xf32> // CHECK: %[[DOT_1:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[INPUT_2_1]], %[[FILTER_1]], %[[DOT_0]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32> // CHECK: %[[INPUT_2_2:.+]] = vector.extract_strided_slice %[[INPUT_2]] {offsets = [0, 2], sizes = [1, 1], strides = [1, 1]} : vector<1x3xf32> to vector<1x1xf32> // CHECK: %[[DOT_2:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[INPUT_2_2]], %[[FILTER_2]], %[[DOT_1]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32> -// CHECK: vector.transfer_write %[[DOT_2]], %[[OUTPUT_ARG]][%c0, %c1, %c0, %c0] {masked = [false, false]} : vector<1x4xf32>, memref<1x2x2x4xf32> +// CHECK: vector.transfer_write %[[DOT_2]], %[[OUTPUT_SUBVIEW]][%c0, %c1, %c0, %c0] {masked = [false, false]} : vector<1x4xf32>, memref<1x2x2x4xf32> // Handle batch #3 -// CHECK: %[[INPUT_3:.+]] = vector.transfer_read %[[INPUT_ARG]][%c0, %c1, %c1, %c0], %[[FLOAT_ZERO]] {masked = [false, false]} : memref<1x2x2x3xf32>, vector<1x3xf32> -// CHECK: %[[OUTPUT_3:.+]] = vector.transfer_read %[[OUTPUT_ARG]][%c0, %c1, %c1, %c0], %[[FLOAT_ZERO]] {masked = [false, false]} : memref<1x2x2x4xf32>, vector<1x4xf32> +// CHECK: %[[INPUT_3:.+]] = vector.transfer_read %[[INPUT_SUBVIEW]][%c0, %c1, %c1, %c0], %[[FLOAT_ZERO]] {masked = [false, false]} : memref<1x2x2x3xf32>, vector<1x3xf32> +// CHECK: %[[OUTPUT_3:.+]] = vector.transfer_read %[[OUTPUT_SUBVIEW]][%c0, %c1, %c1, %c0], %[[FLOAT_ZERO]] {masked = [false, false]} : memref<1x2x2x4xf32>, vector<1x4xf32> // CHECK: %[[INPUT_3_0:.+]] = vector.extract_strided_slice %[[INPUT_3]] {offsets = [0, 0], sizes = [1, 1], strides = [1, 1]} : vector<1x3xf32> to vector<1x1xf32> // CHECK: %[[DOT_0:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[INPUT_3_0]], %[[FILTER_0]], %[[OUTPUT_3]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32> // CHECK: %[[INPUT_3_1:.+]] = vector.extract_strided_slice %[[INPUT_3]] {offsets = [0, 1], sizes = [1, 1], strides = [1, 1]} : vector<1x3xf32> to vector<1x1xf32> // CHECK: %[[DOT_1:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[INPUT_3_1]], %[[FILTER_1]], %[[DOT_0]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32> // CHECK: %[[INPUT_3_2:.+]] = vector.extract_strided_slice %[[INPUT_3]] {offsets = [0, 2], sizes = [1, 1], strides = [1, 1]} : vector<1x3xf32> to vector<1x1xf32> // CHECK: %[[DOT_2:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[INPUT_3_2]], %[[FILTER_2]], %[[DOT_1]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32> -// CHECK: vector.transfer_write %[[DOT_2]], %[[OUTPUT_ARG]][%c0, %c1, %c1, %c0] {masked = [false, false]} : vector<1x4xf32>, memref<1x2x2x4xf32> +// CHECK: vector.transfer_write %[[DOT_2]], %[[OUTPUT_SUBVIEW]][%c0, %c1, %c1, %c0] {masked = [false, false]} : vector<1x4xf32>, memref<1x2x2x4xf32> // ----- // CHECK-LABEL: func @do_not_vectorize_conv_with_non_1_batch -func @do_not_vectorize_conv_with_non_1_batch(%filter: memref<1x1x4x4xf32>, %input: memref<2x1x7x4xf32>, %output: memref<2x1x4x4xf32>) { - %0 = memref.subview %filter[0, 0, 0, 0] [1, 1, 4, 4] [1, 1, 1, 1] : memref<1x1x4x4xf32> to memref<1x1x4x4xf32> - %1 = memref.subview %input[0, 0, 0, 0] [2, 1, 7, 4] [1, 1, 1, 1] : memref<2x1x7x4xf32> to memref<2x1x7x4xf32> - %2 = memref.subview %output[0, 0, 0, 0] [2, 1, 4, 4] [1, 1, 1, 1] : memref<2x1x4x4xf32> to memref<2x1x4x4xf32> +// Passing bigger buffers to avoid memref.subview fold awawy. +func @do_not_vectorize_conv_with_non_1_batch(%filter: memref<2x1x4x4xf32>, %input: memref<3x1x7x4xf32>, %output: memref<3x1x4x4xf32>) { + %0 = memref.subview %filter[0, 0, 0, 0] [1, 1, 4, 4] [1, 1, 1, 1] : memref<2x1x4x4xf32> to memref<1x1x4x4xf32> + %1 = memref.subview %input[0, 0, 0, 0] [2, 1, 7, 4] [1, 1, 1, 1] : memref<3x1x7x4xf32> to memref<2x1x7x4xf32> + %2 = memref.subview %output[0, 0, 0, 0] [2, 1, 4, 4] [1, 1, 1, 1] : memref<3x1x4x4xf32> to memref<2x1x4x4xf32> // CHECK: linalg.conv_2d_input_nhwc_filter_hwcf linalg.conv_2d_input_nhwc_filter_hwcf {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>} ins (%1, %0: memref<2x1x7x4xf32>, memref<1x1x4x4xf32>) @@ -88,10 +94,11 @@ func @do_not_vectorize_conv_with_non_1_batch(%filter: memref<1x1x4x4xf32>, %inpu // ----- // CHECK-LABEL: func @do_not_vectorize_conv_with_non_1_filter_height -func @do_not_vectorize_conv_with_non_1_filter_height(%filter: memref<2x1x4x4xf32>, %input: memref<1x2x7x4xf32>, %output: memref<1x1x4x4xf32>) { - %0 = memref.subview %filter[0, 0, 0, 0] [2, 1, 4, 4] [1, 1, 1, 1] : memref<2x1x4x4xf32> to memref<2x1x4x4xf32> - %1 = memref.subview %input[0, 0, 0, 0] [1, 2, 7, 4] [1, 1, 1, 1] : memref<1x2x7x4xf32> to memref<1x2x7x4xf32> - %2 = memref.subview %output[0, 0, 0, 0] [1, 1, 4, 4] [1, 1, 1, 1] : memref<1x1x4x4xf32> to memref<1x1x4x4xf32> +// Passing bigger buffers to avoid memref.subview fold awawy. +func @do_not_vectorize_conv_with_non_1_filter_height(%filter: memref<3x1x4x4xf32>, %input: memref<2x2x7x4xf32>, %output: memref<2x1x4x4xf32>) { + %0 = memref.subview %filter[0, 0, 0, 0] [2, 1, 4, 4] [1, 1, 1, 1] : memref<3x1x4x4xf32> to memref<2x1x4x4xf32> + %1 = memref.subview %input[0, 0, 0, 0] [1, 2, 7, 4] [1, 1, 1, 1] : memref<2x2x7x4xf32> to memref<1x2x7x4xf32> + %2 = memref.subview %output[0, 0, 0, 0] [1, 1, 4, 4] [1, 1, 1, 1] : memref<2x1x4x4xf32> to memref<1x1x4x4xf32> // CHECK: linalg.conv_2d_input_nhwc_filter_hwcf linalg.conv_2d_input_nhwc_filter_hwcf {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>} ins (%1, %0: memref<1x2x7x4xf32>, memref<2x1x4x4xf32>) @@ -102,10 +109,11 @@ func @do_not_vectorize_conv_with_non_1_filter_height(%filter: memref<2x1x4x4xf32 // ----- // CHECK-LABEL: func @do_not_vectorize_conv_with_non_1_filter_width -func @do_not_vectorize_conv_with_non_1_filter_width(%filter: memref<1x2x4x4xf32>, %input: memref<1x1x8x4xf32>, %output: memref<1x1x4x4xf32>) { - %0 = memref.subview %filter[0, 0, 0, 0] [1, 2, 4, 4] [1, 1, 1, 1] : memref<1x2x4x4xf32> to memref<1x2x4x4xf32> - %1 = memref.subview %input[0, 0, 0, 0] [1, 1, 8, 4] [1, 1, 1, 1] : memref<1x1x8x4xf32> to memref<1x1x8x4xf32> - %2 = memref.subview %output[0, 0, 0, 0] [1, 1, 4, 4] [1, 1, 1, 1] : memref<1x1x4x4xf32> to memref<1x1x4x4xf32> +// Passing bigger buffers to avoid memref.subview fold awawy. +func @do_not_vectorize_conv_with_non_1_filter_width(%filter: memref<2x2x4x4xf32>, %input: memref<2x1x8x4xf32>, %output: memref<2x1x4x4xf32>) { + %0 = memref.subview %filter[0, 0, 0, 0] [1, 2, 4, 4] [1, 1, 1, 1] : memref<2x2x4x4xf32> to memref<1x2x4x4xf32> + %1 = memref.subview %input[0, 0, 0, 0] [1, 1, 8, 4] [1, 1, 1, 1] : memref<2x1x8x4xf32> to memref<1x1x8x4xf32> + %2 = memref.subview %output[0, 0, 0, 0] [1, 1, 4, 4] [1, 1, 1, 1] : memref<2x1x4x4xf32> to memref<1x1x4x4xf32> // CHECK: linalg.conv_2d_input_nhwc_filter_hwcf linalg.conv_2d_input_nhwc_filter_hwcf {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>} ins (%1, %0: memref<1x1x8x4xf32>, memref<1x2x4x4xf32>) @@ -116,10 +124,11 @@ func @do_not_vectorize_conv_with_non_1_filter_width(%filter: memref<1x2x4x4xf32> // ----- // CHECK-LABEL: func @do_not_vectorize_conv_with_non_1_dilation -func @do_not_vectorize_conv_with_non_1_dilation(%filter: memref<1x1x4x4xf32>, %input: memref<1x1x7x4xf32>, %output: memref<1x1x4x4xf32>) { - %0 = memref.subview %filter[0, 0, 0, 0] [1, 1, 4, 4] [1, 1, 1, 1] : memref<1x1x4x4xf32> to memref<1x1x4x4xf32> - %1 = memref.subview %input[0, 0, 0, 0] [1, 1, 7, 4] [1, 1, 1, 1] : memref<1x1x7x4xf32> to memref<1x1x7x4xf32> - %2 = memref.subview %output[0, 0, 0, 0] [1, 1, 4, 4] [1, 1, 1, 1] : memref<1x1x4x4xf32> to memref<1x1x4x4xf32> +// Passing bigger buffers to avoid memref.subview fold awawy. +func @do_not_vectorize_conv_with_non_1_dilation(%filter: memref<2x1x4x4xf32>, %input: memref<2x1x7x4xf32>, %output: memref<2x1x4x4xf32>) { + %0 = memref.subview %filter[0, 0, 0, 0] [1, 1, 4, 4] [1, 1, 1, 1] : memref<2x1x4x4xf32> to memref<1x1x4x4xf32> + %1 = memref.subview %input[0, 0, 0, 0] [1, 1, 7, 4] [1, 1, 1, 1] : memref<2x1x7x4xf32> to memref<1x1x7x4xf32> + %2 = memref.subview %output[0, 0, 0, 0] [1, 1, 4, 4] [1, 1, 1, 1] : memref<2x1x4x4xf32> to memref<1x1x4x4xf32> // CHECK: linalg.conv_2d_input_nhwc_filter_hwcf linalg.conv_2d_input_nhwc_filter_hwcf {dilations = dense<[2, 1]> : vector<2xi64>, strides = dense<2> : vector<2xi64>} ins (%1, %0: memref<1x1x7x4xf32>, memref<1x1x4x4xf32>) @@ -129,76 +138,82 @@ func @do_not_vectorize_conv_with_non_1_dilation(%filter: memref<1x1x4x4xf32>, %i // ----- -func @vectorize_depthwise_conv(%input: memref<1x3x3x8xf32>, %filter: memref<1x1x8xf32>, %output: memref<1x2x2x8xf32>) { - %0 = memref.subview %input[0, 0, 0, 0] [1, 3, 3, 8] [1, 1, 1, 1] : memref<1x3x3x8xf32> to memref<1x3x3x8xf32> - %1 = memref.subview %filter[0, 0, 0] [1, 1, 8] [1, 1, 1] : memref<1x1x8xf32> to memref<1x1x8xf32> - %2 = memref.subview %output[0, 0, 0, 0] [1, 2, 2, 8] [1, 1, 1, 1] : memref<1x2x2x8xf32> to memref<1x2x2x8xf32> +// Passing bigger buffers to avoid memref.subview fold awawy. +func @vectorize_depthwise_conv(%input: memref<2x3x3x8xf32>, %filter: memref<2x1x8xf32>, %output: memref<2x2x2x8xf32>) { + %0 = memref.subview %input[0, 0, 0, 0] [1, 3, 3, 8] [1, 1, 1, 1] : memref<2x3x3x8xf32> to memref<1x3x3x8xf32> + %1 = memref.subview %filter[0, 0, 0] [1, 1, 8] [1, 1, 1] : memref<2x1x8xf32> to memref<1x1x8xf32> + %2 = memref.subview %output[0, 0, 0, 0] [1, 2, 2, 8] [1, 1, 1, 1] : memref<2x2x2x8xf32> to memref<1x2x2x8xf32> linalg.depthwise_conv_2d_input_nhwc_filter_hwc {strides = dense<2> : tensor<2xi64>} ins(%0, %1 : memref<1x3x3x8xf32>, memref<1x1x8xf32>) outs(%2 : memref<1x2x2x8xf32>) return } // CHECK-LABEL: func @vectorize_depthwise_conv -// CHECK-SAME: %[[INPUT_ARG:.+]]: memref<1x3x3x8xf32>, -// CHECK-SAME: %[[FILTER_ARG:.+]]: memref<1x1x8xf32>, -// CHECK-SAME: %[[OUTPUT_ARG:.+]]: memref<1x2x2x8xf32> +// CHECK-SAME: %[[INPUT_ARG:.+]]: memref<2x3x3x8xf32>, +// CHECK-SAME: %[[FILTER_ARG:.+]]: memref<2x1x8xf32>, +// CHECK-SAME: %[[OUTPUT_ARG:.+]]: memref<2x2x2x8xf32> // CHECK: %[[FLOAT_ZERO:.+]] = constant 0.000000e+00 : f32 -// CHECK: %[[FILTER_VECTOR:.+]] = vector.transfer_read %[[FILTER_ARG]][%c0, %c0, %c0], %cst {masked = [false]} : memref<1x1x8xf32>, vector<8xf32> +// CHECK-DAG: %[[INPUT_SUBVIEW:.+]] = memref.subview %[[INPUT_ARG]]{{.*}} to memref<1x3x3x8xf32> +// CHECK-DAG: %[[FILTER_SUBVIEW:.+]] = memref.subview %[[FILTER_ARG]]{{.*}} to memref<1x1x8xf32> +// CHECK-DAG: %[[OUTPUT_SUBVIEW:.+]] = memref.subview %[[OUTPUT_ARG]]{{.*}} to memref<1x2x2x8xf32> + +// CHECK: %[[FILTER_VECTOR:.+]] = vector.transfer_read %[[FILTER_SUBVIEW]][%c0, %c0, %c0], %cst {masked = [false]} : memref<1x1x8xf32>, vector<8xf32> // Common filter #0 // CHECK: %[[FILTER_0:.+]] = vector.extract_strided_slice %[[FILTER_VECTOR]] {offsets = [0], sizes = [4], strides = [1]} : vector<8xf32> to vector<4xf32> -// CHECK: %[[OUTPUT_0_0:.+]] = vector.transfer_read %[[OUTPUT_ARG]][%c0, %c0, %c0, %c0], %cst {masked = [false]} : memref<1x2x2x8xf32>, vector<4xf32> -// CHECK: %[[INPUT_0_0:.+]] = vector.transfer_read %[[INPUT_ARG]][%c0, %c0, %c0, %c0], %cst {masked = [false]} : memref<1x3x3x8xf32>, vector<4xf32> +// CHECK: %[[OUTPUT_0_0:.+]] = vector.transfer_read %[[OUTPUT_SUBVIEW]][%c0, %c0, %c0, %c0], %cst {masked = [false]} : memref<1x2x2x8xf32>, vector<4xf32> +// CHECK: %[[INPUT_0_0:.+]] = vector.transfer_read %[[INPUT_SUBVIEW]][%c0, %c0, %c0, %c0], %cst {masked = [false]} : memref<1x3x3x8xf32>, vector<4xf32> // CHECK: %[[FMA_0_0:.+]] = vector.fma %[[INPUT_0_0]], %[[FILTER_0]], %[[OUTPUT_0_0]] : vector<4xf32> -// CHECK: vector.transfer_write %[[FMA_0_0]], %[[OUTPUT_ARG]][%c0, %c0, %c0, %c0] {masked = [false]} : vector<4xf32>, memref<1x2x2x8xf32> +// CHECK: vector.transfer_write %[[FMA_0_0]], %[[OUTPUT_SUBVIEW]][%c0, %c0, %c0, %c0] {masked = [false]} : vector<4xf32>, memref<1x2x2x8xf32> -// CHECK: %[[OUTPUT_0_1:.+]] = vector.transfer_read %[[OUTPUT_ARG]][%c0, %c0, %c1, %c0], %cst {masked = [false]} : memref<1x2x2x8xf32>, vector<4xf32> -// CHECK: %[[INPUT_0_1:.+]] = vector.transfer_read %[[INPUT_ARG]][%c0, %c0, %c2, %c0], %cst {masked = [false]} : memref<1x3x3x8xf32>, vector<4xf32> +// CHECK: %[[OUTPUT_0_1:.+]] = vector.transfer_read %[[OUTPUT_SUBVIEW]][%c0, %c0, %c1, %c0], %cst {masked = [false]} : memref<1x2x2x8xf32>, vector<4xf32> +// CHECK: %[[INPUT_0_1:.+]] = vector.transfer_read %[[INPUT_SUBVIEW]][%c0, %c0, %c2, %c0], %cst {masked = [false]} : memref<1x3x3x8xf32>, vector<4xf32> // CHECK: %[[FMA_0_1:.+]] = vector.fma %[[INPUT_0_1]], %[[FILTER_0]], %[[OUTPUT_0_1]] : vector<4xf32> -// CHECK: vector.transfer_write %[[FMA_0_1]], %[[OUTPUT_ARG]][%c0, %c0, %c1, %c0] {masked = [false]} : vector<4xf32>, memref<1x2x2x8xf32> +// CHECK: vector.transfer_write %[[FMA_0_1]], %[[OUTPUT_SUBVIEW]][%c0, %c0, %c1, %c0] {masked = [false]} : vector<4xf32>, memref<1x2x2x8xf32> -// CHECK: %[[OUTPUT_1_0:.+]] = vector.transfer_read %[[OUTPUT_ARG]][%c0, %c1, %c0, %c0], %cst {masked = [false]} : memref<1x2x2x8xf32>, vector<4xf32> -// CHECK: %[[INPUT_1_0:.+]] = vector.transfer_read %[[INPUT_ARG]][%c0, %c2, %c0, %c0], %cst {masked = [false]} : memref<1x3x3x8xf32>, vector<4xf32> +// CHECK: %[[OUTPUT_1_0:.+]] = vector.transfer_read %[[OUTPUT_SUBVIEW]][%c0, %c1, %c0, %c0], %cst {masked = [false]} : memref<1x2x2x8xf32>, vector<4xf32> +// CHECK: %[[INPUT_1_0:.+]] = vector.transfer_read %[[INPUT_SUBVIEW]][%c0, %c2, %c0, %c0], %cst {masked = [false]} : memref<1x3x3x8xf32>, vector<4xf32> // CHECK: %[[FMA_1_0:.+]] = vector.fma %[[INPUT_1_0]], %[[FILTER_0]], %[[OUTPUT_1_0]] : vector<4xf32> -// CHECK: vector.transfer_write %[[FMA_1_0]], %[[OUTPUT_ARG]][%c0, %c1, %c0, %c0] {masked = [false]} : vector<4xf32>, memref<1x2x2x8xf32> +// CHECK: vector.transfer_write %[[FMA_1_0]], %[[OUTPUT_SUBVIEW]][%c0, %c1, %c0, %c0] {masked = [false]} : vector<4xf32>, memref<1x2x2x8xf32> -// CHECK: %[[OUTPUT_1_1:.+]] = vector.transfer_read %[[OUTPUT_ARG]][%c0, %c1, %c1, %c0], %cst {masked = [false]} : memref<1x2x2x8xf32>, vector<4xf32> -// CHECK: %[[INPUT_1_1:.+]] = vector.transfer_read %[[INPUT_ARG]][%c0, %c2, %c2, %c0], %cst {masked = [false]} : memref<1x3x3x8xf32>, vector<4xf32> +// CHECK: %[[OUTPUT_1_1:.+]] = vector.transfer_read %[[OUTPUT_SUBVIEW]][%c0, %c1, %c1, %c0], %cst {masked = [false]} : memref<1x2x2x8xf32>, vector<4xf32> +// CHECK: %[[INPUT_1_1:.+]] = vector.transfer_read %[[INPUT_SUBVIEW]][%c0, %c2, %c2, %c0], %cst {masked = [false]} : memref<1x3x3x8xf32>, vector<4xf32> // CHECK: %[[FMA_1_1:.+]] = vector.fma %[[INPUT_1_1]], %[[FILTER_0]], %[[OUTPUT_1_1]] : vector<4xf32> -// CHECK: vector.transfer_write %[[FMA_1_1]], %[[OUTPUT_ARG]][%c0, %c1, %c1, %c0] {masked = [false]} : vector<4xf32>, memref<1x2x2x8xf32> +// CHECK: vector.transfer_write %[[FMA_1_1]], %[[OUTPUT_SUBVIEW]][%c0, %c1, %c1, %c0] {masked = [false]} : vector<4xf32>, memref<1x2x2x8xf32> // Common filter #1 // CHECK: %[[FILTER_1:.+]] = vector.extract_strided_slice %[[FILTER_VECTOR]] {offsets = [4], sizes = [4], strides = [1]} : vector<8xf32> to vector<4xf32> -// CHECK: %[[OUTPUT_0_0:.+]] = vector.transfer_read %[[OUTPUT_ARG]][%c0, %c0, %c0, %c4], %cst {masked = [false]} : memref<1x2x2x8xf32>, vector<4xf32> -// CHECK: %[[INPUT_0_0:.+]] = vector.transfer_read %[[INPUT_ARG]][%c0, %c0, %c0, %c4], %cst {masked = [false]} : memref<1x3x3x8xf32>, vector<4xf32> +// CHECK: %[[OUTPUT_0_0:.+]] = vector.transfer_read %[[OUTPUT_SUBVIEW]][%c0, %c0, %c0, %c4], %cst {masked = [false]} : memref<1x2x2x8xf32>, vector<4xf32> +// CHECK: %[[INPUT_0_0:.+]] = vector.transfer_read %[[INPUT_SUBVIEW]][%c0, %c0, %c0, %c4], %cst {masked = [false]} : memref<1x3x3x8xf32>, vector<4xf32> // CHECK: %[[FMA_0_0:.+]] = vector.fma %[[INPUT_0_0]], %[[FILTER_1]], %[[OUTPUT_0_0]] : vector<4xf32> -// CHECK: vector.transfer_write %[[FMA_0_0]], %[[OUTPUT_ARG]][%c0, %c0, %c0, %c4] {masked = [false]} : vector<4xf32>, memref<1x2x2x8xf32> +// CHECK: vector.transfer_write %[[FMA_0_0]], %[[OUTPUT_SUBVIEW]][%c0, %c0, %c0, %c4] {masked = [false]} : vector<4xf32>, memref<1x2x2x8xf32> -// CHECK: %[[OUTPUT_0_1:.+]] = vector.transfer_read %[[OUTPUT_ARG]][%c0, %c0, %c1, %c4], %cst {masked = [false]} : memref<1x2x2x8xf32>, vector<4xf32> -// CHECK: %[[INPUT_0_1:.+]] = vector.transfer_read %[[INPUT_ARG]][%c0, %c0, %c2, %c4], %cst {masked = [false]} : memref<1x3x3x8xf32>, vector<4xf32> +// CHECK: %[[OUTPUT_0_1:.+]] = vector.transfer_read %[[OUTPUT_SUBVIEW]][%c0, %c0, %c1, %c4], %cst {masked = [false]} : memref<1x2x2x8xf32>, vector<4xf32> +// CHECK: %[[INPUT_0_1:.+]] = vector.transfer_read %[[INPUT_SUBVIEW]][%c0, %c0, %c2, %c4], %cst {masked = [false]} : memref<1x3x3x8xf32>, vector<4xf32> // CHECK: %[[FMA_0_1:.+]] = vector.fma %[[INPUT_0_1]], %[[FILTER_1]], %[[OUTPUT_0_1]] : vector<4xf32> -// CHECK: vector.transfer_write %[[FMA_0_1]], %[[OUTPUT_ARG]][%c0, %c0, %c1, %c4] {masked = [false]} : vector<4xf32>, memref<1x2x2x8xf32> +// CHECK: vector.transfer_write %[[FMA_0_1]], %[[OUTPUT_SUBVIEW]][%c0, %c0, %c1, %c4] {masked = [false]} : vector<4xf32>, memref<1x2x2x8xf32> -// CHECK: %[[OUTPUT_1_0:.+]] = vector.transfer_read %[[OUTPUT_ARG]][%c0, %c1, %c0, %c4], %cst {masked = [false]} : memref<1x2x2x8xf32>, vector<4xf32> -// CHECK: %[[INPUT_1_0:.+]] = vector.transfer_read %[[INPUT_ARG]][%c0, %c2, %c0, %c4], %cst {masked = [false]} : memref<1x3x3x8xf32>, vector<4xf32> +// CHECK: %[[OUTPUT_1_0:.+]] = vector.transfer_read %[[OUTPUT_SUBVIEW]][%c0, %c1, %c0, %c4], %cst {masked = [false]} : memref<1x2x2x8xf32>, vector<4xf32> +// CHECK: %[[INPUT_1_0:.+]] = vector.transfer_read %[[INPUT_SUBVIEW]][%c0, %c2, %c0, %c4], %cst {masked = [false]} : memref<1x3x3x8xf32>, vector<4xf32> // CHECK: %[[FMA_1_0:.+]] = vector.fma %[[INPUT_1_0]], %[[FILTER_1]], %[[OUTPUT_1_0]] : vector<4xf32> -// CHECK: vector.transfer_write %[[FMA_1_0]], %[[OUTPUT_ARG]][%c0, %c1, %c0, %c4] {masked = [false]} : vector<4xf32>, memref<1x2x2x8xf32> +// CHECK: vector.transfer_write %[[FMA_1_0]], %[[OUTPUT_SUBVIEW]][%c0, %c1, %c0, %c4] {masked = [false]} : vector<4xf32>, memref<1x2x2x8xf32> -// CHECK: %[[OUTPUT_1_1:.+]] = vector.transfer_read %[[OUTPUT_ARG]][%c0, %c1, %c1, %c4], %cst {masked = [false]} : memref<1x2x2x8xf32>, vector<4xf32> -// CHECK: %[[INPUT_1_1:.+]] = vector.transfer_read %[[INPUT_ARG]][%c0, %c2, %c2, %c4], %cst {masked = [false]} : memref<1x3x3x8xf32>, vector<4xf32> +// CHECK: %[[OUTPUT_1_1:.+]] = vector.transfer_read %[[OUTPUT_SUBVIEW]][%c0, %c1, %c1, %c4], %cst {masked = [false]} : memref<1x2x2x8xf32>, vector<4xf32> +// CHECK: %[[INPUT_1_1:.+]] = vector.transfer_read %[[INPUT_SUBVIEW]][%c0, %c2, %c2, %c4], %cst {masked = [false]} : memref<1x3x3x8xf32>, vector<4xf32> // CHECK: %[[FMA_1_1:.+]] = vector.fma %[[INPUT_1_1]], %[[FILTER_1]], %[[OUTPUT_1_1]] : vector<4xf32> -// CHECK: vector.transfer_write %[[FMA_1_1]], %[[OUTPUT_ARG]][%c0, %c1, %c1, %c4] {masked = [false]} : vector<4xf32>, memref<1x2x2x8xf32> +// CHECK: vector.transfer_write %[[FMA_1_1]], %[[OUTPUT_SUBVIEW]][%c0, %c1, %c1, %c4] {masked = [false]} : vector<4xf32>, memref<1x2x2x8xf32> // ----- // CHECK-LABEL: func @do_not_vectorize_depthwise_conv_with_non_1_filter_height -func @do_not_vectorize_depthwise_conv_with_non_1_filter_height(%input: memref<1x2x3x4xf32>, %filter: memref<2x1x4xf32>, %output: memref<1x1x2x4xf32>) { - %0 = memref.subview %input[0, 0, 0, 0] [1, 2, 3, 4] [1, 1, 1, 1] : memref<1x2x3x4xf32> to memref<1x2x3x4xf32> - %1 = memref.subview %filter[0, 0, 0] [2, 1, 4] [1, 1, 1] : memref<2x1x4xf32> to memref<2x1x4xf32> - %2 = memref.subview %output[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : memref<1x1x2x4xf32> to memref<1x1x2x4xf32> +// Passing bigger buffers to avoid memref.subview fold awawy. +func @do_not_vectorize_depthwise_conv_with_non_1_filter_height(%input: memref<2x2x3x4xf32>, %filter: memref<3x1x4xf32>, %output: memref<2x1x2x4xf32>) { + %0 = memref.subview %input[0, 0, 0, 0] [1, 2, 3, 4] [1, 1, 1, 1] : memref<2x2x3x4xf32> to memref<1x2x3x4xf32> + %1 = memref.subview %filter[0, 0, 0] [2, 1, 4] [1, 1, 1] : memref<3x1x4xf32> to memref<2x1x4xf32> + %2 = memref.subview %output[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : memref<2x1x2x4xf32> to memref<1x1x2x4xf32> // CHECK: linalg.depthwise_conv_2d_input_nhwc_filter_hwc linalg.depthwise_conv_2d_input_nhwc_filter_hwc {strides = dense<2> : tensor<2xi64>} ins(%0, %1 : memref<1x2x3x4xf32>, memref<2x1x4xf32>) outs(%2 : memref<1x1x2x4xf32>) return @@ -207,10 +222,11 @@ func @do_not_vectorize_depthwise_conv_with_non_1_filter_height(%input: memref<1x // ----- // CHECK-LABEL: func @do_not_vectorize_depthwise_conv_with_non_1_filter_width -func @do_not_vectorize_depthwise_conv_with_non_1_filter_width(%input: memref<1x1x4x4xf32>, %filter: memref<1x2x4xf32>, %output: memref<1x1x2x4xf32>) { - %0 = memref.subview %input[0, 0, 0, 0] [1, 1, 4, 4] [1, 1, 1, 1] : memref<1x1x4x4xf32> to memref<1x1x4x4xf32> - %1 = memref.subview %filter[0, 0, 0] [1, 2, 4] [1, 1, 1] : memref<1x2x4xf32> to memref<1x2x4xf32> - %2 = memref.subview %output[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : memref<1x1x2x4xf32> to memref<1x1x2x4xf32> +// Passing bigger buffers to avoid memref.subview fold awawy. +func @do_not_vectorize_depthwise_conv_with_non_1_filter_width(%input: memref<2x1x4x4xf32>, %filter: memref<2x2x4xf32>, %output: memref<2x1x2x4xf32>) { + %0 = memref.subview %input[0, 0, 0, 0] [1, 1, 4, 4] [1, 1, 1, 1] : memref<2x1x4x4xf32> to memref<1x1x4x4xf32> + %1 = memref.subview %filter[0, 0, 0] [1, 2, 4] [1, 1, 1] : memref<2x2x4xf32> to memref<1x2x4xf32> + %2 = memref.subview %output[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : memref<2x1x2x4xf32> to memref<1x1x2x4xf32> // CHECK: linalg.depthwise_conv_2d_input_nhwc_filter_hwc linalg.depthwise_conv_2d_input_nhwc_filter_hwc {strides = dense<2> : tensor<2xi64>} ins(%0, %1 : memref<1x1x4x4xf32>, memref<1x2x4xf32>) outs(%2 : memref<1x1x2x4xf32>) return diff --git a/iree/compiler/Dialect/Flow/IR/test/stream_folding.mlir b/iree/compiler/Dialect/Flow/IR/test/stream_folding.mlir index 81ca6be3ebca..12685567cc28 100644 --- a/iree/compiler/Dialect/Flow/IR/test/stream_folding.mlir +++ b/iree/compiler/Dialect/Flow/IR/test/stream_folding.mlir @@ -95,7 +95,7 @@ func @dynamicUpdateSliceImmutability( %start1 = constant 1 : index %workload = constant 8 : index // CHECK: %[[TARGET_CLONE:.+]] = flow.tensor.clone %[[TARGET]] : tensor<2x4xi32> - // CHECK-NEXT: %[[UPDATED:.+]] = flow.tensor.update %[[UPDATE]], %[[TARGET]] + // CHECK: %[[UPDATED:.+]] = flow.tensor.update %[[UPDATE]], %[[TARGET]] %t0 = flow.tensor.update %stream_update, %stream_target[%start0, %start1] : tensor<1x1xi32> -> tensor<2x4xi32> // CHECK-NEXT: %[[RETURN:.+]] = flow.dispatch @ex::@entry[%c8](%[[TARGET_CLONE]], %[[UPDATED]]) %t1 = flow.dispatch @ex::@entry[%workload](%stream_target, %t0) : (tensor<2x4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32> diff --git a/iree/compiler/Dialect/Flow/Transforms/ConvertToFlowTensorOps.cpp b/iree/compiler/Dialect/Flow/Transforms/ConvertToFlowTensorOps.cpp index f92a2b441460..865ce1bd6610 100644 --- a/iree/compiler/Dialect/Flow/Transforms/ConvertToFlowTensorOps.cpp +++ b/iree/compiler/Dialect/Flow/Transforms/ConvertToFlowTensorOps.cpp @@ -96,7 +96,7 @@ struct ConvertToFlowTensorOpsPass FuncOp funcOp = getOperation(); MLIRContext *context = funcOp->getContext(); context->allowUnregisteredDialects(true); - OwningRewritePatternList patterns; + OwningRewritePatternList patterns(&getContext()); patterns.insert(context); if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { return signalPassFailure(); diff --git a/iree/compiler/Dialect/Flow/Transforms/DestructiveUpdateUtils.cpp b/iree/compiler/Dialect/Flow/Transforms/DestructiveUpdateUtils.cpp index 9661159e0615..2e0627aa30a6 100644 --- a/iree/compiler/Dialect/Flow/Transforms/DestructiveUpdateUtils.cpp +++ b/iree/compiler/Dialect/Flow/Transforms/DestructiveUpdateUtils.cpp @@ -486,7 +486,7 @@ LogicalResult rewriteLinalgDestructiveUpdates( // Non-default canonicalization patterns. // TODO(nicolasvasilache): add Linalg tiling canonicalization patterns, // affineminscf and others as needed. - OwningRewritePatternList canonicalizationPatterns; + OwningRewritePatternList canonicalizationPatterns(context); scf::ForOp::getCanonicalizationPatterns(canonicalizationPatterns, context); (void)applyPatternsAndFoldGreedily(dispatchOp, std::move(canonicalizationPatterns)); diff --git a/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp b/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp index 5feb9471dde9..b4ed6f1ffd3c 100644 --- a/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp +++ b/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp @@ -930,7 +930,7 @@ void DispatchLinalgOnTensorsPass::runOnOperation() { // Use the workgroup size as a proxy for tile size here. At the flow level // this represents the "workload" per processors and is not necessarily tied // to the workgroup size specified by the backend. - OwningRewritePatternList patterns; + OwningRewritePatternList patterns(&getContext()); auto linalgTilingOptions = linalg::LinalgTilingOptions() .setDistributionOptions(workgroupDistributionOptions) @@ -945,7 +945,7 @@ void DispatchLinalgOnTensorsPass::runOnOperation() { ArrayRef(), Identifier::get("workgroup", context))); // Add canonicalization patterns. - linalg::populateLinalgTilingCanonicalizationPatterns(patterns, context); + linalg::populateLinalgTilingCanonicalizationPatterns(patterns); patterns.insert(context); (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); } @@ -962,7 +962,7 @@ void DispatchLinalgOnTensorsPass::runOnOperation() { // Move other operations into their own dispatch regions. { - OwningRewritePatternList patterns; + OwningRewritePatternList patterns(&getContext()); patterns.insert(); (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); } @@ -979,7 +979,7 @@ void DispatchLinalgOnTensorsPass::runOnOperation() { // Run necessary canonicalization patterns before destructive updates. { - OwningRewritePatternList patterns; + OwningRewritePatternList patterns(&getContext()); // This is needed because tiling and distribution may create // subtensor_insert ops whose source operands come from tensor.cast ops. // Those tensor.cast ops cast tensors into a more dynamic shape, in order diff --git a/iree/compiler/Dialect/Flow/Transforms/HLOToHLOPreprocessing.cpp b/iree/compiler/Dialect/Flow/Transforms/HLOToHLOPreprocessing.cpp index db16081c17bd..4c3c7b47a596 100644 --- a/iree/compiler/Dialect/Flow/Transforms/HLOToHLOPreprocessing.cpp +++ b/iree/compiler/Dialect/Flow/Transforms/HLOToHLOPreprocessing.cpp @@ -796,7 +796,7 @@ struct HLOToHLOPreprocessing void runOnFunction() override { MLIRContext *context = &getContext(); ConversionTarget conversionTarget(*context); - OwningRewritePatternList conversionPatterns; + OwningRewritePatternList conversionPatterns(&getContext()); // Note that various input modalities may do their own legalization of // CHLO. Converting here allows IREE to accept CHLO dialect regardless of // whether it was legalized away at a higher level. @@ -810,7 +810,7 @@ struct HLOToHLOPreprocessing return signalPassFailure(); } - OwningRewritePatternList patterns; + OwningRewritePatternList patterns(&getContext()); mhlo::PopulateUnfuseBatchNormPatterns(context, &patterns); mhlo::PopulateComplexLoweringPatterns(context, &patterns); mhlo::PopulateGatherToTorchIndexSelectPatterns(context, &patterns); diff --git a/iree/compiler/Dialect/Flow/Transforms/PrePostPartitioningConversion.cpp b/iree/compiler/Dialect/Flow/Transforms/PrePostPartitioningConversion.cpp index 1e49f16d0ee6..2cda52291dd5 100644 --- a/iree/compiler/Dialect/Flow/Transforms/PrePostPartitioningConversion.cpp +++ b/iree/compiler/Dialect/Flow/Transforms/PrePostPartitioningConversion.cpp @@ -69,7 +69,7 @@ class PrePartitioningConversionPass void runOnFunction() override { auto *context = &getContext(); ConversionTarget conversionTarget(*context); - OwningRewritePatternList conversionPatterns; + OwningRewritePatternList conversionPatterns(&getContext()); conversionTarget.addLegalDialect(); @@ -118,7 +118,7 @@ class PostPartitioningConversionPass void runOnFunction() override { auto *context = &getContext(); ConversionTarget conversionTarget(getContext()); - OwningRewritePatternList conversionPatterns; + OwningRewritePatternList conversionPatterns(&getContext()); // We have completed all flow op creation at this point. conversionTarget.addLegalDialect(); diff --git a/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors_dynamic.mlir b/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors_dynamic.mlir index ee1f3fea0f18..9b93d1b05f56 100644 --- a/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors_dynamic.mlir +++ b/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors_dynamic.mlir @@ -194,7 +194,8 @@ func @two_dispatches(%A : tensor, %B : tensor) -> tensor {iree.reflection = {}}, %arg1: !shapex.ranked_shape<[?]> {iree.reflection = {}}) -> (tensor {iree.reflection = {}}) attributes {iree.module.export} { diff --git a/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertHALToVM.cpp b/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertHALToVM.cpp index 7f951297318b..b567f05ccc01 100644 --- a/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertHALToVM.cpp +++ b/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertHALToVM.cpp @@ -127,7 +127,7 @@ class ConvertHALToVMPass StringRef(hal_imports_create()->data, hal_imports_create()->size), innerModuleOp); - OwningRewritePatternList conversionPatterns; + OwningRewritePatternList conversionPatterns(&getContext()); populateStandardToVMPatterns(context, typeConverter, conversionPatterns); SymbolTable importSymbols(innerModuleOp); diff --git a/iree/compiler/Dialect/HAL/Transforms/ConvertToHAL.cpp b/iree/compiler/Dialect/HAL/Transforms/ConvertToHAL.cpp index 64a90c2ed8c6..007478077547 100644 --- a/iree/compiler/Dialect/HAL/Transforms/ConvertToHAL.cpp +++ b/iree/compiler/Dialect/HAL/Transforms/ConvertToHAL.cpp @@ -71,7 +71,7 @@ class ConvertToHALPass HALTypeConverter typeConverter(conversionInterfaces); HALConversionTarget conversionTarget(context, typeConverter); - OwningRewritePatternList patterns; + OwningRewritePatternList patterns(&getContext()); setupIREEToHALLegality(context, conversionTarget); populateIREEToHALPatterns(context, patterns); diff --git a/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp b/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp index 69d4da0b2a61..bdfbd259dc32 100644 --- a/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp +++ b/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp @@ -504,7 +504,7 @@ class MaterializeInterfacesPass } // Convert interface-related flow.dispatch.* ops to their hal.* versions. - OwningRewritePatternList patterns; + OwningRewritePatternList patterns(&getContext()); patterns.insert, diff --git a/iree/compiler/Dialect/HAL/Transforms/ResolveEntryPointOrdinals.cpp b/iree/compiler/Dialect/HAL/Transforms/ResolveEntryPointOrdinals.cpp index 54c19202878d..bd6a32dbc43e 100644 --- a/iree/compiler/Dialect/HAL/Transforms/ResolveEntryPointOrdinals.cpp +++ b/iree/compiler/Dialect/HAL/Transforms/ResolveEntryPointOrdinals.cpp @@ -84,7 +84,7 @@ class ResolveEntryPointOrdinalsPass public: void runOnOperation() override { MLIRContext *context = &getContext(); - OwningRewritePatternList patterns; + OwningRewritePatternList patterns(&getContext()); patterns.insert(context); patterns.insert(context); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); diff --git a/iree/compiler/Dialect/Shape/Conversion/ConvertShapeToShapex.cpp b/iree/compiler/Dialect/Shape/Conversion/ConvertShapeToShapex.cpp index d7b189258055..1aaa9a202bdb 100644 --- a/iree/compiler/Dialect/Shape/Conversion/ConvertShapeToShapex.cpp +++ b/iree/compiler/Dialect/Shape/Conversion/ConvertShapeToShapex.cpp @@ -284,7 +284,7 @@ class ConvertShapeToShapex conversionTarget.addLegalDialect(); // Patterns. - OwningRewritePatternList patterns; + OwningRewritePatternList patterns(&getContext()); patterns.insert(context); patterns.insert(context); patterns.insert(context); diff --git a/iree/compiler/Dialect/Shape/IR/test/canonicalize.mlir b/iree/compiler/Dialect/Shape/IR/test/canonicalize.mlir index 73070967dc46..c92301a9129e 100644 --- a/iree/compiler/Dialect/Shape/IR/test/canonicalize.mlir +++ b/iree/compiler/Dialect/Shape/IR/test/canonicalize.mlir @@ -45,8 +45,8 @@ func @foldStaticRankedDim(%arg0: !shapex.ranked_shape<[1,?,2,?]>) -> (i32, i32) func @foldFullyStaticRankedShape(%arg0: tensor<1x2xf32>) -> (i32, i32) { // CHECK-NOT: shapex.get_ranked_shape // CHECK-NOT: shapex.ranked_dim - // CHECK: constant 1 - // CHECK: constant 2 + // CHECK-DAG: constant 1 + // CHECK-DAG: constant 2 %0 = shapex.get_ranked_shape %arg0 : tensor<1x2xf32> -> !shapex.ranked_shape<[1,2]> %1 = shapex.ranked_dim %0[0] : !shapex.ranked_shape<[1,2]> -> i32 %2 = shapex.ranked_dim %0[1] : !shapex.ranked_shape<[1,2]> -> i32 @@ -74,8 +74,8 @@ func @foldFullyStaticRankedShapeDims(%arg0: tensor<1x2xf32>) -> (i32, i32) { // CHECK-NOT: shapex.get_ranked_shape // CHECK-NOT: shapex.ranked_dims // CHECK-NOT: shapex.ranked_dim - // CHECK: constant 1 - // CHECK: constant 2 + // CHECK-DAG: constant 1 + // CHECK-DAG: constant 2 %0 = shapex.get_ranked_shape %arg0 : tensor<1x2xf32> -> !shapex.ranked_shape<[1,2]> %1:2 = shapex.ranked_dims %0 : !shapex.ranked_shape<[1,2]> -> i32, i32 return %1#0, %1#1 : i32, i32 diff --git a/iree/compiler/Dialect/Shape/Transforms/CleanupPlaceholdersPass.cpp b/iree/compiler/Dialect/Shape/Transforms/CleanupPlaceholdersPass.cpp index e4f11c9eaf8b..762cea8e113d 100644 --- a/iree/compiler/Dialect/Shape/Transforms/CleanupPlaceholdersPass.cpp +++ b/iree/compiler/Dialect/Shape/Transforms/CleanupPlaceholdersPass.cpp @@ -38,7 +38,7 @@ class CleanupTieShapePattern : public OpRewritePattern { class CleanupShapePlaceholdersPass : public PassWrapper { void runOnFunction() override { - OwningRewritePatternList patterns; + OwningRewritePatternList patterns(&getContext()); patterns.insert(&getContext()); (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); } diff --git a/iree/compiler/Dialect/Shape/Transforms/ConvertHLOToShapeDialectPass.cpp b/iree/compiler/Dialect/Shape/Transforms/ConvertHLOToShapeDialectPass.cpp index c1af25b2381f..0b23d90bb928 100644 --- a/iree/compiler/Dialect/Shape/Transforms/ConvertHLOToShapeDialectPass.cpp +++ b/iree/compiler/Dialect/Shape/Transforms/ConvertHLOToShapeDialectPass.cpp @@ -72,7 +72,7 @@ class ConvertHLOToShapePass void runOnFunction() override { ConversionTarget conversionTarget(getContext()); - OwningRewritePatternList conversionPatterns; + OwningRewritePatternList conversionPatterns(&getContext()); conversionTarget.addLegalDialect(); conversionTarget.addLegalDialect(); diff --git a/iree/compiler/Dialect/Shape/Transforms/MaterializeShapeCalculationsPass.cpp b/iree/compiler/Dialect/Shape/Transforms/MaterializeShapeCalculationsPass.cpp index 4ef0e2fa9b17..ff93f8c1767f 100644 --- a/iree/compiler/Dialect/Shape/Transforms/MaterializeShapeCalculationsPass.cpp +++ b/iree/compiler/Dialect/Shape/Transforms/MaterializeShapeCalculationsPass.cpp @@ -57,7 +57,7 @@ class MaterializeShapeCalculationsPass target.addLegalDialect(); setupMaterializeShapeCalculationsLegality(target); - OwningRewritePatternList conversionPatterns; + OwningRewritePatternList conversionPatterns(&getContext()); populateMaterializeShapeCalculationsConversionPatterns(conversionPatterns, context); if (failed(applyPartialConversion(getOperation(), target, @@ -69,7 +69,7 @@ class MaterializeShapeCalculationsPass // And then canonicalize shape ops. // TODO(laurenzo): I would prefer to get the list of ops in the dialect // versus doing this, but I don't know that is possible. - OwningRewritePatternList patterns; + OwningRewritePatternList patterns(&getContext()); CastCompatibleShapeOp::getCanonicalizationPatterns(patterns, context); GetRankedShapeOp::getCanonicalizationPatterns(patterns, context); MakeRankedShapeOp::getCanonicalizationPatterns(patterns, context); diff --git a/iree/compiler/Dialect/VM/Conversion/StandardToVM/ConvertStandardToVMTest.cpp b/iree/compiler/Dialect/VM/Conversion/StandardToVM/ConvertStandardToVMTest.cpp index 15b9ed3a746b..416341e12401 100644 --- a/iree/compiler/Dialect/VM/Conversion/StandardToVM/ConvertStandardToVMTest.cpp +++ b/iree/compiler/Dialect/VM/Conversion/StandardToVM/ConvertStandardToVMTest.cpp @@ -41,7 +41,7 @@ class ConvertStandardToVMTestPass IREE::VM::TypeConverter typeConverter( IREE::VM::getTargetOptionsFromFlags()); - OwningRewritePatternList patterns; + OwningRewritePatternList patterns(&getContext()); populateStandardToVMPatterns(&getContext(), typeConverter, patterns); // NOTE: we allow other dialects besides just VM during this pass as we are diff --git a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp index 3173aa304d65..d9972a3ffffe 100644 --- a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp +++ b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp @@ -354,7 +354,7 @@ class ConvertVMToEmitCPass void runOnOperation() override { ConversionTarget target(getContext()); - OwningRewritePatternList patterns; + OwningRewritePatternList patterns(&getContext()); populateVMToCPatterns(&getContext(), patterns); target.addLegalDialect(); diff --git a/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp b/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp index b002f4fd3c14..246467b3ec76 100644 --- a/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp +++ b/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp @@ -284,10 +284,38 @@ void GlobalStoreIndirectRefOp::getCanonicalizationPatterns( // Constants //===----------------------------------------------------------------------===// +namespace { + +template +struct FoldZeroConstInteger final : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(GeneralOp constOp, + PatternRewriter &rewriter) const override { + if (matchPattern(constOp.result(), m_Zero())) { + rewriter.replaceOpWithNewOp(constOp); + return success(); + } + return failure(); + } +}; + +} // namespace + OpFoldResult ConstI32Op::fold(ArrayRef operands) { return value(); } +void ConstI32Op::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert>(context); +} + OpFoldResult ConstI64Op::fold(ArrayRef operands) { return value(); } +void ConstI64Op::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert>(context); +} + OpFoldResult ConstI32ZeroOp::fold(ArrayRef operands) { return IntegerAttr::get(getResult().getType(), 0); } diff --git a/iree/compiler/Dialect/VM/IR/VMOps.td b/iree/compiler/Dialect/VM/IR/VMOps.td index e624738d9ded..38fca4d27c47 100644 --- a/iree/compiler/Dialect/VM/IR/VMOps.td +++ b/iree/compiler/Dialect/VM/IR/VMOps.td @@ -661,6 +661,7 @@ def VM_ConstI32Op : VM_ConstIntegerOp { let summary = [{32-bit integer constant operation}]; let hasFolder = 1; + let hasCanonicalizer = 1; } def VM_ConstI64Op : @@ -668,6 +669,7 @@ def VM_ConstI64Op : [VM_ExtI64]> { let summary = [{64-bit integer constant operation}]; let hasFolder = 1; + let hasCanonicalizer = 1; } class VM_ConstIntegerZeroOp buildTypeTable(IREE::VM::ModuleOp moduleOp) { // required transformations (such as debug op stripping). static LogicalResult canonicalizeModule(BytecodeTargetOptions targetOptions, IREE::VM::ModuleOp moduleOp) { - OwningRewritePatternList patterns; + OwningRewritePatternList patterns(moduleOp.getContext()); ConversionTarget target(*moduleOp.getContext()); target.addLegalDialect(); target.addLegalOp(); diff --git a/iree/compiler/Dialect/VM/Target/C/CModuleTarget.cpp b/iree/compiler/Dialect/VM/Target/C/CModuleTarget.cpp index 2ee2b0c8e897..b332a606c390 100644 --- a/iree/compiler/Dialect/VM/Target/C/CModuleTarget.cpp +++ b/iree/compiler/Dialect/VM/Target/C/CModuleTarget.cpp @@ -425,7 +425,7 @@ static LogicalResult buildModuleDescriptors(IREE::VM::ModuleOp &moduleOp, // Adapted from BytecodeModuleTarget and extended by C specific passes static LogicalResult canonicalizeModule( IREE::VM::ModuleOp moduleOp, IREE::VM::CTargetOptions targetOptions) { - OwningRewritePatternList patterns; + OwningRewritePatternList patterns(&getContext()); ConversionTarget target(*moduleOp.getContext()); target.addLegalDialect(); target.addLegalOp(); diff --git a/iree/compiler/Dialect/VM/Transforms/Conversion.cpp b/iree/compiler/Dialect/VM/Transforms/Conversion.cpp index b010a6bb3d3a..c74864f86ff1 100644 --- a/iree/compiler/Dialect/VM/Transforms/Conversion.cpp +++ b/iree/compiler/Dialect/VM/Transforms/Conversion.cpp @@ -120,7 +120,7 @@ class ConversionPass } } - OwningRewritePatternList conversionPatterns; + OwningRewritePatternList conversionPatterns(&getContext()); populateIREEToVMPatterns(context, typeConverter, conversionPatterns); populateStandardToVMPatterns(context, typeConverter, conversionPatterns); conversionPatterns.insert(context); diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/fft.mlir b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/fft.mlir index 027bfcd28538..0991c792b24f 100644 --- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/fft.mlir +++ b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/fft.mlir @@ -1,9 +1,9 @@ // RUN: iree-opt -split-input-file -iree-vmla-pre-conversion-lowering -iree-vmla-conversion -canonicalize %s | IreeFileCheck %s func private @fft(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) -> (tensor<8xf32>, tensor<8xf32>) { - // CHECK: [[RS:%.+]] = shapex.const_ranked_shape : !shapex.ranked_shape<[8]> - // CHECK-NEXT: [[C32:%.+]] = constant 32 : index - // CHECK-NEXT: [[OUTBUF1:%.+]] = vmla.buffer.alloc byte_length = [[C32]] : !vmla.buffer + // CHECK-DAG: [[RS:%.+]] = shapex.const_ranked_shape : !shapex.ranked_shape<[8]> + // CHECK-DAG: [[C32:%.+]] = constant 32 : index + // CHECK: [[OUTBUF1:%.+]] = vmla.buffer.alloc byte_length = [[C32]] : !vmla.buffer // CHECK-NEXT: [[OUTBUF2:%.+]] = vmla.buffer.alloc byte_length = [[C32]] : !vmla.buffer // CHECK-NEXT: vmla.fft %arg0([[RS]] : !shapex.ranked_shape<[8]>), %arg1([[RS]] : !shapex.ranked_shape<[8]>), out [[OUTBUF1]], [[OUTBUF2]] : f32 %real, %imag = "vmla.fft.pseudo"(%arg0, %arg1) : (tensor<8xf32>, tensor<8xf32>) -> (tensor<8xf32>, tensor<8xf32>) @@ -11,9 +11,9 @@ func private @fft(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) -> (tensor<8xf32> } func private @ifft(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) -> (tensor<8xf32>, tensor<8xf32>) { - // CHECK: [[RS:%.+]] = shapex.const_ranked_shape : !shapex.ranked_shape<[8]> - // CHECK-NEXT: [[C32:%.+]] = constant 32 : index - // CHECK-NEXT: [[OUTBUF1:%.+]] = vmla.buffer.alloc byte_length = [[C32]] : !vmla.buffer + // CHECK-DAG: [[RS:%.+]] = shapex.const_ranked_shape : !shapex.ranked_shape<[8]> + // CHECK-DAG: [[C32:%.+]] = constant 32 : index + // CHECK: [[OUTBUF1:%.+]] = vmla.buffer.alloc byte_length = [[C32]] : !vmla.buffer // CHECK-NEXT: [[OUTBUF2:%.+]] = vmla.buffer.alloc byte_length = [[C32]] : !vmla.buffer // CHECK-NEXT: vmla.ifft %arg0([[RS]] : !shapex.ranked_shape<[8]>), %arg1([[RS]] : !shapex.ranked_shape<[8]>), out [[OUTBUF1]], [[OUTBUF2]] : f32 %real, %imag = "vmla.ifft.pseudo"(%arg0, %arg1) : (tensor<8xf32>, tensor<8xf32>) -> (tensor<8xf32>, tensor<8xf32>) @@ -21,9 +21,9 @@ func private @ifft(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) -> (tensor<8xf32 } func private @rfft(%arg0: tensor<8xf32>) -> (tensor<5xf32>, tensor<5xf32>) { - // CHECK: [[RS:%.+]] = shapex.const_ranked_shape : !shapex.ranked_shape<[8]> - // CHECK-NEXT: [[C20:%.+]] = constant 20 : index - // CHECK-NEXT: [[OUTBUF1:%.+]] = vmla.buffer.alloc byte_length = [[C20]] : !vmla.buffer + // CHECK-DAG: [[RS:%.+]] = shapex.const_ranked_shape : !shapex.ranked_shape<[8]> + // CHECK-DAG: [[C20:%.+]] = constant 20 : index + // CHECK: [[OUTBUF1:%.+]] = vmla.buffer.alloc byte_length = [[C20]] : !vmla.buffer // CHECK-NEXT: [[OUTBUF2:%.+]] = vmla.buffer.alloc byte_length = [[C20]] : !vmla.buffer // CHECK-NEXT: vmla.rfft %arg0([[RS]] : !shapex.ranked_shape<[8]>), out [[OUTBUF1]], [[OUTBUF2]] : f32 %real, %imag = "vmla.rfft.pseudo"(%arg0) : (tensor<8xf32>) -> (tensor<5xf32>, tensor<5xf32>) @@ -31,8 +31,8 @@ func private @rfft(%arg0: tensor<8xf32>) -> (tensor<5xf32>, tensor<5xf32>) { } func private @irfft(%arg0: tensor<5xf32>, %arg1: tensor<5xf32>) -> tensor<8xf32> { - // CHECK: [[RS:%.+]] = shapex.const_ranked_shape : !shapex.ranked_shape<[5]> - // CHECK-NEXT: [[C32:%.+]] = constant 32 : index + // CHECK-DAG: [[RS:%.+]] = shapex.const_ranked_shape : !shapex.ranked_shape<[5]> + // CHECK-DAG: [[C32:%.+]] = constant 32 : index // CHECK-NEXT: [[OUTBUF1:%.+]] = vmla.buffer.alloc byte_length = [[C32]] : !vmla.buffer // CHECK-NEXT: vmla.irfft %arg0([[RS]] : !shapex.ranked_shape<[5]>), %arg1([[RS]] : !shapex.ranked_shape<[5]>), out [[OUTBUF1]] : f32 %real = "vmla.irfft.pseudo"(%arg0, %arg1) : (tensor<5xf32>, tensor<5xf32>) -> (tensor<8xf32>) diff --git a/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/ConvertVMLAToVM.cpp b/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/ConvertVMLAToVM.cpp index ad011a8b0b21..b49c6e3fdb1f 100644 --- a/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/ConvertVMLAToVM.cpp +++ b/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/ConvertVMLAToVM.cpp @@ -402,7 +402,7 @@ class ConvertVMLAToVMPass StringRef(vmla_imports_create()->data, vmla_imports_create()->size), innerModuleOp); - OwningRewritePatternList conversionPatterns; + OwningRewritePatternList conversionPatterns(&getContext()); populateStandardToVMPatterns(context, typeConverter, conversionPatterns); SymbolTable importSymbols(innerModuleOp); diff --git a/iree/compiler/Dialect/VMLA/Transforms/Conversion.cpp b/iree/compiler/Dialect/VMLA/Transforms/Conversion.cpp index be6a9073eea5..cfdab50ae968 100644 --- a/iree/compiler/Dialect/VMLA/Transforms/Conversion.cpp +++ b/iree/compiler/Dialect/VMLA/Transforms/Conversion.cpp @@ -85,14 +85,13 @@ class ConversionPass conversionTarget.addIllegalDialect(); conversionTarget.addIllegalDialect(); - OwningRewritePatternList conversionPatterns; + OwningRewritePatternList conversionPatterns(&getContext()); populateStandardToVMLAPatterns(context, conversionPatterns, typeConverter); populateHLOToVMLAPatterns(context, conversionPatterns, typeConverter); populateHALToVMLAPatterns(context, conversionPatterns, typeConverter); // Ensure FuncOp signatures are updated. - populateFuncOpTypeConversionPattern(conversionPatterns, context, - typeConverter); + populateFuncOpTypeConversionPattern(conversionPatterns, typeConverter); // We allow the shape dialect to persist, making specific dim queries // illegal (which allows them to fold away). These patterns allow dimension diff --git a/iree/compiler/Dialect/VMLA/Transforms/PreConversionLowering.cpp b/iree/compiler/Dialect/VMLA/Transforms/PreConversionLowering.cpp index 4b22d31b9d4d..1afdbd20761c 100644 --- a/iree/compiler/Dialect/VMLA/Transforms/PreConversionLowering.cpp +++ b/iree/compiler/Dialect/VMLA/Transforms/PreConversionLowering.cpp @@ -470,14 +470,14 @@ class PreConversionLoweringPass // These patterns should be run greedily as they are not dialect // conversions. - OwningRewritePatternList greedyPatterns; + OwningRewritePatternList greedyPatterns(&getContext()); mhlo::PopulateComplexLoweringPatterns(context, &greedyPatterns); if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(greedyPatterns)))) { return signalPassFailure(); } - OwningRewritePatternList patterns; + OwningRewritePatternList patterns(&getContext()); ConversionTarget target(*context); target.addLegalDialect(); target.addLegalDialect(); @@ -503,7 +503,7 @@ class PreConversionLoweringPass } { - OwningRewritePatternList greedyPatterns; + OwningRewritePatternList greedyPatterns(&getContext()); greedyPatterns.insert(context); if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(greedyPatterns)))) { diff --git a/iree/compiler/Dialect/VMLA/Transforms/test/transformation.mlir b/iree/compiler/Dialect/VMLA/Transforms/test/transformation.mlir index 493aa88aabca..9e7b73c5c579 100644 --- a/iree/compiler/Dialect/VMLA/Transforms/test/transformation.mlir +++ b/iree/compiler/Dialect/VMLA/Transforms/test/transformation.mlir @@ -17,8 +17,8 @@ hal.interface @legacy_io attributes {sym_visibility = "private"} { } // CHECK: func @simpleMath_rgn_dispatch_0(%arg0: !vmla.interface, %arg1: index, %arg2: index, %arg3: index) { -// CHECK-NEXT: %c0 = constant 0 : index -// CHECK-NEXT: %c16 = constant 16 : index +// CHECK-DAG: %c0 = constant 0 : index +// CHECK-DAG: %c16 = constant 16 : index // CHECK-NEXT: %0 = vmla.interface.binding %arg0 {binding = 0 : i32, set = 0 : i32} : !vmla.buffer // CHECK-NEXT: %1 = vmla.buffer.view %0[%c0], byte_length = %c16 : !vmla.buffer // CHECK-NEXT: %2 = vmla.buffer.alloc byte_length = %c16 : !vmla.buffer diff --git a/iree/test/e2e/models/BUILD b/iree/test/e2e/models/BUILD index d897e764de68..d311688884ea 100644 --- a/iree/test/e2e/models/BUILD +++ b/iree/test/e2e/models/BUILD @@ -72,9 +72,7 @@ iree_check_single_backend_test_suite( iree_check_single_backend_test_suite( name = "check_linalg_on_tensors_vulkan-spirv_vulkan", - srcs = [ - "mobilenetv2_fake_weights.mlir", - ], + srcs = CHECK_FRAMEWORK_TESTS, compiler_flags = [ "-iree-flow-dispatch-linalg-on-tensors", "-iree-codegen-spirv-experimental-linalg-on-tensors", diff --git a/iree/test/e2e/models/CMakeLists.txt b/iree/test/e2e/models/CMakeLists.txt index 0a8aced8b2c8..9a5ffa4868d6 100644 --- a/iree/test/e2e/models/CMakeLists.txt +++ b/iree/test/e2e/models/CMakeLists.txt @@ -46,26 +46,10 @@ iree_check_single_backend_test_suite( check_vulkan-spirv_vulkan SRCS "bert_encoder_unrolled_fake_weights.mlir" - "mobilenetv2_fake_weights.mlir" - TARGET_BACKEND - "vulkan-spirv" - DRIVER - "vulkan" -) - -iree_check_single_backend_test_suite( - NAME - check_linalg_on_tensors_vulkan-spirv_vulkan - SRCS - "mobilenetv2_fake_weights.mlir" TARGET_BACKEND "vulkan-spirv" DRIVER "vulkan" - COMPILER_FLAGS - "-iree-flow-dispatch-linalg-on-tensors" - "-iree-codegen-spirv-experimental-linalg-on-tensors" - "-iree-spirv-enable-vectorization" ) iree_check_single_backend_test_suite( @@ -73,7 +57,6 @@ iree_check_single_backend_test_suite( check_linalg_on_tensors_dylib-llvm-aot_dylib SRCS "bert_encoder_unrolled_fake_weights.mlir" - "mobilenetv2_fake_weights.mlir" TARGET_BACKEND "dylib-llvm-aot" DRIVER diff --git a/iree/test/e2e/tosa_ops/BUILD b/iree/test/e2e/tosa_ops/BUILD index cb3a05392252..06a3544b2e1c 100644 --- a/iree/test/e2e/tosa_ops/BUILD +++ b/iree/test/e2e/tosa_ops/BUILD @@ -47,7 +47,6 @@ ALL_SRCS = enforce_glob( "logical_right_shift.mlir", "maximum.mlir", "minimum.mlir", - "mul.mlir", "negate.mlir", "reluN.mlir", "reshape.mlir", @@ -59,6 +58,9 @@ ALL_SRCS = enforce_glob( "while.mlir", ], include = ["*.mlir"], + exclude = [ + "mul.mlir", # TODO(suderman): Re-enable once apply_scale lowering lands. + ], ) iree_check_single_backend_test_suite( diff --git a/iree/test/e2e/tosa_ops/CMakeLists.txt b/iree/test/e2e/tosa_ops/CMakeLists.txt index 3948a95a51dd..0fc880bf3425 100644 --- a/iree/test/e2e/tosa_ops/CMakeLists.txt +++ b/iree/test/e2e/tosa_ops/CMakeLists.txt @@ -32,7 +32,6 @@ iree_check_single_backend_test_suite( "logical_right_shift.mlir" "maximum.mlir" "minimum.mlir" - "mul.mlir" "negate.mlir" "reluN.mlir" "reshape.mlir" @@ -70,7 +69,6 @@ iree_check_single_backend_test_suite( "logical_right_shift.mlir" "maximum.mlir" "minimum.mlir" - "mul.mlir" "negate.mlir" "reluN.mlir" "reshape.mlir" diff --git a/third_party/llvm-project b/third_party/llvm-project index 0776eca7a4e7..b24436ac96bd 160000 --- a/third_party/llvm-project +++ b/third_party/llvm-project @@ -1 +1 @@ -Subproject commit 0776eca7a4e76bfadc311f3607be3a4f0c0e989a +Subproject commit b24436ac96bdf3f2c545fc85dc8af239d618c9c4