Skip to content

Commit

Permalink
Merge pull request apache#528 from mengxr/sample. Closes apache#528.
Browse files Browse the repository at this point in the history
 Refactor RDD sampling and add randomSplit to RDD (update)

Replace SampledRDD by PartitionwiseSampledRDD, which accepts a RandomSampler instance as input. The current sample with/without replacement can be easily integrated via BernoulliSampler and PoissonSampler. The benefits are:

1) RDD.randomSplit is implemented in the same way, related to https://github.com/apache/incubator-spark/pull/513
2) Stratified sampling and importance sampling can be implemented in the same manner as well.

Unit tests are included for samplers and RDD.randomSplit.

This should performance better than my previous request where the BernoulliSampler creates many Iterator instances:
https://github.com/apache/incubator-spark/pull/513

Author: Xiangrui Meng <[email protected]>

== Merge branch commits ==

commit e8ce957e5f0a600f2dec057924f4a2ca6adba373
Author: Xiangrui Meng <[email protected]>
Date:   Mon Feb 3 12:21:08 2014 -0800

    more docs to PartitionwiseSampledRDD

commit fbb4586d0478ff638b24bce95f75ff06f713d43b
Author: Xiangrui Meng <[email protected]>
Date:   Mon Feb 3 00:44:23 2014 -0800

    move XORShiftRandom to util.random and use it in BernoulliSampler

commit 987456b0ee8612fd4f73cb8c40967112dc3c4c2d
Author: Xiangrui Meng <[email protected]>
Date:   Sat Feb 1 11:06:59 2014 -0800

    relax assertions in SortingSuite because the RangePartitioner has large variance in this case

commit 3690aae416b2dc9b2f9ba32efa465ba7948477f4
Author: Xiangrui Meng <[email protected]>
Date:   Sat Feb 1 09:56:28 2014 -0800

    test split ratio of RDD.randomSplit

commit 8a410bc933a60c4d63852606f8bbc812e416d6ae
Author: Xiangrui Meng <[email protected]>
Date:   Sat Feb 1 09:25:22 2014 -0800

    add a test to ensure seed distribution and minor style update

commit ce7e866f674c30ab48a9ceb09da846d5362ab4b6
Author: Xiangrui Meng <[email protected]>
Date:   Fri Jan 31 18:06:22 2014 -0800

    minor style change

commit 750912b4d77596ed807d361347bd2b7e3b9b7a74
Author: Xiangrui Meng <[email protected]>
Date:   Fri Jan 31 18:04:54 2014 -0800

    fix some long lines

commit c446a25c38d81db02821f7f194b0ce5ab4ed7ff5
Author: Xiangrui Meng <[email protected]>
Date:   Fri Jan 31 17:59:59 2014 -0800

    add complement to BernoulliSampler and minor style changes

commit dbe2bc2bd888a7bdccb127ee6595840274499403
Author: Xiangrui Meng <[email protected]>
Date:   Fri Jan 31 17:45:08 2014 -0800

    switch to partition-wise sampling for better performance

commit a1fca5232308feb369339eac67864c787455bb23
Merge: ac712e4 cf6128f
Author: Xiangrui Meng <[email protected]>
Date:   Fri Jan 31 16:33:09 2014 -0800

    Merge branch 'sample' of github.com:mengxr/incubator-spark into sample

commit cf6128fb672e8c589615adbd3eaa3cbdb72bd461
Author: Xiangrui Meng <[email protected]>
Date:   Sun Jan 26 14:40:07 2014 -0800

    set SampledRDD deprecated in 1.0

commit f430f847c3df91a3894687c513f23f823f77c255
Author: Xiangrui Meng <[email protected]>
Date:   Sun Jan 26 14:38:59 2014 -0800

    update code style

commit a8b5e2021a9204e318c80a44d00c5c495f1befb6
Author: Xiangrui Meng <[email protected]>
Date:   Sun Jan 26 12:56:27 2014 -0800

    move package random to util.random

