Skip to content

Commit

Permalink
add unit test for LR
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Jan 27, 2015
1 parent 7521d1c commit a4f4dbf
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 1 deletion.
34 changes: 34 additions & 0 deletions python/pyspark/ml/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,22 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
HasRegParam):
"""
Logistic regression.
>>> from pyspark.sql import Row
>>> from pyspark.mllib.linalg import Vectors
>>> dataset = sqlCtx.inferSchema(sc.parallelize([ \
Row(label=1.0, features=Vectors.dense(1.0)), \
Row(label=0.0, features=Vectors.sparse(1, [], []))]))
>>> lr = LogisticRegression() \
.setMaxIter(5) \
.setRegParam(0.01)
>>> model = lr.fit(dataset)
>>> test0 = sqlCtx.inferSchema(sc.parallelize([Row(features=Vectors.dense(-1.0))]))
>>> print model.transform(test0).first().prediction
0.0
>>> test1 = sqlCtx.inferSchema(sc.parallelize([Row(features=Vectors.sparse(1, [0], [1.0]))]))
>>> print model.transform(test1).first().prediction
1.0
"""

def __init__(self):
Expand All @@ -52,3 +68,21 @@ def __init__(self, java_model):
@property
def _java_class(self):
return "org.apache.spark.ml.classification.LogisticRegressionModel"


if __name__ == "__main__":
import doctest
from pyspark.context import SparkContext
from pyspark.sql import SQLContext
globs = globals().copy()
# The small batch size here ensures that we see multiple batches,
# even in these small test examples:
sc = SparkContext("local[2]", "ml.feature tests")
sqlCtx = SQLContext(sc)
globs['sc'] = sc
globs['sqlCtx'] = sqlCtx
(failure_count, test_count) = doctest.testmod(
globs=globs, optionflags=doctest.ELLIPSIS)
sc.stop()
if failure_count:
exit(-1)
2 changes: 1 addition & 1 deletion python/pyspark/ml/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def _java_class(self):
if __name__ == "__main__":
import doctest
from pyspark.context import SparkContext
from pyspark.sql import Row, SQLContext
from pyspark.sql import SQLContext
globs = globals().copy()
# The small batch size here ensures that we see multiple batches,
# even in these small test examples:
Expand Down

0 comments on commit a4f4dbf

Please sign in to comment.