Skip to content

Commit

Permalink
Accessing array elements should failed if index is out of bound.
Browse files Browse the repository at this point in the history
Change-Id: If7650cc45aa30fd3d6549536a8af6ca01a746c39
  • Loading branch information
leanken-zz committed Nov 9, 2020
1 parent c269b53 commit d1e0ae3
Show file tree
Hide file tree
Showing 10 changed files with 129 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -840,8 +840,8 @@ object TypeCoercion {
plan resolveOperators { case p =>
p transformExpressionsUp {
// Skip nodes if unresolved or not enough children
case c @ Elt(children) if !c.childrenResolved || children.size < 2 => c
case c @ Elt(children) =>
case c @ Elt(children, _) if !c.childrenResolved || children.size < 2 => c
case c @ Elt(children, _) =>
val index = children.head
val newIndex = ImplicitTypeCasts.implicitCast(index, IntegerType).getOrElse(index)
val newInputs = if (conf.eltOutputAsString ||
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,10 @@ case class ProjectionOverSchema(schema: StructType) {
expr match {
case a: AttributeReference if fieldNames.contains(a.name) =>
Some(a.copy(dataType = schema(a.name).dataType)(a.exprId, a.qualifier))
case GetArrayItem(child, arrayItemOrdinal) =>
getProjection(child).map { projection => GetArrayItem(projection, arrayItemOrdinal) }
case GetArrayItem(child, arrayItemOrdinal, failOnError) =>
getProjection(child).map {
projection => GetArrayItem(projection, arrayItemOrdinal, failOnError)
}
case a: GetArrayStructFields =>
getProjection(a.child).map(p => (p, p.dataType)).map {
case (projection, ArrayType(projSchema @ StructType(_), _)) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ object SelectedField {
throw new AnalysisException(s"DataType '$x' is not supported by MapKeys.")
}
selectField(child, opt)
case GetArrayItem(child, _) =>
case GetArrayItem(child, _, _) =>
// GetArrayItem does not select a field from a struct (i.e. prune the struct) so it can't be
// the top-level extractor. However it can be part of an extractor chain.
val ArrayType(_, containsNull) = child.dataType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1919,9 +1919,14 @@ case class ArrayPosition(left: Expression, right: Expression)
b
""",
since = "2.4.0")
case class ElementAt(left: Expression, right: Expression)
case class ElementAt(
left: Expression,
right: Expression,
failOnError: Boolean = SQLConf.get.ansiEnabled)
extends GetMapValueUtil with GetArrayItemUtil with NullIntolerant {

def this(left: Expression, right: Expression) = this(left, right, SQLConf.get.ansiEnabled)

@transient private lazy val mapKeyType = left.dataType.asInstanceOf[MapType].keyType

@transient private lazy val arrayContainsNull = left.dataType.asInstanceOf[ArrayType].containsNull
Expand Down Expand Up @@ -1978,7 +1983,11 @@ case class ElementAt(left: Expression, right: Expression)
val array = value.asInstanceOf[ArrayData]
val index = ordinal.asInstanceOf[Int]
if (array.numElements() < math.abs(index)) {
null
if (failOnError) {
throw new ArrayIndexOutOfBoundsException(s"Invalid index: $index")
} else {
null
}
} else {
val idx = if (index == 0) {
throw new ArrayIndexOutOfBoundsException("SQL array indices start at 1")
Expand Down Expand Up @@ -2012,10 +2021,17 @@ case class ElementAt(left: Expression, right: Expression)
} else {
""
}

val failOnErrorBranch = if (failOnError) {
s"""throw new ArrayIndexOutOfBoundsException("Invalid index: " + $index);""".stripMargin
} else {
s"${ev.isNull} = true;"
}

s"""
|int $index = (int) $eval2;
|if ($eval1.numElements() < Math.abs($index)) {
| ${ev.isNull} = true;
| $failOnErrorBranch
|} else {
| if ($index == 0) {
| throw new ArrayIndexOutOfBoundsException("SQL array indices start at 1");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
import org.apache.spark.sql.catalyst.util.{quoteIdentifier, ArrayData, GenericArrayData, MapData, TypeUtils}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -222,10 +223,15 @@ case class GetArrayStructFields(
*
* We need to do type checking here as `ordinal` expression maybe unresolved.
*/
case class GetArrayItem(child: Expression, ordinal: Expression)
case class GetArrayItem(
child: Expression,
ordinal: Expression,
failOnError: Boolean = SQLConf.get.ansiEnabled)
extends BinaryExpression with GetArrayItemUtil with ExpectsInputTypes with ExtractValue
with NullIntolerant {

def this(child: Expression, ordinal: Expression) = this(child, ordinal, SQLConf.get.ansiEnabled)

// We have done type checking for child in `ExtractValue`, so only need to check the `ordinal`.
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, IntegralType)

Expand All @@ -240,7 +246,13 @@ case class GetArrayItem(child: Expression, ordinal: Expression)
protected override def nullSafeEval(value: Any, ordinal: Any): Any = {
val baseValue = value.asInstanceOf[ArrayData]
val index = ordinal.asInstanceOf[Number].intValue()
if (index >= baseValue.numElements() || index < 0 || baseValue.isNullAt(index)) {
if (index >= baseValue.numElements() || index < 0) {
if (failOnError) {
throw new ArrayIndexOutOfBoundsException(s"Invalid index: $index")
} else {
null
}
} else if (baseValue.isNullAt(index)) {
null
} else {
baseValue.get(index, dataType)
Expand All @@ -255,9 +267,18 @@ case class GetArrayItem(child: Expression, ordinal: Expression)
} else {
""
}

val failOnErrorBranch = if (failOnError) {
s"""throw new ArrayIndexOutOfBoundsException("Invalid index: " + $index);""".stripMargin
} else {
s"${ev.isNull} = true;"
}

s"""
final int $index = (int) $eval2;
if ($index >= $eval1.numElements() || $index < 0$nullCheck) {
if ($index >= $eval1.numElements() || $index < 0) {
$failOnErrorBranch
} else if (false$nullCheck) {
${ev.isNull} = true;
} else {
${ev.value} = ${CodeGenerator.getValue(eval1, dataType, index)};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, TypeUtils}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.UTF8StringBuilder
import org.apache.spark.unsafe.types.{ByteArray, UTF8String}
Expand Down Expand Up @@ -239,7 +240,11 @@ case class ConcatWs(children: Seq[Expression])
""",
since = "2.0.0")
// scalastyle:on line.size.limit
case class Elt(children: Seq[Expression]) extends Expression {
case class Elt(
children: Seq[Expression],
failOnError: Boolean = SQLConf.get.ansiEnabled) extends Expression {

def this(children: Seq[Expression]) = this(children, SQLConf.get.ansiEnabled)

private lazy val indexExpr = children.head
private lazy val inputExprs = children.tail.toArray
Expand Down Expand Up @@ -275,7 +280,11 @@ case class Elt(children: Seq[Expression]) extends Expression {
} else {
val index = indexObj.asInstanceOf[Int]
if (index <= 0 || index > inputExprs.length) {
null
if (failOnError) {
throw new ArrayIndexOutOfBoundsException(s"Invalid index: $index")
} else {
null
}
} else {
inputExprs(index - 1).eval(input)
}
Expand Down Expand Up @@ -323,6 +332,16 @@ case class Elt(children: Seq[Expression]) extends Expression {
""".stripMargin
}.mkString)

val failOnErrorBranch = if (failOnError) {
s"""
|if (!$indexMatched) {
| throw new ArrayIndexOutOfBoundsException("Invalid index: " + ${index.value});
|}
""".stripMargin
} else {
""
}

ev.copy(
code"""
|${index.code}
Expand All @@ -332,6 +351,7 @@ case class Elt(children: Seq[Expression]) extends Expression {
|do {
| $codes
|} while (false);
|$failOnErrorBranch
|final ${CodeGenerator.javaType(dataType)} ${ev.value} = $inputVal;
|final boolean ${ev.isNull} = ${ev.value} == null;
""".stripMargin)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ object SimplifyExtractValueOps extends Rule[LogicalPlan] {
CreateArray(elems.map(GetStructField(_, ordinal, Some(field.name))), useStringTypeWhenEmpty)

// Remove redundant map lookup.
case ga @ GetArrayItem(CreateArray(elems, _), IntegerLiteral(idx)) =>
case ga @ GetArrayItem(CreateArray(elems, _), IntegerLiteral(idx), _) =>
// Instead of creating the array and then selecting one row, remove array creation
// altogether.
if (idx >= 0 && idx < elems.size) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,24 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(GetArrayItem(nestedArray, Literal(0)), Seq(1))
}

test("SPARK-33391: GetArrayItem ArrayIndexOutOfBoundsException") {
Seq(true, false).foreach { ansiEnabled =>
withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) {
val typeA = ArrayType(StringType)
val array = Literal.create(Seq("a", "b"), typeA)

if (ansiEnabled) {
val ex = intercept[Exception] {
checkEvaluation(GetArrayItem(array, Literal(5)), null)
}
assert(stackTraceToString(ex).contains("Invalid index: 5"))
} else {
checkEvaluation(GetArrayItem(array, Literal(5)), null)
}
}
}
}

test("SPARK-26637 handles GetArrayItem nullability correctly when input array size is constant") {
// CreateArray case
val a = AttributeReference("a", IntegerType, nullable = false)()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3621,6 +3621,25 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
df.select(map(map_entries($"m"), lit(1))),
Row(Map(Seq(Row(1, "a")) -> 1)))
}

test("SPARK-33391: element_at ArrayIndexOutOfBoundsException") {
Seq(true, false).foreach { ansiEnabled =>
withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) {
val df = sql("select element_at(array(1, 2, 3), 5)")
if (ansiEnabled) {
val ex = intercept[Exception] {
df.collect()
}
assert(ex.getMessage.contains("Invalid index: 5"))
} else {
checkAnswer(
df,
Row(null)
)
}
}
}
}
}

object DataFrameFunctionsSuite {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -617,4 +617,23 @@ class StringFunctionsSuite extends QueryTest with SharedSparkSession {
)

}

test("SPARK-33391: elt ArrayIndexOutOfBoundsException") {
Seq(true, false).foreach { ansiEnabled =>
withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) {
val df = sql("select elt(4, '123', '456')")
if (ansiEnabled) {
val ex = intercept[Exception] {
df.collect()
}
assert(ex.getMessage.contains("Invalid index: 4"))
} else {
checkAnswer(
df,
Row(null)
)
}
}
}
}
}

0 comments on commit d1e0ae3

Please sign in to comment.