Skip to content

Commit

Permalink
add options validate and some small changes
Browse files Browse the repository at this point in the history
  • Loading branch information
hhbyyh committed Mar 10, 2016
1 parent 7f87ffb commit 4e45f81
Showing 1 changed file with 28 additions and 26 deletions.
54 changes: 28 additions & 26 deletions mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import org.apache.hadoop.fs.Path

import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.param.{DoubleParam, Param, ParamMap, Params}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg._
Expand All @@ -36,29 +36,30 @@ private[feature] trait ImputerParams extends Params with HasInputCol with HasOut

/**
* The imputation strategy.
* If "mean", then replace missing values using the mean along the axis.
* If "median", then replace missing values using the median along the axis.
* If "most", then replace missing using the most frequent value along the axis.
* If "mean", then replace missing values using the mean value of the feature.
* If "median", then replace missing values using the median value of the feature.
* If "most", then replace missing using the most frequent value of the feature.
* Default: mean
*
* @group param
*/
val strategy: Param[String] = new Param(this, "strategy", "strategy for imputation. " +
"If mean, then replace missing values using the mean along the axis." +
"If median, then replace missing values using the median along the axis." +
"If most, then replace missing using the most frequent value along the axis.")
"If mean, then replace missing values using the mean value of the feature." +
"If median, then replace missing values using the median value of the feature." +
"If most, then replace missing using the most frequent value of the feature.",
ParamValidators.inArray[String](Imputer.supportedStrategyNames.toArray))

/** @group getParam */
def getStrategy: String = $(strategy)

/**
* The placeholder for the missing values. All occurrences of missingvalues will be imputed.
* The placeholder for the missing values. All occurrences of missingValue will be imputed.
* Default: Double.NaN
*
* @group param
*/
val missingValue: DoubleParam = new DoubleParam(this, "missingValue",
"The placeholder for the missing values. All occurrences of missingvalues will be imputed")
"The placeholder for the missing values. All occurrences of missingValue will be imputed")

/** @group getParam */
def getMissingValue: Double = $(missingValue)
Expand All @@ -75,18 +76,13 @@ private[feature] trait ImputerParams extends Params with HasInputCol with HasOut
StructType(outputFields)
}

override def validateParams(): Unit = {
require(Seq("mean", "median", "most").contains($(strategy)),
s"${$(strategy)} is not supported. Options are mean, median and most")
}
}

/**
* :: Experimental ::
* Imputation estimator for completing missing values, either using the mean, the median or
* the most frequent value of the column in which the missing values are located. This class
* also allows for different missing values encodings.
*
* also allows for different missing values.
*/
@Experimental
class Imputer @Since("2.0.0")(override val uid: String)
Expand All @@ -101,7 +97,10 @@ class Imputer @Since("2.0.0")(override val uid: String)
/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)

/** @group setParam */
/**
* Imputation strategy. Available options are "mean", "median" and "most".
* @group setParam
*/
def setStrategy(value: String): this.type = set(strategy, value)

/** @group setParam */
Expand All @@ -112,15 +111,14 @@ class Imputer @Since("2.0.0")(override val uid: String)
override def fit(dataset: DataFrame): ImputerModel = {
val alternate = dataset.select($(inputCol)).schema.fields(0).dataType match {
case DoubleType =>
val colStatistics = getColStatistics(dataset, $(inputCol))
Vectors.dense(Array(colStatistics))
Vectors.dense(getColStatistics(dataset, $(inputCol)))
case _: VectorUDT =>
val vl = dataset.first().getAs[Vector]($(inputCol)).size
val statisticsArray = new Array[Double](vl)
(0 until vl).foreach(i => {
val getI = udf((v: Vector) => v(i))
val tempColName = $(inputCol) + i
val tempData = dataset.where(s"${$(inputCol)} is not null")
val tempData = dataset.where(s"${$(inputCol)} IS NOT NULL")
.select($(inputCol)).withColumn(tempColName, getI(col($(inputCol))))
statisticsArray(i) = getColStatistics(tempData, tempColName)
})
Expand All @@ -129,6 +127,7 @@ class Imputer @Since("2.0.0")(override val uid: String)
copyValues(new ImputerModel(uid, alternate).setParent(this))
}

/** Extract the statistics info from a Double column according to the strategy */
private def getColStatistics(dataset: DataFrame, colName: String): Double = {
val missValue = $(missingValue) match {
case Double.NaN => "NaN"
Expand All @@ -143,7 +142,7 @@ class Imputer @Since("2.0.0")(override val uid: String)
val rddDouble = filteredDF.rdd.map(_.getDouble(0))
rddDouble.sortBy(d => d).zipWithIndex().map {
case (v, idx) => (idx, v)
}.lookup(rddDouble.count()/2).head
}.lookup(rddDouble.count() / 2).head
case "most" =>
val input = filteredDF.rdd.map(_.getDouble(0))
val most = input.map(d => (d, 1)).reduceByKey(_ + _).sortBy(-_._2).first()._1
Expand All @@ -165,6 +164,9 @@ class Imputer @Since("2.0.0")(override val uid: String)
@Since("1.6.0")
object Imputer extends DefaultParamsReadable[Imputer] {

/** Set of strategy names that Imputer currently supports. */
private[ml] val supportedStrategyNames = Set("mean", "median", "most")

@Since("1.6.0")
override def load(path: String): Imputer = super.load(path)
}
Expand All @@ -173,7 +175,7 @@ object Imputer extends DefaultParamsReadable[Imputer] {
* :: Experimental ::
* Model fitted by [[Imputer]].
*
* @param alternate statistics value for each original column during fitting
* @param alternate statistics value for each feature during fitting
*/
@Experimental
class ImputerModel private[ml] (
Expand All @@ -189,7 +191,7 @@ class ImputerModel private[ml] (
/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)

private def matchMissingValue(value: Double): Boolean = {
private def isMissingValue(value: Double): Boolean = {
val miss = $(missingValue)
value == miss || (value.isNaN && miss.isNaN)
}
Expand All @@ -198,7 +200,7 @@ class ImputerModel private[ml] (
dataset.select($(inputCol)).schema.fields(0).dataType match {
case DoubleType =>
val impute = udf { (d: Double) =>
if (matchMissingValue(d)) alternate(0) else d
if (isMissingValue(d)) alternate(0) else d
}
dataset.withColumn($(outputCol), impute(col($(inputCol))))
case _: VectorUDT =>
Expand All @@ -208,20 +210,20 @@ class ImputerModel private[ml] (
}
else {
val vCopy = vector.copy
// TODO replace with update() since this hacks the internal implementation of Vector.
vCopy match {
case d: DenseVector =>
var iter = 0
while(iter < d.size) {
if (matchMissingValue(vCopy(iter))) {
if (isMissingValue(vCopy(iter))) {
d.values(iter) = alternate(iter)
}

iter += 1
}
case s: SparseVector =>
var iter = 0
while(iter < s.values.length) {
if (matchMissingValue(s.values(iter))) {
if (isMissingValue(s.values(iter))) {
s.values(iter) = alternate(s.indices(iter))
}
iter += 1
Expand Down

0 comments on commit 4e45f81

Please sign in to comment.