Skip to content

Commit

Permalink
[SPARK-7474] [MLLIB] update ParamGridBuilder doctest
Browse files Browse the repository at this point in the history
Multiline commands are properly handled in this PR. oefirouz

![screen shot 2015-05-07 at 10 53 25 pm](https://cloud.githubusercontent.com/assets/829644/7531290/02ad2fd4-f50c-11e4-8c04-e58d1a61ad69.png)

Author: Xiangrui Meng <[email protected]>

Closes apache#6001 from mengxr/SPARK-7474 and squashes the following commits:

b94b11d [Xiangrui Meng] update ParamGridBuilder doctest
  • Loading branch information
mengxr authored and nemccarthy committed Jun 19, 2015
1 parent 681d4a5 commit 1641a80
Showing 1 changed file with 13 additions and 15 deletions.
28 changes: 13 additions & 15 deletions python/pyspark/ml/tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,24 +27,22 @@


class ParamGridBuilder(object):
"""
r"""
Builder for a param grid used in grid search-based model selection.
>>> from classification import LogisticRegression
>>> from pyspark.ml.classification import LogisticRegression
>>> lr = LogisticRegression()
>>> output = ParamGridBuilder().baseOn({lr.labelCol: 'l'}) \
.baseOn([lr.predictionCol, 'p']) \
.addGrid(lr.regParam, [1.0, 2.0, 3.0]) \
.addGrid(lr.maxIter, [1, 5]) \
.addGrid(lr.featuresCol, ['f']) \
.build()
>>> expected = [ \
{lr.regParam: 1.0, lr.featuresCol: 'f', lr.maxIter: 1, lr.labelCol: 'l', lr.predictionCol: 'p'}, \
{lr.regParam: 2.0, lr.featuresCol: 'f', lr.maxIter: 1, lr.labelCol: 'l', lr.predictionCol: 'p'}, \
{lr.regParam: 3.0, lr.featuresCol: 'f', lr.maxIter: 1, lr.labelCol: 'l', lr.predictionCol: 'p'}, \
{lr.regParam: 1.0, lr.featuresCol: 'f', lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}, \
{lr.regParam: 2.0, lr.featuresCol: 'f', lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}, \
{lr.regParam: 3.0, lr.featuresCol: 'f', lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}]
>>> output = ParamGridBuilder() \
... .baseOn({lr.labelCol: 'l'}) \
... .baseOn([lr.predictionCol, 'p']) \
... .addGrid(lr.regParam, [1.0, 2.0]) \
... .addGrid(lr.maxIter, [1, 5]) \
... .build()
>>> expected = [
... {lr.regParam: 1.0, lr.maxIter: 1, lr.labelCol: 'l', lr.predictionCol: 'p'},
... {lr.regParam: 2.0, lr.maxIter: 1, lr.labelCol: 'l', lr.predictionCol: 'p'},
... {lr.regParam: 1.0, lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'},
... {lr.regParam: 2.0, lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}]
>>> len(output) == len(expected)
True
>>> all([m in expected for m in output])
Expand Down

0 comments on commit 1641a80

Please sign in to comment.