Skip to content

Commit

Permalink
minor refactoring and tests
Browse files Browse the repository at this point in the history
Signed-off-by: Manish Amde <[email protected]>
  • Loading branch information
manishamde committed Feb 28, 2014
1 parent d504eb1 commit 6b7de78
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 54 deletions.
100 changes: 48 additions & 52 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
Original file line number Diff line number Diff line change
Expand Up @@ -152,17 +152,15 @@ object DecisionTree extends Serializable with Logging {
//Find the number of features by looking at the first sample
val numFeatures = input.take(1)(0).features.length
logDebug("numFeatures = " + numFeatures)
val numSplits = strategy.numBins
logDebug("numSplits = " + numSplits)
val numBins = strategy.numBins
logDebug("numBins = " + numBins)

/*Find the filters used before reaching the current code*/
def findParentFilters(nodeIndex: Int): List[Filter] = {
if (level == 0) {
List[Filter]()
} else {
val nodeFilterIndex = scala.math.pow(2, level).toInt - 1 + nodeIndex
//val parentFilterIndex = nodeFilterIndex / 2
//TODO: Check left or right filter
filters(nodeFilterIndex)
}
}
Expand Down Expand Up @@ -204,9 +202,9 @@ object DecisionTree extends Serializable with Logging {
}

/*Finds the right bin for the given feature*/
def findBin(featureIndex: Int, labeledPoint: LabeledPoint, isFeatureContinous : Boolean) : Int = {
def findBin(featureIndex: Int, labeledPoint: LabeledPoint, isFeatureContinuous : Boolean) : Int = {

if (isFeatureContinous){
if (isFeatureContinuous){
//TODO: Do binary search
for (binIndex <- 0 until strategy.numBins) {
val bin = bins(featureIndex)(binIndex)
Expand Down Expand Up @@ -245,11 +243,11 @@ object DecisionTree extends Serializable with Logging {
// calculating bin index and label per feature per node
val arr = new Array[Double](1+(numFeatures * numNodes))
arr(0) = labeledPoint.label
for (index <- 0 until numNodes) {
val parentFilters = findParentFilters(index)
for (nodeIndex <- 0 until numNodes) {
val parentFilters = findParentFilters(nodeIndex)
//Find out whether the sample qualifies for the particular node
val sampleValid = isSampleValid(parentFilters, labeledPoint)
val shift = 1 + numFeatures * index
val shift = 1 + numFeatures * nodeIndex
if (!sampleValid) {
//Add to invalid bin index -1
for (featureIndex <- 0 until numFeatures) {
Expand All @@ -274,11 +272,11 @@ object DecisionTree extends Serializable with Logging {
val isSampleValidForNode = if (arr(validSignalIndex) != -1) true else false
if (isSampleValidForNode) {
val label = arr(0)
for (feature <- 0 until numFeatures) {
for (featureIndex <- 0 until numFeatures) {
val arrShift = 1 + numFeatures * node
val aggShift = 2 * numSplits * numFeatures * node
val arrIndex = arrShift + feature
val aggIndex = aggShift + 2 * feature * numSplits + arr(arrIndex).toInt * 2
val aggShift = 2 * numBins * numFeatures * node
val arrIndex = arrShift + featureIndex
val aggIndex = aggShift + 2 * featureIndex * numBins + arr(arrIndex).toInt * 2
label match {
case (0.0) => agg(aggIndex) = agg(aggIndex) + 1
case (1.0) => agg(aggIndex + 1) = agg(aggIndex + 1) + 1
Expand All @@ -296,9 +294,9 @@ object DecisionTree extends Serializable with Logging {
val label = arr(0)
for (feature <- 0 until numFeatures) {
val arrShift = 1 + numFeatures * node
val aggShift = 3 * numSplits * numFeatures * node
val aggShift = 3 * numBins * numFeatures * node
val arrIndex = arrShift + feature
val aggIndex = aggShift + 3 * feature * numSplits + arr(arrIndex).toInt * 3
val aggIndex = aggShift + 3 * feature * numBins + arr(arrIndex).toInt * 3
//count, sum, sum^2
agg(aggIndex) = agg(aggIndex) + 1
agg(aggIndex + 1) = agg(aggIndex + 1) + label
Expand All @@ -318,7 +316,6 @@ object DecisionTree extends Serializable with Logging {
@return Array[Double] storing aggregate calculation of size 2*numSplits*numFeatures*numNodes for classification
*/
def binSeqOp(agg : Array[Double], arr: Array[Double]) : Array[Double] = {
//TODO: Requires logic for regressions
strategy.algo match {
case Classification => classificationBinSeqOp(arr, agg)
//TODO: Implement this
Expand All @@ -327,10 +324,9 @@ object DecisionTree extends Serializable with Logging {
agg
}

//TODO: This length is different for regression
val binAggregateLength = strategy.algo match {
case Classification => 2*numSplits * numFeatures * numNodes
case Regression => 3*numSplits * numFeatures * numNodes
case Classification => 2*numBins * numFeatures * numNodes
case Regression => 3*numBins * numFeatures * numNodes
}
logDebug("binAggregateLength = " + binAggregateLength)

Expand Down Expand Up @@ -453,52 +449,52 @@ object DecisionTree extends Serializable with Logging {
strategy.algo match {
case Classification => {

val leftNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numSplits - 1))
val rightNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numSplits - 1))
val leftNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numBins - 1))
val rightNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numBins - 1))
for (featureIndex <- 0 until numFeatures) {
val shift = 2*featureIndex*numSplits
val shift = 2*featureIndex*numBins
leftNodeAgg(featureIndex)(0) = binData(shift + 0)
leftNodeAgg(featureIndex)(1) = binData(shift + 1)
rightNodeAgg(featureIndex)(2 * (numSplits - 2)) = binData(shift + (2 * (numSplits - 1)))
rightNodeAgg(featureIndex)(2 * (numSplits - 2) + 1) = binData(shift + (2 * (numSplits - 1)) + 1)
for (splitIndex <- 1 until numSplits - 1) {
rightNodeAgg(featureIndex)(2 * (numBins - 2)) = binData(shift + (2 * (numBins - 1)))
rightNodeAgg(featureIndex)(2 * (numBins - 2) + 1) = binData(shift + (2 * (numBins - 1)) + 1)
for (splitIndex <- 1 until numBins - 1) {
leftNodeAgg(featureIndex)(2 * splitIndex)
= binData(shift + 2*splitIndex) + leftNodeAgg(featureIndex)(2 * splitIndex - 2)
leftNodeAgg(featureIndex)(2 * splitIndex + 1)
= binData(shift + 2*splitIndex + 1) + leftNodeAgg(featureIndex)(2 * splitIndex - 2 + 1)
rightNodeAgg(featureIndex)(2 * (numSplits - 2 - splitIndex))
= binData(shift + (2 * (numSplits - 1 - splitIndex))) + rightNodeAgg(featureIndex)(2 * (numSplits - 1 - splitIndex))
rightNodeAgg(featureIndex)(2 * (numSplits - 2 - splitIndex) + 1)
= binData(shift + (2 * (numSplits - 1 - splitIndex) + 1)) + rightNodeAgg(featureIndex)(2 * (numSplits - 1 - splitIndex) + 1)
rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex))
= binData(shift + (2 * (numBins - 1 - splitIndex))) + rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex))
rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex) + 1)
= binData(shift + (2 * (numBins - 1 - splitIndex) + 1)) + rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex) + 1)
}
}
(leftNodeAgg, rightNodeAgg)
}
case Regression => {

val leftNodeAgg = Array.ofDim[Double](numFeatures, 3 * (numSplits - 1))
val rightNodeAgg = Array.ofDim[Double](numFeatures, 3 * (numSplits - 1))
val leftNodeAgg = Array.ofDim[Double](numFeatures, 3 * (numBins - 1))
val rightNodeAgg = Array.ofDim[Double](numFeatures, 3 * (numBins - 1))
for (featureIndex <- 0 until numFeatures) {
val shift = 3*featureIndex*numSplits
val shift = 3*featureIndex*numBins
leftNodeAgg(featureIndex)(0) = binData(shift + 0)
leftNodeAgg(featureIndex)(1) = binData(shift + 1)
leftNodeAgg(featureIndex)(2) = binData(shift + 2)
rightNodeAgg(featureIndex)(3 * (numSplits - 2)) = binData(shift + (3 * (numSplits - 1)))
rightNodeAgg(featureIndex)(3 * (numSplits - 2) + 1) = binData(shift + (3 * (numSplits - 1)) + 1)
rightNodeAgg(featureIndex)(3 * (numSplits - 2) + 2) = binData(shift + (3 * (numSplits - 1)) + 2)
for (splitIndex <- 1 until numSplits - 1) {
rightNodeAgg(featureIndex)(3 * (numBins - 2)) = binData(shift + (3 * (numBins - 1)))
rightNodeAgg(featureIndex)(3 * (numBins - 2) + 1) = binData(shift + (3 * (numBins - 1)) + 1)
rightNodeAgg(featureIndex)(3 * (numBins - 2) + 2) = binData(shift + (3 * (numBins - 1)) + 2)
for (splitIndex <- 1 until numBins - 1) {
leftNodeAgg(featureIndex)(3 * splitIndex)
= binData(shift + 3*splitIndex) + leftNodeAgg(featureIndex)(3 * splitIndex - 3)
leftNodeAgg(featureIndex)(3 * splitIndex + 1)
= binData(shift + 3*splitIndex + 1) + leftNodeAgg(featureIndex)(3 * splitIndex - 3 + 1)
leftNodeAgg(featureIndex)(3 * splitIndex + 2)
= binData(shift + 3*splitIndex + 2) + leftNodeAgg(featureIndex)(3 * splitIndex - 3 + 2)
rightNodeAgg(featureIndex)(3 * (numSplits - 2 - splitIndex))
= binData(shift + (3 * (numSplits - 1 - splitIndex))) + rightNodeAgg(featureIndex)(3 * (numSplits - 1 - splitIndex))
rightNodeAgg(featureIndex)(3 * (numSplits - 2 - splitIndex) + 1)
= binData(shift + (3 * (numSplits - 1 - splitIndex) + 1)) + rightNodeAgg(featureIndex)(3 * (numSplits - 1 - splitIndex) + 1)
rightNodeAgg(featureIndex)(3 * (numSplits - 2 - splitIndex) + 2)
= binData(shift + (3 * (numSplits - 1 - splitIndex) + 2)) + rightNodeAgg(featureIndex)(3 * (numSplits - 1 - splitIndex) + 2)
rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex))
= binData(shift + (3 * (numBins - 1 - splitIndex))) + rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex))
rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex) + 1)
= binData(shift + (3 * (numBins - 1 - splitIndex) + 1)) + rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 1)
rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex) + 2)
= binData(shift + (3 * (numBins - 1 - splitIndex) + 2)) + rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 2)
}
}
(leftNodeAgg, rightNodeAgg)
Expand All @@ -509,10 +505,10 @@ object DecisionTree extends Serializable with Logging {
def calculateGainsForAllNodeSplits(leftNodeAgg: Array[Array[Double]], rightNodeAgg: Array[Array[Double]], nodeImpurity: Double)
: Array[Array[InformationGainStats]] = {

val gains = Array.ofDim[InformationGainStats](numFeatures, numSplits - 1)
val gains = Array.ofDim[InformationGainStats](numFeatures, numBins - 1)

for (featureIndex <- 0 until numFeatures) {
for (index <- 0 until numSplits -1) {
for (index <- 0 until numBins -1) {
//logDebug("splitIndex = " + index)
gains(featureIndex)(index) = calculateGainForSplit(leftNodeAgg, featureIndex, index, rightNodeAgg, nodeImpurity)
}
Expand All @@ -521,10 +517,10 @@ object DecisionTree extends Serializable with Logging {
}

/*
Find the best split for a node given bin aggregate data
Find the best split for a node given bin aggregate data
@param binData Array[Double] of size 2*numSplits*numFeatures
*/
@param binData Array[Double] of size 2*numSplits*numFeatures
*/
def binsToBestSplit(binData : Array[Double], nodeImpurity : Double) : (Split, InformationGainStats) = {
logDebug("node impurity = " + nodeImpurity)
val (leftNodeAgg, rightNodeAgg) = extractLeftRightNodeAggregates(binData)
Expand All @@ -536,7 +532,7 @@ object DecisionTree extends Serializable with Logging {
//Initialization with infeasible values
var bestGainStats = new InformationGainStats(Double.MinValue,-1.0,-1.0,-1.0,-1)
for (featureIndex <- 0 until numFeatures) {
for (splitIndex <- 0 until numSplits - 1){
for (splitIndex <- 0 until numBins - 1){
val gainStats = gains(featureIndex)(splitIndex)
if(gainStats.gain > bestGainStats.gain) {
bestGainStats = gainStats
Expand All @@ -556,13 +552,13 @@ object DecisionTree extends Serializable with Logging {
def getBinDataForNode(node: Int): Array[Double] = {
strategy.algo match {
case Classification => {
val shift = 2 * node * numSplits * numFeatures
val binsForNode = binAggregates.slice(shift, shift + 2 * numSplits * numFeatures)
val shift = 2 * node * numBins * numFeatures
val binsForNode = binAggregates.slice(shift, shift + 2 * numBins * numFeatures)
binsForNode
}
case Regression => {
val shift = 3 * node * numSplits * numFeatures
val binsForNode = binAggregates.slice(shift, shift + 3 * numSplits * numFeatures)
val shift = 3 * node * numBins * numFeatures
val binsForNode = binAggregates.slice(shift, shift + 3 * numBins * numFeatures)
binsForNode
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,20 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
//TODO: Test max feature value > num bins


test("stump with all categorical variables"){
test("classification stump with all categorical variables"){
val arr = DecisionTreeSuite.generateCategoricalDataPoints()
assert(arr.length == 1000)
val rdd = sc.parallelize(arr)
val strategy = new Strategy(Classification,Gini,3,100,categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy)
strategy.numBins = 100
val bestSplits = DecisionTree.findBestSplits(rdd,new Array(7),strategy,0,Array[List[Filter]](),splits,bins)
println(bestSplits(0)._1)
println(bestSplits(0)._2)
//TODO: Add asserts
}

test("regression stump with all categorical variables"){
val arr = DecisionTreeSuite.generateCategoricalDataPoints()
assert(arr.length == 1000)
val rdd = sc.parallelize(arr)
Expand All @@ -123,7 +136,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
}


test("stump with fixed label 0 for Gini"){
test("stump with fixed label 0 for Gini"){
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0()
assert(arr.length == 1000)
val rdd = sc.parallelize(arr)
Expand Down

0 comments on commit 6b7de78

Please sign in to comment.