Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-13568] [ML] Create feature transformer to impute missing values #11601

Closed
wants to merge 54 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
2999b26
initial commit for Imputer
hhbyyh Feb 29, 2016
8335cf2
adjust mean and most
hhbyyh Feb 29, 2016
7be5e9b
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Mar 2, 2016
131f7d5
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Mar 3, 2016
a72a3ea
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Mar 5, 2016
78df589
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Mar 7, 2016
b949be5
refine code and add ut
hhbyyh Mar 9, 2016
79b1c62
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Mar 9, 2016
c3d5d55
minor change
hhbyyh Mar 9, 2016
1b39668
add object Imputer and ut refine
hhbyyh Mar 9, 2016
7f87ffb
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Mar 10, 2016
4e45f81
add options validate and some small changes
hhbyyh Mar 10, 2016
e1dd0d2
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Mar 22, 2016
12220eb
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Mar 23, 2016
1b36deb
optimize mean for vectors
hhbyyh Mar 23, 2016
72d104d
style fix
hhbyyh Mar 23, 2016
c311b2e
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Apr 10, 2016
d6b9421
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Apr 11, 2016
d181b12
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Apr 11, 2016
e211481
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Apr 11, 2016
791533b
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Apr 12, 2016
fdd6f94
refactor to support numeric only
hhbyyh Apr 12, 2016
8042cfb
Merge branch 'imputer' of https://github.com/hhbyyh/spark into imputer
hhbyyh Apr 12, 2016
4bdf595
change most to mode
hhbyyh Apr 12, 2016
e6ad69c
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Apr 17, 2016
1718422
move filter to NaN
hhbyyh Apr 17, 2016
594c501
add transformSchema
hhbyyh Apr 20, 2016
3043e7d
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Apr 27, 2016
b3633e8
remove mode and change input type
hhbyyh Apr 27, 2016
053d489
remove print
hhbyyh Apr 27, 2016
63e7032
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Apr 28, 2016
4e1c34a
update document and remove a ut
hhbyyh Apr 28, 2016
051aec6
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Apr 29, 2016
aef094b
fix ut
hhbyyh Apr 29, 2016
335ded7
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Apr 29, 2016
949ed79
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Apr 30, 2016
93bba63
Merge branch 'imputer' of https://github.com/hhbyyh/spark into imputer
hhbyyh Apr 30, 2016
cca8dd4
rename ut
hhbyyh May 1, 2016
eea8947
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh May 3, 2016
4e07431
update parameter doc
hhbyyh May 3, 2016
31556e6
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Sep 7, 2016
d4f92e4
Merge branch 'imputer' of https://github.com/hhbyyh/spark into imputer
hhbyyh Sep 7, 2016
544a65c
update version
hhbyyh Sep 7, 2016
910685e
Merge remote-tracking branch 'upstream/master' into imputer
YY-OnCall Oct 6, 2016
91d4cee
throw exception
YY-OnCall Oct 7, 2016
8744524
change data format
YY-OnCall Oct 7, 2016
ca45c33
Merge remote-tracking branch 'upstream/master' into imputer
YY-OnCall Feb 22, 2017
e86d919
add multi column support
YY-OnCall Feb 22, 2017
4f17c54
Merge remote-tracking branch 'upstream/master' into imputer
YY-OnCall Mar 2, 2017
ce59a5b
Merge remote-tracking branch 'upstream/master' into imputer
YY-OnCall Mar 3, 2017
41d91b9
change surrogateDF format and add ut for multi-columns
YY-OnCall Mar 3, 2017
9f6bd57
Merge remote-tracking branch 'upstream/master' into imputer
YY-OnCall Mar 6, 2017
e378db5
unit test refine and comments update
YY-OnCall Mar 6, 2017
c67afc1
fix exception message
YY-OnCall Mar 8, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
259 changes: 259 additions & 0 deletions mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,259 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.feature

import org.apache.hadoop.fs.Path

import org.apache.spark.SparkException
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unused import

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not applicable anymore as it's used below now.

