Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-13568] [ML] Create feature transformer to impute missing values #11601

Closed
wants to merge 54 commits into from
Closed
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
2999b26
initial commit for Imputer
hhbyyh Feb 29, 2016
8335cf2
adjust mean and most
hhbyyh Feb 29, 2016
7be5e9b
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Mar 2, 2016
131f7d5
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Mar 3, 2016
a72a3ea
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Mar 5, 2016
78df589
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Mar 7, 2016
b949be5
refine code and add ut
hhbyyh Mar 9, 2016
79b1c62
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Mar 9, 2016
c3d5d55
minor change
hhbyyh Mar 9, 2016
1b39668
add object Imputer and ut refine
hhbyyh Mar 9, 2016
7f87ffb
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Mar 10, 2016
4e45f81
add options validate and some small changes
hhbyyh Mar 10, 2016
e1dd0d2
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Mar 22, 2016
12220eb
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Mar 23, 2016
1b36deb
optimize mean for vectors
hhbyyh Mar 23, 2016
72d104d
style fix
hhbyyh Mar 23, 2016
c311b2e
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Apr 10, 2016
d6b9421
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Apr 11, 2016
d181b12
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Apr 11, 2016
e211481
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Apr 11, 2016
791533b
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Apr 12, 2016
fdd6f94
refactor to support numeric only
hhbyyh Apr 12, 2016
8042cfb
Merge branch 'imputer' of https://github.com/hhbyyh/spark into imputer
hhbyyh Apr 12, 2016
4bdf595
change most to mode
hhbyyh Apr 12, 2016
e6ad69c
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Apr 17, 2016
1718422
move filter to NaN
hhbyyh Apr 17, 2016
594c501
add transformSchema
hhbyyh Apr 20, 2016
3043e7d
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Apr 27, 2016
b3633e8
remove mode and change input type
hhbyyh Apr 27, 2016
053d489
remove print
hhbyyh Apr 27, 2016
63e7032
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Apr 28, 2016
4e1c34a
update document and remove a ut
hhbyyh Apr 28, 2016
051aec6
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Apr 29, 2016
aef094b
fix ut
hhbyyh Apr 29, 2016
335ded7
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Apr 29, 2016
949ed79
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Apr 30, 2016
93bba63
Merge branch 'imputer' of https://github.com/hhbyyh/spark into imputer
hhbyyh Apr 30, 2016
cca8dd4
rename ut
hhbyyh May 1, 2016
eea8947
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh May 3, 2016
4e07431
update parameter doc
hhbyyh May 3, 2016
31556e6
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Sep 7, 2016
d4f92e4
Merge branch 'imputer' of https://github.com/hhbyyh/spark into imputer
hhbyyh Sep 7, 2016
544a65c
update version
hhbyyh Sep 7, 2016
910685e
Merge remote-tracking branch 'upstream/master' into imputer
YY-OnCall Oct 6, 2016
91d4cee
throw exception
YY-OnCall Oct 7, 2016
8744524
change data format
YY-OnCall Oct 7, 2016
ca45c33
Merge remote-tracking branch 'upstream/master' into imputer
YY-OnCall Feb 22, 2017
e86d919
add multi column support
YY-OnCall Feb 22, 2017
4f17c54
Merge remote-tracking branch 'upstream/master' into imputer
YY-OnCall Mar 2, 2017
ce59a5b
Merge remote-tracking branch 'upstream/master' into imputer
YY-OnCall Mar 3, 2017
41d91b9
change surrogateDF format and add ut for multi-columns
YY-OnCall Mar 3, 2017
9f6bd57
Merge remote-tracking branch 'upstream/master' into imputer
YY-OnCall Mar 6, 2017
e378db5
unit test refine and comments update
YY-OnCall Mar 6, 2017
c67afc1
fix exception message
YY-OnCall Mar 8, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
290 changes: 290 additions & 0 deletions mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,290 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.feature

import org.apache.hadoop.fs.Path

import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg._
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.{DoubleType, StructField, StructType}

