Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-34969][SPARK-34906][SQL] Followup for Refactor TreeNode's children handling methods into specialized traits #32065

Closed
wants to merge 10 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import org.apache.spark.sql.Column
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Expression, ImplicitCastInputTypes}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, TypedImperativeAggregate}
import org.apache.spark.sql.catalyst.trees.BinaryLike
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -348,7 +349,9 @@ private[spark] object SummaryBuilderImpl extends Logging {
weightExpr: Expression,
mutableAggBufferOffset: Int,
inputAggBufferOffset: Int)
extends TypedImperativeAggregate[SummarizerBuffer] with ImplicitCastInputTypes {
extends TypedImperativeAggregate[SummarizerBuffer]
with ImplicitCastInputTypes
with BinaryLike[Expression] {

override def eval(state: SummarizerBuffer): Any = {
val metrics = requestedMetrics.map {
Expand All @@ -368,7 +371,8 @@ private[spark] object SummaryBuilderImpl extends Logging {

override def inputTypes: Seq[DataType] = vectorUDT :: DoubleType :: Nil

override def children: Seq[Expression] = featuresExpr :: weightExpr :: Nil
override def left: Expression = featuresExpr
override def right: Expression = weightExpr

override def update(state: SummarizerBuffer, row: InternalRow): SummarizerBuffer = {
val features = vectorUDT.deserialize(featuresExpr.eval(row))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.trees.{BinaryLike, LeafLike, TernaryLike, TreeNode, UnaryLike}
import org.apache.spark.sql.catalyst.trees.{BinaryLike, LeafLike, QuaternaryLike, TernaryLike, TreeNode, UnaryLike}
import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -786,7 +786,7 @@ abstract class TernaryExpression extends Expression with TernaryLike[Expression]
* An expression with four inputs and one output. The output is by default evaluated to null
* if any input is evaluated to null.
*/
abstract class QuaternaryExpression extends Expression {
abstract class QuaternaryExpression extends Expression with QuaternaryLike[Expression] {

override def foldable: Boolean = children.forall(_.foldable)

Expand All @@ -797,14 +797,13 @@ abstract class QuaternaryExpression extends Expression {
* If subclass of QuaternaryExpression override nullable, probably should also override this.
*/
override def eval(input: InternalRow): Any = {
val exprs = children
val value1 = exprs(0).eval(input)
val value1 = first.eval(input)
if (value1 != null) {
val value2 = exprs(1).eval(input)
val value2 = second.eval(input)
if (value2 != null) {
val value3 = exprs(2).eval(input)
val value3 = third.eval(input)
if (value3 != null) {
val value4 = exprs(3).eval(input)
val value4 = fourth.eval(input)
if (value4 != null) {
return nullSafeEval(value1, value2, value3, value4)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import org.apache.spark.sql.types.{DataType, IntegerType}
* }}}
*/
abstract class PartitionTransformExpression extends Expression with Unevaluable
with UnaryLike[Expression] {
with UnaryLike[Expression] {
override def nullable: Boolean = true
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ import org.apache.spark.sql.types._
group = "agg_funcs",
since = "1.0.0")
case class Average(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes
with UnaryLike[Expression] {
with UnaryLike[Expression] {

override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("avg")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ import org.apache.spark.sql.types.{AbstractDataType, BooleanType, DataType, Long
group = "agg_funcs",
since = "3.0.0")
case class CountIf(predicate: Expression) extends UnevaluableAggregate with ImplicitCastInputTypes
with UnaryLike[Expression] {
with UnaryLike[Expression] {

override def prettyName: String = "count_if"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, ExpressionDescription, Literal}
import org.apache.spark.sql.catalyst.trees.QuaternaryLike
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.sketch.CountMinSketch
Expand Down Expand Up @@ -60,7 +61,9 @@ case class CountMinSketchAgg(
seedExpression: Expression,
override val mutableAggBufferOffset: Int,
override val inputAggBufferOffset: Int)
extends TypedImperativeAggregate[CountMinSketch] with ExpectsInputTypes {
extends TypedImperativeAggregate[CountMinSketch]
with ExpectsInputTypes
with QuaternaryLike[Expression] {

def this(
child: Expression,
Expand Down Expand Up @@ -145,8 +148,10 @@ case class CountMinSketchAgg(
override def defaultResult: Option[Literal] =
Option(Literal.create(eval(createAggregationBuffer()), dataType))

override def children: Seq[Expression] =
Seq(child, epsExpression, confidenceExpression, seedExpression)

override def prettyName: String = "count_min_sketch"

override def first: Expression = child
override def second: Expression = epsExpression
override def third: Expression = confidenceExpression
override def fourth: Expression = seedExpression
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,9 @@ import org.apache.spark.sql.types._
* Compute the covariance between two expressions.
* When applied on empty data (i.e., count is zero), it returns NULL.
*/
abstract class Covariance(x: Expression, y: Expression, nullOnDivideByZero: Boolean)
abstract class Covariance(val left: Expression, val right: Expression, nullOnDivideByZero: Boolean)
extends DeclarativeAggregate with ImplicitCastInputTypes with BinaryLike[Expression] {

override def left: Expression = x
override def right: Expression = y
override def nullable: Boolean = true
override def dataType: DataType = DoubleType
override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType)
Expand Down Expand Up @@ -72,14 +70,14 @@ abstract class Covariance(x: Expression, y: Expression, nullOnDivideByZero: Bool

protected def updateExpressionsDef: Seq[Expression] = {
val newN = n + 1.0
val dx = x - xAvg
val dy = y - yAvg
val dx = left - xAvg
val dy = right - yAvg
val dyN = dy / newN
val newXAvg = xAvg + dx / newN
val newYAvg = yAvg + dyN
val newCk = ck + dx * (y - newYAvg)
val newCk = ck + dx * (right - newYAvg)

val isNull = x.isNull || y.isNull
val isNull = left.isNull || right.isNull
Seq(
If(isNull, n, newN),
If(isNull, xAvg, newXAvg),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ import org.apache.spark.sql.types._
group = "agg_funcs",
since = "1.0.0")
case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes
with UnaryLike[Expression] {
with UnaryLike[Expression] {

override def nullable: Boolean = true

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.trees.UnaryLike
import org.apache.spark.sql.types.{AbstractDataType, DataType, IntegralType}

abstract class BitAggregate extends DeclarativeAggregate with ExpectsInputTypes
with UnaryLike[Expression] {
with UnaryLike[Expression] {

val child: Expression

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import org.apache.spark.sql.types._
* can cause GC paused and eventually OutOfMemory Errors.
*/
abstract class Collect[T <: Growable[Any] with Iterable[Any]] extends TypedImperativeAggregate[T]
with UnaryLike[Expression] {
with UnaryLike[Expression] {

val child: Expression

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ object GroupingSets {
group = "agg_funcs")
// scalastyle:on line.size.limit line.contains.tab
case class Grouping(child: Expression) extends Expression with Unevaluable
with UnaryLike[Expression] {
with UnaryLike[Expression] {
@transient
override lazy val references: AttributeSet =
AttributeSet(VirtualColumn.groupingIdAttribute :: Nil)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import scala.collection.mutable
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, UnresolvedException}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.trees.{BinaryLike, QuaternaryLike, TernaryLike}
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -119,8 +120,6 @@ trait HigherOrderFunction extends Expression with ExpectsInputTypes {

override def nullable: Boolean = arguments.exists(_.nullable)

override def children: Seq[Expression] = arguments ++ functions

/**
* Arguments of the higher ordered function.
*/
Expand Down Expand Up @@ -182,7 +181,7 @@ trait HigherOrderFunction extends Expression with ExpectsInputTypes {
/**
* Trait for functions having as input one argument and one function.
*/
trait SimpleHigherOrderFunction extends HigherOrderFunction {
trait SimpleHigherOrderFunction extends HigherOrderFunction with BinaryLike[Expression] {

def argument: Expression

Expand All @@ -202,6 +201,9 @@ trait SimpleHigherOrderFunction extends HigherOrderFunction {

def functionForEval: Expression = functionsForEval.head

override def left: Expression = argument
override def right: Expression = function

/**
* Called by [[eval]]. If a subclass keeps the default nullability, it can override this method
* in order to save null-check code.
Expand Down Expand Up @@ -694,7 +696,7 @@ case class ArrayAggregate(
zero: Expression,
merge: Expression,
finish: Expression)
extends HigherOrderFunction with CodegenFallback {
extends HigherOrderFunction with CodegenFallback with QuaternaryLike[Expression] {

def this(argument: Expression, zero: Expression, merge: Expression) = {
this(argument, zero, merge, LambdaFunction.identity)
Expand Down Expand Up @@ -760,6 +762,11 @@ case class ArrayAggregate(
}

override def prettyName: String = "aggregate"

override def first: Expression = argument
override def second: Expression = zero
override def third: Expression = merge
override def fourth: Expression = finish
}

/**
Expand Down Expand Up @@ -884,7 +891,7 @@ case class TransformValues(
since = "3.0.0",
group = "lambda_funcs")
case class MapZipWith(left: Expression, right: Expression, function: Expression)
extends HigherOrderFunction with CodegenFallback {
extends HigherOrderFunction with CodegenFallback with TernaryLike[Expression] {

def functionForEval: Expression = functionsForEval.head

Expand Down Expand Up @@ -1045,6 +1052,10 @@ case class MapZipWith(left: Expression, right: Expression, function: Expression)
}

override def prettyName: String = "map_zip_with"

override def first: Expression = left
override def second: Expression = right
override def third: Expression = function
}

// scalastyle:off line.size.limit
Expand All @@ -1063,15 +1074,15 @@ case class MapZipWith(left: Expression, right: Expression, function: Expression)
group = "lambda_funcs")
// scalastyle:on line.size.limit
case class ZipWith(left: Expression, right: Expression, function: Expression)
extends HigherOrderFunction with CodegenFallback {
extends HigherOrderFunction with CodegenFallback with TernaryLike[Expression] {

def functionForEval: Expression = functionsForEval.head

override def arguments: Seq[Expression] = left :: right :: Nil

override def argumentTypes: Seq[AbstractDataType] = ArrayType :: ArrayType :: Nil

override def functions: Seq[Expression] = List(function)
override def functions: Seq[Expression] = function :: Nil

override def functionTypes: Seq[AbstractDataType] = AnyDataType :: Nil

Expand Down Expand Up @@ -1121,4 +1132,8 @@ case class ZipWith(left: Expression, right: Expression, function: Expression)
}

override def prettyName: String = "zip_with"

override def first: Expression = left
override def second: Expression = right
override def third: Expression = function
}
Original file line number Diff line number Diff line change
Expand Up @@ -1488,7 +1488,6 @@ case class WidthBucket(
numBucket: Expression)
extends QuaternaryExpression with ImplicitCastInputTypes with NullIntolerant {

override def children: Seq[Expression] = Seq(value, minValue, maxValue, numBucket)
override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType, DoubleType, LongType)
override def dataType: DataType = LongType
override def nullable: Boolean = true
Expand All @@ -1507,4 +1506,9 @@ case class WidthBucket(
"org.apache.spark.sql.catalyst.expressions.WidthBucket" +
s".computeBucketNumber($input, $min, $max, $numBucket)")
}

override def first: Expression = value
override def second: Expression = minValue
override def third: Expression = maxValue
override def fourth: Expression = numBucket
}
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,6 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio
override def dataType: DataType = StringType
override def inputTypes: Seq[AbstractDataType] =
Seq(StringType, StringType, StringType, IntegerType)
override def children: Seq[Expression] = subject :: regexp :: rep :: pos :: Nil
override def prettyName: String = "regexp_replace"

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
Expand Down Expand Up @@ -618,6 +617,11 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio
"""
})
}

override def first: Expression = subject
override def second: Expression = regexp
override def third: Expression = rep
override def fourth: Expression = pos
}

object RegExpReplace {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -593,8 +593,6 @@ case class Overlay(input: Expression, replace: Expression, pos: Expression, len:
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, BinaryType),
TypeCollection(StringType, BinaryType), IntegerType, IntegerType)

override def children: Seq[Expression] = input :: replace :: pos :: len :: Nil

override def checkInputDataTypes(): TypeCheckResult = {
val inputTypeCheck = super.checkInputDataTypes()
if (inputTypeCheck.isSuccess) {
Expand Down Expand Up @@ -631,6 +629,11 @@ case class Overlay(input: Expression, replace: Expression, pos: Expression, len:
"org.apache.spark.sql.catalyst.expressions.Overlay" +
s".calculate($input, $replace, $pos, $len);")
}

override def first: Expression = input
override def second: Expression = replace
override def third: Expression = pos
override def fourth: Expression = len
}

object StringTranslate {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -742,7 +742,7 @@ case class NthValue(input: Expression, offset: Expression, ignoreNulls: Boolean)
group = "window_funcs")
// scalastyle:on line.size.limit line.contains.tab
case class NTile(buckets: Expression) extends RowNumberLike with SizeBasedWindowFunction
with UnaryLike[Expression] {
with UnaryLike[Expression] {

def this() = this(Literal(1))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@ case class InsertAction(
}

case class Assignment(key: Expression, value: Expression) extends Expression
with Unevaluable with BinaryLike[Expression] {
with Unevaluable with BinaryLike[Expression] {
override def nullable: Boolean = false
override def dataType: DataType = throw new UnresolvedException("nullable")
override def left: Expression = key
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -850,3 +850,11 @@ trait TernaryLike[T <: TreeNode[T]] { self: TreeNode[T] =>
def third: T
@transient override final lazy val children: Seq[T] = first :: second :: third :: Nil
}

trait QuaternaryLike[T <: TreeNode[T]] { self: TreeNode[T] =>
def first: T
def second: T
def third: T
def fourth: T
@transient override final lazy val children: Seq[T] = first :: second :: third :: fourth :: Nil
}
Original file line number Diff line number Diff line change
Expand Up @@ -560,10 +560,9 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
}
}

case class HugeCodeIntExpression(value: Int) extends Expression {
case class HugeCodeIntExpression(value: Int) extends LeafExpression {
override def nullable: Boolean = true
override def dataType: DataType = IntegerType
override def children: Seq[Expression] = Nil
override def eval(input: InternalRow): Any = value
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
// Assuming HugeMethodLimit to be 8000
Expand Down
Loading