Skip to content

Commit

Permalink
add unit tests to HashingTF and Tokenizer
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Jan 27, 2015
1 parent ba0ba1e commit 7521d1c
Showing 1 changed file with 46 additions and 0 deletions.
46 changes: 46 additions & 0 deletions python/pyspark/ml/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,19 @@

@inherit_doc
class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol):
"""
A tokenizer that converts the input string to lowercase and then splits it by white spaces.
>>> from pyspark.sql import Row
>>> dataset = sqlCtx.inferSchema(sc.parallelize([Row(text="a b c")]))
>>> tokenizer = Tokenizer() \
.setInputCol("text") \
.setOutputCol("words")
>>> print tokenizer.transform(dataset).first()
Row(text=u'a b c', words=[u'a', u'b', u'c'])
>>> print tokenizer.transform(dataset, {tokenizer.outputCol: "tokens"}).first()
Row(text=u'a b c', tokens=[u'a', u'b', u'c'])
"""

def __init__(self):
super(Tokenizer, self).__init__()
Expand All @@ -33,10 +46,43 @@ def _java_class(self):

@inherit_doc
class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures):
"""
Maps a sequence of terms to their term frequencies using the hashing trick.
>>> from pyspark.sql import Row
>>> dataset = sqlCtx.inferSchema(sc.parallelize([Row(words=["a", "b", "c"])]))
>>> hashingTF = HashingTF() \
.setNumFeatures(10) \
.setInputCol("words") \
.setOutputCol("features")
>>> print hashingTF.transform(dataset).first().features
(10,[7,8,9],[1.0,1.0,1.0])
>>> params = {hashingTF.numFeatures: 5, hashingTF.outputCol: "vector"}
>>> print hashingTF.transform(dataset, params).first().vector
(5,[2,3,4],[1.0,1.0,1.0])
"""

def __init__(self):
super(HashingTF, self).__init__()

@property
def _java_class(self):
return "org.apache.spark.ml.feature.HashingTF"


if __name__ == "__main__":
import doctest
from pyspark.context import SparkContext
from pyspark.sql import Row, SQLContext
globs = globals().copy()
# The small batch size here ensures that we see multiple batches,
# even in these small test examples:
sc = SparkContext("local[2]", "ml.feature tests")
sqlCtx = SQLContext(sc)
globs['sc'] = sc
globs['sqlCtx'] = sqlCtx
(failure_count, test_count) = doctest.testmod(
globs=globs, optionflags=doctest.ELLIPSIS)
sc.stop()
if failure_count:
exit(-1)

0 comments on commit 7521d1c

Please sign in to comment.