Skip to content

Commit

Permalink
change surrogateDF format and add ut for multi-columns
Browse files Browse the repository at this point in the history
  • Loading branch information
YY-OnCall committed Mar 3, 2017
1 parent ce59a5b commit 41d91b9
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 82 deletions.
118 changes: 59 additions & 59 deletions mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.SparkException
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.{HasInputCols, HasOutputCol}
import org.apache.spark.ml.param.shared.HasInputCols
import org.apache.spark.ml.util._
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._
Expand All @@ -32,20 +32,21 @@ import org.apache.spark.sql.types._
/**
* Params for [[Imputer]] and [[ImputerModel]].
*/
private[feature] trait ImputerParams extends Params with HasInputCols with HasOutputCol {
private[feature] trait ImputerParams extends Params with HasInputCols {

/**
* The imputation strategy.
* If "mean", then replace missing values using the mean value of the feature.
* If "median", then replace missing values using the approximate median value of the feature.
* If "median", then replace missing values using the approximate median value of the
* feature (relative error less than 0.001).
* Default: mean
*
* @group param
*/
final val strategy: Param[String] = new Param(this, "strategy", "strategy for imputation. " +
"If mean, then replace missing values using the mean value of the feature. " +
"If median, then replace missing values using the median value of the feature.",
ParamValidators.inArray[String](Imputer.supportedStrategyNames.toArray))
final val strategy: Param[String] = new Param(this, "strategy", s"strategy for imputation. " +
s"If ${Imputer.mean}, then replace missing values using the mean value of the feature. " +
s"If ${Imputer.median}, then replace missing values using the median value of the feature.",
ParamValidators.inArray[String](Array(Imputer.mean, Imputer.median)))

/** @group getParam */
def getStrategy: String = $(strategy)
Expand All @@ -63,7 +64,7 @@ private[feature] trait ImputerParams extends Params with HasInputCols with HasOu
/** @group getParam */
def getMissingValue: Double = $(missingValue)

/**
/**
* Param for output column names.
* @group param
*/
Expand All @@ -75,20 +76,18 @@ private[feature] trait ImputerParams extends Params with HasInputCols with HasOu

/** Validates and transforms the input schema. */
protected def validateAndTransformSchema(schema: StructType): StructType = {
require($(inputCols).length == $(outputCols).length, "inputCols and outputCols should have" +
"the same length")
val localInputCols = $(inputCols)
val localOutputCols = $(outputCols)
var outputSchema = schema

$(inputCols).indices.foreach { i =>
val inputCol = localInputCols(i)
val outputCol = localOutputCols(i)
val inputType = schema(inputCol).dataType
require($(inputCols).length == $(inputCols).distinct.length, s"inputCols duplicates:" +
s" (${$(inputCols).mkString(", ")})")
require($(outputCols).length == $(outputCols).distinct.length, s"outputCols duplicates:" +
s" (${$(outputCols).mkString(", ")})")
require($(inputCols).length == $(outputCols).length, s"inputCols(${$(inputCols).length})" +
s" and outputCols(${$(outputCols).length}) should have the same length")
val outputFields = $(inputCols).zip($(outputCols)).map { case (inputCol, outputCol) =>
val inputField = schema(inputCol)
SchemaUtils.checkColumnTypes(schema, inputCol, Seq(DoubleType, FloatType))
outputSchema = SchemaUtils.appendColumn(outputSchema, outputCol, inputType)
StructField(outputCol, inputField.dataType, inputField.nullable)
}
outputSchema
StructType(schema ++ outputFields)
}
}

Expand All @@ -103,53 +102,56 @@ private[feature] trait ImputerParams extends Params with HasInputCols with HasOu
* All Null values in the input column are treated as missing, and so are also imputed.
*/
@Experimental
class Imputer @Since("2.1.0")(override val uid: String)
class Imputer @Since("2.2.0")(override val uid: String)
extends Estimator[ImputerModel] with ImputerParams with DefaultParamsWritable {

@Since("2.1.0")
@Since("2.2.0")
def this() = this(Identifiable.randomUID("imputer"))

/** @group setParam */
@Since("2.1.0")
@Since("2.2.0")
def setInputCols(value: Array[String]): this.type = set(inputCols, value)

/** @group setParam */
@Since("2.1.0")
@Since("2.2.0")
def setOutputCols(value: Array[String]): this.type = set(outputCols, value)

/**
* Imputation strategy. Available options are ["mean", "median"].
* @group setParam
*/
@Since("2.1.0")
@Since("2.2.0")
def setStrategy(value: String): this.type = set(strategy, value)

/** @group setParam */
@Since("2.1.0")
@Since("2.2.0")
def setMissingValue(value: Double): this.type = set(missingValue, value)

setDefault(strategy -> "mean", missingValue -> Double.NaN)
import org.apache.spark.ml.feature.Imputer._
setDefault(strategy -> mean, missingValue -> Double.NaN)

override def fit(dataset: Dataset[_]): ImputerModel = {
transformSchema(dataset.schema, logging = true)
val spark = dataset.sparkSession
import spark.implicits._
val surrogates = $(inputCols).map { inputCol =>
val ic = col(inputCol)
val filtered = dataset.select(ic.cast(DoubleType))
.filter(ic.isNotNull && ic =!= $(missingValue))
.filter(!ic.isNaN)
.filter(ic.isNotNull && ic =!= $(missingValue) && !ic.isNaN)
if(filtered.rdd.isEmpty()) {
throw new SparkException(s"surrogate cannot be computed. " +
s"All the values in ${inputCol} are Null, Nan or missingValue ($missingValue)")
s"All the values in $inputCol are Null, Nan or missingValue ($missingValue)")
}
val surrogate = $(strategy) match {
case "mean" => filtered.select(avg(inputCol)).first().getDouble(0)
case "median" => filtered.stat.approxQuantile(inputCol, Array(0.5), 0.001)(0)
case Imputer.mean => filtered.select(avg(inputCol)).as[Double].first()
case Imputer.median => filtered.stat.approxQuantile(inputCol, Array(0.5), 0.001).head
}
surrogate.asInstanceOf[Double]
surrogate
}

import dataset.sparkSession.implicits._
val surrogateDF = Seq(surrogates).toDF("surrogates")
val rows = spark.sparkContext.parallelize(Seq(Row.fromSeq(surrogates)))
val schema = StructType($(inputCols).map(col => StructField(col, DoubleType, nullable = false)))
val surrogateDF = spark.createDataFrame(rows, schema)
copyValues(new ImputerModel(uid, surrogateDF).setParent(this))
}

Expand All @@ -160,22 +162,23 @@ class Imputer @Since("2.1.0")(override val uid: String)
override def copy(extra: ParamMap): Imputer = defaultCopy(extra)
}

@Since("2.1.0")
@Since("2.2.0")
object Imputer extends DefaultParamsReadable[Imputer] {

/** Set of strategy names that Imputer currently supports. */
private[ml] val supportedStrategyNames = Set("mean", "median")
/** strategy names that Imputer currently supports. */
private[ml] val mean = "mean"
private[ml] val median = "median"

@Since("2.1.0")
@Since("2.2.0")
override def load(path: String): Imputer = super.load(path)
}

/**
* :: Experimental ::
* Model fitted by [[Imputer]].
*
* @param surrogateDF Value by which missing values in the input columns will be replaced. This
* is stored using DataFrame with input column names and the corresponding surrogates.
* @param surrogateDF a DataFrame contains inputCols and their corresponding surrogates, which are
* used to replace the missing values in the input DataFrame.
*/
@Experimental
class ImputerModel private[ml](
Expand All @@ -193,21 +196,18 @@ class ImputerModel private[ml](

override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
val localInputCols = $(inputCols)
val localOutputCols = $(outputCols)
var outputDF = dataset
val surrogates = surrogateDF.head().getSeq[Double](0)

$(inputCols).indices.foreach { i =>
val inputCol = localInputCols(i)
val outputCol = localOutputCols(i)
val inputType = dataset.schema(inputCol).dataType
val ic = col(inputCol)
val icSurrogate = surrogates(i)
outputDF = outputDF.withColumn(outputCol, when(ic.isNull, icSurrogate)
.when(ic === $(missingValue), icSurrogate)
.otherwise(ic)
.cast(inputType))
val surrogates = surrogateDF.select($(inputCols).head, $(inputCols).tail: _*).head().toSeq

$(inputCols).zip($(outputCols)).zip(surrogates).foreach {
case ((inputCol, outputCol), surrogate) =>
val inputType = dataset.schema(inputCol).dataType
val ic = col(inputCol)
outputDF = outputDF.withColumn(outputCol,
when(ic.isNull, surrogate)
.when(ic === $(missingValue), surrogate)
.otherwise(ic)
.cast(inputType))
}
outputDF.toDF()
}
Expand All @@ -221,12 +221,12 @@ class ImputerModel private[ml](
copyValues(copied, extra).setParent(parent)
}

@Since("2.1.0")
@Since("2.2.0")
override def write: MLWriter = new ImputerModelWriter(this)
}


@Since("2.1.0")
@Since("2.2.0")
object ImputerModel extends MLReadable[ImputerModel] {

private[ImputerModel] class ImputerModelWriter(instance: ImputerModel) extends MLWriter {
Expand All @@ -252,9 +252,9 @@ object ImputerModel extends MLReadable[ImputerModel] {
}
}

@Since("2.1.0")
@Since("2.2.0")
override def read: MLReader[ImputerModel] = new ImputerReader

@Since("2.1.0")
@Since("2.2.0")
override def load(path: String): ImputerModel = super.load(path)
}
76 changes: 53 additions & 23 deletions mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,15 @@ class ImputerSuite extends SparkFunSuite with MLlibTestSparkContext with Default

test("Imputer for Double with default missing Value NaN") {
val df = spark.createDataFrame( Seq(
(0, 1.0, 1.0, 1.0),
(1, 1.0, 1.0, 1.0),
(2, 3.0, 3.0, 3.0),
(3, 4.0, 4.0, 4.0),
(4, Double.NaN, 2.25, 1.0)
)).toDF("id", "value", "expected_mean", "expected_median")
val imputer = new Imputer().setInputCols(Array("value")).setOutputCols(Array("out"))
(0, 1.0, 4.0, 1.0, 1.0, 4.0, 4.0),
(1, 11.0, 12.0, 11.0, 11.0, 12.0, 12.0),
(2, 3.0, Double.NaN, 3.0, 3.0, 10.0, 12.0),
(3, Double.NaN, 14.0, 5.0, 3.0, 14.0, 14.0)
)).toDF("id", "value1", "value2", "expected_mean_value1", "expected_median_value1",
"expected_mean_value2", "expected_median_value2")
val imputer = new Imputer()
.setInputCols(Array("value1", "value2"))
.setOutputCols(Array("out1", "out2"))
ImputerSuite.iterateStrategyTest(imputer, df)
}

