Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
Liquan Pei committed Aug 1, 2014
1 parent c475540 commit 8d6befe
Show file tree
Hide file tree
Showing 2 changed files with 393 additions and 0 deletions.
353 changes: 353 additions & 0 deletions mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,353 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* Add a comment to this line
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.feature

import scala.util._
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
import scala.collection.mutable

import com.github.fommil.netlib.BLAS.{getInstance => blas}

import org.apache.spark._
import org.apache.spark.rdd._
import org.apache.spark.SparkContext._
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.HashPartitioner

private case class VocabWord(
var word: String,
var cn: Int,
var point: Array[Int],
var code: Array[Int],
var codeLen:Int
)

class Word2Vec(
val size: Int,
val startingAlpha: Double,
val window: Int,
val minCount: Int)
extends Serializable with Logging {

private val EXP_TABLE_SIZE = 1000
private val MAX_EXP = 6
private val MAX_CODE_LENGTH = 40
private val MAX_SENTENCE_LENGTH = 1000
private val layer1Size = size

private var trainWordsCount = 0
private var vocabSize = 0
private var vocab: Array[VocabWord] = null
private var vocabHash = mutable.HashMap.empty[String, Int]
private var alpha = startingAlpha

private def learnVocab(dataset: RDD[String]) {
vocab = dataset.flatMap(line => line.split(" "))
.map(w => (w, 1))
.reduceByKey(_ + _)
.map(x => VocabWord(x._1, x._2, new Array[Int](MAX_CODE_LENGTH), new Array[Int](MAX_CODE_LENGTH), 0))
.filter(_.cn >= minCount)
.collect()
.sortWith((a, b)=> a.cn > b.cn)

vocabSize = vocab.length
var a = 0
while (a < vocabSize) {
vocabHash += vocab(a).word -> a
trainWordsCount += vocab(a).cn
a += 1
}
logInfo("trainWordsCount = " + trainWordsCount)
}

private def createExpTable(): Array[Double] = {
val expTable = new Array[Double](EXP_TABLE_SIZE)
var i = 0
while (i < EXP_TABLE_SIZE) {
val tmp = math.exp((2.0 * i / EXP_TABLE_SIZE - 1.0) * MAX_EXP)
expTable(i) = tmp / (tmp + 1)
i += 1
}
expTable
}

private def createBinaryTree() {
val count = new Array[Long](vocabSize * 2 + 1)
val binary = new Array[Int](vocabSize * 2 + 1)
val parentNode = new Array[Int](vocabSize * 2 + 1)
val code = new Array[Int](MAX_CODE_LENGTH)
val point = new Array[Int](MAX_CODE_LENGTH)
var a = 0
while (a < vocabSize) {
count(a) = vocab(a).cn
a += 1
}
while (a < 2 * vocabSize) {
count(a) = 1e9.toInt
a += 1
}
var pos1 = vocabSize - 1
var pos2 = vocabSize

var min1i = 0
var min2i = 0

a = 0
while (a < vocabSize - 1) {
if (pos1 >= 0) {
if (count(pos1) < count(pos2)) {
min1i = pos1
pos1 -= 1
} else {
min1i = pos2
pos2 += 1
}
} else {
min1i = pos2
pos2 += 1
}
if (pos1 >= 0) {
if (count(pos1) < count(pos2)) {
min2i = pos1
pos1 -= 1
} else {
min2i = pos2
pos2 += 1
}
} else {
min2i = pos2
pos2 += 1
}
count(vocabSize + a) = count(min1i) + count(min2i)
parentNode(min1i) = vocabSize + a
parentNode(min2i) = vocabSize + a
binary(min2i) = 1
a += 1
}
// Now assign binary code to each vocabulary word
var i = 0
a = 0
while (a < vocabSize) {
var b = a
i = 0
while (b != vocabSize * 2 - 2) {
code(i) = binary(b)
point(i) = b
i += 1
b = parentNode(b)
}
vocab(a).codeLen = i
vocab(a).point(0) = vocabSize - 2
b = 0
while (b < i) {
vocab(a).code(i - b - 1) = code(b)
vocab(a).point(i - b) = point(b) - vocabSize
b += 1
}
a += 1
}
}

/**
* Computes the vector representation of each word in
* vocabulary
* @param dataset an RDD of strings
*/

def fit(dataset:RDD[String]): Word2VecModel = {

learnVocab(dataset)

createBinaryTree()

val sc = dataset.context

val expTable = sc.broadcast(createExpTable())
val V = sc.broadcast(vocab)
val VHash = sc.broadcast(vocabHash)

val sentences = dataset.flatMap(line => line.split(" ")).mapPartitions {
iter => { new Iterator[Array[Int]] {
def hasNext = iter.hasNext
def next = {
var sentence = new ArrayBuffer[Int]
var sentenceLength = 0
while (iter.hasNext && sentenceLength < MAX_SENTENCE_LENGTH) {
val word = VHash.value.get(iter.next)
word match {
case Some(w) => {
sentence += w
sentenceLength += 1
}
case None =>
}
}
sentence.toArray
}
}
}
}

val newSentences = sentences.repartition(1).cache()
val temp = Array.fill[Double](vocabSize * layer1Size)((Random.nextDouble - 0.5) / layer1Size)
val (aggSyn0, _, _, _) =
// TODO: broadcast temp instead of serializing it directly or initialize the model in each executor
newSentences.aggregate((temp.clone(), new Array[Double](vocabSize * layer1Size), 0, 0))(
seqOp = (c, v) => (c, v) match { case ((syn0, syn1, lastWordCount, wordCount), sentence) =>
var lwc = lastWordCount
var wc = wordCount
if (wordCount - lastWordCount > 10000) {
lwc = wordCount
alpha = startingAlpha * (1 - wordCount.toDouble / (trainWordsCount + 1))
if (alpha < startingAlpha * 0.0001) alpha = startingAlpha * 0.0001
logInfo("wordCount = " + wordCount + ", alpha = " + alpha)
}
wc += sentence.size
var pos = 0
while (pos < sentence.size) {
val word = sentence(pos)
// TODO: fix random seed
val b = Random.nextInt(window)
// Train Skip-gram
var a = b
while (a < window * 2 + 1 - b) {
if (a != window) {
val c = pos - window + a
if (c >= 0 && c < sentence.size) {
val lastWord = sentence(c)
val l1 = lastWord * layer1Size
val neu1e = new Array[Double](layer1Size)
//HS
var d = 0
while (d < vocab(word).codeLen) {
val l2 = vocab(word).point(d) * layer1Size
// Propagate hidden -> output
var f = blas.ddot(layer1Size, syn0, l1, 1, syn1, l2, 1)
if (f > -MAX_EXP && f < MAX_EXP) {
val ind = ((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2.0)).toInt
f = expTable.value(ind)
val g = (1 - vocab(word).code(d) - f) * alpha
blas.daxpy(layer1Size, g, syn1, l2, 1, neu1e, 0, 1)
blas.daxpy(layer1Size, g, syn0, l1, 1, syn1, l2, 1)
}
d += 1
}
blas.daxpy(layer1Size, 1.0, neu1e, 0, 1, syn0, l1, 1)
}
}
a += 1
}
pos += 1
}
(syn0, syn1, lwc, wc)
},
combOp = (c1, c2) => (c1, c2) match { case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) =>
val n = syn0_1.length
blas.daxpy(n, 1.0, syn0_2, 1, syn0_1, 1)
blas.daxpy(n, 1.0, syn1_2, 1, syn1_1, 1)
(syn0_1, syn0_2, lwc_1 + lwc_2, wc_1 + wc_2)
})

val wordMap = new Array[(String, Array[Double])](vocabSize)
var i = 0
while (i < vocabSize) {
val word = vocab(i).word
val vector = new Array[Double](layer1Size)
Array.copy(aggSyn0, i * layer1Size, vector, 0, layer1Size)
wordMap(i) = (word, vector)
i += 1
}
val modelRDD = sc.parallelize(wordMap,100).partitionBy(new HashPartitioner(100))
new Word2VecModel(modelRDD)
}
}

