diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 28e2f992e65b8..b8164f64a7b04 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -29,7 +29,7 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ import org.apache.spark.mllib.tree.configuration.FeatureType._ import org.apache.spark.mllib.tree.configuration.Algo._ -/* +/** A class that implements a decision tree algorithm for classification and regression. It supports both continuous and categorical features. @@ -40,7 +40,7 @@ quantile calculation strategy, etc. */ class DecisionTree(val strategy : Strategy) extends Serializable with Logging { - /* + /** Method to train a decision tree model over an RDD @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data @@ -157,14 +157,14 @@ class DecisionTree(val strategy : Strategy) extends Serializable with Logging { object DecisionTree extends Serializable with Logging { - /* + /** Returns an Array[Split] of optimal splits for all nodes at a given level @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data - for DecisionTree + for DecisionTree @param parentImpurities Impurities for all parent nodes for the current level @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing - parameters for construction the DecisionTree + parameters for construction the DecisionTree @param level Level of the tree @param filters Filter for all nodes at a given level @param splits possible splits for all features @@ -200,7 +200,7 @@ object DecisionTree extends Serializable with Logging { } } - /* + /** Find whether the sample is valid input for the current node. In other words, does it pass through all the filters for the current node. */ @@ -236,7 +236,9 @@ object DecisionTree extends Serializable with Logging { true } - /*Finds the right bin for the given feature*/ + /** + Finds the right bin for the given feature + */ def findBin(featureIndex: Int, labeledPoint: LabeledPoint, isFeatureContinuous : Boolean) : Int = { if (isFeatureContinuous){ @@ -266,7 +268,8 @@ object DecisionTree extends Serializable with Logging { } - /*Finds bins for all nodes (and all features) at a given level + /** + Finds bins for all nodes (and all features) at a given level k features, l nodes (level = log2(l)) Storage label, b11, b12, b13, .., bk, b21, b22, .. ,bl1, bl2, .. ,blk Denotes invalid sample for tree by noting bin for feature 1 as -1 @@ -343,7 +346,8 @@ object DecisionTree extends Serializable with Logging { } } - /*Performs a sequential aggregation over a partition. + /** + Performs a sequential aggregation over a partition. for p bins, k features, l nodes (level = log2(l)) storage is of the form: b111_left_count,b111_right_count, .... , .. @@ -370,7 +374,8 @@ object DecisionTree extends Serializable with Logging { } logDebug("binAggregateLength = " + binAggregateLength) - /*Combines the aggregates from partitions + /** + Combines the aggregates from partitions @param agg1 Array containing aggregates from one or more partitions @param agg2 Array containing aggregates from one or more partitions @@ -507,7 +512,7 @@ object DecisionTree extends Serializable with Logging { } } - /* + /** Extracts left and right split aggregates @param binData Array[Double] of size 2*numFeatures*numSplits @@ -604,7 +609,7 @@ object DecisionTree extends Serializable with Logging { gains } - /* + /** Find the best split for a node given bin aggregate data @param binData Array[Double] of size 2*numSplits*numFeatures @@ -669,7 +674,7 @@ object DecisionTree extends Serializable with Logging { bestSplits } - /* + /** Returns split and bins for decision tree calculation. @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala index 7cd128e381e8f..14ec47ce014e7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala @@ -16,6 +16,9 @@ */ package org.apache.spark.mllib.tree.configuration +/** + * Enum to select the algorithm for the decision tree + */ object Algo extends Enumeration { type Algo = Value val Classification, Regression = Value diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/FeatureType.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/FeatureType.scala index a725bf388fe29..b4e8ae4ac39dd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/FeatureType.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/FeatureType.scala @@ -16,6 +16,9 @@ */ package org.apache.spark.mllib.tree.configuration +/** + * Enum to describe whether a feature is "continuous" or "categorical" + */ object FeatureType extends Enumeration { type FeatureType = Value val Continuous, Categorical = Value diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala index 1bbd2d8c1fe92..dae73ab52dea7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala @@ -16,6 +16,9 @@ */ package org.apache.spark.mllib.tree.configuration +/** + * Enum for selecting the quantile calculation strategy + */ object QuantileStrategy extends Enumeration { type QuantileStrategy = Value val Sort, MinMax, ApproxHist = Value diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index 281dabd3364d8..973aaee49e5fb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -20,6 +20,19 @@ import org.apache.spark.mllib.tree.impurity.Impurity import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ +/** + * Stores all the configuration options for tree construction + * @param algo classification or regression + * @param impurity criterion used for information gain calculation + * @param maxDepth maximum depth of the tree + * @param maxBins maximum number of bins used for splitting features + * @param quantileCalculationStrategy algorithm for calculating quantiles + * @param categoricalFeaturesInfo A map storing information about the categorical variables and the + * number of discrete values they take. For example, an entry (n -> + * k) implies the feature n is categorical with k categories 0, + * 1, 2, ... , k-1. It's important to note that features are + * zero-indexed. + */ class Strategy ( val algo : Algo, val impurity : Impurity, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala index 350627e9de1dd..c1b2972f9c25b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala @@ -18,10 +18,20 @@ package org.apache.spark.mllib.tree.impurity import javax.naming.OperationNotSupportedException +/** + * Class for calculating [[http://en.wikipedia.org/wiki/Binary_entropy_function entropy]] during + * binary classification. + */ object Entropy extends Impurity { def log2(x: Double) = scala.math.log(x) / scala.math.log(2) + /** + * entropy calculation + * @param c0 count of instances with label 0 + * @param c1 count of instances with label 1 + * @return entropy value + */ def calculate(c0: Double, c1: Double): Double = { if (c0 == 0 || c1 == 0) { 0 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala index 3c7615f684525..099c7e33dd39a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala @@ -18,8 +18,18 @@ package org.apache.spark.mllib.tree.impurity import javax.naming.OperationNotSupportedException +/** + * Class for calculating the [[http://en.wikipedia.org/wiki/Gini_coefficient Gini + * coefficent]] during binary classification + */ object Gini extends Impurity { + /** + * gini coefficient calculation + * @param c0 count of instances with label 0 + * @param c1 count of instances with label 1 + * @return gini coefficient value + */ def calculate(c0 : Double, c1 : Double): Double = { if (c0 == 0 || c1 == 0) { 0 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala index 65f5b3702779a..b313b8d48eadf 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala @@ -19,9 +19,19 @@ package org.apache.spark.mllib.tree.impurity import javax.naming.OperationNotSupportedException import org.apache.spark.Logging +/** + * Class for calculating variance during regression + */ object Variance extends Impurity with Logging { def calculate(c0: Double, c1: Double): Double = throw new OperationNotSupportedException("Variance.calculate") + /** + * variance calculation + * @param count number of instances + * @param sum sum of labels + * @param sumSquares summation of squares of the labels + * @return + */ def calculate(count: Double, sum: Double, sumSquares: Double): Double = { val squaredLoss = sumSquares - (sum*sum)/count squaredLoss/count diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala index 6664f084a7d8d..0b4b7d2e5b2df 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala @@ -18,6 +18,17 @@ package org.apache.spark.mllib.tree.model import org.apache.spark.mllib.tree.configuration.FeatureType._ +/** + * Used for "binning" the features bins for faster best split calculation. For a continuous + * feature, a bin is determined by a low and a high "split". For a categorical feature, + * the a bin is determined using a single label value (category). + * @param lowSplit signifying the lower threshold for the continuous feature to be + * accepted in the bin + * @param highSplit signifying the upper threshold for the continuous feature to be + * accepted in the bin + * @param featureType type of feature -- categorical or continuous + * @param category categorical label value accepted in the bin + */ case class Bin(lowSplit : Split, highSplit : Split, featureType : FeatureType, category : Double) { } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala index 0da42e826984c..0e94827b0af70 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -16,12 +16,23 @@ */ package org.apache.spark.mllib.tree.model -import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.Algo._ +import org.apache.spark.rdd.RDD +/** + * Model to store the decision tree parameters + * @param topNode root node + * @param algo algorithm type -- classification or regression + */ class DecisionTreeModel(val topNode : Node, val algo : Algo) extends Serializable { - def predict(features : Array[Double]) = { + /** + * Predict values for a single data point using the model trained. + * + * @param features array representing a single data point + * @return Double prediction from the trained model + */ + def predict(features : Array[Double]) : Double = { algo match { case Classification => { if (topNode.predictIfLeaf(features) < 0.5) 0.0 else 1.0 @@ -32,4 +43,15 @@ class DecisionTreeModel(val topNode : Node, val algo : Algo) extends Serializabl } } + /** + * Predict values for the given data set using the model trained. + * + * @param features RDD representing data points to be predicted + * @return RDD[Int] where each entry contains the corresponding prediction + */ + def predict(features: RDD[Array[Double]]): RDD[Double] = { + features.map(x => predict(x)) + } + + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Filter.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Filter.scala index 62e5006c80c1b..9fc794c87398d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Filter.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Filter.scala @@ -16,6 +16,11 @@ */ package org.apache.spark.mllib.tree.model +/** + * Filter specifying a split and type of comparison to be applied on features + * @param split split specifying the feature index, type and threshold + * @param comparison integer specifying <,=,> + */ case class Filter(split : Split, comparison : Int) { // Comparison -1,0,1 signifies <.=,> override def toString = " split = " + split + "comparison = " + comparison diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala index f410a5a2cf812..0f8d7a36d208f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala @@ -16,6 +16,14 @@ */ package org.apache.spark.mllib.tree.model +/** + * Information gain statistics for each split + * @param gain information gain value + * @param impurity current node impurity + * @param leftImpurity left node impurity + * @param rightImpurity right node impurity + * @param predict predicted value + */ class InformationGainStats( val gain : Double, val impurity: Double, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala index fb7e0db9c9dd2..374f065a09032 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala @@ -20,6 +20,16 @@ import org.apache.spark.Logging import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.FeatureType._ +/** + * Node in a decision tree + * @param id integer node id + * @param predict predicted value at the node + * @param isLeaf whether the leaf is a node + * @param split split to calculate left and right nodes + * @param leftNode left child + * @param rightNode right child + * @param stats information gain stats + */ class Node ( val id : Int, val predict : Double, val isLeaf : Boolean, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala index 1604996091597..81e57dbf5e521 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala @@ -18,6 +18,13 @@ package org.apache.spark.mllib.tree.model import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType +/** + * Split applied to a feature + * @param feature feature index + * @param threshold threshold for continuous feature + * @param featureType type of feature -- categorical or continuous + * @param categories accepted values for categorical variables + */ case class Split( feature: Int, threshold : Double, @@ -29,12 +36,28 @@ case class Split( ", categories = " + categories } -class DummyLowSplit(feature: Int, kind : FeatureType) - extends Split(feature, Double.MinValue, kind, List()) +/** + * Split with minimum threshold for continuous features. Helps with the smallest bin creation. + * @param feature feature index + * @param featureType type of feature -- categorical or continuous + */ +class DummyLowSplit(feature: Int, featureType : FeatureType) + extends Split(feature, Double.MinValue, featureType, List()) -class DummyHighSplit(feature: Int, kind : FeatureType) - extends Split(feature, Double.MaxValue, kind, List()) +/** + * Split with maximum threshold for continuous features. Helps with the highest bin creation. + * @param feature feature index + * @param featureType type of feature -- categorical or continuous + */ +class DummyHighSplit(feature: Int, featureType : FeatureType) + extends Split(feature, Double.MaxValue, featureType, List()) -class DummyCategoricalSplit(feature: Int, kind : FeatureType) - extends Split(feature, Double.MaxValue, kind, List()) +/** + * Split with no acceptable feature values for categorical features. Helps with the first bin + * creation. + * @param feature feature index + * @param featureType type of feature -- categorical or continuous + */ +class DummyCategoricalSplit(feature: Int, featureType : FeatureType) + extends Split(feature, Double.MaxValue, featureType, List())