import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.HasInputCols
import org.apache.spark.ml.util._
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._

/**
* Params for [[Imputer]] and [[ImputerModel]].
*/
private[feature] trait ImputerParams extends Params with HasInputCols {

/**
* The imputation strategy.
* If "mean", then replace missing values using the mean value of the feature.
* If "median", then replace missing values using the approximate median value of the feature.
* Default: mean
*
* @group param
*/
final val strategy: Param[String] = new Param(this, "strategy", s"strategy for imputation. " +
s"If ${Imputer.mean}, then replace missing values using the mean value of the feature. " +
s"If ${Imputer.median}, then replace missing values using the median value of the feature.",
ParamValidators.inArray[String](Array(Imputer.mean, Imputer.median)))

/** @group getParam */
def getStrategy: String = $(strategy)

/**
* The placeholder for the missing values. All occurrences of missingValue will be imputed.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doc: Note that null values are always treated as missing.

* Note that null values are always treated as missing.
* Default: Double.NaN
*
* @group param
*/
final val missingValue: DoubleParam = new DoubleParam(this, "missingValue",
"The placeholder for the missing values. All occurrences of missingValue will be imputed")

/** @group getParam */
def getMissingValue: Double = $(missingValue)

/**
* Param for output column names.
* @group param
*/
final val outputCols: StringArrayParam = new StringArrayParam(this, "outputCols",
"output column names")

/** @group getParam */
final def getOutputCols: Array[String] = $(outputCols)

/** Validates and transforms the input schema. */
protected def validateAndTransformSchema(schema: StructType): StructType = {
require($(inputCols).length == $(inputCols).distinct.length, s"inputCols contains" +
s" duplicates: (${$(inputCols).mkString(", ")})")
require($(outputCols).length == $(outputCols).distinct.length, s"outputCols contains" +
s" duplicates: (${$(outputCols).mkString(", ")})")
require($(inputCols).length == $(outputCols).length, s"inputCols(${$(inputCols).length})" +
s" and outputCols(${$(outputCols).length}) should have the same length")
val outputFields = $(inputCols).zip($(outputCols)).map { case (inputCol, outputCol) =>
val inputField = schema(inputCol)
SchemaUtils.checkColumnTypes(schema, inputCol, Seq(DoubleType, FloatType))
StructField(outputCol, inputField.dataType, inputField.nullable)
}
StructType(schema ++ outputFields)
}
}

/**
* :: Experimental ::
* Imputation estimator for completing missing values, either using the mean or the median
* of the column in which the missing values are located. The input column should be of
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned above at https://github.com/apache/spark/pull/11601/files#r104403880, you can add the note about relative error here.

Something like "For computing median, approxQuantile is used with a relative error of X" (provide a ScalaDoc link to approxQuantile).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't add the link as it may break java doc generation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah right - perhaps just mention using approxQuantile?

* DoubleType or FloatType. Currently Imputer does not support categorical features yet
* (SPARK-15041) and possibly creates incorrect values for a categorical feature.
*
* Note that the mean/median value is computed after filtering out missing values.
* All Null values in the input column are treated as missing, and so are also imputed. For
* computing median, DataFrameStatFunctions.approxQuantile is used with a relative error of 0.001.
*/
@Experimental
class Imputer @Since("2.2.0")(override val uid: String)
extends Estimator[ImputerModel] with ImputerParams with DefaultParamsWritable {

@Since("2.2.0")
def this() = this(Identifiable.randomUID("imputer"))

/** @group setParam */
@Since("2.2.0")
def setInputCols(value: Array[String]): this.type = set(inputCols, value)

/** @group setParam */
@Since("2.2.0")
def setOutputCols(value: Array[String]): this.type = set(outputCols, value)

/**
* Imputation strategy. Available options are ["mean", "median"].
* @group setParam
*/
@Since("2.2.0")
def setStrategy(value: String): this.type = set(strategy, value)

/** @group setParam */
@Since("2.2.0")
def setMissingValue(value: Double): this.type = set(missingValue, value)

setDefault(strategy -> Imputer.mean, missingValue -> Double.NaN)

override def fit(dataset: Dataset[_]): ImputerModel = {
transformSchema(dataset.schema, logging = true)
val spark = dataset.sparkSession
import spark.implicits._
val surrogates = $(inputCols).map { inputCol =>
val ic = col(inputCol)
val filtered = dataset.select(ic.cast(DoubleType))
.filter(ic.isNotNull && ic =!= $(missingValue) && !ic.isNaN)
if(filtered.take(1).length == 0) {
throw new SparkException(s"surrogate cannot be computed. " +
s"All the values in $inputCol are Null, Nan or missingValue(${$(missingValue)})")
}
val surrogate = $(strategy) match {
case Imputer.mean => filtered.select(avg(inputCol)).as[Double].first()
case Imputer.median => filtered.stat.approxQuantile(inputCol, Array(0.5), 0.001).head
}
surrogate
}

val rows = spark.sparkContext.parallelize(Seq(Row.fromSeq(surrogates)))
val schema = StructType($(inputCols).map(col => StructField(col, DoubleType, nullable = false)))
val surrogateDF = spark.createDataFrame(rows, schema)
copyValues(new ImputerModel(uid, surrogateDF).setParent(this))
}

override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema)
}

