From fb68d60af542335b97377d8cae433548ce422712 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Sun, 24 May 2020 10:22:17 +0800 Subject: [PATCH 1/7] More expressions should extend NullIntolerant --- pom.xml | 6 ++ sql/catalyst/pom.xml | 4 + .../sql/catalyst/expressions/TimeWindow.scala | 2 +- .../expressions/bitwiseExpressions.scala | 6 +- .../expressions/collectionOperations.scala | 39 ++++---- .../expressions/complexTypeCreator.scala | 4 +- .../catalyst/expressions/csvExpressions.scala | 3 +- .../expressions/datetimeExpressions.scala | 75 +++++++++------ .../expressions/decimalExpressions.scala | 4 +- .../spark/sql/catalyst/expressions/hash.scala | 11 ++- .../expressions/intervalExpressions.scala | 6 +- .../expressions/jsonExpressions.scala | 6 +- .../expressions/mathExpressions.scala | 23 +++-- .../expressions/regexpExpressions.scala | 6 +- .../expressions/stringExpressions.scala | 58 +++++++----- .../sql/catalyst/expressions/xml/xpath.scala | 3 +- .../NullIntolerantCheckerSuite.scala | 91 +++++++++++++++++++ 17 files changed, 245 insertions(+), 102 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullIntolerantCheckerSuite.scala diff --git a/pom.xml b/pom.xml index 6620673a7e5fc..2b253ca58359b 100644 --- a/pom.xml +++ b/pom.xml @@ -882,6 +882,12 @@ jline 2.14.6 + + org.clapper + classutil_${scala.binary.version} + 1.5.0 + test + org.scalatest scalatest_${scala.binary.version} diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 9edbb7fec97d0..9b2c8922d70aa 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -117,6 +117,10 @@ org.apache.arrow arrow-vector + + org.clapper + classutil_${scala.binary.version} + target/scala-${scala.binary.version}/classes diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala index 82d689477080d..f7fe467cea830 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala @@ -144,7 +144,7 @@ object TimeWindow { case class PreciseTimestampConversion( child: Expression, fromType: DataType, - toType: DataType) extends UnaryExpression with ExpectsInputTypes { + toType: DataType) extends UnaryExpression with ExpectsInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(fromType) override def dataType: DataType = toType override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala index 7b819db32e425..342b14eaa3390 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala @@ -127,7 +127,8 @@ case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithme > SELECT _FUNC_ 0; -1 """) -case class BitwiseNot(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class BitwiseNot(child: Expression) + extends UnaryExpression with ExpectsInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(IntegralType) @@ -164,7 +165,8 @@ case class BitwiseNot(child: Expression) extends UnaryExpression with ExpectsInp 0 """, since = "3.0.0") -case class BitwiseCount(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class BitwiseCount(child: Expression) + extends UnaryExpression with ExpectsInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(IntegralType, BooleanType)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 4fd68dcfe5156..e11e72090ced2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -141,7 +141,7 @@ object Size { """, group = "map_funcs") case class MapKeys(child: Expression) - extends UnaryExpression with ExpectsInputTypes { + extends UnaryExpression with ExpectsInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(MapType) @@ -332,7 +332,7 @@ case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsI """, group = "map_funcs") case class MapValues(child: Expression) - extends UnaryExpression with ExpectsInputTypes { + extends UnaryExpression with ExpectsInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(MapType) @@ -361,7 +361,8 @@ case class MapValues(child: Expression) """, group = "map_funcs", since = "3.0.0") -case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class MapEntries(child: Expression) + extends UnaryExpression with ExpectsInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(MapType) @@ -649,7 +650,7 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres """, group = "map_funcs", since = "2.4.0") -case class MapFromEntries(child: Expression) extends UnaryExpression { +case class MapFromEntries(child: Expression) extends UnaryExpression with NullIntolerant { @transient private lazy val dataTypeDetails: Option[(MapType, Boolean, Boolean)] = child.dataType match { @@ -873,7 +874,7 @@ object ArraySortLike { group = "array_funcs") // scalastyle:on line.size.limit case class SortArray(base: Expression, ascendingOrder: Expression) - extends BinaryExpression with ArraySortLike { + extends BinaryExpression with ArraySortLike with NullIntolerant { def this(e: Expression) = this(e, Literal(true)) @@ -1017,7 +1018,8 @@ case class Shuffle(child: Expression, randomSeed: Option[Long] = None) Reverse logic for arrays is available since 2.4.0. """ ) -case class Reverse(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class Reverse(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { // Input types are utilized by type coercion in ImplicitTypeCasts. override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, ArrayType)) @@ -1086,7 +1088,7 @@ case class Reverse(child: Expression) extends UnaryExpression with ImplicitCastI """, group = "array_funcs") case class ArrayContains(left: Expression, right: Expression) - extends BinaryExpression with ImplicitCastInputTypes { + extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = BooleanType @@ -1185,7 +1187,7 @@ case class ArrayContains(left: Expression, right: Expression) since = "2.4.0") // scalastyle:off line.size.limit case class ArraysOverlap(left: Expression, right: Expression) - extends BinaryArrayExpressionWithImplicitCast { + extends BinaryArrayExpressionWithImplicitCast with NullIntolerant { override def checkInputDataTypes(): TypeCheckResult = super.checkInputDataTypes() match { case TypeCheckResult.TypeCheckSuccess => @@ -1410,7 +1412,7 @@ case class ArraysOverlap(left: Expression, right: Expression) since = "2.4.0") // scalastyle:on line.size.limit case class Slice(x: Expression, start: Expression, length: Expression) - extends TernaryExpression with ImplicitCastInputTypes { + extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = x.dataType @@ -1688,7 +1690,8 @@ case class ArrayJoin( """, group = "array_funcs", since = "2.4.0") -case class ArrayMin(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class ArrayMin(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def nullable: Boolean = true @@ -1755,7 +1758,8 @@ case class ArrayMin(child: Expression) extends UnaryExpression with ImplicitCast """, group = "array_funcs", since = "2.4.0") -case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class ArrayMax(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def nullable: Boolean = true @@ -1831,7 +1835,7 @@ case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCast group = "array_funcs", since = "2.4.0") case class ArrayPosition(left: Expression, right: Expression) - extends BinaryExpression with ImplicitCastInputTypes { + extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { @transient private lazy val ordering: Ordering[Any] = TypeUtils.getInterpretedOrdering(right.dataType) @@ -1909,7 +1913,7 @@ case class ArrayPosition(left: Expression, right: Expression) """, since = "2.4.0") case class ElementAt(left: Expression, right: Expression) - extends GetMapValueUtil with GetArrayItemUtil { + extends GetMapValueUtil with GetArrayItemUtil with NullIntolerant { @transient private lazy val mapKeyType = left.dataType.asInstanceOf[MapType].keyType @@ -2245,7 +2249,7 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio """, group = "array_funcs", since = "2.4.0") -case class Flatten(child: Expression) extends UnaryExpression { +case class Flatten(child: Expression) extends UnaryExpression with NullIntolerant { private def childDataType: ArrayType = child.dataType.asInstanceOf[ArrayType] @@ -2884,7 +2888,7 @@ case class ArrayRepeat(left: Expression, right: Expression) group = "array_funcs", since = "2.4.0") case class ArrayRemove(left: Expression, right: Expression) - extends BinaryExpression with ImplicitCastInputTypes { + extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = left.dataType @@ -3081,7 +3085,7 @@ trait ArraySetLike { group = "array_funcs", since = "2.4.0") case class ArrayDistinct(child: Expression) - extends UnaryExpression with ArraySetLike with ExpectsInputTypes { + extends UnaryExpression with ArraySetLike with ExpectsInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType) @@ -3219,7 +3223,8 @@ case class ArrayDistinct(child: Expression) /** * Will become common base class for [[ArrayUnion]], [[ArrayIntersect]], and [[ArrayExcept]]. */ -trait ArrayBinaryLike extends BinaryArrayExpressionWithImplicitCast with ArraySetLike { +trait ArrayBinaryLike + extends BinaryArrayExpressionWithImplicitCast with ArraySetLike with NullIntolerant{ override protected def dt: DataType = dataType override protected def et: DataType = elementType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 858c91a4d8e86..b7fa218c395a8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -255,7 +255,7 @@ object CreateMap { {1.0:"2",3.0:"4"} """, since = "2.4.0") case class MapFromArrays(left: Expression, right: Expression) - extends BinaryExpression with ExpectsInputTypes { + extends BinaryExpression with ExpectsInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, ArrayType) @@ -452,7 +452,7 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression { since = "2.0.1") // scalastyle:on line.size.limit case class StringToMap(text: Expression, pairDelim: Expression, keyValueDelim: Expression) - extends TernaryExpression with ExpectsInputTypes { + extends TernaryExpression with ExpectsInputTypes with NullIntolerant { def this(child: Expression, pairDelim: Expression) = { this(child, pairDelim, Literal(":")) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala index 5140db90c5954..f9ccf3c8c811f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala @@ -211,7 +211,8 @@ case class StructsToCsv( options: Map[String, String], child: Expression, timeZoneId: Option[String] = None) - extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback with ExpectsInputTypes { + extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback with ExpectsInputTypes + with NullIntolerant { override def nullable: Boolean = true def this(options: Map[String, String], child: Expression) = this(options, child, None) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index afc57aa546fe8..dda3a27c3463a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -198,7 +198,7 @@ case class CurrentBatchTimestamp( group = "datetime_funcs", since = "1.5.0") case class DateAdd(startDate: Expression, days: Expression) - extends BinaryExpression with ExpectsInputTypes { + extends BinaryExpression with ExpectsInputTypes with NullIntolerant { override def left: Expression = startDate override def right: Expression = days @@ -234,7 +234,7 @@ case class DateAdd(startDate: Expression, days: Expression) group = "datetime_funcs", since = "1.5.0") case class DateSub(startDate: Expression, days: Expression) - extends BinaryExpression with ExpectsInputTypes { + extends BinaryExpression with ExpectsInputTypes with NullIntolerant { override def left: Expression = startDate override def right: Expression = days @@ -266,7 +266,8 @@ case class DateSub(startDate: Expression, days: Expression) group = "datetime_funcs", since = "1.5.0") case class Hour(child: Expression, timeZoneId: Option[String] = None) - extends UnaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes { + extends UnaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes + with NullIntolerant { def this(child: Expression) = this(child, None) @@ -298,7 +299,8 @@ case class Hour(child: Expression, timeZoneId: Option[String] = None) group = "datetime_funcs", since = "1.5.0") case class Minute(child: Expression, timeZoneId: Option[String] = None) - extends UnaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes { + extends UnaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes + with NullIntolerant { def this(child: Expression) = this(child, None) @@ -330,7 +332,8 @@ case class Minute(child: Expression, timeZoneId: Option[String] = None) group = "datetime_funcs", since = "1.5.0") case class Second(child: Expression, timeZoneId: Option[String] = None) - extends UnaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes { + extends UnaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes + with NullIntolerant { def this(child: Expression) = this(child, None) @@ -353,7 +356,8 @@ case class Second(child: Expression, timeZoneId: Option[String] = None) } case class SecondWithFraction(child: Expression, timeZoneId: Option[String] = None) - extends UnaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes { + extends UnaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes + with NullIntolerant { def this(child: Expression) = this(child, None) @@ -385,7 +389,8 @@ case class SecondWithFraction(child: Expression, timeZoneId: Option[String] = No """, group = "datetime_funcs", since = "1.5.0") -case class DayOfYear(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class DayOfYear(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(DateType) @@ -402,7 +407,7 @@ case class DayOfYear(child: Expression) extends UnaryExpression with ImplicitCas } abstract class NumberToTimestampBase extends UnaryExpression - with ExpectsInputTypes { + with ExpectsInputTypes with NullIntolerant { protected def upScaleFactor: Long @@ -487,7 +492,8 @@ case class MicrosToTimestamp(child: Expression) """, group = "datetime_funcs", since = "1.5.0") -case class Year(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class Year(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(DateType) @@ -503,7 +509,8 @@ case class Year(child: Expression) extends UnaryExpression with ImplicitCastInpu } } -case class YearOfWeek(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class YearOfWeek(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(DateType) @@ -528,7 +535,8 @@ case class YearOfWeek(child: Expression) extends UnaryExpression with ImplicitCa """, group = "datetime_funcs", since = "1.5.0") -case class Quarter(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class Quarter(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(DateType) @@ -553,7 +561,8 @@ case class Quarter(child: Expression) extends UnaryExpression with ImplicitCastI """, group = "datetime_funcs", since = "1.5.0") -case class Month(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class Month(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(DateType) @@ -577,7 +586,8 @@ case class Month(child: Expression) extends UnaryExpression with ImplicitCastInp 30 """, since = "1.5.0") -case class DayOfMonth(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class DayOfMonth(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(DateType) @@ -647,7 +657,7 @@ case class WeekDay(child: Expression) extends DayWeek { } } -abstract class DayWeek extends UnaryExpression with ImplicitCastInputTypes { +abstract class DayWeek extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(DateType) @@ -665,7 +675,8 @@ abstract class DayWeek extends UnaryExpression with ImplicitCastInputTypes { group = "datetime_funcs", since = "1.5.0") // scalastyle:on line.size.limit -case class WeekOfYear(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class WeekOfYear(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(DateType) @@ -704,7 +715,8 @@ case class WeekOfYear(child: Expression) extends UnaryExpression with ImplicitCa since = "1.5.0") // scalastyle:on line.size.limit case class DateFormatClass(left: Expression, right: Expression, timeZoneId: Option[String] = None) - extends BinaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes { + extends BinaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes + with NullIntolerant { def this(left: Expression, right: Expression) = this(left, right, None) @@ -859,7 +871,7 @@ case class UnixTimestamp(timeExp: Expression, format: Expression, timeZoneId: Op } abstract class ToTimestamp - extends BinaryExpression with TimeZoneAwareExpression with ExpectsInputTypes { + extends BinaryExpression with TimeZoneAwareExpression with ExpectsInputTypes with NullIntolerant { // The result of the conversion to timestamp is microseconds divided by this factor. // For example if the factor is 1000000, the result of the expression is in seconds. @@ -1150,7 +1162,8 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[ """, group = "datetime_funcs", since = "1.5.0") -case class LastDay(startDate: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class LastDay(startDate: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def child: Expression = startDate override def inputTypes: Seq[AbstractDataType] = Seq(DateType) @@ -1188,7 +1201,7 @@ case class LastDay(startDate: Expression) extends UnaryExpression with ImplicitC since = "1.5.0") // scalastyle:on line.size.limit case class NextDay(startDate: Expression, dayOfWeek: Expression) - extends BinaryExpression with ImplicitCastInputTypes { + extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { override def left: Expression = startDate override def right: Expression = dayOfWeek @@ -1244,7 +1257,7 @@ case class NextDay(startDate: Expression, dayOfWeek: Expression) * Adds an interval to timestamp. */ case class TimeAdd(start: Expression, interval: Expression, timeZoneId: Option[String] = None) - extends BinaryExpression with TimeZoneAwareExpression with ExpectsInputTypes { + extends BinaryExpression with TimeZoneAwareExpression with ExpectsInputTypes with NullIntolerant { def this(start: Expression, interval: Expression) = this(start, interval, None) @@ -1302,7 +1315,7 @@ case class DateAddInterval( interval: Expression, timeZoneId: Option[String] = None, ansiEnabled: Boolean = SQLConf.get.ansiEnabled) - extends BinaryExpression with ExpectsInputTypes with TimeZoneAwareExpression { + extends BinaryExpression with ExpectsInputTypes with TimeZoneAwareExpression with NullIntolerant { override def left: Expression = start override def right: Expression = interval @@ -1376,7 +1389,7 @@ case class DateAddInterval( since = "1.5.0") // scalastyle:on line.size.limit case class FromUTCTimestamp(left: Expression, right: Expression) - extends BinaryExpression with ImplicitCastInputTypes { + extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, StringType) override def dataType: DataType = TimestampType @@ -1436,7 +1449,7 @@ case class FromUTCTimestamp(left: Expression, right: Expression) since = "1.5.0") // scalastyle:on line.size.limit case class AddMonths(startDate: Expression, numMonths: Expression) - extends BinaryExpression with ImplicitCastInputTypes { + extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { override def left: Expression = startDate override def right: Expression = numMonths @@ -1490,7 +1503,8 @@ case class MonthsBetween( date2: Expression, roundOff: Expression, timeZoneId: Option[String] = None) - extends TernaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes { + extends TernaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes + with NullIntolerant { def this(date1: Expression, date2: Expression) = this(date1, date2, Literal.TrueLiteral, None) @@ -1548,7 +1562,7 @@ case class MonthsBetween( since = "1.5.0") // scalastyle:on line.size.limit case class ToUTCTimestamp(left: Expression, right: Expression) - extends BinaryExpression with ImplicitCastInputTypes { + extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, StringType) override def dataType: DataType = TimestampType @@ -1902,7 +1916,7 @@ case class TruncTimestamp( group = "datetime_funcs", since = "1.5.0") case class DateDiff(endDate: Expression, startDate: Expression) - extends BinaryExpression with ImplicitCastInputTypes { + extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { override def left: Expression = endDate override def right: Expression = startDate @@ -1956,7 +1970,7 @@ private case class GetTimestamp( group = "datetime_funcs", since = "3.0.0") case class MakeDate(year: Expression, month: Expression, day: Expression) - extends TernaryExpression with ImplicitCastInputTypes { + extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant { override def children: Seq[Expression] = Seq(year, month, day) override def inputTypes: Seq[AbstractDataType] = Seq(IntegerType, IntegerType, IntegerType) @@ -2027,7 +2041,8 @@ case class MakeTimestamp( sec: Expression, timezone: Option[Expression] = None, timeZoneId: Option[String] = None) - extends SeptenaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes { + extends SeptenaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes + with NullIntolerant { def this( year: Expression, @@ -2303,7 +2318,7 @@ case class Extract(field: Expression, source: Expression, child: Expression) * between the given timestamps. */ case class SubtractTimestamps(endTimestamp: Expression, startTimestamp: Expression) - extends BinaryExpression with ExpectsInputTypes { + extends BinaryExpression with ExpectsInputTypes with NullIntolerant { override def left: Expression = endTimestamp override def right: Expression = startTimestamp @@ -2324,7 +2339,7 @@ case class SubtractTimestamps(endTimestamp: Expression, startTimestamp: Expressi * Returns the interval from the `left` date (inclusive) to the `right` date (exclusive). */ case class SubtractDates(left: Expression, right: Expression) - extends BinaryExpression with ImplicitCastInputTypes { + extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(DateType, DateType) override def dataType: DataType = CalendarIntervalType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala index 9014ebfe2f96a..c2c70b2ab08e1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.types._ * Note: this expression is internal and created only by the optimizer, * we don't need to do type check for it. */ -case class UnscaledValue(child: Expression) extends UnaryExpression { +case class UnscaledValue(child: Expression) extends UnaryExpression with NullIntolerant { override def dataType: DataType = LongType override def toString: String = s"UnscaledValue($child)" @@ -49,7 +49,7 @@ case class MakeDecimal( child: Expression, precision: Int, scale: Int, - nullOnOverflow: Boolean) extends UnaryExpression { + nullOnOverflow: Boolean) extends UnaryExpression with NullIntolerant { def this(child: Expression, precision: Int, scale: Int) = { this(child, precision, scale, !SQLConf.get.ansiEnabled) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index 4c8c58ae232f4..5e21b58f070ba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -53,7 +53,8 @@ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} > SELECT _FUNC_('Spark'); 8cde774d6f7333752ed72cacddb05126 """) -case class Md5(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class Md5(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = StringType @@ -89,7 +90,7 @@ case class Md5(child: Expression) extends UnaryExpression with ImplicitCastInput """) // scalastyle:on line.size.limit case class Sha2(left: Expression, right: Expression) - extends BinaryExpression with Serializable with ImplicitCastInputTypes { + extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant with Serializable { override def dataType: DataType = StringType override def nullable: Boolean = true @@ -160,7 +161,8 @@ case class Sha2(left: Expression, right: Expression) > SELECT _FUNC_('Spark'); 85f5955f4b27a9a4c2aab6ffe5d7189fc298b92c """) -case class Sha1(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class Sha1(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = StringType @@ -187,7 +189,8 @@ case class Sha1(child: Expression) extends UnaryExpression with ImplicitCastInpu > SELECT _FUNC_('Spark'); 1557323817 """) -case class Crc32(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class Crc32(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = LongType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala index 1a569a7b89fe1..baab224691bc1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala @@ -31,7 +31,7 @@ abstract class ExtractIntervalPart( val dataType: DataType, func: CalendarInterval => Any, funcName: String) - extends UnaryExpression with ExpectsInputTypes with Serializable { + extends UnaryExpression with ExpectsInputTypes with NullIntolerant with Serializable { override def inputTypes: Seq[AbstractDataType] = Seq(CalendarIntervalType) @@ -82,7 +82,7 @@ object ExtractIntervalPart { abstract class IntervalNumOperation( interval: Expression, num: Expression) - extends BinaryExpression with ImplicitCastInputTypes with Serializable { + extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant with Serializable { override def left: Expression = interval override def right: Expression = num @@ -160,7 +160,7 @@ case class MakeInterval( hours: Expression, mins: Expression, secs: Expression) - extends SeptenaryExpression with ImplicitCastInputTypes { + extends SeptenaryExpression with ImplicitCastInputTypes with NullIntolerant { def this( years: Expression, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 205e5271517c3..f4568f860ac0e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -519,7 +519,8 @@ case class JsonToStructs( options: Map[String, String], child: Expression, timeZoneId: Option[String] = None) - extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback with ExpectsInputTypes { + extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback with ExpectsInputTypes + with NullIntolerant { // The JSON input data might be missing certain fields. We force the nullability // of the user-provided schema to avoid data corruptions. In particular, the parquet-mr encoder @@ -638,7 +639,8 @@ case class StructsToJson( options: Map[String, String], child: Expression, timeZoneId: Option[String] = None) - extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback with ExpectsInputTypes { + extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback + with ExpectsInputTypes with NullIntolerant { override def nullable: Boolean = true def this(options: Map[String, String], child: Expression) = this(options, child, None) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index 66e6334e3a450..8806fc68f1306 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -57,7 +57,7 @@ abstract class LeafMathExpression(c: Double, name: String) * @param name The short name of the function */ abstract class UnaryMathExpression(val f: Double => Double, name: String) - extends UnaryExpression with Serializable with ImplicitCastInputTypes { + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant with Serializable { override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType) override def dataType: DataType = DoubleType @@ -111,7 +111,7 @@ abstract class UnaryLogExpression(f: Double => Double, name: String) * @param name The short name of the function */ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) - extends BinaryExpression with Serializable with ImplicitCastInputTypes { + extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant with Serializable { override def inputTypes: Seq[DataType] = Seq(DoubleType, DoubleType) @@ -324,7 +324,7 @@ case class Acosh(child: Expression) -16 """) case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expression) - extends TernaryExpression with ImplicitCastInputTypes { + extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant { override def children: Seq[Expression] = Seq(numExpr, fromBaseExpr, toBaseExpr) override def inputTypes: Seq[AbstractDataType] = Seq(StringType, IntegerType, IntegerType) @@ -452,7 +452,8 @@ object Factorial { > SELECT _FUNC_(5); 120 """) -case class Factorial(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class Factorial(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[DataType] = Seq(IntegerType) @@ -732,7 +733,7 @@ case class ToRadians(child: Expression) extends UnaryMathExpression(math.toRadia """) // scalastyle:on line.size.limit case class Bin(child: Expression) - extends UnaryExpression with Serializable with ImplicitCastInputTypes { + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant with Serializable { override def inputTypes: Seq[DataType] = Seq(LongType) override def dataType: DataType = StringType @@ -831,7 +832,8 @@ object Hex { > SELECT _FUNC_('Spark SQL'); 537061726B2053514C """) -case class Hex(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class Hex(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(LongType, BinaryType, StringType)) @@ -866,7 +868,8 @@ case class Hex(child: Expression) extends UnaryExpression with ImplicitCastInput > SELECT decode(_FUNC_('537061726B2053514C'), 'UTF-8'); Spark SQL """) -case class Unhex(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class Unhex(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(StringType) @@ -952,7 +955,7 @@ case class Pow(left: Expression, right: Expression) 4 """) case class ShiftLeft(left: Expression, right: Expression) - extends BinaryExpression with ImplicitCastInputTypes { + extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(IntegerType, LongType), IntegerType) @@ -986,7 +989,7 @@ case class ShiftLeft(left: Expression, right: Expression) 2 """) case class ShiftRight(left: Expression, right: Expression) - extends BinaryExpression with ImplicitCastInputTypes { + extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(IntegerType, LongType), IntegerType) @@ -1020,7 +1023,7 @@ case class ShiftRight(left: Expression, right: Expression) 2 """) case class ShiftRightUnsigned(left: Expression, right: Expression) - extends BinaryExpression with ImplicitCastInputTypes { + extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(IntegerType, LongType), IntegerType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index 3f60ca388a807..28924fac48eef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -283,7 +283,7 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress """, since = "1.5.0") case class StringSplit(str: Expression, regex: Expression, limit: Expression) - extends TernaryExpression with ImplicitCastInputTypes { + extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = ArrayType(StringType) override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType) @@ -325,7 +325,7 @@ case class StringSplit(str: Expression, regex: Expression, limit: Expression) since = "1.5.0") // scalastyle:on line.size.limit case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expression) - extends TernaryExpression with ImplicitCastInputTypes { + extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant { // last regex in string, we will update the pattern iff regexp value changed. @transient private var lastRegex: UTF8String = _ @@ -433,7 +433,7 @@ object RegExpExtract { """, since = "1.5.0") case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expression) - extends TernaryExpression with ImplicitCastInputTypes { + extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant { def this(s: Expression, r: Expression) = this(s, r, Literal(1)) // last regex in string, we will update the pattern iff regexp value changed. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 0b9fb8f85fe3c..9d1016bb793b1 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -334,7 +334,7 @@ trait String2StringExpression extends ImplicitCastInputTypes { """, since = "1.0.1") case class Upper(child: Expression) - extends UnaryExpression with String2StringExpression { + extends UnaryExpression with String2StringExpression with NullIntolerant { // scalastyle:off caselocale override def convert(v: UTF8String): UTF8String = v.toUpperCase @@ -356,7 +356,8 @@ case class Upper(child: Expression) sparksql """, since = "1.0.1") -case class Lower(child: Expression) extends UnaryExpression with String2StringExpression { +case class Lower(child: Expression) + extends UnaryExpression with String2StringExpression with NullIntolerant{ // scalastyle:off caselocale override def convert(v: UTF8String): UTF8String = v.toLowerCase @@ -432,7 +433,7 @@ case class EndsWith(left: Expression, right: Expression) extends StringPredicate since = "2.3.0") // scalastyle:on line.size.limit case class StringReplace(srcExpr: Expression, searchExpr: Expression, replaceExpr: Expression) - extends TernaryExpression with ImplicitCastInputTypes { + extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant { def this(srcExpr: Expression, searchExpr: Expression) = { this(srcExpr, searchExpr, Literal("")) @@ -598,7 +599,7 @@ object StringTranslate { since = "1.5.0") // scalastyle:on line.size.limit case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replaceExpr: Expression) - extends TernaryExpression with ImplicitCastInputTypes { + extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant { @transient private var lastMatching: UTF8String = _ @transient private var lastReplace: UTF8String = _ @@ -663,7 +664,7 @@ case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replac since = "1.5.0") // scalastyle:on line.size.limit case class FindInSet(left: Expression, right: Expression) extends BinaryExpression - with ImplicitCastInputTypes { + with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType) @@ -1035,7 +1036,7 @@ case class StringTrimRight( since = "1.5.0") // scalastyle:on line.size.limit case class StringInstr(str: Expression, substr: Expression) - extends BinaryExpression with ImplicitCastInputTypes { + extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { override def left: Expression = str override def right: Expression = substr @@ -1077,7 +1078,7 @@ case class StringInstr(str: Expression, substr: Expression) since = "1.5.0") // scalastyle:on line.size.limit case class SubstringIndex(strExpr: Expression, delimExpr: Expression, countExpr: Expression) - extends TernaryExpression with ImplicitCastInputTypes { + extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = StringType override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType) @@ -1205,7 +1206,7 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression) """, since = "1.5.0") case class StringLPad(str: Expression, len: Expression, pad: Expression = Literal(" ")) - extends TernaryExpression with ImplicitCastInputTypes { + extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant { def this(str: Expression, len: Expression) = { this(str, len, Literal(" ")) @@ -1246,7 +1247,7 @@ case class StringLPad(str: Expression, len: Expression, pad: Expression = Litera """, since = "1.5.0") case class StringRPad(str: Expression, len: Expression, pad: Expression = Literal(" ")) - extends TernaryExpression with ImplicitCastInputTypes { + extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant { def this(str: Expression, len: Expression) = { this(str, len, Literal(" ")) @@ -1536,7 +1537,8 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC Spark Sql """, since = "1.5.0") -case class InitCap(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class InitCap(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[DataType] = Seq(StringType) override def dataType: DataType = StringType @@ -1563,7 +1565,7 @@ case class InitCap(child: Expression) extends UnaryExpression with ImplicitCastI """, since = "1.5.0") case class StringRepeat(str: Expression, times: Expression) - extends BinaryExpression with ImplicitCastInputTypes { + extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { override def left: Expression = str override def right: Expression = times @@ -1593,7 +1595,7 @@ case class StringRepeat(str: Expression, times: Expression) """, since = "1.5.0") case class StringSpace(child: Expression) - extends UnaryExpression with ImplicitCastInputTypes { + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = StringType override def inputTypes: Seq[DataType] = Seq(IntegerType) @@ -1738,7 +1740,8 @@ case class Left(str: Expression, len: Expression, child: Expression) extends Run """, since = "1.5.0") // scalastyle:on line.size.limit -case class Length(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class Length(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant{ override def dataType: DataType = IntegerType override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, BinaryType)) @@ -1766,7 +1769,8 @@ case class Length(child: Expression) extends UnaryExpression with ImplicitCastIn 72 """, since = "2.3.0") -case class BitLength(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class BitLength(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = IntegerType override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, BinaryType)) @@ -1797,7 +1801,8 @@ case class BitLength(child: Expression) extends UnaryExpression with ImplicitCas 9 """, since = "2.3.0") -case class OctetLength(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class OctetLength(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = IntegerType override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, BinaryType)) @@ -1828,7 +1833,7 @@ case class OctetLength(child: Expression) extends UnaryExpression with ImplicitC """, since = "1.5.0") case class Levenshtein(left: Expression, right: Expression) extends BinaryExpression - with ImplicitCastInputTypes { + with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType) @@ -1853,7 +1858,8 @@ case class Levenshtein(left: Expression, right: Expression) extends BinaryExpres M460 """, since = "1.5.0") -case class SoundEx(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class SoundEx(child: Expression) + extends UnaryExpression with ExpectsInputTypes with NullIntolerant { override def dataType: DataType = StringType @@ -1879,7 +1885,8 @@ case class SoundEx(child: Expression) extends UnaryExpression with ExpectsInputT 50 """, since = "1.5.0") -case class Ascii(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class Ascii(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = IntegerType override def inputTypes: Seq[DataType] = Seq(StringType) @@ -1921,7 +1928,8 @@ case class Ascii(child: Expression) extends UnaryExpression with ImplicitCastInp """, since = "2.3.0") // scalastyle:on line.size.limit -case class Chr(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class Chr(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = StringType override def inputTypes: Seq[DataType] = Seq(LongType) @@ -1964,7 +1972,8 @@ case class Chr(child: Expression) extends UnaryExpression with ImplicitCastInput U3BhcmsgU1FM """, since = "1.5.0") -case class Base64(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class Base64(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = StringType override def inputTypes: Seq[DataType] = Seq(BinaryType) @@ -1992,7 +2001,8 @@ case class Base64(child: Expression) extends UnaryExpression with ImplicitCastIn Spark SQL """, since = "1.5.0") -case class UnBase64(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class UnBase64(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant{ override def dataType: DataType = BinaryType override def inputTypes: Seq[DataType] = Seq(StringType) @@ -2024,7 +2034,7 @@ case class UnBase64(child: Expression) extends UnaryExpression with ImplicitCast since = "1.5.0") // scalastyle:on line.size.limit case class Decode(bin: Expression, charset: Expression) - extends BinaryExpression with ImplicitCastInputTypes { + extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { override def left: Expression = bin override def right: Expression = charset @@ -2064,7 +2074,7 @@ case class Decode(bin: Expression, charset: Expression) since = "1.5.0") // scalastyle:on line.size.limit case class Encode(value: Expression, charset: Expression) - extends BinaryExpression with ImplicitCastInputTypes { + extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { override def left: Expression = value override def right: Expression = charset @@ -2108,7 +2118,7 @@ case class Encode(value: Expression, charset: Expression) """, since = "1.5.0") case class FormatNumber(x: Expression, d: Expression) - extends BinaryExpression with ExpectsInputTypes { + extends BinaryExpression with ExpectsInputTypes with NullIntolerant { override def left: Expression = x override def right: Expression = d diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala index 55e06cb9e8471..e08a10ecac71c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala @@ -30,7 +30,8 @@ import org.apache.spark.unsafe.types.UTF8String * * This is not the world's most efficient implementation due to type conversion, but works. */ -abstract class XPathExtract extends BinaryExpression with ExpectsInputTypes with CodegenFallback { +abstract class XPathExtract + extends BinaryExpression with ExpectsInputTypes with CodegenFallback with NullIntolerant { override def left: Expression = xml override def right: Expression = path diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullIntolerantCheckerSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullIntolerantCheckerSuite.scala new file mode 100644 index 0000000000000..d63d8e73d5225 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullIntolerantCheckerSuite.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import scala.annotation.tailrec + +import org.clapper.classutil.ClassFinder +import org.objectweb.asm.Opcodes + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero +import org.apache.spark.util.Utils + +class NullIntolerantCheckerSuite extends SparkFunSuite { + + // Do not check these Expressions + private val whiteList = List( + classOf[IntegralDivide], classOf[Divide], classOf[Remainder], classOf[Pmod], + classOf[ToUnixTimestamp], classOf[GetTimestamp], classOf[UnixTimestamp], + classOf[CheckOverflow], classOf[NormalizeNaNAndZero], + classOf[InSet], + classOf[PrintToStderr], classOf[CodegenFallbackExpression]).map(_.getName) + + private val finder = ClassFinder(maybeOverrideAsmVersion = Some(Opcodes.ASM7)) + + private val nullIntolerantName = classOf[NullIntolerant].getName + + Seq(classOf[UnaryExpression], classOf[BinaryExpression], + classOf[TernaryExpression], classOf[QuaternaryExpression], + classOf[SeptenaryExpression]).map(_.getName).foreach { expName => + ClassFinder.concreteSubclasses(expName, finder.getClasses) + .filterNot(c => whiteList.exists(_.equals(c.name))).foreach { classInfo => + test(s"${classInfo.name}") { + // Do not check NonSQLExpressions + if (!classInfo.interfaces.exists(_.equals(classOf[NonSQLExpression].getName))) { + val evalExist = overrodeEval(classInfo.name, expName) + val nullIntolerantExist = implementedNullIntolerant(classInfo.name, expName) + if (evalExist && nullIntolerantExist) { + fail(s"${classInfo.name} should not extend $nullIntolerantName") + } else if (!evalExist && !nullIntolerantExist) { + fail(s"${classInfo.name} should extend $nullIntolerantName") + } else { + assert((!evalExist && nullIntolerantExist) || (evalExist && !nullIntolerantExist)) + } + } + } + } + } + + @tailrec + private def implementedNullIntolerant(className: String, endClassName: String): Boolean = { + val clazz = Utils.classForName(className) + val nullIntolerant = clazz.getInterfaces.exists(_.getName.equals(nullIntolerantName)) || + clazz.getInterfaces.exists { i => + Utils.classForName(i.getName).getInterfaces.exists(_.getName.equals(nullIntolerantName)) + } + val superClassName = clazz.getSuperclass.getName + if (!nullIntolerant && !superClassName.equals(endClassName)) { + implementedNullIntolerant(superClassName, endClassName) + } else { + nullIntolerant + } + } + + @tailrec + private def overrodeEval(className: String, endClassName: String): Boolean = { + val clazz = Utils.classForName(className) + val eval = clazz.getDeclaredMethods.exists(_.getName.equals("eval")) + val superClassName = clazz.getSuperclass.getName + if (!eval && !superClassName.equals(endClassName)) { + overrodeEval(superClassName, endClassName) + } else { + eval + } + } +} From c3625ca3aba36466fae5386644ace2ff8db2334a Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Sun, 24 May 2020 11:42:27 +0800 Subject: [PATCH 2/7] Fix build error --- pom.xml | 3 +-- sql/catalyst/pom.xml | 1 + .../spark/sql/catalyst/expressions/datetimeExpressions.scala | 2 +- .../sql/catalyst/expressions/NullIntolerantCheckerSuite.scala | 1 - tools/pom.xml | 1 - 5 files changed, 3 insertions(+), 5 deletions(-) diff --git a/pom.xml b/pom.xml index 2b253ca58359b..594b2f4d1a1ba 100644 --- a/pom.xml +++ b/pom.xml @@ -885,8 +885,7 @@ org.clapper classutil_${scala.binary.version} - 1.5.0 - test + 1.5.1 org.scalatest diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 9b2c8922d70aa..ab65e7f5cd3bf 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -120,6 +120,7 @@ org.clapper classutil_${scala.binary.version} + test diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index dda3a27c3463a..0d815657109db 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -871,7 +871,7 @@ case class UnixTimestamp(timeExp: Expression, format: Expression, timeZoneId: Op } abstract class ToTimestamp - extends BinaryExpression with TimeZoneAwareExpression with ExpectsInputTypes with NullIntolerant { + extends BinaryExpression with TimeZoneAwareExpression with ExpectsInputTypes { // The result of the conversion to timestamp is microseconds divided by this factor. // For example if the factor is 1000000, the result of the expression is in seconds. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullIntolerantCheckerSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullIntolerantCheckerSuite.scala index d63d8e73d5225..1c38c467c649c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullIntolerantCheckerSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullIntolerantCheckerSuite.scala @@ -31,7 +31,6 @@ class NullIntolerantCheckerSuite extends SparkFunSuite { // Do not check these Expressions private val whiteList = List( classOf[IntegralDivide], classOf[Divide], classOf[Remainder], classOf[Pmod], - classOf[ToUnixTimestamp], classOf[GetTimestamp], classOf[UnixTimestamp], classOf[CheckOverflow], classOf[NormalizeNaNAndZero], classOf[InSet], classOf[PrintToStderr], classOf[CodegenFallbackExpression]).map(_.getName) diff --git a/tools/pom.xml b/tools/pom.xml index 6e806413ef261..a22bcd35371da 100644 --- a/tools/pom.xml +++ b/tools/pom.xml @@ -44,7 +44,6 @@ org.clapper classutil_${scala.binary.version} - 1.5.1 From 7fe349084b0c3793316807d911a0033999cccc40 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Sun, 24 May 2020 20:57:17 +0800 Subject: [PATCH 3/7] Move checker to ExpressionInfoSuite --- pom.xml | 5 -- sql/catalyst/pom.xml | 5 -- .../expressions/complexTypeCreator.scala | 2 +- .../NullIntolerantCheckerSuite.scala | 90 ------------------- .../sql/expressions/ExpressionInfoSuite.scala | 74 ++++++++++++++- tools/pom.xml | 1 + 6 files changed, 75 insertions(+), 102 deletions(-) delete mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullIntolerantCheckerSuite.scala diff --git a/pom.xml b/pom.xml index 594b2f4d1a1ba..6620673a7e5fc 100644 --- a/pom.xml +++ b/pom.xml @@ -882,11 +882,6 @@ jline 2.14.6 - - org.clapper - classutil_${scala.binary.version} - 1.5.1 - org.scalatest scalatest_${scala.binary.version} diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index ab65e7f5cd3bf..9edbb7fec97d0 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -117,11 +117,6 @@ org.apache.arrow arrow-vector - - org.clapper - classutil_${scala.binary.version} - test - target/scala-${scala.binary.version}/classes diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index b7fa218c395a8..b7511dac601de 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -325,7 +325,7 @@ object CreateStruct extends FunctionBuilder { */ val registryEntry: (String, (ExpressionInfo, FunctionBuilder)) = { val info: ExpressionInfo = new ExpressionInfo( - "org.apache.spark.sql.catalyst.expressions.NamedStruct", + "org.apache.spark.sql.catalyst.expressions.CreateStruct", null, "struct", "_FUNC_(col1, col2, col3, ...) - Creates a struct with the given field values.", diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullIntolerantCheckerSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullIntolerantCheckerSuite.scala deleted file mode 100644 index 1c38c467c649c..0000000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullIntolerantCheckerSuite.scala +++ /dev/null @@ -1,90 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions - -import scala.annotation.tailrec - -import org.clapper.classutil.ClassFinder -import org.objectweb.asm.Opcodes - -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero -import org.apache.spark.util.Utils - -class NullIntolerantCheckerSuite extends SparkFunSuite { - - // Do not check these Expressions - private val whiteList = List( - classOf[IntegralDivide], classOf[Divide], classOf[Remainder], classOf[Pmod], - classOf[CheckOverflow], classOf[NormalizeNaNAndZero], - classOf[InSet], - classOf[PrintToStderr], classOf[CodegenFallbackExpression]).map(_.getName) - - private val finder = ClassFinder(maybeOverrideAsmVersion = Some(Opcodes.ASM7)) - - private val nullIntolerantName = classOf[NullIntolerant].getName - - Seq(classOf[UnaryExpression], classOf[BinaryExpression], - classOf[TernaryExpression], classOf[QuaternaryExpression], - classOf[SeptenaryExpression]).map(_.getName).foreach { expName => - ClassFinder.concreteSubclasses(expName, finder.getClasses) - .filterNot(c => whiteList.exists(_.equals(c.name))).foreach { classInfo => - test(s"${classInfo.name}") { - // Do not check NonSQLExpressions - if (!classInfo.interfaces.exists(_.equals(classOf[NonSQLExpression].getName))) { - val evalExist = overrodeEval(classInfo.name, expName) - val nullIntolerantExist = implementedNullIntolerant(classInfo.name, expName) - if (evalExist && nullIntolerantExist) { - fail(s"${classInfo.name} should not extend $nullIntolerantName") - } else if (!evalExist && !nullIntolerantExist) { - fail(s"${classInfo.name} should extend $nullIntolerantName") - } else { - assert((!evalExist && nullIntolerantExist) || (evalExist && !nullIntolerantExist)) - } - } - } - } - } - - @tailrec - private def implementedNullIntolerant(className: String, endClassName: String): Boolean = { - val clazz = Utils.classForName(className) - val nullIntolerant = clazz.getInterfaces.exists(_.getName.equals(nullIntolerantName)) || - clazz.getInterfaces.exists { i => - Utils.classForName(i.getName).getInterfaces.exists(_.getName.equals(nullIntolerantName)) - } - val superClassName = clazz.getSuperclass.getName - if (!nullIntolerant && !superClassName.equals(endClassName)) { - implementedNullIntolerant(superClassName, endClassName) - } else { - nullIntolerant - } - } - - @tailrec - private def overrodeEval(className: String, endClassName: String): Boolean = { - val clazz = Utils.classForName(className) - val eval = clazz.getDeclaredMethods.exists(_.getName.equals("eval")) - val superClassName = clazz.getSuperclass.getName - if (!eval && !superClassName.equals(endClassName)) { - overrodeEval(superClassName, endClassName) - } else { - eval - } - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala index e18514c6f93f9..fb4e028bbfadb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala @@ -21,10 +21,12 @@ import scala.collection.parallel.immutable.ParVector import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.FunctionIdentifier -import org.apache.spark.sql.catalyst.expressions.ExpressionInfo +import org.apache.spark.sql.catalyst.expressions.{NonSQLExpression, _} +import org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero import org.apache.spark.sql.execution.HiveResult.hiveResultString import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.util.Utils class ExpressionInfoSuite extends SparkFunSuite with SharedSparkSession { @@ -156,4 +158,74 @@ class ExpressionInfoSuite extends SparkFunSuite with SharedSparkSession { } } } + + test("Check whether should extend NullIntolerant") { + // Only check expressions extended from these expressions + val parentExpressionNames = Seq(classOf[UnaryExpression], classOf[BinaryExpression], + classOf[TernaryExpression], classOf[QuaternaryExpression], + classOf[SeptenaryExpression]).map(_.getName) + // Do not check these expressions + val whiteList = Seq( + classOf[IntegralDivide], classOf[Divide], classOf[Remainder], classOf[Pmod], + classOf[CheckOverflow], classOf[NormalizeNaNAndZero], classOf[InSet], + classOf[PrintToStderr], classOf[CodegenFallbackExpression]).map(_.getName) + + spark.sessionState.functionRegistry.listFunction() + .map(spark.sessionState.catalog.lookupFunctionInfo).map(_.getClassName) + .filterNot(c => whiteList.exists(_.equals(c))).foreach { className => + if (needToCheckNullIntolerant(className)) { + val evalExist = checkIfEvalOverrode(className) + val nullIntolerantExist = checkIfNullIntolerantMixedIn(className) + if (evalExist && nullIntolerantExist) { + fail(s"$className should not extend ${classOf[NullIntolerant].getSimpleName}") + } else if (!evalExist && !nullIntolerantExist) { + fail(s"$className should extend ${classOf[NullIntolerant].getSimpleName}") + } else { + assert((!evalExist && nullIntolerantExist) || (evalExist && !nullIntolerantExist)) + } + } + } + + def needToCheckNullIntolerant(className: String): Boolean = { + var clazz: Class[_] = Utils.classForName(className) + val isNonSQLExpr = + clazz.getInterfaces.exists(_.getName.equals(classOf[NonSQLExpression].getName)) + var checkNullIntolerant: Boolean = false + while (!checkNullIntolerant && clazz.getSuperclass != null) { + checkNullIntolerant = parentExpressionNames.exists(_.equals(clazz.getSuperclass.getName)) + if (!checkNullIntolerant) { + clazz = clazz.getSuperclass + } + } + checkNullIntolerant && !isNonSQLExpr + } + + def checkIfNullIntolerantMixedIn(className: String): Boolean = { + val nullIntolerantName = classOf[NullIntolerant].getName + var clazz: Class[_] = Utils.classForName(className) + var nullIntolerantMixedIn = false + while (!nullIntolerantMixedIn && !parentExpressionNames.exists(_.equals(clazz.getName))) { + nullIntolerantMixedIn = clazz.getInterfaces.exists(_.getName.equals(nullIntolerantName)) || + clazz.getInterfaces.exists { i => + Utils.classForName(i.getName).getInterfaces.exists(_.getName.equals(nullIntolerantName)) + } + if (!nullIntolerantMixedIn) { + clazz = clazz.getSuperclass + } + } + nullIntolerantMixedIn + } + + def checkIfEvalOverrode(className: String): Boolean = { + var clazz: Class[_] = Utils.classForName(className) + var evalOverrode: Boolean = false + while (!evalOverrode && !parentExpressionNames.exists(_.equals(clazz.getName))) { + evalOverrode = clazz.getDeclaredMethods.exists(_.getName.equals("eval")) + if (!evalOverrode) { + clazz = clazz.getSuperclass + } + } + evalOverrode + } + } } diff --git a/tools/pom.xml b/tools/pom.xml index a22bcd35371da..6e806413ef261 100644 --- a/tools/pom.xml +++ b/tools/pom.xml @@ -44,6 +44,7 @@ org.clapper classutil_${scala.binary.version} + 1.5.1 From b127c4134ea5c75b10093d53e8874de1ff54e807 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Tue, 26 May 2020 11:50:57 +0800 Subject: [PATCH 4/7] fix --- .../sql/expressions/ExpressionInfoSuite.scala | 95 ++++++------------- 1 file changed, 29 insertions(+), 66 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala index fb4e028bbfadb..e28bdee72f309 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala @@ -20,9 +20,8 @@ package org.apache.spark.sql.expressions import scala.collection.parallel.immutable.ParVector import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.FunctionIdentifier -import org.apache.spark.sql.catalyst.expressions.{NonSQLExpression, _} -import org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero +import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.HiveResult.hiveResultString import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession @@ -159,73 +158,37 @@ class ExpressionInfoSuite extends SparkFunSuite with SharedSparkSession { } } - test("Check whether should extend NullIntolerant") { - // Only check expressions extended from these expressions - val parentExpressionNames = Seq(classOf[UnaryExpression], classOf[BinaryExpression], - classOf[TernaryExpression], classOf[QuaternaryExpression], - classOf[SeptenaryExpression]).map(_.getName) - // Do not check these expressions - val whiteList = Seq( - classOf[IntegralDivide], classOf[Divide], classOf[Remainder], classOf[Pmod], - classOf[CheckOverflow], classOf[NormalizeNaNAndZero], classOf[InSet], - classOf[PrintToStderr], classOf[CodegenFallbackExpression]).map(_.getName) - - spark.sessionState.functionRegistry.listFunction() - .map(spark.sessionState.catalog.lookupFunctionInfo).map(_.getClassName) - .filterNot(c => whiteList.exists(_.equals(c))).foreach { className => - if (needToCheckNullIntolerant(className)) { - val evalExist = checkIfEvalOverrode(className) - val nullIntolerantExist = checkIfNullIntolerantMixedIn(className) - if (evalExist && nullIntolerantExist) { - fail(s"$className should not extend ${classOf[NullIntolerant].getSimpleName}") - } else if (!evalExist && !nullIntolerantExist) { - fail(s"$className should extend ${classOf[NullIntolerant].getSimpleName}") - } else { - assert((!evalExist && nullIntolerantExist) || (evalExist && !nullIntolerantExist)) - } - } - } + test("Check whether SQL expressions should extend NullIntolerant") { + // Only check expressions extended from these expressions because these expressions are + // NullIntolerant by default. + val exprTypesToCheck = Seq(classOf[UnaryExpression], classOf[BinaryExpression], + classOf[TernaryExpression], classOf[QuaternaryExpression], classOf[SeptenaryExpression]) - def needToCheckNullIntolerant(className: String): Boolean = { - var clazz: Class[_] = Utils.classForName(className) - val isNonSQLExpr = - clazz.getInterfaces.exists(_.getName.equals(classOf[NonSQLExpression].getName)) - var checkNullIntolerant: Boolean = false - while (!checkNullIntolerant && clazz.getSuperclass != null) { - checkNullIntolerant = parentExpressionNames.exists(_.equals(clazz.getSuperclass.getName)) - if (!checkNullIntolerant) { - clazz = clazz.getSuperclass - } - } - checkNullIntolerant && !isNonSQLExpr - } + // Do not check these expressions, because these expressions extend NullIntolerant + // and override the eval function. + val ignoreSet = Set(classOf[IntegralDivide], classOf[Divide], classOf[Remainder], classOf[Pmod]) - def checkIfNullIntolerantMixedIn(className: String): Boolean = { - val nullIntolerantName = classOf[NullIntolerant].getName - var clazz: Class[_] = Utils.classForName(className) - var nullIntolerantMixedIn = false - while (!nullIntolerantMixedIn && !parentExpressionNames.exists(_.equals(clazz.getName))) { - nullIntolerantMixedIn = clazz.getInterfaces.exists(_.getName.equals(nullIntolerantName)) || - clazz.getInterfaces.exists { i => - Utils.classForName(i.getName).getInterfaces.exists(_.getName.equals(nullIntolerantName)) - } - if (!nullIntolerantMixedIn) { - clazz = clazz.getSuperclass - } - } - nullIntolerantMixedIn - } - - def checkIfEvalOverrode(className: String): Boolean = { - var clazz: Class[_] = Utils.classForName(className) - var evalOverrode: Boolean = false - while (!evalOverrode && !parentExpressionNames.exists(_.equals(clazz.getName))) { - evalOverrode = clazz.getDeclaredMethods.exists(_.getName.equals("eval")) - if (!evalOverrode) { - clazz = clazz.getSuperclass + val candidateExprsToCheck = spark.sessionState.functionRegistry.listFunction() + .map(spark.sessionState.catalog.lookupFunctionInfo).map(_.getClassName) + .filterNot(c => ignoreSet.exists(_.getName.equals(c))) + .map(name => Utils.classForName(name)) + .filterNot(classOf[NonSQLExpression].isAssignableFrom) + + exprTypesToCheck.foreach { superClass => + candidateExprsToCheck.filter(superClass.isAssignableFrom).foreach { clazz => + val isEvalOverrode = clazz.getMethod("eval", classOf[InternalRow]) != + superClass.getMethod("eval", classOf[InternalRow]) + val isNullIntolerantMixedIn = classOf[NullIntolerant].isAssignableFrom(clazz) + if (isEvalOverrode && isNullIntolerantMixedIn) { + fail(s"${clazz.getName} should not extend ${classOf[NullIntolerant].getSimpleName}, " + + s"or add ${clazz.getName} in the ignoreSet of this test.") + } else if (!isEvalOverrode && !isNullIntolerantMixedIn) { + fail(s"${clazz.getName} should extend ${classOf[NullIntolerant].getSimpleName}.") + } else { + assert((!isEvalOverrode && isNullIntolerantMixedIn) || + (isEvalOverrode && !isNullIntolerantMixedIn)) } } - evalOverrode } } } From 82e97e34f46f0b01c9fc49d25aa7bb1a76e140eb Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Tue, 26 May 2020 20:41:16 +0800 Subject: [PATCH 5/7] fix typo --- .../sql/catalyst/expressions/collectionOperations.scala | 2 +- .../spark/sql/catalyst/expressions/stringExpressions.scala | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index e11e72090ced2..b32e9ee05f1ef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -3224,7 +3224,7 @@ case class ArrayDistinct(child: Expression) * Will become common base class for [[ArrayUnion]], [[ArrayIntersect]], and [[ArrayExcept]]. */ trait ArrayBinaryLike - extends BinaryArrayExpressionWithImplicitCast with ArraySetLike with NullIntolerant{ + extends BinaryArrayExpressionWithImplicitCast with ArraySetLike with NullIntolerant { override protected def dt: DataType = dataType override protected def et: DataType = elementType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 9d1016bb793b1..7a8ab17c13f38 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -357,7 +357,7 @@ case class Upper(child: Expression) """, since = "1.0.1") case class Lower(child: Expression) - extends UnaryExpression with String2StringExpression with NullIntolerant{ + extends UnaryExpression with String2StringExpression with NullIntolerant { // scalastyle:off caselocale override def convert(v: UTF8String): UTF8String = v.toLowerCase @@ -1741,7 +1741,7 @@ case class Left(str: Expression, len: Expression, child: Expression) extends Run since = "1.5.0") // scalastyle:on line.size.limit case class Length(child: Expression) - extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant{ + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = IntegerType override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, BinaryType)) @@ -2002,7 +2002,7 @@ case class Base64(child: Expression) """, since = "1.5.0") case class UnBase64(child: Expression) - extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant{ + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = BinaryType override def inputTypes: Seq[DataType] = Seq(StringType) From 1bdff9584a09d42ba27f9742bf5a6e8a4c5c74d3 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Thu, 28 May 2020 18:32:36 +0800 Subject: [PATCH 6/7] Address comment --- .../sql/catalyst/expressions/arithmetic.scala | 38 +++++++------------ .../sql/expressions/ExpressionInfoSuite.scala | 10 ++--- 2 files changed, 16 insertions(+), 32 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 7c521838447d3..043b578691063 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -309,17 +309,11 @@ trait DivModLike extends BinaryArithmetic { override def nullable: Boolean = true - final override def eval(input: InternalRow): Any = { - val input2 = right.eval(input) - if (input2 == null || input2 == 0) { + final override def nullSafeEval(input1: Any, input2: Any): Any = { + if (input2 == 0) { null } else { - val input1 = left.eval(input) - if (input1 == null) { - null - } else { - evalOperation(input1, input2) - } + evalOperation(input1, input2) } } @@ -516,24 +510,18 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { override def nullable: Boolean = true - override def eval(input: InternalRow): Any = { - val input2 = right.eval(input) - if (input2 == null || input2 == 0) { + override def nullSafeEval(input1: Any, input2: Any): Any = { + if (input2 == 0) { null } else { - val input1 = left.eval(input) - if (input1 == null) { - null - } else { - input1 match { - case i: Integer => pmod(i, input2.asInstanceOf[java.lang.Integer]) - case l: Long => pmod(l, input2.asInstanceOf[java.lang.Long]) - case s: Short => pmod(s, input2.asInstanceOf[java.lang.Short]) - case b: Byte => pmod(b, input2.asInstanceOf[java.lang.Byte]) - case f: Float => pmod(f, input2.asInstanceOf[java.lang.Float]) - case d: Double => pmod(d, input2.asInstanceOf[java.lang.Double]) - case d: Decimal => pmod(d, input2.asInstanceOf[Decimal]) - } + input1 match { + case i: Integer => pmod(i, input2.asInstanceOf[java.lang.Integer]) + case l: Long => pmod(l, input2.asInstanceOf[java.lang.Long]) + case s: Short => pmod(s, input2.asInstanceOf[java.lang.Short]) + case b: Byte => pmod(b, input2.asInstanceOf[java.lang.Byte]) + case f: Float => pmod(f, input2.asInstanceOf[java.lang.Float]) + case d: Double => pmod(d, input2.asInstanceOf[java.lang.Double]) + case d: Decimal => pmod(d, input2.asInstanceOf[Decimal]) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala index e28bdee72f309..96d028714d46a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala @@ -164,13 +164,8 @@ class ExpressionInfoSuite extends SparkFunSuite with SharedSparkSession { val exprTypesToCheck = Seq(classOf[UnaryExpression], classOf[BinaryExpression], classOf[TernaryExpression], classOf[QuaternaryExpression], classOf[SeptenaryExpression]) - // Do not check these expressions, because these expressions extend NullIntolerant - // and override the eval function. - val ignoreSet = Set(classOf[IntegralDivide], classOf[Divide], classOf[Remainder], classOf[Pmod]) - val candidateExprsToCheck = spark.sessionState.functionRegistry.listFunction() .map(spark.sessionState.catalog.lookupFunctionInfo).map(_.getClassName) - .filterNot(c => ignoreSet.exists(_.getName.equals(c))) .map(name => Utils.classForName(name)) .filterNot(classOf[NonSQLExpression].isAssignableFrom) @@ -180,8 +175,9 @@ class ExpressionInfoSuite extends SparkFunSuite with SharedSparkSession { superClass.getMethod("eval", classOf[InternalRow]) val isNullIntolerantMixedIn = classOf[NullIntolerant].isAssignableFrom(clazz) if (isEvalOverrode && isNullIntolerantMixedIn) { - fail(s"${clazz.getName} should not extend ${classOf[NullIntolerant].getSimpleName}, " + - s"or add ${clazz.getName} in the ignoreSet of this test.") + fail(s"${clazz.getName} overrode the eval method and extended " + + s"${classOf[NullIntolerant].getSimpleName}, which may be incorrect. " + + s"You may need to override the nullSafeEval method.") } else if (!isEvalOverrode && !isNullIntolerantMixedIn) { fail(s"${clazz.getName} should extend ${classOf[NullIntolerant].getSimpleName}.") } else { From 265abd3b8905613aa4b9878326f0004cae2baab7 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Thu, 28 May 2020 22:23:42 +0800 Subject: [PATCH 7/7] revert --- .../sql/catalyst/expressions/arithmetic.scala | 38 ++++++++++++------- .../sql/expressions/ExpressionInfoSuite.scala | 10 +++-- 2 files changed, 32 insertions(+), 16 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 043b578691063..7c521838447d3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -309,11 +309,17 @@ trait DivModLike extends BinaryArithmetic { override def nullable: Boolean = true - final override def nullSafeEval(input1: Any, input2: Any): Any = { - if (input2 == 0) { + final override def eval(input: InternalRow): Any = { + val input2 = right.eval(input) + if (input2 == null || input2 == 0) { null } else { - evalOperation(input1, input2) + val input1 = left.eval(input) + if (input1 == null) { + null + } else { + evalOperation(input1, input2) + } } } @@ -510,18 +516,24 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { override def nullable: Boolean = true - override def nullSafeEval(input1: Any, input2: Any): Any = { - if (input2 == 0) { + override def eval(input: InternalRow): Any = { + val input2 = right.eval(input) + if (input2 == null || input2 == 0) { null } else { - input1 match { - case i: Integer => pmod(i, input2.asInstanceOf[java.lang.Integer]) - case l: Long => pmod(l, input2.asInstanceOf[java.lang.Long]) - case s: Short => pmod(s, input2.asInstanceOf[java.lang.Short]) - case b: Byte => pmod(b, input2.asInstanceOf[java.lang.Byte]) - case f: Float => pmod(f, input2.asInstanceOf[java.lang.Float]) - case d: Double => pmod(d, input2.asInstanceOf[java.lang.Double]) - case d: Decimal => pmod(d, input2.asInstanceOf[Decimal]) + val input1 = left.eval(input) + if (input1 == null) { + null + } else { + input1 match { + case i: Integer => pmod(i, input2.asInstanceOf[java.lang.Integer]) + case l: Long => pmod(l, input2.asInstanceOf[java.lang.Long]) + case s: Short => pmod(s, input2.asInstanceOf[java.lang.Short]) + case b: Byte => pmod(b, input2.asInstanceOf[java.lang.Byte]) + case f: Float => pmod(f, input2.asInstanceOf[java.lang.Float]) + case d: Double => pmod(d, input2.asInstanceOf[java.lang.Double]) + case d: Decimal => pmod(d, input2.asInstanceOf[Decimal]) + } } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala index 96d028714d46a..53f9757750735 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala @@ -164,8 +164,13 @@ class ExpressionInfoSuite extends SparkFunSuite with SharedSparkSession { val exprTypesToCheck = Seq(classOf[UnaryExpression], classOf[BinaryExpression], classOf[TernaryExpression], classOf[QuaternaryExpression], classOf[SeptenaryExpression]) + // Do not check these expressions, because these expressions extend NullIntolerant + // and override the eval method to avoid evaluating input1 if input2 is 0. + val ignoreSet = Set(classOf[IntegralDivide], classOf[Divide], classOf[Remainder], classOf[Pmod]) + val candidateExprsToCheck = spark.sessionState.functionRegistry.listFunction() .map(spark.sessionState.catalog.lookupFunctionInfo).map(_.getClassName) + .filterNot(c => ignoreSet.exists(_.getName.equals(c))) .map(name => Utils.classForName(name)) .filterNot(classOf[NonSQLExpression].isAssignableFrom) @@ -175,9 +180,8 @@ class ExpressionInfoSuite extends SparkFunSuite with SharedSparkSession { superClass.getMethod("eval", classOf[InternalRow]) val isNullIntolerantMixedIn = classOf[NullIntolerant].isAssignableFrom(clazz) if (isEvalOverrode && isNullIntolerantMixedIn) { - fail(s"${clazz.getName} overrode the eval method and extended " + - s"${classOf[NullIntolerant].getSimpleName}, which may be incorrect. " + - s"You may need to override the nullSafeEval method.") + fail(s"${clazz.getName} should not extend ${classOf[NullIntolerant].getSimpleName}, " + + s"or add ${clazz.getName} in the ignoreSet of this test.") } else if (!isEvalOverrode && !isNullIntolerantMixedIn) { fail(s"${clazz.getName} should extend ${classOf[NullIntolerant].getSimpleName}.") } else {