diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index 3a3f868d4ed74..b3d8b17cecdfd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -34,12 +34,30 @@ import org.apache.spark.sql.types.{DoubleType, StructType} @AlphaComponent final class Bucketizer extends Transformer with HasInputCol with HasOutputCol { + /** + * The given buckets should match 1) its size is larger than zero; 2) it is ordered in a non-DESC + * way. + */ + private def checkBuckets(buckets: Array[Double]): Boolean = { + if (buckets.size == 0) false + else if (buckets.size == 1) true + else { + buckets.foldLeft((true, Double.MinValue)) { case ((validator, prevValue), currValue) => + if (validator & prevValue <= currValue) { + (true, currValue) + } else { + (false, currValue) + } + }._1 + } + } + /** * Parameter for mapping continuous features into buckets. * @group param */ val buckets: Param[Array[Double]] = new Param[Array[Double]](this, "buckets", - "Map continuous features into buckets.") + "Split points for mapping continuous features into buckets.", checkBuckets) /** @group getParam */ def getBuckets: Array[Double] = $(buckets) @@ -55,7 +73,7 @@ final class Bucketizer extends Transformer with HasInputCol with HasOutputCol { override def transform(dataset: DataFrame): DataFrame = { transformSchema(dataset.schema) - val bucketizer = udf { feature: Double => binarySearchForBins($(buckets), feature) } + val bucketizer = udf { feature: Double => binarySearchForBuckets($(buckets), feature) } val outputColName = $(outputCol) val metadata = NominalAttribute.defaultAttr .withName(outputColName).withValues($(buckets).map(_.toString)).toMetadata() @@ -65,7 +83,7 @@ final class Bucketizer extends Transformer with HasInputCol with HasOutputCol { /** * Binary searching in several buckets to place each data point. */ - private def binarySearchForBins(splits: Array[Double], feature: Double): Double = { + private def binarySearchForBuckets(splits: Array[Double], feature: Double): Double = { val wrappedSplits = Array(Double.MinValue) ++ splits ++ Array(Double.MaxValue) var left = 0 var right = wrappedSplits.length - 2