Skip to content

Commit

Permalink
seperate API and adjust batch split
Browse files Browse the repository at this point in the history
  • Loading branch information
hhbyyh committed Mar 2, 2015
1 parent 37af91a commit 581c623
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 69 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ public Tuple2<Long, Vector> call(Tuple2<Vector, Long> doc_id) {
corpus.cache();

// Cluster the documents into three topics using LDA
DistributedLDAModel ldaModel = (DistributedLDAModel) new LDA().setK(3).run(corpus);
DistributedLDAModel ldaModel = new LDA().setK(3).run(corpus);

// Output topics. Each is a distribution over words (matching word count vectors)
System.out.println("Learned topics (as distributions over vocab of " + ldaModel.vocabSize()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ object LDAExample {
}
println()
}

sc.stop()
}

/**
Expand Down
126 changes: 61 additions & 65 deletions mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ import org.apache.spark.mllib.impl.PeriodicGraphCheckpointer
import org.apache.spark.mllib.linalg.{Vector, DenseVector, SparseVector, Matrices}
import org.apache.spark.rdd.RDD
import org.apache.spark.util.Utils
import org.apache.spark.mllib.rdd.RDDFunctions._


/**
Expand Down Expand Up @@ -223,10 +222,6 @@ class LDA private (
this
}

object LDAMode extends Enumeration {
val EM, Online = Value
}

/**
* Learn an LDA model using the given dataset.
*
Expand All @@ -236,37 +231,30 @@ class LDA private (
* Document IDs must be unique and >= 0.
* @return Inferred LDA model
*/
def run(documents: RDD[(Long, Vector)], mode: LDAMode.Value = LDAMode.EM ): LDAModel = {
mode match {
case LDAMode.EM =>
val state = LDA.initialState(documents, k, getDocConcentration, getTopicConcentration, seed,
checkpointInterval)
var iter = 0
val iterationTimes = Array.fill[Double](maxIterations)(0)
while (iter < maxIterations) {
val start = System.nanoTime()
state.next()
val elapsedSeconds = (System.nanoTime() - start) / 1e9
iterationTimes(iter) = elapsedSeconds
iter += 1
}
state.graphCheckpointer.deleteAllCheckpoints()
new DistributedLDAModel(state, iterationTimes)
case LDAMode.Online =>
val vocabSize = documents.first._2.size
val onlineLDA = new LDA.OnlineLDAOptimizer(documents, k, vocabSize)
var iter = 0
while (iter < onlineLDA.batchNumber) {
onlineLDA.next()
iter += 1
}
new LocalLDAModel(Matrices.fromBreeze(onlineLDA._lambda).transpose)
case _ => throw new IllegalArgumentException(s"Do not support mode $mode.")
def run(documents: RDD[(Long, Vector)]): DistributedLDAModel = {
val state = LDA.initialState(documents, k, getDocConcentration, getTopicConcentration, seed,
checkpointInterval)
var iter = 0
val iterationTimes = Array.fill[Double](maxIterations)(0)
while (iter < maxIterations) {
val start = System.nanoTime()
state.next()
val elapsedSeconds = (System.nanoTime() - start) / 1e9
iterationTimes(iter) = elapsedSeconds
iter += 1
}
state.graphCheckpointer.deleteAllCheckpoints()
new DistributedLDAModel(state, iterationTimes)
}

def runOnlineLDA(documents: RDD[(Long, Vector)]): LDAModel = {
val onlineLDA = new LDA.OnlineLDAOptimizer(documents, k)
(0 until onlineLDA.batchNumber).map(_ => onlineLDA.next())
new LocalLDAModel(Matrices.fromBreeze(onlineLDA.lambda).transpose)
}

/** Java-friendly version of [[run()]] */
def run(documents: JavaPairRDD[java.lang.Long, Vector]): LDAModel = {
def run(documents: JavaPairRDD[java.lang.Long, Vector]): DistributedLDAModel = {
run(documents.rdd.asInstanceOf[RDD[(Long, Vector)]])
}
}
Expand Down Expand Up @@ -418,58 +406,66 @@ private[clustering] object LDA {

}

// todo: add reference to paper and Hoffman
/**
* Optimizer for Online LDA algorithm which breaks corpus into mini-batches and scans only once.
* Hoffman, Blei and Bach, “Online Learning for Latent Dirichlet Allocation.” NIPS, 2010.
*/
private[clustering] class OnlineLDAOptimizer(
val documents: RDD[(Long, Vector)],
val k: Int,
val vocabSize: Int) extends Serializable{
private val documents: RDD[(Long, Vector)],
private val k: Int) extends Serializable{

private val kappa = 0.5 // (0.5, 1] how quickly old information is forgotten
private val tau0 = 1024 // down weights early iterations
private val D = documents.count()
private val vocabSize = documents.first._2.size
private val D = documents.count().toInt
private val batchSize = if (D / 1000 > 4096) 4096
else if (D / 1000 < 4) 4
else D / 1000
val batchNumber = (D/batchSize + 1).toInt
private val batches = documents.sliding(batchNumber).collect()
val batchNumber = D/batchSize

// Initialize the variational distribution q(beta|lambda)
var _lambda = getGammaMatrix(k, vocabSize) // K * V
private var _Elogbeta = dirichlet_expectation(_lambda) // K * V
private var _expElogbeta = exp(_Elogbeta) // K * V
var lambda = getGammaMatrix(k, vocabSize) // K * V
private var Elogbeta = dirichlet_expectation(lambda) // K * V
private var expElogbeta = exp(Elogbeta) // K * V

private var batchCount = 0
private var batchId = 0
def next(): Unit = {
// weight of the mini-batch.
val rhot = math.pow(tau0 + batchCount, -kappa)
require(batchId < batchNumber)
// weight of the mini-batch. 1024 down weights early iterations
val weight = math.pow(1024 + batchId, -0.5)
val batch = documents.filter(doc => doc._1 % batchNumber == batchId)

// Given a mini-batch of documents, estimates the parameters gamma controlling the
// variational distribution over the topic weights for each document in the mini-batch.
var stat = BDM.zeros[Double](k, vocabSize)
stat = batches(batchCount).aggregate(stat)(seqOp, _ += _)

stat = stat :* _expElogbeta
_lambda = _lambda * (1 - rhot) + (stat * D.toDouble / batchSize.toDouble + 1.0 / k) * rhot
_Elogbeta = dirichlet_expectation(_lambda)
_expElogbeta = exp(_Elogbeta)
batchCount += 1
stat = batch.aggregate(stat)(seqOp, _ += _)
stat = stat :* expElogbeta

// Update lambda based on documents.
lambda = lambda * (1 - weight) + (stat * D.toDouble / batchSize.toDouble + 1.0 / k) * weight
Elogbeta = dirichlet_expectation(lambda)
expElogbeta = exp(Elogbeta)
batchId += 1
}

private def seqOp(other: BDM[Double], doc: (Long, Vector)): BDM[Double] = {
// for each document d update that document's gamma and phi
private def seqOp(stat: BDM[Double], doc: (Long, Vector)): BDM[Double] = {
val termCounts = doc._2
val (ids, cts) = termCounts match {
case v: DenseVector => (((0 until v.size).toList), v.values)
case v: SparseVector => (v.indices.toList, v.values)
case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass)
}

// Initialize the variational distribution q(theta|gamma) for the mini-batch
var gammad = new Gamma(100, 1.0 / 100.0).samplesVector(k).t // 1 * K
var Elogthetad = vector_dirichlet_expectation(gammad.t).t // 1 * K
var expElogthetad = exp(Elogthetad.t).t // 1 * K
val expElogbetad = _expElogbeta(::, ids).toDenseMatrix // K * ids
val expElogbetad = expElogbeta(::, ids).toDenseMatrix // K * ids

var phinorm = expElogthetad * expElogbetad + 1e-100 // 1 * ids
var meanchange = 1D
val ctsVector = new BDV[Double](cts).t // 1 * ids
val ctsVector = new BDV[Double](cts).t // 1 * ids

// Iterate between gamma and phi until convergence
while (meanchange > 1e-6) {
val lastgamma = gammad
// 1*K 1 * ids ids * k
Expand All @@ -480,30 +476,30 @@ private[clustering] object LDA {
meanchange = sum(abs((gammad - lastgamma).t)) / gammad.t.size.toDouble
}

val v1 = expElogthetad.t.toDenseMatrix.t
val v2 = (ctsVector / phinorm).t.toDenseMatrix
val outerResult = kron(v1, v2) // K * ids
val m1 = expElogthetad.t.toDenseMatrix.t
val m2 = (ctsVector / phinorm).t.toDenseMatrix
val outerResult = kron(m1, m2) // K * ids
for (i <- 0 until ids.size) {
other(::, ids(i)) := (other(::, ids(i)) + outerResult(::, i))
stat(::, ids(i)) := (stat(::, ids(i)) + outerResult(::, i))
}
other
stat
}

def getGammaMatrix(row:Int, col:Int): BDM[Double] ={
private def getGammaMatrix(row:Int, col:Int): BDM[Double] ={
val gammaRandomGenerator = new Gamma(100, 1.0 / 100.0)
val temp = gammaRandomGenerator.sample(row * col).toArray
(new BDM[Double](col, row, temp)).t
}

def dirichlet_expectation(alpha : BDM[Double]): BDM[Double] = {
private def dirichlet_expectation(alpha : BDM[Double]): BDM[Double] = {
val rowSum = sum(alpha(breeze.linalg.*, ::))
val digAlpha = digamma(alpha)
val digRowSum = digamma(rowSum)
val result = digAlpha(::, breeze.linalg.*) - digRowSum
result
}

def vector_dirichlet_expectation(v : BDV[Double]): (BDV[Double]) ={
private def vector_dirichlet_expectation(v : BDV[Double]): (BDV[Double]) ={
digamma(v) - digamma(sum(v))
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ public void distributedLDAModel() {
.setMaxIterations(5)
.setSeed(12345);

DistributedLDAModel model = (DistributedLDAModel)lda.run(corpus);
DistributedLDAModel model = lda.run(corpus);

// Check: basic parameters
LocalLDAModel localModel = model.toLocal();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class LDASuite extends FunSuite with MLlibTestSparkContext {
.setSeed(12345)
val corpus = sc.parallelize(tinyCorpus, 2)

val model: DistributedLDAModel = lda.run(corpus).asInstanceOf[DistributedLDAModel]
val model: DistributedLDAModel = lda.run(corpus)

// Check: basic parameters
val localModel = model.toLocal
Expand Down

0 comments on commit 581c623

Please sign in to comment.