Skip to content

Commit

Permalink
[SPARK-7090] [MLLIB] Introduce LDAOptimizer to LDA to further improve…
Browse files Browse the repository at this point in the history
… extensibility

jira: https://issues.apache.org/jira/browse/SPARK-7090

LDA was implemented with extensibility in mind. And with the development of OnlineLDA and Gibbs Sampling, we are collecting more detailed requirements from different algorithms.
As Joseph Bradley jkbradley proposed in apache#4807 and with some further discussion, we'd like to adjust the code structure a little to present the common interface and extension point clearly.
Basically class LDA would be a common entrance for LDA computing. And each LDA object will refer to a LDAOptimizer for the concrete algorithm implementation. Users can customize LDAOptimizer with specific parameters and assign it to LDA.

Concrete changes:

1. Add a trait `LDAOptimizer`, which defines the common iterface for concrete implementations. Each subClass is a wrapper for a specific LDA algorithm.

2. Move EMOptimizer to file LDAOptimizer and inherits from LDAOptimizer, rename to EMLDAOptimizer. (in case a more generic EMOptimizer comes in the future)
        -adjust the constructor of EMOptimizer, since all the parameters should be passed in through initialState method. This can avoid unwanted confusion or overwrite.
        -move the code from LDA.initalState to initalState of EMLDAOptimizer

3. Add property ldaOptimizer to LDA and its getter/setter, and EMLDAOptimizer is the default Optimizer.

4. Change the return type of LDA.run from DistributedLDAModel to LDAModel.

Further work:
add OnlineLDAOptimizer and other possible Optimizers once ready.

Author: Yuhao Yang <[email protected]>

Closes apache#5661 from hhbyyh/ldaRefactor and squashes the following commits:

