diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala index f4c545ad70e96..87190e9d002b6 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala @@ -26,7 +26,7 @@ import scopt.OptionParser import org.apache.log4j.{Level, Logger} import org.apache.spark.{SparkContext, SparkConf} -import org.apache.spark.mllib.clustering.LDA +import org.apache.spark.mllib.clustering.{DistributedLDAModel, LDA} import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.rdd.RDD @@ -137,7 +137,7 @@ object LDAExample { lda.setCheckpointDir(params.checkpointDir.get) } val startTime = System.nanoTime() - val ldaModel = lda.run(corpus) + val ldaModel = lda.run(corpus).asInstanceOf[DistributedLDAModel] val elapsed = (System.nanoTime() - startTime) / 1e9 println(s"Finished training LDA model. Summary:") @@ -159,6 +159,7 @@ object LDAExample { } println() } + sc.stop() } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala index d8f82867a09d2..74183864de6ff 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala @@ -19,7 +19,9 @@ package org.apache.spark.mllib.clustering import java.util.Random -import breeze.linalg.{DenseVector => BDV, normalize, axpy => brzAxpy} +import breeze.linalg.{DenseVector => BDV, normalize, kron, sum, axpy => brzAxpy, DenseMatrix => BDM} +import breeze.numerics.{exp, abs, digamma} +import breeze.stats.distributions.Gamma import org.apache.spark.Logging import org.apache.spark.annotation.Experimental @@ -27,7 +29,7 @@ import org.apache.spark.api.java.JavaPairRDD import org.apache.spark.graphx._ import org.apache.spark.graphx.impl.GraphImpl import org.apache.spark.mllib.impl.PeriodicGraphCheckpointer -import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.linalg.{Vector, DenseVector, SparseVector, Matrices} import org.apache.spark.rdd.RDD import org.apache.spark.util.Utils @@ -250,6 +252,10 @@ class LDA private ( this } + object LDAMode extends Enumeration { + val EM, Online = Value + } + /** * Learn an LDA model using the given dataset. * @@ -259,24 +265,39 @@ class LDA private ( * Document IDs must be unique and >= 0. * @return Inferred LDA model */ - def run(documents: RDD[(Long, Vector)]): DistributedLDAModel = { - val state = LDA.initialState(documents, k, getDocConcentration, getTopicConcentration, seed, - checkpointDir, checkpointInterval) - var iter = 0 - val iterationTimes = Array.fill[Double](maxIterations)(0) - while (iter < maxIterations) { - val start = System.nanoTime() - state.next() - val elapsedSeconds = (System.nanoTime() - start) / 1e9 - iterationTimes(iter) = elapsedSeconds - iter += 1 + def run(documents: RDD[(Long, Vector)], mode: LDAMode.Value = LDAMode.EM ): LDAModel = { + mode match { + case LDAMode.EM => + val state = LDA.initialState(documents, k, getDocConcentration, getTopicConcentration, seed, + checkpointDir, checkpointInterval) + var iter = 0 + val iterationTimes = Array.fill[Double](maxIterations)(0) + while (iter < maxIterations) { + val start = System.nanoTime() + state.next() + val elapsedSeconds = (System.nanoTime() - start) / 1e9 + iterationTimes(iter) = elapsedSeconds + iter += 1 + } + state.graphCheckpointer.deleteAllCheckpoints() + new DistributedLDAModel(state, iterationTimes) + case LDAMode.Online => + //todo: delete the comment in next line + // I changed the return type to LDAModel, as DistributedLDAModel is based on Graph. + val vocabSize = documents.first._2.size + val onlineLDA = new LDA.OnlineLDAOptimizer(documents, k, vocabSize) + var iter = 0 + while (iter < onlineLDA.batchNumber) { + onlineLDA.next() + iter += 1 + } + new LocalLDAModel(Matrices.fromBreeze(onlineLDA._lambda).transpose) + case _ => throw new IllegalArgumentException(s"Do not support mode $mode.") } - state.graphCheckpointer.deleteAllCheckpoints() - new DistributedLDAModel(state, iterationTimes) } /** Java-friendly version of [[run()]] */ - def run(documents: JavaPairRDD[java.lang.Long, Vector]): DistributedLDAModel = { + def run(documents: JavaPairRDD[java.lang.Long, Vector]): LDAModel = { run(documents.rdd.asInstanceOf[RDD[(Long, Vector)]]) } } @@ -429,6 +450,97 @@ private[clustering] object LDA { } + // todo: add reference to paper and Hoffman + class OnlineLDAOptimizer( + val documents: RDD[(Long, Vector)], + val k: Int, + val vocabSize: Int) extends Serializable{ + + private val kappa = 0.5 // (0.5, 1] how quickly old information is forgotten + private val tau0 = 1024 // down weights early iterations + private val D = documents.count() + private val batchSize = if (D / 1000 > 4096) 4096 + else if (D / 1000 < 4) 4 + else D / 1000 + val batchNumber = (D/batchSize + 1).toInt + // todo: performance killer, need to be replaced + private val batches = documents.randomSplit(Array.fill[Double](batchNumber)(1.0)) + + // Initialize the variational distribution q(beta|lambda) + var _lambda = getGammaMatrix(k, vocabSize) // K * V + private var _Elogbeta = dirichlet_expectation(_lambda) // K * V + private var _expElogbeta = exp(_Elogbeta) // K * V + + private var batchCount = 0 + def next(): Unit = { + // weight of the mini-batch. + val rhot = math.pow(tau0 + batchCount, -kappa) + + var stat = BDM.zeros[Double](k, vocabSize) + stat = batches(batchCount).aggregate(stat)(seqOp, _ += _) + + stat = stat :* _expElogbeta + _lambda = _lambda * (1 - rhot) + (stat * D.toDouble / batchSize.toDouble + 1.0 / k) * rhot + _Elogbeta = dirichlet_expectation(_lambda) + _expElogbeta = exp(_Elogbeta) + batchCount += 1 + } + + private def seqOp(other: BDM[Double], doc: (Long, Vector)): BDM[Double] = { + val termCounts = doc._2 + val (ids, cts) = termCounts match { + case v: DenseVector => (((0 until v.size).toList), v.values) + case v: SparseVector => (v.indices.toList, v.values) + case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass) + } + + var gammad = new Gamma(100, 1.0 / 100.0).samplesVector(k).t // 1 * K + var Elogthetad = vector_dirichlet_expectation(gammad.t).t // 1 * K + var expElogthetad = exp(Elogthetad.t).t // 1 * K + val expElogbetad = _expElogbeta(::, ids).toDenseMatrix // K * ids + + var phinorm = expElogthetad * expElogbetad + 1e-100 // 1 * ids + var meanchange = 1D + val ctsVector = new BDV[Double](cts).t // 1 * ids + + while (meanchange > 1e-6) { + val lastgamma = gammad + // 1*K 1 * ids ids * k + gammad = (expElogthetad :* ((ctsVector / phinorm) * (expElogbetad.t))) + 1.0/k + Elogthetad = vector_dirichlet_expectation(gammad.t).t + expElogthetad = exp(Elogthetad.t).t + phinorm = expElogthetad * expElogbetad + 1e-100 + meanchange = sum(abs((gammad - lastgamma).t)) / gammad.t.size.toDouble + } + + val v1 = expElogthetad.t.toDenseMatrix.t + val v2 = (ctsVector / phinorm).t.toDenseMatrix + val outerResult = kron(v1, v2) // K * ids + for (i <- 0 until ids.size) { + other(::, ids(i)) := (other(::, ids(i)) + outerResult(::, i)) + } + other + } + + def getGammaMatrix(row:Int, col:Int): BDM[Double] ={ + val gammaRandomGenerator = new Gamma(100, 1.0 / 100.0) + val temp = gammaRandomGenerator.sample(row * col).toArray + (new BDM[Double](col, row, temp)).t + } + + def dirichlet_expectation(alpha : BDM[Double]): BDM[Double] = { + val rowSum = sum(alpha(breeze.linalg.*, ::)) + val digAlpha = digamma(alpha) + val digRowSum = digamma(rowSum) + val result = digAlpha(::, breeze.linalg.*) - digRowSum + result + } + + def vector_dirichlet_expectation(v : BDV[Double]): (BDV[Double]) ={ + digamma(v) - digamma(sum(v)) + } + } + /** * Compute gamma_{wjk}, a distribution over topics k. */ diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java index dc10aa67c7c1f..fbe171b4b1ab1 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java @@ -88,7 +88,7 @@ public void distributedLDAModel() { .setMaxIterations(5) .setSeed(12345); - DistributedLDAModel model = lda.run(corpus); + DistributedLDAModel model = (DistributedLDAModel)lda.run(corpus); // Check: basic parameters LocalLDAModel localModel = model.toLocal(); diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala index 302d751eb8a94..d36fb9b479c67 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala @@ -68,7 +68,7 @@ class LDASuite extends FunSuite with MLlibTestSparkContext { .setSeed(12345) val corpus = sc.parallelize(tinyCorpus, 2) - val model: DistributedLDAModel = lda.run(corpus) + val model: DistributedLDAModel = lda.run(corpus).asInstanceOf[DistributedLDAModel] // Check: basic parameters val localModel = model.toLocal