Skip to content

Commit

Permalink
Addressing reviewers comments mengxr. Added true positive rate and fa…
Browse files Browse the repository at this point in the history
…lse positive rate. Test suite code style.
  • Loading branch information
avulanov committed Jul 9, 2014
1 parent a7e8bf0 commit e3db569
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,24 @@ class MulticlassMetrics(predictionsAndLabels: RDD[(Double, Double)]) extends Log
private lazy val fpByClass: Map[Double, Int] = predictionsAndLabels
.map { case (prediction, label) =>
(prediction, if (prediction != label) 1 else 0)
}.reduceByKey(_ + _)
}.reduceByKey(_ + _)
.collectAsMap()

/**
* Returns true positive rate for a given label (category)
* @param label the label.
*/
def truePositiveRate(label: Double): Double = recall(label)

/**
* Returns false positive rate for a given label (category)
* @param label the label.
*/
def falsePositiveRate(label: Double): Double = {
val fp = fpByClass.getOrElse(label, 0)
fp.toDouble / (labelCount - labelCountByClass(label))
}

/**
* Returns precision for a given label (category)
* @param label the label.
Expand All @@ -65,6 +80,7 @@ class MulticlassMetrics(predictionsAndLabels: RDD[(Double, Double)]) extends Log
/**
* Returns f-measure for a given label (category)
* @param label the label.
* @param beta the beta parameter.
*/
def fMeasure(label: Double, beta: Double): Double = {
val p = precision(label)
Expand Down Expand Up @@ -113,15 +129,23 @@ class MulticlassMetrics(predictionsAndLabels: RDD[(Double, Double)]) extends Log
precision(category) * count.toDouble / labelCount
}.sum

/**
* Returns weighted averaged f-measure
* @param beta the beta parameter.
*/
def weightedFMeasure(beta: Double): Double = labelCountByClass.map { case (category, count) =>
fMeasure(category, beta) * count.toDouble / labelCount
}.sum

/**
* Returns weighted averaged f1-measure
*/
lazy val weightedF1Measure: Double = labelCountByClass.map { case (category, count) =>
fMeasure(category) * count.toDouble / labelCount
lazy val weightedFMeasure: Double = labelCountByClass.map { case (category, count) =>
fMeasure(category, 1.0) * count.toDouble / labelCount
}.sum

/**
* Returns the sequence of labels in ascending order
*/
lazy val labels:Array[Double] = tpByClass.keys.toArray.sorted
lazy val labels: Array[Double] = tpByClass.keys.toArray.sorted
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,21 @@ class MulticlassMetricsSuite extends FunSuite with LocalSparkContext {
(1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)), 2)
val metrics = new MulticlassMetrics(scoreAndLabels)
val delta = 0.0000001
val precision0 = 2.0 / (2.0 + 1.0)
val precision1 = 3.0 / (3.0 + 1.0)
val precision2 = 1.0 / (1.0 + 1.0)
val recall0 = 2.0 / (2.0 + 2.0)
val recall1 = 3.0 / (3.0 + 1.0)
val recall2 = 1.0 / (1.0 + 0.0)
val fpRate0 = 1.0 / (9 - 4)
val fpRate1 = 1.0 / (9 - 4)
val fpRate2 = 1.0 / (9 - 1)
val precision0 = 2.0 / (2 + 1)
val precision1 = 3.0 / (3 + 1)
val precision2 = 1.0 / (1 + 1)
val recall0 = 2.0 / (2 + 2)
val recall1 = 3.0 / (3 + 1)
val recall2 = 1.0 / (1 + 0)
val f1measure0 = 2 * precision0 * recall0 / (precision0 + recall0)
val f1measure1 = 2 * precision1 * recall1 / (precision1 + recall1)
val f1measure2 = 2 * precision2 * recall2 / (precision2 + recall2)
assert(math.abs(metrics.falsePositiveRate(0.0) - fpRate0) < delta)
assert(math.abs(metrics.falsePositiveRate(1.0) - fpRate1) < delta)
assert(math.abs(metrics.falsePositiveRate(2.0) - fpRate2) < delta)
assert(math.abs(metrics.precision(0.0) - precision0) < delta)
assert(math.abs(metrics.precision(1.0) - precision1) < delta)
assert(math.abs(metrics.precision(2.0) - precision2) < delta)
Expand All @@ -55,16 +61,16 @@ class MulticlassMetricsSuite extends FunSuite with LocalSparkContext {
assert(math.abs(metrics.fMeasure(1.0) - f1measure1) < delta)
assert(math.abs(metrics.fMeasure(2.0) - f1measure2) < delta)
assert(math.abs(metrics.recall -
(2.0 + 3.0 + 1.0) / ((2.0 + 3.0 + 1.0) + (1.0 + 1.0 + 1.0))) < delta)
(2.0 + 3.0 + 1.0) / ((2 + 3 + 1) + (1 + 1 + 1))) < delta)
assert(math.abs(metrics.recall - metrics.precision) < delta)
assert(math.abs(metrics.recall - metrics.fMeasure) < delta)
assert(math.abs(metrics.recall - metrics.weightedRecall) < delta)
assert(math.abs(metrics.weightedPrecision -
((4.0 / 9.0) * precision0 + (4.0 / 9.0) * precision1 + (1.0 / 9.0) * precision2)) < delta)
((4.0 / 9) * precision0 + (4.0 / 9) * precision1 + (1.0 / 9) * precision2)) < delta)
assert(math.abs(metrics.weightedRecall -
((4.0 / 9.0) * recall0 + (4.0 / 9.0) * recall1 + (1.0 / 9.0) * recall2)) < delta)
assert(math.abs(metrics.weightedF1Measure -
((4.0 / 9.0) * f1measure0 + (4.0 / 9.0) * f1measure1 + (1.0 / 9.0) * f1measure2)) < delta)
((4.0 / 9) * recall0 + (4.0 / 9) * recall1 + (1.0 / 9) * recall2)) < delta)
assert(math.abs(metrics.weightedFMeasure -
((4.0 / 9) * f1measure0 + (4.0 / 9) * f1measure1 + (1.0 / 9) * f1measure2)) < delta)
assert(metrics.labels.sameElements(labels))
}
}

0 comments on commit e3db569

Please sign in to comment.