Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SPARK-1246, added min max API to Double RDDs in java and scala APIs. #140

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,14 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[JDouble, Ja
/** Add up the elements in this RDD. */
def sum(): JDouble = srdd.sum()


/** Max of the elements in this RDD. */
def max(): JDouble = srdd.max()


/** Min of the elements in this RDD. */
def min(): JDouble = srdd.min()

/**
* Return a [[org.apache.spark.util.StatCounter]] object that captures the mean, variance and
* count of the RDD's elements in one operation.
Expand Down
21 changes: 11 additions & 10 deletions core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable {
}

/**
* Return a [[org.apache.spark.util.StatCounter]] object that captures the mean, variance and
* count of the RDD's elements in one operation.
* Return a [[org.apache.spark.util.StatCounter]] object that captures the mean, variance,
* count, min and max of the RDD's elements in one operation.
*/
def stats(): StatCounter = {
self.mapPartitions(nums => Iterator(StatCounter(nums))).reduce((a, b) => a.merge(b))
Expand All @@ -51,6 +51,12 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable {
/** Compute the standard deviation of this RDD's elements. */
def stdev(): Double = stats().stdev

/** Find the min element of this RDD's elements. */
def min(): Double = stats().min

/** Find the max element of this RDD's elements. */
def max(): Double = stats().max

/**
* Compute the sample standard deviation of this RDD's elements (which corrects for bias in
* estimating the standard deviation by dividing by N-1 instead of N).
Expand Down Expand Up @@ -86,14 +92,9 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable {
* If the elements in RDD do not vary (max == min) always returns a single bucket.
*/
def histogram(bucketCount: Int): Pair[Array[Double], Array[Long]] = {
// Compute the minimum and the maxium
val (max: Double, min: Double) = self.mapPartitions { items =>
Iterator(items.foldRight(Double.NegativeInfinity,
Double.PositiveInfinity)((e: Double, x: Pair[Double, Double]) =>
(x._1.max(e), x._2.min(e))))
}.reduce { (maxmin1, maxmin2) =>
(maxmin1._1.max(maxmin2._1), maxmin1._2.min(maxmin2._2))
}
// Compute the minimum and the maximum from stats once
val _stats = stats()
val (max: Double, min: Double) = (_stats.max, _stats.min)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor, but I'd prefer not to unnecessarily bind a val in a scope where it doesn't need to be used:

// Compute the minimum and the maximum from stats once
val (max, min) = { val sc = stats(); (sc.max, sc.min) } 

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, will change that.

if (min.isNaN || max.isNaN || max.isInfinity || min.isInfinity ) {
throw new UnsupportedOperationException(
"Histogram on either an empty RDD or RDD containing +/-infinity or NaN")
Expand Down
22 changes: 18 additions & 4 deletions core/src/main/scala/org/apache/spark/util/StatCounter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ class StatCounter(values: TraversableOnce[Double]) extends Serializable {
private var n: Long = 0 // Running count of our values
private var mu: Double = 0 // Running mean of our values
private var m2: Double = 0 // Running variance numerator (sum of (x - mean)^2)
private var _min: Double = Double.PositiveInfinity
private var _max: Double = Double.NegativeInfinity

def min: Double = _min

def max: Double = _max

merge(values)

Expand All @@ -41,6 +47,8 @@ class StatCounter(values: TraversableOnce[Double]) extends Serializable {
n += 1
mu += delta / n
m2 += delta * (value - mu)
_min = math.min(value, _min)
_max = math.max(value, _max)
this
}

Expand All @@ -58,7 +66,9 @@ class StatCounter(values: TraversableOnce[Double]) extends Serializable {
if (n == 0) {
mu = other.mu
m2 = other.m2
n = other.n
n = other.n
_min = other.min
_max = other.max
} else if (other.n != 0) {
val delta = other.mu - mu
if (other.n * 10 < n) {
Expand All @@ -70,6 +80,8 @@ class StatCounter(values: TraversableOnce[Double]) extends Serializable {
}
m2 += other.m2 + (delta * delta * n * other.n) / (n + other.n)
n += other.n
_min = math.min(other.min, _min)
_max = math.max(other.max, _max)
}
this
}
Expand All @@ -81,6 +93,8 @@ class StatCounter(values: TraversableOnce[Double]) extends Serializable {
other.n = n
other.mu = mu
other.m2 = m2
other._min = _min
other._max = _max
other
}

Expand Down Expand Up @@ -120,9 +134,9 @@ class StatCounter(values: TraversableOnce[Double]) extends Serializable {
*/
def sampleStdev: Double = math.sqrt(sampleVariance)

override def toString: String = {
"(count: %d, mean: %f, stdev: %f)".format(count, mean, stdev)
}
override def toString: String =
s"count: $count, mean: $mean, stdev: $stdev , min: ${_min}, max: ${_max}"

}

object StatCounter {
Expand Down
3 changes: 2 additions & 1 deletion core/src/test/java/org/apache/spark/JavaAPISuite.java
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,8 @@ public Boolean call(Double x) {
Assert.assertEquals(7.46667, rdd.sampleVariance(), 0.01);
Assert.assertEquals(2.49444, rdd.stdev(), 0.01);
Assert.assertEquals(2.73252, rdd.sampleStdev(), 0.01);

Assert.assertEquals(1.0, rdd.min(), 0.01);
Assert.assertEquals(8.0, rdd.max(), 0.01);
rdd.first();
rdd.take(5);
}
Expand Down
2 changes: 2 additions & 0 deletions core/src/test/scala/org/apache/spark/PartitioningSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,8 @@ class PartitioningSuite extends FunSuite with SharedSparkContext with PrivateMet
assert(abs(6.0/2 - rdd.mean) < 0.01)
assert(abs(1.0 - rdd.variance) < 0.01)
assert(abs(1.0 - rdd.stdev) < 0.01)
assert(abs(2.0 - stats.min) < 0.01)
assert(abs(4.0 - stats.max) < 0.01)

// Add other tests here for classes that should be able to handle empty partitions correctly
}
Expand Down
2 changes: 1 addition & 1 deletion project/build.properties
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
sbt.version=0.13.1
sbt.version=0.13.2-M1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason for this? I'd rather not depend on a milestone release of sbt because it's more likely to have bugs.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was accidental, (sorry about that.). I use this version of sbt locally since its really fast with incremental builds.