From 0d0f3eef6d4e2754bfa2904f30bf9e21005ae392 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Tue, 10 Feb 2015 12:30:48 +0800 Subject: [PATCH] replace random split with sliding --- .../main/scala/org/apache/spark/mllib/clustering/LDA.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala index 919bde55a1b8c..716cfd9e103c8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala @@ -32,6 +32,7 @@ import org.apache.spark.mllib.impl.PeriodicGraphCheckpointer import org.apache.spark.mllib.linalg.{Vector, DenseVector, SparseVector, Matrices} import org.apache.spark.rdd.RDD import org.apache.spark.util.Utils +import org.apache.spark.mllib.rdd.RDDFunctions._ /** @@ -430,8 +431,7 @@ private[clustering] object LDA { else if (D / 1000 < 4) 4 else D / 1000 val batchNumber = (D/batchSize + 1).toInt - // todo: performance killer, need to be replaced - private val batches = documents.randomSplit(Array.fill[Double](batchNumber)(1.0)) + private val batches = documents.sliding(batchNumber).collect() // Initialize the variational distribution q(beta|lambda) var _lambda = getGammaMatrix(k, vocabSize) // K * V