diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 1f99f28e991f7..c3cbe2c63ab03 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -77,12 +77,12 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo // Max memory usage for aggregates val maxMemoryUsage = strategy.maxMemoryInMB * 1024 * 1024 logDebug("max memory usage for aggregates = " + maxMemoryUsage + " bytes.") - val numElementsPerNode = { + val numElementsPerNode = strategy.algo match { case Classification => 2 * numBins * numFeatures case Regression => 3 * numBins * numFeatures } - } + logDebug("numElementsPerNode = " + numElementsPerNode) val arraySizePerNode = 8 * numElementsPerNode // approx. memory usage for bin aggregate array val maxNumberOfNodesPerGroup = math.max(maxMemoryUsage / arraySizePerNode, 1)