diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala index 3b7ba7288c0f3..d48068719a851 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala @@ -17,14 +17,13 @@ package org.apache.spark.mllib.evaluation -import org.apache.spark.Logging +import scala.collection.Map + import org.apache.spark.SparkContext._ import org.apache.spark.annotation.Experimental import org.apache.spark.mllib.linalg.{Matrices, Matrix} import org.apache.spark.rdd.RDD -import scala.collection.Map - /** * ::Experimental:: * Evaluator for multiclass classification. @@ -57,12 +56,12 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) { * as in "labels" */ lazy val confusionMatrix: Matrix = { - val transposedMatrix = Array.ofDim[Double](labels.size, labels.size) + val transposedFlatMatrix = Array.ofDim[Double](labels.size * labels.size) for (i <- 0 to labels.size - 1; j <- 0 to labels.size - 1) { - transposedMatrix(i)(j) = confusions.getOrElse((labels(i), labels(j)), 0).toDouble + transposedFlatMatrix(i * labels.size + j) + = confusions.getOrElse((labels(i), labels(j)), 0).toDouble } - val flatMatrix = transposedMatrix.flatMap(arr => arr) - Matrices.dense(transposedMatrix.length, transposedMatrix(0).length, flatMatrix) + Matrices.dense(labels.size, labels.size, transposedFlatMatrix) } /**