Skip to content

Commit

Permalink
[SPARK-29867][ML][PYTHON] Add __repr__ in Python ML Models
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Add ```__repr__``` in Python ML Models

### Why are the changes needed?
In Python ML Models, some of them have ```__repr__```, others don't. In the doctest, when calling Model.setXXX, some of the Models print out the xxxModel... correctly, some of them can't because of lacking the  ```__repr__``` method. For example:
```
    >>> gm = GaussianMixture(k=3, tol=0.0001, seed=10)
    >>> model = gm.fit(df)
    >>> model.setPredictionCol("newPrediction")
    GaussianMixture...
```
After the change, the above code will become the following:
```
    >>> gm = GaussianMixture(k=3, tol=0.0001, seed=10)
    >>> model = gm.fit(df)
    >>> model.setPredictionCol("newPrediction")
    GaussianMixtureModel...
```

### Does this PR introduce any user-facing change?
Yes.

### How was this patch tested?
doctest

Closes apache#26489 from huaxingao/spark-29876.

Authored-by: Huaxin Gao <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
  • Loading branch information
huaxingao authored and dongjoon-hyun committed Nov 16, 2019
1 parent 6d6b233 commit 1112fc6
Show file tree
Hide file tree
Showing 8 changed files with 44 additions and 28 deletions.
11 changes: 4 additions & 7 deletions python/pyspark/ml/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,11 +192,11 @@ class LinearSVC(JavaClassifier, _LinearSVCParams, JavaMLWritable, JavaMLReadable
0.01
>>> model = svm.fit(df)
>>> model.setPredictionCol("newPrediction")
LinearSVC...
LinearSVCModel...
>>> model.getPredictionCol()
'newPrediction'
>>> model.setThreshold(0.5)
LinearSVC...
LinearSVCModel...
>>> model.getThreshold()
0.5
>>> model.coefficients
Expand Down Expand Up @@ -812,9 +812,6 @@ def evaluate(self, dataset):
java_blr_summary = self._call_java("evaluate", dataset)
return BinaryLogisticRegressionSummary(java_blr_summary)

def __repr__(self):
return self._call_java("toString")


class LogisticRegressionSummary(JavaWrapper):
"""
Expand Down Expand Up @@ -1921,7 +1918,7 @@ class NaiveBayes(JavaProbabilisticClassifier, _NaiveBayesParams, HasThresholds,
>>> nb = NaiveBayes(smoothing=1.0, modelType="multinomial", weightCol="weight")
>>> model = nb.fit(df)
>>> model.setFeaturesCol("features")
NaiveBayes_...
NaiveBayesModel...
>>> model.getSmoothing()
1.0
>>> model.pi
Expand Down Expand Up @@ -2114,7 +2111,7 @@ class MultilayerPerceptronClassifier(JavaProbabilisticClassifier, _MultilayerPer
100
>>> model = mlp.fit(df)
>>> model.setFeaturesCol("features")
MultilayerPerceptronClassifier...
MultilayerPerceptronClassificationModel...
>>> model.layers
[2, 2, 2]
>>> model.weights.size
Expand Down
8 changes: 5 additions & 3 deletions python/pyspark/ml/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ class GaussianMixture(JavaEstimator, _GaussianMixtureParams, JavaMLWritable, Jav
>>> model.getFeaturesCol()
'features'
>>> model.setPredictionCol("newPrediction")
GaussianMixture...
GaussianMixtureModel...
>>> model.predict(df.head().features)
2
>>> model.predictProbability(df.head().features)
Expand Down Expand Up @@ -532,7 +532,7 @@ class KMeans(JavaEstimator, _KMeansParams, JavaMLWritable, JavaMLReadable):
>>> model.getDistanceMeasure()
'euclidean'
>>> model.setPredictionCol("newPrediction")
KMeans...
KMeansModel...
>>> model.predict(df.head().features)
0
>>> centers = model.clusterCenters()
Expand Down Expand Up @@ -794,7 +794,7 @@ class BisectingKMeans(JavaEstimator, _BisectingKMeansParams, JavaMLWritable, Jav
>>> model.getMaxIter()
20
>>> model.setPredictionCol("newPrediction")
BisectingKMeans...
BisectingKMeansModel...
>>> model.predict(df.head().features)
0
>>> centers = model.clusterCenters()
Expand Down Expand Up @@ -1265,6 +1265,8 @@ class LDA(JavaEstimator, _LDAParams, JavaMLReadable, JavaMLWritable):
10
>>> lda.clear(lda.maxIter)
>>> model = lda.fit(df)
>>> model.setSeed(1)
DistributedLDAModel...
>>> model.getTopicDistributionCol()
'topicDistribution'
>>> model.isDistributed()
Expand Down
30 changes: 24 additions & 6 deletions python/pyspark/ml/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,8 @@ class BucketedRandomProjectionLSH(_LSH, _BucketedRandomProjectionLSHParams,
>>> model = brp.fit(df)
>>> model.getBucketLength()
1.0
>>> model.setOutputCol("hashes")
BucketedRandomProjectionLSHModel...
>>> model.transform(df).head()
Row(id=0, features=DenseVector([-1.0, -1.0]), hashes=[DenseVector([-1.0])])
>>> data2 = [(4, Vectors.dense([2.0, 2.0 ]),),
Expand Down Expand Up @@ -733,6 +735,8 @@ class CountVectorizer(JavaEstimator, _CountVectorizerParams, JavaMLReadable, Jav
>>> cv.setOutputCol("vectors")
CountVectorizer...
>>> model = cv.fit(df)
>>> model.setInputCol("raw")
CountVectorizerModel...
>>> model.transform(df).show(truncate=False)
+-----+---------------+-------------------------+
|label|raw |vectors |
Expand Down Expand Up @@ -1345,6 +1349,8 @@ class IDF(JavaEstimator, _IDFParams, JavaMLReadable, JavaMLWritable):
>>> idf.setOutputCol("idf")
IDF...
>>> model = idf.fit(df)
>>> model.setOutputCol("idf")
IDFModel...
>>> model.getMinDocFreq()
3
>>> model.idf
Expand Down Expand Up @@ -1519,6 +1525,8 @@ class Imputer(JavaEstimator, _ImputerParams, JavaMLReadable, JavaMLWritable):
>>> imputer.getRelativeError()
0.001
>>> model = imputer.fit(df)
>>> model.setInputCols(["a", "b"])
ImputerModel...
>>> model.getStrategy()
'mean'
>>> model.surrogateDF.show()
Expand Down Expand Up @@ -1810,7 +1818,7 @@ class MaxAbsScaler(JavaEstimator, _MaxAbsScalerParams, JavaMLReadable, JavaMLWri
MaxAbsScaler...
>>> model = maScaler.fit(df)
>>> model.setOutputCol("scaledOutput")
MaxAbsScaler...
MaxAbsScalerModel...
>>> model.transform(df).show()
+-----+------------+
| a|scaledOutput|
Expand Down Expand Up @@ -1928,6 +1936,8 @@ class MinHashLSH(_LSH, HasInputCol, HasOutputCol, HasSeed, JavaMLReadable, JavaM
>>> mh.setSeed(12345)
MinHashLSH...
>>> model = mh.fit(df)
>>> model.setInputCol("features")
MinHashLSHModel...
>>> model.transform(df).head()
Row(id=0, features=SparseVector(6, {0: 1.0, 1: 1.0, 2: 1.0}), hashes=[DenseVector([6179668...
>>> data2 = [(3, Vectors.sparse(6, [1, 3, 5], [1.0, 1.0, 1.0]),),
Expand Down Expand Up @@ -2056,7 +2066,7 @@ class MinMaxScaler(JavaEstimator, _MinMaxScalerParams, JavaMLReadable, JavaMLWri
MinMaxScaler...
>>> model = mmScaler.fit(df)
>>> model.setOutputCol("scaledOutput")
MinMaxScaler...
MinMaxScalerModel...
>>> model.originalMin
DenseVector([0.0])
>>> model.originalMax
Expand Down Expand Up @@ -2421,6 +2431,8 @@ class OneHotEncoder(JavaEstimator, _OneHotEncoderParams, JavaMLReadable, JavaMLW
>>> ohe.setOutputCols(["output"])
OneHotEncoder...
>>> model = ohe.fit(df)
>>> model.setOutputCols(["output"])
OneHotEncoderModel...
>>> model.getHandleInvalid()
'error'
>>> model.transform(df).head().output
Expand Down Expand Up @@ -2935,7 +2947,7 @@ class RobustScaler(JavaEstimator, _RobustScalerParams, JavaMLReadable, JavaMLWri
RobustScaler...
>>> model = scaler.fit(df)
>>> model.setOutputCol("output")
RobustScaler...
RobustScalerModel...
>>> model.median
DenseVector([2.0, -2.0])
>>> model.range
Expand Down Expand Up @@ -3330,7 +3342,7 @@ class StandardScaler(JavaEstimator, _StandardScalerParams, JavaMLReadable, JavaM
>>> model.getInputCol()
'a'
>>> model.setOutputCol("output")
StandardScaler...
StandardScalerModel...
>>> model.mean
DenseVector([1.0])
>>> model.std
Expand Down Expand Up @@ -3490,6 +3502,8 @@ class StringIndexer(JavaEstimator, _StringIndexerParams, JavaMLReadable, JavaMLW
>>> stringIndexer.setHandleInvalid("error")
StringIndexer...
>>> model = stringIndexer.fit(stringIndDf)
>>> model.setHandleInvalid("error")
StringIndexerModel...
>>> td = model.transform(stringIndDf)
>>> sorted(set([(i[0], i[1]) for i in td.select(td.id, td.indexed).collect()]),
... key=lambda x: x[0])
Expand Down Expand Up @@ -4166,7 +4180,7 @@ class VectorIndexer(JavaEstimator, _VectorIndexerParams, JavaMLReadable, JavaMLW
>>> indexer.getHandleInvalid()
'error'
>>> model.setOutputCol("output")
VectorIndexer...
VectorIndexerModel...
>>> model.transform(df).head().output
DenseVector([1.0, 0.0])
>>> model.numFeatures
Expand Down Expand Up @@ -4487,6 +4501,8 @@ class Word2Vec(JavaEstimator, _Word2VecParams, JavaMLReadable, JavaMLWritable):
>>> model = word2Vec.fit(doc)
>>> model.getMinCount()
5
>>> model.setInputCol("sentence")
Word2VecModel...
>>> model.getVectors().show()
+----+--------------------+
|word| vector|
Expand Down Expand Up @@ -4714,7 +4730,7 @@ class PCA(JavaEstimator, _PCAParams, JavaMLReadable, JavaMLWritable):
>>> model.getK()
2
>>> model.setOutputCol("output")
PCA...
PCAModel...
>>> model.transform(df).collect()[0].output
DenseVector([1.648..., -4.013...])
>>> model.explainedVariance
Expand Down Expand Up @@ -5139,6 +5155,8 @@ class ChiSqSelector(JavaEstimator, _ChiSqSelectorParams, JavaMLReadable, JavaMLW
>>> model = selector.fit(df)
>>> model.getFeaturesCol()
'features'
>>> model.setFeaturesCol("features")
ChiSqSelectorModel...
>>> model.transform(df).head().selectedFeatures
DenseVector([18.0])
>>> model.selectedFeatures
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/ml/fpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ class FPGrowth(JavaEstimator, _FPGrowthParams, JavaMLWritable, JavaMLReadable):
>>> fp = FPGrowth(minSupport=0.2, minConfidence=0.7)
>>> fpm = fp.fit(data)
>>> fpm.setPredictionCol("newPrediction")
FPGrowth...
FPGrowthModel...
>>> fpm.freqItemsets.show(5)
+---------+----+
| items|freq|
Expand Down
2 changes: 2 additions & 0 deletions python/pyspark/ml/recommendation.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,8 @@ class ALS(JavaEstimator, _ALSParams, JavaMLWritable, JavaMLReadable):
>>> model = als.fit(df)
>>> model.getUserCol()
'user'
>>> model.setUserCol("user")
ALSModel...
>>> model.getItemCol()
'item'
>>> model.setPredictionCol("newPrediction")
Expand Down
10 changes: 5 additions & 5 deletions python/pyspark/ml/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,9 @@ class LinearRegression(JavaPredictor, _LinearRegressionParams, JavaMLWritable, J
LinearRegression...
>>> model = lr.fit(df)
>>> model.setFeaturesCol("features")
LinearRegression...
LinearRegressionModel...
>>> model.setPredictionCol("newPrediction")
LinearRegression...
LinearRegressionModel...
>>> model.getMaxIter()
5
>>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
Expand Down Expand Up @@ -591,7 +591,7 @@ class IsotonicRegression(JavaEstimator, _IsotonicRegressionParams, HasWeightCol,
>>> ir = IsotonicRegression()
>>> model = ir.fit(df)
>>> model.setFeaturesCol("features")
IsotonicRegression...
IsotonicRegressionModel...
>>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
>>> model.transform(test0).head().prediction
0.0
Expand Down Expand Up @@ -1546,7 +1546,7 @@ class AFTSurvivalRegression(JavaEstimator, _AFTSurvivalRegressionParams,
>>> aftsr.clear(aftsr.maxIter)
>>> model = aftsr.fit(df)
>>> model.setFeaturesCol("features")
AFTSurvivalRegression...
AFTSurvivalRegressionModel...
>>> model.predict(Vectors.dense(6.3))
1.0
>>> model.predictQuantiles(Vectors.dense(6.3))
Expand Down Expand Up @@ -1881,7 +1881,7 @@ class GeneralizedLinearRegression(JavaPredictor, _GeneralizedLinearRegressionPar
>>> glr.clear(glr.maxIter)
>>> model = glr.fit(df)
>>> model.setFeaturesCol("features")
GeneralizedLinearRegression...
GeneralizedLinearRegressionModel...
>>> model.getMaxIter()
25
>>> model.getAggregationDepth()
Expand Down
6 changes: 0 additions & 6 deletions python/pyspark/ml/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,6 @@ def predictLeaf(self, value):
"""
return self._call_java("predictLeaf", value)

def __repr__(self):
return self._call_java("toString")


class _DecisionTreeParams(HasCheckpointInterval, HasSeed, HasWeightCol):
"""
Expand Down Expand Up @@ -208,9 +205,6 @@ def predictLeaf(self, value):
"""
return self._call_java("predictLeaf", value)

def __repr__(self):
return self._call_java("toString")


class _TreeEnsembleParams(_DecisionTreeParams):
"""
Expand Down
3 changes: 3 additions & 0 deletions python/pyspark/ml/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,9 @@ def __init__(self, java_model=None):

self._resetUid(java_model.uid())

def __repr__(self):
return self._call_java("toString")


@inherit_doc
class _JavaPredictorParams(HasLabelCol, HasFeaturesCol, HasPredictionCol):
Expand Down

0 comments on commit 1112fc6

Please sign in to comment.