Skip to content

Commit

Permalink
online lda initial checkin
Browse files Browse the repository at this point in the history
  • Loading branch information
hhbyyh committed Feb 6, 2015
1 parent 6580929 commit d640d9c
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import scopt.OptionParser
import org.apache.log4j.{Level, Logger}

import org.apache.spark.{SparkContext, SparkConf}
import org.apache.spark.mllib.clustering.LDA
import org.apache.spark.mllib.clustering.{DistributedLDAModel, LDA}
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.rdd.RDD

Expand Down Expand Up @@ -137,7 +137,7 @@ object LDAExample {
lda.setCheckpointDir(params.checkpointDir.get)
}
val startTime = System.nanoTime()
val ldaModel = lda.run(corpus)
val ldaModel = lda.run(corpus).asInstanceOf[DistributedLDAModel]
val elapsed = (System.nanoTime() - startTime) / 1e9

println(s"Finished training LDA model. Summary:")
Expand All @@ -159,6 +159,7 @@ object LDAExample {
}
println()
}
sc.stop()

}

Expand Down
144 changes: 128 additions & 16 deletions mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,17 @@ package org.apache.spark.mllib.clustering

import java.util.Random

import breeze.linalg.{DenseVector => BDV, normalize, axpy => brzAxpy}
import breeze.linalg.{DenseVector => BDV, normalize, kron, sum, axpy => brzAxpy, DenseMatrix => BDM}
import breeze.numerics.{exp, abs, digamma}
import breeze.stats.distributions.Gamma

import org.apache.spark.Logging
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaPairRDD
import org.apache.spark.graphx._
import org.apache.spark.graphx.impl.GraphImpl
import org.apache.spark.mllib.impl.PeriodicGraphCheckpointer
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.linalg.{Vector, DenseVector, SparseVector, Matrices}
import org.apache.spark.rdd.RDD
import org.apache.spark.util.Utils

Expand Down Expand Up @@ -250,6 +252,10 @@ class LDA private (
this
}

object LDAMode extends Enumeration {
val EM, Online = Value
}

