From 62dc723fac20409a04b3a47bb6d6a86be03bad37 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Sun, 23 Mar 2014 18:53:42 -0700 Subject: [PATCH] updating javadoc and converting helper methods to package private to allow unit testing --- .../spark/mllib/tree/DecisionTree.scala | 26 ++++++++++++++----- 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index b003f6fe54f3b..3ab644e74df1b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -172,7 +172,11 @@ class DecisionTree private(val strategy: Strategy) extends Serializable with Log object DecisionTree extends Serializable with Logging { /** - * Method to train a decision tree model over an RDD + * Method to train a decision tree model where the instances are represented as an RDD of + * (label, features) pairs. The method supports binary classification and regression. For the + * binary classification, the label for each instance should either be 0 or 1 to denote the two + * classes. The parameters for the algorithm are specified using the strategy parameter. + * * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data * for DecisionTree * @param strategy The configuration parameters for the tree algorithm which specify the type @@ -185,7 +189,11 @@ object DecisionTree extends Serializable with Logging { } /** - * Method to train a decision tree model over an RDD + * Method to train a decision tree model where the instances are represented as an RDD of + * (label, features) pairs. The method supports binary classification and regression. For the + * binary classification, the label for each instance should either be 0 or 1 to denote the two + * classes. + * * @param input input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as * training data * @param algo algorithm, classification or regression @@ -204,8 +212,13 @@ object DecisionTree extends Serializable with Logging { /** - * Method to train a decision tree model over an RDD - * @param input input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as + * Method to train a decision tree model where the instances are represented as an RDD of + * (label, features) pairs. The decision tree method supports binary classification and + * regression. For the binary classification, the label for each instance should either be 0 or + * 1 to denote the two classes. The method also supports categorical features inputs where the + * number of categories can specified using the categoricalFeaturesInfo option. + * + * @param input input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as * training data for DecisionTree * @param algo classification or regression * @param impurity criterion used for information gain calculation @@ -236,6 +249,7 @@ object DecisionTree extends Serializable with Logging { /** * Returns an array of optimal splits for all nodes at a given level + * * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data * for DecisionTree * @param parentImpurities Impurities for all parent nodes for the current level @@ -247,7 +261,7 @@ object DecisionTree extends Serializable with Logging { * @param bins possible bins for all features * @return array of splits with best splits for all nodes at a given level. */ - private def findBestSplits( + protected[tree] def findBestSplits( input: RDD[LabeledPoint], parentImpurities: Array[Double], strategy: Strategy, @@ -885,7 +899,7 @@ object DecisionTree extends Serializable with Logging { * .model.Split] of size (numFeatures, numSplits-1) and bins is an Array of [org.apache * .spark.mllib.tree.model.Bin] of size (numFeatures, numSplits1) */ - private def findSplitsBins( + protected[tree] def findSplitsBins( input: RDD[LabeledPoint], strategy: Strategy): (Array[Array[Split]], Array[Array[Bin]]) = { val count = input.count()