Skip to content

Commit

Permalink
[SPARK-33845][SQL] Remove unnecessary if when trueValue and falseValu…
Browse files Browse the repository at this point in the history
…e are foldable boolean types

### What changes were proposed in this pull request?

Improve `SimplifyConditionals`.
   Simplify `If(cond, TrueLiteral, FalseLiteral)` to `cond`.
   Simplify `If(cond, FalseLiteral, TrueLiteral)` to `Not(cond)`.

The use case is:
```sql
create table t1 using parquet as select id from range(10);
select if (id > 2, false, true) from t1;
```
Before this pr:
```
== Physical Plan ==
*(1) Project [if ((id#1L > 2)) false else true AS (IF((id > CAST(2 AS BIGINT)), false, true))#2]
+- *(1) ColumnarToRow
   +- FileScan parquet default.t1[id#1L] Batched: true, DataFilters: [], Format: Parquet, Location: InMemoryFileIndex[file:/Users/yumwang/opensource/spark/spark-warehouse/org.apache.spark.sql.DataF..., PartitionFilters: [], PushedFilters: [], ReadSchema: struct<id:bigint>
```
After this pr:
```
== Physical Plan ==
*(1) Project [(id#1L <= 2) AS (IF((id > CAST(2 AS BIGINT)), false, true))#2]
+- *(1) ColumnarToRow
   +- FileScan parquet default.t1[id#1L] Batched: true, DataFilters: [], Format: Parquet, Location: InMemoryFileIndex[file:/Users/yumwang/opensource/spark/spark-warehouse/org.apache.spark.sql.DataF..., PartitionFilters: [], PushedFilters: [], ReadSchema: struct<id:bigint>
```

### Why are the changes needed?

Improve query performance.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

Unit test.

Closes apache#30849 from wangyum/SPARK-33798-2.

Authored-by: Yuming Wang <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
  • Loading branch information
wangyum authored and dongjoon-hyun committed Dec 21, 2020
1 parent b4bea1a commit 4b19f49
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,8 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
case If(TrueLiteral, trueValue, _) => trueValue
case If(FalseLiteral, _, falseValue) => falseValue
case If(Literal(null, _), _, falseValue) => falseValue
case If(cond, TrueLiteral, FalseLiteral) => cond
case If(cond, FalseLiteral, TrueLiteral) => Not(cond)
case If(cond, trueValue, falseValue)
if cond.deterministic && trueValue.semanticEquals(falseValue) => trueValue
case If(cond, l @ Literal(null, _), FalseLiteral) if !cond.nullable => And(cond, l)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class PushFoldableIntoBranchesSuite

test("Push down EqualTo through If") {
assertEquivalent(EqualTo(ifExp, Literal(4)), FalseLiteral)
assertEquivalent(EqualTo(ifExp, Literal(3)), If(a, FalseLiteral, TrueLiteral))
assertEquivalent(EqualTo(ifExp, Literal(3)), Not(a))

// Push down at most one not foldable expressions.
assertEquivalent(
Expand All @@ -67,7 +67,7 @@ class PushFoldableIntoBranchesSuite
val nonDeterministic = If(LessThan(Rand(1), Literal(0.5)), Literal(1), Literal(2))
assert(!nonDeterministic.deterministic)
assertEquivalent(EqualTo(nonDeterministic, Literal(2)),
If(LessThan(Rand(1), Literal(0.5)), FalseLiteral, TrueLiteral))
GreaterThanOrEqual(Rand(1), Literal(0.5)))
assertEquivalent(EqualTo(nonDeterministic, Literal(3)),
If(LessThan(Rand(1), Literal(0.5)), FalseLiteral, FalseLiteral))

Expand Down Expand Up @@ -102,8 +102,7 @@ class PushFoldableIntoBranchesSuite
assertEquivalent(Remainder(ifExp, Literal(4)), If(a, Literal(2), Literal(3)))
assertEquivalent(Divide(If(a, Literal(2.0), Literal(3.0)), Literal(1.0)),
If(a, Literal(2.0), Literal(3.0)))
assertEquivalent(And(If(a, FalseLiteral, TrueLiteral), TrueLiteral),
If(a, FalseLiteral, TrueLiteral))
assertEquivalent(And(If(a, FalseLiteral, TrueLiteral), TrueLiteral), Not(a))
assertEquivalent(Or(If(a, FalseLiteral, TrueLiteral), TrueLiteral), TrueLiteral)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter, ArrayTransform, CaseWhen, Expression, GreaterThan, If, LambdaFunction, Literal, MapFilter, NamedExpression, Or, UnresolvedNamedLambdaVariable}
import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter, ArrayTransform, CaseWhen, Expression, GreaterThan, If, LambdaFunction, LessThanOrEqual, Literal, MapFilter, NamedExpression, Or, UnresolvedNamedLambdaVariable}
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest}
import org.apache.spark.sql.catalyst.plans.logical.{DeleteFromTable, LocalRelation, LogicalPlan, UpdateTable}
Expand Down Expand Up @@ -236,12 +236,13 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
Literal(2) === nestedCaseWhen,
TrueLiteral,
FalseLiteral)
val branches = Seq((UnresolvedAttribute("i") > Literal(10)) -> branchValue)
val condition = CaseWhen(branches)
testFilter(originalCond = condition, expectedCond = condition)
testJoin(originalCond = condition, expectedCond = condition)
testDelete(originalCond = condition, expectedCond = condition)
testUpdate(originalCond = condition, expectedCond = condition)
val condition = CaseWhen(Seq((UnresolvedAttribute("i") > Literal(10)) -> branchValue))
val expectedCond =
CaseWhen(Seq((UnresolvedAttribute("i") > Literal(10)) -> (Literal(2) === nestedCaseWhen)))
testFilter(originalCond = condition, expectedCond = expectedCond)
testJoin(originalCond = condition, expectedCond = expectedCond)
testDelete(originalCond = condition, expectedCond = expectedCond)
testUpdate(originalCond = condition, expectedCond = expectedCond)
}

