Skip to content

Commit

Permalink
[mlir] Canonicalize IfOp with trivial then and else bodies to lis…
Browse files Browse the repository at this point in the history
…t of SelectOp's

* Do we need a threshold on maximum number of Yeild arguments processed (maximum number of SelectOp's to be generated)?
* Had to modify some old IfOp tests to not get optimized by this pattern

Differential Revision: https://reviews.llvm.org/D98592
  • Loading branch information
Butygin committed Mar 20, 2021
1 parent 319d093 commit 5657f93
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 1 deletion.
40 changes: 39 additions & 1 deletion mlir/lib/Dialect/SCF/SCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -934,11 +934,49 @@ struct RemoveStaticCondition : public OpRewritePattern<IfOp> {
return success();
}
};

struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> {
using OpRewritePattern<IfOp>::OpRewritePattern;

LogicalResult matchAndRewrite(IfOp op,
PatternRewriter &rewriter) const override {
if (op->getNumResults() == 0)
return failure();

if (!llvm::hasSingleElement(op.thenRegion().front()) ||
!llvm::hasSingleElement(op.elseRegion().front()))
return failure();

auto cond = op.condition();
auto thenYieldArgs =
cast<scf::YieldOp>(op.thenRegion().front().getTerminator())
.getOperands();
auto elseYieldArgs =
cast<scf::YieldOp>(op.elseRegion().front().getTerminator())
.getOperands();
SmallVector<Value> results(op->getNumResults());
assert(thenYieldArgs.size() == results.size());
assert(elseYieldArgs.size() == results.size());
for (auto it : llvm::enumerate(llvm::zip(thenYieldArgs, elseYieldArgs))) {
Value trueVal = std::get<0>(it.value());
Value falseVal = std::get<1>(it.value());
if (trueVal == falseVal)
results[it.index()] = trueVal;
else
results[it.index()] =
rewriter.create<SelectOp>(op.getLoc(), cond, trueVal, falseVal);
}

rewriter.replaceOp(op, results);
return success();
}
};
} // namespace

void IfOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.insert<RemoveUnusedResults, RemoveStaticCondition>(context);
results.insert<RemoveUnusedResults, RemoveStaticCondition,
ConvertTrivialIfToSelect>(context);
}

//===----------------------------------------------------------------------===//
Expand Down
96 changes: 96 additions & 0 deletions mlir/test/Dialect/SCF/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,12 @@ func @single_iteration(%A: memref<?x?x?xi32>) {

// -----

func private @side_effect()
func @one_unused(%cond: i1) -> (index) {
%c0 = constant 0 : index
%c1 = constant 1 : index
%0, %1 = scf.if %cond -> (index, index) {
call @side_effect() : () -> ()
scf.yield %c0, %c1 : index, index
} else {
scf.yield %c0, %c1 : index, index
Expand All @@ -49,6 +51,7 @@ func @one_unused(%cond: i1) -> (index) {
// CHECK-LABEL: func @one_unused
// CHECK: [[C0:%.*]] = constant 1 : index
// CHECK: [[V0:%.*]] = scf.if %{{.*}} -> (index) {
// CHECK: call @side_effect() : () -> ()
// CHECK: scf.yield [[C0]] : index
// CHECK: } else
// CHECK: scf.yield [[C0]] : index
Expand All @@ -57,11 +60,13 @@ func @one_unused(%cond: i1) -> (index) {

// -----

func private @side_effect()
func @nested_unused(%cond1: i1, %cond2: i1) -> (index) {
%c0 = constant 0 : index
%c1 = constant 1 : index
%0, %1 = scf.if %cond1 -> (index, index) {
%2, %3 = scf.if %cond2 -> (index, index) {
call @side_effect() : () -> ()
scf.yield %c0, %c1 : index, index
} else {
scf.yield %c0, %c1 : index, index
Expand All @@ -77,6 +82,7 @@ func @nested_unused(%cond1: i1, %cond2: i1) -> (index) {
// CHECK: [[C0:%.*]] = constant 1 : index
// CHECK: [[V0:%.*]] = scf.if {{.*}} -> (index) {
// CHECK: [[V1:%.*]] = scf.if {{.*}} -> (index) {
// CHECK: call @side_effect() : () -> ()
// CHECK: scf.yield [[C0]] : index
// CHECK: } else
// CHECK: scf.yield [[C0]] : index
Expand Down Expand Up @@ -113,6 +119,96 @@ func @all_unused(%cond: i1) {

// -----

func @empty_if1(%cond: i1) {
scf.if %cond {
scf.yield
}
return
}

// CHECK-LABEL: func @empty_if1
// CHECK-NOT: scf.if
// CHECK: return

// -----

func @empty_if2(%cond: i1) {
scf.if %cond {
scf.yield
} else {
scf.yield
}
return
}

// CHECK-LABEL: func @empty_if2
// CHECK-NOT: scf.if
// CHECK: return

// -----

func @to_select1(%cond: i1) -> index {
%c0 = constant 0 : index
%c1 = constant 1 : index
%0 = scf.if %cond -> index {
scf.yield %c0 : index
} else {
scf.yield %c1 : index
}
return %0 : index
}

// CHECK-LABEL: func @to_select1
// CHECK: [[C0:%.*]] = constant 0 : index
// CHECK: [[C1:%.*]] = constant 1 : index
// CHECK: [[V0:%.*]] = select {{.*}}, [[C0]], [[C1]]
// CHECK: return [[V0]] : index

// -----

func @to_select_same_val(%cond: i1) -> (index, index) {
%c0 = constant 0 : index
%c1 = constant 1 : index
%0, %1 = scf.if %cond -> (index, index) {
scf.yield %c0, %c1 : index, index
} else {
scf.yield %c1, %c1 : index, index
}
return %0, %1 : index, index
}

// CHECK-LABEL: func @to_select_same_val
// CHECK: [[C0:%.*]] = constant 0 : index
// CHECK: [[C1:%.*]] = constant 1 : index
// CHECK: [[V0:%.*]] = select {{.*}}, [[C0]], [[C1]]
// CHECK: return [[V0]], [[C1]] : index, index

// -----

func @to_select2(%cond: i1) -> (index, index) {
%c0 = constant 0 : index
%c1 = constant 1 : index
%c2 = constant 2 : index
%c3 = constant 3 : index
%0, %1 = scf.if %cond -> (index, index) {
scf.yield %c0, %c1 : index, index
} else {
scf.yield %c2, %c3 : index, index
}
return %0, %1 : index, index
}

// CHECK-LABEL: func @to_select2
// CHECK: [[C0:%.*]] = constant 0 : index
// CHECK: [[C1:%.*]] = constant 1 : index
// CHECK: [[C2:%.*]] = constant 2 : index
// CHECK: [[C3:%.*]] = constant 3 : index
// CHECK: [[V0:%.*]] = select {{.*}}, [[C0]], [[C2]]
// CHECK: [[V1:%.*]] = select {{.*}}, [[C1]], [[C3]]
// CHECK: return [[V0]], [[V1]] : index

// -----

func private @make_i32() -> i32

func @for_yields_2(%lb : index, %ub : index, %step : index) -> i32 {
Expand Down

0 comments on commit 5657f93

Please sign in to comment.