Skip to content

Commit

Permalink
update tests and small refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
allisonwang-db committed Apr 20, 2021
1 parent d804c22 commit c34e700
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 63 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ object DecorrelateInnerQuery extends PredicateHelper {
private def canPullUpOverAgg(expression: Expression): Boolean = expression match {
case Equality(_: Attribute, b) => !containsAttribute(b)
case Equality(a, _: Attribute) => !containsAttribute(a)
case _ => false
case o => !containsAttribute(o)
}

/**
Expand Down Expand Up @@ -190,6 +190,64 @@ object DecorrelateInnerQuery extends PredicateHelper {
}
}

/**
* Build a mapping between domain attributes and corresponding outer query expressions
* using the join conditions.
*/
private def buildDomainAttrMap(
conditions: Seq[Expression],
domainAttrs: Seq[Attribute]): Map[Attribute, Expression] = {
val domainAttrSet = AttributeSet(domainAttrs)
conditions.collect {
// When we build the join conditions between the domain attributes and outer references,
// the left hand side is always the domain attribute used in the inner query and the right
// hand side is the attribute from the outer query. Note here the right hand side of a
// condition is not necessarily an attribute, for example it can be a literal (if foldable)
// or a cast expression after the optimization.
case EqualNullSafe(left: Attribute, right: Expression) if domainAttrSet.contains(left) =>
left -> right
}.toMap
}

/**
* Rewrite all [[DomainJoin]]s in the inner query to actual inner joins with the outer query.
*/
def rewriteDomainJoins(
outerPlan: LogicalPlan,
innerPlan: LogicalPlan,
conditions: Seq[Expression]): LogicalPlan = {
innerPlan transform {
case d @ DomainJoin(domainAttrs, child) =>
val domainAttrMap = buildDomainAttrMap(conditions, domainAttrs)
// We should only rewrite a domain join when all corresponding outer plan attributes
// can be found from the join condition.
if (domainAttrMap.size == domainAttrs.size) {
val groupingExprs = domainAttrs.map(domainAttrMap)
val aggregateExprs = groupingExprs.zip(domainAttrs).map {
// Rebuild the aliases.
case (inputAttr, outputAttr) => Alias(inputAttr, outputAttr.name)(outputAttr.exprId)
}
// Construct a domain with the outer query plan.
// DomainJoin [a', b'] => Aggregate [a, b] [a AS a', b AS b']
// +- Relation [a, b]
val domain = Aggregate(groupingExprs, aggregateExprs, outerPlan)
child match {
// A special optimization for OneRowRelation.
// TODO: add a more general rule to optimize join with OneRowRelation.
case _: OneRowRelation => domain
// Construct a domain join.
// Join Inner
// :- Inner Query
// +- Domain
case _ => Join(child, domain, Inner, None, JoinHint.NONE)
}
} else {
throw new UnsupportedOperationException(
s"Unable to rewrite domain join with conditions: $conditions\n$d")
}
}
}

def apply(
innerPlan: LogicalPlan,
outerPlan: LogicalPlan): (LogicalPlan, Seq[Expression]) = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper
if (newCond.isEmpty) oldCond else newCond
}

