From da4f5eeb9423d2ebacb2dffaa62d58efcb5c7db9 Mon Sep 17 00:00:00 2001 From: "xuewei.linxuewei" Date: Mon, 9 Nov 2020 19:32:46 +0800 Subject: [PATCH 1/8] Accessing array elements should failed if index is out of bound. Change-Id: If7650cc45aa30fd3d6549536a8af6ca01a746c39 --- .../sql/catalyst/analysis/TypeCoercion.scala | 4 +-- .../expressions/ProjectionOverSchema.scala | 6 +++-- .../catalyst/expressions/SelectedField.scala | 2 +- .../expressions/collectionOperations.scala | 22 ++++++++++++--- .../expressions/complexTypeExtractors.scala | 27 ++++++++++++++++--- .../expressions/stringExpressions.scala | 24 +++++++++++++++-- .../sql/catalyst/optimizer/ComplexTypes.scala | 2 +- .../expressions/ComplexTypeSuite.scala | 18 +++++++++++++ .../spark/sql/DataFrameFunctionsSuite.scala | 19 +++++++++++++ .../spark/sql/StringFunctionsSuite.scala | 19 +++++++++++++ 10 files changed, 129 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index becdef8b9c603..e8dab28b5e907 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -840,8 +840,8 @@ object TypeCoercion { plan resolveOperators { case p => p transformExpressionsUp { // Skip nodes if unresolved or not enough children - case c @ Elt(children) if !c.childrenResolved || children.size < 2 => c - case c @ Elt(children) => + case c @ Elt(children, _) if !c.childrenResolved || children.size < 2 => c + case c @ Elt(children, _) => val index = children.head val newIndex = ImplicitTypeCasts.implicitCast(index, IntegerType).getOrElse(index) val newInputs = if (conf.eltOutputAsString || diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala index 13c6f8db7c129..6f1d9d065ab1a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala @@ -34,8 +34,10 @@ case class ProjectionOverSchema(schema: StructType) { expr match { case a: AttributeReference if fieldNames.contains(a.name) => Some(a.copy(dataType = schema(a.name).dataType)(a.exprId, a.qualifier)) - case GetArrayItem(child, arrayItemOrdinal) => - getProjection(child).map { projection => GetArrayItem(projection, arrayItemOrdinal) } + case GetArrayItem(child, arrayItemOrdinal, failOnError) => + getProjection(child).map { + projection => GetArrayItem(projection, arrayItemOrdinal, failOnError) + } case a: GetArrayStructFields => getProjection(a.child).map(p => (p, p.dataType)).map { case (projection, ArrayType(projSchema @ StructType(_), _)) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SelectedField.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SelectedField.scala index 7ba3d302d553b..adcc4be10687e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SelectedField.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SelectedField.scala @@ -119,7 +119,7 @@ object SelectedField { throw new AnalysisException(s"DataType '$x' is not supported by MapKeys.") } selectField(child, opt) - case GetArrayItem(child, _) => + case GetArrayItem(child, _, _) => // GetArrayItem does not select a field from a struct (i.e. prune the struct) so it can't be // the top-level extractor. However it can be part of an extractor chain. val ArrayType(_, containsNull) = child.dataType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index cb081b80ba096..6a6a0308057d3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1919,9 +1919,14 @@ case class ArrayPosition(left: Expression, right: Expression) b """, since = "2.4.0") -case class ElementAt(left: Expression, right: Expression) +case class ElementAt( + left: Expression, + right: Expression, + failOnError: Boolean = SQLConf.get.ansiEnabled) extends GetMapValueUtil with GetArrayItemUtil with NullIntolerant { + def this(left: Expression, right: Expression) = this(left, right, SQLConf.get.ansiEnabled) + @transient private lazy val mapKeyType = left.dataType.asInstanceOf[MapType].keyType @transient private lazy val arrayContainsNull = left.dataType.asInstanceOf[ArrayType].containsNull @@ -2008,7 +2013,11 @@ case class ElementAt(left: Expression, right: Expression) val array = value.asInstanceOf[ArrayData] val index = ordinal.asInstanceOf[Int] if (array.numElements() < math.abs(index)) { - null + if (failOnError) { + throw new ArrayIndexOutOfBoundsException(s"Invalid index: $index") + } else { + null + } } else { val idx = if (index == 0) { throw new ArrayIndexOutOfBoundsException("SQL array indices start at 1") @@ -2042,10 +2051,17 @@ case class ElementAt(left: Expression, right: Expression) } else { "" } + + val failOnErrorBranch = if (failOnError) { + s"""throw new ArrayIndexOutOfBoundsException("Invalid index: " + $index);""".stripMargin + } else { + s"${ev.isNull} = true;" + } + s""" |int $index = (int) $eval2; |if ($eval1.numElements() < Math.abs($index)) { - | ${ev.isNull} = true; + | $failOnErrorBranch |} else { | if ($index == 0) { | throw new ArrayIndexOutOfBoundsException("SQL array indices start at 1"); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 60afe140960cc..3a52bfc40c4aa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.catalyst.util.{quoteIdentifier, ArrayData, GenericArrayData, MapData, TypeUtils} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -222,10 +223,15 @@ case class GetArrayStructFields( * * We need to do type checking here as `ordinal` expression maybe unresolved. */ -case class GetArrayItem(child: Expression, ordinal: Expression) +case class GetArrayItem( + child: Expression, + ordinal: Expression, + failOnError: Boolean = SQLConf.get.ansiEnabled) extends BinaryExpression with GetArrayItemUtil with ExpectsInputTypes with ExtractValue with NullIntolerant { + def this(child: Expression, ordinal: Expression) = this(child, ordinal, SQLConf.get.ansiEnabled) + // We have done type checking for child in `ExtractValue`, so only need to check the `ordinal`. override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, IntegralType) @@ -240,7 +246,13 @@ case class GetArrayItem(child: Expression, ordinal: Expression) protected override def nullSafeEval(value: Any, ordinal: Any): Any = { val baseValue = value.asInstanceOf[ArrayData] val index = ordinal.asInstanceOf[Number].intValue() - if (index >= baseValue.numElements() || index < 0 || baseValue.isNullAt(index)) { + if (index >= baseValue.numElements() || index < 0) { + if (failOnError) { + throw new ArrayIndexOutOfBoundsException(s"Invalid index: $index") + } else { + null + } + } else if (baseValue.isNullAt(index)) { null } else { baseValue.get(index, dataType) @@ -255,9 +267,18 @@ case class GetArrayItem(child: Expression, ordinal: Expression) } else { "" } + + val failOnErrorBranch = if (failOnError) { + s"""throw new ArrayIndexOutOfBoundsException("Invalid index: " + $index);""".stripMargin + } else { + s"${ev.isNull} = true;" + } + s""" final int $index = (int) $eval2; - if ($index >= $eval1.numElements() || $index < 0$nullCheck) { + if ($index >= $eval1.numElements() || $index < 0) { + $failOnErrorBranch + } else if (false$nullCheck) { ${ev.isNull} = true; } else { ${ev.value} = ${CodeGenerator.getValue(eval1, dataType, index)}; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 1fe990207160c..e0624238a9742 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, TypeUtils} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.UTF8StringBuilder import org.apache.spark.unsafe.types.{ByteArray, UTF8String} @@ -239,7 +240,11 @@ case class ConcatWs(children: Seq[Expression]) """, since = "2.0.0") // scalastyle:on line.size.limit -case class Elt(children: Seq[Expression]) extends Expression { +case class Elt( + children: Seq[Expression], + failOnError: Boolean = SQLConf.get.ansiEnabled) extends Expression { + + def this(children: Seq[Expression]) = this(children, SQLConf.get.ansiEnabled) private lazy val indexExpr = children.head private lazy val inputExprs = children.tail.toArray @@ -275,7 +280,11 @@ case class Elt(children: Seq[Expression]) extends Expression { } else { val index = indexObj.asInstanceOf[Int] if (index <= 0 || index > inputExprs.length) { - null + if (failOnError) { + throw new ArrayIndexOutOfBoundsException(s"Invalid index: $index") + } else { + null + } } else { inputExprs(index - 1).eval(input) } @@ -323,6 +332,16 @@ case class Elt(children: Seq[Expression]) extends Expression { """.stripMargin }.mkString) + val failOnErrorBranch = if (failOnError) { + s""" + |if (!$indexMatched) { + | throw new ArrayIndexOutOfBoundsException("Invalid index: " + ${index.value}); + |} + """.stripMargin + } else { + "" + } + ev.copy( code""" |${index.code} @@ -332,6 +351,7 @@ case class Elt(children: Seq[Expression]) extends Expression { |do { | $codes |} while (false); + |$failOnErrorBranch |final ${CodeGenerator.javaType(dataType)} ${ev.value} = $inputVal; |final boolean ${ev.isNull} = ${ev.value} == null; """.stripMargin) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala index 2ac8f62b67b3d..7a21ce254a235 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala @@ -61,7 +61,7 @@ object SimplifyExtractValueOps extends Rule[LogicalPlan] { CreateArray(elems.map(GetStructField(_, ordinal, Some(field.name))), useStringTypeWhenEmpty) // Remove redundant map lookup. - case ga @ GetArrayItem(CreateArray(elems, _), IntegerLiteral(idx)) => + case ga @ GetArrayItem(CreateArray(elems, _), IntegerLiteral(idx), _) => // Instead of creating the array and then selecting one row, remove array creation // altogether. if (idx >= 0 && idx < elems.size) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index 38e32ff2518f7..06481ea2994f5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -62,6 +62,24 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(GetArrayItem(nestedArray, Literal(0)), Seq(1)) } + test("SPARK-33391: GetArrayItem ArrayIndexOutOfBoundsException") { + Seq(true, false).foreach { ansiEnabled => + withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { + val typeA = ArrayType(StringType) + val array = Literal.create(Seq("a", "b"), typeA) + + if (ansiEnabled) { + val ex = intercept[Exception] { + checkEvaluation(GetArrayItem(array, Literal(5)), null) + } + assert(stackTraceToString(ex).contains("Invalid index: 5")) + } else { + checkEvaluation(GetArrayItem(array, Literal(5)), null) + } + } + } + } + test("SPARK-26637 handles GetArrayItem nullability correctly when input array size is constant") { // CreateArray case val a = AttributeReference("a", IntegerType, nullable = false)() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 585a835024e19..ef896a5050365 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -3621,6 +3621,25 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { df.select(map(map_entries($"m"), lit(1))), Row(Map(Seq(Row(1, "a")) -> 1))) } + + test("SPARK-33391: element_at ArrayIndexOutOfBoundsException") { + Seq(true, false).foreach { ansiEnabled => + withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { + val df = sql("select element_at(array(1, 2, 3), 5)") + if (ansiEnabled) { + val ex = intercept[Exception] { + df.collect() + } + assert(ex.getMessage.contains("Invalid index: 5")) + } else { + checkAnswer( + df, + Row(null) + ) + } + } + } + } } object DataFrameFunctionsSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index 8d5166b5398cc..c8e4eb6c68e15 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -617,4 +617,23 @@ class StringFunctionsSuite extends QueryTest with SharedSparkSession { ) } + + test("SPARK-33391: elt ArrayIndexOutOfBoundsException") { + Seq(true, false).foreach { ansiEnabled => + withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { + val df = sql("select elt(4, '123', '456')") + if (ansiEnabled) { + val ex = intercept[Exception] { + df.collect() + } + assert(ex.getMessage.contains("Invalid index: 4")) + } else { + checkAnswer( + df, + Row(null) + ) + } + } + } + } } From 7dcba564b70afa24f3d018e9524ec2eb0a23e943 Mon Sep 17 00:00:00 2001 From: "xuewei.linxuewei" Date: Tue, 10 Nov 2020 20:43:21 +0800 Subject: [PATCH 2/8] 1. update doc. 2. add more UT scenario. 3. update nullability calculation for ElementAt and update UT. Change-Id: Ieec69fe5171dfee2863ad93f84bb660076214de0 --- docs/sql-ref-ansi-compliance.md | 2 + .../expressions/collectionOperations.scala | 8 +- .../expressions/stringExpressions.scala | 6 +- .../apache/spark/sql/internal/SQLConf.scala | 4 +- .../CollectionExpressionsSuite.scala | 118 ++++++++++-------- .../expressions/ComplexTypeSuite.scala | 17 ++- .../spark/sql/DataFrameFunctionsSuite.scala | 37 ++++-- .../spark/sql/StringFunctionsSuite.scala | 44 +++++-- 8 files changed, 151 insertions(+), 85 deletions(-) diff --git a/docs/sql-ref-ansi-compliance.md b/docs/sql-ref-ansi-compliance.md index d6e99312bb66e..42fe02305ded4 100644 --- a/docs/sql-ref-ansi-compliance.md +++ b/docs/sql-ref-ansi-compliance.md @@ -111,6 +111,8 @@ SELECT * FROM t; The behavior of some SQL functions can be different under ANSI mode (`spark.sql.ansi.enabled=true`). - `size`: This function returns null for null input under ANSI mode. + - `element_at`: This function throws `ArrayIndexOutOfBoundsException` if using invalid indices under ANSI mode. + - `elt`: This function throws `ArrayIndexOutOfBoundsException` if using invalid indices under ANSI mode. ### SQL Keywords diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 6a6a0308057d3..c678e70c7fe0c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1906,8 +1906,8 @@ case class ArrayPosition(left: Expression, right: Expression) @ExpressionDescription( usage = """ _FUNC_(array, index) - Returns element of array at given (1-based) index. If index < 0, - accesses elements from the last to the first. Returns NULL if the index exceeds the length - of the array. + accesses elements from the last to the first. If the index exceeds the length of the array, + Returns NULL if Ansi mode is off; Throws ArrayIndexOutOfBoundsException when Ansi mode is on. _FUNC_(map, key) - Returns value for given key, or NULL if the key is not contained in the map """, @@ -1974,7 +1974,7 @@ case class ElementAt( if (ordinal == 0) { false } else if (elements.length < math.abs(ordinal)) { - true + if (failOnError) false else true } else { if (ordinal < 0) { elements(elements.length + ordinal).nullable @@ -1996,7 +1996,7 @@ case class ElementAt( true } } else { - true + if (failOnError) arrayContainsNull else true } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index e0624238a9742..faceaa6c4b25a 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -232,7 +232,11 @@ case class ConcatWs(children: Seq[Expression]) */ // scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_(n, input1, input2, ...) - Returns the `n`-th input, e.g., returns `input2` when `n` is 2.", + usage = """ + _FUNC_(n, input1, input2, ...) - Returns the `n`-th input, e.g., returns `input2` when `n` is 2. + If the index exceeds the length of the array, Returns NULL if Ansi mode is off; + Throws ArrayIndexOutOfBoundsException when Ansi mode is on. + """, examples = """ Examples: > SELECT _FUNC_(1, 'scala', 'java'); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 21357a492e39e..9b9af2c7f023e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2146,7 +2146,9 @@ object SQLConf { .doc("When true, Spark tries to conform to the ANSI SQL specification: 1. Spark will " + "throw a runtime exception if an overflow occurs in any operation on integral/decimal " + "field. 2. Spark will forbid using the reserved keywords of ANSI SQL as identifiers in " + - "the SQL parser.") + "the SQL parser. 3. Spark will returns null for null input for function `size`. " + + "4. Spark will throw ArrayIndexOutOfBoundsException if invalid indices " + + "used on function `element_at`/`elt`.") .version("3.0.0") .booleanConf .createWithDefault(false) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index d59d13d49cef4..455efe336064d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -1118,58 +1118,72 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper } test("correctly handles ElementAt nullability for arrays") { - // CreateArray case - val a = AttributeReference("a", IntegerType, nullable = false)() - val b = AttributeReference("b", IntegerType, nullable = true)() - val array = CreateArray(a :: b :: Nil) - assert(!ElementAt(array, Literal(1)).nullable) - assert(!ElementAt(array, Literal(-2)).nullable) - assert(ElementAt(array, Literal(2)).nullable) - assert(ElementAt(array, Literal(-1)).nullable) - assert(!ElementAt(array, Subtract(Literal(2), Literal(1))).nullable) - assert(ElementAt(array, AttributeReference("ordinal", IntegerType)()).nullable) - - // CreateArray case invalid indices - assert(!ElementAt(array, Literal(0)).nullable) - assert(ElementAt(array, Literal(4)).nullable) - assert(ElementAt(array, Literal(-4)).nullable) - - // GetArrayStructFields case - val f1 = StructField("a", IntegerType, nullable = false) - val f2 = StructField("b", IntegerType, nullable = true) - val structType = StructType(f1 :: f2 :: Nil) - val c = AttributeReference("c", structType, nullable = false)() - val inputArray1 = CreateArray(c :: Nil) - val inputArray1ContainsNull = c.nullable - val stArray1 = GetArrayStructFields(inputArray1, f1, 0, 2, inputArray1ContainsNull) - assert(!ElementAt(stArray1, Literal(1)).nullable) - assert(!ElementAt(stArray1, Literal(-1)).nullable) - val stArray2 = GetArrayStructFields(inputArray1, f2, 1, 2, inputArray1ContainsNull) - assert(ElementAt(stArray2, Literal(1)).nullable) - assert(ElementAt(stArray2, Literal(-1)).nullable) - - val d = AttributeReference("d", structType, nullable = true)() - val inputArray2 = CreateArray(c :: d :: Nil) - val inputArray2ContainsNull = c.nullable || d.nullable - val stArray3 = GetArrayStructFields(inputArray2, f1, 0, 2, inputArray2ContainsNull) - assert(!ElementAt(stArray3, Literal(1)).nullable) - assert(!ElementAt(stArray3, Literal(-2)).nullable) - assert(ElementAt(stArray3, Literal(2)).nullable) - assert(ElementAt(stArray3, Literal(-1)).nullable) - val stArray4 = GetArrayStructFields(inputArray2, f2, 1, 2, inputArray2ContainsNull) - assert(ElementAt(stArray4, Literal(1)).nullable) - assert(ElementAt(stArray4, Literal(-2)).nullable) - assert(ElementAt(stArray4, Literal(2)).nullable) - assert(ElementAt(stArray4, Literal(-1)).nullable) - - // GetArrayStructFields case invalid indices - assert(!ElementAt(stArray3, Literal(0)).nullable) - assert(ElementAt(stArray3, Literal(4)).nullable) - assert(ElementAt(stArray3, Literal(-4)).nullable) - - assert(ElementAt(stArray4, Literal(0)).nullable) - assert(ElementAt(stArray4, Literal(4)).nullable) - assert(ElementAt(stArray4, Literal(-4)).nullable) + Seq(true, false).foreach { ansiEnabled => + withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { + // CreateArray case + val a = AttributeReference("a", IntegerType, nullable = false)() + val b = AttributeReference("b", IntegerType, nullable = true)() + val array = CreateArray(a :: b :: Nil) + assert(!ElementAt(array, Literal(1)).nullable) + assert(!ElementAt(array, Literal(-2)).nullable) + assert(ElementAt(array, Literal(2)).nullable) + assert(ElementAt(array, Literal(-1)).nullable) + assert(!ElementAt(array, Subtract(Literal(2), Literal(1))).nullable) + assert(ElementAt(array, AttributeReference("ordinal", IntegerType)()).nullable) + + // CreateArray case invalid indices + assert(!ElementAt(array, Literal(0)).nullable) + if (ansiEnabled) { + assert(!ElementAt(array, Literal(4)).nullable) + assert(!ElementAt(array, Literal(-4)).nullable) + } else { + assert(ElementAt(array, Literal(4)).nullable) + assert(ElementAt(array, Literal(-4)).nullable) + } + + // GetArrayStructFields case + val f1 = StructField("a", IntegerType, nullable = false) + val f2 = StructField("b", IntegerType, nullable = true) + val structType = StructType(f1 :: f2 :: Nil) + val c = AttributeReference("c", structType, nullable = false)() + val inputArray1 = CreateArray(c :: Nil) + val inputArray1ContainsNull = c.nullable + val stArray1 = GetArrayStructFields(inputArray1, f1, 0, 2, inputArray1ContainsNull) + assert(!ElementAt(stArray1, Literal(1)).nullable) + assert(!ElementAt(stArray1, Literal(-1)).nullable) + val stArray2 = GetArrayStructFields(inputArray1, f2, 1, 2, inputArray1ContainsNull) + assert(ElementAt(stArray2, Literal(1)).nullable) + assert(ElementAt(stArray2, Literal(-1)).nullable) + + val d = AttributeReference("d", structType, nullable = true)() + val inputArray2 = CreateArray(c :: d :: Nil) + val inputArray2ContainsNull = c.nullable || d.nullable + val stArray3 = GetArrayStructFields(inputArray2, f1, 0, 2, inputArray2ContainsNull) + assert(!ElementAt(stArray3, Literal(1)).nullable) + assert(!ElementAt(stArray3, Literal(-2)).nullable) + assert(ElementAt(stArray3, Literal(2)).nullable) + assert(ElementAt(stArray3, Literal(-1)).nullable) + val stArray4 = GetArrayStructFields(inputArray2, f2, 1, 2, inputArray2ContainsNull) + assert(ElementAt(stArray4, Literal(1)).nullable) + assert(ElementAt(stArray4, Literal(-2)).nullable) + assert(ElementAt(stArray4, Literal(2)).nullable) + assert(ElementAt(stArray4, Literal(-1)).nullable) + + // GetArrayStructFields case invalid indices + assert(!ElementAt(stArray3, Literal(0)).nullable) + if (ansiEnabled) { + assert(!ElementAt(stArray3, Literal(4)).nullable) + assert(!ElementAt(stArray3, Literal(-4)).nullable) + } else { + assert(ElementAt(stArray3, Literal(4)).nullable) + assert(ElementAt(stArray3, Literal(-4)).nullable) + } + + assert(ElementAt(stArray4, Literal(0)).nullable) + assert(ElementAt(stArray4, Literal(4)).nullable) + assert(ElementAt(stArray4, Literal(-4)).nullable) + } + } } test("Concat") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index 06481ea2994f5..2eddf0ff2888e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -65,16 +65,21 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { test("SPARK-33391: GetArrayItem ArrayIndexOutOfBoundsException") { Seq(true, false).foreach { ansiEnabled => withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { - val typeA = ArrayType(StringType) - val array = Literal.create(Seq("a", "b"), typeA) + val array = Literal.create(Seq("a", "b"), ArrayType(StringType)) if (ansiEnabled) { - val ex = intercept[Exception] { - checkEvaluation(GetArrayItem(array, Literal(5)), null) - } - assert(stackTraceToString(ex).contains("Invalid index: 5")) + checkExceptionInExpression[Exception]( + GetArrayItem(array, Literal(5)), + "Invalid index: 5" + ) + + checkExceptionInExpression[Exception]( + GetArrayItem(array, Literal(-1)), + "Invalid index: -1" + ) } else { checkEvaluation(GetArrayItem(array, Literal(5)), null) + checkEvaluation(GetArrayItem(array, Literal(-1)), null) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index ef896a5050365..3877dd3f7c7e5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -24,7 +24,7 @@ import scala.util.Random import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.{ElementAt, Expression, ExpressionEvalHelper, Literal} import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.plans.logical.OneRowRelation import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.{withDefaultTimeZone, UTC} @@ -36,7 +36,7 @@ import org.apache.spark.sql.types._ /** * Test suite for functions in [[org.apache.spark.sql.functions]]. */ -class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { +class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession with ExpressionEvalHelper{ import testImplicits._ test("array with column name") { @@ -3625,18 +3625,33 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { test("SPARK-33391: element_at ArrayIndexOutOfBoundsException") { Seq(true, false).foreach { ansiEnabled => withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { - val df = sql("select element_at(array(1, 2, 3), 5)") + var df = sql("select element_at(array(1, 2, 3), 5)") + val array = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType)) if (ansiEnabled) { - val ex = intercept[Exception] { - df.collect() - } - assert(ex.getMessage.contains("Invalid index: 5")) + val errMsg = "Invalid index: 5" + val ex = intercept[Exception](df.collect()) + assert(ex.getMessage.contains(errMsg)) + checkExceptionInExpression[Exception](ElementAt(array, Literal(5)), errMsg) } else { - checkAnswer( - df, - Row(null) - ) + checkAnswer(df, Row(null)) + } + + df = sql("select element_at(array(1, 2, 3), -5)") + if (ansiEnabled) { + val errMsg = "Invalid index: -5" + val ex = intercept[Exception](df.collect()) + assert(ex.getMessage.contains(errMsg)) + checkExceptionInExpression[Exception](ElementAt(array, Literal(-5)), errMsg) + } else { + checkAnswer(df, Row(null)) } + + // SQL array indices start at 1 exception throws for both mode. + val errMsg = "SQL array indices start at 1" + df = sql("select element_at(array(1, 2, 3), 0)") + val ex = intercept[Exception](df.collect()) + assert(ex.getMessage.contains(errMsg)) + checkExceptionInExpression[Exception](ElementAt(array, Literal(0)), errMsg) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index c8e4eb6c68e15..ea2a1398cefd7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -17,12 +17,13 @@ package org.apache.spark.sql +import org.apache.spark.sql.catalyst.expressions.{Elt, ExpressionEvalHelper, Literal} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession -class StringFunctionsSuite extends QueryTest with SharedSparkSession { +class StringFunctionsSuite extends QueryTest with SharedSparkSession with ExpressionEvalHelper { import testImplicits._ test("string concat") { @@ -621,17 +622,40 @@ class StringFunctionsSuite extends QueryTest with SharedSparkSession { test("SPARK-33391: elt ArrayIndexOutOfBoundsException") { Seq(true, false).foreach { ansiEnabled => withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { - val df = sql("select elt(4, '123', '456')") + var df = sql("select elt(4, '123', '456')") if (ansiEnabled) { - val ex = intercept[Exception] { - df.collect() - } - assert(ex.getMessage.contains("Invalid index: 4")) + val errMsg = "Invalid index: 4" + val ex = intercept[Exception](df.collect()) + assert(ex.getMessage.contains(errMsg)) + checkExceptionInExpression[Exception]( + Elt(Seq(Literal(4), Literal("123"), Literal("456"))), + errMsg) } else { - checkAnswer( - df, - Row(null) - ) + checkAnswer(df, Row(null)) + } + + df = sql("select elt(0, '123', '456')") + if (ansiEnabled) { + val errMsg = "Invalid index: 0" + val ex = intercept[Exception](df.collect()) + assert(ex.getMessage.contains(errMsg)) + checkExceptionInExpression[Exception]( + Elt(Seq(Literal(0), Literal("123"), Literal("456"))), + errMsg) + } else { + checkAnswer(df, Row(null)) + } + + df = sql("select elt(-1, '123', '456')") + if (ansiEnabled) { + val errMsg = "Invalid index: -1" + val ex = intercept[Exception](df.collect()) + assert(ex.getMessage.contains(errMsg)) + checkExceptionInExpression[Exception]( + Elt(Seq(Literal(-1), Literal("123"), Literal("456"))), + errMsg) + } else { + checkAnswer(df, Row(null)) } } } From f8cfa5be4573a495e5fd281fd665a1c140be4c0f Mon Sep 17 00:00:00 2001 From: "xuewei.linxuewei" Date: Wed, 11 Nov 2020 14:53:48 +0800 Subject: [PATCH 3/8] UT refine and Doc refine. Change-Id: Ieba5163781c147d51739ad661115ae3b8edd3f22 --- docs/sql-ref-ansi-compliance.md | 9 +- .../expressions/collectionOperations.scala | 4 +- .../expressions/complexTypeExtractors.scala | 9 +- .../expressions/stringExpressions.scala | 4 +- .../apache/spark/sql/internal/SQLConf.scala | 9 +- .../CollectionExpressionsSuite.scala | 46 ++++++--- .../expressions/StringExpressionsSuite.scala | 32 ++++++- .../spark/sql/DataFrameFunctionsSuite.scala | 38 +------- .../org/apache/spark/sql/DataFrameSuite.scala | 93 ++++++++++++++++++- .../spark/sql/StringFunctionsSuite.scala | 45 +-------- 10 files changed, 178 insertions(+), 111 deletions(-) diff --git a/docs/sql-ref-ansi-compliance.md b/docs/sql-ref-ansi-compliance.md index 42fe02305ded4..3efd208706fdc 100644 --- a/docs/sql-ref-ansi-compliance.md +++ b/docs/sql-ref-ansi-compliance.md @@ -111,8 +111,13 @@ SELECT * FROM t; The behavior of some SQL functions can be different under ANSI mode (`spark.sql.ansi.enabled=true`). - `size`: This function returns null for null input under ANSI mode. - - `element_at`: This function throws `ArrayIndexOutOfBoundsException` if using invalid indices under ANSI mode. - - `elt`: This function throws `ArrayIndexOutOfBoundsException` if using invalid indices under ANSI mode. + - `element_at`: This function throws `ArrayIndexOutOfBoundsException` if using invalid indices. + - `elt`: This function throws `ArrayIndexOutOfBoundsException` if using invalid indices. + +### SQL Operators + +The behavior of some SQL operators can be different under ANSI mode (`spark.sql.ansi.enabled=true`). + - `GetArrayItem`: This operator throws `ArrayIndexOutOfBoundsException` if using invalid indices. ### SQL Keywords diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index c678e70c7fe0c..5f23bf4c9d68b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1907,7 +1907,7 @@ case class ArrayPosition(left: Expression, right: Expression) usage = """ _FUNC_(array, index) - Returns element of array at given (1-based) index. If index < 0, accesses elements from the last to the first. If the index exceeds the length of the array, - Returns NULL if Ansi mode is off; Throws ArrayIndexOutOfBoundsException when Ansi mode is on. + Returns NULL if ANSI mode is off; Throws ArrayIndexOutOfBoundsException when ANSI mode is on. _FUNC_(map, key) - Returns value for given key, or NULL if the key is not contained in the map """, @@ -1974,7 +1974,7 @@ case class ElementAt( if (ordinal == 0) { false } else if (elements.length < math.abs(ordinal)) { - if (failOnError) false else true + !failOnError } else { if (ordinal < 0) { elements(elements.length + ordinal).nullable diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 3a52bfc40c4aa..f3896d560566d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -263,7 +263,10 @@ case class GetArrayItem( nullSafeCodeGen(ctx, ev, (eval1, eval2) => { val index = ctx.freshName("index") val nullCheck = if (child.dataType.asInstanceOf[ArrayType].containsNull) { - s" || $eval1.isNullAt($index)" + s"""else if ($eval1.isNullAt($index)) { + ${ev.isNull} = true; + } + """ } else { "" } @@ -278,9 +281,7 @@ case class GetArrayItem( final int $index = (int) $eval2; if ($index >= $eval1.numElements() || $index < 0) { $failOnErrorBranch - } else if (false$nullCheck) { - ${ev.isNull} = true; - } else { + } $nullCheck else { ${ev.value} = ${CodeGenerator.getValue(eval1, dataType, index)}; } """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index faceaa6c4b25a..b0f20807903bd 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -234,8 +234,8 @@ case class ConcatWs(children: Seq[Expression]) @ExpressionDescription( usage = """ _FUNC_(n, input1, input2, ...) - Returns the `n`-th input, e.g., returns `input2` when `n` is 2. - If the index exceeds the length of the array, Returns NULL if Ansi mode is off; - Throws ArrayIndexOutOfBoundsException when Ansi mode is on. + If the index exceeds the length of the array, Returns NULL if ANSI mode is off; + Throws ArrayIndexOutOfBoundsException when ANSI mode is on. """, examples = """ Examples: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 9b9af2c7f023e..5a19e35e8c8e1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2144,11 +2144,10 @@ object SQLConf { val ANSI_ENABLED = buildConf("spark.sql.ansi.enabled") .doc("When true, Spark tries to conform to the ANSI SQL specification: 1. Spark will " + - "throw a runtime exception if an overflow occurs in any operation on integral/decimal " + - "field. 2. Spark will forbid using the reserved keywords of ANSI SQL as identifiers in " + - "the SQL parser. 3. Spark will returns null for null input for function `size`. " + - "4. Spark will throw ArrayIndexOutOfBoundsException if invalid indices " + - "used on function `element_at`/`elt`.") + "throw an exception at runtime if the inputs to a SQL operator/function are invalid, " + + "e.g. overflow in arithmetic operations, out-of-range index when accessing array elements. " + + "2. Spark will forbid using the reserved keywords of ANSI SQL as identifiers in " + + "the SQL parser. 3. Spark will returns null for null input for function `size`.") .version("3.0.0") .booleanConf .createWithDefault(false) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 455efe336064d..f65d11a78abc2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -1133,13 +1133,8 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper // CreateArray case invalid indices assert(!ElementAt(array, Literal(0)).nullable) - if (ansiEnabled) { - assert(!ElementAt(array, Literal(4)).nullable) - assert(!ElementAt(array, Literal(-4)).nullable) - } else { - assert(ElementAt(array, Literal(4)).nullable) - assert(ElementAt(array, Literal(-4)).nullable) - } + assert(ElementAt(array, Literal(4)).nullable == !ansiEnabled) + assert(ElementAt(array, Literal(-4)).nullable == !ansiEnabled) // GetArrayStructFields case val f1 = StructField("a", IntegerType, nullable = false) @@ -1171,13 +1166,8 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper // GetArrayStructFields case invalid indices assert(!ElementAt(stArray3, Literal(0)).nullable) - if (ansiEnabled) { - assert(!ElementAt(stArray3, Literal(4)).nullable) - assert(!ElementAt(stArray3, Literal(-4)).nullable) - } else { - assert(ElementAt(stArray3, Literal(4)).nullable) - assert(ElementAt(stArray3, Literal(-4)).nullable) - } + assert(ElementAt(stArray3, Literal(4)).nullable == !ansiEnabled) + assert(ElementAt(stArray3, Literal(-4)).nullable == !ansiEnabled) assert(ElementAt(stArray4, Literal(0)).nullable) assert(ElementAt(stArray4, Literal(4)).nullable) @@ -1897,4 +1887,32 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper Literal(stringToInterval("interval 1 year"))), Seq(Date.valueOf("2018-01-01"))) } + + test("SPARK-33391: element_at ArrayIndexOutOfBoundsException") { + Seq(true, false).foreach { ansiEnabled => + withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { + val array = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType)) + var expr: Expression = ElementAt(array, Literal(5)) + if (ansiEnabled) { + val errMsg = "Invalid index: 5" + checkExceptionInExpression[Exception](expr, errMsg) + } else { + checkEvaluation(expr, null) + } + + expr = ElementAt(array, Literal(-5)) + if (ansiEnabled) { + val errMsg = "Invalid index: -5" + checkExceptionInExpression[Exception](expr, errMsg) + } else { + checkEvaluation(expr, null) + } + + // SQL array indices start at 1 exception throws for both mode. + expr = ElementAt(array, Literal(0)) + val errMsg = "SQL array indices start at 1" + checkExceptionInExpression[Exception](expr, errMsg) + } + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 967ccc42c632d..35e613f861bc1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -18,9 +18,9 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -968,4 +968,34 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { GenerateUnsafeProjection.generate( Sentences(Literal("\"quote"), Literal("\"quote"), Literal("\"quote")) :: Nil) } + + test("SPARK-33391: elt ArrayIndexOutOfBoundsException") { + Seq(true, false).foreach { ansiEnabled => + withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { + var expr: Expression = Elt(Seq(Literal(4), Literal("123"), Literal("456"))) + if (ansiEnabled) { + val errMsg = "Invalid index: 4" + checkExceptionInExpression[Exception](expr, errMsg) + } else { + checkEvaluation(expr, null) + } + + expr = Elt(Seq(Literal(0), Literal("123"), Literal("456"))) + if (ansiEnabled) { + val errMsg = "Invalid index: 0" + checkExceptionInExpression[Exception](expr, errMsg) + } else { + checkEvaluation(expr, null) + } + + expr = Elt(Seq(Literal(-1), Literal("123"), Literal("456"))) + if (ansiEnabled) { + val errMsg = "Invalid index: -1" + checkExceptionInExpression[Exception](expr, errMsg) + } else { + checkEvaluation(expr, null) + } + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 3877dd3f7c7e5..585a835024e19 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -24,7 +24,7 @@ import scala.util.Random import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{ElementAt, Expression, ExpressionEvalHelper, Literal} +import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.plans.logical.OneRowRelation import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.{withDefaultTimeZone, UTC} @@ -36,7 +36,7 @@ import org.apache.spark.sql.types._ /** * Test suite for functions in [[org.apache.spark.sql.functions]]. */ -class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession with ExpressionEvalHelper{ +class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { import testImplicits._ test("array with column name") { @@ -3621,40 +3621,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession with Exp df.select(map(map_entries($"m"), lit(1))), Row(Map(Seq(Row(1, "a")) -> 1))) } - - test("SPARK-33391: element_at ArrayIndexOutOfBoundsException") { - Seq(true, false).foreach { ansiEnabled => - withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { - var df = sql("select element_at(array(1, 2, 3), 5)") - val array = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType)) - if (ansiEnabled) { - val errMsg = "Invalid index: 5" - val ex = intercept[Exception](df.collect()) - assert(ex.getMessage.contains(errMsg)) - checkExceptionInExpression[Exception](ElementAt(array, Literal(5)), errMsg) - } else { - checkAnswer(df, Row(null)) - } - - df = sql("select element_at(array(1, 2, 3), -5)") - if (ansiEnabled) { - val errMsg = "Invalid index: -5" - val ex = intercept[Exception](df.collect()) - assert(ex.getMessage.contains(errMsg)) - checkExceptionInExpression[Exception](ElementAt(array, Literal(-5)), errMsg) - } else { - checkAnswer(df, Row(null)) - } - - // SQL array indices start at 1 exception throws for both mode. - val errMsg = "SQL array indices start at 1" - df = sql("select element_at(array(1, 2, 3), 0)") - val ex = intercept[Exception](df.collect()) - assert(ex.getMessage.contains(errMsg)) - checkExceptionInExpression[Exception](ElementAt(array, Literal(0)), errMsg) - } - } - } } object DataFrameFunctionsSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 321f4966178d7..a6680735b885a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -33,7 +33,7 @@ import org.apache.spark.SparkException import org.apache.spark.scheduler.{SparkListener, SparkListenerJobEnd} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} -import org.apache.spark.sql.catalyst.expressions.Uuid +import org.apache.spark.sql.catalyst.expressions.{Expression, GetArrayItem, Literal, Uuid} import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, OneRowRelation} import org.apache.spark.sql.catalyst.util.DateTimeUtils @@ -2567,6 +2567,97 @@ class DataFrameSuite extends QueryTest val df = l.join(r, $"col2" === $"col4", "LeftOuter") checkAnswer(df, Row("2", "2")) } + + test("SPARK-33391: element_at ArrayIndexOutOfBoundsException") { + Seq(true, false).foreach { ansiEnabled => + withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { + var df = sql("select element_at(array(1, 2, 3), 5)") + if (ansiEnabled) { + val errMsg = "Invalid index: 5" + val ex = intercept[Exception](df.collect()) + assert(ex.getMessage.contains(errMsg)) + } else { + checkAnswer(df, Row(null)) + } + + df = sql("select element_at(array(1, 2, 3), -5)") + if (ansiEnabled) { + val errMsg = "Invalid index: -5" + val ex = intercept[Exception](df.collect()) + assert(ex.getMessage.contains(errMsg)) + } else { + checkAnswer(df, Row(null)) + } + + // SQL array indices start at 1 exception throws for both mode. + val errMsg = "SQL array indices start at 1" + df = sql("select element_at(array(1, 2, 3), 0)") + val ex = intercept[Exception](df.collect()) + assert(ex.getMessage.contains(errMsg)) + } + } + } + + test("SPARK-33391: elt ArrayIndexOutOfBoundsException") { + Seq(true, false).foreach { ansiEnabled => + withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { + var df = sql("select elt(4, '123', '456')") + if (ansiEnabled) { + val errMsg = "Invalid index: 4" + val ex = intercept[Exception](df.collect()) + assert(ex.getMessage.contains(errMsg)) + } else { + checkAnswer(df, Row(null)) + } + + df = sql("select elt(0, '123', '456')") + if (ansiEnabled) { + val errMsg = "Invalid index: 0" + val ex = intercept[Exception](df.collect()) + assert(ex.getMessage.contains(errMsg)) + } else { + checkAnswer(df, Row(null)) + } + + df = sql("select elt(-1, '123', '456')") + if (ansiEnabled) { + val errMsg = "Invalid index: -1" + val ex = intercept[Exception](df.collect()) + assert(ex.getMessage.contains(errMsg)) + } else { + checkAnswer(df, Row(null)) + } + } + } + } + + test("SPARK-33391: GetArrayItem ArrayIndexOutOfBoundsException") { + def getArrayItem(child: Expression, ordinal: Expression): Column = { + Column(GetArrayItem(child, ordinal)) + } + Seq(true, false).foreach { ansiEnabled => + withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { + val array = Literal.create(Seq("a", "b"), ArrayType(StringType)) + var df = testData.where($"key" === lit(1)).select(getArrayItem(array, Literal(5))) + if (ansiEnabled) { + val errMsg = "Invalid index: 5" + val ex = intercept[Exception](df.collect()) + assert(ex.getMessage.contains(errMsg)) + } else { + checkAnswer(df, Row(null)) + } + + df = testData.where($"key" === lit(1)).select(getArrayItem(array, Literal(-1))) + if (ansiEnabled) { + val errMsg = "Invalid index: -1" + val ex = intercept[Exception](df.collect()) + assert(ex.getMessage.contains(errMsg)) + } else { + checkAnswer(df, Row(null)) + } + } + } + } } case class GroupByKey(a: Int, b: Int) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index ea2a1398cefd7..8d5166b5398cc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -17,13 +17,12 @@ package org.apache.spark.sql -import org.apache.spark.sql.catalyst.expressions.{Elt, ExpressionEvalHelper, Literal} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession -class StringFunctionsSuite extends QueryTest with SharedSparkSession with ExpressionEvalHelper { +class StringFunctionsSuite extends QueryTest with SharedSparkSession { import testImplicits._ test("string concat") { @@ -618,46 +617,4 @@ class StringFunctionsSuite extends QueryTest with SharedSparkSession with Expres ) } - - test("SPARK-33391: elt ArrayIndexOutOfBoundsException") { - Seq(true, false).foreach { ansiEnabled => - withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { - var df = sql("select elt(4, '123', '456')") - if (ansiEnabled) { - val errMsg = "Invalid index: 4" - val ex = intercept[Exception](df.collect()) - assert(ex.getMessage.contains(errMsg)) - checkExceptionInExpression[Exception]( - Elt(Seq(Literal(4), Literal("123"), Literal("456"))), - errMsg) - } else { - checkAnswer(df, Row(null)) - } - - df = sql("select elt(0, '123', '456')") - if (ansiEnabled) { - val errMsg = "Invalid index: 0" - val ex = intercept[Exception](df.collect()) - assert(ex.getMessage.contains(errMsg)) - checkExceptionInExpression[Exception]( - Elt(Seq(Literal(0), Literal("123"), Literal("456"))), - errMsg) - } else { - checkAnswer(df, Row(null)) - } - - df = sql("select elt(-1, '123', '456')") - if (ansiEnabled) { - val errMsg = "Invalid index: -1" - val ex = intercept[Exception](df.collect()) - assert(ex.getMessage.contains(errMsg)) - checkExceptionInExpression[Exception]( - Elt(Seq(Literal(-1), Literal("123"), Literal("456"))), - errMsg) - } else { - checkAnswer(df, Row(null)) - } - } - } - } } From 610983835626ae5afb7bc7fd6ec4efa0aec9f548 Mon Sep 17 00:00:00 2001 From: "xuewei.linxuewei" Date: Wed, 11 Nov 2020 14:58:01 +0800 Subject: [PATCH 4/8] update doc. Change-Id: If309d81022820ec69bf891b2707acb7f60c974cb --- docs/sql-ref-ansi-compliance.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/sql-ref-ansi-compliance.md b/docs/sql-ref-ansi-compliance.md index 3efd208706fdc..0b827c13d5b14 100644 --- a/docs/sql-ref-ansi-compliance.md +++ b/docs/sql-ref-ansi-compliance.md @@ -117,7 +117,7 @@ The behavior of some SQL functions can be different under ANSI mode (`spark.sql. ### SQL Operators The behavior of some SQL operators can be different under ANSI mode (`spark.sql.ansi.enabled=true`). - - `GetArrayItem`: This operator throws `ArrayIndexOutOfBoundsException` if using invalid indices. + - `array_col[index]`: This operator throws `ArrayIndexOutOfBoundsException` if using invalid indices. ### SQL Keywords From a9312a0546e891266323423e007f65d78ce49ff4 Mon Sep 17 00:00:00 2001 From: "xuewei.linxuewei" Date: Wed, 11 Nov 2020 15:47:19 +0800 Subject: [PATCH 5/8] doc refine. Change-Id: If944a02c596fe5a27edb48fff1b7e02f9e4de849 --- docs/sql-ref-ansi-compliance.md | 2 +- .../sql/catalyst/expressions/collectionOperations.scala | 5 +++-- .../spark/sql/catalyst/expressions/stringExpressions.scala | 5 +++-- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/docs/sql-ref-ansi-compliance.md b/docs/sql-ref-ansi-compliance.md index 0b827c13d5b14..c2b36033e318e 100644 --- a/docs/sql-ref-ansi-compliance.md +++ b/docs/sql-ref-ansi-compliance.md @@ -110,7 +110,7 @@ SELECT * FROM t; ### SQL Functions The behavior of some SQL functions can be different under ANSI mode (`spark.sql.ansi.enabled=true`). - - `size`: This function returns null for null input under ANSI mode. + - `size`: This function returns null for null input. - `element_at`: This function throws `ArrayIndexOutOfBoundsException` if using invalid indices. - `elt`: This function throws `ArrayIndexOutOfBoundsException` if using invalid indices. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 5f23bf4c9d68b..f4732d8fe0d13 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1906,8 +1906,9 @@ case class ArrayPosition(left: Expression, right: Expression) @ExpressionDescription( usage = """ _FUNC_(array, index) - Returns element of array at given (1-based) index. If index < 0, - accesses elements from the last to the first. If the index exceeds the length of the array, - Returns NULL if ANSI mode is off; Throws ArrayIndexOutOfBoundsException when ANSI mode is on. + accesses elements from the last to the first. The function returns NULL + if the index exceeds the length of the array and `spark.sql.ansi.enabled` is set to false. + If `spark.sql.ansi.enabled` is set to true, it throws ArrayIndexOutOfBoundsException for invalid indices. _FUNC_(map, key) - Returns value for given key, or NULL if the key is not contained in the map """, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index b0f20807903bd..fb5ba488d2e5f 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -234,8 +234,9 @@ case class ConcatWs(children: Seq[Expression]) @ExpressionDescription( usage = """ _FUNC_(n, input1, input2, ...) - Returns the `n`-th input, e.g., returns `input2` when `n` is 2. - If the index exceeds the length of the array, Returns NULL if ANSI mode is off; - Throws ArrayIndexOutOfBoundsException when ANSI mode is on. + The function returns NULL if the index exceeds the length of the array + and `spark.sql.ansi.enabled` is set to false. If `spark.sql.ansi.enabled` is set to true, + it throws ArrayIndexOutOfBoundsException for invalid indices. """, examples = """ Examples: From 6729d7bcf5f458d4feaac8f0ebc05d0688444405 Mon Sep 17 00:00:00 2001 From: "xuewei.linxuewei" Date: Wed, 11 Nov 2020 15:59:47 +0800 Subject: [PATCH 6/8] line 100 exceed. Change-Id: Icaa8aecd5bd21812b66b216518beebd608251fb8 --- .../spark/sql/catalyst/expressions/collectionOperations.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index f4732d8fe0d13..c835349b3ca21 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1908,7 +1908,8 @@ case class ArrayPosition(left: Expression, right: Expression) _FUNC_(array, index) - Returns element of array at given (1-based) index. If index < 0, accesses elements from the last to the first. The function returns NULL if the index exceeds the length of the array and `spark.sql.ansi.enabled` is set to false. - If `spark.sql.ansi.enabled` is set to true, it throws ArrayIndexOutOfBoundsException for invalid indices. + If `spark.sql.ansi.enabled` is set to true, it throws ArrayIndexOutOfBoundsException + for invalid indices. _FUNC_(map, key) - Returns value for given key, or NULL if the key is not contained in the map """, From 36d235ab0941320613adff066276887404c3c8b6 Mon Sep 17 00:00:00 2001 From: "xuewei.linxuewei" Date: Thu, 12 Nov 2020 11:53:35 +0800 Subject: [PATCH 7/8] 1. fix GetArrayItem nullablity issue. 2. move DataFrameSuite code into array.sql and ansi/array.sql 3. add numElements to exception message. 4. other code refine. Change-Id: Ieb322ed7b036fc3322fd3b814c8508bfef266378 --- .../expressions/collectionOperations.scala | 27 +- .../expressions/complexTypeExtractors.scala | 37 ++- .../expressions/stringExpressions.scala | 6 +- .../apache/spark/sql/internal/SQLConf.scala | 2 +- .../CollectionExpressionsSuite.scala | 6 +- .../expressions/ComplexTypeSuite.scala | 6 +- .../expressions/StringExpressionsSuite.scala | 8 +- .../resources/sql-tests/inputs/ansi/array.sql | 1 + .../test/resources/sql-tests/inputs/array.sql | 12 + .../sql-tests/results/ansi/array.sql.out | 234 ++++++++++++++++++ .../resources/sql-tests/results/array.sql.out | 67 ++++- .../org/apache/spark/sql/DataFrameSuite.scala | 93 +------ 12 files changed, 364 insertions(+), 135 deletions(-) create mode 100644 sql/core/src/test/resources/sql-tests/inputs/ansi/array.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/ansi/array.sql.out diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index c835349b3ca21..075c1aa057901 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1986,24 +1986,9 @@ case class ElementAt( } } - override def computeNullabilityFromArray(child: Expression, ordinal: Expression): Boolean = { - if (ordinal.foldable && !ordinal.nullable) { - val intOrdinal = ordinal.eval().asInstanceOf[Number].intValue() - child match { - case CreateArray(ar, _) => - nullability(ar, intOrdinal) - case GetArrayStructFields(CreateArray(elements, _), field, _, _, _) => - nullability(elements, intOrdinal) || field.nullable - case _ => - true - } - } else { - if (failOnError) arrayContainsNull else true - } - } - override def nullable: Boolean = left.dataType match { - case _: ArrayType => computeNullabilityFromArray(left, right) + case _: ArrayType => + computeNullabilityFromArray(left, right, failOnError, nullability) case _: MapType => true } @@ -2016,7 +2001,8 @@ case class ElementAt( val index = ordinal.asInstanceOf[Int] if (array.numElements() < math.abs(index)) { if (failOnError) { - throw new ArrayIndexOutOfBoundsException(s"Invalid index: $index") + throw new ArrayIndexOutOfBoundsException( + s"Invalid index: $index, numElements: ${array.numElements()}") } else { null } @@ -2055,7 +2041,10 @@ case class ElementAt( } val failOnErrorBranch = if (failOnError) { - s"""throw new ArrayIndexOutOfBoundsException("Invalid index: " + $index);""".stripMargin + s"""throw new ArrayIndexOutOfBoundsException( + | "Invalid index: " + $index + ", numElements: " + $eval1.numElements() + |); + """.stripMargin } else { s"${ev.isNull} = true;" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index f3896d560566d..9cc7f169f19b1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -240,15 +240,25 @@ case class GetArrayItem( override def left: Expression = child override def right: Expression = ordinal - override def nullable: Boolean = computeNullabilityFromArray(left, right) + override def nullable: Boolean = + computeNullabilityFromArray(left, right, failOnError, nullability) override def dataType: DataType = child.dataType.asInstanceOf[ArrayType].elementType + private def nullability(elements: Seq[Expression], ordinal: Int): Boolean = { + if (ordinal >= 0 && ordinal < elements.length) { + elements(ordinal).nullable + } else { + !failOnError + } + } + protected override def nullSafeEval(value: Any, ordinal: Any): Any = { val baseValue = value.asInstanceOf[ArrayData] val index = ordinal.asInstanceOf[Number].intValue() if (index >= baseValue.numElements() || index < 0) { if (failOnError) { - throw new ArrayIndexOutOfBoundsException(s"Invalid index: $index") + throw new ArrayIndexOutOfBoundsException( + s"Invalid index: $index, numElements: ${baseValue.numElements()}") } else { null } @@ -272,7 +282,10 @@ case class GetArrayItem( } val failOnErrorBranch = if (failOnError) { - s"""throw new ArrayIndexOutOfBoundsException("Invalid index: " + $index);""".stripMargin + s"""throw new ArrayIndexOutOfBoundsException( + | "Invalid index: " + $index + ", numElements: " + $eval1.numElements() + |); + """.stripMargin } else { s"${ev.isNull} = true;" } @@ -295,20 +308,24 @@ case class GetArrayItem( trait GetArrayItemUtil { /** `Null` is returned for invalid ordinals. */ - protected def computeNullabilityFromArray(child: Expression, ordinal: Expression): Boolean = { + protected def computeNullabilityFromArray( + child: Expression, + ordinal: Expression, + failOnError: Boolean, + nullability: (Seq[Expression], Int) => Boolean): Boolean = { + val arrayContainsNull = child.dataType.asInstanceOf[ArrayType].containsNull if (ordinal.foldable && !ordinal.nullable) { val intOrdinal = ordinal.eval().asInstanceOf[Number].intValue() child match { - case CreateArray(ar, _) if intOrdinal < ar.length => - ar(intOrdinal).nullable - case GetArrayStructFields(CreateArray(elements, _), field, _, _, _) - if intOrdinal < elements.length => - elements(intOrdinal).nullable || field.nullable + case CreateArray(ar, _) => + nullability(ar, intOrdinal) + case GetArrayStructFields(CreateArray(elements, _), field, _, _, _) => + nullability(elements, intOrdinal) || field.nullable case _ => true } } else { - true + if (failOnError) arrayContainsNull else true } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index fb5ba488d2e5f..afbe95e92d0f7 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -286,7 +286,8 @@ case class Elt( val index = indexObj.asInstanceOf[Int] if (index <= 0 || index > inputExprs.length) { if (failOnError) { - throw new ArrayIndexOutOfBoundsException(s"Invalid index: $index") + throw new ArrayIndexOutOfBoundsException( + s"Invalid index: $index, numElements: ${inputExprs.length}") } else { null } @@ -340,7 +341,8 @@ case class Elt( val failOnErrorBranch = if (failOnError) { s""" |if (!$indexMatched) { - | throw new ArrayIndexOutOfBoundsException("Invalid index: " + ${index.value}); + | throw new ArrayIndexOutOfBoundsException( + | "Invalid index: " + ${index.value} + ", numElements: " + ${inputExprs.length}); |} """.stripMargin } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 5a19e35e8c8e1..ef988052affcd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2147,7 +2147,7 @@ object SQLConf { "throw an exception at runtime if the inputs to a SQL operator/function are invalid, " + "e.g. overflow in arithmetic operations, out-of-range index when accessing array elements. " + "2. Spark will forbid using the reserved keywords of ANSI SQL as identifiers in " + - "the SQL parser. 3. Spark will returns null for null input for function `size`.") + "the SQL parser. 3. Spark will return NULL for null input for function `size`.") .version("3.0.0") .booleanConf .createWithDefault(false) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index f65d11a78abc2..6ee88c9eaef86 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -1888,13 +1888,13 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper Seq(Date.valueOf("2018-01-01"))) } - test("SPARK-33391: element_at ArrayIndexOutOfBoundsException") { + test("SPARK-33386: element_at ArrayIndexOutOfBoundsException") { Seq(true, false).foreach { ansiEnabled => withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { val array = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType)) var expr: Expression = ElementAt(array, Literal(5)) if (ansiEnabled) { - val errMsg = "Invalid index: 5" + val errMsg = "Invalid index: 5, numElements: 3" checkExceptionInExpression[Exception](expr, errMsg) } else { checkEvaluation(expr, null) @@ -1902,7 +1902,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper expr = ElementAt(array, Literal(-5)) if (ansiEnabled) { - val errMsg = "Invalid index: -5" + val errMsg = "Invalid index: -5, numElements: 3" checkExceptionInExpression[Exception](expr, errMsg) } else { checkEvaluation(expr, null) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index 2eddf0ff2888e..67ab2071de037 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -62,7 +62,7 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(GetArrayItem(nestedArray, Literal(0)), Seq(1)) } - test("SPARK-33391: GetArrayItem ArrayIndexOutOfBoundsException") { + test("SPARK-33386: GetArrayItem ArrayIndexOutOfBoundsException") { Seq(true, false).foreach { ansiEnabled => withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { val array = Literal.create(Seq("a", "b"), ArrayType(StringType)) @@ -70,12 +70,12 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { if (ansiEnabled) { checkExceptionInExpression[Exception]( GetArrayItem(array, Literal(5)), - "Invalid index: 5" + "Invalid index: 5, numElements: 2" ) checkExceptionInExpression[Exception]( GetArrayItem(array, Literal(-1)), - "Invalid index: -1" + "Invalid index: -1, numElements: 2" ) } else { checkEvaluation(GetArrayItem(array, Literal(5)), null) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 35e613f861bc1..a1b6cec24f23f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -969,12 +969,12 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { Sentences(Literal("\"quote"), Literal("\"quote"), Literal("\"quote")) :: Nil) } - test("SPARK-33391: elt ArrayIndexOutOfBoundsException") { + test("SPARK-33386: elt ArrayIndexOutOfBoundsException") { Seq(true, false).foreach { ansiEnabled => withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { var expr: Expression = Elt(Seq(Literal(4), Literal("123"), Literal("456"))) if (ansiEnabled) { - val errMsg = "Invalid index: 4" + val errMsg = "Invalid index: 4, numElements: 2" checkExceptionInExpression[Exception](expr, errMsg) } else { checkEvaluation(expr, null) @@ -982,7 +982,7 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { expr = Elt(Seq(Literal(0), Literal("123"), Literal("456"))) if (ansiEnabled) { - val errMsg = "Invalid index: 0" + val errMsg = "Invalid index: 0, numElements: 2" checkExceptionInExpression[Exception](expr, errMsg) } else { checkEvaluation(expr, null) @@ -990,7 +990,7 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { expr = Elt(Seq(Literal(-1), Literal("123"), Literal("456"))) if (ansiEnabled) { - val errMsg = "Invalid index: -1" + val errMsg = "Invalid index: -1, numElements: 2" checkExceptionInExpression[Exception](expr, errMsg) } else { checkEvaluation(expr, null) diff --git a/sql/core/src/test/resources/sql-tests/inputs/ansi/array.sql b/sql/core/src/test/resources/sql-tests/inputs/ansi/array.sql new file mode 100644 index 0000000000000..662756cbfb0b0 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/ansi/array.sql @@ -0,0 +1 @@ +--IMPORT array.sql diff --git a/sql/core/src/test/resources/sql-tests/inputs/array.sql b/sql/core/src/test/resources/sql-tests/inputs/array.sql index 984321ab795fc..f73b653659eb4 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/array.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/array.sql @@ -90,3 +90,15 @@ select size(date_array), size(timestamp_array) from primitive_arrays; + +-- index out of range for array elements +select element_at(array(1, 2, 3), 5); +select element_at(array(1, 2, 3), -5); +select element_at(array(1, 2, 3), 0); + +select elt(4, '123', '456'); +select elt(0, '123', '456'); +select elt(-1, '123', '456'); + +select array(1, 2, 3)[5]; +select array(1, 2, 3)[-1]; diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/array.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/array.sql.out new file mode 100644 index 0000000000000..12a77e36273fa --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/ansi/array.sql.out @@ -0,0 +1,234 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 20 + + +-- !query +create temporary view data as select * from values + ("one", array(11, 12, 13), array(array(111, 112, 113), array(121, 122, 123))), + ("two", array(21, 22, 23), array(array(211, 212, 213), array(221, 222, 223))) + as data(a, b, c) +-- !query schema +struct<> +-- !query output + + + +-- !query +select * from data +-- !query schema +struct,c:array>> +-- !query output +one [11,12,13] [[111,112,113],[121,122,123]] +two [21,22,23] [[211,212,213],[221,222,223]] + + +-- !query +select a, b[0], b[0] + b[1] from data +-- !query schema +struct +-- !query output +one 11 23 +two 21 43 + + +-- !query +select a, c[0][0] + c[0][0 + 1] from data +-- !query schema +struct +-- !query output +one 223 +two 423 + + +-- !query +create temporary view primitive_arrays as select * from values ( + array(true), + array(2Y, 1Y), + array(2S, 1S), + array(2, 1), + array(2L, 1L), + array(9223372036854775809, 9223372036854775808), + array(2.0D, 1.0D), + array(float(2.0), float(1.0)), + array(date '2016-03-14', date '2016-03-13'), + array(timestamp '2016-11-15 20:54:00.000', timestamp '2016-11-12 20:54:00.000') +) as primitive_arrays( + boolean_array, + tinyint_array, + smallint_array, + int_array, + bigint_array, + decimal_array, + double_array, + float_array, + date_array, + timestamp_array +) +-- !query schema +struct<> +-- !query output + + + +-- !query +select * from primitive_arrays +-- !query schema +struct,tinyint_array:array,smallint_array:array,int_array:array,bigint_array:array,decimal_array:array,double_array:array,float_array:array,date_array:array,timestamp_array:array> +-- !query output +[true] [2,1] [2,1] [2,1] [2,1] [9223372036854775809,9223372036854775808] [2.0,1.0] [2.0,1.0] [2016-03-14,2016-03-13] [2016-11-15 20:54:00,2016-11-12 20:54:00] + + +-- !query +select + array_contains(boolean_array, true), array_contains(boolean_array, false), + array_contains(tinyint_array, 2Y), array_contains(tinyint_array, 0Y), + array_contains(smallint_array, 2S), array_contains(smallint_array, 0S), + array_contains(int_array, 2), array_contains(int_array, 0), + array_contains(bigint_array, 2L), array_contains(bigint_array, 0L), + array_contains(decimal_array, 9223372036854775809), array_contains(decimal_array, 1), + array_contains(double_array, 2.0D), array_contains(double_array, 0.0D), + array_contains(float_array, float(2.0)), array_contains(float_array, float(0.0)), + array_contains(date_array, date '2016-03-14'), array_contains(date_array, date '2016-01-01'), + array_contains(timestamp_array, timestamp '2016-11-15 20:54:00.000'), array_contains(timestamp_array, timestamp '2016-01-01 20:54:00.000') +from primitive_arrays +-- !query schema +struct +-- !query output +true false true false true false true false true false true false true false true false true false true false + + +-- !query +select array_contains(b, 11), array_contains(c, array(111, 112, 113)) from data +-- !query schema +struct +-- !query output +false false +true true + + +-- !query +select + sort_array(boolean_array), + sort_array(tinyint_array), + sort_array(smallint_array), + sort_array(int_array), + sort_array(bigint_array), + sort_array(decimal_array), + sort_array(double_array), + sort_array(float_array), + sort_array(date_array), + sort_array(timestamp_array) +from primitive_arrays +-- !query schema +struct,sort_array(tinyint_array, true):array,sort_array(smallint_array, true):array,sort_array(int_array, true):array,sort_array(bigint_array, true):array,sort_array(decimal_array, true):array,sort_array(double_array, true):array,sort_array(float_array, true):array,sort_array(date_array, true):array,sort_array(timestamp_array, true):array> +-- !query output +[true] [1,2] [1,2] [1,2] [1,2] [9223372036854775808,9223372036854775809] [1.0,2.0] [1.0,2.0] [2016-03-13,2016-03-14] [2016-11-12 20:54:00,2016-11-15 20:54:00] + + +-- !query +select sort_array(array('b', 'd'), '1') +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +cannot resolve 'sort_array(array('b', 'd'), '1')' due to data type mismatch: Sort order in second argument requires a boolean literal.; line 1 pos 7 + + +-- !query +select sort_array(array('b', 'd'), cast(NULL as boolean)) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +cannot resolve 'sort_array(array('b', 'd'), CAST(NULL AS BOOLEAN))' due to data type mismatch: Sort order in second argument requires a boolean literal.; line 1 pos 7 + + +-- !query +select + size(boolean_array), + size(tinyint_array), + size(smallint_array), + size(int_array), + size(bigint_array), + size(decimal_array), + size(double_array), + size(float_array), + size(date_array), + size(timestamp_array) +from primitive_arrays +-- !query schema +struct +-- !query output +1 2 2 2 2 2 2 2 2 2 + + +-- !query +select element_at(array(1, 2, 3), 5) +-- !query schema +struct<> +-- !query output +java.lang.ArrayIndexOutOfBoundsException +Invalid index: 5, numElements: 3 + + +-- !query +select element_at(array(1, 2, 3), -5) +-- !query schema +struct<> +-- !query output +java.lang.ArrayIndexOutOfBoundsException +Invalid index: -5, numElements: 3 + + +-- !query +select element_at(array(1, 2, 3), 0) +-- !query schema +struct<> +-- !query output +java.lang.ArrayIndexOutOfBoundsException +SQL array indices start at 1 + + +-- !query +select elt(4, '123', '456') +-- !query schema +struct<> +-- !query output +java.lang.ArrayIndexOutOfBoundsException +Invalid index: 4, numElements: 2 + + +-- !query +select elt(0, '123', '456') +-- !query schema +struct<> +-- !query output +java.lang.ArrayIndexOutOfBoundsException +Invalid index: 0, numElements: 2 + + +-- !query +select elt(-1, '123', '456') +-- !query schema +struct<> +-- !query output +java.lang.ArrayIndexOutOfBoundsException +Invalid index: -1, numElements: 2 + + +-- !query +select array(1, 2, 3)[5] +-- !query schema +struct<> +-- !query output +java.lang.ArrayIndexOutOfBoundsException +Invalid index: 5, numElements: 3 + + +-- !query +select array(1, 2, 3)[-1] +-- !query schema +struct<> +-- !query output +java.lang.ArrayIndexOutOfBoundsException +Invalid index: -1, numElements: 3 diff --git a/sql/core/src/test/resources/sql-tests/results/array.sql.out b/sql/core/src/test/resources/sql-tests/results/array.sql.out index 2c2b1a7856304..9bf0d89ed71fe 100644 --- a/sql/core/src/test/resources/sql-tests/results/array.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/array.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 12 +-- Number of queries: 20 -- !query @@ -160,3 +160,68 @@ from primitive_arrays struct -- !query output 1 2 2 2 2 2 2 2 2 2 + + +-- !query +select element_at(array(1, 2, 3), 5) +-- !query schema +struct +-- !query output +NULL + + +-- !query +select element_at(array(1, 2, 3), -5) +-- !query schema +struct +-- !query output +NULL + + +-- !query +select element_at(array(1, 2, 3), 0) +-- !query schema +struct<> +-- !query output +java.lang.ArrayIndexOutOfBoundsException +SQL array indices start at 1 + + +-- !query +select elt(4, '123', '456') +-- !query schema +struct +-- !query output +NULL + + +-- !query +select elt(0, '123', '456') +-- !query schema +struct +-- !query output +NULL + + +-- !query +select elt(-1, '123', '456') +-- !query schema +struct +-- !query output +NULL + + +-- !query +select array(1, 2, 3)[5] +-- !query schema +struct +-- !query output +NULL + + +-- !query +select array(1, 2, 3)[-1] +-- !query schema +struct +-- !query output +NULL diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index a6680735b885a..321f4966178d7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -33,7 +33,7 @@ import org.apache.spark.SparkException import org.apache.spark.scheduler.{SparkListener, SparkListenerJobEnd} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} -import org.apache.spark.sql.catalyst.expressions.{Expression, GetArrayItem, Literal, Uuid} +import org.apache.spark.sql.catalyst.expressions.Uuid import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, OneRowRelation} import org.apache.spark.sql.catalyst.util.DateTimeUtils @@ -2567,97 +2567,6 @@ class DataFrameSuite extends QueryTest val df = l.join(r, $"col2" === $"col4", "LeftOuter") checkAnswer(df, Row("2", "2")) } - - test("SPARK-33391: element_at ArrayIndexOutOfBoundsException") { - Seq(true, false).foreach { ansiEnabled => - withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { - var df = sql("select element_at(array(1, 2, 3), 5)") - if (ansiEnabled) { - val errMsg = "Invalid index: 5" - val ex = intercept[Exception](df.collect()) - assert(ex.getMessage.contains(errMsg)) - } else { - checkAnswer(df, Row(null)) - } - - df = sql("select element_at(array(1, 2, 3), -5)") - if (ansiEnabled) { - val errMsg = "Invalid index: -5" - val ex = intercept[Exception](df.collect()) - assert(ex.getMessage.contains(errMsg)) - } else { - checkAnswer(df, Row(null)) - } - - // SQL array indices start at 1 exception throws for both mode. - val errMsg = "SQL array indices start at 1" - df = sql("select element_at(array(1, 2, 3), 0)") - val ex = intercept[Exception](df.collect()) - assert(ex.getMessage.contains(errMsg)) - } - } - } - - test("SPARK-33391: elt ArrayIndexOutOfBoundsException") { - Seq(true, false).foreach { ansiEnabled => - withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { - var df = sql("select elt(4, '123', '456')") - if (ansiEnabled) { - val errMsg = "Invalid index: 4" - val ex = intercept[Exception](df.collect()) - assert(ex.getMessage.contains(errMsg)) - } else { - checkAnswer(df, Row(null)) - } - - df = sql("select elt(0, '123', '456')") - if (ansiEnabled) { - val errMsg = "Invalid index: 0" - val ex = intercept[Exception](df.collect()) - assert(ex.getMessage.contains(errMsg)) - } else { - checkAnswer(df, Row(null)) - } - - df = sql("select elt(-1, '123', '456')") - if (ansiEnabled) { - val errMsg = "Invalid index: -1" - val ex = intercept[Exception](df.collect()) - assert(ex.getMessage.contains(errMsg)) - } else { - checkAnswer(df, Row(null)) - } - } - } - } - - test("SPARK-33391: GetArrayItem ArrayIndexOutOfBoundsException") { - def getArrayItem(child: Expression, ordinal: Expression): Column = { - Column(GetArrayItem(child, ordinal)) - } - Seq(true, false).foreach { ansiEnabled => - withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { - val array = Literal.create(Seq("a", "b"), ArrayType(StringType)) - var df = testData.where($"key" === lit(1)).select(getArrayItem(array, Literal(5))) - if (ansiEnabled) { - val errMsg = "Invalid index: 5" - val ex = intercept[Exception](df.collect()) - assert(ex.getMessage.contains(errMsg)) - } else { - checkAnswer(df, Row(null)) - } - - df = testData.where($"key" === lit(1)).select(getArrayItem(array, Literal(-1))) - if (ansiEnabled) { - val errMsg = "Invalid index: -1" - val ex = intercept[Exception](df.collect()) - assert(ex.getMessage.contains(errMsg)) - } else { - checkAnswer(df, Row(null)) - } - } - } - } } case class GroupByKey(a: Int, b: Int) From 4b161a4ef4f3be2f1a239c0c19f8aa62b5a03706 Mon Sep 17 00:00:00 2001 From: "xuewei.linxuewei" Date: Thu, 12 Nov 2020 13:24:41 +0800 Subject: [PATCH 8/8] naming. Change-Id: I719078352ab04f3ab096e62fd2b6a1c06bbdffd3 --- .../spark/sql/catalyst/expressions/collectionOperations.scala | 4 ++-- .../sql/catalyst/expressions/complexTypeExtractors.scala | 4 ++-- .../spark/sql/catalyst/expressions/stringExpressions.scala | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 075c1aa057901..ee98ebf5a8a50 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -2040,7 +2040,7 @@ case class ElementAt( "" } - val failOnErrorBranch = if (failOnError) { + val indexOutOfBoundBranch = if (failOnError) { s"""throw new ArrayIndexOutOfBoundsException( | "Invalid index: " + $index + ", numElements: " + $eval1.numElements() |); @@ -2052,7 +2052,7 @@ case class ElementAt( s""" |int $index = (int) $eval2; |if ($eval1.numElements() < Math.abs($index)) { - | $failOnErrorBranch + | $indexOutOfBoundBranch |} else { | if ($index == 0) { | throw new ArrayIndexOutOfBoundsException("SQL array indices start at 1"); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 9cc7f169f19b1..363d388692c9f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -281,7 +281,7 @@ case class GetArrayItem( "" } - val failOnErrorBranch = if (failOnError) { + val indexOutOfBoundBranch = if (failOnError) { s"""throw new ArrayIndexOutOfBoundsException( | "Invalid index: " + $index + ", numElements: " + $eval1.numElements() |); @@ -293,7 +293,7 @@ case class GetArrayItem( s""" final int $index = (int) $eval2; if ($index >= $eval1.numElements() || $index < 0) { - $failOnErrorBranch + $indexOutOfBoundBranch } $nullCheck else { ${ev.value} = ${CodeGenerator.getValue(eval1, dataType, index)}; } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index afbe95e92d0f7..16e22940495f1 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -338,7 +338,7 @@ case class Elt( """.stripMargin }.mkString) - val failOnErrorBranch = if (failOnError) { + val indexOutOfBoundBranch = if (failOnError) { s""" |if (!$indexMatched) { | throw new ArrayIndexOutOfBoundsException( @@ -358,7 +358,7 @@ case class Elt( |do { | $codes |} while (false); - |$failOnErrorBranch + |$indexOutOfBoundBranch |final ${CodeGenerator.javaType(dataType)} ${ev.value} = $inputVal; |final boolean ${ev.isNull} = ${ev.value} == null; """.stripMargin)