From cd2c2b436bfac172cbfeb115220d988042080915 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Fri, 7 Mar 2014 00:38:00 -0800 Subject: [PATCH] fixing code style based on feedback --- .../spark/mllib/tree/DecisionTree.scala | 467 +++++++++++------- .../spark/mllib/tree/DecisionTreeRunner.scala | 143 ------ .../spark/mllib/tree/configuration/Algo.scala | 3 +- .../tree/configuration/FeatureType.scala | 1 + .../tree/configuration/QuantileStrategy.scala | 1 + .../mllib/tree/configuration/Strategy.scala | 15 +- .../spark/mllib/tree/impurity/Entropy.scala | 1 + .../spark/mllib/tree/impurity/Gini.scala | 1 + .../spark/mllib/tree/impurity/Impurity.scala | 19 +- .../spark/mllib/tree/impurity/Variance.scala | 4 +- .../apache/spark/mllib/tree/model/Bin.scala | 3 +- .../mllib/tree/model/DecisionTreeModel.scala | 5 +- .../spark/mllib/tree/model/Filter.scala | 3 +- .../tree/model/InformationGainStats.scala | 9 +- .../apache/spark/mllib/tree/model/Node.scala | 28 +- .../apache/spark/mllib/tree/model/Split.scala | 7 +- .../spark/mllib/tree/DecisionTreeSuite.scala | 7 +- 17 files changed, 365 insertions(+), 352 deletions(-) delete mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRunner.scala diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 1c813244e5630..d57cb6dc4c91d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -18,37 +18,35 @@ package org.apache.spark.mllib.tree import org.apache.spark.SparkContext._ +import scala.util.control.Breaks._ import org.apache.spark.rdd.RDD import org.apache.spark.mllib.tree.model._ import org.apache.spark.{SparkContext, Logging} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.model.Split -import scala.util.control.Breaks._ import org.apache.spark.mllib.tree.configuration.Strategy import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ import org.apache.spark.mllib.tree.configuration.FeatureType._ import org.apache.spark.mllib.tree.configuration.Algo._ -import org.apache.spark.mllib.tree.impurity.Impurity +import org.apache.spark.mllib.tree.impurity.{Variance, Entropy, Gini, Impurity} /** -A class that implements a decision tree algorithm for classification and regression. -It supports both continuous and categorical features. - -@param strategy The configuration parameters for the tree algorithm which specify the type of -algorithm (classification, -regression, etc.), feature type (continuous, categorical), depth of the tree, -quantile calculation strategy, etc. - */ -class DecisionTree private (val strategy : Strategy) extends Serializable with Logging { + * A class that implements a decision tree algorithm for classification and regression. It + * supports both continuous and categorical features. + * @param strategy The configuration parameters for the tree algorithm which specify the type + * of algorithm (classification, regression, etc.), feature type (continuous, + * categorical), + * depth of the tree, quantile calculation strategy, etc. + */ +class DecisionTree private(val strategy: Strategy) extends Serializable with Logging { /** - Method to train a decision tree model over an RDD - - @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data - for DecisionTree - @return a DecisionTreeModel that can be used for prediction + * Method to train a decision tree model over an RDD + * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data + * for DecisionTree + * @return a DecisionTreeModel that can be used for prediction */ - def train(input : RDD[LabeledPoint]) : DecisionTreeModel = { + def train(input: RDD[LabeledPoint]): DecisionTreeModel = { //Cache input RDD for speedup during multiple passes input.cache() @@ -59,7 +57,7 @@ class DecisionTree private (val strategy : Strategy) extends Serializable with L val maxDepth = strategy.maxDepth - val maxNumNodes = scala.math.pow(2,maxDepth).toInt - 1 + val maxNumNodes = scala.math.pow(2, maxDepth).toInt - 1 val filters = new Array[List[Filter]](maxNumNodes) filters(0) = List() val parentImpurities = new Array[Double](maxNumNodes) @@ -70,7 +68,7 @@ class DecisionTree private (val strategy : Strategy) extends Serializable with L logDebug("algo = " + strategy.algo) breakable { - for (level <- 0 until maxDepth){ + for (level <- 0 until maxDepth) { logDebug("#####################################") logDebug("level = " + level) @@ -78,19 +76,19 @@ class DecisionTree private (val strategy : Strategy) extends Serializable with L //Find best split for all nodes at a level val splitsStatsForLevel = DecisionTree.findBestSplits(input, parentImpurities, strategy, - level, filters,splits,bins) + level, filters, splits, bins) - for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex){ + for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) { - extractNodeInfo(nodeSplitStats, level, index, nodes) + extractNodeInfo(nodeSplitStats, level, index, nodes) extractInfoForLowerLevels(level, index, maxDepth, nodeSplitStats, parentImpurities, filters) logDebug("final best split = " + nodeSplitStats._1) } - require(scala.math.pow(2,level)==splitsStatsForLevel.length) + require(scala.math.pow(2, level) == splitsStatsForLevel.length) - val allLeaf = splitsStatsForLevel.forall(_._2.gain <= 0 ) + val allLeaf = splitsStatsForLevel.forall(_._2.gain <= 0) logDebug("all leaf = " + allLeaf) if (allLeaf) break @@ -159,92 +157,89 @@ class DecisionTree private (val strategy : Strategy) extends Serializable with L object DecisionTree extends Serializable with Logging { /** - Method to train a decision tree model over an RDD - - @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data - for DecisionTree - @param strategy The configuration parameters for the tree algorithm which specify the type of algorithm - (classification, regression, etc.), feature type (continuous, categorical), - depth of the tree, quantile calculation strategy, etc. - @return a DecisionTreeModel that can be used for prediction + * Method to train a decision tree model over an RDD + * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data + * for DecisionTree + * @param strategy The configuration parameters for the tree algorithm which specify the type + * of algoritm (classification, regression, etc.), feature type (continuous, + * categorical), depth of the tree, quantile calculation strategy, etc. + * @return a DecisionTreeModel that can be used for prediction */ - def train(input : RDD[LabeledPoint], strategy : Strategy) : DecisionTreeModel = { - new DecisionTree(strategy).train(input : RDD[LabeledPoint]) + def train(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeModel = { + new DecisionTree(strategy).train(input: RDD[LabeledPoint]) } /** - Method to train a decision tree model over an RDD - - @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data - for DecisionTree - @param algo classification or regression - @param impurity criterion used for information gain calculation - @param maxDepth maximum depth of the tree - @return a DecisionTreeModel that can be used for prediction - */ + * Method to train a decision tree model over an RDD + * @param input input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as + * training data + * @param algo algo classification or regression + * @param impurity impurity criterion used for information gain calculation + * @param maxDepth maxDepth maximum depth of the tree + * @return a DecisionTreeModel that can be used for prediction + */ def train( - input : RDD[LabeledPoint], - algo : Algo, - impurity : Impurity, - maxDepth : Int - ) : DecisionTreeModel = { + input: RDD[LabeledPoint], + algo: Algo, + impurity: Impurity, + maxDepth: Int + ): DecisionTreeModel = { val strategy = new Strategy(algo,impurity,maxDepth) - new DecisionTree(strategy).train(input : RDD[LabeledPoint]) + new DecisionTree(strategy).train(input: RDD[LabeledPoint]) } /** - Method to train a decision tree model over an RDD - - @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data - for DecisionTree - @param algo classification or regression - @param impurity criterion used for information gain calculation - @param maxDepth maximum depth of the tree - @param maxBins maximum number of bins used for splitting features - @param quantileCalculationStrategy algorithm for calculating quantiles - @param categoricalFeaturesInfo A map storing information about the categorical variables and the number of discrete - values they take. For example, an entry (n -> k) implies the feature n is - categorical with k categories 0, 1, 2, ... , k-1. It's important to note that - features are zero-indexed. - @return a DecisionTreeModel that can be used for prediction - */ + * Method to train a decision tree model over an RDD + * @param input input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as + * training data for DecisionTree + * @param algo classification or regression + * @param impurity criterion used for information gain calculation + * @param maxDepth maximum depth of the tree + * @param maxBins maximum number of bins used for splitting features + * @param quantileCalculationStrategy algorithm for calculating quantiles + * @param categoricalFeaturesInfo A map storing information about the categorical variables and + * the number of discrete values they take. For example, + * an entry (n -> k) implies the feature n is categorical with k + * categories 0, 1, 2, ... , k-1. It's important to note that + * features are zero-indexed. + * @return a DecisionTreeModel that can be used for prediction + */ def train( - input : RDD[LabeledPoint], - algo : Algo, - impurity : Impurity, - maxDepth : Int, - maxBins : Int, - quantileCalculationStrategy : QuantileStrategy, - categoricalFeaturesInfo : Map[Int,Int] - ) : DecisionTreeModel = { - val strategy = new Strategy(algo,impurity,maxDepth,maxBins,quantileCalculationStrategy,categoricalFeaturesInfo) - new DecisionTree(strategy).train(input : RDD[LabeledPoint]) + input: RDD[LabeledPoint], + algo: Algo, + impurity: Impurity, + maxDepth: Int, + maxBins: Int, + quantileCalculationStrategy: QuantileStrategy, + categoricalFeaturesInfo: Map[Int,Int] + ): DecisionTreeModel = { + val strategy = new Strategy(algo, impurity, maxDepth, maxBins, quantileCalculationStrategy, + categoricalFeaturesInfo) + new DecisionTree(strategy).train(input: RDD[LabeledPoint]) } /** - Returns an Array[Split] of optimal splits for all nodes at a given level - - @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data - for DecisionTree - @param parentImpurities Impurities for all parent nodes for the current level - @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing - parameters for construction the DecisionTree - @param level Level of the tree - @param filters Filter for all nodes at a given level - @param splits possible splits for all features - @param bins possible bins for all features - - @return Array[Split] instance for best splits for all nodes at a given level. - */ + * Returns an array of optimal splits for all nodes at a given level + * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data + * for DecisionTree + * @param parentImpurities Impurities for all parent nodes for the current level + * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing + * parameters for construction the DecisionTree + * @param level Level of the tree + * @param filters Filters for all nodes at a given level + * @param splits possible splits for all features + * @param bins possible bins for all features + * @return array of splits with best splits for all nodes at a given level. + */ def findBestSplits( - input : RDD[LabeledPoint], - parentImpurities : Array[Double], - strategy: Strategy, - level: Int, - filters : Array[List[Filter]], - splits : Array[Array[Split]], - bins : Array[Array[Bin]]) : Array[(Split, InformationGainStats)] = { + input: RDD[LabeledPoint], + parentImpurities: Array[Double], + strategy: Strategy, + level: Int, + filters: Array[List[Filter]], + splits: Array[Array[Split]], + bins: Array[Array[Bin]]): Array[(Split, InformationGainStats)] = { //Common calculations for multiple nested methods val numNodes = scala.math.pow(2, level).toInt @@ -266,8 +261,8 @@ object DecisionTree extends Serializable with Logging { } /** - Find whether the sample is valid input for the current node. - In other words, does it pass through all the filters for the current node. + * Find whether the sample is valid input for the current node. In other words, + * does it pass through all the filters for the current node. */ def isSampleValid(parentFilters: List[Filter], labeledPoint: LabeledPoint): Boolean = { @@ -302,9 +297,12 @@ object DecisionTree extends Serializable with Logging { } /** - Finds the right bin for the given feature + * Finds the right bin for the given feature */ - def findBin(featureIndex: Int, labeledPoint: LabeledPoint, isFeatureContinuous : Boolean) : Int = { + def findBin( + featureIndex: Int, + labeledPoint: LabeledPoint, + isFeatureContinuous: Boolean): Int = { if (isFeatureContinuous){ for (binIndex <- 0 until strategy.numBins) { @@ -334,16 +332,18 @@ object DecisionTree extends Serializable with Logging { } /** - Finds bins for all nodes (and all features) at a given level - k features, l nodes (level = log2(l)) - Storage label, b11, b12, b13, .., bk, b21, b22, .. ,bl1, bl2, .. ,blk - Denotes invalid sample for tree by noting bin for feature 1 as -1 + * Finds bins for all nodes (and all features) at a given level k features, + * l nodes (level = log2(l)). + * Storage label, b11, b12, b13, .., b1k, + * b21, b22, .. , b2k, + * bl1, bl2, .. , blk + * Denotes invalid sample for tree by noting bin for feature 1 as -1 */ - def findBinsForLevel(labeledPoint : LabeledPoint) : Array[Double] = { + def findBinsForLevel(labeledPoint: LabeledPoint): Array[Double] = { // calculating bin index and label per feature per node - val arr = new Array[Double](1+(numFeatures * numNodes)) + val arr = new Array[Double](1 + (numFeatures * numNodes)) arr(0) = labeledPoint.label for (nodeIndex <- 0 until numNodes) { val parentFilters = findParentFilters(nodeIndex) @@ -354,7 +354,7 @@ object DecisionTree extends Serializable with Logging { //Add to invalid bin index -1 breakable { for (featureIndex <- 0 until numFeatures) { - arr(shift+featureIndex) = -1 + arr(shift + featureIndex) = -1 //Breaking since marking one bin is sufficient break() } @@ -440,20 +440,19 @@ object DecisionTree extends Serializable with Logging { } /** - Performs a sequential aggregation over a partition. - - for p bins, k features, l nodes (level = log2(l)) storage is of the form: - b111_left_count,b111_right_count, .... , .. - .. bpk1_left_count, bpk1_right_count, .... , .. - .. bpkl_left_count, bpkl_right_count - - @param agg Array[Double] storing aggregate calculation of size 2*numSplits*numFeatures*numNodes for classification - and 3*numSplits*numFeatures*numNodes for regression - @param arr Array[Double] of size 1+(numFeatures*numNodes) - @return Array[Double] storing aggregate calculation of size 2*numSplits*numFeatures*numNodes for classification - and 3*numSplits*numFeatures*numNodes for regression - */ - def binSeqOp(agg : Array[Double], arr: Array[Double]) : Array[Double] = { + * Performs a sequential aggregation over a partition. + * for p bins, k features, l nodes (level = log2(l)) storage is of the form: + * b111_left_count,b111_right_count, .... , .... + * bpk1_left_count, bpk1_right_count, .... , ...., bpkl_left_count, bpkl_right_count + * @param agg Array[Double] storing aggregate calculation of size + * 2*numSplits*numFeatures*numNodes for classification and + * 3*numSplits*numFeatures*numNodes for regression + * @param arr Array[Double] of size 1+(numFeatures*numNodes) + * @return Array[Double] storing aggregate calculation of size + * 2*numSplits*numFeatures*numNodes for classification and + * 3*numSplits*numFeatures*numNodes for regression + */ + def binSeqOp(agg: Array[Double], arr: Array[Double]): Array[Double] = { strategy.algo match { case Classification => classificationBinSeqOp(arr, agg) case Regression => regressionBinSeqOp(arr, agg) @@ -468,13 +467,12 @@ object DecisionTree extends Serializable with Logging { logDebug("binAggregateLength = " + binAggregateLength) /** - Combines the aggregates from partitions - @param agg1 Array containing aggregates from one or more partitions - @param agg2 Array containing aggregates from one or more partitions - - @return Combined aggregate from agg1 and agg2 + * Combines the aggregates from partitions + * @param agg1 Array containing aggregates from one or more partitions + * @param agg2 Array containing aggregates from one or more partitions + * @return Combined aggregate from agg1 and agg2 */ - def binCombOp(agg1 : Array[Double], agg2: Array[Double]) : Array[Double] = { + def binCombOp(agg1: Array[Double], agg2: Array[Double]): Array[Double] = { strategy.algo match { case Classification => { val combinedAggregate = new Array[Double](binAggregateLength) @@ -513,11 +511,13 @@ object DecisionTree extends Serializable with Logging { * @param topImpurity impurity of the parent node * @return information gain and statistics for all splits */ - def calculateGainForSplit(leftNodeAgg: Array[Array[Double]], - featureIndex: Int, - splitIndex: Int, - rightNodeAgg: Array[Array[Double]], - topImpurity: Double) : InformationGainStats = { + def calculateGainForSplit( + leftNodeAgg: Array[Array[Double]], + featureIndex: Int, + splitIndex: Int, + rightNodeAgg: Array[Array[Double]], + topImpurity: Double): InformationGainStats = { + strategy.algo match { case Classification => { @@ -606,19 +606,18 @@ object DecisionTree extends Serializable with Logging { } } - val predict = (leftSum + rightSum)/(leftCount+rightCount) - new InformationGainStats(gain,impurity,leftImpurity,rightImpurity,predict) + val predict = (leftSum + rightSum)/(leftCount + rightCount) + new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict) } } } - /** - Extracts left and right split aggregates - - @param binData Array[Double] of size 2*numFeatures*numSplits - @return (leftNodeAgg, rightNodeAgg) tuple of type (Array[Double], Array[Double]) where - each array is of size(numFeature,2*(numSplits-1)) + /** + * Extracts left and right split aggregates + * @param binData Array[Double] of size 2*numFeatures*numSplits + * @return (leftNodeAgg, rightNodeAgg) tuple of type (Array[Double], + * Array[Double]) where each array is of size(numFeature,2*(numSplits-1)) */ def extractLeftRightNodeAggregates( binData: Array[Double]): (Array[Array[Double]], Array[Array[Double]]) = { @@ -696,8 +695,7 @@ object DecisionTree extends Serializable with Logging { def calculateGainsForAllNodeSplits( leftNodeAgg: Array[Array[Double]], rightNodeAgg: Array[Array[Double]], - nodeImpurity: Double) - : Array[Array[InformationGainStats]] = { + nodeImpurity: Double): Array[Array[InformationGainStats]] = { val gains = Array.ofDim[InformationGainStats](numFeatures, numBins - 1) @@ -710,14 +708,15 @@ object DecisionTree extends Serializable with Logging { gains } - /** - Find the best split for a node given bin aggregate data - - @param binData Array[Double] of size 2*numSplits*numFeatures - */ + /** + * Find the best split for a node given bin aggregate data + * @param binData Array[Double] of size 2*numSplits*numFeatures + * @param nodeImpurity impurity of the top node + * @return + */ def binsToBestSplit( - binData : Array[Double], - nodeImpurity : Double) : (Split, InformationGainStats) = { + binData: Array[Double], + nodeImpurity: Double): (Split, InformationGainStats) = { logDebug("node impurity = " + nodeImpurity) val (leftNodeAgg, rightNodeAgg) = extractLeftRightNodeAggregates(binData) @@ -771,24 +770,24 @@ object DecisionTree extends Serializable with Logging { logDebug("node impurity = " + parentNodeImpurity) bestSplits(node) = binsToBestSplit(binsForNode, parentNodeImpurity) } - bestSplits } + + /** - Returns split and bins for decision tree calculation. - - @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data - for DecisionTree - @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing - parameters for construction the DecisionTree - @return a tuple of (splits,bins) where splits is an Array of [org.apache.spark.mllib.tree.model.Split] of - size (numFeatures,numSplits-1) and bins is an Array of [org.apache.spark.mllib.tree.model.Bin] of - size (numFeatures,numSplits1) + * Returns split and bins for decision tree calculation. + * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data + * for DecisionTree + * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing + * parameters for construction the DecisionTree + * @return a tuple of (splits,bins) where splits is an Array of [org.apache.spark.mllib.tree + * .model.Split] of size (numFeatures,numSplits-1) and bins is an Array of [org.apache + * .spark.mllib.tree.model.Bin] of size (numFeatures,numSplits1) */ def findSplitsBins( - input : RDD[LabeledPoint], - strategy : Strategy): (Array[Array[Split]], Array[Array[Bin]]) = { + input: RDD[LabeledPoint], + strategy: Strategy): (Array[Array[Split]], Array[Array[Bin]]) = { val count = input.count() @@ -807,7 +806,7 @@ object DecisionTree extends Serializable with Logging { val sampledInput = input.sample(false, fraction, 42).collect() val numSamples = sampledInput.length - val stride : Double = numSamples.toDouble/numBins + val stride: Double = numSamples.toDouble/numBins logDebug("stride = " + stride) strategy.quantileCalculationStrategy match { @@ -821,11 +820,11 @@ object DecisionTree extends Serializable with Logging { if (isFeatureContinuous) { val featureSamples = sampledInput.map(lp => lp.features(featureIndex)).sorted - val stride : Double = numSamples.toDouble/numBins + val stride: Double = numSamples.toDouble/numBins logDebug("stride = " + stride) for (index <- 0 until numBins-1) { - val sampleIndex = (index+1)*stride.toInt - val split = new Split(featureIndex,featureSamples(sampleIndex),Continuous, List()) + val sampleIndex = (index + 1)*stride.toInt + val split = new Split(featureIndex, featureSamples(sampleIndex), Continuous, List()) splits(featureIndex)(index) = split } } else { @@ -856,17 +855,17 @@ object DecisionTree extends Serializable with Logging { var categoriesForSplit = List[Double]() categoriesSortedByCentriod.iterator.zipWithIndex foreach { case((key, value), index) => { - categoriesForSplit = key :: categoriesForSplit - splits(featureIndex)(index) = new Split(featureIndex,Double.MinValue,Categorical, + categoriesForSplit = key:: categoriesForSplit + splits(featureIndex)(index) = new Split(featureIndex, Double.MinValue, Categorical, categoriesForSplit) bins(featureIndex)(index) = { if(index == 0) { - new Bin(new DummyCategoricalSplit(featureIndex,Categorical), - splits(featureIndex)(0),Categorical,key) + new Bin(new DummyCategoricalSplit(featureIndex, Categorical), + splits(featureIndex)(0), Categorical, key) } else { - new Bin(splits(featureIndex)(index-1),splits(featureIndex)(index), - Categorical,key) + new Bin(splits(featureIndex)(index-1), splits(featureIndex)(index), + Categorical, key) } } } @@ -882,19 +881,19 @@ object DecisionTree extends Serializable with Logging { = new Bin(new DummyLowSplit(featureIndex, Continuous),splits(featureIndex)(0), Continuous,Double.MinValue) for (index <- 1 until numBins - 1){ - val bin = new Bin(splits(featureIndex)(index-1),splits(featureIndex)(index), - Continuous,Double.MinValue) + val bin = new Bin(splits(featureIndex)(index-1), splits(featureIndex)(index), + Continuous, Double.MinValue) bins(featureIndex)(index) = bin } bins(featureIndex)(numBins-1) = new Bin(splits(featureIndex)(numBins-2),new DummyHighSplit(featureIndex, - Continuous),Continuous,Double.MinValue) + Continuous), Continuous, Double.MinValue) } else { val maxFeatureValue = strategy.categoricalFeaturesInfo(featureIndex) for (i <- maxFeatureValue until numBins){ bins(featureIndex)(i) - = new Bin(new DummyCategoricalSplit(featureIndex,Categorical), - new DummyCategoricalSplit(featureIndex,Categorical),Categorical,Double.MaxValue) + = new Bin(new DummyCategoricalSplit(featureIndex, Categorical), + new DummyCategoricalSplit(featureIndex, Categorical), Categorical, Double.MaxValue) } } } @@ -906,10 +905,126 @@ object DecisionTree extends Serializable with Logging { case ApproxHist => { throw new UnsupportedOperationException("approximate histogram not supported yet.") } + } + } + + + val usage = """ + Usage: DecisionTreeRunner[slices] --algo --trainDataDir path --testDataDir path --maxDepth num [--impurity ] [--maxBins num] + """ + + + def main(args: Array[String]) { + + if (args.length < 2) { + System.err.println(usage) + System.exit(1) + } + + val sc = new SparkContext(args(0), "DecisionTree") + + + val arglist = args.toList.drop(1) + type OptionMap = Map[Symbol, Any] + + def nextOption(map : OptionMap, list: List[String]): OptionMap = { + def isSwitch(s : String) = (s(0) == '-') + list match { + case Nil => map + case "--algo" :: string :: tail => nextOption(map ++ Map('algo -> string), tail) + case "--impurity" :: string :: tail => nextOption(map ++ Map('impurity -> string), tail) + case "--maxDepth" :: string :: tail => nextOption(map ++ Map('maxDepth -> string), tail) + case "--maxBins" :: string :: tail => nextOption(map ++ Map('maxBins -> string), tail) + case "--trainDataDir" :: string :: tail => nextOption(map ++ Map('trainDataDir -> string) + , tail) + case "--testDataDir" :: string :: tail => nextOption(map ++ Map('testDataDir -> string), + tail) + case string :: Nil => nextOption(map ++ Map('infile -> string), list.tail) + case option :: tail => logError("Unknown option " + option) + sys.exit(1) + } + } + val options = nextOption(Map(),arglist) + logDebug(options.toString()) + //TODO: Add validation for input parameters + + //Load training data + val trainData = loadLabeledData(sc, options.get('trainDataDir).get.toString) + + //Figure out the type of algorithm + val algoStr = options.get('algo).get.toString + val algo = algoStr match { + case "Classification" => Classification + case "Regression" => Regression + } + //Identify the type of impurity + val impurityStr = options.getOrElse('impurity, + if (algo == Classification) "Gini" else "Variance").toString + val impurity = impurityStr match { + case "Gini" => Gini + case "Entropy" => Entropy + case "Variance" => Variance } + + val maxDepth = options.getOrElse('maxDepth,"1").toString.toInt + val maxBins = options.getOrElse('maxBins,"100").toString.toInt + + val strategy = new Strategy(algo = algo, impurity = impurity, maxDepth = maxDepth, + maxBins = maxBins) + val model = DecisionTree.train(trainData,strategy) + + //Load test data + val testData = loadLabeledData(sc, options.get('testDataDir).get.toString) + + //Measure algorithm accuracy + val accuracy = accuracyScore(model, testData) + logDebug("accuracy = " + accuracy) + + val mse = meanSquaredError(model,testData) + logDebug("mean square error = " + mse) + + sc.stop() + } + + /** + * Load labeled data from a file. The data format used here is + * , ... + * where , are feature values in Double and is the corresponding label as Double. + * + * @param sc SparkContext + * @param dir Directory to the input data files. + * @return An RDD of LabeledPoint. Each labeled point has two elements: the first element is + * the label, and the second element represents the feature values (an array of Double). + */ + def loadLabeledData(sc: SparkContext, dir: String): RDD[LabeledPoint] = { + sc.textFile(dir).map { line => + val parts = line.trim().split(",") + val label = parts(0).toDouble + val features = parts.slice(1,parts.length).map(_.toDouble) + LabeledPoint(label, features) + } + } + + //TODO: Port them to a metrics package + def accuracyScore(model: DecisionTreeModel, data: RDD[LabeledPoint]): Double = { + val correctCount = data.filter(y => model.predict(y.features) == y.label).count() + val count = data.count() + logDebug("correct prediction count = " + correctCount) + logDebug("data count = " + count) + correctCount.toDouble / count + } + + //TODO: Make these generic MLTable metrics + def meanSquaredError(tree: DecisionTreeModel, data: RDD[LabeledPoint]): Double = { + val meanSumOfSquares = + data.map(y => (tree.predict(y.features) - y.label)*(tree.predict(y.features) - y.label)) + .mean() + meanSumOfSquares } -} \ No newline at end of file +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRunner.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRunner.scala deleted file mode 100644 index d93633d26228d..0000000000000 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRunner.scala +++ /dev/null @@ -1,143 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.mllib.tree - -import org.apache.spark.SparkContext._ -import org.apache.spark.{Logging, SparkContext} -import org.apache.spark.mllib.tree.impurity.{Gini,Entropy,Variance} -import org.apache.spark.rdd.RDD -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.model.DecisionTreeModel -import org.apache.spark.mllib.tree.configuration.Strategy -import org.apache.spark.mllib.tree.configuration.Algo._ - - -object DecisionTreeRunner extends Logging { - - val usage = """ - Usage: DecisionTreeRunner[slices] --algo --trainDataDir path --testDataDir path --maxDepth num [--impurity ] [--maxBins num] - """ - - - def main(args: Array[String]) { - - if (args.length < 2) { - System.err.println(usage) - System.exit(1) - } - - val sc = new SparkContext(args(0), "DecisionTree") - - - val arglist = args.toList.drop(1) - type OptionMap = Map[Symbol, Any] - - def nextOption(map : OptionMap, list: List[String]) : OptionMap = { - def isSwitch(s : String) = (s(0) == '-') - list match { - case Nil => map - case "--algo" :: string :: tail => nextOption(map ++ Map('algo -> string), tail) - case "--impurity" :: string :: tail => nextOption(map ++ Map('impurity -> string), tail) - case "--maxDepth" :: string :: tail => nextOption(map ++ Map('maxDepth -> string), tail) - case "--maxBins" :: string :: tail => nextOption(map ++ Map('maxBins -> string), tail) - case "--trainDataDir" :: string :: tail => nextOption(map ++ Map('trainDataDir -> string), tail) - case "--testDataDir" :: string :: tail => nextOption(map ++ Map('testDataDir -> string), tail) - case string :: Nil => nextOption(map ++ Map('infile -> string), list.tail) - case option :: tail => logError("Unknown option "+option) - sys.exit(1) - } - } - val options = nextOption(Map(),arglist) - logDebug(options.toString()) - //TODO: Add validation for input parameters - - //Load training data - val trainData = loadLabeledData(sc, options.get('trainDataDir).get.toString) - - //Figure out the type of algorithm - val algoStr = options.get('algo).get.toString - val algo = algoStr match { - case "Classification" => Classification - case "Regression" => Regression - } - - //Identify the type of impurity - val impurityStr = options.getOrElse('impurity,if (algo == Classification) "Gini" else "Variance").toString - val impurity = impurityStr match { - case "Gini" => Gini - case "Entropy" => Entropy - case "Variance" => Variance - } - - val maxDepth = options.getOrElse('maxDepth,"1").toString.toInt - val maxBins = options.getOrElse('maxBins,"100").toString.toInt - - val strategy = new Strategy(algo = algo, impurity = impurity, maxDepth = maxDepth, maxBins = maxBins) - val model = DecisionTree.train(trainData,strategy) - - //Load test data - val testData = loadLabeledData(sc, options.get('testDataDir).get.toString) - - //Measure algorithm accuracy - val accuracy = accuracyScore(model, testData) - logDebug("accuracy = " + accuracy) - - val mse = meanSquaredError(model,testData) - logDebug("mean square error = " + mse) - - sc.stop() - } - - /** - * Load labeled data from a file. The data format used here is - * , ... - * where , are feature values in Double and is the corresponding label as Double. - * - * @param sc SparkContext - * @param dir Directory to the input data files. - * @return An RDD of LabeledPoint. Each labeled point has two elements: the first element is - * the label, and the second element represents the feature values (an array of Double). - */ - def loadLabeledData(sc: SparkContext, dir: String): RDD[LabeledPoint] = { - sc.textFile(dir).map { line => - val parts = line.trim().split(",") - val label = parts(0).toDouble - val features = parts.slice(1,parts.length).map(_.toDouble) - LabeledPoint(label, features) - } - } - - //TODO: Port them to a metrics package - def accuracyScore(model : DecisionTreeModel, data : RDD[LabeledPoint]) : Double = { - val correctCount = data.filter(y => model.predict(y.features) == y.label).count() - val count = data.count() - logDebug("correct prediction count = " + correctCount) - logDebug("data count = " + count) - correctCount.toDouble / count - } - - //TODO: Make these generic MLTable metrics - def meanSquaredError(tree : DecisionTreeModel, data : RDD[LabeledPoint]) : Double = { - val meanSumOfSquares = - data.map(y => (tree.predict(y.features) - y.label)*(tree.predict(y.features) - y.label)).mean() - meanSumOfSquares - } - - - - -} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala index 14ec47ce014e7..2dd1f0f27b8f5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala @@ -14,6 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.apache.spark.mllib.tree.configuration /** @@ -22,4 +23,4 @@ package org.apache.spark.mllib.tree.configuration object Algo extends Enumeration { type Algo = Value val Classification, Regression = Value -} \ No newline at end of file +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/FeatureType.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/FeatureType.scala index b4e8ae4ac39dd..09ee0586c58fa 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/FeatureType.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/FeatureType.scala @@ -14,6 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.apache.spark.mllib.tree.configuration /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala index dae73ab52dea7..2457a480c2a14 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala @@ -14,6 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.apache.spark.mllib.tree.configuration /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index 88dfa76fc284f..9e461cfdbbd08 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -14,6 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.apache.spark.mllib.tree.configuration import org.apache.spark.mllib.tree.impurity.Impurity @@ -34,13 +35,13 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ * zero-indexed. */ class Strategy ( - val algo : Algo, - val impurity : Impurity, - val maxDepth : Int, - val maxBins : Int = 100, - val quantileCalculationStrategy : QuantileStrategy = Sort, - val categoricalFeaturesInfo : Map[Int,Int] = Map[Int,Int]()) extends Serializable { + val algo: Algo, + val impurity: Impurity, + val maxDepth: Int, + val maxBins: Int = 100, + val quantileCalculationStrategy: QuantileStrategy = Sort, + val categoricalFeaturesInfo: Map[Int,Int] = Map[Int,Int]()) extends Serializable { - var numBins : Int = Int.MinValue + var numBins: Int = Int.MinValue } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala index c1b2972f9c25b..9018821abc875 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala @@ -14,6 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.apache.spark.mllib.tree.impurity import javax.naming.OperationNotSupportedException diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala index 099c7e33dd39a..20af8f6c1c2cd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala @@ -14,6 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.apache.spark.mllib.tree.impurity import javax.naming.OperationNotSupportedException diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala index cda534b462234..97092c85aea61 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala @@ -14,12 +14,29 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.apache.spark.mllib.tree.impurity +/** + * Trail for calculating information gain + */ trait Impurity extends Serializable { + /** + * information calculation for binary classification + * @param c0 count of instances with label 0 + * @param c1 count of instances with label 1 + * @return information value + */ def calculate(c0 : Double, c1 : Double): Double - def calculate(count : Double, sum : Double, sumSquares : Double) : Double + /** + * information calculation for regression + * @param count number of instances + * @param sum sum of labels + * @param sumSquares summation of squares of the labels + * @return information value + */ + def calculate(count: Double, sum: Double, sumSquares: Double): Double } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala index b313b8d48eadf..85b7be560fecb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala @@ -14,6 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.apache.spark.mllib.tree.impurity import javax.naming.OperationNotSupportedException @@ -23,7 +24,8 @@ import org.apache.spark.Logging * Class for calculating variance during regression */ object Variance extends Impurity with Logging { - def calculate(c0: Double, c1: Double): Double = throw new OperationNotSupportedException("Variance.calculate") + def calculate(c0: Double, c1: Double): Double + = throw new OperationNotSupportedException("Variance.calculate") /** * variance calculation diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala index 0b4b7d2e5b2df..47afe3aed2b1b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala @@ -14,6 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.apache.spark.mllib.tree.model import org.apache.spark.mllib.tree.configuration.FeatureType._ @@ -29,6 +30,6 @@ import org.apache.spark.mllib.tree.configuration.FeatureType._ * @param featureType type of feature -- categorical or continuous * @param category categorical label value accepted in the bin */ -case class Bin(lowSplit : Split, highSplit : Split, featureType : FeatureType, category : Double) { +case class Bin(lowSplit: Split, highSplit: Split, featureType: FeatureType, category: Double) { } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala index 0e94827b0af70..94d77571dc22f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -14,6 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.apache.spark.mllib.tree.model import org.apache.spark.mllib.tree.configuration.Algo._ @@ -24,7 +25,7 @@ import org.apache.spark.rdd.RDD * @param topNode root node * @param algo algorithm type -- classification or regression */ -class DecisionTreeModel(val topNode : Node, val algo : Algo) extends Serializable { +class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable { /** * Predict values for a single data point using the model trained. @@ -32,7 +33,7 @@ class DecisionTreeModel(val topNode : Node, val algo : Algo) extends Serializabl * @param features array representing a single data point * @return Double prediction from the trained model */ - def predict(features : Array[Double]) : Double = { + def predict(features: Array[Double]): Double = { algo match { case Classification => { if (topNode.predictIfLeaf(features) < 0.5) 0.0 else 1.0 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Filter.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Filter.scala index 9fc794c87398d..ebc9595eafef3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Filter.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Filter.scala @@ -14,6 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.apache.spark.mllib.tree.model /** @@ -21,7 +22,7 @@ package org.apache.spark.mllib.tree.model * @param split split specifying the feature index, type and threshold * @param comparison integer specifying <,=,> */ -case class Filter(split : Split, comparison : Int) { +case class Filter(split: Split, comparison: Int) { // Comparison -1,0,1 signifies <.=,> override def toString = " split = " + split + "comparison = " + comparison } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala index 0f8d7a36d208f..64ff826486f5b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala @@ -14,6 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.apache.spark.mllib.tree.model /** @@ -25,11 +26,11 @@ package org.apache.spark.mllib.tree.model * @param predict predicted value */ class InformationGainStats( - val gain : Double, + val gain: Double, val impurity: Double, - val leftImpurity : Double, - val rightImpurity : Double, - val predict : Double) extends Serializable { + val leftImpurity: Double, + val rightImpurity: Double, + val predict: Double) extends Serializable { override def toString = { "gain = %f, impurity = %f, left impurity = %f, right impurity = %f, predict = %f" diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala index 374f065a09032..4a2c876a51b54 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala @@ -14,6 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.apache.spark.mllib.tree.model import org.apache.spark.Logging @@ -30,19 +31,23 @@ import org.apache.spark.mllib.tree.configuration.FeatureType._ * @param rightNode right child * @param stats information gain stats */ -class Node ( val id : Int, - val predict : Double, - val isLeaf : Boolean, - val split : Option[Split], - var leftNode : Option[Node], - var rightNode : Option[Node], - val stats : Option[InformationGainStats] - ) extends Serializable with Logging{ +class Node ( + val id: Int, + val predict: Double, + val isLeaf: Boolean, + val split: Option[Split], + var leftNode: Option[Node], + var rightNode: Option[Node], + val stats: Option[InformationGainStats]) extends Serializable with Logging{ override def toString = "id = " + id + ", isLeaf = " + isLeaf + ", predict = " + predict + ", " + "split = " + split + ", stats = " + stats - def build(nodes : Array[Node]) : Unit = { + /** + * build the left node and right nodes if not leaf + * @param nodes array of nodes + */ + def build(nodes : Array[Node]): Unit = { logDebug("building node " + id + " at level " + (scala.math.log(id + 1)/scala.math.log(2)).toInt ) @@ -59,6 +64,11 @@ class Node ( val id : Int, } } + /** + * predict value if node is not leaf + * @param feature feature value + * @return predicted value + */ def predictIfLeaf(feature : Array[Double]) : Double = { if (isLeaf) { predict diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala index 81e57dbf5e521..fffd68d7a64b5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala @@ -14,6 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.apache.spark.mllib.tree.model import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType @@ -27,9 +28,9 @@ import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType */ case class Split( feature: Int, - threshold : Double, - featureType : FeatureType, - categories : List[Double]){ + threshold: Double, + featureType: FeatureType, + categories: List[Double]){ override def toString = "Feature = " + feature + ", threshold = " + threshold + ", featureType = " + featureType + diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 39635a7e654a2..a299b087dfda8 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -14,6 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.apache.spark.mllib.tree import scala.util.Random @@ -393,7 +394,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { object DecisionTreeSuite { - def generateOrderedLabeledPointsWithLabel0() : Array[LabeledPoint] = { + def generateOrderedLabeledPointsWithLabel0(): Array[LabeledPoint] = { val arr = new Array[LabeledPoint](1000) for (i <- 0 until 1000){ val lp = new LabeledPoint(0.0,Array(i.toDouble,1000.0-i)) @@ -403,7 +404,7 @@ object DecisionTreeSuite { } - def generateOrderedLabeledPointsWithLabel1() : Array[LabeledPoint] = { + def generateOrderedLabeledPointsWithLabel1(): Array[LabeledPoint] = { val arr = new Array[LabeledPoint](1000) for (i <- 0 until 1000){ val lp = new LabeledPoint(1.0,Array(i.toDouble,999.0-i)) @@ -412,7 +413,7 @@ object DecisionTreeSuite { arr } - def generateCategoricalDataPoints() : Array[LabeledPoint] = { + def generateCategoricalDataPoints(): Array[LabeledPoint] = { val arr = new Array[LabeledPoint](1000) for (i <- 0 until 1000){ if (i < 600){