Expand All @@ -42,7 +44,7 @@ class ImputerSuite extends SparkFunSuite with MLlibTestSparkContext with Default
(1, 3.0, 3.0, 3.0),
(2, Double.NaN, Double.NaN, Double.NaN),
(3, -1.0, 2.0, 3.0)
)).toDF("id", "value", "expected_mean", "expected_median")
)).toDF("id", "value", "expected_mean_value", "expected_median_value")
val imputer = new Imputer().setInputCols(Array("value")).setOutputCols(Array("out"))
.setMissingValue(-1.0)
ImputerSuite.iterateStrategyTest(imputer, df)
Expand All @@ -55,32 +57,31 @@ class ImputerSuite extends SparkFunSuite with MLlibTestSparkContext with Default
(2, 10.0F, 10.0F, 10.0F),
(3, 10.0F, 10.0F, 10.0F),
(4, -1.0F, 6.0F, 3.0F)
)).toDF("id", "value", "expected_mean", "expected_median")
)).toDF("id", "value", "expected_mean_value", "expected_median_value")
val imputer = new Imputer().setInputCols(Array("value")).setOutputCols(Array("out"))
.setMissingValue(-1)
ImputerSuite.iterateStrategyTest(imputer, df)
}

test("Imputer should impute null as well as 'missingValue'") {
val df = spark.createDataFrame( Seq(
val rawDf = spark.createDataFrame( Seq(
(0, 4.0, 4.0, 4.0),
(1, 10.0, 10.0, 10.0),
(2, 10.0, 10.0, 10.0),
(3, Double.NaN, 8.0, 10.0),
(4, -1.0, 8.0, 10.0)
)).toDF("id", "value", "expected_mean", "expected_median")
val df2 = df.selectExpr("*", "IF(value=-1.0, null, value) as nullable_value")
val imputer = new Imputer().setInputCols(Array("nullable_value")).setOutputCols(Array("out"))
ImputerSuite.iterateStrategyTest(imputer, df2)
)).toDF("id", "rawValue", "expected_mean_value", "expected_median_value")
val df = rawDf.selectExpr("*", "IF(rawValue=-1.0, null, rawValue) as value")
val imputer = new Imputer().setInputCols(Array("value")).setOutputCols(Array("out"))
ImputerSuite.iterateStrategyTest(imputer, df)
}


