diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index ed0cf825b1d50..1116d0c4f711e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -152,8 +152,8 @@ object DecisionTree extends Serializable with Logging { //Find the number of features by looking at the first sample val numFeatures = input.take(1)(0).features.length logDebug("numFeatures = " + numFeatures) - val numSplits = strategy.numBins - logDebug("numSplits = " + numSplits) + val numBins = strategy.numBins + logDebug("numBins = " + numBins) /*Find the filters used before reaching the current code*/ def findParentFilters(nodeIndex: Int): List[Filter] = { @@ -161,8 +161,6 @@ object DecisionTree extends Serializable with Logging { List[Filter]() } else { val nodeFilterIndex = scala.math.pow(2, level).toInt - 1 + nodeIndex - //val parentFilterIndex = nodeFilterIndex / 2 - //TODO: Check left or right filter filters(nodeFilterIndex) } } @@ -204,9 +202,9 @@ object DecisionTree extends Serializable with Logging { } /*Finds the right bin for the given feature*/ - def findBin(featureIndex: Int, labeledPoint: LabeledPoint, isFeatureContinous : Boolean) : Int = { + def findBin(featureIndex: Int, labeledPoint: LabeledPoint, isFeatureContinuous : Boolean) : Int = { - if (isFeatureContinous){ + if (isFeatureContinuous){ //TODO: Do binary search for (binIndex <- 0 until strategy.numBins) { val bin = bins(featureIndex)(binIndex) @@ -245,11 +243,11 @@ object DecisionTree extends Serializable with Logging { // calculating bin index and label per feature per node val arr = new Array[Double](1+(numFeatures * numNodes)) arr(0) = labeledPoint.label - for (index <- 0 until numNodes) { - val parentFilters = findParentFilters(index) + for (nodeIndex <- 0 until numNodes) { + val parentFilters = findParentFilters(nodeIndex) //Find out whether the sample qualifies for the particular node val sampleValid = isSampleValid(parentFilters, labeledPoint) - val shift = 1 + numFeatures * index + val shift = 1 + numFeatures * nodeIndex if (!sampleValid) { //Add to invalid bin index -1 for (featureIndex <- 0 until numFeatures) { @@ -274,11 +272,11 @@ object DecisionTree extends Serializable with Logging { val isSampleValidForNode = if (arr(validSignalIndex) != -1) true else false if (isSampleValidForNode) { val label = arr(0) - for (feature <- 0 until numFeatures) { + for (featureIndex <- 0 until numFeatures) { val arrShift = 1 + numFeatures * node - val aggShift = 2 * numSplits * numFeatures * node - val arrIndex = arrShift + feature - val aggIndex = aggShift + 2 * feature * numSplits + arr(arrIndex).toInt * 2 + val aggShift = 2 * numBins * numFeatures * node + val arrIndex = arrShift + featureIndex + val aggIndex = aggShift + 2 * featureIndex * numBins + arr(arrIndex).toInt * 2 label match { case (0.0) => agg(aggIndex) = agg(aggIndex) + 1 case (1.0) => agg(aggIndex + 1) = agg(aggIndex + 1) + 1 @@ -296,9 +294,9 @@ object DecisionTree extends Serializable with Logging { val label = arr(0) for (feature <- 0 until numFeatures) { val arrShift = 1 + numFeatures * node - val aggShift = 3 * numSplits * numFeatures * node + val aggShift = 3 * numBins * numFeatures * node val arrIndex = arrShift + feature - val aggIndex = aggShift + 3 * feature * numSplits + arr(arrIndex).toInt * 3 + val aggIndex = aggShift + 3 * feature * numBins + arr(arrIndex).toInt * 3 //count, sum, sum^2 agg(aggIndex) = agg(aggIndex) + 1 agg(aggIndex + 1) = agg(aggIndex + 1) + label @@ -318,7 +316,6 @@ object DecisionTree extends Serializable with Logging { @return Array[Double] storing aggregate calculation of size 2*numSplits*numFeatures*numNodes for classification */ def binSeqOp(agg : Array[Double], arr: Array[Double]) : Array[Double] = { - //TODO: Requires logic for regressions strategy.algo match { case Classification => classificationBinSeqOp(arr, agg) //TODO: Implement this @@ -327,10 +324,9 @@ object DecisionTree extends Serializable with Logging { agg } - //TODO: This length is different for regression val binAggregateLength = strategy.algo match { - case Classification => 2*numSplits * numFeatures * numNodes - case Regression => 3*numSplits * numFeatures * numNodes + case Classification => 2*numBins * numFeatures * numNodes + case Regression => 3*numBins * numFeatures * numNodes } logDebug("binAggregateLength = " + binAggregateLength) @@ -453,52 +449,52 @@ object DecisionTree extends Serializable with Logging { strategy.algo match { case Classification => { - val leftNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numSplits - 1)) - val rightNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numSplits - 1)) + val leftNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numBins - 1)) + val rightNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numBins - 1)) for (featureIndex <- 0 until numFeatures) { - val shift = 2*featureIndex*numSplits + val shift = 2*featureIndex*numBins leftNodeAgg(featureIndex)(0) = binData(shift + 0) leftNodeAgg(featureIndex)(1) = binData(shift + 1) - rightNodeAgg(featureIndex)(2 * (numSplits - 2)) = binData(shift + (2 * (numSplits - 1))) - rightNodeAgg(featureIndex)(2 * (numSplits - 2) + 1) = binData(shift + (2 * (numSplits - 1)) + 1) - for (splitIndex <- 1 until numSplits - 1) { + rightNodeAgg(featureIndex)(2 * (numBins - 2)) = binData(shift + (2 * (numBins - 1))) + rightNodeAgg(featureIndex)(2 * (numBins - 2) + 1) = binData(shift + (2 * (numBins - 1)) + 1) + for (splitIndex <- 1 until numBins - 1) { leftNodeAgg(featureIndex)(2 * splitIndex) = binData(shift + 2*splitIndex) + leftNodeAgg(featureIndex)(2 * splitIndex - 2) leftNodeAgg(featureIndex)(2 * splitIndex + 1) = binData(shift + 2*splitIndex + 1) + leftNodeAgg(featureIndex)(2 * splitIndex - 2 + 1) - rightNodeAgg(featureIndex)(2 * (numSplits - 2 - splitIndex)) - = binData(shift + (2 * (numSplits - 1 - splitIndex))) + rightNodeAgg(featureIndex)(2 * (numSplits - 1 - splitIndex)) - rightNodeAgg(featureIndex)(2 * (numSplits - 2 - splitIndex) + 1) - = binData(shift + (2 * (numSplits - 1 - splitIndex) + 1)) + rightNodeAgg(featureIndex)(2 * (numSplits - 1 - splitIndex) + 1) + rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex)) + = binData(shift + (2 * (numBins - 1 - splitIndex))) + rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex)) + rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex) + 1) + = binData(shift + (2 * (numBins - 1 - splitIndex) + 1)) + rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex) + 1) } } (leftNodeAgg, rightNodeAgg) } case Regression => { - val leftNodeAgg = Array.ofDim[Double](numFeatures, 3 * (numSplits - 1)) - val rightNodeAgg = Array.ofDim[Double](numFeatures, 3 * (numSplits - 1)) + val leftNodeAgg = Array.ofDim[Double](numFeatures, 3 * (numBins - 1)) + val rightNodeAgg = Array.ofDim[Double](numFeatures, 3 * (numBins - 1)) for (featureIndex <- 0 until numFeatures) { - val shift = 3*featureIndex*numSplits + val shift = 3*featureIndex*numBins leftNodeAgg(featureIndex)(0) = binData(shift + 0) leftNodeAgg(featureIndex)(1) = binData(shift + 1) leftNodeAgg(featureIndex)(2) = binData(shift + 2) - rightNodeAgg(featureIndex)(3 * (numSplits - 2)) = binData(shift + (3 * (numSplits - 1))) - rightNodeAgg(featureIndex)(3 * (numSplits - 2) + 1) = binData(shift + (3 * (numSplits - 1)) + 1) - rightNodeAgg(featureIndex)(3 * (numSplits - 2) + 2) = binData(shift + (3 * (numSplits - 1)) + 2) - for (splitIndex <- 1 until numSplits - 1) { + rightNodeAgg(featureIndex)(3 * (numBins - 2)) = binData(shift + (3 * (numBins - 1))) + rightNodeAgg(featureIndex)(3 * (numBins - 2) + 1) = binData(shift + (3 * (numBins - 1)) + 1) + rightNodeAgg(featureIndex)(3 * (numBins - 2) + 2) = binData(shift + (3 * (numBins - 1)) + 2) + for (splitIndex <- 1 until numBins - 1) { leftNodeAgg(featureIndex)(3 * splitIndex) = binData(shift + 3*splitIndex) + leftNodeAgg(featureIndex)(3 * splitIndex - 3) leftNodeAgg(featureIndex)(3 * splitIndex + 1) = binData(shift + 3*splitIndex + 1) + leftNodeAgg(featureIndex)(3 * splitIndex - 3 + 1) leftNodeAgg(featureIndex)(3 * splitIndex + 2) = binData(shift + 3*splitIndex + 2) + leftNodeAgg(featureIndex)(3 * splitIndex - 3 + 2) - rightNodeAgg(featureIndex)(3 * (numSplits - 2 - splitIndex)) - = binData(shift + (3 * (numSplits - 1 - splitIndex))) + rightNodeAgg(featureIndex)(3 * (numSplits - 1 - splitIndex)) - rightNodeAgg(featureIndex)(3 * (numSplits - 2 - splitIndex) + 1) - = binData(shift + (3 * (numSplits - 1 - splitIndex) + 1)) + rightNodeAgg(featureIndex)(3 * (numSplits - 1 - splitIndex) + 1) - rightNodeAgg(featureIndex)(3 * (numSplits - 2 - splitIndex) + 2) - = binData(shift + (3 * (numSplits - 1 - splitIndex) + 2)) + rightNodeAgg(featureIndex)(3 * (numSplits - 1 - splitIndex) + 2) + rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex)) + = binData(shift + (3 * (numBins - 1 - splitIndex))) + rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex)) + rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex) + 1) + = binData(shift + (3 * (numBins - 1 - splitIndex) + 1)) + rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 1) + rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex) + 2) + = binData(shift + (3 * (numBins - 1 - splitIndex) + 2)) + rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 2) } } (leftNodeAgg, rightNodeAgg) @@ -509,10 +505,10 @@ object DecisionTree extends Serializable with Logging { def calculateGainsForAllNodeSplits(leftNodeAgg: Array[Array[Double]], rightNodeAgg: Array[Array[Double]], nodeImpurity: Double) : Array[Array[InformationGainStats]] = { - val gains = Array.ofDim[InformationGainStats](numFeatures, numSplits - 1) + val gains = Array.ofDim[InformationGainStats](numFeatures, numBins - 1) for (featureIndex <- 0 until numFeatures) { - for (index <- 0 until numSplits -1) { + for (index <- 0 until numBins -1) { //logDebug("splitIndex = " + index) gains(featureIndex)(index) = calculateGainForSplit(leftNodeAgg, featureIndex, index, rightNodeAgg, nodeImpurity) } @@ -521,10 +517,10 @@ object DecisionTree extends Serializable with Logging { } /* - Find the best split for a node given bin aggregate data + Find the best split for a node given bin aggregate data - @param binData Array[Double] of size 2*numSplits*numFeatures - */ + @param binData Array[Double] of size 2*numSplits*numFeatures + */ def binsToBestSplit(binData : Array[Double], nodeImpurity : Double) : (Split, InformationGainStats) = { logDebug("node impurity = " + nodeImpurity) val (leftNodeAgg, rightNodeAgg) = extractLeftRightNodeAggregates(binData) @@ -536,7 +532,7 @@ object DecisionTree extends Serializable with Logging { //Initialization with infeasible values var bestGainStats = new InformationGainStats(Double.MinValue,-1.0,-1.0,-1.0,-1) for (featureIndex <- 0 until numFeatures) { - for (splitIndex <- 0 until numSplits - 1){ + for (splitIndex <- 0 until numBins - 1){ val gainStats = gains(featureIndex)(splitIndex) if(gainStats.gain > bestGainStats.gain) { bestGainStats = gainStats @@ -556,13 +552,13 @@ object DecisionTree extends Serializable with Logging { def getBinDataForNode(node: Int): Array[Double] = { strategy.algo match { case Classification => { - val shift = 2 * node * numSplits * numFeatures - val binsForNode = binAggregates.slice(shift, shift + 2 * numSplits * numFeatures) + val shift = 2 * node * numBins * numFeatures + val binsForNode = binAggregates.slice(shift, shift + 2 * numBins * numFeatures) binsForNode } case Regression => { - val shift = 3 * node * numSplits * numFeatures - val binsForNode = binAggregates.slice(shift, shift + 3 * numSplits * numFeatures) + val shift = 3 * node * numBins * numFeatures + val binsForNode = binAggregates.slice(shift, shift + 3 * numBins * numFeatures) binsForNode } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 40bb94e6794d7..8d5ed343e0eb4 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -109,7 +109,20 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { //TODO: Test max feature value > num bins - test("stump with all categorical variables"){ + test("classification stump with all categorical variables"){ + val arr = DecisionTreeSuite.generateCategoricalDataPoints() + assert(arr.length == 1000) + val rdd = sc.parallelize(arr) + val strategy = new Strategy(Classification,Gini,3,100,categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) + val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy) + strategy.numBins = 100 + val bestSplits = DecisionTree.findBestSplits(rdd,new Array(7),strategy,0,Array[List[Filter]](),splits,bins) + println(bestSplits(0)._1) + println(bestSplits(0)._2) + //TODO: Add asserts + } + + test("regression stump with all categorical variables"){ val arr = DecisionTreeSuite.generateCategoricalDataPoints() assert(arr.length == 1000) val rdd = sc.parallelize(arr) @@ -123,7 +136,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { } - test("stump with fixed label 0 for Gini"){ + test("stump with fixed label 0 for Gini"){ val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0() assert(arr.length == 1000) val rdd = sc.parallelize(arr)