-
Notifications
You must be signed in to change notification settings - Fork 28.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-33386][SQL] Accessing array elements in ElementAt/Elt/GetArrayItem should failed if index is out of bound #30297
Changes from 6 commits
da4f5ee
7dcba56
f8cfa5b
6109838
a9312a0
6729d7b
36d235a
4b161a4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1906,8 +1906,10 @@ case class ArrayPosition(left: Expression, right: Expression) | |
@ExpressionDescription( | ||
usage = """ | ||
_FUNC_(array, index) - Returns element of array at given (1-based) index. If index < 0, | ||
accesses elements from the last to the first. Returns NULL if the index exceeds the length | ||
of the array. | ||
accesses elements from the last to the first. The function returns NULL | ||
if the index exceeds the length of the array and `spark.sql.ansi.enabled` is set to false. | ||
If `spark.sql.ansi.enabled` is set to true, it throws ArrayIndexOutOfBoundsException | ||
for invalid indices. | ||
|
||
_FUNC_(map, key) - Returns value for given key, or NULL if the key is not contained in the map | ||
""", | ||
|
@@ -1919,9 +1921,14 @@ case class ArrayPosition(left: Expression, right: Expression) | |
b | ||
""", | ||
since = "2.4.0") | ||
case class ElementAt(left: Expression, right: Expression) | ||
case class ElementAt( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please update the usage above of |
||
left: Expression, | ||
right: Expression, | ||
failOnError: Boolean = SQLConf.get.ansiEnabled) | ||
viirya marked this conversation as resolved.
Show resolved
Hide resolved
|
||
extends GetMapValueUtil with GetArrayItemUtil with NullIntolerant { | ||
|
||
def this(left: Expression, right: Expression) = this(left, right, SQLConf.get.ansiEnabled) | ||
|
||
@transient private lazy val mapKeyType = left.dataType.asInstanceOf[MapType].keyType | ||
|
||
@transient private lazy val arrayContainsNull = left.dataType.asInstanceOf[ArrayType].containsNull | ||
|
@@ -1969,7 +1976,7 @@ case class ElementAt(left: Expression, right: Expression) | |
if (ordinal == 0) { | ||
false | ||
} else if (elements.length < math.abs(ordinal)) { | ||
true | ||
!failOnError | ||
} else { | ||
if (ordinal < 0) { | ||
elements(elements.length + ordinal).nullable | ||
|
@@ -1991,7 +1998,7 @@ case class ElementAt(left: Expression, right: Expression) | |
true | ||
} | ||
} else { | ||
true | ||
if (failOnError) arrayContainsNull else true | ||
} | ||
} | ||
|
||
|
@@ -2008,7 +2015,11 @@ case class ElementAt(left: Expression, right: Expression) | |
val array = value.asInstanceOf[ArrayData] | ||
val index = ordinal.asInstanceOf[Int] | ||
if (array.numElements() < math.abs(index)) { | ||
null | ||
if (failOnError) { | ||
throw new ArrayIndexOutOfBoundsException(s"Invalid index: $index") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: can we include the total number of elements too in the error message? sometimes that is helpful for debugging. |
||
} else { | ||
null | ||
} | ||
} else { | ||
val idx = if (index == 0) { | ||
throw new ArrayIndexOutOfBoundsException("SQL array indices start at 1") | ||
|
@@ -2042,10 +2053,17 @@ case class ElementAt(left: Expression, right: Expression) | |
} else { | ||
"" | ||
} | ||
|
||
val failOnErrorBranch = if (failOnError) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: it should be |
||
s"""throw new ArrayIndexOutOfBoundsException("Invalid index: " + $index);""".stripMargin | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shall we remove |
||
} else { | ||
s"${ev.isNull} = true;" | ||
} | ||
|
||
s""" | ||
|int $index = (int) $eval2; | ||
|if ($eval1.numElements() < Math.abs($index)) { | ||
| ${ev.isNull} = true; | ||
| $failOnErrorBranch | ||
|} else { | ||
| if ($index == 0) { | ||
| throw new ArrayIndexOutOfBoundsException("SQL array indices start at 1"); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.InternalRow | |
import org.apache.spark.sql.catalyst.analysis._ | ||
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} | ||
import org.apache.spark.sql.catalyst.util.{quoteIdentifier, ArrayData, GenericArrayData, MapData, TypeUtils} | ||
import org.apache.spark.sql.internal.SQLConf | ||
import org.apache.spark.sql.types._ | ||
|
||
//////////////////////////////////////////////////////////////////////////////////////////////////// | ||
|
@@ -222,10 +223,15 @@ case class GetArrayStructFields( | |
* | ||
* We need to do type checking here as `ordinal` expression maybe unresolved. | ||
*/ | ||
case class GetArrayItem(child: Expression, ordinal: Expression) | ||
case class GetArrayItem( | ||
child: Expression, | ||
ordinal: Expression, | ||
failOnError: Boolean = SQLConf.get.ansiEnabled) | ||
extends BinaryExpression with GetArrayItemUtil with ExpectsInputTypes with ExtractValue | ||
with NullIntolerant { | ||
|
||
def this(child: Expression, ordinal: Expression) = this(child, ordinal, SQLConf.get.ansiEnabled) | ||
|
||
// We have done type checking for child in `ExtractValue`, so only need to check the `ordinal`. | ||
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, IntegralType) | ||
|
||
|
@@ -240,7 +246,13 @@ case class GetArrayItem(child: Expression, ordinal: Expression) | |
protected override def nullSafeEval(value: Any, ordinal: Any): Any = { | ||
val baseValue = value.asInstanceOf[ArrayData] | ||
val index = ordinal.asInstanceOf[Number].intValue() | ||
if (index >= baseValue.numElements() || index < 0 || baseValue.isNullAt(index)) { | ||
if (index >= baseValue.numElements() || index < 0) { | ||
if (failOnError) { | ||
throw new ArrayIndexOutOfBoundsException(s"Invalid index: $index") | ||
} else { | ||
null | ||
} | ||
} else if (baseValue.isNullAt(index)) { | ||
null | ||
} else { | ||
baseValue.get(index, dataType) | ||
|
@@ -251,15 +263,25 @@ case class GetArrayItem(child: Expression, ordinal: Expression) | |
nullSafeCodeGen(ctx, ev, (eval1, eval2) => { | ||
val index = ctx.freshName("index") | ||
val nullCheck = if (child.dataType.asInstanceOf[ArrayType].containsNull) { | ||
s" || $eval1.isNullAt($index)" | ||
s"""else if ($eval1.isNullAt($index)) { | ||
${ev.isNull} = true; | ||
} | ||
""" | ||
} else { | ||
"" | ||
} | ||
|
||
val failOnErrorBranch = if (failOnError) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: it should be |
||
s"""throw new ArrayIndexOutOfBoundsException("Invalid index: " + $index);""".stripMargin | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shall we remove |
||
} else { | ||
s"${ev.isNull} = true;" | ||
} | ||
|
||
s""" | ||
final int $index = (int) $eval2; | ||
if ($index >= $eval1.numElements() || $index < 0$nullCheck) { | ||
${ev.isNull} = true; | ||
} else { | ||
if ($index >= $eval1.numElements() || $index < 0) { | ||
$failOnErrorBranch | ||
} $nullCheck else { | ||
${ev.value} = ${CodeGenerator.getValue(eval1, dataType, index)}; | ||
} | ||
""" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult | |
import org.apache.spark.sql.catalyst.expressions.codegen._ | ||
import org.apache.spark.sql.catalyst.expressions.codegen.Block._ | ||
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, TypeUtils} | ||
import org.apache.spark.sql.internal.SQLConf | ||
import org.apache.spark.sql.types._ | ||
import org.apache.spark.unsafe.UTF8StringBuilder | ||
import org.apache.spark.unsafe.types.{ByteArray, UTF8String} | ||
|
@@ -231,15 +232,24 @@ case class ConcatWs(children: Seq[Expression]) | |
*/ | ||
// scalastyle:off line.size.limit | ||
@ExpressionDescription( | ||
usage = "_FUNC_(n, input1, input2, ...) - Returns the `n`-th input, e.g., returns `input2` when `n` is 2.", | ||
usage = """ | ||
_FUNC_(n, input1, input2, ...) - Returns the `n`-th input, e.g., returns `input2` when `n` is 2. | ||
The function returns NULL if the index exceeds the length of the array | ||
and `spark.sql.ansi.enabled` is set to false. If `spark.sql.ansi.enabled` is set to true, | ||
it throws ArrayIndexOutOfBoundsException for invalid indices. | ||
""", | ||
examples = """ | ||
Examples: | ||
> SELECT _FUNC_(1, 'scala', 'java'); | ||
scala | ||
""", | ||
since = "2.0.0") | ||
// scalastyle:on line.size.limit | ||
case class Elt(children: Seq[Expression]) extends Expression { | ||
case class Elt( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
children: Seq[Expression], | ||
failOnError: Boolean = SQLConf.get.ansiEnabled) extends Expression { | ||
|
||
def this(children: Seq[Expression]) = this(children, SQLConf.get.ansiEnabled) | ||
|
||
private lazy val indexExpr = children.head | ||
private lazy val inputExprs = children.tail.toArray | ||
|
@@ -275,7 +285,11 @@ case class Elt(children: Seq[Expression]) extends Expression { | |
} else { | ||
val index = indexObj.asInstanceOf[Int] | ||
if (index <= 0 || index > inputExprs.length) { | ||
null | ||
if (failOnError) { | ||
throw new ArrayIndexOutOfBoundsException(s"Invalid index: $index") | ||
} else { | ||
null | ||
} | ||
} else { | ||
inputExprs(index - 1).eval(input) | ||
} | ||
|
@@ -323,6 +337,16 @@ case class Elt(children: Seq[Expression]) extends Expression { | |
""".stripMargin | ||
}.mkString) | ||
|
||
val failOnErrorBranch = if (failOnError) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: it should be |
||
s""" | ||
|if (!$indexMatched) { | ||
| throw new ArrayIndexOutOfBoundsException("Invalid index: " + ${index.value}); | ||
|} | ||
""".stripMargin | ||
} else { | ||
"" | ||
} | ||
|
||
ev.copy( | ||
code""" | ||
|${index.code} | ||
|
@@ -332,6 +356,7 @@ case class Elt(children: Seq[Expression]) extends Expression { | |
|do { | ||
| $codes | ||
|} while (false); | ||
|$failOnErrorBranch | ||
|final ${CodeGenerator.javaType(dataType)} ${ev.value} = $inputVal; | ||
|final boolean ${ev.isNull} = ${ev.value} == null; | ||
""".stripMargin) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2144,9 +2144,10 @@ object SQLConf { | |
|
||
val ANSI_ENABLED = buildConf("spark.sql.ansi.enabled") | ||
.doc("When true, Spark tries to conform to the ANSI SQL specification: 1. Spark will " + | ||
"throw a runtime exception if an overflow occurs in any operation on integral/decimal " + | ||
"field. 2. Spark will forbid using the reserved keywords of ANSI SQL as identifiers in " + | ||
"the SQL parser.") | ||
"throw an exception at runtime if the inputs to a SQL operator/function are invalid, " + | ||
"e.g. overflow in arithmetic operations, out-of-range index when accessing array elements. " + | ||
"2. Spark will forbid using the reserved keywords of ANSI SQL as identifiers in " + | ||
"the SQL parser. 3. Spark will returns null for null input for function `size`.") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. will returns -> will return, and perhaps capital |
||
.version("3.0.0") | ||
.booleanConf | ||
.createWithDefault(false) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want to make
ElementAt
behavior consistent on map type?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's about to support ANSI mode for map type in next PR.