diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index b966f775f7b01..87c81e7b0bd2f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -17,20 +17,19 @@ package org.apache.spark.mllib.feature -import scala.util.Random -import scala.collection.mutable.ArrayBuffer import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer +import scala.util.Random import com.github.fommil.netlib.BLAS.{getInstance => blas} - -import org.apache.spark.annotation.Experimental -import org.apache.spark.Logging -import org.apache.spark.rdd._ +import org.apache.spark.{HashPartitioner, Logging} import org.apache.spark.SparkContext._ +import org.apache.spark.annotation.Experimental import org.apache.spark.mllib.linalg.{Vector, Vectors} -import org.apache.spark.HashPartitioner -import org.apache.spark.storage.StorageLevel import org.apache.spark.mllib.rdd.RDDFunctions._ +import org.apache.spark.rdd._ +import org.apache.spark.storage.StorageLevel + /** * Entry in vocabulary */ @@ -52,7 +51,7 @@ private case class VocabWord( * * We used skip-gram model in our implementation and hierarchical softmax * method to train the model. The variable names in the implementation - * mathes the original C implementation. + * matches the original C implementation. * * For original C implementation, see https://code.google.com/p/word2vec/ * For research papers, see @@ -61,34 +60,41 @@ private case class VocabWord( * Distributed Representations of Words and Phrases and their Compositionality. * @param size vector dimension * @param startingAlpha initial learning rate - * @param window context words from [-window, window] - * @param minCount minimum frequncy to consider a vocabulary word - * @param parallelisum number of partitions to run Word2Vec + * @param parallelism number of partitions to run Word2Vec (using a small number for accuracy) + * @param numIterations number of iterations to run, should be smaller than or equal to parallelism */ @Experimental class Word2Vec( val size: Int, val startingAlpha: Double, - val window: Int, - val minCount: Int, - val parallelism:Int = 1, - val numIterations:Int = 1) - extends Serializable with Logging { - + val parallelism: Int, + val numIterations: Int) extends Serializable with Logging { + + /** + * Word2Vec with a single thread. + */ + def this(size: Int, startingAlpha: Int) = this(size, startingAlpha, 1, 1) + private val EXP_TABLE_SIZE = 1000 private val MAX_EXP = 6 private val MAX_CODE_LENGTH = 40 private val MAX_SENTENCE_LENGTH = 1000 private val layer1Size = size private val modelPartitionNum = 100 - + + /** context words from [-window, window] */ + private val window = 5 + + /** minimum frequency to consider a vocabulary word */ + private val minCount = 5 + private var trainWordsCount = 0 private var vocabSize = 0 private var vocab: Array[VocabWord] = null private var vocabHash = mutable.HashMap.empty[String, Int] private var alpha = startingAlpha - private def learnVocab(words:RDD[String]){ + private def learnVocab(words:RDD[String]): Unit = { vocab = words.map(w => (w, 1)) .reduceByKey(_ + _) .map(x => VocabWord( @@ -99,7 +105,7 @@ class Word2Vec( 0)) .filter(_.cn >= minCount) .collect() - .sortWith((a, b)=> a.cn > b.cn) + .sortWith((a, b) => a.cn > b.cn) vocabSize = vocab.length var a = 0 @@ -111,22 +117,18 @@ class Word2Vec( logInfo("trainWordsCount = " + trainWordsCount) } - private def learnVocabPerPartition(words:RDD[String]) { - - } - - private def createExpTable(): Array[Double] = { - val expTable = new Array[Double](EXP_TABLE_SIZE) + private def createExpTable(): Array[Float] = { + val expTable = new Array[Float](EXP_TABLE_SIZE) var i = 0 while (i < EXP_TABLE_SIZE) { val tmp = math.exp((2.0 * i / EXP_TABLE_SIZE - 1.0) * MAX_EXP) - expTable(i) = tmp / (tmp + 1) + expTable(i) = (tmp / (tmp + 1.0)).toFloat i += 1 } expTable } - private def createBinaryTree() { + private def createBinaryTree(): Unit = { val count = new Array[Long](vocabSize * 2 + 1) val binary = new Array[Int](vocabSize * 2 + 1) val parentNode = new Array[Int](vocabSize * 2 + 1) @@ -208,8 +210,7 @@ class Word2Vec( * @param dataset an RDD of words * @return a Word2VecModel */ - - def fit[S <: Iterable[String]](dataset:RDD[S]): Word2VecModel = { + def fit[S <: Iterable[String]](dataset: RDD[S]): Word2VecModel = { val words = dataset.flatMap(x => x) @@ -223,39 +224,37 @@ class Word2Vec( val bcVocab = sc.broadcast(vocab) val bcVocabHash = sc.broadcast(vocabHash) - val sentences: RDD[Array[Int]] = words.mapPartitions { - iter => { new Iterator[Array[Int]] { - def hasNext = iter.hasNext - - def next = { - var sentence = new ArrayBuffer[Int] - var sentenceLength = 0 - while (iter.hasNext && sentenceLength < MAX_SENTENCE_LENGTH) { - val word = bcVocabHash.value.get(iter.next) - word match { - case Some(w) => { - sentence += w - sentenceLength += 1 - } - case None => - } + val sentences: RDD[Array[Int]] = words.mapPartitions { iter => + new Iterator[Array[Int]] { + def hasNext: Boolean = iter.hasNext + + def next(): Array[Int] = { + var sentence = new ArrayBuffer[Int] + var sentenceLength = 0 + while (iter.hasNext && sentenceLength < MAX_SENTENCE_LENGTH) { + val word = bcVocabHash.value.get(iter.next()) + word match { + case Some(w) => + sentence += w + sentenceLength += 1 + case None => } - sentence.toArray } + sentence.toArray } } } val newSentences = sentences.repartition(parallelism).cache() - var syn0Global - = Array.fill[Double](vocabSize * layer1Size)((Random.nextDouble - 0.5) / layer1Size) - var syn1Global = new Array[Double](vocabSize * layer1Size) + var syn0Global = + Array.fill[Float](vocabSize * layer1Size)((Random.nextFloat() - 0.5f) / layer1Size) + var syn1Global = new Array[Float](vocabSize * layer1Size) for(iter <- 1 to numIterations) { val (aggSyn0, aggSyn1, _, _) = - // TODO: broadcast temp instead of serializing it directly + // TODO: broadcast temp instead of serializing it directly // or initialize the model in each executor - newSentences.treeAggregate((syn0Global.clone(), syn1Global.clone(), 0, 0))( + newSentences.treeAggregate((syn0Global, syn1Global, 0, 0))( seqOp = (c, v) => (c, v) match { case ((syn0, syn1, lastWordCount, wordCount), sentence) => var lwc = lastWordCount @@ -280,23 +279,23 @@ class Word2Vec( if (c >= 0 && c < sentence.size) { val lastWord = sentence(c) val l1 = lastWord * layer1Size - val neu1e = new Array[Double](layer1Size) + val neu1e = new Array[Float](layer1Size) // Hierarchical softmax var d = 0 while (d < bcVocab.value(word).codeLen) { val l2 = bcVocab.value(word).point(d) * layer1Size // Propagate hidden -> output - var f = blas.ddot(layer1Size, syn0, l1, 1, syn1, l2, 1) + var f = blas.sdot(layer1Size, syn0, l1, 1, syn1, l2, 1) if (f > -MAX_EXP && f < MAX_EXP) { val ind = ((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2.0)).toInt f = expTable.value(ind) - val g = (1 - bcVocab.value(word).code(d) - f) * alpha - blas.daxpy(layer1Size, g, syn1, l2, 1, neu1e, 0, 1) - blas.daxpy(layer1Size, g, syn0, l1, 1, syn1, l2, 1) + val g = ((1 - bcVocab.value(word).code(d) - f) * alpha).toFloat + blas.saxpy(layer1Size, g, syn1, l2, 1, neu1e, 0, 1) + blas.saxpy(layer1Size, g, syn0, l1, 1, syn1, l2, 1) } d += 1 } - blas.daxpy(layer1Size, 1.0, neu1e, 0, 1, syn0, l1, 1) + blas.saxpy(layer1Size, 1.0f, neu1e, 0, 1, syn0, l1, 1) } } a += 1 @@ -308,12 +307,12 @@ class Word2Vec( combOp = (c1, c2) => (c1, c2) match { case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) => val n = syn0_1.length - val weight1 = 1.0 * wc_1 / (wc_1 + wc_2) - val weight2 = 1.0 * wc_2 / (wc_1 + wc_2) - blas.dscal(n, weight1, syn0_1, 1) - blas.dscal(n, weight1, syn1_1, 1) - blas.daxpy(n, weight2, syn0_2, 1, syn0_1, 1) - blas.daxpy(n, weight2, syn1_2, 1, syn1_1, 1) + val weight1 = 1.0f * wc_1 / (wc_1 + wc_2) + val weight2 = 1.0f * wc_2 / (wc_1 + wc_2) + blas.sscal(n, weight1, syn0_1, 1) + blas.sscal(n, weight1, syn1_1, 1) + blas.saxpy(n, weight2, syn0_2, 1, syn0_1, 1) + blas.saxpy(n, weight2, syn1_2, 1, syn1_1, 1) (syn0_1, syn1_1, lwc_1 + lwc_2, wc_1 + wc_2) }) syn0Global = aggSyn0 @@ -321,11 +320,11 @@ class Word2Vec( } newSentences.unpersist() - val wordMap = new Array[(String, Array[Double])](vocabSize) + val wordMap = new Array[(String, Array[Float])](vocabSize) var i = 0 while (i < vocabSize) { val word = bcVocab.value(i).word - val vector = new Array[Double](layer1Size) + val vector = new Array[Float](layer1Size) Array.copy(syn0Global, i * layer1Size, vector, 0, layer1Size) wordMap(i) = (word, vector) i += 1 @@ -341,15 +340,15 @@ class Word2Vec( /** * Word2Vec model */ -class Word2VecModel (private val model:RDD[(String, Array[Double])]) extends Serializable { +class Word2VecModel(private val model: RDD[(String, Array[Float])]) extends Serializable { - private def cosineSimilarity(v1: Array[Double], v2: Array[Double]): Double = { + private def cosineSimilarity(v1: Array[Float], v2: Array[Float]): Double = { require(v1.length == v2.length, "Vectors should have the same length") val n = v1.length - val norm1 = blas.dnrm2(n, v1, 1) - val norm2 = blas.dnrm2(n, v2, 1) + val norm1 = blas.snrm2(n, v1, 1) + val norm2 = blas.snrm2(n, v2, 1) if (norm1 == 0 || norm2 == 0) return 0.0 - blas.ddot(n, v1, 1, v2,1) / norm1 / norm2 + blas.sdot(n, v1, 1, v2,1) / norm1 / norm2 } /** @@ -360,9 +359,9 @@ class Word2VecModel (private val model:RDD[(String, Array[Double])]) extends Ser def transform(word: String): Vector = { val result = model.lookup(word) if (result.isEmpty) { - throw new IllegalStateException(s"${word} not in vocabulary") + throw new IllegalStateException(s"$word not in vocabulary") } - else Vectors.dense(result(0)) + else Vectors.dense(result(0).map(_.toDouble)) } /** @@ -394,7 +393,7 @@ class Word2VecModel (private val model:RDD[(String, Array[Double])]) extends Ser def findSynonyms(vector: Vector, num: Int): Array[(String, Double)] = { require(num > 0, "Number of similar words should > 0") val topK = model.map { case(w, vec) => - (cosineSimilarity(vector.toArray, vec), w) } + (cosineSimilarity(vector.toArray.map(_.toFloat), vec), w) } .sortByKey(ascending = false) .take(num + 1) .map(_.swap) @@ -410,18 +409,16 @@ object Word2Vec{ * @param input RDD of words * @param size vector dimension * @param startingAlpha initial learning rate - * @param window context words from [-window, window] - * @param minCount minimum frequncy to consider a vocabulary word - * @return Word2Vec model - */ + * @param parallelism number of partitions to run Word2Vec (using a small number for accuracy) + * @param numIterations number of iterations, should be smaller than or equal to parallelism + * @return Word2Vec model + */ def train[S <: Iterable[String]]( input: RDD[S], size: Int, startingAlpha: Double, - window: Int, - minCount: Int, parallelism: Int = 1, numIterations:Int = 1): Word2VecModel = { - new Word2Vec(size,startingAlpha, window, minCount, parallelism, numIterations).fit[S](input) + new Word2Vec(size,startingAlpha, parallelism, numIterations).fit[S](input) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala index e2b71c16f3308..3ec3208f5fa34 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala @@ -19,10 +19,12 @@ package org.apache.spark.mllib.feature import org.scalatest.FunSuite -import org.apache.spark.SparkContext._ import org.apache.spark.mllib.util.LocalSparkContext class Word2VecSuite extends FunSuite with LocalSparkContext { + + // TODO: add more tests + test("Word2Vec") { val sentence = "a b " * 100 + "a c " * 10 val localDoc = Seq(sentence, sentence) @@ -33,28 +35,27 @@ class Word2VecSuite extends FunSuite with LocalSparkContext { val window = 2 val minCount = 2 val num = 2 - val word = "a" val model = Word2Vec.train(doc, size, startingAlpha, window, minCount) - val synons = model.findSynonyms("a", 2) - assert(synons.length == num) - assert(synons(0)._1 == "b") - assert(synons(1)._1 == "c") + val syms = model.findSynonyms("a", 2) + assert(syms.length == num) + assert(syms(0)._1 == "b") + assert(syms(1)._1 == "c") } test("Word2VecModel") { val num = 2 val localModel = Seq( - ("china" , Array(0.50, 0.50, 0.50, 0.50)), - ("japan" , Array(0.40, 0.50, 0.50, 0.50)), - ("taiwan", Array(0.60, 0.50, 0.50, 0.50)), - ("korea" , Array(0.45, 0.60, 0.60, 0.60)) + ("china" , Array(0.50f, 0.50f, 0.50f, 0.50f)), + ("japan" , Array(0.40f, 0.50f, 0.50f, 0.50f)), + ("taiwan", Array(0.60f, 0.50f, 0.50f, 0.50f)), + ("korea" , Array(0.45f, 0.60f, 0.60f, 0.60f)) ) val model = new Word2VecModel(sc.parallelize(localModel, 2)) - val synons = model.findSynonyms("china", num) - assert(synons.length == num) - assert(synons(0)._1 == "taiwan") - assert(synons(1)._1 == "japan") + val syms = model.findSynonyms("china", num) + assert(syms.length == num) + assert(syms(0)._1 == "taiwan") + assert(syms(1)._1 == "japan") } }