diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala new file mode 100644 index 0000000000000..b1a802ee13fc4 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala @@ -0,0 +1,259 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.hadoop.fs.Path + +import org.apache.spark.SparkException +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared.HasInputCols +import org.apache.spark.ml.util._ +import org.apache.spark.sql.{DataFrame, Dataset, Row} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ + +/** + * Params for [[Imputer]] and [[ImputerModel]]. + */ +private[feature] trait ImputerParams extends Params with HasInputCols { + + /** + * The imputation strategy. + * If "mean", then replace missing values using the mean value of the feature. + * If "median", then replace missing values using the approximate median value of the feature. + * Default: mean + * + * @group param + */ + final val strategy: Param[String] = new Param(this, "strategy", s"strategy for imputation. " + + s"If ${Imputer.mean}, then replace missing values using the mean value of the feature. " + + s"If ${Imputer.median}, then replace missing values using the median value of the feature.", + ParamValidators.inArray[String](Array(Imputer.mean, Imputer.median))) + + /** @group getParam */ + def getStrategy: String = $(strategy) + + /** + * The placeholder for the missing values. All occurrences of missingValue will be imputed. + * Note that null values are always treated as missing. + * Default: Double.NaN + * + * @group param + */ + final val missingValue: DoubleParam = new DoubleParam(this, "missingValue", + "The placeholder for the missing values. All occurrences of missingValue will be imputed") + + /** @group getParam */ + def getMissingValue: Double = $(missingValue) + + /** + * Param for output column names. + * @group param + */ + final val outputCols: StringArrayParam = new StringArrayParam(this, "outputCols", + "output column names") + + /** @group getParam */ + final def getOutputCols: Array[String] = $(outputCols) + + /** Validates and transforms the input schema. */ + protected def validateAndTransformSchema(schema: StructType): StructType = { + require($(inputCols).length == $(inputCols).distinct.length, s"inputCols contains" + + s" duplicates: (${$(inputCols).mkString(", ")})") + require($(outputCols).length == $(outputCols).distinct.length, s"outputCols contains" + + s" duplicates: (${$(outputCols).mkString(", ")})") + require($(inputCols).length == $(outputCols).length, s"inputCols(${$(inputCols).length})" + + s" and outputCols(${$(outputCols).length}) should have the same length") + val outputFields = $(inputCols).zip($(outputCols)).map { case (inputCol, outputCol) => + val inputField = schema(inputCol) + SchemaUtils.checkColumnTypes(schema, inputCol, Seq(DoubleType, FloatType)) + StructField(outputCol, inputField.dataType, inputField.nullable) + } + StructType(schema ++ outputFields) + } +} + +/** + * :: Experimental :: + * Imputation estimator for completing missing values, either using the mean or the median + * of the column in which the missing values are located. The input column should be of + * DoubleType or FloatType. Currently Imputer does not support categorical features yet + * (SPARK-15041) and possibly creates incorrect values for a categorical feature. + * + * Note that the mean/median value is computed after filtering out missing values. + * All Null values in the input column are treated as missing, and so are also imputed. For + * computing median, DataFrameStatFunctions.approxQuantile is used with a relative error of 0.001. + */ +@Experimental +class Imputer @Since("2.2.0")(override val uid: String) + extends Estimator[ImputerModel] with ImputerParams with DefaultParamsWritable { + + @Since("2.2.0") + def this() = this(Identifiable.randomUID("imputer")) + + /** @group setParam */ + @Since("2.2.0") + def setInputCols(value: Array[String]): this.type = set(inputCols, value) + + /** @group setParam */ + @Since("2.2.0") + def setOutputCols(value: Array[String]): this.type = set(outputCols, value) + + /** + * Imputation strategy. Available options are ["mean", "median"]. + * @group setParam + */ + @Since("2.2.0") + def setStrategy(value: String): this.type = set(strategy, value) + + /** @group setParam */ + @Since("2.2.0") + def setMissingValue(value: Double): this.type = set(missingValue, value) + + setDefault(strategy -> Imputer.mean, missingValue -> Double.NaN) + + override def fit(dataset: Dataset[_]): ImputerModel = { + transformSchema(dataset.schema, logging = true) + val spark = dataset.sparkSession + import spark.implicits._ + val surrogates = $(inputCols).map { inputCol => + val ic = col(inputCol) + val filtered = dataset.select(ic.cast(DoubleType)) + .filter(ic.isNotNull && ic =!= $(missingValue) && !ic.isNaN) + if(filtered.take(1).length == 0) { + throw new SparkException(s"surrogate cannot be computed. " + + s"All the values in $inputCol are Null, Nan or missingValue(${$(missingValue)})") + } + val surrogate = $(strategy) match { + case Imputer.mean => filtered.select(avg(inputCol)).as[Double].first() + case Imputer.median => filtered.stat.approxQuantile(inputCol, Array(0.5), 0.001).head + } + surrogate + } + + val rows = spark.sparkContext.parallelize(Seq(Row.fromSeq(surrogates))) + val schema = StructType($(inputCols).map(col => StructField(col, DoubleType, nullable = false))) + val surrogateDF = spark.createDataFrame(rows, schema) + copyValues(new ImputerModel(uid, surrogateDF).setParent(this)) + } + + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) + } + + override def copy(extra: ParamMap): Imputer = defaultCopy(extra) +} + +@Since("2.2.0") +object Imputer extends DefaultParamsReadable[Imputer] { + + /** strategy names that Imputer currently supports. */ + private[ml] val mean = "mean" + private[ml] val median = "median" + + @Since("2.2.0") + override def load(path: String): Imputer = super.load(path) +} + +/** + * :: Experimental :: + * Model fitted by [[Imputer]]. + * + * @param surrogateDF a DataFrame contains inputCols and their corresponding surrogates, which are + * used to replace the missing values in the input DataFrame. + */ +@Experimental +class ImputerModel private[ml]( + override val uid: String, + val surrogateDF: DataFrame) + extends Model[ImputerModel] with ImputerParams with MLWritable { + + import ImputerModel._ + + /** @group setParam */ + def setInputCols(value: Array[String]): this.type = set(inputCols, value) + + /** @group setParam */ + def setOutputCols(value: Array[String]): this.type = set(outputCols, value) + + override def transform(dataset: Dataset[_]): DataFrame = { + transformSchema(dataset.schema, logging = true) + var outputDF = dataset + val surrogates = surrogateDF.select($(inputCols).map(col): _*).head().toSeq + + $(inputCols).zip($(outputCols)).zip(surrogates).foreach { + case ((inputCol, outputCol), surrogate) => + val inputType = dataset.schema(inputCol).dataType + val ic = col(inputCol) + outputDF = outputDF.withColumn(outputCol, + when(ic.isNull, surrogate) + .when(ic === $(missingValue), surrogate) + .otherwise(ic) + .cast(inputType)) + } + outputDF.toDF() + } + + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) + } + + override def copy(extra: ParamMap): ImputerModel = { + val copied = new ImputerModel(uid, surrogateDF) + copyValues(copied, extra).setParent(parent) + } + + @Since("2.2.0") + override def write: MLWriter = new ImputerModelWriter(this) +} + + +@Since("2.2.0") +object ImputerModel extends MLReadable[ImputerModel] { + + private[ImputerModel] class ImputerModelWriter(instance: ImputerModel) extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val dataPath = new Path(path, "data").toString + instance.surrogateDF.repartition(1).write.parquet(dataPath) + } + } + + private class ImputerReader extends MLReader[ImputerModel] { + + private val className = classOf[ImputerModel].getName + + override def load(path: String): ImputerModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val dataPath = new Path(path, "data").toString + val surrogateDF = sqlContext.read.parquet(dataPath) + val model = new ImputerModel(metadata.uid, surrogateDF) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } + + @Since("2.2.0") + override def read: MLReader[ImputerModel] = new ImputerReader + + @Since("2.2.0") + override def load(path: String): ImputerModel = super.load(path) +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala new file mode 100644 index 0000000000000..ee2ba73fa96d5 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.ml.feature + +import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.sql.{DataFrame, Row} + +class ImputerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + + test("Imputer for Double with default missing Value NaN") { + val df = spark.createDataFrame( Seq( + (0, 1.0, 4.0, 1.0, 1.0, 4.0, 4.0), + (1, 11.0, 12.0, 11.0, 11.0, 12.0, 12.0), + (2, 3.0, Double.NaN, 3.0, 3.0, 10.0, 12.0), + (3, Double.NaN, 14.0, 5.0, 3.0, 14.0, 14.0) + )).toDF("id", "value1", "value2", "expected_mean_value1", "expected_median_value1", + "expected_mean_value2", "expected_median_value2") + val imputer = new Imputer() + .setInputCols(Array("value1", "value2")) + .setOutputCols(Array("out1", "out2")) + ImputerSuite.iterateStrategyTest(imputer, df) + } + + test("Imputer should handle NaNs when computing surrogate value, if missingValue is not NaN") { + val df = spark.createDataFrame( Seq( + (0, 1.0, 1.0, 1.0), + (1, 3.0, 3.0, 3.0), + (2, Double.NaN, Double.NaN, Double.NaN), + (3, -1.0, 2.0, 3.0) + )).toDF("id", "value", "expected_mean_value", "expected_median_value") + val imputer = new Imputer().setInputCols(Array("value")).setOutputCols(Array("out")) + .setMissingValue(-1.0) + ImputerSuite.iterateStrategyTest(imputer, df) + } + + test("Imputer for Float with missing Value -1.0") { + val df = spark.createDataFrame( Seq( + (0, 1.0F, 1.0F, 1.0F), + (1, 3.0F, 3.0F, 3.0F), + (2, 10.0F, 10.0F, 10.0F), + (3, 10.0F, 10.0F, 10.0F), + (4, -1.0F, 6.0F, 3.0F) + )).toDF("id", "value", "expected_mean_value", "expected_median_value") + val imputer = new Imputer().setInputCols(Array("value")).setOutputCols(Array("out")) + .setMissingValue(-1) + ImputerSuite.iterateStrategyTest(imputer, df) + } + + test("Imputer should impute null as well as 'missingValue'") { + val rawDf = spark.createDataFrame( Seq( + (0, 4.0, 4.0, 4.0), + (1, 10.0, 10.0, 10.0), + (2, 10.0, 10.0, 10.0), + (3, Double.NaN, 8.0, 10.0), + (4, -1.0, 8.0, 10.0) + )).toDF("id", "rawValue", "expected_mean_value", "expected_median_value") + val df = rawDf.selectExpr("*", "IF(rawValue=-1.0, null, rawValue) as value") + val imputer = new Imputer().setInputCols(Array("value")).setOutputCols(Array("out")) + ImputerSuite.iterateStrategyTest(imputer, df) + } + + test("Imputer throws exception when surrogate cannot be computed") { + val df = spark.createDataFrame( Seq( + (0, Double.NaN, 1.0, 1.0), + (1, Double.NaN, 3.0, 3.0), + (2, Double.NaN, Double.NaN, Double.NaN) + )).toDF("id", "value", "expected_mean_value", "expected_median_value") + Seq("mean", "median").foreach { strategy => + val imputer = new Imputer().setInputCols(Array("value")).setOutputCols(Array("out")) + .setStrategy(strategy) + withClue("Imputer should fail all the values are invalid") { + val e: SparkException = intercept[SparkException] { + val model = imputer.fit(df) + } + assert(e.getMessage.contains("surrogate cannot be computed")) + } + } + } + + test("Imputer input & output column validation") { + val df = spark.createDataFrame( Seq( + (0, 1.0, 1.0, 1.0), + (1, Double.NaN, 3.0, 3.0), + (2, Double.NaN, Double.NaN, Double.NaN) + )).toDF("id", "value1", "value2", "value3") + Seq("mean", "median").foreach { strategy => + withClue("Imputer should fail if inputCols and outputCols are different length") { + val e: IllegalArgumentException = intercept[IllegalArgumentException] { + val imputer = new Imputer().setStrategy(strategy) + .setInputCols(Array("value1", "value2")) + .setOutputCols(Array("out1")) + val model = imputer.fit(df) + } + assert(e.getMessage.contains("should have the same length")) + } + + withClue("Imputer should fail if inputCols contains duplicates") { + val e: IllegalArgumentException = intercept[IllegalArgumentException] { + val imputer = new Imputer().setStrategy(strategy) + .setInputCols(Array("value1", "value1")) + .setOutputCols(Array("out1", "out2")) + val model = imputer.fit(df) + } + assert(e.getMessage.contains("inputCols contains duplicates")) + } + + withClue("Imputer should fail if outputCols contains duplicates") { + val e: IllegalArgumentException = intercept[IllegalArgumentException] { + val imputer = new Imputer().setStrategy(strategy) + .setInputCols(Array("value1", "value2")) + .setOutputCols(Array("out1", "out1")) + val model = imputer.fit(df) + } + assert(e.getMessage.contains("outputCols contains duplicates")) + } + } + } + + test("Imputer read/write") { + val t = new Imputer() + .setInputCols(Array("myInputCol")) + .setOutputCols(Array("myOutputCol")) + .setMissingValue(-1.0) + testDefaultReadWrite(t) + } + + test("ImputerModel read/write") { + val spark = this.spark + import spark.implicits._ + val surrogateDF = Seq(1.234).toDF("myInputCol") + + val instance = new ImputerModel( + "myImputer", surrogateDF) + .setInputCols(Array("myInputCol")) + .setOutputCols(Array("myOutputCol")) + val newInstance = testDefaultReadWrite(instance) + assert(newInstance.surrogateDF.columns === instance.surrogateDF.columns) + assert(newInstance.surrogateDF.collect() === instance.surrogateDF.collect()) + } + +} + +object ImputerSuite { + + /** + * Imputation strategy. Available options are ["mean", "median"]. + * @param df DataFrame with columns "id", "value", "expected_mean", "expected_median" + */ + def iterateStrategyTest(imputer: Imputer, df: DataFrame): Unit = { + val inputCols = imputer.getInputCols + + Seq("mean", "median").foreach { strategy => + imputer.setStrategy(strategy) + val model = imputer.fit(df) + val resultDF = model.transform(df) + imputer.getInputCols.zip(imputer.getOutputCols).foreach { case (inputCol, outputCol) => + resultDF.select(s"expected_${strategy}_$inputCol", outputCol).collect().foreach { + case Row(exp: Float, out: Float) => + assert((exp.isNaN && out.isNaN) || (exp == out), + s"Imputed values differ. Expected: $exp, actual: $out") + case Row(exp: Double, out: Double) => + assert((exp.isNaN && out.isNaN) || (exp ~== out absTol 1e-5), + s"Imputed values differ. Expected: $exp, actual: $out") + } + } + } + } +}