-
Notifications
You must be signed in to change notification settings - Fork 28.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-18282][ML][PYSPARK] Add python clustering summaries for GMM and BKM #15777
Changes from 6 commits
edc2c44
c3859da
f599175
6f89617
952d24a
428348d
d6caa02
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -309,13 +309,16 @@ def interceptVector(self): | |
@since("2.0.0") | ||
def summary(self): | ||
""" | ||
Gets summary (e.g. residuals, mse, r-squared ) of model on | ||
training set. An exception is thrown if | ||
`trainingSummary is None`. | ||
Gets summary (e.g. accuracy/precision/recall, objective history, total iterations) of model | ||
trained on the training set. An exception is thrown if `trainingSummary is None`. | ||
""" | ||
java_blrt_summary = self._call_java("summary") | ||
# Note: Once multiclass is added, update this to return correct summary | ||
return BinaryLogisticRegressionTrainingSummary(java_blrt_summary) | ||
if self.hasSummary: | ||
java_blrt_summary = self._call_java("summary") | ||
# Note: Once multiclass is added, update this to return correct summary | ||
return BinaryLogisticRegressionTrainingSummary(java_blrt_summary) | ||
else: | ||
raise RuntimeError("No training summary available for this %s" % | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Before, this would throw a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think thats generally a good improvement, the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I like this change, we should always throw an exception easy to understand by users. |
||
self.__class__.__name__) | ||
|
||
@property | ||
@since("2.0.0") | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,16 +17,74 @@ | |
|
||
from pyspark import since, keyword_only | ||
from pyspark.ml.util import * | ||
from pyspark.ml.wrapper import JavaEstimator, JavaModel | ||
from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaWrapper | ||
from pyspark.ml.param.shared import * | ||
from pyspark.ml.common import inherit_doc | ||
|
||
__all__ = ['BisectingKMeans', 'BisectingKMeansModel', | ||
__all__ = ['BisectingKMeans', 'BisectingKMeansModel', 'BisectingKMeansSummary', | ||
'KMeans', 'KMeansModel', | ||
'GaussianMixture', 'GaussianMixtureModel', | ||
'GaussianMixture', 'GaussianMixtureModel', 'GaussianMixtureSummary', | ||
'LDA', 'LDAModel', 'LocalLDAModel', 'DistributedLDAModel'] | ||
|
||
|
||
class ClusteringSummary(JavaWrapper): | ||
""" | ||
.. note:: Experimental | ||
|
||
Clustering results for a given model. | ||
|
||
.. versionadded:: 2.1.0 | ||
""" | ||
|
||
@property | ||
@since("2.1.0") | ||
def predictionCol(self): | ||
""" | ||
Name for column of predicted clusters in `predictions`. | ||
""" | ||
return self._call_java("predictionCol") | ||
|
||
@property | ||
@since("2.1.0") | ||
def predictions(self): | ||
""" | ||
DataFrame produced by the model's `transform` method. | ||
""" | ||
return self._call_java("predictions") | ||
|
||
@property | ||
@since("2.1.0") | ||
def featuresCol(self): | ||
""" | ||
Name for column of features in `predictions`. | ||
""" | ||
return self._call_java("featuresCol") | ||
|
||
@property | ||
@since("2.1.0") | ||
def k(self): | ||
""" | ||
The number of clusters the model was trained with. | ||
""" | ||
return self._call_java("k") | ||
|
||
@property | ||
@since("2.1.0") | ||
def cluster(self): | ||
""" | ||
DataFrame of predicted cluster centers for each training data point. | ||
""" | ||
return self._call_java("cluster") | ||
|
||
@property | ||
@since("2.1.0") | ||
def clusterSizes(self): | ||
""" | ||
Size of (number of data points in) each cluster. | ||
""" | ||
return self._call_java("clusterSizes") | ||
|
||
|
||
class GaussianMixtureModel(JavaModel, JavaMLWritable, JavaMLReadable): | ||
""" | ||
.. note:: Experimental | ||
|
@@ -56,6 +114,28 @@ def gaussiansDF(self): | |
""" | ||
return self._call_java("gaussiansDF") | ||
|
||
@property | ||
@since("2.1.0") | ||
def hasSummary(self): | ||
""" | ||
Indicates whether a training summary exists for this model | ||
instance. | ||
""" | ||
return self._call_java("hasSummary") | ||
|
||
@property | ||
@since("2.1.0") | ||
def summary(self): | ||
""" | ||
Gets summary (e.g. cluster assignments, cluster sizes) of the model trained on the | ||
training set. An exception is thrown if no summary exists. | ||
""" | ||
if self.hasSummary: | ||
return GaussianMixtureSummary(self._call_java("summary")) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Typo, should be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wow, good catch! |
||
else: | ||
raise RuntimeError("No training summary available for this %s" % | ||
self.__class__.__name__) | ||
|
||
|
||
@inherit_doc | ||
class GaussianMixture(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol, HasSeed, | ||
|
@@ -92,6 +172,13 @@ class GaussianMixture(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIte | |
>>> gm = GaussianMixture(k=3, tol=0.0001, | ||
... maxIter=10, seed=10) | ||
>>> model = gm.fit(df) | ||
>>> model.hasSummary | ||
True | ||
>>> summary = model.summary | ||
>>> summary.k | ||
3 | ||
>>> summary.clusterSizes | ||
[2, 2, 2] | ||
>>> weights = model.weights | ||
>>> len(weights) | ||
3 | ||
|
@@ -118,6 +205,8 @@ class GaussianMixture(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIte | |
>>> model_path = temp_path + "/gmm_model" | ||
>>> model.save(model_path) | ||
>>> model2 = GaussianMixtureModel.load(model_path) | ||
>>> model2.hasSummary | ||
False | ||
>>> model2.weights == model.weights | ||
True | ||
>>> model2.gaussiansDF.show() | ||
|
@@ -181,6 +270,32 @@ def getK(self): | |
return self.getOrDefault(self.k) | ||
|
||
|
||
class GaussianMixtureSummary(ClusteringSummary): | ||
""" | ||
.. note:: Experimental | ||
|
||
Gaussian mixture clustering results for a given model. | ||
|
||
.. versionadded:: 2.1.0 | ||
""" | ||
|
||
@property | ||
@since("2.1.0") | ||
def probabilityCol(self): | ||
""" | ||
Name for column of predicted probability of each cluster in `predictions`. | ||
""" | ||
return self._call_java("probabilityCol") | ||
|
||
@property | ||
@since("2.1.0") | ||
def probability(self): | ||
""" | ||
DataFrame of probabilities of each cluster for each training data point. | ||
""" | ||
return self._call_java("probability") | ||
|
||
|
||
class KMeansModel(JavaModel, JavaMLWritable, JavaMLReadable): | ||
""" | ||
Model fitted by KMeans. | ||
|
@@ -346,6 +461,27 @@ def computeCost(self, dataset): | |
""" | ||
return self._call_java("computeCost", dataset) | ||
|
||
@property | ||
@since("2.1.0") | ||
def hasSummary(self): | ||
""" | ||
Indicates whether a training summary exists for this model instance. | ||
""" | ||
return self._call_java("hasSummary") | ||
|
||
@property | ||
@since("2.1.0") | ||
def summary(self): | ||
""" | ||
Gets summary (e.g. cluster assignments, cluster sizes) of the model trained on the | ||
training set. An exception is thrown if no summary exists. | ||
""" | ||
if self.hasSummary: | ||
return BisectingKMeansSummary(self._call_java("summary")) | ||
else: | ||
raise RuntimeError("No training summary available for this %s" % | ||
self.__class__.__name__) | ||
|
||
|
||
@inherit_doc | ||
class BisectingKMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasSeed, | ||
|
@@ -373,6 +509,13 @@ class BisectingKMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIte | |
2 | ||
>>> model.computeCost(df) | ||
2.000... | ||
>>> model.hasSummary | ||
True | ||
>>> summary = model.summary | ||
>>> summary.k | ||
2 | ||
>>> summary.clusterSizes | ||
[2, 2] | ||
>>> transformed = model.transform(df).select("features", "prediction") | ||
>>> rows = transformed.collect() | ||
>>> rows[0].prediction == rows[1].prediction | ||
|
@@ -387,6 +530,8 @@ class BisectingKMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIte | |
>>> model_path = temp_path + "/bkm_model" | ||
>>> model.save(model_path) | ||
>>> model2 = BisectingKMeansModel.load(model_path) | ||
>>> model2.hasSummary | ||
False | ||
>>> model.clusterCenters()[0] == model2.clusterCenters()[0] | ||
array([ True, True], dtype=bool) | ||
>>> model.clusterCenters()[1] == model2.clusterCenters()[1] | ||
|
@@ -460,6 +605,17 @@ def _create_model(self, java_model): | |
return BisectingKMeansModel(java_model) | ||
|
||
|
||
class BisectingKMeansSummary(ClusteringSummary): | ||
""" | ||
.. note:: Experimental | ||
|
||
Bisecting KMeans clustering results for a given model. | ||
|
||
.. versionadded:: 2.1.0 | ||
""" | ||
pass | ||
|
||
|
||
@inherit_doc | ||
class LDAModel(JavaModel): | ||
""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks better. Could you make the change for Scala LiR, LoR, GLM and KMeans as well? I think they should be consistent. Thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated. I also added tests. Thanks for reviewing!