diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 5e88109b5ffb5..a16bff2b5f4d7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -1034,8 +1034,12 @@ object DecisionTree extends Serializable with Logging { /** * Calculates the classifier accuracy. */ - def accuracyScore(model: DecisionTreeModel, data: RDD[LabeledPoint]): Double = { - val correctCount = data.filter(y => model.predict(y.features) == y.label).count() + def accuracyScore(model: DecisionTreeModel, data: RDD[LabeledPoint], + threshold: Double = 0.5): Double = { + def predictedValue(features: Array[Double]) = { + if (model.predict(features) < threshold) 0.0 else 1.0 + } + val correctCount = data.filter(y => predictedValue(y.features) == y.label).count() val count = data.count() logDebug("correct prediction count = " + correctCount) logDebug("data count = " + count) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala index 94d77571dc22f..a056da77641ee 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -34,14 +34,7 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable * @return Double prediction from the trained model */ def predict(features: Array[Double]): Double = { - algo match { - case Classification => { - if (topNode.predictIfLeaf(features) < 0.5) 0.0 else 1.0 - } - case Regression => { - topNode.predictIfLeaf(features) - } - } + topNode.predictIfLeaf(features) } /**