Skip to content

Commit

Permalink
[SPARK-2937] Separate out samplyByKeyExact as its own API in PairRDDF…
Browse files Browse the repository at this point in the history
…unction

To enable Python consistency and `Experimental` label of the `sampleByKeyExact` API.

Author: Doris Xin <[email protected]>
Author: Xiangrui Meng <[email protected]>

Closes apache#1866 from dorx/stratified and squashes the following commits:

0ad97b2 [Doris Xin] reviewer comments.
2948aae [Doris Xin] remove unrelated changes
e990325 [Doris Xin] Merge branch 'master' into stratified
555a3f9 [Doris Xin] separate out sampleByKeyExact as its own API
616e55c [Doris Xin] merge master
245439e [Doris Xin] moved minSamplingRate to getUpperBound
eaf5771 [Doris Xin] bug fixes.
17a381b [Doris Xin] fixed a merge issue and a failed unit
ea7d27f [Doris Xin] merge master
b223529 [Xiangrui Meng] use approx bounds for poisson fix poisson mean for waitlisting add unit tests for Java
b3013a4 [Xiangrui Meng] move math3 back to test scope
eecee5f [Doris Xin] Merge branch 'master' into stratified
f4c21f3 [Doris Xin] Reviewer comments
a10e68d [Doris Xin] style fix
a2bf756 [Doris Xin] Merge branch 'master' into stratified
680b677 [Doris Xin] use mapPartitionWithIndex instead
9884a9f [Doris Xin] style fix
bbfb8c9 [Doris Xin] Merge branch 'master' into stratified
ee9d260 [Doris Xin] addressed reviewer comments
6b5b10b [Doris Xin] Merge branch 'master' into stratified
254e03c [Doris Xin] minor fixes and Java API.
4ad516b [Doris Xin] remove unused imports from PairRDDFunctions
bd9dc6e [Doris Xin] unit bug and style violation fixed
1fe1cff [Doris Xin] Changed fractionByKey to a map to enable arg check
944a10c [Doris Xin] [SPARK-2145] Add lower bound on sampling rate
0214a76 [Doris Xin] cleanUp
90d94c0 [Doris Xin] merge master
9e74ab5 [Doris Xin] Separated out most of the logic in sampleByKey
7327611 [Doris Xin] merge master
50581fc [Doris Xin] added a TODO for logging in python
46f6c8c [Doris Xin] fixed the NPE caused by closures being cleaned before being passed into the aggregate function
7e1a481 [Doris Xin] changed the permission on SamplingUtil
1d413ce [Doris Xin] fixed checkstyle issues
9ee94ee [Doris Xin] [SPARK-2082] stratified sampling in PairRDDFunctions that guarantees exact sample size
e3fd6a6 [Doris Xin] Merge branch 'master' into takeSample
7cab53a [Doris Xin] fixed import bug in rdd.py
ffea61a [Doris Xin] SPARK-1939: Refactor takeSample method in RDD
1441977 [Doris Xin] SPARK-1939 Refactor takeSample method in RDD to use ScaSRS
  • Loading branch information
dorx authored and mengxr committed Aug 10, 2014
1 parent 28dcbb5 commit b715aa0
Show file tree
Hide file tree
Showing 4 changed files with 216 additions and 128 deletions.
68 changes: 31 additions & 37 deletions core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -133,68 +133,62 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
* Return a subset of this RDD sampled by key (via stratified sampling).
*
* Create a sample of this RDD using variable sampling rates for different keys as specified by
* `fractions`, a key to sampling rate map.
*
* If `exact` is set to false, create the sample via simple random sampling, with one pass
* over the RDD, to produce a sample of size that's approximately equal to the sum of
* math.ceil(numItems * samplingRate) over all key values; otherwise, use additional passes over
* the RDD to create a sample size that's exactly equal to the sum of
* `fractions`, a key to sampling rate map, via simple random sampling with one pass over the
* RDD, to produce a sample of size that's approximately equal to the sum of
* math.ceil(numItems * samplingRate) over all key values.
*/
def sampleByKey(withReplacement: Boolean,
fractions: JMap[K, Double],
exact: Boolean,
seed: Long): JavaPairRDD[K, V] =
new JavaPairRDD[K, V](rdd.sampleByKey(withReplacement, fractions, exact, seed))
new JavaPairRDD[K, V](rdd.sampleByKey(withReplacement, fractions, seed))

