diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 81bb012ac6d74..376a9f36568a7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -314,7 +314,7 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper { // a && a => a case (l, r) if l fastEquals r => l // (a || b) && (a || c) => a || (b && c) - case (_, _) => + case _ => // 1. Split left and right to get the disjunctive predicates, // i.e. lhsSet = (a, b), rhsSet = (a, c) // 2. Find the common predict between lhsSet and rhsSet, i.e. common = (a) @@ -323,19 +323,20 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper { val lhsSet = splitDisjunctivePredicates(left).toSet val rhsSet = splitDisjunctivePredicates(right).toSet val common = lhsSet.intersect(rhsSet) - val ldiff = lhsSet.diff(common) - val rdiff = rhsSet.diff(common) - if (ldiff.size == 0 || rdiff.size == 0) { - // a && (a || b) => a - common.reduce(Or) + if (common.isEmpty) { + // No common factors, return the original predicate + and } else { - // (a || b || c || ...) && (a || b || d || ...) && (a || b || e || ...) ... => - // (a || b) || ((c || ...) && (f || ...) && (e || ...) && ...) - (ldiff.reduceOption(Or) ++ rdiff.reduceOption(Or)) - .reduceOption(And) - .map(_ :: common.toList) - .getOrElse(common.toList) - .reduce(Or) + val ldiff = lhsSet.diff(common) + val rdiff = rhsSet.diff(common) + if (ldiff.isEmpty || rdiff.isEmpty) { + // (a || b || c || ...) && (a || b) => (a || b) + common.reduce(Or) + } else { + // (a || b || c || ...) && (a || b || d || ...) => + // ((c || ...) && (d || ...)) || a || b + (common + And(ldiff.reduce(Or), rdiff.reduce(Or))).reduce(Or) + } } } // end of And(left, right) @@ -351,7 +352,7 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper { // a || a => a case (l, r) if l fastEquals r => l // (a && b) || (a && c) => a && (b || c) - case (_, _) => + case _ => // 1. Split left and right to get the conjunctive predicates, // i.e. lhsSet = (a, b), rhsSet = (a, c) // 2. Find the common predict between lhsSet and rhsSet, i.e. common = (a) @@ -360,19 +361,20 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper { val lhsSet = splitConjunctivePredicates(left).toSet val rhsSet = splitConjunctivePredicates(right).toSet val common = lhsSet.intersect(rhsSet) - val ldiff = lhsSet.diff(common) - val rdiff = rhsSet.diff(common) - if ( ldiff.size == 0 || rdiff.size == 0) { - // a || (b && a) => a - common.reduce(And) + if (common.isEmpty) { + // No common factors, return the original predicate + or } else { - // (a && b && c && ...) || (a && b && d && ...) || (a && b && e && ...) ... => - // a && b && ((c && ...) || (d && ...) || (e && ...) || ...) - (ldiff.reduceOption(And) ++ rdiff.reduceOption(And)) - .reduceOption(Or) - .map(_ :: common.toList) - .getOrElse(common.toList) - .reduce(And) + val ldiff = lhsSet.diff(common) + val rdiff = rhsSet.diff(common) + if (ldiff.isEmpty || rdiff.isEmpty) { + // (a && b) || (a && b && c && ...) => a && b + common.reduce(And) + } else { + // (a && b && c && ...) || (a && b && d && ...) => + // ((c && ...) || (d && ...)) && a && b + (common + Or(ldiff.reduce(And), rdiff.reduce(And))).reduce(And) + } } } // end of Or(left, right) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala index a0863dad96eb0..264a0eff37d34 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala @@ -18,14 +18,14 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.analysis.EliminateAnalysisOperators -import org.apache.spark.sql.catalyst.expressions.{Literal, Expression} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.dsl.expressions._ -class BooleanSimplificationSuite extends PlanTest { +class BooleanSimplificationSuite extends PlanTest with PredicateHelper { object Optimize extends RuleExecutor[LogicalPlan] { val batches = @@ -40,11 +40,29 @@ class BooleanSimplificationSuite extends PlanTest { val testRelation = LocalRelation('a.int, 'b.int, 'c.int, 'd.string) - def checkCondition(originCondition: Expression, optimizedCondition: Expression): Unit = { - val originQuery = testRelation.where(originCondition).analyze - val optimized = Optimize(originQuery) - val expected = testRelation.where(optimizedCondition).analyze - comparePlans(optimized, expected) + // The `foldLeft` is required to handle cases like comparing `a && (b && c)` and `(a && b) && c` + def compareConditions(e1: Expression, e2: Expression): Boolean = (e1, e2) match { + case (lhs: And, rhs: And) => + val lhsSet = splitConjunctivePredicates(lhs).toSet + val rhsSet = splitConjunctivePredicates(rhs).toSet + lhsSet.foldLeft(rhsSet) { (set, e) => + set.find(compareConditions(_, e)).map(set - _).getOrElse(set) + }.isEmpty + + case (lhs: Or, rhs: Or) => + val lhsSet = splitDisjunctivePredicates(lhs).toSet + val rhsSet = splitDisjunctivePredicates(rhs).toSet + lhsSet.foldLeft(rhsSet) { (set, e) => + set.find(compareConditions(_, e)).map(set - _).getOrElse(set) + }.isEmpty + + case (l, r) => l == r + } + + def checkCondition(input: Expression, expected: Expression): Unit = { + val plan = testRelation.where(input).analyze + val actual = Optimize(plan).expressions.head + compareConditions(actual, expected) } test("a && a => a") { @@ -72,8 +90,8 @@ class BooleanSimplificationSuite extends PlanTest { (((('b > 3) && ('c > 2)) || (('c < 1) && ('a === 5))) || (('b < 5) && ('a > 1))) && ('a === 'b) - checkCondition(input, expected) + checkCondition(input, expected) } test("(a || b || c || ...) && (a || b || d || ...) && (a || b || e || ...) ...") { @@ -85,8 +103,8 @@ class BooleanSimplificationSuite extends PlanTest { checkCondition(('a < 2 || 'b > 3) && ('a < 2 || 'c > 5), ('b > 3 && 'c > 5) || 'a < 2) - var input: Expression = ('a === 'b || 'b > 3) && ('a === 'b || 'a > 3) && ('a === 'b || 'a < 5) - var expected: Expression = ('b > 3 && 'a > 3 && 'a < 5) || 'a === 'b - checkCondition(input, expected) + checkCondition( + ('a === 'b || 'b > 3) && ('a === 'b || 'a > 3) && ('a === 'b || 'a < 5), + ('b > 3 && 'a > 3 && 'a < 5) || 'a === 'b) } }