From 34f124a1b7069d643c3496168b867f2fdde87257 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Mon, 11 May 2015 12:10:38 -0700 Subject: [PATCH] Removed lowerInclusive, upperInclusive params from Bucketizer, and used splits instead. --- .../apache/spark/ml/feature/Bucketizer.scala | 120 +++++----------- .../apache/spark/ml/util/SchemaUtils.scala | 11 ++ .../spark/ml/feature/BucketizerSuite.scala | 136 ++++++++++++------ 3 files changed, 139 insertions(+), 128 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index 1a476bacf1043..7dba64bc3506f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -39,14 +39,17 @@ final class Bucketizer private[ml] (override val parent: Estimator[Bucketizer]) /** * Parameter for mapping continuous features into buckets. With n splits, there are n+1 buckets. - * A bucket defined by splits x,y holds values in the range [x,y). Note that the splits should be - * strictly increasing. + * A bucket defined by splits x,y holds values in the range [x,y). Splits should be strictly + * increasing. Values at -inf, inf must be explicitly provided to cover all Double values; + * otherwise, values outside the splits specified will be treated as errors. * @group param */ val splits: Param[Array[Double]] = new Param[Array[Double]](this, "splits", "Split points for mapping continuous features into buckets. With n splits, there are n+1 " + "buckets. A bucket defined by splits x,y holds values in the range [x,y). The splits " + - "should be strictly increasing.", + "should be strictly increasing. Values at -inf, inf must be explicitly provided to cover" + + " all Double values; otherwise, values outside the splits specified will be treated as" + + " errors.", Bucketizer.checkSplits) /** @group getParam */ @@ -55,40 +58,6 @@ final class Bucketizer private[ml] (override val parent: Estimator[Bucketizer]) /** @group setParam */ def setSplits(value: Array[Double]): this.type = set(splits, value) - /** - * An indicator of the inclusiveness of negative infinite. If true, then use implicit bin - * (-inf, getSplits.head). If false, then throw exception if values < getSplits.head are - * encountered. - * @group Param */ - val lowerInclusive: BooleanParam = new BooleanParam(this, "lowerInclusive", - "An indicator of the inclusiveness of negative infinite. If true, then use implicit bin " + - "(-inf, getSplits.head). If false, then throw exception if values < getSplits.head are " + - "encountered.") - setDefault(lowerInclusive -> true) - - /** @group getParam */ - def getLowerInclusive: Boolean = $(lowerInclusive) - - /** @group setParam */ - def setLowerInclusive(value: Boolean): this.type = set(lowerInclusive, value) - - /** - * An indicator of the inclusiveness of positive infinite. If true, then use implicit bin - * [getSplits.last, inf). If false, then throw exception if values > getSplits.last are - * encountered. - * @group Param */ - val upperInclusive: BooleanParam = new BooleanParam(this, "upperInclusive", - "An indicator of the inclusiveness of positive infinite. If true, then use implicit bin " + - "[getSplits.last, inf). If false, then throw exception if values > getSplits.last are " + - "encountered.") - setDefault(upperInclusive -> true) - - /** @group getParam */ - def getUpperInclusive: Boolean = $(upperInclusive) - - /** @group setParam */ - def setUpperInclusive(value: Boolean): this.type = set(upperInclusive, value) - /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) @@ -97,81 +66,66 @@ final class Bucketizer private[ml] (override val parent: Estimator[Bucketizer]) override def transform(dataset: DataFrame): DataFrame = { transformSchema(dataset.schema) - val wrappedSplits = Array(Double.MinValue) ++ $(splits) ++ Array(Double.MaxValue) val bucketizer = udf { feature: Double => - Bucketizer - .binarySearchForBuckets(wrappedSplits, feature, $(lowerInclusive), $(upperInclusive)) } + Bucketizer.binarySearchForBuckets($(splits), feature) + } val newCol = bucketizer(dataset($(inputCol))) val newField = prepOutputField(dataset.schema) dataset.withColumn($(outputCol), newCol.as($(outputCol), newField.metadata)) } private def prepOutputField(schema: StructType): StructField = { - val innerRanges = $(splits).sliding(2).map(bucket => bucket.mkString(", ")).toArray - val values = ($(lowerInclusive), $(upperInclusive)) match { - case (true, true) => - Array(s"-inf, ${$(splits).head}") ++ innerRanges ++ Array(s"${$(splits).last}, inf") - case (true, false) => Array(s"-inf, ${$(splits).head}") ++ innerRanges - case (false, true) => innerRanges ++ Array(s"${$(splits).last}, inf") - case _ => innerRanges - } - val attr = - new NominalAttribute(name = Some($(outputCol)), isOrdinal = Some(true), values = Some(values)) + val buckets = $(splits).sliding(2).map(bucket => bucket.mkString(", ")).toArray + val attr = new NominalAttribute(name = Some($(outputCol)), isOrdinal = Some(true), + values = Some(buckets)) attr.toStructField() } override def transformSchema(schema: StructType): StructType = { SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType) - require(schema.fields.forall(_.name != $(outputCol)), - s"Output column ${$(outputCol)} already exists.") - StructType(schema.fields :+ prepOutputField(schema)) + SchemaUtils.appendColumn(schema, prepOutputField(schema)) } } private[feature] object Bucketizer { - /** - * The given splits should match 1) its size is larger than zero; 2) it is ordered in a strictly - * increasing way. - */ - private def checkSplits(splits: Array[Double]): Boolean = { - if (splits.size == 0) false - else if (splits.size == 1) true - else { - splits.foldLeft((true, Double.MinValue)) { case ((validator, prevValue), currValue) => - if (validator && prevValue < currValue) { - (true, currValue) - } else { - (false, currValue) - } - }._1 + /** We require splits to be of length >= 3 and to be in strictly increasing order. */ + def checkSplits(splits: Array[Double]): Boolean = { + if (splits.length < 3) { + false + } else { + var i = 0 + while (i < splits.length - 1) { + if (splits(i) >= splits(i + 1)) return false + i += 1 + } + true } } /** * Binary searching in several buckets to place each data point. + * @throws RuntimeException if a feature is < splits.head or >= splits.last */ - private[feature] def binarySearchForBuckets( + def binarySearchForBuckets( splits: Array[Double], - feature: Double, - lowerInclusive: Boolean, - upperInclusive: Boolean): Double = { - if ((feature < splits.head && !lowerInclusive) || (feature > splits.last && !upperInclusive)) { - throw new RuntimeException(s"Feature $feature out of bound, check your features or loosen " + - s"the lower/upper bound constraint.") + feature: Double): Double = { + // Check bounds. We make an exception for +inf so that it can exist in some bin. + if ((feature < splits.head) || (feature >= splits.last && feature != Double.PositiveInfinity)) { + throw new RuntimeException(s"Feature value $feature out of Bucketizer bounds" + + s" [${splits.head}, ${splits.last}). Check your features, or loosen " + + s"the lower/upper bound constraints.") } var left = 0 var right = splits.length - 2 - while (left <= right) { - val mid = left + (right - left) / 2 - val split = splits(mid) - if ((feature >= split) && (feature < splits(mid + 1))) { - return mid - } else if (feature < split) { - right = mid - 1 + while (left < right) { + val mid = (left + right) / 2 + val split = splits(mid + 1) + if (feature < split) { + right = mid } else { left = mid + 1 } } - throw new RuntimeException(s"Unexpected error: failed to find a bucket for feature $feature.") + left } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala index 0383bf0b382b7..11592b77eb356 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala @@ -58,4 +58,15 @@ object SchemaUtils { val outputFields = schema.fields :+ StructField(colName, dataType, nullable = false) StructType(outputFields) } + + /** + * Appends a new column to the input schema. This fails if the given output column already exists. + * @param schema input schema + * @param col New column schema + * @return new schema with the input column appended + */ + def appendColumn(schema: StructType, col: StructField): StructType = { + require(!schema.fieldNames.contains(col.name), s"Column ${col.name} already exists.") + StructType(schema.fields :+ col) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala index 8be5421bfff64..77b0b75391201 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala @@ -21,6 +21,7 @@ import scala.util.Random import org.scalatest.FunSuite +import org.apache.spark.SparkException import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ @@ -28,13 +29,20 @@ import org.apache.spark.sql.{DataFrame, Row, SQLContext} class BucketizerSuite extends FunSuite with MLlibTestSparkContext { - test("Bucket continuous features with setter") { - val sqlContext = new SQLContext(sc) - val data = Array(0.1, -0.5, 0.2, -0.3, 0.8, 0.7, -0.1, -0.4, -0.9) + @transient private var sqlContext: SQLContext = _ + + override def beforeAll(): Unit = { + super.beforeAll() + sqlContext = new SQLContext(sc) + } + + test("Bucket continuous features, without -inf,inf") { + // Check a set of valid feature values. val splits = Array(-0.5, 0.0, 0.5) - val bucketizedData = Array(2.0, 1.0, 2.0, 1.0, 3.0, 3.0, 1.0, 1.0, 0.0) - val dataFrame: DataFrame = sqlContext.createDataFrame( - data.zip(bucketizedData)).toDF("feature", "expected") + val validData = Array(-0.5, -0.3, 0.0, 0.2) + val expectedBuckets = Array(0.0, 0.0, 1.0, 1.0) + val dataFrame: DataFrame = + sqlContext.createDataFrame(validData.zip(expectedBuckets)).toDF("feature", "expected") val bucketizer: Bucketizer = new Bucketizer() .setInputCol("feature") @@ -43,58 +51,96 @@ class BucketizerSuite extends FunSuite with MLlibTestSparkContext { bucketizer.transform(dataFrame).select("result", "expected").collect().foreach { case Row(x: Double, y: Double) => - assert(x === y, "The feature value is not correct after bucketing.") + assert(x === y, + s"The feature value is not correct after bucketing. Expected $y but found $x") } - } - test("Binary search correctness in contrast with linear search") { - val data = Array.fill(100)(Random.nextDouble()) - val splits = Array.fill(10)(Random.nextDouble()).sorted - val wrappedSplits = Array(Double.MinValue) ++ splits ++ Array(Double.MaxValue) - val bsResult = Vectors.dense( - data.map(x => Bucketizer.binarySearchForBuckets(wrappedSplits, x, true, true))) - val lsResult = Vectors.dense(data.map(x => BucketizerSuite.linearSearchForBuckets(splits, x))) - assert(bsResult ~== lsResult absTol 1e-5) + // Check for exceptions when using a set of invalid feature values. + val invalidData1: Array[Double] = Array(-0.9) ++ validData + val invalidData2 = Array(0.5) ++ validData + val badDF1 = sqlContext.createDataFrame(invalidData1.zipWithIndex).toDF("feature", "idx") + intercept[RuntimeException]{ + bucketizer.transform(badDF1).collect() + println("Invalid feature value -0.9 was not caught as an invalid feature!") + } + val badDF2 = sqlContext.createDataFrame(invalidData2.zipWithIndex).toDF("feature", "idx") + intercept[RuntimeException]{ + bucketizer.transform(badDF2).collect() + println("Invalid feature value 0.5 was not caught as an invalid feature!") + } } - test("Binary search of features at splits") { - val splits = Array.fill(10)(Random.nextDouble()).sorted - val data = splits - val expected = Vectors.dense(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0) - val wrappedSplits = Array(Double.MinValue) ++ splits ++ Array(Double.MaxValue) - val result = Vectors.dense( - data.map(x => Bucketizer.binarySearchForBuckets(wrappedSplits, x, true, true))) - assert(result ~== expected absTol 1e-5) + test("Bucket continuous features, with -inf,inf") { + val splits = Array(Double.NegativeInfinity, -0.5, 0.0, 0.5, Double.PositiveInfinity) + val validData = Array(-0.9, -0.5, -0.3, 0.0, 0.2, 0.5, 0.9) + val expectedBuckets = Array(0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0) + val dataFrame: DataFrame = + sqlContext.createDataFrame(validData.zip(expectedBuckets)).toDF("feature", "expected") + + val bucketizer: Bucketizer = new Bucketizer() + .setInputCol("feature") + .setOutputCol("result") + .setSplits(splits) + + bucketizer.transform(dataFrame).select("result", "expected").collect().foreach { + case Row(x: Double, y: Double) => + assert(x === y, + s"The feature value is not correct after bucketing. Expected $y but found $x") + } } - test("Binary search of features between splits") { - val data = Array.fill(10)(Random.nextDouble()) - val splits = Array(-0.1, 1.1) - val expected = Vectors.dense(Array.fill(10)(1.0)) - val wrappedSplits = Array(Double.MinValue) ++ splits ++ Array(Double.MaxValue) - val result = Vectors.dense( - data.map(x => Bucketizer.binarySearchForBuckets(wrappedSplits, x, true, true))) - assert(result ~== expected absTol 1e-5) + test("Binary search correctness on hand-picked examples") { + import BucketizerSuite.checkBinarySearch + // length 3, with -inf + checkBinarySearch(Array(Double.NegativeInfinity, 0.0, 1.0)) + // length 4 + checkBinarySearch(Array(-1.0, -0.5, 0.0, 1.0)) + // length 5 + checkBinarySearch(Array(-1.0, -0.5, 0.0, 1.0, 1.5)) + // length 3, with inf + checkBinarySearch(Array(0.0, 1.0, Double.PositiveInfinity)) + // length 3, with -inf and inf + checkBinarySearch(Array(Double.NegativeInfinity, 1.0, Double.PositiveInfinity)) } - test("Binary search of features outside splits") { - val data = Array.fill(5)(Random.nextDouble() + 1.1) ++ Array.fill(5)(Random.nextDouble() - 1.1) - val splits = Array(0.0, 1.1) - val expected = Vectors.dense(Array.fill(5)(2.0) ++ Array.fill(5)(0.0)) - val wrappedSplits = Array(Double.MinValue) ++ splits ++ Array(Double.MaxValue) - val result = Vectors.dense( - data.map(x => Bucketizer.binarySearchForBuckets(wrappedSplits, x, true, true))) - assert(result ~== expected absTol 1e-5) + test("Binary search correctness in contrast with linear search, on random data") { + val data = Array.fill(100)(Random.nextDouble()) + val splits: Array[Double] = Double.NegativeInfinity +: + Array.fill(10)(Random.nextDouble()).sorted :+ Double.PositiveInfinity + val bsResult = Vectors.dense(data.map(x => Bucketizer.binarySearchForBuckets(splits, x))) + val lsResult = Vectors.dense(data.map(x => BucketizerSuite.linearSearchForBuckets(splits, x))) + assert(bsResult ~== lsResult absTol 1e-5) } } -private object BucketizerSuite { - private def linearSearchForBuckets(splits: Array[Double], feature: Double): Double = { +private object BucketizerSuite extends FunSuite { + /** Brute force search for buckets. Bucket i is defined by the range [split(i), split(i+1)). */ + def linearSearchForBuckets(splits: Array[Double], feature: Double): Double = { + require(feature >= splits.head) var i = 0 - while (i < splits.size) { - if (feature < splits(i)) return i + while (i < splits.length - 1) { + if (feature < splits(i + 1)) return i i += 1 } - i + throw new RuntimeException( + s"linearSearchForBuckets failed to find bucket for feature value $feature") + } + + /** Check all values in splits, plus values between all splits. */ + def checkBinarySearch(splits: Array[Double]): Unit = { + def testFeature(feature: Double, expectedBucket: Double): Unit = { + assert(Bucketizer.binarySearchForBuckets(splits, feature) === expectedBucket, + s"Expected feature value $feature to be in bucket $expectedBucket with splits:" + + s" ${splits.mkString(", ")}") + } + var i = 0 + while (i < splits.length - 1) { + testFeature(splits(i), i) // Split i should fall in bucket i. + testFeature((splits(i) + splits(i + 1)) / 2, i) // Value between splits i,i+1 should be in i. + i += 1 + } + if (splits.last === Double.PositiveInfinity) { + testFeature(Double.PositiveInfinity, splits.length - 2) + } } }