diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala index 546b21bdcc3b4..377dcd6666d64 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala @@ -101,7 +101,7 @@ object DecorrelateInnerQuery extends PredicateHelper { private def canPullUpOverAgg(expression: Expression): Boolean = expression match { case Equality(_: Attribute, b) => !containsAttribute(b) case Equality(a, _: Attribute) => !containsAttribute(a) - case _ => false + case o => !containsAttribute(o) } /** @@ -190,6 +190,64 @@ object DecorrelateInnerQuery extends PredicateHelper { } } + /** + * Build a mapping between domain attributes and corresponding outer query expressions + * using the join conditions. + */ + private def buildDomainAttrMap( + conditions: Seq[Expression], + domainAttrs: Seq[Attribute]): Map[Attribute, Expression] = { + val domainAttrSet = AttributeSet(domainAttrs) + conditions.collect { + // When we build the join conditions between the domain attributes and outer references, + // the left hand side is always the domain attribute used in the inner query and the right + // hand side is the attribute from the outer query. Note here the right hand side of a + // condition is not necessarily an attribute, for example it can be a literal (if foldable) + // or a cast expression after the optimization. + case EqualNullSafe(left: Attribute, right: Expression) if domainAttrSet.contains(left) => + left -> right + }.toMap + } + + /** + * Rewrite all [[DomainJoin]]s in the inner query to actual inner joins with the outer query. + */ + def rewriteDomainJoins( + outerPlan: LogicalPlan, + innerPlan: LogicalPlan, + conditions: Seq[Expression]): LogicalPlan = { + innerPlan transform { + case d @ DomainJoin(domainAttrs, child) => + val domainAttrMap = buildDomainAttrMap(conditions, domainAttrs) + // We should only rewrite a domain join when all corresponding outer plan attributes + // can be found from the join condition. + if (domainAttrMap.size == domainAttrs.size) { + val groupingExprs = domainAttrs.map(domainAttrMap) + val aggregateExprs = groupingExprs.zip(domainAttrs).map { + // Rebuild the aliases. + case (inputAttr, outputAttr) => Alias(inputAttr, outputAttr.name)(outputAttr.exprId) + } + // Construct a domain with the outer query plan. + // DomainJoin [a', b'] => Aggregate [a, b] [a AS a', b AS b'] + // +- Relation [a, b] + val domain = Aggregate(groupingExprs, aggregateExprs, outerPlan) + child match { + // A special optimization for OneRowRelation. + // TODO: add a more general rule to optimize join with OneRowRelation. + case _: OneRowRelation => domain + // Construct a domain join. + // Join Inner + // :- Inner Query + // +- Domain + case _ => Join(child, domain, Inner, None, JoinHint.NONE) + } + } else { + throw new UnsupportedOperationException( + s"Unable to rewrite domain join with conditions: $conditions\n$d") + } + } + } + def apply( innerPlan: LogicalPlan, outerPlan: LogicalPlan): (LogicalPlan, Seq[Expression]) = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index 48f2cf8e72f3d..9381796d3d06b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -295,7 +295,7 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper if (newCond.isEmpty) oldCond else newCond } - def rewrite(sub: LogicalPlan, outer: Seq[LogicalPlan]): (LogicalPlan, Seq[Expression]) = { + def decorrelate(sub: LogicalPlan, outer: Seq[LogicalPlan]): (LogicalPlan, Seq[Expression]) = { if (SQLConf.get.decorrelateInnerQueryEnabled) { DecorrelateInnerQuery(sub, outer) } else { @@ -305,7 +305,7 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper plan transformExpressions { case ScalarSubquery(sub, children, exprId) if children.nonEmpty => - val (newPlan, newCond) = rewrite(sub, outerPlans) + val (newPlan, newCond) = decorrelate(sub, outerPlans) ScalarSubquery(newPlan, getJoinCondition(newCond, children), exprId) case Exists(sub, children, exprId) if children.nonEmpty => val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, outerPlans) @@ -509,56 +509,6 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe // Name of generated column used in rewrite below val ALWAYS_TRUE_COLNAME = "alwaysTrue" - /** - * Build a mapping between domain attributes and corresponding outer query expressions - * using the join conditions. - */ - private def buildDomainAttrMap( - conditions: Seq[Expression], - domainAttrs: Seq[Attribute]): Map[Attribute, Expression] = { - val outputSet = AttributeSet(domainAttrs) - conditions.collect { - // When we build the equality conditions, the left side is always the - // domain attributes used in the inner plan, and the right side is the - // attribute from outer plan. Note the right hand side is not necessarily - // an attribute, for example it can be a literal (if foldable) or a cast expression. - case EqualNullSafe(left: Attribute, right: Expression) if outputSet.contains(left) => - left -> right - }.toMap - } - - /** - * Rewrite domain join placeholder to actual inner joins. - */ - private def rewriteDomainJoins( - outerPlan: LogicalPlan, - innerPlan: LogicalPlan, - conditions: Seq[Expression]): LogicalPlan = { - innerPlan transform { - case d @ DomainJoin(domainAttrs, child) => - val domainAttrMap = buildDomainAttrMap(conditions, domainAttrs) - // We should only rewrite a domain join when all corresponding outer plan attributes - // can be found from the join condition. - if (domainAttrMap.size == domainAttrs.size) { - val groupingExprs = domainAttrs.map(domainAttrMap) - val aggregateExprs = groupingExprs.zip(domainAttrs).map { - // Rebuild the aliases. - case (inputAttr, outputAttr) => Alias(inputAttr, outputAttr.name)(outputAttr.exprId) - } - val domain = Aggregate(groupingExprs, aggregateExprs, outerPlan) - child match { - // A special optimization for OneRowRelation. - // TODO: add a more general rule to optimize join with OneRowRelation. - case _: OneRowRelation => domain - case _ => Join(child, domain, Inner, None, JoinHint.NONE) - } - } else { - throw new UnsupportedOperationException( - s"Unable to rewrite domain join with conditions: $conditions\n$d") - } - } - } - /** * Construct a new child plan by left joining the given subqueries to a base plan. * This method returns the child plan and an attribute mapping @@ -571,7 +521,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe val subqueryAttrMapping = ArrayBuffer[(Attribute, Attribute)]() val newChild = subqueries.foldLeft(child) { case (currentChild, ScalarSubquery(sub, conditions, _)) => - val query = rewriteDomainJoins(currentChild, sub, conditions) + val query = DecorrelateInnerQuery.rewriteDomainJoins(currentChild, sub, conditions) val origOutput = query.output.head val resultWithZeroTups = evalSubqueryOnZeroTups(query) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuerySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuerySuite.scala index f58e473728caf..93b27035aca33 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuerySuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuerySuite.scala @@ -110,21 +110,21 @@ class DecorrelateInnerQuerySuite extends PlanTest { check(innerPlan, outerPlan, correctAnswer, Seq(x === y + b)) } - test("aggregate with correlated equality predicates - 1") { + test("aggregate with correlated equality predicates that can be pulled up") { val outerPlan = testRelation2 val minB = Alias(min(b), "min_b")() val innerPlan = Aggregate(Nil, Seq(minB), - Filter(And(OuterReference(x) === a + c, b === 3), + Filter(And(OuterReference(x) === a, b === 3), testRelation)) val correctAnswer = - Aggregate(Seq(a, c), Seq(minB, a, c), + Aggregate(Seq(a), Seq(minB, a), Filter(b === 3, testRelation)) - check(innerPlan, outerPlan, correctAnswer, Seq(x === a + c)) + check(innerPlan, outerPlan, correctAnswer, Seq(x === a)) } - test("aggregate with correlated equality predicates - 2") { + test("aggregate with correlated equality predicates that cannot be pulled up") { val outerPlan = testRelation2 val minB = Alias(min(b), "min_b")() val innerPlan = @@ -132,12 +132,13 @@ class DecorrelateInnerQuerySuite extends PlanTest { Filter(OuterReference(x) === OuterReference(y) + a, testRelation)) val correctAnswer = - Aggregate(Seq(a), Seq(minB, a), - testRelation) - check(innerPlan, outerPlan, correctAnswer, Seq(x === y + a)) + Aggregate(Seq(x, y), Seq(minB, x, y), + Filter(x === y + a, + DomainJoin(Seq(x, y), testRelation))) + check(innerPlan, outerPlan, correctAnswer, Seq(x <=> x, y <=> y)) } - test("aggregate with correlated equality predicates - 3") { + test("aggregate with correlated equality predicates that has no attribute") { val outerPlan = testRelation2 val minB = Alias(min(b), "min_b")() val innerPlan =