From 720b5a3ea697a881fc7d7c286b65ef110421f89e Mon Sep 17 00:00:00 2001 From: Liquan Pei Date: Fri, 1 Aug 2014 22:53:03 -0700 Subject: [PATCH] Add test for Word2Vec algorithm, minor fixes --- .../apache/spark/mllib/feature/Word2Vec.scala | 17 ++++++++------ .../spark/mllib/feature/Word2VecSuite.scala | 22 ++++++++++++++++++- 2 files changed, 31 insertions(+), 8 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index f4266c94f63e0..b55122d3c9f1e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -50,24 +50,27 @@ private case class VocabWord( * natural language processing and machine learning algorithms. * * We used skip-gram model in our implementation and hierarchical softmax - * method to train the model. + * method to train the model. The variable names in the implementation + * mathes the original C implementation. * * For original C implementation, see https://code.google.com/p/word2vec/ * For research papers, see * Efficient Estimation of Word Representations in Vector Space * and - * Distributed Representations of Words and Phrases and their Compositionality + * Distributed Representations of Words and Phrases and their Compositionality. * @param size vector dimension * @param startingAlpha initial learning rate * @param window context words from [-window, window] * @param minCount minimum frequncy to consider a vocabulary word + * @param parallelisum number of partitions to run Word2Vec */ @Experimental class Word2Vec( val size: Int, val startingAlpha: Double, val window: Int, - val minCount: Int) + val minCount: Int, + val parallelism:Int = 1) extends Serializable with Logging { private val EXP_TABLE_SIZE = 1000 @@ -237,7 +240,7 @@ class Word2Vec( } } - val newSentences = sentences.repartition(1).cache() + val newSentences = sentences.repartition(parallelism).cache() val temp = Array.fill[Double](vocabSize * layer1Size)((Random.nextDouble - 0.5) / layer1Size) val (aggSyn0, _, _, _) = // TODO: broadcast temp instead of serializing it directly @@ -248,7 +251,7 @@ class Word2Vec( var wc = wordCount if (wordCount - lastWordCount > 10000) { lwc = wordCount - alpha = startingAlpha * (1 - wordCount.toDouble / (trainWordsCount + 1)) + alpha = startingAlpha * (1 - parallelism * wordCount.toDouble / (trainWordsCount + 1)) if (alpha < startingAlpha * 0.0001) alpha = startingAlpha * 0.0001 logInfo("wordCount = " + wordCount + ", alpha = " + alpha) } @@ -296,7 +299,7 @@ class Word2Vec( val n = syn0_1.length blas.daxpy(n, 1.0, syn0_2, 1, syn0_1, 1) blas.daxpy(n, 1.0, syn1_2, 1, syn1_1, 1) - (syn0_1, syn0_2, lwc_1 + lwc_2, wc_1 + wc_2) + (syn0_1, syn1_1, lwc_1 + lwc_2, wc_1 + wc_2) }) val wordMap = new Array[(String, Array[Double])](vocabSize) @@ -309,7 +312,7 @@ class Word2Vec( i += 1 } val modelRDD = sc.parallelize(wordMap, modelPartitionNum) - .partitionBy(new HashPartitioner(modelPartitionNum)) + .partitionBy(new HashPartitioner(modelPartitionNum)).cache() new Word2VecModel(modelRDD) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala index 54e56529c5a47..e2b71c16f3308 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala @@ -23,7 +23,27 @@ import org.apache.spark.SparkContext._ import org.apache.spark.mllib.util.LocalSparkContext class Word2VecSuite extends FunSuite with LocalSparkContext { - test("word2vec") { + test("Word2Vec") { + val sentence = "a b " * 100 + "a c " * 10 + val localDoc = Seq(sentence, sentence) + val doc = sc.parallelize(localDoc) + .map(line => line.split(" ").toSeq) + val size = 10 + val startingAlpha = 0.025 + val window = 2 + val minCount = 2 + val num = 2 + val word = "a" + + val model = Word2Vec.train(doc, size, startingAlpha, window, minCount) + val synons = model.findSynonyms("a", 2) + assert(synons.length == num) + assert(synons(0)._1 == "b") + assert(synons(1)._1 == "c") + } + + + test("Word2VecModel") { val num = 2 val localModel = Seq( ("china" , Array(0.50, 0.50, 0.50, 0.50)),