/**
* Params for [[Imputer]] and [[ImputerModel]].
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indentation errors here and elsewhere

*/
private[feature] trait ImputerParams extends Params with HasInputCol with HasOutputCol {

/**
* The imputation strategy.
* If "mean", then replace missing values using the mean value of the feature.
* If "median", then replace missing values using the median value of the feature.
* If "most", then replace missing using the most frequent value of the feature.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rename to "mode"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Joseph, Do you mean to change "strategy" to "mode"? "strategy" is from http://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.Imputer.html

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rename "most" to "mode" (mode is the most frequent value in the dataset)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Thanks.

* Default: mean
*
* @group param
*/
val strategy: Param[String] = new Param(this, "strategy", "strategy for imputation. " +
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make all Param vals final

"If mean, then replace missing values using the mean value of the feature." +
Copy link
Contributor

@sethah sethah Apr 28, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There needs to be a space after "... feature." here.

"If median, then replace missing values using the median value of the feature." +
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using median here is fine I think, but can we document that it's actually an approximation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, thanks.

"If most, then replace missing using the most frequent value of the feature.",
ParamValidators.inArray[String](Imputer.supportedStrategyNames.toArray))

/** @group getParam */
def getStrategy: String = $(strategy)

/**
* The placeholder for the missing values. All occurrences of missingValue will be imputed.
* Default: Double.NaN
*
* @group param
*/
val missingValue: DoubleParam = new DoubleParam(this, "missingValue",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about null values? Should we treat all null values as missing? I could imagine cases in which people want to handle both NaN and null.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. added support for null values.

"The placeholder for the missing values. All occurrences of missingValue will be imputed")

/** @group getParam */
def getMissingValue: Double = $(missingValue)

/** Validates and transforms the input schema. */
protected def validateAndTransformSchema(schema: StructType): StructType = {
validateParams()
val inputType = schema($(inputCol)).dataType
require(inputType.isInstanceOf[VectorUDT] || inputType.isInstanceOf[DoubleType],
s"Input column ${$(inputCol)} must of type Vector or Double")
require(!schema.fieldNames.contains($(outputCol)),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is already checked in appendColumn

s"Output column ${$(outputCol)} already exists.")
val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false)
StructType(outputFields)
}

}

/**
* :: Experimental ::
* Imputation estimator for completing missing values, either using the mean, the median or
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Document the accepted input schema

* the most frequent value of the column in which the missing values are located. This class
* also allows for different missing values.
*/
@Experimental
class Imputer @Since("2.0.0")(override val uid: String)
extends Estimator[ImputerModel] with ImputerParams with DefaultParamsWritable {

@Since("2.0.0")
def this() = this(Identifiable.randomUID("imputer"))

/** @group setParam */
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we add Since annotations for the setters?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've heard an argument that everything in the class is implicitly since 2.1.0 since the class itself is - unless otherwise stated. Which does make sense. But I do slightly favour being explicit about it (even if it is a bit pedantic) so yeah let's add the annotation to all the setters.

def setInputCol(value: String): this.type = set(inputCol, value)

/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)

/**
* Imputation strategy. Available options are "mean", "median" and "most".
* @group setParam
*/
def setStrategy(value: String): this.type = set(strategy, value)

/** @group setParam */
def setMissingValue(value: Double): this.type = set(missingValue, value)

setDefault(strategy -> "mean", missingValue -> Double.NaN)

override def fit(dataset: DataFrame): ImputerModel = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 for using existing implementations in DataFrame or spark.mllib.stats. I don't think we need to implement anything new here, right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For single numerical column, we can indeed use existing stat functions for computing mean and median (after filtering the missing values), i.e. mean and approxQuantile.

If we decide to support vector columns, then we need:

  1. Statistics.colStats to handle NaN / missing values (null?) SPARK-13639
  2. to check whether approxQuantiles function can be used for this. My sense is it can't directly, we would need a version that can operate on Array or Vector columns

val alternate = dataset.select($(inputCol)).schema.fields(0).dataType match {
case DoubleType =>
Vectors.dense(getColStatistics(dataset, $(inputCol)))
case _: VectorUDT =>
val vl = dataset.first().getAs[Vector]($(inputCol)).size
val statisticsArray = new Array[Double](vl)
(0 until vl).foreach(i => {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be a big performance issue with large vectors, as we could be running 100s (or millions!) of SQL queries sequentially... For vectors I favour the colStats approach of using MultivariateOnlineSummarizer for efficiency. I think that if we support vectors here, we should rather enable ignoring the NaNs in colStats or a deeper API (even if we do it as private or DevelopApi for now).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. I'll run some performance benchmark to compare the solutions and get back to you. @MLnick Thanks for the review.

val getI = udf((v: Vector) => v(i))
val tempColName = $(inputCol) + i
val tempData = dataset.where(s"${$(inputCol)} IS NOT NULL")
.select($(inputCol)).withColumn(tempColName, getI(col($(inputCol))))
statisticsArray(i) = getColStatistics(tempData, tempColName)
})
Vectors.dense(statisticsArray)
}
copyValues(new ImputerModel(uid, alternate).setParent(this))
}

