From e3db56998d2915421e46b159b293c5abc9cc90d0 Mon Sep 17 00:00:00 2001 From: Alexander Ulanov Date: Wed, 9 Jul 2014 15:40:57 +0400 Subject: [PATCH] Addressing reviewers comments mengxr. Added true positive rate and false positive rate. Test suite code style. --- .../mllib/evaluation/MulticlassMetrics.scala | 32 ++++++++++++++++--- .../evaluation/MulticlassMetricsSuite.scala | 28 +++++++++------- 2 files changed, 45 insertions(+), 15 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala index 8f25a3d0020d0..df30fe601604d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala @@ -43,9 +43,24 @@ class MulticlassMetrics(predictionsAndLabels: RDD[(Double, Double)]) extends Log private lazy val fpByClass: Map[Double, Int] = predictionsAndLabels .map { case (prediction, label) => (prediction, if (prediction != label) 1 else 0) - }.reduceByKey(_ + _) + }.reduceByKey(_ + _) .collectAsMap() + /** + * Returns true positive rate for a given label (category) + * @param label the label. + */ + def truePositiveRate(label: Double): Double = recall(label) + + /** + * Returns false positive rate for a given label (category) + * @param label the label. + */ + def falsePositiveRate(label: Double): Double = { + val fp = fpByClass.getOrElse(label, 0) + fp.toDouble / (labelCount - labelCountByClass(label)) + } + /** * Returns precision for a given label (category) * @param label the label. @@ -65,6 +80,7 @@ class MulticlassMetrics(predictionsAndLabels: RDD[(Double, Double)]) extends Log /** * Returns f-measure for a given label (category) * @param label the label. + * @param beta the beta parameter. */ def fMeasure(label: Double, beta: Double): Double = { val p = precision(label) @@ -113,15 +129,23 @@ class MulticlassMetrics(predictionsAndLabels: RDD[(Double, Double)]) extends Log precision(category) * count.toDouble / labelCount }.sum + /** + * Returns weighted averaged f-measure + * @param beta the beta parameter. + */ + def weightedFMeasure(beta: Double): Double = labelCountByClass.map { case (category, count) => + fMeasure(category, beta) * count.toDouble / labelCount + }.sum + /** * Returns weighted averaged f1-measure */ - lazy val weightedF1Measure: Double = labelCountByClass.map { case (category, count) => - fMeasure(category) * count.toDouble / labelCount + lazy val weightedFMeasure: Double = labelCountByClass.map { case (category, count) => + fMeasure(category, 1.0) * count.toDouble / labelCount }.sum /** * Returns the sequence of labels in ascending order */ - lazy val labels:Array[Double] = tpByClass.keys.toArray.sorted + lazy val labels: Array[Double] = tpByClass.keys.toArray.sorted } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala index 9bdd5745677aa..e2dd57d698141 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala @@ -36,15 +36,21 @@ class MulticlassMetricsSuite extends FunSuite with LocalSparkContext { (1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)), 2) val metrics = new MulticlassMetrics(scoreAndLabels) val delta = 0.0000001 - val precision0 = 2.0 / (2.0 + 1.0) - val precision1 = 3.0 / (3.0 + 1.0) - val precision2 = 1.0 / (1.0 + 1.0) - val recall0 = 2.0 / (2.0 + 2.0) - val recall1 = 3.0 / (3.0 + 1.0) - val recall2 = 1.0 / (1.0 + 0.0) + val fpRate0 = 1.0 / (9 - 4) + val fpRate1 = 1.0 / (9 - 4) + val fpRate2 = 1.0 / (9 - 1) + val precision0 = 2.0 / (2 + 1) + val precision1 = 3.0 / (3 + 1) + val precision2 = 1.0 / (1 + 1) + val recall0 = 2.0 / (2 + 2) + val recall1 = 3.0 / (3 + 1) + val recall2 = 1.0 / (1 + 0) val f1measure0 = 2 * precision0 * recall0 / (precision0 + recall0) val f1measure1 = 2 * precision1 * recall1 / (precision1 + recall1) val f1measure2 = 2 * precision2 * recall2 / (precision2 + recall2) + assert(math.abs(metrics.falsePositiveRate(0.0) - fpRate0) < delta) + assert(math.abs(metrics.falsePositiveRate(1.0) - fpRate1) < delta) + assert(math.abs(metrics.falsePositiveRate(2.0) - fpRate2) < delta) assert(math.abs(metrics.precision(0.0) - precision0) < delta) assert(math.abs(metrics.precision(1.0) - precision1) < delta) assert(math.abs(metrics.precision(2.0) - precision2) < delta) @@ -55,16 +61,16 @@ class MulticlassMetricsSuite extends FunSuite with LocalSparkContext { assert(math.abs(metrics.fMeasure(1.0) - f1measure1) < delta) assert(math.abs(metrics.fMeasure(2.0) - f1measure2) < delta) assert(math.abs(metrics.recall - - (2.0 + 3.0 + 1.0) / ((2.0 + 3.0 + 1.0) + (1.0 + 1.0 + 1.0))) < delta) + (2.0 + 3.0 + 1.0) / ((2 + 3 + 1) + (1 + 1 + 1))) < delta) assert(math.abs(metrics.recall - metrics.precision) < delta) assert(math.abs(metrics.recall - metrics.fMeasure) < delta) assert(math.abs(metrics.recall - metrics.weightedRecall) < delta) assert(math.abs(metrics.weightedPrecision - - ((4.0 / 9.0) * precision0 + (4.0 / 9.0) * precision1 + (1.0 / 9.0) * precision2)) < delta) + ((4.0 / 9) * precision0 + (4.0 / 9) * precision1 + (1.0 / 9) * precision2)) < delta) assert(math.abs(metrics.weightedRecall - - ((4.0 / 9.0) * recall0 + (4.0 / 9.0) * recall1 + (1.0 / 9.0) * recall2)) < delta) - assert(math.abs(metrics.weightedF1Measure - - ((4.0 / 9.0) * f1measure0 + (4.0 / 9.0) * f1measure1 + (1.0 / 9.0) * f1measure2)) < delta) + ((4.0 / 9) * recall0 + (4.0 / 9) * recall1 + (1.0 / 9) * recall2)) < delta) + assert(math.abs(metrics.weightedFMeasure - + ((4.0 / 9) * f1measure0 + (4.0 / 9) * f1measure1 + (1.0 / 9) * f1measure2)) < delta) assert(metrics.labels.sameElements(labels)) } }