override def copy(extra: ParamMap): Imputer = defaultCopy(extra)
}

@Since("2.2.0")
object Imputer extends DefaultParamsReadable[Imputer] {

/** strategy names that Imputer currently supports. */
private[ml] val mean = "mean"
private[ml] val median = "median"

@Since("2.2.0")
override def load(path: String): Imputer = super.load(path)
}

/**
* :: Experimental ::
* Model fitted by [[Imputer]].
*
* @param surrogateDF a DataFrame contains inputCols and their corresponding surrogates, which are
* used to replace the missing values in the input DataFrame.
*/
@Experimental
class ImputerModel private[ml](
override val uid: String,
val surrogateDF: DataFrame)
extends Model[ImputerModel] with ImputerParams with MLWritable {

import ImputerModel._

/** @group setParam */
def setInputCols(value: Array[String]): this.type = set(inputCols, value)

/** @group setParam */
def setOutputCols(value: Array[String]): this.type = set(outputCols, value)

override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
var outputDF = dataset
val surrogates = surrogateDF.select($(inputCols).map(col): _*).head().toSeq

$(inputCols).zip($(outputCols)).zip(surrogates).foreach {
case ((inputCol, outputCol), surrogate) =>
val inputType = dataset.schema(inputCol).dataType
val ic = col(inputCol)
outputDF = outputDF.withColumn(outputCol,
when(ic.isNull, surrogate)
.when(ic === $(missingValue), surrogate)
.otherwise(ic)
.cast(inputType))
}
outputDF.toDF()
}

override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema)
}

override def copy(extra: ParamMap): ImputerModel = {
val copied = new ImputerModel(uid, surrogateDF)
copyValues(copied, extra).setParent(parent)
}

@Since("2.2.0")
override def write: MLWriter = new ImputerModelWriter(this)
}


@Since("2.2.0")
object ImputerModel extends MLReadable[ImputerModel] {

private[ImputerModel] class ImputerModelWriter(instance: ImputerModel) extends MLWriter {

override protected def saveImpl(path: String): Unit = {
DefaultParamsWriter.saveMetadata(instance, path, sc)
val dataPath = new Path(path, "data").toString
instance.surrogateDF.repartition(1).write.parquet(dataPath)
}
}

private class ImputerReader extends MLReader[ImputerModel] {

private val className = classOf[ImputerModel].getName

override def load(path: String): ImputerModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
val surrogateDF = sqlContext.read.parquet(dataPath)
val model = new ImputerModel(metadata.uid, surrogateDF)
DefaultParamsReader.getAndSetParams(model, metadata)
model
}
}

@Since("2.2.0")
override def read: MLReader[ImputerModel] = new ImputerReader

@Since("2.2.0")
override def load(path: String): ImputerModel = super.load(path)
}
Loading