commit ab0fa2c4965033737a9e3a9bf0a59cbb0df6a6f5
Author: Xiangrui Meng <[email protected]>
Date:   Sun Jan 26 12:50:35 2014 -0800

    add Apache headers and update code style

commit 985609fe1a55655ad11966e05a93c18c138a403d
Author: Xiangrui Meng <[email protected]>
Date:   Sun Jan 26 11:49:25 2014 -0800

    add new lines

commit b21bddf29850a2c006a868869b8f91960a029322
Author: Xiangrui Meng <[email protected]>
Date:   Sun Jan 26 11:46:35 2014 -0800

    move samplers to random.IndependentRandomSampler and add tests

commit c02dacb4a941618e434cefc129c002915db08be6
Author: Xiangrui Meng <[email protected]>
Date:   Sat Jan 25 15:20:24 2014 -0800

    add RandomSampler

commit 8ff7ba3c5cf1fc338c29ae8b5fa06c222640e89c
Author: Xiangrui Meng <[email protected]>
Date:   Fri Jan 24 13:23:22 2014 -0800

    init impl of IndependentlySampledRDD
  • Loading branch information
mengxr authored and rxin committed Feb 3, 2014
1 parent 1625d8c commit 23af00f
Show file tree
Hide file tree
Showing 13 changed files with 390 additions and 16 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.rdd

import java.util.Random

import scala.reflect.ClassTag

import org.apache.spark.{TaskContext, Partition}
import org.apache.spark.util.random.RandomSampler

private[spark]
class PartitionwiseSampledRDDPartition(val prev: Partition, val seed: Long)
extends Partition with Serializable {
override val index: Int = prev.index
}

/**
* A RDD sampled from its parent RDD partition-wise. For each partition of the parent RDD,
* a user-specified [[org.apache.spark.util.random.RandomSampler]] instance is used to obtain
* a random sample of the records in the partition. The random seeds assigned to the samplers
* are guaranteed to have different values.
*
* @param prev RDD to be sampled
* @param sampler a random sampler
* @param seed random seed, default to System.nanoTime
* @tparam T input RDD item type
* @tparam U sampled RDD item type
*/
class PartitionwiseSampledRDD[T: ClassTag, U: ClassTag](
prev: RDD[T],
sampler: RandomSampler[T, U],
seed: Long = System.nanoTime)
extends RDD[U](prev) {

override def getPartitions: Array[Partition] = {
val random = new Random(seed)
firstParent[T].partitions.map(x => new PartitionwiseSampledRDDPartition(x, random.nextLong()))
}

override def getPreferredLocations(split: Partition): Seq[String] =
firstParent[T].preferredLocations(split.asInstanceOf[PartitionwiseSampledRDDPartition].prev)

override def compute(splitIn: Partition, context: TaskContext): Iterator[U] = {
val split = splitIn.asInstanceOf[PartitionwiseSampledRDDPartition]
val thisSampler = sampler.clone
thisSampler.setSeed(split.seed)
thisSampler.sample(firstParent[T].iterator(split.prev, context))
}
}
26 changes: 24 additions & 2 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ import org.apache.spark.util.{Utils, BoundedPriorityQueue, SerializableHyperLogL

import org.apache.spark.SparkContext._
import org.apache.spark._
import org.apache.spark.util.random.{PoissonSampler, BernoulliSampler}

/**
* A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. Represents an immutable,
Expand Down Expand Up @@ -319,8 +320,29 @@ abstract class RDD[T: ClassTag](
/**
* Return a sampled subset of this RDD.
*/
def sample(withReplacement: Boolean, fraction: Double, seed: Int): RDD[T] =
new SampledRDD(this, withReplacement, fraction, seed)
def sample(withReplacement: Boolean, fraction: Double, seed: Int): RDD[T] = {
if (withReplacement) {
new PartitionwiseSampledRDD[T, T](this, new PoissonSampler[T](fraction), seed)
} else {
new PartitionwiseSampledRDD[T, T](this, new BernoulliSampler[T](fraction), seed)
}
}

/**
* Randomly splits this RDD with the provided weights.
*
* @param weights weights for splits, will be normalized if they don't sum to 1
* @param seed random seed, default to System.nanoTime
*
* @return split RDDs in an array
*/
def randomSplit(weights: Array[Double], seed: Long = System.nanoTime): Array[RDD[T]] = {
val sum = weights.sum
val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _)
normalizedCumWeights.sliding(2).map { x =>
new PartitionwiseSampledRDD[T, T](this, new BernoulliSampler[T](x(0), x(1)), seed)
}.toArray
}

