Skip to content

Commit

Permalink
[transform] e2e UT for packed gemm (#892)
Browse files Browse the repository at this point in the history
  • Loading branch information
wyzero committed Dec 22, 2022
1 parent f2246ca commit 9b6419a
Show file tree
Hide file tree
Showing 16 changed files with 315 additions and 1 deletion.
1 change: 1 addition & 0 deletions tao_compiler/mlir/disc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2052,6 +2052,7 @@ cc_library(
"//tensorflow/compiler/xla/mlir_hlo:mlir_hlo",
"//tensorflow/compiler/xla/mlir_hlo:lhlo",
"//tensorflow/compiler/mlir/disc/tools/disc-transform:all_passes",
"//tensorflow/compiler/mlir/disc/tools/disc-transform:DISCLinalgExtDialect",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AffineDialect",
"@llvm-project//mlir:FuncDialect",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 0 : i32}} {
func.func @main(%arg0: tensor<?x?xf32>) -> (tensor<?x?xf32>) attributes {tf.entry_function = {inputs = "{{INPUTS}}", outputs = "{{OUTPUTS}}", input_placements="{{INPUT_PLACEMENTS}}", output_placements="{{OUTPUT_PLACEMENTS}}"}} {
%graph = tf_executor.graph {
%0:2 = tf_executor.island wraps "tf.Const"() {value = dense<-0.8> : tensor<512x1024xf32>} : () -> tensor<512x1024xf32>
%1:2 = tf_executor.island wraps "tf.MatMul"(%arg0, %0) {transpose_a = false, transpose_b = false} : (tensor<?x?xf32>, tensor<512x1024xf32>) -> (tensor<?x?xf32>)
tf_executor.fetch %1 : tensor<?x?xf32>
}
return %graph : tensor<?x?xf32>
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
transform.structured.canonicalized_sequence failures(propagate) {
^bb1(%arg1: !pdl.operation):
%fill = transform.structured.match ops{["linalg.fill"]} in %arg1
%matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1

%0:2 = transform.structured.tile_to_foreach_thread_op %matmul num_threads [1, 1]
transform.structured.fuse_into_containing_op %fill into %0#0

// first level tile and fuse matmul and fill op.
%1:3 = transform.structured.fuse %0#1 {tile_sizes = [288, 48, 0], tile_interchange = [0, 1, 2]}
// second level tile and fuse matmul and fill op.
%2:3 = transform.structured.fuse %1#0 {tile_sizes = [6, 16, 0], tile_interchange = [0, 1, 2]}

// gemm reduction axis tiling
%3:2 = transform.structured.tile %2#0 [0, 0, 1] {interchange=[0, 1, 2]}

// clean up
%func0 = transform.structured.match ops{["func.func"]} in %arg1
transform.disc.apply_patterns %func0 {canonicalization}
// fold two extract_slice ops generated by two-level tiling. It's needed to enable following
// pad and hosit schedule.
%weight_inner_slice = get_producer_of_operand %3#0[1] : (!pdl.operation) -> !pdl.operation
transform.disc.fold_producer_extract_slice %weight_inner_slice {max_repeat_num = 2}

// pad to match the requirement of hardware vector/tensor instruction.
%4 = transform.structured.pad %3#0 {padding_values=[0.0 : f32, 0.0 : f32, 0.0 : f32], padding_dimensions=[0, 1, 2], pack_paddings=[1, 1, 0], hoist_paddings=[0, 0, 0], transpose_paddings=[[1, 0], [0, 1], [0, 1]]}

%pad_for_input = get_producer_of_operand %4[0] : (!pdl.operation) -> !pdl.operation
%pad_for_weight = get_producer_of_operand %4[1] : (!pdl.operation) -> !pdl.operation
%foreach_op = transform.structured.match ops{["scf.foreach_thread"]} in %arg1
%loop_for_outter_most_n = transform.loop.get_parent_for %4 { num_loops = 4} : (!pdl.operation) -> !pdl.operation
transform.disc.cache_read {padded} %pad_for_input at %loop_for_outter_most_n with tile_levels = [1, 1] tile_sizes = [6, 1] permutation = [0, 2, 3, 1]
transform.disc.cache_read {padded} %pad_for_weight at %foreach_op with tile_levels = [1, 1] tile_sizes = [1, 16] permutation = [2, 0, 1, 3]

%func1 = transform.structured.match ops{["func.func"]} in %arg1
transform.disc.apply_patterns %func1 {canonicalization}

%pack_op = transform.structured.match ops{["disc_linalg_ext.multi_level_pack"]} in %arg1
transform.disc.lower_multi_level_pack_to_loop %pack_op

%func2 = transform.structured.match ops{["func.func"]} in %arg1
transform.disc.apply_patterns %func2 {canonicalization}

%func3 = transform.structured.match ops{["func.func"]} in %arg1
transform.structured.vectorize %func3 {vectorize_padding}

%func4 = transform.structured.match ops{["func.func"]} in %arg1
transform.disc.apply_patterns %func4 {canonicalization}

transform.disc.bufferize %arg1

transform.lower_vectors {
contraction_lowering = "outerproduct",
multireduction_lowering = "innerparallel",
split_transfers = "linalg-copy",
// stages = [0, 1, 2, 3, 4, 5, 6, 7],
stages = [0, 1, 2, 3],
transpose_avx2_lowering = false,
transpose_lowering = "eltwise",
unroll_vector_transfers = true
}

transform.lower_vectors {
contraction_lowering = "outerproduct",
multireduction_lowering = "innerparallel",
split_transfers = "linalg-copy",
// stages = [0, 1, 2, 3, 4, 5, 6, 7],
stages = [5, 6, 7],
transpose_avx2_lowering = false,
transpose_lowering = "eltwise",
unroll_vector_transfers = true
}
}
50 changes: 50 additions & 0 deletions tao_compiler/mlir/disc/tests/disc-transform/packed_matmul.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/* Copyright 2022 The BladeDISC Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/compiler/mlir/disc/tests/mlir_feature_test.h"
#include "tensorflow/compiler/mlir/disc/tests/mlir_test.h"
#include "tensorflow/core/platform/test.h"

namespace mlir_test {

const std::string c_ft_path =
"tensorflow/compiler/mlir/disc/tests/disc-transform/data/";

static bool init_threads = []() {
setenv("OMP_NUM_THREADS", "1", 1);
return true;
}();

TEST(PackedMatmul, F32_304x1024x512) {
EnvSetting setting = {
{"DISC_TRANSFORM_SCHEDULE_FILE",
{c_ft_path + "packed_matmul_nn_p_f32_large_schedule.mlir", false}},
{"DISC_ENABLE_TRANSFORM_SCHEDULE", {"1", false}},
{"DISC_ENABLE_SHAPE_CONSTRAINT_IR", {"1", false}},
{"DISC_MEM_INTENSIVE_OPT_EXPERIMENTAL", {"0", false}}};
EnvSettingContext ctx(setting);
EXPECT_TRUE(feature_test_main(
/*mlir_file_path*/ c_ft_path + "packed_matmul_nn_p_512x1024_f32.mlir",
/*backend_types*/ {BackendType::kAArch64},
/*num_inputs*/ 1,
/*num_outputs*/ 1,
/*input_descriptors*/ {"304x512xf32_X"},
/*output_descriptors*/ {"f32_X"},
/*input_vals*/ {},
/*expected_output_vals*/ {},
/*profiling*/ true));
}

} // namespace mlir_test
1 change: 1 addition & 0 deletions tao_compiler/mlir/disc/tools/disc-transform/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ cc_library(
name = "legalize_lmhlo_fusion_to_linalg",
srcs = ["transforms/legalize_lmhlo_fusion_to_linalg.cc"],
deps = [
":DISCLinalgExtDialect",
":pass_details",
"//tensorflow/compiler/xla/mlir_hlo:lhlo",
"@llvm-project//llvm:Support",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,38 @@ class DISCLinalgExt_Op<string mnemonic, list<Trait> traits = []> :
code extraLinalgExtOpClassDeclaration = "";
}

// We intentionally mark this op not folable since folding it may lose the
// specail semantics captured by this op. For example, it maybe be folded to
// a arith.constant op. It's only supposed to be folded in some special
// usecases (e.g. for packed weight).
def DISCLinalgExt_ConstantWrapperOp : Op<DISCLinalgExt_Dialect, "constant_wrapper", [
Pure, AllTypesMatch<["value", "result"]>]>{
let summary = "integer or floating point tensor constant";
let description = [{
The `constant` operation produces an SSA value equal to some integer or
floating-point constant specified by an attribute. This is the way MLIR
forms simple integer and floating point constants.

Note that this op is not foldable, and is supposed to be used as a placeholder for
later rewriting for RAL.

Example:

```
// Integer constant
%1 = disc_linalg_ext.constant_wrapper dense<42> : tensor<i32>

// Equivalent generic form
%1 = "disc_linalg_ext.constant_wrapper"() {value = dense<42> : tensor<i32>} : () -> tensor<i32>
```
}];

let arguments = (ins ElementsAttr:$value);
let results = (outs AnyType:$result);

let assemblyFormat = "attr-dict $value";
}

def DISCLinalgExt_MultiLevelPackOp : DISCLinalgExt_Op<"multi_level_pack", [
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
]>{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,12 @@ func.func @main(%arg0: tensor<128x128xf32>) -> tensor<4x4x32x32xf32> {
// CHECK: disc_linalg_ext.multi_level_pack
%1 = disc_linalg_ext.multi_level_pack %arg0 with tile_levels = [1, 1] tile_sizes = [32, 32] permutation = [0, 2, 1, 3] into %0 : (tensor<128x128xf32> tensor<4x4x32x32xf32>) -> tensor<4x4x32x32xf32>
return %1 : tensor<4x4x32x32xf32>
}
}

// -----

// CHECK-LABEL: @const
func.func @const() -> tensor<2x4xf32> {
%0 = disc_linalg_ext.constant_wrapper dense<[[0.0, 1.0, 2.0, 3.0], [4.0, 5.0, 6.0, 7.0]]> : tensor<2x4xf32>
return %0 : tensor<2x4xf32>
}
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ CommonExtensions::CommonExtensions() {
using bufferization::BufferizationOptions;
using bufferization::OneShotAnalysisState;
using bufferization::OneShotBufferizationOptions;
using disc_linalg_ext::ConstantWrapperOp;
using disc_linalg_ext::MultiLevelPackOp;

namespace {
Expand Down Expand Up @@ -108,6 +109,13 @@ OneShotBufferizationOptions getBufferizationOptions() {
options.functionBoundaryTypeConversion =
BufferizationOptions::LayoutMapOption::IdentityLayoutMap;

// bufferization.to_memref is used to bufferize constant_wrapper ops. DISC has
// it's own logic to handle constants. We'd like to leave the these constant
// ops as is and insert bufferization.to_memref to convert the tensor to
// memref.
options.opFilter.denyOperation<disc_linalg_ext::ConstantWrapperOp>();
options.opFilter.denyOperation<bufferization::ToMemrefOp>();

// This type converter converts tensor types to memref types when no exact
// memref type can be inferred from the context.
options.unknownTypeConverterFn = [](Value value, unsigned memorySpace,
Expand Down Expand Up @@ -444,6 +452,37 @@ struct TransferWriteOfFillOpPattern
}
};

/// convert:
/// %0 = disc_linalg_ext.constant_wrapper ...
/// %1 = disc_linalg_ext.multi_level_pack %0 ...
/// use(%1)
/// to:
/// %0 = disc_linalg_ext.constant_wrapper ... // folded
/// use(%0)
struct FoldMultiLevelPackOfConstantWrapperPattern
: public OpRewritePattern<MultiLevelPackOp> {
using OpRewritePattern<MultiLevelPackOp>::OpRewritePattern;

LogicalResult matchAndRewrite(MultiLevelPackOp op,
PatternRewriter& rewriter) const override {
auto constOp = op.getInput().getDefiningOp<ConstantWrapperOp>();
if (!constOp) return failure();

Attribute paddingAttr;
if (op.getPaddingValue() &&
!matchPattern(op.getPaddingValue(), m_Constant(&paddingAttr)))
return failure();

SmallVector<Attribute> attrs{constOp.getValue(), nullptr, paddingAttr};
SmallVector<OpFoldResult> results;
if (failed(op.fold(attrs, results))) return failure();

rewriter.replaceOpWithNewOp<ConstantWrapperOp>(op, op.getOutputType(),
results[0].get<Attribute>());
return success();
}
};

static void addAllRegisteredCanonicalizationPatterns(
RewritePatternSet& patterns) {
MLIRContext* ctx = patterns.getContext();
Expand All @@ -458,6 +497,7 @@ static void addAllRegisteredCanonicalizationPatterns(
patterns.insert<FoldSelfInsertSlicePattern>(ctx);
patterns.insert<TransferReadOfFillOpPattern>(ctx);
patterns.insert<TransferWriteOfFillOpPattern>(ctx);
patterns.insert<FoldMultiLevelPackOfConstantWrapperPattern>(ctx);
}

} // namespace
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ class LinalgDialect;

namespace disc_ral {

namespace disc_linalg_ext {
class DISCLinalgExtDialect;
}

#define GEN_PASS_CLASSES
#include "tensorflow/compiler/mlir/disc/tools/disc-transform/transforms/transform_passes.h.inc"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
#include "mlir/IR/MLIRContext.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/Passes.h"
#include "tensorflow/compiler/mlir/disc/tools/disc-transform/LinalgExt/LinalgExtDialect.h"
#include "tensorflow/compiler/mlir/disc/tools/disc-transform/LinalgExt/LinalgExtOps.h"
#include "tensorflow/compiler/mlir/disc/tools/disc-transform/transforms/PassDetail.h"

#define DEBUG_TYPE "disc-legalize-lmhlo-fusion-to-linalg"
Expand Down Expand Up @@ -104,10 +106,23 @@ LogicalResult emitDotGeneralOp(lmhlo::DotGeneralOp op, OpBuilder& b,
return success();
}

LogicalResult emitConstOp(lmhlo::ConstantOp op, OpBuilder& b,
BlockAndValueMapping& mapping) {
auto resultTy = convertMemRefToTensorType(op->getOperand(0).getType());
Location loc = op->getLoc();
auto newOp = b.create<disc_linalg_ext::ConstantWrapperOp>(loc, resultTy,
op.getValue());
mapping.erase(op->getOperand(0));
mapping.map(op->getOperand(0), newOp.getResult());
return success();
}

LogicalResult emitLmhloOp(Operation* op, OpBuilder& b,
BlockAndValueMapping& mapping) {
if (auto dotGeneralOp = dyn_cast<lmhlo::DotGeneralOp>(op)) {
return emitDotGeneralOp(dotGeneralOp, b, mapping);
} else if (auto constOp = dyn_cast<lmhlo::ConstantOp>(op)) {
return emitConstOp(constOp, b, mapping);
}
// TODO(wyzero): support other lmhlo ops.
return failure();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ class OperationPass;

namespace disc_ral {

namespace disc_linalg_ext {
class DISCLinalgExtDialect;
}

// Converts a lmhlo fusion op in side a function to its linalg on tensor
// equivalent.
std::unique_ptr<OperationPass<ModuleOp>>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,21 @@ transform.structured.canonicalized_sequence failures(propagate) {
^bb1(%arg1: !pdl.operation):
transform.disc.apply_patterns %arg1 {canonicalization}
}

// -----

// CHECK-LABEL: @fold_constant_wrapper_and_multi_level_pack
func.func @fold_constant_wrapper_and_multi_level_pack() -> tensor<64x512x1x16xf32> {
// CHECK: %[[T0:.*]] = disc_linalg_ext.constant_wrapper dense<-8.000000e-01> : tensor<64x512x1x16xf32>
// CHECK-NEXT: return %[[T0]] : tensor<64x512x1x16xf32>
%cst = arith.constant 0.000000e+00 : f32
%0 = disc_linalg_ext.constant_wrapper dense<-8.000000e-01> : tensor<512x1024xf32>
%1 = tensor.empty() : tensor<64x512x1x16xf32>
%2 = disc_linalg_ext.multi_level_pack %0 with padding_value(%cst : f32) tile_levels = [1, 1] tile_sizes = [1, 16] permutation = [2, 0, 1, 3] into %1 : (tensor<512x1024xf32> tensor<64x512x1x16xf32>) -> tensor<64x512x1x16xf32>
return %2 : tensor<64x512x1x16xf32>
}

transform.structured.canonicalized_sequence failures(propagate) {
^bb1(%arg1: !pdl.operation):
transform.disc.apply_patterns %arg1 {canonicalization}
}
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,22 @@ func.func @not_use_alloca_due_to_dynamic_shape(%arg0: index) -> tensor<?x?xf32>
return %2 : tensor<?x?xf32>
}

transform.structured.canonicalized_sequence failures(propagate) {
^bb1(%arg1: !pdl.operation):
transform.disc.bufferize %arg1
}

// -----

// CHECK-LABEL: @bufferize_constant_wrapper
func.func @bufferize_constant_wrapper() -> tensor<512x1024xf32> {
// CHECK: %[[T0:.*]] = disc_linalg_ext.constant_wrapper dense<-8.000000e-01> : tensor<512x1024xf32>
// CHECK-NEXT: %[[T1:.*]] = bufferization.to_memref %[[T0]]
// CHECK-NEXT: return %[[T1]]
%0 = disc_linalg_ext.constant_wrapper dense<-8.000000e-01> : tensor<512x1024xf32>
return %0 : tensor<512x1024xf32>
}

transform.structured.canonicalized_sequence failures(propagate) {
^bb1(%arg1: !pdl.operation):
transform.disc.bufferize %arg1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,22 @@ func.func @matmul_nn(%arg1: memref<?x?xf32, "cpu">, %arg2: memref<?x?xf32, "cpu"
"lmhlo.terminator"() : () -> ()
}) {disc.device = "cpu", disc.fusion.name = "matmul_nn_kTransform_dot_general__1_1_0", disc.fusion_type = "kTransform"} : () -> ()
return %arg3 : memref<?x?xf32, "cpu">
}

// -----

// CHECK-LABEL: @packed_matmul_nn
// CHECK-SAME: (%[[ARG0:.*]]: tensor<?x512xf32>, %[[ARG1:.*]]: tensor<512x1024xf32>, %[[ARG2:.*]]: tensor<?x1024xf32>)
func.func @packed_matmul_nn(%arg1: memref<?x512xf32, "cpu">, %arg2: memref<512x1024xf32, "cpu">, %arg3: memref<?x1024xf32, "cpu">) -> memref<?x1024xf32, "cpu"> {
// CHECK: %[[T0:.*]] = disc_linalg_ext.constant_wrapper dense<-8.000000e-01> : tensor<512x1024xf32>
// CHECK: %[[T1:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[T2:.*]] = linalg.fill ins(%[[T1]] : f32) outs(%[[ARG2]] : tensor<?x1024xf32>) -> tensor<?x1024xf32>
// CHECK: %[[T3:.*]] = linalg.matmul ins(%[[ARG0]], %[[T0]] : tensor<?x512xf32>, tensor<512x1024xf32>) outs(%[[T2]] : tensor<?x1024xf32>) -> tensor<?x1024xf32>
// CHECK: return %[[T3]]
"lmhlo.fusion"() ({
"lmhlo.constant"(%arg2) {disc.device = "cpu", value = dense<-8.000000e-01> : tensor<512x1024xf32>} : (memref<512x1024xf32, "cpu">) -> ()
"lmhlo.dot_general"(%arg1, %arg2, %arg3) {disc.device = "cpu", dot_dimension_numbers = #mhlo.dot<lhs_contracting_dimensions = [1], rhs_contracting_dimensions = [0]>} : (memref<?x512xf32, "cpu">, memref<512x1024xf32, "cpu">, memref<?x1024xf32, "cpu">) -> ()
"lmhlo.terminator"() : () -> ()
}) {disc.device = "cpu", disc.fusion.name = "matmul_nn_kTransform_dot_general__1_1_0", disc.fusion_type = "kTransform"} : () -> ()
return %arg3 : memref<?x1024xf32, "cpu">
}
Loading

0 comments on commit 9b6419a

Please sign in to comment.