From dad0afc85aea64c06b4dd64504b3112c881ae4e6 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Sun, 15 Dec 2013 00:25:58 -0800 Subject: [PATCH] decison stump functionality working Signed-off-by: Manish Amde --- .../spark/mllib/tree/DecisionTree.scala | 123 +++++++++++++----- .../spark/mllib/tree/DecisionTreeSuite.scala | 28 ++-- 2 files changed, 108 insertions(+), 43 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 43ede29ef6fd8..4f7324345e1d8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -20,9 +20,10 @@ package org.apache.spark.mllib.tree import org.apache.spark.SparkContext._ import org.apache.spark.rdd.RDD import org.apache.spark.mllib.tree.model._ -import org.apache.spark.Logging +import org.apache.spark.{SparkContext, Logging} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.model.Split +import org.apache.spark.mllib.tree.impurity.Gini class DecisionTree(val strategy : Strategy) { @@ -46,8 +47,13 @@ class DecisionTree(val strategy : Strategy) { //Find best split for all nodes at a level val numNodes= scala.math.pow(2,level).toInt //TODO: Change the input parent impurities values - val bestSplits = DecisionTree.findBestSplits(input, Array(0.0), strategy, level, filters,splits,bins) + val splits_stats_for_level = DecisionTree.findBestSplits(input, Array(2.0), strategy, level, filters,splits,bins) + for (tmp <- splits_stats_for_level){ + println("final best split = " + tmp._1) + } //TODO: update filters and decision tree model + require(scala.math.pow(2,level)==splits_stats_for_level.length) + } return new DecisionTreeModel() @@ -77,7 +83,7 @@ object DecisionTree extends Serializable { level: Int, filters : Array[List[Filter]], splits : Array[Array[Split]], - bins : Array[Array[Bin]]) : Array[Split] = { + bins : Array[Array[Bin]]) : Array[(Split, Double, Long, Long)] = { //Common calculations for multiple nested methods val numNodes = scala.math.pow(2, level).toInt @@ -94,8 +100,9 @@ object DecisionTree extends Serializable { List[Filter]() } else { val nodeFilterIndex = scala.math.pow(2, level).toInt + nodeIndex - val parentFilterIndex = nodeFilterIndex / 2 - filters(parentFilterIndex) + //val parentFilterIndex = nodeFilterIndex / 2 + //TODO: Check left or right filter + filters(nodeFilterIndex) } } @@ -230,22 +237,26 @@ object DecisionTree extends Serializable { //binAggregates.foreach(x => println(x)) - def calculateGainForSplit(leftNodeAgg: Array[Array[Double]], featureIndex: Int, index: Int, rightNodeAgg: Array[Array[Double]], topImpurity: Double): Double = { + def calculateGainForSplit(leftNodeAgg: Array[Array[Double]], + featureIndex: Int, + index: Int, + rightNodeAgg: Array[Array[Double]], + topImpurity: Double) : (Double, Long, Long) = { val left0Count = leftNodeAgg(featureIndex)(2 * index) val left1Count = leftNodeAgg(featureIndex)(2 * index + 1) val leftCount = left0Count + left1Count - if (leftCount == 0) return 0 - - //println("left0count = " + left0Count + ", left1count = " + left1Count + ", leftCount = " + leftCount) - val leftImpurity = strategy.impurity.calculate(left0Count, left1Count) - val right0Count = rightNodeAgg(featureIndex)(2 * index) val right1Count = rightNodeAgg(featureIndex)(2 * index + 1) val rightCount = right0Count + right1Count - if (rightCount == 0) return 0 + if (leftCount == 0) return (0, leftCount.toLong, rightCount.toLong) + + //println("left0count = " + left0Count + ", left1count = " + left1Count + ", leftCount = " + leftCount) + val leftImpurity = strategy.impurity.calculate(left0Count, left1Count) + + if (rightCount == 0) return (0, leftCount.toLong, rightCount.toLong) //println("right0count = " + right0Count + ", right1count = " + right1Count + ", rightCount = " + rightCount) val rightImpurity = strategy.impurity.calculate(right0Count, right1Count) @@ -253,7 +264,7 @@ object DecisionTree extends Serializable { val leftWeight = leftCount.toDouble / (leftCount + rightCount) val rightWeight = rightCount.toDouble / (leftCount + rightCount) - topImpurity - leftWeight * leftImpurity - rightWeight * rightImpurity + (topImpurity - leftWeight * leftImpurity - rightWeight * rightImpurity, leftCount.toLong, rightCount.toLong) } @@ -295,9 +306,10 @@ object DecisionTree extends Serializable { (leftNodeAgg, rightNodeAgg) } - def calculateGainsForAllNodeSplits(leftNodeAgg: Array[Array[Double]], rightNodeAgg: Array[Array[Double]], nodeImpurity: Double): Array[Array[Double]] = { + def calculateGainsForAllNodeSplits(leftNodeAgg: Array[Array[Double]], rightNodeAgg: Array[Array[Double]], nodeImpurity: Double) + : Array[Array[(Double,Long,Long)]] = { - val gains = Array.ofDim[Double](numFeatures, numSplits - 1) + val gains = Array.ofDim[(Double,Long,Long)](numFeatures, numSplits - 1) for (featureIndex <- 0 until numFeatures) { for (index <- 0 until numSplits -1) { @@ -313,7 +325,7 @@ object DecisionTree extends Serializable { @param binData Array[Double] of size 2*numSplits*numFeatures */ - def binsToBestSplit(binData : Array[Double], nodeImpurity : Double) : Split = { + def binsToBestSplit(binData : Array[Double], nodeImpurity : Double) : (Split, Double, Long, Long) = { println("node impurity = " + nodeImpurity) val (leftNodeAgg, rightNodeAgg) = extractLeftRightNodeAggregates(binData) val gains = calculateGainsForAllNodeSplits(leftNodeAgg, rightNodeAgg, nodeImpurity) @@ -321,32 +333,36 @@ object DecisionTree extends Serializable { //println("gains.size = " + gains.size) //println("gains(0).size = " + gains(0).size) - val (bestFeatureIndex,bestSplitIndex) = { + val (bestFeatureIndex,bestSplitIndex, gain, leftCount, rightCount) = { var bestFeatureIndex = 0 var bestSplitIndex = 0 var maxGain = Double.MinValue + var leftSamples = Long.MinValue + var rightSamples = Long.MinValue for (featureIndex <- 0 until numFeatures) { for (splitIndex <- 0 until numSplits - 1){ val gain = gains(featureIndex)(splitIndex) //println("featureIndex = " + featureIndex + ", splitIndex = " + splitIndex + ", gain = " + gain) - if(gain > maxGain) { - maxGain = gain + if(gain._1 > maxGain) { + maxGain = gain._1 + leftSamples = gain._2 + rightSamples = gain._3 bestFeatureIndex = featureIndex bestSplitIndex = splitIndex - println("bestFeatureIndex = " + bestFeatureIndex + ", bestSplitIndex = " + bestSplitIndex + ", maxGain = " + maxGain) + println("bestFeatureIndex = " + bestFeatureIndex + ", bestSplitIndex = " + bestSplitIndex + + ", maxGain = " + maxGain + ", leftSamples = " + leftSamples + ",rightSamples = " + rightSamples) } } } - (bestFeatureIndex,bestSplitIndex) + (bestFeatureIndex,bestSplitIndex,maxGain,leftSamples,rightSamples) } - splits(bestFeatureIndex)(bestSplitIndex) - - //TODo: Return array of node stats with split and impurity information + (splits(bestFeatureIndex)(bestSplitIndex),gain,leftCount,rightCount) + //TODO: Return array of node stats with split and impurity information } //Calculate best splits for all nodes at a given level - val bestSplits = new Array[Split](numNodes) + val bestSplits = new Array[(Split, Double, Long, Long)](numNodes) for (node <- 0 until numNodes){ val shift = 2*node*numSplits*numFeatures val binsForNode = binAggregates.slice(shift,shift+2*numSplits*numFeatures) @@ -381,9 +397,6 @@ object DecisionTree extends Serializable { val sampledInput = input.sample(false, fraction, 42).collect() val numSamples = sampledInput.length - //TODO: Remove this requirement - require(numSamples > numSplits, "length of input samples should be greater than numSplits") - //Find the number of features by looking at the first sample val numFeatures = input.take(1)(0).features.length @@ -395,14 +408,22 @@ object DecisionTree extends Serializable { //Find all splits for (featureIndex <- 0 until numFeatures){ val featureSamples = sampledInput.map(lp => lp.features(featureIndex)).sorted - val stride : Double = numSamples.toDouble/numSplits - - println("stride = " + stride) - for (index <- 0 until numSplits-1) { - val sampleIndex = (index+1)*stride.toInt - val split = new Split(featureIndex,featureSamples(sampleIndex),"continuous") - splits(featureIndex)(index) = split + if (numSamples < numSplits) { + //TODO: Test this + println("numSamples = " + numSamples + ", less than numSplits = " + numSplits) + for (index <- 0 until numSplits-1) { + val split = new Split(featureIndex,featureSamples(index),"continuous") + splits(featureIndex)(index) = split + } + } else { + val stride : Double = numSamples.toDouble/numSplits + println("stride = " + stride) + for (index <- 0 until numSplits-1) { + val sampleIndex = (index+1)*stride.toInt + val split = new Split(featureIndex,featureSamples(sampleIndex),"continuous") + splits(featureIndex)(index) = split + } } } @@ -430,4 +451,36 @@ object DecisionTree extends Serializable { } } + def main(args: Array[String]) { + + val sc = new SparkContext(args(0), "DecisionTree") + val data = loadLabeledData(sc, args(1)) + + val strategy = new Strategy(kind = "classification", impurity = Gini, maxDepth = 2, numSplits = 569) + val model = new DecisionTree(strategy).train(data) + + sc.stop() + } + + /** + * Load labeled data from a file. The data format used here is + * , ... + * where , are feature values in Double and is the corresponding label as Double. + * + * @param sc SparkContext + * @param dir Directory to the input data files. + * @return An RDD of LabeledPoint. Each labeled point has two elements: the first element is + * the label, and the second element represents the feature values (an array of Double). + */ + def loadLabeledData(sc: SparkContext, dir: String): RDD[LabeledPoint] = { + sc.textFile(dir).map { line => + val parts = line.trim().split(",") + val label = parts(0).toDouble + val features = parts.slice(1,parts.length).map(_.toDouble) + LabeledPoint(label, features) + } + } + + + } \ No newline at end of file diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 2c9794371eb29..5ee61b0a5173c 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -68,10 +68,13 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(bins(0).length==100) assert(splits(0).length==99) assert(bins(0).length==100) - println(splits(1)(98)) val bestSplits = DecisionTree.findBestSplits(rdd,Array(0.0),strategy,0,Array[List[Filter]](),splits,bins) assert(bestSplits.length == 1) - println(bestSplits(0)) + assert(0==bestSplits(0)._1.feature) + assert(10==bestSplits(0)._1.threshold) + assert(0==bestSplits(0)._2) + assert(10==bestSplits(0)._3) + assert(990==bestSplits(0)._4) } test("stump with fixed label 1 for Gini"){ @@ -86,10 +89,13 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(bins(0).length==100) assert(splits(0).length==99) assert(bins(0).length==100) - println(splits(1)(98)) val bestSplits = DecisionTree.findBestSplits(rdd,Array(0.0),strategy,0,Array[List[Filter]](),splits,bins) assert(bestSplits.length == 1) - println(bestSplits(0)) + assert(0==bestSplits(0)._1.feature) + assert(10==bestSplits(0)._1.threshold) + assert(0==bestSplits(0)._2) + assert(10==bestSplits(0)._3) + assert(990==bestSplits(0)._4) } @@ -105,10 +111,13 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(bins(0).length==100) assert(splits(0).length==99) assert(bins(0).length==100) - println(splits(1)(98)) val bestSplits = DecisionTree.findBestSplits(rdd,Array(0.0),strategy,0,Array[List[Filter]](),splits,bins) assert(bestSplits.length == 1) - println(bestSplits(0)) + assert(0==bestSplits(0)._1.feature) + assert(10==bestSplits(0)._1.threshold) + assert(0==bestSplits(0)._2) + assert(10==bestSplits(0)._3) + assert(990==bestSplits(0)._4) } test("stump with fixed label 1 for Entropy"){ @@ -123,10 +132,13 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(bins(0).length==100) assert(splits(0).length==99) assert(bins(0).length==100) - println(splits(1)(98)) val bestSplits = DecisionTree.findBestSplits(rdd,Array(0.0),strategy,0,Array[List[Filter]](),splits,bins) assert(bestSplits.length == 1) - println(bestSplits(0)) + assert(0==bestSplits(0)._1.feature) + assert(10==bestSplits(0)._1.threshold) + assert(0==bestSplits(0)._2) + assert(10==bestSplits(0)._3) + assert(990==bestSplits(0)._4) }