Skip to content

Commit

Permalink
Merge pull request #1 from mengxr/Ishiihara-master
Browse files Browse the repository at this point in the history
some updates
  • Loading branch information
Ishiihara committed Aug 4, 2014
2 parents e93e726 + c14da41 commit 26a948d
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 96 deletions.
161 changes: 79 additions & 82 deletions mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,19 @@

package org.apache.spark.mllib.feature

import scala.util.Random
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.util.Random

import com.github.fommil.netlib.BLAS.{getInstance => blas}

import org.apache.spark.annotation.Experimental
import org.apache.spark.Logging
import org.apache.spark.rdd._
import org.apache.spark.{HashPartitioner, Logging}
import org.apache.spark.SparkContext._
import org.apache.spark.annotation.Experimental
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.HashPartitioner
import org.apache.spark.storage.StorageLevel
import org.apache.spark.mllib.rdd.RDDFunctions._
import org.apache.spark.rdd._
import org.apache.spark.storage.StorageLevel

/**
* Entry in vocabulary
*/
Expand All @@ -52,7 +51,7 @@ private case class VocabWord(
*
* We used skip-gram model in our implementation and hierarchical softmax
* method to train the model. The variable names in the implementation
* mathes the original C implementation.
* matches the original C implementation.
*
* For original C implementation, see https://code.google.com/p/word2vec/
* For research papers, see
Expand All @@ -61,34 +60,41 @@ private case class VocabWord(
* Distributed Representations of Words and Phrases and their Compositionality.
* @param size vector dimension
* @param startingAlpha initial learning rate
* @param window context words from [-window, window]
* @param minCount minimum frequncy to consider a vocabulary word
* @param parallelisum number of partitions to run Word2Vec
* @param parallelism number of partitions to run Word2Vec (using a small number for accuracy)
* @param numIterations number of iterations to run, should be smaller than or equal to parallelism
*/
@Experimental
class Word2Vec(
val size: Int,
val startingAlpha: Double,
val window: Int,
val minCount: Int,
val parallelism:Int = 1,
val numIterations:Int = 1)
extends Serializable with Logging {

val parallelism: Int,
val numIterations: Int) extends Serializable with Logging {

/**
* Word2Vec with a single thread.
*/
def this(size: Int, startingAlpha: Int) = this(size, startingAlpha, 1, 1)

private val EXP_TABLE_SIZE = 1000
private val MAX_EXP = 6
private val MAX_CODE_LENGTH = 40
private val MAX_SENTENCE_LENGTH = 1000
private val layer1Size = size
private val modelPartitionNum = 100


/** context words from [-window, window] */
private val window = 5

/** minimum frequency to consider a vocabulary word */
private val minCount = 5

private var trainWordsCount = 0
private var vocabSize = 0
private var vocab: Array[VocabWord] = null
private var vocabHash = mutable.HashMap.empty[String, Int]
private var alpha = startingAlpha

private def learnVocab(words:RDD[String]){
private def learnVocab(words:RDD[String]): Unit = {
vocab = words.map(w => (w, 1))
.reduceByKey(_ + _)
.map(x => VocabWord(
Expand All @@ -99,7 +105,7 @@ class Word2Vec(
0))
.filter(_.cn >= minCount)
.collect()
.sortWith((a, b)=> a.cn > b.cn)
.sortWith((a, b) => a.cn > b.cn)

vocabSize = vocab.length
var a = 0
Expand All @@ -111,22 +117,18 @@ class Word2Vec(
logInfo("trainWordsCount = " + trainWordsCount)
}

private def learnVocabPerPartition(words:RDD[String]) {

}

private def createExpTable(): Array[Double] = {
val expTable = new Array[Double](EXP_TABLE_SIZE)
private def createExpTable(): Array[Float] = {
val expTable = new Array[Float](EXP_TABLE_SIZE)
var i = 0
while (i < EXP_TABLE_SIZE) {
val tmp = math.exp((2.0 * i / EXP_TABLE_SIZE - 1.0) * MAX_EXP)
expTable(i) = tmp / (tmp + 1)
expTable(i) = (tmp / (tmp + 1.0)).toFloat
i += 1
}
expTable
}

private def createBinaryTree() {
private def createBinaryTree(): Unit = {
val count = new Array[Long](vocabSize * 2 + 1)
val binary = new Array[Int](vocabSize * 2 + 1)
val parentNode = new Array[Int](vocabSize * 2 + 1)
Expand Down Expand Up @@ -208,8 +210,7 @@ class Word2Vec(
* @param dataset an RDD of words
* @return a Word2VecModel
*/

def fit[S <: Iterable[String]](dataset:RDD[S]): Word2VecModel = {
def fit[S <: Iterable[String]](dataset: RDD[S]): Word2VecModel = {

val words = dataset.flatMap(x => x)

Expand All @@ -223,39 +224,37 @@ class Word2Vec(
val bcVocab = sc.broadcast(vocab)
val bcVocabHash = sc.broadcast(vocabHash)

val sentences: RDD[Array[Int]] = words.mapPartitions {
iter => { new Iterator[Array[Int]] {
def hasNext = iter.hasNext

def next = {
var sentence = new ArrayBuffer[Int]
var sentenceLength = 0
while (iter.hasNext && sentenceLength < MAX_SENTENCE_LENGTH) {
val word = bcVocabHash.value.get(iter.next)
word match {
case Some(w) => {
sentence += w
sentenceLength += 1
}
case None =>
}
val sentences: RDD[Array[Int]] = words.mapPartitions { iter =>
new Iterator[Array[Int]] {
def hasNext: Boolean = iter.hasNext

def next(): Array[Int] = {
var sentence = new ArrayBuffer[Int]
var sentenceLength = 0
while (iter.hasNext && sentenceLength < MAX_SENTENCE_LENGTH) {
val word = bcVocabHash.value.get(iter.next())
word match {
case Some(w) =>
sentence += w
sentenceLength += 1
case None =>
}
sentence.toArray
}
sentence.toArray
}
}
}

val newSentences = sentences.repartition(parallelism).cache()
var syn0Global
= Array.fill[Double](vocabSize * layer1Size)((Random.nextDouble - 0.5) / layer1Size)
var syn1Global = new Array[Double](vocabSize * layer1Size)
var syn0Global =
Array.fill[Float](vocabSize * layer1Size)((Random.nextFloat() - 0.5f) / layer1Size)
var syn1Global = new Array[Float](vocabSize * layer1Size)

for(iter <- 1 to numIterations) {
val (aggSyn0, aggSyn1, _, _) =
// TODO: broadcast temp instead of serializing it directly
// TODO: broadcast temp instead of serializing it directly
// or initialize the model in each executor
newSentences.treeAggregate((syn0Global.clone(), syn1Global.clone(), 0, 0))(
newSentences.treeAggregate((syn0Global, syn1Global, 0, 0))(
seqOp = (c, v) => (c, v) match {
case ((syn0, syn1, lastWordCount, wordCount), sentence) =>
var lwc = lastWordCount
Expand All @@ -280,23 +279,23 @@ class Word2Vec(
if (c >= 0 && c < sentence.size) {
val lastWord = sentence(c)
val l1 = lastWord * layer1Size
val neu1e = new Array[Double](layer1Size)
val neu1e = new Array[Float](layer1Size)
// Hierarchical softmax
var d = 0
while (d < bcVocab.value(word).codeLen) {
val l2 = bcVocab.value(word).point(d) * layer1Size
// Propagate hidden -> output
var f = blas.ddot(layer1Size, syn0, l1, 1, syn1, l2, 1)
var f = blas.sdot(layer1Size, syn0, l1, 1, syn1, l2, 1)
if (f > -MAX_EXP && f < MAX_EXP) {
val ind = ((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2.0)).toInt
f = expTable.value(ind)
val g = (1 - bcVocab.value(word).code(d) - f) * alpha
blas.daxpy(layer1Size, g, syn1, l2, 1, neu1e, 0, 1)
blas.daxpy(layer1Size, g, syn0, l1, 1, syn1, l2, 1)
val g = ((1 - bcVocab.value(word).code(d) - f) * alpha).toFloat
blas.saxpy(layer1Size, g, syn1, l2, 1, neu1e, 0, 1)
blas.saxpy(layer1Size, g, syn0, l1, 1, syn1, l2, 1)
}
d += 1
}
blas.daxpy(layer1Size, 1.0, neu1e, 0, 1, syn0, l1, 1)
blas.saxpy(layer1Size, 1.0f, neu1e, 0, 1, syn0, l1, 1)
}
}
a += 1
Expand All @@ -308,24 +307,24 @@ class Word2Vec(
combOp = (c1, c2) => (c1, c2) match {
case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) =>
val n = syn0_1.length
val weight1 = 1.0 * wc_1 / (wc_1 + wc_2)
val weight2 = 1.0 * wc_2 / (wc_1 + wc_2)
blas.dscal(n, weight1, syn0_1, 1)
blas.dscal(n, weight1, syn1_1, 1)
blas.daxpy(n, weight2, syn0_2, 1, syn0_1, 1)
blas.daxpy(n, weight2, syn1_2, 1, syn1_1, 1)
val weight1 = 1.0f * wc_1 / (wc_1 + wc_2)
val weight2 = 1.0f * wc_2 / (wc_1 + wc_2)
blas.sscal(n, weight1, syn0_1, 1)
blas.sscal(n, weight1, syn1_1, 1)
blas.saxpy(n, weight2, syn0_2, 1, syn0_1, 1)
blas.saxpy(n, weight2, syn1_2, 1, syn1_1, 1)
(syn0_1, syn1_1, lwc_1 + lwc_2, wc_1 + wc_2)
})
syn0Global = aggSyn0
syn1Global = aggSyn1
}
newSentences.unpersist()

val wordMap = new Array[(String, Array[Double])](vocabSize)
val wordMap = new Array[(String, Array[Float])](vocabSize)
var i = 0
while (i < vocabSize) {
val word = bcVocab.value(i).word
val vector = new Array[Double](layer1Size)
val vector = new Array[Float](layer1Size)
Array.copy(syn0Global, i * layer1Size, vector, 0, layer1Size)
wordMap(i) = (word, vector)
i += 1
Expand All @@ -341,15 +340,15 @@ class Word2Vec(
/**
* Word2Vec model
*/
class Word2VecModel (private val model:RDD[(String, Array[Double])]) extends Serializable {
class Word2VecModel(private val model: RDD[(String, Array[Float])]) extends Serializable {

private def cosineSimilarity(v1: Array[Double], v2: Array[Double]): Double = {
private def cosineSimilarity(v1: Array[Float], v2: Array[Float]): Double = {
require(v1.length == v2.length, "Vectors should have the same length")
val n = v1.length
val norm1 = blas.dnrm2(n, v1, 1)
val norm2 = blas.dnrm2(n, v2, 1)
val norm1 = blas.snrm2(n, v1, 1)
val norm2 = blas.snrm2(n, v2, 1)
if (norm1 == 0 || norm2 == 0) return 0.0
blas.ddot(n, v1, 1, v2,1) / norm1 / norm2
blas.sdot(n, v1, 1, v2,1) / norm1 / norm2
}

/**
Expand All @@ -360,9 +359,9 @@ class Word2VecModel (private val model:RDD[(String, Array[Double])]) extends Ser
def transform(word: String): Vector = {
val result = model.lookup(word)
if (result.isEmpty) {
throw new IllegalStateException(s"${word} not in vocabulary")
throw new IllegalStateException(s"$word not in vocabulary")
}
else Vectors.dense(result(0))
else Vectors.dense(result(0).map(_.toDouble))
}

/**
Expand Down Expand Up @@ -394,7 +393,7 @@ class Word2VecModel (private val model:RDD[(String, Array[Double])]) extends Ser
def findSynonyms(vector: Vector, num: Int): Array[(String, Double)] = {
require(num > 0, "Number of similar words should > 0")
val topK = model.map { case(w, vec) =>
(cosineSimilarity(vector.toArray, vec), w) }
(cosineSimilarity(vector.toArray.map(_.toFloat), vec), w) }
.sortByKey(ascending = false)
.take(num + 1)
.map(_.swap)
Expand All @@ -410,18 +409,16 @@ object Word2Vec{
* @param input RDD of words
* @param size vector dimension
* @param startingAlpha initial learning rate
* @param window context words from [-window, window]
* @param minCount minimum frequncy to consider a vocabulary word
* @return Word2Vec model
*/
* @param parallelism number of partitions to run Word2Vec (using a small number for accuracy)
* @param numIterations number of iterations, should be smaller than or equal to parallelism
* @return Word2Vec model
*/
def train[S <: Iterable[String]](
input: RDD[S],
size: Int,
startingAlpha: Double,
window: Int,
minCount: Int,
parallelism: Int = 1,
numIterations:Int = 1): Word2VecModel = {
new Word2Vec(size,startingAlpha, window, minCount, parallelism, numIterations).fit[S](input)
new Word2Vec(size,startingAlpha, parallelism, numIterations).fit[S](input)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@ package org.apache.spark.mllib.feature

import org.scalatest.FunSuite

import org.apache.spark.SparkContext._
import org.apache.spark.mllib.util.LocalSparkContext

class Word2VecSuite extends FunSuite with LocalSparkContext {

// TODO: add more tests

test("Word2Vec") {
val sentence = "a b " * 100 + "a c " * 10
val localDoc = Seq(sentence, sentence)
Expand All @@ -33,28 +35,27 @@ class Word2VecSuite extends FunSuite with LocalSparkContext {
val window = 2
val minCount = 2
val num = 2
val word = "a"

val model = Word2Vec.train(doc, size, startingAlpha, window, minCount)
val synons = model.findSynonyms("a", 2)
assert(synons.length == num)
assert(synons(0)._1 == "b")
assert(synons(1)._1 == "c")
val syms = model.findSynonyms("a", 2)
assert(syms.length == num)
assert(syms(0)._1 == "b")
assert(syms(1)._1 == "c")
}


test("Word2VecModel") {
val num = 2
val localModel = Seq(
("china" , Array(0.50, 0.50, 0.50, 0.50)),
("japan" , Array(0.40, 0.50, 0.50, 0.50)),
("taiwan", Array(0.60, 0.50, 0.50, 0.50)),
("korea" , Array(0.45, 0.60, 0.60, 0.60))
("china" , Array(0.50f, 0.50f, 0.50f, 0.50f)),
("japan" , Array(0.40f, 0.50f, 0.50f, 0.50f)),
("taiwan", Array(0.60f, 0.50f, 0.50f, 0.50f)),
("korea" , Array(0.45f, 0.60f, 0.60f, 0.60f))
)
val model = new Word2VecModel(sc.parallelize(localModel, 2))
val synons = model.findSynonyms("china", num)
assert(synons.length == num)
assert(synons(0)._1 == "taiwan")
assert(synons(1)._1 == "japan")
val syms = model.findSynonyms("china", num)
assert(syms.length == num)
assert(syms(0)._1 == "taiwan")
assert(syms(1)._1 == "japan")
}
}

0 comments on commit 26a948d

Please sign in to comment.