/**
* Return a subset of this RDD sampled by key (via stratified sampling).
*
* Create a sample of this RDD using variable sampling rates for different keys as specified by
* `fractions`, a key to sampling rate map.
*
* If `exact` is set to false, create the sample via simple random sampling, with one pass
* over the RDD, to produce a sample of size that's approximately equal to the sum of
* math.ceil(numItems * samplingRate) over all key values; otherwise, use additional passes over
* the RDD to create a sample size that's exactly equal to the sum of
* `fractions`, a key to sampling rate map, via simple random sampling with one pass over the
* RDD, to produce a sample of size that's approximately equal to the sum of
* math.ceil(numItems * samplingRate) over all key values.
*
* Use Utils.random.nextLong as the default seed for the random number generator
* Use Utils.random.nextLong as the default seed for the random number generator.
*/
def sampleByKey(withReplacement: Boolean,
fractions: JMap[K, Double],
exact: Boolean): JavaPairRDD[K, V] =
sampleByKey(withReplacement, fractions, exact, Utils.random.nextLong)
fractions: JMap[K, Double]): JavaPairRDD[K, V] =
sampleByKey(withReplacement, fractions, Utils.random.nextLong)

/**
* Return a subset of this RDD sampled by key (via stratified sampling).
*
* Create a sample of this RDD using variable sampling rates for different keys as specified by
* `fractions`, a key to sampling rate map.
* ::Experimental::
* Return a subset of this RDD sampled by key (via stratified sampling) containing exactly
* math.ceil(numItems * samplingRate) for each stratum (group of pairs with the same key).
*
* Produce a sample of size that's approximately equal to the sum of
* math.ceil(numItems * samplingRate) over all key values with one pass over the RDD via
* simple random sampling.
* This method differs from [[sampleByKey]] in that we make additional passes over the RDD to
* create a sample size that's exactly equal to the sum of math.ceil(numItems * samplingRate)
* over all key values with a 99.99% confidence. When sampling without replacement, we need one
* additional pass over the RDD to guarantee sample size; when sampling with replacement, we need
* two additional passes.
*/
def sampleByKey(withReplacement: Boolean,
@Experimental
def sampleByKeyExact(withReplacement: Boolean,
fractions: JMap[K, Double],
seed: Long): JavaPairRDD[K, V] =
sampleByKey(withReplacement, fractions, false, seed)
new JavaPairRDD[K, V](rdd.sampleByKeyExact(withReplacement, fractions, seed))

/**
* Return a subset of this RDD sampled by key (via stratified sampling).
* ::Experimental::
* Return a subset of this RDD sampled by key (via stratified sampling) containing exactly
* math.ceil(numItems * samplingRate) for each stratum (group of pairs with the same key).
*
* Create a sample of this RDD using variable sampling rates for different keys as specified by
* `fractions`, a key to sampling rate map.
*
* Produce a sample of size that's approximately equal to the sum of
* math.ceil(numItems * samplingRate) over all key values with one pass over the RDD via
* simple random sampling.
* This method differs from [[sampleByKey]] in that we make additional passes over the RDD to
* create a sample size that's exactly equal to the sum of math.ceil(numItems * samplingRate)
* over all key values with a 99.99% confidence. When sampling without replacement, we need one
* additional pass over the RDD to guarantee sample size; when sampling with replacement, we need
* two additional passes.
*
* Use Utils.random.nextLong as the default seed for the random number generator
* Use Utils.random.nextLong as the default seed for the random number generator.
*/
def sampleByKey(withReplacement: Boolean, fractions: JMap[K, Double]): JavaPairRDD[K, V] =
sampleByKey(withReplacement, fractions, false, Utils.random.nextLong)
@Experimental
def sampleByKeyExact(withReplacement: Boolean, fractions: JMap[K, Double]): JavaPairRDD[K, V] =
sampleByKeyExact(withReplacement, fractions, Utils.random.nextLong)

/**
* Return the union of this RDD and another one. Any identical elements will appear multiple
Expand Down
51 changes: 37 additions & 14 deletions core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -197,33 +197,56 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
* Return a subset of this RDD sampled by key (via stratified sampling).
*
* Create a sample of this RDD using variable sampling rates for different keys as specified by
* `fractions`, a key to sampling rate map.
*
* If `exact` is set to false, create the sample via simple random sampling, with one pass
* over the RDD, to produce a sample of size that's approximately equal to the sum of
* math.ceil(numItems * samplingRate) over all key values; otherwise, use
* additional passes over the RDD to create a sample size that's exactly equal to the sum of
* math.ceil(numItems * samplingRate) over all key values with a 99.99% confidence. When sampling
* without replacement, we need one additional pass over the RDD to guarantee sample size;
* when sampling with replacement, we need two additional passes.
* `fractions`, a key to sampling rate map, via simple random sampling with one pass over the
* RDD, to produce a sample of size that's approximately equal to the sum of
* math.ceil(numItems * samplingRate) over all key values.
*
* @param withReplacement whether to sample with or without replacement
* @param fractions map of specific keys to sampling rates
* @param seed seed for the random number generator
* @param exact whether sample size needs to be exactly math.ceil(fraction * size) per key
* @return RDD containing the sampled subset
*/
def sampleByKey(withReplacement: Boolean,
fractions: Map[K, Double],
exact: Boolean = false,
seed: Long = Utils.random.nextLong): RDD[(K, V)]= {
seed: Long = Utils.random.nextLong): RDD[(K, V)] = {

require(fractions.values.forall(v => v >= 0.0), "Negative sampling rates.")

val samplingFunc = if (withReplacement) {
StratifiedSamplingUtils.getPoissonSamplingFunction(self, fractions, false, seed)
} else {
StratifiedSamplingUtils.getBernoulliSamplingFunction(self, fractions, false, seed)
}
self.mapPartitionsWithIndex(samplingFunc, preservesPartitioning = true)
}

/**
* ::Experimental::
* Return a subset of this RDD sampled by key (via stratified sampling) containing exactly
* math.ceil(numItems * samplingRate) for each stratum (group of pairs with the same key).
*
* This method differs from [[sampleByKey]] in that we make additional passes over the RDD to
* create a sample size that's exactly equal to the sum of math.ceil(numItems * samplingRate)
* over all key values with a 99.99% confidence. When sampling without replacement, we need one
* additional pass over the RDD to guarantee sample size; when sampling with replacement, we need
* two additional passes.
*
* @param withReplacement whether to sample with or without replacement
* @param fractions map of specific keys to sampling rates
* @param seed seed for the random number generator
* @return RDD containing the sampled subset
*/
@Experimental
def sampleByKeyExact(withReplacement: Boolean,
fractions: Map[K, Double],
seed: Long = Utils.random.nextLong): RDD[(K, V)] = {

require(fractions.values.forall(v => v >= 0.0), "Negative sampling rates.")

val samplingFunc = if (withReplacement) {
StratifiedSamplingUtils.getPoissonSamplingFunction(self, fractions, exact, seed)
StratifiedSamplingUtils.getPoissonSamplingFunction(self, fractions, true, seed)
} else {
StratifiedSamplingUtils.getBernoulliSamplingFunction(self, fractions, exact, seed)
StratifiedSamplingUtils.getBernoulliSamplingFunction(self, fractions, true, seed)
}
self.mapPartitionsWithIndex(samplingFunc, preservesPartitioning = true)
}
Expand Down
20 changes: 18 additions & 2 deletions core/src/test/java/org/apache/spark/JavaAPISuite.java
Original file line number Diff line number Diff line change
Expand Up @@ -1239,12 +1239,28 @@ public Tuple2<Integer, Integer> call(Integer i) {
Assert.assertTrue(worCounts.size() == 2);
Assert.assertTrue(worCounts.get(0) > 0);
Assert.assertTrue(worCounts.get(1) > 0);
JavaPairRDD<Integer, Integer> wrExact = rdd2.sampleByKey(true, fractions, true, 1L);
}

@Test
@SuppressWarnings("unchecked")
public void sampleByKeyExact() {
JavaRDD<Integer> rdd1 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8), 3);
JavaPairRDD<Integer, Integer> rdd2 = rdd1.mapToPair(
new PairFunction<Integer, Integer, Integer>() {
@Override
public Tuple2<Integer, Integer> call(Integer i) {
return new Tuple2<Integer, Integer>(i % 2, 1);
}
});
Map<Integer, Object> fractions = Maps.newHashMap();
fractions.put(0, 0.5);
fractions.put(1, 1.0);
JavaPairRDD<Integer, Integer> wrExact = rdd2.sampleByKeyExact(true, fractions, 1L);
Map<Integer, Long> wrExactCounts = (Map<Integer, Long>) (Object) wrExact.countByKey();
Assert.assertTrue(wrExactCounts.size() == 2);
Assert.assertTrue(wrExactCounts.get(0) == 2);
Assert.assertTrue(wrExactCounts.get(1) == 4);
JavaPairRDD<Integer, Integer> worExact = rdd2.sampleByKey(false, fractions, true, 1L);
JavaPairRDD<Integer, Integer> worExact = rdd2.sampleByKeyExact(false, fractions, 1L);
Map<Integer, Long> worExactCounts = (Map<Integer, Long>) (Object) worExact.countByKey();
Assert.assertTrue(worExactCounts.size() == 2);
Assert.assertTrue(worExactCounts.get(0) == 2);
Expand Down
Loading

0 comments on commit b715aa0

Please sign in to comment.