Skip to content

Commit

Permalink
add object Imputer and ut refine
Browse files Browse the repository at this point in the history
  • Loading branch information
hhbyyh committed Mar 9, 2016
1 parent c3d5d55 commit 1b39668
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 16 deletions.
18 changes: 12 additions & 6 deletions mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -134,19 +134,18 @@ class Imputer @Since("2.0.0")(override val uid: String)
case Double.NaN => "NaN"
case _ => $(missingValue).toString
}
val filteredDF = dataset.select(colName).where(s"$colName != '$missValue'")
val colStatistics = $(strategy) match {
case "mean" =>
dataset.where(s"$colName != '$missValue'").selectExpr(s"avg($colName)").first().getDouble(0)
filteredDF.selectExpr(s"avg($colName)").first().getDouble(0)
case "median" =>
// TODO: optimize the sort with quick-select or Percentile(Hive) if required
val rddDouble = dataset.select(colName).where(s"$colName != $missValue").rdd
.map(_.getDouble(0))
val rddDouble = filteredDF.rdd.map(_.getDouble(0))
rddDouble.sortBy(d => d).zipWithIndex().map {
case (v, idx) => (idx, v)
}.lookup(rddDouble.count()/2).head
case "most" =>
val input = dataset.where(s"$colName != $missValue").select(colName).rdd
.map(_.getDouble(0))
val input = filteredDF.rdd.map(_.getDouble(0))
val most = input.map(d => (d, 1)).reduceByKey(_ + _).sortBy(-_._2).first()._1
most
}
Expand All @@ -163,6 +162,13 @@ class Imputer @Since("2.0.0")(override val uid: String)
}
}

@Since("1.6.0")
object Imputer extends DefaultParamsReadable[Imputer] {

@Since("1.6.0")
override def load(path: String): Imputer = super.load(path)
}

/**
* :: Experimental ::
* Model fitted by [[Imputer]].
Expand Down Expand Up @@ -214,7 +220,7 @@ class ImputerModel private[ml] (
}
case s: SparseVector =>
var iter = 0
while(iter < s.values.size) {
while(iter < s.values.length) {
if (matchMissingValue(s.values(iter))) {
s.values(iter) = alternate(s.indices(iter))
}
Expand Down
25 changes: 15 additions & 10 deletions mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package org.apache.spark.ml.feature

import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.Row
Expand Down Expand Up @@ -64,19 +64,23 @@ class ImputerSuite extends SparkFunSuite with MLlibTestSparkContext with Default

test("Imputer for Vector column with NaN and null") {
val df = sqlContext.createDataFrame( Seq(
(0, Vector(1, 2), Vector(1, 2), Vector(1, 2), Vector(1, 2)),
(1, Vector(1, 2), Vector(1, 2), Vector(1, 2), Vector(1, 2)),
(2, Vector(3, 2), Vector(3, 2), Vector(3, 2), Vector(3, 2)),
(3, Vector(4, 2), Vector(4, 2), Vector(4, 2), Vector(4, 2)),
(4, Vector(Double.NaN, 2), Vector(2.25, 2), Vector(3.0, 2), Vector(1.0, 2)),
(4, null, Vector(2.25, 2), Vector(3.0, 2), Vector(1.0, 2))
(0, Vectors.dense(1, 2), Vectors.dense(1, 2), Vectors.dense(1, 2), Vectors.dense(1, 2)),
(1, Vectors.dense(1, 2), Vectors.dense(1, 2), Vectors.dense(1, 2), Vectors.dense(1, 2)),
(2, Vectors.dense(3, 2), Vectors.dense(3, 2), Vectors.dense(3, 2), Vectors.dense(3, 2)),
(3, Vectors.dense(4, 2), Vectors.dense(4, 2), Vectors.dense(4, 2), Vectors.dense(4, 2)),
(4, Vectors.dense(Double.NaN, 2), Vectors.dense(2.25, 2), Vectors.dense(3.0, 2),
Vectors.dense(1.0, 2)),
(5, Vectors.sparse(2, Array(0, 1), Array(Double.NaN, 2.0)), Vectors.dense(2.25, 2),
Vectors.dense(3.0, 2), Vectors.dense(1.0, 2)),
(6, null.asInstanceOf[Vector], Vectors.dense(2.25, 2), Vectors.dense(3.0, 2),
Vectors.dense(1.0, 2))
)).toDF("id", "value", "mean", "median", "most")
Seq("mean", "median", "most").foreach { strategy =>
val imputer = new Imputer().setInputCol("value").setOutputCol("out").setStrategy(strategy)
val model = imputer.fit(df)
model.transform(df).select(strategy, "out").collect()
.foreach { case Row(d1: Double, d2: Double) =>
assert(d1 ~== d2 absTol 1e-5, s"Imputer ut error: $d2 should be $d1")
.foreach { case Row(v1: Vector, v2: Vector) =>
assert(v1 == v2, s"$strategy Imputer ut error: $v2 should be $v1")
}
}
}
Expand All @@ -85,10 +89,11 @@ class ImputerSuite extends SparkFunSuite with MLlibTestSparkContext with Default
val t = new Imputer()
.setInputCol("myInputCol")
.setOutputCol("myOutputCol")
.setMissingValue(-1.0)
testDefaultReadWrite(t)
}

test("Imputer read/write") {
test("ImputerModel read/write") {
val instance = new ImputerModel(
"myImputer", Vectors.dense(1.0, 10.0))
.setInputCol("myInputCol")
Expand Down

0 comments on commit 1b39668

Please sign in to comment.