From 1e8c70483984d86a204e0377b2b043cc17c854ac Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Mon, 31 Mar 2014 20:31:18 -0700 Subject: [PATCH] remove numBins field in the Strategy class --- .../scala/org/apache/spark/mllib/tree/DecisionTree.scala | 5 +---- .../apache/spark/mllib/tree/configuration/Strategy.scala | 5 +---- .../org/apache/spark/mllib/tree/DecisionTreeSuite.scala | 6 ------ 3 files changed, 2 insertions(+), 14 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 5e8fc70bd3c04..33205b919db8f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -56,9 +56,6 @@ class DecisionTree private(val strategy: Strategy) extends Serializable with Log val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) logDebug("numSplits = " + bins(0).length) - // Set number of bins for the input data. - strategy.numBins = bins(0).length - // depth of the decision tree val maxDepth = strategy.maxDepth // the max number of nodes possible given the depth of the tree @@ -300,7 +297,7 @@ object DecisionTree extends Serializable with Logging { // Find the number of features by looking at the first sample. val numFeatures = input.first().features.length logDebug("numFeatures = " + numFeatures) - val numBins = strategy.numBins + val numBins = bins(0).length logDebug("numBins = " + numBins) /** Find the filters used before reaching the current code. */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index 7c9b4796ed62b..df565f3eb8859 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -40,7 +40,4 @@ class Strategy ( val maxDepth: Int, val maxBins: Int = 100, val quantileCalculationStrategy: QuantileStrategy = Sort, - val categoricalFeaturesInfo: Map[Int,Int] = Map[Int,Int]()) extends Serializable { - - var numBins: Int = Int.MinValue -} + val categoricalFeaturesInfo: Map[Int,Int] = Map[Int,Int]()) extends Serializable diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index a359bf3a76ce1..4349c7000a0ae 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -252,7 +252,6 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) - strategy.numBins = 100 val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0, Array[List[Filter]](), splits, bins) @@ -280,7 +279,6 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy) - strategy.numBins = 100 val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0, Array[List[Filter]](), splits, bins) @@ -310,7 +308,6 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(splits(0).length === 99) assert(bins(0).length === 100) - strategy.numBins = 100 val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0, Array[List[Filter]](), splits, bins) assert(bestSplits.length === 1) @@ -334,7 +331,6 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(splits(0).length === 99) assert(bins(0).length === 100) - strategy.numBins = 100 val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0, Array[List[Filter]](), splits, bins) assert(bestSplits.length === 1) @@ -359,7 +355,6 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(splits(0).length === 99) assert(bins(0).length === 100) - strategy.numBins = 100 val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0, Array[List[Filter]](), splits, bins) assert(bestSplits.length === 1) @@ -384,7 +379,6 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(splits(0).length === 99) assert(bins(0).length === 100) - strategy.numBins = 100 val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0, Array[List[Filter]](), splits, bins) assert(bestSplits.length === 1)