From 0dc2b6361d61b7d94cba3dc83da2abb7e08ba6fe Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Sun, 28 Sep 2014 21:44:50 -0700 Subject: [PATCH 1/8] [SPARK-1545] [mllib] Add Random Forests This PR adds RandomForest to MLlib. The implementation is basic, and future performance optimizations will be important. (Note: RFs = Random Forests.) # Overview ## RandomForest * trains multiple trees at once to reduce the number of passes over the data * allows feature subsets at each node * uses a queue of nodes instead of fixed groups for each level This implementation is based an implementation by manishamde and the [Alpine Labs Sequoia Forest](https://github.com/AlpineNow/SparkML2) by codedeft (in particular, the TreePoint, BaggedPoint, and node queue implementations). Thank you for your inputs! ## Testing Correctness: This has been tested for correctness with the test suites and with DecisionTreeRunner on example datasets. Performance: This has been performance tested using [this branch of spark-perf](https://github.com/jkbradley/spark-perf/tree/rfs). Results below. ### Regression tests for DecisionTree Summary: For training 1 tree, there are small regressions, especially from feature subsampling. In the table below, each row is a single (random) dataset. The 2 different sets of result columns are for 2 different RF implementations: * (numTrees): This is from an earlier commit, after implementing RandomForest to train multiple trees at once. It does not include any code for feature subsampling. * (feature subsets): This is from this current PR's code, after implementing feature subsampling. These tests were to identify regressions in DecisionTree, so they are training 1 tree with all of the features (i.e., no feature subsampling). These were run on an EC2 cluster with 15 workers, training 1 tree with maxDepth = 5 (= 6 levels). Speedup values < 1 indicate slowdowns from the old DecisionTree implementation. numInstances | numFeatures | runtime (sec) | speedup | runtime (sec) | speedup ---- | ---- | ---- | ---- | ---- | ---- | | (numTrees) | (numTrees) | (feature subsets) | (feature subsets) 20000 | 100 | 4.051 | 1.044433473 | 4.478 | 0.9448414471 20000 | 500 | 8.472 | 1.104461756 | 9.315 | 1.004508857 20000 | 1500 | 19.354 | 1.05854087 | 20.863 | 0.9819776638 20000 | 3500 | 43.674 | 1.072033704 | 45.887 | 1.020332556 200000 | 100 | 4.196 | 1.171830315 | 4.848 | 1.014232673 200000 | 500 | 8.926 | 1.082791844 | 9.771 | 0.989151571 200000 | 1500 | 20.58 | 1.068415938 | 22.134 | 0.9934038131 200000 | 3500 | 48.043 | 1.075203464 | 52.249 | 0.9886505005 2000000 | 100 | 4.944 | 1.01355178 | 5.796 | 0.8645617667 2000000 | 500 | 11.11 | 1.016831683 | 12.482 | 0.9050632911 2000000 | 1500 | 31.144 | 1.017852556 | 35.274 | 0.8986789136 2000000 | 3500 | 79.981 | 1.085382778 | 101.105 | 0.8586123337 20000000 | 100 | 8.304 | 0.9270231214 | 9.073 | 0.8484514494 20000000 | 500 | 28.174 | 1.083268262 | 34.236 | 0.8914592826 20000000 | 1500 | 143.97 | 0.9579634646 | 159.275 | 0.8659111599 ### Tests for forests I have run other tests with numTrees=10 and with sqrt(numFeatures), and those indicate that multi-model training and feature subsets can speed up training for forests, especially when training deeper trees. # Details on specific classes ## Changes to DecisionTree * Main train() method is now in RandomForest. * findBestSplits() is no longer needed. (It split levels into groups, but we now use a queue of nodes.) * Many small changes to support RFs. (Note: These methods should be moved to RandomForest.scala in a later PR, but are in DecisionTree.scala to make code comparison easier.) ## RandomForest * Main train() method is from old DecisionTree. * selectNodesToSplit: Note that it selects nodes and feature subsets jointly to track memory usage. ## RandomForestModel * Stores an Array[DecisionTreeModel] * Prediction: * For classification, most common label. For regression, mean. * We could support other methods later. ## examples/.../DecisionTreeRunner * This now takes numTrees and featureSubsetStrategy, to support RFs. ## DTStatsAggregator * 2 types of functionality (w/ and w/o subsampling features): These require different indexing methods. (We could treat both as subsampling, but this is less efficient DTStatsAggregator is now abstract, and 2 child classes implement these 2 types of functionality. ## impurities * These now take instance weights. ## Node * Some vals changed to vars. * This is unfortunately a public API change (DeveloperApi). This could be avoided by creating a LearningNode struct, but would be awkward. ## RandomForestSuite Please let me know if there are missing tests! ## BaggedPoint This wraps TreePoint and holds bootstrap weights/counts. # Design decisions * BaggedPoint: BaggedPoint is separate from TreePoint since it may be useful for other bagging algorithms later on. * RandomForest public API: What options should be easily supported by the train* methods? Should ALL options be in the Java-friendly constructors? Should there be a constructor taking Strategy? * Feature subsampling options: What options should be supported? scikit-learn supports the same options, except for "onethird." One option would be to allow users to specific fractions ("0.1"): the current options could be supported, and any unrecognized values would be parsed as Doubles in [0,1]. * Splits and bins are computed before bootstrapping, so all trees use the same discretization. * One queue, instead of one queue per tree. CC: mengxr manishamde codedeft chouqin Please let me know if you have suggestions---thanks! Author: Joseph K. Bradley Author: qiping.lqp Author: chouqin Closes #2435 from jkbradley/rfs-new and squashes the following commits: c694174 [Joseph K. Bradley] Fixed typo cc59d78 [Joseph K. Bradley] fixed imports e25909f [Joseph K. Bradley] Simplified node group maps. Specifically, created NodeIndexInfo to store node index in agg and feature subsets, and no longer create extra maps in findBestSplits fbe9a1e [Joseph K. Bradley] Changed default featureSubsetStrategy to be sqrt for classification, onethird for regression. Updated docs with references. ef7c293 [Joseph K. Bradley] Updates based on code review. Most substantial changes: * Simplified DTStatsAggregator * Made RandomForestModel.trees public * Added test for regression to RandomForestSuite 593b13c [Joseph K. Bradley] Fixed bug in metadata for computing log2(num features). Now it checks >= 1. a1a08df [Joseph K. Bradley] Removed old comments 866e766 [Joseph K. Bradley] Changed RandomForestSuite randomized tests to use multiple fixed random seeds. ff8bb96 [Joseph K. Bradley] removed usage of null from RandomForest and replaced with Option bf1a4c5 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into rfs-new 6b79c07 [Joseph K. Bradley] Added RandomForestSuite, and fixed small bugs, style issues. d7753d4 [Joseph K. Bradley] Added numTrees and featureSubsetStrategy to DecisionTreeRunner (to support RandomForest). Fixed bugs so that RandomForest now runs. 746d43c [Joseph K. Bradley] Implemented feature subsampling. Tested DecisionTree but not RandomForest. 6309d1d [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into rfs-new. Added RandomForestModel.toString b7ae594 [Joseph K. Bradley] Updated docs. Small fix for bug which does not cause errors: No longer allocate unused child nodes for leaf nodes. 121c74e [Joseph K. Bradley] Basic random forests are implemented. Random features per node not yet implemented. Test suite not implemented. 325d18a [Joseph K. Bradley] Merge branch 'chouqin-dt-preprune' into rfs-new 4ef9bf1 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into rfs-new 61b2e72 [Joseph K. Bradley] Added max of 10GB for maxMemoryInMB in Strategy. a95e7c8 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into chouqin-dt-preprune 6da8571 [Joseph K. Bradley] RFs partly implemented, not done yet eddd1eb [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into rfs-new 5c4ac33 [Joseph K. Bradley] Added check in Strategy to make sure minInstancesPerNode >= 1 0dd4d87 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-spark-3160 95c479d [Joseph K. Bradley] * Fixed typo in tree suite test "do not choose split that does not satisfy min instance per node requirements" * small style fixes e2628b6 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into chouqin-dt-preprune 19b01af [Joseph K. Bradley] Merge remote-tracking branch 'chouqin/dt-preprune' into chouqin-dt-preprune f1d11d1 [chouqin] fix typo c7ebaf1 [chouqin] fix typo 39f9b60 [chouqin] change edge `minInstancesPerNode` to 2 and add one more test c6e2dfc [Joseph K. Bradley] Added minInstancesPerNode and minInfoGain parameters to DecisionTreeRunner.scala and to Python API in tree.py 306120f [Joseph K. Bradley] Fixed typo in DecisionTreeModel.scala doc eaa1dcf [Joseph K. Bradley] Added topNode doc in DecisionTree and scalastyle fix d4d7864 [Joseph K. Bradley] Marked Node.build as deprecated d4dbb99 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-spark-3160 1a8f0ad [Joseph K. Bradley] Eliminated pre-allocated nodes array in main train() method. * Nodes are constructed and added to the tree structure as needed during training. 0278a11 [chouqin] remove `noSplit` and set `Predict` private to tree d593ec7 [chouqin] fix docs and change minInstancesPerNode to 1 2ab763b [Joseph K. Bradley] Simplifications to DecisionTree code: efcc736 [qiping.lqp] fix bug 10b8012 [qiping.lqp] fix style 6728fad [qiping.lqp] minor fix: remove empty lines bb465ca [qiping.lqp] Merge branch 'master' of https://github.com/apache/spark into dt-preprune cadd569 [qiping.lqp] add api docs 46b891f [qiping.lqp] fix bug e72c7e4 [qiping.lqp] add comments 845c6fa [qiping.lqp] fix style f195e83 [qiping.lqp] fix style 987cbf4 [qiping.lqp] fix bug ff34845 [qiping.lqp] separate calculation of predict of node from calculation of info gain ac42378 [qiping.lqp] add min info gain and min instances per node parameters in decision tree --- .../examples/mllib/DecisionTreeRunner.scala | 76 ++- .../spark/mllib/tree/DecisionTree.scala | 457 ++++++------------ .../spark/mllib/tree/RandomForest.scala | 451 +++++++++++++++++ .../spark/mllib/tree/impl/BaggedPoint.scala | 80 +++ .../mllib/tree/impl/DTStatsAggregator.scala | 219 +++++++-- .../tree/impl/DecisionTreeMetadata.scala | 47 +- .../spark/mllib/tree/impurity/Entropy.scala | 4 +- .../spark/mllib/tree/impurity/Gini.scala | 4 +- .../spark/mllib/tree/impurity/Impurity.scala | 2 +- .../spark/mllib/tree/impurity/Variance.scala | 8 +- .../apache/spark/mllib/tree/model/Node.scala | 13 +- .../mllib/tree/model/RandomForestModel.scala | 105 ++++ .../spark/mllib/tree/DecisionTreeSuite.scala | 210 ++++---- .../spark/mllib/tree/RandomForestSuite.scala | 245 ++++++++++ 14 files changed, 1410 insertions(+), 511 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala index 4683e6eb966be..96fb068e9e126 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala @@ -21,16 +21,18 @@ import scopt.OptionParser import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.SparkContext._ +import org.apache.spark.mllib.evaluation.MulticlassMetrics import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.{DecisionTree, impurity} +import org.apache.spark.mllib.tree.{RandomForest, DecisionTree, impurity} import org.apache.spark.mllib.tree.configuration.{Algo, Strategy} import org.apache.spark.mllib.tree.configuration.Algo._ -import org.apache.spark.mllib.tree.model.DecisionTreeModel +import org.apache.spark.mllib.tree.model.{RandomForestModel, DecisionTreeModel} import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD +import org.apache.spark.util.Utils /** - * An example runner for decision tree. Run with + * An example runner for decision trees and random forests. Run with * {{{ * ./bin/run-example org.apache.spark.examples.mllib.DecisionTreeRunner [options] * }}} @@ -57,6 +59,8 @@ object DecisionTreeRunner { maxBins: Int = 32, minInstancesPerNode: Int = 1, minInfoGain: Double = 0.0, + numTrees: Int = 1, + featureSubsetStrategy: String = "auto", fracTest: Double = 0.2) def main(args: Array[String]) { @@ -79,11 +83,20 @@ object DecisionTreeRunner { .action((x, c) => c.copy(maxBins = x)) opt[Int]("minInstancesPerNode") .text(s"min number of instances required at child nodes to create the parent split," + - s" default: ${defaultParams.minInstancesPerNode}") + s" default: ${defaultParams.minInstancesPerNode}") .action((x, c) => c.copy(minInstancesPerNode = x)) opt[Double]("minInfoGain") .text(s"min info gain required to create a split, default: ${defaultParams.minInfoGain}") .action((x, c) => c.copy(minInfoGain = x)) + opt[Int]("numTrees") + .text(s"number of trees (1 = decision tree, 2+ = random forest)," + + s" default: ${defaultParams.numTrees}") + .action((x, c) => c.copy(numTrees = x)) + opt[String]("featureSubsetStrategy") + .text(s"feature subset sampling strategy" + + s" (${RandomForest.supportedFeatureSubsetStrategies.mkString(", ")}}), " + + s"default: ${defaultParams.featureSubsetStrategy}") + .action((x, c) => c.copy(featureSubsetStrategy = x)) opt[Double]("fracTest") .text(s"fraction of data to hold out for testing, default: ${defaultParams.fracTest}") .action((x, c) => c.copy(fracTest = x)) @@ -191,18 +204,35 @@ object DecisionTreeRunner { numClassesForClassification = numClasses, minInstancesPerNode = params.minInstancesPerNode, minInfoGain = params.minInfoGain) - val model = DecisionTree.train(training, strategy) - - println(model) - - if (params.algo == Classification) { - val accuracy = accuracyScore(model, test) - println(s"Test accuracy = $accuracy") - } - - if (params.algo == Regression) { - val mse = meanSquaredError(model, test) - println(s"Test mean squared error = $mse") + if (params.numTrees == 1) { + val model = DecisionTree.train(training, strategy) + println(model) + if (params.algo == Classification) { + val accuracy = + new MulticlassMetrics(test.map(lp => (model.predict(lp.features), lp.label))).precision + println(s"Test accuracy = $accuracy") + } + if (params.algo == Regression) { + val mse = meanSquaredError(model, test) + println(s"Test mean squared error = $mse") + } + } else { + val randomSeed = Utils.random.nextInt() + if (params.algo == Classification) { + val model = RandomForest.trainClassifier(training, strategy, params.numTrees, + params.featureSubsetStrategy, randomSeed) + println(model) + val accuracy = + new MulticlassMetrics(test.map(lp => (model.predict(lp.features), lp.label))).precision + println(s"Test accuracy = $accuracy") + } + if (params.algo == Regression) { + val model = RandomForest.trainRegressor(training, strategy, params.numTrees, + params.featureSubsetStrategy, randomSeed) + println(model) + val mse = meanSquaredError(model, test) + println(s"Test mean squared error = $mse") + } } sc.stop() @@ -211,9 +241,7 @@ object DecisionTreeRunner { /** * Calculates the classifier accuracy. */ - private def accuracyScore( - model: DecisionTreeModel, - data: RDD[LabeledPoint]): Double = { + private def accuracyScore(model: DecisionTreeModel, data: RDD[LabeledPoint]): Double = { val correctCount = data.filter(y => model.predict(y.features) == y.label).count() val count = data.count() correctCount.toDouble / count @@ -228,4 +256,14 @@ object DecisionTreeRunner { err * err }.mean() } + + /** + * Calculates the mean squared error for regression. + */ + private def meanSquaredError(tree: RandomForestModel, data: RDD[LabeledPoint]): Double = { + data.map { y => + val err = tree.predict(y.features) - y.label + err * err + }.mean() + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index c7f2576c822b1..b7dc373ebd9cc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -18,12 +18,14 @@ package org.apache.spark.mllib.tree import scala.collection.JavaConverters._ +import scala.collection.mutable import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD import org.apache.spark.Logging import org.apache.spark.mllib.rdd.RDDFunctions._ import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.RandomForest.NodeIndexInfo import org.apache.spark.mllib.tree.configuration.Strategy import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.FeatureType._ @@ -33,7 +35,6 @@ import org.apache.spark.mllib.tree.impurity.{Impurities, Impurity} import org.apache.spark.mllib.tree.impurity._ import org.apache.spark.mllib.tree.model._ import org.apache.spark.rdd.RDD -import org.apache.spark.storage.StorageLevel import org.apache.spark.util.random.XORShiftRandom @@ -56,99 +57,10 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo * @return DecisionTreeModel that can be used for prediction */ def train(input: RDD[LabeledPoint]): DecisionTreeModel = { - - val timer = new TimeTracker() - - timer.start("total") - - timer.start("init") - - val retaggedInput = input.retag(classOf[LabeledPoint]) - val metadata = DecisionTreeMetadata.buildMetadata(retaggedInput, strategy) - logDebug("algo = " + strategy.algo) - logDebug("maxBins = " + metadata.maxBins) - - // Find the splits and the corresponding bins (interval between the splits) using a sample - // of the input data. - timer.start("findSplitsBins") - val (splits, bins) = DecisionTree.findSplitsBins(retaggedInput, metadata) - timer.stop("findSplitsBins") - logDebug("numBins: feature: number of bins") - logDebug(Range(0, metadata.numFeatures).map { featureIndex => - s"\t$featureIndex\t${metadata.numBins(featureIndex)}" - }.mkString("\n")) - - // Bin feature values (TreePoint representation). - // Cache input RDD for speedup during multiple passes. - val treeInput = TreePoint.convertToTreeRDD(retaggedInput, bins, metadata) - .persist(StorageLevel.MEMORY_AND_DISK) - - // depth of the decision tree - val maxDepth = strategy.maxDepth - require(maxDepth <= 30, - s"DecisionTree currently only supports maxDepth <= 30, but was given maxDepth = $maxDepth.") - - // Calculate level for single group construction - - // Max memory usage for aggregates - val maxMemoryUsage = strategy.maxMemoryInMB * 1024L * 1024L - logDebug("max memory usage for aggregates = " + maxMemoryUsage + " bytes.") - // TODO: Calculate memory usage more precisely. - val numElementsPerNode = DecisionTree.getElementsPerNode(metadata) - - logDebug("numElementsPerNode = " + numElementsPerNode) - val arraySizePerNode = 8 * numElementsPerNode // approx. memory usage for bin aggregate array - val maxNumberOfNodesPerGroup = math.max(maxMemoryUsage / arraySizePerNode, 1) - logDebug("maxNumberOfNodesPerGroup = " + maxNumberOfNodesPerGroup) - // nodes at a level is 2^level. level is zero indexed. - val maxLevelForSingleGroup = math.max( - (math.log(maxNumberOfNodesPerGroup) / math.log(2)).floor.toInt, 0) - logDebug("max level for single group = " + maxLevelForSingleGroup) - - timer.stop("init") - - /* - * The main idea here is to perform level-wise training of the decision tree nodes thus - * reducing the passes over the data from l to log2(l) where l is the total number of nodes. - * Each data sample is handled by a particular node at that level (or it reaches a leaf - * beforehand and is not used in later levels. - */ - - var topNode: Node = null // set on first iteration - var level = 0 - var break = false - while (level <= maxDepth && !break) { - logDebug("#####################################") - logDebug("level = " + level) - logDebug("#####################################") - - // Find best split for all nodes at a level. - timer.start("findBestSplits") - val (tmpTopNode: Node, doneTraining: Boolean) = DecisionTree.findBestSplits(treeInput, - metadata, level, topNode, splits, bins, maxLevelForSingleGroup, timer) - timer.stop("findBestSplits") - - if (level == 0) { - topNode = tmpTopNode - } - if (doneTraining) { - break = true - logDebug("done training") - } - - level += 1 - } - - logDebug("#####################################") - logDebug("Extracting tree model") - logDebug("#####################################") - - timer.stop("total") - - logInfo("Internal timing for DecisionTree:") - logInfo(s"$timer") - - new DecisionTreeModel(topNode, strategy.algo) + // Note: random seed will not be used since numTrees = 1. + val rf = new RandomForest(strategy, numTrees = 1, featureSubsetStrategy = "all", seed = 0) + val rfModel = rf.train(input) + rfModel.trees(0) } } @@ -352,58 +264,10 @@ object DecisionTree extends Serializable with Logging { impurity, maxDepth, maxBins) } - /** - * Returns an array of optimal splits for all nodes at a given level. Splits the task into - * multiple groups if the level-wise training task could lead to memory overflow. - * - * @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint]] - * @param metadata Learning and dataset metadata - * @param level Level of the tree - * @param topNode Root node of the tree (or invalid node when training first level). - * @param splits possible splits for all features, indexed (numFeatures)(numSplits) - * @param bins possible bins for all features, indexed (numFeatures)(numBins) - * @param maxLevelForSingleGroup the deepest level for single-group level-wise computation. - * @return (root, doneTraining) where: - * root = Root node (which is newly created on the first iteration), - * doneTraining = true if no more internal nodes were created. - */ - private[tree] def findBestSplits( - input: RDD[TreePoint], - metadata: DecisionTreeMetadata, - level: Int, - topNode: Node, - splits: Array[Array[Split]], - bins: Array[Array[Bin]], - maxLevelForSingleGroup: Int, - timer: TimeTracker = new TimeTracker): (Node, Boolean) = { - - // split into groups to avoid memory overflow during aggregation - if (level > maxLevelForSingleGroup) { - // When information for all nodes at a given level cannot be stored in memory, - // the nodes are divided into multiple groups at each level with the number of groups - // increasing exponentially per level. For example, if maxLevelForSingleGroup is 10, - // numGroups is equal to 2 at level 11 and 4 at level 12, respectively. - val numGroups = 1 << level - maxLevelForSingleGroup - logDebug("numGroups = " + numGroups) - // Iterate over each group of nodes at a level. - var groupIndex = 0 - var doneTraining = true - while (groupIndex < numGroups) { - val (_, doneTrainingGroup) = findBestSplitsPerGroup(input, metadata, level, - topNode, splits, bins, timer, numGroups, groupIndex) - doneTraining = doneTraining && doneTrainingGroup - groupIndex += 1 - } - (topNode, doneTraining) // Not first iteration, so topNode was already set. - } else { - findBestSplitsPerGroup(input, metadata, level, topNode, splits, bins, timer) - } - } - /** * Get the node index corresponding to this data point. - * This function mimics prediction, passing an example from the root node down to a node - * at the current level being trained; that node's index is returned. + * This function mimics prediction, passing an example from the root node down to a leaf + * or unsplit node; that node's index is returned. * * @param node Node in tree from which to classify the given data point. * @param binnedFeatures Binned feature vector for data point. @@ -413,14 +277,15 @@ object DecisionTree extends Serializable with Logging { * Otherwise, last node reachable in tree matching this example. * Note: This is the global node index, i.e., the index used in the tree. * This index is different from the index used during training a particular - * set of nodes in a (level, group). + * group of nodes on one call to [[findBestSplits()]]. */ private def predictNodeIndex( node: Node, binnedFeatures: Array[Int], bins: Array[Array[Bin]], unorderedFeatures: Set[Int]): Int = { - if (node.isLeaf) { + if (node.isLeaf || node.split.isEmpty) { + // Node is either leaf, or has not yet been split. node.id } else { val featureIndex = node.split.get.feature @@ -465,43 +330,60 @@ object DecisionTree extends Serializable with Logging { * @param agg Array storing aggregate calculation, with a set of sufficient statistics for * each (node, feature, bin). * @param treePoint Data point being aggregated. - * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group). + * @param nodeIndex Node corresponding to treePoint. agg is indexed in [0, numNodes). * @param bins possible bins for all features, indexed (numFeatures)(numBins) * @param unorderedFeatures Set of indices of unordered features. + * @param instanceWeight Weight (importance) of instance in dataset. */ private def mixedBinSeqOp( agg: DTStatsAggregator, treePoint: TreePoint, nodeIndex: Int, bins: Array[Array[Bin]], - unorderedFeatures: Set[Int]): Unit = { - // Iterate over all features. - val numFeatures = treePoint.binnedFeatures.size + unorderedFeatures: Set[Int], + instanceWeight: Double, + featuresForNode: Option[Array[Int]]): Unit = { + val numFeaturesPerNode = if (featuresForNode.nonEmpty) { + // Use subsampled features + featuresForNode.get.size + } else { + // Use all features + agg.metadata.numFeatures + } val nodeOffset = agg.getNodeOffset(nodeIndex) - var featureIndex = 0 - while (featureIndex < numFeatures) { + // Iterate over features. + var featureIndexIdx = 0 + while (featureIndexIdx < numFeaturesPerNode) { + val featureIndex = if (featuresForNode.nonEmpty) { + featuresForNode.get.apply(featureIndexIdx) + } else { + featureIndexIdx + } if (unorderedFeatures.contains(featureIndex)) { // Unordered feature val featureValue = treePoint.binnedFeatures(featureIndex) val (leftNodeFeatureOffset, rightNodeFeatureOffset) = - agg.getLeftRightNodeFeatureOffsets(nodeIndex, featureIndex) + agg.getLeftRightNodeFeatureOffsets(nodeIndex, featureIndexIdx) // Update the left or right bin for each split. - val numSplits = agg.numSplits(featureIndex) + val numSplits = agg.metadata.numSplits(featureIndex) var splitIndex = 0 while (splitIndex < numSplits) { if (bins(featureIndex)(splitIndex).highSplit.categories.contains(featureValue)) { - agg.nodeFeatureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label) + agg.nodeFeatureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label, + instanceWeight) } else { - agg.nodeFeatureUpdate(rightNodeFeatureOffset, splitIndex, treePoint.label) + agg.nodeFeatureUpdate(rightNodeFeatureOffset, splitIndex, treePoint.label, + instanceWeight) } splitIndex += 1 } } else { // Ordered feature val binIndex = treePoint.binnedFeatures(featureIndex) - agg.nodeUpdate(nodeOffset, featureIndex, binIndex, treePoint.label) + agg.nodeUpdate(nodeOffset, nodeIndex, featureIndexIdx, binIndex, treePoint.label, + instanceWeight) } - featureIndex += 1 + featureIndexIdx += 1 } } @@ -513,66 +395,77 @@ object DecisionTree extends Serializable with Logging { * @param agg Array storing aggregate calculation, with a set of sufficient statistics for * each (node, feature, bin). * @param treePoint Data point being aggregated. - * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group). - * @return agg + * @param nodeIndex Node corresponding to treePoint. agg is indexed in [0, numNodes). + * @param instanceWeight Weight (importance) of instance in dataset. */ private def orderedBinSeqOp( agg: DTStatsAggregator, treePoint: TreePoint, - nodeIndex: Int): Unit = { + nodeIndex: Int, + instanceWeight: Double, + featuresForNode: Option[Array[Int]]): Unit = { val label = treePoint.label val nodeOffset = agg.getNodeOffset(nodeIndex) - // Iterate over all features. - val numFeatures = agg.numFeatures - var featureIndex = 0 - while (featureIndex < numFeatures) { - val binIndex = treePoint.binnedFeatures(featureIndex) - agg.nodeUpdate(nodeOffset, featureIndex, binIndex, label) - featureIndex += 1 + // Iterate over features. + if (featuresForNode.nonEmpty) { + // Use subsampled features + var featureIndexIdx = 0 + while (featureIndexIdx < featuresForNode.get.size) { + val binIndex = treePoint.binnedFeatures(featuresForNode.get.apply(featureIndexIdx)) + agg.nodeUpdate(nodeOffset, nodeIndex, featureIndexIdx, binIndex, label, instanceWeight) + featureIndexIdx += 1 + } + } else { + // Use all features + val numFeatures = agg.metadata.numFeatures + var featureIndex = 0 + while (featureIndex < numFeatures) { + val binIndex = treePoint.binnedFeatures(featureIndex) + agg.nodeUpdate(nodeOffset, nodeIndex, featureIndex, binIndex, label, instanceWeight) + featureIndex += 1 + } } } /** - * Returns an array of optimal splits for a group of nodes at a given level + * Given a group of nodes, this finds the best split for each node. * * @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint]] * @param metadata Learning and dataset metadata - * @param level Level of the tree - * @param topNode Root node of the tree (or invalid node when training first level). + * @param topNodes Root node for each tree. Used for matching instances with nodes. + * @param nodesForGroup Mapping: treeIndex --> nodes to be split in tree + * @param treeToNodeToIndexInfo Mapping: treeIndex --> nodeIndex --> nodeIndexInfo, + * where nodeIndexInfo stores the index in the group and the + * feature subsets (if using feature subsets). * @param splits possible splits for all features, indexed (numFeatures)(numSplits) * @param bins possible bins for all features, indexed (numFeatures)(numBins) - * @param numGroups total number of node groups at the current level. Default value is set to 1. - * @param groupIndex index of the node group being processed. Default value is set to 0. - * @return (root, doneTraining) where: - * root = Root node (which is newly created on the first iteration), - * doneTraining = true if no more internal nodes were created. + * @param nodeQueue Queue of nodes to split, with values (treeIndex, node). + * Updated with new non-leaf nodes which are created. */ - private def findBestSplitsPerGroup( - input: RDD[TreePoint], + private[tree] def findBestSplits( + input: RDD[BaggedPoint[TreePoint]], metadata: DecisionTreeMetadata, - level: Int, - topNode: Node, + topNodes: Array[Node], + nodesForGroup: Map[Int, Array[Node]], + treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]], splits: Array[Array[Split]], bins: Array[Array[Bin]], - timer: TimeTracker, - numGroups: Int = 1, - groupIndex: Int = 0): (Node, Boolean) = { + nodeQueue: mutable.Queue[(Int, Node)], + timer: TimeTracker = new TimeTracker): Unit = { /* * The high-level descriptions of the best split optimizations are noted here. * - * *Level-wise training* - * We perform bin calculations for all nodes at the given level to avoid making multiple - * passes over the data. Thus, for a slightly increased computation and storage cost we save - * several iterations over the data especially at higher levels of the decision tree. + * *Group-wise training* + * We perform bin calculations for groups of nodes to reduce the number of + * passes over the data. Each iteration requires more computation and storage, + * but saves several iterations over the data. * * *Bin-wise computation* * We use a bin-wise best split computation strategy instead of a straightforward best split * computation strategy. Instead of analyzing each sample for contribution to the left/right * child node impurity of every split, we first categorize each feature of a sample into a - * bin. Each bin is an interval between a low and high split. Since each split, and thus bin, - * is ordered (read ordering for categorical variables in the findSplitsBins method), - * we exploit this structure to calculate aggregates for bins and then use these aggregates + * bin. We exploit this structure to calculate aggregates for bins and then use these aggregates * to calculate information gain for each split. * * *Aggregation over partitions* @@ -582,42 +475,15 @@ object DecisionTree extends Serializable with Logging { * drastically reduce the communication overhead. */ - // Common calculations for multiple nested methods: - - // numNodes: Number of nodes in this (level of tree, group), - // where nodes at deeper (larger) levels may be divided into groups. - val numNodes = Node.maxNodesInLevel(level) / numGroups + // numNodes: Number of nodes in this group + val numNodes = nodesForGroup.values.map(_.size).sum logDebug("numNodes = " + numNodes) - logDebug("numFeatures = " + metadata.numFeatures) logDebug("numClasses = " + metadata.numClasses) logDebug("isMulticlass = " + metadata.isMulticlass) logDebug("isMulticlassWithCategoricalFeatures = " + metadata.isMulticlassWithCategoricalFeatures) - // shift when more than one group is used at deep tree level - val groupShift = numNodes * groupIndex - - // Used for treePointToNodeIndex to get an index for this (level, group). - // - Node.startIndexInLevel(level) gives the global index offset for nodes at this level. - // - groupShift corrects for groups in this level before the current group. - val globalNodeIndexOffset = Node.startIndexInLevel(level) + groupShift - - /** - * Find the node index for the given example. - * Nodes are indexed from 0 at the start of this (level, group). - * If the example does not reach this level, returns a value < 0. - */ - def treePointToNodeIndex(treePoint: TreePoint): Int = { - if (level == 0) { - 0 - } else { - val globalNodeIndex = - predictNodeIndex(topNode, treePoint.binnedFeatures, bins, metadata.unorderedFeatures) - globalNodeIndex - globalNodeIndexOffset - } - } - /** * Performs a sequential aggregation over a partition. * @@ -626,21 +492,27 @@ object DecisionTree extends Serializable with Logging { * * @param agg Array storing aggregate calculation, with a set of sufficient statistics for * each (node, feature, bin). - * @param treePoint Data point being aggregated. + * @param baggedPoint Data point being aggregated. * @return agg */ def binSeqOp( agg: DTStatsAggregator, - treePoint: TreePoint): DTStatsAggregator = { - val nodeIndex = treePointToNodeIndex(treePoint) - // If the example does not reach this level, then nodeIndex < 0. - // If the example reaches this level but is handled in a different group, - // then either nodeIndex < 0 (previous group) or nodeIndex >= numNodes (later group). - if (nodeIndex >= 0 && nodeIndex < numNodes) { - if (metadata.unorderedFeatures.isEmpty) { - orderedBinSeqOp(agg, treePoint, nodeIndex) - } else { - mixedBinSeqOp(agg, treePoint, nodeIndex, bins, metadata.unorderedFeatures) + baggedPoint: BaggedPoint[TreePoint]): DTStatsAggregator = { + treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) => + val nodeIndex = predictNodeIndex(topNodes(treeIndex), baggedPoint.datum.binnedFeatures, + bins, metadata.unorderedFeatures) + val nodeInfo = nodeIndexToInfo.getOrElse(nodeIndex, null) + // If the example does not reach a node in this group, then nodeIndex = null. + if (nodeInfo != null) { + val aggNodeIndex = nodeInfo.nodeIndexInGroup + val featuresForNode = nodeInfo.featureSubset + val instanceWeight = baggedPoint.subsampleWeights(treeIndex) + if (metadata.unorderedFeatures.isEmpty) { + orderedBinSeqOp(agg, baggedPoint.datum, aggNodeIndex, instanceWeight, featuresForNode) + } else { + mixedBinSeqOp(agg, baggedPoint.datum, aggNodeIndex, bins, metadata.unorderedFeatures, + instanceWeight, featuresForNode) + } } } agg @@ -649,71 +521,62 @@ object DecisionTree extends Serializable with Logging { // Calculate bin aggregates. timer.start("aggregation") val binAggregates: DTStatsAggregator = { - val initAgg = new DTStatsAggregator(metadata, numNodes) + val initAgg = if (metadata.subsamplingFeatures) { + new DTStatsAggregatorSubsampledFeatures(metadata, treeToNodeToIndexInfo) + } else { + new DTStatsAggregatorFixedFeatures(metadata, numNodes) + } input.treeAggregate(initAgg)(binSeqOp, DTStatsAggregator.binCombOp) } timer.stop("aggregation") - // Calculate best splits for all nodes at a given level + // Calculate best splits for all nodes in the group timer.start("chooseSplits") - // On the first iteration, we need to get and return the newly created root node. - var newTopNode: Node = topNode - - // Iterate over all nodes at this level - var nodeIndex = 0 - var internalNodeCount = 0 - while (nodeIndex < numNodes) { - val (split: Split, stats: InformationGainStats, predict: Predict) = - binsToBestSplit(binAggregates, nodeIndex, level, metadata, splits) - logDebug("best split = " + split) - - val globalNodeIndex = globalNodeIndexOffset + nodeIndex - // Extract info for this node at the current level. - val isLeaf = (stats.gain <= 0) || (level == metadata.maxDepth) - val node = - new Node(globalNodeIndex, predict.predict, isLeaf, Some(split), None, None, Some(stats)) - logDebug("Node = " + node) - - if (!isLeaf) { - internalNodeCount += 1 - } - if (level == 0) { - newTopNode = node - } else { - // Set parent. - val parentNode = Node.getNode(Node.parentIndex(globalNodeIndex), topNode) - if (Node.isLeftChild(globalNodeIndex)) { - parentNode.leftNode = Some(node) - } else { - parentNode.rightNode = Some(node) + // Iterate over all nodes in this group. + nodesForGroup.foreach { case (treeIndex, nodesForTree) => + nodesForTree.foreach { node => + val nodeIndex = node.id + val nodeInfo = treeToNodeToIndexInfo(treeIndex)(nodeIndex) + val aggNodeIndex = nodeInfo.nodeIndexInGroup + val featuresForNode = nodeInfo.featureSubset + val (split: Split, stats: InformationGainStats, predict: Predict) = + binsToBestSplit(binAggregates, aggNodeIndex, splits, featuresForNode) + logDebug("best split = " + split) + + // Extract info for this node. Create children if not leaf. + val isLeaf = (stats.gain <= 0) || (Node.indexToLevel(nodeIndex) == metadata.maxDepth) + assert(node.id == nodeIndex) + node.predict = predict.predict + node.isLeaf = isLeaf + node.stats = Some(stats) + logDebug("Node = " + node) + + if (!isLeaf) { + node.split = Some(split) + node.leftNode = Some(Node.emptyNode(Node.leftChildIndex(nodeIndex))) + node.rightNode = Some(Node.emptyNode(Node.rightChildIndex(nodeIndex))) + nodeQueue.enqueue((treeIndex, node.leftNode.get)) + nodeQueue.enqueue((treeIndex, node.rightNode.get)) + logDebug("leftChildIndex = " + node.leftNode.get.id + + ", impurity = " + stats.leftImpurity) + logDebug("rightChildIndex = " + node.rightNode.get.id + + ", impurity = " + stats.rightImpurity) } } - if (level < metadata.maxDepth) { - logDebug("leftChildIndex = " + Node.leftChildIndex(globalNodeIndex) + - ", impurity = " + stats.leftImpurity) - logDebug("rightChildIndex = " + Node.rightChildIndex(globalNodeIndex) + - ", impurity = " + stats.rightImpurity) - } - - nodeIndex += 1 } timer.stop("chooseSplits") - - val doneTraining = internalNodeCount == 0 - (newTopNode, doneTraining) } /** * Calculate the information gain for a given (feature, split) based upon left/right aggregates. * @param leftImpurityCalculator left node aggregates for this (feature, split) * @param rightImpurityCalculator right node aggregate for this (feature, split) - * @return information gain and statistics for all splits + * @return information gain and statistics for split */ private def calculateGainForSplit( leftImpurityCalculator: ImpurityCalculator, rightImpurityCalculator: ImpurityCalculator, - level: Int, metadata: DecisionTreeMetadata): InformationGainStats = { val leftCount = leftImpurityCalculator.count val rightCount = rightImpurityCalculator.count @@ -753,7 +616,7 @@ object DecisionTree extends Serializable with Logging { * Calculate predict value for current node, given stats of any split. * Note that this function is called only once for each node. * @param leftImpurityCalculator left node aggregates for a split - * @param rightImpurityCalculator right node aggregates for a node + * @param rightImpurityCalculator right node aggregates for a split * @return predict value for current node */ private def calculatePredict( @@ -770,27 +633,33 @@ object DecisionTree extends Serializable with Logging { /** * Find the best split for a node. * @param binAggregates Bin statistics. - * @param nodeIndex Index for node to split in this (level, group). - * @return tuple for best split: (Split, information gain) + * @param nodeIndex Index into aggregates for node to split in this group. + * @return tuple for best split: (Split, information gain, prediction at node) */ private def binsToBestSplit( binAggregates: DTStatsAggregator, nodeIndex: Int, - level: Int, - metadata: DecisionTreeMetadata, - splits: Array[Array[Split]]): (Split, InformationGainStats, Predict) = { + splits: Array[Array[Split]], + featuresForNode: Option[Array[Int]]): (Split, InformationGainStats, Predict) = { + + val metadata: DecisionTreeMetadata = binAggregates.metadata // calculate predict only once var predict: Option[Predict] = None // For each (feature, split), calculate the gain, and select the best (feature, split). - val (bestSplit, bestSplitStats) = Range(0, metadata.numFeatures).map { featureIndex => + val (bestSplit, bestSplitStats) = Range(0, metadata.numFeaturesPerNode).map { featureIndexIdx => + val featureIndex = if (featuresForNode.nonEmpty) { + featuresForNode.get.apply(featureIndexIdx) + } else { + featureIndexIdx + } val numSplits = metadata.numSplits(featureIndex) if (metadata.isContinuous(featureIndex)) { // Cumulative sum (scanLeft) of bin statistics. // Afterwards, binAggregates for a bin is the sum of aggregates for // that bin + all preceding bins. - val nodeFeatureOffset = binAggregates.getNodeFeatureOffset(nodeIndex, featureIndex) + val nodeFeatureOffset = binAggregates.getNodeFeatureOffset(nodeIndex, featureIndexIdx) var splitIndex = 0 while (splitIndex < numSplits) { binAggregates.mergeForNodeFeature(nodeFeatureOffset, splitIndex + 1, splitIndex) @@ -803,26 +672,26 @@ object DecisionTree extends Serializable with Logging { val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits) rightChildStats.subtract(leftChildStats) predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats))) - val gainStats = calculateGainForSplit(leftChildStats, rightChildStats, level, metadata) + val gainStats = calculateGainForSplit(leftChildStats, rightChildStats, metadata) (splitIdx, gainStats) }.maxBy(_._2.gain) (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) } else if (metadata.isUnordered(featureIndex)) { // Unordered categorical feature val (leftChildOffset, rightChildOffset) = - binAggregates.getLeftRightNodeFeatureOffsets(nodeIndex, featureIndex) + binAggregates.getLeftRightNodeFeatureOffsets(nodeIndex, featureIndexIdx) val (bestFeatureSplitIndex, bestFeatureGainStats) = Range(0, numSplits).map { splitIndex => val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex) val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex) predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats))) - val gainStats = calculateGainForSplit(leftChildStats, rightChildStats, level, metadata) + val gainStats = calculateGainForSplit(leftChildStats, rightChildStats, metadata) (splitIndex, gainStats) }.maxBy(_._2.gain) (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) } else { // Ordered categorical feature - val nodeFeatureOffset = binAggregates.getNodeFeatureOffset(nodeIndex, featureIndex) + val nodeFeatureOffset = binAggregates.getNodeFeatureOffset(nodeIndex, featureIndexIdx) val numBins = metadata.numBins(featureIndex) /* Each bin is one category (feature value). @@ -887,7 +756,7 @@ object DecisionTree extends Serializable with Logging { binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory) rightChildStats.subtract(leftChildStats) predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats))) - val gainStats = calculateGainForSplit(leftChildStats, rightChildStats, level, metadata) + val gainStats = calculateGainForSplit(leftChildStats, rightChildStats, metadata) (splitIndex, gainStats) }.maxBy(_._2.gain) val categoriesForSplit = @@ -903,18 +772,6 @@ object DecisionTree extends Serializable with Logging { (bestSplit, bestSplitStats, predict.get) } - /** - * Get the number of values to be stored per node in the bin aggregates. - */ - private def getElementsPerNode(metadata: DecisionTreeMetadata): Long = { - val totalBins = metadata.numBins.map(_.toLong).sum - if (metadata.isClassification) { - metadata.numClasses * totalBins - } else { - 3 * totalBins - } - } - /** * Returns splits and bins for decision tree calculation. * Continuous and categorical features are handled differently. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala new file mode 100644 index 0000000000000..7fa7725e79e46 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala @@ -0,0 +1,451 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree + +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import org.apache.spark.Logging +import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.configuration.Algo._ +import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ +import org.apache.spark.mllib.tree.configuration.Strategy +import org.apache.spark.mllib.tree.impl.{BaggedPoint, TreePoint, DecisionTreeMetadata, TimeTracker} +import org.apache.spark.mllib.tree.impurity.Impurities +import org.apache.spark.mllib.tree.model._ +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.Utils + +/** + * :: Experimental :: + * A class which implements a random forest learning algorithm for classification and regression. + * It supports both continuous and categorical features. + * + * The settings for featureSubsetStrategy are based on the following references: + * - log2: tested in Breiman (2001) + * - sqrt: recommended by Breiman manual for random forests + * - The defaults of sqrt (classification) and onethird (regression) match the R randomForest + * package. + * @see [[http://www.stat.berkeley.edu/~breiman/randomforest2001.pdf Breiman (2001)]] + * @see [[http://www.stat.berkeley.edu/~breiman/Using_random_forests_V3.1.pdf Breiman manual for + * random forests]] + * + * @param strategy The configuration parameters for the random forest algorithm which specify + * the type of algorithm (classification, regression, etc.), feature type + * (continuous, categorical), depth of the tree, quantile calculation strategy, + * etc. + * @param numTrees If 1, then no bootstrapping is used. If > 1, then bootstrapping is done. + * @param featureSubsetStrategy Number of features to consider for splits at each node. + * Supported: "auto" (default), "all", "sqrt", "log2", "onethird". + * If "auto" is set, this parameter is set based on numTrees: + * if numTrees == 1, set to "all"; + * if numTrees > 1 (forest) set to "sqrt" for classification and + * to "onethird" for regression. + * @param seed Random seed for bootstrapping and choosing feature subsets. + */ +@Experimental +private class RandomForest ( + private val strategy: Strategy, + private val numTrees: Int, + featureSubsetStrategy: String, + private val seed: Int) + extends Serializable with Logging { + + strategy.assertValid() + require(numTrees > 0, s"RandomForest requires numTrees > 0, but was given numTrees = $numTrees.") + require(RandomForest.supportedFeatureSubsetStrategies.contains(featureSubsetStrategy), + s"RandomForest given invalid featureSubsetStrategy: $featureSubsetStrategy." + + s" Supported values: ${RandomForest.supportedFeatureSubsetStrategies.mkString(", ")}.") + + /** + * Method to train a decision tree model over an RDD + * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] + * @return RandomForestModel that can be used for prediction + */ + def train(input: RDD[LabeledPoint]): RandomForestModel = { + + val timer = new TimeTracker() + + timer.start("total") + + timer.start("init") + + val retaggedInput = input.retag(classOf[LabeledPoint]) + val metadata = + DecisionTreeMetadata.buildMetadata(retaggedInput, strategy, numTrees, featureSubsetStrategy) + logDebug("algo = " + strategy.algo) + logDebug("numTrees = " + numTrees) + logDebug("seed = " + seed) + logDebug("maxBins = " + metadata.maxBins) + logDebug("featureSubsetStrategy = " + featureSubsetStrategy) + logDebug("numFeaturesPerNode = " + metadata.numFeaturesPerNode) + + // Find the splits and the corresponding bins (interval between the splits) using a sample + // of the input data. + timer.start("findSplitsBins") + val (splits, bins) = DecisionTree.findSplitsBins(retaggedInput, metadata) + timer.stop("findSplitsBins") + logDebug("numBins: feature: number of bins") + logDebug(Range(0, metadata.numFeatures).map { featureIndex => + s"\t$featureIndex\t${metadata.numBins(featureIndex)}" + }.mkString("\n")) + + // Bin feature values (TreePoint representation). + // Cache input RDD for speedup during multiple passes. + val treeInput = TreePoint.convertToTreeRDD(retaggedInput, bins, metadata) + val baggedInput = if (numTrees > 1) { + BaggedPoint.convertToBaggedRDD(treeInput, numTrees, seed) + } else { + BaggedPoint.convertToBaggedRDDWithoutSampling(treeInput) + }.persist(StorageLevel.MEMORY_AND_DISK) + + // depth of the decision tree + val maxDepth = strategy.maxDepth + require(maxDepth <= 30, + s"DecisionTree currently only supports maxDepth <= 30, but was given maxDepth = $maxDepth.") + + // Max memory usage for aggregates + // TODO: Calculate memory usage more precisely. + val maxMemoryUsage: Long = strategy.maxMemoryInMB * 1024L * 1024L + logDebug("max memory usage for aggregates = " + maxMemoryUsage + " bytes.") + val maxMemoryPerNode = { + val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) { + // Find numFeaturesPerNode largest bins to get an upper bound on memory usage. + Some(metadata.numBins.zipWithIndex.sortBy(- _._1) + .take(metadata.numFeaturesPerNode).map(_._2)) + } else { + None + } + RandomForest.aggregateSizeForNode(metadata, featureSubset) * 8L + } + require(maxMemoryPerNode <= maxMemoryUsage, + s"RandomForest/DecisionTree given maxMemoryInMB = ${strategy.maxMemoryInMB}," + + " which is too small for the given features." + + s" Minimum value = ${maxMemoryPerNode / (1024L * 1024L)}") + + timer.stop("init") + + /* + * The main idea here is to perform group-wise training of the decision tree nodes thus + * reducing the passes over the data from (# nodes) to (# nodes / maxNumberOfNodesPerGroup). + * Each data sample is handled by a particular node (or it reaches a leaf and is not used + * in lower levels). + */ + + // FIFO queue of nodes to train: (treeIndex, node) + val nodeQueue = new mutable.Queue[(Int, Node)]() + + val rng = new scala.util.Random() + rng.setSeed(seed) + + // Allocate and queue root nodes. + val topNodes: Array[Node] = Array.fill[Node](numTrees)(Node.emptyNode(nodeIndex = 1)) + Range(0, numTrees).foreach(treeIndex => nodeQueue.enqueue((treeIndex, topNodes(treeIndex)))) + + while (nodeQueue.nonEmpty) { + // Collect some nodes to split, and choose features for each node (if subsampling). + // Each group of nodes may come from one or multiple trees, and at multiple levels. + val (nodesForGroup, treeToNodeToIndexInfo) = + RandomForest.selectNodesToSplit(nodeQueue, maxMemoryUsage, metadata, rng) + // Sanity check (should never occur): + assert(nodesForGroup.size > 0, + s"RandomForest selected empty nodesForGroup. Error for unknown reason.") + + // Choose node splits, and enqueue new nodes as needed. + timer.start("findBestSplits") + DecisionTree.findBestSplits(baggedInput, + metadata, topNodes, nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue, timer) + timer.stop("findBestSplits") + } + + timer.stop("total") + + logInfo("Internal timing for DecisionTree:") + logInfo(s"$timer") + + val trees = topNodes.map(topNode => new DecisionTreeModel(topNode, strategy.algo)) + RandomForestModel.build(trees) + } + +} + +object RandomForest extends Serializable with Logging { + + /** + * Method to train a decision tree model for binary or multiclass classification. + * + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * Labels should take values {0, 1, ..., numClasses-1}. + * @param strategy Parameters for training each tree in the forest. + * @param numTrees Number of trees in the random forest. + * @param featureSubsetStrategy Number of features to consider for splits at each node. + * Supported: "auto" (default), "all", "sqrt", "log2", "onethird". + * If "auto" is set, this parameter is set based on numTrees: + * if numTrees == 1, set to "all"; + * if numTrees > 1 (forest) set to "sqrt" for classification and + * to "onethird" for regression. + * @param seed Random seed for bootstrapping and choosing feature subsets. + * @return RandomForestModel that can be used for prediction + */ + def trainClassifier( + input: RDD[LabeledPoint], + strategy: Strategy, + numTrees: Int, + featureSubsetStrategy: String, + seed: Int): RandomForestModel = { + require(strategy.algo == Classification, + s"RandomForest.trainClassifier given Strategy with invalid algo: ${strategy.algo}") + val rf = new RandomForest(strategy, numTrees, featureSubsetStrategy, seed) + rf.train(input) + } + + /** + * Method to train a decision tree model for binary or multiclass classification. + * + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * Labels should take values {0, 1, ..., numClasses-1}. + * @param numClassesForClassification number of classes for classification. + * @param categoricalFeaturesInfo Map storing arity of categorical features. + * E.g., an entry (n -> k) indicates that feature n is categorical + * with k categories indexed from 0: {0, 1, ..., k-1}. + * @param numTrees Number of trees in the random forest. + * @param featureSubsetStrategy Number of features to consider for splits at each node. + * Supported: "auto" (default), "all", "sqrt", "log2", "onethird". + * If "auto" is set, this parameter is set based on numTrees: + * if numTrees == 1, set to "all"; + * if numTrees > 1 (forest) set to "sqrt" for classification and + * to "onethird" for regression. + * @param impurity Criterion used for information gain calculation. + * Supported values: "gini" (recommended) or "entropy". + * @param maxDepth Maximum depth of the tree. + * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. + * (suggested value: 4) + * @param maxBins maximum number of bins used for splitting features + * (suggested value: 100) + * @param seed Random seed for bootstrapping and choosing feature subsets. + * @return RandomForestModel that can be used for prediction + */ + def trainClassifier( + input: RDD[LabeledPoint], + numClassesForClassification: Int, + categoricalFeaturesInfo: Map[Int, Int], + numTrees: Int, + featureSubsetStrategy: String, + impurity: String, + maxDepth: Int, + maxBins: Int, + seed: Int = Utils.random.nextInt()): RandomForestModel = { + val impurityType = Impurities.fromString(impurity) + val strategy = new Strategy(Classification, impurityType, maxDepth, + numClassesForClassification, maxBins, Sort, categoricalFeaturesInfo) + trainClassifier(input, strategy, numTrees, featureSubsetStrategy, seed) + } + + /** + * Java-friendly API for [[org.apache.spark.mllib.tree.RandomForest$#trainClassifier]] + */ + def trainClassifier( + input: JavaRDD[LabeledPoint], + numClassesForClassification: Int, + categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer], + numTrees: Int, + featureSubsetStrategy: String, + impurity: String, + maxDepth: Int, + maxBins: Int, + seed: Int): RandomForestModel = { + trainClassifier(input.rdd, numClassesForClassification, + categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap, + numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, seed) + } + + /** + * Method to train a decision tree model for regression. + * + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * Labels are real numbers. + * @param strategy Parameters for training each tree in the forest. + * @param numTrees Number of trees in the random forest. + * @param featureSubsetStrategy Number of features to consider for splits at each node. + * Supported: "auto" (default), "all", "sqrt", "log2", "onethird". + * If "auto" is set, this parameter is set based on numTrees: + * if numTrees == 1, set to "all"; + * if numTrees > 1 (forest) set to "sqrt" for classification and + * to "onethird" for regression. + * @param seed Random seed for bootstrapping and choosing feature subsets. + * @return RandomForestModel that can be used for prediction + */ + def trainRegressor( + input: RDD[LabeledPoint], + strategy: Strategy, + numTrees: Int, + featureSubsetStrategy: String, + seed: Int): RandomForestModel = { + require(strategy.algo == Regression, + s"RandomForest.trainRegressor given Strategy with invalid algo: ${strategy.algo}") + val rf = new RandomForest(strategy, numTrees, featureSubsetStrategy, seed) + rf.train(input) + } + + /** + * Method to train a decision tree model for regression. + * + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * Labels are real numbers. + * @param categoricalFeaturesInfo Map storing arity of categorical features. + * E.g., an entry (n -> k) indicates that feature n is categorical + * with k categories indexed from 0: {0, 1, ..., k-1}. + * @param numTrees Number of trees in the random forest. + * @param featureSubsetStrategy Number of features to consider for splits at each node. + * Supported: "auto" (default), "all", "sqrt", "log2", "onethird". + * If "auto" is set, this parameter is set based on numTrees: + * if numTrees == 1, set to "all"; + * if numTrees > 1 (forest) set to "sqrt" for classification and + * to "onethird" for regression. + * @param impurity Criterion used for information gain calculation. + * Supported values: "variance". + * @param maxDepth Maximum depth of the tree. + * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. + * (suggested value: 4) + * @param maxBins maximum number of bins used for splitting features + * (suggested value: 100) + * @param seed Random seed for bootstrapping and choosing feature subsets. + * @return RandomForestModel that can be used for prediction + */ + def trainRegressor( + input: RDD[LabeledPoint], + categoricalFeaturesInfo: Map[Int, Int], + numTrees: Int, + featureSubsetStrategy: String, + impurity: String, + maxDepth: Int, + maxBins: Int, + seed: Int = Utils.random.nextInt()): RandomForestModel = { + val impurityType = Impurities.fromString(impurity) + val strategy = new Strategy(Regression, impurityType, maxDepth, + 0, maxBins, Sort, categoricalFeaturesInfo) + trainRegressor(input, strategy, numTrees, featureSubsetStrategy, seed) + } + + /** + * Java-friendly API for [[org.apache.spark.mllib.tree.RandomForest$#trainRegressor]] + */ + def trainRegressor( + input: JavaRDD[LabeledPoint], + categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer], + numTrees: Int, + featureSubsetStrategy: String, + impurity: String, + maxDepth: Int, + maxBins: Int, + seed: Int): RandomForestModel = { + trainRegressor(input.rdd, + categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap, + numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, seed) + } + + /** + * List of supported feature subset sampling strategies. + */ + val supportedFeatureSubsetStrategies: Array[String] = + Array("auto", "all", "sqrt", "log2", "onethird") + + private[tree] class NodeIndexInfo( + val nodeIndexInGroup: Int, + val featureSubset: Option[Array[Int]]) extends Serializable + + /** + * Pull nodes off of the queue, and collect a group of nodes to be split on this iteration. + * This tracks the memory usage for aggregates and stops adding nodes when too much memory + * will be needed; this allows an adaptive number of nodes since different nodes may require + * different amounts of memory (if featureSubsetStrategy is not "all"). + * + * @param nodeQueue Queue of nodes to split. + * @param maxMemoryUsage Bound on size of aggregate statistics. + * @return (nodesForGroup, treeToNodeToIndexInfo). + * nodesForGroup holds the nodes to split: treeIndex --> nodes in tree. + * treeToNodeToIndexInfo holds indices selected features for each node: + * treeIndex --> (global) node index --> (node index in group, feature indices). + * The (global) node index is the index in the tree; the node index in group is the + * index in [0, numNodesInGroup) of the node in this group. + * The feature indices are None if not subsampling features. + */ + private[tree] def selectNodesToSplit( + nodeQueue: mutable.Queue[(Int, Node)], + maxMemoryUsage: Long, + metadata: DecisionTreeMetadata, + rng: scala.util.Random): (Map[Int, Array[Node]], Map[Int, Map[Int, NodeIndexInfo]]) = { + // Collect some nodes to split: + // nodesForGroup(treeIndex) = nodes to split + val mutableNodesForGroup = new mutable.HashMap[Int, mutable.ArrayBuffer[Node]]() + val mutableTreeToNodeToIndexInfo = + new mutable.HashMap[Int, mutable.HashMap[Int, NodeIndexInfo]]() + var memUsage: Long = 0L + var numNodesInGroup = 0 + while (nodeQueue.nonEmpty && memUsage < maxMemoryUsage) { + val (treeIndex, node) = nodeQueue.head + // Choose subset of features for node (if subsampling). + val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) { + // TODO: Use more efficient subsampling? (use selection-and-rejection or reservoir) + Some(rng.shuffle(Range(0, metadata.numFeatures).toList) + .take(metadata.numFeaturesPerNode).toArray) + } else { + None + } + // Check if enough memory remains to add this node to the group. + val nodeMemUsage = RandomForest.aggregateSizeForNode(metadata, featureSubset) * 8L + if (memUsage + nodeMemUsage <= maxMemoryUsage) { + nodeQueue.dequeue() + mutableNodesForGroup.getOrElseUpdate(treeIndex, new mutable.ArrayBuffer[Node]()) += node + mutableTreeToNodeToIndexInfo + .getOrElseUpdate(treeIndex, new mutable.HashMap[Int, NodeIndexInfo]())(node.id) + = new NodeIndexInfo(numNodesInGroup, featureSubset) + } + numNodesInGroup += 1 + memUsage += nodeMemUsage + } + // Convert mutable maps to immutable ones. + val nodesForGroup: Map[Int, Array[Node]] = mutableNodesForGroup.mapValues(_.toArray).toMap + val treeToNodeToIndexInfo = mutableTreeToNodeToIndexInfo.mapValues(_.toMap).toMap + (nodesForGroup, treeToNodeToIndexInfo) + } + + /** + * Get the number of values to be stored for this node in the bin aggregates. + * @param featureSubset Indices of features which may be split at this node. + * If None, then use all features. + */ + private[tree] def aggregateSizeForNode( + metadata: DecisionTreeMetadata, + featureSubset: Option[Array[Int]]): Long = { + val totalBins = if (featureSubset.nonEmpty) { + featureSubset.get.map(featureIndex => metadata.numBins(featureIndex).toLong).sum + } else { + metadata.numBins.map(_.toLong).sum + } + if (metadata.isClassification) { + metadata.numClasses * totalBins + } else { + 3 * totalBins + } + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala new file mode 100644 index 0000000000000..937c8a2ac5836 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.impl + +import cern.jet.random.Poisson +import cern.jet.random.engine.DRand + +import org.apache.spark.rdd.RDD +import org.apache.spark.util.Utils + +/** + * Internal representation of a datapoint which belongs to several subsamples of the same dataset, + * particularly for bagging (e.g., for random forests). + * + * This holds one instance, as well as an array of weights which represent the (weighted) + * number of times which this instance appears in each subsample. + * E.g., (datum, [1, 0, 4]) indicates that there are 3 subsamples of the dataset and that + * this datum has 1 copy, 0 copies, and 4 copies in the 3 subsamples, respectively. + * + * @param datum Data instance + * @param subsampleWeights Weight of this instance in each subsampled dataset. + * + * TODO: This does not currently support (Double) weighted instances. Once MLlib has weighted + * dataset support, update. (We store subsampleWeights as Double for this future extension.) + */ +private[tree] class BaggedPoint[Datum](val datum: Datum, val subsampleWeights: Array[Double]) + extends Serializable + +private[tree] object BaggedPoint { + + /** + * Convert an input dataset into its BaggedPoint representation, + * choosing subsample counts for each instance. + * Each subsample has the same number of instances as the original dataset, + * and is created by subsampling with replacement. + * @param input Input dataset. + * @param numSubsamples Number of subsamples of this RDD to take. + * @param seed Random seed. + * @return BaggedPoint dataset representation + */ + def convertToBaggedRDD[Datum]( + input: RDD[Datum], + numSubsamples: Int, + seed: Int = Utils.random.nextInt()): RDD[BaggedPoint[Datum]] = { + input.mapPartitionsWithIndex { (partitionIndex, instances) => + // TODO: Support different sampling rates, and sampling without replacement. + // Use random seed = seed + partitionIndex + 1 to make generation reproducible. + val poisson = new Poisson(1.0, new DRand(seed + partitionIndex + 1)) + instances.map { instance => + val subsampleWeights = new Array[Double](numSubsamples) + var subsampleIndex = 0 + while (subsampleIndex < numSubsamples) { + subsampleWeights(subsampleIndex) = poisson.nextInt() + subsampleIndex += 1 + } + new BaggedPoint(instance, subsampleWeights) + } + } + } + + def convertToBaggedRDDWithoutSampling[Datum](input: RDD[Datum]): RDD[BaggedPoint[Datum]] = { + input.map(datum => new BaggedPoint(datum, Array(1.0))) + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala index 61a94246711bf..d49df7a016375 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala @@ -17,16 +17,17 @@ package org.apache.spark.mllib.tree.impl +import org.apache.spark.mllib.tree.RandomForest.NodeIndexInfo import org.apache.spark.mllib.tree.impurity._ /** * DecisionTree statistics aggregator. * This holds a flat array of statistics for a set of (nodes, features, bins) * and helps with indexing. + * This class is abstract to support learning with and without feature subsampling. */ -private[tree] class DTStatsAggregator( - val metadata: DecisionTreeMetadata, - val numNodes: Int) extends Serializable { +private[tree] abstract class DTStatsAggregator( + val metadata: DecisionTreeMetadata) extends Serializable { /** * [[ImpurityAggregator]] instance specifying the impurity type. @@ -43,49 +44,21 @@ private[tree] class DTStatsAggregator( */ val statsSize: Int = impurityAggregator.statsSize - val numFeatures: Int = metadata.numFeatures - - /** - * Number of bins for each feature. This is indexed by the feature index. - */ - val numBins: Array[Int] = metadata.numBins - - /** - * Number of splits for the given feature. - */ - def numSplits(featureIndex: Int): Int = metadata.numSplits(featureIndex) - /** * Indicator for each feature of whether that feature is an unordered feature. * TODO: Is Array[Boolean] any faster? */ def isUnordered(featureIndex: Int): Boolean = metadata.isUnordered(featureIndex) - /** - * Offset for each feature for calculating indices into the [[allStats]] array. - */ - private val featureOffsets: Array[Int] = { - numBins.scanLeft(0)((total, nBins) => total + statsSize * nBins) - } - - /** - * Number of elements for each node, corresponding to stride between nodes in [[allStats]]. - */ - private val nodeStride: Int = featureOffsets.last - /** * Total number of elements stored in this aggregator. */ - val allStatsSize: Int = numNodes * nodeStride + def allStatsSize: Int /** - * Flat array of elements. - * Index for start of stats for a (node, feature, bin) is: - * index = nodeIndex * nodeStride + featureOffsets(featureIndex) + binIndex * statsSize - * Note: For unordered features, the left child stats have binIndex in [0, numBins(featureIndex)) - * and the right child stats in [numBins(featureIndex), 2 * numBins(featureIndex)) + * Get flat array of elements stored in this aggregator. */ - val allStats: Array[Double] = new Array[Double](allStatsSize) + protected def allStats: Array[Double] /** * Get an [[ImpurityCalculator]] for a given (node, feature, bin). @@ -102,36 +75,39 @@ private[tree] class DTStatsAggregator( /** * Update the stats for a given (node, feature, bin) for ordered features, using the given label. */ - def update(nodeIndex: Int, featureIndex: Int, binIndex: Int, label: Double): Unit = { - val i = nodeIndex * nodeStride + featureOffsets(featureIndex) + binIndex * statsSize - impurityAggregator.update(allStats, i, label) + def update( + nodeIndex: Int, + featureIndex: Int, + binIndex: Int, + label: Double, + instanceWeight: Double): Unit = { + val i = getNodeFeatureOffset(nodeIndex, featureIndex) + binIndex * statsSize + impurityAggregator.update(allStats, i, label, instanceWeight) } /** * Pre-compute node offset for use with [[nodeUpdate]]. */ - def getNodeOffset(nodeIndex: Int): Int = nodeIndex * nodeStride + def getNodeOffset(nodeIndex: Int): Int /** * Faster version of [[update]]. * Update the stats for a given (node, feature, bin) for ordered features, using the given label. * @param nodeOffset Pre-computed node offset from [[getNodeOffset]]. */ - def nodeUpdate(nodeOffset: Int, featureIndex: Int, binIndex: Int, label: Double): Unit = { - val i = nodeOffset + featureOffsets(featureIndex) + binIndex * statsSize - impurityAggregator.update(allStats, i, label) - } + def nodeUpdate( + nodeOffset: Int, + nodeIndex: Int, + featureIndex: Int, + binIndex: Int, + label: Double, + instanceWeight: Double): Unit /** * Pre-compute (node, feature) offset for use with [[nodeFeatureUpdate]]. * For ordered features only. */ - def getNodeFeatureOffset(nodeIndex: Int, featureIndex: Int): Int = { - require(!isUnordered(featureIndex), - s"DTStatsAggregator.getNodeFeatureOffset is for ordered features only, but was called" + - s" for unordered feature $featureIndex.") - nodeIndex * nodeStride + featureOffsets(featureIndex) - } + def getNodeFeatureOffset(nodeIndex: Int, featureIndex: Int): Int /** * Pre-compute (node, feature) offset for use with [[nodeFeatureUpdate]]. @@ -140,9 +116,9 @@ private[tree] class DTStatsAggregator( def getLeftRightNodeFeatureOffsets(nodeIndex: Int, featureIndex: Int): (Int, Int) = { require(isUnordered(featureIndex), s"DTStatsAggregator.getLeftRightNodeFeatureOffsets is for unordered features only," + - s" but was called for ordered feature $featureIndex.") - val baseOffset = nodeIndex * nodeStride + featureOffsets(featureIndex) - (baseOffset, baseOffset + (numBins(featureIndex) >> 1) * statsSize) + s" but was called for ordered feature $featureIndex.") + val baseOffset = getNodeFeatureOffset(nodeIndex, featureIndex) + (baseOffset, baseOffset + (metadata.numBins(featureIndex) >> 1) * statsSize) } /** @@ -154,8 +130,13 @@ private[tree] class DTStatsAggregator( * (node, feature, left/right child) offset from * [[getLeftRightNodeFeatureOffsets]]. */ - def nodeFeatureUpdate(nodeFeatureOffset: Int, binIndex: Int, label: Double): Unit = { - impurityAggregator.update(allStats, nodeFeatureOffset + binIndex * statsSize, label) + def nodeFeatureUpdate( + nodeFeatureOffset: Int, + binIndex: Int, + label: Double, + instanceWeight: Double): Unit = { + impurityAggregator.update(allStats, nodeFeatureOffset + binIndex * statsSize, label, + instanceWeight) } /** @@ -189,7 +170,139 @@ private[tree] class DTStatsAggregator( } this } +} + +/** + * DecisionTree statistics aggregator. + * This holds a flat array of statistics for a set of (nodes, features, bins) + * and helps with indexing. + * + * This instance of [[DTStatsAggregator]] is used when not subsampling features. + * + * @param numNodes Number of nodes to collect statistics for. + */ +private[tree] class DTStatsAggregatorFixedFeatures( + metadata: DecisionTreeMetadata, + numNodes: Int) extends DTStatsAggregator(metadata) { + + /** + * Offset for each feature for calculating indices into the [[allStats]] array. + * Mapping: featureIndex --> offset + */ + private val featureOffsets: Array[Int] = { + metadata.numBins.scanLeft(0)((total, nBins) => total + statsSize * nBins) + } + + /** + * Number of elements for each node, corresponding to stride between nodes in [[allStats]]. + */ + private val nodeStride: Int = featureOffsets.last + override val allStatsSize: Int = numNodes * nodeStride + + /** + * Flat array of elements. + * Index for start of stats for a (node, feature, bin) is: + * index = nodeIndex * nodeStride + featureOffsets(featureIndex) + binIndex * statsSize + * Note: For unordered features, the left child stats precede the right child stats + * in the binIndex order. + */ + override protected val allStats: Array[Double] = new Array[Double](allStatsSize) + + override def getNodeOffset(nodeIndex: Int): Int = nodeIndex * nodeStride + + override def nodeUpdate( + nodeOffset: Int, + nodeIndex: Int, + featureIndex: Int, + binIndex: Int, + label: Double, + instanceWeight: Double): Unit = { + val i = nodeOffset + featureOffsets(featureIndex) + binIndex * statsSize + impurityAggregator.update(allStats, i, label, instanceWeight) + } + + override def getNodeFeatureOffset(nodeIndex: Int, featureIndex: Int): Int = { + nodeIndex * nodeStride + featureOffsets(featureIndex) + } +} + +/** + * DecisionTree statistics aggregator. + * This holds a flat array of statistics for a set of (nodes, features, bins) + * and helps with indexing. + * + * This instance of [[DTStatsAggregator]] is used when subsampling features. + * + * @param treeToNodeToIndexInfo Mapping: treeIndex --> nodeIndex --> nodeIndexInfo, + * where nodeIndexInfo stores the index in the group and the + * feature subsets (if using feature subsets). + */ +private[tree] class DTStatsAggregatorSubsampledFeatures( + metadata: DecisionTreeMetadata, + treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]]) extends DTStatsAggregator(metadata) { + + /** + * For each node, offset for each feature for calculating indices into the [[allStats]] array. + * Mapping: nodeIndex --> featureIndex --> offset + */ + private val featureOffsets: Array[Array[Int]] = { + val numNodes: Int = treeToNodeToIndexInfo.values.map(_.size).sum + val offsets = new Array[Array[Int]](numNodes) + treeToNodeToIndexInfo.foreach { case (treeIndex, nodeToIndexInfo) => + nodeToIndexInfo.foreach { case (globalNodeIndex, nodeInfo) => + offsets(nodeInfo.nodeIndexInGroup) = nodeInfo.featureSubset.get.map(metadata.numBins(_)) + .scanLeft(0)((total, nBins) => total + statsSize * nBins) + } + } + offsets + } + + /** + * For each node, offset for each feature for calculating indices into the [[allStats]] array. + */ + protected val nodeOffsets: Array[Int] = featureOffsets.map(_.last).scanLeft(0)(_ + _) + + override val allStatsSize: Int = nodeOffsets.last + + /** + * Flat array of elements. + * Index for start of stats for a (node, feature, bin) is: + * index = nodeOffsets(nodeIndex) + featureOffsets(featureIndex) + binIndex * statsSize + * Note: For unordered features, the left child stats precede the right child stats + * in the binIndex order. + */ + override protected val allStats: Array[Double] = new Array[Double](allStatsSize) + + override def getNodeOffset(nodeIndex: Int): Int = nodeOffsets(nodeIndex) + + /** + * Faster version of [[update]]. + * Update the stats for a given (node, feature, bin) for ordered features, using the given label. + * @param nodeOffset Pre-computed node offset from [[getNodeOffset]]. + * @param featureIndex Index of feature in featuresForNodes(nodeIndex). + * Note: This is NOT the original feature index. + */ + override def nodeUpdate( + nodeOffset: Int, + nodeIndex: Int, + featureIndex: Int, + binIndex: Int, + label: Double, + instanceWeight: Double): Unit = { + val i = nodeOffset + featureOffsets(nodeIndex)(featureIndex) + binIndex * statsSize + impurityAggregator.update(allStats, i, label, instanceWeight) + } + + /** + * Pre-compute (node, feature) offset for use with [[nodeFeatureUpdate]]. + * For ordered features only. + * @param featureIndex Index of feature in featuresForNodes(nodeIndex). + * Note: This is NOT the original feature index. + */ + override def getNodeFeatureOffset(nodeIndex: Int, featureIndex: Int): Int = { + nodeOffsets(nodeIndex) + featureOffsets(nodeIndex)(featureIndex) + } } private[tree] object DTStatsAggregator extends Serializable { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala index b6d49e5555b1a..212dce25236e0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala @@ -48,7 +48,9 @@ private[tree] class DecisionTreeMetadata( val quantileStrategy: QuantileStrategy, val maxDepth: Int, val minInstancesPerNode: Int, - val minInfoGain: Double) extends Serializable { + val minInfoGain: Double, + val numTrees: Int, + val numFeaturesPerNode: Int) extends Serializable { def isUnordered(featureIndex: Int): Boolean = unorderedFeatures.contains(featureIndex) @@ -73,6 +75,11 @@ private[tree] class DecisionTreeMetadata( numBins(featureIndex) - 1 } + /** + * Indicates if feature subsampling is being used. + */ + def subsamplingFeatures: Boolean = numFeatures != numFeaturesPerNode + } private[tree] object DecisionTreeMetadata { @@ -82,7 +89,11 @@ private[tree] object DecisionTreeMetadata { * This computes which categorical features will be ordered vs. unordered, * as well as the number of splits and bins for each feature. */ - def buildMetadata(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeMetadata = { + def buildMetadata( + input: RDD[LabeledPoint], + strategy: Strategy, + numTrees: Int, + featureSubsetStrategy: String): DecisionTreeMetadata = { val numFeatures = input.take(1)(0).features.size val numExamples = input.count() @@ -128,13 +139,43 @@ private[tree] object DecisionTreeMetadata { } } + // Set number of features to use per node (for random forests). + val _featureSubsetStrategy = featureSubsetStrategy match { + case "auto" => + if (numTrees == 1) { + "all" + } else { + if (strategy.algo == Classification) { + "sqrt" + } else { + "onethird" + } + } + case _ => featureSubsetStrategy + } + val numFeaturesPerNode: Int = _featureSubsetStrategy match { + case "all" => numFeatures + case "sqrt" => math.sqrt(numFeatures).ceil.toInt + case "log2" => math.max(1, (math.log(numFeatures) / math.log(2)).ceil.toInt) + case "onethird" => (numFeatures / 3.0).ceil.toInt + } + new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max, strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, numBins, strategy.impurity, strategy.quantileCalculationStrategy, strategy.maxDepth, - strategy.minInstancesPerNode, strategy.minInfoGain) + strategy.minInstancesPerNode, strategy.minInfoGain, numTrees, numFeaturesPerNode) } /** + * Version of [[buildMetadata()]] for DecisionTree. + */ + def buildMetadata( + input: RDD[LabeledPoint], + strategy: Strategy): DecisionTreeMetadata = { + buildMetadata(input, strategy, numTrees = 1, featureSubsetStrategy = "all") + } + + /** * Given the arity of a categorical feature (arity = number of categories), * return the number of bins for the feature if it is to be treated as an unordered feature. * There is 1 split for every partitioning of categories into 2 disjoint, non-empty sets; diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala index 1c8afc2d0f4bc..0e02345aa3774 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala @@ -89,12 +89,12 @@ private[tree] class EntropyAggregator(numClasses: Int) * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous. * @param offset Start index of stats for this (node, feature, bin). */ - def update(allStats: Array[Double], offset: Int, label: Double): Unit = { + def update(allStats: Array[Double], offset: Int, label: Double, instanceWeight: Double): Unit = { if (label >= statsSize) { throw new IllegalArgumentException(s"EntropyAggregator given label $label" + s" but requires label < numClasses (= $statsSize).") } - allStats(offset + label.toInt) += 1 + allStats(offset + label.toInt) += instanceWeight } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala index 5cfdf345d163c..7c83cd48e16a0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala @@ -85,12 +85,12 @@ private[tree] class GiniAggregator(numClasses: Int) * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous. * @param offset Start index of stats for this (node, feature, bin). */ - def update(allStats: Array[Double], offset: Int, label: Double): Unit = { + def update(allStats: Array[Double], offset: Int, label: Double, instanceWeight: Double): Unit = { if (label >= statsSize) { throw new IllegalArgumentException(s"GiniAggregator given label $label" + s" but requires label < numClasses (= $statsSize).") } - allStats(offset + label.toInt) += 1 + allStats(offset + label.toInt) += instanceWeight } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala index 5a047d6cb5480..60e2ab2bb829e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala @@ -78,7 +78,7 @@ private[tree] abstract class ImpurityAggregator(val statsSize: Int) extends Seri * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous. * @param offset Start index of stats for this (node, feature, bin). */ - def update(allStats: Array[Double], offset: Int, label: Double): Unit + def update(allStats: Array[Double], offset: Int, label: Double, instanceWeight: Double): Unit /** * Get an [[ImpurityCalculator]] for a (node, feature, bin). diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala index e9ccecb1b8067..df9eafa5da16a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala @@ -75,10 +75,10 @@ private[tree] class VarianceAggregator() * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous. * @param offset Start index of stats for this (node, feature, bin). */ - def update(allStats: Array[Double], offset: Int, label: Double): Unit = { - allStats(offset) += 1 - allStats(offset + 1) += label - allStats(offset + 2) += label * label + def update(allStats: Array[Double], offset: Int, label: Double, instanceWeight: Double): Unit = { + allStats(offset) += instanceWeight + allStats(offset + 1) += instanceWeight * label + allStats(offset + 2) += instanceWeight * label * label } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala index 5f0095d23c7ed..56c3e25d9285f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala @@ -41,12 +41,12 @@ import org.apache.spark.mllib.linalg.Vector @DeveloperApi class Node ( val id: Int, - val predict: Double, - val isLeaf: Boolean, - val split: Option[Split], + var predict: Double, + var isLeaf: Boolean, + var split: Option[Split], var leftNode: Option[Node], var rightNode: Option[Node], - val stats: Option[InformationGainStats]) extends Serializable with Logging { + var stats: Option[InformationGainStats]) extends Serializable with Logging { override def toString = "id = " + id + ", isLeaf = " + isLeaf + ", predict = " + predict + ", " + "split = " + split + ", stats = " + stats @@ -167,6 +167,11 @@ class Node ( private[tree] object Node { + /** + * Return a node with the given node id (but nothing else set). + */ + def emptyNode(nodeIndex: Int): Node = new Node(nodeIndex, 0, false, None, None, None, None) + /** * Return the index of the left child of this node. */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala new file mode 100644 index 0000000000000..538c0e233202a --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.model + +import scala.collection.mutable + +import org.apache.spark.annotation.Experimental +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.tree.configuration.Algo._ +import org.apache.spark.rdd.RDD + +/** + * :: Experimental :: + * Random forest model for classification or regression. + * This model stores a collection of [[DecisionTreeModel]] instances and uses them to make + * aggregate predictions. + * @param trees Trees which make up this forest. This cannot be empty. + * @param algo algorithm type -- classification or regression + */ +@Experimental +class RandomForestModel(val trees: Array[DecisionTreeModel], val algo: Algo) extends Serializable { + + require(trees.size > 0, s"RandomForestModel cannot be created with empty trees collection.") + + /** + * Predict values for a single data point. + * + * @param features array representing a single data point + * @return Double prediction from the trained model + */ + def predict(features: Vector): Double = { + algo match { + case Classification => + val predictionToCount = new mutable.HashMap[Int, Int]() + trees.foreach { tree => + val prediction = tree.predict(features).toInt + predictionToCount(prediction) = predictionToCount.getOrElse(prediction, 0) + 1 + } + predictionToCount.maxBy(_._2)._1 + case Regression => + trees.map(_.predict(features)).sum / trees.size + } + } + + /** + * Predict values for the given data set. + * + * @param features RDD representing data points to be predicted + * @return RDD[Double] where each entry contains the corresponding prediction + */ + def predict(features: RDD[Vector]): RDD[Double] = { + features.map(x => predict(x)) + } + + /** + * Get number of trees in forest. + */ + def numTrees: Int = trees.size + + /** + * Print full model. + */ + override def toString: String = { + val header = algo match { + case Classification => + s"RandomForestModel classifier with $numTrees trees\n" + case Regression => + s"RandomForestModel regressor with $numTrees trees\n" + case _ => throw new IllegalArgumentException( + s"RandomForestModel given unknown algo parameter: $algo.") + } + header + trees.zipWithIndex.map { case (tree, treeIndex) => + s" Tree $treeIndex:\n" + tree.topNode.subtreeToString(4) + }.fold("")(_ + _) + } + +} + +private[tree] object RandomForestModel { + + def build(trees: Array[DecisionTreeModel]): RandomForestModel = { + require(trees.size > 0, s"RandomForestModel cannot be created with empty trees collection.") + val algo: Algo = trees(0).algo + require(trees.forall(_.algo == algo), + "RandomForestModel cannot combine trees which have different output types" + + " (classification/regression).") + new RandomForestModel(trees, algo) + } + +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 2b2e579b992f6..a48ed71a1c5fc 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.mllib.tree import scala.collection.JavaConverters._ +import scala.collection.mutable import org.scalatest.FunSuite @@ -26,39 +27,13 @@ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.FeatureType._ import org.apache.spark.mllib.tree.configuration.Strategy -import org.apache.spark.mllib.tree.impl.{DecisionTreeMetadata, TreePoint} +import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata, TreePoint} import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance} import org.apache.spark.mllib.tree.model.{InformationGainStats, DecisionTreeModel, Node} import org.apache.spark.mllib.util.LocalSparkContext class DecisionTreeSuite extends FunSuite with LocalSparkContext { - def validateClassifier( - model: DecisionTreeModel, - input: Seq[LabeledPoint], - requiredAccuracy: Double) { - val predictions = input.map(x => model.predict(x.features)) - val numOffPredictions = predictions.zip(input).count { case (prediction, expected) => - prediction != expected.label - } - val accuracy = (input.length - numOffPredictions).toDouble / input.length - assert(accuracy >= requiredAccuracy, - s"validateClassifier calculated accuracy $accuracy but required $requiredAccuracy.") - } - - def validateRegressor( - model: DecisionTreeModel, - input: Seq[LabeledPoint], - requiredMSE: Double) { - val predictions = input.map(x => model.predict(x.features)) - val squaredError = predictions.zip(input).map { case (prediction, expected) => - val err = prediction - expected.label - err * err - }.sum - val mse = squaredError / input.length - assert(mse <= requiredMSE, s"validateRegressor calculated MSE $mse but required $requiredMSE.") - } - test("Binary classification with continuous features: split and bin calculation") { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() assert(arr.length === 1000) @@ -233,7 +208,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { numClassesForClassification = 100, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 10, 1-> 10)) - // 2^10 - 1 > 100, so categorical features will be ordered + // 2^(10-1) - 1 > 100, so categorical features will be ordered val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) assert(!metadata.isUnordered(featureIndex = 0)) @@ -269,9 +244,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(splits(0).length === 0) assert(bins(0).length === 0) - val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val (rootNode: Node, doneTraining: Boolean) = - DecisionTree.findBestSplits(treeInput, metadata, 0, null, splits, bins, 10) + val rootNode = DecisionTree.train(rdd, strategy).topNode val split = rootNode.split.get assert(split.categories === List(1.0)) @@ -299,10 +272,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(!metadata.isUnordered(featureIndex = 0)) assert(!metadata.isUnordered(featureIndex = 1)) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) - val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0, - null, splits, bins, 10) + val rootNode = DecisionTree.train(rdd, strategy).topNode val split = rootNode.split.get assert(split.categories.length === 1) @@ -331,7 +301,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(!metadata.isUnordered(featureIndex = 1)) val model = DecisionTree.train(rdd, strategy) - validateRegressor(model, arr, 0.0) + DecisionTreeSuite.validateRegressor(model, arr, 0.0) assert(model.numNodes === 3) assert(model.depth === 1) } @@ -352,12 +322,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bins.length === 2) assert(bins(0).length === 100) - val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0, - null, splits, bins, 10) - - val split = rootNode.split.get - assert(split.feature === 0) + val rootNode = DecisionTree.train(rdd, strategy).topNode val stats = rootNode.stats.get assert(stats.gain === 0) @@ -381,12 +346,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bins.length === 2) assert(bins(0).length === 100) - val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0, - null, splits, bins, 10) - - val split = rootNode.split.get - assert(split.feature === 0) + val rootNode = DecisionTree.train(rdd, strategy).topNode val stats = rootNode.stats.get assert(stats.gain === 0) @@ -411,12 +371,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bins.length === 2) assert(bins(0).length === 100) - val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0, - null, splits, bins, 10) - - val split = rootNode.split.get - assert(split.feature === 0) + val rootNode = DecisionTree.train(rdd, strategy).topNode val stats = rootNode.stats.get assert(stats.gain === 0) @@ -441,12 +396,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bins.length === 2) assert(bins(0).length === 100) - val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0, - null, splits, bins, 10) - - val split = rootNode.split.get - assert(split.feature === 0) + val rootNode = DecisionTree.train(rdd, strategy).topNode val stats = rootNode.stats.get assert(stats.gain === 0) @@ -471,25 +421,39 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val strategyOneNode = new Strategy(Classification, Entropy, maxDepth = 1, numClassesForClassification = 2, maxBins = 100) val modelOneNode = DecisionTree.train(rdd, strategyOneNode) - val rootNodeCopy1 = modelOneNode.topNode.deepCopy() - val rootNodeCopy2 = modelOneNode.topNode.deepCopy() + val rootNode1 = modelOneNode.topNode.deepCopy() + val rootNode2 = modelOneNode.topNode.deepCopy() + assert(rootNode1.leftNode.nonEmpty) + assert(rootNode1.rightNode.nonEmpty) - // Single group second level tree construction. val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val (rootNode, _) = DecisionTree.findBestSplits(treeInput, metadata, 1, - rootNodeCopy1, splits, bins, 10) - assert(rootNode.leftNode.nonEmpty) - assert(rootNode.rightNode.nonEmpty) + val baggedInput = BaggedPoint.convertToBaggedRDDWithoutSampling(treeInput) + + // Single group second level tree construction. + val nodesForGroup = Map((0, Array(rootNode1.leftNode.get, rootNode1.rightNode.get))) + val treeToNodeToIndexInfo = Map((0, Map( + (rootNode1.leftNode.get.id, new RandomForest.NodeIndexInfo(0, None)), + (rootNode1.rightNode.get.id, new RandomForest.NodeIndexInfo(1, None))))) + val nodeQueue = new mutable.Queue[(Int, Node)]() + DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode1), + nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue) val children1 = new Array[Node](2) - children1(0) = rootNode.leftNode.get - children1(1) = rootNode.rightNode.get - - // maxLevelForSingleGroup parameter is set to 0 to force splitting into groups for second - // level tree construction. - val (rootNode2, _) = DecisionTree.findBestSplits(treeInput, metadata, 1, - rootNodeCopy2, splits, bins, 0) - assert(rootNode2.leftNode.nonEmpty) - assert(rootNode2.rightNode.nonEmpty) + children1(0) = rootNode1.leftNode.get + children1(1) = rootNode1.rightNode.get + + // Train one second-level node at a time. + val nodesForGroupA = Map((0, Array(rootNode2.leftNode.get))) + val treeToNodeToIndexInfoA = Map((0, Map( + (rootNode2.leftNode.get.id, new RandomForest.NodeIndexInfo(0, None))))) + nodeQueue.clear() + DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode2), + nodesForGroupA, treeToNodeToIndexInfoA, splits, bins, nodeQueue) + val nodesForGroupB = Map((0, Array(rootNode2.rightNode.get))) + val treeToNodeToIndexInfoB = Map((0, Map( + (rootNode2.rightNode.get.id, new RandomForest.NodeIndexInfo(0, None))))) + nodeQueue.clear() + DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode2), + nodesForGroupB, treeToNodeToIndexInfoB, splits, bins, nodeQueue) val children2 = new Array[Node](2) children2(0) = rootNode2.leftNode.get children2(1) = rootNode2.rightNode.get @@ -521,10 +485,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(metadata.isUnordered(featureIndex = 0)) assert(metadata.isUnordered(featureIndex = 1)) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) - val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0, - null, splits, bins, 10) + val rootNode = DecisionTree.train(rdd, strategy).topNode val split = rootNode.split.get assert(split.feature === 0) @@ -544,7 +505,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { numClassesForClassification = 2) val model = DecisionTree.train(rdd, strategy) - validateClassifier(model, arr, 1.0) + DecisionTreeSuite.validateClassifier(model, arr, 1.0) assert(model.numNodes === 3) assert(model.depth === 1) } @@ -561,7 +522,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { numClassesForClassification = 2) val model = DecisionTree.train(rdd, strategy) - validateClassifier(model, arr, 1.0) + DecisionTreeSuite.validateClassifier(model, arr, 1.0) assert(model.numNodes === 3) assert(model.depth === 1) assert(model.topNode.split.get.feature === 1) @@ -581,14 +542,11 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(metadata.isUnordered(featureIndex = 1)) val model = DecisionTree.train(rdd, strategy) - validateClassifier(model, arr, 1.0) + DecisionTreeSuite.validateClassifier(model, arr, 1.0) assert(model.numNodes === 3) assert(model.depth === 1) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) - val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0, - null, splits, bins, 10) + val rootNode = model.topNode val split = rootNode.split.get assert(split.feature === 0) @@ -610,12 +568,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) val model = DecisionTree.train(rdd, strategy) - validateClassifier(model, arr, 0.9) + DecisionTreeSuite.validateClassifier(model, arr, 0.9) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) - val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0, - null, splits, bins, 10) + val rootNode = model.topNode val split = rootNode.split.get assert(split.feature === 1) @@ -635,12 +590,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(metadata.isUnordered(featureIndex = 0)) val model = DecisionTree.train(rdd, strategy) - validateClassifier(model, arr, 0.9) + DecisionTreeSuite.validateClassifier(model, arr, 0.9) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) - val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0, - null, splits, bins, 10) + val rootNode = model.topNode val split = rootNode.split.get assert(split.feature === 1) @@ -660,10 +612,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(!metadata.isUnordered(featureIndex = 0)) assert(!metadata.isUnordered(featureIndex = 1)) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) - val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0, - null, splits, bins, 10) + val rootNode = DecisionTree.train(rdd, strategy).topNode val split = rootNode.split.get assert(split.feature === 0) @@ -682,7 +631,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(strategy.isMulticlassClassification) val model = DecisionTree.train(rdd, strategy) - validateClassifier(model, arr, 0.6) + DecisionTreeSuite.validateClassifier(model, arr, 0.6) } test("split must satisfy min instances per node requirements") { @@ -691,24 +640,20 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { arr(1) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))) arr(2) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0)))) - val input = sc.parallelize(arr) + val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2, numClassesForClassification = 2, minInstancesPerNode = 2) - val model = DecisionTree.train(input, strategy) + val model = DecisionTree.train(rdd, strategy) assert(model.topNode.isLeaf) assert(model.topNode.predict == 0.0) - val predicts = input.map(p => model.predict(p.features)).collect() + val predicts = rdd.map(p => model.predict(p.features)).collect() predicts.foreach { predict => assert(predict == 0.0) } - // test for findBestSplits when no valid split can be found - val metadata = DecisionTreeMetadata.buildMetadata(input, strategy) - val (splits, bins) = DecisionTree.findSplitsBins(input, metadata) - val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata) - val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0, - null, splits, bins, 10) + // test when no valid split can be found + val rootNode = model.topNode val gain = rootNode.stats.get assert(gain == InformationGainStats.invalidInformationGainStats) @@ -723,15 +668,12 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { arr(2) = new LabeledPoint(0.0, Vectors.dense(0.0, 0.0)) arr(3) = new LabeledPoint(0.0, Vectors.dense(0.0, 0.0)) - val input = sc.parallelize(arr) + val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxBins = 2, maxDepth = 2, categoricalFeaturesInfo = Map(0 -> 2, 1-> 2), numClassesForClassification = 2, minInstancesPerNode = 2) - val metadata = DecisionTreeMetadata.buildMetadata(input, strategy) - val (splits, bins) = DecisionTree.findSplitsBins(input, metadata) - val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata) - val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0, - null, splits, bins, 10) + + val rootNode = DecisionTree.train(rdd, strategy).topNode val split = rootNode.split.get val gain = rootNode.stats.get @@ -757,12 +699,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(predict == 0.0) } - // test for findBestSplits when no valid split can be found - val metadata = DecisionTreeMetadata.buildMetadata(input, strategy) - val (splits, bins) = DecisionTree.findSplitsBins(input, metadata) - val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata) - val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0, - null, splits, bins, 10) + // test when no valid split can be found + val rootNode = model.topNode val gain = rootNode.stats.get assert(gain == InformationGainStats.invalidInformationGainStats) @@ -771,6 +709,32 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { object DecisionTreeSuite { + def validateClassifier( + model: DecisionTreeModel, + input: Seq[LabeledPoint], + requiredAccuracy: Double) { + val predictions = input.map(x => model.predict(x.features)) + val numOffPredictions = predictions.zip(input).count { case (prediction, expected) => + prediction != expected.label + } + val accuracy = (input.length - numOffPredictions).toDouble / input.length + assert(accuracy >= requiredAccuracy, + s"validateClassifier calculated accuracy $accuracy but required $requiredAccuracy.") + } + + def validateRegressor( + model: DecisionTreeModel, + input: Seq[LabeledPoint], + requiredMSE: Double) { + val predictions = input.map(x => model.predict(x.features)) + val squaredError = predictions.zip(input).map { case (prediction, expected) => + val err = prediction - expected.label + err * err + }.sum + val mse = squaredError / input.length + assert(mse <= requiredMSE, s"validateRegressor calculated MSE $mse but required $requiredMSE.") + } + def generateOrderedLabeledPointsWithLabel0(): Array[LabeledPoint] = { val arr = new Array[LabeledPoint](1000) for (i <- 0 until 1000) { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala new file mode 100644 index 0000000000000..30669fcd1c75b --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala @@ -0,0 +1,245 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree + +import scala.collection.mutable + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.configuration.Algo._ +import org.apache.spark.mllib.tree.configuration.Strategy +import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata} +import org.apache.spark.mllib.tree.impurity.{Gini, Variance} +import org.apache.spark.mllib.tree.model.{Node, RandomForestModel} +import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.util.StatCounter + +/** + * Test suite for [[RandomForest]]. + */ +class RandomForestSuite extends FunSuite with LocalSparkContext { + + test("BaggedPoint RDD: without subsampling") { + val arr = RandomForestSuite.generateOrderedLabeledPoints(numFeatures = 1) + val rdd = sc.parallelize(arr) + val baggedRDD = BaggedPoint.convertToBaggedRDDWithoutSampling(rdd) + baggedRDD.collect().foreach { baggedPoint => + assert(baggedPoint.subsampleWeights.size == 1 && baggedPoint.subsampleWeights(0) == 1) + } + } + + test("BaggedPoint RDD: with subsampling") { + val numSubsamples = 100 + val (expectedMean, expectedStddev) = (1.0, 1.0) + + val seeds = Array(123, 5354, 230, 349867, 23987) + val arr = RandomForestSuite.generateOrderedLabeledPoints(numFeatures = 1) + val rdd = sc.parallelize(arr) + seeds.foreach { seed => + val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, numSubsamples, seed = seed) + val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect() + RandomForestSuite.testRandomArrays(subsampleCounts, numSubsamples, expectedMean, + expectedStddev, epsilon = 0.01) + } + } + + test("Binary classification with continuous features:" + + " comparing DecisionTree vs. RandomForest(numTrees = 1)") { + + val arr = RandomForestSuite.generateOrderedLabeledPoints(numFeatures = 50) + val rdd = sc.parallelize(arr) + val categoricalFeaturesInfo = Map.empty[Int, Int] + val numTrees = 1 + + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2, + numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo) + + val rf = RandomForest.trainClassifier(rdd, strategy, numTrees = numTrees, + featureSubsetStrategy = "auto", seed = 123) + assert(rf.trees.size === 1) + val rfTree = rf.trees(0) + + val dt = DecisionTree.train(rdd, strategy) + + RandomForestSuite.validateClassifier(rf, arr, 0.9) + DecisionTreeSuite.validateClassifier(dt, arr, 0.9) + + // Make sure trees are the same. + assert(rfTree.toString == dt.toString) + } + + test("Regression with continuous features:" + + " comparing DecisionTree vs. RandomForest(numTrees = 1)") { + + val arr = RandomForestSuite.generateOrderedLabeledPoints(numFeatures = 50) + val rdd = sc.parallelize(arr) + val categoricalFeaturesInfo = Map.empty[Int, Int] + val numTrees = 1 + + val strategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2, + numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo) + + val rf = RandomForest.trainRegressor(rdd, strategy, numTrees = numTrees, + featureSubsetStrategy = "auto", seed = 123) + assert(rf.trees.size === 1) + val rfTree = rf.trees(0) + + val dt = DecisionTree.train(rdd, strategy) + + RandomForestSuite.validateRegressor(rf, arr, 0.01) + DecisionTreeSuite.validateRegressor(dt, arr, 0.01) + + // Make sure trees are the same. + assert(rfTree.toString == dt.toString) + } + + test("Binary classification with continuous features: subsampling features") { + val numFeatures = 50 + val arr = RandomForestSuite.generateOrderedLabeledPoints(numFeatures) + val rdd = sc.parallelize(arr) + val categoricalFeaturesInfo = Map.empty[Int, Int] + + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2, + numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo) + + // Select feature subset for top nodes. Return true if OK. + def checkFeatureSubsetStrategy( + numTrees: Int, + featureSubsetStrategy: String, + numFeaturesPerNode: Int): Unit = { + val seeds = Array(123, 5354, 230, 349867, 23987) + val maxMemoryUsage: Long = 128 * 1024L * 1024L + val metadata = + DecisionTreeMetadata.buildMetadata(rdd, strategy, numTrees, featureSubsetStrategy) + seeds.foreach { seed => + val failString = s"Failed on test with:" + + s"numTrees=$numTrees, featureSubsetStrategy=$featureSubsetStrategy," + + s" numFeaturesPerNode=$numFeaturesPerNode, seed=$seed" + val nodeQueue = new mutable.Queue[(Int, Node)]() + val topNodes: Array[Node] = new Array[Node](numTrees) + Range(0, numTrees).foreach { treeIndex => + topNodes(treeIndex) = Node.emptyNode(nodeIndex = 1) + nodeQueue.enqueue((treeIndex, topNodes(treeIndex))) + } + val rng = new scala.util.Random(seed = seed) + val (nodesForGroup: Map[Int, Array[Node]], + treeToNodeToIndexInfo: Map[Int, Map[Int, RandomForest.NodeIndexInfo]]) = + RandomForest.selectNodesToSplit(nodeQueue, maxMemoryUsage, metadata, rng) + + assert(nodesForGroup.size === numTrees, failString) + assert(nodesForGroup.values.forall(_.size == 1), failString) // 1 node per tree + if (numFeaturesPerNode == numFeatures) { + // featureSubset values should all be None + assert(treeToNodeToIndexInfo.values.forall(_.values.forall(_.featureSubset.isEmpty)), + failString) + } else { + // Check number of features. + assert(treeToNodeToIndexInfo.values.forall(_.values.forall( + _.featureSubset.get.size === numFeaturesPerNode)), failString) + } + } + } + + checkFeatureSubsetStrategy(numTrees = 1, "auto", numFeatures) + checkFeatureSubsetStrategy(numTrees = 1, "all", numFeatures) + checkFeatureSubsetStrategy(numTrees = 1, "sqrt", math.sqrt(numFeatures).ceil.toInt) + checkFeatureSubsetStrategy(numTrees = 1, "log2", + (math.log(numFeatures) / math.log(2)).ceil.toInt) + checkFeatureSubsetStrategy(numTrees = 1, "onethird", (numFeatures / 3.0).ceil.toInt) + + checkFeatureSubsetStrategy(numTrees = 2, "all", numFeatures) + checkFeatureSubsetStrategy(numTrees = 2, "auto", math.sqrt(numFeatures).ceil.toInt) + checkFeatureSubsetStrategy(numTrees = 2, "sqrt", math.sqrt(numFeatures).ceil.toInt) + checkFeatureSubsetStrategy(numTrees = 2, "log2", + (math.log(numFeatures) / math.log(2)).ceil.toInt) + checkFeatureSubsetStrategy(numTrees = 2, "onethird", (numFeatures / 3.0).ceil.toInt) + } + +} + +object RandomForestSuite { + + /** + * Aggregates all values in data, and tests whether the empirical mean and stddev are within + * epsilon of the expected values. + * @param data Every element of the data should be an i.i.d. sample from some distribution. + */ + def testRandomArrays( + data: Array[Array[Double]], + numCols: Int, + expectedMean: Double, + expectedStddev: Double, + epsilon: Double) { + val values = new mutable.ArrayBuffer[Double]() + data.foreach { row => + assert(row.size == numCols) + values ++= row + } + val stats = new StatCounter(values) + assert(math.abs(stats.mean - expectedMean) < epsilon) + assert(math.abs(stats.stdev - expectedStddev) < epsilon) + } + + def validateClassifier( + model: RandomForestModel, + input: Seq[LabeledPoint], + requiredAccuracy: Double) { + val predictions = input.map(x => model.predict(x.features)) + val numOffPredictions = predictions.zip(input).count { case (prediction, expected) => + prediction != expected.label + } + val accuracy = (input.length - numOffPredictions).toDouble / input.length + assert(accuracy >= requiredAccuracy, + s"validateClassifier calculated accuracy $accuracy but required $requiredAccuracy.") + } + + def validateRegressor( + model: RandomForestModel, + input: Seq[LabeledPoint], + requiredMSE: Double) { + val predictions = input.map(x => model.predict(x.features)) + val squaredError = predictions.zip(input).map { case (prediction, expected) => + val err = prediction - expected.label + err * err + }.sum + val mse = squaredError / input.length + assert(mse <= requiredMSE, s"validateRegressor calculated MSE $mse but required $requiredMSE.") + } + + def generateOrderedLabeledPoints(numFeatures: Int): Array[LabeledPoint] = { + val numInstances = 1000 + val arr = new Array[LabeledPoint](numInstances) + for (i <- 0 until numInstances) { + val label = if (i < numInstances / 10) { + 0.0 + } else if (i < numInstances / 2) { + 1.0 + } else if (i < numInstances * 0.9) { + 0.0 + } else { + 1.0 + } + val features = Array.fill[Double](numFeatures)(i.toDouble) + arr(i) = new LabeledPoint(label, Vectors.dense(features)) + } + arr + } + +} From 1651cc117d73f0af6ec9f55b0c6c9b2bd565906c Mon Sep 17 00:00:00 2001 From: Nicholas Chammas Date: Sun, 28 Sep 2014 21:55:09 -0700 Subject: [PATCH 2/8] [EC2] Cleanup Python parens and disk dict Minor fixes: * Remove unnecessary parens (Python style) * Sort `disks_by_instance` dict and remove duplicate `t1.micro` key Author: Nicholas Chammas Closes #2571 from nchammas/ec2-polish and squashes the following commits: 9d203d5 [Nicholas Chammas] paren and dict cleanup --- ec2/spark_ec2.py | 60 ++++++++++++++++++++++++------------------------ 1 file changed, 30 insertions(+), 30 deletions(-) diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index 7f2cd7d94de39..5776d0b519309 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -508,7 +508,7 @@ def tag_instance(instance, name): break except: print "Failed attempt %i of 5 to tag %s" % ((i + 1), name) - if (i == 5): + if i == 5: raise "Error - failed max attempts to add name tag" time.sleep(5) @@ -530,7 +530,7 @@ def get_existing_cluster(conn, opts, cluster_name, die_on_error=True): for res in reservations: active = [i for i in res.instances if is_active(i)] for instance in active: - if (instance.tags.get(u'Name') is None): + if instance.tags.get(u'Name') is None: tag_instance(instance, name) # Now proceed to detect master and slaves instances. reservations = conn.get_all_instances() @@ -545,7 +545,7 @@ def get_existing_cluster(conn, opts, cluster_name, die_on_error=True): elif name.startswith(cluster_name + "-slave"): slave_nodes.append(inst) if any((master_nodes, slave_nodes)): - print ("Found %d master(s), %d slaves" % (len(master_nodes), len(slave_nodes))) + print "Found %d master(s), %d slaves" % (len(master_nodes), len(slave_nodes)) if master_nodes != [] or not die_on_error: return (master_nodes, slave_nodes) else: @@ -626,43 +626,43 @@ def wait_for_cluster(conn, wait_secs, master_nodes, slave_nodes): def get_num_disks(instance_type): # From http://docs.aws.amazon.com/AWSEC2/latest/UserGuide/InstanceStorage.html # Updated 2014-6-20 + # For easy maintainability, please keep this manually-inputted dictionary sorted by key. disks_by_instance = { - "m1.small": 1, - "m1.medium": 1, - "m1.large": 2, - "m1.xlarge": 4, - "t1.micro": 1, "c1.medium": 1, "c1.xlarge": 4, - "m2.xlarge": 1, - "m2.2xlarge": 1, - "m2.4xlarge": 2, + "c3.2xlarge": 2, + "c3.4xlarge": 2, + "c3.8xlarge": 2, + "c3.large": 2, + "c3.xlarge": 2, "cc1.4xlarge": 2, "cc2.8xlarge": 4, "cg1.4xlarge": 2, - "hs1.8xlarge": 24, "cr1.8xlarge": 2, + "g2.2xlarge": 1, "hi1.4xlarge": 2, - "m3.medium": 1, - "m3.large": 1, - "m3.xlarge": 2, - "m3.2xlarge": 2, - "i2.xlarge": 1, + "hs1.8xlarge": 24, "i2.2xlarge": 2, "i2.4xlarge": 4, "i2.8xlarge": 8, - "c3.large": 2, - "c3.xlarge": 2, - "c3.2xlarge": 2, - "c3.4xlarge": 2, - "c3.8xlarge": 2, - "r3.large": 1, - "r3.xlarge": 1, + "i2.xlarge": 1, + "m1.large": 2, + "m1.medium": 1, + "m1.small": 1, + "m1.xlarge": 4, + "m2.2xlarge": 1, + "m2.4xlarge": 2, + "m2.xlarge": 1, + "m3.2xlarge": 2, + "m3.large": 1, + "m3.medium": 1, + "m3.xlarge": 2, "r3.2xlarge": 1, "r3.4xlarge": 1, "r3.8xlarge": 2, - "g2.2xlarge": 1, - "t1.micro": 0 + "r3.large": 1, + "r3.xlarge": 1, + "t1.micro": 0, } if instance_type in disks_by_instance: return disks_by_instance[instance_type] @@ -785,7 +785,7 @@ def ssh(host, opts, command): ssh_command(opts) + ['-t', '-t', '%s@%s' % (opts.user, host), stringify_command(command)]) except subprocess.CalledProcessError as e: - if (tries > 5): + if tries > 5: # If this was an ssh failure, provide the user with hints. if e.returncode == 255: raise UsageError( @@ -820,18 +820,18 @@ def ssh_read(host, opts, command): ssh_command(opts) + ['%s@%s' % (opts.user, host), stringify_command(command)]) -def ssh_write(host, opts, command, input): +def ssh_write(host, opts, command, arguments): tries = 0 while True: proc = subprocess.Popen( ssh_command(opts) + ['%s@%s' % (opts.user, host), stringify_command(command)], stdin=subprocess.PIPE) - proc.stdin.write(input) + proc.stdin.write(arguments) proc.stdin.close() status = proc.wait() if status == 0: break - elif (tries > 5): + elif tries > 5: raise RuntimeError("ssh_write failed with error %s" % proc.returncode) else: print >> stderr, \ From 657bdff41a27568a981b3e342ad380fe92aa08a0 Mon Sep 17 00:00:00 2001 From: "Zhang, Liye" Date: Mon, 29 Sep 2014 01:13:15 -0700 Subject: [PATCH 3/8] [CORE] Bugfix: LogErr format in DAGScheduler.scala Author: Zhang, Liye Closes #2572 from liyezhang556520/DAGLogErr and squashes the following commits: 5be2491 [Zhang, Liye] Bugfix: LogErr format in DAGScheduler.scala --- .../main/scala/org/apache/spark/scheduler/DAGScheduler.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 70c235dffff70..5a96f52a10cd4 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1209,7 +1209,7 @@ class DAGScheduler( .format(job.jobId, stageId)) } else if (jobsForStage.get.size == 1) { if (!stageIdToStage.contains(stageId)) { - logError("Missing Stage for stage with id $stageId") + logError(s"Missing Stage for stage with id $stageId") } else { // This is the only job that uses this stage, so fail the stage if it is running. val stage = stageIdToStage(stageId) From aedd251c54fd130fe6e2f28d7587d39136e7ad1c Mon Sep 17 00:00:00 2001 From: Nicholas Chammas Date: Mon, 29 Sep 2014 10:45:08 -0700 Subject: [PATCH 4/8] [EC2] Sort long, manually-inputted dictionaries Similar to the work done in #2571, this PR just sorts the remaining manually-inputted dicts in the EC2 script so they are easier to maintain. Author: Nicholas Chammas Closes #2578 from nchammas/ec2-dict-sort and squashes the following commits: f55c692 [Nicholas Chammas] sort long dictionaries --- ec2/spark_ec2.py | 69 ++++++++++++++++++++++++++---------------------- 1 file changed, 38 insertions(+), 31 deletions(-) diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index 5776d0b519309..941dfb988b9fb 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -217,8 +217,15 @@ def is_active(instance): # Return correct versions of Spark and Shark, given the supplied Spark version def get_spark_shark_version(opts): spark_shark_map = { - "0.7.3": "0.7.1", "0.8.0": "0.8.0", "0.8.1": "0.8.1", "0.9.0": "0.9.0", "0.9.1": "0.9.1", - "1.0.0": "1.0.0", "1.0.1": "1.0.1", "1.0.2": "1.0.2", "1.1.0": "1.1.0" + "0.7.3": "0.7.1", + "0.8.0": "0.8.0", + "0.8.1": "0.8.1", + "0.9.0": "0.9.0", + "0.9.1": "0.9.1", + "1.0.0": "1.0.0", + "1.0.1": "1.0.1", + "1.0.2": "1.0.2", + "1.1.0": "1.1.0", } version = opts.spark_version.replace("v", "") if version not in spark_shark_map: @@ -227,49 +234,49 @@ def get_spark_shark_version(opts): return (version, spark_shark_map[version]) -# Attempt to resolve an appropriate AMI given the architecture and -# region of the request. -# Information regarding Amazon Linux AMI instance type was update on 2014-6-20: -# http://aws.amazon.com/amazon-linux-ami/instance-type-matrix/ +# Attempt to resolve an appropriate AMI given the architecture and region of the request. +# Source: http://aws.amazon.com/amazon-linux-ami/instance-type-matrix/ +# Last Updated: 2014-06-20 +# For easy maintainability, please keep this manually-inputted dictionary sorted by key. def get_spark_ami(opts): instance_types = { - "m1.small": "pvm", - "m1.medium": "pvm", - "m1.large": "pvm", - "m1.xlarge": "pvm", - "t1.micro": "pvm", "c1.medium": "pvm", "c1.xlarge": "pvm", - "m2.xlarge": "pvm", - "m2.2xlarge": "pvm", - "m2.4xlarge": "pvm", + "c3.2xlarge": "pvm", + "c3.4xlarge": "pvm", + "c3.8xlarge": "pvm", + "c3.large": "pvm", + "c3.xlarge": "pvm", "cc1.4xlarge": "hvm", "cc2.8xlarge": "hvm", "cg1.4xlarge": "hvm", - "hs1.8xlarge": "pvm", - "hi1.4xlarge": "pvm", - "m3.medium": "hvm", - "m3.large": "hvm", - "m3.xlarge": "hvm", - "m3.2xlarge": "hvm", "cr1.8xlarge": "hvm", - "i2.xlarge": "hvm", + "hi1.4xlarge": "pvm", + "hs1.8xlarge": "pvm", "i2.2xlarge": "hvm", "i2.4xlarge": "hvm", "i2.8xlarge": "hvm", - "c3.large": "pvm", - "c3.xlarge": "pvm", - "c3.2xlarge": "pvm", - "c3.4xlarge": "pvm", - "c3.8xlarge": "pvm", - "r3.large": "hvm", - "r3.xlarge": "hvm", + "i2.xlarge": "hvm", + "m1.large": "pvm", + "m1.medium": "pvm", + "m1.small": "pvm", + "m1.xlarge": "pvm", + "m2.2xlarge": "pvm", + "m2.4xlarge": "pvm", + "m2.xlarge": "pvm", + "m3.2xlarge": "hvm", + "m3.large": "hvm", + "m3.medium": "hvm", + "m3.xlarge": "hvm", "r3.2xlarge": "hvm", "r3.4xlarge": "hvm", "r3.8xlarge": "hvm", + "r3.large": "hvm", + "r3.xlarge": "hvm", + "t1.micro": "pvm", + "t2.medium": "hvm", "t2.micro": "hvm", "t2.small": "hvm", - "t2.medium": "hvm" } if opts.instance_type in instance_types: instance_type = instance_types[opts.instance_type] @@ -624,8 +631,8 @@ def wait_for_cluster(conn, wait_secs, master_nodes, slave_nodes): # Get number of local disks available for a given EC2 instance type. def get_num_disks(instance_type): - # From http://docs.aws.amazon.com/AWSEC2/latest/UserGuide/InstanceStorage.html - # Updated 2014-6-20 + # Source: http://docs.aws.amazon.com/AWSEC2/latest/UserGuide/InstanceStorage.html + # Last Updated: 2014-06-20 # For easy maintainability, please keep this manually-inputted dictionary sorted by key. disks_by_instance = { "c1.medium": 1, From 587a0cd7ed964ebfca2c97924c4f1e363f1fd3cb Mon Sep 17 00:00:00 2001 From: Reza Zadeh Date: Mon, 29 Sep 2014 11:15:09 -0700 Subject: [PATCH 5/8] [MLlib] [SPARK-2885] DIMSUM: All-pairs similarity MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # All-pairs similarity via DIMSUM Compute all pairs of similar vectors using brute force approach, and also DIMSUM sampling approach. Laying down some notation: we are looking for all pairs of similar columns in an m x n RowMatrix whose entries are denoted a_ij, with the i’th row denoted r_i and the j’th column denoted c_j. There is an oversampling parameter labeled ɣ that should be set to 4 log(n)/s to get provably correct results (with high probability), where s is the similarity threshold. The algorithm is stated with a Map and Reduce, with proofs of correctness and efficiency in published papers [1] [2]. The reducer is simply the summation reducer. The mapper is more interesting, and is also the heart of the scheme. As an exercise, you should try to see why in expectation, the map-reduce below outputs cosine similarities. ![dimsumv2](https://cloud.githubusercontent.com/assets/3220351/3807272/d1d9514e-1c62-11e4-9f12-3cfdb1d78b3a.png) [1] Bosagh-Zadeh, Reza and Carlsson, Gunnar (2013), Dimension Independent Matrix Square using MapReduce, arXiv:1304.1467 http://arxiv.org/abs/1304.1467 [2] Bosagh-Zadeh, Reza and Goel, Ashish (2012), Dimension Independent Similarity Computation, arXiv:1206.2082 http://arxiv.org/abs/1206.2082 # Testing Tests for all invocations included. Added L1 and L2 norm computation to MultivariateStatisticalSummary since it was needed. Added tests for both of them. Author: Reza Zadeh Author: Xiangrui Meng Closes #1778 from rezazadeh/dimsumv2 and squashes the following commits: 404c64c [Reza Zadeh] Merge remote-tracking branch 'upstream/master' into dimsumv2 4eb71c6 [Reza Zadeh] Add excludes for normL1 and normL2 ee8bd65 [Reza Zadeh] Merge remote-tracking branch 'upstream/master' into dimsumv2 976ddd4 [Reza Zadeh] Broadcast colMags. Avoid div by zero. 3467cff [Reza Zadeh] Merge remote-tracking branch 'upstream/master' into dimsumv2 aea0247 [Reza Zadeh] Allow large thresholds to promote sparsity 9fe17c0 [Xiangrui Meng] organize imports 2196ba5 [Xiangrui Meng] Merge branch 'rezazadeh-dimsumv2' into dimsumv2 254ca08 [Reza Zadeh] Merge remote-tracking branch 'upstream/master' into dimsumv2 f2947e4 [Xiangrui Meng] some optimization 3c4cf41 [Xiangrui Meng] Merge branch 'master' into rezazadeh-dimsumv2 0e4eda4 [Reza Zadeh] Use partition index for RNG 251bb9c [Reza Zadeh] Documentation 25e9d0d [Reza Zadeh] Line length for style fb296f6 [Reza Zadeh] renamed to normL1 and normL2 3764983 [Reza Zadeh] Documentation e9c6791 [Reza Zadeh] New interface and documentation 613f261 [Reza Zadeh] Column magnitude summary 75a0b51 [Reza Zadeh] Use Ints instead of Longs in the shuffle 0f12ade [Reza Zadeh] Style changes eb1dc20 [Reza Zadeh] Use Double.PositiveInfinity instead of Double.Max f56a882 [Reza Zadeh] Remove changes to MultivariateOnlineSummarizer dbc55ba [Reza Zadeh] Make colMagnitudes a method in RowMatrix 41e8ece [Reza Zadeh] style changes 139c8e1 [Reza Zadeh] Syntax changes 029aa9c [Reza Zadeh] javadoc and new test 75edb25 [Reza Zadeh] All tests passing! 05e59b8 [Reza Zadeh] Add test 502ce52 [Reza Zadeh] new interface 654c4fb [Reza Zadeh] default methods 3726ca9 [Reza Zadeh] Remove MatrixAlgebra 6bebabb [Reza Zadeh] remove changes to MatrixSuite 5b8cd7d [Reza Zadeh] Initial files --- .../mllib/linalg/distributed/RowMatrix.scala | 171 +++++++++++++++++- .../stat/MultivariateOnlineSummarizer.scala | 38 +++- .../stat/MultivariateStatisticalSummary.scala | 10 + .../linalg/distributed/RowMatrixSuite.scala | 37 ++++ project/MimaExcludes.scala | 9 +- 5 files changed, 259 insertions(+), 6 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index 4174f45d231c7..8380058cf9b41 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -19,17 +19,21 @@ package org.apache.spark.mllib.linalg.distributed import java.util.Arrays -import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, SparseVector => BSV} -import breeze.linalg.{svd => brzSvd, axpy => brzAxpy} +import scala.collection.mutable.ListBuffer + +import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, SparseVector => BSV, axpy => brzAxpy, + svd => brzSvd} import breeze.numerics.{sqrt => brzSqrt} import com.github.fommil.netlib.BLAS.{getInstance => blas} +import org.apache.spark.Logging +import org.apache.spark.SparkContext._ import org.apache.spark.annotation.Experimental import org.apache.spark.mllib.linalg._ -import org.apache.spark.rdd.RDD -import org.apache.spark.Logging import org.apache.spark.mllib.rdd.RDDFunctions._ import org.apache.spark.mllib.stat.{MultivariateOnlineSummarizer, MultivariateStatisticalSummary} +import org.apache.spark.rdd.RDD +import org.apache.spark.util.random.XORShiftRandom import org.apache.spark.storage.StorageLevel /** @@ -411,6 +415,165 @@ class RowMatrix( new RowMatrix(AB, nRows, B.numCols) } + /** + * Compute all cosine similarities between columns of this matrix using the brute-force + * approach of computing normalized dot products. + * + * @return An n x n sparse upper-triangular matrix of cosine similarities between + * columns of this matrix. + */ + def columnSimilarities(): CoordinateMatrix = { + columnSimilarities(0.0) + } + + /** + * Compute similarities between columns of this matrix using a sampling approach. + * + * The threshold parameter is a trade-off knob between estimate quality and computational cost. + * + * Setting a threshold of 0 guarantees deterministic correct results, but comes at exactly + * the same cost as the brute-force approach. Setting the threshold to positive values + * incurs strictly less computational cost than the brute-force approach, however the + * similarities computed will be estimates. + * + * The sampling guarantees relative-error correctness for those pairs of columns that have + * similarity greater than the given similarity threshold. + * + * To describe the guarantee, we set some notation: + * Let A be the smallest in magnitude non-zero element of this matrix. + * Let B be the largest in magnitude non-zero element of this matrix. + * Let L be the maximum number of non-zeros per row. + * + * For example, for {0,1} matrices: A=B=1. + * Another example, for the Netflix matrix: A=1, B=5 + * + * For those column pairs that are above the threshold, + * the computed similarity is correct to within 20% relative error with probability + * at least 1 - (0.981)^10/B^ + * + * The shuffle size is bounded by the *smaller* of the following two expressions: + * + * O(n log(n) L / (threshold * A)) + * O(m L^2^) + * + * The latter is the cost of the brute-force approach, so for non-zero thresholds, + * the cost is always cheaper than the brute-force approach. + * + * @param threshold Set to 0 for deterministic guaranteed correctness. + * Similarities above this threshold are estimated + * with the cost vs estimate quality trade-off described above. + * @return An n x n sparse upper-triangular matrix of cosine similarities + * between columns of this matrix. + */ + def columnSimilarities(threshold: Double): CoordinateMatrix = { + require(threshold >= 0, s"Threshold cannot be negative: $threshold") + + if (threshold > 1) { + logWarning(s"Threshold is greater than 1: $threshold " + + "Computation will be more efficient with promoted sparsity, " + + " however there is no correctness guarantee.") + } + + val gamma = if (threshold < 1e-6) { + Double.PositiveInfinity + } else { + 10 * math.log(numCols()) / threshold + } + + columnSimilaritiesDIMSUM(computeColumnSummaryStatistics().normL2.toArray, gamma) + } + + /** + * Find all similar columns using the DIMSUM sampling algorithm, described in two papers + * + * http://arxiv.org/abs/1206.2082 + * http://arxiv.org/abs/1304.1467 + * + * @param colMags A vector of column magnitudes + * @param gamma The oversampling parameter. For provable results, set to 10 * log(n) / s, + * where s is the smallest similarity score to be estimated, + * and n is the number of columns + * @return An n x n sparse upper-triangular matrix of cosine similarities + * between columns of this matrix. + */ + private[mllib] def columnSimilaritiesDIMSUM( + colMags: Array[Double], + gamma: Double): CoordinateMatrix = { + require(gamma > 1.0, s"Oversampling should be greater than 1: $gamma") + require(colMags.size == this.numCols(), "Number of magnitudes didn't match column dimension") + val sg = math.sqrt(gamma) // sqrt(gamma) used many times + + // Don't divide by zero for those columns with zero magnitude + val colMagsCorrected = colMags.map(x => if (x == 0) 1.0 else x) + + val sc = rows.context + val pBV = sc.broadcast(colMagsCorrected.map(c => sg / c)) + val qBV = sc.broadcast(colMagsCorrected.map(c => math.min(sg, c))) + + val sims = rows.mapPartitionsWithIndex { (indx, iter) => + val p = pBV.value + val q = qBV.value + + val rand = new XORShiftRandom(indx) + val scaled = new Array[Double](p.size) + iter.flatMap { row => + val buf = new ListBuffer[((Int, Int), Double)]() + row match { + case sv: SparseVector => + val nnz = sv.indices.size + var k = 0 + while (k < nnz) { + scaled(k) = sv.values(k) / q(sv.indices(k)) + k += 1 + } + k = 0 + while (k < nnz) { + val i = sv.indices(k) + val iVal = scaled(k) + if (iVal != 0 && rand.nextDouble() < p(i)) { + var l = k + 1 + while (l < nnz) { + val j = sv.indices(l) + val jVal = scaled(l) + if (jVal != 0 && rand.nextDouble() < p(j)) { + buf += (((i, j), iVal * jVal)) + } + l += 1 + } + } + k += 1 + } + case dv: DenseVector => + val n = dv.values.size + var i = 0 + while (i < n) { + scaled(i) = dv.values(i) / q(i) + i += 1 + } + i = 0 + while (i < n) { + val iVal = scaled(i) + if (iVal != 0 && rand.nextDouble() < p(i)) { + var j = i + 1 + while (j < n) { + val jVal = scaled(j) + if (jVal != 0 && rand.nextDouble() < p(j)) { + buf += (((i, j), iVal * jVal)) + } + j += 1 + } + } + i += 1 + } + } + buf + } + }.reduceByKey(_ + _).map { case ((i, j), sim) => + MatrixEntry(i.toLong, j.toLong, sim) + } + new CoordinateMatrix(sims, numCols(), numCols()) + } + private[mllib] override def toBreeze(): BDM[Double] = { val m = numRows().toInt val n = numCols().toInt diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala index 7d845c44365dd..3025d4837cab4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala @@ -42,6 +42,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S private var n = 0 private var currMean: BDV[Double] = _ private var currM2n: BDV[Double] = _ + private var currM2: BDV[Double] = _ + private var currL1: BDV[Double] = _ private var totalCnt: Long = 0 private var nnz: BDV[Double] = _ private var currMax: BDV[Double] = _ @@ -60,6 +62,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S currMean = BDV.zeros[Double](n) currM2n = BDV.zeros[Double](n) + currM2 = BDV.zeros[Double](n) + currL1 = BDV.zeros[Double](n) nnz = BDV.zeros[Double](n) currMax = BDV.fill(n)(Double.MinValue) currMin = BDV.fill(n)(Double.MaxValue) @@ -81,6 +85,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S val tmpPrevMean = currMean(i) currMean(i) = (currMean(i) * nnz(i) + value) / (nnz(i) + 1.0) currM2n(i) += (value - currMean(i)) * (value - tmpPrevMean) + currM2(i) += value * value + currL1(i) += math.abs(value) nnz(i) += 1.0 } @@ -97,7 +103,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S * @return This MultivariateOnlineSummarizer object. */ def merge(other: MultivariateOnlineSummarizer): this.type = { - if (this.totalCnt != 0 && other.totalCnt != 0) { + if (this.totalCnt != 0 && other.totalCnt != 0) { require(n == other.n, s"Dimensions mismatch when merging with another summarizer. " + s"Expecting $n but got ${other.n}.") totalCnt += other.totalCnt @@ -114,6 +120,15 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S currM2n(i) += other.currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) * other.nnz(i) / (nnz(i) + other.nnz(i)) } + // merge m2 together + if (nnz(i) + other.nnz(i) != 0.0) { + currM2(i) += other.currM2(i) + } + // merge l1 together + if (nnz(i) + other.nnz(i) != 0.0) { + currL1(i) += other.currL1(i) + } + if (currMax(i) < other.currMax(i)) { currMax(i) = other.currMax(i) } @@ -127,6 +142,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S this.n = other.n this.currMean = other.currMean.copy this.currM2n = other.currM2n.copy + this.currM2 = other.currM2.copy + this.currL1 = other.currL1.copy this.totalCnt = other.totalCnt this.nnz = other.nnz.copy this.currMax = other.currMax.copy @@ -198,4 +215,23 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S } Vectors.fromBreeze(currMin) } + + override def normL2: Vector = { + require(totalCnt > 0, s"Nothing has been added to this summarizer.") + + val realMagnitude = BDV.zeros[Double](n) + + var i = 0 + while (i < currM2.size) { + realMagnitude(i) = math.sqrt(currM2(i)) + i += 1 + } + + Vectors.fromBreeze(realMagnitude) + } + + override def normL1: Vector = { + require(totalCnt > 0, s"Nothing has been added to this summarizer.") + Vectors.fromBreeze(currL1) + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala index f9eb343da2b82..6a364c93284af 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala @@ -53,4 +53,14 @@ trait MultivariateStatisticalSummary { * Minimum value of each column. */ def min: Vector + + /** + * Euclidean magnitude of each column + */ + def normL2: Vector + + /** + * L1 norm of each column + */ + def normL1: Vector } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala index 1d3a3221365cc..63f3ed58c0d4d 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala @@ -95,6 +95,40 @@ class RowMatrixSuite extends FunSuite with LocalSparkContext { } } + test("similar columns") { + val colMags = Vectors.dense(Math.sqrt(126), Math.sqrt(66), Math.sqrt(94)) + val expected = BDM( + (0.0, 54.0, 72.0), + (0.0, 0.0, 78.0), + (0.0, 0.0, 0.0)) + + for (i <- 0 until n; j <- 0 until n) { + expected(i, j) /= (colMags(i) * colMags(j)) + } + + for (mat <- Seq(denseMat, sparseMat)) { + val G = mat.columnSimilarities(0.11).toBreeze() + for (i <- 0 until n; j <- 0 until n) { + if (expected(i, j) > 0) { + val actual = expected(i, j) + val estimate = G(i, j) + assert(math.abs(actual - estimate) / actual < 0.2, + s"Similarities not close enough: $actual vs $estimate") + } + } + } + + for (mat <- Seq(denseMat, sparseMat)) { + val G = mat.columnSimilarities() + assert(closeToZero(G.toBreeze() - expected)) + } + + for (mat <- Seq(denseMat, sparseMat)) { + val G = mat.columnSimilaritiesDIMSUM(colMags.toArray, 150.0) + assert(closeToZero(G.toBreeze() - expected)) + } + } + test("svd of a full-rank matrix") { for (mat <- Seq(denseMat, sparseMat)) { for (mode <- Seq("auto", "local-svd", "local-eigs", "dist-eigs")) { @@ -190,6 +224,9 @@ class RowMatrixSuite extends FunSuite with LocalSparkContext { assert(summary.numNonzeros === Vectors.dense(3.0, 3.0, 4.0), "nnz mismatch") assert(summary.max === Vectors.dense(9.0, 7.0, 8.0), "max mismatch") assert(summary.min === Vectors.dense(0.0, 0.0, 1.0), "column mismatch.") + assert(summary.normL2 === Vectors.dense(Math.sqrt(126), Math.sqrt(66), Math.sqrt(94)), + "magnitude mismatch.") + assert(summary.normL1 === Vectors.dense(18.0, 12.0, 16.0), "L1 norm mismatch") } } } diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 3280e662fa0b1..1adfaa18c6202 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -39,7 +39,14 @@ object MimaExcludes { MimaBuild.excludeSparkPackage("graphx") ) ++ MimaBuild.excludeSparkClass("mllib.linalg.Matrix") ++ - MimaBuild.excludeSparkClass("mllib.linalg.Vector") + MimaBuild.excludeSparkClass("mllib.linalg.Vector") ++ + Seq( + // Added normL1 and normL2 to trait MultivariateStatisticalSummary + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.stat.MultivariateStatisticalSummary.normL1"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.stat.MultivariateStatisticalSummary.normL2") + ) case v if v.startsWith("1.1") => Seq( From dab1b0ae29a6d3017bdca23464f22a51d51eaae1 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Mon, 29 Sep 2014 11:25:32 -0700 Subject: [PATCH 6/8] [SPARK-3032][Shuffle] Fix key comparison integer overflow introduced sorting exception Previous key comparison in `ExternalSorter` will get wrong sorting result or exception when key comparison overflows, details can be seen in [SPARK-3032](https://issues.apache.org/jira/browse/SPARK-3032). Here fix this and add a unit test to prove it. Author: jerryshao Closes #2514 from jerryshao/SPARK-3032 and squashes the following commits: 6f3c302 [jerryshao] Improve the unit test according to comments 01911e6 [jerryshao] Change the test to show the contract violate exception 83acb38 [jerryshao] Minor changes according to comments fa2a08f [jerryshao] Fix key comparison integer overflow introduced sorting exception --- .../util/collection/ExternalSorter.scala | 2 +- .../util/collection/ExternalSorterSuite.scala | 55 +++++++++++++++++++ 2 files changed, 56 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 0a152cb97ad9e..644fa36818647 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -144,7 +144,7 @@ private[spark] class ExternalSorter[K, V, C]( override def compare(a: K, b: K): Int = { val h1 = if (a == null) 0 else a.hashCode() val h2 = if (b == null) 0 else b.hashCode() - h1 - h2 + if (h1 < h2) -1 else if (h1 == h2) 0 else 1 } }) diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala index 706faed980f31..f26e40fbd4b36 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala @@ -24,6 +24,8 @@ import org.scalatest.{PrivateMethodTester, FunSuite} import org.apache.spark._ import org.apache.spark.SparkContext._ +import scala.util.Random + class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMethodTester { private def createSparkConf(loadDefaults: Boolean): SparkConf = { val conf = new SparkConf(loadDefaults) @@ -707,4 +709,57 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe Some(agg), Some(new HashPartitioner(FEW_PARTITIONS)), None, None) assertDidNotBypassMergeSort(sorter4) } + + test("sort without breaking sorting contracts") { + val conf = createSparkConf(true) + conf.set("spark.shuffle.memoryFraction", "0.01") + conf.set("spark.shuffle.manager", "sort") + sc = new SparkContext("local-cluster[1,1,512]", "test", conf) + + // Using wrongOrdering to show integer overflow introduced exception. + val rand = new Random(100L) + val wrongOrdering = new Ordering[String] { + override def compare(a: String, b: String) = { + val h1 = if (a == null) 0 else a.hashCode() + val h2 = if (b == null) 0 else b.hashCode() + h1 - h2 + } + } + + val testData = Array.tabulate(100000) { _ => rand.nextInt().toString } + + val sorter1 = new ExternalSorter[String, String, String]( + None, None, Some(wrongOrdering), None) + val thrown = intercept[IllegalArgumentException] { + sorter1.insertAll(testData.iterator.map(i => (i, i))) + sorter1.iterator + } + + assert(thrown.getClass() === classOf[IllegalArgumentException]) + assert(thrown.getMessage().contains("Comparison method violates its general contract")) + sorter1.stop() + + // Using aggregation and external spill to make sure ExternalSorter using + // partitionKeyComparator. + def createCombiner(i: String) = ArrayBuffer(i) + def mergeValue(c: ArrayBuffer[String], i: String) = c += i + def mergeCombiners(c1: ArrayBuffer[String], c2: ArrayBuffer[String]) = c1 ++= c2 + + val agg = new Aggregator[String, String, ArrayBuffer[String]]( + createCombiner, mergeValue, mergeCombiners) + + val sorter2 = new ExternalSorter[String, String, ArrayBuffer[String]]( + Some(agg), None, None, None) + sorter2.insertAll(testData.iterator.map(i => (i, i))) + + // To validate the hash ordering of key + var minKey = Int.MinValue + sorter2.iterator.foreach { case (k, v) => + val h = k.hashCode() + assert(h >= minKey) + minKey = h + } + + sorter2.stop() + } } From e43c72fe04d4fbf2a108b456d533e641b71b0a2a Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 29 Sep 2014 12:38:24 -0700 Subject: [PATCH 7/8] Add more debug message for ManagedBuffer This is to help debug the error reported at http://apache-spark-user-list.1001560.n3.nabble.com/SQL-queries-fail-in-1-2-0-SNAPSHOT-td15327.html Author: Reynold Xin Closes #2580 from rxin/buffer-debug and squashes the following commits: 5814292 [Reynold Xin] Logging close() in case close() fails. 323dfec [Reynold Xin] Add more debug message. --- .../apache/spark/network/ManagedBuffer.scala | 43 ++++++++++++++++--- .../scala/org/apache/spark/util/Utils.scala | 14 ++++++ 2 files changed, 51 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala b/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala index e990c1da6730f..a4409181ec907 100644 --- a/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala +++ b/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala @@ -17,15 +17,17 @@ package org.apache.spark.network -import java.io.{FileInputStream, RandomAccessFile, File, InputStream} +import java.io._ import java.nio.ByteBuffer import java.nio.channels.FileChannel import java.nio.channels.FileChannel.MapMode +import scala.util.Try + import com.google.common.io.ByteStreams import io.netty.buffer.{ByteBufInputStream, ByteBuf} -import org.apache.spark.util.ByteBufferInputStream +import org.apache.spark.util.{ByteBufferInputStream, Utils} /** @@ -71,18 +73,47 @@ final class FileSegmentManagedBuffer(val file: File, val offset: Long, val lengt try { channel = new RandomAccessFile(file, "r").getChannel channel.map(MapMode.READ_ONLY, offset, length) + } catch { + case e: IOException => + Try(channel.size).toOption match { + case Some(fileLen) => + throw new IOException(s"Error in reading $this (actual file length $fileLen)", e) + case None => + throw new IOException(s"Error in opening $this", e) + } } finally { if (channel != null) { - channel.close() + Utils.tryLog(channel.close()) } } } override def inputStream(): InputStream = { - val is = new FileInputStream(file) - is.skip(offset) - ByteStreams.limit(is, length) + var is: FileInputStream = null + try { + is = new FileInputStream(file) + is.skip(offset) + ByteStreams.limit(is, length) + } catch { + case e: IOException => + if (is != null) { + Utils.tryLog(is.close()) + } + Try(file.length).toOption match { + case Some(fileLen) => + throw new IOException(s"Error in reading $this (actual file length $fileLen)", e) + case None => + throw new IOException(s"Error in opening $this", e) + } + case e: Throwable => + if (is != null) { + Utils.tryLog(is.close()) + } + throw e + } } + + override def toString: String = s"${getClass.getName}($file, $offset, $length)" } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 2755887feeeff..10d440828e323 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1304,6 +1304,20 @@ private[spark] object Utils extends Logging { } } + /** Executes the given block in a Try, logging any uncaught exceptions. */ + def tryLog[T](f: => T): Try[T] = { + try { + val res = f + scala.util.Success(res) + } catch { + case ct: ControlThrowable => + throw ct + case t: Throwable => + logError(s"Uncaught exception in thread ${Thread.currentThread().getName}", t) + scala.util.Failure(t) + } + } + /** Returns true if the given exception was fatal. See docs for scala.util.control.NonFatal. */ def isFatalError(e: Throwable): Boolean = { e match { From 0bbe7faeffa17577ae8a33dfcd8c4c783db5c909 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?baishuo=28=E7=99=BD=E7=A1=95=29?= Date: Mon, 29 Sep 2014 15:51:55 -0700 Subject: [PATCH 8/8] [SPARK-3007][SQL]Add Dynamic Partition support to Spark Sql hive MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit a new PR base on new master. changes are the same as https://github.com/apache/spark/pull/1919 Author: baishuo(白硕) Author: baishuo Author: Cheng Lian Closes #2226 from baishuo/patch-3007 and squashes the following commits: e69ce88 [Cheng Lian] Adds tests to verify dynamic partitioning folder layout b20a3dc [Cheng Lian] Addresses @yhuai's comments 096bbbc [baishuo(白硕)] Merge pull request #1 from liancheng/refactor-dp 1093c20 [Cheng Lian] Adds more tests 5004542 [Cheng Lian] Minor refactoring fae9eff [Cheng Lian] Refactors InsertIntoHiveTable to a Command 528e84c [Cheng Lian] Fixes typo in test name, regenerated golden answer files c464b26 [Cheng Lian] Refactors dynamic partitioning support 5033928 [baishuo] pass check style 2201c75 [baishuo] use HiveConf.DEFAULTPARTITIONNAME to replace hive.exec.default.partition.name b47c9bf [baishuo] modify according micheal's advice c3ab36d [baishuo] modify for some bad indentation 7ce2d9f [baishuo] modify code to pass scala style checks 37c1c43 [baishuo] delete a empty else branch 66e33fc [baishuo] do a little modify 88d0110 [baishuo] update file after test a3961d9 [baishuo(白硕)] Update Cast.scala f7467d0 [baishuo(白硕)] Update InsertIntoHiveTable.scala c1a59dd [baishuo(白硕)] Update Cast.scala 0e18496 [baishuo(白硕)] Update HiveQuerySuite.scala 60f70aa [baishuo(白硕)] Update InsertIntoHiveTable.scala 0a50db9 [baishuo(白硕)] Update HiveCompatibilitySuite.scala 491c7d0 [baishuo(白硕)] Update InsertIntoHiveTable.scala a2374a8 [baishuo(白硕)] Update InsertIntoHiveTable.scala 701a814 [baishuo(白硕)] Update SparkHadoopWriter.scala dc24c41 [baishuo(白硕)] Update HiveQl.scala --- .../execution/HiveCompatibilitySuite.scala | 17 ++ .../org/apache/spark/SparkHadoopWriter.scala | 195 ---------------- .../org/apache/spark/sql/hive/HiveQl.scala | 5 - .../hive/execution/InsertIntoHiveTable.scala | 207 +++++++++-------- .../spark/sql/hive/hiveWriterContainers.scala | 217 ++++++++++++++++++ ...rtition-0-be33aaa7253c8f248ff3921cd7dae340 | 0 ...rtition-1-640552dd462707563fd255a713f83b41 | 0 ...rtition-2-36456c9d0d2e3ef72ab5ba9ba48e5493 | 1 + ...rtition-3-b7f7fa7ebf666f4fee27e149d8c6961f | 0 ...rtition-4-8bdb71ad8cb3cc3026043def2525de3a | 0 ...rtition-5-c630dce438f3792e7fb0f523fbbb3e1e | 0 ...rtition-6-7abc9ec8a36cdc5e89e955265a7fd7cf | 0 ...rtition-7-be33aaa7253c8f248ff3921cd7dae340 | 0 .../sql/hive/execution/HiveQuerySuite.scala | 100 +++++++- 14 files changed, 443 insertions(+), 299 deletions(-) delete mode 100644 sql/hive/src/main/scala/org/apache/spark/SparkHadoopWriter.scala create mode 100644 sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala create mode 100644 sql/hive/src/test/resources/golden/dynamic_partition-0-be33aaa7253c8f248ff3921cd7dae340 create mode 100644 sql/hive/src/test/resources/golden/dynamic_partition-1-640552dd462707563fd255a713f83b41 create mode 100644 sql/hive/src/test/resources/golden/dynamic_partition-2-36456c9d0d2e3ef72ab5ba9ba48e5493 create mode 100644 sql/hive/src/test/resources/golden/dynamic_partition-3-b7f7fa7ebf666f4fee27e149d8c6961f create mode 100644 sql/hive/src/test/resources/golden/dynamic_partition-4-8bdb71ad8cb3cc3026043def2525de3a create mode 100644 sql/hive/src/test/resources/golden/dynamic_partition-5-c630dce438f3792e7fb0f523fbbb3e1e create mode 100644 sql/hive/src/test/resources/golden/dynamic_partition-6-7abc9ec8a36cdc5e89e955265a7fd7cf create mode 100644 sql/hive/src/test/resources/golden/dynamic_partition-7-be33aaa7253c8f248ff3921cd7dae340 diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 556c984ad392b..35e9c9939d4b7 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -220,6 +220,23 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { */ override def whiteList = Seq( "add_part_exist", + "dynamic_partition_skip_default", + "infer_bucket_sort_dyn_part", + "load_dyn_part1", + "load_dyn_part2", + "load_dyn_part3", + "load_dyn_part4", + "load_dyn_part5", + "load_dyn_part6", + "load_dyn_part7", + "load_dyn_part8", + "load_dyn_part9", + "load_dyn_part10", + "load_dyn_part11", + "load_dyn_part12", + "load_dyn_part13", + "load_dyn_part14", + "load_dyn_part14_win", "add_part_multiple", "add_partition_no_whitelist", "add_partition_with_whitelist", diff --git a/sql/hive/src/main/scala/org/apache/spark/SparkHadoopWriter.scala b/sql/hive/src/main/scala/org/apache/spark/SparkHadoopWriter.scala deleted file mode 100644 index ab7862f4f9e06..0000000000000 --- a/sql/hive/src/main/scala/org/apache/spark/SparkHadoopWriter.scala +++ /dev/null @@ -1,195 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive - -import java.io.IOException -import java.text.NumberFormat -import java.util.Date - -import org.apache.hadoop.fs.Path -import org.apache.hadoop.hive.ql.exec.{FileSinkOperator, Utilities} -import org.apache.hadoop.hive.ql.io.{HiveFileFormatUtils, HiveOutputFormat} -import org.apache.hadoop.hive.ql.plan.FileSinkDesc -import org.apache.hadoop.mapred._ -import org.apache.hadoop.io.Writable - -import org.apache.spark.{Logging, SerializableWritable, SparkHadoopWriter} - -/** - * Internal helper class that saves an RDD using a Hive OutputFormat. - * It is based on [[SparkHadoopWriter]]. - */ -private[hive] class SparkHiveHadoopWriter( - @transient jobConf: JobConf, - fileSinkConf: FileSinkDesc) - extends Logging - with SparkHadoopMapRedUtil - with Serializable { - - private val now = new Date() - private val conf = new SerializableWritable(jobConf) - - private var jobID = 0 - private var splitID = 0 - private var attemptID = 0 - private var jID: SerializableWritable[JobID] = null - private var taID: SerializableWritable[TaskAttemptID] = null - - @transient private var writer: FileSinkOperator.RecordWriter = null - @transient private var format: HiveOutputFormat[AnyRef, Writable] = null - @transient private var committer: OutputCommitter = null - @transient private var jobContext: JobContext = null - @transient private var taskContext: TaskAttemptContext = null - - def preSetup() { - setIDs(0, 0, 0) - setConfParams() - - val jCtxt = getJobContext() - getOutputCommitter().setupJob(jCtxt) - } - - - def setup(jobid: Int, splitid: Int, attemptid: Int) { - setIDs(jobid, splitid, attemptid) - setConfParams() - } - - def open() { - val numfmt = NumberFormat.getInstance() - numfmt.setMinimumIntegerDigits(5) - numfmt.setGroupingUsed(false) - - val extension = Utilities.getFileExtension( - conf.value, - fileSinkConf.getCompressed, - getOutputFormat()) - - val outputName = "part-" + numfmt.format(splitID) + extension - val path = FileOutputFormat.getTaskOutputPath(conf.value, outputName) - - getOutputCommitter().setupTask(getTaskContext()) - writer = HiveFileFormatUtils.getHiveRecordWriter( - conf.value, - fileSinkConf.getTableInfo, - conf.value.getOutputValueClass.asInstanceOf[Class[Writable]], - fileSinkConf, - path, - null) - } - - def write(value: Writable) { - if (writer != null) { - writer.write(value) - } else { - throw new IOException("Writer is null, open() has not been called") - } - } - - def close() { - // Seems the boolean value passed into close does not matter. - writer.close(false) - } - - def commit() { - val taCtxt = getTaskContext() - val cmtr = getOutputCommitter() - if (cmtr.needsTaskCommit(taCtxt)) { - try { - cmtr.commitTask(taCtxt) - logInfo (taID + ": Committed") - } catch { - case e: IOException => - logError("Error committing the output of task: " + taID.value, e) - cmtr.abortTask(taCtxt) - throw e - } - } else { - logWarning ("No need to commit output of task: " + taID.value) - } - } - - def commitJob() { - // always ? Or if cmtr.needsTaskCommit ? - val cmtr = getOutputCommitter() - cmtr.commitJob(getJobContext()) - } - - // ********* Private Functions ********* - - private def getOutputFormat(): HiveOutputFormat[AnyRef,Writable] = { - if (format == null) { - format = conf.value.getOutputFormat() - .asInstanceOf[HiveOutputFormat[AnyRef,Writable]] - } - format - } - - private def getOutputCommitter(): OutputCommitter = { - if (committer == null) { - committer = conf.value.getOutputCommitter - } - committer - } - - private def getJobContext(): JobContext = { - if (jobContext == null) { - jobContext = newJobContext(conf.value, jID.value) - } - jobContext - } - - private def getTaskContext(): TaskAttemptContext = { - if (taskContext == null) { - taskContext = newTaskAttemptContext(conf.value, taID.value) - } - taskContext - } - - private def setIDs(jobId: Int, splitId: Int, attemptId: Int) { - jobID = jobId - splitID = splitId - attemptID = attemptId - - jID = new SerializableWritable[JobID](SparkHadoopWriter.createJobID(now, jobId)) - taID = new SerializableWritable[TaskAttemptID]( - new TaskAttemptID(new TaskID(jID.value, true, splitID), attemptID)) - } - - private def setConfParams() { - conf.value.set("mapred.job.id", jID.value.toString) - conf.value.set("mapred.tip.id", taID.value.getTaskID.toString) - conf.value.set("mapred.task.id", taID.value.toString) - conf.value.setBoolean("mapred.task.is.map", true) - conf.value.setInt("mapred.task.partition", splitID) - } -} - -private[hive] object SparkHiveHadoopWriter { - def createPathFromString(path: String, conf: JobConf): Path = { - if (path == null) { - throw new IllegalArgumentException("Output path is null") - } - val outputPath = new Path(path) - val fs = outputPath.getFileSystem(conf) - if (outputPath == null || fs == null) { - throw new IllegalArgumentException("Incorrectly formatted output path") - } - outputPath.makeQualified(fs) - } -} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 0aa6292c0184e..4e30e6e06fe21 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -837,11 +837,6 @@ private[hive] object HiveQl { cleanIdentifier(key.toLowerCase) -> None }.toMap).getOrElse(Map.empty) - if (partitionKeys.values.exists(p => p.isEmpty)) { - throw new NotImplementedError(s"Do not support INSERT INTO/OVERWRITE with" + - s"dynamic partitioning.") - } - InsertIntoTable(UnresolvedRelation(db, tableName, None), partitionKeys, query, overwrite) case a: ASTNode => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index a284a91a91e31..3d2ee010696f6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -19,27 +19,25 @@ package org.apache.spark.sql.hive.execution import scala.collection.JavaConversions._ -import java.util.{HashMap => JHashMap} - import org.apache.hadoop.hive.common.`type`.{HiveDecimal, HiveVarchar} +import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hadoop.hive.metastore.MetaStoreUtils -import org.apache.hadoop.hive.ql.Context import org.apache.hadoop.hive.ql.metadata.Hive import org.apache.hadoop.hive.ql.plan.{FileSinkDesc, TableDesc} +import org.apache.hadoop.hive.ql.{Context, ErrorMsg} import org.apache.hadoop.hive.serde2.Serializer -import org.apache.hadoop.hive.serde2.objectinspector._ import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption -import org.apache.hadoop.hive.serde2.objectinspector.primitive.JavaHiveDecimalObjectInspector -import org.apache.hadoop.hive.serde2.objectinspector.primitive.JavaHiveVarcharObjectInspector -import org.apache.hadoop.io.Writable +import org.apache.hadoop.hive.serde2.objectinspector._ +import org.apache.hadoop.hive.serde2.objectinspector.primitive.{JavaHiveDecimalObjectInspector, JavaHiveVarcharObjectInspector} import org.apache.hadoop.mapred.{FileOutputCommitter, FileOutputFormat, JobConf} -import org.apache.spark.{SparkException, TaskContext} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions.Row -import org.apache.spark.sql.execution.{SparkPlan, UnaryNode} -import org.apache.spark.sql.hive.{HiveContext, MetastoreRelation, SparkHiveHadoopWriter} +import org.apache.spark.sql.execution.{Command, SparkPlan, UnaryNode} +import org.apache.spark.sql.hive._ +import org.apache.spark.{SerializableWritable, SparkException, TaskContext} /** * :: DeveloperApi :: @@ -51,7 +49,7 @@ case class InsertIntoHiveTable( child: SparkPlan, overwrite: Boolean) (@transient sc: HiveContext) - extends UnaryNode { + extends UnaryNode with Command { @transient lazy val outputClass = newSerializer(table.tableDesc).getSerializedClass @transient private lazy val hiveContext = new Context(sc.hiveconf) @@ -101,66 +99,74 @@ case class InsertIntoHiveTable( } def saveAsHiveFile( - rdd: RDD[Writable], + rdd: RDD[Row], valueClass: Class[_], fileSinkConf: FileSinkDesc, - conf: JobConf, - isCompressed: Boolean) { - if (valueClass == null) { - throw new SparkException("Output value class not set") - } - conf.setOutputValueClass(valueClass) - if (fileSinkConf.getTableInfo.getOutputFileFormatClassName == null) { - throw new SparkException("Output format class not set") - } - // Doesn't work in Scala 2.9 due to what may be a generics bug - // TODO: Should we uncomment this for Scala 2.10? - // conf.setOutputFormat(outputFormatClass) - conf.set("mapred.output.format.class", fileSinkConf.getTableInfo.getOutputFileFormatClassName) + conf: SerializableWritable[JobConf], + writerContainer: SparkHiveWriterContainer) { + assert(valueClass != null, "Output value class not set") + conf.value.setOutputValueClass(valueClass) + + val outputFileFormatClassName = fileSinkConf.getTableInfo.getOutputFileFormatClassName + assert(outputFileFormatClassName != null, "Output format class not set") + conf.value.set("mapred.output.format.class", outputFileFormatClassName) + + val isCompressed = conf.value.getBoolean( + ConfVars.COMPRESSRESULT.varname, ConfVars.COMPRESSRESULT.defaultBoolVal) + if (isCompressed) { // Please note that isCompressed, "mapred.output.compress", "mapred.output.compression.codec", // and "mapred.output.compression.type" have no impact on ORC because it uses table properties // to store compression information. - conf.set("mapred.output.compress", "true") + conf.value.set("mapred.output.compress", "true") fileSinkConf.setCompressed(true) - fileSinkConf.setCompressCodec(conf.get("mapred.output.compression.codec")) - fileSinkConf.setCompressType(conf.get("mapred.output.compression.type")) + fileSinkConf.setCompressCodec(conf.value.get("mapred.output.compression.codec")) + fileSinkConf.setCompressType(conf.value.get("mapred.output.compression.type")) } - conf.setOutputCommitter(classOf[FileOutputCommitter]) - FileOutputFormat.setOutputPath( - conf, - SparkHiveHadoopWriter.createPathFromString(fileSinkConf.getDirName, conf)) + conf.value.setOutputCommitter(classOf[FileOutputCommitter]) + FileOutputFormat.setOutputPath( + conf.value, + SparkHiveWriterContainer.createPathFromString(fileSinkConf.getDirName, conf.value)) log.debug("Saving as hadoop file of type " + valueClass.getSimpleName) - val writer = new SparkHiveHadoopWriter(conf, fileSinkConf) - writer.preSetup() + writerContainer.driverSideSetup() + sc.sparkContext.runJob(rdd, writeToFile _) + writerContainer.commitJob() + + // Note that this function is executed on executor side + def writeToFile(context: TaskContext, iterator: Iterator[Row]) { + val serializer = newSerializer(fileSinkConf.getTableInfo) + val standardOI = ObjectInspectorUtils + .getStandardObjectInspector( + fileSinkConf.getTableInfo.getDeserializer.getObjectInspector, + ObjectInspectorCopyOption.JAVA) + .asInstanceOf[StructObjectInspector] + + val fieldOIs = standardOI.getAllStructFieldRefs.map(_.getFieldObjectInspector).toArray + val outputData = new Array[Any](fieldOIs.length) - def writeToFile(context: TaskContext, iter: Iterator[Writable]) { // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it // around by taking a mod. We expect that no task will be attempted 2 billion times. val attemptNumber = (context.attemptId % Int.MaxValue).toInt + writerContainer.executorSideSetup(context.stageId, context.partitionId, attemptNumber) - writer.setup(context.stageId, context.partitionId, attemptNumber) - writer.open() + iterator.foreach { row => + var i = 0 + while (i < fieldOIs.length) { + // TODO (lian) avoid per row dynamic dispatching and pattern matching cost in `wrap` + outputData(i) = wrap(row(i), fieldOIs(i)) + i += 1 + } - var count = 0 - while(iter.hasNext) { - val record = iter.next() - count += 1 - writer.write(record) + val writer = writerContainer.getLocalFileWriter(row) + writer.write(serializer.serialize(outputData, standardOI)) } - writer.close() - writer.commit() + writerContainer.close() } - - sc.sparkContext.runJob(rdd, writeToFile _) - writer.commitJob() } - override def execute() = result - /** * Inserts all the rows in the table into Hive. Row objects are properly serialized with the * `org.apache.hadoop.hive.serde2.SerDe` and the @@ -168,50 +174,57 @@ case class InsertIntoHiveTable( * * Note: this is run once and then kept to avoid double insertions. */ - private lazy val result: RDD[Row] = { - val childRdd = child.execute() - assert(childRdd != null) - + override protected[sql] lazy val sideEffectResult: Seq[Row] = { // Have to pass the TableDesc object to RDD.mapPartitions and then instantiate new serializer // instances within the closure, since Serializer is not serializable while TableDesc is. val tableDesc = table.tableDesc val tableLocation = table.hiveQlTable.getDataLocation val tmpLocation = hiveContext.getExternalTmpFileURI(tableLocation) val fileSinkConf = new FileSinkDesc(tmpLocation.toString, tableDesc, false) - val rdd = childRdd.mapPartitions { iter => - val serializer = newSerializer(fileSinkConf.getTableInfo) - val standardOI = ObjectInspectorUtils - .getStandardObjectInspector( - fileSinkConf.getTableInfo.getDeserializer.getObjectInspector, - ObjectInspectorCopyOption.JAVA) - .asInstanceOf[StructObjectInspector] + val numDynamicPartitions = partition.values.count(_.isEmpty) + val numStaticPartitions = partition.values.count(_.nonEmpty) + val partitionSpec = partition.map { + case (key, Some(value)) => key -> value + case (key, None) => key -> "" + } - val fieldOIs = standardOI.getAllStructFieldRefs.map(_.getFieldObjectInspector).toArray - val outputData = new Array[Any](fieldOIs.length) - iter.map { row => - var i = 0 - while (i < row.length) { - // Casts Strings to HiveVarchars when necessary. - outputData(i) = wrap(row(i), fieldOIs(i)) - i += 1 - } + // All partition column names in the format of "//..." + val partitionColumns = fileSinkConf.getTableInfo.getProperties.getProperty("partition_columns") + val partitionColumnNames = Option(partitionColumns).map(_.split("/")).orNull - serializer.serialize(outputData, standardOI) + // Validate partition spec if there exist any dynamic partitions + if (numDynamicPartitions > 0) { + // Report error if dynamic partitioning is not enabled + if (!sc.hiveconf.getBoolVar(HiveConf.ConfVars.DYNAMICPARTITIONING)) { + throw new SparkException(ErrorMsg.DYNAMIC_PARTITION_DISABLED.getMsg) + } + + // Report error if dynamic partition strict mode is on but no static partition is found + if (numStaticPartitions == 0 && + sc.hiveconf.getVar(HiveConf.ConfVars.DYNAMICPARTITIONINGMODE).equalsIgnoreCase("strict")) { + throw new SparkException(ErrorMsg.DYNAMIC_PARTITION_STRICT_MODE.getMsg) + } + + // Report error if any static partition appears after a dynamic partition + val isDynamic = partitionColumnNames.map(partitionSpec(_).isEmpty) + isDynamic.init.zip(isDynamic.tail).find(_ == (true, false)).foreach { _ => + throw new SparkException(ErrorMsg.PARTITION_DYN_STA_ORDER.getMsg) } } - // ORC stores compression information in table properties. While, there are other formats - // (e.g. RCFile) that rely on hadoop configurations to store compression information. val jobConf = new JobConf(sc.hiveconf) - saveAsHiveFile( - rdd, - outputClass, - fileSinkConf, - jobConf, - sc.hiveconf.getBoolean("hive.exec.compress.output", false)) - - // TODO: Handle dynamic partitioning. + val jobConfSer = new SerializableWritable(jobConf) + + val writerContainer = if (numDynamicPartitions > 0) { + val dynamicPartColNames = partitionColumnNames.takeRight(numDynamicPartitions) + new SparkHiveDynamicPartitionWriterContainer(jobConf, fileSinkConf, dynamicPartColNames) + } else { + new SparkHiveWriterContainer(jobConf, fileSinkConf) + } + + saveAsHiveFile(child.execute(), outputClass, fileSinkConf, jobConfSer, writerContainer) + val outputPath = FileOutputFormat.getOutputPath(jobConf) // Have to construct the format of dbname.tablename. val qualifiedTableName = s"${table.databaseName}.${table.tableName}" @@ -220,10 +233,6 @@ case class InsertIntoHiveTable( // holdDDLTime will be true when TOK_HOLD_DDLTIME presents in the query as a hint. val holdDDLTime = false if (partition.nonEmpty) { - val partitionSpec = partition.map { - case (key, Some(value)) => key -> value - case (key, None) => key -> "" // Should not reach here right now. - } val partVals = MetaStoreUtils.getPvals(table.hiveQlTable.getPartCols, partitionSpec) db.validatePartitionNameCharacters(partVals) // inheritTableSpecs is set to true. It should be set to false for a IMPORT query @@ -231,14 +240,26 @@ case class InsertIntoHiveTable( val inheritTableSpecs = true // TODO: Correctly set isSkewedStoreAsSubdir. val isSkewedStoreAsSubdir = false - db.loadPartition( - outputPath, - qualifiedTableName, - partitionSpec, - overwrite, - holdDDLTime, - inheritTableSpecs, - isSkewedStoreAsSubdir) + if (numDynamicPartitions > 0) { + db.loadDynamicPartitions( + outputPath, + qualifiedTableName, + partitionSpec, + overwrite, + numDynamicPartitions, + holdDDLTime, + isSkewedStoreAsSubdir + ) + } else { + db.loadPartition( + outputPath, + qualifiedTableName, + partitionSpec, + overwrite, + holdDDLTime, + inheritTableSpecs, + isSkewedStoreAsSubdir) + } } else { db.loadTable( outputPath, @@ -251,6 +272,6 @@ case class InsertIntoHiveTable( // however for now we return an empty list to simplify compatibility checks with hive, which // does not return anything for insert operations. // TODO: implement hive compatibility as rules. - sc.sparkContext.makeRDD(Nil, 1) + Seq.empty[Row] } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala new file mode 100644 index 0000000000000..a667188fa53bd --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala @@ -0,0 +1,217 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import java.io.IOException +import java.text.NumberFormat +import java.util.Date + +import scala.collection.mutable + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.hive.conf.HiveConf.ConfVars +import org.apache.hadoop.hive.ql.exec.{FileSinkOperator, Utilities} +import org.apache.hadoop.hive.ql.io.{HiveFileFormatUtils, HiveOutputFormat} +import org.apache.hadoop.hive.ql.plan.FileSinkDesc +import org.apache.hadoop.io.Writable +import org.apache.hadoop.mapred._ + +import org.apache.spark.sql.Row +import org.apache.spark.{Logging, SerializableWritable, SparkHadoopWriter} + +/** + * Internal helper class that saves an RDD using a Hive OutputFormat. + * It is based on [[SparkHadoopWriter]]. + */ +private[hive] class SparkHiveWriterContainer( + @transient jobConf: JobConf, + fileSinkConf: FileSinkDesc) + extends Logging + with SparkHadoopMapRedUtil + with Serializable { + + private val now = new Date() + protected val conf = new SerializableWritable(jobConf) + + private var jobID = 0 + private var splitID = 0 + private var attemptID = 0 + private var jID: SerializableWritable[JobID] = null + private var taID: SerializableWritable[TaskAttemptID] = null + + @transient private var writer: FileSinkOperator.RecordWriter = null + @transient private lazy val committer = conf.value.getOutputCommitter + @transient private lazy val jobContext = newJobContext(conf.value, jID.value) + @transient private lazy val taskContext = newTaskAttemptContext(conf.value, taID.value) + @transient private lazy val outputFormat = + conf.value.getOutputFormat.asInstanceOf[HiveOutputFormat[AnyRef,Writable]] + + def driverSideSetup() { + setIDs(0, 0, 0) + setConfParams() + committer.setupJob(jobContext) + } + + def executorSideSetup(jobId: Int, splitId: Int, attemptId: Int) { + setIDs(jobId, splitId, attemptId) + setConfParams() + committer.setupTask(taskContext) + initWriters() + } + + protected def getOutputName: String = { + val numberFormat = NumberFormat.getInstance() + numberFormat.setMinimumIntegerDigits(5) + numberFormat.setGroupingUsed(false) + val extension = Utilities.getFileExtension(conf.value, fileSinkConf.getCompressed, outputFormat) + "part-" + numberFormat.format(splitID) + extension + } + + def getLocalFileWriter(row: Row): FileSinkOperator.RecordWriter = writer + + def close() { + // Seems the boolean value passed into close does not matter. + writer.close(false) + commit() + } + + def commitJob() { + committer.commitJob(jobContext) + } + + protected def initWriters() { + // NOTE this method is executed at the executor side. + // For Hive tables without partitions or with only static partitions, only 1 writer is needed. + writer = HiveFileFormatUtils.getHiveRecordWriter( + conf.value, + fileSinkConf.getTableInfo, + conf.value.getOutputValueClass.asInstanceOf[Class[Writable]], + fileSinkConf, + FileOutputFormat.getTaskOutputPath(conf.value, getOutputName), + Reporter.NULL) + } + + protected def commit() { + if (committer.needsTaskCommit(taskContext)) { + try { + committer.commitTask(taskContext) + logInfo (taID + ": Committed") + } catch { + case e: IOException => + logError("Error committing the output of task: " + taID.value, e) + committer.abortTask(taskContext) + throw e + } + } else { + logInfo("No need to commit output of task: " + taID.value) + } + } + + // ********* Private Functions ********* + + private def setIDs(jobId: Int, splitId: Int, attemptId: Int) { + jobID = jobId + splitID = splitId + attemptID = attemptId + + jID = new SerializableWritable[JobID](SparkHadoopWriter.createJobID(now, jobId)) + taID = new SerializableWritable[TaskAttemptID]( + new TaskAttemptID(new TaskID(jID.value, true, splitID), attemptID)) + } + + private def setConfParams() { + conf.value.set("mapred.job.id", jID.value.toString) + conf.value.set("mapred.tip.id", taID.value.getTaskID.toString) + conf.value.set("mapred.task.id", taID.value.toString) + conf.value.setBoolean("mapred.task.is.map", true) + conf.value.setInt("mapred.task.partition", splitID) + } +} + +private[hive] object SparkHiveWriterContainer { + def createPathFromString(path: String, conf: JobConf): Path = { + if (path == null) { + throw new IllegalArgumentException("Output path is null") + } + val outputPath = new Path(path) + val fs = outputPath.getFileSystem(conf) + if (outputPath == null || fs == null) { + throw new IllegalArgumentException("Incorrectly formatted output path") + } + outputPath.makeQualified(fs.getUri, fs.getWorkingDirectory) + } +} + +private[spark] class SparkHiveDynamicPartitionWriterContainer( + @transient jobConf: JobConf, + fileSinkConf: FileSinkDesc, + dynamicPartColNames: Array[String]) + extends SparkHiveWriterContainer(jobConf, fileSinkConf) { + + private val defaultPartName = jobConf.get( + ConfVars.DEFAULTPARTITIONNAME.varname, ConfVars.DEFAULTPARTITIONNAME.defaultVal) + + @transient private var writers: mutable.HashMap[String, FileSinkOperator.RecordWriter] = _ + + override protected def initWriters(): Unit = { + // NOTE: This method is executed at the executor side. + // Actual writers are created for each dynamic partition on the fly. + writers = mutable.HashMap.empty[String, FileSinkOperator.RecordWriter] + } + + override def close(): Unit = { + writers.values.foreach(_.close(false)) + commit() + } + + override def getLocalFileWriter(row: Row): FileSinkOperator.RecordWriter = { + val dynamicPartPath = dynamicPartColNames + .zip(row.takeRight(dynamicPartColNames.length)) + .map { case (col, rawVal) => + val string = String.valueOf(rawVal) + s"/$col=${if (rawVal == null || string.isEmpty) defaultPartName else string}" + } + .mkString + + def newWriter = { + val newFileSinkDesc = new FileSinkDesc( + fileSinkConf.getDirName + dynamicPartPath, + fileSinkConf.getTableInfo, + fileSinkConf.getCompressed) + newFileSinkDesc.setCompressCodec(fileSinkConf.getCompressCodec) + newFileSinkDesc.setCompressType(fileSinkConf.getCompressType) + + val path = { + val outputPath = FileOutputFormat.getOutputPath(conf.value) + assert(outputPath != null, "Undefined job output-path") + val workPath = new Path(outputPath, dynamicPartPath.stripPrefix("/")) + new Path(workPath, getOutputName) + } + + HiveFileFormatUtils.getHiveRecordWriter( + conf.value, + fileSinkConf.getTableInfo, + conf.value.getOutputValueClass.asInstanceOf[Class[Writable]], + newFileSinkDesc, + path, + Reporter.NULL) + } + + writers.getOrElseUpdate(dynamicPartPath, newWriter) + } +} diff --git a/sql/hive/src/test/resources/golden/dynamic_partition-0-be33aaa7253c8f248ff3921cd7dae340 b/sql/hive/src/test/resources/golden/dynamic_partition-0-be33aaa7253c8f248ff3921cd7dae340 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/dynamic_partition-1-640552dd462707563fd255a713f83b41 b/sql/hive/src/test/resources/golden/dynamic_partition-1-640552dd462707563fd255a713f83b41 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/dynamic_partition-2-36456c9d0d2e3ef72ab5ba9ba48e5493 b/sql/hive/src/test/resources/golden/dynamic_partition-2-36456c9d0d2e3ef72ab5ba9ba48e5493 new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/dynamic_partition-2-36456c9d0d2e3ef72ab5ba9ba48e5493 @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/dynamic_partition-3-b7f7fa7ebf666f4fee27e149d8c6961f b/sql/hive/src/test/resources/golden/dynamic_partition-3-b7f7fa7ebf666f4fee27e149d8c6961f new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/dynamic_partition-4-8bdb71ad8cb3cc3026043def2525de3a b/sql/hive/src/test/resources/golden/dynamic_partition-4-8bdb71ad8cb3cc3026043def2525de3a new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/dynamic_partition-5-c630dce438f3792e7fb0f523fbbb3e1e b/sql/hive/src/test/resources/golden/dynamic_partition-5-c630dce438f3792e7fb0f523fbbb3e1e new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/dynamic_partition-6-7abc9ec8a36cdc5e89e955265a7fd7cf b/sql/hive/src/test/resources/golden/dynamic_partition-6-7abc9ec8a36cdc5e89e955265a7fd7cf new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/dynamic_partition-7-be33aaa7253c8f248ff3921cd7dae340 b/sql/hive/src/test/resources/golden/dynamic_partition-7-be33aaa7253c8f248ff3921cd7dae340 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 2da8a6fac3d99..5d743a51b47c5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -19,6 +19,9 @@ package org.apache.spark.sql.hive.execution import scala.util.Try +import org.apache.hadoop.hive.conf.HiveConf.ConfVars + +import org.apache.spark.SparkException import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ @@ -380,7 +383,7 @@ class HiveQuerySuite extends HiveComparisonTest { def isExplanation(result: SchemaRDD) = { val explanation = result.select('plan).collect().map { case Row(plan: String) => plan } - explanation.exists(_ == "== Physical Plan ==") + explanation.contains("== Physical Plan ==") } test("SPARK-1704: Explain commands as a SchemaRDD") { @@ -568,6 +571,91 @@ class HiveQuerySuite extends HiveComparisonTest { case class LogEntry(filename: String, message: String) case class LogFile(name: String) + createQueryTest("dynamic_partition", + """ + |DROP TABLE IF EXISTS dynamic_part_table; + |CREATE TABLE dynamic_part_table(intcol INT) PARTITIONED BY (partcol1 INT, partcol2 INT); + | + |SET hive.exec.dynamic.partition.mode=nonstrict; + | + |INSERT INTO TABLE dynamic_part_table PARTITION(partcol1, partcol2) + |SELECT 1, 1, 1 FROM src WHERE key=150; + | + |INSERT INTO TABLE dynamic_part_table PARTITION(partcol1, partcol2) + |SELECT 1, NULL, 1 FROM src WHERE key=150; + | + |INSERT INTO TABLE dynamic_part_table PARTITION(partcol1, partcol2) + |SELECT 1, 1, NULL FROM src WHERE key=150; + | + |INSERT INTO TABLe dynamic_part_table PARTITION(partcol1, partcol2) + |SELECT 1, NULL, NULL FROM src WHERE key=150; + | + |DROP TABLE IF EXISTS dynamic_part_table; + """.stripMargin) + + test("Dynamic partition folder layout") { + sql("DROP TABLE IF EXISTS dynamic_part_table") + sql("CREATE TABLE dynamic_part_table(intcol INT) PARTITIONED BY (partcol1 INT, partcol2 INT)") + sql("SET hive.exec.dynamic.partition.mode=nonstrict") + + val data = Map( + Seq("1", "1") -> 1, + Seq("1", "NULL") -> 2, + Seq("NULL", "1") -> 3, + Seq("NULL", "NULL") -> 4) + + data.foreach { case (parts, value) => + sql( + s"""INSERT INTO TABLE dynamic_part_table PARTITION(partcol1, partcol2) + |SELECT $value, ${parts.mkString(", ")} FROM src WHERE key=150 + """.stripMargin) + + val partFolder = Seq("partcol1", "partcol2") + .zip(parts) + .map { case (k, v) => + if (v == "NULL") { + s"$k=${ConfVars.DEFAULTPARTITIONNAME.defaultVal}" + } else { + s"$k=$v" + } + } + .mkString("/") + + // Loads partition data to a temporary table to verify contents + val path = s"$warehousePath/dynamic_part_table/$partFolder/part-00000" + + sql("DROP TABLE IF EXISTS dp_verify") + sql("CREATE TABLE dp_verify(intcol INT)") + sql(s"LOAD DATA LOCAL INPATH '$path' INTO TABLE dp_verify") + + assert(sql("SELECT * FROM dp_verify").collect() === Array(Row(value))) + } + } + + test("Partition spec validation") { + sql("DROP TABLE IF EXISTS dp_test") + sql("CREATE TABLE dp_test(key INT, value STRING) PARTITIONED BY (dp INT, sp INT)") + sql("SET hive.exec.dynamic.partition.mode=strict") + + // Should throw when using strict dynamic partition mode without any static partition + intercept[SparkException] { + sql( + """INSERT INTO TABLE dp_test PARTITION(dp) + |SELECT key, value, key % 5 FROM src + """.stripMargin) + } + + sql("SET hive.exec.dynamic.partition.mode=nonstrict") + + // Should throw when a static partition appears after a dynamic partition + intercept[SparkException] { + sql( + """INSERT INTO TABLE dp_test PARTITION(dp, sp = 1) + |SELECT key, value, key % 5 FROM src + """.stripMargin) + } + } + test("SPARK-3414 regression: should store analyzed logical plan when registering a temp table") { sparkContext.makeRDD(Seq.empty[LogEntry]).registerTempTable("rawLogs") sparkContext.makeRDD(Seq.empty[LogFile]).registerTempTable("logFiles") @@ -625,27 +713,27 @@ class HiveQuerySuite extends HiveComparisonTest { assert(sql("SET").collect().size == 0) assertResult(Set(testKey -> testVal)) { - collectResults(hql(s"SET $testKey=$testVal")) + collectResults(sql(s"SET $testKey=$testVal")) } assert(hiveconf.get(testKey, "") == testVal) assertResult(Set(testKey -> testVal)) { - collectResults(hql("SET")) + collectResults(sql("SET")) } sql(s"SET ${testKey + testKey}=${testVal + testVal}") assert(hiveconf.get(testKey + testKey, "") == testVal + testVal) assertResult(Set(testKey -> testVal, (testKey + testKey) -> (testVal + testVal))) { - collectResults(hql("SET")) + collectResults(sql("SET")) } // "set key" assertResult(Set(testKey -> testVal)) { - collectResults(hql(s"SET $testKey")) + collectResults(sql(s"SET $testKey")) } assertResult(Set(nonexistentKey -> "")) { - collectResults(hql(s"SET $nonexistentKey")) + collectResults(sql(s"SET $nonexistentKey")) } // Assert that sql() should have the same effects as sql() by repeating the above using sql().