From b09dc983f4f05da61479c87617526064b0e3dde8 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Sun, 26 Jan 2014 14:54:43 -0800 Subject: [PATCH] minor refactoring Signed-off-by: Manish Amde --- .../spark/mllib/tree/DecisionTree.scala | 43 ++++++++++--------- 1 file changed, 23 insertions(+), 20 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 1116d0c4f711e..ab2c9011dd93b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -367,18 +367,18 @@ object DecisionTree extends Serializable with Logging { def calculateGainForSplit(leftNodeAgg: Array[Array[Double]], featureIndex: Int, - index: Int, + splitIndex: Int, rightNodeAgg: Array[Array[Double]], topImpurity: Double) : InformationGainStats = { strategy.algo match { case Classification => { - val left0Count = leftNodeAgg(featureIndex)(2 * index) - val left1Count = leftNodeAgg(featureIndex)(2 * index + 1) + val left0Count = leftNodeAgg(featureIndex)(2 * splitIndex) + val left1Count = leftNodeAgg(featureIndex)(2 * splitIndex + 1) val leftCount = left0Count + left1Count - val right0Count = rightNodeAgg(featureIndex)(2 * index) - val right1Count = rightNodeAgg(featureIndex)(2 * index + 1) + val right0Count = rightNodeAgg(featureIndex)(2 * splitIndex) + val right1Count = rightNodeAgg(featureIndex)(2 * splitIndex + 1) val rightCount = right0Count + right1Count val impurity = if (level > 0) topImpurity else strategy.impurity.calculate(left0Count + right0Count, left1Count + right1Count) @@ -405,13 +405,13 @@ object DecisionTree extends Serializable with Logging { new InformationGainStats(gain,impurity,leftImpurity,rightImpurity,predict) } case Regression => { - val leftCount = leftNodeAgg(featureIndex)(3 * index) - val leftSum = leftNodeAgg(featureIndex)(3 * index + 1) - val leftSumSquares = leftNodeAgg(featureIndex)(3 * index + 2) + val leftCount = leftNodeAgg(featureIndex)(3 * splitIndex) + val leftSum = leftNodeAgg(featureIndex)(3 * splitIndex + 1) + val leftSumSquares = leftNodeAgg(featureIndex)(3 * splitIndex + 2) - val rightCount = rightNodeAgg(featureIndex)(3 * index) - val rightSum = rightNodeAgg(featureIndex)(3 * index + 1) - val rightSumSquares = rightNodeAgg(featureIndex)(3 * index + 2) + val rightCount = rightNodeAgg(featureIndex)(3 * splitIndex) + val rightSum = rightNodeAgg(featureIndex)(3 * splitIndex + 1) + val rightSumSquares = rightNodeAgg(featureIndex)(3 * splitIndex + 2) val impurity = if (level > 0) topImpurity else strategy.impurity.calculate(leftCount + rightCount, leftSum + rightSum, leftSumSquares + rightSumSquares) @@ -463,9 +463,9 @@ object DecisionTree extends Serializable with Logging { leftNodeAgg(featureIndex)(2 * splitIndex + 1) = binData(shift + 2*splitIndex + 1) + leftNodeAgg(featureIndex)(2 * splitIndex - 2 + 1) rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex)) - = binData(shift + (2 * (numBins - 1 - splitIndex))) + rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex)) + = binData(shift + (2 * (numBins - 2 - splitIndex))) + rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex)) rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex) + 1) - = binData(shift + (2 * (numBins - 1 - splitIndex) + 1)) + rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex) + 1) + = binData(shift + (2 * (numBins - 2 - splitIndex) + 1)) + rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex) + 1) } } (leftNodeAgg, rightNodeAgg) @@ -490,11 +490,11 @@ object DecisionTree extends Serializable with Logging { leftNodeAgg(featureIndex)(3 * splitIndex + 2) = binData(shift + 3*splitIndex + 2) + leftNodeAgg(featureIndex)(3 * splitIndex - 3 + 2) rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex)) - = binData(shift + (3 * (numBins - 1 - splitIndex))) + rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex)) + = binData(shift + (3 * (numBins - 2 - splitIndex))) + rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex)) rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex) + 1) - = binData(shift + (3 * (numBins - 1 - splitIndex) + 1)) + rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 1) + = binData(shift + (3 * (numBins - 2 - splitIndex) + 1)) + rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 1) rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex) + 2) - = binData(shift + (3 * (numBins - 1 - splitIndex) + 2)) + rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 2) + = binData(shift + (3 * (numBins - 2 - splitIndex) + 2)) + rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 2) } } (leftNodeAgg, rightNodeAgg) @@ -508,9 +508,9 @@ object DecisionTree extends Serializable with Logging { val gains = Array.ofDim[InformationGainStats](numFeatures, numBins - 1) for (featureIndex <- 0 until numFeatures) { - for (index <- 0 until numBins -1) { + for (splitIndex <- 0 until numBins -1) { //logDebug("splitIndex = " + index) - gains(featureIndex)(index) = calculateGainForSplit(leftNodeAgg, featureIndex, index, rightNodeAgg, nodeImpurity) + gains(featureIndex)(splitIndex) = calculateGainForSplit(leftNodeAgg, featureIndex, splitIndex, rightNodeAgg, nodeImpurity) } } gains @@ -544,6 +544,8 @@ object DecisionTree extends Serializable with Logging { (bestFeatureIndex,bestSplitIndex,bestGainStats) } + logDebug("best split bin = " + bins(bestFeatureIndex)(bestSplitIndex)) + logDebug("best split bin = " + splits(bestFeatureIndex)(bestSplitIndex)) (splits(bestFeatureIndex)(bestSplitIndex),gainStats) } @@ -614,13 +616,14 @@ object DecisionTree extends Serializable with Logging { //Find all splits for (featureIndex <- 0 until numFeatures){ - val isFeatureContinous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty - if (isFeatureContinous) { + val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty + if (isFeatureContinuous) { val featureSamples = sampledInput.map(lp => lp.features(featureIndex)).sorted val stride : Double = numSamples.toDouble/numBins logDebug("stride = " + stride) for (index <- 0 until numBins-1) { + //TODO: Investigate this val sampleIndex = (index+1)*stride.toInt val split = new Split(featureIndex,featureSamples(sampleIndex),Continuous, List()) splits(featureIndex)(index) = split