def rewrite(sub: LogicalPlan, outer: Seq[LogicalPlan]): (LogicalPlan, Seq[Expression]) = {
def decorrelate(sub: LogicalPlan, outer: Seq[LogicalPlan]): (LogicalPlan, Seq[Expression]) = {
if (SQLConf.get.decorrelateInnerQueryEnabled) {
DecorrelateInnerQuery(sub, outer)
} else {
Expand All @@ -305,7 +305,7 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper

plan transformExpressions {
case ScalarSubquery(sub, children, exprId) if children.nonEmpty =>
val (newPlan, newCond) = rewrite(sub, outerPlans)
val (newPlan, newCond) = decorrelate(sub, outerPlans)
ScalarSubquery(newPlan, getJoinCondition(newCond, children), exprId)
case Exists(sub, children, exprId) if children.nonEmpty =>
val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, outerPlans)
Expand Down Expand Up @@ -509,56 +509,6 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe
// Name of generated column used in rewrite below
val ALWAYS_TRUE_COLNAME = "alwaysTrue"

/**
* Build a mapping between domain attributes and corresponding outer query expressions
* using the join conditions.
*/
private def buildDomainAttrMap(
conditions: Seq[Expression],
domainAttrs: Seq[Attribute]): Map[Attribute, Expression] = {
val outputSet = AttributeSet(domainAttrs)
conditions.collect {
// When we build the equality conditions, the left side is always the
// domain attributes used in the inner plan, and the right side is the
// attribute from outer plan. Note the right hand side is not necessarily
// an attribute, for example it can be a literal (if foldable) or a cast expression.
case EqualNullSafe(left: Attribute, right: Expression) if outputSet.contains(left) =>
left -> right
}.toMap
}

/**
* Rewrite domain join placeholder to actual inner joins.
*/
private def rewriteDomainJoins(
outerPlan: LogicalPlan,
innerPlan: LogicalPlan,
conditions: Seq[Expression]): LogicalPlan = {
innerPlan transform {
case d @ DomainJoin(domainAttrs, child) =>
val domainAttrMap = buildDomainAttrMap(conditions, domainAttrs)
// We should only rewrite a domain join when all corresponding outer plan attributes
// can be found from the join condition.
if (domainAttrMap.size == domainAttrs.size) {
val groupingExprs = domainAttrs.map(domainAttrMap)
val aggregateExprs = groupingExprs.zip(domainAttrs).map {
// Rebuild the aliases.
case (inputAttr, outputAttr) => Alias(inputAttr, outputAttr.name)(outputAttr.exprId)
}
val domain = Aggregate(groupingExprs, aggregateExprs, outerPlan)
child match {
// A special optimization for OneRowRelation.
// TODO: add a more general rule to optimize join with OneRowRelation.
case _: OneRowRelation => domain
case _ => Join(child, domain, Inner, None, JoinHint.NONE)
}
} else {
throw new UnsupportedOperationException(
s"Unable to rewrite domain join with conditions: $conditions\n$d")
}
}
}

/**
* Construct a new child plan by left joining the given subqueries to a base plan.
* This method returns the child plan and an attribute mapping
Expand All @@ -571,7 +521,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe
val subqueryAttrMapping = ArrayBuffer[(Attribute, Attribute)]()
val newChild = subqueries.foldLeft(child) {
case (currentChild, ScalarSubquery(sub, conditions, _)) =>
val query = rewriteDomainJoins(currentChild, sub, conditions)
val query = DecorrelateInnerQuery.rewriteDomainJoins(currentChild, sub, conditions)
val origOutput = query.output.head

val resultWithZeroTups = evalSubqueryOnZeroTups(query)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,34 +110,35 @@ class DecorrelateInnerQuerySuite extends PlanTest {
check(innerPlan, outerPlan, correctAnswer, Seq(x === y + b))
}

test("aggregate with correlated equality predicates - 1") {
test("aggregate with correlated equality predicates that can be pulled up") {
val outerPlan = testRelation2
val minB = Alias(min(b), "min_b")()
val innerPlan =
Aggregate(Nil, Seq(minB),
Filter(And(OuterReference(x) === a + c, b === 3),
Filter(And(OuterReference(x) === a, b === 3),
testRelation))
val correctAnswer =
Aggregate(Seq(a, c), Seq(minB, a, c),
Aggregate(Seq(a), Seq(minB, a),
Filter(b === 3,
testRelation))
check(innerPlan, outerPlan, correctAnswer, Seq(x === a + c))
check(innerPlan, outerPlan, correctAnswer, Seq(x === a))
}

test("aggregate with correlated equality predicates - 2") {
test("aggregate with correlated equality predicates that cannot be pulled up") {
val outerPlan = testRelation2
val minB = Alias(min(b), "min_b")()
val innerPlan =
Aggregate(Nil, Seq(minB),
Filter(OuterReference(x) === OuterReference(y) + a,
testRelation))
val correctAnswer =
Aggregate(Seq(a), Seq(minB, a),
testRelation)
check(innerPlan, outerPlan, correctAnswer, Seq(x === y + a))
Aggregate(Seq(x, y), Seq(minB, x, y),
Filter(x === y + a,
DomainJoin(Seq(x, y), testRelation)))
check(innerPlan, outerPlan, correctAnswer, Seq(x <=> x, y <=> y))
}

test("aggregate with correlated equality predicates - 3") {
test("aggregate with correlated equality predicates that has no attribute") {
val outerPlan = testRelation2
val minB = Alias(min(b), "min_b")()
val innerPlan =
Expand Down

0 comments on commit c34e700

Please sign in to comment.