diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala index d0a299cb894b2..b352847ab6cc2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala @@ -295,7 +295,7 @@ class OnlineLDAOptimizer extends LDAOptimizer { } /** - * The function is for test only now. In the future, it can help support training strop/resume + * The function is for test only now. In the future, it can help support training stop/resume */ private[clustering] def setLambda(lambda: BDM[Double]): this.type = { this.lambda = lambda @@ -310,8 +310,9 @@ class OnlineLDAOptimizer extends LDAOptimizer { this } - override private[clustering] def initialize(docs: RDD[(Long, Vector)], lda: LDA): - OnlineLDAOptimizer = { + override private[clustering] def initialize( + docs: RDD[(Long, Vector)], + lda: LDA): OnlineLDAOptimizer = { this.k = lda.getK this.corpusSize = docs.count() this.vocabSize = docs.first()._2.size @@ -333,7 +334,6 @@ class OnlineLDAOptimizer extends LDAOptimizer { submitMiniBatch(batch) } - /** * Submit a subset (like 1%, decide by the miniBatchFraction) of the corpus to the Online LDA * model, and it will update the topic distribution adaptively for the terms appearing in the diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java index 2e8a9e99f6d01..dd61f054b18c7 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java @@ -109,35 +109,38 @@ public void distributedLDAModel() { assert(model.logPrior() < 0.0); } - @Test public void OnlineOptimizerCompatibility() { - int k = 3; - double topicSmoothing = 1.2; - double termSmoothing = 1.2; - - // Train a model - OnlineLDAOptimizer op = new OnlineLDAOptimizer().setTau_0(1024).setKappa(0.51) - .setGammaShape(1e40).setMiniBatchFraction(0.5); - LDA lda = new LDA(); - lda.setK(k) - .setDocConcentration(topicSmoothing) - .setTopicConcentration(termSmoothing) - .setMaxIterations(5) - .setSeed(12345) - .setOptimizer(op); - - LDAModel model = lda.run(corpus); - - // Check: basic parameters - assertEquals(model.k(), k); - assertEquals(model.vocabSize(), tinyVocabSize); - - // Check: topic summaries - Tuple2[] roundedTopicSummary = model.describeTopics(); - assertEquals(roundedTopicSummary.length, k); - Tuple2[] roundedLocalTopicSummary = model.describeTopics(); - assertEquals(roundedLocalTopicSummary.length, k); + int k = 3; + double topicSmoothing = 1.2; + double termSmoothing = 1.2; + + // Train a model + OnlineLDAOptimizer op = new OnlineLDAOptimizer() + .setTau_0(1024) + .setKappa(0.51) + .setGammaShape(1e40) + .setMiniBatchFraction(0.5); + + LDA lda = new LDA(); + lda.setK(k) + .setDocConcentration(topicSmoothing) + .setTopicConcentration(termSmoothing) + .setMaxIterations(5) + .setSeed(12345) + .setOptimizer(op); + + LDAModel model = lda.run(corpus); + + // Check: basic parameters + assertEquals(model.k(), k); + assertEquals(model.vocabSize(), tinyVocabSize); + + // Check: topic summaries + Tuple2[] roundedTopicSummary = model.describeTopics(); + assertEquals(roundedTopicSummary.length, k); + Tuple2[] roundedLocalTopicSummary = model.describeTopics(); + assertEquals(roundedLocalTopicSummary.length, k); } private static int tinyK = LDASuite$.MODULE$.tinyK();