def takeSample(withReplacement: Boolean, num: Int, seed: Int): Array[T] = {
var fraction = 0.0
Expand Down
2 changes: 2 additions & 0 deletions core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,13 @@ import cern.jet.random.engine.DRand

import org.apache.spark.{Partition, TaskContext}

@deprecated("Replaced by PartitionwiseSampledRDDPartition", "1.0")
private[spark]
class SampledRDDPartition(val prev: Partition, val seed: Int) extends Partition with Serializable {
override val index: Int = prev.index
}

@deprecated("Replaced by PartitionwiseSampledRDD", "1.0")
class SampledRDD[T: ClassTag](
prev: RDD[T],
withReplacement: Boolean,
Expand Down
1 change: 1 addition & 0 deletions core/src/main/scala/org/apache/spark/util/Vector.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.util

import scala.util.Random
import org.apache.spark.util.random.XORShiftRandom

class Vector(val elements: Array[Double]) extends Serializable {
def length = elements.length
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.util.random

/**
* A class with pseudorandom behavior.
*/
trait Pseudorandom {
/** Set random seed. */
def setSeed(seed: Long)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.util.random

import java.util.Random
import cern.jet.random.Poisson
import cern.jet.random.engine.DRand

/**
* A pseudorandom sampler. It is possible to change the sampled item type. For example, we might
* want to add weights for stratified sampling or importance sampling. Should only use
* transformations that are tied to the sampler and cannot be applied after sampling.
*
* @tparam T item type
* @tparam U sampled item type
*/
trait RandomSampler[T, U] extends Pseudorandom with Cloneable with Serializable {

/** take a random sample */
def sample(items: Iterator[T]): Iterator[U]

override def clone: RandomSampler[T, U] =
throw new NotImplementedError("clone() is not implemented.")
}

/**
* A sampler based on Bernoulli trials.
*
* @param lb lower bound of the acceptance range
* @param ub upper bound of the acceptance range
* @param complement whether to use the complement of the range specified, default to false
* @tparam T item type
*/
class BernoulliSampler[T](lb: Double, ub: Double, complement: Boolean = false)
(implicit random: Random = new XORShiftRandom)
extends RandomSampler[T, T] {

def this(ratio: Double)(implicit random: Random = new XORShiftRandom)
= this(0.0d, ratio)(random)

override def setSeed(seed: Long) = random.setSeed(seed)

override def sample(items: Iterator[T]): Iterator[T] = {
items.filter { item =>
val x = random.nextDouble()
(x >= lb && x < ub) ^ complement
}
}

override def clone = new BernoulliSampler[T](lb, ub)
}

/**
* A sampler based on values drawn from Poisson distribution.
*
* @param poisson a Poisson random number generator
* @tparam T item type
*/
class PoissonSampler[T](mean: Double)
(implicit var poisson: Poisson = new Poisson(mean, new DRand))
extends RandomSampler[T, T] {

override def setSeed(seed: Long) {
poisson = new Poisson(mean, new DRand(seed.toInt))
}

override def sample(items: Iterator[T]): Iterator[T] = {
items.flatMap { item =>
val count = poisson.nextInt()
if (count == 0) {
Iterator.empty
} else {
Iterator.fill(count)(item)
}
}
}

override def clone = new PoissonSampler[T](mean)
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
* limitations under the License.
*/

package org.apache.spark.util
package org.apache.spark.util.random

import java.util.{Random => JavaRandom}
import org.apache.spark.util.Utils.timeIt
Expand Down Expand Up @@ -46,6 +46,10 @@ private[spark] class XORShiftRandom(init: Long) extends JavaRandom(init) {
seed = nextSeed
(nextSeed & ((1L << bits) -1)).asInstanceOf[Int]
}

override def setSeed(s: Long) {
seed = s
}
}

/** Contains benchmark method and main method to run benchmark of the RNG */
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.rdd

import org.scalatest.FunSuite
import org.apache.spark.SharedSparkContext
import org.apache.spark.util.random.RandomSampler

/** a sampler that outputs its seed */
class MockSampler extends RandomSampler[Long, Long] {

private var s: Long = _

override def setSeed(seed: Long) {
s = seed
}

override def sample(items: Iterator[Long]): Iterator[Long] = {
return Iterator(s)
}

override def clone = new MockSampler
}

class PartitionwiseSampledRDDSuite extends FunSuite with SharedSparkContext {

test("seedDistribution") {
val rdd = sc.makeRDD(Array(1L, 2L, 3L, 4L), 2)
val sampler = new MockSampler
val sample = new PartitionwiseSampledRDD[Long, Long](rdd, sampler, 0L)
assert(sample.distinct.count == 2, "Seeds must be different.")
}
}

15 changes: 15 additions & 0 deletions core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,21 @@ class RDDSuite extends FunSuite with SharedSparkContext {
}
}

test("randomSplit") {
val n = 600
val data = sc.parallelize(1 to n, 2)
for(seed <- 1 to 5) {
val splits = data.randomSplit(Array(1.0, 2.0, 3.0), seed)
assert(splits.size == 3, "wrong number of splits")
assert(splits.flatMap(_.collect).sorted.toList == data.collect.toList,
"incomplete or wrong split")
val s = splits.map(_.count)
assert(math.abs(s(0) - 100) < 50) // std = 9.13
assert(math.abs(s(1) - 200) < 50) // std = 11.55
assert(math.abs(s(2) - 300) < 50) // std = 12.25
}
}

test("runJob on an invalid partition") {
intercept[IllegalArgumentException] {
sc.runJob(sc.parallelize(1 to 10, 2), {iter: Iterator[Int] => iter.size}, Seq(0, 1, 2), false)
Expand Down
16 changes: 8 additions & 8 deletions core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,10 @@ class SortingSuite extends FunSuite with SharedSparkContext with ShouldMatchers
assert(sorted.collect() === pairArr.sortBy(_._1))
val partitions = sorted.collectPartitions()
logInfo("Partition lengths: " + partitions.map(_.length).mkString(", "))
partitions(0).length should be > 180
partitions(1).length should be > 180
partitions(2).length should be > 180
partitions(3).length should be > 180
val lengthArr = partitions.map(_.length)
lengthArr.foreach { len =>
assert(len > 100 && len < 400)
}
partitions(0).last should be < partitions(1).head
partitions(1).last should be < partitions(2).head
partitions(2).last should be < partitions(3).head
Expand All @@ -113,10 +113,10 @@ class SortingSuite extends FunSuite with SharedSparkContext with ShouldMatchers
assert(sorted.collect() === pairArr.sortBy(_._1).reverse)
val partitions = sorted.collectPartitions()
logInfo("partition lengths: " + partitions.map(_.length).mkString(", "))
partitions(0).length should be > 180
partitions(1).length should be > 180
partitions(2).length should be > 180
partitions(3).length should be > 180
val lengthArr = partitions.map(_.length)
lengthArr.foreach { len =>
assert(len > 100 && len < 400)
}
partitions(0).last should be > partitions(1).head
partitions(1).last should be > partitions(2).head
partitions(2).last should be > partitions(3).head
Expand Down
Loading

0 comments on commit 23af00f

Please sign in to comment.