Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[transform] enhance fusion pass to support epilogue fusion #983

Merged
merged 1 commit into from
Feb 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 76 additions & 10 deletions tao_compiler/mlir/disc/transforms/fusion_utils_transform_based.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,44 @@ bool isSupportedDot(Operation* op) {
return (lhsCntractingDims.size() == 1 && rhsCntractingDims.size() == 1);
}

bool isScalarConstOp(Operation* op) {
auto constOp = dyn_cast<lmhlo::ConstantOp>(op);
if (!constOp) return false;
MemRefType type = constOp.getOutput().getType().cast<MemRefType>();
return (type.getRank() == 0 || constOp.getValue().isSplat());
}

bool isBcastOp(Operation* op) {
return isa<lmhlo::BroadcastInDimOp, lmhlo::BroadcastOp,
lmhlo::DynamicBroadcastInDimOp>(op);
}

bool isSupportedBcast(Operation* op, ShapeAnalysis& shapeAnalysisBase) {
if (!isBcastOp(op)) return false;
auto shapeIRAnalysis =
dynamic_cast<ShapeConstraintIRAnalysis*>(&shapeAnalysisBase);
auto dimAttr = op->getAttrOfType<DenseElementsAttr>("broadcast_dimensions");
assert(dimAttr);
auto dimensions = dimAttr.getValues<int64_t>();
auto in = op->getOperand(0);
auto inType = in.getType().cast<MemRefType>();
auto out = cast<lmhlo::LmhloOp>(op).getResultBuffer();
if (inType.getRank() != dimensions.size()) return false;
for (auto [inDimIdx, inDimSize] : llvm::enumerate(inType.getShape())) {
int64_t outDimIdx = dimensions[inDimIdx];
if (inDimSize != ShapedType::kDynamicSize) continue;
// linalg generic op does not support "runtime broadcast semantic", thus we
// have to know if we need to broadcast in the compile time.
if (!shapeIRAnalysis ||
!shapeIRAnalysis->isProductEqual(in, {inDimIdx}, out, {outDimIdx}))
return false;
}
return true;
}

bool TransformBasedCpuFusionStrategy::isFusible(Operation* op) {
return isSupportedDot(op) || isa<lmhlo::ConstantOp>(op);
return isSupportedDot(op) || isElementWise(op) || isBcastOp(op) ||
isa<lmhlo::ConstantOp>(op);
}

