From 68c2318f4bba22d86a895f1c9b76e2a431c3241b Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Thu, 30 Apr 2015 10:31:39 +0800 Subject: [PATCH] add a java ut --- .../spark/mllib/clustering/JavaLDASuite.java | 31 +++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java index fbe171b4b1ab1..2e8a9e99f6d01 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java @@ -109,6 +109,37 @@ public void distributedLDAModel() { assert(model.logPrior() < 0.0); } + + @Test + public void OnlineOptimizerCompatibility() { + int k = 3; + double topicSmoothing = 1.2; + double termSmoothing = 1.2; + + // Train a model + OnlineLDAOptimizer op = new OnlineLDAOptimizer().setTau_0(1024).setKappa(0.51) + .setGammaShape(1e40).setMiniBatchFraction(0.5); + LDA lda = new LDA(); + lda.setK(k) + .setDocConcentration(topicSmoothing) + .setTopicConcentration(termSmoothing) + .setMaxIterations(5) + .setSeed(12345) + .setOptimizer(op); + + LDAModel model = lda.run(corpus); + + // Check: basic parameters + assertEquals(model.k(), k); + assertEquals(model.vocabSize(), tinyVocabSize); + + // Check: topic summaries + Tuple2[] roundedTopicSummary = model.describeTopics(); + assertEquals(roundedTopicSummary.length, k); + Tuple2[] roundedLocalTopicSummary = model.describeTopics(); + assertEquals(roundedLocalTopicSummary.length, k); + } + private static int tinyK = LDASuite$.MODULE$.tinyK(); private static int tinyVocabSize = LDASuite$.MODULE$.tinyVocabSize(); private static Matrix tinyTopics = LDASuite$.MODULE$.tinyTopics();