Skip to content

Commit

Permalink
[SPARK-12016] [MLLIB] [PYSPARK] Wrap Word2VecModel when loading it in…
Browse files Browse the repository at this point in the history
… pyspark

JIRA: https://issues.apache.org/jira/browse/SPARK-12016

We should not directly use Word2VecModel in pyspark. We need to wrap it in a Word2VecModelWrapper when loading it in pyspark.

Author: Liang-Chi Hsieh <[email protected]>

Closes apache#10100 from viirya/fix-load-py-wordvecmodel.
  • Loading branch information
viirya authored and davies committed Dec 14, 2015
1 parent e25f1fe commit b51a4cd
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -680,39 +680,6 @@ private[python] class PythonMLLibAPI extends Serializable {
}
}

private[python] class Word2VecModelWrapper(model: Word2VecModel) {
def transform(word: String): Vector = {
model.transform(word)
}

/**
* Transforms an RDD of words to its vector representation
* @param rdd an RDD of words
* @return an RDD of vector representations of words
*/
def transform(rdd: JavaRDD[String]): JavaRDD[Vector] = {
rdd.rdd.map(model.transform)
}

def findSynonyms(word: String, num: Int): JList[Object] = {
val vec = transform(word)
findSynonyms(vec, num)
}

def findSynonyms(vector: Vector, num: Int): JList[Object] = {
val result = model.findSynonyms(vector, num)
val similarity = Vectors.dense(result.map(_._2))
val words = result.map(_._1)
List(words, similarity).map(_.asInstanceOf[Object]).asJava
}

def getVectors: JMap[String, JList[Float]] = {
model.getVectors.map({case (k, v) => (k, v.toList.asJava)}).asJava
}

def save(sc: SparkContext, path: String): Unit = model.save(sc, path)
}

/**
* Java stub for Python mllib DecisionTree.train().
* This stub returns a handle to the Java object instead of the content of the Java object.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.api.python

import java.util.{ArrayList => JArrayList, List => JList, Map => JMap}
import scala.collection.JavaConverters._

import org.apache.spark.SparkContext
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.feature.Word2VecModel
import org.apache.spark.mllib.linalg.{Vector, Vectors}

/**
* Wrapper around Word2VecModel to provide helper methods in Python
*/
private[python] class Word2VecModelWrapper(model: Word2VecModel) {
def transform(word: String): Vector = {
model.transform(word)
}

/**
* Transforms an RDD of words to its vector representation
* @param rdd an RDD of words
* @return an RDD of vector representations of words
*/
def transform(rdd: JavaRDD[String]): JavaRDD[Vector] = {
rdd.rdd.map(model.transform)
}

def findSynonyms(word: String, num: Int): JList[Object] = {
val vec = transform(word)
findSynonyms(vec, num)
}

def findSynonyms(vector: Vector, num: Int): JList[Object] = {
val result = model.findSynonyms(vector, num)
val similarity = Vectors.dense(result.map(_._2))
val words = result.map(_._1)
List(words, similarity).map(_.asInstanceOf[Object]).asJava
}

def getVectors: JMap[String, JList[Float]] = {
model.getVectors.map({case (k, v) => (k, v.toList.asJava)}).asJava
}

def save(sc: SparkContext, path: String): Unit = model.save(sc, path)
}
6 changes: 5 additions & 1 deletion python/pyspark/mllib/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,8 @@ def load(cls, sc, path):
"""
jmodel = sc._jvm.org.apache.spark.mllib.feature \
.Word2VecModel.load(sc._jsc.sc(), path)
return Word2VecModel(jmodel)
model = sc._jvm.Word2VecModelWrapper(jmodel)
return Word2VecModel(model)


@ignore_unicode_prefix
Expand Down Expand Up @@ -546,6 +547,9 @@ class Word2Vec(object):
>>> sameModel = Word2VecModel.load(sc, path)
>>> model.transform("a") == sameModel.transform("a")
True
>>> syms = sameModel.findSynonyms("a", 2)
>>> [s[0] for s in syms]
[u'b', u'c']
>>> from shutil import rmtree
>>> try:
... rmtree(path)
Expand Down

0 comments on commit b51a4cd

Please sign in to comment.