From 6bcc8be34f6253bc7d4f9d4dcb478bf91f108c86 Mon Sep 17 00:00:00 2001 From: Liquan Pei Date: Sun, 3 Aug 2014 11:15:09 -0700 Subject: [PATCH] add multiple iteration support --- .../apache/spark/mllib/feature/Word2Vec.scala | 130 ++++++++++-------- 1 file changed, 70 insertions(+), 60 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index b55122d3c9f1e..21c2395fb18ae 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -70,7 +70,8 @@ class Word2Vec( val startingAlpha: Double, val window: Int, val minCount: Int, - val parallelism:Int = 1) + val parallelism:Int = 1, + val numIterations:Int = 1) extends Serializable with Logging { private val EXP_TABLE_SIZE = 1000 @@ -241,73 +242,80 @@ class Word2Vec( } val newSentences = sentences.repartition(parallelism).cache() - val temp = Array.fill[Double](vocabSize * layer1Size)((Random.nextDouble - 0.5) / layer1Size) - val (aggSyn0, _, _, _) = - // TODO: broadcast temp instead of serializing it directly - // or initialize the model in each executor - newSentences.aggregate((temp.clone(), new Array[Double](vocabSize * layer1Size), 0, 0))( - seqOp = (c, v) => (c, v) match { case ((syn0, syn1, lastWordCount, wordCount), sentence) => - var lwc = lastWordCount - var wc = wordCount - if (wordCount - lastWordCount > 10000) { - lwc = wordCount - alpha = startingAlpha * (1 - parallelism * wordCount.toDouble / (trainWordsCount + 1)) - if (alpha < startingAlpha * 0.0001) alpha = startingAlpha * 0.0001 - logInfo("wordCount = " + wordCount + ", alpha = " + alpha) - } - wc += sentence.size - var pos = 0 - while (pos < sentence.size) { - val word = sentence(pos) - // TODO: fix random seed - val b = Random.nextInt(window) - // Train Skip-gram - var a = b - while (a < window * 2 + 1 - b) { - if (a != window) { - val c = pos - window + a - if (c >= 0 && c < sentence.size) { - val lastWord = sentence(c) - val l1 = lastWord * layer1Size - val neu1e = new Array[Double](layer1Size) - // Hierarchical softmax - var d = 0 - while (d < vocab(word).codeLen) { - val l2 = vocab(word).point(d) * layer1Size - // Propagate hidden -> output - var f = blas.ddot(layer1Size, syn0, l1, 1, syn1, l2, 1) - if (f > -MAX_EXP && f < MAX_EXP) { - val ind = ((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2.0)).toInt - f = expTable.value(ind) - val g = (1 - vocab(word).code(d) - f) * alpha - blas.daxpy(layer1Size, g, syn1, l2, 1, neu1e, 0, 1) - blas.daxpy(layer1Size, g, syn0, l1, 1, syn1, l2, 1) + var syn0Global + = Array.fill[Double](vocabSize * layer1Size)((Random.nextDouble - 0.5) / layer1Size) + var syn1Global = new Array[Double](vocabSize * layer1Size) + + for(iter <- 1 to numIterations) { + val (aggSyn0, aggSyn1, _, _) = + // TODO: broadcast temp instead of serializing it directly + // or initialize the model in each executor + newSentences.aggregate((syn0Global.clone(), syn1Global.clone(), 0, 0))( + seqOp = (c, v) => (c, v) match { + case ((syn0, syn1, lastWordCount, wordCount), sentence) => + var lwc = lastWordCount + var wc = wordCount + if (wordCount - lastWordCount > 10000) { + lwc = wordCount + alpha = startingAlpha * (1 - parallelism * wordCount.toDouble / (trainWordsCount + 1)) + if (alpha < startingAlpha * 0.0001) alpha = startingAlpha * 0.0001 + logInfo("wordCount = " + wordCount + ", alpha = " + alpha) + } + wc += sentence.size + var pos = 0 + while (pos < sentence.size) { + val word = sentence(pos) + // TODO: fix random seed + val b = Random.nextInt(window) + // Train Skip-gram + var a = b + while (a < window * 2 + 1 - b) { + if (a != window) { + val c = pos - window + a + if (c >= 0 && c < sentence.size) { + val lastWord = sentence(c) + val l1 = lastWord * layer1Size + val neu1e = new Array[Double](layer1Size) + // Hierarchical softmax + var d = 0 + while (d < vocab(word).codeLen) { + val l2 = vocab(word).point(d) * layer1Size + // Propagate hidden -> output + var f = blas.ddot(layer1Size, syn0, l1, 1, syn1, l2, 1) + if (f > -MAX_EXP && f < MAX_EXP) { + val ind = ((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2.0)).toInt + f = expTable.value(ind) + val g = (1 - vocab(word).code(d) - f) * alpha + blas.daxpy(layer1Size, g, syn1, l2, 1, neu1e, 0, 1) + blas.daxpy(layer1Size, g, syn0, l1, 1, syn1, l2, 1) + } + d += 1 } - d += 1 + blas.daxpy(layer1Size, 1.0, neu1e, 0, 1, syn0, l1, 1) } - blas.daxpy(layer1Size, 1.0, neu1e, 0, 1, syn0, l1, 1) } + a += 1 } - a += 1 + pos += 1 } - pos += 1 - } - (syn0, syn1, lwc, wc) - }, - combOp = (c1, c2) => (c1, c2) match { - case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) => - val n = syn0_1.length - blas.daxpy(n, 1.0, syn0_2, 1, syn0_1, 1) - blas.daxpy(n, 1.0, syn1_2, 1, syn1_1, 1) - (syn0_1, syn1_1, lwc_1 + lwc_2, wc_1 + wc_2) - }) - + (syn0, syn1, lwc, wc) + }, + combOp = (c1, c2) => (c1, c2) match { + case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) => + val n = syn0_1.length + blas.daxpy(n, 1.0, syn0_2, 1, syn0_1, 1) + blas.daxpy(n, 1.0, syn1_2, 1, syn1_1, 1) + (syn0_1, syn1_1, lwc_1 + lwc_2, wc_1 + wc_2) + }) + syn0Global = aggSyn0 + syn1Global = aggSyn1 + } val wordMap = new Array[(String, Array[Double])](vocabSize) var i = 0 while (i < vocabSize) { val word = vocab(i).word val vector = new Array[Double](layer1Size) - Array.copy(aggSyn0, i * layer1Size, vector, 0, layer1Size) + Array.copy(syn0Global, i * layer1Size, vector, 0, layer1Size) wordMap(i) = (word, vector) i += 1 } @@ -398,7 +406,9 @@ object Word2Vec{ size: Int, startingAlpha: Double, window: Int, - minCount: Int): Word2VecModel = { - new Word2Vec(size,startingAlpha, window, minCount).fit[S](input) + minCount: Int, + parallelism: Int = 1, + numIterations:Int = 1): Word2VecModel = { + new Word2Vec(size,startingAlpha, window, minCount, parallelism, numIterations).fit[S](input) } }