Skip to content

Commit

Permalink
SPARK-3278 Isotonic regression java api
Browse files Browse the repository at this point in the history
  • Loading branch information
zapletal-martin committed Jan 12, 2015
1 parent 45aa7e8 commit 3c2954b
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ class IsotonicRegressionModel (
* @param testData features to be labeled
* @return predicted labels
*/
def predict(testData: JavaRDD[java.lang.Double]): RDD[java.lang.Double] =
testData.rdd.map(x => x.doubleValue()).map(predict)
def predict(testData: JavaRDD[java.lang.Double]): JavaRDD[java.lang.Double] =
testData.rdd.map(_.doubleValue()).map(predict).map(new java.lang.Double(_))

/**
* Predict a single label
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,28 @@ import org.apache.spark.annotation.DeveloperApi
import scala.collection.JavaConversions._
import java.lang.{Double => JDouble}

/**
* :: DeveloperApi ::
* Generate test data for Isotonic regresision.
*/
@DeveloperApi
object IsotonicDataGenerator {

/**
* Return a Java List of ordered labeled points
*
* @param labels list of labels for the data points
* @return Java List of input.
*/
def generateIsotonicInputAsList(labels: Array[Double]): java.util.List[(JDouble, JDouble)] = {
seqAsJavaList(generateIsotonicInput(wrapDoubleArray(labels):_*).map(x => (new JDouble(x._1), new JDouble(x._2))))
seqAsJavaList(
generateIsotonicInput(
wrapDoubleArray(labels):_*).map(x => (new JDouble(x._1), new JDouble(x._2))))
}

/**
* Return an ordered sequence of labeled data points with default weights
*
* @param labels list of labels for the data points
* @return sequence of data points
*/
Expand All @@ -45,6 +53,7 @@ object IsotonicDataGenerator {

/**
* Return an ordered sequence of labeled weighted data points
*
* @param labels list of labels for the data points
* @param weights list of weights for the data points
* @return sequence of data points
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,10 @@ public Double call(Tuple2<Double, Double> v) throws Exception {
}
});

Double[] predictions = model.predict(testRDD).collect();
List<Double> predictions = model.predict(testRDD).collect();

Assert.assertTrue(predictions[0] == 1d);
Assert.assertTrue(predictions[11] == 12d);
Assert.assertTrue(predictions.get(0) == 1d);
Assert.assertTrue(predictions.get(11) == 12d);
}
}

0 comments on commit 3c2954b

Please sign in to comment.