test("Imputer throws exception when surrogate cannot be computed") {
val df = spark.createDataFrame( Seq(
(0, Double.NaN, 1.0, 1.0),
(1, Double.NaN, 3.0, 3.0),
(2, Double.NaN, Double.NaN, Double.NaN)
)).toDF("id", "value", "expected_mean", "expected_median")
)).toDF("id", "value", "expected_mean_value", "expected_median_value")
Seq("mean", "median").foreach { strategy =>
val imputer = new Imputer().setInputCols(Array("value")).setOutputCols(Array("out"))
.setStrategy(strategy)
Expand All @@ -90,6 +91,30 @@ class ImputerSuite extends SparkFunSuite with MLlibTestSparkContext with Default
}
}

test("Imputer throws exception when inputCols does not match outputCols") {
val df = spark.createDataFrame( Seq(
(0, 1.0, 1.0, 1.0),
(1, Double.NaN, 3.0, 3.0),
(2, Double.NaN, Double.NaN, Double.NaN)
)).toDF("id", "value1", "value2", "value3")
Seq("mean", "median").foreach { strategy =>
// inputCols and outCols length different
val imputer = new Imputer()
.setInputCols(Array("value1", "value2"))
.setOutputCols(Array("out1"))
.setStrategy(strategy)
intercept[IllegalArgumentException] {
val model = imputer.fit(df)
}
// duplicate name in inputCols
imputer.setInputCols(Array("value1", "value1")).setOutputCols(Array("out1, out2"))
intercept[IllegalArgumentException] {
val model = imputer.fit(df)
}

}
}

