Skip to content

Commit

Permalink
[SQL][Minor] Refactors deeply nested FP style code in BooleanSimplifi…
Browse files Browse the repository at this point in the history
…cation

This is a follow-up of apache#4090. The original deeply nested `reduceOption` code is hard to grasp.

<!-- Reviewable:start -->
[<img src="https://reviewable.io/review_button.png" height=40 alt="Review on Reviewable"/>](https://reviewable.io/reviews/apache/spark/4091)
<!-- Reviewable:end -->

Author: Cheng Lian <[email protected]>

Closes apache#4091 from liancheng/refactor-boolean-simplification and squashes the following commits:

cd8860b [Cheng Lian] Improves `compareConditions` to handle more subtle cases
1bf3258 [Cheng Lian] Avoids converting predicate sets to lists
e833ca4 [Cheng Lian] Refactors deeply nested FP style code
  • Loading branch information
liancheng authored and bomeng committed Jan 21, 2015
1 parent 03646d8 commit b02acb2
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper {
// a && a => a
case (l, r) if l fastEquals r => l
// (a || b) && (a || c) => a || (b && c)
case (_, _) =>
case _ =>
// 1. Split left and right to get the disjunctive predicates,
// i.e. lhsSet = (a, b), rhsSet = (a, c)
// 2. Find the common predict between lhsSet and rhsSet, i.e. common = (a)
Expand All @@ -323,19 +323,20 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper {
val lhsSet = splitDisjunctivePredicates(left).toSet
val rhsSet = splitDisjunctivePredicates(right).toSet
val common = lhsSet.intersect(rhsSet)
val ldiff = lhsSet.diff(common)
val rdiff = rhsSet.diff(common)
if (ldiff.size == 0 || rdiff.size == 0) {
// a && (a || b) => a
common.reduce(Or)
if (common.isEmpty) {
// No common factors, return the original predicate
and
} else {
// (a || b || c || ...) && (a || b || d || ...) && (a || b || e || ...) ... =>
// (a || b) || ((c || ...) && (f || ...) && (e || ...) && ...)
(ldiff.reduceOption(Or) ++ rdiff.reduceOption(Or))
.reduceOption(And)
.map(_ :: common.toList)
.getOrElse(common.toList)
.reduce(Or)
val ldiff = lhsSet.diff(common)
val rdiff = rhsSet.diff(common)
if (ldiff.isEmpty || rdiff.isEmpty) {
// (a || b || c || ...) && (a || b) => (a || b)
common.reduce(Or)
} else {
// (a || b || c || ...) && (a || b || d || ...) =>
// ((c || ...) && (d || ...)) || a || b
(common + And(ldiff.reduce(Or), rdiff.reduce(Or))).reduce(Or)
}
}
} // end of And(left, right)

Expand All @@ -351,7 +352,7 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper {
// a || a => a
case (l, r) if l fastEquals r => l
// (a && b) || (a && c) => a && (b || c)
case (_, _) =>
case _ =>
// 1. Split left and right to get the conjunctive predicates,
// i.e. lhsSet = (a, b), rhsSet = (a, c)
// 2. Find the common predict between lhsSet and rhsSet, i.e. common = (a)
Expand All @@ -360,19 +361,20 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper {
val lhsSet = splitConjunctivePredicates(left).toSet
val rhsSet = splitConjunctivePredicates(right).toSet
val common = lhsSet.intersect(rhsSet)
val ldiff = lhsSet.diff(common)
val rdiff = rhsSet.diff(common)
if ( ldiff.size == 0 || rdiff.size == 0) {
// a || (b && a) => a
common.reduce(And)
if (common.isEmpty) {
// No common factors, return the original predicate
or
} else {
// (a && b && c && ...) || (a && b && d && ...) || (a && b && e && ...) ... =>
// a && b && ((c && ...) || (d && ...) || (e && ...) || ...)
(ldiff.reduceOption(And) ++ rdiff.reduceOption(And))
.reduceOption(Or)
.map(_ :: common.toList)
.getOrElse(common.toList)
.reduce(And)
val ldiff = lhsSet.diff(common)
val rdiff = rhsSet.diff(common)
if (ldiff.isEmpty || rdiff.isEmpty) {
// (a && b) || (a && b && c && ...) => a && b
common.reduce(And)
} else {
// (a && b && c && ...) || (a && b && d && ...) =>
// ((c && ...) || (d && ...)) && a && b
(common + Or(ldiff.reduce(And), rdiff.reduce(And))).reduce(And)
}
}
} // end of Or(left, right)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.analysis.EliminateAnalysisOperators
import org.apache.spark.sql.catalyst.expressions.{Literal, Expression}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.dsl.expressions._

class BooleanSimplificationSuite extends PlanTest {
class BooleanSimplificationSuite extends PlanTest with PredicateHelper {

object Optimize extends RuleExecutor[LogicalPlan] {
val batches =
Expand All @@ -40,11 +40,29 @@ class BooleanSimplificationSuite extends PlanTest {

val testRelation = LocalRelation('a.int, 'b.int, 'c.int, 'd.string)

def checkCondition(originCondition: Expression, optimizedCondition: Expression): Unit = {
val originQuery = testRelation.where(originCondition).analyze
val optimized = Optimize(originQuery)
val expected = testRelation.where(optimizedCondition).analyze
comparePlans(optimized, expected)
// The `foldLeft` is required to handle cases like comparing `a && (b && c)` and `(a && b) && c`
def compareConditions(e1: Expression, e2: Expression): Boolean = (e1, e2) match {
case (lhs: And, rhs: And) =>
val lhsSet = splitConjunctivePredicates(lhs).toSet
val rhsSet = splitConjunctivePredicates(rhs).toSet
lhsSet.foldLeft(rhsSet) { (set, e) =>
set.find(compareConditions(_, e)).map(set - _).getOrElse(set)
}.isEmpty

case (lhs: Or, rhs: Or) =>
val lhsSet = splitDisjunctivePredicates(lhs).toSet
val rhsSet = splitDisjunctivePredicates(rhs).toSet
lhsSet.foldLeft(rhsSet) { (set, e) =>
set.find(compareConditions(_, e)).map(set - _).getOrElse(set)
}.isEmpty

case (l, r) => l == r
}

def checkCondition(input: Expression, expected: Expression): Unit = {
val plan = testRelation.where(input).analyze
val actual = Optimize(plan).expressions.head
compareConditions(actual, expected)
}

test("a && a => a") {
Expand Down Expand Up @@ -72,8 +90,8 @@ class BooleanSimplificationSuite extends PlanTest {
(((('b > 3) && ('c > 2)) ||
(('c < 1) && ('a === 5))) ||
(('b < 5) && ('a > 1))) && ('a === 'b)
checkCondition(input, expected)

checkCondition(input, expected)
}

test("(a || b || c || ...) && (a || b || d || ...) && (a || b || e || ...) ...") {
Expand All @@ -85,8 +103,8 @@ class BooleanSimplificationSuite extends PlanTest {

checkCondition(('a < 2 || 'b > 3) && ('a < 2 || 'c > 5), ('b > 3 && 'c > 5) || 'a < 2)

var input: Expression = ('a === 'b || 'b > 3) && ('a === 'b || 'a > 3) && ('a === 'b || 'a < 5)
var expected: Expression = ('b > 3 && 'a > 3 && 'a < 5) || 'a === 'b
checkCondition(input, expected)
checkCondition(
('a === 'b || 'b > 3) && ('a === 'b || 'a > 3) && ('a === 'b || 'a < 5),
('b > 3 && 'a > 3 && 'a < 5) || 'a === 'b)
}
}

0 comments on commit b02acb2

Please sign in to comment.