Skip to content

Commit

Permalink
1. update doc.
Browse files Browse the repository at this point in the history
2. add more UT scenario.
3. update nullability calculation for ElementAt and update UT.

Change-Id: Ieec69fe5171dfee2863ad93f84bb660076214de0
  • Loading branch information
leanken-zz committed Nov 10, 2020
1 parent da4f5ee commit 7dcba56
Show file tree
Hide file tree
Showing 8 changed files with 151 additions and 85 deletions.
2 changes: 2 additions & 0 deletions docs/sql-ref-ansi-compliance.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ SELECT * FROM t;

The behavior of some SQL functions can be different under ANSI mode (`spark.sql.ansi.enabled=true`).
- `size`: This function returns null for null input under ANSI mode.
- `element_at`: This function throws `ArrayIndexOutOfBoundsException` if using invalid indices under ANSI mode.
- `elt`: This function throws `ArrayIndexOutOfBoundsException` if using invalid indices under ANSI mode.

### SQL Keywords

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1906,8 +1906,8 @@ case class ArrayPosition(left: Expression, right: Expression)
@ExpressionDescription(
usage = """
_FUNC_(array, index) - Returns element of array at given (1-based) index. If index < 0,
accesses elements from the last to the first. Returns NULL if the index exceeds the length
of the array.
accesses elements from the last to the first. If the index exceeds the length of the array,
Returns NULL if Ansi mode is off; Throws ArrayIndexOutOfBoundsException when Ansi mode is on.
_FUNC_(map, key) - Returns value for given key, or NULL if the key is not contained in the map
""",
Expand Down Expand Up @@ -1974,7 +1974,7 @@ case class ElementAt(
if (ordinal == 0) {
false
} else if (elements.length < math.abs(ordinal)) {
true
if (failOnError) false else true
} else {
if (ordinal < 0) {
elements(elements.length + ordinal).nullable
Expand All @@ -1996,7 +1996,7 @@ case class ElementAt(
true
}
} else {
true
if (failOnError) arrayContainsNull else true
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,11 @@ case class ConcatWs(children: Seq[Expression])
*/
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = "_FUNC_(n, input1, input2, ...) - Returns the `n`-th input, e.g., returns `input2` when `n` is 2.",
usage = """
_FUNC_(n, input1, input2, ...) - Returns the `n`-th input, e.g., returns `input2` when `n` is 2.
If the index exceeds the length of the array, Returns NULL if Ansi mode is off;
Throws ArrayIndexOutOfBoundsException when Ansi mode is on.
""",
examples = """
Examples:
> SELECT _FUNC_(1, 'scala', 'java');
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2146,7 +2146,9 @@ object SQLConf {
.doc("When true, Spark tries to conform to the ANSI SQL specification: 1. Spark will " +
"throw a runtime exception if an overflow occurs in any operation on integral/decimal " +
"field. 2. Spark will forbid using the reserved keywords of ANSI SQL as identifiers in " +
"the SQL parser.")
"the SQL parser. 3. Spark will returns null for null input for function `size`. " +
"4. Spark will throw ArrayIndexOutOfBoundsException if invalid indices " +
"used on function `element_at`/`elt`.")
.version("3.0.0")
.booleanConf
.createWithDefault(false)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1118,58 +1118,72 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
}

test("correctly handles ElementAt nullability for arrays") {
// CreateArray case
val a = AttributeReference("a", IntegerType, nullable = false)()
val b = AttributeReference("b", IntegerType, nullable = true)()
val array = CreateArray(a :: b :: Nil)
assert(!ElementAt(array, Literal(1)).nullable)
assert(!ElementAt(array, Literal(-2)).nullable)
assert(ElementAt(array, Literal(2)).nullable)
assert(ElementAt(array, Literal(-1)).nullable)
assert(!ElementAt(array, Subtract(Literal(2), Literal(1))).nullable)
assert(ElementAt(array, AttributeReference("ordinal", IntegerType)()).nullable)

// CreateArray case invalid indices
assert(!ElementAt(array, Literal(0)).nullable)
assert(ElementAt(array, Literal(4)).nullable)
assert(ElementAt(array, Literal(-4)).nullable)

// GetArrayStructFields case
val f1 = StructField("a", IntegerType, nullable = false)
val f2 = StructField("b", IntegerType, nullable = true)
val structType = StructType(f1 :: f2 :: Nil)
val c = AttributeReference("c", structType, nullable = false)()
val inputArray1 = CreateArray(c :: Nil)
val inputArray1ContainsNull = c.nullable
val stArray1 = GetArrayStructFields(inputArray1, f1, 0, 2, inputArray1ContainsNull)
assert(!ElementAt(stArray1, Literal(1)).nullable)
assert(!ElementAt(stArray1, Literal(-1)).nullable)
val stArray2 = GetArrayStructFields(inputArray1, f2, 1, 2, inputArray1ContainsNull)
assert(ElementAt(stArray2, Literal(1)).nullable)
assert(ElementAt(stArray2, Literal(-1)).nullable)

val d = AttributeReference("d", structType, nullable = true)()
val inputArray2 = CreateArray(c :: d :: Nil)
val inputArray2ContainsNull = c.nullable || d.nullable
val stArray3 = GetArrayStructFields(inputArray2, f1, 0, 2, inputArray2ContainsNull)
assert(!ElementAt(stArray3, Literal(1)).nullable)
assert(!ElementAt(stArray3, Literal(-2)).nullable)
assert(ElementAt(stArray3, Literal(2)).nullable)
assert(ElementAt(stArray3, Literal(-1)).nullable)
val stArray4 = GetArrayStructFields(inputArray2, f2, 1, 2, inputArray2ContainsNull)
assert(ElementAt(stArray4, Literal(1)).nullable)
assert(ElementAt(stArray4, Literal(-2)).nullable)
assert(ElementAt(stArray4, Literal(2)).nullable)
assert(ElementAt(stArray4, Literal(-1)).nullable)

// GetArrayStructFields case invalid indices
assert(!ElementAt(stArray3, Literal(0)).nullable)
assert(ElementAt(stArray3, Literal(4)).nullable)
assert(ElementAt(stArray3, Literal(-4)).nullable)

assert(ElementAt(stArray4, Literal(0)).nullable)
assert(ElementAt(stArray4, Literal(4)).nullable)
assert(ElementAt(stArray4, Literal(-4)).nullable)
Seq(true, false).foreach { ansiEnabled =>
withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) {
// CreateArray case
val a = AttributeReference("a", IntegerType, nullable = false)()
val b = AttributeReference("b", IntegerType, nullable = true)()
val array = CreateArray(a :: b :: Nil)
assert(!ElementAt(array, Literal(1)).nullable)
assert(!ElementAt(array, Literal(-2)).nullable)
assert(ElementAt(array, Literal(2)).nullable)
assert(ElementAt(array, Literal(-1)).nullable)
assert(!ElementAt(array, Subtract(Literal(2), Literal(1))).nullable)
assert(ElementAt(array, AttributeReference("ordinal", IntegerType)()).nullable)

// CreateArray case invalid indices
assert(!ElementAt(array, Literal(0)).nullable)
if (ansiEnabled) {
assert(!ElementAt(array, Literal(4)).nullable)
assert(!ElementAt(array, Literal(-4)).nullable)
} else {
assert(ElementAt(array, Literal(4)).nullable)
assert(ElementAt(array, Literal(-4)).nullable)
}

// GetArrayStructFields case
val f1 = StructField("a", IntegerType, nullable = false)
val f2 = StructField("b", IntegerType, nullable = true)
val structType = StructType(f1 :: f2 :: Nil)
val c = AttributeReference("c", structType, nullable = false)()
val inputArray1 = CreateArray(c :: Nil)
val inputArray1ContainsNull = c.nullable
val stArray1 = GetArrayStructFields(inputArray1, f1, 0, 2, inputArray1ContainsNull)
assert(!ElementAt(stArray1, Literal(1)).nullable)
assert(!ElementAt(stArray1, Literal(-1)).nullable)
val stArray2 = GetArrayStructFields(inputArray1, f2, 1, 2, inputArray1ContainsNull)
assert(ElementAt(stArray2, Literal(1)).nullable)
assert(ElementAt(stArray2, Literal(-1)).nullable)

val d = AttributeReference("d", structType, nullable = true)()
val inputArray2 = CreateArray(c :: d :: Nil)
val inputArray2ContainsNull = c.nullable || d.nullable
val stArray3 = GetArrayStructFields(inputArray2, f1, 0, 2, inputArray2ContainsNull)
assert(!ElementAt(stArray3, Literal(1)).nullable)
assert(!ElementAt(stArray3, Literal(-2)).nullable)
assert(ElementAt(stArray3, Literal(2)).nullable)
assert(ElementAt(stArray3, Literal(-1)).nullable)
val stArray4 = GetArrayStructFields(inputArray2, f2, 1, 2, inputArray2ContainsNull)
assert(ElementAt(stArray4, Literal(1)).nullable)
assert(ElementAt(stArray4, Literal(-2)).nullable)
assert(ElementAt(stArray4, Literal(2)).nullable)
assert(ElementAt(stArray4, Literal(-1)).nullable)

// GetArrayStructFields case invalid indices
assert(!ElementAt(stArray3, Literal(0)).nullable)
if (ansiEnabled) {
assert(!ElementAt(stArray3, Literal(4)).nullable)
assert(!ElementAt(stArray3, Literal(-4)).nullable)
} else {
assert(ElementAt(stArray3, Literal(4)).nullable)
assert(ElementAt(stArray3, Literal(-4)).nullable)
}

assert(ElementAt(stArray4, Literal(0)).nullable)
assert(ElementAt(stArray4, Literal(4)).nullable)
assert(ElementAt(stArray4, Literal(-4)).nullable)
}
}
}

test("Concat") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,16 +65,21 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
test("SPARK-33391: GetArrayItem ArrayIndexOutOfBoundsException") {
Seq(true, false).foreach { ansiEnabled =>
withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) {
val typeA = ArrayType(StringType)
val array = Literal.create(Seq("a", "b"), typeA)
val array = Literal.create(Seq("a", "b"), ArrayType(StringType))

if (ansiEnabled) {
val ex = intercept[Exception] {
checkEvaluation(GetArrayItem(array, Literal(5)), null)
}
assert(stackTraceToString(ex).contains("Invalid index: 5"))
checkExceptionInExpression[Exception](
GetArrayItem(array, Literal(5)),
"Invalid index: 5"
)

checkExceptionInExpression[Exception](
GetArrayItem(array, Literal(-1)),
"Invalid index: -1"
)
} else {
checkEvaluation(GetArrayItem(array, Literal(5)), null)
checkEvaluation(GetArrayItem(array, Literal(-1)), null)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import scala.util.Random

import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.{ElementAt, Expression, ExpressionEvalHelper, Literal}
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.plans.logical.OneRowRelation
import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.{withDefaultTimeZone, UTC}
Expand All @@ -36,7 +36,7 @@ import org.apache.spark.sql.types._
/**
* Test suite for functions in [[org.apache.spark.sql.functions]].
*/
class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession with ExpressionEvalHelper{
import testImplicits._

test("array with column name") {
Expand Down Expand Up @@ -3625,18 +3625,33 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
test("SPARK-33391: element_at ArrayIndexOutOfBoundsException") {
Seq(true, false).foreach { ansiEnabled =>
withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) {
val df = sql("select element_at(array(1, 2, 3), 5)")
var df = sql("select element_at(array(1, 2, 3), 5)")
val array = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType))
if (ansiEnabled) {
val ex = intercept[Exception] {
df.collect()
}
assert(ex.getMessage.contains("Invalid index: 5"))
val errMsg = "Invalid index: 5"
val ex = intercept[Exception](df.collect())
assert(ex.getMessage.contains(errMsg))
checkExceptionInExpression[Exception](ElementAt(array, Literal(5)), errMsg)
} else {
checkAnswer(
df,
Row(null)
)
checkAnswer(df, Row(null))
}

df = sql("select element_at(array(1, 2, 3), -5)")
if (ansiEnabled) {
val errMsg = "Invalid index: -5"
val ex = intercept[Exception](df.collect())
assert(ex.getMessage.contains(errMsg))
checkExceptionInExpression[Exception](ElementAt(array, Literal(-5)), errMsg)
} else {
checkAnswer(df, Row(null))
}

// SQL array indices start at 1 exception throws for both mode.
val errMsg = "SQL array indices start at 1"
df = sql("select element_at(array(1, 2, 3), 0)")
val ex = intercept[Exception](df.collect())
assert(ex.getMessage.contains(errMsg))
checkExceptionInExpression[Exception](ElementAt(array, Literal(0)), errMsg)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@

package org.apache.spark.sql

import org.apache.spark.sql.catalyst.expressions.{Elt, ExpressionEvalHelper, Literal}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession


class StringFunctionsSuite extends QueryTest with SharedSparkSession {
class StringFunctionsSuite extends QueryTest with SharedSparkSession with ExpressionEvalHelper {
import testImplicits._

test("string concat") {
Expand Down Expand Up @@ -621,17 +622,40 @@ class StringFunctionsSuite extends QueryTest with SharedSparkSession {
test("SPARK-33391: elt ArrayIndexOutOfBoundsException") {
Seq(true, false).foreach { ansiEnabled =>
withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) {
val df = sql("select elt(4, '123', '456')")
var df = sql("select elt(4, '123', '456')")
if (ansiEnabled) {
val ex = intercept[Exception] {
df.collect()
}
assert(ex.getMessage.contains("Invalid index: 4"))
val errMsg = "Invalid index: 4"
val ex = intercept[Exception](df.collect())
assert(ex.getMessage.contains(errMsg))
checkExceptionInExpression[Exception](
Elt(Seq(Literal(4), Literal("123"), Literal("456"))),
errMsg)
} else {
checkAnswer(
df,
Row(null)
)
checkAnswer(df, Row(null))
}

df = sql("select elt(0, '123', '456')")
if (ansiEnabled) {
val errMsg = "Invalid index: 0"
val ex = intercept[Exception](df.collect())
assert(ex.getMessage.contains(errMsg))
checkExceptionInExpression[Exception](
Elt(Seq(Literal(0), Literal("123"), Literal("456"))),
errMsg)
} else {
checkAnswer(df, Row(null))
}

df = sql("select elt(-1, '123', '456')")
if (ansiEnabled) {
val errMsg = "Invalid index: -1"
val ex = intercept[Exception](df.collect())
assert(ex.getMessage.contains(errMsg))
checkExceptionInExpression[Exception](
Elt(Seq(Literal(-1), Literal("123"), Literal("456"))),
errMsg)
} else {
checkAnswer(df, Row(null))
}
}
}
Expand Down

0 comments on commit 7dcba56

Please sign in to comment.