From 7efbb6f91ca94f9243dbb7a16ea3fc9b6f548b99 Mon Sep 17 00:00:00 2001 From: Liquan Pei Date: Sun, 3 Aug 2014 12:16:19 -0700 Subject: [PATCH] use broadcast version of vocab in aggregate --- .../apache/spark/mllib/feature/Word2Vec.scala | 24 +++++++++++-------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 21c2395fb18ae..3ace0800fb9f8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -29,7 +29,7 @@ import org.apache.spark.rdd._ import org.apache.spark.SparkContext._ import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.HashPartitioner - +import org.apache.spark.storage.StorageLevel /** * Entry in vocabulary */ @@ -215,10 +215,10 @@ class Word2Vec( val sc = dataset.context val expTable = sc.broadcast(createExpTable()) - val V = sc.broadcast(vocab) - val VHash = sc.broadcast(vocabHash) + val bcVocab = sc.broadcast(vocab) + val bcVocabHash = sc.broadcast(vocabHash) - val sentences = words.mapPartitions { + val sentences: RDD[Array[Int]] = words.mapPartitions { iter => { new Iterator[Array[Int]] { def hasNext = iter.hasNext @@ -226,7 +226,7 @@ class Word2Vec( var sentence = new ArrayBuffer[Int] var sentenceLength = 0 while (iter.hasNext && sentenceLength < MAX_SENTENCE_LENGTH) { - val word = VHash.value.get(iter.next) + val word = bcVocabHash.value.get(iter.next) word match { case Some(w) => { sentence += w @@ -278,14 +278,14 @@ class Word2Vec( val neu1e = new Array[Double](layer1Size) // Hierarchical softmax var d = 0 - while (d < vocab(word).codeLen) { - val l2 = vocab(word).point(d) * layer1Size + while (d < bcVocab.value(word).codeLen) { + val l2 = bcVocab.value(word).point(d) * layer1Size // Propagate hidden -> output var f = blas.ddot(layer1Size, syn0, l1, 1, syn1, l2, 1) if (f > -MAX_EXP && f < MAX_EXP) { val ind = ((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2.0)).toInt f = expTable.value(ind) - val g = (1 - vocab(word).code(d) - f) * alpha + val g = (1 - bcVocab.value(word).code(d) - f) * alpha blas.daxpy(layer1Size, g, syn1, l2, 1, neu1e, 0, 1) blas.daxpy(layer1Size, g, syn0, l1, 1, syn1, l2, 1) } @@ -310,17 +310,21 @@ class Word2Vec( syn0Global = aggSyn0 syn1Global = aggSyn1 } + newSentences.unpersist() + val wordMap = new Array[(String, Array[Double])](vocabSize) var i = 0 while (i < vocabSize) { - val word = vocab(i).word + val word = bcVocab.value(i).word val vector = new Array[Double](layer1Size) Array.copy(syn0Global, i * layer1Size, vector, 0, layer1Size) wordMap(i) = (word, vector) i += 1 } val modelRDD = sc.parallelize(wordMap, modelPartitionNum) - .partitionBy(new HashPartitioner(modelPartitionNum)).cache() + .partitionBy(new HashPartitioner(modelPartitionNum)) + .persist(StorageLevel.MEMORY_AND_DISK) + new Word2VecModel(modelRDD) } }