Skip to content

Commit

Permalink
Move checker to ExpressionInfoSuite
Browse files Browse the repository at this point in the history
  • Loading branch information
wangyum committed May 24, 2020
1 parent c3625ca commit 7fe3490
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 102 deletions.
5 changes: 0 additions & 5 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -882,11 +882,6 @@
<artifactId>jline</artifactId>
<version>2.14.6</version>
</dependency>
<dependency>
<groupId>org.clapper</groupId>
<artifactId>classutil_${scala.binary.version}</artifactId>
<version>1.5.1</version>
</dependency>
<dependency>
<groupId>org.scalatest</groupId>
<artifactId>scalatest_${scala.binary.version}</artifactId>
Expand Down
5 changes: 0 additions & 5 deletions sql/catalyst/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,6 @@
<groupId>org.apache.arrow</groupId>
<artifactId>arrow-vector</artifactId>
</dependency>
<dependency>
<groupId>org.clapper</groupId>
<artifactId>classutil_${scala.binary.version}</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
<build>
<outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ object CreateStruct extends FunctionBuilder {
*/
val registryEntry: (String, (ExpressionInfo, FunctionBuilder)) = {
val info: ExpressionInfo = new ExpressionInfo(
"org.apache.spark.sql.catalyst.expressions.NamedStruct",
"org.apache.spark.sql.catalyst.expressions.CreateStruct",
null,
"struct",
"_FUNC_(col1, col2, col3, ...) - Creates a struct with the given field values.",
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,12 @@ import scala.collection.parallel.immutable.ParVector

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.expressions.ExpressionInfo
import org.apache.spark.sql.catalyst.expressions.{NonSQLExpression, _}
import org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero
import org.apache.spark.sql.execution.HiveResult.hiveResultString
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.util.Utils

class ExpressionInfoSuite extends SparkFunSuite with SharedSparkSession {

Expand Down Expand Up @@ -156,4 +158,74 @@ class ExpressionInfoSuite extends SparkFunSuite with SharedSparkSession {
}
}
}

test("Check whether should extend NullIntolerant") {
// Only check expressions extended from these expressions
val parentExpressionNames = Seq(classOf[UnaryExpression], classOf[BinaryExpression],
classOf[TernaryExpression], classOf[QuaternaryExpression],
classOf[SeptenaryExpression]).map(_.getName)
// Do not check these expressions
val whiteList = Seq(
classOf[IntegralDivide], classOf[Divide], classOf[Remainder], classOf[Pmod],
classOf[CheckOverflow], classOf[NormalizeNaNAndZero], classOf[InSet],
classOf[PrintToStderr], classOf[CodegenFallbackExpression]).map(_.getName)

spark.sessionState.functionRegistry.listFunction()
.map(spark.sessionState.catalog.lookupFunctionInfo).map(_.getClassName)
.filterNot(c => whiteList.exists(_.equals(c))).foreach { className =>
if (needToCheckNullIntolerant(className)) {
val evalExist = checkIfEvalOverrode(className)
val nullIntolerantExist = checkIfNullIntolerantMixedIn(className)
if (evalExist && nullIntolerantExist) {
fail(s"$className should not extend ${classOf[NullIntolerant].getSimpleName}")
} else if (!evalExist && !nullIntolerantExist) {
fail(s"$className should extend ${classOf[NullIntolerant].getSimpleName}")
} else {
assert((!evalExist && nullIntolerantExist) || (evalExist && !nullIntolerantExist))
}
}
}

def needToCheckNullIntolerant(className: String): Boolean = {
var clazz: Class[_] = Utils.classForName(className)
val isNonSQLExpr =
clazz.getInterfaces.exists(_.getName.equals(classOf[NonSQLExpression].getName))
var checkNullIntolerant: Boolean = false
while (!checkNullIntolerant && clazz.getSuperclass != null) {
checkNullIntolerant = parentExpressionNames.exists(_.equals(clazz.getSuperclass.getName))
if (!checkNullIntolerant) {
clazz = clazz.getSuperclass
}
}
checkNullIntolerant && !isNonSQLExpr
}

def checkIfNullIntolerantMixedIn(className: String): Boolean = {
val nullIntolerantName = classOf[NullIntolerant].getName
var clazz: Class[_] = Utils.classForName(className)
var nullIntolerantMixedIn = false
while (!nullIntolerantMixedIn && !parentExpressionNames.exists(_.equals(clazz.getName))) {
nullIntolerantMixedIn = clazz.getInterfaces.exists(_.getName.equals(nullIntolerantName)) ||
clazz.getInterfaces.exists { i =>
Utils.classForName(i.getName).getInterfaces.exists(_.getName.equals(nullIntolerantName))
}
if (!nullIntolerantMixedIn) {
clazz = clazz.getSuperclass
}
}
nullIntolerantMixedIn
}

def checkIfEvalOverrode(className: String): Boolean = {
var clazz: Class[_] = Utils.classForName(className)
var evalOverrode: Boolean = false
while (!evalOverrode && !parentExpressionNames.exists(_.equals(clazz.getName))) {
evalOverrode = clazz.getDeclaredMethods.exists(_.getName.equals("eval"))
if (!evalOverrode) {
clazz = clazz.getSuperclass
}
}
evalOverrode
}
}
}
1 change: 1 addition & 0 deletions tools/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
<dependency>
<groupId>org.clapper</groupId>
<artifactId>classutil_${scala.binary.version}</artifactId>
<version>1.5.1</version>
</dependency>
</dependencies>

Expand Down

0 comments on commit 7fe3490

Please sign in to comment.