diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala index 9424d67e725a4..09bf49d393602 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.analysis import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.catalyst.expressions.SubqueryExpression -import org.apache.spark.sql.catalyst.plans.logical.{CTERelationDef, CTERelationRef, LogicalPlan, SubqueryAlias, UnresolvedWith, WithCTE} +import org.apache.spark.sql.catalyst.plans.logical.{CTEInChildren, CTERelationDef, CTERelationRef, LogicalPlan, SubqueryAlias, UnresolvedWith, WithCTE} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreePattern._ import org.apache.spark.sql.catalyst.util.TypeUtils._ @@ -30,8 +30,7 @@ import org.apache.spark.sql.internal.SQLConf.{LEGACY_CTE_PRECEDENCE_POLICY, Lega /** * Analyze WITH nodes and substitute child plan with CTE references or CTE definitions depending * on the conditions below: - * 1. If in legacy mode, or if the query is a SQL command or DML statement, replace with CTE - * definitions, i.e., inline CTEs. + * 1. If in legacy mode, replace with CTE definitions, i.e., inline CTEs. * 2. Otherwise, replace with CTE references `CTERelationRef`s. The decision to inline or not * inline will be made later by the rule `InlineCTE` after query analysis. * @@ -46,6 +45,9 @@ import org.apache.spark.sql.internal.SQLConf.{LEGACY_CTE_PRECEDENCE_POLICY, Lega * dependency for any valid CTE query (i.e., given CTE definitions A and B with B referencing A, * A is guaranteed to appear before B). Otherwise, it must be an invalid user query, and an * analysis exception will be thrown later by relation resolving rules. + * + * If the query is a SQL command or DML statement (extends `CTEInChildren`), + * place `WithCTE` into their children. */ object CTESubstitution extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = { @@ -66,18 +68,18 @@ object CTESubstitution extends Rule[LogicalPlan] { if (cteDefs.isEmpty) { substituted } else if (substituted eq firstSubstituted.get) { - WithCTE(substituted, cteDefs.toSeq) + withCTEDefs(substituted, cteDefs.toSeq) } else { var done = false substituted.resolveOperatorsWithPruning(_ => !done) { case p if p eq firstSubstituted.get => // `firstSubstituted` is the parent of all other CTEs (if any). done = true - WithCTE(p, cteDefs.toSeq) + withCTEDefs(p, cteDefs.toSeq) case p if p.children.count(_.containsPattern(CTE)) > 1 => // This is the first common parent of all CTEs. done = true - WithCTE(p, cteDefs.toSeq) + withCTEDefs(p, cteDefs.toSeq) } } } @@ -242,7 +244,7 @@ object CTESubstitution extends Rule[LogicalPlan] { private def substituteCTE( plan: LogicalPlan, alwaysInline: Boolean, - cteRelations: Seq[(String, CTERelationDef)]): LogicalPlan = + cteRelations: Seq[(String, CTERelationDef)]): LogicalPlan = { plan.resolveOperatorsUpWithPruning( _.containsAnyPattern(RELATION_TIME_TRAVEL, UNRESOLVED_RELATION, PLAN_EXPRESSION)) { case RelationTimeTravel(UnresolvedRelation(Seq(table), _, _), _, _) @@ -266,4 +268,21 @@ object CTESubstitution extends Rule[LogicalPlan] { e.withNewPlan(apply(substituteCTE(e.plan, alwaysInline, cteRelations))) } } + } + + /** + * Finds all logical nodes that should have `WithCTE` in their children like + * `InsertIntoStatement`, put `WithCTE` on top of the children and don't place `WithCTE` + * on top of the plan. If there are no such nodes, put `WithCTE` on the top. + */ + private def withCTEDefs(p: LogicalPlan, cteDefs: Seq[CTERelationDef]): LogicalPlan = { + val withCTE = WithCTE(p, cteDefs) + var onTop = true + val newPlan = p.resolveOperatorsDown { + case cteInChildren: CTEInChildren => + onTop = false + cteInChildren.withCTE(withCTE) + } + if (onTop) withCTE else WithCTE(newPlan, cteDefs) + } }