Skip to content

Commit

Permalink
Add comments, minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Liquan Pei committed Aug 1, 2014
1 parent 8d6befe commit 0aafb1b
Showing 1 changed file with 46 additions and 23 deletions.
69 changes: 46 additions & 23 deletions mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ import org.apache.spark.SparkContext._
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.HashPartitioner

/**
* Entry in vocabulary
*/
private case class VocabWord(
var word: String,
var cn: Int,
Expand All @@ -39,6 +42,9 @@ private case class VocabWord(
var codeLen:Int
)

/**
* Vector representation of word
*/
class Word2Vec(
val size: Int,
val startingAlpha: Double,
Expand All @@ -51,7 +57,8 @@ class Word2Vec(
private val MAX_CODE_LENGTH = 40
private val MAX_SENTENCE_LENGTH = 1000
private val layer1Size = size

private val modelPartitionNum = 100

private var trainWordsCount = 0
private var vocabSize = 0
private var vocab: Array[VocabWord] = null
Expand Down Expand Up @@ -169,6 +176,7 @@ class Word2Vec(
* Computes the vector representation of each word in
* vocabulary
* @param dataset an RDD of strings
* @return a Word2VecModel
*/

def fit(dataset:RDD[String]): Word2VecModel = {
Expand Down Expand Up @@ -274,11 +282,14 @@ class Word2Vec(
wordMap(i) = (word, vector)
i += 1
}
val modelRDD = sc.parallelize(wordMap,100).partitionBy(new HashPartitioner(100))
val modelRDD = sc.parallelize(wordMap, modelPartitionNum).partitionBy(new HashPartitioner(modelPartitionNum))
new Word2VecModel(modelRDD)
}
}

/**
* Word2Vec model
*/
class Word2VecModel (val _model:RDD[(String, Array[Double])]) extends Serializable {

val model = _model
Expand All @@ -292,22 +303,46 @@ class Word2VecModel (val _model:RDD[(String, Array[Double])]) extends Serializab
blas.ddot(n, v1, 1, v2,1) / norm1 / norm2
}

/**
* Transforms a word to its vector representation
* @param word a word
* @return vector representation of word
*/

def transform(word: String): Array[Double] = {
val result = model.lookup(word)
if (result.isEmpty) Array[Double]()
else result(0)
}

/**
* Transforms an RDD to its vector representation
* @param dataset a an RDD of words
* @return RDD of vector representation
*/

def transform(dataset: RDD[String]): RDD[Array[Double]] = {
dataset.map(word => transform(word))
}

/**
* Find synonyms of a word
* @param word a word
* @param num number of synonyms to find
* @return array of (word, similarity)
*/
def findSynonyms(word: String, num: Int): Array[(String, Double)] = {
val vector = transform(word)
if (vector.isEmpty) Array[(String, Double)]()
else findSynonyms(vector,num)
}

/**
* Find synonyms of the vector representation of a word
* @param vector vector representation of a word
* @param num number of synonyms to find
* @return array of (word, similarity)
*/
def findSynonyms(vector: Array[Double], num: Int): Array[(String, Double)] = {
require(num > 0, "Number of similar words should > 0")
val topK = model.map(
Expand All @@ -321,6 +356,15 @@ class Word2VecModel (val _model:RDD[(String, Array[Double])]) extends Serializab
}

object Word2Vec extends Serializable with Logging {
/**
* Train Word2Vec model
* @param input RDD of words
* @param size vectoer dimension
* @param startingAlpha initial learning rate
* @param window context words from [-window, window]
* @param minCount minimum frequncy to consider a vocabulary word
* @return Word2Vec model
*/
def train(
input: RDD[String],
size: Int,
Expand All @@ -329,25 +373,4 @@ object Word2Vec extends Serializable with Logging {
minCount: Int): Word2VecModel = {
new Word2Vec(size,startingAlpha, window, minCount).fit(input)
}

def main(args: Array[String]) {
if (args.length < 6) {
println("Usage: word2vec input size startingAlpha window minCount num")
sys.exit(1)
}
val conf = new SparkConf()
.setAppName("word2vec")

val sc = new SparkContext(conf)
val input = sc.textFile(args(0))
val size = args(1).toInt
val startingAlpha = args(2).toDouble
val window = args(3).toInt
val minCount = args(4).toInt
val num = args(5).toInt
val model = train(input, size, startingAlpha, window, minCount)
val vec = model.findSynonyms("china", num)
for((w, dist) <- vec) logInfo(w.toString + " " + dist.toString)
sc.stop()
}
}

0 comments on commit 0aafb1b

Please sign in to comment.