Skip to content

Commit

Permalink
[SPARK-7127] [MLLIB] Adding broadcast of model before prediction for …
Browse files Browse the repository at this point in the history
…ensembles

Broadcast of ensemble models in transformImpl before call to predict

Author: Bryan Cutler <[email protected]>

Closes #6300 from BryanCutler/bcast-ensemble-models-7127 and squashes the following commits:

86e73de [Bryan Cutler] [SPARK-7127] Replaced deprecated callUDF with udf
40a139d [Bryan Cutler] Merge branch 'master' into bcast-ensemble-models-7127
9afad56 [Bryan Cutler] [SPARK-7127] Simplified calls by overriding transformImpl and using broadcasted model in callUDF to make prediction
1f34be4 [Bryan Cutler] [SPARK-7127] Removed accidental newline
171a6ce [Bryan Cutler] [SPARK-7127] Used modelAccessor parameter in predictImpl to access broadcasted model
6fd153c [Bryan Cutler] [SPARK-7127] Applied broadcasting to remaining ensemble models
aaad77b [Bryan Cutler] [SPARK-7127] Removed abstract class for broadcasting model, instead passing a prediction function as param to transform
83904bb [Bryan Cutler] [SPARK-7127] Adding broadcast of model before prediction in RandomForestClassifier
  • Loading branch information
BryanCutler authored and jkbradley committed Jul 17, 2015
1 parent 830666f commit 8b8be1f
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 8 deletions.
12 changes: 8 additions & 4 deletions mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -169,17 +169,21 @@ abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType,
override def transform(dataset: DataFrame): DataFrame = {
transformSchema(dataset.schema, logging = true)
if ($(predictionCol).nonEmpty) {
val predictUDF = udf { (features: Any) =>
predict(features.asInstanceOf[FeaturesType])
}
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
transformImpl(dataset)
} else {
this.logWarning(s"$uid: Predictor.transform() was called as NOOP" +
" since no output columns were set.")
dataset
}
}

protected def transformImpl(dataset: DataFrame): DataFrame = {
val predictUDF = udf { (features: Any) =>
predict(features.asInstanceOf[FeaturesType])
}
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
}

/**
* Predict label for the given features.
* This internal method is used to implement [[transform()]] and output [[predictionCol]].
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ import org.apache.spark.mllib.tree.loss.{LogLoss => OldLogLoss, Loss => OldLoss}
import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.DoubleType

/**
* :: Experimental ::
Expand Down Expand Up @@ -177,8 +179,15 @@ final class GBTClassificationModel(

override def treeWeights: Array[Double] = _treeWeights

override protected def transformImpl(dataset: DataFrame): DataFrame = {
val bcastModel = dataset.sqlContext.sparkContext.broadcast(this)
val predictUDF = udf { (features: Any) =>
bcastModel.value.predict(features.asInstanceOf[Vector])
}
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
}

override protected def predict(features: Vector): Double = {
// TODO: Override transform() to broadcast model: SPARK-7127
// TODO: When we add a generic Boosting class, handle transform there? SPARK-7129
// Classifies by thresholding sum of weighted tree predictions
val treePredictions = _trees.map(_.rootNode.predict(features))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.DoubleType

/**
* :: Experimental ::
Expand Down Expand Up @@ -143,8 +145,15 @@ final class RandomForestClassificationModel private[ml] (

override def treeWeights: Array[Double] = _treeWeights

override protected def transformImpl(dataset: DataFrame): DataFrame = {
val bcastModel = dataset.sqlContext.sparkContext.broadcast(this)
val predictUDF = udf { (features: Any) =>
bcastModel.value.predict(features.asInstanceOf[Vector])
}
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
}

override protected def predict(features: Vector): Double = {
// TODO: Override transform() to broadcast model. SPARK-7127
// TODO: When we add a generic Bagging class, handle transform there: SPARK-7128
// Classifies using majority votes.
// Ignore the weights since all are 1.0 for now.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, Loss
import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.DoubleType

/**
* :: Experimental ::
Expand Down Expand Up @@ -167,8 +169,15 @@ final class GBTRegressionModel(

override def treeWeights: Array[Double] = _treeWeights

override protected def transformImpl(dataset: DataFrame): DataFrame = {
val bcastModel = dataset.sqlContext.sparkContext.broadcast(this)
val predictUDF = udf { (features: Any) =>
bcastModel.value.predict(features.asInstanceOf[Vector])
}
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
}

override protected def predict(features: Vector): Double = {
// TODO: Override transform() to broadcast model. SPARK-7127
// TODO: When we add a generic Boosting class, handle transform there? SPARK-7129
// Classifies by thresholding sum of weighted tree predictions
val treePredictions = _trees.map(_.rootNode.predict(features))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.DoubleType

/**
* :: Experimental ::
Expand Down Expand Up @@ -129,8 +131,15 @@ final class RandomForestRegressionModel private[ml] (

override def treeWeights: Array[Double] = _treeWeights

override protected def transformImpl(dataset: DataFrame): DataFrame = {
val bcastModel = dataset.sqlContext.sparkContext.broadcast(this)
val predictUDF = udf { (features: Any) =>
bcastModel.value.predict(features.asInstanceOf[Vector])
}
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
}

override protected def predict(features: Vector): Double = {
// TODO: Override transform() to broadcast model. SPARK-7127
// TODO: When we add a generic Bagging class, handle transform there. SPARK-7128
// Predict average of tree predictions.
// Ignore the weights since all are 1.0 for now.
Expand Down

0 comments on commit 8b8be1f

Please sign in to comment.