Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-34974][SQL] Improve subquery decorrelation framework #32072

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -140,14 +140,11 @@ object SubExprUtils extends PredicateHelper {
* Given a logical plan, returns TRUE if it has an outer reference and false otherwise.
*/
def hasOuterReferences(plan: LogicalPlan): Boolean = {
plan.find {
case f: Filter => containsOuter(f.condition)
case other => false
}.isDefined
plan.find(_.expressions.exists(containsOuter)).isDefined
}

/**
* Given a list of expressions, returns the expressions which have outer references. Aggregate
* Given an expression, returns the expressions which have outer references. Aggregate
* expressions are treated in a special way. If the children of aggregate expression contains an
* outer reference, then the entire aggregate expression is marked as an outer reference.
* Example (SQL):
Expand Down Expand Up @@ -183,18 +180,18 @@ object SubExprUtils extends PredicateHelper {
* }}}
* The code below needs to change when we support the above cases.
*/
def getOuterReferences(conditions: Seq[Expression]): Seq[Expression] = {
def getOuterReferences(expr: Expression): Seq[Expression] = {
val outerExpressions = ArrayBuffer.empty[Expression]
conditions foreach { expr =>
expr transformDown {
case a: AggregateExpression if a.collectLeaves.forall(_.isInstanceOf[OuterReference]) =>
val newExpr = stripOuterReference(a)
outerExpressions += newExpr
newExpr
case OuterReference(e) =>
outerExpressions += e
e
}
expr transformDown {
case a: AggregateExpression if a.collectLeaves.forall(_.isInstanceOf[OuterReference]) =>
// Collect and update the sub-tree so that outer references inside this aggregate
// expression will not be collected. For example: min(outer(a)) -> min(a).
val newExpr = stripOuterReference(a)
outerExpressions += newExpr
newExpr
case OuterReference(e) =>
outerExpressions += e
e
}
outerExpressions.toSeq
}
Expand All @@ -204,8 +201,7 @@ object SubExprUtils extends PredicateHelper {
* Filter operator can host outer references.
*/
def getOuterReferences(plan: LogicalPlan): Seq[Expression] = {
val conditions = plan.collect { case Filter(cond, _) => cond }
getOuterReferences(conditions)
plan.flatMap(_.expressions.flatMap(getOuterReferences))
}

/**
Expand Down
Loading