Skip to content

Commit

Permalink
added multiple train methods for java compatability
Browse files Browse the repository at this point in the history
  • Loading branch information
manishamde committed Mar 6, 2014
1 parent d3023b3 commit 63e786b
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import org.apache.spark.mllib.tree.configuration.Strategy
import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
import org.apache.spark.mllib.tree.configuration.FeatureType._
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.impurity.Impurity

/**
A class that implements a decision tree algorithm for classification and regression.
Expand All @@ -38,7 +39,7 @@ algorithm (classification,
regression, etc.), feature type (continuous, categorical), depth of the tree,
quantile calculation strategy, etc.
*/
class DecisionTree(val strategy : Strategy) extends Serializable with Logging {
class DecisionTree private (val strategy : Strategy) extends Serializable with Logging {

/**
Method to train a decision tree model over an RDD
Expand Down Expand Up @@ -157,6 +158,70 @@ class DecisionTree(val strategy : Strategy) extends Serializable with Logging {

object DecisionTree extends Serializable with Logging {

/**
Method to train a decision tree model over an RDD
@param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data
for DecisionTree
@param strategy The configuration parameters for the tree algorithm which specify the type of algorithm
(classification, regression, etc.), feature type (continuous, categorical),
depth of the tree, quantile calculation strategy, etc.
@return a DecisionTreeModel that can be used for prediction
*/
def train(input : RDD[LabeledPoint], strategy : Strategy) : DecisionTreeModel = {
new DecisionTree(strategy).train(input : RDD[LabeledPoint])
}

/**
Method to train a decision tree model over an RDD
@param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data
for DecisionTree
@param algo classification or regression
@param impurity criterion used for information gain calculation
@param maxDepth maximum depth of the tree
@return a DecisionTreeModel that can be used for prediction
*/
def train(
input : RDD[LabeledPoint],
algo : Algo,
impurity : Impurity,
maxDepth : Int
) : DecisionTreeModel = {
val strategy = new Strategy(algo,impurity,maxDepth)
new DecisionTree(strategy).train(input : RDD[LabeledPoint])
}


/**
Method to train a decision tree model over an RDD
@param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data
for DecisionTree
@param algo classification or regression
@param impurity criterion used for information gain calculation
@param maxDepth maximum depth of the tree
@param maxBins maximum number of bins used for splitting features
@param quantileCalculationStrategy algorithm for calculating quantiles
@param categoricalFeaturesInfo A map storing information about the categorical variables and the number of discrete
values they take. For example, an entry (n -> k) implies the feature n is
categorical with k categories 0, 1, 2, ... , k-1. It's important to note that
features are zero-indexed.
@return a DecisionTreeModel that can be used for prediction
*/
def train(
input : RDD[LabeledPoint],
algo : Algo,
impurity : Impurity,
maxDepth : Int,
maxBins : Int,
quantileCalculationStrategy : QuantileStrategy,
categoricalFeaturesInfo : Map[Int,Int]
) : DecisionTreeModel = {
val strategy = new Strategy(algo,impurity,maxDepth,maxBins,quantileCalculationStrategy,categoricalFeaturesInfo)
new DecisionTree(strategy).train(input : RDD[LabeledPoint])
}

/**
Returns an Array[Split] of optimal splits for all nodes at a given level
Expand Down Expand Up @@ -717,13 +782,13 @@ object DecisionTree extends Serializable with Logging {
for DecisionTree
@param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing
parameters for construction the DecisionTree
@return a tuple of (splits,bins) where Split is an Array[Array[Split]] of size (numFeatures,
numSplits-1) and bins is an
Array[Array[Bin]] of size (numFeatures,numSplits1)
@return a tuple of (splits,bins) where splits is an Array of [org.apache.spark.mllib.tree.model.Split] of
size (numFeatures,numSplits-1) and bins is an Array of [org.apache.spark.mllib.tree.model.Bin] of
size (numFeatures,numSplits1)
*/
def findSplitsBins(
input : RDD[LabeledPoint],
strategy : Strategy) : (Array[Array[Split]], Array[Array[Bin]]) = {
strategy : Strategy): (Array[Array[Split]], Array[Array[Bin]]) = {

val count = input.count()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ object DecisionTreeRunner extends Logging {
val maxBins = options.getOrElse('maxBins,"100").toString.toInt

val strategy = new Strategy(algo = algo, impurity = impurity, maxDepth = maxDepth, maxBins = maxBins)
val model = new DecisionTree(strategy).train(trainData)
val model = DecisionTree.train(trainData,strategy)

//Load test data
val testData = loadLabeledData(sc, options.get('testDataDir).get.toString)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class Strategy (
val algo : Algo,
val impurity : Impurity,
val maxDepth : Int,
val maxBins : Int,
val maxBins : Int = 100,
val quantileCalculationStrategy : QuantileStrategy = Sort,
val categoricalFeaturesInfo : Map[Int,Int] = Map[Int,Int]()) extends Serializable {

Expand Down

0 comments on commit 63e786b

Please sign in to comment.