diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index c7cb86fa30f0b..70c7ff305541d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -186,6 +186,8 @@ class NaiveBayes private ( private var lambda: Double, private var modelType: NaiveBayes.ModelType) extends Serializable with Logging { + private def this(lambda: Double) = this(lambda, NaiveBayes.Multinomial) + def this() = this(1.0, NaiveBayes.Multinomial) /** Set the smoothing parameter. Default: 1.0. */ diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala index 7ce9be4e3cdd4..c296a4b7b4627 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala @@ -146,11 +146,7 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { ).map(_.map(math.log)) val testData = NaiveBayesSuite.generateNaiveBayesInput( - pi, - theta, - nPoints, - 45, - NaiveBayes.Bernoulli) + pi, theta, nPoints, 45, NaiveBayes.Bernoulli) val testRDD = sc.parallelize(testData, 2) testRDD.cache() @@ -158,11 +154,7 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { validateModelFit(pi, theta, model) val validationData = NaiveBayesSuite.generateNaiveBayesInput( - pi, - theta, - nPoints, - 20, - NaiveBayes.Bernoulli) + pi, theta, nPoints, 20, NaiveBayes.Bernoulli) val validationRDD = sc.parallelize(validationData, 2) // Test prediction on RDD.