Skip to content

Commit

Permalink
fix styles
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Aug 4, 2014
1 parent 384c771 commit c14da41
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 32 deletions.
37 changes: 19 additions & 18 deletions mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,18 @@

package org.apache.spark.mllib.feature

import scala.util.Random
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.util.Random

import com.github.fommil.netlib.BLAS.{getInstance => blas}

import org.apache.spark.annotation.Experimental
import org.apache.spark.Logging
import org.apache.spark.rdd._
import org.apache.spark.{HashPartitioner, Logging}
import org.apache.spark.SparkContext._
import org.apache.spark.annotation.Experimental
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.HashPartitioner
import org.apache.spark.storage.StorageLevel
import org.apache.spark.mllib.rdd.RDDFunctions._
import org.apache.spark.rdd._
import org.apache.spark.storage.StorageLevel

/**
* Entry in vocabulary
Expand All @@ -53,7 +51,7 @@ private case class VocabWord(
*
* We used skip-gram model in our implementation and hierarchical softmax
* method to train the model. The variable names in the implementation
* mathes the original C implementation.
* matches the original C implementation.
*
* For original C implementation, see https://code.google.com/p/word2vec/
* For research papers, see
Expand All @@ -69,10 +67,14 @@ private case class VocabWord(
class Word2Vec(
val size: Int,
val startingAlpha: Double,
val parallelism: Int = 1,
val numIterations: Int = 1)
extends Serializable with Logging {

val parallelism: Int,
val numIterations: Int) extends Serializable with Logging {

/**
* Word2Vec with a single thread.
*/
def this(size: Int, startingAlpha: Int) = this(size, startingAlpha, 1, 1)

private val EXP_TABLE_SIZE = 1000
private val MAX_EXP = 6
private val MAX_CODE_LENGTH = 40
Expand All @@ -92,7 +94,7 @@ class Word2Vec(
private var vocabHash = mutable.HashMap.empty[String, Int]
private var alpha = startingAlpha

private def learnVocab(words:RDD[String]){
private def learnVocab(words:RDD[String]): Unit = {
vocab = words.map(w => (w, 1))
.reduceByKey(_ + _)
.map(x => VocabWord(
Expand Down Expand Up @@ -126,7 +128,7 @@ class Word2Vec(
expTable
}

private def createBinaryTree() {
private def createBinaryTree(): Unit = {
val count = new Array[Long](vocabSize * 2 + 1)
val binary = new Array[Int](vocabSize * 2 + 1)
val parentNode = new Array[Int](vocabSize * 2 + 1)
Expand Down Expand Up @@ -208,7 +210,6 @@ class Word2Vec(
* @param dataset an RDD of words
* @return a Word2VecModel
*/

def fit[S <: Iterable[String]](dataset: RDD[S]): Word2VecModel = {

val words = dataset.flatMap(x => x)
Expand Down Expand Up @@ -339,7 +340,7 @@ class Word2Vec(
/**
* Word2Vec model
*/
class Word2VecModel (private val model:RDD[(String, Array[Float])]) extends Serializable {
class Word2VecModel(private val model: RDD[(String, Array[Float])]) extends Serializable {

private def cosineSimilarity(v1: Array[Float], v2: Array[Float]): Double = {
require(v1.length == v2.length, "Vectors should have the same length")
Expand All @@ -358,7 +359,7 @@ class Word2VecModel (private val model:RDD[(String, Array[Float])]) extends Seri
def transform(word: String): Vector = {
val result = model.lookup(word)
if (result.isEmpty) {
throw new IllegalStateException(s"${word} not in vocabulary")
throw new IllegalStateException(s"$word not in vocabulary")
}
else Vectors.dense(result(0).map(_.toDouble))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@ package org.apache.spark.mllib.feature

import org.scalatest.FunSuite

import org.apache.spark.SparkContext._
import org.apache.spark.mllib.util.LocalSparkContext

class Word2VecSuite extends FunSuite with LocalSparkContext {

// TODO: add more tests

test("Word2Vec") {
val sentence = "a b " * 100 + "a c " * 10
val localDoc = Seq(sentence, sentence)
Expand All @@ -33,28 +35,27 @@ class Word2VecSuite extends FunSuite with LocalSparkContext {
val window = 2
val minCount = 2
val num = 2
val word = "a"

val model = Word2Vec.train(doc, size, startingAlpha, window, minCount)
val synons = model.findSynonyms("a", 2)
assert(synons.length == num)
assert(synons(0)._1 == "b")
assert(synons(1)._1 == "c")
val syms = model.findSynonyms("a", 2)
assert(syms.length == num)
assert(syms(0)._1 == "b")
assert(syms(1)._1 == "c")
}


test("Word2VecModel") {
val num = 2
val localModel = Seq(
("china" , Array(0.50, 0.50, 0.50, 0.50)),
("japan" , Array(0.40, 0.50, 0.50, 0.50)),
("taiwan", Array(0.60, 0.50, 0.50, 0.50)),
("korea" , Array(0.45, 0.60, 0.60, 0.60))
("china" , Array(0.50f, 0.50f, 0.50f, 0.50f)),
("japan" , Array(0.40f, 0.50f, 0.50f, 0.50f)),
("taiwan", Array(0.60f, 0.50f, 0.50f, 0.50f)),
("korea" , Array(0.45f, 0.60f, 0.60f, 0.60f))
)
val model = new Word2VecModel(sc.parallelize(localModel, 2))
val synons = model.findSynonyms("china", num)
assert(synons.length == num)
assert(synons(0)._1 == "taiwan")
assert(synons(1)._1 == "japan")
val syms = model.findSynonyms("china", num)
assert(syms.length == num)
assert(syms(0)._1 == "taiwan")
assert(syms(1)._1 == "japan")
}
}

0 comments on commit c14da41

Please sign in to comment.