From c284d9f0822ebdf426aa96fd22ceb3dc06c59e9c Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Sat, 15 Nov 2014 20:42:10 -0800 Subject: [PATCH] fix a racing condition in zipWithIndex --- .../apache/spark/rdd/ZippedWithIndexRDD.scala | 31 ++++++++++--------- .../scala/org/apache/spark/rdd/RDDSuite.scala | 5 +++ 2 files changed, 22 insertions(+), 14 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala index e2c301603b4a5..8c43a559409f2 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala @@ -39,21 +39,24 @@ class ZippedWithIndexRDDPartition(val prev: Partition, val startIndex: Long) private[spark] class ZippedWithIndexRDD[T: ClassTag](@transient prev: RDD[T]) extends RDD[(T, Long)](prev) { - override def getPartitions: Array[Partition] = { + /** The start index of each partition. */ + @transient private val startIndices: Array[Long] = { val n = prev.partitions.size - val startIndices: Array[Long] = - if (n == 0) { - Array[Long]() - } else if (n == 1) { - Array(0L) - } else { - prev.context.runJob( - prev, - Utils.getIteratorSize _, - 0 until n - 1, // do not need to count the last partition - false - ).scanLeft(0L)(_ + _) - } + if (n == 0) { + Array[Long]() + } else if (n == 1) { + Array(0L) + } else { + prev.context.runJob( + prev, + Utils.getIteratorSize _, + 0 until n - 1, // do not need to count the last partition + allowLocal = false + ).scanLeft(0L)(_ + _) + } + } + + override def getPartitions: Array[Partition] = { firstParent[T].partitions.map(x => new ZippedWithIndexRDDPartition(x, startIndices(x.index))) } diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 6d2e696dc2fc4..e079ca3b1e896 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -739,6 +739,11 @@ class RDDSuite extends FunSuite with SharedSparkContext { } } + test("zipWithIndex chained with other RDDs (SPARK-4433)") { + val count = sc.parallelize(0 until 10, 2).zipWithIndex().repartition(4).count() + assert(count === 10) + } + test("zipWithUniqueId") { val n = 10 val data = sc.parallelize(0 until n, 3)