/**
* Learn an LDA model using the given dataset.
*
Expand All @@ -259,24 +265,39 @@ class LDA private (
* Document IDs must be unique and >= 0.
* @return Inferred LDA model
*/
def run(documents: RDD[(Long, Vector)]): DistributedLDAModel = {
val state = LDA.initialState(documents, k, getDocConcentration, getTopicConcentration, seed,
checkpointDir, checkpointInterval)
var iter = 0
val iterationTimes = Array.fill[Double](maxIterations)(0)
while (iter < maxIterations) {
val start = System.nanoTime()
state.next()
val elapsedSeconds = (System.nanoTime() - start) / 1e9
iterationTimes(iter) = elapsedSeconds
iter += 1
def run(documents: RDD[(Long, Vector)], mode: LDAMode.Value = LDAMode.EM ): LDAModel = {
mode match {
case LDAMode.EM =>
val state = LDA.initialState(documents, k, getDocConcentration, getTopicConcentration, seed,
checkpointDir, checkpointInterval)
var iter = 0
val iterationTimes = Array.fill[Double](maxIterations)(0)
while (iter < maxIterations) {
val start = System.nanoTime()
state.next()
val elapsedSeconds = (System.nanoTime() - start) / 1e9
iterationTimes(iter) = elapsedSeconds
iter += 1
}
state.graphCheckpointer.deleteAllCheckpoints()
new DistributedLDAModel(state, iterationTimes)
case LDAMode.Online =>
//todo: delete the comment in next line
// I changed the return type to LDAModel, as DistributedLDAModel is based on Graph.
val vocabSize = documents.first._2.size
val onlineLDA = new LDA.OnlineLDAOptimizer(documents, k, vocabSize)
var iter = 0
while (iter < onlineLDA.batchNumber) {
onlineLDA.next()
iter += 1
}
new LocalLDAModel(Matrices.fromBreeze(onlineLDA._lambda).transpose)
case _ => throw new IllegalArgumentException(s"Do not support mode $mode.")
}
state.graphCheckpointer.deleteAllCheckpoints()
new DistributedLDAModel(state, iterationTimes)
}

/** Java-friendly version of [[run()]] */
def run(documents: JavaPairRDD[java.lang.Long, Vector]): DistributedLDAModel = {
def run(documents: JavaPairRDD[java.lang.Long, Vector]): LDAModel = {
run(documents.rdd.asInstanceOf[RDD[(Long, Vector)]])
}
}
Expand Down Expand Up @@ -429,6 +450,97 @@ private[clustering] object LDA {

}

// todo: add reference to paper and Hoffman
class OnlineLDAOptimizer(
val documents: RDD[(Long, Vector)],
val k: Int,
val vocabSize: Int) extends Serializable{

private val kappa = 0.5 // (0.5, 1] how quickly old information is forgotten
private val tau0 = 1024 // down weights early iterations
private val D = documents.count()
private val batchSize = if (D / 1000 > 4096) 4096
else if (D / 1000 < 4) 4
else D / 1000
val batchNumber = (D/batchSize + 1).toInt
// todo: performance killer, need to be replaced
private val batches = documents.randomSplit(Array.fill[Double](batchNumber)(1.0))

// Initialize the variational distribution q(beta|lambda)
var _lambda = getGammaMatrix(k, vocabSize) // K * V
private var _Elogbeta = dirichlet_expectation(_lambda) // K * V
private var _expElogbeta = exp(_Elogbeta) // K * V

private var batchCount = 0
def next(): Unit = {
// weight of the mini-batch.
val rhot = math.pow(tau0 + batchCount, -kappa)

var stat = BDM.zeros[Double](k, vocabSize)
stat = batches(batchCount).aggregate(stat)(seqOp, _ += _)

stat = stat :* _expElogbeta
_lambda = _lambda * (1 - rhot) + (stat * D.toDouble / batchSize.toDouble + 1.0 / k) * rhot
_Elogbeta = dirichlet_expectation(_lambda)
_expElogbeta = exp(_Elogbeta)
batchCount += 1
}

private def seqOp(other: BDM[Double], doc: (Long, Vector)): BDM[Double] = {
val termCounts = doc._2
val (ids, cts) = termCounts match {
case v: DenseVector => (((0 until v.size).toList), v.values)
case v: SparseVector => (v.indices.toList, v.values)
case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass)
}

var gammad = new Gamma(100, 1.0 / 100.0).samplesVector(k).t // 1 * K
var Elogthetad = vector_dirichlet_expectation(gammad.t).t // 1 * K
var expElogthetad = exp(Elogthetad.t).t // 1 * K
val expElogbetad = _expElogbeta(::, ids).toDenseMatrix // K * ids

var phinorm = expElogthetad * expElogbetad + 1e-100 // 1 * ids
var meanchange = 1D
val ctsVector = new BDV[Double](cts).t // 1 * ids

while (meanchange > 1e-6) {
val lastgamma = gammad
// 1*K 1 * ids ids * k
gammad = (expElogthetad :* ((ctsVector / phinorm) * (expElogbetad.t))) + 1.0/k
Elogthetad = vector_dirichlet_expectation(gammad.t).t
expElogthetad = exp(Elogthetad.t).t
phinorm = expElogthetad * expElogbetad + 1e-100
meanchange = sum(abs((gammad - lastgamma).t)) / gammad.t.size.toDouble
}

val v1 = expElogthetad.t.toDenseMatrix.t
val v2 = (ctsVector / phinorm).t.toDenseMatrix
val outerResult = kron(v1, v2) // K * ids
for (i <- 0 until ids.size) {
other(::, ids(i)) := (other(::, ids(i)) + outerResult(::, i))
}
other
}

def getGammaMatrix(row:Int, col:Int): BDM[Double] ={
val gammaRandomGenerator = new Gamma(100, 1.0 / 100.0)
val temp = gammaRandomGenerator.sample(row * col).toArray
(new BDM[Double](col, row, temp)).t
}

def dirichlet_expectation(alpha : BDM[Double]): BDM[Double] = {
val rowSum = sum(alpha(breeze.linalg.*, ::))
val digAlpha = digamma(alpha)
val digRowSum = digamma(rowSum)
val result = digAlpha(::, breeze.linalg.*) - digRowSum
result
}

def vector_dirichlet_expectation(v : BDV[Double]): (BDV[Double]) ={
digamma(v) - digamma(sum(v))
}
}

/**
* Compute gamma_{wjk}, a distribution over topics k.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ public void distributedLDAModel() {
.setMaxIterations(5)
.setSeed(12345);

DistributedLDAModel model = lda.run(corpus);
DistributedLDAModel model = (DistributedLDAModel)lda.run(corpus);

// Check: basic parameters
LocalLDAModel localModel = model.toLocal();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class LDASuite extends FunSuite with MLlibTestSparkContext {
.setSeed(12345)
val corpus = sc.parallelize(tinyCorpus, 2)

val model: DistributedLDAModel = lda.run(corpus)
val model: DistributedLDAModel = lda.run(corpus).asInstanceOf[DistributedLDAModel]

// Check: basic parameters
val localModel = model.toLocal
Expand Down

0 comments on commit d640d9c

Please sign in to comment.