Skip to content

Commit

Permalink
[SPARK-4373][MLLIB] fix MLlib maven tests
Browse files Browse the repository at this point in the history
We want to make sure there is at most one spark context inside the same jvm. JoshRosen

Author: Xiangrui Meng <[email protected]>

Closes #3235 from mengxr/SPARK-4373 and squashes the following commits:

6574b69 [Xiangrui Meng] rename LocalSparkContext to MLlibTestSparkContext
913d48d [Xiangrui Meng] make sure there is at most one spark context inside the same jvm
  • Loading branch information
mengxr authored and JoshRosen committed Nov 13, 2014
1 parent 723a86b commit 23f5bdf
Show file tree
Hide file tree
Showing 36 changed files with 108 additions and 82 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,24 @@ package org.apache.spark.ml.classification
import org.scalatest.FunSuite

import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput
import org.apache.spark.mllib.util.LocalSparkContext
import org.apache.spark.sql.SchemaRDD
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{SQLContext, SchemaRDD}

class LogisticRegressionSuite extends FunSuite with LocalSparkContext {
class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext {

import sqlContext._
@transient var sqlContext: SQLContext = _
@transient var dataset: SchemaRDD = _

val dataset: SchemaRDD = sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2)
override def beforeAll(): Unit = {
super.beforeAll()
sqlContext = new SQLContext(sc)
dataset = sqlContext.createSchemaRDD(
sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2))
}

