From c14da411d4da1b6553759afff7952ac746c9fa15 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Sun, 3 Aug 2014 22:09:58 -0700 Subject: [PATCH] fix styles --- .../apache/spark/mllib/feature/Word2Vec.scala | 37 ++++++++++--------- .../spark/mllib/feature/Word2VecSuite.scala | 29 ++++++++------- 2 files changed, 34 insertions(+), 32 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 03cb0ff11027f..87c81e7b0bd2f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -17,20 +17,18 @@ package org.apache.spark.mllib.feature -import scala.util.Random -import scala.collection.mutable.ArrayBuffer import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer +import scala.util.Random import com.github.fommil.netlib.BLAS.{getInstance => blas} - -import org.apache.spark.annotation.Experimental -import org.apache.spark.Logging -import org.apache.spark.rdd._ +import org.apache.spark.{HashPartitioner, Logging} import org.apache.spark.SparkContext._ +import org.apache.spark.annotation.Experimental import org.apache.spark.mllib.linalg.{Vector, Vectors} -import org.apache.spark.HashPartitioner -import org.apache.spark.storage.StorageLevel import org.apache.spark.mllib.rdd.RDDFunctions._ +import org.apache.spark.rdd._ +import org.apache.spark.storage.StorageLevel /** * Entry in vocabulary @@ -53,7 +51,7 @@ private case class VocabWord( * * We used skip-gram model in our implementation and hierarchical softmax * method to train the model. The variable names in the implementation - * mathes the original C implementation. + * matches the original C implementation. * * For original C implementation, see https://code.google.com/p/word2vec/ * For research papers, see @@ -69,10 +67,14 @@ private case class VocabWord( class Word2Vec( val size: Int, val startingAlpha: Double, - val parallelism: Int = 1, - val numIterations: Int = 1) - extends Serializable with Logging { - + val parallelism: Int, + val numIterations: Int) extends Serializable with Logging { + + /** + * Word2Vec with a single thread. + */ + def this(size: Int, startingAlpha: Int) = this(size, startingAlpha, 1, 1) + private val EXP_TABLE_SIZE = 1000 private val MAX_EXP = 6 private val MAX_CODE_LENGTH = 40 @@ -92,7 +94,7 @@ class Word2Vec( private var vocabHash = mutable.HashMap.empty[String, Int] private var alpha = startingAlpha - private def learnVocab(words:RDD[String]){ + private def learnVocab(words:RDD[String]): Unit = { vocab = words.map(w => (w, 1)) .reduceByKey(_ + _) .map(x => VocabWord( @@ -126,7 +128,7 @@ class Word2Vec( expTable } - private def createBinaryTree() { + private def createBinaryTree(): Unit = { val count = new Array[Long](vocabSize * 2 + 1) val binary = new Array[Int](vocabSize * 2 + 1) val parentNode = new Array[Int](vocabSize * 2 + 1) @@ -208,7 +210,6 @@ class Word2Vec( * @param dataset an RDD of words * @return a Word2VecModel */ - def fit[S <: Iterable[String]](dataset: RDD[S]): Word2VecModel = { val words = dataset.flatMap(x => x) @@ -339,7 +340,7 @@ class Word2Vec( /** * Word2Vec model */ -class Word2VecModel (private val model:RDD[(String, Array[Float])]) extends Serializable { +class Word2VecModel(private val model: RDD[(String, Array[Float])]) extends Serializable { private def cosineSimilarity(v1: Array[Float], v2: Array[Float]): Double = { require(v1.length == v2.length, "Vectors should have the same length") @@ -358,7 +359,7 @@ class Word2VecModel (private val model:RDD[(String, Array[Float])]) extends Seri def transform(word: String): Vector = { val result = model.lookup(word) if (result.isEmpty) { - throw new IllegalStateException(s"${word} not in vocabulary") + throw new IllegalStateException(s"$word not in vocabulary") } else Vectors.dense(result(0).map(_.toDouble)) } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala index e2b71c16f3308..3ec3208f5fa34 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala @@ -19,10 +19,12 @@ package org.apache.spark.mllib.feature import org.scalatest.FunSuite -import org.apache.spark.SparkContext._ import org.apache.spark.mllib.util.LocalSparkContext class Word2VecSuite extends FunSuite with LocalSparkContext { + + // TODO: add more tests + test("Word2Vec") { val sentence = "a b " * 100 + "a c " * 10 val localDoc = Seq(sentence, sentence) @@ -33,28 +35,27 @@ class Word2VecSuite extends FunSuite with LocalSparkContext { val window = 2 val minCount = 2 val num = 2 - val word = "a" val model = Word2Vec.train(doc, size, startingAlpha, window, minCount) - val synons = model.findSynonyms("a", 2) - assert(synons.length == num) - assert(synons(0)._1 == "b") - assert(synons(1)._1 == "c") + val syms = model.findSynonyms("a", 2) + assert(syms.length == num) + assert(syms(0)._1 == "b") + assert(syms(1)._1 == "c") } test("Word2VecModel") { val num = 2 val localModel = Seq( - ("china" , Array(0.50, 0.50, 0.50, 0.50)), - ("japan" , Array(0.40, 0.50, 0.50, 0.50)), - ("taiwan", Array(0.60, 0.50, 0.50, 0.50)), - ("korea" , Array(0.45, 0.60, 0.60, 0.60)) + ("china" , Array(0.50f, 0.50f, 0.50f, 0.50f)), + ("japan" , Array(0.40f, 0.50f, 0.50f, 0.50f)), + ("taiwan", Array(0.60f, 0.50f, 0.50f, 0.50f)), + ("korea" , Array(0.45f, 0.60f, 0.60f, 0.60f)) ) val model = new Word2VecModel(sc.parallelize(localModel, 2)) - val synons = model.findSynonyms("china", num) - assert(synons.length == num) - assert(synons(0)._1 == "taiwan") - assert(synons(1)._1 == "japan") + val syms = model.findSynonyms("china", num) + assert(syms.length == num) + assert(syms(0)._1 == "taiwan") + assert(syms(1)._1 == "japan") } }