/** Extract the statistics info from a Double column according to the strategy */
private def getColStatistics(dataset: DataFrame, colName: String): Double = {
val missValue = $(missingValue) match {
case Double.NaN => "NaN"
case _ => $(missingValue).toString
}
val filteredDF = dataset.select(colName).where(s"$colName != '$missValue'")
val colStatistics = $(strategy) match {
case "mean" =>
filteredDF.selectExpr(s"avg($colName)").first().getDouble(0)
case "median" =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should favour using the new approxQuantile sql stat function here rather than computing exactly.

// TODO: optimize the sort with quick-select or Percentile(Hive) if required
val rddDouble = filteredDF.rdd.map(_.getDouble(0))
rddDouble.sortBy(d => d).zipWithIndex().map {
case (v, idx) => (idx, v)
}.lookup(rddDouble.count() / 2).head
case "most" =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a little worried about performance here - on huge columns of Double this could be a problem. I struggle to actually see a use case for most... perhaps on categorical or ordinal columns it can make sense.

Will think about it a bit more.

val input = filteredDF.rdd.map(_.getDouble(0))
val most = input.map(d => (d, 1)).reduceByKey(_ + _).sortBy(-_._2).first()._1
most
}
colStatistics
}

override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema)
}

override def copy(extra: ParamMap): Imputer = {
val copied = new Imputer(uid)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use defaultCopy

copyValues(copied, extra)
}
}

@Since("1.6.0")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update Since versions

object Imputer extends DefaultParamsReadable[Imputer] {

/** Set of strategy names that Imputer currently supports. */
private[ml] val supportedStrategyNames = Set("mean", "median", "most")

@Since("1.6.0")
override def load(path: String): Imputer = super.load(path)
}

/**
* :: Experimental ::
* Model fitted by [[Imputer]].
*
* @param alternate statistics value for each feature during fitting
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is "alternate" the commonly used name for imputation? I have heard "surrogate" used more frequently, but may be wrong.

*/
@Experimental
class ImputerModel private[ml] (
override val uid: String,
val alternate: Vector)
extends Model[ImputerModel] with ImputerParams with MLWritable {

import ImputerModel._

/** @group setParam */
def setInputCol(value: String): this.type = set(inputCol, value)

/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)

private def isMissingValue(value: Double): Boolean = {
val miss = $(missingValue)
value == miss || (value.isNaN && miss.isNaN)
}

override def transform(dataset: DataFrame): DataFrame = {
dataset.select($(inputCol)).schema.fields(0).dataType match {
case DoubleType =>
val impute = udf { (d: Double) =>
if (isMissingValue(d)) alternate(0) else d
}
dataset.withColumn($(outputCol), impute(col($(inputCol))))
case _: VectorUDT =>
val impute = udf { (vector: Vector) =>
if (vector == null) {
alternate
}
else {
val vCopy = vector.copy
// TODO replace with update() since this hacks the internal implementation of Vector.
vCopy match {
case d: DenseVector =>
var iter = 0
while(iter < d.size) {
if (isMissingValue(vCopy(iter))) {
d.values(iter) = alternate(iter)
}
iter += 1
}
case s: SparseVector =>
var iter = 0
while(iter < s.values.length) {
if (isMissingValue(s.values(iter))) {
s.values(iter) = alternate(s.indices(iter))
}
iter += 1
}
}
vCopy
}
}
dataset.withColumn($(outputCol), impute(col($(inputCol))))
}
}

override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema)
}

override def copy(extra: ParamMap): ImputerModel = {
val copied = new ImputerModel(uid, alternate)
copyValues(copied, extra).setParent(parent)
}

@Since("2.0.0")
override def write: MLWriter = new ImputerModelWriter(this)
}


@Since("2.0.0")
object ImputerModel extends MLReadable[ImputerModel] {

private[ImputerModel]
class ImputerModelWriter(instance: ImputerModel) extends MLWriter {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor: this fits on one line.


private case class Data(alternate: Vector)

override protected def saveImpl(path: String): Unit = {
DefaultParamsWriter.saveMetadata(instance, path, sc)
val data = new Data(instance.alternate)
val dataPath = new Path(path, "data").toString
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
}

private class ImputerReader extends MLReader[ImputerModel] {

private val className = classOf[ImputerModel].getName

override def load(path: String): ImputerModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
val Row(alternate: Vector) = sqlContext.read.parquet(dataPath)
.select("alternate")
.head()
val model = new ImputerModel(metadata.uid, alternate)
DefaultParamsReader.getAndSetParams(model, metadata)
model
}
}

@Since("2.0.0")
override def read: MLReader[ImputerModel] = new ImputerReader

@Since("2.0.0")
override def load(path: String): ImputerModel = super.load(path)
}
105 changes: 105 additions & 0 deletions mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml.feature

import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.Row

class ImputerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need tests for multiple columns too


