Skip to content

Commit

Permalink
Fixes to mutliclass metics
Browse files Browse the repository at this point in the history
  • Loading branch information
avulanov committed Jun 30, 2014
1 parent d5ce981 commit e2c91c3
Showing 1 changed file with 18 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -60,72 +60,75 @@ class MulticlassMetrics(predictionsAndLabels: RDD[(Double, Double)]) extends Log
* @param label the label.
* @return F1-measure.
*/
def f1Measure(label: Double): Double =
2 * precision(label) * recall(label) / (precision(label) + recall(label))
def f1Measure(label: Double): Double ={
val p = precision(label)
val r = recall(label)
if((p + r) == 0) 0 else 2 * p * r / (p + r)
}

/**
* Returns micro-averaged Recall
* (equals to microPrecision and microF1measure for multiclass classifier)
* @return microRecall.
*/
def microRecall: Double =
tpByClass.foldLeft(0L){case (sum,(_, tp)) => sum + tp}.toDouble / labelCount.toDouble
lazy val microRecall: Double =
tpByClass.foldLeft(0L){case (sum,(_, tp)) => sum + tp}.toDouble / labelCount

/**
* Returns micro-averaged Precision
* (equals to microPrecision and microF1measure for multiclass classifier)
* @return microPrecision.
*/
def microPrecision: Double = microRecall
lazy val microPrecision: Double = microRecall

/**
* Returns micro-averaged F1-measure
* (equals to microPrecision and microRecall for multiclass classifier)
* @return microF1measure.
*/
def microF1Measure: Double = microRecall
lazy val microF1Measure: Double = microRecall

/**
* Returns weighted averaged Recall
* @return weightedRecall.
*/
def weightedRecall: Double = labelCountByClass.foldLeft(0.0){case(wRecall, (category, count)) =>
wRecall + recall(category) * count.toDouble / labelCount.toDouble}
lazy val weightedRecall: Double = labelCountByClass.foldLeft(0.0){case(wRecall, (category, count)) =>
wRecall + recall(category) * count.toDouble / labelCount}

/**
* Returns weighted averaged Precision
* @return weightedPrecision.
*/
def weightedPrecision: Double =
lazy val weightedPrecision: Double =
labelCountByClass.foldLeft(0.0){case(wPrecision, (category, count)) =>
wPrecision + precision(category) * count.toDouble / labelCount.toDouble}
wPrecision + precision(category) * count.toDouble / labelCount}

/**
* Returns weighted averaged F1-measure
* @return weightedF1Measure.
*/
def weightedF1Measure: Double =
lazy val weightedF1Measure: Double =
labelCountByClass.foldLeft(0.0){case(wF1measure, (category, count)) =>
wF1measure + f1Measure(category) * count.toDouble / labelCount.toDouble}
wF1measure + f1Measure(category) * count.toDouble / labelCount}

/**
* Returns map with Precisions for individual classes
* @return precisionPerClass.
*/
def precisionPerClass =
lazy val precisionPerClass =
labelCountByClass.map{case (category, _) => (category, precision(category))}.toMap

/**
* Returns map with Recalls for individual classes
* @return recallPerClass.
*/
def recallPerClass =
lazy val recallPerClass =
labelCountByClass.map{case (category, _) => (category, recall(category))}.toMap

/**
* Returns map with F1-measures for individual classes
* @return f1MeasurePerClass.
*/
def f1MeasurePerClass =
lazy val f1MeasurePerClass =
labelCountByClass.map{case (category, _) => (category, f1Measure(category))}.toMap
}

0 comments on commit e2c91c3

Please sign in to comment.