Skip to content

Commit

Permalink
SPARK-1939: Refactor takeSample method in RDD
Browse files Browse the repository at this point in the history
Reviewer comments addressed:
- commons-math3 is now a test-only dependency. bumped up to v3.3
- comments added to explain what computeFraction is doing
- fixed the unit for computeFraction to use BinomialDitro for without
replacement sampling
- stylistic fixes
  • Loading branch information
dorx committed May 30, 2014
1 parent 1441977 commit ffea61a
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 32 deletions.
1 change: 1 addition & 0 deletions core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-math3</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.google.code.findbugs</groupId>
Expand Down
33 changes: 23 additions & 10 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -388,8 +388,8 @@ abstract class RDD[T: ClassTag](
* @return sample of specified size in an array
*/
def takeSample(withReplacement: Boolean,
num: Int,
seed: Long = Utils.random.nextLong): Array[T] = {
num: Int,
seed: Long = Utils.random.nextLong): Array[T] = {
var fraction = 0.0
var total = 0
val multiplier = 3.0
Expand Down Expand Up @@ -431,18 +431,31 @@ abstract class RDD[T: ClassTag](
Utils.randomizeInPlace(samples, rand).take(total)
}

private[spark] def computeFraction(num: Int, total: Long, withReplacement: Boolean) : Double = {
/**
* Let p = num / total, where num is the sample size and total is the total number of
* datapoints in the RDD. We're trying to compute q > p such that
* - when sampling with replacement, we're drawing each datapoint with prob_i ~ Pois(q),
* where we want to guarantee Pr[s < num] < 0.0001 for s = sum(prob_i for i from 0 to total),
* i.e. the failure rate of not having a sufficiently large sample < 0.0001.
* Setting q = p + 5 * sqrt(p/total) is sufficient to guarantee 0.9999 success rate for
* num > 12, but we need a slightly larger q (9 empirically determined).
* - when sampling without replacement, we're drawing each datapoint with prob_i
* ~ Binomial(total, fraction) and our choice of q guarantees 1-delta, or 0.9999 success
* rate, where success rate is defined the same as in sampling with replacement.
*
* @param num sample size
* @param total size of RDD
* @param withReplacement whether sampling with replacement
* @return a sampling rate that guarantees sufficient sample size with 99.99% success rate
*/
private[rdd] def computeFraction(num: Int, total: Long, withReplacement: Boolean): Double = {
val fraction = num.toDouble / total
if (withReplacement) {
var numStDev = 5
if (num < 12) {
// special case to guarantee sample size for small s
numStDev = 9
}
val numStDev = if (num < 12) 9 else 5
fraction + numStDev * math.sqrt(fraction / total)
} else {
val delta = 0.00005
val gamma = - math.log(delta)/total
val delta = 1e-4
val gamma = - math.log(delta) / total
math.min(1, fraction + gamma + math.sqrt(gamma * gamma + 2 * gamma * fraction))
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class BernoulliSampler[T](lb: Double, ub: Double, complement: Boolean = false)
}

/**
* Return a sampler which is the complement of the range specified of the current sampler.
* Return a sampler that is the complement of the range specified of the current sampler.
*/
def cloneComplement(): BernoulliSampler[T] = new BernoulliSampler[T](lb, ub, !complement)

Expand Down
36 changes: 17 additions & 19 deletions core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ import scala.reflect.ClassTag

import org.scalatest.FunSuite

import org.apache.commons.math3.distribution.BinomialDistribution
import org.apache.commons.math3.distribution.PoissonDistribution

import org.apache.spark._
import org.apache.spark.SparkContext._
import org.apache.spark.rdd._
Expand Down Expand Up @@ -496,29 +498,25 @@ class RDDSuite extends FunSuite with SharedSparkContext {
}

test("computeFraction") {
// test that the computed fraction guarantees enough datapoints in the sample with a failure rate <= 0.0001
// test that the computed fraction guarantees enough datapoints
// in the sample with a failure rate <= 0.0001
val data = new EmptyRDD[Int](sc)
val n = 100000

for (s <- 1 to 15) {
val frac = data.computeFraction(s, n, true)
val qpois = new PoissonDistribution(frac * n)
assert(qpois.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low")
val poisson = new PoissonDistribution(frac * n)
assert(poisson.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low")
}
for (s <- 1 to 15) {
val frac = data.computeFraction(s, n, false)
val qpois = new PoissonDistribution(frac * n)
assert(qpois.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low")
}
for (s <- List(1, 10, 100, 1000)) {
for (s <- List(20, 100, 1000)) {
val frac = data.computeFraction(s, n, true)
val qpois = new PoissonDistribution(frac * n)
assert(qpois.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low")
val poisson = new PoissonDistribution(frac * n)
assert(poisson.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low")
}
for (s <- List(1, 10, 100, 1000)) {
val frac = data.computeFraction(s, n, false)
val qpois = new PoissonDistribution(frac * n)
assert(qpois.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low")
val binomial = new BinomialDistribution(n, frac)
assert(binomial.inverseCumulativeProbability(0.0001)*n >= s, "Computed fraction is too low")
}
}

Expand All @@ -530,37 +528,37 @@ class RDDSuite extends FunSuite with SharedSparkContext {
val sample = data.takeSample(withReplacement=false, num=num)
assert(sample.size === num) // Got exactly num elements
assert(sample.toSet.size === num) // Elements are distinct
assert(sample.forall(x => 1 <= x && x <= n), "elements not in [1, 100]")
assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]")
}
for (seed <- 1 to 5) {
val sample = data.takeSample(withReplacement=false, 20, seed)
assert(sample.size === 20) // Got exactly 20 elements
assert(sample.toSet.size === 20) // Elements are distinct
assert(sample.forall(x => 1 <= x && x <= n), "elements not in [1, 100]")
assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]")
}
for (seed <- 1 to 5) {
val sample = data.takeSample(withReplacement=false, 100, seed)
assert(sample.size === 100) // Got only 100 elements
assert(sample.toSet.size === 100) // Elements are distinct
assert(sample.forall(x => 1 <= x && x <= n), "elements not in [1, 100]")
assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]")
}
for (seed <- 1 to 5) {
val sample = data.takeSample(withReplacement=true, 20, seed)
assert(sample.size === 20) // Got exactly 20 elements
assert(sample.forall(x => 1 <= x && x <= n), "elements not in [1, 100]")
assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]")
}
{
val sample = data.takeSample(withReplacement=true, num=20)
assert(sample.size === 20) // Got exactly 100 elements
assert(sample.toSet.size <= 20, "sampling with replacement returned all distinct elements")
assert(sample.forall(x => 1 <= x && x <= n), "elements not in [1, 100]")
assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]")
}
{
val sample = data.takeSample(withReplacement=true, num=n)
assert(sample.size === n) // Got exactly 100 elements
// Chance of getting all distinct elements is astronomically low, so test we got < 100
assert(sample.toSet.size < n, "sampling with replacement returned all distinct elements")
assert(sample.forall(x => 1 <= x && x <= n), "elements not in [1, 100]")
assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]")
}
for (seed <- 1 to 5) {
val sample = data.takeSample(withReplacement=true, n, seed)
Expand Down
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-math3</artifactId>
<version>3.2</version>
<version>3.3</version>
</dependency>
<dependency>
<groupId>com.google.code.findbugs</groupId>
Expand Down
2 changes: 1 addition & 1 deletion project/SparkBuild.scala
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ object SparkBuild extends Build {
libraryDependencies ++= Seq(
"com.google.guava" % "guava" % "14.0.1",
"org.apache.commons" % "commons-lang3" % "3.3.2",
"org.apache.commons" % "commons-math3" % "3.2",
"org.apache.commons" % "commons-math3" % "3.3" % "test",
"com.google.code.findbugs" % "jsr305" % "1.3.9",
"log4j" % "log4j" % "1.2.17",
"org.slf4j" % "slf4j-api" % slf4jVersion,
Expand Down

0 comments on commit ffea61a

Please sign in to comment.