Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Minor][ML] Refactor clustering summary. #15555

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -297,27 +297,13 @@ object BisectingKMeans extends DefaultParamsReadable[BisectingKMeans] {
@Since("2.1.0")
@Experimental
class BisectingKMeansSummary private[clustering] (
@Since("2.1.0") @transient val predictions: DataFrame,
@Since("2.1.0") val predictionCol: String,
@Since("2.1.0") val featuresCol: String,
@Since("2.1.0") val k: Int) extends Serializable {

/**
* Cluster centers of the transformed data.
*/
@Since("2.1.0")
@transient lazy val cluster: DataFrame = predictions.select(predictionCol)

/**
* Size of (number of data points in) each cluster.
*/
@Since("2.1.0")
lazy val clusterSizes: Array[Long] = {
val sizes = Array.fill[Long](k)(0)
cluster.groupBy(predictionCol).count().select(predictionCol, "count").collect().foreach {
case Row(cluster: Int, count: Long) => sizes(cluster) = count
}
sizes
}

}
predictions: DataFrame,
predictionCol: String,
featuresCol: String,
k: Int)
extends ClusteringSummary (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor: this can fit on one line like: k: Int) extends ClusteringSummary(..., ..., ...)

predictions,
predictionCol,
featuresCol,
k
)
Original file line number Diff line number Diff line change
Expand Up @@ -365,33 +365,20 @@ object GaussianMixture extends DefaultParamsReadable[GaussianMixture] {
@Since("2.0.0")
@Experimental
class GaussianMixtureSummary private[clustering] (
@Since("2.0.0") @transient val predictions: DataFrame,
@Since("2.0.0") val predictionCol: String,
@Since("2.0.0") val probabilityCol: String,
@Since("2.0.0") val featuresCol: String,
@Since("2.0.0") val k: Int) extends Serializable {

/**
* Cluster centers of the transformed data.
*/
@Since("2.0.0")
@transient lazy val cluster: DataFrame = predictions.select(predictionCol)
predictions: DataFrame,
predictionCol: String,
val probabilityCol: String,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could do a Since tag here

featuresCol: String,
k: Int)
extends ClusteringSummary (
predictions,
predictionCol,
featuresCol,
k) {

/**
* Probability of each cluster.
*/
@Since("2.0.0")
@transient lazy val probability: DataFrame = predictions.select(probabilityCol)

/**
* Size of (number of data points in) each cluster.
*/
@Since("2.0.0")
lazy val clusterSizes: Array[Long] = {
val sizes = Array.fill[Long](k)(0)
cluster.groupBy(predictionCol).count().select(predictionCol, "count").collect().foreach {
case Row(cluster: Int, count: Long) => sizes(cluster) = count
}
sizes
}
}
32 changes: 26 additions & 6 deletions mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
Original file line number Diff line number Diff line change
Expand Up @@ -354,21 +354,41 @@ object KMeans extends DefaultParamsReadable[KMeans] {
@Since("2.0.0")
@Experimental
class KMeansSummary private[clustering] (
@Since("2.0.0") @transient val predictions: DataFrame,
@Since("2.0.0") val predictionCol: String,
@Since("2.0.0") val featuresCol: String,
@Since("2.0.0") val k: Int) extends Serializable {
predictions: DataFrame,
predictionCol: String,
featuresCol: String,
k: Int)
extends ClusteringSummary (
predictions,
predictionCol,
featuresCol,
k
)

/**
* :: Experimental ::
* Summary of clustering.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Summary of clustering algorithms." ?

*
* @param predictions [[DataFrame]] produced by model.transform()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Add periods for each line

* @param predictionCol Name for column of predicted clusters in `predictions`
* @param featuresCol Name for column of features in `predictions`
* @param k Number of clusters
*/
@Experimental
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about adding @Since("2.1.0") here?
Create a new scala file named Clustering.scala and move ClusteringSummary into it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ClusteringSummary will be succeeded by summaries who were added in different version, so I think we should not add since version here. To the issue for a new file, I think ClusteringSummary is a small class, we can place it here temporarily.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not entirely certain on the official policy for the @Since tags, but it seems better to me to put @Since("2.1.0") here for the class and the methods. It will be correct for some and will at least not be incorrect for others. I'm not positive though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm also ambivalent about this, the reason behind my change is that some classes such as KMeansSummary and GaussianMixtureSummary were added at 2.0. If I put @Since("2.1.0") here, it looks not quite right, but I'm not sure whether it's OK. @jkbradley What's your opinion? Thanks.

class ClusteringSummary private[clustering] (
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is generic to clustering, how about putting it in a new file?

@transient val predictions: DataFrame,
val predictionCol: String,
val featuresCol: String,
val k: Int) extends Serializable {

/**
* Cluster centers of the transformed data.
*/
@Since("2.0.0")
@transient lazy val cluster: DataFrame = predictions.select(predictionCol)

/**
* Size of (number of data points in) each cluster.
*/
@Since("2.0.0")
lazy val clusterSizes: Array[Long] = {
val sizes = Array.fill[Long](k)(0)
cluster.groupBy(predictionCol).count().select(predictionCol, "count").collect().foreach {
Expand Down