Skip to content

Commit

Permalink
[SPARK-12340][SQL] fix Int overflow in the SparkPlan.executeTake, RDD…
Browse files Browse the repository at this point in the history
….take and AsyncRDDActions.takeAsync

I have closed pull request #10487. And I create this pull request to resolve the problem.

spark jira
https://issues.apache.org/jira/browse/SPARK-12340

Author: QiangCai <[email protected]>

Closes #10562 from QiangCai/bugfix.
  • Loading branch information
QiangCai authored and sarutak committed Jan 6, 2016
1 parent b2467b3 commit 5d871ea
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 14 deletions.
12 changes: 6 additions & 6 deletions core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi
val localProperties = self.context.getLocalProperties
// Cached thread pool to handle aggregation of subtasks.
implicit val executionContext = AsyncRDDActions.futureExecutionContext
val results = new ArrayBuffer[T](num)
val results = new ArrayBuffer[T]
val totalParts = self.partitions.length

/*
Expand All @@ -77,13 +77,13 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi
This implementation is non-blocking, asynchronously handling the
results of each job and triggering the next job using callbacks on futures.
*/
def continue(partsScanned: Int)(implicit jobSubmitter: JobSubmitter) : Future[Seq[T]] =
def continue(partsScanned: Long)(implicit jobSubmitter: JobSubmitter) : Future[Seq[T]] =
if (results.size >= num || partsScanned >= totalParts) {
Future.successful(results.toSeq)
} else {
// The number of partitions to try in this iteration. It is ok for this number to be
// greater than totalParts because we actually cap it at totalParts in runJob.
var numPartsToTry = 1
var numPartsToTry = 1L
if (partsScanned > 0) {
// If we didn't find any rows after the previous iteration, quadruple and retry.
// Otherwise, interpolate the number of partitions we need to try, but overestimate it
Expand All @@ -99,7 +99,7 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi
}

val left = num - results.size
val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts)
val p = partsScanned.toInt until math.min(partsScanned + numPartsToTry, totalParts).toInt

val buf = new Array[Array[T]](p.size)
self.context.setCallSite(callSite)
Expand All @@ -111,11 +111,11 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi
Unit)
job.flatMap {_ =>
buf.foreach(results ++= _.take(num - results.size))
continue(partsScanned + numPartsToTry)
continue(partsScanned + p.size)
}
}

new ComplexFutureAction[Seq[T]](continue(0)(_))
new ComplexFutureAction[Seq[T]](continue(0L)(_))
}

/**
Expand Down
8 changes: 4 additions & 4 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1190,11 +1190,11 @@ abstract class RDD[T: ClassTag](
} else {
val buf = new ArrayBuffer[T]
val totalParts = this.partitions.length
var partsScanned = 0
var partsScanned = 0L
while (buf.size < num && partsScanned < totalParts) {
// The number of partitions to try in this iteration. It is ok for this number to be
// greater than totalParts because we actually cap it at totalParts in runJob.
var numPartsToTry = 1
var numPartsToTry = 1L
if (partsScanned > 0) {
// If we didn't find any rows after the previous iteration, quadruple and retry.
// Otherwise, interpolate the number of partitions we need to try, but overestimate
Expand All @@ -1209,11 +1209,11 @@ abstract class RDD[T: ClassTag](
}

val left = num - buf.size
val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts)
val p = partsScanned.toInt until math.min(partsScanned + numPartsToTry, totalParts).toInt
val res = sc.runJob(this, (it: Iterator[T]) => it.take(left).toArray, p)

res.foreach(buf ++= _.take(num - buf.size))
partsScanned += numPartsToTry
partsScanned += p.size
}

buf.toArray
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,11 +165,11 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ

val buf = new ArrayBuffer[InternalRow]
val totalParts = childRDD.partitions.length
var partsScanned = 0
var partsScanned = 0L
while (buf.size < n && partsScanned < totalParts) {
// The number of partitions to try in this iteration. It is ok for this number to be
// greater than totalParts because we actually cap it at totalParts in runJob.
var numPartsToTry = 1
var numPartsToTry = 1L
if (partsScanned > 0) {
// If we didn't find any rows after the first iteration, just try all partitions next.
// Otherwise, interpolate the number of partitions we need to try, but overestimate it
Expand All @@ -183,13 +183,13 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
numPartsToTry = math.max(0, numPartsToTry) // guard against negative num of partitions

val left = n - buf.size
val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts)
val p = partsScanned.toInt until math.min(partsScanned + numPartsToTry, totalParts).toInt
val sc = sqlContext.sparkContext
val res =
sc.runJob(childRDD, (it: Iterator[InternalRow]) => it.take(left).toArray, p)

res.foreach(buf ++= _.take(n - buf.size))
partsScanned += numPartsToTry
partsScanned += p.size
}

buf.toArray
Expand Down
12 changes: 12 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2067,4 +2067,16 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
)
}
}

test("SPARK-12340: overstep the bounds of Int in SparkPlan.executeTake") {
val rdd = sqlContext.sparkContext.parallelize(1 to 3 , 3 )
rdd.toDF("key").registerTempTable("spark12340")
checkAnswer(
sql("select key from spark12340 limit 2147483638"),
Row(1) :: Row(2) :: Row(3) :: Nil
)
assert(rdd.take(2147483638).size === 3)
assert(rdd.takeAsync(2147483638).get.size === 3)
}

}

0 comments on commit 5d871ea

Please sign in to comment.