Skip to content

Commit

Permalink
removing threshold for classification predict method
Browse files Browse the repository at this point in the history
  • Loading branch information
manishamde committed Mar 13, 2014
1 parent 2116360 commit 632818f
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1034,8 +1034,12 @@ object DecisionTree extends Serializable with Logging {
/**
* Calculates the classifier accuracy.
*/
def accuracyScore(model: DecisionTreeModel, data: RDD[LabeledPoint]): Double = {
val correctCount = data.filter(y => model.predict(y.features) == y.label).count()
def accuracyScore(model: DecisionTreeModel, data: RDD[LabeledPoint],
threshold: Double = 0.5): Double = {
def predictedValue(features: Array[Double]) = {
if (model.predict(features) < threshold) 0.0 else 1.0
}
val correctCount = data.filter(y => predictedValue(y.features) == y.label).count()
val count = data.count()
logDebug("correct prediction count = " + correctCount)
logDebug("data count = " + count)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,7 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable
* @return Double prediction from the trained model
*/
def predict(features: Array[Double]): Double = {
algo match {
case Classification => {
if (topNode.predictIfLeaf(features) < 0.5) 0.0 else 1.0
}
case Regression => {
topNode.predictIfLeaf(features)
}
}
topNode.predictIfLeaf(features)
}

/**
Expand Down

0 comments on commit 632818f

Please sign in to comment.