Skip to content

Commit

Permalink
allowing num to be greater than count in all cases
Browse files Browse the repository at this point in the history
  • Loading branch information
dorx committed Jun 12, 2014
1 parent 1481b01 commit fb1452f
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 24 deletions.
13 changes: 6 additions & 7 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -396,25 +396,24 @@ abstract class RDD[T: ClassTag](
throw new IllegalArgumentException("Negative number of elements requested")
}

if (initialCount == 0) {
if (initialCount == 0 || num == 0) {
return new Array[T](0)
}

if (!withReplacement && num > initialCount) {
throw new IllegalArgumentException("Cannot create sample larger than the original when " +
"sampling without replacement")
}

val maxSampleSize = Int.MaxValue - (numStDev * math.sqrt(Int.MaxValue)).toInt
if (num > maxSampleSize) {
throw new IllegalArgumentException("Cannot support a sample size > Int.MaxValue - " +
s"$numStDev * math.sqrt(Int.MaxValue)")
}

val rand = new Random(seed)
if (!withReplacement && num > initialCount) {
return Utils.randomizeInPlace(this.collect(), rand)
}

val fraction = SamplingUtils.computeFractionForSampleSize(num, initialCount,
withReplacement)

val rand = new Random(seed)
var samples = this.sample(withReplacement, fraction, rand.nextInt()).collect()

// If the first sample didn't turn out large enough, keep trying to take samples;
Expand Down
36 changes: 19 additions & 17 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,35 +366,37 @@ def takeSample(self, withReplacement, num, seed=None):
[4, 2, 1, 8, 2, 7, 0, 4, 1, 4]
"""

#TODO remove
logging.basicConfig(level=logging.INFO)
numStDev = 10.0
initialCount = self.count()

if num < 0:
raise ValueError

if initialCount == 0:
if initialCount == 0 or num == 0:
return list()

rand = Random(seed)
if (not withReplacement) and num > initialCount:
raise ValueError
# shuffle current RDD and return
samples = self.collect()
fraction = float(num) / initialCount
num = initialCount
else:
maxSampleSize = sys.maxint - int(numStDev * sqrt(sys.maxint))
if num > maxSampleSize:
raise ValueError

maxSampleSize = sys.maxint - int(numStDev * sqrt(sys.maxint))
if num > maxSampleSize:
raise ValueError
fraction = self._computeFractionForSampleSize(num, initialCount, withReplacement)

fraction = self._computeFractionForSampleSize(num, initialCount, withReplacement)

samples = self.sample(withReplacement, fraction, seed).collect()
samples = self.sample(withReplacement, fraction, seed).collect()

# If the first sample didn't turn out large enough, keep trying to take samples;
# this shouldn't happen often because we use a big multiplier for their initial size.
# See: scala/spark/RDD.scala
rand = Random(seed)
while len(samples) < num:
#TODO add log warning for when more than one iteration was run
samples = self.sample(withReplacement, fraction, rand.randint(0, sys.maxint)).collect()
# If the first sample didn't turn out large enough, keep trying to take samples;
# this shouldn't happen often because we use a big multiplier for their initial size.
# See: scala/spark/RDD.scala
while len(samples) < num:
#TODO add log warning for when more than one iteration was run
seed = rand.randint(0, sys.maxint)
samples = self.sample(withReplacement, fraction, seed).collect()

sampler = RDDSampler(withReplacement, fraction, rand.randint(0, sys.maxint))
sampler.shuffle(samples)
Expand Down

0 comments on commit fb1452f

Please sign in to comment.