diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala index 397dbb28f8e3e..109ccbd964aca 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.Column import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Expression, ImplicitCastInputTypes} import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, TypedImperativeAggregate} +import org.apache.spark.sql.catalyst.trees.BinaryLike import org.apache.spark.sql.functions.lit import org.apache.spark.sql.types._ @@ -348,7 +349,9 @@ private[spark] object SummaryBuilderImpl extends Logging { weightExpr: Expression, mutableAggBufferOffset: Int, inputAggBufferOffset: Int) - extends TypedImperativeAggregate[SummarizerBuffer] with ImplicitCastInputTypes { + extends TypedImperativeAggregate[SummarizerBuffer] + with ImplicitCastInputTypes + with BinaryLike[Expression] { override def eval(state: SummarizerBuffer): Any = { val metrics = requestedMetrics.map { @@ -368,7 +371,8 @@ private[spark] object SummaryBuilderImpl extends Logging { override def inputTypes: Seq[DataType] = vectorUDT :: DoubleType :: Nil - override def children: Seq[Expression] = featuresExpr :: weightExpr :: Nil + override def left: Expression = featuresExpr + override def right: Expression = weightExpr override def update(state: SummarizerBuffer, row: InternalRow): SummarizerBuffer = { val features = vectorUDT.deserialize(featuresExpr.eval(row)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 42892e25fa0e8..297aa2b4cb9e9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.catalyst.trees.{BinaryLike, LeafLike, TernaryLike, TreeNode, UnaryLike} +import org.apache.spark.sql.catalyst.trees.{BinaryLike, LeafLike, QuaternaryLike, TernaryLike, TreeNode, UnaryLike} import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.internal.SQLConf @@ -786,7 +786,7 @@ abstract class TernaryExpression extends Expression with TernaryLike[Expression] * An expression with four inputs and one output. The output is by default evaluated to null * if any input is evaluated to null. */ -abstract class QuaternaryExpression extends Expression { +abstract class QuaternaryExpression extends Expression with QuaternaryLike[Expression] { override def foldable: Boolean = children.forall(_.foldable) @@ -797,14 +797,13 @@ abstract class QuaternaryExpression extends Expression { * If subclass of QuaternaryExpression override nullable, probably should also override this. */ override def eval(input: InternalRow): Any = { - val exprs = children - val value1 = exprs(0).eval(input) + val value1 = first.eval(input) if (value1 != null) { - val value2 = exprs(1).eval(input) + val value2 = second.eval(input) if (value2 != null) { - val value3 = exprs(2).eval(input) + val value3 = third.eval(input) if (value3 != null) { - val value4 = exprs(3).eval(input) + val value4 = fourth.eval(input) if (value4 != null) { return nullSafeEval(value1, value2, value3, value4) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PartitionTransforms.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PartitionTransforms.scala index 9d34368b6c541..05d553757e742 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PartitionTransforms.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PartitionTransforms.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.types.{DataType, IntegerType} * }}} */ abstract class PartitionTransformExpression extends Expression with Unevaluable - with UnaryLike[Expression] { + with UnaryLike[Expression] { override def nullable: Boolean = true } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index 490e14afe992b..36004b0ea6244 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.types._ group = "agg_funcs", since = "1.0.0") case class Average(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes - with UnaryLike[Expression] { + with UnaryLike[Expression] { override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("avg") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountIf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountIf.scala index 53a3fd6b6c23d..c1c4c84497bcd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountIf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountIf.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.types.{AbstractDataType, BooleanType, DataType, Long group = "agg_funcs", since = "3.0.0") case class CountIf(predicate: Expression) extends UnevaluableAggregate with ImplicitCastInputTypes - with UnaryLike[Expression] { + with UnaryLike[Expression] { override def prettyName: String = "count_if" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala index 45d55a085a717..a838a0a0e8977 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, ExpressionDescription, Literal} +import org.apache.spark.sql.catalyst.trees.QuaternaryLike import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.sketch.CountMinSketch @@ -60,7 +61,9 @@ case class CountMinSketchAgg( seedExpression: Expression, override val mutableAggBufferOffset: Int, override val inputAggBufferOffset: Int) - extends TypedImperativeAggregate[CountMinSketch] with ExpectsInputTypes { + extends TypedImperativeAggregate[CountMinSketch] + with ExpectsInputTypes + with QuaternaryLike[Expression] { def this( child: Expression, @@ -145,8 +148,10 @@ case class CountMinSketchAgg( override def defaultResult: Option[Literal] = Option(Literal.create(eval(createAggregationBuffer()), dataType)) - override def children: Seq[Expression] = - Seq(child, epsExpression, confidenceExpression, seedExpression) - override def prettyName: String = "count_min_sketch" + + override def first: Expression = child + override def second: Expression = epsExpression + override def third: Expression = confidenceExpression + override def fourth: Expression = seedExpression } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala index 160ee92b00447..8fcee104d276b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala @@ -27,11 +27,9 @@ import org.apache.spark.sql.types._ * Compute the covariance between two expressions. * When applied on empty data (i.e., count is zero), it returns NULL. */ -abstract class Covariance(x: Expression, y: Expression, nullOnDivideByZero: Boolean) +abstract class Covariance(val left: Expression, val right: Expression, nullOnDivideByZero: Boolean) extends DeclarativeAggregate with ImplicitCastInputTypes with BinaryLike[Expression] { - override def left: Expression = x - override def right: Expression = y override def nullable: Boolean = true override def dataType: DataType = DoubleType override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType) @@ -72,14 +70,14 @@ abstract class Covariance(x: Expression, y: Expression, nullOnDivideByZero: Bool protected def updateExpressionsDef: Seq[Expression] = { val newN = n + 1.0 - val dx = x - xAvg - val dy = y - yAvg + val dx = left - xAvg + val dy = right - yAvg val dyN = dy / newN val newXAvg = xAvg + dx / newN val newYAvg = yAvg + dyN - val newCk = ck + dx * (y - newYAvg) + val newCk = ck + dx * (right - newYAvg) - val isNull = x.isNull || y.isNull + val isNull = left.isNull || right.isNull Seq( If(isNull, n, newN), If(isNull, xAvg, newXAvg), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index 9cb8097041fed..f412a3ec31e0f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -39,7 +39,7 @@ import org.apache.spark.sql.types._ group = "agg_funcs", since = "1.0.0") case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes - with UnaryLike[Expression] { + with UnaryLike[Expression] { override def nullable: Boolean = true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/bitwiseAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/bitwiseAggregates.scala index 25c099525ef81..5ffc0f6ce3a42 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/bitwiseAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/bitwiseAggregates.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.trees.UnaryLike import org.apache.spark.sql.types.{AbstractDataType, DataType, IntegralType} abstract class BitAggregate extends DeclarativeAggregate with ExpectsInputTypes - with UnaryLike[Expression] { + with UnaryLike[Expression] { val child: Expression diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala index 5f1d03264fa74..d8a76d7add262 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.types._ * can cause GC paused and eventually OutOfMemory Errors. */ abstract class Collect[T <: Growable[Any] with Iterable[Any]] extends TypedImperativeAggregate[T] - with UnaryLike[Expression] { + with UnaryLike[Expression] { val child: Expression diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala index 44aed2554ccf9..66cd1403ac8f0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala @@ -175,7 +175,7 @@ object GroupingSets { group = "agg_funcs") // scalastyle:on line.size.limit line.contains.tab case class Grouping(child: Expression) extends Expression with Unevaluable - with UnaryLike[Expression] { + with UnaryLike[Expression] { @transient override lazy val references: AttributeSet = AttributeSet(VirtualColumn.groupingIdAttribute :: Nil) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index a8df4fb6cd9f2..bbfdf7135824c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -25,6 +25,7 @@ import scala.collection.mutable import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, UnresolvedException} import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.trees.{BinaryLike, QuaternaryLike, TernaryLike} import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.internal.SQLConf @@ -119,8 +120,6 @@ trait HigherOrderFunction extends Expression with ExpectsInputTypes { override def nullable: Boolean = arguments.exists(_.nullable) - override def children: Seq[Expression] = arguments ++ functions - /** * Arguments of the higher ordered function. */ @@ -182,7 +181,7 @@ trait HigherOrderFunction extends Expression with ExpectsInputTypes { /** * Trait for functions having as input one argument and one function. */ -trait SimpleHigherOrderFunction extends HigherOrderFunction { +trait SimpleHigherOrderFunction extends HigherOrderFunction with BinaryLike[Expression] { def argument: Expression @@ -202,6 +201,9 @@ trait SimpleHigherOrderFunction extends HigherOrderFunction { def functionForEval: Expression = functionsForEval.head + override def left: Expression = argument + override def right: Expression = function + /** * Called by [[eval]]. If a subclass keeps the default nullability, it can override this method * in order to save null-check code. @@ -694,7 +696,7 @@ case class ArrayAggregate( zero: Expression, merge: Expression, finish: Expression) - extends HigherOrderFunction with CodegenFallback { + extends HigherOrderFunction with CodegenFallback with QuaternaryLike[Expression] { def this(argument: Expression, zero: Expression, merge: Expression) = { this(argument, zero, merge, LambdaFunction.identity) @@ -760,6 +762,11 @@ case class ArrayAggregate( } override def prettyName: String = "aggregate" + + override def first: Expression = argument + override def second: Expression = zero + override def third: Expression = merge + override def fourth: Expression = finish } /** @@ -884,7 +891,7 @@ case class TransformValues( since = "3.0.0", group = "lambda_funcs") case class MapZipWith(left: Expression, right: Expression, function: Expression) - extends HigherOrderFunction with CodegenFallback { + extends HigherOrderFunction with CodegenFallback with TernaryLike[Expression] { def functionForEval: Expression = functionsForEval.head @@ -1045,6 +1052,10 @@ case class MapZipWith(left: Expression, right: Expression, function: Expression) } override def prettyName: String = "map_zip_with" + + override def first: Expression = left + override def second: Expression = right + override def third: Expression = function } // scalastyle:off line.size.limit @@ -1063,7 +1074,7 @@ case class MapZipWith(left: Expression, right: Expression, function: Expression) group = "lambda_funcs") // scalastyle:on line.size.limit case class ZipWith(left: Expression, right: Expression, function: Expression) - extends HigherOrderFunction with CodegenFallback { + extends HigherOrderFunction with CodegenFallback with TernaryLike[Expression] { def functionForEval: Expression = functionsForEval.head @@ -1071,7 +1082,7 @@ case class ZipWith(left: Expression, right: Expression, function: Expression) override def argumentTypes: Seq[AbstractDataType] = ArrayType :: ArrayType :: Nil - override def functions: Seq[Expression] = List(function) + override def functions: Seq[Expression] = function :: Nil override def functionTypes: Seq[AbstractDataType] = AnyDataType :: Nil @@ -1121,4 +1132,8 @@ case class ZipWith(left: Expression, right: Expression, function: Expression) } override def prettyName: String = "zip_with" + + override def first: Expression = left + override def second: Expression = right + override def third: Expression = function } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index 7ddb00b62b89c..3b58f3d868d3c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -1488,7 +1488,6 @@ case class WidthBucket( numBucket: Expression) extends QuaternaryExpression with ImplicitCastInputTypes with NullIntolerant { - override def children: Seq[Expression] = Seq(value, minValue, maxValue, numBucket) override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType, DoubleType, LongType) override def dataType: DataType = LongType override def nullable: Boolean = true @@ -1507,4 +1506,9 @@ case class WidthBucket( "org.apache.spark.sql.catalyst.expressions.WidthBucket" + s".computeBucketNumber($input, $min, $max, $numBucket)") } + + override def first: Expression = value + override def second: Expression = minValue + override def third: Expression = maxValue + override def fourth: Expression = numBucket } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index bd2d8375782d5..9fdab350ceb95 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -562,7 +562,6 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio override def dataType: DataType = StringType override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, StringType, IntegerType) - override def children: Seq[Expression] = subject :: regexp :: rep :: pos :: Nil override def prettyName: String = "regexp_replace" override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { @@ -618,6 +617,11 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio """ }) } + + override def first: Expression = subject + override def second: Expression = regexp + override def third: Expression = rep + override def fourth: Expression = pos } object RegExpReplace { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index c6b7738f8c24d..714f1d6dc4bfc 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -593,8 +593,6 @@ case class Overlay(input: Expression, replace: Expression, pos: Expression, len: override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, BinaryType), TypeCollection(StringType, BinaryType), IntegerType, IntegerType) - override def children: Seq[Expression] = input :: replace :: pos :: len :: Nil - override def checkInputDataTypes(): TypeCheckResult = { val inputTypeCheck = super.checkInputDataTypes() if (inputTypeCheck.isSuccess) { @@ -631,6 +629,11 @@ case class Overlay(input: Expression, replace: Expression, pos: Expression, len: "org.apache.spark.sql.catalyst.expressions.Overlay" + s".calculate($input, $replace, $pos, $len);") } + + override def first: Expression = input + override def second: Expression = replace + override def third: Expression = pos + override def fourth: Expression = len } object StringTranslate { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index d45614fa292c5..fa027d1ab0561 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -742,7 +742,7 @@ case class NthValue(input: Expression, offset: Expression, ignoreNulls: Boolean) group = "window_funcs") // scalastyle:on line.size.limit line.contains.tab case class NTile(buckets: Expression) extends RowNumberLike with SizeBasedWindowFunction - with UnaryLike[Expression] { + with UnaryLike[Expression] { def this() = this(Literal(1)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index 3c3d642f7d36d..7af3f32b18994 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -431,7 +431,7 @@ case class InsertAction( } case class Assignment(key: Expression, value: Expression) extends Expression - with Unevaluable with BinaryLike[Expression] { + with Unevaluable with BinaryLike[Expression] { override def nullable: Boolean = false override def dataType: DataType = throw new UnresolvedException("nullable") override def left: Expression = key diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 06bb7baed9ce5..8b5a4c971912d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -850,3 +850,11 @@ trait TernaryLike[T <: TreeNode[T]] { self: TreeNode[T] => def third: T @transient override final lazy val children: Seq[T] = first :: second :: third :: Nil } + +trait QuaternaryLike[T <: TreeNode[T]] { self: TreeNode[T] => + def first: T + def second: T + def third: T + def fourth: T + @transient override final lazy val children: Seq[T] = first :: second :: third :: fourth :: Nil +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index b118dba9e3711..026d9676f4fba 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -560,10 +560,9 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { } } -case class HugeCodeIntExpression(value: Int) extends Expression { +case class HugeCodeIntExpression(value: Int) extends LeafExpression { override def nullable: Boolean = true override def dataType: DataType = IntegerType - override def children: Seq[Expression] = Nil override def eval(input: InternalRow): Any = value override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { // Assuming HugeMethodLimit to be 8000 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala index c9de8c7e1a9d0..e1f070a8b66cb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala @@ -22,7 +22,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.spark.SparkContext import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnaryCommand} import org.apache.spark.sql.execution.{SparkPlan, SQLExecution} import org.apache.spark.sql.execution.datasources.BasicWriteJobStatsTracker import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} @@ -31,7 +31,7 @@ import org.apache.spark.util.SerializableConfiguration /** * A special `Command` which writes data out and updates metrics. */ -trait DataWritingCommand extends Command { +trait DataWritingCommand extends UnaryCommand { /** * The input query plan that produces the data to be written. * IMPORTANT: the input query plan MUST be analyzed, so that we can carry its output columns @@ -39,7 +39,7 @@ trait DataWritingCommand extends Command { */ def query: LogicalPlan - override final def children: Seq[LogicalPlan] = query :: Nil + override final def child: LogicalPlan = query // Output column names of the analyzed input query plan. def outputColumnNames: Seq[String] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala index ac6e2ba9eba4f..8bc3cedff2426 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.plans.logical.{LeafCommand, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan} import org.apache.spark.sql.connector.ExternalCommandRunner import org.apache.spark.sql.execution.{ExplainMode, LeafExecNode, SparkPlan, UnaryExecNode} import org.apache.spark.sql.execution.metric.SQLMetric @@ -37,7 +37,9 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap * A logical command that is executed for its side-effects. `RunnableCommand`s are * wrapped in `ExecutedCommand` during execution. */ -trait RunnableCommand extends LeafCommand { +trait RunnableCommand extends Command { + + override def children: Seq[LogicalPlan] = Nil // The map used to record the metrics of running the command. This will be passed to // `ExecutedCommand` during query planning. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/AddPartitionExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/AddPartitionExec.scala index 57d74ab4e4e39..f3fba93b08eb6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/AddPartitionExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/AddPartitionExec.scala @@ -31,7 +31,7 @@ case class AddPartitionExec( table: SupportsPartitionManagement, partSpecs: Seq[ResolvedPartitionSpec], ignoreIfExists: Boolean, - refreshCache: () => Unit) extends V2CommandExec { + refreshCache: () => Unit) extends LeafV2CommandExec { import DataSourceV2Implicits._ override def output: Seq[Attribute] = Seq.empty diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/AlterNamespaceSetPropertiesExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/AlterNamespaceSetPropertiesExec.scala index 1eebe4cdb6a86..4bde31abfc0d4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/AlterNamespaceSetPropertiesExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/AlterNamespaceSetPropertiesExec.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.connector.catalog.{NamespaceChange, SupportsNamespac case class AlterNamespaceSetPropertiesExec( catalog: SupportsNamespaces, namespace: Seq[String], - props: Map[String, String]) extends V2CommandExec { + props: Map[String, String]) extends LeafV2CommandExec { override protected def run(): Seq[InternalRow] = { val changes = props.map{ case (k, v) => NamespaceChange.setProperty(k, v) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/AlterTableExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/AlterTableExec.scala index 89762757ecb28..cc1c73b020d68 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/AlterTableExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/AlterTableExec.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.errors.QueryExecutionErrors case class AlterTableExec( catalog: TableCatalog, ident: Identifier, - changes: Seq[TableChange]) extends V2CommandExec { + changes: Seq[TableChange]) extends LeafV2CommandExec { override def output: Seq[Attribute] = Seq.empty diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CacheTableExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CacheTableExec.scala index 56e008d1d95f8..5b4b9e314ab1e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CacheTableExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CacheTableExec.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdenti import org.apache.spark.sql.execution.command.CreateViewCommand import org.apache.spark.storage.StorageLevel -trait BaseCacheTableExec extends V2CommandExec { +trait BaseCacheTableExec extends LeafV2CommandExec { def relationName: String def planToCache: LogicalPlan def dataFrameForCachedPlan: DataFrame @@ -117,7 +117,7 @@ case class CacheTableAsSelectExec( case class UncacheTableExec( relation: LogicalPlan, - cascade: Boolean) extends V2CommandExec { + cascade: Boolean) extends LeafV2CommandExec { override def run(): Seq[InternalRow] = { val sparkSession = sqlContext.sparkSession sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession, relation, cascade) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateNamespaceExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateNamespaceExec.scala index da567f569b9d5..dba84d2385aec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateNamespaceExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateNamespaceExec.scala @@ -34,7 +34,7 @@ case class CreateNamespaceExec( namespace: Seq[String], ifNotExists: Boolean, private var properties: Map[String, String]) - extends V2CommandExec { + extends LeafV2CommandExec { override protected def run(): Seq[InternalRow] = { import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ import org.apache.spark.sql.connector.catalog.SupportsNamespaces._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateTableExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateTableExec.scala index 752c04313ffaa..be7331b0d7dc4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateTableExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateTableExec.scala @@ -33,7 +33,7 @@ case class CreateTableExec( tableSchema: StructType, partitioning: Seq[Transform], tableProperties: Map[String, String], - ignoreIfExists: Boolean) extends V2CommandExec { + ignoreIfExists: Boolean) extends LeafV2CommandExec { import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ override protected def run(): Seq[InternalRow] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DeleteFromTableExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DeleteFromTableExec.scala index f0a45c249dc10..05893a67b3728 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DeleteFromTableExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DeleteFromTableExec.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.sources.Filter case class DeleteFromTableExec( table: SupportsDelete, condition: Array[Filter], - refreshCache: () => Unit) extends V2CommandExec { + refreshCache: () => Unit) extends LeafV2CommandExec { override protected def run(): Seq[InternalRow] = { table.deleteWhere(condition) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DescribeColumnExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DescribeColumnExec.scala index ab8c5617aa36b..f7d79a1259ea9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DescribeColumnExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DescribeColumnExec.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute case class DescribeColumnExec( override val output: Seq[Attribute], column: Attribute, - isExtended: Boolean) extends V2CommandExec { + isExtended: Boolean) extends LeafV2CommandExec { override protected def run(): Seq[InternalRow] = { val rows = new ArrayBuffer[InternalRow]() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DescribeNamespaceExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DescribeNamespaceExec.scala index 2da96b769a41a..bd8a4f06fe114 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DescribeNamespaceExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DescribeNamespaceExec.scala @@ -31,7 +31,7 @@ case class DescribeNamespaceExec( output: Seq[Attribute], catalog: SupportsNamespaces, namespace: Seq[String], - isExtended: Boolean) extends V2CommandExec { + isExtended: Boolean) extends LeafV2CommandExec { override protected def run(): Seq[InternalRow] = { val rows = new ArrayBuffer[InternalRow]() val ns = namespace.toArray diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DescribeTableExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DescribeTableExec.scala index 769d76a9b1c2c..c20189efc91fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DescribeTableExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DescribeTableExec.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.connector.catalog.{CatalogV2Util, SupportsMetadataCo case class DescribeTableExec( output: Seq[Attribute], table: Table, - isExtended: Boolean) extends V2CommandExec { + isExtended: Boolean) extends LeafV2CommandExec { override protected def run(): Seq[InternalRow] = { val rows = new ArrayBuffer[InternalRow]() addSchema(rows) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropNamespaceExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropNamespaceExec.scala index 2efd901884c5e..dbd5cbd874945 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropNamespaceExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropNamespaceExec.scala @@ -30,7 +30,7 @@ case class DropNamespaceExec( namespace: Seq[String], ifExists: Boolean, cascade: Boolean) - extends V2CommandExec { + extends LeafV2CommandExec { override protected def run(): Seq[InternalRow] = { import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropPartitionExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropPartitionExec.scala index 50e14483a9afd..8dea5d66e4bfa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropPartitionExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropPartitionExec.scala @@ -30,7 +30,7 @@ case class DropPartitionExec( partSpecs: Seq[ResolvedPartitionSpec], ignoreIfNotExists: Boolean, purge: Boolean, - refreshCache: () => Unit) extends V2CommandExec { + refreshCache: () => Unit) extends LeafV2CommandExec { import DataSourceV2Implicits._ override def output: Seq[Attribute] = Seq.empty diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropTableExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropTableExec.scala index 8b2b5e835513e..1e0627fb6dfdd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropTableExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropTableExec.scala @@ -30,7 +30,7 @@ case class DropTableExec( ident: Identifier, ifExists: Boolean, purge: Boolean, - invalidateCache: () => Unit) extends V2CommandExec { + invalidateCache: () => Unit) extends LeafV2CommandExec { override def run(): Seq[InternalRow] = { if (catalog.tableExists(ident)) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/RefreshTableExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/RefreshTableExec.scala index bfd7c6f729cc8..05aa52decd21a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/RefreshTableExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/RefreshTableExec.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.connector.catalog.{Identifier, TableCatalog} case class RefreshTableExec( catalog: TableCatalog, ident: Identifier, - refreshCache: () => Unit) extends V2CommandExec { + refreshCache: () => Unit) extends LeafV2CommandExec { override protected def run(): Seq[InternalRow] = { catalog.invalidateTable(ident) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/RenamePartitionExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/RenamePartitionExec.scala index 20b2dd1ab83cc..1db29c80739ad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/RenamePartitionExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/RenamePartitionExec.scala @@ -29,7 +29,7 @@ case class RenamePartitionExec( table: SupportsPartitionManagement, from: ResolvedPartitionSpec, to: ResolvedPartitionSpec, - refreshCache: () => Unit) extends V2CommandExec { + refreshCache: () => Unit) extends LeafV2CommandExec { override def output: Seq[Attribute] = Seq.empty diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/RenameTableExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/RenameTableExec.scala index e44ad64b6c268..f5ea355182a70 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/RenameTableExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/RenameTableExec.scala @@ -33,7 +33,7 @@ case class RenameTableExec( newIdent: Identifier, invalidateCache: () => Option[StorageLevel], cacheTable: (SparkSession, LogicalPlan, Option[String], StorageLevel) => Unit) - extends V2CommandExec { + extends LeafV2CommandExec { override def output: Seq[Attribute] = Seq.empty diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceTableExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceTableExec.scala index 18c8a9fb90f87..749cbf631b03e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceTableExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceTableExec.scala @@ -35,7 +35,7 @@ case class ReplaceTableExec( partitioning: Seq[Transform], tableProperties: Map[String, String], orCreate: Boolean, - invalidateCache: (TableCatalog, Table, Identifier) => Unit) extends V2CommandExec { + invalidateCache: (TableCatalog, Table, Identifier) => Unit) extends LeafV2CommandExec { override protected def run(): Seq[InternalRow] = { if (catalog.tableExists(ident)) { @@ -59,7 +59,7 @@ case class AtomicReplaceTableExec( partitioning: Seq[Transform], tableProperties: Map[String, String], orCreate: Boolean, - invalidateCache: (TableCatalog, Table, Identifier) => Unit) extends V2CommandExec { + invalidateCache: (TableCatalog, Table, Identifier) => Unit) extends LeafV2CommandExec { override protected def run(): Seq[InternalRow] = { if (catalog.tableExists(identifier)) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/SetCatalogAndNamespaceExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/SetCatalogAndNamespaceExec.scala index b13cea266707b..fab95bffac25d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/SetCatalogAndNamespaceExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/SetCatalogAndNamespaceExec.scala @@ -28,7 +28,7 @@ case class SetCatalogAndNamespaceExec( catalogManager: CatalogManager, catalogName: Option[String], namespace: Option[Seq[String]]) - extends V2CommandExec { + extends LeafV2CommandExec { override protected def run(): Seq[InternalRow] = { // The catalog is updated first because CatalogManager resets the current namespace // when the current catalog is set. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowCurrentNamespaceExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowCurrentNamespaceExec.scala index 121ae1c5b1176..0977452a6ca08 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowCurrentNamespaceExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowCurrentNamespaceExec.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.NamespaceHelper case class ShowCurrentNamespaceExec( output: Seq[Attribute], catalogManager: CatalogManager) - extends V2CommandExec { + extends LeafV2CommandExec { override protected def run(): Seq[InternalRow] = { Seq(toCatalystRow(catalogManager.currentCatalog.name, catalogManager.currentNamespace.quoted)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowTablePropertiesExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowTablePropertiesExec.scala index 4e1633e1460ec..33d7337fab635 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowTablePropertiesExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowTablePropertiesExec.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Table} case class ShowTablePropertiesExec( output: Seq[Attribute], catalogTable: Table, - propertyKey: Option[String]) extends V2CommandExec { + propertyKey: Option[String]) extends LeafV2CommandExec { override protected def run(): Seq[InternalRow] = { import scala.collection.JavaConverters._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/TruncatePartitionExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/TruncatePartitionExec.scala index 135005b64973d..6ebfbc1e55959 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/TruncatePartitionExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/TruncatePartitionExec.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.connector.catalog.{SupportsAtomicPartitionManagement case class TruncatePartitionExec( table: SupportsPartitionManagement, partSpec: ResolvedPartitionSpec, - refreshCache: () => Unit) extends V2CommandExec { + refreshCache: () => Unit) extends LeafV2CommandExec { override def output: Seq[Attribute] = Seq.empty diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/TruncateTableExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/TruncateTableExec.scala index 69261b3084776..948dc1bc8c87c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/TruncateTableExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/TruncateTableExec.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.connector.catalog.TruncatableTable */ case class TruncateTableExec( table: TruncatableTable, - refreshCache: () => Unit) extends V2CommandExec { + refreshCache: () => Unit) extends LeafV2CommandExec { override def output: Seq[Attribute] = Seq.empty diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V1FallbackWriters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V1FallbackWriters.scala index bf5bc98ff2489..f99a4db5e39f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V1FallbackWriters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V1FallbackWriters.scala @@ -55,9 +55,8 @@ case class OverwriteByExpressionExecV1( write: V1Write) extends V1FallbackWriters /** Some helper interfaces that use V2 write semantics through the V1 writer interface. */ -sealed trait V1FallbackWriters extends V2CommandExec with SupportsV1Write { +sealed trait V1FallbackWriters extends LeafV2CommandExec with SupportsV1Write { override def output: Seq[Attribute] = Nil - override final def children: Seq[SparkPlan] = Nil def table: SupportsWrite def refreshCache: () => Unit diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2CommandExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2CommandExec.scala index 1e33405f35e1e..3c67a3d968fe6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2CommandExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2CommandExec.scala @@ -21,6 +21,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions.{AttributeSet, GenericRowWithSchema} +import org.apache.spark.sql.catalyst.trees.LeafLike import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.types.StructType @@ -57,8 +58,6 @@ abstract class V2CommandExec extends SparkPlan { sqlContext.sparkContext.parallelize(result, 1) } - override def children: Seq[SparkPlan] = Nil - override def producedAttributes: AttributeSet = outputSet protected def toCatalystRow(values: Any*): InternalRow = { @@ -69,3 +68,5 @@ abstract class V2CommandExec extends SparkPlan { RowEncoder(StructType.fromAttributes(output)).resolveAndBind().createSerializer() } } + +trait LeafV2CommandExec extends V2CommandExec with LeafLike[SparkPlan] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index a1c6133a24c82..6914330bb289d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -24,7 +24,7 @@ import scala.util.Random import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.{Expression, UnaryExpression} import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.plans.logical.OneRowRelation import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.{withDefaultTimeZone, UTC} @@ -3632,8 +3632,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { } object DataFrameFunctionsSuite { - case class CodegenFallbackExpr(child: Expression) extends Expression with CodegenFallback { - override def children: Seq[Expression] = Seq(child) + case class CodegenFallbackExpr(child: Expression) extends UnaryExpression with CodegenFallback { override def nullable: Boolean = child.nullable override def dataType: DataType = child.dataType override lazy val resolved = true diff --git a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala index 8f449037a5e01..311bc52515827 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Expression, Generator} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.expressions.codegen.Block._ +import org.apache.spark.sql.catalyst.trees.LeafLike import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{IntegerType, StructType} @@ -358,8 +359,7 @@ class GeneratorFunctionSuite extends QueryTest with SharedSparkSession { } } -case class EmptyGenerator() extends Generator { - override def children: Seq[Expression] = Nil +case class EmptyGenerator() extends Generator with LeafLike[Expression] { override def elementSchema: StructType = new StructType().add("id", IntegerType) override def eval(input: InternalRow): TraversableOnce[InternalRow] = Seq.empty override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala index 54fc090d9f061..abe94c2a0b410 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.TypedImperativeAggregateSuite.TypedMax import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression, GenericInternalRow, ImplicitCastInputTypes, SpecificInternalRow} import org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate +import org.apache.spark.sql.catalyst.trees.UnaryLike import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions._ @@ -232,8 +233,9 @@ object TypedImperativeAggregateSuite { nullable: Boolean = false, mutableAggBufferOffset: Int = 0, inputAggBufferOffset: Int = 0) - extends TypedImperativeAggregate[MaxValue] with ImplicitCastInputTypes { - + extends TypedImperativeAggregate[MaxValue] + with ImplicitCastInputTypes + with UnaryLike[Expression] { override def createAggregationBuffer(): MaxValue = { // Returns Int.MinValue if all inputs are null @@ -270,8 +272,6 @@ object TypedImperativeAggregateSuite { override lazy val deterministic: Boolean = true - override def children: Seq[Expression] = Seq(child) - override def inputTypes: Seq[AbstractDataType] = Seq(IntegerType) override def dataType: DataType = IntegerType diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/TestingTypedCount.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/TestingTypedCount.scala index 31b24301767af..0ef7b3383e086 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/TestingTypedCount.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/TestingTypedCount.scala @@ -22,6 +22,7 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, Da import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.{ImperativeAggregate, TypedImperativeAggregate} +import org.apache.spark.sql.catalyst.trees.UnaryLike import org.apache.spark.sql.hive.execution.TestingTypedCount.State import org.apache.spark.sql.types._ @@ -32,12 +33,11 @@ case class TestingTypedCount( child: Expression, mutableAggBufferOffset: Int = 0, inputAggBufferOffset: Int = 0) - extends TypedImperativeAggregate[TestingTypedCount.State] { + extends TypedImperativeAggregate[TestingTypedCount.State] + with UnaryLike[Expression] { def this(child: Expression) = this(child, 0, 0) - override def children: Seq[Expression] = child :: Nil - override def dataType: DataType = LongType override def nullable: Boolean = false