Skip to content

Commit

Permalink
fixing code style based on feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
manishamde committed Mar 7, 2014
1 parent 63e786b commit cd2c2b4
Show file tree
Hide file tree
Showing 17 changed files with 365 additions and 352 deletions.
467 changes: 291 additions & 176 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Large diffs are not rendered by default.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.tree.configuration

/**
Expand All @@ -22,4 +23,4 @@ package org.apache.spark.mllib.tree.configuration
object Algo extends Enumeration {
type Algo = Value
val Classification, Regression = Value
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.tree.configuration

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.tree.configuration

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.tree.configuration

import org.apache.spark.mllib.tree.impurity.Impurity
Expand All @@ -34,13 +35,13 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
* zero-indexed.
*/
class Strategy (
val algo : Algo,
val impurity : Impurity,
val maxDepth : Int,
val maxBins : Int = 100,
val quantileCalculationStrategy : QuantileStrategy = Sort,
val categoricalFeaturesInfo : Map[Int,Int] = Map[Int,Int]()) extends Serializable {
val algo: Algo,
val impurity: Impurity,
val maxDepth: Int,
val maxBins: Int = 100,
val quantileCalculationStrategy: QuantileStrategy = Sort,
val categoricalFeaturesInfo: Map[Int,Int] = Map[Int,Int]()) extends Serializable {

var numBins : Int = Int.MinValue
var numBins: Int = Int.MinValue

}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.tree.impurity

import javax.naming.OperationNotSupportedException
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.tree.impurity

import javax.naming.OperationNotSupportedException
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,29 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.tree.impurity

/**
* Trail for calculating information gain
*/
trait Impurity extends Serializable {

/**
* information calculation for binary classification
* @param c0 count of instances with label 0
* @param c1 count of instances with label 1
* @return information value
*/
def calculate(c0 : Double, c1 : Double): Double

def calculate(count : Double, sum : Double, sumSquares : Double) : Double
/**
* information calculation for regression
* @param count number of instances
* @param sum sum of labels
* @param sumSquares summation of squares of the labels
* @return information value
*/
def calculate(count: Double, sum: Double, sumSquares: Double): Double

}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.tree.impurity

import javax.naming.OperationNotSupportedException
Expand All @@ -23,7 +24,8 @@ import org.apache.spark.Logging
* Class for calculating variance during regression
*/
object Variance extends Impurity with Logging {
def calculate(c0: Double, c1: Double): Double = throw new OperationNotSupportedException("Variance.calculate")
def calculate(c0: Double, c1: Double): Double
= throw new OperationNotSupportedException("Variance.calculate")

/**
* variance calculation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.tree.model

import org.apache.spark.mllib.tree.configuration.FeatureType._
Expand All @@ -29,6 +30,6 @@ import org.apache.spark.mllib.tree.configuration.FeatureType._
* @param featureType type of feature -- categorical or continuous
* @param category categorical label value accepted in the bin
*/
case class Bin(lowSplit : Split, highSplit : Split, featureType : FeatureType, category : Double) {
case class Bin(lowSplit: Split, highSplit: Split, featureType: FeatureType, category: Double) {

}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.tree.model

import org.apache.spark.mllib.tree.configuration.Algo._
Expand All @@ -24,15 +25,15 @@ import org.apache.spark.rdd.RDD
* @param topNode root node
* @param algo algorithm type -- classification or regression
*/
class DecisionTreeModel(val topNode : Node, val algo : Algo) extends Serializable {
class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable {

/**
* Predict values for a single data point using the model trained.
*
* @param features array representing a single data point
* @return Double prediction from the trained model
*/
def predict(features : Array[Double]) : Double = {
def predict(features: Array[Double]): Double = {
algo match {
case Classification => {
if (topNode.predictIfLeaf(features) < 0.5) 0.0 else 1.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,15 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.tree.model

/**
* Filter specifying a split and type of comparison to be applied on features
* @param split split specifying the feature index, type and threshold
* @param comparison integer specifying <,=,>
*/
case class Filter(split : Split, comparison : Int) {
case class Filter(split: Split, comparison: Int) {
// Comparison -1,0,1 signifies <.=,>
override def toString = " split = " + split + "comparison = " + comparison
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.tree.model

/**
Expand All @@ -25,11 +26,11 @@ package org.apache.spark.mllib.tree.model
* @param predict predicted value
*/
class InformationGainStats(
val gain : Double,
val gain: Double,
val impurity: Double,
val leftImpurity : Double,
val rightImpurity : Double,
val predict : Double) extends Serializable {
val leftImpurity: Double,
val rightImpurity: Double,
val predict: Double) extends Serializable {

override def toString = {
"gain = %f, impurity = %f, left impurity = %f, right impurity = %f, predict = %f"
Expand Down
Loading

0 comments on commit cd2c2b4

Please sign in to comment.