Skip to content

Commit

Permalink
Removed lowerInclusive, upperInclusive params from Bucketizer, and us…
Browse files Browse the repository at this point in the history
…ed splits instead.
  • Loading branch information
jkbradley committed May 11, 2015
1 parent eacfcfa commit 34f124a
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 128 deletions.
120 changes: 37 additions & 83 deletions mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,17 @@ final class Bucketizer private[ml] (override val parent: Estimator[Bucketizer])

/**
* Parameter for mapping continuous features into buckets. With n splits, there are n+1 buckets.
* A bucket defined by splits x,y holds values in the range [x,y). Note that the splits should be
* strictly increasing.
* A bucket defined by splits x,y holds values in the range [x,y). Splits should be strictly
* increasing. Values at -inf, inf must be explicitly provided to cover all Double values;
* otherwise, values outside the splits specified will be treated as errors.
* @group param
*/
val splits: Param[Array[Double]] = new Param[Array[Double]](this, "splits",
"Split points for mapping continuous features into buckets. With n splits, there are n+1 " +
"buckets. A bucket defined by splits x,y holds values in the range [x,y). The splits " +
"should be strictly increasing.",
"should be strictly increasing. Values at -inf, inf must be explicitly provided to cover" +
" all Double values; otherwise, values outside the splits specified will be treated as" +
" errors.",
Bucketizer.checkSplits)

/** @group getParam */
Expand All @@ -55,40 +58,6 @@ final class Bucketizer private[ml] (override val parent: Estimator[Bucketizer])
/** @group setParam */
def setSplits(value: Array[Double]): this.type = set(splits, value)

/**
* An indicator of the inclusiveness of negative infinite. If true, then use implicit bin
* (-inf, getSplits.head). If false, then throw exception if values < getSplits.head are
* encountered.
* @group Param */
val lowerInclusive: BooleanParam = new BooleanParam(this, "lowerInclusive",
"An indicator of the inclusiveness of negative infinite. If true, then use implicit bin " +
"(-inf, getSplits.head). If false, then throw exception if values < getSplits.head are " +
"encountered.")
setDefault(lowerInclusive -> true)

/** @group getParam */
def getLowerInclusive: Boolean = $(lowerInclusive)

/** @group setParam */
def setLowerInclusive(value: Boolean): this.type = set(lowerInclusive, value)

/**
* An indicator of the inclusiveness of positive infinite. If true, then use implicit bin
* [getSplits.last, inf). If false, then throw exception if values > getSplits.last are
* encountered.
* @group Param */
val upperInclusive: BooleanParam = new BooleanParam(this, "upperInclusive",
"An indicator of the inclusiveness of positive infinite. If true, then use implicit bin " +
"[getSplits.last, inf). If false, then throw exception if values > getSplits.last are " +
"encountered.")
setDefault(upperInclusive -> true)

/** @group getParam */
def getUpperInclusive: Boolean = $(upperInclusive)

/** @group setParam */
def setUpperInclusive(value: Boolean): this.type = set(upperInclusive, value)

/** @group setParam */
def setInputCol(value: String): this.type = set(inputCol, value)

Expand All @@ -97,81 +66,66 @@ final class Bucketizer private[ml] (override val parent: Estimator[Bucketizer])

override def transform(dataset: DataFrame): DataFrame = {
transformSchema(dataset.schema)
val wrappedSplits = Array(Double.MinValue) ++ $(splits) ++ Array(Double.MaxValue)
val bucketizer = udf { feature: Double =>
Bucketizer
.binarySearchForBuckets(wrappedSplits, feature, $(lowerInclusive), $(upperInclusive)) }
Bucketizer.binarySearchForBuckets($(splits), feature)
}
val newCol = bucketizer(dataset($(inputCol)))
val newField = prepOutputField(dataset.schema)
dataset.withColumn($(outputCol), newCol.as($(outputCol), newField.metadata))
}