test("logistic regression") {
val sqlContext = this.sqlContext
import sqlContext._
val lr = new LogisticRegression
val model = lr.fit(dataset)
model.transform(dataset)
Expand All @@ -38,6 +46,8 @@ class LogisticRegressionSuite extends FunSuite with LocalSparkContext {
}

test("logistic regression with setters") {
val sqlContext = this.sqlContext
import sqlContext._
val lr = new LogisticRegression()
.setMaxIter(10)
.setRegParam(1.0)
Expand All @@ -48,6 +58,8 @@ class LogisticRegressionSuite extends FunSuite with LocalSparkContext {
}

test("logistic regression fit and transform with varargs") {
val sqlContext = this.sqlContext
import sqlContext._
val lr = new LogisticRegression
val model = lr.fit(dataset, lr.maxIter -> 10, lr.regParam -> 1.0)
model.transform(dataset, model.threshold -> 0.8, model.scoreCol -> "probability")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,19 @@ import org.scalatest.FunSuite
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput
import org.apache.spark.mllib.util.LocalSparkContext
import org.apache.spark.sql.SchemaRDD
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{SQLContext, SchemaRDD}

class CrossValidatorSuite extends FunSuite with LocalSparkContext {
class CrossValidatorSuite extends FunSuite with MLlibTestSparkContext {

import sqlContext._
@transient var dataset: SchemaRDD = _

val dataset: SchemaRDD = sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2)
override def beforeAll(): Unit = {
super.beforeAll()
val sqlContext = new SQLContext(sc)
dataset = sqlContext.createSchemaRDD(
sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2))
}

test("cross validation with logistic regression") {
val lr = new LogisticRegression
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.scalatest.Matchers

import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression._
import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext}
import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}
import org.apache.spark.mllib.util.TestingUtils._

object LogisticRegressionSuite {
Expand Down Expand Up @@ -57,7 +57,7 @@ object LogisticRegressionSuite {
}
}

class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Matchers {
class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with Matchers {
def validatePrediction(
predictions: Seq[Double],
input: Seq[LabeledPoint],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.scalatest.FunSuite
import org.apache.spark.SparkException
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext}
import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}

object NaiveBayesSuite {

Expand Down Expand Up @@ -60,7 +60,7 @@ object NaiveBayesSuite {
}
}

class NaiveBayesSuite extends FunSuite with LocalSparkContext {
class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {

def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) {
val numOfPredictions = predictions.zip(input).count {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.scalatest.FunSuite
import org.apache.spark.SparkException
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression._
import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext}
import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}

object SVMSuite {

Expand Down Expand Up @@ -58,7 +58,7 @@ object SVMSuite {

}

class SVMSuite extends FunSuite with LocalSparkContext {
class SVMSuite extends FunSuite with MLlibTestSparkContext {

def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) {
val numOffPredictions = predictions.zip(input).count { case (prediction, expected) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ import scala.util.Random
import org.scalatest.FunSuite

import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext}
import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}
import org.apache.spark.mllib.util.TestingUtils._

class KMeansSuite extends FunSuite with LocalSparkContext {
class KMeansSuite extends FunSuite with MLlibTestSparkContext {

import org.apache.spark.mllib.clustering.KMeans.{K_MEANS_PARALLEL, RANDOM}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ package org.apache.spark.mllib.evaluation

import org.scalatest.FunSuite

import org.apache.spark.mllib.util.LocalSparkContext
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._

class AreaUnderCurveSuite extends FunSuite with LocalSparkContext {
class AreaUnderCurveSuite extends FunSuite with MLlibTestSparkContext {
test("auc computation") {
val curve = Seq((0.0, 0.0), (1.0, 1.0), (2.0, 3.0), (3.0, 0.0))
val auc = 4.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ package org.apache.spark.mllib.evaluation

import org.scalatest.FunSuite

import org.apache.spark.mllib.util.LocalSparkContext
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._

class BinaryClassificationMetricsSuite extends FunSuite with LocalSparkContext {
class BinaryClassificationMetricsSuite extends FunSuite with MLlibTestSparkContext {

def cond1(x: (Double, Double)): Boolean = x._1 ~= (x._2) absTol 1E-5

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ package org.apache.spark.mllib.evaluation
import org.scalatest.FunSuite

import org.apache.spark.mllib.linalg.Matrices
import org.apache.spark.mllib.util.LocalSparkContext
import org.apache.spark.mllib.util.MLlibTestSparkContext

class MulticlassMetricsSuite extends FunSuite with LocalSparkContext {
class MulticlassMetricsSuite extends FunSuite with MLlibTestSparkContext {
test("Multiclass evaluation metrics") {
/*
* Confusion matrix for 3-class classification with total 9 instances:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ package org.apache.spark.mllib.evaluation

import org.scalatest.FunSuite

import org.apache.spark.mllib.util.LocalSparkContext
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD

class MultilabelMetricsSuite extends FunSuite with LocalSparkContext {
class MultilabelMetricsSuite extends FunSuite with MLlibTestSparkContext {
test("Multilabel evaluation metrics") {
/*
* Documents true labels (5x class0, 3x class1, 4x class2):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ package org.apache.spark.mllib.evaluation
import org.scalatest.FunSuite

import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.mllib.util.LocalSparkContext
import org.apache.spark.mllib.util.MLlibTestSparkContext

class RankingMetricsSuite extends FunSuite with LocalSparkContext {
class RankingMetricsSuite extends FunSuite with MLlibTestSparkContext {
test("Ranking metrics: map, ndcg") {
val predictionAndLabels = sc.parallelize(
Seq(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ package org.apache.spark.mllib.evaluation

import org.scalatest.FunSuite

import org.apache.spark.mllib.util.LocalSparkContext
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._

class RegressionMetricsSuite extends FunSuite with LocalSparkContext {
class RegressionMetricsSuite extends FunSuite with MLlibTestSparkContext {

test("regression metrics") {
val predictionAndObservations = sc.parallelize(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ package org.apache.spark.mllib.feature
import org.scalatest.FunSuite

import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.LocalSparkContext
import org.apache.spark.mllib.util.MLlibTestSparkContext

class HashingTFSuite extends FunSuite with LocalSparkContext {
class HashingTFSuite extends FunSuite with MLlibTestSparkContext {

test("hashing tf on a single doc") {
val hashingTF = new HashingTF(1000)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ import org.scalatest.FunSuite

import org.apache.spark.SparkContext._
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors}
import org.apache.spark.mllib.util.LocalSparkContext
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._

class IDFSuite extends FunSuite with LocalSparkContext {
class IDFSuite extends FunSuite with MLlibTestSparkContext {

test("idf") {
val n = 4
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ import org.scalatest.FunSuite
import breeze.linalg.{norm => brzNorm}

import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors}
import org.apache.spark.mllib.util.LocalSparkContext
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._

class NormalizerSuite extends FunSuite with LocalSparkContext {
class NormalizerSuite extends FunSuite with MLlibTestSparkContext {

val data = Array(
Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@ package org.apache.spark.mllib.feature
import org.scalatest.FunSuite

import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.util.LocalSparkContext
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.mllib.rdd.RDDFunctions._
import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, MultivariateOnlineSummarizer}
import org.apache.spark.rdd.RDD

class StandardScalerSuite extends FunSuite with LocalSparkContext {
class StandardScalerSuite extends FunSuite with MLlibTestSparkContext {

private def computeSummary(data: RDD[Vector]): MultivariateStatisticalSummary = {
data.treeAggregate(new MultivariateOnlineSummarizer)(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ package org.apache.spark.mllib.feature

import org.scalatest.FunSuite

import org.apache.spark.mllib.util.LocalSparkContext
import org.apache.spark.mllib.util.MLlibTestSparkContext

class Word2VecSuite extends FunSuite with LocalSparkContext {
class Word2VecSuite extends FunSuite with MLlibTestSparkContext {

// TODO: add more tests

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ import org.scalatest.FunSuite

import breeze.linalg.{DenseMatrix => BDM}

import org.apache.spark.mllib.util.LocalSparkContext
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.linalg.Vectors

class CoordinateMatrixSuite extends FunSuite with LocalSparkContext {
class CoordinateMatrixSuite extends FunSuite with MLlibTestSparkContext {

val m = 5
val n = 4
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ import org.scalatest.FunSuite

import breeze.linalg.{diag => brzDiag, DenseMatrix => BDM, DenseVector => BDV}

import org.apache.spark.mllib.util.LocalSparkContext
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.linalg.{Matrices, Vectors}

class IndexedRowMatrixSuite extends FunSuite with LocalSparkContext {
class IndexedRowMatrixSuite extends FunSuite with MLlibTestSparkContext {

val m = 4
val n = 3
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ import breeze.linalg.{DenseVector => BDV, DenseMatrix => BDM, norm => brzNorm, s
import org.scalatest.FunSuite

import org.apache.spark.mllib.linalg.{Matrices, Vectors, Vector}
import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext}
import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}

class RowMatrixSuite extends FunSuite with LocalSparkContext {
class RowMatrixSuite extends FunSuite with MLlibTestSparkContext {

val m = 4
val n = 3
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.scalatest.{FunSuite, Matchers}

import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression._
import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext}
import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}
import org.apache.spark.mllib.util.TestingUtils._

object GradientDescentSuite {
Expand Down Expand Up @@ -61,7 +61,7 @@ object GradientDescentSuite {
}
}

class GradientDescentSuite extends FunSuite with LocalSparkContext with Matchers {
class GradientDescentSuite extends FunSuite with MLlibTestSparkContext with Matchers {

test("Assert the loss is decreasing.") {
val nPoints = 10000
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ import org.scalatest.{FunSuite, Matchers}

import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext}
import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}
import org.apache.spark.mllib.util.TestingUtils._

class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers {
class LBFGSSuite extends FunSuite with MLlibTestSparkContext with Matchers {

val nPoints = 10000
val A = 2.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.scalatest.FunSuite
import org.apache.spark.SparkContext._
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.rdd.{RandomRDDPartition, RandomRDD}
import org.apache.spark.mllib.util.LocalSparkContext
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.util.StatCounter

Expand All @@ -34,7 +34,7 @@ import org.apache.spark.util.StatCounter
*
* TODO update tests to use TestingUtils for floating point comparison after PR 1367 is merged
*/
class RandomRDDsSuite extends FunSuite with LocalSparkContext with Serializable {
class RandomRDDsSuite extends FunSuite with MLlibTestSparkContext with Serializable {

def testGeneratedRDD(rdd: RDD[Double],
expectedSize: Long,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ package org.apache.spark.mllib.rdd

import org.scalatest.FunSuite

import org.apache.spark.mllib.util.LocalSparkContext
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.rdd.RDDFunctions._

class RDDFunctionsSuite extends FunSuite with LocalSparkContext {
class RDDFunctionsSuite extends FunSuite with MLlibTestSparkContext {

test("sliding") {
val data = 0 until 6
Expand Down
Loading

0 comments on commit 23f5bdf

Please sign in to comment.