Skip to content

Commit

Permalink
refactor Bucketizer
Browse files Browse the repository at this point in the history
  • Loading branch information
yinxusen committed May 7, 2015
1 parent 11fb00a commit 2466322
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 51 deletions.
130 changes: 81 additions & 49 deletions mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,46 +25,55 @@ import org.apache.spark.ml.util.SchemaUtils
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DoubleType, StructType}
import org.apache.spark.sql.types.{DoubleType, StructField, StructType}

/**
* :: AlphaComponent ::
* `Bucketizer` maps a column of continuous features to a column of feature buckets.
*/
@AlphaComponent
final class Bucketizer(override val parent: Estimator[Bucketizer] = null)
private[ml] final class Bucketizer(override val parent: Estimator[Bucketizer])
extends Model[Bucketizer] with HasInputCol with HasOutputCol {

/**
* The given buckets should match 1) its size is larger than zero; 2) it is ordered in a non-DESC
* way.
*/
private def checkBuckets(buckets: Array[Double]): Boolean = {
if (buckets.size == 0) false
else if (buckets.size == 1) true
else {
buckets.foldLeft((true, Double.MinValue)) { case ((validator, prevValue), currValue) =>
if (validator & prevValue <= currValue) {
(true, currValue)
} else {
(false, currValue)
}
}._1
}
}
def this() = this(null)

/**
* Parameter for mapping continuous features into buckets.
* Parameter for mapping continuous features into buckets. With n splits, there are n+1 buckets.
* A bucket defined by splits x,y holds values in the range (x,y].
* @group param
*/
val buckets: Param[Array[Double]] = new Param[Array[Double]](this, "buckets",
"Split points for mapping continuous features into buckets.", checkBuckets)
val splits: Param[Array[Double]] = new Param[Array[Double]](this, "splits",
"Split points for mapping continuous features into buckets. With n splits, there are n+1" +
" buckets. A bucket defined by splits x,y holds values in the range (x,y].",
Bucketizer.checkSplits)

/** @group getParam */
def getBuckets: Array[Double] = $(buckets)
def getSplits: Array[Double] = $(splits)

/** @group setParam */
def setBuckets(value: Array[Double]): this.type = set(buckets, value)
def setSplits(value: Array[Double]): this.type = set(splits, value)

/** @group Param */
val lowerInclusive: BooleanParam = new BooleanParam(this, "lowerInclusive",
"An indicator of the inclusiveness of negative infinite.")
setDefault(lowerInclusive -> true)

/** @group getParam */
def getLowerInclusive: Boolean = $(lowerInclusive)

/** @group setParam */
def setLowerInclusive(value: Boolean): this.type = set(lowerInclusive, value)

/** @group Param */
val upperInclusive: BooleanParam = new BooleanParam(this, "upperInclusive",
"An indicator of the inclusiveness of positive infinite.")
setDefault(upperInclusive -> true)

/** @group getParam */
def getUpperInclusive: Boolean = $(upperInclusive)

/** @group setParam */
def setUpperInclusive(value: Boolean): this.type = set(upperInclusive, value)

/** @group setParam */
def setInputCol(value: String): this.type = set(inputCol, value)
Expand All @@ -74,45 +83,68 @@ final class Bucketizer(override val parent: Estimator[Bucketizer] = null)

override def transform(dataset: DataFrame): DataFrame = {
transformSchema(dataset.schema)
val bucketizer = udf { feature: Double => binarySearchForBuckets($(buckets), feature) }
val outputColName = $(outputCol)
val metadata = NominalAttribute.defaultAttr
.withName(outputColName).withValues($(buckets).map(_.toString)).toMetadata()
dataset.select(col("*"), bucketizer(dataset($(inputCol))).as(outputColName, metadata))
val wrappedSplits = Array(Double.MinValue) ++ $(splits) ++ Array(Double.MaxValue)
val bucketizer = udf { feature: Double =>
Bucketizer.binarySearchForBuckets(wrappedSplits, feature) }
val newCol = bucketizer(dataset($(inputCol)))
val newField = prepOutputField(dataset.schema)
dataset.withColumn($(outputCol), newCol.as($(outputCol), newField.metadata))
}

private def prepOutputField(schema: StructType): StructField = {
val attr = new NominalAttribute(
name = Some($(outputCol)),
isOrdinal = Some(true),
numValues = Some($(splits).size),
values = Some($(splits).map(_.toString)))

attr.toStructField()
}

override def transformSchema(schema: StructType): StructType = {
SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType)
require(schema.fields.forall(_.name != $(outputCol)),
s"Output column ${$(outputCol)} already exists.")
StructType(schema.fields :+ prepOutputField(schema))
}
}

object Bucketizer {
/**
* The given splits should match 1) its size is larger than zero; 2) it is ordered in a strictly
* increasing way.
*/
private def checkSplits(splits: Array[Double]): Boolean = {
if (splits.size == 0) false
else if (splits.size == 1) true
else {
splits.foldLeft((true, Double.MinValue)) { case ((validator, prevValue), currValue) =>
if (validator && prevValue < currValue) {
(true, currValue)
} else {
(false, currValue)
}
}._1
}
}

/**
* Binary searching in several buckets to place each data point.
*/
private def binarySearchForBuckets(splits: Array[Double], feature: Double): Double = {
val wrappedSplits = Array(Double.MinValue) ++ splits ++ Array(Double.MaxValue)
private[feature] def binarySearchForBuckets(splits: Array[Double], feature: Double): Double = {
var left = 0
var right = wrappedSplits.length - 2
var right = splits.length - 2
while (left <= right) {
val mid = left + (right - left) / 2
val split = wrappedSplits(mid)
if ((feature > split) && (feature <= wrappedSplits(mid + 1))) {
val split = splits(mid)
if ((feature > split) && (feature <= splits(mid + 1))) {
return mid
} else if (feature <= split) {
right = mid - 1
} else {
left = mid + 1
}
}
-1
}

override def transformSchema(schema: StructType): StructType = {
SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType)

val inputFields = schema.fields
val outputColName = $(outputCol)

require(inputFields.forall(_.name != outputColName),
s"Output column $outputColName already exists.")

val attr = NominalAttribute.defaultAttr.withName(outputColName)
val outputFields = inputFields :+ attr.toStructField()
StructType(outputFields)
throw new Exception("Failed to find a bucket.")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@

package org.apache.spark.ml.feature

import org.scalatest.FunSuite

import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Row, SQLContext}
import org.scalatest.FunSuite

class BucketizerSuite extends FunSuite with MLlibTestSparkContext {

Expand All @@ -34,11 +35,15 @@ class BucketizerSuite extends FunSuite with MLlibTestSparkContext {
val bucketizer: Bucketizer = new Bucketizer()
.setInputCol("feature")
.setOutputCol("result")
.setBuckets(buckets)
.setSplits(buckets)

bucketizer.transform(dataFrame).select("result", "expected").collect().foreach {
case Row(x: Double, y: Double) =>
assert(x === y, "The feature value is not correct after bucketing.")
}
}

test("Binary search for finding buckets") {

}
}

0 comments on commit 2466322

Please sign in to comment.