From 0945baf90660a101ae0f86a39d4c91ca74ae5ee3 Mon Sep 17 00:00:00 2001 From: Ali Afroozeh Date: Fri, 9 Apr 2021 15:06:26 +0200 Subject: [PATCH] [SPARK-34989] Improve the performance of mapChildren and withNewChildren methods MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What changes were proposed in this pull request? One of the main performance bottlenecks in query compilation is overly-generic tree transformation methods, namely `mapChildren` and `withNewChildren` (defined in `TreeNode`). These methods have an overly-generic implementation to iterate over the children and rely on reflection to create new instances. We have observed that, especially for queries with large query plans, a significant amount of CPU cycles are wasted in these methods. In this PR we make these methods more efficient, by delegating the iteration and instantiation to concrete node types. The benchmarks show that we can expect significant performance improvement in total query compilation time in queries with large query plans (from 30-80%) and about 20% on average. #### Problem detail The `mapChildren` method in `TreeNode` is overly generic and costly. To be more specific, this method: - iterates over all the fields of a node using Scala’s product iterator. While the iteration is not reflection-based, thanks to the Scala compiler generating code for `Product`, we create many anonymous functions and visit many nested structures (recursive calls). The anonymous functions (presumably compiled to Java anonymous inner classes) also show up quite high on the list in the object allocation profiles, so we are putting unnecessary pressure on GC here. - does a lot of comparisons. Basically for each element returned from the product iterator, we check if it is a child (contained in the list of children) and then transform it. We can avoid that by just iterating over children, but in the current implementation, we need to gather all the fields (only transform the children) so that we can instantiate the object using the reflection. - creates objects using reflection, by delegating to the `makeCopy` method, which is several orders of magnitude slower than using the constructor. #### Solution The proposed solution in this PR is rather straightforward: we rewrite the `mapChildren` method using the `children` and `withNewChildren` methods. The default `withNewChildren` method suffers from the same problems as `mapChildren` and we need to make it more efficient by specializing it in concrete classes. Similar to how each concrete query plan node already defines its children, it should also define how they can be constructed given a new list of children. Actually, the implementation is quite simple in most cases and is a one-liner thanks to the copy method present in Scala case classes. Note that we cannot abstract over the copy method, it’s generated by the compiler for case classes if no other type higher in the hierarchy defines it. For most concrete nodes, the implementation of `withNewChildren` looks like this: ``` override def withNewChildren(newChildren: Seq[LogicalPlan]): LogicalPlan = copy(children = newChildren) ``` The current `withNewChildren` method has two properties that we should preserve: - It returns the same instance if the provided children are the same as its children, i.e., it preserves referential equality. - It copies tags and maintains the origin links when a new copy is created. These properties are hard to enforce in the concrete node type implementation. Therefore, we propose a template method `withNewChildrenInternal` that should be rewritten by the concrete classes and let the `withNewChildren` method take care of referential equality and copying: ``` override def withNewChildren(newChildren: Seq[LogicalPlan]): LogicalPlan = { if (childrenFastEquals(children, newChildren)) { this } else { CurrentOrigin.withOrigin(origin) { val res = withNewChildrenInternal(newChildren) res.copyTagsFrom(this) res } } } ``` With the refactoring done in a previous PR (https://github.com/apache/spark/pull/31932) most tree node types fall in one of the categories of `Leaf`, `Unary`, `Binary` or `Ternary`. These traits have a more efficient implementation for `mapChildren` and define a more specialized version of `withNewChildrenInternal` that avoids creating unnecessary lists. For example, the `mapChildren` method in `UnaryLike` is defined as follows: ``` override final def mapChildren(f: T => T): T = { val newChild = f(child) if (newChild fastEquals child) { this.asInstanceOf[T] } else { CurrentOrigin.withOrigin(origin) { val res = withNewChildInternal(newChild) res.copyTagsFrom(this.asInstanceOf[T]) res } } } ``` #### Results With this PR, we have observed significant performance improvements in query compilation time, more specifically in the analysis and optimization phases. The table below shows the TPC-DS queries that had more than 25% speedup in compilation times. Biggest speedups are observed in queries with large query plans. | Query | Speedup | | ------------- | ------------- | |q4 |29%| |q9 |81%| |q14a |31%| |q14b |28%| |q22 |33%| |q33 |29%| |q34 |25%| |q39 |27%| |q41 |27%| |q44 |26%| |q47 |28%| |q48 |76%| |q49 |46%| |q56 |26%| |q58 |43%| |q59 |46%| |q60 |50%| |q65 |59%| |q66 |46%| |q67 |52%| |q69 |31%| |q70 |30%| |q96 |26%| |q98 |32%| #### Binary incompatibility Changing the `withNewChildren` in `TreeNode` breaks the binary compatibility of the code compiled against older versions of Spark because now it is expected that concrete `TreeNode` subclasses all implement the `withNewChildrenInternal` method. This is a problem, for example, when users write custom expressions. This change is the right choice, since it forces all newly added expressions to Catalyst implement it in an efficient manner and will prevent future regressions. Please note that we have not completely removed the old implementation and renamed it to `legacyWithNewChildren`. This method will be removed in the future and for now helps the transition. There are expressions such as `UpdateFields` that have a complex way of defining children. Writing `withNewChildren` for them requires refactoring the expression. For now, these expressions use the old, slow method. In a future PR we address these expressions. ### Does this PR introduce _any_ user-facing change? This PR does not introduce user facing changes but my break binary compatibility of the code compiled against older versions. See the binary compatibility section. ### How was this patch tested? This PR is mainly a refactoring and passes existing tests. Closes #32030 from dbaliafroozeh/ImprovedMapChildren. Authored-by: Ali Afroozeh Signed-off-by: herman --- .../spark/sql/avro/AvroDataToCatalyst.scala | 3 + .../spark/sql/avro/CatalystDataToAvro.scala | 3 + .../org/apache/spark/ml/stat/Summarizer.scala | 4 + .../sql/catalyst/analysis/unresolved.scala | 30 +++ .../expressions/CallMethodViaReflection.scala | 3 + .../spark/sql/catalyst/expressions/Cast.scala | 6 + .../catalyst/expressions/DynamicPruning.scala | 6 + .../expressions/PartitionTransforms.scala | 5 + .../sql/catalyst/expressions/PythonUDF.scala | 3 + .../sql/catalyst/expressions/ScalaUDF.scala | 3 + .../sql/catalyst/expressions/SortOrder.scala | 6 + .../SubExprEvaluationRuntime.scala | 3 + .../sql/catalyst/expressions/TimeWindow.scala | 6 + .../sql/catalyst/expressions/TryCast.scala | 3 + .../ApproxCountDistinctForIntervals.scala | 4 + .../aggregate/ApproximatePercentile.scala | 4 + .../expressions/aggregate/Average.scala | 3 + .../aggregate/CentralMomentAgg.scala | 18 ++ .../catalyst/expressions/aggregate/Corr.scala | 3 + .../expressions/aggregate/Count.scala | 3 + .../expressions/aggregate/CountIf.scala | 3 + .../aggregate/CountMinSketchAgg.scala | 8 + .../expressions/aggregate/Covariance.scala | 7 + .../expressions/aggregate/First.scala | 2 + .../aggregate/HyperLogLogPlusPlus.scala | 3 + .../catalyst/expressions/aggregate/Last.scala | 2 + .../catalyst/expressions/aggregate/Max.scala | 2 + .../expressions/aggregate/MaxByAndMinBy.scala | 6 + .../catalyst/expressions/aggregate/Min.scala | 2 + .../expressions/aggregate/Percentile.scala | 7 + .../expressions/aggregate/PivotFirst.scala | 4 + .../expressions/aggregate/Product.scala | 3 + .../catalyst/expressions/aggregate/Sum.scala | 2 + .../aggregate/UnevaluableAggs.scala | 4 + .../aggregate/bitwiseAggregates.scala | 9 + .../expressions/aggregate/collect.scala | 6 + .../expressions/aggregate/interfaces.scala | 10 + .../sql/catalyst/expressions/arithmetic.scala | 36 ++++ .../expressions/bitwiseExpressions.scala | 18 ++ .../expressions/codegen/javaCode.scala | 8 +- .../expressions/collectionOperations.scala | 96 +++++++++ .../expressions/complexTypeCreator.scala | 26 +++ .../expressions/complexTypeExtractors.scala | 14 ++ .../expressions/conditionalExpressions.scala | 10 + .../expressions/constraintExpressions.scala | 8 +- .../catalyst/expressions/csvExpressions.scala | 9 + .../expressions/datetimeExpressions.scala | 162 +++++++++++++++ .../expressions/decimalExpressions.scala | 15 ++ .../sql/catalyst/expressions/generators.scala | 18 ++ .../sql/catalyst/expressions/grouping.scala | 11 + .../spark/sql/catalyst/expressions/hash.scala | 18 ++ .../expressions/higherOrderFunctions.scala | 58 +++++- .../expressions/intervalExpressions.scala | 67 +++++- .../expressions/jsonExpressions.scala | 22 ++ .../expressions/mathExpressions.scala | 114 ++++++++-- .../spark/sql/catalyst/expressions/misc.scala | 11 + .../expressions/namedExpressions.scala | 3 + .../expressions/nullExpressions.scala | 24 +++ .../expressions/objects/objects.scala | 70 ++++++- .../sql/catalyst/expressions/predicates.scala | 36 ++++ .../expressions/randomExpressions.scala | 4 + .../expressions/regexpExpressions.scala | 30 +++ .../expressions/stringExpressions.scala | 134 ++++++++++++ .../sql/catalyst/expressions/subquery.scala | 9 + .../expressions/windowExpressions.scala | 39 ++++ .../sql/catalyst/expressions/xml/xpath.scala | 24 +++ .../optimizer/CostBasedJoinReorder.scala | 3 + .../optimizer/NormalizeFloatingNumbers.scala | 3 + .../plans/logical/EventTimeWatermark.scala | 3 + .../plans/logical/ScriptTransformation.scala | 3 + .../plans/logical/basicLogicalOperators.scala | 73 +++++++ .../sql/catalyst/plans/logical/hints.scala | 6 + .../sql/catalyst/plans/logical/object.scala | 52 ++++- .../logical/pythonLogicalOperators.scala | 20 +- .../catalyst/plans/logical/statements.scala | 11 +- .../catalyst/plans/logical/v2Commands.scala | 196 ++++++++++++++++-- .../plans/physical/partitioning.scala | 11 + .../catalyst/streaming/WriteToStream.scala | 2 + .../streaming/WriteToStreamStatement.scala | 3 + .../spark/sql/catalyst/trees/TreeNode.scala | 151 +++++++++++++- .../analysis/AnalysisErrorSuite.scala | 2 + .../catalyst/analysis/TypeCoercionSuite.scala | 10 + .../analysis/UnsupportedOperationsSuite.scala | 2 + .../SubexpressionEliminationSuite.scala | 2 + .../ConvertToLocalRelationSuite.scala | 3 + .../sql/catalyst/plans/LogicalPlanSuite.scala | 3 + .../logical/LogicalPlanIntegritySuite.scala | 2 + .../sql/catalyst/trees/TreeNodeSuite.scala | 12 +- .../sql/execution/CollectMetricsExec.scala | 3 + .../apache/spark/sql/execution/Columnar.scala | 6 + .../spark/sql/execution/ExpandExec.scala | 3 + .../spark/sql/execution/GenerateExec.scala | 3 + .../apache/spark/sql/execution/SortExec.scala | 3 + .../SparkScriptTransformationExec.scala | 3 + .../SubqueryAdaptiveBroadcastExec.scala | 3 + .../sql/execution/SubqueryBroadcastExec.scala | 3 + .../sql/execution/WholeStageCodegenExec.scala | 6 + .../adaptive/CustomShuffleReaderExec.scala | 3 + .../aggregate/HashAggregateExec.scala | 3 + .../aggregate/ObjectHashAggregateExec.scala | 3 + .../aggregate/SortAggregateExec.scala | 3 + .../aggregate/TypedAggregateExpression.scala | 8 + .../spark/sql/execution/aggregate/udaf.scala | 7 + .../execution/basicPhysicalOperators.scala | 18 ++ .../command/AnalyzeColumnCommand.scala | 2 +- .../command/AnalyzePartitionCommand.scala | 2 +- .../command/AnalyzeTableCommand.scala | 2 +- .../command/AnalyzeTablesCommand.scala | 2 +- .../InsertIntoDataSourceDirCommand.scala | 2 +- .../sql/execution/command/SetCommand.scala | 5 +- .../spark/sql/execution/command/cache.scala | 2 +- .../sql/execution/command/commands.scala | 12 +- .../command/createDataSourceTables.scala | 5 +- .../spark/sql/execution/command/ddl.scala | 30 +-- .../sql/execution/command/functions.scala | 10 +- .../sql/execution/command/resources.scala | 13 +- .../spark/sql/execution/command/tables.scala | 30 +-- .../spark/sql/execution/command/views.scala | 6 +- .../datasources/FileFormatWriter.scala | 3 + .../InsertIntoDataSourceCommand.scala | 4 +- .../InsertIntoHadoopFsRelationCommand.scala | 3 + .../SaveIntoDataSourceCommand.scala | 4 +- .../spark/sql/execution/datasources/ddl.scala | 10 +- .../v2/WriteToDataSourceV2Exec.scala | 32 ++- .../spark/sql/execution/debug/package.scala | 3 + .../exchange/BroadcastExchangeExec.scala | 3 + .../exchange/ShuffleExchangeExec.scala | 3 + .../joins/BroadcastHashJoinExec.scala | 4 + .../joins/BroadcastNestedLoopJoinExec.scala | 4 + .../joins/CartesianProductExec.scala | 4 + .../joins/ShuffledHashJoinExec.scala | 4 + .../execution/joins/SortMergeJoinExec.scala | 4 + .../apache/spark/sql/execution/limit.scala | 17 +- .../apache/spark/sql/execution/objects.scala | 33 +++ .../python/AggregateInPandasExec.scala | 3 + .../python/ArrowEvalPythonExec.scala | 3 + .../python/BatchEvalPythonExec.scala | 3 + .../python/FlatMapCoGroupsInPandasExec.scala | 4 + .../python/FlatMapGroupsInPandasExec.scala | 3 + .../execution/python/MapInPandasExec.scala | 3 + .../execution/python/WindowInPandasExec.scala | 3 + .../streaming/EventTimeWatermarkExec.scala | 3 + .../FlatMapGroupsWithStateExec.scala | 3 + .../StreamingSymmetricHashJoinExec.scala | 4 + .../WriteToContinuousDataSource.scala | 2 + .../WriteToContinuousDataSourceExec.scala | 3 + .../sources/WriteToMicroBatchDataSource.scala | 3 + .../streaming/statefulOperators.scala | 9 + .../execution/streaming/streamingLimits.scala | 6 + .../apache/spark/sql/execution/subquery.scala | 3 + .../sql/execution/window/WindowExec.scala | 3 + .../q14a.sf100/explain.txt | 2 +- .../approved-plans-v1_4/q14a/explain.txt | 2 +- .../approved-plans-v1_4/q5.sf100/explain.txt | 2 +- .../approved-plans-v1_4/q5/explain.txt | 2 +- .../approved-plans-v1_4/q77.sf100/explain.txt | 2 +- .../approved-plans-v1_4/q77/explain.txt | 2 +- .../approved-plans-v1_4/q80.sf100/explain.txt | 2 +- .../approved-plans-v1_4/q80/explain.txt | 2 +- .../spark/sql/DataFrameFunctionsSuite.scala | 2 + .../spark/sql/ExtraStrategiesSuite.scala | 5 +- .../sql/SparkSessionExtensionSuite.scala | 21 ++ .../sql/TypedImperativeAggregateSuite.scala | 6 +- .../BaseScriptTransformationSuite.scala | 3 + .../sql/execution/ColumnarRulesSuite.scala | 1 + .../spark/sql/execution/ExchangeSuite.scala | 3 + .../spark/sql/execution/PlannerSuite.scala | 2 + .../spark/sql/execution/ReferenceSort.scala | 3 + .../sql/util/DataFrameCallbackSuite.scala | 4 +- .../CreateHiveTableAsSelectCommand.scala | 6 + .../HiveScriptTransformationExec.scala | 3 + .../execution/InsertIntoHiveDirCommand.scala | 3 + .../hive/execution/InsertIntoHiveTable.scala | 3 + .../org/apache/spark/sql/hive/hiveUDFs.scala | 12 ++ .../hive/execution/TestingTypedCount.scala | 3 + 175 files changed, 2213 insertions(+), 146 deletions(-) diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala index 64fb588e98506..b4965003ba33d 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala @@ -134,4 +134,7 @@ private[avro] case class AvroDataToCatalyst( """ }) } + + override protected def withNewChildInternal(newChild: Expression): AvroDataToCatalyst = + copy(child = newChild) } diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/CatalystDataToAvro.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/CatalystDataToAvro.scala index 53910b752fdd6..5d79c44ad422e 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/CatalystDataToAvro.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/CatalystDataToAvro.scala @@ -64,4 +64,7 @@ private[avro] case class CatalystDataToAvro( defineCodeGen(ctx, ev, input => s"(byte[]) $expr.nullSafeEval($input)") } + + override protected def withNewChildInternal(newChild: Expression): CatalystDataToAvro = + copy(child = newChild) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala index 109ccbd964aca..a3dd133a4ce8d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala @@ -374,6 +374,10 @@ private[spark] object SummaryBuilderImpl extends Logging { override def left: Expression = featuresExpr override def right: Expression = weightExpr + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): MetricsAggregate = + copy(featuresExpr = newLeft, weightExpr = newRight) + override def update(state: SummarizerBuffer, row: InternalRow): SummarizerBuffer = { val features = vectorUDT.deserialize(featuresExpr.eval(row)) val weight = weightExpr.eval(row).asInstanceOf[Double] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 3fc3db30fa7b1..3b2f4ca79cbc0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -263,6 +263,9 @@ case class UnresolvedGenerator(name: FunctionIdentifier, children: Seq[Expressio override def terminate(): TraversableOnce[InternalRow] = throw QueryExecutionErrors.cannotTerminateGeneratorError(this) + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): UnresolvedGenerator = copy(children = newChildren) } case class UnresolvedFunction( @@ -284,6 +287,15 @@ case class UnresolvedFunction( val distinct = if (isDistinct) "distinct " else "" s"'$name($distinct${children.mkString(", ")})" } + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): UnresolvedFunction = { + if (filter.isDefined) { + copy(arguments = newChildren.dropRight(1), filter = Some(newChildren.last)) + } else { + copy(arguments = newChildren) + } + } } object UnresolvedFunction { @@ -441,6 +453,8 @@ case class MultiAlias(child: Expression, names: Seq[String]) override def toString: String = s"$child AS $names" + override protected def withNewChildInternal(newChild: Expression): MultiAlias = + copy(child = newChild) } /** @@ -475,6 +489,11 @@ case class UnresolvedExtractValue(child: Expression, extraction: Expression) override def toString: String = s"$child[$extraction]" override def sql: String = s"${child.sql}[${extraction.sql}]" + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): UnresolvedExtractValue = { + copy(child = newLeft, extraction = newRight) + } } /** @@ -499,6 +518,9 @@ case class UnresolvedAlias( override def newInstance(): NamedExpression = throw new UnresolvedException("newInstance") override lazy val resolved = false + + override protected def withNewChildInternal(newChild: Expression): UnresolvedAlias = + copy(child = newChild) } /** @@ -520,6 +542,9 @@ case class UnresolvedSubqueryColumnAliases( override def output: Seq[Attribute] = Nil override lazy val resolved = false + + override protected def withNewChildInternal( + newChild: LogicalPlan): UnresolvedSubqueryColumnAliases = copy(child = newChild) } /** @@ -541,6 +566,9 @@ case class UnresolvedDeserializer(deserializer: Expression, inputAttributes: Seq override def dataType: DataType = throw new UnresolvedException("dataType") override def nullable: Boolean = throw new UnresolvedException("nullable") override lazy val resolved = false + + override protected def withNewChildInternal(newChild: Expression): UnresolvedDeserializer = + copy(deserializer = newChild) } case class GetColumnByOrdinal(ordinal: Int, dataType: DataType) extends LeafExpression @@ -587,6 +615,8 @@ case class UnresolvedHaving( extends UnaryNode { override lazy val resolved: Boolean = false override def output: Seq[Attribute] = child.output + override protected def withNewChildInternal(newChild: LogicalPlan): UnresolvedHaving = + copy(child = newChild) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala index 0de17d420f0c9..7cb830d115689 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala @@ -114,6 +114,9 @@ case class CallMethodViaReflection(children: Seq[Expression]) /** A temporary buffer used to hold intermediate results returned by children. */ @transient private lazy val buffer = new Array[Object](argExprs.length) + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): CallMethodViaReflection = copy(children = newChildren) } object CallMethodViaReflection { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 879b154a84761..1e1b7eeca0f35 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -1812,6 +1812,8 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String } else { s"cannot cast ${child.dataType.catalogString} to ${dataType.catalogString}" } + + override protected def withNewChildInternal(newChild: Expression): Cast = copy(child = newChild) } /** @@ -1841,6 +1843,8 @@ case class AnsiCast(child: Expression, dataType: DataType, timeZoneId: Option[St Some(SQLConf.STORE_ASSIGNMENT_POLICY.key), Some(SQLConf.StoreAssignmentPolicy.LEGACY.toString)) + override protected def withNewChildInternal(newChild: Expression): AnsiCast = + copy(child = newChild) } object AnsiCast { @@ -1998,4 +2002,6 @@ case class UpCast(child: Expression, target: AbstractDataType, walkedTypePath: S case DecimalType => DecimalType.SYSTEM_DEFAULT case _ => target.asInstanceOf[DataType] } + + override protected def withNewChildInternal(newChild: Expression): UpCast = copy(child = newChild) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DynamicPruning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DynamicPruning.scala index 550fa4c3f73e4..de4b874637f09 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DynamicPruning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DynamicPruning.scala @@ -78,6 +78,9 @@ case class DynamicPruningSubquery( buildKeys = buildKeys.map(_.canonicalized), exprId = ExprId(0)) } + + override protected def withNewChildInternal(newChild: Expression): DynamicPruningSubquery = + copy(pruningKey = newChild) } /** @@ -94,4 +97,7 @@ case class DynamicPruningExpression(child: Expression) override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { child.genCode(ctx) } + + override protected def withNewChildInternal(newChild: Expression): DynamicPruningExpression = + copy(child = newChild) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PartitionTransforms.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PartitionTransforms.scala index 05d553757e742..ab390618d4c5c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PartitionTransforms.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PartitionTransforms.scala @@ -43,6 +43,7 @@ abstract class PartitionTransformExpression extends Expression with Unevaluable */ case class Years(child: Expression) extends PartitionTransformExpression { override def dataType: DataType = IntegerType + override protected def withNewChildInternal(newChild: Expression): Years = copy(child = newChild) } /** @@ -50,6 +51,7 @@ case class Years(child: Expression) extends PartitionTransformExpression { */ case class Months(child: Expression) extends PartitionTransformExpression { override def dataType: DataType = IntegerType + override protected def withNewChildInternal(newChild: Expression): Months = copy(child = newChild) } /** @@ -57,6 +59,7 @@ case class Months(child: Expression) extends PartitionTransformExpression { */ case class Days(child: Expression) extends PartitionTransformExpression { override def dataType: DataType = IntegerType + override protected def withNewChildInternal(newChild: Expression): Days = copy(child = newChild) } /** @@ -64,6 +67,7 @@ case class Days(child: Expression) extends PartitionTransformExpression { */ case class Hours(child: Expression) extends PartitionTransformExpression { override def dataType: DataType = IntegerType + override protected def withNewChildInternal(newChild: Expression): Hours = copy(child = newChild) } /** @@ -71,4 +75,5 @@ case class Hours(child: Expression) extends PartitionTransformExpression { */ case class Bucket(numBuckets: Literal, child: Expression) extends PartitionTransformExpression { override def dataType: DataType = IntegerType + override protected def withNewChildInternal(newChild: Expression): Bucket = copy(child = newChild) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala index da2e1821feb0f..73f8c300b4ae7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala @@ -73,4 +73,7 @@ case class PythonUDF( // `resultId` can be seen as cosmetic variation in PythonUDF, as it doesn't affect the result. this.copy(resultId = ExprId(-1)).withNewChildren(canonicalizedChildren) } + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): PythonUDF = + copy(children = newChildren) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 4086e7698e7b1..375ae95acfc39 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -1195,4 +1195,7 @@ case class ScalaUDF( resultConverter(result) } + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): ScalaUDF = + copy(children = newChildren) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala index d9923b5d022e0..9aef25ce60599 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -88,6 +88,9 @@ case class SortOrder( children.exists(required.child.semanticEquals) && direction == required.direction && nullOrdering == required.nullOrdering } + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): SortOrder = + copy(child = newChildren.head, sameOrderExpressions = newChildren.tail) } object SortOrder { @@ -226,4 +229,7 @@ case class SortPrefix(child: SortOrder) extends UnaryExpression { } override def dataType: DataType = LongType + + override protected def withNewChildInternal(newChild: Expression): SortPrefix = + copy(child = newChild.asInstanceOf[SortOrder]) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SubExprEvaluationRuntime.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SubExprEvaluationRuntime.scala index a1f7ba3008775..0f224fefe3911 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SubExprEvaluationRuntime.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SubExprEvaluationRuntime.scala @@ -140,6 +140,9 @@ case class ExpressionProxy( } override def hashCode(): Int = this.id.hashCode() + + override protected def withNewChildInternal(newChild: Expression): ExpressionProxy = + copy(child = newChild) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala index f7fe467cea830..ed1d77017c120 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala @@ -92,6 +92,9 @@ case class TimeWindow( } dataTypeCheck } + + override protected def withNewChildInternal(newChild: Expression): TimeWindow = + copy(timeColumn = newChild) } object TimeWindow { @@ -155,4 +158,7 @@ case class PreciseTimestampConversion( """.stripMargin) } override def nullSafeEval(input: Any): Any = input + + override protected def withNewChildInternal(newChild: Expression): PreciseTimestampConversion = + copy(child = newChild) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TryCast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TryCast.scala index 88563898fc818..0f63de1bf7e45 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TryCast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TryCast.scala @@ -85,6 +85,9 @@ case class TryCast(child: Expression, dataType: DataType, timeZoneId: Option[Str override def typeCheckFailureMessage: String = AnsiCast.typeCheckFailureMessage(child.dataType, dataType, None, None) + override protected def withNewChildInternal(newChild: Expression): TryCast = + copy(child = newChild) + override def toString: String = { s"try_cast($child as ${dataType.simpleString})" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala index 42dc6f6b200d0..19e212d1f9e69 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala @@ -249,4 +249,8 @@ case class ApproxCountDistinctForIntervals( override def getLong(offset: Int): Long = array(offset) override def setLong(offset: Int, value: Long): Unit = { array(offset) = value } } + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): ApproxCountDistinctForIntervals = + copy(child = newLeft, endpointsExpression = newRight) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala index 4e4a06a628453..38d8d7d71ead8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala @@ -208,6 +208,10 @@ case class ApproximatePercentile( override def deserialize(bytes: Array[Byte]): PercentileDigest = { ApproximatePercentile.serializer.deserialize(bytes) } + + override protected def withNewChildrenInternal( + newFirst: Expression, newSecond: Expression, newThird: Expression): ApproximatePercentile = + copy(child = newFirst, percentageExpression = newSecond, accuracyExpression = newThird) } object ApproximatePercentile { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index 36004b0ea6244..90e91ae41856c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -93,4 +93,7 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit coalesce(child.cast(sumDataType), Literal.default(sumDataType))), /* count = */ If(child.isNull, count, count + 1L) ) + + override protected def withNewChildInternal(newChild: Expression): Average = + copy(child = newChild) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala index 4ca933ff45d02..c5c78e5062f56 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala @@ -167,6 +167,9 @@ case class StddevPop( } override def prettyName: String = "stddev_pop" + + override protected def withNewChildInternal(newChild: Expression): StddevPop = + copy(child = newChild) } // Compute the sample standard deviation of a column @@ -197,6 +200,9 @@ case class StddevSamp( override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("stddev_samp") + + override protected def withNewChildInternal(newChild: Expression): StddevSamp = + copy(child = newChild) } // Compute the population variance of a column @@ -223,6 +229,9 @@ case class VariancePop( } override def prettyName: String = "var_pop" + + override protected def withNewChildInternal(newChild: Expression): VariancePop = + copy(child = newChild) } // Compute the sample variance of a column @@ -250,6 +259,9 @@ case class VarianceSamp( } override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("var_samp") + + override protected def withNewChildInternal(newChild: Expression): VarianceSamp = + copy(child = newChild) } @ExpressionDescription( @@ -278,6 +290,9 @@ case class Skewness( If(n === 0.0, Literal.create(null, DoubleType), If(m2 === 0.0, divideByZeroEvalResult, sqrt(n) * m3 / sqrt(m2 * m2 * m2))) } + + override protected def withNewChildInternal(newChild: Expression): Skewness = + copy(child = newChild) } @ExpressionDescription( @@ -306,4 +321,7 @@ case class Kurtosis( } override def prettyName: String = "kurtosis" + + override protected def withNewChildInternal(newChild: Expression): Kurtosis = + copy(child = newChild) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala index d819971478ecf..c798004fe7843 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala @@ -127,4 +127,7 @@ case class Corr( } override def prettyName: String = "corr" + + override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Corr = + copy(x = newLeft, y = newRight) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala index 189d21603e70f..1d13155ef6898 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala @@ -89,6 +89,9 @@ case class Count(children: Seq[Expression]) extends DeclarativeAggregate { ) } } + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Count = + copy(children = newChildren) } object Count { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountIf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountIf.scala index c1c4c84497bcd..d4fdd5115b59d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountIf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountIf.scala @@ -56,4 +56,7 @@ case class CountIf(predicate: Expression) extends UnevaluableAggregate with Impl s"function $prettyName requires boolean type, not ${predicate.dataType.catalogString}" ) } + + override protected def withNewChildInternal(newChild: Expression): CountIf = + copy(predicate = newChild) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala index a838a0a0e8977..38d0db1e7610c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala @@ -154,4 +154,12 @@ case class CountMinSketchAgg( override def second: Expression = epsExpression override def third: Expression = confidenceExpression override def fourth: Expression = seedExpression + + override protected def withNewChildrenInternal(first: Expression, second: Expression, + third: Expression, fourth: Expression): CountMinSketchAgg = + copy( + child = first, + epsExpression = second, + confidenceExpression = third, + seedExpression = fourth) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala index 8fcee104d276b..9ea9b3782032b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala @@ -109,6 +109,10 @@ case class CovPopulation( If(n === 0.0, Literal.create(null, DoubleType), ck / n) } override def prettyName: String = "covar_pop" + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): CovPopulation = + copy(left = newLeft, right = newRight) } @@ -135,4 +139,7 @@ case class CovSample( If(n === 1.0, divideByZeroEvalResult, ck / (n - 1.0))) } override def prettyName: String = "covar_samp" + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): CovSample = copy(left = newLeft, right = newRight) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala index accd15a711503..ea994af0e6168 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala @@ -118,6 +118,8 @@ case class First(child: Expression, ignoreNulls: Boolean) override lazy val evaluateExpression: AttributeReference = first override def toString: String = s"$prettyName($child)${if (ignoreNulls) " ignore nulls"}" + + override protected def withNewChildInternal(newChild: Expression): First = copy(child = newChild) } object FirstLast { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala index 430c25cee2a93..9b0493f3e68a4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala @@ -138,6 +138,9 @@ case class HyperLogLogPlusPlus( override def eval(buffer: InternalRow): Any = { hllppHelper.query(buffer, mutableAggBufferOffset) } + + override protected def withNewChildInternal(newChild: Expression): HyperLogLogPlusPlus = + copy(child = newChild) } object HyperLogLogPlusPlus { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala index e3c427d584489..0fe6199cd8c31 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala @@ -115,4 +115,6 @@ case class Last(child: Expression, ignoreNulls: Boolean) override lazy val evaluateExpression: AttributeReference = last override def toString: String = s"$prettyName($child)${if (ignoreNulls) " ignore nulls"}" + + override protected def withNewChildInternal(newChild: Expression): Last = copy(child = newChild) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala index 42721ea48c7ca..b802678ec0468 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala @@ -62,4 +62,6 @@ case class Max(child: Expression) extends DeclarativeAggregate with UnaryLike[Ex } override lazy val evaluateExpression: AttributeReference = max + + override protected def withNewChildInternal(newChild: Expression): Max = copy(child = newChild) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/MaxByAndMinBy.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/MaxByAndMinBy.scala index e402bcae144ad..664bc32ccc464 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/MaxByAndMinBy.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/MaxByAndMinBy.scala @@ -110,6 +110,9 @@ case class MaxBy(valueExpr: Expression, orderingExpr: Expression) extends MaxMin override protected def orderingUpdater(oldExpr: Expression, newExpr: Expression): Expression = greatest(oldExpr, newExpr) + + override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): MaxBy = + copy(valueExpr = newLeft, orderingExpr = newRight) } @ExpressionDescription( @@ -130,4 +133,7 @@ case class MinBy(valueExpr: Expression, orderingExpr: Expression) extends MaxMin override protected def orderingUpdater(oldExpr: Expression, newExpr: Expression): Expression = least(oldExpr, newExpr) + + override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): MinBy = + copy(valueExpr = newLeft, orderingExpr = newRight) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala index 84410c7de3229..9c5c7bbda4dc7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala @@ -62,4 +62,6 @@ case class Min(child: Expression) extends DeclarativeAggregate with UnaryLike[Ex } override lazy val evaluateExpression: AttributeReference = min + + override protected def withNewChildInternal(newChild: Expression): Min = copy(child = newChild) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala index b81c523ce32ba..5bce4d348c726 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala @@ -304,4 +304,11 @@ case class Percentile( bis.close() } } + + override protected def withNewChildrenInternal( + newFirst: Expression, newSecond: Expression, newThird: Expression): Percentile = copy( + child = newFirst, + percentageExpression = newSecond, + frequencyExpression = newThird + ) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala index 422fcab5bf890..b90e46e1545d8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala @@ -153,5 +153,9 @@ case class PivotFirst( override val inputAggBufferAttributes: Seq[AttributeReference] = aggBufferAttributes.map(_.newInstance()) + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): PivotFirst = + copy(pivotColumn = newLeft, valueColumn = newRight) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Product.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Product.scala index 50c74f1c49a99..3af3944fd47d7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Product.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Product.scala @@ -59,4 +59,7 @@ case class Product(child: Expression) Seq(coalesce(coalesce(product.left, one) * product.right, product.left)) override lazy val evaluateExpression: Expression = product + + override protected def withNewChildInternal(newChild: Expression): Product = + copy(child = newChild) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index f412a3ec31e0f..56eebedddf08d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -148,4 +148,6 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast CheckOverflowInSum(sum, d, !SQLConf.get.ansiEnabled)) case _ => sum } + + override protected def withNewChildInternal(newChild: Expression): Sum = copy(child = newChild) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/UnevaluableAggs.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/UnevaluableAggs.scala index 5b914c4333687..878d853aca3cc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/UnevaluableAggs.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/UnevaluableAggs.scala @@ -56,6 +56,8 @@ abstract class UnevaluableBooleanAggBase(arg: Expression) since = "3.0.0") case class BoolAnd(arg: Expression) extends UnevaluableBooleanAggBase(arg) { override def nodeName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("bool_and") + override protected def withNewChildInternal(newChild: Expression): Expression = + copy(arg = newChild) } @ExpressionDescription( @@ -73,4 +75,6 @@ case class BoolAnd(arg: Expression) extends UnevaluableBooleanAggBase(arg) { since = "3.0.0") case class BoolOr(arg: Expression) extends UnevaluableBooleanAggBase(arg) { override def nodeName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("bool_or") + override protected def withNewChildInternal(newChild: Expression): Expression = + copy(arg = newChild) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/bitwiseAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/bitwiseAggregates.scala index 5ffc0f6ce3a42..86a16ad389b5d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/bitwiseAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/bitwiseAggregates.scala @@ -69,6 +69,9 @@ case class BitAndAgg(child: Expression) extends BitAggregate { override def bitOperator(left: Expression, right: Expression): BinaryArithmetic = { BitwiseAnd(left, right) } + + override protected def withNewChildInternal(newChild: Expression): BitAndAgg = + copy(child = newChild) } @ExpressionDescription( @@ -87,6 +90,9 @@ case class BitOrAgg(child: Expression) extends BitAggregate { override def bitOperator(left: Expression, right: Expression): BinaryArithmetic = { BitwiseOr(left, right) } + + override protected def withNewChildInternal(newChild: Expression): BitOrAgg = + copy(child = newChild) } @ExpressionDescription( @@ -105,4 +111,7 @@ case class BitXorAgg(child: Expression) extends BitAggregate { override def bitOperator(left: Expression, right: Expression): BinaryArithmetic = { BitwiseXor(left, right) } + + override protected def withNewChildInternal(newChild: Expression): Expression = + copy(child = newChild) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala index d8a76d7add262..a8db8211a9e4d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala @@ -125,6 +125,9 @@ case class CollectList( override def eval(buffer: mutable.ArrayBuffer[Any]): Any = { new GenericArrayData(buffer.toArray) } + + override protected def withNewChildInternal(newChild: Expression): CollectList = + copy(child = newChild) } /** @@ -191,4 +194,7 @@ case class CollectSet( override def prettyName: String = "collect_set" override def createAggregationBuffer(): mutable.HashSet[Any] = mutable.HashSet.empty + + override protected def withNewChildInternal(newChild: Expression): CollectSet = + copy(child = newChild) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index e0c6ce7208c94..281734c6f14ae 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -164,6 +164,16 @@ case class AggregateExpression( case _ => aggFuncStr } } + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): AggregateExpression = + if (filter.isDefined) { + copy( + aggregateFunction = newChildren(0).asInstanceOf[AggregateFunction], + filter = Some(newChildren(1))) + } else { + copy(aggregateFunction = newChildren(0).asInstanceOf[AggregateFunction]) + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 64ea579e5ca05..28851918429aa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -105,6 +105,9 @@ case class UnaryMinus( case funcName => s"$funcName(${child.sql})" } } + + override protected def withNewChildInternal(newChild: Expression): UnaryMinus = + copy(child = newChild) } @ExpressionDescription( @@ -131,6 +134,9 @@ case class UnaryPositive(child: Expression) protected override def nullSafeEval(input: Any): Any = input override def sql: String = s"(+ ${child.sql})" + + override protected def withNewChildInternal(newChild: Expression): UnaryPositive = + copy(child = newChild) } /** @@ -183,6 +189,8 @@ case class Abs(child: Expression, failOnError: Boolean = SQLConf.get.ansiEnabled } protected override def nullSafeEval(input: Any): Any = numeric.abs(input) + + override protected def withNewChildInternal(newChild: Expression): Abs = copy(child = newChild) } abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant { @@ -309,6 +317,9 @@ case class Add( } override def exactMathMethod: Option[String] = Some("addExact") + + override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Add = + copy(left = newLeft, right = newRight) } @ExpressionDescription( @@ -352,6 +363,9 @@ case class Subtract( } override def exactMathMethod: Option[String] = Some("subtractExact") + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): Subtract = copy(left = newLeft, right = newRight) } @ExpressionDescription( @@ -380,6 +394,9 @@ case class Multiply( protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.times(input1, input2) override def exactMathMethod: Option[String] = Some("multiplyExact") + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): Multiply = copy(left = newLeft, right = newRight) } // Common base trait for Divide and Remainder, since these two classes are almost identical @@ -506,6 +523,9 @@ case class Divide( } override def evalOperation(left: Any, right: Any): Any = div(left, right) + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): Divide = copy(left = newLeft, right = newRight) } // scalastyle:off line.size.limit @@ -553,6 +573,10 @@ case class IntegralDivide( } override def evalOperation(left: Any, right: Any): Any = div(left, right) + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): IntegralDivide = + copy(left = newLeft, right = newRight) } @ExpressionDescription( @@ -607,6 +631,9 @@ case class Remainder( } override def evalOperation(left: Any, right: Any): Any = mod(left, right) + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): Remainder = copy(left = newLeft, right = newRight) } @ExpressionDescription( @@ -791,6 +818,9 @@ case class Pmod( } override def sql: String = s"$prettyName(${left.sql}, ${right.sql})" + + override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Pmod = + copy(left = newLeft, right = newRight) } /** @@ -866,6 +896,9 @@ case class Least(children: Seq[Expression]) extends ComplexTypeMergingExpression |$codes """.stripMargin) } + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Least = + copy(children = newChildren) } /** @@ -941,4 +974,7 @@ case class Greatest(children: Seq[Expression]) extends ComplexTypeMergingExpress |$codes """.stripMargin) } + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Greatest = + copy(children = newChildren) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala index a1fb68ea169c5..3940c65593ec5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala @@ -56,6 +56,9 @@ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithme } protected override def nullSafeEval(input1: Any, input2: Any): Any = and(input1, input2) + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): BitwiseAnd = copy(left = newLeft, right = newRight) } /** @@ -92,6 +95,9 @@ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmet } protected override def nullSafeEval(input1: Any, input2: Any): Any = or(input1, input2) + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): BitwiseOr = copy(left = newLeft, right = newRight) } /** @@ -128,6 +134,9 @@ case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithme } protected override def nullSafeEval(input1: Any, input2: Any): Any = xor(input1, input2) + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): BitwiseXor = copy(left = newLeft, right = newRight) } /** @@ -169,6 +178,9 @@ case class BitwiseNot(child: Expression) protected override def nullSafeEval(input: Any): Any = not(input) override def sql: String = s"~${child.sql}" + + override protected def withNewChildInternal(newChild: Expression): BitwiseNot = + copy(child = newChild) } @ExpressionDescription( @@ -204,6 +216,9 @@ case class BitwiseCount(child: Expression) case IntegerType => java.lang.Long.bitCount(input.asInstanceOf[Int]) case LongType => java.lang.Long.bitCount(input.asInstanceOf[Long]) } + + override protected def withNewChildInternal(newChild: Expression): BitwiseCount = + copy(child = newChild) } object BitwiseGetUtil { @@ -262,4 +277,7 @@ case class BitwiseGet(left: Expression, right: Expression) override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("bit_get") + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): BitwiseGet = copy(left = newLeft, right = newRight) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala index 689858dc6ee67..c840cdfd8b2dc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala @@ -22,7 +22,7 @@ import java.lang.{Boolean => JBool} import scala.collection.mutable.ArrayBuffer import scala.language.implicitConversions -import org.apache.spark.sql.catalyst.trees.TreeNode +import org.apache.spark.sql.catalyst.trees.{LeafLike, TreeNode} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.types.{BooleanType, DataType} @@ -298,11 +298,13 @@ case class CodeBlock(codeParts: Seq[String], blockInputs: Seq[JavaCode]) extends } buf.toString } + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Block]): Block = + super.legacyWithNewChildren(newChildren) } -case object EmptyBlock extends Block with Serializable { +case object EmptyBlock extends Block with Serializable with LeafLike[Block] { override val code: String = "" - override def children: Seq[Block] = Seq.empty } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index d3fad8cb329c2..125e796a98c2d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -125,6 +125,8 @@ case class Size(child: Expression, legacySizeOfNull: Boolean) defineCodeGen(ctx, ev, c => s"($c).numElements()") } } + + override protected def withNewChildInternal(newChild: Expression): Size = copy(child = newChild) } object Size { @@ -159,6 +161,9 @@ case class MapKeys(child: Expression) } override def prettyName: String = "map_keys" + + override protected def withNewChildInternal(newChild: Expression): MapKeys = + copy(child = newChild) } @ExpressionDescription( @@ -321,6 +326,9 @@ case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsI } override def prettyName: String = "arrays_zip" + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): ArraysZip = + copy(children = newChildren) } /** @@ -351,6 +359,9 @@ case class MapValues(child: Expression) } override def prettyName: String = "map_values" + + override protected def withNewChildInternal(newChild: Expression): MapValues = + copy(child = newChild) } /** @@ -523,6 +534,8 @@ case class MapEntries(child: Expression) } override def prettyName: String = "map_entries" + + override def withNewChildInternal(newChild: Expression): MapEntries = copy(child = newChild) } /** @@ -642,6 +655,9 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres } override def prettyName: String = "map_concat" + + override def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): MapConcat = + copy(children = newChildren) } /** @@ -720,6 +736,9 @@ case class MapFromEntries(child: Expression) extends UnaryExpression with NullIn } override def prettyName: String = "map_from_entries" + + override protected def withNewChildInternal(newChild: Expression): MapFromEntries = + copy(child = newChild) } @@ -919,6 +938,10 @@ case class SortArray(base: Expression, ascendingOrder: Expression) } override def prettyName: String = "sort_array" + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): SortArray = + copy(base = newLeft, ascendingOrder = newRight) } /** @@ -1007,6 +1030,8 @@ case class Shuffle(child: Expression, randomSeed: Option[Long] = None) } override def freshCopy(): Shuffle = Shuffle(child, randomSeed) + + override def withNewChildInternal(newChild: Expression): Shuffle = copy(child = newChild) } /** @@ -1083,6 +1108,9 @@ case class Reverse(child: Expression) } override def prettyName: String = "reverse" + + override protected def withNewChildInternal(newChild: Expression): Reverse = + copy(child = newChild) } /** @@ -1180,6 +1208,10 @@ case class ArrayContains(left: Expression, right: Expression) } override def prettyName: String = "array_contains" + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): ArrayContains = + copy(left = newLeft, right = newRight) } /** @@ -1403,6 +1435,10 @@ case class ArraysOverlap(left: Expression, right: Expression) } override def prettyName: String = "arrays_overlap" + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): ArraysOverlap = + copy(left = newLeft, right = newRight) } /** @@ -1516,6 +1552,10 @@ case class Slice(x: Expression, start: Expression, length: Expression) |} """.stripMargin } + + override protected def withNewChildrenInternal( + newFirst: Expression, newSecond: Expression, newThird: Expression): Slice = + copy(x = newFirst, start = newSecond, length = newThird) } /** @@ -1559,6 +1599,16 @@ case class ArrayJoin( Seq(array, delimiter) } + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = + if (nullReplacement.isDefined) { + copy( + array = newChildren(0), + delimiter = newChildren(1), + nullReplacement = Some(newChildren(2))) + } else { + copy(array = newChildren(0), delimiter = newChildren(1)) + } + override def nullable: Boolean = children.exists(_.nullable) override def foldable: Boolean = children.forall(_.foldable) @@ -1756,6 +1806,9 @@ case class ArrayMin(child: Expression) } override def prettyName: String = "array_min" + + override protected def withNewChildInternal(newChild: Expression): ArrayMin = + copy(child = newChild) } /** @@ -1824,6 +1877,9 @@ case class ArrayMax(child: Expression) } override def prettyName: String = "array_max" + + override protected def withNewChildInternal(newChild: Expression): ArrayMax = + copy(child = newChild) } @@ -1903,6 +1959,10 @@ case class ArrayPosition(left: Expression, right: Expression) """.stripMargin }) } + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): ArrayPosition = + copy(left = newLeft, right = newRight) } /** @@ -2085,6 +2145,9 @@ case class ElementAt( } override def prettyName: String = "element_at" + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): ElementAt = copy(left = newLeft, right = newRight) } /** @@ -2291,6 +2354,9 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio override def toString: String = s"concat(${children.mkString(", ")})" override def sql: String = s"concat(${children.map(_.sql).mkString(", ")})" + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Concat = + copy(children = newChildren) } /** @@ -2403,6 +2469,9 @@ case class Flatten(child: Expression) extends UnaryExpression with NullIntoleran } override def prettyName: String = "flatten" + + override protected def withNewChildInternal(newChild: Expression): Flatten = + copy(child = newChild) } @ExpressionDescription( @@ -2460,6 +2529,15 @@ case class Sequence( override def children: Seq[Expression] = Seq(start, stop) ++ stepOpt + override def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): TimeZoneAwareExpression = { + if (stepOpt.isDefined) { + copy(start = newChildren(0), stop = newChildren(1), stepOpt = Some(newChildren(2))) + } else { + copy(start = newChildren(0), stop = newChildren(1)) + } + } + override def foldable: Boolean = children.forall(_.foldable) override def nullable: Boolean = children.exists(_.nullable) @@ -2949,6 +3027,8 @@ case class ArrayRepeat(left: Expression, right: Expression) """.stripMargin } + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): ArrayRepeat = copy(left = newLeft, right = newRight) } /** @@ -3063,6 +3143,9 @@ case class ArrayRemove(left: Expression, right: Expression) } override def prettyName: String = "array_remove" + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): ArrayRemove = copy(left = newLeft, right = newRight) } /** @@ -3295,6 +3378,9 @@ case class ArrayDistinct(child: Expression) } override def prettyName: String = "array_distinct" + + override protected def withNewChildInternal(newChild: Expression): ArrayDistinct = + copy(child = newChild) } /** @@ -3497,6 +3583,9 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArrayBinaryLi } override def prettyName: String = "array_union" + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): ArrayUnion = copy(left = newLeft, right = newRight) } object ArrayUnion { @@ -3780,6 +3869,10 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArrayBina } override def prettyName: String = "array_intersect" + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): ArrayIntersect = + copy(left = newLeft, right = newRight) } /** @@ -4004,4 +4097,7 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArrayBinaryL } override def prettyName: String = "array_except" + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): ArrayExcept = copy(left = newLeft, right = newRight) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 3c016a7a54995..f1456c4c8e079 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -102,6 +102,9 @@ case class CreateArray(children: Seq[Expression], useStringTypeWhenEmpty: Boolea } override def prettyName: String = "array" + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): CreateArray = + copy(children = newChildren) } object CreateArray { @@ -254,6 +257,9 @@ case class CreateMap(children: Seq[Expression], useStringTypeWhenEmpty: Boolean) } override def prettyName: String = "map" + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): CreateMap = + copy(children = newChildren) } object CreateMap { @@ -314,6 +320,10 @@ case class MapFromArrays(left: Expression, right: Expression) } override def prettyName: String = "map_from_arrays" + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): MapFromArrays = + copy(left = newLeft, right = newRight) } /** @@ -493,6 +503,9 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression with val childrenSQL = children.indices.filter(_ % 2 == 1).map(children(_).sql).mkString(", ") s"$alias($childrenSQL)" }.getOrElse(super.sql) + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): CreateNamedStruct = copy(children = newChildren) } /** @@ -576,6 +589,13 @@ case class StringToMap(text: Expression, pairDelim: Expression, keyValueDelim: E } override def prettyName: String = "str_to_map" + + override protected def withNewChildrenInternal( + newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = copy( + text = newFirst, + pairDelim = newSecond, + keyValueDelim = newThird + ) } /** @@ -627,6 +647,9 @@ case class WithField(name: String, valExpr: Expression) "WithField.nullable should not be called.") override def prettyName: String = "WithField" + + override protected def withNewChildInternal(newChild: Expression): WithField = + copy(valExpr = newChild) } /** @@ -659,6 +682,9 @@ case class UpdateFields(structExpr: Expression, fieldOps: Seq[StructFieldsOperat case e: Expression => e } + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = + super.legacyWithNewChildren(newChildren) + override def dataType: StructType = StructType(newFields) override def nullable: Boolean = structExpr.nullable diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 139d9a584ccbe..f64cc8a28b566 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -138,6 +138,9 @@ case class GetStructField(child: Expression, ordinal: Int, name: Option[String] } }) } + + override protected def withNewChildInternal(newChild: Expression): GetStructField = + copy(child = newChild) } /** @@ -212,6 +215,9 @@ case class GetArrayStructFields( """ }) } + + override protected def withNewChildInternal(newChild: Expression): GetArrayStructFields = + copy(child = newChild) } /** @@ -292,6 +298,10 @@ case class GetArrayItem( """ }) } + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): GetArrayItem = + copy(child = newLeft, ordinal = newRight) } /** @@ -470,4 +480,8 @@ case class GetMapValue( override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { doGetValueGenCode(ctx, ev, child.dataType.asInstanceOf[MapType], failOnError) } + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): GetMapValue = + copy(child = newLeft, key = newRight) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index a062dd49a3c92..e708d56cd89c0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -95,6 +95,13 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi override def toString: String = s"if ($predicate) $trueValue else $falseValue" override def sql: String = s"(IF(${predicate.sql}, ${trueValue.sql}, ${falseValue.sql}))" + + override protected def withNewChildrenInternal( + newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = copy( + predicate = newFirst, + trueValue = newSecond, + falseValue = newThird + ) } /** @@ -132,6 +139,9 @@ case class CaseWhen( override def children: Seq[Expression] = branches.flatMap(b => b._1 :: b._2 :: Nil) ++ elseValue + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = + super.legacyWithNewChildren(newChildren) + // both then and else expressions should be considered. @transient override lazy val inputTypesForMerging: Seq[DataType] = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraintExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraintExpressions.scala index 5bfae7b77e096..8feaf52ecb134 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraintExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraintExpressions.scala @@ -36,6 +36,12 @@ case class KnownNotNull(child: Expression) extends TaggingExpression { override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { child.genCode(ctx).copy(isNull = FalseLiteral) } + + override protected def withNewChildInternal(newChild: Expression): KnownNotNull = + copy(child = newChild) } -case class KnownFloatingPointNormalized(child: Expression) extends TaggingExpression +case class KnownFloatingPointNormalized(child: Expression) extends TaggingExpression { + override protected def withNewChildInternal(newChild: Expression): KnownFloatingPointNormalized = + copy(child = newChild) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala index ac47020de4d46..79bbc103c92d3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala @@ -140,6 +140,9 @@ case class CsvToStructs( override def inputTypes: Seq[AbstractDataType] = StringType :: Nil override def prettyName: String = "from_csv" + + override protected def withNewChildInternal(newChild: Expression): CsvToStructs = + copy(child = newChild) } /** @@ -197,6 +200,9 @@ case class SchemaOfCsv( } override def prettyName: String = "schema_of_csv" + + override protected def withNewChildInternal(newChild: Expression): SchemaOfCsv = + copy(child = newChild) } /** @@ -264,4 +270,7 @@ case class StructsToCsv( override def inputTypes: Seq[AbstractDataType] = StructType :: Nil override def prettyName: String = "to_csv" + + override protected def withNewChildInternal(newChild: Expression): StructsToCsv = + copy(child = newChild) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 355064e73dfab..ba9d458c0ae5a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -251,6 +251,9 @@ case class DateAdd(startDate: Expression, days: Expression) } override def prettyName: String = "date_add" + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): DateAdd = copy(startDate = newLeft, days = newRight) } /** @@ -286,6 +289,9 @@ case class DateSub(startDate: Expression, days: Expression) } override def prettyName: String = "date_sub" + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): DateSub = copy(startDate = newLeft, days = newRight) } trait GetTimeField extends UnaryExpression @@ -323,6 +329,7 @@ case class Hour(child: Expression, timeZoneId: Option[String] = None) extends Ge override def withTimeZone(timeZoneId: String): Hour = copy(timeZoneId = Option(timeZoneId)) override val func = DateTimeUtils.getHours override val funcName = "getHours" + override protected def withNewChildInternal(newChild: Expression): Hour = copy(child = newChild) } @ExpressionDescription( @@ -339,6 +346,7 @@ case class Minute(child: Expression, timeZoneId: Option[String] = None) extends override def withTimeZone(timeZoneId: String): Minute = copy(timeZoneId = Option(timeZoneId)) override val func = DateTimeUtils.getMinutes override val funcName = "getMinutes" + override protected def withNewChildInternal(newChild: Expression): Minute = copy(child = newChild) } @ExpressionDescription( @@ -355,6 +363,8 @@ case class Second(child: Expression, timeZoneId: Option[String] = None) extends override def withTimeZone(timeZoneId: String): Second = copy(timeZoneId = Option(timeZoneId)) override val func = DateTimeUtils.getSeconds override val funcName = "getSeconds" + override protected def withNewChildInternal(newChild: Expression): Second = + copy(child = newChild) } case class SecondWithFraction(child: Expression, timeZoneId: Option[String] = None) @@ -366,6 +376,8 @@ case class SecondWithFraction(child: Expression, timeZoneId: Option[String] = No copy(timeZoneId = Option(timeZoneId)) override val func = DateTimeUtils.getSecondsWithFraction override val funcName = "getSecondsWithFraction" + override protected def withNewChildInternal(newChild: Expression): SecondWithFraction = + copy(child = newChild) } trait GetDateField extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { @@ -398,6 +410,8 @@ trait GetDateField extends UnaryExpression with ImplicitCastInputTypes with Null case class DayOfYear(child: Expression) extends GetDateField { override val func = DateTimeUtils.getDayInYear override val funcName = "getDayInYear" + override protected def withNewChildInternal(newChild: Expression): DayOfYear = + copy(child = newChild) } @ExpressionDescription( @@ -421,6 +435,9 @@ case class DateFromUnixDate(child: Expression) extends UnaryExpression defineCodeGen(ctx, ev, c => c) override def prettyName: String = "date_from_unix_date" + + override protected def withNewChildInternal(newChild: Expression): DateFromUnixDate = + copy(child = newChild) } @ExpressionDescription( @@ -444,6 +461,9 @@ case class UnixDate(child: Expression) extends UnaryExpression defineCodeGen(ctx, ev, c => c) override def prettyName: String = "unix_date" + + override protected def withNewChildInternal(newChild: Expression): UnixDate = + copy(child = newChild) } abstract class IntegralToTimestampBase extends UnaryExpression @@ -531,6 +551,9 @@ case class SecondsToTimestamp(child: Expression) extends UnaryExpression } override def prettyName: String = "timestamp_seconds" + + override protected def withNewChildInternal(newChild: Expression): SecondsToTimestamp = + copy(child = newChild) } // scalastyle:off line.size.limit @@ -550,6 +573,9 @@ case class MillisToTimestamp(child: Expression) override def upScaleFactor: Long = MICROS_PER_MILLIS override def prettyName: String = "timestamp_millis" + + override protected def withNewChildInternal(newChild: Expression): MillisToTimestamp = + copy(child = newChild) } // scalastyle:off line.size.limit @@ -569,6 +595,9 @@ case class MicrosToTimestamp(child: Expression) override def upScaleFactor: Long = 1L override def prettyName: String = "timestamp_micros" + + override protected def withNewChildInternal(newChild: Expression): MicrosToTimestamp = + copy(child = newChild) } abstract class TimestampToLongBase extends UnaryExpression @@ -608,6 +637,9 @@ case class UnixSeconds(child: Expression) extends TimestampToLongBase { override def scaleFactor: Long = MICROS_PER_SECOND override def prettyName: String = "unix_seconds" + + override protected def withNewChildInternal(newChild: Expression): UnixSeconds = + copy(child = newChild) } // scalastyle:off line.size.limit @@ -625,6 +657,9 @@ case class UnixMillis(child: Expression) extends TimestampToLongBase { override def scaleFactor: Long = MICROS_PER_MILLIS override def prettyName: String = "unix_millis" + + override protected def withNewChildInternal(newChild: Expression): UnixMillis = + copy(child = newChild) } // scalastyle:off line.size.limit @@ -642,6 +677,9 @@ case class UnixMicros(child: Expression) extends TimestampToLongBase { override def scaleFactor: Long = 1L override def prettyName: String = "unix_micros" + + override protected def withNewChildInternal(newChild: Expression): UnixMicros = + copy(child = newChild) } @ExpressionDescription( @@ -656,11 +694,15 @@ case class UnixMicros(child: Expression) extends TimestampToLongBase { case class Year(child: Expression) extends GetDateField { override val func = DateTimeUtils.getYear override val funcName = "getYear" + override protected def withNewChildInternal(newChild: Expression): Year = + copy(child = newChild) } case class YearOfWeek(child: Expression) extends GetDateField { override val func = DateTimeUtils.getWeekBasedYear override val funcName = "getWeekBasedYear" + override protected def withNewChildInternal(newChild: Expression): YearOfWeek = + copy(child = newChild) } @ExpressionDescription( @@ -675,6 +717,8 @@ case class YearOfWeek(child: Expression) extends GetDateField { case class Quarter(child: Expression) extends GetDateField { override val func = DateTimeUtils.getQuarter override val funcName = "getQuarter" + override protected def withNewChildInternal(newChild: Expression): Quarter = + copy(child = newChild) } @ExpressionDescription( @@ -689,6 +733,7 @@ case class Quarter(child: Expression) extends GetDateField { case class Month(child: Expression) extends GetDateField { override val func = DateTimeUtils.getMonth override val funcName = "getMonth" + override protected def withNewChildInternal(newChild: Expression): Month = copy(child = newChild) } @ExpressionDescription( @@ -703,6 +748,8 @@ case class Month(child: Expression) extends GetDateField { case class DayOfMonth(child: Expression) extends GetDateField { override val func = DateTimeUtils.getDayOfMonth override val funcName = "getDayOfMonth" + override protected def withNewChildInternal(newChild: Expression): DayOfMonth = + copy(child = newChild) } // scalastyle:off line.size.limit @@ -719,6 +766,8 @@ case class DayOfMonth(child: Expression) extends GetDateField { case class DayOfWeek(child: Expression) extends GetDateField { override val func = DateTimeUtils.getDayOfWeek override val funcName = "getDayOfWeek" + override protected def withNewChildInternal(newChild: Expression): DayOfWeek = + copy(child = newChild) } // scalastyle:off line.size.limit @@ -735,6 +784,8 @@ case class DayOfWeek(child: Expression) extends GetDateField { case class WeekDay(child: Expression) extends GetDateField { override val func = DateTimeUtils.getWeekDay override val funcName = "getWeekDay" + override protected def withNewChildInternal(newChild: Expression): WeekDay = + copy(child = newChild) } // scalastyle:off line.size.limit @@ -751,6 +802,8 @@ case class WeekDay(child: Expression) extends GetDateField { case class WeekOfYear(child: Expression) extends GetDateField { override val func = DateTimeUtils.getWeekOfYear override val funcName = "getWeekOfYear" + override protected def withNewChildInternal(newChild: Expression): WeekOfYear = + copy(child = newChild) } // scalastyle:off line.size.limit @@ -814,6 +867,10 @@ case class DateFormatClass(left: Expression, right: Expression, timeZoneId: Opti override protected def formatString: Expression = right override protected def isParsing: Boolean = false + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): DateFormatClass = + copy(left = newLeft, right = newRight) } /** @@ -859,6 +916,10 @@ case class ToUnixTimestamp( } override def prettyName: String = "to_unix_timestamp" + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): ToUnixTimestamp = + copy(timeExp = newLeft, format = newRight) } // scalastyle:off line.size.limit @@ -915,6 +976,10 @@ case class UnixTimestamp( } override def prettyName: String = "unix_timestamp" + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): UnixTimestamp = + copy(timeExp = newLeft, format = newRight) } abstract class ToTimestamp @@ -1120,6 +1185,10 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[ override protected def formatString: Expression = format override protected def isParsing: Boolean = false + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): FromUnixTime = + copy(sec = newLeft, format = newRight) } /** @@ -1152,6 +1221,9 @@ case class LastDay(startDate: Expression) } override def prettyName: String = "last_day" + + override protected def withNewChildInternal(newChild: Expression): LastDay = + copy(startDate = newChild) } /** @@ -1249,6 +1321,10 @@ case class NextDay( } override def prettyName: String = "next_day" + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): NextDay = + copy(startDate = newLeft, dayOfWeek = newRight) } /** @@ -1292,6 +1368,10 @@ case class TimeAdd(start: Expression, interval: Expression, timeZoneId: Option[S }) } } + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): TimeAdd = + copy(start = newLeft, interval = newRight) } /** @@ -1305,6 +1385,8 @@ case class DatetimeSub( override def exprsReplaced: Seq[Expression] = Seq(start, interval) override def toString: String = s"$start - $interval" override def mkString(childrenString: Seq[String]): String = childrenString.mkString(" - ") + override protected def withNewChildInternal(newChild: Expression): DatetimeSub = + copy(child = newChild) } /** @@ -1367,6 +1449,10 @@ case class DateAddInterval( override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = copy(timeZoneId = Option(timeZoneId)) + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): DateAddInterval = + copy(start = newLeft, interval = newRight) } sealed trait UTCTimestamp extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { @@ -1447,6 +1533,9 @@ case class FromUTCTimestamp(left: Expression, right: Expression) extends UTCTime override val func = DateTimeUtils.fromUTCTime override val funcName: String = "fromUTCTime" override val prettyName: String = "from_utc_timestamp" + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): FromUTCTimestamp = + copy(left = newLeft, right = newRight) } /** @@ -1478,6 +1567,9 @@ case class ToUTCTimestamp(left: Expression, right: Expression) extends UTCTimest override val func = DateTimeUtils.toUTCTime override val funcName: String = "toUTCTime" override val prettyName: String = "to_utc_timestamp" + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): ToUTCTimestamp = + copy(left = newLeft, right = newRight) } abstract class AddMonthsBase extends BinaryExpression with ImplicitCastInputTypes @@ -1517,6 +1609,10 @@ case class AddMonths(startDate: Expression, numMonths: Expression) extends AddMo override def inputTypes: Seq[AbstractDataType] = Seq(DateType, IntegerType) override def prettyName: String = "add_months" + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): AddMonths = + copy(startDate = newLeft, numMonths = newRight) } // Adds the year-month interval to the date @@ -1528,6 +1624,10 @@ case class DateAddYMInterval(date: Expression, interval: Expression) extends Add override def toString: String = s"$left + $right" override def sql: String = s"${left.sql} + ${right.sql}" + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): DateAddYMInterval = + copy(date = newLeft, interval = newRight) } // Adds the year-month interval to the timestamp @@ -1562,6 +1662,10 @@ case class TimestampAddYMInterval( s"""$dtu.timestampAddMonths($micros, $months, $zid)""" }) } + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): TimestampAddYMInterval = + copy(timestamp = newLeft, interval = newRight) } /** @@ -1628,6 +1732,10 @@ case class MonthsBetween( } override def prettyName: String = "months_between" + + override protected def withNewChildrenInternal( + newFirst: Expression, newSecond: Expression, newThird: Expression): MonthsBetween = + copy(date1 = newFirst, date2 = newSecond, roundOff = newThird) } /** @@ -1672,6 +1780,9 @@ case class ParseToDate(left: Expression, format: Option[Expression], child: Expr override def flatArguments: Iterator[Any] = Iterator(left, format) override def prettyName: String = "to_date" + + override protected def withNewChildInternal(newChild: Expression): ParseToDate = + copy(child = newChild) } /** @@ -1714,6 +1825,9 @@ case class ParseToTimestamp(left: Expression, format: Option[Expression], child: override def prettyName: String = "to_timestamp" override def dataType: DataType = TimestampType + + override protected def withNewChildInternal(newChild: Expression): ParseToTimestamp = + copy(child = newChild) } trait TruncInstant extends BinaryExpression with ImplicitCastInputTypes { @@ -1849,6 +1963,10 @@ case class TruncDate(date: Expression, format: Expression) (date: String, fmt: String) => s"truncDate($date, $fmt);" } } + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): TruncDate = + copy(date = newLeft, format = newRight) } /** @@ -1920,6 +2038,10 @@ case class TruncTimestamp( s"truncTimestamp($date, $fmt, $zid);" } } + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): TruncTimestamp = + copy(format = newLeft, timestamp = newRight) } /** @@ -1952,6 +2074,10 @@ case class DateDiff(endDate: Expression, startDate: Expression) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, (end, start) => s"$end - $start") } + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): DateDiff = + copy(endDate = newLeft, startDate = newRight) } /** @@ -1969,6 +2095,10 @@ private case class GetTimestamp( override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = copy(timeZoneId = Option(timeZoneId)) + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): GetTimestamp = + copy(left = newLeft, right = newRight) } @ExpressionDescription( @@ -2032,6 +2162,10 @@ case class MakeDate( } override def prettyName: String = "make_date" + + override protected def withNewChildrenInternal( + newFirst: Expression, newSecond: Expression, newThird: Expression): MakeDate = + copy(year = newFirst, month = newSecond, day = newThird) } // scalastyle:off line.size.limit @@ -2198,6 +2332,20 @@ case class MakeTimestamp( } override def prettyName: String = "make_timestamp" + +// override def children: Seq[Expression] = Seq(year, month, day, hour, min, sec) ++ timezone + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): MakeTimestamp = { + val timezoneOpt = if (timezone.isDefined) Some(newChildren(6)) else None + copy( + year = newChildren(0), + month = newChildren(1), + day = newChildren(2), + hour = newChildren(3), + min = newChildren(4), + sec = newChildren(5), + timezone = timezoneOpt) + } } object DatePart { @@ -2284,6 +2432,9 @@ case class DatePart(field: Expression, source: Expression, child: Expression) override def exprsReplaced: Seq[Expression] = Seq(field, source) override def prettyName: String = "date_part" + + override protected def withNewChildInternal(newChild: Expression): DatePart = + copy(child = newChild) } // scalastyle:off line.size.limit @@ -2349,6 +2500,9 @@ case class Extract(field: Expression, source: Expression, child: Expression) override def mkString(childrenString: Seq[String]): String = { prettyName + childrenString.mkString("(", " FROM ", ")") } + + override protected def withNewChildInternal(newChild: Expression): Extract = + copy(child = newChild) } /** @@ -2401,6 +2555,10 @@ case class SubtractTimestamps( defineCodeGen(ctx, ev, (end, start) => s"new org.apache.spark.unsafe.types.CalendarInterval(0, 0, $end - $start)") } + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): SubtractTimestamps = + copy(left = newLeft, right = newRight) } object SubtractTimestamps { @@ -2452,6 +2610,10 @@ case class SubtractDates( s"$dtu.subtractDates($leftDays, $rightDays)" }) } + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): SubtractDates = + copy(left = newLeft, right = newRight) } object SubtractDates { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala index b987beda6407e..7165bca201a9e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala @@ -40,6 +40,9 @@ case class UnscaledValue(child: Expression) extends UnaryExpression with NullInt override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, c => s"$c.toUnscaledLong()") } + + override protected def withNewChildInternal(newChild: Expression): UnscaledValue = + copy(child = newChild) } /** @@ -89,6 +92,9 @@ case class MakeDecimal( |""".stripMargin }) } + + override protected def withNewChildInternal(newChild: Expression): MakeDecimal = + copy(child = newChild) } object MakeDecimal { @@ -111,6 +117,9 @@ case class PromotePrecision(child: Expression) extends UnaryExpression { override def prettyName: String = "promote_precision" override def sql: String = child.sql override lazy val canonicalized: Expression = child.canonicalized + + override protected def withNewChildInternal(newChild: Expression): Expression = + copy(child = newChild) } /** @@ -145,6 +154,9 @@ case class CheckOverflow( override def toString: String = s"CheckOverflow($child, $dataType, $nullOnOverflow)" override def sql: String = child.sql + + override protected def withNewChildInternal(newChild: Expression): CheckOverflow = + copy(child = newChild) } // A variant `CheckOverflow`, which treats null as overflow. This is necessary in `Sum`. @@ -194,4 +206,7 @@ case class CheckOverflowInSum( override def toString: String = s"CheckOverflowInSum($child, $dataType, $nullOnOverflow)" override def sql: String = child.sql + + override protected def withNewChildInternal(newChild: Expression): CheckOverflowInSum = + copy(child = newChild) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index f10ceea519cce..fef9bb338d834 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -118,6 +118,9 @@ case class UserDefinedGenerator( } override def toString: String = s"UserDefinedGenerator(${children.mkString(",")})" + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): UserDefinedGenerator = copy(children = newChildren) } /** @@ -227,6 +230,9 @@ case class Stack(children: Seq[Expression]) extends Generator { |$wrapperClass ${ev.value} = $wrapperClass$$.MODULE$$.make($rowData); """.stripMargin, isNull = FalseLiteral) } + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Stack = + copy(children = newChildren) } /** @@ -253,6 +259,9 @@ case class ReplicateRows(children: Seq[Expression]) extends Generator with Codeg InternalRow(fields: _*) } } + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): ReplicateRows = copy(children = newChildren) } /** @@ -269,6 +278,9 @@ case class GeneratorOuter(child: Generator) extends UnaryExpression with Generat override def elementSchema: StructType = child.elementSchema override lazy val resolved: Boolean = false + + override protected def withNewChildInternal(newChild: Expression): GeneratorOuter = + copy(child = newChild.asInstanceOf[Generator]) } /** @@ -369,6 +381,8 @@ abstract class ExplodeBase extends UnaryExpression with CollectionGenerator with // scalastyle:on line.size.limit case class Explode(child: Expression) extends ExplodeBase { override val position: Boolean = false + override protected def withNewChildInternal(newChild: Expression): Explode = + copy(child = newChild) } /** @@ -394,6 +408,8 @@ case class Explode(child: Expression) extends ExplodeBase { // scalastyle:on line.size.limit line.contains.tab case class PosExplode(child: Expression) extends ExplodeBase { override val position = true + override protected def withNewChildInternal(newChild: Expression): PosExplode = + copy(child = newChild) } /** @@ -445,4 +461,6 @@ case class Inline(child: Expression) extends UnaryExpression with CollectionGene override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { child.genCode(ctx) } + + override protected def withNewChildInternal(newChild: Expression): Inline = copy(child = newChild) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala index bf28efabcd561..0dd82bed15082 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala @@ -111,6 +111,8 @@ case class Cube( children: Seq[Expression]) extends BaseGroupingSets { override def groupingSets: Seq[Seq[Expression]] = groupingSetIndexes.map(_.map(children)) override def selectedGroupByExprs: Seq[Seq[Expression]] = BaseGroupingSets.cubeExprs(groupingSets) + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Cube = + copy(children = newChildren) } object Cube { @@ -125,6 +127,8 @@ case class Rollup( override def groupingSets: Seq[Seq[Expression]] = groupingSetIndexes.map(_.map(children)) override def selectedGroupByExprs: Seq[Seq[Expression]] = BaseGroupingSets.rollupExprs(groupingSets) + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Rollup = + copy(children = newChildren) } object Rollup { @@ -142,6 +146,9 @@ case class GroupingSets( // Includes the `userGivenGroupByExprs` in the children, which will be included in the final // GROUP BY expressions, so that `SELECT c ... GROUP BY (a, b, c) GROUPING SETS (a, b)` works. override def children: Seq[Expression] = flatGroupingSets ++ userGivenGroupByExprs + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): GroupingSets = + super.legacyWithNewChildren(newChildren).asInstanceOf[GroupingSets] } object GroupingSets { @@ -184,6 +191,8 @@ case class Grouping(child: Expression) extends Expression with Unevaluable AttributeSet(VirtualColumn.groupingIdAttribute :: Nil) override def dataType: DataType = ByteType override def nullable: Boolean = false + override protected def withNewChildInternal(newChild: Expression): Grouping = + copy(child = newChild) } /** @@ -223,6 +232,8 @@ case class GroupingID(groupByExprs: Seq[Expression]) extends Expression with Une override def dataType: DataType = GroupingID.dataType override def nullable: Boolean = false override def prettyName: String = "grouping_id" + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): GroupingID = + copy(groupByExprs = newChildren) } object GroupingID { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index 9738559b6d67a..f23c1e56ce4e9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -69,6 +69,8 @@ case class Md5(child: Expression) defineCodeGen(ctx, ev, c => s"UTF8String.fromString(${classOf[DigestUtils].getName}.md5Hex($c))") } + + override protected def withNewChildInternal(newChild: Expression): Md5 = copy(child = newChild) } /** @@ -152,6 +154,9 @@ case class Sha2(left: Expression, right: Expression) """ }) } + + override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Sha2 = + copy(left = newLeft, right = newRight) } /** @@ -182,6 +187,8 @@ case class Sha1(child: Expression) s"UTF8String.fromString(${classOf[DigestUtils].getName}.sha1Hex($c))" ) } + + override protected def withNewChildInternal(newChild: Expression): Sha1 = copy(child = newChild) } /** @@ -221,6 +228,8 @@ case class Crc32(child: Expression) """ }) } + + override protected def withNewChildInternal(newChild: Expression): Crc32 = copy(child = newChild) } @@ -598,6 +607,9 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends HashExpress override protected def computeHash(value: Any, dataType: DataType, seed: Int): Int = { Murmur3HashFunction.hash(value, dataType, seed).toInt } + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Murmur3Hash = + copy(children = newChildren) } object Murmur3HashFunction extends InterpretedHashFunction { @@ -638,6 +650,9 @@ case class XxHash64(children: Seq[Expression], seed: Long) extends HashExpressio override protected def computeHash(value: Any, dataType: DataType, seed: Long): Long = { XxHash64Function.hash(value, dataType, seed) } + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): XxHash64 = + copy(children = newChildren) } object XxHash64Function extends InterpretedHashFunction { @@ -842,6 +857,9 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { |$code """.stripMargin } + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): HiveHash = + copy(children = newChildren) } object HiveHashFunction extends InterpretedHashFunction { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index bbfdf7135824c..a0f9dc2f58b20 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -103,6 +103,12 @@ case class LambdaFunction( lazy val bound: Boolean = arguments.forall(_.resolved) override def eval(input: InternalRow): Any = function.eval(input) + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): LambdaFunction = + copy( + function = newChildren.head, + arguments = newChildren.tail.asInstanceOf[Seq[NamedExpression]]) } object LambdaFunction { @@ -219,6 +225,7 @@ trait SimpleHigherOrderFunction extends HigherOrderFunction with BinaryLike[Expr nullSafeEval(inputRow, value) } } + } trait ArrayBasedSimpleHigherOrderFunction extends SimpleHigherOrderFunction { @@ -289,6 +296,10 @@ case class ArrayTransform( } override def prettyName: String = "transform" + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): ArrayTransform = + copy(argument = newLeft, function = newRight) } /** @@ -378,6 +389,10 @@ case class ArraySort( } override def prettyName: String = "array_sort" + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): ArraySort = + copy(argument = newLeft, function = newRight) } object ArraySort { @@ -448,6 +463,10 @@ case class MapFilter( override def functionType: AbstractDataType = BooleanType override def prettyName: String = "map_filter" + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): MapFilter = + copy(argument = newLeft, function = newRight) } /** @@ -513,6 +532,10 @@ case class ArrayFilter( } override def prettyName: String = "filter" + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): ArrayFilter = + copy(argument = newLeft, function = newRight) } /** @@ -594,6 +617,10 @@ case class ArrayExists( } override def prettyName: String = "exists" + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): ArrayExists = + copy(argument = newLeft, function = newRight) } object ArrayExists { @@ -670,6 +697,10 @@ case class ArrayForAll( } override def prettyName: String = "forall" + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): ArrayForAll = + copy(argument = newLeft, function = newRight) } /** @@ -767,6 +798,10 @@ case class ArrayAggregate( override def second: Expression = zero override def third: Expression = merge override def fourth: Expression = finish + + override protected def withNewChildrenInternal(first: Expression, second: Expression, + third: Expression, fourth: Expression): ArrayAggregate = + copy(argument = first, zero = second, merge = third, finish = fourth) } /** @@ -802,7 +837,7 @@ case class TransformKeys( } @transient lazy val LambdaFunction( - _, (keyVar: NamedLambdaVariable) :: (valueVar: NamedLambdaVariable) :: Nil, _) = function + _, Seq(keyVar: NamedLambdaVariable, valueVar: NamedLambdaVariable), _) = function private lazy val mapBuilder = new ArrayBasedMapBuilder(dataType.keyType, dataType.valueType) @@ -821,6 +856,10 @@ case class TransformKeys( } override def prettyName: String = "transform_keys" + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): TransformKeys = + copy(argument = newLeft, function = newRight) } /** @@ -852,7 +891,7 @@ case class TransformValues( } @transient lazy val LambdaFunction( - _, (keyVar: NamedLambdaVariable) :: (valueVar: NamedLambdaVariable) :: Nil, _) = function + _, Seq(keyVar: NamedLambdaVariable, valueVar: NamedLambdaVariable), _) = function override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = { val map = argumentValue.asInstanceOf[MapData] @@ -869,6 +908,10 @@ case class TransformValues( } override def prettyName: String = "transform_values" + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): TransformValues = + copy(argument = newLeft, function = newRight) } /** @@ -1056,6 +1099,13 @@ case class MapZipWith(left: Expression, right: Expression, function: Expression) override def first: Expression = left override def second: Expression = right override def third: Expression = function + + override protected def withNewChildrenInternal( + newFirst: Expression, newSecond: Expression, newThird: Expression): MapZipWith = + copy( + left = newFirst, + right = newSecond, + function = newThird) } // scalastyle:off line.size.limit @@ -1136,4 +1186,8 @@ case class ZipWith(left: Expression, right: Expression, function: Expression) override def first: Expression = left override def second: Expression = right override def third: Expression = function + + override protected def withNewChildrenInternal( + newFirst: Expression, newSecond: Expression, newThird: Expression): ZipWith = + copy(left = newFirst, right = newSecond, function = newThird) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala index 23cf0bcafbe10..4311b38bdc78c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala @@ -49,22 +49,40 @@ abstract class ExtractIntervalPart( } case class ExtractIntervalYears(child: Expression) - extends ExtractIntervalPart(child, IntegerType, getYears, "getYears") + extends ExtractIntervalPart(child, IntegerType, getYears, "getYears") { + override protected def withNewChildInternal(newChild: Expression): ExtractIntervalYears = + copy(child = newChild) +} case class ExtractIntervalMonths(child: Expression) - extends ExtractIntervalPart(child, ByteType, getMonths, "getMonths") + extends ExtractIntervalPart(child, ByteType, getMonths, "getMonths") { + override protected def withNewChildInternal(newChild: Expression): ExtractIntervalMonths = + copy(child = newChild) +} case class ExtractIntervalDays(child: Expression) - extends ExtractIntervalPart(child, IntegerType, getDays, "getDays") + extends ExtractIntervalPart(child, IntegerType, getDays, "getDays") { + override protected def withNewChildInternal(newChild: Expression): ExtractIntervalDays = + copy(child = newChild) +} case class ExtractIntervalHours(child: Expression) - extends ExtractIntervalPart(child, LongType, getHours, "getHours") + extends ExtractIntervalPart(child, LongType, getHours, "getHours") { + override protected def withNewChildInternal(newChild: Expression): ExtractIntervalHours = + copy(child = newChild) +} case class ExtractIntervalMinutes(child: Expression) - extends ExtractIntervalPart(child, ByteType, getMinutes, "getMinutes") + extends ExtractIntervalPart(child, ByteType, getMinutes, "getMinutes") { + override protected def withNewChildInternal(newChild: Expression): ExtractIntervalMinutes = + copy(child = newChild) +} case class ExtractIntervalSeconds(child: Expression) - extends ExtractIntervalPart(child, DecimalType(8, 6), getSeconds, "getSeconds") + extends ExtractIntervalPart(child, DecimalType(8, 6), getSeconds, "getSeconds") { + override protected def withNewChildInternal(newChild: Expression): ExtractIntervalSeconds = + copy(child = newChild) +} object ExtractIntervalPart { @@ -119,6 +137,10 @@ case class MultiplyInterval( if (failOnError) multiplyExact else multiply override protected def operationName: String = if (failOnError) "multiplyExact" else "multiply" + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): MultiplyInterval = + copy(interval = newLeft, num = newRight) } case class DivideInterval( @@ -131,6 +153,10 @@ case class DivideInterval( if (failOnError) divideExact else divide override protected def operationName: String = if (failOnError) "divideExact" else "divide" + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): DivideInterval = + copy(interval = newLeft, num = newRight) } // scalastyle:off line.size.limit @@ -251,6 +277,19 @@ case class MakeInterval( } override def prettyName: String = "make_interval" + + // Seq(years, months, weeks, days, hours, mins, secs) + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): MakeInterval = + copy( + years = newChildren(0), + months = newChildren(1), + weeks = newChildren(2), + days = newChildren(3), + hours = newChildren(4), + mins = newChildren(5), + secs = newChildren(6) + ) } // Multiply an year-month interval by a numeric @@ -298,6 +337,10 @@ case class MultiplyYMInterval( } override def toString: String = s"($left * $right)" + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): MultiplyYMInterval = + copy(interval = newLeft, num = newRight) } // Multiply a day-time interval by a numeric @@ -340,6 +383,10 @@ case class MultiplyDTInterval( } override def toString: String = s"($left * $right)" + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): MultiplyDTInterval = + copy(interval = newLeft, num = newRight) } // Divide an year-month interval by a numeric @@ -394,6 +441,10 @@ case class DivideYMInterval( } override def toString: String = s"($left / $right)" + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): DivideYMInterval = + copy(interval = newLeft, num = newRight) } // Divide a day-time interval by a numeric @@ -437,4 +488,8 @@ case class DivideDTInterval( } override def toString: String = s"($left / $right)" + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): DivideDTInterval = + copy(interval = newLeft, num = newRight) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index b217110f075a7..6a56bbf1916bd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -335,6 +335,10 @@ case class GetJsonObject(json: Expression, path: Expression) false } } + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): GetJsonObject = + copy(json = newLeft, path = newRight) } // scalastyle:off line.size.limit line.contains.tab @@ -498,6 +502,9 @@ case class JsonTuple(children: Seq[Expression]) generator.copyCurrentStructure(parser) } } + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): JsonTuple = + copy(children = newChildren) } /** @@ -609,6 +616,9 @@ case class JsonToStructs( } override def prettyName: String = "from_json" + + override protected def withNewChildInternal(newChild: Expression): JsonToStructs = + copy(child = newChild) } /** @@ -731,6 +741,9 @@ case class StructsToJson( override def inputTypes: Seq[AbstractDataType] = TypeCollection(ArrayType, StructType) :: Nil override def prettyName: String = "to_json" + + override protected def withNewChildInternal(newChild: Expression): StructsToJson = + copy(child = newChild) } /** @@ -805,6 +818,9 @@ case class SchemaOfJson( } override def prettyName: String = "schema_of_json" + + override protected def withNewChildInternal(newChild: Expression): SchemaOfJson = + copy(child = newChild) } /** @@ -874,6 +890,9 @@ case class LengthOfJsonArray(child: Expression) extends UnaryExpression } length } + + override protected def withNewChildInternal(newChild: Expression): LengthOfJsonArray = + copy(child = newChild) } /** @@ -943,4 +962,7 @@ case class JsonObjectKeys(child: Expression) extends UnaryExpression with Codege } new GenericArrayData(arrayBufferOfKeys.toArray) } + + override protected def withNewChildInternal(newChild: Expression): JsonObjectKeys = + copy(child = newChild) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index 3b58f3d868d3c..516eeb9929e80 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -187,7 +187,9 @@ case class Pi() extends LeafMathExpression(math.Pi, "PI") """, since = "1.4.0", group = "math_funcs") -case class Acos(child: Expression) extends UnaryMathExpression(math.acos, "ACOS") +case class Acos(child: Expression) extends UnaryMathExpression(math.acos, "ACOS") { + override protected def withNewChildInternal(newChild: Expression): Acos = copy(child = newChild) +} @ExpressionDescription( usage = """ @@ -203,7 +205,9 @@ case class Acos(child: Expression) extends UnaryMathExpression(math.acos, "ACOS" """, since = "1.4.0", group = "math_funcs") -case class Asin(child: Expression) extends UnaryMathExpression(math.asin, "ASIN") +case class Asin(child: Expression) extends UnaryMathExpression(math.asin, "ASIN") { + override protected def withNewChildInternal(newChild: Expression): Asin = copy(child = newChild) +} @ExpressionDescription( usage = """ @@ -217,7 +221,9 @@ case class Asin(child: Expression) extends UnaryMathExpression(math.asin, "ASIN" """, since = "1.4.0", group = "math_funcs") -case class Atan(child: Expression) extends UnaryMathExpression(math.atan, "ATAN") +case class Atan(child: Expression) extends UnaryMathExpression(math.atan, "ATAN") { + override protected def withNewChildInternal(newChild: Expression): Atan = copy(child = newChild) +} @ExpressionDescription( usage = "_FUNC_(expr) - Returns the cube root of `expr`.", @@ -228,7 +234,9 @@ case class Atan(child: Expression) extends UnaryMathExpression(math.atan, "ATAN" """, since = "1.4.0", group = "math_funcs") -case class Cbrt(child: Expression) extends UnaryMathExpression(math.cbrt, "CBRT") +case class Cbrt(child: Expression) extends UnaryMathExpression(math.cbrt, "CBRT") { + override protected def withNewChildInternal(newChild: Expression): Cbrt = copy(child = newChild) +} @ExpressionDescription( usage = "_FUNC_(expr) - Returns the smallest integer not smaller than `expr`.", @@ -267,6 +275,8 @@ case class Ceil(child: Expression) extends UnaryMathExpression(math.ceil, "CEIL" case _ => defineCodeGen(ctx, ev, c => s"(long)(java.lang.Math.${funcName}($c))") } } + + override protected def withNewChildInternal(newChild: Expression): Ceil = copy(child = newChild) } @ExpressionDescription( @@ -285,7 +295,9 @@ case class Ceil(child: Expression) extends UnaryMathExpression(math.ceil, "CEIL" """, since = "1.4.0", group = "math_funcs") -case class Cos(child: Expression) extends UnaryMathExpression(math.cos, "COS") +case class Cos(child: Expression) extends UnaryMathExpression(math.cos, "COS") { + override protected def withNewChildInternal(newChild: Expression): Cos = copy(child = newChild) +} @ExpressionDescription( usage = """ @@ -303,7 +315,9 @@ case class Cos(child: Expression) extends UnaryMathExpression(math.cos, "COS") """, since = "1.4.0", group = "math_funcs") -case class Cosh(child: Expression) extends UnaryMathExpression(math.cosh, "COSH") +case class Cosh(child: Expression) extends UnaryMathExpression(math.cosh, "COSH") { + override protected def withNewChildInternal(newChild: Expression): Cosh = copy(child = newChild) +} @ExpressionDescription( usage = """ @@ -324,6 +338,7 @@ case class Acosh(child: Expression) defineCodeGen(ctx, ev, c => s"java.lang.StrictMath.log($c + java.lang.Math.sqrt($c * $c - 1.0))") } + override protected def withNewChildInternal(newChild: Expression): Acosh = copy(child = newChild) } /** @@ -372,6 +387,10 @@ case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expre """ ) } + + override protected def withNewChildrenInternal( + newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = + copy(numExpr = newFirst, fromBaseExpr = newSecond, toBaseExpr = newThird) } @ExpressionDescription( @@ -387,6 +406,7 @@ case class Exp(child: Expression) extends UnaryMathExpression(StrictMath.exp, "E override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, c => s"java.lang.StrictMath.exp($c)") } + override protected def withNewChildInternal(newChild: Expression): Exp = copy(child = newChild) } @ExpressionDescription( @@ -402,6 +422,7 @@ case class Expm1(child: Expression) extends UnaryMathExpression(StrictMath.expm1 override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, c => s"java.lang.StrictMath.expm1($c)") } + override protected def withNewChildInternal(newChild: Expression): Expm1 = copy(child = newChild) } @ExpressionDescription( @@ -441,6 +462,8 @@ case class Floor(child: Expression) extends UnaryMathExpression(math.floor, "FLO case _ => defineCodeGen(ctx, ev, c => s"(long)(java.lang.Math.${funcName}($c))") } } + + override protected def withNewChildInternal(newChild: Expression): Floor = copy(child = newChild) } object Factorial { @@ -514,6 +537,9 @@ case class Factorial(child: Expression) """ }) } + + override protected def withNewChildInternal(newChild: Expression): Factorial = + copy(child = newChild) } @ExpressionDescription( @@ -527,6 +553,7 @@ case class Factorial(child: Expression) group = "math_funcs") case class Log(child: Expression) extends UnaryLogExpression(StrictMath.log, "LOG") { override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("ln") + override protected def withNewChildInternal(newChild: Expression): Log = copy(child = newChild) } @ExpressionDescription( @@ -551,6 +578,7 @@ case class Log2(child: Expression) """ ) } + override protected def withNewChildInternal(newChild: Expression): Log2 = copy(child = newChild) } @ExpressionDescription( @@ -562,7 +590,9 @@ case class Log2(child: Expression) """, since = "1.4.0", group = "math_funcs") -case class Log10(child: Expression) extends UnaryLogExpression(StrictMath.log10, "LOG10") +case class Log10(child: Expression) extends UnaryLogExpression(StrictMath.log10, "LOG10") { + override protected def withNewChildInternal(newChild: Expression): Log10 = copy(child = newChild) +} @ExpressionDescription( usage = "_FUNC_(expr) - Returns log(1 + `expr`).", @@ -575,6 +605,7 @@ case class Log10(child: Expression) extends UnaryLogExpression(StrictMath.log10, group = "math_funcs") case class Log1p(child: Expression) extends UnaryLogExpression(StrictMath.log1p, "LOG1P") { protected override val yAsymptote: Double = -1.0 + override protected def withNewChildInternal(newChild: Expression): Log1p = copy(child = newChild) } // scalastyle:off line.size.limit @@ -591,6 +622,7 @@ case class Log1p(child: Expression) extends UnaryLogExpression(StrictMath.log1p, case class Rint(child: Expression) extends UnaryMathExpression(math.rint, "ROUND") { override def funcName: String = "rint" override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("rint") + override protected def withNewChildInternal(newChild: Expression): Rint = copy(child = newChild) } @ExpressionDescription( @@ -602,7 +634,9 @@ case class Rint(child: Expression) extends UnaryMathExpression(math.rint, "ROUND """, since = "1.4.0", group = "math_funcs") -case class Signum(child: Expression) extends UnaryMathExpression(math.signum, "SIGNUM") +case class Signum(child: Expression) extends UnaryMathExpression(math.signum, "SIGNUM") { + override protected def withNewChildInternal(newChild: Expression): Signum = copy(child = newChild) +} @ExpressionDescription( usage = "_FUNC_(expr) - Returns the sine of `expr`, as if computed by `java.lang.Math._FUNC_`.", @@ -617,7 +651,9 @@ case class Signum(child: Expression) extends UnaryMathExpression(math.signum, "S """, since = "1.4.0", group = "math_funcs") -case class Sin(child: Expression) extends UnaryMathExpression(math.sin, "SIN") +case class Sin(child: Expression) extends UnaryMathExpression(math.sin, "SIN") { + override protected def withNewChildInternal(newChild: Expression): Sin = copy(child = newChild) +} @ExpressionDescription( usage = """ @@ -634,7 +670,9 @@ case class Sin(child: Expression) extends UnaryMathExpression(math.sin, "SIN") """, since = "1.4.0", group = "math_funcs") -case class Sinh(child: Expression) extends UnaryMathExpression(math.sinh, "SINH") +case class Sinh(child: Expression) extends UnaryMathExpression(math.sinh, "SINH") { + override protected def withNewChildInternal(newChild: Expression): Sinh = copy(child = newChild) +} @ExpressionDescription( usage = """ @@ -656,6 +694,7 @@ case class Asinh(child: Expression) s"$c == Double.NEGATIVE_INFINITY ? Double.NEGATIVE_INFINITY : " + s"java.lang.StrictMath.log($c + java.lang.Math.sqrt($c * $c + 1.0))") } + override protected def withNewChildInternal(newChild: Expression): Asinh = copy(child = newChild) } @ExpressionDescription( @@ -667,7 +706,9 @@ case class Asinh(child: Expression) """, since = "1.1.1", group = "math_funcs") -case class Sqrt(child: Expression) extends UnaryMathExpression(math.sqrt, "SQRT") +case class Sqrt(child: Expression) extends UnaryMathExpression(math.sqrt, "SQRT") { + override protected def withNewChildInternal(newChild: Expression): Sqrt = copy(child = newChild) +} @ExpressionDescription( usage = """ @@ -684,7 +725,9 @@ case class Sqrt(child: Expression) extends UnaryMathExpression(math.sqrt, "SQRT" """, since = "1.4.0", group = "math_funcs") -case class Tan(child: Expression) extends UnaryMathExpression(math.tan, "TAN") +case class Tan(child: Expression) extends UnaryMathExpression(math.tan, "TAN") { + override protected def withNewChildInternal(newChild: Expression): Tan = copy(child = newChild) +} @ExpressionDescription( usage = """ @@ -706,6 +749,7 @@ case class Cot(child: Expression) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, c => s"${ev.value} = 1 / java.lang.Math.tan($c);") } + override protected def withNewChildInternal(newChild: Expression): Cot = copy(child = newChild) } @ExpressionDescription( @@ -724,7 +768,9 @@ case class Cot(child: Expression) """, since = "1.4.0", group = "math_funcs") -case class Tanh(child: Expression) extends UnaryMathExpression(math.tanh, "TANH") +case class Tanh(child: Expression) extends UnaryMathExpression(math.tanh, "TANH") { + override protected def withNewChildInternal(newChild: Expression): Tanh = copy(child = newChild) +} @ExpressionDescription( usage = """ @@ -747,6 +793,7 @@ case class Atanh(child: Expression) defineCodeGen(ctx, ev, c => s"0.5 * (java.lang.StrictMath.log1p($c) - java.lang.StrictMath.log1p(- $c))") } + override protected def withNewChildInternal(newChild: Expression): Atanh = copy(child = newChild) } @ExpressionDescription( @@ -764,6 +811,8 @@ case class Atanh(child: Expression) group = "math_funcs") case class ToDegrees(child: Expression) extends UnaryMathExpression(math.toDegrees, "DEGREES") { override def funcName: String = "toDegrees" + override protected def withNewChildInternal(newChild: Expression): ToDegrees = + copy(child = newChild) } @ExpressionDescription( @@ -781,6 +830,8 @@ case class ToDegrees(child: Expression) extends UnaryMathExpression(math.toDegre group = "math_funcs") case class ToRadians(child: Expression) extends UnaryMathExpression(math.toRadians, "RADIANS") { override def funcName: String = "toRadians" + override protected def withNewChildInternal(newChild: Expression): ToRadians = + copy(child = newChild) } // scalastyle:off line.size.limit @@ -811,6 +862,8 @@ case class Bin(child: Expression) defineCodeGen(ctx, ev, (c) => s"UTF8String.fromString(java.lang.Long.toBinaryString($c))") } + + override protected def withNewChildInternal(newChild: Expression): Bin = copy(child = newChild) } object Hex { @@ -923,6 +976,8 @@ case class Hex(child: Expression) }) }) } + + override protected def withNewChildInternal(newChild: Expression): Hex = copy(child = newChild) } /** @@ -958,6 +1013,8 @@ case class Unhex(child: Expression) """ }) } + + override protected def withNewChildInternal(newChild: Expression): Unhex = copy(child = newChild) } @@ -996,6 +1053,9 @@ case class Atan2(left: Expression, right: Expression) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.atan2($c1 + 0.0, $c2 + 0.0)") } + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): Expression = copy(left = newLeft, right = newRight) } @ExpressionDescription( @@ -1012,6 +1072,8 @@ case class Pow(left: Expression, right: Expression) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.StrictMath.pow($c1, $c2)") } + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): Expression = copy(left = newLeft, right = newRight) } @@ -1048,6 +1110,9 @@ case class ShiftLeft(left: Expression, right: Expression) override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, (left, right) => s"$left << $right") } + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): ShiftLeft = copy(left = newLeft, right = newRight) } @@ -1084,6 +1149,9 @@ case class ShiftRight(left: Expression, right: Expression) override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, (left, right) => s"$left >> $right") } + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): ShiftRight = copy(left = newLeft, right = newRight) } @@ -1120,6 +1188,10 @@ case class ShiftRightUnsigned(left: Expression, right: Expression) override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, (left, right) => s"$left >>> $right") } + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): ShiftRightUnsigned = + copy(left = newLeft, right = newRight) } @ExpressionDescription( @@ -1132,7 +1204,10 @@ case class ShiftRightUnsigned(left: Expression, right: Expression) since = "1.4.0", group = "math_funcs") case class Hypot(left: Expression, right: Expression) - extends BinaryMathExpression(math.hypot, "HYPOT") + extends BinaryMathExpression(math.hypot, "HYPOT") { + override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Hypot = + copy(left = newLeft, right = newRight) +} /** @@ -1190,6 +1265,9 @@ case class Logarithm(left: Expression, right: Expression) """) } } + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): Logarithm = copy(left = newLeft, right = newRight) } /** @@ -1387,6 +1465,8 @@ case class Round(child: Expression, scale: Expression) extends RoundBase(child, scale, BigDecimal.RoundingMode.HALF_UP, "ROUND_HALF_UP") with Serializable with ImplicitCastInputTypes { def this(child: Expression) = this(child, Literal(0)) + override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Round = + copy(child = newLeft, scale = newRight) } /** @@ -1409,6 +1489,8 @@ case class BRound(child: Expression, scale: Expression) extends RoundBase(child, scale, BigDecimal.RoundingMode.HALF_EVEN, "ROUND_HALF_EVEN") with Serializable with ImplicitCastInputTypes { def this(child: Expression) = this(child, Literal(0)) + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): BRound = copy(child = newLeft, scale = newRight) } object WidthBucket { @@ -1511,4 +1593,8 @@ case class WidthBucket( override def second: Expression = minValue override def third: Expression = maxValue override def fourth: Expression = numBucket + + override protected def withNewChildrenInternal( + first: Expression, second: Expression, third: Expression, fourth: Expression): WidthBucket = + copy(value = first, minValue = second, maxValue = third, numBucket = fourth) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 6b3b949af24cf..9e854cf5fd891 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -51,6 +51,9 @@ case class PrintToStderr(child: Expression) extends UnaryExpression { | ${ev.value} = $c; """.stripMargin) } + + override protected def withNewChildInternal(newChild: Expression): PrintToStderr = + copy(child = newChild) } /** @@ -100,6 +103,9 @@ case class RaiseError(child: Expression, dataType: DataType) value = JavaCode.defaultLiteral(dataType) ) } + + override protected def withNewChildInternal(newChild: Expression): RaiseError = + copy(child = newChild) } object RaiseError { @@ -133,6 +139,9 @@ case class AssertTrue(left: Expression, right: Expression, child: Expression) override def flatArguments: Iterator[Any] = Iterator(left, right) override def exprsReplaced: Seq[Expression] = Seq(left, right) + + override protected def withNewChildInternal(newChild: Expression): AssertTrue = + copy(child = newChild) } object AssertTrue { @@ -268,4 +277,6 @@ case class TypeOf(child: Expression) extends UnaryExpression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, _ => s"""UTF8String.fromString(${child.dataType.catalogString})""") } + + override protected def withNewChildInternal(newChild: Expression): TypeOf = copy(child = newChild) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index e73b024dd18c2..b73a189027bfe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -226,6 +226,9 @@ case class Alias(child: Expression, name: String)( if (qualifier.nonEmpty) qualifier.map(quoteIfNeeded).mkString(".") + "." else "" s"${child.sql} AS $qualifierPrefix${quoteIfNeeded(name)}" } + + override protected def withNewChildInternal(newChild: Expression): Alias = + copy(child = newChild)(exprId, qualifier, explicitMetadata, nonInheritableMetadataKeys) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index d508129c190b9..2c2df6bf438b6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -120,6 +120,9 @@ case class Coalesce(children: Seq[Expression]) extends ComplexTypeMergingExpress |} while (false); """.stripMargin) } + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Coalesce = + copy(children = newChildren) } @@ -141,6 +144,8 @@ case class IfNull(left: Expression, right: Expression, child: Expression) override def flatArguments: Iterator[Any] = Iterator(left, right) override def exprsReplaced: Seq[Expression] = Seq(left, right) + + override protected def withNewChildInternal(newChild: Expression): IfNull = copy(child = newChild) } @@ -162,6 +167,8 @@ case class NullIf(left: Expression, right: Expression, child: Expression) override def flatArguments: Iterator[Any] = Iterator(left, right) override def exprsReplaced: Seq[Expression] = Seq(left, right) + + override protected def withNewChildInternal(newChild: Expression): NullIf = copy(child = newChild) } @@ -182,6 +189,8 @@ case class Nvl(left: Expression, right: Expression, child: Expression) extends R override def flatArguments: Iterator[Any] = Iterator(left, right) override def exprsReplaced: Seq[Expression] = Seq(left, right) + + override protected def withNewChildInternal(newChild: Expression): Nvl = copy(child = newChild) } @@ -205,6 +214,8 @@ case class Nvl2(expr1: Expression, expr2: Expression, expr3: Expression, child: override def flatArguments: Iterator[Any] = Iterator(expr1, expr2, expr3) override def exprsReplaced: Seq[Expression] = Seq(expr1, expr2, expr3) + + override protected def withNewChildInternal(newChild: Expression): Nvl2 = copy(child = newChild) } @@ -249,6 +260,8 @@ case class IsNaN(child: Expression) extends UnaryExpression ${ev.value} = !${eval.isNull} && Double.isNaN(${eval.value});""", isNull = FalseLiteral) } } + + override protected def withNewChildInternal(newChild: Expression): IsNaN = copy(child = newChild) } /** @@ -311,6 +324,9 @@ case class NaNvl(left: Expression, right: Expression) }""") } } + + override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): NaNvl = + copy(left = newLeft, right = newRight) } @@ -339,6 +355,8 @@ case class IsNull(child: Expression) extends UnaryExpression with Predicate { } override def sql: String = s"(${child.sql} IS NULL)" + + override protected def withNewChildInternal(newChild: Expression): IsNull = copy(child = newChild) } @@ -374,6 +392,9 @@ case class IsNotNull(child: Expression) extends UnaryExpression with Predicate { } override def sql: String = s"(${child.sql} IS NOT NULL)" + + override protected def withNewChildInternal(newChild: Expression): IsNotNull = + copy(child = newChild) } @@ -466,4 +487,7 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate |${CodeGenerator.JAVA_BOOLEAN} ${ev.value} = $nonnull >= $n; """.stripMargin, isNull = FalseLiteral) } + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): AtLeastNNonNulls = copy(children = newChildren) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 5be521683381d..5ae0cef7b400c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions.objects import java.lang.reflect.{Method, Modifier} import scala.collection.JavaConverters._ -import scala.collection.mutable.{Builder, IndexedSeq, WrappedArray} +import scala.collection.mutable.{Builder, WrappedArray} import scala.reflect.ClassTag import scala.util.{Properties, Try} @@ -279,6 +279,9 @@ case class StaticInvoke( """ ev.copy(code = code) } + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = + copy(arguments = newChildren) } /** @@ -400,6 +403,9 @@ case class Invoke( } override def toString: String = s"$targetObject.$functionName" + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Invoke = + copy(targetObject = newChildren.head, arguments = newChildren.tail) } object NewInstance { @@ -506,6 +512,9 @@ case class NewInstance( } override def toString: String = s"newInstance($cls)" + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): NewInstance = + copy(arguments = newChildren) } /** @@ -543,6 +552,9 @@ case class UnwrapOption( """ ev.copy(code = code) } + + override protected def withNewChildInternal(newChild: Expression): UnwrapOption = + copy(child = newChild) } /** @@ -573,6 +585,9 @@ case class WrapOption(child: Expression, optType: DataType) """ ev.copy(code = code, isNull = FalseLiteral) } + + override protected def withNewChildInternal(newChild: Expression): WrapOption = + copy(child = newChild) } object LambdaVariable { @@ -659,6 +674,9 @@ case class UnresolvedMapObjects( override def dataType: DataType = customCollectionCls.map(ObjectType.apply).getOrElse { throw QueryExecutionErrors.customCollectionClsNotResolvedError } + + override protected def withNewChildInternal(newChild: Expression): UnresolvedMapObjects = + copy(child = newChild) } object MapObjects { @@ -1025,6 +1043,13 @@ case class MapObjects private( """ ev.copy(code = code, isNull = genInputData.isNull) } + + override protected def withNewChildrenInternal( + newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = + copy( + loopVar = newFirst.asInstanceOf[LambdaVariable], + lambdaFunction = newSecond, + inputData = newThird) } /** @@ -1044,6 +1069,9 @@ case class UnresolvedCatalystToExternalMap( override lazy val resolved = false override def dataType: DataType = ObjectType(collClass) + + override protected def withNewChildInternal( + newChild: Expression): UnresolvedCatalystToExternalMap = copy(child = newChild) } object CatalystToExternalMap { @@ -1214,6 +1242,15 @@ case class CatalystToExternalMap private( """ ev.copy(code = code, isNull = genInputData.isNull) } + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): CatalystToExternalMap = + copy( + keyLoopVar = newChildren(0).asInstanceOf[LambdaVariable], + keyLambdaFunction = newChildren(1), + valueLoopVar = newChildren(2).asInstanceOf[LambdaVariable], + valueLambdaFunction = newChildren(3), + inputData = newChildren(4)) } object ExternalMapToCatalyst { @@ -1437,6 +1474,15 @@ case class ExternalMapToCatalyst private( """ ev.copy(code = code, isNull = inputMap.isNull) } + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): ExternalMapToCatalyst = + copy( + keyLoopVar = newChildren(0).asInstanceOf[LambdaVariable], + keyConverter = newChildren(1), + valueLoopVar = newChildren(2).asInstanceOf[LambdaVariable], + valueConverter = newChildren(3), + inputData = newChildren(4)) } /** @@ -1487,6 +1533,9 @@ case class CreateExternalRow(children: Seq[Expression], schema: StructType) """.stripMargin ev.copy(code = code, isNull = FalseLiteral) } + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): CreateExternalRow = copy(children = newChildren) } /** @@ -1516,6 +1565,9 @@ case class EncodeUsingSerializer(child: Expression, kryo: Boolean) } override def dataType: DataType = BinaryType + + override protected def withNewChildInternal(newChild: Expression): EncodeUsingSerializer = + copy(child = newChild) } /** @@ -1548,6 +1600,9 @@ case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: B } override def dataType: DataType = ObjectType(tag.runtimeClass) + + override protected def withNewChildInternal(newChild: Expression): DecodeUsingSerializer[T] = + copy(child = newChild) } /** @@ -1629,6 +1684,10 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp """.stripMargin ev.copy(code = code, isNull = instanceGen.isNull, value = instanceGen.value) } + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): InitializeJavaBean = + super.legacyWithNewChildren(newChildren).asInstanceOf[InitializeJavaBean] } /** @@ -1676,6 +1735,9 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String] = Nil) """ ev.copy(code = code, isNull = FalseLiteral, value = childGen.value) } + + override protected def withNewChildInternal(newChild: Expression): AssertNotNull = + copy(child = newChild) } /** @@ -1727,6 +1789,9 @@ case class GetExternalRowField( """ ev.copy(code = code, isNull = FalseLiteral) } + + override protected def withNewChildInternal(newChild: Expression): GetExternalRowField = + copy(child = newChild) } /** @@ -1801,4 +1866,7 @@ case class ValidateExternalType(child: Expression, expected: DataType) """ ev.copy(code = code, isNull = input.isNull) } + + override protected def withNewChildInternal(newChild: Expression): ValidateExternalType = + copy(child = newChild) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 33eb120e009ed..d9d0643a9130c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -322,6 +322,8 @@ case class Not(child: Expression) } override def sql: String = s"(NOT ${child.sql})" + + override protected def withNewChildInternal(newChild: Expression): Not = copy(child = newChild) } /** @@ -379,6 +381,9 @@ case class InSubquery(values: Seq[Expression], query: ListQuery) override def nullable: Boolean = children.exists(_.nullable) override def toString: String = s"$value IN ($query)" override def sql: String = s"(${value.sql} IN (${query.sql}))" + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): InSubquery = + copy(values = newChildren.dropRight(1), query = newChildren.last.asInstanceOf[ListQuery]) } @@ -520,6 +525,9 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { val listSQL = list.map(_.sql).mkString(", ") s"($valueSQL IN ($listSQL))" } + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): In = + copy(value = newChildren.head, list = newChildren.tail) } /** @@ -625,6 +633,8 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with .mkString(", ") s"($valueSQL IN ($listSQL))" } + + override protected def withNewChildInternal(newChild: Expression): InSet = copy(child = newChild) } @ExpressionDescription( @@ -708,6 +718,9 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with """) } } + + override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): And = + copy(left = newLeft, right = newRight) } @ExpressionDescription( @@ -792,6 +805,9 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P """) } } + + override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Or = + copy(left = newLeft, right = newRight) } @@ -877,6 +893,9 @@ case class EqualTo(left: Expression, right: Expression) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, (c1, c2) => ctx.genEqual(left.dataType, c1, c2)) } + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): EqualTo = copy(left = newLeft, right = newRight) } // TODO: although map type is not orderable, technically map type should be able to be used @@ -938,6 +957,10 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp boolean ${ev.value} = (${eval1.isNull} && ${eval2.isNull}) || (!${eval1.isNull} && !${eval2.isNull} && $equalCode);""", isNull = FalseLiteral) } + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): EqualNullSafe = + copy(left = newLeft, right = newRight) } @ExpressionDescription( @@ -970,6 +993,9 @@ case class LessThan(left: Expression, right: Expression) override def symbol: String = "<" protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.lt(input1, input2) + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): Expression = copy(left = newLeft, right = newRight) } @ExpressionDescription( @@ -1002,6 +1028,9 @@ case class LessThanOrEqual(left: Expression, right: Expression) override def symbol: String = "<=" protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.lteq(input1, input2) + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): Expression = copy(left = newLeft, right = newRight) } @ExpressionDescription( @@ -1034,6 +1063,9 @@ case class GreaterThan(left: Expression, right: Expression) override def symbol: String = ">" protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.gt(input1, input2) + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): Expression = copy(left = newLeft, right = newRight) } @ExpressionDescription( @@ -1066,6 +1098,10 @@ case class GreaterThanOrEqual(left: Expression, right: Expression) override def symbol: String = ">=" protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.gteq(input1, input2) + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): GreaterThanOrEqual = + copy(left = newLeft, right = newRight) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index 0a4c6e27d51d9..d470cadff85b7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -111,6 +111,8 @@ case class Rand(child: Expression, hideSeed: Boolean = false) extends RDG { override def sql: String = { s"rand(${if (hideSeed) "" else child.sql})" } + + override protected def withNewChildInternal(newChild: Expression): Rand = copy(child = newChild) } object Rand { @@ -162,6 +164,8 @@ case class Randn(child: Expression, hideSeed: Boolean = false) extends RDG { override def sql: String = { s"randn(${if (hideSeed) "" else child.sql})" } + + override protected def withNewChildInternal(newChild: Expression): Randn = copy(child = newChild) } object Randn { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index 9fdab350ceb95..13d00faea37f9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -180,6 +180,9 @@ case class Like(left: Expression, right: Expression, escapeChar: Char) }) } } + + override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Like = + copy(left = newLeft, right = newRight) } sealed abstract class MultiLikeBase @@ -268,10 +271,14 @@ sealed abstract class LikeAllBase extends MultiLikeBase { case class LikeAll(child: Expression, patterns: Seq[UTF8String]) extends LikeAllBase { override def isNotSpecified: Boolean = false + override protected def withNewChildInternal(newChild: Expression): LikeAll = + copy(child = newChild) } case class NotLikeAll(child: Expression, patterns: Seq[UTF8String]) extends LikeAllBase { override def isNotSpecified: Boolean = true + override protected def withNewChildInternal(newChild: Expression): NotLikeAll = + copy(child = newChild) } /** @@ -324,10 +331,14 @@ sealed abstract class LikeAnyBase extends MultiLikeBase { case class LikeAny(child: Expression, patterns: Seq[UTF8String]) extends LikeAnyBase { override def isNotSpecified: Boolean = false + override protected def withNewChildInternal(newChild: Expression): LikeAny = + copy(child = newChild) } case class NotLikeAny(child: Expression, patterns: Seq[UTF8String]) extends LikeAnyBase { override def isNotSpecified: Boolean = true + override protected def withNewChildInternal(newChild: Expression): NotLikeAny = + copy(child = newChild) } // scalastyle:off line.contains.tab @@ -409,6 +420,9 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress }) } } + + override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): RLike = + copy(left = newLeft, right = newRight) } @@ -467,6 +481,10 @@ case class StringSplit(str: Expression, regex: Expression, limit: Expression) } override def prettyName: String = "split" + + override protected def withNewChildrenInternal( + newFirst: Expression, newSecond: Expression, newThird: Expression): StringSplit = + copy(str = newFirst, regex = newSecond, limit = newThird) } @@ -622,6 +640,10 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio override def second: Expression = regexp override def third: Expression = rep override def fourth: Expression = pos + + override protected def withNewChildrenInternal( + first: Expression, second: Expression, third: Expression, fourth: Expression): RegExpReplace = + copy(subject = first, regexp = second, rep = third, pos = fourth) } object RegExpReplace { @@ -765,6 +787,10 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio }""" }) } + + override protected def withNewChildrenInternal( + newFirst: Expression, newSecond: Expression, newThird: Expression): RegExpExtract = + copy(subject = newFirst, regexp = newSecond, idx = newThird) } /** @@ -868,4 +894,8 @@ case class RegExpExtractAll(subject: Expression, regexp: Expression, idx: Expres """ }) } + + override protected def withNewChildrenInternal( + newFirst: Expression, newSecond: Expression, newThird: Expression): RegExpExtractAll = + copy(subject = newFirst, regexp = newSecond, idx = newThird) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 714f1d6dc4bfc..3d5f812af9c2e 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -227,6 +227,9 @@ case class ConcatWs(children: Seq[Expression]) """) } } + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): ConcatWs = + copy(children = newChildren) } /** @@ -366,6 +369,9 @@ case class Elt( |final boolean ${ev.isNull} = ${ev.value} == null; """.stripMargin) } + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Elt = + copy(children = newChildren) } @@ -403,6 +409,8 @@ case class Upper(child: Expression) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, c => s"($c).toUpperCase()") } + + override protected def withNewChildInternal(newChild: Expression): Upper = copy(child = newChild) } /** @@ -430,6 +438,8 @@ case class Lower(child: Expression) override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("lower") + + override protected def withNewChildInternal(newChild: Expression): Lower = copy(child = newChild) } /** A base trait for functions that compare two strings, returning a boolean. */ @@ -454,6 +464,8 @@ case class Contains(left: Expression, right: Expression) extends StringPredicate override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, (c1, c2) => s"($c1).contains($c2)") } + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): Contains = copy(left = newLeft, right = newRight) } /** @@ -464,6 +476,8 @@ case class StartsWith(left: Expression, right: Expression) extends StringPredica override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, (c1, c2) => s"($c1).startsWith($c2)") } + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): StartsWith = copy(left = newLeft, right = newRight) } /** @@ -474,6 +488,8 @@ case class EndsWith(left: Expression, right: Expression) extends StringPredicate override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, (c1, c2) => s"($c1).endsWith($c2)") } + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): EndsWith = copy(left = newLeft, right = newRight) } /** @@ -522,6 +538,10 @@ case class StringReplace(srcExpr: Expression, searchExpr: Expression, replaceExp override def third: Expression = replaceExpr override def prettyName: String = "replace" + + override protected def withNewChildrenInternal( + newFirst: Expression, newSecond: Expression, newThird: Expression): StringReplace = + copy(srcExpr = newFirst, searchExpr = newSecond, replaceExpr = newThird) } object Overlay { @@ -634,6 +654,10 @@ case class Overlay(input: Expression, replace: Expression, pos: Expression, len: override def second: Expression = replace override def third: Expression = pos override def fourth: Expression = len + + override protected def withNewChildrenInternal( + first: Expression, second: Expression, third: Expression, fourth: Expression): Overlay = + copy(input = first, replace = second, pos = third, len = fourth) } object StringTranslate { @@ -731,6 +755,10 @@ case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replac override def second: Expression = matchingExpr override def third: Expression = replaceExpr override def prettyName: String = "translate" + + override protected def withNewChildrenInternal( + newFirst: Expression, newSecond: Expression, newThird: Expression): StringTranslate = + copy(srcExpr = newFirst, matchingExpr = newSecond, replaceExpr = newThird) } /** @@ -769,6 +797,9 @@ case class FindInSet(left: Expression, right: Expression) extends BinaryExpressi override def dataType: DataType = IntegerType override def prettyName: String = "find_in_set" + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): FindInSet = copy(left = newLeft, right = newRight) } trait String2TrimExpression extends Expression with ImplicitCastInputTypes { @@ -926,6 +957,11 @@ case class StringTrim(srcStr: Expression, trimStr: Option[Expression] = None) srcString.trim(trimString) override val trimMethod: String = "trim" + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = + copy( + srcStr = newChildren.head, + trimStr = if (trimStr.isDefined) Some(newChildren.last) else None) } /** @@ -974,6 +1010,9 @@ case class StringTrimBoth(srcStr: Expression, trimStr: Option[Expression], child override def flatArguments: Iterator[Any] = Iterator(srcStr, trimStr) override def prettyName: String = "btrim" + + override protected def withNewChildInternal(newChild: Expression): StringTrimBoth = + copy(child = newChild) } object StringTrimLeft { @@ -1027,6 +1066,12 @@ case class StringTrimLeft(srcStr: Expression, trimStr: Option[Expression] = None srcString.trimLeft(trimString) override val trimMethod: String = "trimLeft" + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): StringTrimLeft = + copy( + srcStr = newChildren.head, + trimStr = if (trimStr.isDefined) Some(newChildren.last) else None) } object StringTrimRight { @@ -1082,6 +1127,12 @@ case class StringTrimRight(srcStr: Expression, trimStr: Option[Expression] = Non srcString.trimRight(trimString) override val trimMethod: String = "trimRight" + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): StringTrimRight = + copy( + srcStr = newChildren.head, + trimStr = if (trimStr.isDefined) Some(newChildren.last) else None) } /** @@ -1120,6 +1171,9 @@ case class StringInstr(str: Expression, substr: Expression) defineCodeGen(ctx, ev, (l, r) => s"($l).indexOf($r, 0) + 1") } + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): StringInstr = copy(str = newLeft, substr = newRight) } /** @@ -1164,6 +1218,10 @@ case class SubstringIndex(strExpr: Expression, delimExpr: Expression, countExpr: override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, (str, delim, count) => s"$str.subStringIndex($delim, $count)") } + + override protected def withNewChildrenInternal( + newFirst: Expression, newSecond: Expression, newThird: Expression): SubstringIndex = + copy(strExpr = newFirst, delimExpr = newSecond, countExpr = newThird) } /** @@ -1258,6 +1316,11 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression) override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("locate") + + override protected def withNewChildrenInternal( + newFirst: Expression, newSecond: Expression, newThird: Expression): StringLocate = + copy(substr = newFirst, str = newSecond, start = newThird) + } /** @@ -1302,6 +1365,10 @@ case class StringLPad(str: Expression, len: Expression, pad: Expression = Litera } override def prettyName: String = "lpad" + + override protected def withNewChildrenInternal( + newFirst: Expression, newSecond: Expression, newThird: Expression): StringLPad = + copy(str = newFirst, len = newSecond, pad = newThird) } /** @@ -1347,6 +1414,10 @@ case class StringRPad(str: Expression, len: Expression, pad: Expression = Litera } override def prettyName: String = "rpad" + + override protected def withNewChildrenInternal( + newFirst: Expression, newSecond: Expression, newThird: Expression): StringRPad = + copy(str = newFirst, len = newSecond, pad = newThird) } object ParseUrl { @@ -1519,6 +1590,9 @@ case class ParseUrl(children: Seq[Expression], failOnError: Boolean = SQLConf.ge } } } + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): ParseUrl = + copy(children = newChildren) } /** @@ -1606,6 +1680,9 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC override def prettyName: String = getTagValue( FunctionRegistry.FUNC_ALIAS).getOrElse("format_string") + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): FormatString = FormatString(newChildren: _*) } /** @@ -1638,6 +1715,9 @@ case class InitCap(child: Expression) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, str => s"$str.toLowerCase().toTitleCase()") } + + override protected def withNewChildInternal(newChild: Expression): InitCap = + copy(child = newChild) } /** @@ -1669,6 +1749,9 @@ case class StringRepeat(str: Expression, times: Expression) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, (l, r) => s"($l).repeat($r)") } + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): StringRepeat = copy(str = newLeft, times = newRight) } /** @@ -1700,6 +1783,9 @@ case class StringSpace(child: Expression) } override def prettyName: String = "space" + + override protected def withNewChildInternal(newChild: Expression): StringSpace = + copy(child = newChild) } /** @@ -1767,6 +1853,11 @@ case class Substring(str: Expression, pos: Expression, len: Expression) } }) } + + override protected def withNewChildrenInternal( + newFirst: Expression, newSecond: Expression, newThird: Expression): Substring = + copy(str = newFirst, pos = newSecond, len = newThird) + } /** @@ -1791,6 +1882,8 @@ case class Right(str: Expression, len: Expression, child: Expression) extends Ru override def flatArguments: Iterator[Any] = Iterator(str, len) override def exprsReplaced: Seq[Expression] = Seq(str, len) + + override protected def withNewChildInternal(newChild: Expression): Right = copy(child = newChild) } /** @@ -1814,6 +1907,7 @@ case class Left(str: Expression, len: Expression, child: Expression) extends Run override def flatArguments: Iterator[Any] = Iterator(str, len) override def exprsReplaced: Seq[Expression] = Seq(str, len) + override protected def withNewChildInternal(newChild: Expression): Left = copy(child = newChild) } /** @@ -1851,6 +1945,8 @@ case class Length(child: Expression) case BinaryType => defineCodeGen(ctx, ev, c => s"($c).length") } } + + override protected def withNewChildInternal(newChild: Expression): Length = copy(child = newChild) } /** @@ -1883,6 +1979,9 @@ case class BitLength(child: Expression) } override def prettyName: String = "bit_length" + + override protected def withNewChildInternal(newChild: Expression): BitLength = + copy(child = newChild) } /** @@ -1916,6 +2015,9 @@ case class OctetLength(child: Expression) } override def prettyName: String = "octet_length" + + override protected def withNewChildInternal(newChild: Expression): OctetLength = + copy(child = newChild) } /** @@ -1943,6 +2045,9 @@ case class Levenshtein(left: Expression, right: Expression) extends BinaryExpres nullSafeCodeGen(ctx, ev, (left, right) => s"${ev.value} = $left.levenshteinDistance($right);") } + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): Levenshtein = copy(left = newLeft, right = newRight) } /** @@ -1969,6 +2074,9 @@ case class SoundEx(child: Expression) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, c => s"$c.soundex()") } + + override protected def withNewChildInternal(newChild: Expression): SoundEx = + copy(child = newChild) } /** @@ -2012,6 +2120,8 @@ case class Ascii(child: Expression) } """}) } + + override protected def withNewChildInternal(newChild: Expression): Ascii = copy(child = newChild) } /** @@ -2060,6 +2170,8 @@ case class Chr(child: Expression) """ }) } + + override protected def withNewChildInternal(newChild: Expression): Chr = copy(child = newChild) } /** @@ -2090,6 +2202,8 @@ case class Base64(child: Expression) ${classOf[CommonsBase64].getName}.encodeBase64($child)); """}) } + + override protected def withNewChildInternal(newChild: Expression): Base64 = copy(child = newChild) } /** @@ -2119,6 +2233,9 @@ case class UnBase64(child: Expression) ${ev.value} = ${classOf[CommonsBase64].getName}.decodeBase64($child.toString()); """}) } + + override protected def withNewChildInternal(newChild: Expression): UnBase64 = + copy(child = newChild) } object Decode { @@ -2178,6 +2295,8 @@ case class Decode(params: Seq[Expression], child: Expression) extends RuntimeRep override def flatArguments: Iterator[Any] = Iterator(params) override def exprsReplaced: Seq[Expression] = params + + override protected def withNewChildInternal(newChild: Expression): Decode = copy(child = newChild) } /** @@ -2219,6 +2338,10 @@ case class StringDecode(bin: Expression, charset: Expression) } """) } + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): StringDecode = + copy(bin = newLeft, charset = newRight) } /** @@ -2259,6 +2382,9 @@ case class Encode(value: Expression, charset: Expression) org.apache.spark.unsafe.Platform.throwException(e); }""") } + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): Encode = copy(value = newLeft, charset = newRight) } /** @@ -2439,6 +2565,9 @@ case class FormatNumber(x: Expression, d: Expression) } override def prettyName: String = "format_number" + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): FormatNumber = copy(x = newLeft, d = newRight) } /** @@ -2509,4 +2638,9 @@ case class Sentences( } new GenericArrayData(result.toSeq) } + + override protected def withNewChildrenInternal( + newFirst: Expression, newSecond: Expression, newThird: Expression): Sentences = + copy(str = newFirst, language = newSecond, country = newThird) + } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index ff8856708c6d1..ea6e427a95b5c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -238,6 +238,9 @@ case class ScalarSubquery( children.map(_.canonicalized), ExprId(0)) } + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): ScalarSubquery = copy(children = newChildren) } object ScalarSubquery { @@ -283,6 +286,9 @@ case class ListQuery( ExprId(0), childOutputs.map(_.canonicalized.asInstanceOf[Attribute])) } + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): ListQuery = + copy(children = newChildren) } /** @@ -325,4 +331,7 @@ case class Exists( children.map(_.canonicalized), ExprId(0)) } + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Exists = + copy(children = newChildren) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index fa027d1ab0561..ff486bfbdef75 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -47,6 +47,13 @@ case class WindowSpecDefinition( override def children: Seq[Expression] = partitionSpec ++ orderSpec :+ frameSpecification + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): WindowSpecDefinition = + copy( + partitionSpec = newChildren.take(partitionSpec.size), + orderSpec = newChildren.drop(partitionSpec.size).dropRight(1).asInstanceOf[Seq[SortOrder]], + frameSpecification = newChildren.last.asInstanceOf[WindowFrame]) + override lazy val resolved: Boolean = childrenResolved && checkInputDataTypes().isSuccess && frameSpecification.isInstanceOf[SpecifiedWindowFrame] @@ -266,6 +273,10 @@ case class SpecifiedWindowFrame( case _ => true } } + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): SpecifiedWindowFrame = + copy(lower = newLeft, upper = newRight) } case class UnresolvedWindowExpression( @@ -275,6 +286,9 @@ case class UnresolvedWindowExpression( override def dataType: DataType = throw new UnresolvedException("dataType") override def nullable: Boolean = throw new UnresolvedException("nullable") override lazy val resolved = false + + override protected def withNewChildInternal(newChild: Expression): UnresolvedWindowExpression = + copy(child = newChild) } case class WindowExpression( @@ -290,6 +304,10 @@ case class WindowExpression( override def toString: String = s"$windowFunction $windowSpec" override def sql: String = windowFunction.sql + " OVER " + windowSpec.sql + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): WindowExpression = + copy(windowFunction = newLeft, windowSpec = newRight.asInstanceOf[WindowSpecDefinition]) } /** @@ -458,6 +476,10 @@ case class Lead( override def first: Expression = input override def second: Expression = offset override def third: Expression = default + + override protected def withNewChildrenInternal( + newFirst: Expression, newSecond: Expression, newThird: Expression): Lead = + copy(input = newFirst, offset = newSecond, default = newThird) } /** @@ -513,6 +535,10 @@ case class Lag( override def first: Expression = input override def second: Expression = inputOffset override def third: Expression = default + + override protected def withNewChildrenInternal( + newFirst: Expression, newSecond: Expression, newThird: Expression): Lag = + copy(input = newFirst, inputOffset = newSecond, default = newThird) } abstract class AggregateWindowFunction extends DeclarativeAggregate with WindowFunction { @@ -698,6 +724,10 @@ case class NthValue(input: Expression, offset: Expression, ignoreNulls: Boolean) override def prettyName: String = "nth_value" override def sql: String = s"$prettyName(${input.sql}, ${offset.sql})${if (ignoreNulls) " ignore nulls" else ""}" + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): NthValue = + copy(input = newLeft, offset = newRight) } /** @@ -800,6 +830,9 @@ case class NTile(buckets: Expression) extends RowNumberLike with SizeBasedWindow ) override val evaluateExpression = bucket + + override protected def withNewChildInternal( + newChild: Expression): NTile = copy(buckets = newChild) } /** @@ -884,6 +917,8 @@ abstract class RankLike extends AggregateWindowFunction { case class Rank(children: Seq[Expression]) extends RankLike { def this() = this(Nil) override def withOrder(order: Seq[Expression]): Rank = Rank(order) + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Rank = + copy(children = newChildren) } /** @@ -925,6 +960,8 @@ case class DenseRank(children: Seq[Expression]) extends RankLike { override val aggBufferAttributes = rank +: orderAttrs override val initialValues = zero +: orderInit override def prettyName: String = "dense_rank" + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): DenseRank = + copy(children = newChildren) } /** @@ -966,4 +1003,6 @@ case class PercentRank(children: Seq[Expression]) extends RankLike with SizeBase override val evaluateExpression = If(n > one, (rank - one).cast(DoubleType) / (n - one).cast(DoubleType), 0.0d) override def prettyName: String = "percent_rank" + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): PercentRank = + copy(children = newChildren) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala index b8fc830f18183..336dc7a480cff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala @@ -75,6 +75,9 @@ case class XPathBoolean(xml: Expression, path: Expression) extends XPathExtract override def nullSafeEval(xml: Any, path: Any): Any = { xpathUtil.evalBoolean(xml.asInstanceOf[UTF8String].toString, pathString) } + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): XPathBoolean = copy(xml = newLeft, path = newRight) } // scalastyle:off line.size.limit @@ -96,6 +99,9 @@ case class XPathShort(xml: Expression, path: Expression) extends XPathExtract { val ret = xpathUtil.evalNumber(xml.asInstanceOf[UTF8String].toString, pathString) if (ret eq null) null else ret.shortValue() } + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): XPathShort = copy(xml = newLeft, path = newRight) } // scalastyle:off line.size.limit @@ -117,6 +123,9 @@ case class XPathInt(xml: Expression, path: Expression) extends XPathExtract { val ret = xpathUtil.evalNumber(xml.asInstanceOf[UTF8String].toString, pathString) if (ret eq null) null else ret.intValue() } + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): Expression = copy(xml = newLeft, path = newRight) } // scalastyle:off line.size.limit @@ -138,6 +147,9 @@ case class XPathLong(xml: Expression, path: Expression) extends XPathExtract { val ret = xpathUtil.evalNumber(xml.asInstanceOf[UTF8String].toString, pathString) if (ret eq null) null else ret.longValue() } + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): XPathLong = copy(xml = newLeft, path = newRight) } // scalastyle:off line.size.limit @@ -159,6 +171,9 @@ case class XPathFloat(xml: Expression, path: Expression) extends XPathExtract { val ret = xpathUtil.evalNumber(xml.asInstanceOf[UTF8String].toString, pathString) if (ret eq null) null else ret.floatValue() } + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): XPathFloat = copy(xml = newLeft, path = newRight) } // scalastyle:off line.size.limit @@ -181,6 +196,9 @@ case class XPathDouble(xml: Expression, path: Expression) extends XPathExtract { val ret = xpathUtil.evalNumber(xml.asInstanceOf[UTF8String].toString, pathString) if (ret eq null) null else ret.doubleValue() } + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): XPathDouble = copy(xml = newLeft, path = newRight) } // scalastyle:off line.size.limit @@ -202,6 +220,9 @@ case class XPathString(xml: Expression, path: Expression) extends XPathExtract { val ret = xpathUtil.evalString(xml.asInstanceOf[UTF8String].toString, pathString) UTF8String.fromString(ret) } + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): Expression = copy(xml = newLeft, path = newRight) } // scalastyle:off line.size.limit @@ -233,4 +254,7 @@ case class XPathList(xml: Expression, path: Expression) extends XPathExtract { null } } + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): XPathList = copy(xml = newLeft, path = newRight) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala index 828f768f17701..2a288ffd8ecf3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala @@ -107,6 +107,9 @@ case class OrderedJoin( joinType: JoinType, condition: Option[Expression]) extends BinaryNode { override def output: Seq[Attribute] = left.output ++ right.output + override protected def withNewChildrenInternal( + newLeft: LogicalPlan, newRight: LogicalPlan): OrderedJoin = + copy(left = newLeft, right = newRight) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala index ac8766cd74367..a6444b13acd02 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala @@ -211,4 +211,7 @@ case class NormalizeNaNAndZero(child: Expression) extends UnaryExpression with E nullSafeCodeGen(ctx, ev, codeToNormalize) } + + override protected def withNewChildInternal(newChild: Expression): NormalizeNaNAndZero = + copy(child = newChild) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/EventTimeWatermark.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/EventTimeWatermark.scala index b6bf7cd85d472..bf3f93de97f8c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/EventTimeWatermark.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/EventTimeWatermark.scala @@ -61,4 +61,7 @@ case class EventTimeWatermark( a } } + + override protected def withNewChildInternal(newChild: LogicalPlan): EventTimeWatermark = + copy(child = newChild) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala index 30bff884b2249..6299976911ee4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala @@ -35,6 +35,9 @@ case class ScriptTransformation( ioschema: ScriptInputOutputSchema) extends UnaryNode { @transient override lazy val references: AttributeSet = AttributeSet(input.flatMap(_.references)) + + override protected def withNewChildInternal(newChild: LogicalPlan): ScriptTransformation = + copy(child = newChild) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 962ce938d2954..ba54be7679ec1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -41,6 +41,8 @@ import org.apache.spark.util.random.RandomSampler case class ReturnAnswer(child: LogicalPlan) extends UnaryNode { override def maxRows: Option[Long] = child.maxRows override def output: Seq[Attribute] = child.output + override protected def withNewChildInternal(newChild: LogicalPlan): ReturnAnswer = + copy(child = newChild) } /** @@ -52,6 +54,8 @@ case class ReturnAnswer(child: LogicalPlan) extends UnaryNode { */ case class Subquery(child: LogicalPlan, correlated: Boolean) extends OrderPreservingUnaryNode { override def output: Seq[Attribute] = child.output + override protected def withNewChildInternal(newChild: LogicalPlan): Subquery = + copy(child = newChild) } object Subquery { @@ -78,6 +82,9 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) override lazy val validConstraints: ExpressionSet = getAllValidConstraints(projectList) + + override protected def withNewChildInternal(newChild: LogicalPlan): Project = + copy(child = newChild) } /** @@ -136,6 +143,9 @@ case class Generate( } def output: Seq[Attribute] = requiredChildOutput ++ qualifiedGeneratorOutput + + override protected def withNewChildInternal(newChild: LogicalPlan): Generate = + copy(child = newChild) } case class Filter(condition: Expression, child: LogicalPlan) @@ -149,6 +159,9 @@ case class Filter(condition: Expression, child: LogicalPlan) .filterNot(SubqueryExpression.hasCorrelatedSubquery) child.constraints.union(ExpressionSet(predicates)) } + + override protected def withNewChildInternal(newChild: LogicalPlan): Filter = + copy(child = newChild) } abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { @@ -201,6 +214,9 @@ case class Intersect( Some(children.flatMap(_.maxRows).min) } } + + override protected def withNewChildrenInternal( + newLeft: LogicalPlan, newRight: LogicalPlan): Intersect = copy(left = newLeft, right = newRight) } case class Except( @@ -214,6 +230,9 @@ case class Except( override def metadataOutput: Seq[Attribute] = Nil override protected lazy val validConstraints: ExpressionSet = leftConstraints + + override protected def withNewChildrenInternal( + newLeft: LogicalPlan, newRight: LogicalPlan): Except = copy(left = newLeft, right = newRight) } /** Factory for constructing new `Union` nodes. */ @@ -326,6 +345,9 @@ case class Union( .map(child => rewriteConstraints(children.head.output, child.output, child.constraints)) .reduce(merge(_, _)) } + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[LogicalPlan]): Union = + copy(children = newChildren) } case class Join( @@ -436,6 +458,9 @@ case class Join( || e.asInstanceOf[JoinHint].leftHint.isDefined || e.asInstanceOf[JoinHint].rightHint.isDefined) } + + override protected def withNewChildrenInternal( + newLeft: LogicalPlan, newRight: LogicalPlan): Join = copy(left = newLeft, right = newRight) } /** @@ -461,6 +486,9 @@ case class InsertIntoDir( override def output: Seq[Attribute] = Seq.empty override def metadataOutput: Seq[Attribute] = Nil override lazy val resolved: Boolean = false + + override protected def withNewChildInternal(newChild: LogicalPlan): InsertIntoDir = + copy(child = newChild) } /** @@ -515,6 +543,9 @@ case class View( case _ => false } } + + override protected def withNewChildInternal(newChild: LogicalPlan): View = + copy(child = newChild) } object View { @@ -548,12 +579,16 @@ case class With(child: LogicalPlan, cteRelations: Seq[(String, SubqueryAlias)]) } override def innerChildren: Seq[LogicalPlan] = cteRelations.map(_._2) + + override protected def withNewChildInternal(newChild: LogicalPlan): With = copy(child = newChild) } case class WithWindowDefinition( windowDefinitions: Map[String, WindowSpecDefinition], child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output + override protected def withNewChildInternal(newChild: LogicalPlan): WithWindowDefinition = + copy(child = newChild) } /** @@ -569,6 +604,7 @@ case class Sort( override def output: Seq[Attribute] = child.output override def maxRows: Option[Long] = child.maxRows override def outputOrdering: Seq[SortOrder] = order + override protected def withNewChildInternal(newChild: LogicalPlan): Sort = copy(child = newChild) } /** Factory for constructing new `Range` nodes. */ @@ -739,6 +775,9 @@ case class Aggregate( val nonAgg = aggregateExpressions.filter(_.find(_.isInstanceOf[AggregateExpression]).isEmpty) getAllValidConstraints(nonAgg) } + + override protected def withNewChildInternal(newChild: LogicalPlan): Aggregate = + copy(child = newChild) } case class Window( @@ -753,6 +792,9 @@ case class Window( override def producedAttributes: AttributeSet = windowOutputSet def windowOutputSet: AttributeSet = AttributeSet(windowExpressions.map(_.toAttribute)) + + override protected def withNewChildInternal(newChild: LogicalPlan): Window = + copy(child = newChild) } object Expand { @@ -869,6 +911,9 @@ case class Expand( // This operator can reuse attributes (for example making them null when doing a roll up) so // the constraints of the child may no longer be valid. override protected lazy val validConstraints: ExpressionSet = ExpressionSet() + + override protected def withNewChildInternal(newChild: LogicalPlan): Expand = + copy(child = newChild) } /** @@ -901,6 +946,8 @@ case class Pivot( groupByExprsOpt.getOrElse(Seq.empty).map(_.toAttribute) ++ pivotAgg } override def metadataOutput: Seq[Attribute] = Nil + + override protected def withNewChildInternal(newChild: LogicalPlan): Pivot = copy(child = newChild) } /** @@ -950,6 +997,9 @@ case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends OrderP case _ => None } } + + override protected def withNewChildInternal(newChild: LogicalPlan): GlobalLimit = + copy(child = newChild) } /** @@ -967,6 +1017,9 @@ case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends OrderPr case _ => None } } + + override protected def withNewChildInternal(newChild: LogicalPlan): LocalLimit = + copy(child = newChild) } /** @@ -987,6 +1040,8 @@ case class Tail(limitExpr: Expression, child: LogicalPlan) extends OrderPreservi case _ => None } } + + override protected def withNewChildInternal(newChild: LogicalPlan): Tail = copy(child = newChild) } /** @@ -1013,6 +1068,9 @@ case class SubqueryAlias( } override def doCanonicalize(): LogicalPlan = child.canonicalized + + override protected def withNewChildInternal(newChild: LogicalPlan): SubqueryAlias = + copy(child = newChild) } object SubqueryAlias { @@ -1066,6 +1124,9 @@ case class Sample( override def maxRows: Option[Long] = child.maxRows override def output: Seq[Attribute] = child.output + + override protected def withNewChildInternal(newChild: LogicalPlan): Sample = + copy(child = newChild) } /** @@ -1074,6 +1135,8 @@ case class Sample( case class Distinct(child: LogicalPlan) extends UnaryNode { override def maxRows: Option[Long] = child.maxRows override def output: Seq[Attribute] = child.output + override protected def withNewChildInternal(newChild: LogicalPlan): Distinct = + copy(child = newChild) } /** @@ -1104,6 +1167,8 @@ case class Repartition(numPartitions: Int, shuffle: Boolean, child: LogicalPlan) case _ => RoundRobinPartitioning(numPartitions) } } + override protected def withNewChildInternal(newChild: LogicalPlan): Repartition = + copy(child = newChild) } /** @@ -1145,6 +1210,9 @@ case class RepartitionByExpression( } override def shuffle: Boolean = true + + override protected def withNewChildInternal(newChild: LogicalPlan): RepartitionByExpression = + copy(child = newChild) } object RepartitionByExpression { @@ -1178,6 +1246,8 @@ case class Deduplicate( child: LogicalPlan) extends UnaryNode { override def maxRows: Option[Long] = child.maxRows override def output: Seq[Attribute] = child.output + override protected def withNewChildInternal(newChild: LogicalPlan): Deduplicate = + copy(child = newChild) } /** @@ -1206,4 +1276,7 @@ case class CollectMetrics( } override def output: Seq[Attribute] = child.output + + override protected def withNewChildInternal(newChild: LogicalPlan): CollectMetrics = + copy(child = newChild) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala index 4b5e278fccdfb..5bda94cea9527 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala @@ -31,6 +31,9 @@ case class UnresolvedHint(name: String, parameters: Seq[Any], child: LogicalPlan override lazy val resolved: Boolean = false override def output: Seq[Attribute] = child.output + + override protected def withNewChildInternal(newChild: LogicalPlan): UnresolvedHint = + copy(child = newChild) } /** @@ -43,6 +46,9 @@ case class ResolvedHint(child: LogicalPlan, hints: HintInfo = HintInfo()) override def output: Seq[Attribute] = child.output override def doCanonicalize(): LogicalPlan = child.canonicalized + + override protected def withNewChildInternal(newChild: LogicalPlan): ResolvedHint = + copy(child = newChild) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index d383532cbd3d3..6d61a86ab5ef7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -79,7 +79,10 @@ trait ObjectConsumer extends UnaryNode { case class DeserializeToObject( deserializer: Expression, outputObjAttr: Attribute, - child: LogicalPlan) extends UnaryNode with ObjectProducer + child: LogicalPlan) extends UnaryNode with ObjectProducer { + override protected def withNewChildInternal(newChild: LogicalPlan): DeserializeToObject = + copy(child = newChild) +} /** * Takes the input object from child and turns it into unsafe row using the given serializer @@ -90,6 +93,9 @@ case class SerializeFromObject( child: LogicalPlan) extends ObjectConsumer { override def output: Seq[Attribute] = serializer.map(_.toAttribute) + + override protected def withNewChildInternal(newChild: LogicalPlan): SerializeFromObject = + copy(child = newChild) } object MapPartitions { @@ -111,7 +117,10 @@ object MapPartitions { case class MapPartitions( func: Iterator[Any] => Iterator[Any], outputObjAttr: Attribute, - child: LogicalPlan) extends ObjectConsumer with ObjectProducer + child: LogicalPlan) extends ObjectConsumer with ObjectProducer { + override protected def withNewChildInternal(newChild: LogicalPlan): MapPartitions = + copy(child = newChild) +} object MapPartitionsInR { def apply( @@ -159,6 +168,9 @@ case class MapPartitionsInR( override protected def stringArgs: Iterator[Any] = Iterator(inputSchema, outputSchema, outputObjAttr, child) + + override protected def withNewChildInternal(newChild: LogicalPlan): MapPartitionsInR = + copy(child = newChild) } /** @@ -182,6 +194,9 @@ case class MapPartitionsInRWithArrow( inputSchema, StructType.fromAttributes(output), child) override val producedAttributes = AttributeSet(output) + + override protected def withNewChildInternal(newChild: LogicalPlan): MapPartitionsInRWithArrow = + copy(child = newChild) } object MapElements { @@ -207,7 +222,10 @@ case class MapElements( argumentClass: Class[_], argumentSchema: StructType, outputObjAttr: Attribute, - child: LogicalPlan) extends ObjectConsumer with ObjectProducer + child: LogicalPlan) extends ObjectConsumer with ObjectProducer { + override protected def withNewChildInternal(newChild: LogicalPlan): MapElements = + copy(child = newChild) +} object TypedFilter { def apply[T : Encoder](func: AnyRef, child: LogicalPlan): TypedFilter = { @@ -251,6 +269,9 @@ case class TypedFilter( val funcObj = Literal.create(func, ObjectType(funcMethod._1)) Invoke(funcObj, funcMethod._2, BooleanType, input :: Nil) } + + override protected def withNewChildInternal(newChild: LogicalPlan): TypedFilter = + copy(child = newChild) } object FunctionUtils { @@ -334,6 +355,9 @@ case class AppendColumns( override def output: Seq[Attribute] = child.output ++ newColumns def newColumns: Seq[Attribute] = serializer.map(_.toAttribute) + + override protected def withNewChildInternal(newChild: LogicalPlan): AppendColumns = + copy(child = newChild) } /** @@ -346,6 +370,9 @@ case class AppendColumnsWithObject( child: LogicalPlan) extends ObjectConsumer { override def output: Seq[Attribute] = (childSerializer ++ newColumnsSerializer).map(_.toAttribute) + + override protected def withNewChildInternal(newChild: LogicalPlan): AppendColumnsWithObject = + copy(child = newChild) } /** Factory for constructing new `MapGroups` nodes. */ @@ -382,7 +409,10 @@ case class MapGroups( groupingAttributes: Seq[Attribute], dataAttributes: Seq[Attribute], outputObjAttr: Attribute, - child: LogicalPlan) extends UnaryNode with ObjectProducer + child: LogicalPlan) extends UnaryNode with ObjectProducer { + override protected def withNewChildInternal(newChild: LogicalPlan): MapGroups = + copy(child = newChild) +} /** Internal class representing State */ trait LogicalGroupState[S] @@ -453,6 +483,9 @@ case class FlatMapGroupsWithState( if (isMapGroupsWithState) { assert(outputMode == OutputMode.Update) } + + override protected def withNewChildInternal(newChild: LogicalPlan): FlatMapGroupsWithState = + copy(child = newChild) } /** Factory for constructing new `FlatMapGroupsInR` nodes. */ @@ -513,6 +546,9 @@ case class FlatMapGroupsInR( override protected def stringArgs: Iterator[Any] = Iterator(inputSchema, outputSchema, keyDeserializer, valueDeserializer, groupingAttributes, dataAttributes, outputObjAttr, child) + + override protected def withNewChildInternal(newChild: LogicalPlan): FlatMapGroupsInR = + copy(child = newChild) } /** @@ -537,6 +573,9 @@ case class FlatMapGroupsInRWithArrow( inputSchema, StructType.fromAttributes(output), keyDeserializer, groupingAttributes, child) override val producedAttributes = AttributeSet(output) + + override protected def withNewChildInternal(newChild: LogicalPlan): FlatMapGroupsInRWithArrow = + copy(child = newChild) } /** Factory for constructing new `CoGroup` nodes. */ @@ -584,4 +623,7 @@ case class CoGroup( rightAttr: Seq[Attribute], outputObjAttr: Attribute, left: LogicalPlan, - right: LogicalPlan) extends BinaryNode with ObjectProducer + right: LogicalPlan) extends BinaryNode with ObjectProducer { + override protected def withNewChildrenInternal( + newLeft: LogicalPlan, newRight: LogicalPlan): CoGroup = copy(left = newLeft, right = newRight) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala index 62f2d598b96dc..ba8352cf6ac89 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala @@ -37,6 +37,9 @@ case class FlatMapGroupsInPandas( * from the input. */ override val producedAttributes = AttributeSet(output) + + override protected def withNewChildInternal(newChild: LogicalPlan): FlatMapGroupsInPandas = + copy(child = newChild) } /** @@ -49,6 +52,9 @@ case class MapInPandas( child: LogicalPlan) extends UnaryNode { override val producedAttributes = AttributeSet(output) + + override protected def withNewChildInternal(newChild: LogicalPlan): MapInPandas = + copy(child = newChild) } /** @@ -70,6 +76,10 @@ case class FlatMapCoGroupsInPandas( def leftAttributes: Seq[Attribute] = left.output.take(leftGroupingLen) def rightAttributes: Seq[Attribute] = right.output.take(rightGroupingLen) + + override protected def withNewChildrenInternal( + newLeft: LogicalPlan, newRight: LogicalPlan): FlatMapCoGroupsInPandas = + copy(left = newLeft, right = newRight) } trait BaseEvalPython extends UnaryNode { @@ -89,7 +99,10 @@ trait BaseEvalPython extends UnaryNode { case class BatchEvalPython( udfs: Seq[PythonUDF], resultAttrs: Seq[Attribute], - child: LogicalPlan) extends BaseEvalPython + child: LogicalPlan) extends BaseEvalPython { + override protected def withNewChildInternal(newChild: LogicalPlan): BatchEvalPython = + copy(child = newChild) +} /** * A logical plan that evaluates a [[PythonUDF]] with Apache Arrow. @@ -98,4 +111,7 @@ case class ArrowEvalPython( udfs: Seq[PythonUDF], resultAttrs: Seq[Attribute], child: LogicalPlan, - evalType: Int) extends BaseEvalPython + evalType: Int) extends BaseEvalPython { + override protected def withNewChildInternal(newChild: LogicalPlan): ArrowEvalPython = + copy(child = newChild) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala index d600c15004d1e..44550ae2844ca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala @@ -167,6 +167,8 @@ case class CreateTableAsSelectStatement( ifNotExists: Boolean) extends UnaryParsedStatement { override def child: LogicalPlan = asSelect + override protected def withNewChildInternal(newChild: LogicalPlan): CreateTableAsSelectStatement = + copy(asSelect = newChild) } /** @@ -181,7 +183,10 @@ case class CreateViewStatement( child: LogicalPlan, allowExisting: Boolean, replace: Boolean, - viewType: ViewType) extends UnaryParsedStatement + viewType: ViewType) extends UnaryParsedStatement { + override protected def withNewChildInternal(newChild: LogicalPlan): CreateViewStatement = + copy(child = newChild) +} /** * A REPLACE TABLE command, as parsed from SQL. @@ -220,6 +225,8 @@ case class ReplaceTableAsSelectStatement( orCreate: Boolean) extends UnaryParsedStatement { override def child: LogicalPlan = asSelect + override protected def withNewChildInternal( + newChild: LogicalPlan): ReplaceTableAsSelectStatement = copy(asSelect = newChild) } @@ -300,6 +307,8 @@ case class InsertIntoStatement( "IF NOT EXISTS is only valid with static partitions") override def child: LogicalPlan = query + override protected def withNewChildInternal(newChild: LogicalPlan): InsertIntoStatement = + copy(query = newChild) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index 889509db25c4c..8b7f2db0584b3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -77,6 +77,8 @@ case class AppendData( write: Option[Write] = None) extends V2WriteCommand { override def withNewQuery(newQuery: LogicalPlan): AppendData = copy(query = newQuery) override def withNewTable(newTable: NamedRelation): AppendData = copy(table = newTable) + override protected def withNewChildInternal(newChild: LogicalPlan): AppendData = + copy(query = newChild) } object AppendData { @@ -115,6 +117,9 @@ case class OverwriteByExpression( override def withNewTable(newTable: NamedRelation): OverwriteByExpression = { copy(table = newTable) } + + override protected def withNewChildInternal(newChild: LogicalPlan): OverwriteByExpression = + copy(query = newChild) } object OverwriteByExpression { @@ -150,6 +155,9 @@ case class OverwritePartitionsDynamic( override def withNewTable(newTable: NamedRelation): OverwritePartitionsDynamic = { copy(table = newTable) } + + override protected def withNewChildInternal(newChild: LogicalPlan): OverwritePartitionsDynamic = + copy(query = newChild) } object OverwritePartitionsDynamic { @@ -222,6 +230,9 @@ case class CreateTableAsSelect( override def withPartitioning(rewritten: Seq[Transform]): V2CreateTablePlan = { this.copy(partitioning = rewritten) } + + override protected def withNewChildInternal(newChild: LogicalPlan): CreateTableAsSelect = + copy(query = newChild) } /** @@ -272,6 +283,9 @@ case class ReplaceTableAsSelect( override def withPartitioning(rewritten: Seq[Transform]): V2CreateTablePlan = { this.copy(partitioning = rewritten) } + + override protected def withNewChildInternal(newChild: LogicalPlan): ReplaceTableAsSelect = + copy(query = newChild) } /** @@ -291,6 +305,8 @@ case class DropNamespace( ifExists: Boolean, cascade: Boolean) extends UnaryCommand { override def child: LogicalPlan = namespace + override protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan = + copy(namespace = newChild) } /** @@ -301,6 +317,8 @@ case class DescribeNamespace( extended: Boolean, override val output: Seq[Attribute] = DescribeNamespace.getOutputAttrs) extends UnaryCommand { override def child: LogicalPlan = namespace + override protected def withNewChildInternal(newChild: LogicalPlan): DescribeNamespace = + copy(namespace = newChild) } object DescribeNamespace { @@ -319,6 +337,8 @@ case class SetNamespaceProperties( namespace: LogicalPlan, properties: Map[String, String]) extends UnaryCommand { override def child: LogicalPlan = namespace + override protected def withNewChildInternal(newChild: LogicalPlan): SetNamespaceProperties = + copy(namespace = newChild) } /** @@ -328,6 +348,8 @@ case class SetNamespaceLocation( namespace: LogicalPlan, location: String) extends UnaryCommand { override def child: LogicalPlan = namespace + override protected def withNewChildInternal(newChild: LogicalPlan): SetNamespaceLocation = + copy(namespace = newChild) } /** @@ -338,6 +360,8 @@ case class ShowNamespaces( pattern: Option[String], override val output: Seq[Attribute] = ShowNamespaces.getOutputAttrs) extends UnaryCommand { override def child: LogicalPlan = namespace + override protected def withNewChildInternal(newChild: LogicalPlan): ShowNamespaces = + copy(namespace = newChild) } object ShowNamespaces { @@ -355,6 +379,8 @@ case class DescribeRelation( isExtended: Boolean, override val output: Seq[Attribute] = DescribeRelation.getOutputAttrs) extends UnaryCommand { override def child: LogicalPlan = relation + override protected def withNewChildInternal(newChild: LogicalPlan): DescribeRelation = + copy(relation = newChild) } object DescribeRelation { @@ -370,6 +396,8 @@ case class DescribeColumn( isExtended: Boolean, override val output: Seq[Attribute] = DescribeColumn.getOutputAttrs) extends UnaryCommand { override def child: LogicalPlan = relation + override protected def withNewChildInternal(newChild: LogicalPlan): DescribeColumn = + copy(relation = newChild) } object DescribeColumn { @@ -383,6 +411,8 @@ case class DeleteFromTable( table: LogicalPlan, condition: Option[Expression]) extends UnaryCommand with SupportsSubquery { override def child: LogicalPlan = table + override protected def withNewChildInternal(newChild: LogicalPlan): DeleteFromTable = + copy(table = newChild) } /** @@ -393,6 +423,8 @@ case class UpdateTable( assignments: Seq[Assignment], condition: Option[Expression]) extends UnaryCommand with SupportsSubquery { override def child: LogicalPlan = table + override protected def withNewChildInternal(newChild: LogicalPlan): UpdateTable = + copy(table = newChild) } /** @@ -407,6 +439,9 @@ case class MergeIntoTable( def duplicateResolved: Boolean = targetTable.outputSet.intersect(sourceTable.outputSet).isEmpty override def left: LogicalPlan = targetTable override def right: LogicalPlan = sourceTable + override protected def withNewChildrenInternal( + newLeft: LogicalPlan, newRight: LogicalPlan): MergeIntoTable = + copy(targetTable = newLeft, sourceTable = newRight) } sealed abstract class MergeAction extends Expression with Unevaluable { @@ -416,28 +451,49 @@ sealed abstract class MergeAction extends Expression with Unevaluable { override def children: Seq[Expression] = condition.toSeq } -case class DeleteAction(condition: Option[Expression]) extends MergeAction +case class DeleteAction(condition: Option[Expression]) extends MergeAction { + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): DeleteAction = + copy(condition = if (condition.isDefined) Some(newChildren(0)) else None) +} case class UpdateAction( condition: Option[Expression], assignments: Seq[Assignment]) extends MergeAction { override def children: Seq[Expression] = condition.toSeq ++ assignments + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): UpdateAction = + copy( + condition = if (condition.isDefined) Some(newChildren.head) else None, + assignments = newChildren.tail.asInstanceOf[Seq[Assignment]]) } case class UpdateStarAction(condition: Option[Expression]) extends MergeAction { override def children: Seq[Expression] = condition.toSeq override lazy val resolved = false + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): UpdateStarAction = + copy(condition = if (condition.isDefined) Some(newChildren(0)) else None) } case class InsertAction( condition: Option[Expression], assignments: Seq[Assignment]) extends MergeAction { override def children: Seq[Expression] = condition.toSeq ++ assignments + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): InsertAction = + copy( + condition = if (condition.isDefined) Some(newChildren.head) else None, + assignments = newChildren.tail.asInstanceOf[Seq[Assignment]]) } case class InsertStarAction(condition: Option[Expression]) extends MergeAction { override def children: Seq[Expression] = condition.toSeq override lazy val resolved = false + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): InsertStarAction = + copy(condition = if (condition.isDefined) Some(newChildren(0)) else None) } case class Assignment(key: Expression, value: Expression) extends Expression @@ -446,6 +502,8 @@ case class Assignment(key: Expression, value: Expression) extends Expression override def dataType: DataType = throw new UnresolvedException("nullable") override def left: Expression = key override def right: Expression = value + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): Assignment = copy(key = newLeft, value = newRight) } /** @@ -462,7 +520,10 @@ case class Assignment(key: Expression, value: Expression) extends Expression case class DropTable( child: LogicalPlan, ifExists: Boolean, - purge: Boolean) extends UnaryCommand + purge: Boolean) extends UnaryCommand { + override protected def withNewChildInternal(newChild: LogicalPlan): DropTable = + copy(child = newChild) +} /** * The logical plan for no-op command handling non-existing table. @@ -509,7 +570,10 @@ case class AlterTable( case class RenameTable( child: LogicalPlan, newName: Seq[String], - isView: Boolean) extends UnaryCommand + isView: Boolean) extends UnaryCommand { + override protected def withNewChildInternal(newChild: LogicalPlan): RenameTable = + copy(child = newChild) +} /** * The logical plan of the SHOW TABLES command. @@ -519,6 +583,8 @@ case class ShowTables( pattern: Option[String], override val output: Seq[Attribute] = ShowTables.getOutputAttrs) extends UnaryCommand { override def child: LogicalPlan = namespace + override protected def withNewChildInternal(newChild: LogicalPlan): ShowTables = + copy(namespace = newChild) } object ShowTables { @@ -537,6 +603,8 @@ case class ShowTableExtended( partitionSpec: Option[PartitionSpec], override val output: Seq[Attribute] = ShowTableExtended.getOutputAttrs) extends UnaryCommand { override def child: LogicalPlan = namespace + override protected def withNewChildInternal(newChild: LogicalPlan): ShowTableExtended = + copy(namespace = newChild) } object ShowTableExtended { @@ -558,6 +626,8 @@ case class ShowViews( pattern: Option[String], override val output: Seq[Attribute] = ShowViews.getOutputAttrs) extends UnaryCommand { override def child: LogicalPlan = namespace + override protected def withNewChildInternal(newChild: LogicalPlan): ShowViews = + copy(namespace = newChild) } object ShowViews { @@ -578,7 +648,10 @@ case class SetCatalogAndNamespace( /** * The logical plan of the REFRESH TABLE command. */ -case class RefreshTable(child: LogicalPlan) extends UnaryCommand +case class RefreshTable(child: LogicalPlan) extends UnaryCommand { + override protected def withNewChildInternal(newChild: LogicalPlan): RefreshTable = + copy(child = newChild) +} /** * The logical plan of the SHOW CURRENT NAMESPACE command. @@ -597,6 +670,8 @@ case class ShowTableProperties( propertyKey: Option[String], override val output: Seq[Attribute] = ShowTableProperties.getOutputAttrs) extends UnaryCommand { override def child: LogicalPlan = table + override protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan = + copy(table = newChild) } object ShowTableProperties { @@ -615,7 +690,10 @@ object ShowTableProperties { * where the `text` is the new comment written as a string literal; or `NULL` to drop the comment. * */ -case class CommentOnNamespace(child: LogicalPlan, comment: String) extends UnaryCommand +case class CommentOnNamespace(child: LogicalPlan, comment: String) extends UnaryCommand { + override protected def withNewChildInternal(newChild: LogicalPlan): CommentOnNamespace = + copy(child = newChild) +} /** * The logical plan that defines or changes the comment of an TABLE for v2 catalogs. @@ -627,17 +705,26 @@ case class CommentOnNamespace(child: LogicalPlan, comment: String) extends Unary * where the `text` is the new comment written as a string literal; or `NULL` to drop the comment. * */ -case class CommentOnTable(child: LogicalPlan, comment: String) extends UnaryCommand +case class CommentOnTable(child: LogicalPlan, comment: String) extends UnaryCommand { + override protected def withNewChildInternal(newChild: LogicalPlan): CommentOnTable = + copy(child = newChild) +} /** * The logical plan of the REFRESH FUNCTION command. */ -case class RefreshFunction(child: LogicalPlan) extends UnaryCommand +case class RefreshFunction(child: LogicalPlan) extends UnaryCommand { + override protected def withNewChildInternal(newChild: LogicalPlan): RefreshFunction = + copy(child = newChild) +} /** * The logical plan of the DESCRIBE FUNCTION command. */ -case class DescribeFunction(child: LogicalPlan, isExtended: Boolean) extends UnaryCommand +case class DescribeFunction(child: LogicalPlan, isExtended: Boolean) extends UnaryCommand { + override protected def withNewChildInternal(newChild: LogicalPlan): DescribeFunction = + copy(child = newChild) +} /** * The logical plan of the DROP FUNCTION command. @@ -645,7 +732,10 @@ case class DescribeFunction(child: LogicalPlan, isExtended: Boolean) extends Una case class DropFunction( child: LogicalPlan, ifExists: Boolean, - isTemp: Boolean) extends UnaryCommand + isTemp: Boolean) extends UnaryCommand { + override protected def withNewChildInternal(newChild: LogicalPlan): DropFunction = + copy(child = newChild) +} /** * The logical plan of the SHOW FUNCTIONS command. @@ -657,6 +747,9 @@ case class ShowFunctions( pattern: Option[String], override val output: Seq[Attribute] = ShowFunctions.getOutputAttrs) extends Command { override def children: Seq[LogicalPlan] = child.toSeq + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[LogicalPlan]): ShowFunctions = + copy(child = if (child.isDefined) Some(newChildren.head) else None) } object ShowFunctions { @@ -671,7 +764,10 @@ object ShowFunctions { case class AnalyzeTable( child: LogicalPlan, partitionSpec: Map[String, Option[String]], - noScan: Boolean) extends UnaryCommand + noScan: Boolean) extends UnaryCommand { + override protected def withNewChildInternal(newChild: LogicalPlan): AnalyzeTable = + copy(child = newChild) +} /** * The logical plan of the ANALYZE TABLES command. @@ -680,6 +776,8 @@ case class AnalyzeTables( namespace: LogicalPlan, noScan: Boolean) extends UnaryCommand { override def child: LogicalPlan = namespace + override protected def withNewChildInternal(newChild: LogicalPlan): AnalyzeTables = + copy(namespace = newChild) } /** @@ -691,6 +789,9 @@ case class AnalyzeColumn( allColumns: Boolean) extends UnaryCommand { require(columnNames.isDefined ^ allColumns, "Parameter `columnNames` or `allColumns` are " + "mutually exclusive. Only one of them should be specified.") + + override protected def withNewChildInternal(newChild: LogicalPlan): AnalyzeColumn = + copy(child = newChild) } /** @@ -705,7 +806,10 @@ case class AnalyzeColumn( case class AddPartitions( table: LogicalPlan, parts: Seq[PartitionSpec], - ifNotExists: Boolean) extends V2PartitionCommand + ifNotExists: Boolean) extends V2PartitionCommand { + override protected def withNewChildInternal(newChild: LogicalPlan): AddPartitions = + copy(table = newChild) +} /** * The logical plan of the ALTER TABLE DROP PARTITION command. @@ -723,7 +827,10 @@ case class DropPartitions( table: LogicalPlan, parts: Seq[PartitionSpec], ifExists: Boolean, - purge: Boolean) extends V2PartitionCommand + purge: Boolean) extends V2PartitionCommand { + override protected def withNewChildInternal(newChild: LogicalPlan): DropPartitions = + copy(table = newChild) +} /** * The logical plan of the ALTER TABLE ... RENAME TO PARTITION command. @@ -731,12 +838,18 @@ case class DropPartitions( case class RenamePartitions( table: LogicalPlan, from: PartitionSpec, - to: PartitionSpec) extends V2PartitionCommand + to: PartitionSpec) extends V2PartitionCommand { + override protected def withNewChildInternal(newChild: LogicalPlan): RenamePartitions = + copy(table = newChild) +} /** * The logical plan of the ALTER TABLE ... RECOVER PARTITIONS command. */ -case class RecoverPartitions(child: LogicalPlan) extends UnaryCommand +case class RecoverPartitions(child: LogicalPlan) extends UnaryCommand { + override protected def withNewChildInternal(newChild: LogicalPlan): RecoverPartitions = + copy(child = newChild) +} /** * The logical plan of the LOAD DATA INTO TABLE command. @@ -746,7 +859,10 @@ case class LoadData( path: String, isLocal: Boolean, isOverwrite: Boolean, - partition: Option[TablePartitionSpec]) extends UnaryCommand + partition: Option[TablePartitionSpec]) extends UnaryCommand { + override protected def withNewChildInternal(newChild: LogicalPlan): LoadData = + copy(child = newChild) +} /** * The logical plan of the SHOW CREATE TABLE command. @@ -754,7 +870,10 @@ case class LoadData( case class ShowCreateTable( child: LogicalPlan, asSerde: Boolean = false, - override val output: Seq[Attribute] = ShowCreateTable.getoutputAttrs) extends UnaryCommand + override val output: Seq[Attribute] = ShowCreateTable.getoutputAttrs) extends UnaryCommand { + override protected def withNewChildInternal(newChild: LogicalPlan): ShowCreateTable = + copy(child = newChild) +} object ShowCreateTable { def getoutputAttrs: Seq[Attribute] = { @@ -768,7 +887,10 @@ object ShowCreateTable { case class ShowColumns( child: LogicalPlan, namespace: Option[Seq[String]], - override val output: Seq[Attribute] = ShowColumns.getOutputAttrs) extends UnaryCommand + override val output: Seq[Attribute] = ShowColumns.getOutputAttrs) extends UnaryCommand { + override protected def withNewChildInternal(newChild: LogicalPlan): ShowColumns = + copy(child = newChild) +} object ShowColumns { def getOutputAttrs: Seq[Attribute] = { @@ -781,6 +903,8 @@ object ShowColumns { */ case class TruncateTable(table: LogicalPlan) extends UnaryCommand { override def child: LogicalPlan = table + override protected def withNewChildInternal(newChild: LogicalPlan): TruncateTable = + copy(table = newChild) } /** @@ -790,6 +914,8 @@ case class TruncatePartition( table: LogicalPlan, partitionSpec: PartitionSpec) extends V2PartitionCommand { override def allowPartialPartitionSpec: Boolean = true + override protected def withNewChildInternal(newChild: LogicalPlan): TruncatePartition = + copy(table = newChild) } /** @@ -801,6 +927,8 @@ case class ShowPartitions( override val output: Seq[Attribute] = ShowPartitions.getOutputAttrs) extends V2PartitionCommand { override def allowPartialPartitionSpec: Boolean = true + override protected def withNewChildInternal(newChild: LogicalPlan): ShowPartitions = + copy(table = newChild) } object ShowPartitions { @@ -814,7 +942,10 @@ object ShowPartitions { */ case class DropView( child: LogicalPlan, - ifExists: Boolean) extends UnaryCommand + ifExists: Boolean) extends UnaryCommand { + override protected def withNewChildInternal(newChild: LogicalPlan): DropView = + copy(child = newChild) +} /** * The logical plan of the MSCK REPAIR TABLE command. @@ -822,7 +953,10 @@ case class DropView( case class RepairTable( child: LogicalPlan, enableAddPartitions: Boolean, - enableDropPartitions: Boolean) extends UnaryCommand + enableDropPartitions: Boolean) extends UnaryCommand { + override protected def withNewChildInternal(newChild: LogicalPlan): RepairTable = + copy(child = newChild) +} /** * The logical plan of the ALTER VIEW ... AS command. @@ -833,6 +967,9 @@ case class AlterViewAs( query: LogicalPlan) extends BinaryCommand { override def left: LogicalPlan = child override def right: LogicalPlan = query + override protected def withNewChildrenInternal( + newLeft: LogicalPlan, newRight: LogicalPlan): LogicalPlan = + copy(child = newLeft, query = newRight) } /** @@ -840,7 +977,10 @@ case class AlterViewAs( */ case class SetViewProperties( child: LogicalPlan, - properties: Map[String, String]) extends UnaryCommand + properties: Map[String, String]) extends UnaryCommand { + override protected def withNewChildInternal(newChild: LogicalPlan): SetViewProperties = + copy(child = newChild) +} /** * The logical plan of the ALTER VIEW ... UNSET TBLPROPERTIES command. @@ -848,7 +988,10 @@ case class SetViewProperties( case class UnsetViewProperties( child: LogicalPlan, propertyKeys: Seq[String], - ifExists: Boolean) extends UnaryCommand + ifExists: Boolean) extends UnaryCommand { + override protected def withNewChildInternal(newChild: LogicalPlan): UnsetViewProperties = + copy(child = newChild) +} /** * The logical plan of the ALTER TABLE ... SET [SERDE|SERDEPROPERTIES] command. @@ -857,7 +1000,10 @@ case class SetTableSerDeProperties( child: LogicalPlan, serdeClassName: Option[String], serdeProperties: Option[Map[String, String]], - partitionSpec: Option[TablePartitionSpec]) extends UnaryCommand + partitionSpec: Option[TablePartitionSpec]) extends UnaryCommand { + override protected def withNewChildInternal(newChild: LogicalPlan): SetTableSerDeProperties = + copy(child = newChild) +} /** * The logical plan of the CACHE TABLE command. @@ -894,6 +1040,8 @@ case class SetTableLocation( partitionSpec: Option[TablePartitionSpec], location: String) extends UnaryCommand { override def child: LogicalPlan = table + override protected def withNewChildInternal(newChild: LogicalPlan): SetTableLocation = + copy(table = newChild) } /** @@ -903,6 +1051,8 @@ case class SetTableProperties( table: LogicalPlan, properties: Map[String, String]) extends UnaryCommand { override def child: LogicalPlan = table + override protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan = + copy(table = newChild) } /** @@ -913,4 +1063,6 @@ case class UnsetTableProperties( propertyKeys: Seq[String], ifExists: Boolean) extends UnaryCommand { override def child: LogicalPlan = table + override protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan = + copy(table = newChild) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index c4002aa441a50..0f8c7887b2b1c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -235,6 +235,9 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) * than numPartitions) based on hashing expressions. */ def partitionIdExpression: Expression = Pmod(new Murmur3Hash(expressions), Literal(numPartitions)) + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): HashPartitioning = copy(expressions = newChildren) } /** @@ -284,6 +287,10 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) } } } + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): RangePartitioning = + copy(ordering = newChildren.asInstanceOf[Seq[SortOrder]]) } /** @@ -326,6 +333,10 @@ case class PartitioningCollection(partitionings: Seq[Partitioning]) override def toString: String = { partitionings.map(_.toString).mkString("(", " or ", ")") } + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): PartitioningCollection = + super.legacyWithNewChildren(newChildren).asInstanceOf[PartitioningCollection] } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/WriteToStream.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/WriteToStream.scala index 990ae302dbbee..2a29137355a4e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/WriteToStream.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/WriteToStream.scala @@ -39,5 +39,7 @@ case class WriteToStream( override def child: LogicalPlan = inputQuery + override protected def withNewChildInternal(newChild: LogicalPlan): WriteToStream = + copy(inputQuery = newChild) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/WriteToStreamStatement.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/WriteToStreamStatement.scala index 34a4c13efb62e..407c70a591d72 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/WriteToStreamStatement.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/WriteToStreamStatement.scala @@ -57,5 +57,8 @@ case class WriteToStreamStatement( override def output: Seq[Attribute] = Nil override def child: LogicalPlan = inputQuery + + override protected def withNewChildInternal(newChild: LogicalPlan): WriteToStreamStatement = + copy(inputQuery = newChild) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 8fc62382bdbba..3fab95cbe4c38 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -246,11 +246,50 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { arr } + private def childrenFastEquals( + originalChildren: IndexedSeq[BaseType], newChildren: IndexedSeq[BaseType]): Boolean = { + val size = originalChildren.size + var i = 0 + while (i < size) { + if (!originalChildren(i).fastEquals(newChildren(i))) return false + i += 1 + } + true + } + + // This is a temporary solution, we will change the type of children to IndexedSeq in a + // followup PR + private def asIndexedSeq(seq: Seq[BaseType]): IndexedSeq[BaseType] = { + if (seq.isInstanceOf[IndexedSeq[BaseType]]) { + seq.asInstanceOf[IndexedSeq[BaseType]] + } else { + seq.toIndexedSeq + } + } + + final def withNewChildren(newChildren: Seq[BaseType]): BaseType = { + val childrenIndexedSeq = asIndexedSeq(children) + val newChildrenIndexedSeq = asIndexedSeq(newChildren) + assert(newChildrenIndexedSeq.size == childrenIndexedSeq.size, "Incorrect number of children") + if (childrenIndexedSeq.isEmpty || + childrenFastEquals(newChildrenIndexedSeq, childrenIndexedSeq)) { + this + } else { + CurrentOrigin.withOrigin(origin) { + val res = withNewChildrenInternal(newChildrenIndexedSeq) + res.copyTagsFrom(this) + res + } + } + } + + protected def withNewChildrenInternal(newChildren: IndexedSeq[BaseType]): BaseType + /** * Returns a copy of this node with the children replaced. * TODO: Validate somewhere (in debug mode?) that children are ordered correctly. */ - def withNewChildren(newChildren: Seq[BaseType]): BaseType = { + protected final def legacyWithNewChildren(newChildren: Seq[BaseType]): BaseType = { assert(newChildren.size == children.size, "Incorrect number of children") var changed = false val remainingNewChildren = newChildren.toBuffer @@ -355,7 +394,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { */ def mapChildren(f: BaseType => BaseType): BaseType = { if (containsChild.nonEmpty) { - mapChildren(f, forceCopy = false) + withNewChildren(children.map(f)) } else { this } @@ -844,24 +883,96 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { trait LeafLike[T <: TreeNode[T]] { self: TreeNode[T] => override final def children: Seq[T] = Nil + override final def mapChildren(f: T => T): T = this.asInstanceOf[T] + override final def withNewChildrenInternal(newChildren: IndexedSeq[T]): T = this.asInstanceOf[T] } trait UnaryLike[T <: TreeNode[T]] { self: TreeNode[T] => def child: T - @transient override final lazy val children: Seq[T] = child :: Nil + @transient override final lazy val children: Seq[T] = IndexedSeq(child) + + override final def mapChildren(f: T => T): T = { + val newChild = f(child) + if (newChild fastEquals child) { + this.asInstanceOf[T] + } else { + CurrentOrigin.withOrigin(origin) { + val res = withNewChildInternal(newChild) + res.copyTagsFrom(this.asInstanceOf[T]) + res + } + } + } + + override final def withNewChildrenInternal(newChildren: IndexedSeq[T]): T = { + assert(newChildren.size == 1, "Incorrect number of children") + withNewChildInternal(newChildren.head) + } + + protected def withNewChildInternal(newChild: T): T } trait BinaryLike[T <: TreeNode[T]] { self: TreeNode[T] => def left: T def right: T - @transient override final lazy val children: Seq[T] = left :: right :: Nil + @transient override final lazy val children: Seq[T] = IndexedSeq(left, right) + + override final def mapChildren(f: T => T): T = { + var newLeft = f(left) + newLeft = if (newLeft fastEquals left) left else newLeft + var newRight = f(right) + newRight = if (newRight fastEquals right) right else newRight + + if (newLeft.eq(left) && newRight.eq(right)) { + this.asInstanceOf[T] + } else { + CurrentOrigin.withOrigin(origin) { + val res = withNewChildrenInternal(newLeft, newRight) + res.copyTagsFrom(this.asInstanceOf[T]) + res + } + } + } + + override final def withNewChildrenInternal(newChildren: IndexedSeq[T]): T = { + assert(newChildren.size == 2, "Incorrect number of children") + withNewChildrenInternal(newChildren(0), newChildren(1)) + } + + protected def withNewChildrenInternal(newLeft: T, newRight: T): T } trait TernaryLike[T <: TreeNode[T]] { self: TreeNode[T] => def first: T def second: T def third: T - @transient override final lazy val children: Seq[T] = first :: second :: third :: Nil + @transient override final lazy val children: Seq[T] = IndexedSeq(first, second, third) + + override final def mapChildren(f: T => T): T = { + var newFirst = f(first) + newFirst = if (newFirst fastEquals first) first else newFirst + var newSecond = f(second) + newSecond = if (newSecond fastEquals second) second else newSecond + var newThird = f(third) + newThird = if (newThird fastEquals third) third else newThird + + if (newFirst.eq(first) && newSecond.eq(second) && newThird.eq(third)) { + this.asInstanceOf[T] + } else { + CurrentOrigin.withOrigin(origin) { + val res = withNewChildrenInternal(newFirst, newSecond, newThird) + res.copyTagsFrom(this.asInstanceOf[T]) + res + } + } + } + + override final def withNewChildrenInternal(newChildren: IndexedSeq[T]): T = { + assert(newChildren.size == 3, "Incorrect number of children") + withNewChildrenInternal(newChildren(0), newChildren(1), newChildren(2)) + } + + protected def withNewChildrenInternal(newFirst: T, newSecond: T, newThird: T): T } trait QuaternaryLike[T <: TreeNode[T]] { self: TreeNode[T] => @@ -869,5 +980,33 @@ trait QuaternaryLike[T <: TreeNode[T]] { self: TreeNode[T] => def second: T def third: T def fourth: T - @transient override final lazy val children: Seq[T] = first :: second :: third :: fourth :: Nil + @transient override final lazy val children: Seq[T] = IndexedSeq(first, second, third, fourth) + + override final def mapChildren(f: T => T): T = { + var newFirst = f(first) + newFirst = if (newFirst fastEquals first) first else newFirst + var newSecond = f(second) + newSecond = if (newSecond fastEquals second) second else newSecond + var newThird = f(third) + newThird = if (newThird fastEquals third) third else newThird + var newFourth = f(fourth) + newFourth = if (newFourth fastEquals fourth) fourth else newFourth + + if (newFirst.eq(first) && newSecond.eq(second) && newThird.eq(third) && newFourth.eq(fourth)) { + this.asInstanceOf[T] + } else { + CurrentOrigin.withOrigin(origin) { + val res = withNewChildrenInternal(newFirst, newSecond, newThird, newFourth) + res.copyTagsFrom(this.asInstanceOf[T]) + res + } + } + } + + override final def withNewChildrenInternal(newChildren: IndexedSeq[T]): T = { + assert(newChildren.size == 4, "Incorrect number of children") + withNewChildrenInternal(newChildren(0), newChildren(1), newChildren(2), newChildren(3)) + } + + protected def withNewChildrenInternal(newFirst: T, newSecond: T, newThird: T, newFourth: T): T } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index a9d9acdc8f52e..aecbf241e3947 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -88,6 +88,8 @@ case class TestFunction( extends Expression with ImplicitCastInputTypes with Unevaluable { override def nullable: Boolean = true override def dataType: DataType = StringType + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = + copy(children = newChildren) } case class UnresolvedTestPlan() extends LeafNode { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index a6145c5421d48..9058e3eb3f041 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -1623,12 +1623,16 @@ object TypeCoercionSuite { extends UnaryExpression with ExpectsInputTypes with Unevaluable { override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) override def dataType: DataType = NullType + override protected def withNewChildInternal(newChild: Expression): AnyTypeUnaryExpression = + copy(child = newChild) } case class NumericTypeUnaryExpression(child: Expression) extends UnaryExpression with ExpectsInputTypes with Unevaluable { override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) override def dataType: DataType = NullType + override protected def withNewChildInternal(newChild: Expression): NumericTypeUnaryExpression = + copy(child = newChild) } case class AnyTypeBinaryOperator(left: Expression, right: Expression) @@ -1636,6 +1640,9 @@ object TypeCoercionSuite { override def dataType: DataType = NullType override def inputType: AbstractDataType = AnyDataType override def symbol: String = "anytype" + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): AnyTypeBinaryOperator = + copy(left = newLeft, right = newRight) } case class NumericTypeBinaryOperator(left: Expression, right: Expression) @@ -1643,5 +1650,8 @@ object TypeCoercionSuite { override def dataType: DataType = NullType override def inputType: AbstractDataType = NumericType override def symbol: String = "numerictype" + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): NumericTypeBinaryOperator = + copy(left = newLeft, right = newRight) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala index 71993e1a369ec..dc62841e058e7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala @@ -998,6 +998,8 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper { case class StreamingPlanWrapper(child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output override def isStreaming: Boolean = true + override protected def withNewChildInternal(newChild: LogicalPlan): StreamingPlanWrapper = + copy(child = newChild) } case class TestStreamingRelation(output: Seq[Attribute]) extends LeafNode { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala index 65671d253dc53..9bfe69b1709d2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala @@ -314,4 +314,6 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel case class CodegenFallbackExpression(child: Expression) extends UnaryExpression with CodegenFallback { override def dataType: DataType = child.dataType + override protected def withNewChildInternal(newChild: Expression): CodegenFallbackExpression = + copy(child = newChild) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConvertToLocalRelationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConvertToLocalRelationSuite.scala index 43579d4c903a1..02b6eed9ed050 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConvertToLocalRelationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConvertToLocalRelationSuite.scala @@ -104,4 +104,7 @@ case class ExprReuseOutput(child: Expression) extends UnaryExpression { row.update(0, child.eval(input)) row } + + override protected def withNewChildInternal(newChild: Expression): ExprReuseOutput = + copy(child = newChild) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala index 84452399de824..3784f40101702 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala @@ -66,6 +66,9 @@ class LogicalPlanSuite extends SparkFunSuite { case class TestBinaryRelation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { override def output: Seq[Attribute] = left.output ++ right.output + override protected def withNewChildrenInternal( + newLeft: LogicalPlan, newRight: LogicalPlan): LogicalPlan = + copy(left = newLeft, right = newRight) } require(relation.isStreaming === false) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanIntegritySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanIntegritySuite.scala index 6f342b8d94379..009e2a731fe41 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanIntegritySuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanIntegritySuite.scala @@ -28,6 +28,8 @@ class LogicalPlanIntegritySuite extends PlanTest { case class OutputTestPlan(child: LogicalPlan, output: Seq[Attribute]) extends UnaryNode { override val analyzed = true + override protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan = + copy(child = newChild) } test("Checks if the same `ExprId` refers to a semantically-equal attribute in a plan output") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index 4ad8475a0113c..0d316779d8bcb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -47,6 +47,8 @@ case class Dummy(optKey: Option[Expression]) extends Expression with CodegenFall override def dataType: NullType = NullType override lazy val resolved = true override def eval(input: InternalRow): Any = null.asInstanceOf[Any] + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = + copy(optKey = if (optKey.isDefined) Some(newChildren(0)) else None) } case class ComplexPlan(exprs: Seq[Seq[Expression]]) @@ -59,6 +61,8 @@ case class ExpressionInMap(map: Map[String, Expression]) extends Unevaluable { override def nullable: Boolean = true override def dataType: NullType = NullType override lazy val resolved = true + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = + super.legacyWithNewChildren(newChildren) } case class SeqTupleExpression(sons: Seq[(Expression, Expression)], @@ -67,6 +71,9 @@ case class SeqTupleExpression(sons: Seq[(Expression, Expression)], override def nullable: Boolean = true override def dataType: NullType = NullType override lazy val resolved = true + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = + super.legacyWithNewChildren(newChildren) } case class JsonTestTreeNode(arg: Any) extends LeafNode { @@ -738,7 +745,10 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper { } object MalformedClassObject extends Serializable { - case class MalformedNameExpression(child: Expression) extends TaggingExpression + case class MalformedNameExpression(child: Expression) extends TaggingExpression { + override protected def withNewChildInternal(newChild: Expression): Expression = + copy(child = newChild) + } } test("SPARK-32999: TreeNode.nodeName should not throw malformed class name error") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CollectMetricsExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CollectMetricsExec.scala index b0bbb52bc4990..500425e4809e0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CollectMetricsExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CollectMetricsExec.scala @@ -78,6 +78,9 @@ case class CollectMetricsExec( } } } + + override protected def withNewChildInternal(newChild: SparkPlan): CollectMetricsExec = + copy(child = newChild) } object CollectMetricsExec { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Columnar.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Columnar.scala index 8d542792a0e28..6bdd93e9230b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Columnar.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Columnar.scala @@ -201,6 +201,9 @@ case class ColumnarToRowExec(child: SparkPlan) extends ColumnarToRowTransition w override def inputRDDs(): Seq[RDD[InternalRow]] = { Seq(child.executeColumnar().asInstanceOf[RDD[InternalRow]]) // Hack because of type erasure } + + override protected def withNewChildInternal(newChild: SparkPlan): ColumnarToRowExec = + copy(child = newChild) } /** @@ -486,6 +489,9 @@ case class RowToColumnarExec(child: SparkPlan) extends RowToColumnarTransition { } } } + + override protected def withNewChildInternal(newChild: SparkPlan): RowToColumnarExec = + copy(child = newChild) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala index 6f5bf15d82638..3fd653130e57c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala @@ -203,4 +203,7 @@ case class ExpandExec( |} """.stripMargin } + + override protected def withNewChildInternal(newChild: SparkPlan): ExpandExec = + copy(child = newChild) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala index 0d5ec2d6c6f1c..6c7929437ffdd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala @@ -325,4 +325,7 @@ case class GenerateExec( if (condition) Seq(code) else Seq.empty } + + override protected def withNewChildInternal(newChild: SparkPlan): GenerateExec = + copy(child = newChild) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala index 6b6ca531c6d3b..984a45cd058ab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala @@ -202,4 +202,7 @@ case class SortExec( } super.cleanupResources() } + + override protected def withNewChildInternal(newChild: SparkPlan): SortExec = + copy(child = newChild) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkScriptTransformationExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkScriptTransformationExec.scala index 75c91667012a3..7f3628926e351 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkScriptTransformationExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkScriptTransformationExec.scala @@ -72,6 +72,9 @@ case class SparkScriptTransformationExec( outputIterator } + + override protected def withNewChildInternal(newChild: SparkPlan): SparkScriptTransformationExec = + copy(child = newChild) } case class SparkScriptTransformationWriterThread( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SubqueryAdaptiveBroadcastExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SubqueryAdaptiveBroadcastExec.scala index cece43090cb76..a735d913c953a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SubqueryAdaptiveBroadcastExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SubqueryAdaptiveBroadcastExec.scala @@ -39,4 +39,7 @@ case class SubqueryAdaptiveBroadcastExec( throw new UnsupportedOperationException( "SubqueryAdaptiveBroadcastExec does not support the execute() code path.") } + + override protected def withNewChildInternal(newChild: SparkPlan): SubqueryAdaptiveBroadcastExec = + copy(child = newChild) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SubqueryBroadcastExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SubqueryBroadcastExec.scala index 70ba13550afbf..47cb70dde86a8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SubqueryBroadcastExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SubqueryBroadcastExec.scala @@ -113,6 +113,9 @@ case class SubqueryBroadcastExec( } override def stringArgs: Iterator[Any] = super.stringArgs ++ Iterator(s"[id=#$id]") + + override protected def withNewChildInternal(newChild: SparkPlan): SubqueryBroadcastExec = + copy(child = newChild) } object SubqueryBroadcastExec { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 9c50dc91b6385..85bc98d194fee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -554,6 +554,9 @@ case class InputAdapter(child: SparkPlan) extends UnaryExecNode with InputRDDCod } override def needCopyResult: Boolean = false + + override protected def withNewChildInternal(newChild: SparkPlan): InputAdapter = + copy(child = newChild) } object WholeStageCodegenExec { @@ -829,6 +832,9 @@ case class WholeStageCodegenExec(child: SparkPlan)(val codegenStageId: Int) override def limitNotReachedChecks: Seq[String] = Nil override protected def otherCopyArgs: Seq[AnyRef] = Seq(codegenStageId.asInstanceOf[Integer]) + + override protected def withNewChildInternal(newChild: SparkPlan): WholeStageCodegenExec = + copy(child = newChild)(codegenStageId) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CustomShuffleReaderExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CustomShuffleReaderExec.scala index 4639ccc11fc6a..f2eefbc028b5f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CustomShuffleReaderExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CustomShuffleReaderExec.scala @@ -195,4 +195,7 @@ case class CustomShuffleReaderExec private( override protected def doExecuteColumnar(): RDD[ColumnarBatch] = { shuffleRDD.asInstanceOf[RDD[ColumnarBatch]] } + + override protected def withNewChildInternal(newChild: SparkPlan): CustomShuffleReaderExec = + copy(child = newChild) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 7d45638146a71..6e23a2844d148 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -1108,6 +1108,9 @@ case class HashAggregateExec( s"$allAggregateExpressions $resultExpressions fallbackStartsAt=$fallbackStartsAt" } } + + override protected def withNewChildInternal(newChild: SparkPlan): HashAggregateExec = + copy(child = newChild) } object HashAggregateExec { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala index e5f59e0d4e9bf..559f545dc05ad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala @@ -138,6 +138,9 @@ case class ObjectHashAggregateExec( s"ObjectHashAggregate(keys=$keyString, functions=$functionString)" } } + + override protected def withNewChildInternal(newChild: SparkPlan): ObjectHashAggregateExec = + copy(child = newChild) } object ObjectHashAggregateExec { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala index 2400ceef544d6..4fb0f44db81c7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala @@ -101,4 +101,7 @@ case class SortAggregateExec( s"SortAggregate(key=$keyString, functions=$functionString)" } } + + override protected def withNewChildInternal(newChild: SparkPlan): SortAggregateExec = + copy(child = newChild) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala index ea44c6013b7d9..d958790dd09b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala @@ -203,6 +203,10 @@ case class SimpleTypedAggregateExpression( schema: StructType): TypedAggregateExpression = { copy(inputDeserializer = Some(deser), inputClass = Some(cls), inputSchema = Some(schema)) } + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): SimpleTypedAggregateExpression = + super.legacyWithNewChildren(newChildren).asInstanceOf[SimpleTypedAggregateExpression] } case class ComplexTypedAggregateExpression( @@ -285,4 +289,8 @@ case class ComplexTypedAggregateExpression( schema: StructType): TypedAggregateExpression = { copy(inputDeserializer = Some(deser), inputClass = Some(cls), inputSchema = Some(schema)) } + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): ComplexTypedAggregateExpression = + super.legacyWithNewChildren(newChildren).asInstanceOf[ComplexTypedAggregateExpression] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala index e6851a9af739f..1aae76e0fb29b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala @@ -454,6 +454,9 @@ case class ScalaUDAF( override def nodeName: String = name override def name: String = udafName.getOrElse(udaf.getClass.getSimpleName) + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): ScalaUDAF = + copy(children = newChildren) } case class ScalaAggregator[IN, BUF, OUT]( @@ -520,6 +523,10 @@ case class ScalaAggregator[IN, BUF, OUT]( override def nodeName: String = name override def name: String = aggregatorName.getOrElse(agg.getClass.getSimpleName) + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): ScalaAggregator[IN, BUF, OUT] = + copy(children = newChildren) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index abd336006848b..b537040fe71df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -107,6 +107,9 @@ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan) |${ExplainUtils.generateFieldString("Input", child.output)} |""".stripMargin } + + override protected def withNewChildInternal(newChild: SparkPlan): ProjectExec = + copy(child = newChild) } trait GeneratePredicateHelper extends PredicateHelper { @@ -286,6 +289,9 @@ case class FilterExec(condition: Expression, child: SparkPlan) |Condition : ${condition} |""".stripMargin } + + override protected def withNewChildInternal(newChild: SparkPlan): FilterExec = + copy(child = newChild) } /** @@ -392,6 +398,9 @@ case class SampleExec( """.stripMargin.trim } } + + override protected def withNewChildInternal(newChild: SparkPlan): SampleExec = + copy(child = newChild) } @@ -687,6 +696,9 @@ case class UnionExec(children: Seq[SparkPlan]) extends SparkPlan { protected override def doExecute(): RDD[InternalRow] = sparkContext.union(children.map(_.execute())) + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[SparkPlan]): UnionExec = + copy(children = newChildren) } /** @@ -720,6 +732,9 @@ case class CoalesceExec(numPartitions: Int, child: SparkPlan) extends UnaryExecN child.execute().coalesce(numPartitions, shuffle = false) } } + + override protected def withNewChildInternal(newChild: SparkPlan): CoalesceExec = + copy(child = newChild) } object CoalesceExec { @@ -849,6 +864,9 @@ case class SubqueryExec(name: String, child: SparkPlan, maxNumRows: Option[Int] } override def stringArgs: Iterator[Any] = Iterator(name, child) ++ Iterator(s"[id=#$id]") + + override protected def withNewChildInternal(newChild: SparkPlan): SubqueryExec = + copy(child = newChild) } object SubqueryExec { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala index 641bd26c381ad..e3c2e90a42dec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.types._ case class AnalyzeColumnCommand( tableIdent: TableIdentifier, columnNames: Option[Seq[String]], - allColumns: Boolean) extends RunnableCommand { + allColumns: Boolean) extends LeafRunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { require(columnNames.isDefined ^ allColumns, "Parameter `columnNames` or `allColumns` are " + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzePartitionCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzePartitionCommand.scala index 51d4c5f41b1d3..5b3cb7476608b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzePartitionCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzePartitionCommand.scala @@ -43,7 +43,7 @@ import org.apache.spark.sql.util.PartitioningUtils case class AnalyzePartitionCommand( tableIdent: TableIdentifier, partitionSpec: Map[String, Option[String]], - noscan: Boolean = true) extends RunnableCommand { + noscan: Boolean = true) extends LeafRunnableCommand { private def getPartitionSpec(table: CatalogTable): Option[TablePartitionSpec] = { val normalizedPartitionSpec = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala index d114ca015d7ca..157554e821811 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier */ case class AnalyzeTableCommand( tableIdent: TableIdentifier, - noScan: Boolean = true) extends RunnableCommand { + noScan: Boolean = true) extends LeafRunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { CommandUtils.analyzeTable(sparkSession, tableIdent, noScan) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTablesCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTablesCommand.scala index ef0701909de2e..c9b22a7d1b258 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTablesCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTablesCommand.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.{Row, SparkSession} */ case class AnalyzeTablesCommand( databaseName: Option[String], - noScan: Boolean) extends RunnableCommand { + noScan: Boolean) extends LeafRunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/InsertIntoDataSourceDirCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/InsertIntoDataSourceDirCommand.scala index d065bc0dab4cd..be680a733eac9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/InsertIntoDataSourceDirCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/InsertIntoDataSourceDirCommand.scala @@ -42,7 +42,7 @@ case class InsertIntoDataSourceDirCommand( storage: CatalogStorageFormat, provider: String, query: LogicalPlan, - overwrite: Boolean) extends RunnableCommand { + overwrite: Boolean) extends LeafRunnableCommand { override def innerChildren: Seq[LogicalPlan] = query :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala index 7d92e6e189fb2..0ebc927c552f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala @@ -34,7 +34,8 @@ import org.apache.spark.sql.types.{StringType, StructField, StructType} * set; * }}} */ -case class SetCommand(kv: Option[(String, Option[String])]) extends RunnableCommand with Logging { +case class SetCommand(kv: Option[(String, Option[String])]) + extends LeafRunnableCommand with Logging { private def keyValueOutput: Seq[Attribute] = { val schema = StructType( @@ -169,7 +170,7 @@ object SetCommand { * reset spark.sql.session.timeZone; * }}} */ -case class ResetCommand(config: Option[String]) extends RunnableCommand with IgnoreCachedData { +case class ResetCommand(config: Option[String]) extends LeafRunnableCommand with IgnoreCachedData { override def run(sparkSession: SparkSession): Seq[Row] = { val globalInitialConfigs = sparkSession.sharedState.conf diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala index 2f72af7f4b512..de5dbddbfa146 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.plans.logical.IgnoreCachedData /** * Clear all cached data from the in-memory cache. */ -case object ClearCacheCommand extends RunnableCommand with IgnoreCachedData { +case object ClearCacheCommand extends LeafRunnableCommand with IgnoreCachedData { override def run(sparkSession: SparkSession): Seq[Row] = { sparkSession.catalog.clearCache() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala index 8bc3cedff2426..7f4f816d328da 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan} +import org.apache.spark.sql.catalyst.trees.LeafLike import org.apache.spark.sql.connector.ExternalCommandRunner import org.apache.spark.sql.execution.{ExplainMode, LeafExecNode, SparkPlan, UnaryExecNode} import org.apache.spark.sql.execution.metric.SQLMetric @@ -48,6 +49,8 @@ trait RunnableCommand extends Command { def run(sparkSession: SparkSession): Seq[Row] } +trait LeafRunnableCommand extends RunnableCommand with LeafLike[LogicalPlan] + /** * A physical operator that executes the run method of a `RunnableCommand` and * saves the result to prevent multiple executions. @@ -132,6 +135,9 @@ case class DataWritingCommandExec(cmd: DataWritingCommand, child: SparkPlan) protected override def doExecute(): RDD[InternalRow] = { sqlContext.sparkContext.parallelize(sideEffectResult, 1) } + + override protected def withNewChildInternal(newChild: SparkPlan): DataWritingCommandExec = + copy(child = newChild) } /** @@ -150,7 +156,7 @@ case class DataWritingCommandExec(cmd: DataWritingCommand, child: SparkPlan) case class ExplainCommand( logicalPlan: LogicalPlan, mode: ExplainMode) - extends RunnableCommand { + extends LeafRunnableCommand { override val output: Seq[Attribute] = Seq(AttributeReference("plan", StringType, nullable = true)()) @@ -167,7 +173,7 @@ case class ExplainCommand( /** An explain command for users to see how a streaming batch is executed. */ case class StreamingExplainCommand( queryExecution: IncrementalExecution, - extended: Boolean) extends RunnableCommand { + extended: Boolean) extends LeafRunnableCommand { override val output: Seq[Attribute] = Seq(AttributeReference("plan", StringType, nullable = true)()) @@ -193,7 +199,7 @@ case class StreamingExplainCommand( case class ExternalCommandExecutor( runner: ExternalCommandRunner, command: String, - options: Map[String, String]) extends RunnableCommand { + options: Map[String, String]) extends LeafRunnableCommand { override def output: Seq[Attribute] = Seq(AttributeReference("command_output", StringType)()) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala index bb54457afdc78..bb3869ddf811e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala @@ -42,7 +42,7 @@ import org.apache.spark.sql.types.StructType * }}} */ case class CreateDataSourceTableCommand(table: CatalogTable, ignoreIfExists: Boolean) - extends RunnableCommand { + extends LeafRunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { assert(table.tableType != CatalogTableType.VIEW) @@ -227,4 +227,7 @@ case class CreateDataSourceTableAsSelectCommand( throw ex } } + + override protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan = + copy(query = newChild) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 7330f5bee9c21..c7456cd9d2058 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -69,7 +69,7 @@ case class CreateDatabaseCommand( path: Option[String], comment: Option[String], props: Map[String, String]) - extends RunnableCommand { + extends LeafRunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog @@ -105,7 +105,7 @@ case class DropDatabaseCommand( databaseName: String, ifExists: Boolean, cascade: Boolean) - extends RunnableCommand { + extends LeafRunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { sparkSession.sessionState.catalog.dropDatabase(databaseName, ifExists, cascade) @@ -125,7 +125,7 @@ case class DropDatabaseCommand( case class AlterDatabasePropertiesCommand( databaseName: String, props: Map[String, String]) - extends RunnableCommand { + extends LeafRunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog @@ -146,7 +146,7 @@ case class AlterDatabasePropertiesCommand( * }}} */ case class AlterDatabaseSetLocationCommand(databaseName: String, location: String) - extends RunnableCommand { + extends LeafRunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog @@ -171,7 +171,7 @@ case class DescribeDatabaseCommand( databaseName: String, extended: Boolean, override val output: Seq[Attribute]) - extends RunnableCommand { + extends LeafRunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { val dbMetadata: CatalogDatabase = @@ -211,7 +211,7 @@ case class DropTableCommand( tableName: TableIdentifier, ifExists: Boolean, isView: Boolean, - purge: Boolean) extends RunnableCommand { + purge: Boolean) extends LeafRunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog @@ -264,7 +264,7 @@ case class AlterTableSetPropertiesCommand( tableName: TableIdentifier, properties: Map[String, String], isView: Boolean) - extends RunnableCommand { + extends LeafRunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog @@ -295,7 +295,7 @@ case class AlterTableUnsetPropertiesCommand( propKeys: Seq[String], ifExists: Boolean, isView: Boolean) - extends RunnableCommand { + extends LeafRunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog @@ -333,7 +333,7 @@ case class AlterTableUnsetPropertiesCommand( case class AlterTableChangeColumnCommand( tableName: TableIdentifier, columnName: String, - newColumn: StructField) extends RunnableCommand { + newColumn: StructField) extends LeafRunnableCommand { // TODO: support change column name/dataType/metadata/position. override def run(sparkSession: SparkSession): Seq[Row] = { @@ -402,7 +402,7 @@ case class AlterTableSerDePropertiesCommand( serdeClassName: Option[String], serdeProperties: Option[Map[String, String]], partSpec: Option[TablePartitionSpec]) - extends RunnableCommand { + extends LeafRunnableCommand { // should never happen if we parsed things correctly require(serdeClassName.isDefined || serdeProperties.isDefined, @@ -454,7 +454,7 @@ case class AlterTableAddPartitionCommand( tableName: TableIdentifier, partitionSpecsAndLocs: Seq[(TablePartitionSpec, Option[String])], ifNotExists: Boolean) - extends RunnableCommand { + extends LeafRunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog @@ -509,7 +509,7 @@ case class AlterTableRenamePartitionCommand( tableName: TableIdentifier, oldPartition: TablePartitionSpec, newPartition: TablePartitionSpec) - extends RunnableCommand { + extends LeafRunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog @@ -556,7 +556,7 @@ case class AlterTableDropPartitionCommand( ifExists: Boolean, purge: Boolean, retainData: Boolean) - extends RunnableCommand { + extends LeafRunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog @@ -600,7 +600,7 @@ case class RepairTableCommand( tableName: TableIdentifier, enableAddPartitions: Boolean, enableDropPartitions: Boolean, - cmd: String = "MSCK REPAIR TABLE") extends RunnableCommand { + cmd: String = "MSCK REPAIR TABLE") extends LeafRunnableCommand { // These are list of statistics that can be collected quickly without requiring a scan of the data // see https://github.com/apache/hive/blob/master/ @@ -833,7 +833,7 @@ case class AlterTableSetLocationCommand( tableName: TableIdentifier, partitionSpec: Option[TablePartitionSpec], location: String) - extends RunnableCommand { + extends LeafRunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala index af5ba4839ea10..0eda90a596999 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala @@ -55,7 +55,7 @@ case class CreateFunctionCommand( isTemp: Boolean, ignoreIfExists: Boolean, replace: Boolean) - extends RunnableCommand { + extends LeafRunnableCommand { if (ignoreIfExists && replace) { throw new AnalysisException("CREATE FUNCTION with both IF NOT EXISTS and REPLACE" + @@ -112,7 +112,7 @@ case class CreateFunctionCommand( */ case class DescribeFunctionCommand( functionName: FunctionIdentifier, - isExtended: Boolean) extends RunnableCommand { + isExtended: Boolean) extends LeafRunnableCommand { override val output: Seq[Attribute] = { val schema = StructType(StructField("function_desc", StringType, nullable = false) :: Nil) @@ -177,7 +177,7 @@ case class DropFunctionCommand( functionName: String, ifExists: Boolean, isTemp: Boolean) - extends RunnableCommand { + extends LeafRunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog @@ -216,7 +216,7 @@ case class ShowFunctionsCommand( pattern: Option[String], showUserFunctions: Boolean, showSystemFunctions: Boolean, - override val output: Seq[Attribute]) extends RunnableCommand { + override val output: Seq[Attribute]) extends LeafRunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { val dbName = db.getOrElse(sparkSession.sessionState.catalog.getCurrentDatabase) @@ -255,7 +255,7 @@ case class ShowFunctionsCommand( case class RefreshFunctionCommand( databaseName: Option[String], functionName: String) - extends RunnableCommand { + extends LeafRunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/resources.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/resources.scala index 691837f38d7e3..af053f72cc647 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/resources.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/resources.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.types.StringType /** * Adds a jar to the current session so it can be used (for UDFs or serdes). */ -case class AddJarCommand(path: String) extends RunnableCommand { +case class AddJarCommand(path: String) extends LeafRunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { sparkSession.sessionState.resourceLoader.addJar(path) Seq.empty[Row] @@ -39,7 +39,7 @@ case class AddJarCommand(path: String) extends RunnableCommand { /** * Adds a file to the current session so it can be used. */ -case class AddFileCommand(path: String) extends RunnableCommand { +case class AddFileCommand(path: String) extends LeafRunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { val recursive = !sparkSession.sessionState.conf.addSingleFileInAddFile sparkSession.sparkContext.addFile(path, recursive) @@ -50,7 +50,7 @@ case class AddFileCommand(path: String) extends RunnableCommand { /** * Adds an archive to the current session so it can be used. */ -case class AddArchiveCommand(path: String) extends RunnableCommand { +case class AddArchiveCommand(path: String) extends LeafRunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { sparkSession.sparkContext.addArchive(path) Seq.empty[Row] @@ -61,7 +61,7 @@ case class AddArchiveCommand(path: String) extends RunnableCommand { * Returns a list of file paths that are added to resources. * If file paths are provided, return the ones that are added to resources. */ -case class ListFilesCommand(files: Seq[String] = Seq.empty[String]) extends RunnableCommand { +case class ListFilesCommand(files: Seq[String] = Seq.empty[String]) extends LeafRunnableCommand { override val output: Seq[Attribute] = { AttributeReference("Results", StringType, nullable = false)() :: Nil } @@ -88,7 +88,7 @@ case class ListFilesCommand(files: Seq[String] = Seq.empty[String]) extends Runn * Returns a list of jar files that are added to resources. * If jar files are provided, return the ones that are added to resources. */ -case class ListJarsCommand(jars: Seq[String] = Seq.empty[String]) extends RunnableCommand { +case class ListJarsCommand(jars: Seq[String] = Seq.empty[String]) extends LeafRunnableCommand { override val output: Seq[Attribute] = { AttributeReference("Results", StringType, nullable = false)() :: Nil } @@ -109,7 +109,8 @@ case class ListJarsCommand(jars: Seq[String] = Seq.empty[String]) extends Runnab * Returns a list of archive paths that are added to resources. * If archive paths are provided, return the ones that are added to resources. */ -case class ListArchivesCommand(archives: Seq[String] = Seq.empty[String]) extends RunnableCommand { +case class ListArchivesCommand(archives: Seq[String] = Seq.empty[String]) + extends LeafRunnableCommand { override val output: Seq[Attribute] = { AttributeReference("Results", StringType, nullable = false)() :: Nil } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index 488c628fb8633..72168f243900f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -82,7 +82,7 @@ case class CreateTableLikeCommand( fileFormat: CatalogStorageFormat, provider: Option[String], properties: Map[String, String] = Map.empty, - ifNotExists: Boolean) extends RunnableCommand { + ifNotExists: Boolean) extends LeafRunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog @@ -161,7 +161,7 @@ case class CreateTableLikeCommand( */ case class CreateTableCommand( table: CatalogTable, - ignoreIfExists: Boolean) extends RunnableCommand { + ignoreIfExists: Boolean) extends LeafRunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { sparkSession.sessionState.catalog.createTable(table, ignoreIfExists) @@ -183,7 +183,7 @@ case class AlterTableRenameCommand( oldName: TableIdentifier, newName: TableIdentifier, isView: Boolean) - extends RunnableCommand { + extends LeafRunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog @@ -224,7 +224,7 @@ case class AlterTableRenameCommand( */ case class AlterTableAddColumnsCommand( table: TableIdentifier, - colsToAdd: Seq[StructField]) extends RunnableCommand { + colsToAdd: Seq[StructField]) extends LeafRunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog val catalogTable = verifyAlterTableAddColumn(sparkSession.sessionState.conf, catalog, table) @@ -300,7 +300,7 @@ case class LoadDataCommand( path: String, isLocal: Boolean, isOverwrite: Boolean, - partition: Option[TablePartitionSpec]) extends RunnableCommand { + partition: Option[TablePartitionSpec]) extends LeafRunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog @@ -441,7 +441,7 @@ object LoadDataCommand { */ case class TruncateTableCommand( tableName: TableIdentifier, - partitionSpec: Option[TablePartitionSpec]) extends RunnableCommand { + partitionSpec: Option[TablePartitionSpec]) extends LeafRunnableCommand { override def run(spark: SparkSession): Seq[Row] = { val catalog = spark.sessionState.catalog @@ -580,7 +580,7 @@ case class TruncateTableCommand( } } -abstract class DescribeCommandBase extends RunnableCommand { +abstract class DescribeCommandBase extends LeafRunnableCommand { protected def describeSchema( schema: StructType, buffer: ArrayBuffer[Row], @@ -745,7 +745,7 @@ case class DescribeColumnCommand( colNameParts: Seq[String], isExtended: Boolean, override val output: Seq[Attribute]) - extends RunnableCommand { + extends LeafRunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { @@ -828,7 +828,7 @@ case class ShowTablesCommand( tableIdentifierPattern: Option[String], override val output: Seq[Attribute], isExtended: Boolean = false, - partitionSpec: Option[TablePartitionSpec] = None) extends RunnableCommand { + partitionSpec: Option[TablePartitionSpec] = None) extends LeafRunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { // Since we need to return a Seq of rows, we will call getTables directly @@ -888,7 +888,7 @@ case class ShowTablesCommand( case class ShowTablePropertiesCommand( table: TableIdentifier, propertyKey: Option[String], - override val output: Seq[Attribute]) extends RunnableCommand { + override val output: Seq[Attribute]) extends LeafRunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog @@ -924,7 +924,7 @@ case class ShowTablePropertiesCommand( case class ShowColumnsCommand( databaseName: Option[String], tableName: TableIdentifier, - override val output: Seq[Attribute]) extends RunnableCommand { + override val output: Seq[Attribute]) extends LeafRunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog @@ -955,7 +955,7 @@ case class ShowColumnsCommand( case class ShowPartitionsCommand( tableName: TableIdentifier, override val output: Seq[Attribute], - spec: Option[TablePartitionSpec]) extends RunnableCommand { + spec: Option[TablePartitionSpec]) extends LeafRunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog @@ -1080,7 +1080,7 @@ trait ShowCreateTableCommandBase { case class ShowCreateTableCommand( table: TableIdentifier, override val output: Seq[Attribute]) - extends RunnableCommand with ShowCreateTableCommandBase { + extends LeafRunnableCommand with ShowCreateTableCommandBase { override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog @@ -1234,7 +1234,7 @@ case class ShowCreateTableCommand( case class ShowCreateTableAsSerdeCommand( table: TableIdentifier, override val output: Seq[Attribute]) - extends RunnableCommand with ShowCreateTableCommandBase { + extends LeafRunnableCommand with ShowCreateTableCommandBase { override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog @@ -1354,7 +1354,7 @@ case class ShowCreateTableAsSerdeCommand( * }}} */ case class RefreshTableCommand(tableIdent: TableIdentifier) - extends RunnableCommand { + extends LeafRunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { // Refresh the given table's metadata. If this table is cached as an InMemoryRelation, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala index b302b268b3afc..93ea22682a942 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala @@ -66,7 +66,7 @@ case class CreateViewCommand( allowExisting: Boolean, replace: Boolean, viewType: ViewType) - extends RunnableCommand { + extends LeafRunnableCommand { import ViewHelper._ @@ -233,7 +233,7 @@ case class CreateViewCommand( case class AlterViewAsCommand( name: TableIdentifier, originalText: String, - query: LogicalPlan) extends RunnableCommand { + query: LogicalPlan) extends LeafRunnableCommand { import ViewHelper._ @@ -301,7 +301,7 @@ case class AlterViewAsCommand( case class ShowViewsCommand( databaseName: String, tableIdentifierPattern: Option[String], - override val output: Seq[Attribute]) extends RunnableCommand { + override val output: Seq[Attribute]) extends LeafRunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 5f019557d337a..6300e10c0bb3d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -68,6 +68,9 @@ object FileFormatWriter extends Logging { |}""".stripMargin }) } + + override protected def withNewChildInternal(newChild: Expression): Empty2Null = + copy(child = newChild) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala index bd9cc0e44fca3..789b1d714fcb6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.datasources import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.command.RunnableCommand +import org.apache.spark.sql.execution.command.LeafRunnableCommand import org.apache.spark.sql.sources.InsertableRelation @@ -31,7 +31,7 @@ case class InsertIntoDataSourceCommand( logicalRelation: LogicalRelation, query: LogicalPlan, overwrite: Boolean) - extends RunnableCommand { + extends LeafRunnableCommand { override def innerChildren: Seq[QueryPlan[_]] = Seq(query) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala index b29ccb85d77a6..267b360b474ca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala @@ -270,4 +270,7 @@ case class InsertIntoHadoopFsRelationCommand( } }.toMap } + + override protected def withNewChildInternal( + newChild: LogicalPlan): InsertIntoHadoopFsRelationCommand = copy(query = newChild) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala index 5195bb295f5bf..486f73cab44f7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.datasources import org.apache.spark.sql.{Dataset, Row, SaveMode, SparkSession} import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.command.RunnableCommand +import org.apache.spark.sql.execution.command.LeafRunnableCommand import org.apache.spark.sql.sources.CreatableRelationProvider /** @@ -36,7 +36,7 @@ case class SaveIntoDataSourceCommand( query: LogicalPlan, dataSource: CreatableRelationProvider, options: Map[String, String], - mode: SaveMode) extends RunnableCommand { + mode: SaveMode) extends LeafRunnableCommand { override def innerChildren: Seq[QueryPlan[_]] = Seq(query) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index 137e50236a295..221db208bc629 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.CatalogTable import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.command.{DDLUtils, RunnableCommand} +import org.apache.spark.sql.execution.command.{DDLUtils, LeafRunnableCommand} import org.apache.spark.sql.execution.command.ViewHelper.createTemporaryViewRelation import org.apache.spark.sql.internal.StaticSQLConf import org.apache.spark.sql.types._ @@ -52,6 +52,10 @@ case class CreateTable( override def children: Seq[LogicalPlan] = query.toSeq override def output: Seq[Attribute] = Seq.empty override lazy val resolved: Boolean = false + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[LogicalPlan]): LogicalPlan = + copy(query = if (query.isDefined) Some(newChildren.head) else None) } /** @@ -63,7 +67,7 @@ case class CreateTempViewUsing( replace: Boolean, global: Boolean, provider: String, - options: Map[String, String]) extends RunnableCommand { + options: Map[String, String]) extends LeafRunnableCommand { if (tableIdent.database.isDefined) { throw new AnalysisException( @@ -123,7 +127,7 @@ case class CreateTempViewUsing( } case class RefreshResource(path: String) - extends RunnableCommand { + extends LeafRunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { sparkSession.catalog.refreshByPath(path) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index 2ed0e06807bf0..764b63db35a7e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -47,6 +47,8 @@ case class WriteToDataSourceV2(batchWrite: BatchWrite, query: LogicalPlan) extends UnaryNode { override def child: LogicalPlan = query override def output: Seq[Attribute] = Nil + override protected def withNewChildInternal(newChild: LogicalPlan): WriteToDataSourceV2 = + copy(query = newChild) } /** @@ -82,6 +84,9 @@ case class CreateTableAsSelectExec( partitioning.toArray, properties.asJava) writeToTable(catalog, table, writeOptions, ident) } + + override protected def withNewChildInternal(newChild: SparkPlan): CreateTableAsSelectExec = + copy(query = newChild) } /** @@ -116,6 +121,9 @@ case class AtomicCreateTableAsSelectExec( ident, schema, partitioning.toArray, properties.asJava) writeToTable(catalog, stagedTable, writeOptions, ident) } + + override protected def withNewChildInternal(newChild: SparkPlan): AtomicCreateTableAsSelectExec = + copy(query = newChild) } /** @@ -160,6 +168,9 @@ case class ReplaceTableAsSelectExec( ident, schema, partitioning.toArray, properties.asJava) writeToTable(catalog, table, writeOptions, ident) } + + override protected def withNewChildInternal(newChild: SparkPlan): ReplaceTableAsSelectExec = + copy(query = newChild) } /** @@ -207,6 +218,9 @@ case class AtomicReplaceTableAsSelectExec( } writeToTable(catalog, staged, writeOptions, ident) } + + override protected def withNewChildInternal(newChild: SparkPlan): AtomicReplaceTableAsSelectExec = + copy(query = newChild) } /** @@ -217,7 +231,10 @@ case class AtomicReplaceTableAsSelectExec( case class AppendDataExec( query: SparkPlan, refreshCache: () => Unit, - write: Write) extends V2ExistingTableWriteExec + write: Write) extends V2ExistingTableWriteExec { + override protected def withNewChildInternal(newChild: SparkPlan): AppendDataExec = + copy(query = newChild) +} /** * Physical plan node for overwrite into a v2 table. @@ -232,7 +249,10 @@ case class AppendDataExec( case class OverwriteByExpressionExec( query: SparkPlan, refreshCache: () => Unit, - write: Write) extends V2ExistingTableWriteExec + write: Write) extends V2ExistingTableWriteExec { + override protected def withNewChildInternal(newChild: SparkPlan): OverwriteByExpressionExec = + copy(query = newChild) +} /** * Physical plan node for dynamic partition overwrite into a v2 table. @@ -246,7 +266,10 @@ case class OverwriteByExpressionExec( case class OverwritePartitionsDynamicExec( query: SparkPlan, refreshCache: () => Unit, - write: Write) extends V2ExistingTableWriteExec + write: Write) extends V2ExistingTableWriteExec { + override protected def withNewChildInternal(newChild: SparkPlan): OverwritePartitionsDynamicExec = + copy(query = newChild) +} case class WriteToDataSourceV2Exec( batchWrite: BatchWrite, @@ -255,6 +278,9 @@ case class WriteToDataSourceV2Exec( override protected def run(): Seq[InternalRow] = { writeWithV2(batchWrite) } + + override protected def withNewChildInternal(newChild: SparkPlan): WriteToDataSourceV2Exec = + copy(query = newChild) } trait V2ExistingTableWriteExec extends V2TableWriteExec { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index 3cbebca14f7dc..6c744e66d7abb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -288,5 +288,8 @@ package object debug { } override def supportsColumnar: Boolean = child.supportsColumnar + + override protected def withNewChildInternal(newChild: SparkPlan): DebugExec = + copy(child = newChild) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala index ca640c43a03a0..94a8a8f0d9e5c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala @@ -205,6 +205,9 @@ case class BroadcastExchangeExec( ex) } } + + override protected def withNewChildInternal(newChild: SparkPlan): BroadcastExchangeExec = + copy(child = newChild) } object BroadcastExchangeExec { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index 2a7b12f7f515a..6ec376764a38f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -166,6 +166,9 @@ case class ShuffleExchangeExec( } cachedShuffleRDD } + + override protected def withNewChildInternal(newChild: SparkPlan): ShuffleExchangeExec = + copy(child = newChild) } object ShuffleExchangeExec { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index cec1286c98a7e..ccbcaa2573f64 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -254,4 +254,8 @@ case class BroadcastHashJoinExec( super.codegenAnti(ctx, input) } } + + override protected def withNewChildrenInternal( + newLeft: SparkPlan, newRight: SparkPlan): BroadcastHashJoinExec = + copy(left = newLeft, right = newRight) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala index fa1a57a8ae3a5..acdd346c84594 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala @@ -548,4 +548,8 @@ case class BroadcastNestedLoopJoinExec( """.stripMargin } } + + override protected def withNewChildrenInternal( + newLeft: SparkPlan, newRight: SparkPlan): BroadcastNestedLoopJoinExec = + copy(left = newLeft, right = newRight) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala index b6386d0d11b4b..1b2d3731f7e8d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala @@ -101,4 +101,8 @@ case class CartesianProductExec( } } } + + override protected def withNewChildrenInternal( + newLeft: SparkPlan, newRight: SparkPlan): CartesianProductExec = + copy(left = newLeft, right = newRight) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala index cd57408e7972d..8514fc2fc4da1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala @@ -318,4 +318,8 @@ case class ShuffledHashJoinExec( v => s"$v = $thisPlan.buildHashedRelation(inputs[1]);", forceInline = true) HashedRelationInfo(relationTerm, keyIsUnique = false, isEmpty = false) } + + override protected def withNewChildrenInternal( + newLeft: SparkPlan, newRight: SparkPlan): ShuffledHashJoinExec = + copy(left = newLeft, right = newRight) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index eabbdc8ed3243..8e0b7173ad453 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -633,6 +633,10 @@ case class SortMergeJoinExec( |$eagerCleanup """.stripMargin } + + override protected def withNewChildrenInternal( + newLeft: SparkPlan, newRight: SparkPlan): SortMergeJoinExec = + copy(left = newLeft, right = newRight) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index e5a299523c79c..5114c075a72d0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -73,6 +73,9 @@ case class CollectLimitExec(limit: Int, child: SparkPlan) extends LimitExec { singlePartitionRDD.mapPartitionsInternal(_.take(limit)) } } + + override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = + copy(child = newChild) } /** @@ -95,6 +98,9 @@ case class CollectTailExec(limit: Int, child: SparkPlan) extends LimitExec { // job launch, we might just have to mimic the implementation of `CollectLimitExec`. sparkContext.parallelize(executeCollect(), numSlices = 1) } + + override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = + copy(child = newChild) } object BaseLimitExec { @@ -160,7 +166,10 @@ trait BaseLimitExec extends LimitExec with CodegenSupport { /** * Take the first `limit` elements of each child partition, but do not collect or shuffle them. */ -case class LocalLimitExec(limit: Int, child: SparkPlan) extends BaseLimitExec +case class LocalLimitExec(limit: Int, child: SparkPlan) extends BaseLimitExec { + override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = + copy(child = newChild) +} /** * Take the first `limit` elements of the child's single output partition. @@ -168,6 +177,9 @@ case class LocalLimitExec(limit: Int, child: SparkPlan) extends BaseLimitExec case class GlobalLimitExec(limit: Int, child: SparkPlan) extends BaseLimitExec { override def requiredChildDistribution: List[Distribution] = AllTuples :: Nil + + override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = + copy(child = newChild) } /** @@ -249,4 +261,7 @@ case class TakeOrderedAndProjectExec( s"TakeOrderedAndProject(limit=$limit, orderBy=$orderByString, output=$outputString)" } + + override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = + copy(child = newChild) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index c08db132c946f..fa46f75abe8f3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -99,6 +99,9 @@ case class DeserializeToObjectExec( iter.map(projection) } } + + override protected def withNewChildInternal(newChild: SparkPlan): DeserializeToObjectExec = + copy(child = newChild) } /** @@ -135,6 +138,9 @@ case class SerializeFromObjectExec( iter.map(projection) } } + + override protected def withNewChildInternal(newChild: SparkPlan): SerializeFromObjectExec = + copy(child = newChild) } /** @@ -195,6 +201,9 @@ case class MapPartitionsExec( func(iter.map(getObject)).map(outputObject) } } + + override protected def withNewChildInternal(newChild: SparkPlan): MapPartitionsExec = + copy(child = newChild) } /** @@ -252,6 +261,9 @@ case class MapPartitionsInRWithArrowExec( }.map(outputProject) } } + + override protected def withNewChildInternal(newChild: SparkPlan): MapPartitionsInRWithArrowExec = + copy(child = newChild) } /** @@ -304,6 +316,9 @@ case class MapElementsExec( override def outputOrdering: Seq[SortOrder] = child.outputOrdering override def outputPartitioning: Partitioning = child.outputPartitioning + + override protected def withNewChildInternal(newChild: SparkPlan): MapElementsExec = + copy(child = newChild) } /** @@ -333,6 +348,9 @@ case class AppendColumnsExec( } } } + + override protected def withNewChildInternal(newChild: SparkPlan): AppendColumnsExec = + copy(child = newChild) } /** @@ -366,6 +384,9 @@ case class AppendColumnsWithObjectExec( } } } + + override protected def withNewChildInternal(newChild: SparkPlan): AppendColumnsWithObjectExec = + copy(child = newChild) } /** @@ -405,6 +426,9 @@ case class MapGroupsExec( } } } + + override protected def withNewChildInternal(newChild: SparkPlan): MapGroupsExec = + copy(child = newChild) } object MapGroupsExec { @@ -495,6 +519,9 @@ case class FlatMapGroupsInRExec( } } } + + override protected def withNewChildInternal(newChild: SparkPlan): FlatMapGroupsInRExec = + copy(child = newChild) } /** @@ -577,6 +604,9 @@ case class FlatMapGroupsInRWithArrowExec( }.map(outputProject) } } + + override protected def withNewChildInternal(newChild: SparkPlan): FlatMapGroupsInRWithArrowExec = + copy(child = newChild) } /** @@ -623,4 +653,7 @@ case class CoGroupExec( } } } + + override protected def withNewChildrenInternal( + newLeft: SparkPlan, newRight: SparkPlan): CoGroupExec = copy(left = newLeft, right = newRight) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala index dadf1129c34b5..5019008ec5e32 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala @@ -154,4 +154,7 @@ case class AggregateInPandasExec( } }} } + + override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = + copy(child = newChild) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala index 67f075f0785fb..096712cf93529 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala @@ -94,4 +94,7 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], resultAttrs: Seq[Attribute] batch.rowIterator.asScala } } + + override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = + copy(child = newChild) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala index 2ab7262763835..10f7966b93d1a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala @@ -103,4 +103,7 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], resultAttrs: Seq[Attribute] } } } + + override protected def withNewChildInternal(newChild: SparkPlan): BatchEvalPythonExec = + copy(child = newChild) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala index b079405bdc2f8..e830ea6b54662 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala @@ -103,4 +103,8 @@ case class FlatMapCoGroupsInPandasExec( } } } + + override protected def withNewChildrenInternal( + newLeft: SparkPlan, newRight: SparkPlan): FlatMapCoGroupsInPandasExec = + copy(left = newLeft, right = newRight) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala index 5032bc81327b9..3a3a6022f9985 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala @@ -94,4 +94,7 @@ case class FlatMapGroupsInPandasExec( executePython(data, output, runner) }} } + + override protected def withNewChildInternal(newChild: SparkPlan): FlatMapGroupsInPandasExec = + copy(child = newChild) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInPandasExec.scala index 71f51f1abc6f5..0434710da43ae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInPandasExec.scala @@ -93,4 +93,7 @@ case class MapInPandasExec( }.map(unsafeProj) } } + + override protected def withNewChildInternal(newChild: SparkPlan): MapInPandasExec = + copy(child = newChild) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala index 983fe9db73824..909a026bac7d6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala @@ -401,4 +401,7 @@ case class WindowInPandasExec( } } } + + override protected def withNewChildInternal(newChild: SparkPlan): WindowInPandasExec = + copy(child = newChild) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala index 20fb06a851dd7..7e094fee32547 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala @@ -125,4 +125,7 @@ case class EventTimeWatermarkExec( a } } + + override protected def withNewChildInternal(newChild: SparkPlan): EventTimeWatermarkExec = + copy(child = newChild) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala index 747094b7791c1..fe788dd8b9408 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala @@ -246,4 +246,7 @@ case class FlatMapGroupsWithStateExec( CompletionIterator[InternalRow, Iterator[InternalRow]](mappedIterator, onIteratorCompletion) } } + + override protected def withNewChildInternal(newChild: SparkPlan): FlatMapGroupsWithStateExec = + copy(child = newChild) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala index 73d2f826f1126..b2c8141e5db0e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala @@ -620,4 +620,8 @@ case class StreamingSymmetricHashJoinExec( def numUpdatedStateRows: Long = updatedStateRowsCount } + + override protected def withNewChildrenInternal( + newLeft: SparkPlan, newRight: SparkPlan): StreamingSymmetricHashJoinExec = + copy(left = newLeft, right = newRight) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala index 1923fc969801e..ceb52f520df66 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala @@ -28,4 +28,6 @@ case class WriteToContinuousDataSource(write: StreamingWrite, query: LogicalPlan extends UnaryNode { override def child: LogicalPlan = query override def output: Seq[Attribute] = Nil + override protected def withNewChildInternal( + newChild: LogicalPlan): WriteToContinuousDataSource = copy(query = newChild) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala index f1898ad3f27ca..1e0caf4785d5e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala @@ -70,4 +70,7 @@ case class WriteToContinuousDataSourceExec(write: StreamingWrite, query: SparkPl sparkContext.emptyRDD } + + override protected def withNewChildInternal( + newChild: SparkPlan): WriteToContinuousDataSourceExec = copy(query = newChild) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/WriteToMicroBatchDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/WriteToMicroBatchDataSource.scala index 4bacd71a55ec1..7989b941563a7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/WriteToMicroBatchDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/WriteToMicroBatchDataSource.scala @@ -36,4 +36,7 @@ case class WriteToMicroBatchDataSource(write: StreamingWrite, query: LogicalPlan def createPlan(batchId: Long): WriteToDataSourceV2 = { WriteToDataSourceV2(new MicroBatchWrite(batchId, write), query) } + + override protected def withNewChildInternal(newChild: LogicalPlan): WriteToMicroBatchDataSource = + copy(query = newChild) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index e52f2a17b659d..b52603ebc0443 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -281,6 +281,9 @@ case class StateStoreRestoreExec( ClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil } } + + override protected def withNewChildInternal(newChild: SparkPlan): StateStoreRestoreExec = + copy(child = newChild) } /** @@ -436,6 +439,9 @@ case class StateStoreSaveExec( eventTimeWatermark.isDefined && newMetadata.batchWatermarkMs > eventTimeWatermark.get } + + override protected def withNewChildInternal(newChild: SparkPlan): StateStoreSaveExec = + copy(child = newChild) } /** Physical operator for executing streaming Deduplicate. */ @@ -509,6 +515,9 @@ case class StreamingDeduplicateExec( override def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = { eventTimeWatermark.isDefined && newMetadata.batchWatermarkMs > eventTimeWatermark.get } + + override protected def withNewChildInternal(newChild: SparkPlan): StreamingDeduplicateExec = + copy(child = newChild) } object StreamingDeduplicateExec { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/streamingLimits.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/streamingLimits.scala index e53e0644eb268..51723a25e04e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/streamingLimits.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/streamingLimits.scala @@ -95,6 +95,9 @@ case class StreamingGlobalLimitExec( private def getValueRow(value: Long): UnsafeRow = { UnsafeProjection.create(valueSchema)(new GenericInternalRow(Array[Any](value))) } + + override protected def withNewChildInternal(newChild: SparkPlan): StreamingGlobalLimitExec = + copy(child = newChild) } @@ -133,4 +136,7 @@ case class StreamingLocalLimitExec(limit: Int, child: SparkPlan) override def outputPartitioning: Partitioning = child.outputPartitioning override def output: Seq[Attribute] = child.output + + override protected def withNewChildInternal(newChild: SparkPlan): StreamingLocalLimitExec = + copy(child = newChild) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala index 9c950fd8033a7..15b85013c4621 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala @@ -166,6 +166,9 @@ case class InSubqueryExec( exprId = ExprId(0), resultBroadcast = null) } + + override protected def withNewChildInternal(newChild: Expression): InSubqueryExec = + copy(child = newChild) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala index 6e0e36cbe5901..8011c803394d3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala @@ -211,4 +211,7 @@ case class WindowExec( } } } + + override protected def withNewChildInternal(newChild: SparkPlan): WindowExec = + copy(child = newChild) } diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q14a.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q14a.sf100/explain.txt index e4ec487623d2c..0c191216db316 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q14a.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q14a.sf100/explain.txt @@ -720,7 +720,7 @@ Input [6]: [i_brand_id#104, i_class_id#105, i_category_id#106, sales#116, number (130) Expand [codegen id : 130] Input [6]: [sales#68, number_sales#69, channel#73, i_brand_id#54, i_class_id#55, i_category_id#56] -Arguments: [List(sales#68, number_sales#69, channel#73, i_brand_id#54, i_class_id#55, i_category_id#56, 0), List(sales#68, number_sales#69, channel#73, i_brand_id#54, i_class_id#55, null, 1), List(sales#68, number_sales#69, channel#73, i_brand_id#54, null, null, 3), List(sales#68, number_sales#69, channel#73, null, null, null, 7), List(sales#68, number_sales#69, null, null, null, null, 15)], [sales#68, number_sales#69, channel#120, i_brand_id#121, i_class_id#122, i_category_id#123, spark_grouping_id#124] +Arguments: [ArrayBuffer(sales#68, number_sales#69, channel#73, i_brand_id#54, i_class_id#55, i_category_id#56, 0), ArrayBuffer(sales#68, number_sales#69, channel#73, i_brand_id#54, i_class_id#55, null, 1), ArrayBuffer(sales#68, number_sales#69, channel#73, i_brand_id#54, null, null, 3), ArrayBuffer(sales#68, number_sales#69, channel#73, null, null, null, 7), ArrayBuffer(sales#68, number_sales#69, null, null, null, null, 15)], [sales#68, number_sales#69, channel#120, i_brand_id#121, i_class_id#122, i_category_id#123, spark_grouping_id#124] (131) HashAggregate [codegen id : 130] Input [7]: [sales#68, number_sales#69, channel#120, i_brand_id#121, i_class_id#122, i_category_id#123, spark_grouping_id#124] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q14a/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q14a/explain.txt index 6f61fc8e96ae1..ffcbef4ce1602 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q14a/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q14a/explain.txt @@ -625,7 +625,7 @@ Input [6]: [i_brand_id#96, i_class_id#97, i_category_id#98, sales#109, number_sa (111) Expand [codegen id : 79] Input [6]: [sales#63, number_sales#64, channel#68, i_brand_id#46, i_class_id#47, i_category_id#48] -Arguments: [List(sales#63, number_sales#64, channel#68, i_brand_id#46, i_class_id#47, i_category_id#48, 0), List(sales#63, number_sales#64, channel#68, i_brand_id#46, i_class_id#47, null, 1), List(sales#63, number_sales#64, channel#68, i_brand_id#46, null, null, 3), List(sales#63, number_sales#64, channel#68, null, null, null, 7), List(sales#63, number_sales#64, null, null, null, null, 15)], [sales#63, number_sales#64, channel#113, i_brand_id#114, i_class_id#115, i_category_id#116, spark_grouping_id#117] +Arguments: [ArrayBuffer(sales#63, number_sales#64, channel#68, i_brand_id#46, i_class_id#47, i_category_id#48, 0), ArrayBuffer(sales#63, number_sales#64, channel#68, i_brand_id#46, i_class_id#47, null, 1), ArrayBuffer(sales#63, number_sales#64, channel#68, i_brand_id#46, null, null, 3), ArrayBuffer(sales#63, number_sales#64, channel#68, null, null, null, 7), ArrayBuffer(sales#63, number_sales#64, null, null, null, null, 15)], [sales#63, number_sales#64, channel#113, i_brand_id#114, i_class_id#115, i_category_id#116, spark_grouping_id#117] (112) HashAggregate [codegen id : 79] Input [7]: [sales#63, number_sales#64, channel#113, i_brand_id#114, i_class_id#115, i_category_id#116, spark_grouping_id#117] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q5.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q5.sf100/explain.txt index 28a457258eff7..c9a772d3163ca 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q5.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q5.sf100/explain.txt @@ -429,7 +429,7 @@ Results [5]: [MakeDecimal(sum(UnscaledValue(sales_price#95))#129,17,2) AS sales# (77) Expand [codegen id : 23] Input [5]: [sales#41, RETURNS#42, profit#43, channel#44, id#45] -Arguments: [List(sales#41, returns#42, profit#43, channel#44, id#45, 0), List(sales#41, returns#42, profit#43, channel#44, null, 1), List(sales#41, returns#42, profit#43, null, null, 3)], [sales#41, returns#42, profit#43, channel#138, id#139, spark_grouping_id#140] +Arguments: [ArrayBuffer(sales#41, returns#42, profit#43, channel#44, id#45, 0), ArrayBuffer(sales#41, returns#42, profit#43, channel#44, null, 1), ArrayBuffer(sales#41, returns#42, profit#43, null, null, 3)], [sales#41, returns#42, profit#43, channel#138, id#139, spark_grouping_id#140] (78) HashAggregate [codegen id : 23] Input [6]: [sales#41, returns#42, profit#43, channel#138, id#139, spark_grouping_id#140] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q5/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q5/explain.txt index cb130ce17795a..c01302bf69a40 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q5/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q5/explain.txt @@ -414,7 +414,7 @@ Results [5]: [MakeDecimal(sum(UnscaledValue(sales_price#95))#128,17,2) AS sales# (74) Expand [codegen id : 20] Input [5]: [sales#41, RETURNS#42, profit#43, channel#44, id#45] -Arguments: [List(sales#41, returns#42, profit#43, channel#44, id#45, 0), List(sales#41, returns#42, profit#43, channel#44, null, 1), List(sales#41, returns#42, profit#43, null, null, 3)], [sales#41, returns#42, profit#43, channel#137, id#138, spark_grouping_id#139] +Arguments: [ArrayBuffer(sales#41, returns#42, profit#43, channel#44, id#45, 0), ArrayBuffer(sales#41, returns#42, profit#43, channel#44, null, 1), ArrayBuffer(sales#41, returns#42, profit#43, null, null, 3)], [sales#41, returns#42, profit#43, channel#137, id#138, spark_grouping_id#139] (75) HashAggregate [codegen id : 20] Input [6]: [sales#41, returns#42, profit#43, channel#137, id#138, spark_grouping_id#139] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q77.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q77.sf100/explain.txt index 4b2299ca2e749..dc5a7fc792af9 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q77.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q77.sf100/explain.txt @@ -488,7 +488,7 @@ Input [6]: [wp_web_page_sk#77, sales#86, profit#87, wp_web_page_sk#92, returns#1 (85) Expand [codegen id : 23] Input [5]: [sales#18, returns#37, profit#38, channel#39, id#40] -Arguments: [List(sales#18, returns#37, profit#38, channel#39, id#40, 0), List(sales#18, returns#37, profit#38, channel#39, null, 1), List(sales#18, returns#37, profit#38, null, null, 3)], [sales#18, returns#37, profit#38, channel#108, id#109, spark_grouping_id#110] +Arguments: [ArrayBuffer(sales#18, returns#37, profit#38, channel#39, id#40, 0), ArrayBuffer(sales#18, returns#37, profit#38, channel#39, null, 1), ArrayBuffer(sales#18, returns#37, profit#38, null, null, 3)], [sales#18, returns#37, profit#38, channel#108, id#109, spark_grouping_id#110] (86) HashAggregate [codegen id : 23] Input [6]: [sales#18, returns#37, profit#38, channel#108, id#109, spark_grouping_id#110] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q77/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q77/explain.txt index 618da39637e23..62bd5aba36e53 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q77/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q77/explain.txt @@ -488,7 +488,7 @@ Input [6]: [wp_web_page_sk#77, sales#86, profit#87, wp_web_page_sk#93, returns#1 (85) Expand [codegen id : 23] Input [5]: [sales#18, returns#37, profit#38, channel#39, id#40] -Arguments: [List(sales#18, returns#37, profit#38, channel#39, id#40, 0), List(sales#18, returns#37, profit#38, channel#39, null, 1), List(sales#18, returns#37, profit#38, null, null, 3)], [sales#18, returns#37, profit#38, channel#108, id#109, spark_grouping_id#110] +Arguments: [ArrayBuffer(sales#18, returns#37, profit#38, channel#39, id#40, 0), ArrayBuffer(sales#18, returns#37, profit#38, channel#39, null, 1), ArrayBuffer(sales#18, returns#37, profit#38, null, null, 3)], [sales#18, returns#37, profit#38, channel#108, id#109, spark_grouping_id#110] (86) HashAggregate [codegen id : 23] Input [6]: [sales#18, returns#37, profit#38, channel#108, id#109, spark_grouping_id#110] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q80.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q80.sf100/explain.txt index bdb1a52a18f2d..040407d99e48d 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q80.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q80.sf100/explain.txt @@ -590,7 +590,7 @@ Results [5]: [MakeDecimal(sum(UnscaledValue(ws_ext_sales_price#90))#117,17,2) AS (107) Expand [codegen id : 31] Input [5]: [sales#42, returns#43, profit#44, channel#45, id#46] -Arguments: [List(sales#42, returns#43, profit#44, channel#45, id#46, 0), List(sales#42, returns#43, profit#44, channel#45, null, 1), List(sales#42, returns#43, profit#44, null, null, 3)], [sales#42, returns#43, profit#44, channel#125, id#126, spark_grouping_id#127] +Arguments: [ArrayBuffer(sales#42, returns#43, profit#44, channel#45, id#46, 0), ArrayBuffer(sales#42, returns#43, profit#44, channel#45, null, 1), ArrayBuffer(sales#42, returns#43, profit#44, null, null, 3)], [sales#42, returns#43, profit#44, channel#125, id#126, spark_grouping_id#127] (108) HashAggregate [codegen id : 31] Input [6]: [sales#42, returns#43, profit#44, channel#125, id#126, spark_grouping_id#127] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q80/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q80/explain.txt index aa15d27d4e562..467127aa2e493 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q80/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q80/explain.txt @@ -590,7 +590,7 @@ Results [5]: [MakeDecimal(sum(UnscaledValue(ws_ext_sales_price#90))#117,17,2) AS (107) Expand [codegen id : 31] Input [5]: [sales#42, returns#43, profit#44, channel#45, id#46] -Arguments: [List(sales#42, returns#43, profit#44, channel#45, id#46, 0), List(sales#42, returns#43, profit#44, channel#45, null, 1), List(sales#42, returns#43, profit#44, null, null, 3)], [sales#42, returns#43, profit#44, channel#125, id#126, spark_grouping_id#127] +Arguments: [ArrayBuffer(sales#42, returns#43, profit#44, channel#45, id#46, 0), ArrayBuffer(sales#42, returns#43, profit#44, channel#45, null, 1), ArrayBuffer(sales#42, returns#43, profit#44, null, null, 3)], [sales#42, returns#43, profit#44, channel#125, id#126, spark_grouping_id#127] (108) HashAggregate [codegen id : 31] Input [6]: [sales#42, returns#43, profit#44, channel#125, id#126, spark_grouping_id#127] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 6914330bb289d..70dc0d09bcad5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -3637,5 +3637,7 @@ object DataFrameFunctionsSuite { override def dataType: DataType = child.dataType override lazy val resolved = true override def eval(input: InternalRow): Any = child.eval(input) + override protected def withNewChildInternal(newChild: Expression): CodegenFallbackExpr = + copy(child = newChild) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala index 9192370cfa620..bec68fae08719 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala @@ -21,10 +21,10 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} -import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.{LeafExecNode, SparkPlan} import org.apache.spark.sql.test.SharedSparkSession -case class FastOperator(output: Seq[Attribute]) extends SparkPlan { +case class FastOperator(output: Seq[Attribute]) extends LeafExecNode { override protected def doExecute(): RDD[InternalRow] = { val str = Literal("so fast").value @@ -35,7 +35,6 @@ case class FastOperator(output: Seq[Attribute]) extends SparkPlan { } override def producedAttributes: AttributeSet = outputSet - override def children: Seq[SparkPlan] = Nil } object TestStrategy extends Strategy { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala index 35d2513835611..d4a6d84ce2b30 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala @@ -582,6 +582,10 @@ class ColumnarAlias(child: ColumnarExpression, name: String)( with ColumnarExpression { override def columnarEval(batch: ColumnarBatch): Any = child.columnarEval(batch) + + override protected def withNewChildInternal(newChild: Expression): ColumnarAlias = + new ColumnarAlias(newChild.asInstanceOf[ColumnarExpression], name)(exprId, qualifier, + explicitMetadata, nonInheritableMetadataKeys) } class ColumnarAttributeReference( @@ -641,6 +645,9 @@ class ColumnarProjectExec(projectList: Seq[NamedExpression], child: SparkPlan) } override def hashCode(): Int = super.hashCode() + + override def withNewChildInternal(newChild: SparkPlan): ColumnarProjectExec = + new ColumnarProjectExec(projectList, newChild) } /** @@ -705,6 +712,12 @@ class BrokenColumnarAdd( } ret } + + override def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): BrokenColumnarAdd = + new BrokenColumnarAdd( + left = newLeft.asInstanceOf[ColumnarExpression], + right = newRight.asInstanceOf[ColumnarExpression], failOnError) } class CannotReplaceException(str: String) extends RuntimeException(str) { @@ -781,6 +794,8 @@ case class MyShuffleExchangeExec(delegate: ShuffleExchangeExec) extends ShuffleE override def child: SparkPlan = delegate.child override protected def doExecute(): RDD[InternalRow] = delegate.execute() override def outputPartitioning: Partitioning = delegate.outputPartitioning + override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = + super.legacyWithNewChildren(Seq(newChild)) } /** @@ -798,6 +813,9 @@ case class MyBroadcastExchangeExec(delegate: BroadcastExchangeExec) extends Broa override protected def doExecute(): RDD[InternalRow] = delegate.execute() override def doExecuteBroadcast[T](): Broadcast[T] = delegate.executeBroadcast() override def outputPartitioning: Partitioning = delegate.outputPartitioning + + override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = + super.legacyWithNewChildren(Seq(newChild)) } class ReplacedRowToColumnarExec(override val child: SparkPlan) @@ -815,6 +833,9 @@ class ReplacedRowToColumnarExec(override val child: SparkPlan) } override def hashCode(): Int = super.hashCode() + + override def withNewChildInternal(newChild: SparkPlan): ReplacedRowToColumnarExec = + new ReplacedRowToColumnarExec(newChild) } case class MyPostRule() extends Rule[SparkPlan] { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala index abe94c2a0b410..986e625137a77 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala @@ -233,8 +233,7 @@ object TypedImperativeAggregateSuite { nullable: Boolean = false, mutableAggBufferOffset: Int = 0, inputAggBufferOffset: Int = 0) - extends TypedImperativeAggregate[MaxValue] - with ImplicitCastInputTypes + extends TypedImperativeAggregate[MaxValue] with ImplicitCastInputTypes with UnaryLike[Expression] { override def createAggregationBuffer(): MaxValue = { @@ -297,6 +296,9 @@ object TypedImperativeAggregateSuite { val value = stream.readInt() new MaxValue(value, isValueSet) } + + override protected def withNewChildInternal(newChild: Expression): TypedMax = + copy(child = newChild) } private class MaxValue(var value: Int, var isValueSet: Boolean = false) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala index cef870b249985..2011d057338c1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala @@ -600,6 +600,9 @@ case class ExceptionInjectingOperator(child: SparkPlan) extends UnaryExecNode { override def output: Seq[Attribute] = child.output override def outputPartitioning: Partitioning = child.outputPartitioning + + override protected def withNewChildInternal(newChild: SparkPlan): ExceptionInjectingOperator = + copy(child = newChild) } @SQLUserDefinedType(udt = classOf[SimpleTupleUDT]) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ColumnarRulesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ColumnarRulesSuite.scala index dd2790040b9e8..df08acd35ef17 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ColumnarRulesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ColumnarRulesSuite.scala @@ -60,4 +60,5 @@ case class LeafOp(override val supportsColumnar: Boolean) extends LeafExecNode { case class UnaryOp(child: SparkPlan, override val supportsColumnar: Boolean) extends UnaryExecNode { override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() override def output: Seq[Attribute] = child.output + override protected def withNewChildInternal(newChild: SparkPlan): UnaryOp = copy(child = newChild) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala index fb97e15e4df63..9776e76b541ee 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala @@ -40,6 +40,9 @@ case class ColumnarExchange(child: SparkPlan) extends Exchange { override protected def doExecute(): RDD[InternalRow] = throw new RanRowBased override protected def doExecuteColumnar(): RDD[ColumnarBatch] = throw new RanColumnar + + override protected def withNewChildInternal(newChild: SparkPlan): ColumnarExchange = + copy(child = newChild) } class ExchangeSuite extends SparkPlanTest with SharedSparkSession { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 1724f785c2ff9..0b30b8cdf2644 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -1264,4 +1264,6 @@ private case class DummySparkPlan( ) extends SparkPlan { override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException override def output: Seq[Attribute] = Seq.empty + override protected def withNewChildrenInternal(newChildren: IndexedSeq[SparkPlan]): SparkPlan = + copy(children = newChildren) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ReferenceSort.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ReferenceSort.scala index a31e2382940e6..1592949fe9a9b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ReferenceSort.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ReferenceSort.scala @@ -58,4 +58,7 @@ case class ReferenceSort( override def outputOrdering: Seq[SortOrder] = sortOrder override def outputPartitioning: Partitioning = child.outputPartitioning + + override protected def withNewChildInternal(newChild: SparkPlan): ReferenceSort = + copy(child = newChild) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala index b17c93503804c..b3d29df1b29bc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, InsertIntoStatement, LogicalPlan, Project} import org.apache.spark.sql.execution.{QueryExecution, QueryExecutionException, WholeStageCodegenExec} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper -import org.apache.spark.sql.execution.command.RunnableCommand +import org.apache.spark.sql.execution.command.LeafRunnableCommand import org.apache.spark.sql.execution.datasources.{CreateTable, InsertIntoHadoopFsRelationCommand} import org.apache.spark.sql.execution.datasources.json.JsonFileFormat import org.apache.spark.sql.test.SharedSparkSession @@ -302,7 +302,7 @@ class DataFrameCallbackSuite extends QueryTest } /** A test command that throws `java.lang.Error` during execution. */ -case class ErrorTestCommand(foo: String) extends RunnableCommand { +case class ErrorTestCommand(foo: String) extends LeafRunnableCommand { override val output: Seq[Attribute] = Seq(AttributeReference("foo", StringType)()) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala index 283c254b39602..fe5d74f889dbb 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala @@ -130,6 +130,9 @@ case class CreateHiveTableAsSelectCommand( override def writingCommandClassName: String = Utils.getSimpleName(classOf[InsertIntoHiveTable]) + + override protected def withNewChildInternal( + newChild: LogicalPlan): CreateHiveTableAsSelectCommand = copy(query = newChild) } /** @@ -177,4 +180,7 @@ case class OptimizedCreateHiveTableAsSelectCommand( override def writingCommandClassName: String = Utils.getSimpleName(classOf[InsertIntoHadoopFsRelationCommand]) + + override protected def withNewChildInternal( + newChild: LogicalPlan): OptimizedCreateHiveTableAsSelectCommand = copy(query = newChild) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationExec.scala index 2059f5bff9cbb..27fdb22391226 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationExec.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationExec.scala @@ -184,6 +184,9 @@ private[hive] case class HiveScriptTransformationExec( outputIterator } + + override protected def withNewChildInternal(newChild: SparkPlan): HiveScriptTransformationExec = + copy(child = newChild) } private[hive] case class HiveScriptTransformationWriterThread( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveDirCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveDirCommand.scala index 7ef637ed553ad..09aa1e8eea1f8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveDirCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveDirCommand.scala @@ -137,5 +137,8 @@ case class InsertIntoHiveDirCommand( Seq.empty[Row] } + + override protected def withNewChildInternal( + newChild: LogicalPlan): InsertIntoHiveDirCommand = copy(query = newChild) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index bfb24cfedb55a..fcd11e67587cf 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -343,4 +343,7 @@ case class InsertIntoHiveTable( isSrcLocal = false) } } + + override protected def withNewChildInternal(newChild: LogicalPlan): InsertIntoHiveTable = + copy(query = newChild) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index 7717e6ee207d9..7c3d1617bfaeb 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -110,6 +110,9 @@ private[hive] case class HiveSimpleUDF( override def prettyName: String = name override def sql: String = s"$name(${children.map(_.sql).mkString(", ")})" + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = + copy(children = newChildren) } // Adapter from Catalyst ExpressionResult to Hive DeferredObject @@ -186,6 +189,9 @@ private[hive] case class HiveGenericUDF( override def toString: String = { s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" } + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = + copy(children = newChildren) } /** @@ -279,6 +285,9 @@ private[hive] case class HiveGenericUDTF( } override def prettyName: String = name + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = + copy(children = newChildren) } /** @@ -528,6 +537,9 @@ private[hive] case class HiveUDAFFunction( buffer } } + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = + copy(children = newChildren) } case class HiveUDAFBuffer(buf: AggregationBuffer, canDoMerge: Boolean) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/TestingTypedCount.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/TestingTypedCount.scala index 0ef7b3383e086..ee233fbd7238f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/TestingTypedCount.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/TestingTypedCount.scala @@ -78,6 +78,9 @@ case class TestingTypedCount( copy(inputAggBufferOffset = newInputAggBufferOffset) override val prettyName: String = "typed_count" + + override protected def withNewChildInternal(newChild: Expression): TestingTypedCount = + copy(child = newChild) } object TestingTypedCount {