Skip to content

Commit

Permalink
add multiple iteration support
Browse files Browse the repository at this point in the history
  • Loading branch information
Liquan Pei committed Aug 3, 2014
1 parent 720b5a3 commit 6bcc8be
Showing 1 changed file with 70 additions and 60 deletions.
130 changes: 70 additions & 60 deletions mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ class Word2Vec(
val startingAlpha: Double,
val window: Int,
val minCount: Int,
val parallelism:Int = 1)
val parallelism:Int = 1,
val numIterations:Int = 1)
extends Serializable with Logging {

private val EXP_TABLE_SIZE = 1000
Expand Down Expand Up @@ -241,73 +242,80 @@ class Word2Vec(
}

val newSentences = sentences.repartition(parallelism).cache()
val temp = Array.fill[Double](vocabSize * layer1Size)((Random.nextDouble - 0.5) / layer1Size)
val (aggSyn0, _, _, _) =
// TODO: broadcast temp instead of serializing it directly
// or initialize the model in each executor
newSentences.aggregate((temp.clone(), new Array[Double](vocabSize * layer1Size), 0, 0))(
seqOp = (c, v) => (c, v) match { case ((syn0, syn1, lastWordCount, wordCount), sentence) =>
var lwc = lastWordCount
var wc = wordCount
if (wordCount - lastWordCount > 10000) {
lwc = wordCount
alpha = startingAlpha * (1 - parallelism * wordCount.toDouble / (trainWordsCount + 1))
if (alpha < startingAlpha * 0.0001) alpha = startingAlpha * 0.0001
logInfo("wordCount = " + wordCount + ", alpha = " + alpha)
}
wc += sentence.size
var pos = 0
while (pos < sentence.size) {
val word = sentence(pos)
// TODO: fix random seed
val b = Random.nextInt(window)
// Train Skip-gram
var a = b
while (a < window * 2 + 1 - b) {
if (a != window) {
val c = pos - window + a
if (c >= 0 && c < sentence.size) {
val lastWord = sentence(c)
val l1 = lastWord * layer1Size
val neu1e = new Array[Double](layer1Size)
// Hierarchical softmax
var d = 0
while (d < vocab(word).codeLen) {
val l2 = vocab(word).point(d) * layer1Size
// Propagate hidden -> output
var f = blas.ddot(layer1Size, syn0, l1, 1, syn1, l2, 1)
if (f > -MAX_EXP && f < MAX_EXP) {
val ind = ((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2.0)).toInt
f = expTable.value(ind)
val g = (1 - vocab(word).code(d) - f) * alpha
blas.daxpy(layer1Size, g, syn1, l2, 1, neu1e, 0, 1)
blas.daxpy(layer1Size, g, syn0, l1, 1, syn1, l2, 1)
var syn0Global
= Array.fill[Double](vocabSize * layer1Size)((Random.nextDouble - 0.5) / layer1Size)
var syn1Global = new Array[Double](vocabSize * layer1Size)

for(iter <- 1 to numIterations) {
val (aggSyn0, aggSyn1, _, _) =
// TODO: broadcast temp instead of serializing it directly
// or initialize the model in each executor
newSentences.aggregate((syn0Global.clone(), syn1Global.clone(), 0, 0))(
seqOp = (c, v) => (c, v) match {
case ((syn0, syn1, lastWordCount, wordCount), sentence) =>
var lwc = lastWordCount
var wc = wordCount
if (wordCount - lastWordCount > 10000) {
lwc = wordCount
alpha = startingAlpha * (1 - parallelism * wordCount.toDouble / (trainWordsCount + 1))
if (alpha < startingAlpha * 0.0001) alpha = startingAlpha * 0.0001
logInfo("wordCount = " + wordCount + ", alpha = " + alpha)
}
wc += sentence.size
var pos = 0
while (pos < sentence.size) {
val word = sentence(pos)
// TODO: fix random seed
val b = Random.nextInt(window)
// Train Skip-gram
var a = b
while (a < window * 2 + 1 - b) {
if (a != window) {
val c = pos - window + a
if (c >= 0 && c < sentence.size) {
val lastWord = sentence(c)
val l1 = lastWord * layer1Size
val neu1e = new Array[Double](layer1Size)
// Hierarchical softmax
var d = 0
while (d < vocab(word).codeLen) {
val l2 = vocab(word).point(d) * layer1Size
// Propagate hidden -> output
var f = blas.ddot(layer1Size, syn0, l1, 1, syn1, l2, 1)
if (f > -MAX_EXP && f < MAX_EXP) {
val ind = ((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2.0)).toInt
f = expTable.value(ind)
val g = (1 - vocab(word).code(d) - f) * alpha
blas.daxpy(layer1Size, g, syn1, l2, 1, neu1e, 0, 1)
blas.daxpy(layer1Size, g, syn0, l1, 1, syn1, l2, 1)
}
d += 1
}
d += 1
blas.daxpy(layer1Size, 1.0, neu1e, 0, 1, syn0, l1, 1)
}
blas.daxpy(layer1Size, 1.0, neu1e, 0, 1, syn0, l1, 1)
}
a += 1
}
a += 1
pos += 1
}
pos += 1
}
(syn0, syn1, lwc, wc)
},
combOp = (c1, c2) => (c1, c2) match {
case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) =>
val n = syn0_1.length
blas.daxpy(n, 1.0, syn0_2, 1, syn0_1, 1)
blas.daxpy(n, 1.0, syn1_2, 1, syn1_1, 1)
(syn0_1, syn1_1, lwc_1 + lwc_2, wc_1 + wc_2)
})

(syn0, syn1, lwc, wc)
},
combOp = (c1, c2) => (c1, c2) match {
case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) =>
val n = syn0_1.length
blas.daxpy(n, 1.0, syn0_2, 1, syn0_1, 1)
blas.daxpy(n, 1.0, syn1_2, 1, syn1_1, 1)
(syn0_1, syn1_1, lwc_1 + lwc_2, wc_1 + wc_2)
})
syn0Global = aggSyn0
syn1Global = aggSyn1
}
val wordMap = new Array[(String, Array[Double])](vocabSize)
var i = 0
while (i < vocabSize) {
val word = vocab(i).word
val vector = new Array[Double](layer1Size)
Array.copy(aggSyn0, i * layer1Size, vector, 0, layer1Size)
Array.copy(syn0Global, i * layer1Size, vector, 0, layer1Size)
wordMap(i) = (word, vector)
i += 1
}
Expand Down Expand Up @@ -398,7 +406,9 @@ object Word2Vec{
size: Int,
startingAlpha: Double,
window: Int,
minCount: Int): Word2VecModel = {
new Word2Vec(size,startingAlpha, window, minCount).fit[S](input)
minCount: Int,
parallelism: Int = 1,
numIterations:Int = 1): Word2VecModel = {
new Word2Vec(size,startingAlpha, window, minCount, parallelism, numIterations).fit[S](input)
}
}

0 comments on commit 6bcc8be

Please sign in to comment.