diff --git a/tao_compiler/mlir/disc/BUILD b/tao_compiler/mlir/disc/BUILD index 99f102d400b..792cc2c5301 100644 --- a/tao_compiler/mlir/disc/BUILD +++ b/tao_compiler/mlir/disc/BUILD @@ -2052,6 +2052,7 @@ cc_library( "//tensorflow/compiler/xla/mlir_hlo:mlir_hlo", "//tensorflow/compiler/xla/mlir_hlo:lhlo", "//tensorflow/compiler/mlir/disc/tools/disc-transform:all_passes", + "//tensorflow/compiler/mlir/disc/tools/disc-transform:DISCLinalgExtDialect", "@llvm-project//llvm:Support", "@llvm-project//mlir:AffineDialect", "@llvm-project//mlir:FuncDialect", diff --git a/tao_compiler/mlir/disc/tests/disc-transform/data/packed_matmul_nn_p_512x1024_f32.mlir b/tao_compiler/mlir/disc/tests/disc-transform/data/packed_matmul_nn_p_512x1024_f32.mlir new file mode 100644 index 00000000000..9de4365ac7e --- /dev/null +++ b/tao_compiler/mlir/disc/tests/disc-transform/data/packed_matmul_nn_p_512x1024_f32.mlir @@ -0,0 +1,10 @@ +module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 0 : i32}} { + func.func @main(%arg0: tensor) -> (tensor) attributes {tf.entry_function = {inputs = "{{INPUTS}}", outputs = "{{OUTPUTS}}", input_placements="{{INPUT_PLACEMENTS}}", output_placements="{{OUTPUT_PLACEMENTS}}"}} { + %graph = tf_executor.graph { + %0:2 = tf_executor.island wraps "tf.Const"() {value = dense<-0.8> : tensor<512x1024xf32>} : () -> tensor<512x1024xf32> + %1:2 = tf_executor.island wraps "tf.MatMul"(%arg0, %0) {transpose_a = false, transpose_b = false} : (tensor, tensor<512x1024xf32>) -> (tensor) + tf_executor.fetch %1 : tensor + } + return %graph : tensor + } +} \ No newline at end of file diff --git a/tao_compiler/mlir/disc/tests/disc-transform/data/packed_matmul_nn_p_f32_large_schedule.mlir b/tao_compiler/mlir/disc/tests/disc-transform/data/packed_matmul_nn_p_f32_large_schedule.mlir new file mode 100644 index 00000000000..8d6c30b9cdf --- /dev/null +++ b/tao_compiler/mlir/disc/tests/disc-transform/data/packed_matmul_nn_p_f32_large_schedule.mlir @@ -0,0 +1,73 @@ +transform.structured.canonicalized_sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %fill = transform.structured.match ops{["linalg.fill"]} in %arg1 + %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1 + + %0:2 = transform.structured.tile_to_foreach_thread_op %matmul num_threads [1, 1] + transform.structured.fuse_into_containing_op %fill into %0#0 + + // first level tile and fuse matmul and fill op. + %1:3 = transform.structured.fuse %0#1 {tile_sizes = [288, 48, 0], tile_interchange = [0, 1, 2]} + // second level tile and fuse matmul and fill op. + %2:3 = transform.structured.fuse %1#0 {tile_sizes = [6, 16, 0], tile_interchange = [0, 1, 2]} + + // gemm reduction axis tiling + %3:2 = transform.structured.tile %2#0 [0, 0, 1] {interchange=[0, 1, 2]} + + // clean up + %func0 = transform.structured.match ops{["func.func"]} in %arg1 + transform.disc.apply_patterns %func0 {canonicalization} + // fold two extract_slice ops generated by two-level tiling. It's needed to enable following + // pad and hosit schedule. + %weight_inner_slice = get_producer_of_operand %3#0[1] : (!pdl.operation) -> !pdl.operation + transform.disc.fold_producer_extract_slice %weight_inner_slice {max_repeat_num = 2} + + // pad to match the requirement of hardware vector/tensor instruction. + %4 = transform.structured.pad %3#0 {padding_values=[0.0 : f32, 0.0 : f32, 0.0 : f32], padding_dimensions=[0, 1, 2], pack_paddings=[1, 1, 0], hoist_paddings=[0, 0, 0], transpose_paddings=[[1, 0], [0, 1], [0, 1]]} + + %pad_for_input = get_producer_of_operand %4[0] : (!pdl.operation) -> !pdl.operation + %pad_for_weight = get_producer_of_operand %4[1] : (!pdl.operation) -> !pdl.operation + %foreach_op = transform.structured.match ops{["scf.foreach_thread"]} in %arg1 + %loop_for_outter_most_n = transform.loop.get_parent_for %4 { num_loops = 4} : (!pdl.operation) -> !pdl.operation + transform.disc.cache_read {padded} %pad_for_input at %loop_for_outter_most_n with tile_levels = [1, 1] tile_sizes = [6, 1] permutation = [0, 2, 3, 1] + transform.disc.cache_read {padded} %pad_for_weight at %foreach_op with tile_levels = [1, 1] tile_sizes = [1, 16] permutation = [2, 0, 1, 3] + + %func1 = transform.structured.match ops{["func.func"]} in %arg1 + transform.disc.apply_patterns %func1 {canonicalization} + + %pack_op = transform.structured.match ops{["disc_linalg_ext.multi_level_pack"]} in %arg1 + transform.disc.lower_multi_level_pack_to_loop %pack_op + + %func2 = transform.structured.match ops{["func.func"]} in %arg1 + transform.disc.apply_patterns %func2 {canonicalization} + + %func3 = transform.structured.match ops{["func.func"]} in %arg1 + transform.structured.vectorize %func3 {vectorize_padding} + + %func4 = transform.structured.match ops{["func.func"]} in %arg1 + transform.disc.apply_patterns %func4 {canonicalization} + + transform.disc.bufferize %arg1 + + transform.lower_vectors { + contraction_lowering = "outerproduct", + multireduction_lowering = "innerparallel", + split_transfers = "linalg-copy", + // stages = [0, 1, 2, 3, 4, 5, 6, 7], + stages = [0, 1, 2, 3], + transpose_avx2_lowering = false, + transpose_lowering = "eltwise", + unroll_vector_transfers = true + } + + transform.lower_vectors { + contraction_lowering = "outerproduct", + multireduction_lowering = "innerparallel", + split_transfers = "linalg-copy", + // stages = [0, 1, 2, 3, 4, 5, 6, 7], + stages = [5, 6, 7], + transpose_avx2_lowering = false, + transpose_lowering = "eltwise", + unroll_vector_transfers = true + } +} \ No newline at end of file diff --git a/tao_compiler/mlir/disc/tests/disc-transform/packed_matmul.cc b/tao_compiler/mlir/disc/tests/disc-transform/packed_matmul.cc new file mode 100644 index 00000000000..dad1f1b796e --- /dev/null +++ b/tao_compiler/mlir/disc/tests/disc-transform/packed_matmul.cc @@ -0,0 +1,50 @@ +/* Copyright 2022 The BladeDISC Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/disc/tests/mlir_feature_test.h" +#include "tensorflow/compiler/mlir/disc/tests/mlir_test.h" +#include "tensorflow/core/platform/test.h" + +namespace mlir_test { + +const std::string c_ft_path = + "tensorflow/compiler/mlir/disc/tests/disc-transform/data/"; + +static bool init_threads = []() { + setenv("OMP_NUM_THREADS", "1", 1); + return true; +}(); + +TEST(PackedMatmul, F32_304x1024x512) { + EnvSetting setting = { + {"DISC_TRANSFORM_SCHEDULE_FILE", + {c_ft_path + "packed_matmul_nn_p_f32_large_schedule.mlir", false}}, + {"DISC_ENABLE_TRANSFORM_SCHEDULE", {"1", false}}, + {"DISC_ENABLE_SHAPE_CONSTRAINT_IR", {"1", false}}, + {"DISC_MEM_INTENSIVE_OPT_EXPERIMENTAL", {"0", false}}}; + EnvSettingContext ctx(setting); + EXPECT_TRUE(feature_test_main( + /*mlir_file_path*/ c_ft_path + "packed_matmul_nn_p_512x1024_f32.mlir", + /*backend_types*/ {BackendType::kAArch64}, + /*num_inputs*/ 1, + /*num_outputs*/ 1, + /*input_descriptors*/ {"304x512xf32_X"}, + /*output_descriptors*/ {"f32_X"}, + /*input_vals*/ {}, + /*expected_output_vals*/ {}, + /*profiling*/ true)); +} + +} // namespace mlir_test diff --git a/tao_compiler/mlir/disc/tools/disc-transform/BUILD b/tao_compiler/mlir/disc/tools/disc-transform/BUILD index 04f22ee1a2d..30302019b59 100644 --- a/tao_compiler/mlir/disc/tools/disc-transform/BUILD +++ b/tao_compiler/mlir/disc/tools/disc-transform/BUILD @@ -282,6 +282,7 @@ cc_library( name = "legalize_lmhlo_fusion_to_linalg", srcs = ["transforms/legalize_lmhlo_fusion_to_linalg.cc"], deps = [ + ":DISCLinalgExtDialect", ":pass_details", "//tensorflow/compiler/xla/mlir_hlo:lhlo", "@llvm-project//llvm:Support", diff --git a/tao_compiler/mlir/disc/tools/disc-transform/LinalgExt/LinalgExtOps.td b/tao_compiler/mlir/disc/tools/disc-transform/LinalgExt/LinalgExtOps.td index 71cdcd444e0..a26532a28db 100644 --- a/tao_compiler/mlir/disc/tools/disc-transform/LinalgExt/LinalgExtOps.td +++ b/tao_compiler/mlir/disc/tools/disc-transform/LinalgExt/LinalgExtOps.td @@ -40,6 +40,38 @@ class DISCLinalgExt_Op traits = []> : code extraLinalgExtOpClassDeclaration = ""; } +// We intentionally mark this op not folable since folding it may lose the +// specail semantics captured by this op. For example, it maybe be folded to +// a arith.constant op. It's only supposed to be folded in some special +// usecases (e.g. for packed weight). +def DISCLinalgExt_ConstantWrapperOp : Op]>{ + let summary = "integer or floating point tensor constant"; + let description = [{ + The `constant` operation produces an SSA value equal to some integer or + floating-point constant specified by an attribute. This is the way MLIR + forms simple integer and floating point constants. + + Note that this op is not foldable, and is supposed to be used as a placeholder for + later rewriting for RAL. + + Example: + + ``` + // Integer constant + %1 = disc_linalg_ext.constant_wrapper dense<42> : tensor + + // Equivalent generic form + %1 = "disc_linalg_ext.constant_wrapper"() {value = dense<42> : tensor} : () -> tensor + ``` + }]; + + let arguments = (ins ElementsAttr:$value); + let results = (outs AnyType:$result); + + let assemblyFormat = "attr-dict $value"; +} + def DISCLinalgExt_MultiLevelPackOp : DISCLinalgExt_Op<"multi_level_pack", [ DeclareOpInterfaceMethods, ]>{ diff --git a/tao_compiler/mlir/disc/tools/disc-transform/LinalgExt/tests/ops.mlir b/tao_compiler/mlir/disc/tools/disc-transform/LinalgExt/tests/ops.mlir index d61f0da789e..fdb8557492d 100644 --- a/tao_compiler/mlir/disc/tools/disc-transform/LinalgExt/tests/ops.mlir +++ b/tao_compiler/mlir/disc/tools/disc-transform/LinalgExt/tests/ops.mlir @@ -5,4 +5,12 @@ func.func @main(%arg0: tensor<128x128xf32>) -> tensor<4x4x32x32xf32> { // CHECK: disc_linalg_ext.multi_level_pack %1 = disc_linalg_ext.multi_level_pack %arg0 with tile_levels = [1, 1] tile_sizes = [32, 32] permutation = [0, 2, 1, 3] into %0 : (tensor<128x128xf32> tensor<4x4x32x32xf32>) -> tensor<4x4x32x32xf32> return %1 : tensor<4x4x32x32xf32> -} \ No newline at end of file +} + +// ----- + +// CHECK-LABEL: @const +func.func @const() -> tensor<2x4xf32> { + %0 = disc_linalg_ext.constant_wrapper dense<[[0.0, 1.0, 2.0, 3.0], [4.0, 5.0, 6.0, 7.0]]> : tensor<2x4xf32> + return %0 : tensor<2x4xf32> +} diff --git a/tao_compiler/mlir/disc/tools/disc-transform/TransformOps/TransformOpsExt.cc b/tao_compiler/mlir/disc/tools/disc-transform/TransformOps/TransformOpsExt.cc index c8fce07709d..2a94bc13798 100644 --- a/tao_compiler/mlir/disc/tools/disc-transform/TransformOps/TransformOpsExt.cc +++ b/tao_compiler/mlir/disc/tools/disc-transform/TransformOps/TransformOpsExt.cc @@ -57,6 +57,7 @@ CommonExtensions::CommonExtensions() { using bufferization::BufferizationOptions; using bufferization::OneShotAnalysisState; using bufferization::OneShotBufferizationOptions; +using disc_linalg_ext::ConstantWrapperOp; using disc_linalg_ext::MultiLevelPackOp; namespace { @@ -108,6 +109,13 @@ OneShotBufferizationOptions getBufferizationOptions() { options.functionBoundaryTypeConversion = BufferizationOptions::LayoutMapOption::IdentityLayoutMap; + // bufferization.to_memref is used to bufferize constant_wrapper ops. DISC has + // it's own logic to handle constants. We'd like to leave the these constant + // ops as is and insert bufferization.to_memref to convert the tensor to + // memref. + options.opFilter.denyOperation(); + options.opFilter.denyOperation(); + // This type converter converts tensor types to memref types when no exact // memref type can be inferred from the context. options.unknownTypeConverterFn = [](Value value, unsigned memorySpace, @@ -444,6 +452,37 @@ struct TransferWriteOfFillOpPattern } }; +/// convert: +/// %0 = disc_linalg_ext.constant_wrapper ... +/// %1 = disc_linalg_ext.multi_level_pack %0 ... +/// use(%1) +/// to: +/// %0 = disc_linalg_ext.constant_wrapper ... // folded +/// use(%0) +struct FoldMultiLevelPackOfConstantWrapperPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(MultiLevelPackOp op, + PatternRewriter& rewriter) const override { + auto constOp = op.getInput().getDefiningOp(); + if (!constOp) return failure(); + + Attribute paddingAttr; + if (op.getPaddingValue() && + !matchPattern(op.getPaddingValue(), m_Constant(&paddingAttr))) + return failure(); + + SmallVector attrs{constOp.getValue(), nullptr, paddingAttr}; + SmallVector results; + if (failed(op.fold(attrs, results))) return failure(); + + rewriter.replaceOpWithNewOp(op, op.getOutputType(), + results[0].get()); + return success(); + } +}; + static void addAllRegisteredCanonicalizationPatterns( RewritePatternSet& patterns) { MLIRContext* ctx = patterns.getContext(); @@ -458,6 +497,7 @@ static void addAllRegisteredCanonicalizationPatterns( patterns.insert(ctx); patterns.insert(ctx); patterns.insert(ctx); + patterns.insert(ctx); } } // namespace diff --git a/tao_compiler/mlir/disc/tools/disc-transform/transforms/PassDetail.h b/tao_compiler/mlir/disc/tools/disc-transform/transforms/PassDetail.h index e5e584d0fdd..792c975fbf2 100644 --- a/tao_compiler/mlir/disc/tools/disc-transform/transforms/PassDetail.h +++ b/tao_compiler/mlir/disc/tools/disc-transform/transforms/PassDetail.h @@ -60,6 +60,10 @@ class LinalgDialect; namespace disc_ral { +namespace disc_linalg_ext { +class DISCLinalgExtDialect; +} + #define GEN_PASS_CLASSES #include "tensorflow/compiler/mlir/disc/tools/disc-transform/transforms/transform_passes.h.inc" diff --git a/tao_compiler/mlir/disc/tools/disc-transform/transforms/legalize_lmhlo_fusion_to_linalg.cc b/tao_compiler/mlir/disc/tools/disc-transform/transforms/legalize_lmhlo_fusion_to_linalg.cc index ad21af36b6f..f1352ddab33 100644 --- a/tao_compiler/mlir/disc/tools/disc-transform/transforms/legalize_lmhlo_fusion_to_linalg.cc +++ b/tao_compiler/mlir/disc/tools/disc-transform/transforms/legalize_lmhlo_fusion_to_linalg.cc @@ -19,6 +19,8 @@ #include "mlir/IR/MLIRContext.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/Passes.h" +#include "tensorflow/compiler/mlir/disc/tools/disc-transform/LinalgExt/LinalgExtDialect.h" +#include "tensorflow/compiler/mlir/disc/tools/disc-transform/LinalgExt/LinalgExtOps.h" #include "tensorflow/compiler/mlir/disc/tools/disc-transform/transforms/PassDetail.h" #define DEBUG_TYPE "disc-legalize-lmhlo-fusion-to-linalg" @@ -104,10 +106,23 @@ LogicalResult emitDotGeneralOp(lmhlo::DotGeneralOp op, OpBuilder& b, return success(); } +LogicalResult emitConstOp(lmhlo::ConstantOp op, OpBuilder& b, + BlockAndValueMapping& mapping) { + auto resultTy = convertMemRefToTensorType(op->getOperand(0).getType()); + Location loc = op->getLoc(); + auto newOp = b.create(loc, resultTy, + op.getValue()); + mapping.erase(op->getOperand(0)); + mapping.map(op->getOperand(0), newOp.getResult()); + return success(); +} + LogicalResult emitLmhloOp(Operation* op, OpBuilder& b, BlockAndValueMapping& mapping) { if (auto dotGeneralOp = dyn_cast(op)) { return emitDotGeneralOp(dotGeneralOp, b, mapping); + } else if (auto constOp = dyn_cast(op)) { + return emitConstOp(constOp, b, mapping); } // TODO(wyzero): support other lmhlo ops. return failure(); diff --git a/tao_compiler/mlir/disc/tools/disc-transform/transforms/passes.h b/tao_compiler/mlir/disc/tools/disc-transform/transforms/passes.h index 06332b4119c..2cf124a32b7 100644 --- a/tao_compiler/mlir/disc/tools/disc-transform/transforms/passes.h +++ b/tao_compiler/mlir/disc/tools/disc-transform/transforms/passes.h @@ -35,6 +35,10 @@ class OperationPass; namespace disc_ral { +namespace disc_linalg_ext { +class DISCLinalgExtDialect; +} + // Converts a lmhlo fusion op in side a function to its linalg on tensor // equivalent. std::unique_ptr> diff --git a/tao_compiler/mlir/disc/tools/disc-transform/transforms/tests/apply-patterns-canonicalization.mlir b/tao_compiler/mlir/disc/tools/disc-transform/transforms/tests/apply-patterns-canonicalization.mlir index 67e69b03a1f..ec162bf4079 100644 --- a/tao_compiler/mlir/disc/tools/disc-transform/transforms/tests/apply-patterns-canonicalization.mlir +++ b/tao_compiler/mlir/disc/tools/disc-transform/transforms/tests/apply-patterns-canonicalization.mlir @@ -109,3 +109,21 @@ transform.structured.canonicalized_sequence failures(propagate) { ^bb1(%arg1: !pdl.operation): transform.disc.apply_patterns %arg1 {canonicalization} } + +// ----- + +// CHECK-LABEL: @fold_constant_wrapper_and_multi_level_pack +func.func @fold_constant_wrapper_and_multi_level_pack() -> tensor<64x512x1x16xf32> { + // CHECK: %[[T0:.*]] = disc_linalg_ext.constant_wrapper dense<-8.000000e-01> : tensor<64x512x1x16xf32> + // CHECK-NEXT: return %[[T0]] : tensor<64x512x1x16xf32> + %cst = arith.constant 0.000000e+00 : f32 + %0 = disc_linalg_ext.constant_wrapper dense<-8.000000e-01> : tensor<512x1024xf32> + %1 = tensor.empty() : tensor<64x512x1x16xf32> + %2 = disc_linalg_ext.multi_level_pack %0 with padding_value(%cst : f32) tile_levels = [1, 1] tile_sizes = [1, 16] permutation = [2, 0, 1, 3] into %1 : (tensor<512x1024xf32> tensor<64x512x1x16xf32>) -> tensor<64x512x1x16xf32> + return %2 : tensor<64x512x1x16xf32> +} + +transform.structured.canonicalized_sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + transform.disc.apply_patterns %arg1 {canonicalization} +} diff --git a/tao_compiler/mlir/disc/tools/disc-transform/transforms/tests/disc-bufferize.mlir b/tao_compiler/mlir/disc/tools/disc-transform/transforms/tests/disc-bufferize.mlir index b6c88e21ca9..b81bdac7500 100644 --- a/tao_compiler/mlir/disc/tools/disc-transform/transforms/tests/disc-bufferize.mlir +++ b/tao_compiler/mlir/disc/tools/disc-transform/transforms/tests/disc-bufferize.mlir @@ -88,6 +88,22 @@ func.func @not_use_alloca_due_to_dynamic_shape(%arg0: index) -> tensor return %2 : tensor } +transform.structured.canonicalized_sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + transform.disc.bufferize %arg1 +} + +// ----- + +// CHECK-LABEL: @bufferize_constant_wrapper +func.func @bufferize_constant_wrapper() -> tensor<512x1024xf32> { + // CHECK: %[[T0:.*]] = disc_linalg_ext.constant_wrapper dense<-8.000000e-01> : tensor<512x1024xf32> + // CHECK-NEXT: %[[T1:.*]] = bufferization.to_memref %[[T0]] + // CHECK-NEXT: return %[[T1]] + %0 = disc_linalg_ext.constant_wrapper dense<-8.000000e-01> : tensor<512x1024xf32> + return %0 : tensor<512x1024xf32> +} + transform.structured.canonicalized_sequence failures(propagate) { ^bb1(%arg1: !pdl.operation): transform.disc.bufferize %arg1 diff --git a/tao_compiler/mlir/disc/tools/disc-transform/transforms/tests/legalize-lmhlo-fusion-to-linalg.mlir b/tao_compiler/mlir/disc/tools/disc-transform/transforms/tests/legalize-lmhlo-fusion-to-linalg.mlir index ff958a999ed..aa5972bd874 100644 --- a/tao_compiler/mlir/disc/tools/disc-transform/transforms/tests/legalize-lmhlo-fusion-to-linalg.mlir +++ b/tao_compiler/mlir/disc/tools/disc-transform/transforms/tests/legalize-lmhlo-fusion-to-linalg.mlir @@ -12,4 +12,22 @@ func.func @matmul_nn(%arg1: memref, %arg2: memref () }) {disc.device = "cpu", disc.fusion.name = "matmul_nn_kTransform_dot_general__1_1_0", disc.fusion_type = "kTransform"} : () -> () return %arg3 : memref +} + +// ----- + +// CHECK-LABEL: @packed_matmul_nn +// CHECK-SAME: (%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor<512x1024xf32>, %[[ARG2:.*]]: tensor) +func.func @packed_matmul_nn(%arg1: memref, %arg2: memref<512x1024xf32, "cpu">, %arg3: memref) -> memref { + // CHECK: %[[T0:.*]] = disc_linalg_ext.constant_wrapper dense<-8.000000e-01> : tensor<512x1024xf32> + // CHECK: %[[T1:.*]] = arith.constant 0.000000e+00 : f32 + // CHECK: %[[T2:.*]] = linalg.fill ins(%[[T1]] : f32) outs(%[[ARG2]] : tensor) -> tensor + // CHECK: %[[T3:.*]] = linalg.matmul ins(%[[ARG0]], %[[T0]] : tensor, tensor<512x1024xf32>) outs(%[[T2]] : tensor) -> tensor + // CHECK: return %[[T3]] + "lmhlo.fusion"() ({ + "lmhlo.constant"(%arg2) {disc.device = "cpu", value = dense<-8.000000e-01> : tensor<512x1024xf32>} : (memref<512x1024xf32, "cpu">) -> () + "lmhlo.dot_general"(%arg1, %arg2, %arg3) {disc.device = "cpu", dot_dimension_numbers = #mhlo.dot} : (memref, memref<512x1024xf32, "cpu">, memref) -> () + "lmhlo.terminator"() : () -> () + }) {disc.device = "cpu", disc.fusion.name = "matmul_nn_kTransform_dot_general__1_1_0", disc.fusion_type = "kTransform"} : () -> () + return %arg3 : memref } \ No newline at end of file diff --git a/tao_compiler/mlir/disc/tools/disc-transform/transforms/transform_passes.td b/tao_compiler/mlir/disc/tools/disc-transform/transforms/transform_passes.td index 1bcaf54b258..bda97e24d26 100644 --- a/tao_compiler/mlir/disc/tools/disc-transform/transforms/transform_passes.td +++ b/tao_compiler/mlir/disc/tools/disc-transform/transforms/transform_passes.td @@ -18,6 +18,9 @@ include "mlir/Pass/PassBase.td" def DiscLegalizeLmhloFusionToLinalgPass : Pass<"disc-legalize-lmhlo-fusion-to-linalg", "ModuleOp"> { let summary = "Pass to convert a lmhlo fusion op to linalg on tensor."; let constructor = "createDiscLegalizeLmhloFusionToLinalgPass()"; + let dependentDialects = [ + "disc_ral::disc_linalg_ext::DISCLinalgExtDialect", + ]; } def DiscTransformDialectInterpreterPass : Pass<"disc-transform-dialect-interpreter", "ModuleOp"> { diff --git a/tao_compiler/mlir/disc/transforms/disc_transform_legalize_to_loop.cc b/tao_compiler/mlir/disc/transforms/disc_transform_legalize_to_loop.cc index 95760b4ac67..b70d82f7dac 100644 --- a/tao_compiler/mlir/disc/transforms/disc_transform_legalize_to_loop.cc +++ b/tao_compiler/mlir/disc/transforms/disc_transform_legalize_to_loop.cc @@ -34,6 +34,8 @@ limitations under the License. #include "mlir/Pass/PassManager.h" #include "mlir/Support/LogicalResult.h" #include "tensorflow/compiler/mlir/disc/disc_util.h" +#include "tensorflow/compiler/mlir/disc/tools/disc-transform/LinalgExt/LinalgExtDialect.h" +#include "tensorflow/compiler/mlir/disc/tools/disc-transform/LinalgExt/LinalgExtOps.h" #include "tensorflow/compiler/mlir/disc/tools/disc-transform/transforms/passes.h" #include "tensorflow/compiler/mlir/disc/transforms/PassDetail.h" #include "tensorflow/compiler/mlir/disc/transforms/codegen_utils.h" @@ -155,6 +157,25 @@ LogicalResult DiscTransformLegalizeToLoopPass::inlineTransformedModule( b.setInsertionPoint(body, body->begin()); auto mapping = buildValueMapping(fusionPattern, funcOps[0], true); for (auto& nestedOp : funcOps[0].getBody().front().without_terminator()) { + // ConstantWrapperOp will be cloned when we handle its corresponding + // bufferization::to_memref op. + if (isa(&nestedOp)) continue; + if (auto toMemrefOp = dyn_cast(&nestedOp)) { + auto constOp = toMemrefOp->getOperand(0) + .getDefiningOp(); + if (!constOp) + return constOp->emitError() + << "unkown operand for bufferization::ToMemrefOp\n"; + auto ip = b.saveInsertionPoint(); + b.setInsertionPoint(fusion); + Location loc = constOp->getLoc(); + Value buffer = b.create( + loc, toMemrefOp.getResult().getType().cast()); + b.create(loc, constOp.getValue(), buffer); + mapping.map(toMemrefOp.getResult(), buffer); + b.restoreInsertionPoint(ip); + continue; + } Operation* cloned = b.clone(nestedOp, mapping); } return success();