From 118b158fac017a125dbc644597c4bf7eb58a0016 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 6 May 2015 19:32:28 -0700 Subject: [PATCH] Params.setDefault taking a set of ParamPairs should be annotated with varargs. I thought it would not work before, but it apparently does. CrossValidator.transform should call transformSchema since the underlying Model might be a PipelineModel --- mllib/src/main/scala/org/apache/spark/ml/param/params.scala | 4 +--- .../scala/org/apache/spark/ml/tuning/CrossValidator.scala | 3 ++- .../test/java/org/apache/spark/ml/param/JavaTestParams.java | 1 + 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 51ce19d29cd29..6d09962fe6ee2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -366,13 +366,11 @@ trait Params extends Identifiable with Serializable { /** * Sets default values for a list of params. * - * Note: Java developers should use the single-parameter [[setDefault()]]. - * Annotating this with varargs causes compilation failures. - * * @param paramPairs a list of param pairs that specify params and their default values to set * respectively. Make sure that the params are initialized before this method * gets called. */ + @varargs protected final def setDefault(paramPairs: ParamPair[_]*): this.type = { paramPairs.foreach { p => setDefault(p.param.asInstanceOf[Param[Any]], p.value) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 9208127eb1d79..ac0d1fed84b2e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -105,7 +105,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP override def fit(dataset: DataFrame): CrossValidatorModel = { val schema = dataset.schema - transformSchema(dataset.schema, logging = true) + transformSchema(schema, logging = true) val sqlCtx = dataset.sqlContext val est = $(estimator) val eval = $(evaluator) @@ -159,6 +159,7 @@ class CrossValidatorModel private[ml] ( } override def transform(dataset: DataFrame): DataFrame = { + transformSchema(dataset.schema, logging = true) bestModel.transform(dataset) } diff --git a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java index 8abe575610d19..532eca47918fc 100644 --- a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java +++ b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java @@ -59,5 +59,6 @@ public JavaTestParams() { ParamValidators.inArray(validStrings)); setDefault(myIntParam, 1); setDefault(myDoubleParam, 0.5); + setDefault(myIntParam.w(1), myDoubleParam.w(0.5)); } }