Skip to content

Commit

Permalink
added regression support
Browse files Browse the repository at this point in the history
Signed-off-by: Manish Amde <[email protected]>
  • Loading branch information
manishamde committed Feb 28, 2014
1 parent c8f6d60 commit e23c2e5
Show file tree
Hide file tree
Showing 7 changed files with 231 additions and 76 deletions.
261 changes: 194 additions & 67 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*/
package org.apache.spark.mllib.tree

import org.apache.spark.SparkContext._
import org.apache.spark.{Logging, SparkContext}
import org.apache.spark.mllib.tree.impurity.{Gini,Entropy,Variance}
import org.apache.spark.rdd.RDD
Expand Down Expand Up @@ -95,6 +96,9 @@ object DecisionTreeRunner extends Logging {
val accuracy = accuracyScore(model, testData)
logDebug("accuracy = " + accuracy)

val mse = meanSquaredError(model,testData)
logDebug("mean square error = " + mse)

sc.stop()
}

Expand Down Expand Up @@ -126,6 +130,14 @@ object DecisionTreeRunner extends Logging {
correctCount.toDouble / count
}

//TODO: Make these generic MLTable metrics
def meanSquaredError(tree : DecisionTreeModel, data : RDD[LabeledPoint]) : Double = {
val meanSumOfSquares = data.map(y => (tree.predict(y.features) - y.label)*(tree.predict(y.features) - y.label)).mean()
println("meanSumOfSquares = " + meanSumOfSquares)
meanSumOfSquares
}




}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
*/
package org.apache.spark.mllib.tree.impurity

import javax.naming.OperationNotSupportedException

object Entropy extends Impurity {

def log2(x: Double) = scala.math.log(x) / scala.math.log(2)
Expand All @@ -31,4 +33,6 @@ object Entropy extends Impurity {
}
}

}
def calculate(count: Double, sum: Double, sumSquares: Double): Double =
throw new OperationNotSupportedException("Entropy.calculate")
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
*/
package org.apache.spark.mllib.tree.impurity

import javax.naming.OperationNotSupportedException

object Gini extends Impurity {

def calculate(c0 : Double, c1 : Double): Double = {
Expand All @@ -29,4 +31,5 @@ object Gini extends Impurity {
}
}

}
def calculate(count: Double, sum: Double, sumSquares: Double): Double = throw new OperationNotSupportedException("Gini.calculate")
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,6 @@ trait Impurity extends Serializable {

def calculate(c0 : Double, c1 : Double): Double

def calculate(count : Double, sum : Double, sumSquares : Double) : Double

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,14 @@
package org.apache.spark.mllib.tree.impurity

import javax.naming.OperationNotSupportedException
import org.apache.spark.Logging

object Variance extends Impurity {
object Variance extends Impurity with Logging {
def calculate(c0: Double, c1: Double): Double = throw new OperationNotSupportedException("Variance.calculate")
}

def calculate(count: Double, sum: Double, sumSquares: Double): Double = {
val squaredLoss = sumSquares - (sum*sum)/count
squaredLoss/count
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
assert(arr.length == 1000)
val rdd = sc.parallelize(arr)
val strategy = new Strategy(Regression,Gini,3,100)
val strategy = new Strategy(Classification,Gini,3,100)
val (splits, bins) = DecisionTree.find_splits_bins(rdd,strategy)
assert(splits.length==2)
assert(bins.length==2)
Expand All @@ -62,7 +62,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0()
assert(arr.length == 1000)
val rdd = sc.parallelize(arr)
val strategy = new Strategy(Regression,Gini,3,100)
val strategy = new Strategy(Classification,Gini,3,100)
val (splits, bins) = DecisionTree.find_splits_bins(rdd,strategy)
assert(splits.length==2)
assert(splits(0).length==99)
Expand All @@ -88,7 +88,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
assert(arr.length == 1000)
val rdd = sc.parallelize(arr)
val strategy = new Strategy(Regression,Gini,3,100)
val strategy = new Strategy(Classification,Gini,3,100)
val (splits, bins) = DecisionTree.find_splits_bins(rdd,strategy)
assert(splits.length==2)
assert(splits(0).length==99)
Expand All @@ -114,7 +114,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0()
assert(arr.length == 1000)
val rdd = sc.parallelize(arr)
val strategy = new Strategy(Regression,Entropy,3,100)
val strategy = new Strategy(Classification,Entropy,3,100)
val (splits, bins) = DecisionTree.find_splits_bins(rdd,strategy)
assert(splits.length==2)
assert(splits(0).length==99)
Expand All @@ -139,7 +139,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
assert(arr.length == 1000)
val rdd = sc.parallelize(arr)
val strategy = new Strategy(Regression,Entropy,3,100)
val strategy = new Strategy(Classification,Entropy,3,100)
val (splits, bins) = DecisionTree.find_splits_bins(rdd,strategy)
assert(splits.length==2)
assert(splits(0).length==99)
Expand Down

0 comments on commit e23c2e5

Please sign in to comment.