From d7eec4536a9f29cf9f8783b620fdf75f7b4e08bb Mon Sep 17 00:00:00 2001 From: uncleGen Date: Sat, 28 Nov 2015 15:51:43 +0800 Subject: [PATCH 1/4] bug fix --- .../main/scala/org/apache/spark/Partitioner.scala | 4 ++-- .../apache/spark/util/random/SamplingUtils.scala | 14 ++++++++++---- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index e4df7af81a6d2..ef9a2dab1c106 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -253,7 +253,7 @@ private[spark] object RangePartitioner { */ def sketch[K : ClassTag]( rdd: RDD[K], - sampleSizePerPartition: Int): (Long, Array[(Int, Int, Array[K])]) = { + sampleSizePerPartition: Int): (Long, Array[(Int, Long, Array[K])]) = { val shift = rdd.id // val classTagK = classTag[K] // to avoid serializing the entire partitioner object val sketched = rdd.mapPartitionsWithIndex { (idx, iter) => @@ -262,7 +262,7 @@ private[spark] object RangePartitioner { iter, sampleSizePerPartition, seed) Iterator((idx, n, sample)) }.collect() - val numItems = sketched.map(_._2.toLong).sum + val numItems = sketched.map(_._2).sum (numItems, sketched) } diff --git a/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala b/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala index c9a864ae62778..92fe3698ff081 100644 --- a/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala @@ -34,7 +34,7 @@ private[spark] object SamplingUtils { input: Iterator[T], k: Int, seed: Long = Random.nextLong()) - : (Array[T], Int) = { + : (Array[T], Long) = { val reservoir = new Array[T](k) // Put the first k elements in the reservoir. var i = 0 @@ -52,16 +52,22 @@ private[spark] object SamplingUtils { (trimReservoir, i) } else { // If input size > k, continue the sampling process. + var l = i.toLong val rand = new XORShiftRandom(seed) while (input.hasNext) { val item = input.next() - val replacementIndex = rand.nextInt(i) + val replacementIndex = l < Int.MaxValue match { + case true => + rand.nextInt(l.toInt) + case false => + rand.nextInt() + } if (replacementIndex < k) { reservoir(replacementIndex) = item } - i += 1 + l += 1 } - (reservoir, i) + (reservoir, l) } } From 986fa5c3de0d0bbf7ffd921e9c76305ac82428f2 Mon Sep 17 00:00:00 2001 From: uncleGen Date: Sat, 28 Nov 2015 15:58:48 +0800 Subject: [PATCH 2/4] code style --- .../scala/org/apache/spark/util/random/SamplingUtils.scala | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala b/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala index 92fe3698ff081..f17ac7fb6a8e3 100644 --- a/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala @@ -57,10 +57,8 @@ private[spark] object SamplingUtils { while (input.hasNext) { val item = input.next() val replacementIndex = l < Int.MaxValue match { - case true => - rand.nextInt(l.toInt) - case false => - rand.nextInt() + case true => rand.nextInt(l.toInt) + case false => rand.nextInt() } if (replacementIndex < k) { reservoir(replacementIndex) = item From bb2ed41112901ab64d6b74b39389e60d76c173f3 Mon Sep 17 00:00:00 2001 From: uncleGen Date: Sat, 28 Nov 2015 17:47:43 +0800 Subject: [PATCH 3/4] modify some pointless and verbose code --- .../org/apache/spark/util/random/SamplingUtils.scala | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala b/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala index f17ac7fb6a8e3..0a8c59bbaae96 100644 --- a/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala @@ -56,12 +56,13 @@ private[spark] object SamplingUtils { val rand = new XORShiftRandom(seed) while (input.hasNext) { val item = input.next() - val replacementIndex = l < Int.MaxValue match { - case true => rand.nextInt(l.toInt) - case false => rand.nextInt() + val replacementIndex = if (l < Int.MaxValue) { + rand.nextInt(l.toInt) + } else { + rand.nextLong() } if (replacementIndex < k) { - reservoir(replacementIndex) = item + reservoir(replacementIndex.toInt) = item } l += 1 } From 3cafda5294c382f06d816bb2479db2e16d9bd6dc Mon Sep 17 00:00:00 2001 From: uncleGen Date: Mon, 30 Nov 2015 11:04:45 +0800 Subject: [PATCH 4/4] bug fix --- .../scala/org/apache/spark/util/random/SamplingUtils.scala | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala b/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala index 0a8c59bbaae96..f98932a470165 100644 --- a/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala @@ -56,11 +56,7 @@ private[spark] object SamplingUtils { val rand = new XORShiftRandom(seed) while (input.hasNext) { val item = input.next() - val replacementIndex = if (l < Int.MaxValue) { - rand.nextInt(l.toInt) - } else { - rand.nextLong() - } + val replacementIndex = (rand.nextDouble() * l).toLong if (replacementIndex < k) { reservoir(replacementIndex.toInt) = item }