Skip to content

Commit

Permalink
Add test for Word2Vec algorithm, minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Liquan Pei committed Aug 2, 2014
1 parent 2e92b59 commit 720b5a3
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 8 deletions.
17 changes: 10 additions & 7 deletions mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -50,24 +50,27 @@ private case class VocabWord(
* natural language processing and machine learning algorithms.
*
* We used skip-gram model in our implementation and hierarchical softmax
* method to train the model.
* method to train the model. The variable names in the implementation
* mathes the original C implementation.
*
* For original C implementation, see https://code.google.com/p/word2vec/
* For research papers, see
* Efficient Estimation of Word Representations in Vector Space
* and
* Distributed Representations of Words and Phrases and their Compositionality
* Distributed Representations of Words and Phrases and their Compositionality.
* @param size vector dimension
* @param startingAlpha initial learning rate
* @param window context words from [-window, window]
* @param minCount minimum frequncy to consider a vocabulary word
* @param parallelisum number of partitions to run Word2Vec
*/
@Experimental
class Word2Vec(
val size: Int,
val startingAlpha: Double,
val window: Int,
val minCount: Int)
val minCount: Int,
val parallelism:Int = 1)
extends Serializable with Logging {

private val EXP_TABLE_SIZE = 1000
Expand Down Expand Up @@ -237,7 +240,7 @@ class Word2Vec(
}
}

val newSentences = sentences.repartition(1).cache()
val newSentences = sentences.repartition(parallelism).cache()
val temp = Array.fill[Double](vocabSize * layer1Size)((Random.nextDouble - 0.5) / layer1Size)
val (aggSyn0, _, _, _) =
// TODO: broadcast temp instead of serializing it directly
Expand All @@ -248,7 +251,7 @@ class Word2Vec(
var wc = wordCount
if (wordCount - lastWordCount > 10000) {
lwc = wordCount
alpha = startingAlpha * (1 - wordCount.toDouble / (trainWordsCount + 1))
alpha = startingAlpha * (1 - parallelism * wordCount.toDouble / (trainWordsCount + 1))
if (alpha < startingAlpha * 0.0001) alpha = startingAlpha * 0.0001
logInfo("wordCount = " + wordCount + ", alpha = " + alpha)
}
Expand Down Expand Up @@ -296,7 +299,7 @@ class Word2Vec(
val n = syn0_1.length
blas.daxpy(n, 1.0, syn0_2, 1, syn0_1, 1)
blas.daxpy(n, 1.0, syn1_2, 1, syn1_1, 1)
(syn0_1, syn0_2, lwc_1 + lwc_2, wc_1 + wc_2)
(syn0_1, syn1_1, lwc_1 + lwc_2, wc_1 + wc_2)
})

val wordMap = new Array[(String, Array[Double])](vocabSize)
Expand All @@ -309,7 +312,7 @@ class Word2Vec(
i += 1
}
val modelRDD = sc.parallelize(wordMap, modelPartitionNum)
.partitionBy(new HashPartitioner(modelPartitionNum))
.partitionBy(new HashPartitioner(modelPartitionNum)).cache()
new Word2VecModel(modelRDD)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,27 @@ import org.apache.spark.SparkContext._
import org.apache.spark.mllib.util.LocalSparkContext

class Word2VecSuite extends FunSuite with LocalSparkContext {
test("word2vec") {
test("Word2Vec") {
val sentence = "a b " * 100 + "a c " * 10
val localDoc = Seq(sentence, sentence)
val doc = sc.parallelize(localDoc)
.map(line => line.split(" ").toSeq)
val size = 10
val startingAlpha = 0.025
val window = 2
val minCount = 2
val num = 2
val word = "a"

val model = Word2Vec.train(doc, size, startingAlpha, window, minCount)
val synons = model.findSynonyms("a", 2)
assert(synons.length == num)
assert(synons(0)._1 == "b")
assert(synons(1)._1 == "c")
}


test("Word2VecModel") {
val num = 2
val localModel = Seq(
("china" , Array(0.50, 0.50, 0.50, 0.50)),
Expand Down

0 comments on commit 720b5a3

Please sign in to comment.