Skip to content

Commit

Permalink
more tests for categorical features
Browse files Browse the repository at this point in the history
Signed-off-by: Manish Amde <[email protected]>
  • Loading branch information
manishamde committed Feb 28, 2014
1 parent dbb7ac1 commit d504eb1
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -204,15 +204,12 @@ object DecisionTree extends Serializable with Logging {
}

/*Finds the right bin for the given feature*/
def findBin(featureIndex: Int, labeledPoint: LabeledPoint) : Int = {
//logDebug("finding bin for labeled point " + labeledPoint.features(featureIndex))
def findBin(featureIndex: Int, labeledPoint: LabeledPoint, isFeatureContinous : Boolean) : Int = {

val isFeatureContinous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
if (isFeatureContinous){
//TODO: Do binary search
for (binIndex <- 0 until strategy.numBins) {
val bin = bins(featureIndex)(binIndex)
//TODO: Remove this requirement post basic functional
val lowThreshold = bin.lowSplit.threshold
val highThreshold = bin.highSplit.threshold
val features = labeledPoint.features
Expand All @@ -222,9 +219,9 @@ object DecisionTree extends Serializable with Logging {
}
throw new UnknownError("no bin was found for continuous variable.")
} else {

for (binIndex <- 0 until strategy.numBins) {
val bin = bins(featureIndex)(binIndex)
//TODO: Remove this requirement post basic functional
val category = bin.category
val features = labeledPoint.features
if (category == features(featureIndex)) {
Expand Down Expand Up @@ -262,7 +259,8 @@ object DecisionTree extends Serializable with Logging {
} else {
for (featureIndex <- 0 until numFeatures) {
//logDebug("shift+featureIndex =" + (shift+featureIndex))
arr(shift + featureIndex) = findBin(featureIndex, labeledPoint)
val isFeatureContinous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
arr(shift + featureIndex) = findBin(featureIndex, labeledPoint,isFeatureContinous)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
println(splits(1)(0))
println(splits(1)(1))
println(bins(1)(0))
//TODO: Add asserts

}

test("split and bin calculations for categorical variables with no sample for one category"){
Expand All @@ -100,12 +102,28 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
println(bins(1)(1))
println(bins(0)(2))
println(bins(0)(3))
//TODO: Add asserts

}

//TODO: Test max feature value > num bins


test("stump with fixed label 0 for Gini"){
test("stump with all categorical variables"){
val arr = DecisionTreeSuite.generateCategoricalDataPoints()
assert(arr.length == 1000)
val rdd = sc.parallelize(arr)
val strategy = new Strategy(Classification,Gini,3,100,categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy)
strategy.numBins = 100
val bestSplits = DecisionTree.findBestSplits(rdd,new Array(7),strategy,0,Array[List[Filter]](),splits,bins)
println(bestSplits(0)._1)
println(bestSplits(0)._2)
//TODO: Add asserts
}


test("stump with fixed label 0 for Gini"){
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0()
assert(arr.length == 1000)
val rdd = sc.parallelize(arr)
Expand Down

0 comments on commit d504eb1

Please sign in to comment.