diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala index 76fe96a5938c0..dff84224a0e31 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala @@ -23,18 +23,21 @@ import org.apache.spark.SparkContext._ /** * Evaluator for multiclass classification. + * NB: type Double both for prediction and label is retained + * for compatibility with model.predict that returns Double + * and MLUtils.loadLibSVMFile that loads class labels as Double * - * @param scoreAndLabels an RDD of (score, label) pairs. + * @param predictionsAndLabels an RDD of (prediction, label) pairs. */ -class MulticlassMetrics(scoreAndLabels: RDD[(Double, Double)]) extends Logging { +class MulticlassMetrics(predictionsAndLabels: RDD[(Double, Double)]) extends Logging { /* class = category; label = instance of class; prediction = instance of class */ - private lazy val labelCountByClass = scoreAndLabels.values.countByValue() + private lazy val labelCountByClass = predictionsAndLabels.values.countByValue() private lazy val labelCount = labelCountByClass.foldLeft(0L){case(sum, (_, count)) => sum + count} - private lazy val tpByClass = scoreAndLabels.map{ case (prediction, label) => + private lazy val tpByClass = predictionsAndLabels.map{ case (prediction, label) => (label, if(label == prediction) 1 else 0) }.reduceByKey{_ + _}.collectAsMap - private lazy val fpByClass = scoreAndLabels.map{ case (prediction, label) => + private lazy val fpByClass = predictionsAndLabels.map{ case (prediction, label) => (prediction, if(prediction != label) 1 else 0) }.reduceByKey{_ + _}.collectAsMap /**