test("Imputer for Double column") {
val df = sqlContext.createDataFrame( Seq(
(0, 1.0, 1.0, 1.0, 1.0),
(1, 1.0, 1.0, 1.0, 1.0),
(2, 3.0, 3.0, 3.0, 3.0),
(3, 4.0, 4.0, 4.0, 4.0),
(4, Double.NaN, 2.25, 3.0, 1.0 )
)).toDF("id", "value", "mean", "median", "most")
Seq("mean", "median", "most").foreach { strategy =>
val imputer = new Imputer().setInputCol("value").setOutputCol("out").setStrategy(strategy)
val model = imputer.fit(df)
model.transform(df).select(strategy, "out").collect()
.foreach { case Row(d1: Double, d2: Double) =>
assert(d1 ~== d2 absTol 1e-5, s"Imputer ut error: $d2 should be $d1")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we make this error message more informative, like "Imputed values differ. Expected: $d1, actual: $d2"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, changed.

}
}
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to add tests for the case where the entire column is null or NaN. I just checked the NaN case and it will throw a NPE in the fit method.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch yes - obviously the imputer can't actually do anything useful in that case - but it should either throw a useful error, or return the dataset unchanged.

I would favor an error in this case as if a user is explicitly wanting to impute missing data and all their data is missing, rather blow up now than later in the pipeline.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, actually this also fails if the entire input column is the missing value as well. We need to beef up the test suite :)

test("Imputer for Double with missing Value -1.0") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd argue this test is not necessary because of the test right below, where we test that it works for integer missing values and leaves NaNs alone.

val df = sqlContext.createDataFrame( Seq(
(0, 1.0, 1.0, 1.0, 1.0),
(1, 1.0, 1.0, 1.0, 1.0),
(2, 3.0, 3.0, 3.0, 3.0),
(3, 4.0, 4.0, 4.0, 4.0),
(4, -1.0, 2.25, 3.0, 1.0 )
)).toDF("id", "value", "mean", "median", "most")
Seq("mean", "median", "most").foreach { strategy =>
val imputer = new Imputer().setInputCol("value").setOutputCol("out").setStrategy(strategy)
.setMissingValue(-1.0)
val model = imputer.fit(df)
model.transform(df).select(strategy, "out").collect()
.foreach { case Row(d1: Double, d2: Double) =>
assert(d1 ~== d2 absTol 1e-5, s"Imputer ut error: $d2 should be $d1")
}
}
}

test("Imputer for Vector column with NaN and null") {
val df = sqlContext.createDataFrame( Seq(
(0, Vectors.dense(1, 2), Vectors.dense(1, 2), Vectors.dense(1, 2), Vectors.dense(1, 2)),
(1, Vectors.dense(1, 2), Vectors.dense(1, 2), Vectors.dense(1, 2), Vectors.dense(1, 2)),
(2, Vectors.dense(3, 2), Vectors.dense(3, 2), Vectors.dense(3, 2), Vectors.dense(3, 2)),
(3, Vectors.dense(4, 2), Vectors.dense(4, 2), Vectors.dense(4, 2), Vectors.dense(4, 2)),
(4, Vectors.dense(Double.NaN, 2), Vectors.dense(2.25, 2), Vectors.dense(3.0, 2),
Vectors.dense(1.0, 2)),
(5, Vectors.sparse(2, Array(0, 1), Array(Double.NaN, 2.0)), Vectors.dense(2.25, 2),
Vectors.dense(3.0, 2), Vectors.dense(1.0, 2)),
(6, null.asInstanceOf[Vector], Vectors.dense(2.25, 2), Vectors.dense(3.0, 2),
Vectors.dense(1.0, 2))
)).toDF("id", "value", "mean", "median", "most")
Seq("mean", "median", "most").foreach { strategy =>
val imputer = new Imputer().setInputCol("value").setOutputCol("out").setStrategy(strategy)
val model = imputer.fit(df)
model.transform(df).select(strategy, "out").collect()
.foreach { case Row(v1: Vector, v2: Vector) =>
assert(v1 == v2, s"$strategy Imputer ut error: $v2 should be $v1")
}
}
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should also have a test for a non-NaN missing value, but with NaN in the dataset, to check that "mean" and "median" behave as we expect.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

test("Imputer read/write") {
val t = new Imputer()
.setInputCol("myInputCol")
.setOutputCol("myOutputCol")
.setMissingValue(-1.0)
testDefaultReadWrite(t)
}

test("ImputerModel read/write") {
val instance = new ImputerModel(
"myImputer", Vectors.dense(1.0, 10.0))
.setInputCol("myInputCol")
.setOutputCol("myOutputCol")
val newInstance = testDefaultReadWrite(instance)
assert(newInstance.alternate === instance.alternate)
}

}