private def prepOutputField(schema: StructType): StructField = {
val innerRanges = $(splits).sliding(2).map(bucket => bucket.mkString(", ")).toArray
val values = ($(lowerInclusive), $(upperInclusive)) match {
case (true, true) =>
Array(s"-inf, ${$(splits).head}") ++ innerRanges ++ Array(s"${$(splits).last}, inf")
case (true, false) => Array(s"-inf, ${$(splits).head}") ++ innerRanges
case (false, true) => innerRanges ++ Array(s"${$(splits).last}, inf")
case _ => innerRanges
}
val attr =
new NominalAttribute(name = Some($(outputCol)), isOrdinal = Some(true), values = Some(values))
val buckets = $(splits).sliding(2).map(bucket => bucket.mkString(", ")).toArray
val attr = new NominalAttribute(name = Some($(outputCol)), isOrdinal = Some(true),
values = Some(buckets))
attr.toStructField()
}

override def transformSchema(schema: StructType): StructType = {
SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType)
require(schema.fields.forall(_.name != $(outputCol)),
s"Output column ${$(outputCol)} already exists.")
StructType(schema.fields :+ prepOutputField(schema))
SchemaUtils.appendColumn(schema, prepOutputField(schema))
}
}

private[feature] object Bucketizer {
/**
* The given splits should match 1) its size is larger than zero; 2) it is ordered in a strictly
* increasing way.
*/
private def checkSplits(splits: Array[Double]): Boolean = {
if (splits.size == 0) false
else if (splits.size == 1) true
else {
splits.foldLeft((true, Double.MinValue)) { case ((validator, prevValue), currValue) =>
if (validator && prevValue < currValue) {
(true, currValue)
} else {
(false, currValue)
}
}._1
/** We require splits to be of length >= 3 and to be in strictly increasing order. */
def checkSplits(splits: Array[Double]): Boolean = {
if (splits.length < 3) {
false
} else {
var i = 0
while (i < splits.length - 1) {
if (splits(i) >= splits(i + 1)) return false
i += 1
}
true
}
}

/**
* Binary searching in several buckets to place each data point.
* @throws RuntimeException if a feature is < splits.head or >= splits.last
*/
private[feature] def binarySearchForBuckets(
def binarySearchForBuckets(
splits: Array[Double],
feature: Double,
lowerInclusive: Boolean,
upperInclusive: Boolean): Double = {
if ((feature < splits.head && !lowerInclusive) || (feature > splits.last && !upperInclusive)) {
throw new RuntimeException(s"Feature $feature out of bound, check your features or loosen " +
s"the lower/upper bound constraint.")
feature: Double): Double = {
// Check bounds. We make an exception for +inf so that it can exist in some bin.
if ((feature < splits.head) || (feature >= splits.last && feature != Double.PositiveInfinity)) {
throw new RuntimeException(s"Feature value $feature out of Bucketizer bounds" +
s" [${splits.head}, ${splits.last}). Check your features, or loosen " +
s"the lower/upper bound constraints.")
}
var left = 0
var right = splits.length - 2
while (left <= right) {
val mid = left + (right - left) / 2
val split = splits(mid)
if ((feature >= split) && (feature < splits(mid + 1))) {
return mid
} else if (feature < split) {
right = mid - 1
while (left < right) {
val mid = (left + right) / 2
val split = splits(mid + 1)
if (feature < split) {
right = mid
} else {
left = mid + 1
}
}
throw new RuntimeException(s"Unexpected error: failed to find a bucket for feature $feature.")
left
}
}
11 changes: 11 additions & 0 deletions mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -58,4 +58,15 @@ object SchemaUtils {
val outputFields = schema.fields :+ StructField(colName, dataType, nullable = false)
StructType(outputFields)
}

/**
* Appends a new column to the input schema. This fails if the given output column already exists.
* @param schema input schema
* @param col New column schema
* @return new schema with the input column appended
*/
def appendColumn(schema: StructType, col: StructField): StructType = {
require(!schema.fieldNames.contains(col.name), s"Column ${col.name} already exists.")
StructType(schema.fields :+ col)
}
}
136 changes: 91 additions & 45 deletions mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,28 @@ import scala.util.Random

import org.scalatest.FunSuite

import org.apache.spark.SparkException
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.{DataFrame, Row, SQLContext}

class BucketizerSuite extends FunSuite with MLlibTestSparkContext {

test("Bucket continuous features with setter") {
val sqlContext = new SQLContext(sc)
val data = Array(0.1, -0.5, 0.2, -0.3, 0.8, 0.7, -0.1, -0.4, -0.9)
@transient private var sqlContext: SQLContext = _

override def beforeAll(): Unit = {
super.beforeAll()
sqlContext = new SQLContext(sc)
}

test("Bucket continuous features, without -inf,inf") {
// Check a set of valid feature values.
val splits = Array(-0.5, 0.0, 0.5)
val bucketizedData = Array(2.0, 1.0, 2.0, 1.0, 3.0, 3.0, 1.0, 1.0, 0.0)
val dataFrame: DataFrame = sqlContext.createDataFrame(
data.zip(bucketizedData)).toDF("feature", "expected")
val validData = Array(-0.5, -0.3, 0.0, 0.2)
val expectedBuckets = Array(0.0, 0.0, 1.0, 1.0)
val dataFrame: DataFrame =
sqlContext.createDataFrame(validData.zip(expectedBuckets)).toDF("feature", "expected")

val bucketizer: Bucketizer = new Bucketizer()
.setInputCol("feature")
Expand All @@ -43,58 +51,96 @@ class BucketizerSuite extends FunSuite with MLlibTestSparkContext {

bucketizer.transform(dataFrame).select("result", "expected").collect().foreach {
case Row(x: Double, y: Double) =>
assert(x === y, "The feature value is not correct after bucketing.")
assert(x === y,
s"The feature value is not correct after bucketing. Expected $y but found $x")
}
}

test("Binary search correctness in contrast with linear search") {
val data = Array.fill(100)(Random.nextDouble())
val splits = Array.fill(10)(Random.nextDouble()).sorted
val wrappedSplits = Array(Double.MinValue) ++ splits ++ Array(Double.MaxValue)
val bsResult = Vectors.dense(
data.map(x => Bucketizer.binarySearchForBuckets(wrappedSplits, x, true, true)))
val lsResult = Vectors.dense(data.map(x => BucketizerSuite.linearSearchForBuckets(splits, x)))
assert(bsResult ~== lsResult absTol 1e-5)
// Check for exceptions when using a set of invalid feature values.
val invalidData1: Array[Double] = Array(-0.9) ++ validData
val invalidData2 = Array(0.5) ++ validData
val badDF1 = sqlContext.createDataFrame(invalidData1.zipWithIndex).toDF("feature", "idx")
intercept[RuntimeException]{
bucketizer.transform(badDF1).collect()
println("Invalid feature value -0.9 was not caught as an invalid feature!")
}
val badDF2 = sqlContext.createDataFrame(invalidData2.zipWithIndex).toDF("feature", "idx")
intercept[RuntimeException]{
bucketizer.transform(badDF2).collect()
println("Invalid feature value 0.5 was not caught as an invalid feature!")
}
}

test("Binary search of features at splits") {
val splits = Array.fill(10)(Random.nextDouble()).sorted
val data = splits
val expected = Vectors.dense(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0)
val wrappedSplits = Array(Double.MinValue) ++ splits ++ Array(Double.MaxValue)
val result = Vectors.dense(
data.map(x => Bucketizer.binarySearchForBuckets(wrappedSplits, x, true, true)))
assert(result ~== expected absTol 1e-5)
test("Bucket continuous features, with -inf,inf") {
val splits = Array(Double.NegativeInfinity, -0.5, 0.0, 0.5, Double.PositiveInfinity)
val validData = Array(-0.9, -0.5, -0.3, 0.0, 0.2, 0.5, 0.9)
val expectedBuckets = Array(0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0)
val dataFrame: DataFrame =
sqlContext.createDataFrame(validData.zip(expectedBuckets)).toDF("feature", "expected")

val bucketizer: Bucketizer = new Bucketizer()
.setInputCol("feature")
.setOutputCol("result")
.setSplits(splits)

bucketizer.transform(dataFrame).select("result", "expected").collect().foreach {
case Row(x: Double, y: Double) =>
assert(x === y,
s"The feature value is not correct after bucketing. Expected $y but found $x")
}
}

test("Binary search of features between splits") {
val data = Array.fill(10)(Random.nextDouble())
val splits = Array(-0.1, 1.1)
val expected = Vectors.dense(Array.fill(10)(1.0))
val wrappedSplits = Array(Double.MinValue) ++ splits ++ Array(Double.MaxValue)
val result = Vectors.dense(
data.map(x => Bucketizer.binarySearchForBuckets(wrappedSplits, x, true, true)))
assert(result ~== expected absTol 1e-5)
test("Binary search correctness on hand-picked examples") {
import BucketizerSuite.checkBinarySearch
// length 3, with -inf
checkBinarySearch(Array(Double.NegativeInfinity, 0.0, 1.0))
// length 4
checkBinarySearch(Array(-1.0, -0.5, 0.0, 1.0))
// length 5
checkBinarySearch(Array(-1.0, -0.5, 0.0, 1.0, 1.5))
// length 3, with inf
checkBinarySearch(Array(0.0, 1.0, Double.PositiveInfinity))
// length 3, with -inf and inf
checkBinarySearch(Array(Double.NegativeInfinity, 1.0, Double.PositiveInfinity))
}

test("Binary search of features outside splits") {
val data = Array.fill(5)(Random.nextDouble() + 1.1) ++ Array.fill(5)(Random.nextDouble() - 1.1)
val splits = Array(0.0, 1.1)
val expected = Vectors.dense(Array.fill(5)(2.0) ++ Array.fill(5)(0.0))
val wrappedSplits = Array(Double.MinValue) ++ splits ++ Array(Double.MaxValue)
val result = Vectors.dense(
data.map(x => Bucketizer.binarySearchForBuckets(wrappedSplits, x, true, true)))
assert(result ~== expected absTol 1e-5)
test("Binary search correctness in contrast with linear search, on random data") {
val data = Array.fill(100)(Random.nextDouble())
val splits: Array[Double] = Double.NegativeInfinity +:
Array.fill(10)(Random.nextDouble()).sorted :+ Double.PositiveInfinity
val bsResult = Vectors.dense(data.map(x => Bucketizer.binarySearchForBuckets(splits, x)))
val lsResult = Vectors.dense(data.map(x => BucketizerSuite.linearSearchForBuckets(splits, x)))
assert(bsResult ~== lsResult absTol 1e-5)
}
}

private object BucketizerSuite {
private def linearSearchForBuckets(splits: Array[Double], feature: Double): Double = {
private object BucketizerSuite extends FunSuite {
/** Brute force search for buckets. Bucket i is defined by the range [split(i), split(i+1)). */
def linearSearchForBuckets(splits: Array[Double], feature: Double): Double = {
require(feature >= splits.head)
var i = 0
while (i < splits.size) {
if (feature < splits(i)) return i
while (i < splits.length - 1) {
if (feature < splits(i + 1)) return i
i += 1
}
i
throw new RuntimeException(
s"linearSearchForBuckets failed to find bucket for feature value $feature")
}

/** Check all values in splits, plus values between all splits. */
def checkBinarySearch(splits: Array[Double]): Unit = {
def testFeature(feature: Double, expectedBucket: Double): Unit = {
assert(Bucketizer.binarySearchForBuckets(splits, feature) === expectedBucket,
s"Expected feature value $feature to be in bucket $expectedBucket with splits:" +
s" ${splits.mkString(", ")}")
}
var i = 0
while (i < splits.length - 1) {
testFeature(splits(i), i) // Split i should fall in bucket i.
testFeature((splits(i) + splits(i + 1)) / 2, i) // Value between splits i,i+1 should be in i.
i += 1
}
if (splits.last === Double.PositiveInfinity) {
testFeature(Double.PositiveInfinity, splits.length - 2)
}
}
}

0 comments on commit 34f124a

Please sign in to comment.