From 0aafb1b02a19fe4f1689543baf1882a49a7ff11a Mon Sep 17 00:00:00 2001 From: Liquan Pei Date: Fri, 1 Aug 2014 08:34:11 -0700 Subject: [PATCH] Add comments, minor fixes --- .../apache/spark/mllib/feature/Word2Vec.scala | 69 ++++++++++++------- 1 file changed, 46 insertions(+), 23 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 9461d0cc1ba2d..18f507c2f1b46 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -31,6 +31,9 @@ import org.apache.spark.SparkContext._ import org.apache.spark.mllib.linalg.Vector import org.apache.spark.HashPartitioner +/** + * Entry in vocabulary + */ private case class VocabWord( var word: String, var cn: Int, @@ -39,6 +42,9 @@ private case class VocabWord( var codeLen:Int ) +/** + * Vector representation of word + */ class Word2Vec( val size: Int, val startingAlpha: Double, @@ -51,7 +57,8 @@ class Word2Vec( private val MAX_CODE_LENGTH = 40 private val MAX_SENTENCE_LENGTH = 1000 private val layer1Size = size - + private val modelPartitionNum = 100 + private var trainWordsCount = 0 private var vocabSize = 0 private var vocab: Array[VocabWord] = null @@ -169,6 +176,7 @@ class Word2Vec( * Computes the vector representation of each word in * vocabulary * @param dataset an RDD of strings + * @return a Word2VecModel */ def fit(dataset:RDD[String]): Word2VecModel = { @@ -274,11 +282,14 @@ class Word2Vec( wordMap(i) = (word, vector) i += 1 } - val modelRDD = sc.parallelize(wordMap,100).partitionBy(new HashPartitioner(100)) + val modelRDD = sc.parallelize(wordMap, modelPartitionNum).partitionBy(new HashPartitioner(modelPartitionNum)) new Word2VecModel(modelRDD) } } +/** +* Word2Vec model +*/ class Word2VecModel (val _model:RDD[(String, Array[Double])]) extends Serializable { val model = _model @@ -292,22 +303,46 @@ class Word2VecModel (val _model:RDD[(String, Array[Double])]) extends Serializab blas.ddot(n, v1, 1, v2,1) / norm1 / norm2 } + /** + * Transforms a word to its vector representation + * @param word a word + * @return vector representation of word + */ + def transform(word: String): Array[Double] = { val result = model.lookup(word) if (result.isEmpty) Array[Double]() else result(0) } + /** + * Transforms an RDD to its vector representation + * @param dataset a an RDD of words + * @return RDD of vector representation + */ + def transform(dataset: RDD[String]): RDD[Array[Double]] = { dataset.map(word => transform(word)) } + /** + * Find synonyms of a word + * @param word a word + * @param num number of synonyms to find + * @return array of (word, similarity) + */ def findSynonyms(word: String, num: Int): Array[(String, Double)] = { val vector = transform(word) if (vector.isEmpty) Array[(String, Double)]() else findSynonyms(vector,num) } + /** + * Find synonyms of the vector representation of a word + * @param vector vector representation of a word + * @param num number of synonyms to find + * @return array of (word, similarity) + */ def findSynonyms(vector: Array[Double], num: Int): Array[(String, Double)] = { require(num > 0, "Number of similar words should > 0") val topK = model.map( @@ -321,6 +356,15 @@ class Word2VecModel (val _model:RDD[(String, Array[Double])]) extends Serializab } object Word2Vec extends Serializable with Logging { + /** + * Train Word2Vec model + * @param input RDD of words + * @param size vectoer dimension + * @param startingAlpha initial learning rate + * @param window context words from [-window, window] + * @param minCount minimum frequncy to consider a vocabulary word + * @return Word2Vec model + */ def train( input: RDD[String], size: Int, @@ -329,25 +373,4 @@ object Word2Vec extends Serializable with Logging { minCount: Int): Word2VecModel = { new Word2Vec(size,startingAlpha, window, minCount).fit(input) } - - def main(args: Array[String]) { - if (args.length < 6) { - println("Usage: word2vec input size startingAlpha window minCount num") - sys.exit(1) - } - val conf = new SparkConf() - .setAppName("word2vec") - - val sc = new SparkContext(conf) - val input = sc.textFile(args(0)) - val size = args(1).toInt - val startingAlpha = args(2).toDouble - val window = args(3).toInt - val minCount = args(4).toInt - val num = args(5).toInt - val model = train(input, size, startingAlpha, window, minCount) - val vec = model.findSynonyms("china", num) - for((w, dist) <- vec) logInfo(w.toString + " " + dist.toString) - sc.stop() - } }