Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[transform] e2e UT for packed gemm #892

Merged
merged 1 commit into from
Dec 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tao_compiler/mlir/disc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2052,6 +2052,7 @@ cc_library(
"//tensorflow/compiler/xla/mlir_hlo:mlir_hlo",
"//tensorflow/compiler/xla/mlir_hlo:lhlo",
"//tensorflow/compiler/mlir/disc/tools/disc-transform:all_passes",
"//tensorflow/compiler/mlir/disc/tools/disc-transform:DISCLinalgExtDialect",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AffineDialect",
"@llvm-project//mlir:FuncDialect",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 0 : i32}} {
func.func @main(%arg0: tensor<?x?xf32>) -> (tensor<?x?xf32>) attributes {tf.entry_function = {inputs = "{{INPUTS}}", outputs = "{{OUTPUTS}}", input_placements="{{INPUT_PLACEMENTS}}", output_placements="{{OUTPUT_PLACEMENTS}}"}} {
%graph = tf_executor.graph {
%0:2 = tf_executor.island wraps "tf.Const"() {value = dense<-0.8> : tensor<512x1024xf32>} : () -> tensor<512x1024xf32>
%1:2 = tf_executor.island wraps "tf.MatMul"(%arg0, %0) {transpose_a = false, transpose_b = false} : (tensor<?x?xf32>, tensor<512x1024xf32>) -> (tensor<?x?xf32>)
tf_executor.fetch %1 : tensor<?x?xf32>
}
return %graph : tensor<?x?xf32>
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
transform.structured.canonicalized_sequence failures(propagate) {
^bb1(%arg1: !pdl.operation):
%fill = transform.structured.match ops{["linalg.fill"]} in %arg1
%matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1

%0:2 = transform.structured.tile_to_foreach_thread_op %matmul num_threads [1, 1]
transform.structured.fuse_into_containing_op %fill into %0#0

// first level tile and fuse matmul and fill op.
%1:3 = transform.structured.fuse %0#1 {tile_sizes = [288, 48, 0], tile_interchange = [0, 1, 2]}
// second level tile and fuse matmul and fill op.
%2:3 = transform.structured.fuse %1#0 {tile_sizes = [6, 16, 0], tile_interchange = [0, 1, 2]}

// gemm reduction axis tiling
%3:2 = transform.structured.tile %2#0 [0, 0, 1] {interchange=[0, 1, 2]}

// clean up
%func0 = transform.structured.match ops{["func.func"]} in %arg1
transform.disc.apply_patterns %func0 {canonicalization}
// fold two extract_slice ops generated by two-level tiling. It's needed to enable following
// pad and hosit schedule.
%weight_inner_slice = get_producer_of_operand %3#0[1] : (!pdl.operation) -> !pdl.operation
transform.disc.fold_producer_extract_slice %weight_inner_slice {max_repeat_num = 2}

// pad to match the requirement of hardware vector/tensor instruction.
%4 = transform.structured.pad %3#0 {padding_values=[0.0 : f32, 0.0 : f32, 0.0 : f32], padding_dimensions=[0, 1, 2], pack_paddings=[1, 1, 0], hoist_paddings=[0, 0, 0], transpose_paddings=[[1, 0], [0, 1], [0, 1]]}

%pad_for_input = get_producer_of_operand %4[0] : (!pdl.operation) -> !pdl.operation
%pad_for_weight = get_producer_of_operand %4[1] : (!pdl.operation) -> !pdl.operation
%foreach_op = transform.structured.match ops{["scf.foreach_thread"]} in %arg1
%loop_for_outter_most_n = transform.loop.get_parent_for %4 { num_loops = 4} : (!pdl.operation) -> !pdl.operation
transform.disc.cache_read {padded} %pad_for_input at %loop_for_outter_most_n with tile_levels = [1, 1] tile_sizes = [6, 1] permutation = [0, 2, 3, 1]
transform.disc.cache_read {padded} %pad_for_weight at %foreach_op with tile_levels = [1, 1] tile_sizes = [1, 16] permutation = [2, 0, 1, 3]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks pretty readable except for things like %2#0, one has to memorize meaning of each output.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll try to give more meaningful names to these return results in the later schedule.


%func1 = transform.structured.match ops{["func.func"]} in %arg1
transform.disc.apply_patterns %func1 {canonicalization}

%pack_op = transform.structured.match ops{["disc_linalg_ext.multi_level_pack"]} in %arg1
transform.disc.lower_multi_level_pack_to_loop %pack_op

%func2 = transform.structured.match ops{["func.func"]} in %arg1
transform.disc.apply_patterns %func2 {canonicalization}

%func3 = transform.structured.match ops{["func.func"]} in %arg1
transform.structured.vectorize %func3 {vectorize_padding}

%func4 = transform.structured.match ops{["func.func"]} in %arg1
transform.disc.apply_patterns %func4 {canonicalization}

transform.disc.bufferize %arg1

transform.lower_vectors {
contraction_lowering = "outerproduct",
multireduction_lowering = "innerparallel",
split_transfers = "linalg-copy",
// stages = [0, 1, 2, 3, 4, 5, 6, 7],
stages = [0, 1, 2, 3],
transpose_avx2_lowering = false,
transpose_lowering = "eltwise",
unroll_vector_transfers = true
}

transform.lower_vectors {
contraction_lowering = "outerproduct",
multireduction_lowering = "innerparallel",
split_transfers = "linalg-copy",
// stages = [0, 1, 2, 3, 4, 5, 6, 7],
stages = [5, 6, 7],
transpose_avx2_lowering = false,
transpose_lowering = "eltwise",
unroll_vector_transfers = true
}
}
50 changes: 50 additions & 0 deletions tao_compiler/mlir/disc/tests/disc-transform/packed_matmul.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/* Copyright 2022 The BladeDISC Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/compiler/mlir/disc/tests/mlir_feature_test.h"
#include "tensorflow/compiler/mlir/disc/tests/mlir_test.h"
#include "tensorflow/core/platform/test.h"

namespace mlir_test {

const std::string c_ft_path =
"tensorflow/compiler/mlir/disc/tests/disc-transform/data/";

static bool init_threads = []() {
setenv("OMP_NUM_THREADS", "1", 1);
return true;
}();

TEST(PackedMatmul, F32_304x1024x512) {
EnvSetting setting = {
{"DISC_TRANSFORM_SCHEDULE_FILE",
{c_ft_path + "packed_matmul_nn_p_f32_large_schedule.mlir", false}},
{"DISC_ENABLE_TRANSFORM_SCHEDULE", {"1", false}},
{"DISC_ENABLE_SHAPE_CONSTRAINT_IR", {"1", false}},
{"DISC_MEM_INTENSIVE_OPT_EXPERIMENTAL", {"0", false}}};
EnvSettingContext ctx(setting);
EXPECT_TRUE(feature_test_main(
/*mlir_file_path*/ c_ft_path + "packed_matmul_nn_p_512x1024_f32.mlir",
/*backend_types*/ {BackendType::kAArch64},
/*num_inputs*/ 1,
/*num_outputs*/ 1,
/*input_descriptors*/ {"304x512xf32_X"},
/*output_descriptors*/ {"f32_X"},
/*input_vals*/ {},
/*expected_output_vals*/ {},
/*profiling*/ true));
}

} // namespace mlir_test
1 change: 1 addition & 0 deletions tao_compiler/mlir/disc/tools/disc-transform/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ cc_library(
name = "legalize_lmhlo_fusion_to_linalg",
srcs = ["transforms/legalize_lmhlo_fusion_to_linalg.cc"],
deps = [
":DISCLinalgExtDialect",
":pass_details",
"//tensorflow/compiler/xla/mlir_hlo:lhlo",
"@llvm-project//llvm:Support",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,38 @@ class DISCLinalgExt_Op<string mnemonic, list<Trait> traits = []> :
code extraLinalgExtOpClassDeclaration = "";
}

// We intentionally mark this op not folable since folding it may lose the
// specail semantics captured by this op. For example, it maybe be folded to
// a arith.constant op. It's only supposed to be folded in some special
// usecases (e.g. for packed weight).
def DISCLinalgExt_ConstantWrapperOp : Op<DISCLinalgExt_Dialect, "constant_wrapper", [
Pure, AllTypesMatch<["value", "result"]>]>{
let summary = "integer or floating point tensor constant";
let description = [{
The `constant` operation produces an SSA value equal to some integer or
floating-point constant specified by an attribute. This is the way MLIR
forms simple integer and floating point constants.

Note that this op is not foldable, and is supposed to be used as a placeholder for
later rewriting for RAL.

Example:

```
// Integer constant
%1 = disc_linalg_ext.constant_wrapper dense<42> : tensor<i32>

// Equivalent generic form
%1 = "disc_linalg_ext.constant_wrapper"() {value = dense<42> : tensor<i32>} : () -> tensor<i32>
```
}];

let arguments = (ins ElementsAttr:$value);
let results = (outs AnyType:$result);

let assemblyFormat = "attr-dict $value";
}

def DISCLinalgExt_MultiLevelPackOp : DISCLinalgExt_Op<"multi_level_pack", [
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
]>{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,12 @@ func.func @main(%arg0: tensor<128x128xf32>) -> tensor<4x4x32x32xf32> {
// CHECK: disc_linalg_ext.multi_level_pack
%1 = disc_linalg_ext.multi_level_pack %arg0 with tile_levels = [1, 1] tile_sizes = [32, 32] permutation = [0, 2, 1, 3] into %0 : (tensor<128x128xf32> tensor<4x4x32x32xf32>) -> tensor<4x4x32x32xf32>
return %1 : tensor<4x4x32x32xf32>
}
}

// -----

// CHECK-LABEL: @const
func.func @const() -> tensor<2x4xf32> {
%0 = disc_linalg_ext.constant_wrapper dense<[[0.0, 1.0, 2.0, 3.0], [4.0, 5.0, 6.0, 7.0]]> : tensor<2x4xf32>
return %0 : tensor<2x4xf32>
}
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ CommonExtensions::CommonExtensions() {
using bufferization::BufferizationOptions;
using bufferization::OneShotAnalysisState;
using bufferization::OneShotBufferizationOptions;
using disc_linalg_ext::ConstantWrapperOp;
using disc_linalg_ext::MultiLevelPackOp;

namespace {
Expand Down Expand Up @@ -108,6 +109,13 @@ OneShotBufferizationOptions getBufferizationOptions() {
options.functionBoundaryTypeConversion =
BufferizationOptions::LayoutMapOption::IdentityLayoutMap;

// bufferization.to_memref is used to bufferize constant_wrapper ops. DISC has
// it's own logic to handle constants. We'd like to leave the these constant
// ops as is and insert bufferization.to_memref to convert the tensor to
// memref.
options.opFilter.denyOperation<disc_linalg_ext::ConstantWrapperOp>();
options.opFilter.denyOperation<bufferization::ToMemrefOp>();

// This type converter converts tensor types to memref types when no exact
// memref type can be inferred from the context.
options.unknownTypeConverterFn = [](Value value, unsigned memorySpace,
Expand Down Expand Up @@ -444,6 +452,37 @@ struct TransferWriteOfFillOpPattern
}
};

/// convert:
/// %0 = disc_linalg_ext.constant_wrapper ...
/// %1 = disc_linalg_ext.multi_level_pack %0 ...
/// use(%1)
/// to:
/// %0 = disc_linalg_ext.constant_wrapper ... // folded
/// use(%0)
struct FoldMultiLevelPackOfConstantWrapperPattern
: public OpRewritePattern<MultiLevelPackOp> {
using OpRewritePattern<MultiLevelPackOp>::OpRewritePattern;

LogicalResult matchAndRewrite(MultiLevelPackOp op,
PatternRewriter& rewriter) const override {
auto constOp = op.getInput().getDefiningOp<ConstantWrapperOp>();
if (!constOp) return failure();

Attribute paddingAttr;
if (op.getPaddingValue() &&
!matchPattern(op.getPaddingValue(), m_Constant(&paddingAttr)))
return failure();

SmallVector<Attribute> attrs{constOp.getValue(), nullptr, paddingAttr};
SmallVector<OpFoldResult> results;
if (failed(op.fold(attrs, results))) return failure();

rewriter.replaceOpWithNewOp<ConstantWrapperOp>(op, op.getOutputType(),
results[0].get<Attribute>());
return success();
}
};

static void addAllRegisteredCanonicalizationPatterns(
RewritePatternSet& patterns) {
MLIRContext* ctx = patterns.getContext();
Expand All @@ -458,6 +497,7 @@ static void addAllRegisteredCanonicalizationPatterns(
patterns.insert<FoldSelfInsertSlicePattern>(ctx);
patterns.insert<TransferReadOfFillOpPattern>(ctx);
patterns.insert<TransferWriteOfFillOpPattern>(ctx);
patterns.insert<FoldMultiLevelPackOfConstantWrapperPattern>(ctx);
}

} // namespace
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ class LinalgDialect;

namespace disc_ral {

namespace disc_linalg_ext {
class DISCLinalgExtDialect;
}

#define GEN_PASS_CLASSES
#include "tensorflow/compiler/mlir/disc/tools/disc-transform/transforms/transform_passes.h.inc"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
#include "mlir/IR/MLIRContext.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/Passes.h"
#include "tensorflow/compiler/mlir/disc/tools/disc-transform/LinalgExt/LinalgExtDialect.h"
#include "tensorflow/compiler/mlir/disc/tools/disc-transform/LinalgExt/LinalgExtOps.h"
#include "tensorflow/compiler/mlir/disc/tools/disc-transform/transforms/PassDetail.h"

#define DEBUG_TYPE "disc-legalize-lmhlo-fusion-to-linalg"
Expand Down Expand Up @@ -104,10 +106,23 @@ LogicalResult emitDotGeneralOp(lmhlo::DotGeneralOp op, OpBuilder& b,
return success();
}

LogicalResult emitConstOp(lmhlo::ConstantOp op, OpBuilder& b,
BlockAndValueMapping& mapping) {
auto resultTy = convertMemRefToTensorType(op->getOperand(0).getType());
Location loc = op->getLoc();
auto newOp = b.create<disc_linalg_ext::ConstantWrapperOp>(loc, resultTy,
op.getValue());
mapping.erase(op->getOperand(0));
mapping.map(op->getOperand(0), newOp.getResult());
return success();
}

LogicalResult emitLmhloOp(Operation* op, OpBuilder& b,
BlockAndValueMapping& mapping) {
if (auto dotGeneralOp = dyn_cast<lmhlo::DotGeneralOp>(op)) {
return emitDotGeneralOp(dotGeneralOp, b, mapping);
} else if (auto constOp = dyn_cast<lmhlo::ConstantOp>(op)) {
return emitConstOp(constOp, b, mapping);
}
// TODO(wyzero): support other lmhlo ops.
return failure();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ class OperationPass;

namespace disc_ral {

namespace disc_linalg_ext {
class DISCLinalgExtDialect;
}

// Converts a lmhlo fusion op in side a function to its linalg on tensor
// equivalent.
std::unique_ptr<OperationPass<ModuleOp>>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,21 @@ transform.structured.canonicalized_sequence failures(propagate) {
^bb1(%arg1: !pdl.operation):
transform.disc.apply_patterns %arg1 {canonicalization}
}

// -----

// CHECK-LABEL: @fold_constant_wrapper_and_multi_level_pack
func.func @fold_constant_wrapper_and_multi_level_pack() -> tensor<64x512x1x16xf32> {
// CHECK: %[[T0:.*]] = disc_linalg_ext.constant_wrapper dense<-8.000000e-01> : tensor<64x512x1x16xf32>
// CHECK-NEXT: return %[[T0]] : tensor<64x512x1x16xf32>
%cst = arith.constant 0.000000e+00 : f32
%0 = disc_linalg_ext.constant_wrapper dense<-8.000000e-01> : tensor<512x1024xf32>
%1 = tensor.empty() : tensor<64x512x1x16xf32>
%2 = disc_linalg_ext.multi_level_pack %0 with padding_value(%cst : f32) tile_levels = [1, 1] tile_sizes = [1, 16] permutation = [2, 0, 1, 3] into %1 : (tensor<512x1024xf32> tensor<64x512x1x16xf32>) -> tensor<64x512x1x16xf32>
return %2 : tensor<64x512x1x16xf32>
}

transform.structured.canonicalized_sequence failures(propagate) {
^bb1(%arg1: !pdl.operation):
transform.disc.apply_patterns %arg1 {canonicalization}
}
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,22 @@ func.func @not_use_alloca_due_to_dynamic_shape(%arg0: index) -> tensor<?x?xf32>
return %2 : tensor<?x?xf32>
}

transform.structured.canonicalized_sequence failures(propagate) {
^bb1(%arg1: !pdl.operation):
transform.disc.bufferize %arg1
}

// -----

// CHECK-LABEL: @bufferize_constant_wrapper
func.func @bufferize_constant_wrapper() -> tensor<512x1024xf32> {
// CHECK: %[[T0:.*]] = disc_linalg_ext.constant_wrapper dense<-8.000000e-01> : tensor<512x1024xf32>
// CHECK-NEXT: %[[T1:.*]] = bufferization.to_memref %[[T0]]
// CHECK-NEXT: return %[[T1]]
%0 = disc_linalg_ext.constant_wrapper dense<-8.000000e-01> : tensor<512x1024xf32>
return %0 : tensor<512x1024xf32>
}

transform.structured.canonicalized_sequence failures(propagate) {
^bb1(%arg1: !pdl.operation):
transform.disc.bufferize %arg1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,22 @@ func.func @matmul_nn(%arg1: memref<?x?xf32, "cpu">, %arg2: memref<?x?xf32, "cpu"
"lmhlo.terminator"() : () -> ()
}) {disc.device = "cpu", disc.fusion.name = "matmul_nn_kTransform_dot_general__1_1_0", disc.fusion_type = "kTransform"} : () -> ()
return %arg3 : memref<?x?xf32, "cpu">
}

// -----

// CHECK-LABEL: @packed_matmul_nn
// CHECK-SAME: (%[[ARG0:.*]]: tensor<?x512xf32>, %[[ARG1:.*]]: tensor<512x1024xf32>, %[[ARG2:.*]]: tensor<?x1024xf32>)
func.func @packed_matmul_nn(%arg1: memref<?x512xf32, "cpu">, %arg2: memref<512x1024xf32, "cpu">, %arg3: memref<?x1024xf32, "cpu">) -> memref<?x1024xf32, "cpu"> {
// CHECK: %[[T0:.*]] = disc_linalg_ext.constant_wrapper dense<-8.000000e-01> : tensor<512x1024xf32>
// CHECK: %[[T1:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[T2:.*]] = linalg.fill ins(%[[T1]] : f32) outs(%[[ARG2]] : tensor<?x1024xf32>) -> tensor<?x1024xf32>
// CHECK: %[[T3:.*]] = linalg.matmul ins(%[[ARG0]], %[[T0]] : tensor<?x512xf32>, tensor<512x1024xf32>) outs(%[[T2]] : tensor<?x1024xf32>) -> tensor<?x1024xf32>
// CHECK: return %[[T3]]
"lmhlo.fusion"() ({
"lmhlo.constant"(%arg2) {disc.device = "cpu", value = dense<-8.000000e-01> : tensor<512x1024xf32>} : (memref<512x1024xf32, "cpu">) -> ()
"lmhlo.dot_general"(%arg1, %arg2, %arg3) {disc.device = "cpu", dot_dimension_numbers = #mhlo.dot<lhs_contracting_dimensions = [1], rhs_contracting_dimensions = [0]>} : (memref<?x512xf32, "cpu">, memref<512x1024xf32, "cpu">, memref<?x1024xf32, "cpu">) -> ()
"lmhlo.terminator"() : () -> ()
}) {disc.device = "cpu", disc.fusion.name = "matmul_nn_kTransform_dot_general__1_1_0", disc.fusion_type = "kTransform"} : () -> ()
return %arg3 : memref<?x1024xf32, "cpu">
}
Loading