diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java index c0d1a622ffad8..36207ae38d9a9 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java @@ -58,7 +58,7 @@ public Tuple2 call(Tuple2 doc_id) { corpus.cache(); // Cluster the documents into three topics using LDA - DistributedLDAModel ldaModel = (DistributedLDAModel) new LDA().setK(3).run(corpus); + DistributedLDAModel ldaModel = new LDA().setK(3).run(corpus); // Output topics. Each is a distribution over words (matching word count vectors) System.out.println("Learned topics (as distributions over vocab of " + ldaModel.vocabSize() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala index 0e1b27a8bd2ee..11399a7633638 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala @@ -159,7 +159,7 @@ object LDAExample { } println() } - + sc.stop() } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala index 1453e4dac768e..76ecdf92f26ed 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala @@ -32,7 +32,6 @@ import org.apache.spark.mllib.impl.PeriodicGraphCheckpointer import org.apache.spark.mllib.linalg.{Vector, DenseVector, SparseVector, Matrices} import org.apache.spark.rdd.RDD import org.apache.spark.util.Utils -import org.apache.spark.mllib.rdd.RDDFunctions._ /** @@ -223,10 +222,6 @@ class LDA private ( this } - object LDAMode extends Enumeration { - val EM, Online = Value - } - /** * Learn an LDA model using the given dataset. * @@ -236,37 +231,30 @@ class LDA private ( * Document IDs must be unique and >= 0. * @return Inferred LDA model */ - def run(documents: RDD[(Long, Vector)], mode: LDAMode.Value = LDAMode.EM ): LDAModel = { - mode match { - case LDAMode.EM => - val state = LDA.initialState(documents, k, getDocConcentration, getTopicConcentration, seed, - checkpointInterval) - var iter = 0 - val iterationTimes = Array.fill[Double](maxIterations)(0) - while (iter < maxIterations) { - val start = System.nanoTime() - state.next() - val elapsedSeconds = (System.nanoTime() - start) / 1e9 - iterationTimes(iter) = elapsedSeconds - iter += 1 - } - state.graphCheckpointer.deleteAllCheckpoints() - new DistributedLDAModel(state, iterationTimes) - case LDAMode.Online => - val vocabSize = documents.first._2.size - val onlineLDA = new LDA.OnlineLDAOptimizer(documents, k, vocabSize) - var iter = 0 - while (iter < onlineLDA.batchNumber) { - onlineLDA.next() - iter += 1 - } - new LocalLDAModel(Matrices.fromBreeze(onlineLDA._lambda).transpose) - case _ => throw new IllegalArgumentException(s"Do not support mode $mode.") + def run(documents: RDD[(Long, Vector)]): DistributedLDAModel = { + val state = LDA.initialState(documents, k, getDocConcentration, getTopicConcentration, seed, + checkpointInterval) + var iter = 0 + val iterationTimes = Array.fill[Double](maxIterations)(0) + while (iter < maxIterations) { + val start = System.nanoTime() + state.next() + val elapsedSeconds = (System.nanoTime() - start) / 1e9 + iterationTimes(iter) = elapsedSeconds + iter += 1 } + state.graphCheckpointer.deleteAllCheckpoints() + new DistributedLDAModel(state, iterationTimes) + } + + def runOnlineLDA(documents: RDD[(Long, Vector)]): LDAModel = { + val onlineLDA = new LDA.OnlineLDAOptimizer(documents, k) + (0 until onlineLDA.batchNumber).map(_ => onlineLDA.next()) + new LocalLDAModel(Matrices.fromBreeze(onlineLDA.lambda).transpose) } /** Java-friendly version of [[run()]] */ - def run(documents: JavaPairRDD[java.lang.Long, Vector]): LDAModel = { + def run(documents: JavaPairRDD[java.lang.Long, Vector]): DistributedLDAModel = { run(documents.rdd.asInstanceOf[RDD[(Long, Vector)]]) } } @@ -418,42 +406,48 @@ private[clustering] object LDA { } - // todo: add reference to paper and Hoffman + /** + * Optimizer for Online LDA algorithm which breaks corpus into mini-batches and scans only once. + * Hoffman, Blei and Bach, “Online Learning for Latent Dirichlet Allocation.” NIPS, 2010. + */ private[clustering] class OnlineLDAOptimizer( - val documents: RDD[(Long, Vector)], - val k: Int, - val vocabSize: Int) extends Serializable{ + private val documents: RDD[(Long, Vector)], + private val k: Int) extends Serializable{ - private val kappa = 0.5 // (0.5, 1] how quickly old information is forgotten - private val tau0 = 1024 // down weights early iterations - private val D = documents.count() + private val vocabSize = documents.first._2.size + private val D = documents.count().toInt private val batchSize = if (D / 1000 > 4096) 4096 else if (D / 1000 < 4) 4 else D / 1000 - val batchNumber = (D/batchSize + 1).toInt - private val batches = documents.sliding(batchNumber).collect() + val batchNumber = D/batchSize // Initialize the variational distribution q(beta|lambda) - var _lambda = getGammaMatrix(k, vocabSize) // K * V - private var _Elogbeta = dirichlet_expectation(_lambda) // K * V - private var _expElogbeta = exp(_Elogbeta) // K * V + var lambda = getGammaMatrix(k, vocabSize) // K * V + private var Elogbeta = dirichlet_expectation(lambda) // K * V + private var expElogbeta = exp(Elogbeta) // K * V - private var batchCount = 0 + private var batchId = 0 def next(): Unit = { - // weight of the mini-batch. - val rhot = math.pow(tau0 + batchCount, -kappa) + require(batchId < batchNumber) + // weight of the mini-batch. 1024 down weights early iterations + val weight = math.pow(1024 + batchId, -0.5) + val batch = documents.filter(doc => doc._1 % batchNumber == batchId) + // Given a mini-batch of documents, estimates the parameters gamma controlling the + // variational distribution over the topic weights for each document in the mini-batch. var stat = BDM.zeros[Double](k, vocabSize) - stat = batches(batchCount).aggregate(stat)(seqOp, _ += _) - - stat = stat :* _expElogbeta - _lambda = _lambda * (1 - rhot) + (stat * D.toDouble / batchSize.toDouble + 1.0 / k) * rhot - _Elogbeta = dirichlet_expectation(_lambda) - _expElogbeta = exp(_Elogbeta) - batchCount += 1 + stat = batch.aggregate(stat)(seqOp, _ += _) + stat = stat :* expElogbeta + + // Update lambda based on documents. + lambda = lambda * (1 - weight) + (stat * D.toDouble / batchSize.toDouble + 1.0 / k) * weight + Elogbeta = dirichlet_expectation(lambda) + expElogbeta = exp(Elogbeta) + batchId += 1 } - private def seqOp(other: BDM[Double], doc: (Long, Vector)): BDM[Double] = { + // for each document d update that document's gamma and phi + private def seqOp(stat: BDM[Double], doc: (Long, Vector)): BDM[Double] = { val termCounts = doc._2 val (ids, cts) = termCounts match { case v: DenseVector => (((0 until v.size).toList), v.values) @@ -461,15 +455,17 @@ private[clustering] object LDA { case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass) } + // Initialize the variational distribution q(theta|gamma) for the mini-batch var gammad = new Gamma(100, 1.0 / 100.0).samplesVector(k).t // 1 * K var Elogthetad = vector_dirichlet_expectation(gammad.t).t // 1 * K var expElogthetad = exp(Elogthetad.t).t // 1 * K - val expElogbetad = _expElogbeta(::, ids).toDenseMatrix // K * ids + val expElogbetad = expElogbeta(::, ids).toDenseMatrix // K * ids var phinorm = expElogthetad * expElogbetad + 1e-100 // 1 * ids var meanchange = 1D - val ctsVector = new BDV[Double](cts).t // 1 * ids + val ctsVector = new BDV[Double](cts).t // 1 * ids + // Iterate between gamma and phi until convergence while (meanchange > 1e-6) { val lastgamma = gammad // 1*K 1 * ids ids * k @@ -480,22 +476,22 @@ private[clustering] object LDA { meanchange = sum(abs((gammad - lastgamma).t)) / gammad.t.size.toDouble } - val v1 = expElogthetad.t.toDenseMatrix.t - val v2 = (ctsVector / phinorm).t.toDenseMatrix - val outerResult = kron(v1, v2) // K * ids + val m1 = expElogthetad.t.toDenseMatrix.t + val m2 = (ctsVector / phinorm).t.toDenseMatrix + val outerResult = kron(m1, m2) // K * ids for (i <- 0 until ids.size) { - other(::, ids(i)) := (other(::, ids(i)) + outerResult(::, i)) + stat(::, ids(i)) := (stat(::, ids(i)) + outerResult(::, i)) } - other + stat } - def getGammaMatrix(row:Int, col:Int): BDM[Double] ={ + private def getGammaMatrix(row:Int, col:Int): BDM[Double] ={ val gammaRandomGenerator = new Gamma(100, 1.0 / 100.0) val temp = gammaRandomGenerator.sample(row * col).toArray (new BDM[Double](col, row, temp)).t } - def dirichlet_expectation(alpha : BDM[Double]): BDM[Double] = { + private def dirichlet_expectation(alpha : BDM[Double]): BDM[Double] = { val rowSum = sum(alpha(breeze.linalg.*, ::)) val digAlpha = digamma(alpha) val digRowSum = digamma(rowSum) @@ -503,7 +499,7 @@ private[clustering] object LDA { result } - def vector_dirichlet_expectation(v : BDV[Double]): (BDV[Double]) ={ + private def vector_dirichlet_expectation(v : BDV[Double]): (BDV[Double]) ={ digamma(v) - digamma(sum(v)) } } diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java index fbe171b4b1ab1..dc10aa67c7c1f 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java @@ -88,7 +88,7 @@ public void distributedLDAModel() { .setMaxIterations(5) .setSeed(12345); - DistributedLDAModel model = (DistributedLDAModel)lda.run(corpus); + DistributedLDAModel model = lda.run(corpus); // Check: basic parameters LocalLDAModel localModel = model.toLocal(); diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala index d36fb9b479c67..302d751eb8a94 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala @@ -68,7 +68,7 @@ class LDASuite extends FunSuite with MLlibTestSparkContext { .setSeed(12345) val corpus = sc.parallelize(tinyCorpus, 2) - val model: DistributedLDAModel = lda.run(corpus).asInstanceOf[DistributedLDAModel] + val model: DistributedLDAModel = lda.run(corpus) // Check: basic parameters val localModel = model.toLocal