diff --git a/docs/sql-ref-ansi-compliance.md b/docs/sql-ref-ansi-compliance.md index d6e99312bb66e..42fe02305ded4 100644 --- a/docs/sql-ref-ansi-compliance.md +++ b/docs/sql-ref-ansi-compliance.md @@ -111,6 +111,8 @@ SELECT * FROM t; The behavior of some SQL functions can be different under ANSI mode (`spark.sql.ansi.enabled=true`). - `size`: This function returns null for null input under ANSI mode. + - `element_at`: This function throws `ArrayIndexOutOfBoundsException` if using invalid indices under ANSI mode. + - `elt`: This function throws `ArrayIndexOutOfBoundsException` if using invalid indices under ANSI mode. ### SQL Keywords diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 6a6a0308057d3..c678e70c7fe0c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1906,8 +1906,8 @@ case class ArrayPosition(left: Expression, right: Expression) @ExpressionDescription( usage = """ _FUNC_(array, index) - Returns element of array at given (1-based) index. If index < 0, - accesses elements from the last to the first. Returns NULL if the index exceeds the length - of the array. + accesses elements from the last to the first. If the index exceeds the length of the array, + Returns NULL if Ansi mode is off; Throws ArrayIndexOutOfBoundsException when Ansi mode is on. _FUNC_(map, key) - Returns value for given key, or NULL if the key is not contained in the map """, @@ -1974,7 +1974,7 @@ case class ElementAt( if (ordinal == 0) { false } else if (elements.length < math.abs(ordinal)) { - true + if (failOnError) false else true } else { if (ordinal < 0) { elements(elements.length + ordinal).nullable @@ -1996,7 +1996,7 @@ case class ElementAt( true } } else { - true + if (failOnError) arrayContainsNull else true } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index e0624238a9742..faceaa6c4b25a 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -232,7 +232,11 @@ case class ConcatWs(children: Seq[Expression]) */ // scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_(n, input1, input2, ...) - Returns the `n`-th input, e.g., returns `input2` when `n` is 2.", + usage = """ + _FUNC_(n, input1, input2, ...) - Returns the `n`-th input, e.g., returns `input2` when `n` is 2. + If the index exceeds the length of the array, Returns NULL if Ansi mode is off; + Throws ArrayIndexOutOfBoundsException when Ansi mode is on. + """, examples = """ Examples: > SELECT _FUNC_(1, 'scala', 'java'); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 21357a492e39e..9b9af2c7f023e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2146,7 +2146,9 @@ object SQLConf { .doc("When true, Spark tries to conform to the ANSI SQL specification: 1. Spark will " + "throw a runtime exception if an overflow occurs in any operation on integral/decimal " + "field. 2. Spark will forbid using the reserved keywords of ANSI SQL as identifiers in " + - "the SQL parser.") + "the SQL parser. 3. Spark will returns null for null input for function `size`. " + + "4. Spark will throw ArrayIndexOutOfBoundsException if invalid indices " + + "used on function `element_at`/`elt`.") .version("3.0.0") .booleanConf .createWithDefault(false) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index d59d13d49cef4..455efe336064d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -1118,58 +1118,72 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper } test("correctly handles ElementAt nullability for arrays") { - // CreateArray case - val a = AttributeReference("a", IntegerType, nullable = false)() - val b = AttributeReference("b", IntegerType, nullable = true)() - val array = CreateArray(a :: b :: Nil) - assert(!ElementAt(array, Literal(1)).nullable) - assert(!ElementAt(array, Literal(-2)).nullable) - assert(ElementAt(array, Literal(2)).nullable) - assert(ElementAt(array, Literal(-1)).nullable) - assert(!ElementAt(array, Subtract(Literal(2), Literal(1))).nullable) - assert(ElementAt(array, AttributeReference("ordinal", IntegerType)()).nullable) - - // CreateArray case invalid indices - assert(!ElementAt(array, Literal(0)).nullable) - assert(ElementAt(array, Literal(4)).nullable) - assert(ElementAt(array, Literal(-4)).nullable) - - // GetArrayStructFields case - val f1 = StructField("a", IntegerType, nullable = false) - val f2 = StructField("b", IntegerType, nullable = true) - val structType = StructType(f1 :: f2 :: Nil) - val c = AttributeReference("c", structType, nullable = false)() - val inputArray1 = CreateArray(c :: Nil) - val inputArray1ContainsNull = c.nullable - val stArray1 = GetArrayStructFields(inputArray1, f1, 0, 2, inputArray1ContainsNull) - assert(!ElementAt(stArray1, Literal(1)).nullable) - assert(!ElementAt(stArray1, Literal(-1)).nullable) - val stArray2 = GetArrayStructFields(inputArray1, f2, 1, 2, inputArray1ContainsNull) - assert(ElementAt(stArray2, Literal(1)).nullable) - assert(ElementAt(stArray2, Literal(-1)).nullable) - - val d = AttributeReference("d", structType, nullable = true)() - val inputArray2 = CreateArray(c :: d :: Nil) - val inputArray2ContainsNull = c.nullable || d.nullable - val stArray3 = GetArrayStructFields(inputArray2, f1, 0, 2, inputArray2ContainsNull) - assert(!ElementAt(stArray3, Literal(1)).nullable) - assert(!ElementAt(stArray3, Literal(-2)).nullable) - assert(ElementAt(stArray3, Literal(2)).nullable) - assert(ElementAt(stArray3, Literal(-1)).nullable) - val stArray4 = GetArrayStructFields(inputArray2, f2, 1, 2, inputArray2ContainsNull) - assert(ElementAt(stArray4, Literal(1)).nullable) - assert(ElementAt(stArray4, Literal(-2)).nullable) - assert(ElementAt(stArray4, Literal(2)).nullable) - assert(ElementAt(stArray4, Literal(-1)).nullable) - - // GetArrayStructFields case invalid indices - assert(!ElementAt(stArray3, Literal(0)).nullable) - assert(ElementAt(stArray3, Literal(4)).nullable) - assert(ElementAt(stArray3, Literal(-4)).nullable) - - assert(ElementAt(stArray4, Literal(0)).nullable) - assert(ElementAt(stArray4, Literal(4)).nullable) - assert(ElementAt(stArray4, Literal(-4)).nullable) + Seq(true, false).foreach { ansiEnabled => + withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { + // CreateArray case + val a = AttributeReference("a", IntegerType, nullable = false)() + val b = AttributeReference("b", IntegerType, nullable = true)() + val array = CreateArray(a :: b :: Nil) + assert(!ElementAt(array, Literal(1)).nullable) + assert(!ElementAt(array, Literal(-2)).nullable) + assert(ElementAt(array, Literal(2)).nullable) + assert(ElementAt(array, Literal(-1)).nullable) + assert(!ElementAt(array, Subtract(Literal(2), Literal(1))).nullable) + assert(ElementAt(array, AttributeReference("ordinal", IntegerType)()).nullable) + + // CreateArray case invalid indices + assert(!ElementAt(array, Literal(0)).nullable) + if (ansiEnabled) { + assert(!ElementAt(array, Literal(4)).nullable) + assert(!ElementAt(array, Literal(-4)).nullable) + } else { + assert(ElementAt(array, Literal(4)).nullable) + assert(ElementAt(array, Literal(-4)).nullable) + } + + // GetArrayStructFields case + val f1 = StructField("a", IntegerType, nullable = false) + val f2 = StructField("b", IntegerType, nullable = true) + val structType = StructType(f1 :: f2 :: Nil) + val c = AttributeReference("c", structType, nullable = false)() + val inputArray1 = CreateArray(c :: Nil) + val inputArray1ContainsNull = c.nullable + val stArray1 = GetArrayStructFields(inputArray1, f1, 0, 2, inputArray1ContainsNull) + assert(!ElementAt(stArray1, Literal(1)).nullable) + assert(!ElementAt(stArray1, Literal(-1)).nullable) + val stArray2 = GetArrayStructFields(inputArray1, f2, 1, 2, inputArray1ContainsNull) + assert(ElementAt(stArray2, Literal(1)).nullable) + assert(ElementAt(stArray2, Literal(-1)).nullable) + + val d = AttributeReference("d", structType, nullable = true)() + val inputArray2 = CreateArray(c :: d :: Nil) + val inputArray2ContainsNull = c.nullable || d.nullable + val stArray3 = GetArrayStructFields(inputArray2, f1, 0, 2, inputArray2ContainsNull) + assert(!ElementAt(stArray3, Literal(1)).nullable) + assert(!ElementAt(stArray3, Literal(-2)).nullable) + assert(ElementAt(stArray3, Literal(2)).nullable) + assert(ElementAt(stArray3, Literal(-1)).nullable) + val stArray4 = GetArrayStructFields(inputArray2, f2, 1, 2, inputArray2ContainsNull) + assert(ElementAt(stArray4, Literal(1)).nullable) + assert(ElementAt(stArray4, Literal(-2)).nullable) + assert(ElementAt(stArray4, Literal(2)).nullable) + assert(ElementAt(stArray4, Literal(-1)).nullable) + + // GetArrayStructFields case invalid indices + assert(!ElementAt(stArray3, Literal(0)).nullable) + if (ansiEnabled) { + assert(!ElementAt(stArray3, Literal(4)).nullable) + assert(!ElementAt(stArray3, Literal(-4)).nullable) + } else { + assert(ElementAt(stArray3, Literal(4)).nullable) + assert(ElementAt(stArray3, Literal(-4)).nullable) + } + + assert(ElementAt(stArray4, Literal(0)).nullable) + assert(ElementAt(stArray4, Literal(4)).nullable) + assert(ElementAt(stArray4, Literal(-4)).nullable) + } + } } test("Concat") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index 06481ea2994f5..2eddf0ff2888e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -65,16 +65,21 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { test("SPARK-33391: GetArrayItem ArrayIndexOutOfBoundsException") { Seq(true, false).foreach { ansiEnabled => withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { - val typeA = ArrayType(StringType) - val array = Literal.create(Seq("a", "b"), typeA) + val array = Literal.create(Seq("a", "b"), ArrayType(StringType)) if (ansiEnabled) { - val ex = intercept[Exception] { - checkEvaluation(GetArrayItem(array, Literal(5)), null) - } - assert(stackTraceToString(ex).contains("Invalid index: 5")) + checkExceptionInExpression[Exception]( + GetArrayItem(array, Literal(5)), + "Invalid index: 5" + ) + + checkExceptionInExpression[Exception]( + GetArrayItem(array, Literal(-1)), + "Invalid index: -1" + ) } else { checkEvaluation(GetArrayItem(array, Literal(5)), null) + checkEvaluation(GetArrayItem(array, Literal(-1)), null) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index ef896a5050365..3877dd3f7c7e5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -24,7 +24,7 @@ import scala.util.Random import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.{ElementAt, Expression, ExpressionEvalHelper, Literal} import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.plans.logical.OneRowRelation import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.{withDefaultTimeZone, UTC} @@ -36,7 +36,7 @@ import org.apache.spark.sql.types._ /** * Test suite for functions in [[org.apache.spark.sql.functions]]. */ -class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { +class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession with ExpressionEvalHelper{ import testImplicits._ test("array with column name") { @@ -3625,18 +3625,33 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { test("SPARK-33391: element_at ArrayIndexOutOfBoundsException") { Seq(true, false).foreach { ansiEnabled => withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { - val df = sql("select element_at(array(1, 2, 3), 5)") + var df = sql("select element_at(array(1, 2, 3), 5)") + val array = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType)) if (ansiEnabled) { - val ex = intercept[Exception] { - df.collect() - } - assert(ex.getMessage.contains("Invalid index: 5")) + val errMsg = "Invalid index: 5" + val ex = intercept[Exception](df.collect()) + assert(ex.getMessage.contains(errMsg)) + checkExceptionInExpression[Exception](ElementAt(array, Literal(5)), errMsg) } else { - checkAnswer( - df, - Row(null) - ) + checkAnswer(df, Row(null)) + } + + df = sql("select element_at(array(1, 2, 3), -5)") + if (ansiEnabled) { + val errMsg = "Invalid index: -5" + val ex = intercept[Exception](df.collect()) + assert(ex.getMessage.contains(errMsg)) + checkExceptionInExpression[Exception](ElementAt(array, Literal(-5)), errMsg) + } else { + checkAnswer(df, Row(null)) } + + // SQL array indices start at 1 exception throws for both mode. + val errMsg = "SQL array indices start at 1" + df = sql("select element_at(array(1, 2, 3), 0)") + val ex = intercept[Exception](df.collect()) + assert(ex.getMessage.contains(errMsg)) + checkExceptionInExpression[Exception](ElementAt(array, Literal(0)), errMsg) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index c8e4eb6c68e15..ea2a1398cefd7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -17,12 +17,13 @@ package org.apache.spark.sql +import org.apache.spark.sql.catalyst.expressions.{Elt, ExpressionEvalHelper, Literal} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession -class StringFunctionsSuite extends QueryTest with SharedSparkSession { +class StringFunctionsSuite extends QueryTest with SharedSparkSession with ExpressionEvalHelper { import testImplicits._ test("string concat") { @@ -621,17 +622,40 @@ class StringFunctionsSuite extends QueryTest with SharedSparkSession { test("SPARK-33391: elt ArrayIndexOutOfBoundsException") { Seq(true, false).foreach { ansiEnabled => withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { - val df = sql("select elt(4, '123', '456')") + var df = sql("select elt(4, '123', '456')") if (ansiEnabled) { - val ex = intercept[Exception] { - df.collect() - } - assert(ex.getMessage.contains("Invalid index: 4")) + val errMsg = "Invalid index: 4" + val ex = intercept[Exception](df.collect()) + assert(ex.getMessage.contains(errMsg)) + checkExceptionInExpression[Exception]( + Elt(Seq(Literal(4), Literal("123"), Literal("456"))), + errMsg) } else { - checkAnswer( - df, - Row(null) - ) + checkAnswer(df, Row(null)) + } + + df = sql("select elt(0, '123', '456')") + if (ansiEnabled) { + val errMsg = "Invalid index: 0" + val ex = intercept[Exception](df.collect()) + assert(ex.getMessage.contains(errMsg)) + checkExceptionInExpression[Exception]( + Elt(Seq(Literal(0), Literal("123"), Literal("456"))), + errMsg) + } else { + checkAnswer(df, Row(null)) + } + + df = sql("select elt(-1, '123', '456')") + if (ansiEnabled) { + val errMsg = "Invalid index: -1" + val ex = intercept[Exception](df.collect()) + assert(ex.getMessage.contains(errMsg)) + checkExceptionInExpression[Exception]( + Elt(Seq(Literal(-1), Literal("123"), Literal("456"))), + errMsg) + } else { + checkAnswer(df, Row(null)) } } }