Skip to content

Commit

Permalink
updated test suite with model type fix
Browse files Browse the repository at this point in the history
  • Loading branch information
leahmcguire committed Mar 5, 2015
1 parent 85f298f commit e016569
Showing 1 changed file with 8 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import scala.util.Random
import org.scalatest.FunSuite

import org.apache.spark.SparkException
import org.apache.spark.mllib.classification.NaiveBayesModels.NaiveBayesModels
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}
Expand All @@ -49,7 +48,7 @@ object NaiveBayesSuite {
theta: Array[Array[Double]], // CXD
nPoints: Int,
seed: Int,
dataModel: NaiveBayesModels= NaiveBayesModels.Multinomial,
dataModel: NaiveBayes.ModelType = NaiveBayes.Multinomial,
sample: Int = 10): Seq[LabeledPoint] = {
val D = theta(0).length
val rnd = new Random(seed)
Expand All @@ -60,10 +59,10 @@ object NaiveBayesSuite {
for (i <- 0 until nPoints) yield {
val y = calcLabel(rnd.nextDouble(), _pi)
val xi = dataModel match {
case NaiveBayesModels.Bernoulli => Array.tabulate[Double] (D) {j =>
case NaiveBayes.Bernoulli => Array.tabulate[Double] (D) {j =>
if (rnd.nextDouble () < _theta(y)(j) ) 1 else 0
}
case NaiveBayesModels.Multinomial =>
case NaiveBayes.Multinomial =>
val mult = Multinomial(BDV(_theta(y)))
val emptyMap = (0 until D).map(x => (x, 0.0)).toMap
val counts = emptyMap ++ mult.sample(sample).groupBy(x => x).map {
Expand All @@ -78,7 +77,7 @@ object NaiveBayesSuite {

/** Binary labels, 3 features */
private val binaryModel = new NaiveBayesModel(labels = Array(0.0, 1.0), pi = Array(0.2, 0.8),
theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)), NaiveBayesModels.Bernoulli)
theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)), NaiveBayes.Bernoulli)
}

class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
Expand Down Expand Up @@ -121,7 +120,7 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
).map(_.map(math.log))

val testData = NaiveBayesSuite.generateNaiveBayesInput(
pi, theta, nPoints, 42, NaiveBayesModels.Multinomial)
pi, theta, nPoints, 42, NaiveBayes.Multinomial)
val testRDD = sc.parallelize(testData, 2)
testRDD.cache()

Expand All @@ -133,7 +132,7 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
theta,
nPoints,
17,
NaiveBayesModels.Multinomial)
NaiveBayes.Multinomial)
val validationRDD = sc.parallelize(validationData, 2)

// Test prediction on RDD.
Expand All @@ -158,7 +157,7 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
theta,
nPoints,
45,
NaiveBayesModels.Bernoulli)
NaiveBayes.Bernoulli)
val testRDD = sc.parallelize(testData, 2)
testRDD.cache()

Expand All @@ -170,7 +169,7 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
theta,
nPoints,
20,
NaiveBayesModels.Bernoulli)
NaiveBayes.Bernoulli)
val validationRDD = sc.parallelize(validationData, 2)

// Test prediction on RDD.
Expand Down

0 comments on commit e016569

Please sign in to comment.