From 1843f739a6fc5cc341ef9f455859276f26f2b3b9 Mon Sep 17 00:00:00 2001 From: Alexander Ulanov Date: Wed, 16 Jul 2014 12:41:11 +0400 Subject: [PATCH] Scala style fix --- .../mllib/evaluation/MultilabelMetrics.scala | 77 ++++++++----------- .../evaluation/MultilabelMetricsSuite.scala | 16 +--- 2 files changed, 35 insertions(+), 58 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala index 8e1afdd0ae17e..432cabf1c8a4d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala @@ -17,94 +17,83 @@ package org.apache.spark.mllib.evaluation -import org.apache.spark.Logging import org.apache.spark.rdd.RDD import org.apache.spark.SparkContext._ /** * Evaluator for multilabel classification. - * NB: type Double both for prediction and label is retained - * for compatibility with model.predict that returns Double - * and MLUtils.loadLibSVMFile that loads class labels as Double - * * @param predictionAndLabels an RDD of (predictions, labels) pairs, both are non-null sets. */ -class MultilabelMetrics(predictionAndLabels:RDD[(Set[Double], Set[Double])]) extends Logging{ +class MultilabelMetrics(predictionAndLabels: RDD[(Set[Double], Set[Double])]) { - private lazy val numDocs = predictionAndLabels.count + private lazy val numDocs: Long = predictionAndLabels.count - private lazy val numLabels = predictionAndLabels.flatMap{case(_, labels) => labels}.distinct.count + private lazy val numLabels: Long = predictionAndLabels.flatMap { case (_, labels) => + labels}.distinct.count /** * Returns strict Accuracy * (for equal sets of labels) - * @return strictAccuracy. */ - lazy val strictAccuracy = predictionAndLabels.filter{case(predictions, labels) => + lazy val strictAccuracy: Double = predictionAndLabels.filter { case (predictions, labels) => predictions == labels}.count.toDouble / numDocs /** * Returns Accuracy - * @return Accuracy. */ - lazy val accuracy = predictionAndLabels.map{ case(predictions, labels) => + lazy val accuracy: Double = predictionAndLabels.map { case (predictions, labels) => labels.intersect(predictions).size.toDouble / labels.union(predictions).size}.sum / numDocs /** * Returns Hamming-loss - * @return hammingLoss. */ - lazy val hammingLoss = (predictionAndLabels.map{ case(predictions, labels) => + lazy val hammingLoss: Double = (predictionAndLabels.map { case (predictions, labels) => labels.diff(predictions).size + predictions.diff(labels).size}. sum).toDouble / (numDocs * numLabels) /** * Returns Document-based Precision averaged by the number of documents - * @return macroPrecisionDoc. */ - lazy val macroPrecisionDoc = (predictionAndLabels.map{ case(predictions, labels) => - if(predictions.size >0) - predictions.intersect(labels).size.toDouble / predictions.size else 0}.sum) / numDocs + lazy val macroPrecisionDoc: Double = (predictionAndLabels.map { case (predictions, labels) => + if (predictions.size > 0) { + predictions.intersect(labels).size.toDouble / predictions.size + } else 0 + }.sum) / numDocs /** * Returns Document-based Recall averaged by the number of documents - * @return macroRecallDoc. */ - lazy val macroRecallDoc = (predictionAndLabels.map{ case(predictions, labels) => + lazy val macroRecallDoc: Double = (predictionAndLabels.map { case (predictions, labels) => labels.intersect(predictions).size.toDouble / labels.size}.sum) / numDocs /** * Returns Document-based F1-measure averaged by the number of documents - * @return macroRecallDoc. */ - lazy val macroF1MeasureDoc = (predictionAndLabels.map{ case(predictions, labels) => + lazy val macroF1MeasureDoc: Double = (predictionAndLabels.map { case (predictions, labels) => 2.0 * predictions.intersect(labels).size / (predictions.size + labels.size)}.sum) / numDocs /** * Returns micro-averaged document-based Precision * (equals to label-based microPrecision) - * @return microPrecisionDoc. */ - lazy val microPrecisionDoc = microPrecisionClass + lazy val microPrecisionDoc: Double = microPrecisionClass /** * Returns micro-averaged document-based Recall * (equals to label-based microRecall) - * @return microRecallDoc. */ - lazy val microRecallDoc = microRecallClass + lazy val microRecallDoc: Double = microRecallClass /** * Returns micro-averaged document-based F1-measure * (equals to label-based microF1measure) - * @return microF1MeasureDoc. */ - lazy val microF1MeasureDoc = microF1MeasureClass + lazy val microF1MeasureDoc: Double = microF1MeasureClass - private lazy val tpPerClass = predictionAndLabels.flatMap{ case(predictions, labels) => + private lazy val tpPerClass = predictionAndLabels.flatMap { case (predictions, labels) => predictions.intersect(labels).map(category => (category, 1))}.reduceByKey(_ + _).collectAsMap() - private lazy val fpPerClass = predictionAndLabels.flatMap{ case(predictions, labels) => + private lazy val fpPerClass = predictionAndLabels.flatMap { case(predictions, labels) => predictions.diff(labels).map(category => (category, 1))}.reduceByKey(_ + _).collectAsMap() private lazy val fnPerClass = predictionAndLabels.flatMap{ case(predictions, labels) => @@ -113,24 +102,26 @@ class MultilabelMetrics(predictionAndLabels:RDD[(Set[Double], Set[Double])]) ext /** * Returns Precision for a given label (category) * @param label the label. - * @return Precision. */ - def precisionClass(label: Double) = if((tpPerClass(label) + fpPerClass.getOrElse(label, 0)) == 0) - 0 else tpPerClass(label).toDouble / (tpPerClass(label) + fpPerClass.getOrElse(label, 0)) + def precisionClass(label: Double) = { + val tp = tpPerClass(label) + val fp = fpPerClass.getOrElse(label, 0) + if (tp + fp == 0) 0 else tp.toDouble / (tp + fp) + } /** * Returns Recall for a given label (category) * @param label the label. - * @return Recall. */ - def recallClass(label: Double) = if((tpPerClass(label) + fnPerClass.getOrElse(label, 0)) == 0) - 0 else - tpPerClass(label).toDouble / (tpPerClass(label) + fnPerClass.getOrElse(label, 0)) + def recallClass(label: Double) = { + val tp = tpPerClass(label) + val fn = fnPerClass.getOrElse(label, 0) + if (tp + fn == 0) 0 else tp.toDouble / (tp + fn) + } /** * Returns F1-measure for a given label (category) * @param label the label. - * @return F1-measure. */ def f1MeasureClass(label: Double) = { val precision = precisionClass(label) @@ -138,13 +129,12 @@ class MultilabelMetrics(predictionAndLabels:RDD[(Set[Double], Set[Double])]) ext if((precision + recall) == 0) 0 else 2 * precision * recall / (precision + recall) } - private lazy val sumTp = tpPerClass.foldLeft(0L){ case(sum, (_, tp)) => sum + tp} - private lazy val sumFpClass = fpPerClass.foldLeft(0L){ case(sum, (_, fp)) => sum + fp} - private lazy val sumFnClass = fnPerClass.foldLeft(0L){ case(sum, (_, fn)) => sum + fn} + private lazy val sumTp = tpPerClass.foldLeft(0L){ case (sum, (_, tp)) => sum + tp} + private lazy val sumFpClass = fpPerClass.foldLeft(0L){ case (sum, (_, fp)) => sum + fp} + private lazy val sumFnClass = fnPerClass.foldLeft(0L){ case (sum, (_, fn)) => sum + fn} /** * Returns micro-averaged label-based Precision - * @return microPrecisionClass. */ lazy val microPrecisionClass = { val sumFp = fpPerClass.foldLeft(0L){ case(sumFp, (_, fp)) => sumFp + fp} @@ -153,7 +143,6 @@ class MultilabelMetrics(predictionAndLabels:RDD[(Set[Double], Set[Double])]) ext /** * Returns micro-averaged label-based Recall - * @return microRecallClass. */ lazy val microRecallClass = { val sumFn = fnPerClass.foldLeft(0.0){ case(sumFn, (_, fn)) => sumFn + fn} @@ -162,8 +151,6 @@ class MultilabelMetrics(predictionAndLabels:RDD[(Set[Double], Set[Double])]) ext /** * Returns micro-averaged label-based F1-measure - * @return microRecallClass. */ lazy val microF1MeasureClass = 2.0 * sumTp / (2 * sumTp + sumFnClass + sumFpClass) - } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala index d67abc4f4df3d..4d33aa3e5ed53 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala @@ -17,10 +17,10 @@ package org.apache.spark.mllib.evaluation -import org.apache.spark.mllib.util.LocalSparkContext -import org.apache.spark.rdd.RDD import org.scalatest.FunSuite +import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.rdd.RDD class MultilabelMetricsSuite extends FunSuite with LocalSparkContext { test("Multilabel evaluation metrics") { @@ -45,7 +45,7 @@ class MultilabelMetricsSuite extends FunSuite with LocalSparkContext { * class 2 - doc 0, 3, 4, 6 (total 4) * */ - val scoreAndLabels:RDD[(Set[Double], Set[Double])] = sc.parallelize( + val scoreAndLabels: RDD[(Set[Double], Set[Double])] = sc.parallelize( Seq((Set(0.0, 1.0), Set(0.0, 2.0)), (Set(0.0, 2.0), Set(0.0, 1.0)), (Set(), Set(0.0)), @@ -70,7 +70,6 @@ class MultilabelMetricsSuite extends FunSuite with LocalSparkContext { val microRecallClass = sumTp.toDouble / (4 + 1 + 2 + 1 + 2 + 2) val microF1MeasureClass = 2.0 * sumTp.toDouble / (2 * sumTp.toDouble + (1 + 1 + 2) + (0 + 1 + 2)) - val macroPrecisionDoc = 1.0 / 7 * (1.0 / 2 + 1.0 / 2 + 0 + 1.0 / 1 + 2.0 / 2 + 2.0 / 3 + 1.0 / 1.0) val macroRecallDoc = 1.0 / 7 * @@ -78,12 +77,9 @@ class MultilabelMetricsSuite extends FunSuite with LocalSparkContext { val macroF1MeasureDoc = (1.0 / 7) * 2 * ( 1.0 / (2 + 2) + 1.0 / (2 + 2) + 0 + 1.0 / (1 + 1) + 2.0 / (2 + 2) + 2.0 / (3 + 2) + 1.0 / (1 + 2) ) - val hammingLoss = (1.0 / (7 * 3)) * (2 + 2 + 1 + 0 + 0 + 1 + 1) - val strictAccuracy = 2.0 / 7 val accuracy = 1.0 / 7 * (1.0 / 3 + 1.0 /3 + 0 + 1.0 / 1 + 2.0 / 2 + 2.0 / 3 + 1.0 / 2) - assert(math.abs(metrics.precisionClass(0.0) - precision0) < delta) assert(math.abs(metrics.precisionClass(1.0) - precision1) < delta) assert(math.abs(metrics.precisionClass(2.0) - precision2) < delta) @@ -93,20 +89,14 @@ class MultilabelMetricsSuite extends FunSuite with LocalSparkContext { assert(math.abs(metrics.f1MeasureClass(0.0) - f1measure0) < delta) assert(math.abs(metrics.f1MeasureClass(1.0) - f1measure1) < delta) assert(math.abs(metrics.f1MeasureClass(2.0) - f1measure2) < delta) - assert(math.abs(metrics.microPrecisionClass - microPrecisionClass) < delta) assert(math.abs(metrics.microRecallClass - microRecallClass) < delta) assert(math.abs(metrics.microF1MeasureClass - microF1MeasureClass) < delta) - assert(math.abs(metrics.macroPrecisionDoc - macroPrecisionDoc) < delta) assert(math.abs(metrics.macroRecallDoc - macroRecallDoc) < delta) assert(math.abs(metrics.macroF1MeasureDoc - macroF1MeasureDoc) < delta) - assert(math.abs(metrics.hammingLoss - hammingLoss) < delta) assert(math.abs(metrics.strictAccuracy - strictAccuracy) < delta) assert(math.abs(metrics.accuracy - accuracy) < delta) - - } - }