test("Imputer read/write") {
val t = new Imputer()
.setInputCols(Array("myInputCol"))
Expand Down Expand Up @@ -120,16 +145,21 @@ object ImputerSuite{
* @param df DataFrame with columns "id", "value", "expected_mean", "expected_median"
*/
def iterateStrategyTest(imputer: Imputer, df: DataFrame): Unit = {
val inputCols = imputer.getInputCols

Seq("mean", "median").foreach { strategy =>
imputer.setStrategy(strategy)
val model = imputer.fit(df)
model.transform(df).select("expected_" + strategy, "out").collect().foreach {
case Row(exp: Float, out: Float) =>
assert((exp.isNaN && out.isNaN) || (exp == out),
s"Imputed values differ. Expected: $exp, actual: $out")
case Row(exp: Double, out: Double) =>
assert((exp.isNaN && out.isNaN) || (exp ~== out absTol 1e-5),
s"Imputed values differ. Expected: $exp, actual: $out")
val resultDF = model.transform(df)
imputer.getInputCols.zip(imputer.getOutputCols).foreach { case (inputCol, outputCol) =>
resultDF.select(s"expected_${strategy}_$inputCol", outputCol).collect().foreach {
case Row(exp: Float, out: Float) =>
assert((exp.isNaN && out.isNaN) || (exp == out),
s"Imputed values differ. Expected: $exp, actual: $out")
case Row(exp: Double, out: Double) =>
assert((exp.isNaN && out.isNaN) || (exp ~== out absTol 1e-5),
s"Imputed values differ. Expected: $exp, actual: $out")
}
}
}
}
Expand Down

0 comments on commit 41d91b9

Please sign in to comment.