test("inability to replace null in non-boolean branches of If inside another If") {
Expand All @@ -252,10 +253,14 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
Literal(3)),
TrueLiteral,
FalseLiteral)
testFilter(originalCond = condition, expectedCond = condition)
testJoin(originalCond = condition, expectedCond = condition)
testDelete(originalCond = condition, expectedCond = condition)
testUpdate(originalCond = condition, expectedCond = condition)
val expectedCond = Literal(5) > If(
UnresolvedAttribute("i") === Literal(15),
Literal(null, IntegerType),
Literal(3))
testFilter(originalCond = condition, expectedCond = expectedCond)
testJoin(originalCond = condition, expectedCond = expectedCond)
testDelete(originalCond = condition, expectedCond = expectedCond)
testUpdate(originalCond = condition, expectedCond = expectedCond)
}

test("replace null in If used as a join condition") {
Expand Down Expand Up @@ -405,9 +410,9 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
val lambda1 = LambdaFunction(
function = If(cond, Literal(null, BooleanType), TrueLiteral),
arguments = lambdaArgs)
// the optimized lambda body is: if(arg > 0, false, true)
// the optimized lambda body is: if(arg > 0, false, true) => arg <= 0
val lambda2 = LambdaFunction(
function = If(cond, FalseLiteral, TrueLiteral),
function = LessThanOrEqual(condArg, Literal(0)),
arguments = lambdaArgs)
testProjection(
originalExpr = createExpr(argument, lambda1) as 'x,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,4 +199,20 @@ class SimplifyConditionalSuite extends PlanTest with ExpressionEvalHelper with P
If(Factorial(5) > 100L, b, nullLiteral).eval(EmptyRow))
}
}

test("SPARK-33845: remove unnecessary if when the outputs are boolean type") {
assertEquivalent(
If(IsNotNull(UnresolvedAttribute("a")), TrueLiteral, FalseLiteral),
IsNotNull(UnresolvedAttribute("a")))
assertEquivalent(
If(IsNotNull(UnresolvedAttribute("a")), FalseLiteral, TrueLiteral),
IsNull(UnresolvedAttribute("a")))

assertEquivalent(
If(GreaterThan(Rand(0), UnresolvedAttribute("a")), TrueLiteral, FalseLiteral),
GreaterThan(Rand(0), UnresolvedAttribute("a")))
assertEquivalent(
If(GreaterThan(Rand(0), UnresolvedAttribute("a")), FalseLiteral, TrueLiteral),
LessThanOrEqual(Rand(0), UnresolvedAttribute("a")))
}
}

0 comments on commit 4b19f49

Please sign in to comment.