Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-18481][ML] ML 2.1 QA: Remove deprecated methods for ML #15913

Closed
wants to merge 12 commits into from
4 changes: 4 additions & 0 deletions mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ abstract class PipelineStage extends Params with Logging {
*
* Check transform validity and derive the output schema from the input schema.
*
* We check validity for interactions between parameters during `transformSchema` and
* raise an exception if any parameter value is invalid. Parameter value checks which
* do not depend on other parameters are handled by `Param.validate()`.
*
* Typical implementation should first conduct verification on schema change and parameter
* validity, including complex parameter interaction checks.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,12 @@ class GBTClassificationModel private[ml](
@Since("1.4.0")
override def trees: Array[DecisionTreeRegressionModel] = _trees

/**
* Number of trees in ensemble
*/
@Since("2.0.0")
val getNumTrees: Int = trees.length

@Since("1.4.0")
override def treeWeights: Array[Double] = _treeWeights

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions.{col, lit}
import org.apache.spark.sql.types.DoubleType
import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.VersionUtils

Expand Down Expand Up @@ -176,8 +176,12 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas
}
}

override def validateParams(): Unit = {
override protected def validateAndTransformSchema(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this actually necessary? Before, validateParams() was never used. Seems like we could just remove it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's a bug that validateParams was never used. It should validate params interaction before fitting(if necessary), this is why we deprecate validateParams and move what it does to transformSchema. We do not have corresponding test cases before, so no test was broken when we deprecated validateParams. I added test cases in this PR.

schema: StructType,
fitting: Boolean,
featuresDataType: DataType): StructType = {
checkThresholdConsistency()
super.validateAndTransformSchema(schema, fitting, featuresDataType)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ class RandomForestClassificationModel private[ml] (
@Since("1.6.0") override val numFeatures: Int,
@Since("1.5.0") override val numClasses: Int)
extends ProbabilisticClassificationModel[Vector, RandomForestClassificationModel]
with RandomForestClassificationModelParams with TreeEnsembleModel[DecisionTreeClassificationModel]
with RandomForestClassifierParams with TreeEnsembleModel[DecisionTreeClassificationModel]
with MLWritable with Serializable {

require(_trees.nonEmpty, "RandomForestClassificationModel requires at least 1 tree.")
Expand Down Expand Up @@ -221,15 +221,6 @@ class RandomForestClassificationModel private[ml] (
}
}

/**
* Number of trees in ensemble
*
* @deprecated Use [[getNumTrees]] instead. This method will be removed in 2.1.0
*/
// TODO: Once this is removed, then this class can inherit from RandomForestClassifierParams
@deprecated("Use getNumTrees instead. This method will be removed in 2.1.0.", "2.0.0")
val numTrees: Int = trees.length

@Since("1.4.0")
override def copy(extra: ParamMap): RandomForestClassificationModel = {
copyValues(new RandomForestClassificationModel(uid, _trees, numFeatures, numClasses), extra)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,13 +216,6 @@ final class ChiSqSelectorModel private[ml] (
@Since("1.6.0")
def setOutputCol(value: String): this.type = set(outputCol, value)

/**
* @group setParam
*/
@Since("1.6.0")
@deprecated("labelCol is not used by ChiSqSelectorModel.", "2.0.0")
def setLabelCol(value: String): this.type = set(labelCol, value)

@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
val transformedSchema = transformSchema(dataset.schema, logging = true)
Expand Down
15 changes: 0 additions & 15 deletions mllib/src/main/scala/org/apache/spark/ml/param/params.scala
Original file line number Diff line number Diff line change
Expand Up @@ -546,21 +546,6 @@ trait Params extends Identifiable with Serializable {
.map(m => m.invoke(this).asInstanceOf[Param[_]])
}

/**
* Validates parameter values stored internally.
* Raise an exception if any parameter value is invalid.
*
* This only needs to check for interactions between parameters.
* Parameter value checks which do not depend on other parameters are handled by
* `Param.validate()`. This method does not handle input/output column parameters;
* those are checked during schema validation.
* @deprecated Will be removed in 2.1.0. All the checks should be merged into transformSchema
*/
@deprecated("Will be removed in 2.1.0. Checks should be merged into transformSchema.", "2.0.0")
def validateParams(): Unit = {
// Do nothing by default. Override to handle Param interactions.
}

/**
* Explains a param.
* @param param input param, must belong to this instance.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,12 @@ class GBTRegressionModel private[ml](
@Since("1.4.0")
override def trees: Array[DecisionTreeRegressionModel] = _trees

/**
* Number of trees in ensemble
*/
@Since("2.0.0")
val getNumTrees: Int = trees.length
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this have an @Since tag?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added.


@Since("1.4.0")
override def treeWeights: Array[Double] = _treeWeights

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -605,9 +605,6 @@ class LinearRegressionSummary private[regression] (
private val privateModel: LinearRegressionModel,
private val diagInvAtWA: Array[Double]) extends Serializable {

@deprecated("The model field is deprecated and will be removed in 2.1.0.", "2.0.0")
val model: LinearRegressionModel = privateModel

@transient private val metrics = new RegressionMetrics(
predictions
.select(col(predictionCol), col(labelCol).cast(DoubleType))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ class RandomForestRegressionModel private[ml] (
private val _trees: Array[DecisionTreeRegressionModel],
override val numFeatures: Int)
extends PredictionModel[Vector, RandomForestRegressionModel]
with RandomForestRegressionModelParams with TreeEnsembleModel[DecisionTreeRegressionModel]
with RandomForestRegressorParams with TreeEnsembleModel[DecisionTreeRegressionModel]
with MLWritable with Serializable {

require(_trees.nonEmpty, "RandomForestRegressionModel requires at least 1 tree.")
Expand Down Expand Up @@ -181,14 +181,6 @@ class RandomForestRegressionModel private[ml] (
_trees.map(_.rootNode.predictImpl(features).prediction).sum / getNumTrees
}

/**
* Number of trees in ensemble
* @deprecated Use [[getNumTrees]] instead. This method will be removed in 2.1.0
*/
// TODO: Once this is removed, then this class can inherit from RandomForestRegressorParams
@deprecated("Use getNumTrees instead. This method will be removed in 2.1.0.", "2.0.0")
val numTrees: Int = trees.length

@Since("1.4.0")
override def copy(extra: ParamMap): RandomForestRegressionModel = {
copyValues(new RandomForestRegressionModel(uid, _trees, numFeatures), extra).setParent(parent)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,6 @@ private[ml] trait TreeEnsembleModel[M <: DecisionTreeModel] {
/** Trees in this ensemble. Warning: These have null parent Estimators. */
def trees: Array[M]

/**
* Number of trees in ensemble
*/
val getNumTrees: Int = trees.length

/** Weights for each tree, zippable with [[trees]] */
def treeWeights: Array[Double]

Expand Down
90 changes: 39 additions & 51 deletions mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
Original file line number Diff line number Diff line change
Expand Up @@ -317,8 +317,32 @@ private[ml] trait TreeEnsembleParams extends DecisionTreeParams {
}
}

/** Used for [[RandomForestParams]] */
private[ml] trait HasFeatureSubsetStrategy extends Params {
/**
* Parameters for Random Forest algorithms.
*/
private[ml] trait RandomForestParams extends TreeEnsembleParams {

/**
* Number of trees to train (>= 1).
* If 1, then no bootstrapping is used. If > 1, then bootstrapping is done.
* TODO: Change to always do bootstrapping (simpler). SPARK-7130
* (default = 20)
*
* Note: The reason that we cannot add this to both GBT and RF (i.e. in TreeEnsembleParams)
* is the param `maxIter` controls how many trees a GBT has. The semantics in the algorithms
* are a bit different.
* @group param
*/
final val numTrees: IntParam = new IntParam(this, "numTrees", "Number of trees to train (>= 1)",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: the reason that we cannot add this to both GBT and RF (i.e. in TreeEnsembleParams) is because the param maxIter controls how many trees a GBT has. The semantics in the algos are a bit different.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good suggestion, added.

ParamValidators.gtEq(1))

setDefault(numTrees -> 20)

/** @group setParam */
def setNumTrees(value: Int): this.type = set(numTrees, value)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these setter methods in traits Java-compatible?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, we already have setNumTrees which calls super.setNumTrees in RandomForestClassifier and RandomForestRegressor.


/** @group getParam */
final def getNumTrees: Int = $(numTrees)

/**
* The number of features to consider for splits at each tree node.
Expand Down Expand Up @@ -364,38 +388,6 @@ private[ml] trait HasFeatureSubsetStrategy extends Params {
final def getFeatureSubsetStrategy: String = $(featureSubsetStrategy).toLowerCase
}

/**
* Used for [[RandomForestParams]].
* This is separated out from [[RandomForestParams]] because of an issue with the
* `numTrees` method conflicting with this Param in the Estimator.
*/
private[ml] trait HasNumTrees extends Params {

/**
* Number of trees to train (>= 1).
* If 1, then no bootstrapping is used. If > 1, then bootstrapping is done.
* TODO: Change to always do bootstrapping (simpler). SPARK-7130
* (default = 20)
* @group param
*/
final val numTrees: IntParam = new IntParam(this, "numTrees", "Number of trees to train (>= 1)",
ParamValidators.gtEq(1))

setDefault(numTrees -> 20)

/** @group setParam */
def setNumTrees(value: Int): this.type = set(numTrees, value)

/** @group getParam */
final def getNumTrees: Int = $(numTrees)
}

/**
* Parameters for Random Forest algorithms.
*/
private[ml] trait RandomForestParams extends TreeEnsembleParams
with HasFeatureSubsetStrategy with HasNumTrees

private[spark] object RandomForestParams {
// These options should be lowercase.
final val supportedFeatureSubsetStrategies: Array[String] =
Expand All @@ -405,21 +397,15 @@ private[spark] object RandomForestParams {
private[ml] trait RandomForestClassifierParams
extends RandomForestParams with TreeClassifierParams

private[ml] trait RandomForestClassificationModelParams extends TreeEnsembleParams
with HasFeatureSubsetStrategy with TreeClassifierParams

private[ml] trait RandomForestRegressorParams
extends RandomForestParams with TreeRegressorParams

private[ml] trait RandomForestRegressionModelParams extends TreeEnsembleParams
with HasFeatureSubsetStrategy with TreeRegressorParams

/**
* Parameters for Gradient-Boosted Tree algorithms.
*
* Note: Marked as private and DeveloperApi since this may be made public in the future.
*/
private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasStepSize {
private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter {

/* TODO: Add this doc when we add this param. SPARK-7132
* Threshold for stopping early when runWithValidation is used.
Expand All @@ -432,24 +418,26 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasS
// final val validationTol: DoubleParam = new DoubleParam(this, "validationTol", "")
// validationTol -> 1e-5

setDefault(maxIter -> 20, stepSize -> 0.1)

/** @group setParam */
def setMaxIter(value: Int): this.type = set(maxIter, value)

/**
* Step size (a.k.a. learning rate) in interval (0, 1] for shrinking the contribution of each
* estimator.
* Param for Step size (a.k.a. learning rate) in interval (0, 1] for shrinking
* the contribution of each estimator.
* (default = 0.1)
* @group setParam
* @group param
*/
final val stepSize: DoubleParam = new DoubleParam(this, "stepSize", "Step size " +
"(a.k.a. learning rate) in interval (0, 1] for shrinking the contribution of each estimator.",
ParamValidators.inRange(0, 1, lowerInclusive = false, upperInclusive = true))

/** @group getParam */
final def getStepSize: Double = $(stepSize)

/** @group setParam */
def setStepSize(value: Double): this.type = set(stepSize, value)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I asked about this being Java-friendly because Param setter methods in traits used to have the wrong type in Java. I wonder if this is no longer true. Still, it might make sense to remove the setter method from the trait since it does not make sense to have it in the Model classes. We could put the setter method in each subclass and then deprecate the method in the Model classes.

This issue also shows up in MimaExcludes. If setFeatureSubsetStrategy were put in the concrete classes from the outset, then you would not need to include it in MimaExcludes now.

I'm Ok with doing the change now or in a follow-up PR.

I believe I was the one who incorrectly put the setters in the traits...

Copy link
Contributor Author

@yanboliang yanboliang Nov 26, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I understand what you mean. If we would like to correct the setter methods in traits, we involve changes to lots of traits which include DecisionTreeParams, TreeClassifierParams, TreeRegressorParams, RandomForestParams, GBTParams, etc. So i will merge this firstly after it passed Jenkins and address this issue in a separate follow-up PR. Thanks.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jkbradley I have sent #16017 to fix this issue, please feel free to comment that. Thanks.


override def validateParams(): Unit = {
require(ParamValidators.inRange(0, 1, lowerInclusive = false, upperInclusive = true)(
getStepSize), "GBT parameter stepSize should be in interval (0, 1], " +
s"but it given invalid value $getStepSize.")
}
setDefault(maxIter -> 20, stepSize -> 0.1)

/** (private[ml]) Create a BoostingStrategy instance to use with the old API. */
private[ml] def getOldBoostingStrategy(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ private[util] sealed trait BaseReadWrite {
* Sets the Spark SQLContext to use for saving/loading.
*/
@Since("1.6.0")
@deprecated("Use session instead", "2.0.0")
@deprecated("Use session instead, This method will be removed in 2.2.0.", "2.0.0")
def context(sqlContext: SQLContext): this.type = {
optionSparkSession = Option(sqlContext.sparkSession)
this
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,14 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
ParamsSuite.checkParams(model)
}

test("GBT parameter stepSize should be in interval (0, 1]") {
withClue("GBT parameter stepSize should be in interval (0, 1]") {
intercept[IllegalArgumentException] {
new GBTClassifier().setStepSize(10)
}
}
}

test("Binary classification with continuous features: Log Loss") {
val categoricalFeatures = Map.empty[Int, Int]
testCombinations.foreach {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,12 @@ class LogisticRegressionSuite
}
}
// thresholds and threshold must be consistent: values
withClue("fit with ParamMap should throw error if threshold, thresholds do not match.") {
intercept[IllegalArgumentException] {
lr2.fit(smallBinaryDataset,
lr2.thresholds -> Array(0.3, 0.7), lr2.threshold -> (expectedThreshold / 2.0))
}
}
withClue("fit with ParamMap should throw error if threshold, thresholds do not match.") {
intercept[IllegalArgumentException] {
val lr2model = lr2.fit(smallBinaryDataset,
Expand Down
30 changes: 30 additions & 0 deletions project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -867,6 +867,36 @@ object MimaExcludes {
// [SPARK-12221] Add CPU time to metrics
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskMetrics.this"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskMetricDistributions.this")
) ++ Seq(
// [SPARK-18481] ML 2.1 QA: Remove deprecated methods for ML
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.PipelineStage.validateParams"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.param.JavaParams.validateParams"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.param.Params.validateParams"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.validateParams"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegression.validateParams"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassifier.validateParams"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionModel.validateParams"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.numTrees"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.feature.ChiSqSelectorModel.setLabelCol"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.evaluation.Evaluator.validateParams"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressor.validateParams"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.validateParams"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.LinearRegressionSummary.model"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.numTrees"),
ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.classification.RandomForestClassifier"),
ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel"),
ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.classification.GBTClassifier"),
ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.classification.GBTClassificationModel"),
ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.regression.RandomForestRegressor"),
ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel"),
ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.regression.GBTRegressor"),
ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.regression.GBTRegressionModel"),
ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.getNumTrees"),
ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.getNumTrees"),
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.numTrees"),
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setFeatureSubsetStrategy"),
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.numTrees"),
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setFeatureSubsetStrategy")
)
}

Expand Down
Loading