diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index d99b63a414481..fe511bd904e70 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -84,7 +84,7 @@ class PipelineModel( * Gets the model produced by the input estimator. Throws an NoSuchElementException is the input * estimator does not exist in the pipeline. */ - def getModel[M](estimator: Estimator[M]): M = { + def getModel[M <: Model](estimator: Estimator[M]): M = { val matched = transformers.filter { case m: Model => m.parent.eq(estimator) case _ => false diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala index cb67d7c4b2af3..eba1b9c36c7ff 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala @@ -84,7 +84,7 @@ abstract class UnaryTransformer[IN, OUT: TypeTag, SELF <: UnaryTransformer[IN, O override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = { import dataset.sqlContext._ val map = this.paramMap ++ paramMap - val udf: IN => OUT = this.createTransformFunc(map) + val udf = this.createTransformFunc(map) dataset.select(Star(None), udf.call(map(inputCol).attr) as map(outputCol)) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala index 1ab8f2e55c852..b7ce0e644c887 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala @@ -28,11 +28,14 @@ import org.apache.spark.storage.StorageLevel * Evaluator for binary classification, which expects two input columns: score and label. */ class BinaryClassificationEvaluator extends Evaluator with Params - with HasScoreCol with HasLabelCol with HasMetricName { - - setMetricName("areaUnderROC") + with HasScoreCol with HasLabelCol { + /** param for metric name in evaluation */ + val metricName: Param[String] = new Param(this, "metricName", + "metric name in evaluation (areaUnderROC|areaUnderPR)", Some("areaUnderROC")) + def getMetricName: String = get(metricName) def setMetricName(value: String): this.type = { set(metricName, value); this } + def setScoreCol(value: String): this.type = { set(scoreCol, value); this } def setLabelCol(value: String): this.type = { set(labelCol, value); this } diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 17edec7f5b3c3..4999729780b0a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -25,7 +25,8 @@ import scala.collection.mutable import org.apache.spark.ml.Identifiable /** - * A param with self-contained documentation and optionally default value. + * A param with self-contained documentation and optionally default value. Primitive-typed param + * should use the specialized versions, which are more friendly to Java users. * * @param parent parent object * @param name param name @@ -59,26 +60,31 @@ class Param[T] ( // specialize primitive-typed params because Java doesn't recognize scala.Double, scala.Int, ... +/** Specialized version of [[Param[Double]] for Java. */ class DoubleParam(parent: Params, name: String, doc: String, default: Option[Double] = None) extends Param[Double](parent, name, doc, default) { override def w(value: Double): ParamPair[Double] = ParamPair(this, value) } +/** Specialized version of [[Param[Int]] for Java. */ class IntParam(parent: Params, name: String, doc: String, default: Option[Int] = None) extends Param[Int](parent, name, doc, default) { override def w(value: Int): ParamPair[Int] = ParamPair(this, value) } +/** Specialized version of [[Param[Float]] for Java. */ class FloatParam(parent: Params, name: String, doc: String, default: Option[Float] = None) extends Param[Float](parent, name, doc, default) { override def w(value: Float): ParamPair[Float] = ParamPair(this, value) } +/** Specialized version of [[Param[Long]] for Java. */ class LongParam(parent: Params, name: String, doc: String, default: Option[Long] = None) extends Param[Long](parent, name, doc, default) { override def w(value: Long): ParamPair[Long] = ParamPair(this, value) } +/** Specilized version of [[Param[Boolean]] for Java. */ class BooleanParam(parent: Params, name: String, doc: String, default: Option[Boolean] = None) extends Param[Boolean](parent, name, doc, default) { override def w(value: Boolean): ParamPair[Boolean] = ParamPair(this, value) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared.scala index 6656802601878..dad4f3372644e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared.scala @@ -63,12 +63,6 @@ private[ml] trait HasThreshold extends Params { def getThreshold: Double = get(threshold) } -private[ml] trait HasMetricName extends Params { - /** param for metric name in evaluation */ - val metricName: Param[String] = new Param(this, "metricName", "metric name in evaluation") - def getMetricName: String = get(metricName) -} - private[ml] trait HasInputCol extends Params { /** param for input column name */ val inputCol: Param[String] = new Param(this, "inputCol", "input column name")