class Word2VecModel (val _model:RDD[(String, Array[Double])]) extends Serializable {

val model = _model

private def distance(v1: Array[Double], v2: Array[Double]): Double = {
require(v1.length == v2.length, "Vectors should have the same length")
val n = v1.length
val norm1 = blas.dnrm2(n, v1, 1)
val norm2 = blas.dnrm2(n, v2, 1)
if (norm1 == 0 || norm2 == 0) return 0.0
blas.ddot(n, v1, 1, v2,1) / norm1 / norm2
}

def transform(word: String): Array[Double] = {
val result = model.lookup(word)
if (result.isEmpty) Array[Double]()
else result(0)
}

def transform(dataset: RDD[String]): RDD[Array[Double]] = {
dataset.map(word => transform(word))
}

def findSynonyms(word: String, num: Int): Array[(String, Double)] = {
val vector = transform(word)
if (vector.isEmpty) Array[(String, Double)]()
else findSynonyms(vector,num)
}

def findSynonyms(vector: Array[Double], num: Int): Array[(String, Double)] = {
require(num > 0, "Number of similar words should > 0")
val topK = model.map(
{case(w, vec) => (distance(vector, vec), w)})
.sortByKey(ascending = false)
.take(num + 1)
.map({case (dist, w) => (w, dist)}).drop(1)

topK
}
}

object Word2Vec extends Serializable with Logging {
def train(
input: RDD[String],
size: Int,
startingAlpha: Double,
window: Int,
minCount: Int): Word2VecModel = {
new Word2Vec(size,startingAlpha, window, minCount).fit(input)
}

def main(args: Array[String]) {
if (args.length < 6) {
println("Usage: word2vec input size startingAlpha window minCount num")
sys.exit(1)
}
val conf = new SparkConf()
.setAppName("word2vec")

val sc = new SparkContext(conf)
val input = sc.textFile(args(0))
val size = args(1).toInt
val startingAlpha = args(2).toDouble
val window = args(3).toInt
val minCount = args(4).toInt
val num = args(5).toInt
val model = train(input, size, startingAlpha, window, minCount)
val vec = model.findSynonyms("china", num)
for((w, dist) <- vec) logInfo(w.toString + " " + dist.toString)
sc.stop()
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* Add a comment to this line
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.feature

import org.scalatest.FunSuite
import org.apache.spark.SparkContext._
import org.apache.spark.mllib.util.LocalSparkContext

class Word2VecSuite extends FunSuite with LocalSparkContext {
test("word2vec") {
val num = 2
val localModel = Seq(
("china" , Array(0.50, 0.50, 0.50, 0.50)),
("japan" , Array(0.40, 0.50, 0.50, 0.50)),
("taiwan", Array(0.60, 0.50, 0.50, 0.50)),
("korea" , Array(0.45, 0.60, 0.60, 0.60))
)
val model = new Word2VecModel(sc.parallelize(localModel, 2))
val synons = model.findSynonyms("china", num)
assert(synons.length == num)
assert(synons(0)._1 == "taiwan")
assert(synons(1)._1 == "japan")
}
}

0 comments on commit 8d6befe

Please sign in to comment.