diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala index 919bde55a1b8c..716cfd9e103c8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala @@ -32,6 +32,7 @@ import org.apache.spark.mllib.impl.PeriodicGraphCheckpointer import org.apache.spark.mllib.linalg.{Vector, DenseVector, SparseVector, Matrices} import org.apache.spark.rdd.RDD import org.apache.spark.util.Utils +import org.apache.spark.mllib.rdd.RDDFunctions._ /** @@ -430,8 +431,7 @@ private[clustering] object LDA { else if (D / 1000 < 4) 4 else D / 1000 val batchNumber = (D/batchSize + 1).toInt - // todo: performance killer, need to be replaced - private val batches = documents.randomSplit(Array.fill[Double](batchNumber)(1.0)) + private val batches = documents.sliding(batchNumber).collect() // Initialize the variational distribution q(beta|lambda) var _lambda = getGammaMatrix(k, vocabSize) // K * V