Skip to content

Commit

Permalink
check buckets
Browse files Browse the repository at this point in the history
  • Loading branch information
yinxusen committed May 7, 2015
1 parent 4024cf1 commit 998bc87
Showing 1 changed file with 21 additions and 3 deletions.
24 changes: 21 additions & 3 deletions mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,30 @@ import org.apache.spark.sql.types.{DoubleType, StructType}
@AlphaComponent
final class Bucketizer extends Transformer with HasInputCol with HasOutputCol {

/**
* The given buckets should match 1) its size is larger than zero; 2) it is ordered in a non-DESC
* way.
*/
private def checkBuckets(buckets: Array[Double]): Boolean = {
if (buckets.size == 0) false
else if (buckets.size == 1) true
else {
buckets.foldLeft((true, Double.MinValue)) { case ((validator, prevValue), currValue) =>
if (validator & prevValue <= currValue) {
(true, currValue)
} else {
(false, currValue)
}
}._1
}
}

/**
* Parameter for mapping continuous features into buckets.
* @group param
*/
val buckets: Param[Array[Double]] = new Param[Array[Double]](this, "buckets",
"Map continuous features into buckets.")
"Split points for mapping continuous features into buckets.", checkBuckets)

/** @group getParam */
def getBuckets: Array[Double] = $(buckets)
Expand All @@ -55,7 +73,7 @@ final class Bucketizer extends Transformer with HasInputCol with HasOutputCol {

override def transform(dataset: DataFrame): DataFrame = {
transformSchema(dataset.schema)
val bucketizer = udf { feature: Double => binarySearchForBins($(buckets), feature) }
val bucketizer = udf { feature: Double => binarySearchForBuckets($(buckets), feature) }
val outputColName = $(outputCol)
val metadata = NominalAttribute.defaultAttr
.withName(outputColName).withValues($(buckets).map(_.toString)).toMetadata()
Expand All @@ -65,7 +83,7 @@ final class Bucketizer extends Transformer with HasInputCol with HasOutputCol {
/**
* Binary searching in several buckets to place each data point.
*/
private def binarySearchForBins(splits: Array[Double], feature: Double): Double = {
private def binarySearchForBuckets(splits: Array[Double], feature: Double): Double = {
val wrappedSplits = Array(Double.MinValue) ++ splits ++ Array(Double.MaxValue)
var left = 0
var right = wrappedSplits.length - 2
Expand Down

0 comments on commit 998bc87

Please sign in to comment.