Skip to content

Commit

Permalink
Addressing reviewers comments mengxr
Browse files Browse the repository at this point in the history
  • Loading branch information
avulanov committed Jul 11, 2014
1 parent 4811378 commit f0dadc9
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.mllib.evaluation
import org.apache.spark.Logging
import org.apache.spark.SparkContext._
import org.apache.spark.annotation.Experimental
import org.apache.spark.mllib.linalg.{Matrices, Matrix}
import org.apache.spark.rdd.RDD

import scala.collection.Map
Expand All @@ -31,19 +32,19 @@ import scala.collection.Map
* @param predictionAndLabels an RDD of (prediction, label) pairs.
*/
@Experimental
class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) extends Logging {
class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) {

private lazy val labelCountByClass: Map[Double, Long] = predictionAndLabels.values.countByValue()
private lazy val labelCount: Long = labelCountByClass.values.sum
private lazy val tpByClass: Map[Double, Int] = predictionAndLabels
.map { case (prediction, label) =>
(label, if (label == prediction) 1 else 0)
}.reduceByKey(_ + _)
(label, if (label == prediction) 1 else 0)
}.reduceByKey(_ + _)
.collectAsMap()
private lazy val fpByClass: Map[Double, Int] = predictionAndLabels
.map { case (prediction, label) =>
(prediction, if (prediction != label) 1 else 0)
}.reduceByKey(_ + _)
(prediction, if (prediction != label) 1 else 0)
}.reduceByKey(_ + _)
.collectAsMap()
private lazy val confusions = predictionAndLabels.map {
case (prediction, label) => ((prediction, label), 1)
Expand All @@ -55,12 +56,13 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) extends Logg
* they are ordered by class label ascending,
* as in "labels"
*/
lazy val confusionMatrix: Array[Array[Int]] = {
val matrix = Array.ofDim[Int](labels.size, labels.size)
lazy val confusionMatrix: Matrix = {
val transposedMatrix = Array.ofDim[Double](labels.size, labels.size)
for (i <- 0 to labels.size - 1; j <- 0 to labels.size - 1) {
matrix(j)(i) = confusions.getOrElse((labels(i), labels(j)), 0)
transposedMatrix(i)(j) = confusions.getOrElse((labels(i), labels(j)), 0).toDouble
}
matrix
val flatMatrix = transposedMatrix.flatMap(arr => arr)
Matrices.dense(transposedMatrix.length, transposedMatrix(0).length, flatMatrix)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.mllib.evaluation

import org.apache.spark.mllib.linalg.Matrices
import org.apache.spark.mllib.util.LocalSparkContext
import org.scalatest.FunSuite

Expand All @@ -28,7 +29,7 @@ class MulticlassMetricsSuite extends FunSuite with LocalSparkContext {
* |1|3|0| true class1 (4 instances)
* |0|0|1| true class2 (1 instance)
*/
val confusionMatrix = Array(Array(2, 1, 1), Array(1, 3, 0), Array(0, 0, 1))
val confusionMatrix = Matrices.dense(3, 3, Array(2, 1, 0, 1, 3, 0, 1, 0, 1))
val labels = Array(0.0, 1.0, 2.0)
val predictionAndLabels = sc.parallelize(
Seq((0.0, 0.0), (0.0, 1.0), (0.0, 0.0), (1.0, 0.0), (1.0, 1.0),
Expand All @@ -51,7 +52,7 @@ class MulticlassMetricsSuite extends FunSuite with LocalSparkContext {
val f2measure1 = (1 + 2 * 2) * precision1 * recall1 / (2 * 2 * precision1 + recall1)
val f2measure2 = (1 + 2 * 2) * precision2 * recall2 / (2 * 2 * precision2 + recall2)

assert(metrics.confusionMatrix.deep == confusionMatrix.deep)
assert(metrics.confusionMatrix.toArray.sameElements(confusionMatrix.toArray))
assert(math.abs(metrics.falsePositiveRate(0.0) - fpRate0) < delta)
assert(math.abs(metrics.falsePositiveRate(1.0) - fpRate1) < delta)
assert(math.abs(metrics.falsePositiveRate(2.0) - fpRate2) < delta)
Expand Down

0 comments on commit f0dadc9

Please sign in to comment.