Skip to content

Commit

Permalink
use broadcast version of vocab in aggregate
Browse files Browse the repository at this point in the history
  • Loading branch information
Liquan Pei committed Aug 3, 2014
1 parent 6bcc8be commit 7efbb6f
Showing 1 changed file with 14 additions and 10 deletions.
24 changes: 14 additions & 10 deletions mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.apache.spark.rdd._
import org.apache.spark.SparkContext._
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.HashPartitioner

import org.apache.spark.storage.StorageLevel
/**
* Entry in vocabulary
*/
Expand Down Expand Up @@ -215,18 +215,18 @@ class Word2Vec(
val sc = dataset.context

val expTable = sc.broadcast(createExpTable())
val V = sc.broadcast(vocab)
val VHash = sc.broadcast(vocabHash)
val bcVocab = sc.broadcast(vocab)
val bcVocabHash = sc.broadcast(vocabHash)

val sentences = words.mapPartitions {
val sentences: RDD[Array[Int]] = words.mapPartitions {
iter => { new Iterator[Array[Int]] {
def hasNext = iter.hasNext

def next = {
var sentence = new ArrayBuffer[Int]
var sentenceLength = 0
while (iter.hasNext && sentenceLength < MAX_SENTENCE_LENGTH) {
val word = VHash.value.get(iter.next)
val word = bcVocabHash.value.get(iter.next)
word match {
case Some(w) => {
sentence += w
Expand Down Expand Up @@ -278,14 +278,14 @@ class Word2Vec(
val neu1e = new Array[Double](layer1Size)
// Hierarchical softmax
var d = 0
while (d < vocab(word).codeLen) {
val l2 = vocab(word).point(d) * layer1Size
while (d < bcVocab.value(word).codeLen) {
val l2 = bcVocab.value(word).point(d) * layer1Size
// Propagate hidden -> output
var f = blas.ddot(layer1Size, syn0, l1, 1, syn1, l2, 1)
if (f > -MAX_EXP && f < MAX_EXP) {
val ind = ((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2.0)).toInt
f = expTable.value(ind)
val g = (1 - vocab(word).code(d) - f) * alpha
val g = (1 - bcVocab.value(word).code(d) - f) * alpha
blas.daxpy(layer1Size, g, syn1, l2, 1, neu1e, 0, 1)
blas.daxpy(layer1Size, g, syn0, l1, 1, syn1, l2, 1)
}
Expand All @@ -310,17 +310,21 @@ class Word2Vec(
syn0Global = aggSyn0
syn1Global = aggSyn1
}
newSentences.unpersist()

val wordMap = new Array[(String, Array[Double])](vocabSize)
var i = 0
while (i < vocabSize) {
val word = vocab(i).word
val word = bcVocab.value(i).word
val vector = new Array[Double](layer1Size)
Array.copy(syn0Global, i * layer1Size, vector, 0, layer1Size)
wordMap(i) = (word, vector)
i += 1
}
val modelRDD = sc.parallelize(wordMap, modelPartitionNum)
.partitionBy(new HashPartitioner(modelPartitionNum)).cache()
.partitionBy(new HashPartitioner(modelPartitionNum))
.persist(StorageLevel.MEMORY_AND_DISK)

new Word2VecModel(modelRDD)
}
}
Expand Down

0 comments on commit 7efbb6f

Please sign in to comment.