diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 4f7324345e1d8..4fd030e3a3c05 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -83,7 +83,7 @@ object DecisionTree extends Serializable { level: Int, filters : Array[List[Filter]], splits : Array[Array[Split]], - bins : Array[Array[Bin]]) : Array[(Split, Double, Long, Long)] = { + bins : Array[Array[Bin]]) : Array[(Split, InformationGainStats)] = { //Common calculations for multiple nested methods val numNodes = scala.math.pow(2, level).toInt @@ -241,7 +241,7 @@ object DecisionTree extends Serializable { featureIndex: Int, index: Int, rightNodeAgg: Array[Array[Double]], - topImpurity: Double) : (Double, Long, Long) = { + topImpurity: Double) : InformationGainStats = { val left0Count = leftNodeAgg(featureIndex)(2 * index) val left1Count = leftNodeAgg(featureIndex)(2 * index + 1) @@ -251,12 +251,12 @@ object DecisionTree extends Serializable { val right1Count = rightNodeAgg(featureIndex)(2 * index + 1) val rightCount = right0Count + right1Count - if (leftCount == 0) return (0, leftCount.toLong, rightCount.toLong) + if (leftCount == 0) return new InformationGainStats(0,topImpurity,Double.MinValue,0,topImpurity,rightCount.toLong) + if (rightCount == 0) return new InformationGainStats(0,topImpurity,topImpurity,leftCount.toLong,Double.MinValue,0) //println("left0count = " + left0Count + ", left1count = " + left1Count + ", leftCount = " + leftCount) val leftImpurity = strategy.impurity.calculate(left0Count, left1Count) - if (rightCount == 0) return (0, leftCount.toLong, rightCount.toLong) //println("right0count = " + right0Count + ", right1count = " + right1Count + ", rightCount = " + rightCount) val rightImpurity = strategy.impurity.calculate(right0Count, right1Count) @@ -264,7 +264,9 @@ object DecisionTree extends Serializable { val leftWeight = leftCount.toDouble / (leftCount + rightCount) val rightWeight = rightCount.toDouble / (leftCount + rightCount) - (topImpurity - leftWeight * leftImpurity - rightWeight * rightImpurity, leftCount.toLong, rightCount.toLong) + val gain = topImpurity - leftWeight * leftImpurity - rightWeight * rightImpurity + + new InformationGainStats(gain,topImpurity,leftImpurity,leftCount.toLong,rightImpurity,rightCount.toLong) } @@ -307,9 +309,9 @@ object DecisionTree extends Serializable { } def calculateGainsForAllNodeSplits(leftNodeAgg: Array[Array[Double]], rightNodeAgg: Array[Array[Double]], nodeImpurity: Double) - : Array[Array[(Double,Long,Long)]] = { + : Array[Array[InformationGainStats]] = { - val gains = Array.ofDim[(Double,Long,Long)](numFeatures, numSplits - 1) + val gains = Array.ofDim[InformationGainStats](numFeatures, numSplits - 1) for (featureIndex <- 0 until numFeatures) { for (index <- 0 until numSplits -1) { @@ -325,7 +327,7 @@ object DecisionTree extends Serializable { @param binData Array[Double] of size 2*numSplits*numFeatures */ - def binsToBestSplit(binData : Array[Double], nodeImpurity : Double) : (Split, Double, Long, Long) = { + def binsToBestSplit(binData : Array[Double], nodeImpurity : Double) : (Split, InformationGainStats) = { println("node impurity = " + nodeImpurity) val (leftNodeAgg, rightNodeAgg) = extractLeftRightNodeAggregates(binData) val gains = calculateGainsForAllNodeSplits(leftNodeAgg, rightNodeAgg, nodeImpurity) @@ -333,36 +335,35 @@ object DecisionTree extends Serializable { //println("gains.size = " + gains.size) //println("gains(0).size = " + gains(0).size) - val (bestFeatureIndex,bestSplitIndex, gain, leftCount, rightCount) = { + val (bestFeatureIndex,bestSplitIndex, gainStats) = { var bestFeatureIndex = 0 var bestSplitIndex = 0 - var maxGain = Double.MinValue - var leftSamples = Long.MinValue - var rightSamples = Long.MinValue + //Initialization with infeasible values + var bestGainStats = new InformationGainStats(-1.0,-1.0,-1.0,0,-1.0,0) +// var maxGain = Double.MinValue +// var leftSamples = Long.MinValue +// var rightSamples = Long.MinValue for (featureIndex <- 0 until numFeatures) { for (splitIndex <- 0 until numSplits - 1){ - val gain = gains(featureIndex)(splitIndex) + val gainStats = gains(featureIndex)(splitIndex) //println("featureIndex = " + featureIndex + ", splitIndex = " + splitIndex + ", gain = " + gain) - if(gain._1 > maxGain) { - maxGain = gain._1 - leftSamples = gain._2 - rightSamples = gain._3 + if(gainStats.gain > bestGainStats.gain) { + bestGainStats = gainStats bestFeatureIndex = featureIndex bestSplitIndex = splitIndex println("bestFeatureIndex = " + bestFeatureIndex + ", bestSplitIndex = " + bestSplitIndex - + ", maxGain = " + maxGain + ", leftSamples = " + leftSamples + ",rightSamples = " + rightSamples) + + ", gain stats = " + bestGainStats) } } } - (bestFeatureIndex,bestSplitIndex,maxGain,leftSamples,rightSamples) + (bestFeatureIndex,bestSplitIndex,bestGainStats) } - (splits(bestFeatureIndex)(bestSplitIndex),gain,leftCount,rightCount) - //TODO: Return array of node stats with split and impurity information + (splits(bestFeatureIndex)(bestSplitIndex),gainStats) } //Calculate best splits for all nodes at a given level - val bestSplits = new Array[(Split, Double, Long, Long)](numNodes) + val bestSplits = new Array[(Split, InformationGainStats)](numNodes) for (node <- 0 until numNodes){ val shift = 2*node*numSplits*numFeatures val binsForNode = binAggregates.slice(shift,shift+2*numSplits*numFeatures) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 5ee61b0a5173c..2b5988bb3c6a3 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -72,9 +72,11 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(bestSplits.length == 1) assert(0==bestSplits(0)._1.feature) assert(10==bestSplits(0)._1.threshold) - assert(0==bestSplits(0)._2) - assert(10==bestSplits(0)._3) - assert(990==bestSplits(0)._4) + assert(0==bestSplits(0)._2.gain) + assert(10==bestSplits(0)._2.leftSamples) + assert(0==bestSplits(0)._2.leftImpurity) + assert(990==bestSplits(0)._2.rightSamples) + assert(0==bestSplits(0)._2.rightImpurity) } test("stump with fixed label 1 for Gini"){ @@ -93,9 +95,11 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(bestSplits.length == 1) assert(0==bestSplits(0)._1.feature) assert(10==bestSplits(0)._1.threshold) - assert(0==bestSplits(0)._2) - assert(10==bestSplits(0)._3) - assert(990==bestSplits(0)._4) + assert(0==bestSplits(0)._2.gain) + assert(10==bestSplits(0)._2.leftSamples) + assert(0==bestSplits(0)._2.leftImpurity) + assert(990==bestSplits(0)._2.rightSamples) + assert(0==bestSplits(0)._2.rightImpurity) } @@ -115,9 +119,11 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(bestSplits.length == 1) assert(0==bestSplits(0)._1.feature) assert(10==bestSplits(0)._1.threshold) - assert(0==bestSplits(0)._2) - assert(10==bestSplits(0)._3) - assert(990==bestSplits(0)._4) + assert(0==bestSplits(0)._2.gain) + assert(10==bestSplits(0)._2.leftSamples) + assert(0==bestSplits(0)._2.leftImpurity) + assert(990==bestSplits(0)._2.rightSamples) + assert(0==bestSplits(0)._2.rightImpurity) } test("stump with fixed label 1 for Entropy"){ @@ -136,9 +142,11 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(bestSplits.length == 1) assert(0==bestSplits(0)._1.feature) assert(10==bestSplits(0)._1.threshold) - assert(0==bestSplits(0)._2) - assert(10==bestSplits(0)._3) - assert(990==bestSplits(0)._4) + assert(0==bestSplits(0)._2.gain) + assert(10==bestSplits(0)._2.leftSamples) + assert(0==bestSplits(0)._2.leftImpurity) + assert(990==bestSplits(0)._2.rightSamples) + assert(0==bestSplits(0)._2.rightImpurity) }