diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 18b9b3043db8a..1c25287efd238 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -661,7 +661,7 @@ class LogisticRegression @Since("1.2.0") ( $(labelCol), $(featuresCol), objectiveHistory) - model.setSummary(logRegSummary) + model.setSummary(Some(logRegSummary)) } else { model } @@ -803,9 +803,9 @@ class LogisticRegressionModel private[spark] ( } } - private[classification] def setSummary( - summary: LogisticRegressionTrainingSummary): this.type = { - this.trainingSummary = Some(summary) + private[classification] + def setSummary(summary: Option[LogisticRegressionTrainingSummary]): this.type = { + this.trainingSummary = summary this } @@ -900,8 +900,7 @@ class LogisticRegressionModel private[spark] ( override def copy(extra: ParamMap): LogisticRegressionModel = { val newModel = copyValues(new LogisticRegressionModel(uid, coefficientMatrix, interceptVector, numClasses, isMultinomial), extra) - if (trainingSummary.isDefined) newModel.setSummary(trainingSummary.get) - newModel.setParent(parent) + newModel.setSummary(trainingSummary).setParent(parent) } override protected def raw2prediction(rawPrediction: Vector): Double = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala index f8a606d60b2aa..e6ca3aedffd9d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala @@ -95,8 +95,7 @@ class BisectingKMeansModel private[ml] ( @Since("2.0.0") override def copy(extra: ParamMap): BisectingKMeansModel = { val copied = copyValues(new BisectingKMeansModel(uid, parentModel), extra) - if (trainingSummary.isDefined) copied.setSummary(trainingSummary.get) - copied.setParent(this.parent) + copied.setSummary(trainingSummary).setParent(this.parent) } @Since("2.0.0") @@ -132,8 +131,8 @@ class BisectingKMeansModel private[ml] ( private var trainingSummary: Option[BisectingKMeansSummary] = None - private[clustering] def setSummary(summary: BisectingKMeansSummary): this.type = { - this.trainingSummary = Some(summary) + private[clustering] def setSummary(summary: Option[BisectingKMeansSummary]): this.type = { + this.trainingSummary = summary this } @@ -265,7 +264,7 @@ class BisectingKMeans @Since("2.0.0") ( val model = copyValues(new BisectingKMeansModel(uid, parentModel).setParent(this)) val summary = new BisectingKMeansSummary( model.transform(dataset), $(predictionCol), $(featuresCol), $(k)) - model.setSummary(summary) + model.setSummary(Some(summary)) instr.logSuccess(model) model } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index a0bd66e731a1d..9522d6671abbb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -90,8 +90,7 @@ class GaussianMixtureModel private[ml] ( @Since("2.0.0") override def copy(extra: ParamMap): GaussianMixtureModel = { val copied = copyValues(new GaussianMixtureModel(uid, weights, gaussians), extra) - if (trainingSummary.isDefined) copied.setSummary(trainingSummary.get) - copied.setParent(this.parent) + copied.setSummary(trainingSummary).setParent(this.parent) } @Since("2.0.0") @@ -150,8 +149,8 @@ class GaussianMixtureModel private[ml] ( private var trainingSummary: Option[GaussianMixtureSummary] = None - private[clustering] def setSummary(summary: GaussianMixtureSummary): this.type = { - this.trainingSummary = Some(summary) + private[clustering] def setSummary(summary: Option[GaussianMixtureSummary]): this.type = { + this.trainingSummary = summary this } @@ -340,7 +339,7 @@ class GaussianMixture @Since("2.0.0") ( .setParent(this) val summary = new GaussianMixtureSummary(model.transform(dataset), $(predictionCol), $(probabilityCol), $(featuresCol), $(k)) - model.setSummary(summary) + model.setSummary(Some(summary)) instr.logNumFeatures(model.gaussians.head.mean.size) instr.logSuccess(model) model diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 26505b4cc1501..152bd13b7a17a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -110,8 +110,7 @@ class KMeansModel private[ml] ( @Since("1.5.0") override def copy(extra: ParamMap): KMeansModel = { val copied = copyValues(new KMeansModel(uid, parentModel), extra) - if (trainingSummary.isDefined) copied.setSummary(trainingSummary.get) - copied.setParent(this.parent) + copied.setSummary(trainingSummary).setParent(this.parent) } /** @group setParam */ @@ -165,8 +164,8 @@ class KMeansModel private[ml] ( private var trainingSummary: Option[KMeansSummary] = None - private[clustering] def setSummary(summary: KMeansSummary): this.type = { - this.trainingSummary = Some(summary) + private[clustering] def setSummary(summary: Option[KMeansSummary]): this.type = { + this.trainingSummary = summary this } @@ -325,7 +324,7 @@ class KMeans @Since("1.5.0") ( val model = copyValues(new KMeansModel(uid, parentModel).setParent(this)) val summary = new KMeansSummary( model.transform(dataset), $(predictionCol), $(featuresCol), $(k)) - model.setSummary(summary) + model.setSummary(Some(summary)) instr.logSuccess(model) model } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index 1d2961e0277f5..62f47d25524d6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -270,7 +270,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val .setParent(this)) val trainingSummary = new GeneralizedLinearRegressionTrainingSummary(dataset, model, wlsModel.diagInvAtWA.toArray, 1, getSolver) - return model.setSummary(trainingSummary) + return model.setSummary(Some(trainingSummary)) } // Fit Generalized Linear Model by iteratively reweighted least squares (IRLS). @@ -284,7 +284,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val .setParent(this)) val trainingSummary = new GeneralizedLinearRegressionTrainingSummary(dataset, model, irlsModel.diagInvAtWA.toArray, irlsModel.numIterations, getSolver) - model.setSummary(trainingSummary) + model.setSummary(Some(trainingSummary)) } @Since("2.0.0") @@ -761,8 +761,8 @@ class GeneralizedLinearRegressionModel private[ml] ( def hasSummary: Boolean = trainingSummary.nonEmpty private[regression] - def setSummary(summary: GeneralizedLinearRegressionTrainingSummary): this.type = { - this.trainingSummary = Some(summary) + def setSummary(summary: Option[GeneralizedLinearRegressionTrainingSummary]): this.type = { + this.trainingSummary = summary this } @@ -778,8 +778,7 @@ class GeneralizedLinearRegressionModel private[ml] ( override def copy(extra: ParamMap): GeneralizedLinearRegressionModel = { val copied = copyValues(new GeneralizedLinearRegressionModel(uid, coefficients, intercept), extra) - if (trainingSummary.isDefined) copied.setSummary(trainingSummary.get) - copied.setParent(parent) + copied.setSummary(trainingSummary).setParent(parent) } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 71c542adf6f6f..5fba9ab9e28ce 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -223,7 +223,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String model.diagInvAtWA.toArray, model.objectiveHistory) - return lrModel.setSummary(trainingSummary) + return lrModel.setSummary(Some(trainingSummary)) } val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE @@ -276,7 +276,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String model, Array(0D), Array(0D)) - return model.setSummary(trainingSummary) + return model.setSummary(Some(trainingSummary)) } else { require($(regParam) == 0.0, "The standard deviation of the label is zero. " + "Model cannot be regularized.") @@ -398,7 +398,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String model, Array(0D), objectiveHistory) - model.setSummary(trainingSummary) + model.setSummary(Some(trainingSummary)) } @Since("1.4.0") @@ -444,8 +444,9 @@ class LinearRegressionModel private[ml] ( throw new SparkException("No training summary available for this LinearRegressionModel") } - private[regression] def setSummary(summary: LinearRegressionTrainingSummary): this.type = { - this.trainingSummary = Some(summary) + private[regression] + def setSummary(summary: Option[LinearRegressionTrainingSummary]): this.type = { + this.trainingSummary = summary this } @@ -488,8 +489,7 @@ class LinearRegressionModel private[ml] ( @Since("1.4.0") override def copy(extra: ParamMap): LinearRegressionModel = { val newModel = copyValues(new LinearRegressionModel(uid, coefficients, intercept), extra) - if (trainingSummary.isDefined) newModel.setSummary(trainingSummary.get) - newModel.setParent(parent) + newModel.setSummary(trainingSummary).setParent(parent) } /** diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 2877285eb4d59..e360542eae2ab 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -147,6 +147,8 @@ class LogisticRegressionSuite assert(model.hasSummary) val copiedModel = model.copy(ParamMap.empty) assert(copiedModel.hasSummary) + model.setSummary(None) + assert(!model.hasSummary) } test("empty probabilityCol") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala index 49797d938d751..fc491cd6161fd 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala @@ -109,6 +109,9 @@ class BisectingKMeansSuite assert(clusterSizes.length === k) assert(clusterSizes.sum === numRows) assert(clusterSizes.forall(_ >= 0)) + + model.setSummary(None) + assert(!model.hasSummary) } test("read/write") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala index 7165b63ed3b96..07299123f8a47 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala @@ -111,6 +111,9 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext assert(clusterSizes.length === k) assert(clusterSizes.sum === numRows) assert(clusterSizes.forall(_ >= 0)) + + model.setSummary(None) + assert(!model.hasSummary) } test("read/write") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index 73972557d2631..c1b7242e11a8f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -123,6 +123,9 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR assert(clusterSizes.length === k) assert(clusterSizes.sum === numRows) assert(clusterSizes.forall(_ >= 0)) + + model.setSummary(None) + assert(!model.hasSummary) } test("KMeansModel transform with non-default feature and prediction cols") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala index 6a4ac1735b2cb..9b0fa67630d2e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala @@ -197,6 +197,8 @@ class GeneralizedLinearRegressionSuite assert(model.hasSummary) val copiedModel = model.copy(ParamMap.empty) assert(copiedModel.hasSummary) + model.setSummary(None) + assert(!model.hasSummary) assert(model.getFeaturesCol === "features") assert(model.getPredictionCol === "prediction") diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index df97d0b2ae7ad..0be82742a33be 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -146,6 +146,8 @@ class LinearRegressionSuite assert(model.hasSummary) val copiedModel = model.copy(ParamMap.empty) assert(copiedModel.hasSummary) + model.setSummary(None) + assert(!model.hasSummary) model.transform(datasetWithDenseFeature) .select("label", "prediction") diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 56c8c62259e79..83e1e89347660 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -309,13 +309,16 @@ def interceptVector(self): @since("2.0.0") def summary(self): """ - Gets summary (e.g. residuals, mse, r-squared ) of model on - training set. An exception is thrown if - `trainingSummary is None`. + Gets summary (e.g. accuracy/precision/recall, objective history, total iterations) of model + trained on the training set. An exception is thrown if `trainingSummary is None`. """ - java_blrt_summary = self._call_java("summary") - # Note: Once multiclass is added, update this to return correct summary - return BinaryLogisticRegressionTrainingSummary(java_blrt_summary) + if self.hasSummary: + java_blrt_summary = self._call_java("summary") + # Note: Once multiclass is added, update this to return correct summary + return BinaryLogisticRegressionTrainingSummary(java_blrt_summary) + else: + raise RuntimeError("No training summary available for this %s" % + self.__class__.__name__) @property @since("2.0.0") diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index 7632f05c3b68c..e58ec1e7ac296 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -17,16 +17,74 @@ from pyspark import since, keyword_only from pyspark.ml.util import * -from pyspark.ml.wrapper import JavaEstimator, JavaModel +from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaWrapper from pyspark.ml.param.shared import * from pyspark.ml.common import inherit_doc -__all__ = ['BisectingKMeans', 'BisectingKMeansModel', +__all__ = ['BisectingKMeans', 'BisectingKMeansModel', 'BisectingKMeansSummary', 'KMeans', 'KMeansModel', - 'GaussianMixture', 'GaussianMixtureModel', + 'GaussianMixture', 'GaussianMixtureModel', 'GaussianMixtureSummary', 'LDA', 'LDAModel', 'LocalLDAModel', 'DistributedLDAModel'] +class ClusteringSummary(JavaWrapper): + """ + .. note:: Experimental + + Clustering results for a given model. + + .. versionadded:: 2.1.0 + """ + + @property + @since("2.1.0") + def predictionCol(self): + """ + Name for column of predicted clusters in `predictions`. + """ + return self._call_java("predictionCol") + + @property + @since("2.1.0") + def predictions(self): + """ + DataFrame produced by the model's `transform` method. + """ + return self._call_java("predictions") + + @property + @since("2.1.0") + def featuresCol(self): + """ + Name for column of features in `predictions`. + """ + return self._call_java("featuresCol") + + @property + @since("2.1.0") + def k(self): + """ + The number of clusters the model was trained with. + """ + return self._call_java("k") + + @property + @since("2.1.0") + def cluster(self): + """ + DataFrame of predicted cluster centers for each training data point. + """ + return self._call_java("cluster") + + @property + @since("2.1.0") + def clusterSizes(self): + """ + Size of (number of data points in) each cluster. + """ + return self._call_java("clusterSizes") + + class GaussianMixtureModel(JavaModel, JavaMLWritable, JavaMLReadable): """ .. note:: Experimental @@ -56,6 +114,28 @@ def gaussiansDF(self): """ return self._call_java("gaussiansDF") + @property + @since("2.1.0") + def hasSummary(self): + """ + Indicates whether a training summary exists for this model + instance. + """ + return self._call_java("hasSummary") + + @property + @since("2.1.0") + def summary(self): + """ + Gets summary (e.g. cluster assignments, cluster sizes) of the model trained on the + training set. An exception is thrown if no summary exists. + """ + if self.hasSummary: + return GaussianMixtureSummary(self._call_java("summary")) + else: + raise RuntimeError("No training summary available for this %s" % + self.__class__.__name__) + @inherit_doc class GaussianMixture(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol, HasSeed, @@ -92,6 +172,13 @@ class GaussianMixture(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIte >>> gm = GaussianMixture(k=3, tol=0.0001, ... maxIter=10, seed=10) >>> model = gm.fit(df) + >>> model.hasSummary + True + >>> summary = model.summary + >>> summary.k + 3 + >>> summary.clusterSizes + [2, 2, 2] >>> weights = model.weights >>> len(weights) 3 @@ -118,6 +205,8 @@ class GaussianMixture(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIte >>> model_path = temp_path + "/gmm_model" >>> model.save(model_path) >>> model2 = GaussianMixtureModel.load(model_path) + >>> model2.hasSummary + False >>> model2.weights == model.weights True >>> model2.gaussiansDF.show() @@ -181,6 +270,32 @@ def getK(self): return self.getOrDefault(self.k) +class GaussianMixtureSummary(ClusteringSummary): + """ + .. note:: Experimental + + Gaussian mixture clustering results for a given model. + + .. versionadded:: 2.1.0 + """ + + @property + @since("2.1.0") + def probabilityCol(self): + """ + Name for column of predicted probability of each cluster in `predictions`. + """ + return self._call_java("probabilityCol") + + @property + @since("2.1.0") + def probability(self): + """ + DataFrame of probabilities of each cluster for each training data point. + """ + return self._call_java("probability") + + class KMeansModel(JavaModel, JavaMLWritable, JavaMLReadable): """ Model fitted by KMeans. @@ -346,6 +461,27 @@ def computeCost(self, dataset): """ return self._call_java("computeCost", dataset) + @property + @since("2.1.0") + def hasSummary(self): + """ + Indicates whether a training summary exists for this model instance. + """ + return self._call_java("hasSummary") + + @property + @since("2.1.0") + def summary(self): + """ + Gets summary (e.g. cluster assignments, cluster sizes) of the model trained on the + training set. An exception is thrown if no summary exists. + """ + if self.hasSummary: + return BisectingKMeansSummary(self._call_java("summary")) + else: + raise RuntimeError("No training summary available for this %s" % + self.__class__.__name__) + @inherit_doc class BisectingKMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasSeed, @@ -373,6 +509,13 @@ class BisectingKMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIte 2 >>> model.computeCost(df) 2.000... + >>> model.hasSummary + True + >>> summary = model.summary + >>> summary.k + 2 + >>> summary.clusterSizes + [2, 2] >>> transformed = model.transform(df).select("features", "prediction") >>> rows = transformed.collect() >>> rows[0].prediction == rows[1].prediction @@ -387,6 +530,8 @@ class BisectingKMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIte >>> model_path = temp_path + "/bkm_model" >>> model.save(model_path) >>> model2 = BisectingKMeansModel.load(model_path) + >>> model2.hasSummary + False >>> model.clusterCenters()[0] == model2.clusterCenters()[0] array([ True, True], dtype=bool) >>> model.clusterCenters()[1] == model2.clusterCenters()[1] @@ -460,6 +605,17 @@ def _create_model(self, java_model): return BisectingKMeansModel(java_model) +class BisectingKMeansSummary(ClusteringSummary): + """ + .. note:: Experimental + + Bisecting KMeans clustering results for a given model. + + .. versionadded:: 2.1.0 + """ + pass + + @inherit_doc class LDAModel(JavaModel): """ diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 0bc319ca4d601..385391ba53fd4 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -160,8 +160,12 @@ def summary(self): training set. An exception is thrown if `trainingSummary is None`. """ - java_lrt_summary = self._call_java("summary") - return LinearRegressionTrainingSummary(java_lrt_summary) + if self.hasSummary: + java_lrt_summary = self._call_java("summary") + return LinearRegressionTrainingSummary(java_lrt_summary) + else: + raise RuntimeError("No training summary available for this %s" % + self.__class__.__name__) @property @since("2.0.0") @@ -1459,8 +1463,12 @@ def summary(self): training set. An exception is thrown if `trainingSummary is None`. """ - java_glrt_summary = self._call_java("summary") - return GeneralizedLinearRegressionTrainingSummary(java_glrt_summary) + if self.hasSummary: + java_glrt_summary = self._call_java("summary") + return GeneralizedLinearRegressionTrainingSummary(java_glrt_summary) + else: + raise RuntimeError("No training summary available for this %s" % + self.__class__.__name__) @property @since("2.0.0") diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 9d46cc3b4ae64..c0f0d4073564e 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -1097,6 +1097,38 @@ def test_logistic_regression_summary(self): sameSummary = model.evaluate(df) self.assertAlmostEqual(sameSummary.areaUnderROC, s.areaUnderROC) + def test_gaussian_mixture_summary(self): + data = [(Vectors.dense(1.0),), (Vectors.dense(5.0),), (Vectors.dense(10.0),), + (Vectors.sparse(1, [], []),)] + df = self.spark.createDataFrame(data, ["features"]) + gmm = GaussianMixture(k=2) + model = gmm.fit(df) + self.assertTrue(model.hasSummary) + s = model.summary + self.assertTrue(isinstance(s.predictions, DataFrame)) + self.assertEqual(s.probabilityCol, "probability") + self.assertTrue(isinstance(s.probability, DataFrame)) + self.assertEqual(s.featuresCol, "features") + self.assertEqual(s.predictionCol, "prediction") + self.assertTrue(isinstance(s.cluster, DataFrame)) + self.assertEqual(len(s.clusterSizes), 2) + self.assertEqual(s.k, 2) + + def test_bisecting_kmeans_summary(self): + data = [(Vectors.dense(1.0),), (Vectors.dense(5.0),), (Vectors.dense(10.0),), + (Vectors.sparse(1, [], []),)] + df = self.spark.createDataFrame(data, ["features"]) + bkm = BisectingKMeans(k=2) + model = bkm.fit(df) + self.assertTrue(model.hasSummary) + s = model.summary + self.assertTrue(isinstance(s.predictions, DataFrame)) + self.assertEqual(s.featuresCol, "features") + self.assertEqual(s.predictionCol, "prediction") + self.assertTrue(isinstance(s.cluster, DataFrame)) + self.assertEqual(len(s.clusterSizes), 2) + self.assertEqual(s.k, 2) + class OneVsRestTests(SparkSessionTestCase):