diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationTree.scala new file mode 100644 index 0000000000000..a6f27e7fd1111 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationTree.scala @@ -0,0 +1,21 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.mllib.classification + +class ClassificationTree { + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionTree.scala new file mode 100644 index 0000000000000..fd9beb79ab88f --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionTree.scala @@ -0,0 +1,21 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.mllib.regression + +class RegressionTree { + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala new file mode 100644 index 0000000000000..e3a52990bfb65 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree + +import org.apache.spark.rdd.RDD +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.model.{Split, Bin, DecisionTreeModel} + + +class DecisionTree(val strategy : Strategy) { + + def train(input : RDD[LabeledPoint]) : DecisionTreeModel = { + + //Cache input RDD for speedup during multiple passes + input.cache() + + //TODO: Find all splits and bins using quantiles including support for categorical features, single-pass + val (splits, bins) = DecisionTree.find_splits_bins(input, strategy) + + //TODO: Level-wise training of tree and obtain Decision Tree model + + + return new DecisionTreeModel() + } + +} + +object DecisionTree { + def find_splits_bins(input : RDD[LabeledPoint], strategy : Strategy) : (Array[Array[Split]], Array[Array[Bin]]) = { + val numSplits = strategy.numSplits + //TODO: Justify this calculation + val requiredSamples : Long = numSplits*numSplits + val count : Long = input.count() + val numSamples : Long = if (requiredSamples < count) requiredSamples else count + val numFeatures = input.take(1)(0).features.length + (Array.ofDim[Split](numFeatures,numSplits),Array.ofDim[Bin](numFeatures,numSplits)) + } + +} \ No newline at end of file diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/README.md b/mllib/src/main/scala/org/apache/spark/mllib/tree/README.md new file mode 100644 index 0000000000000..61cba281cccfc --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/README.md @@ -0,0 +1,15 @@ +This package contains the default implementation of the decision tree algorithm. + +The decision tree algorithm supports: ++ information loss calculation with entropy and gini for classification and variance for regression ++ node model pruning ++ printing to dot files ++ unit tests + +#Performance testing + +#Future Extensions + ++ Random forests ++ Boosting ++ Extremely randomized trees \ No newline at end of file diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/Strategy.scala new file mode 100644 index 0000000000000..12f0bb37d793c --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/Strategy.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.mllib.tree + +import org.apache.spark.mllib.tree.impurity.Impurity + +class Strategy ( + val kind : String, + val impurity : Impurity, + val maxDepth : Int, + val numSplits : Int, + val quantileCalculationStrategy : String = "sampleAndSort") { + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala new file mode 100644 index 0000000000000..00feb25e25322 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.mllib.tree.impurity + +object Entropy extends Impurity { + + def log2(x: Double) = scala.math.log(x) / scala.math.log(2) + + def calculate(c0: Double, c1: Double): Double = { + if (c0 == 0 || c1 == 0) { + 0 + } else { + val total = c0 + c1 + val f0 = c0 / total + val f1 = c1 / total + -(f0 * log2(f0)) - (f1 * log2(f1)) + } + } + + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala new file mode 100644 index 0000000000000..a95f0431c6e8f --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.mllib.tree.impurity + +object Gini extends Impurity { + + def calculate(c0 : Double, c1 : Double): Double = { + val total = c0 + c1 + val f0 = c0 / total + val f1 = c1 / total + 1 - f0*f0 - f1*f1 + } + + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala new file mode 100644 index 0000000000000..fadebb0c203eb --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala @@ -0,0 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.mllib.tree.impurity + +trait Impurity { + + def calculate(c0 : Double, c1 : Double): Double + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala new file mode 100644 index 0000000000000..98f332122785e --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala @@ -0,0 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.mllib.tree.impurity + +import javax.naming.OperationNotSupportedException + +object Variance extends Impurity { + def calculate(c0: Double, c1: Double): Double = throw new OperationNotSupportedException("Variance.calculate") + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala new file mode 100644 index 0000000000000..638d65805a813 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala @@ -0,0 +1,21 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.mllib.tree.model + +case class Bin(kind : String, lowSplit : Split, highSplit : Split) { + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala new file mode 100644 index 0000000000000..d0465d8c6fb6f --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -0,0 +1,21 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.mllib.tree.model + +class DecisionTreeModel { + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala new file mode 100644 index 0000000000000..4f8beb73bbb09 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.mllib.tree.model + +case class Split( + val feature: Int, + val threshold : Double, + val kind : String) { + +} + +class dummyLowSplit(kind : String) extends Split(Int.MinValue, Double.MinValue, kind) + +class dummyHighSplit(kind : String) extends Split(Int.MaxValue, Double.MaxValue, kind) +