0e2e006 [Yuhao Yang] respond to review comments
08a45da [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into ldaRefactor
e756ce4 [Yuhao Yang] solve mima exception
d74fd8f [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into ldaRefactor
0bb8400 [Yuhao Yang] refactor LDA with Optimizer
ec2f857 [Yuhao Yang] protoptype for discussion
  • Loading branch information
hhbyyh authored and jkbradley committed Apr 28, 2015
1 parent 62888a4 commit 4d9e560
Show file tree
Hide file tree
Showing 8 changed files with 256 additions and 151 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ public Tuple2<Long, Vector> call(Tuple2<Vector, Long> doc_id) {
corpus.cache();

// Cluster the documents into three topics using LDA
DistributedLDAModel ldaModel = new LDA().setK(3).run(corpus);
DistributedLDAModel ldaModel = (DistributedLDAModel)new LDA().setK(3).run(corpus);

// Output topics. Each is a distribution over words (matching word count vectors)
System.out.println("Learned topics (as distributions over vocab of " + ldaModel.vocabSize()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import scopt.OptionParser
import org.apache.log4j.{Level, Logger}

import org.apache.spark.{SparkContext, SparkConf}
import org.apache.spark.mllib.clustering.LDA
import org.apache.spark.mllib.clustering.{DistributedLDAModel, LDA}
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.rdd.RDD

Expand Down Expand Up @@ -137,7 +137,7 @@ object LDAExample {
sc.setCheckpointDir(params.checkpointDir.get)
}
val startTime = System.nanoTime()
val ldaModel = lda.run(corpus)
val ldaModel = lda.run(corpus).asInstanceOf[DistributedLDAModel]
val elapsed = (System.nanoTime() - startTime) / 1e9

println(s"Finished training LDA model. Summary:")
Expand Down
181 changes: 36 additions & 145 deletions mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,11 @@

package org.apache.spark.mllib.clustering

import java.util.Random

import breeze.linalg.{DenseVector => BDV, normalize}

import breeze.linalg.{DenseVector => BDV}
import org.apache.spark.Logging
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaPairRDD
import org.apache.spark.graphx._
import org.apache.spark.graphx.impl.GraphImpl
import org.apache.spark.mllib.impl.PeriodicGraphCheckpointer
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.rdd.RDD
import org.apache.spark.util.Utils
Expand All @@ -42,16 +37,9 @@ import org.apache.spark.util.Utils
* - "token": instance of a term appearing in a document
* - "topic": multinomial distribution over words representing some concept
*
* Currently, the underlying implementation uses Expectation-Maximization (EM), implemented
* according to the Asuncion et al. (2009) paper referenced below.
*
* References:
* - Original LDA paper (journal version):
* Blei, Ng, and Jordan. "Latent Dirichlet Allocation." JMLR, 2003.
* - This class implements their "smoothed" LDA model.
* - Paper which clearly explains several algorithms, including EM:
* Asuncion, Welling, Smyth, and Teh.
* "On Smoothing and Inference for Topic Models." UAI, 2009.
*
* @see [[http://en.wikipedia.org/wiki/Latent_Dirichlet_allocation Latent Dirichlet allocation
* (Wikipedia)]]
Expand All @@ -63,10 +51,11 @@ class LDA private (
private var docConcentration: Double,
private var topicConcentration: Double,
private var seed: Long,
private var checkpointInterval: Int) extends Logging {
private var checkpointInterval: Int,
private var ldaOptimizer: LDAOptimizer) extends Logging {

def this() = this(k = 10, maxIterations = 20, docConcentration = -1, topicConcentration = -1,
seed = Utils.random.nextLong(), checkpointInterval = 10)
seed = Utils.random.nextLong(), checkpointInterval = 10, ldaOptimizer = new EMLDAOptimizer)

/**
* Number of topics to infer. I.e., the number of soft cluster centers.
Expand Down Expand Up @@ -220,6 +209,32 @@ class LDA private (
this
}


/** LDAOptimizer used to perform the actual calculation */
def getOptimizer: LDAOptimizer = ldaOptimizer

/**
* LDAOptimizer used to perform the actual calculation (default = EMLDAOptimizer)
*/
def setOptimizer(optimizer: LDAOptimizer): this.type = {
this.ldaOptimizer = optimizer
this
}

/**
* Set the LDAOptimizer used to perform the actual calculation by algorithm name.
* Currently "em" is supported.
*/
def setOptimizer(optimizerName: String): this.type = {
this.ldaOptimizer =
optimizerName.toLowerCase match {
case "em" => new EMLDAOptimizer
case other =>
throw new IllegalArgumentException(s"Only em is supported but got $other.")
}
this
}

/**
* Learn an LDA model using the given dataset.
*
Expand All @@ -229,9 +244,9 @@ class LDA private (
* Document IDs must be unique and >= 0.
* @return Inferred LDA model
*/
def run(documents: RDD[(Long, Vector)]): DistributedLDAModel = {
val state = LDA.initialState(documents, k, getDocConcentration, getTopicConcentration, seed,
checkpointInterval)
def run(documents: RDD[(Long, Vector)]): LDAModel = {
val state = ldaOptimizer.initialState(documents, k, getDocConcentration, getTopicConcentration,
seed, checkpointInterval)
var iter = 0
val iterationTimes = Array.fill[Double](maxIterations)(0)
while (iter < maxIterations) {
Expand All @@ -241,12 +256,11 @@ class LDA private (
iterationTimes(iter) = elapsedSeconds
iter += 1
}
state.graphCheckpointer.deleteAllCheckpoints()
new DistributedLDAModel(state, iterationTimes)
state.getLDAModel(iterationTimes)
}

/** Java-friendly version of [[run()]] */
def run(documents: JavaPairRDD[java.lang.Long, Vector]): DistributedLDAModel = {
def run(documents: JavaPairRDD[java.lang.Long, Vector]): LDAModel = {
run(documents.rdd.asInstanceOf[RDD[(Long, Vector)]])
}
}
Expand Down Expand Up @@ -320,88 +334,10 @@ private[clustering] object LDA {

private[clustering] def isTermVertex(v: (VertexId, _)): Boolean = v._1 < 0

/**
* Optimizer for EM algorithm which stores data + parameter graph, plus algorithm parameters.
*
* @param graph EM graph, storing current parameter estimates in vertex descriptors and
* data (token counts) in edge descriptors.
* @param k Number of topics
* @param vocabSize Number of unique terms
* @param docConcentration "alpha"
* @param topicConcentration "beta" or "eta"
*/
private[clustering] class EMOptimizer(
var graph: Graph[TopicCounts, TokenCount],
val k: Int,
val vocabSize: Int,
val docConcentration: Double,
val topicConcentration: Double,
checkpointInterval: Int) {

private[LDA] val graphCheckpointer = new PeriodicGraphCheckpointer[TopicCounts, TokenCount](
graph, checkpointInterval)

def next(): EMOptimizer = {
val eta = topicConcentration
val W = vocabSize
val alpha = docConcentration

val N_k = globalTopicTotals
val sendMsg: EdgeContext[TopicCounts, TokenCount, (Boolean, TopicCounts)] => Unit =
(edgeContext) => {
// Compute N_{wj} gamma_{wjk}
val N_wj = edgeContext.attr
// E-STEP: Compute gamma_{wjk} (smoothed topic distributions), scaled by token count
// N_{wj}.
val scaledTopicDistribution: TopicCounts =
computePTopic(edgeContext.srcAttr, edgeContext.dstAttr, N_k, W, eta, alpha) *= N_wj
edgeContext.sendToDst((false, scaledTopicDistribution))
edgeContext.sendToSrc((false, scaledTopicDistribution))
}
// This is a hack to detect whether we could modify the values in-place.
// TODO: Add zero/seqOp/combOp option to aggregateMessages. (SPARK-5438)
val mergeMsg: ((Boolean, TopicCounts), (Boolean, TopicCounts)) => (Boolean, TopicCounts) =
(m0, m1) => {
val sum =
if (m0._1) {
m0._2 += m1._2
} else if (m1._1) {
m1._2 += m0._2
} else {
m0._2 + m1._2
}
(true, sum)
}
// M-STEP: Aggregation computes new N_{kj}, N_{wk} counts.
val docTopicDistributions: VertexRDD[TopicCounts] =
graph.aggregateMessages[(Boolean, TopicCounts)](sendMsg, mergeMsg)
.mapValues(_._2)
// Update the vertex descriptors with the new counts.
val newGraph = GraphImpl.fromExistingRDDs(docTopicDistributions, graph.edges)
graph = newGraph
graphCheckpointer.updateGraph(newGraph)
globalTopicTotals = computeGlobalTopicTotals()
this
}

/**
* Aggregate distributions over topics from all term vertices.
*
* Note: This executes an action on the graph RDDs.
*/
var globalTopicTotals: TopicCounts = computeGlobalTopicTotals()

private def computeGlobalTopicTotals(): TopicCounts = {
val numTopics = k
graph.vertices.filter(isTermVertex).values.fold(BDV.zeros[Double](numTopics))(_ += _)
}

}

/**
* Compute gamma_{wjk}, a distribution over topics k.
*/
private def computePTopic(
private[clustering] def computePTopic(
docTopicCounts: TopicCounts,
termTopicCounts: TopicCounts,
totalTopicCounts: TopicCounts,
Expand All @@ -427,49 +363,4 @@ private[clustering] object LDA {
// normalize
BDV(gamma_wj) /= sum
}

/**
* Compute bipartite term/doc graph.
*/
private def initialState(
docs: RDD[(Long, Vector)],
k: Int,
docConcentration: Double,
topicConcentration: Double,
randomSeed: Long,
checkpointInterval: Int): EMOptimizer = {
// For each document, create an edge (Document -> Term) for each unique term in the document.
val edges: RDD[Edge[TokenCount]] = docs.flatMap { case (docID: Long, termCounts: Vector) =>
// Add edges for terms with non-zero counts.
termCounts.toBreeze.activeIterator.filter(_._2 != 0.0).map { case (term, cnt) =>
Edge(docID, term2index(term), cnt)
}
}

val vocabSize = docs.take(1).head._2.size

// Create vertices.
// Initially, we use random soft assignments of tokens to topics (random gamma).
def createVertices(): RDD[(VertexId, TopicCounts)] = {
val verticesTMP: RDD[(VertexId, TopicCounts)] =
edges.mapPartitionsWithIndex { case (partIndex, partEdges) =>
val random = new Random(partIndex + randomSeed)
partEdges.flatMap { edge =>
val gamma = normalize(BDV.fill[Double](k)(random.nextDouble()), 1.0)
val sum = gamma * edge.attr
Seq((edge.srcId, sum), (edge.dstId, sum))
}
}
verticesTMP.reduceByKey(_ + _)
}

val docTermVertices = createVertices()

// Partition such that edges are grouped by document
val graph = Graph(docTermVertices, edges)
.partitionBy(PartitionStrategy.EdgePartition1D)

new EMOptimizer(graph, k, vocabSize, docConcentration, topicConcentration, checkpointInterval)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ class DistributedLDAModel private (

import LDA._

private[clustering] def this(state: LDA.EMOptimizer, iterationTimes: Array[Double]) = {
private[clustering] def this(state: EMLDAOptimizer, iterationTimes: Array[Double]) = {
this(state.graph, state.globalTopicTotals, state.k, state.vocabSize, state.docConcentration,
state.topicConcentration, iterationTimes)
}
Expand Down
Loading

0 comments on commit 4d9e560

Please sign in to comment.