Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-28481][SQL] More expressions should extend NullIntolerant #28626

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ object TimeWindow {
case class PreciseTimestampConversion(
child: Expression,
fromType: DataType,
toType: DataType) extends UnaryExpression with ExpectsInputTypes {
toType: DataType) extends UnaryExpression with ExpectsInputTypes with NullIntolerant {
override def inputTypes: Seq[AbstractDataType] = Seq(fromType)
override def dataType: DataType = toType
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -309,17 +309,11 @@ trait DivModLike extends BinaryArithmetic {

override def nullable: Boolean = true

final override def eval(input: InternalRow): Any = {
val input2 = right.eval(input)
if (input2 == null || input2 == 0) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah I see the difference now. Previously we can skip evaluating input1 if input2 is 0. Can we change it back and add comment to explain it? sorry for the back and forth!

final override def nullSafeEval(input1: Any, input2: Any): Any = {
if (input2 == 0) {
null
} else {
val input1 = left.eval(input)
if (input1 == null) {
null
} else {
evalOperation(input1, input2)
}
evalOperation(input1, input2)
}
}

Expand Down Expand Up @@ -516,24 +510,18 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic {

override def nullable: Boolean = true

override def eval(input: InternalRow): Any = {
val input2 = right.eval(input)
if (input2 == null || input2 == 0) {
override def nullSafeEval(input1: Any, input2: Any): Any = {
if (input2 == 0) {
null
} else {
val input1 = left.eval(input)
if (input1 == null) {
null
} else {
input1 match {
case i: Integer => pmod(i, input2.asInstanceOf[java.lang.Integer])
case l: Long => pmod(l, input2.asInstanceOf[java.lang.Long])
case s: Short => pmod(s, input2.asInstanceOf[java.lang.Short])
case b: Byte => pmod(b, input2.asInstanceOf[java.lang.Byte])
case f: Float => pmod(f, input2.asInstanceOf[java.lang.Float])
case d: Double => pmod(d, input2.asInstanceOf[java.lang.Double])
case d: Decimal => pmod(d, input2.asInstanceOf[Decimal])
}
input1 match {
case i: Integer => pmod(i, input2.asInstanceOf[java.lang.Integer])
case l: Long => pmod(l, input2.asInstanceOf[java.lang.Long])
case s: Short => pmod(s, input2.asInstanceOf[java.lang.Short])
case b: Byte => pmod(b, input2.asInstanceOf[java.lang.Byte])
case f: Float => pmod(f, input2.asInstanceOf[java.lang.Float])
case d: Double => pmod(d, input2.asInstanceOf[java.lang.Double])
case d: Decimal => pmod(d, input2.asInstanceOf[Decimal])
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,8 @@ case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithme
> SELECT _FUNC_ 0;
-1
""")
case class BitwiseNot(child: Expression) extends UnaryExpression with ExpectsInputTypes {
case class BitwiseNot(child: Expression)
extends UnaryExpression with ExpectsInputTypes with NullIntolerant {

override def inputTypes: Seq[AbstractDataType] = Seq(IntegralType)

Expand Down Expand Up @@ -164,7 +165,8 @@ case class BitwiseNot(child: Expression) extends UnaryExpression with ExpectsInp
0
""",
since = "3.0.0")
case class BitwiseCount(child: Expression) extends UnaryExpression with ExpectsInputTypes {
case class BitwiseCount(child: Expression)
extends UnaryExpression with ExpectsInputTypes with NullIntolerant {

override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(IntegralType, BooleanType))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ object Size {
""",
group = "map_funcs")
case class MapKeys(child: Expression)
extends UnaryExpression with ExpectsInputTypes {
extends UnaryExpression with ExpectsInputTypes with NullIntolerant {

override def inputTypes: Seq[AbstractDataType] = Seq(MapType)

Expand Down Expand Up @@ -332,7 +332,7 @@ case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsI
""",
group = "map_funcs")
case class MapValues(child: Expression)
extends UnaryExpression with ExpectsInputTypes {
extends UnaryExpression with ExpectsInputTypes with NullIntolerant {

override def inputTypes: Seq[AbstractDataType] = Seq(MapType)

Expand Down Expand Up @@ -361,7 +361,8 @@ case class MapValues(child: Expression)
""",
group = "map_funcs",
since = "3.0.0")
case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInputTypes {
case class MapEntries(child: Expression)
extends UnaryExpression with ExpectsInputTypes with NullIntolerant {

override def inputTypes: Seq[AbstractDataType] = Seq(MapType)

Expand Down Expand Up @@ -649,7 +650,7 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres
""",
group = "map_funcs",
since = "2.4.0")
case class MapFromEntries(child: Expression) extends UnaryExpression {
case class MapFromEntries(child: Expression) extends UnaryExpression with NullIntolerant {

@transient
private lazy val dataTypeDetails: Option[(MapType, Boolean, Boolean)] = child.dataType match {
Expand Down Expand Up @@ -873,7 +874,7 @@ object ArraySortLike {
group = "array_funcs")
// scalastyle:on line.size.limit
case class SortArray(base: Expression, ascendingOrder: Expression)
extends BinaryExpression with ArraySortLike {
extends BinaryExpression with ArraySortLike with NullIntolerant {

def this(e: Expression) = this(e, Literal(true))

Expand Down Expand Up @@ -1017,7 +1018,8 @@ case class Shuffle(child: Expression, randomSeed: Option[Long] = None)
Reverse logic for arrays is available since 2.4.0.
"""
)
case class Reverse(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
case class Reverse(child: Expression)
extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {

// Input types are utilized by type coercion in ImplicitTypeCasts.
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, ArrayType))
Expand Down Expand Up @@ -1086,7 +1088,7 @@ case class Reverse(child: Expression) extends UnaryExpression with ImplicitCastI
""",
group = "array_funcs")
case class ArrayContains(left: Expression, right: Expression)
extends BinaryExpression with ImplicitCastInputTypes {
extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant {

override def dataType: DataType = BooleanType

Expand Down Expand Up @@ -1185,7 +1187,7 @@ case class ArrayContains(left: Expression, right: Expression)
since = "2.4.0")
// scalastyle:off line.size.limit
case class ArraysOverlap(left: Expression, right: Expression)
extends BinaryArrayExpressionWithImplicitCast {
extends BinaryArrayExpressionWithImplicitCast with NullIntolerant {

override def checkInputDataTypes(): TypeCheckResult = super.checkInputDataTypes() match {
case TypeCheckResult.TypeCheckSuccess =>
Expand Down Expand Up @@ -1410,7 +1412,7 @@ case class ArraysOverlap(left: Expression, right: Expression)
since = "2.4.0")
// scalastyle:on line.size.limit
case class Slice(x: Expression, start: Expression, length: Expression)
extends TernaryExpression with ImplicitCastInputTypes {
extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant {

override def dataType: DataType = x.dataType

Expand Down Expand Up @@ -1688,7 +1690,8 @@ case class ArrayJoin(
""",
group = "array_funcs",
since = "2.4.0")
case class ArrayMin(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
case class ArrayMin(child: Expression)
extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {

override def nullable: Boolean = true

Expand Down Expand Up @@ -1755,7 +1758,8 @@ case class ArrayMin(child: Expression) extends UnaryExpression with ImplicitCast
""",
group = "array_funcs",
since = "2.4.0")
case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
case class ArrayMax(child: Expression)
extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {

override def nullable: Boolean = true

Expand Down Expand Up @@ -1831,7 +1835,7 @@ case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCast
group = "array_funcs",
since = "2.4.0")
case class ArrayPosition(left: Expression, right: Expression)
extends BinaryExpression with ImplicitCastInputTypes {
extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant {

@transient private lazy val ordering: Ordering[Any] =
TypeUtils.getInterpretedOrdering(right.dataType)
Expand Down Expand Up @@ -1909,7 +1913,7 @@ case class ArrayPosition(left: Expression, right: Expression)
""",
since = "2.4.0")
case class ElementAt(left: Expression, right: Expression)
extends GetMapValueUtil with GetArrayItemUtil {
extends GetMapValueUtil with GetArrayItemUtil with NullIntolerant {

@transient private lazy val mapKeyType = left.dataType.asInstanceOf[MapType].keyType

Expand Down Expand Up @@ -2245,7 +2249,7 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio
""",
group = "array_funcs",
since = "2.4.0")
case class Flatten(child: Expression) extends UnaryExpression {
case class Flatten(child: Expression) extends UnaryExpression with NullIntolerant {

private def childDataType: ArrayType = child.dataType.asInstanceOf[ArrayType]

Expand Down Expand Up @@ -2884,7 +2888,7 @@ case class ArrayRepeat(left: Expression, right: Expression)
group = "array_funcs",
since = "2.4.0")
case class ArrayRemove(left: Expression, right: Expression)
extends BinaryExpression with ImplicitCastInputTypes {
extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant {

override def dataType: DataType = left.dataType

Expand Down Expand Up @@ -3081,7 +3085,7 @@ trait ArraySetLike {
group = "array_funcs",
since = "2.4.0")
case class ArrayDistinct(child: Expression)
extends UnaryExpression with ArraySetLike with ExpectsInputTypes {
extends UnaryExpression with ArraySetLike with ExpectsInputTypes with NullIntolerant {

override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType)

Expand Down Expand Up @@ -3219,7 +3223,8 @@ case class ArrayDistinct(child: Expression)
/**
* Will become common base class for [[ArrayUnion]], [[ArrayIntersect]], and [[ArrayExcept]].
*/
trait ArrayBinaryLike extends BinaryArrayExpressionWithImplicitCast with ArraySetLike {
trait ArrayBinaryLike
extends BinaryArrayExpressionWithImplicitCast with ArraySetLike with NullIntolerant {
override protected def dt: DataType = dataType
override protected def et: DataType = elementType

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ object CreateMap {
{1.0:"2",3.0:"4"}
""", since = "2.4.0")
case class MapFromArrays(left: Expression, right: Expression)
extends BinaryExpression with ExpectsInputTypes {
extends BinaryExpression with ExpectsInputTypes with NullIntolerant {

override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, ArrayType)

Expand Down Expand Up @@ -476,7 +476,7 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression {
since = "2.0.1")
// scalastyle:on line.size.limit
case class StringToMap(text: Expression, pairDelim: Expression, keyValueDelim: Expression)
extends TernaryExpression with ExpectsInputTypes {
extends TernaryExpression with ExpectsInputTypes with NullIntolerant {

def this(child: Expression, pairDelim: Expression) = {
this(child, pairDelim, Literal(":"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,8 @@ case class StructsToCsv(
options: Map[String, String],
child: Expression,
timeZoneId: Option[String] = None)
extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback with ExpectsInputTypes {
extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback with ExpectsInputTypes
with NullIntolerant {
override def nullable: Boolean = true

def this(options: Map[String, String], child: Expression) = this(options, child, None)
Expand Down
Loading