bool TransformBasedCpuFusionStrategy::initFusionPattern(
Expand All @@ -53,9 +89,11 @@ bool TransformBasedCpuFusionStrategy::initFusionPattern(
// special case for single operation.
if (fusionPattern.getOpList().size() == 1) {
Operation* op = *fusionPattern.getOpList().begin();
if (this->isFusible(op)) {
if (!isBcastOp(op) && this->isFusible(op) ||
isBcastOp(op) && isSupportedBcast(op, shapeAnalysis)) {
fusionPattern.setDominantOp(op);
fusionPattern.setFusionType(FusionType::kTransform);
fusionPattern.setFusionType(isSupportedDot(op) ? FusionType::kTransform
: FusionType::kLoop);
}
return true;
}
Expand All @@ -64,7 +102,9 @@ bool TransformBasedCpuFusionStrategy::initFusionPattern(
DenseSet<Operation*> supportedDotOps;
for (Operation* op : fusionPattern.getOpList()) {
// early return for the case where there are non supported ops.
if (!this->isFusible(op)) return true;
if (!this->isFusible(op) ||
isBcastOp(op) && !isSupportedBcast(op, shapeAnalysis))
return true;
if (isSupportedDot(op)) {
supportedDotOps.insert(op);
dotWeights.insert(op->getOperand(1));
Expand All @@ -73,18 +113,44 @@ bool TransformBasedCpuFusionStrategy::initFusionPattern(

// Only support one gemm a.t.m.
if (supportedDotOps.size() != 1) return true;
Operation* dominantDotOp = *supportedDotOps.begin();

// Only support fuse const ops that are used as weights for some dot ops and
// not consumed by ops outside the fusion pattern.
// Only support fuse const ops that are not consumed by ops outside the fusion
// pattern and have one of the following properties:
// - const ops that are used as weights for some dot ops
// - const op has single element.
DenseSet<Value> constDotWeights;
for (Operation* op : fusionPattern.getOpList()) {
if (!isa<lmhlo::ConstantOp>(op)) continue;
if (llvm::find(dotWeights, op->getOperand(0)) == dotWeights.end() ||
llvm::find(fusionPattern.getRootOps(), op) !=
fusionPattern.getRootOps().end())
if (llvm::find(fusionPattern.getRootOps(), op) !=
fusionPattern.getRootOps().end())
return true;
if (llvm::find(dotWeights, op->getOperand(0)) != dotWeights.end()) {
constDotWeights.insert(op->getOperand(0));
continue;
}
if (!isScalarConstOp(op)) return true;
}

// We only support epilogue fusion right now.
// Check if the dot op does not consume any result produced by other
// lmhlo ops (except const ops).
for (Value operand : dominantDotOp->getOperands().drop_back()) {
if (llvm::find(constDotWeights, operand) != constDotWeights.end()) continue;
auto& operands = fusionPattern.getOperands();
if (llvm::find(operands, operand) == operands.end()) return true;
}

// We only support single output right now
if (fusionPattern.getResults().size() != 1) return true;
// the shape of the output should be the same as the shape of result of
// dominant op.
for (Value result : fusionPattern.getResults()) {
if (!shapeAnalysis.isShapeEqual(result, dominantDotOp->getOperand(2)))
return true;
}

fusionPattern.setDominantOp(*supportedDotOps.begin());
fusionPattern.setDominantOp(dominantDotOp);
fusionPattern.setFusionType(FusionType::kTransform);
return true;
}
Expand Down
44 changes: 42 additions & 2 deletions tao_compiler/mlir/disc/transforms/tests/cpu-only-lmhlo-fusion.mlir
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
// RUN: disc-opt -pass-pipeline='func.func(disc-fusion{gpu-enabled=false fusion-strategy=base})' -split-input-file %s -o - | FileCheck %s --check-prefix=BASE
// RUN: disc-opt -split-input-file -pass-pipeline='func.func(disc-fusion{gpu-enabled=false fusion-strategy=base})' %s | FileCheck %s --check-prefix=BASE
// RUN: DISC_ENABLE_TRANSFORM_SCHEDULE=1 disc-opt -split-input-file -pass-pipeline='func.func(disc-fusion{gpu-enabled=false fusion-strategy=stitch})' %s -o - | FileCheck %s --check-prefix=TRANSFORM

// BASE-LABEL: @custom_call_op
// BASE-SAME: (%[[ARG0:.*]]: memref<?x?xf32, "cpu">, %[[ARG1:.*]]: memref<?x?xf32, "cpu">, %[[ARG2:.*]]: memref<?x?xf32, "cpu">, %[[ARG3:.*]]: memref<?x?xf32, "cpu">) -> memref<?x?xf32, "cpu">
func.func @custom_call_op(%arg0: memref<?x?xf32, "cpu">, %arg1: memref<?x?xf32, "cpu">,
%arg2: memref<?x?xf32, "cpu">, %arg3: memref<?x?xf32, "cpu">) -> memref<?x?xf32, "cpu"> {
// BASE-NOT: "lmhlo.fusion"
// BASE-NEXT: lmhlo.abs
// BASE-NEXT: lmhlo_disc.custom_call
// BASE-NEXT: lmhlo.add
// BASE-NEXT: return
"lmhlo.abs"(%arg0, %arg1) : (memref<?x?xf32, "cpu">, memref<?x?xf32, "cpu">) -> ()
"lmhlo_disc.custom_call"(%arg1, %arg2) {backend_config = "{}", call_target_name = "test", disc.device = "cpu", has_side_effect = false, operand_segment_sizes = array<i32: 1, 1>} : (memref<?x?xf32, "cpu">, memref<?x?xf32, "cpu">) -> ()
"lmhlo.add"(%arg1, %arg2, %arg3) : (memref<?x?xf32, "cpu">, memref<?x?xf32, "cpu">, memref<?x?xf32, "cpu">) -> ()
Expand Down Expand Up @@ -60,4 +63,41 @@ func.func @matmul_nn_const_weight_with_external_user(%arg1: memref<1024x1024xf32
"lmhlo.constant"(%arg1) {disc.device = "cpu", value = dense<-1.0> : tensor<1024x1024xf32>} : (memref<1024x1024xf32, "cpu">) -> ()
"lmhlo.dot_general"(%arg2, %arg1, %arg3) {dot_dimension_numbers = #mhlo.dot<lhs_contracting_dimensions = [1], rhs_contracting_dimensions = [0]>} : (memref<?x1024xf32, "cpu">, memref<1024x1024xf32, "cpu">, memref<?x1024xf32, "cpu">) -> ()
return %arg1, %arg3 : memref<1024x1024xf32, "cpu">, memref<?x1024xf32, "cpu">
}

// -----

// TRANSFORM-LABEL: @matmul_nn_const_weight_with_epilogue0
func.func @matmul_nn_const_weight_with_epilogue0(%arg1: memref<1024x1024xf32, "cpu">,
%arg2: memref<?x1024xf32, "cpu">, %arg3: memref<?x1024xf32, "cpu">,
%arg4: memref<f32, "cpu">,
%arg5: memref<2xindex, "cpu">) -> (memref<?x1024xf32, "cpu">) {
"lmhlo.constant"(%arg1) {disc.device = "cpu", value = dense<-1.0> : tensor<1024x1024xf32>} : (memref<1024x1024xf32, "cpu">) -> ()
%c0 = arith.constant 0 : index
%d0 = memref.dim %arg2, %c0 : memref<?x1024xf32, "cpu">
%t0 = memref.alloc(%d0) {kDiscSymbolicDimAttr = [@S0, @C1024]} : memref<?x1024xf32, "cpu">
// TRANSFORM: "lmhlo.fusion"() ({
// TRANSFORM-NEXT: lmhlo.constant
// TRANSFORM-SAME: value = dense<-1.000000e+00>
// TRANSFORM-NEXT: lmhlo.dot_general
// TRANSFORM-NEXT: lmhlo.constant
// TRANSFORM-SAME: value = dense<1.000000e+00>
// TRANSFORM-NEXT: lmhlo.dynamic_broadcast_in_dim
// TRANSFORM-NEXT: lmhlo.add
// TRANSFORM-NEXT: lmhlo.terminator
// TRANSFORM-NEXT: })
// TRANSFORM-SAME: disc.fusion_type = "kTransform"
"lmhlo.dot_general"(%arg2, %arg1, %t0) {dot_dimension_numbers = #mhlo.dot<lhs_contracting_dimensions = [1], rhs_contracting_dimensions = [0]>} : (memref<?x1024xf32, "cpu">, memref<1024x1024xf32, "cpu">, memref<?x1024xf32, "cpu">) -> ()
"lmhlo.constant"(%arg4) {disc.device = "cpu", value = dense<1.0> : tensor<f32>} : (memref<f32, "cpu">) -> ()
%t1 = memref.alloc(%d0) {kDiscSymbolicDimAttr = [@S0, @C1024]} : memref<?x1024xf32, "cpu">
"lmhlo.dynamic_broadcast_in_dim"(%arg4, %arg5, %t1) {disc.device = "cpu", broadcast_dimensions = dense<[]> : tensor<0xi64>} : (memref<f32, "cpu">, memref<2xindex, "cpu">, memref<?x1024xf32, "cpu">) -> ()
%t2 = memref.alloc(%d0) {kDiscSymbolicDimAttr = [@S0, @C1024]} : memref<?x1024xf32, "cpu">
"lmhlo.add"(%t0, %t1, %t2) : (memref<?x1024xf32, "cpu">, memref<?x1024xf32, "cpu">, memref<?x1024xf32, "cpu">) -> ()
return %t2 : memref<?x1024xf32, "cpu">
}

"disc_shape.SymbolicDim"() {knownNegativeOne = false, knownNonNegative = true, knownNonSizeOne = false, knownNonSizeZero = false, sym_name = "S0", value = -1 : i64} : () -> ()
"disc_shape.SymbolicDim"() {knownNegativeOne = false, knownNonNegative = true, knownNonSizeOne = true, knownNonSizeZero = true, sym_name = "C1024", value = 1024 : i64} : () -> ()
func.func @shape_constraint_graph() {
return
}