Skip to content

Commit

Permalink
Params.setDefault taking a set of ParamPairs should be annotated with…
Browse files Browse the repository at this point in the history
… varargs. I thought it would not work before, but it apparently does.

CrossValidator.transform should call transformSchema since the underlying Model might be a PipelineModel
  • Loading branch information
jkbradley committed May 7, 2015
1 parent 71a452b commit 118b158
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 4 deletions.
4 changes: 1 addition & 3 deletions mllib/src/main/scala/org/apache/spark/ml/param/params.scala
Original file line number Diff line number Diff line change
Expand Up @@ -366,13 +366,11 @@ trait Params extends Identifiable with Serializable {
/**
* Sets default values for a list of params.
*
* Note: Java developers should use the single-parameter [[setDefault()]].
* Annotating this with varargs causes compilation failures.
*
* @param paramPairs a list of param pairs that specify params and their default values to set
* respectively. Make sure that the params are initialized before this method
* gets called.
*/
@varargs
protected final def setDefault(paramPairs: ParamPair[_]*): this.type = {
paramPairs.foreach { p =>
setDefault(p.param.asInstanceOf[Param[Any]], p.value)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP

override def fit(dataset: DataFrame): CrossValidatorModel = {
val schema = dataset.schema
transformSchema(dataset.schema, logging = true)
transformSchema(schema, logging = true)
val sqlCtx = dataset.sqlContext
val est = $(estimator)
val eval = $(evaluator)
Expand Down Expand Up @@ -159,6 +159,7 @@ class CrossValidatorModel private[ml] (
}

override def transform(dataset: DataFrame): DataFrame = {
transformSchema(dataset.schema, logging = true)
bestModel.transform(dataset)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,5 +59,6 @@ public JavaTestParams() {
ParamValidators.inArray(validStrings));
setDefault(myIntParam, 1);
setDefault(myDoubleParam, 0.5);
setDefault(myIntParam.w(1), myDoubleParam.w(0.5));
}
}

0 comments on commit 118b158

Please sign in to comment.