Skip to content

Commit

Permalink
[SPARK-7380] [MLLIB] pipeline stages should be copyable in Python
Browse files Browse the repository at this point in the history
This PR makes pipeline stages in Python copyable and hence simplifies some implementations. It also includes the following changes:

1. Rename `paramMap` and `defaultParamMap` to `_paramMap` and `_defaultParamMap`, respectively.
2. Accept a list of param maps in `fit`.
3. Use parent uid and name to identify param.

jkbradley

Author: Xiangrui Meng <[email protected]>
Author: Joseph K. Bradley <[email protected]>

Closes #6088 from mengxr/SPARK-7380 and squashes the following commits:

413c463 [Xiangrui Meng] remove unnecessary doc
4159f35 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-7380
611c719 [Xiangrui Meng] fix python style
68862b8 [Xiangrui Meng] update _java_obj initialization
927ad19 [Xiangrui Meng] fix ml/tests.py
0138fc3 [Xiangrui Meng] update feature transformers and fix a bug in RegexTokenizer
9ca44fb [Xiangrui Meng] simplify Java wrappers and add tests
c7d84ef [Xiangrui Meng] update ml/tests.py to test copy params
7e0d27f [Xiangrui Meng] merge master
46840fb [Xiangrui Meng] update wrappers
b6db1ed [Xiangrui Meng] update all self.paramMap to self._paramMap
46cb6ed [Xiangrui Meng] merge master
a163413 [Xiangrui Meng] fix style
1042e80 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-7380
9630eae [Xiangrui Meng] fix Identifiable._randomUID
13bd70a [Xiangrui Meng] update ml/tests.py
64a536c [Xiangrui Meng] use _fit/_transform/_evaluate to simplify the impl
02abf13 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into copyable-python
66ce18c [Joseph K. Bradley] some cleanups before sending to Xiangrui
7431272 [Joseph K. Bradley] Rebased with master
  • Loading branch information
mengxr committed May 18, 2015
1 parent 56ede88 commit 9c7e802
Show file tree
Hide file tree
Showing 16 changed files with 498 additions and 261 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class RegexTokenizer(override val uid: String)
* Default: 1, to avoid returning empty strings
* @group param
*/
val minTokenLength: IntParam = new IntParam(this, "minLength", "minimum token length (>= 0)",
val minTokenLength: IntParam = new IntParam(this, "minTokenLength", "minimum token length (>= 0)",
ParamValidators.gtEq(0))

/** @group setParam */
Expand Down
7 changes: 3 additions & 4 deletions mllib/src/main/scala/org/apache/spark/ml/param/params.scala
Original file line number Diff line number Diff line change
Expand Up @@ -483,16 +483,15 @@ trait Params extends Identifiable with Serializable {
def copy(extra: ParamMap): Params = {
val that = this.getClass.getConstructor(classOf[String]).newInstance(uid)
copyValues(that, extra)
that
}

/**
* Extracts the embedded default param values and user-supplied values, and then merges them with
* extra values from input into a flat param map, where the latter value is used if there exist
* conflicts, i.e., with ordering: default param values < user-supplied values < extraParamMap.
* conflicts, i.e., with ordering: default param values < user-supplied values < extra.
*/
final def extractParamMap(extraParamMap: ParamMap): ParamMap = {
defaultParamMap ++ paramMap ++ extraParamMap
final def extractParamMap(extra: ParamMap): ParamMap = {
defaultParamMap ++ paramMap ++ extra
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,17 @@ import java.util.UUID
/**
* Trait for an object with an immutable unique ID that identifies itself and its derivatives.
*/
trait Identifiable {
private[spark] trait Identifiable {

/**
* An immutable unique ID for the object and its derivatives.
*/
val uid: String

override def toString: String = uid
}

object Identifiable {
private[spark] object Identifiable {

/**
* Returns a random UID that concatenates the given prefix, "_", and 12 random hex chars.
Expand Down
35 changes: 20 additions & 15 deletions python/pyspark/ml/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
...
TypeError: Method setParams forces keyword arguments.
"""
_java_class = "org.apache.spark.ml.classification.LogisticRegression"

# a placeholder to make it appear in the generated doc
elasticNetParam = \
Param(Params._dummy(), "elasticNetParam",
Expand All @@ -75,6 +75,8 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred
threshold=0.5, probabilityCol="probability")
"""
super(LogisticRegression, self).__init__()
self._java_obj = self._new_java_obj(
"org.apache.spark.ml.classification.LogisticRegression", self.uid)
#: param for the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty
# is an L2 penalty. For alpha = 1, it is an L1 penalty.
self.elasticNetParam = \
Expand Down Expand Up @@ -111,7 +113,7 @@ def setElasticNetParam(self, value):
"""
Sets the value of :py:attr:`elasticNetParam`.
"""
self.paramMap[self.elasticNetParam] = value
self._paramMap[self.elasticNetParam] = value
return self

def getElasticNetParam(self):
Expand All @@ -124,7 +126,7 @@ def setFitIntercept(self, value):
"""
Sets the value of :py:attr:`fitIntercept`.
"""
self.paramMap[self.fitIntercept] = value
self._paramMap[self.fitIntercept] = value
return self

def getFitIntercept(self):
Expand All @@ -137,7 +139,7 @@ def setThreshold(self, value):
"""
Sets the value of :py:attr:`threshold`.
"""
self.paramMap[self.threshold] = value
self._paramMap[self.threshold] = value
return self

def getThreshold(self):
Expand Down Expand Up @@ -208,7 +210,6 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
1.0
"""

_java_class = "org.apache.spark.ml.classification.DecisionTreeClassifier"
# a placeholder to make it appear in the generated doc
impurity = Param(Params._dummy(), "impurity",
"Criterion used for information gain calculation (case-insensitive). " +
Expand All @@ -224,6 +225,8 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini")
"""
super(DecisionTreeClassifier, self).__init__()
self._java_obj = self._new_java_obj(
"org.apache.spark.ml.classification.DecisionTreeClassifier", self.uid)
#: param for Criterion used for information gain calculation (case-insensitive).
self.impurity = \
Param(self, "impurity",
Expand Down Expand Up @@ -256,7 +259,7 @@ def setImpurity(self, value):
"""
Sets the value of :py:attr:`impurity`.
"""
self.paramMap[self.impurity] = value
self._paramMap[self.impurity] = value
return self

def getImpurity(self):
Expand Down Expand Up @@ -299,7 +302,6 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
1.0
"""

_java_class = "org.apache.spark.ml.classification.RandomForestClassifier"
# a placeholder to make it appear in the generated doc
impurity = Param(Params._dummy(), "impurity",
"Criterion used for information gain calculation (case-insensitive). " +
Expand All @@ -325,6 +327,8 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred
numTrees=20, featureSubsetStrategy="auto", seed=42)
"""
super(RandomForestClassifier, self).__init__()
self._java_obj = self._new_java_obj(
"org.apache.spark.ml.classification.RandomForestClassifier", self.uid)
#: param for Criterion used for information gain calculation (case-insensitive).
self.impurity = \
Param(self, "impurity",
Expand Down Expand Up @@ -370,7 +374,7 @@ def setImpurity(self, value):
"""
Sets the value of :py:attr:`impurity`.
"""
self.paramMap[self.impurity] = value
self._paramMap[self.impurity] = value
return self

def getImpurity(self):
Expand All @@ -383,7 +387,7 @@ def setSubsamplingRate(self, value):
"""
Sets the value of :py:attr:`subsamplingRate`.
"""
self.paramMap[self.subsamplingRate] = value
self._paramMap[self.subsamplingRate] = value
return self

def getSubsamplingRate(self):
Expand All @@ -396,7 +400,7 @@ def setNumTrees(self, value):
"""
Sets the value of :py:attr:`numTrees`.
"""
self.paramMap[self.numTrees] = value
self._paramMap[self.numTrees] = value
return self

def getNumTrees(self):
Expand All @@ -409,7 +413,7 @@ def setFeatureSubsetStrategy(self, value):
"""
Sets the value of :py:attr:`featureSubsetStrategy`.
"""
self.paramMap[self.featureSubsetStrategy] = value
self._paramMap[self.featureSubsetStrategy] = value
return self

def getFeatureSubsetStrategy(self):
Expand Down Expand Up @@ -452,7 +456,6 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol
1.0
"""

_java_class = "org.apache.spark.ml.classification.GBTClassifier"
# a placeholder to make it appear in the generated doc
lossType = Param(Params._dummy(), "lossType",
"Loss function which GBT tries to minimize (case-insensitive). " +
Expand All @@ -476,6 +479,8 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred
lossType="logistic", maxIter=20, stepSize=0.1)
"""
super(GBTClassifier, self).__init__()
self._java_obj = self._new_java_obj(
"org.apache.spark.ml.classification.GBTClassifier", self.uid)
#: param for Loss function which GBT tries to minimize (case-insensitive).
self.lossType = Param(self, "lossType",
"Loss function which GBT tries to minimize (case-insensitive). " +
Expand Down Expand Up @@ -517,7 +522,7 @@ def setLossType(self, value):
"""
Sets the value of :py:attr:`lossType`.
"""
self.paramMap[self.lossType] = value
self._paramMap[self.lossType] = value
return self

def getLossType(self):
Expand All @@ -530,7 +535,7 @@ def setSubsamplingRate(self, value):
"""
Sets the value of :py:attr:`subsamplingRate`.
"""
self.paramMap[self.subsamplingRate] = value
self._paramMap[self.subsamplingRate] = value
return self

def getSubsamplingRate(self):
Expand All @@ -543,7 +548,7 @@ def setStepSize(self, value):
"""
Sets the value of :py:attr:`stepSize`.
"""
self.paramMap[self.stepSize] = value
self._paramMap[self.stepSize] = value
return self

def getStepSize(self):
Expand Down
6 changes: 3 additions & 3 deletions python/pyspark/ml/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,6 @@ class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPrediction
0.83...
"""

_java_class = "org.apache.spark.ml.evaluation.BinaryClassificationEvaluator"

# a placeholder to make it appear in the generated doc
metricName = Param(Params._dummy(), "metricName",
"metric name in evaluation (areaUnderROC|areaUnderPR)")
Expand All @@ -56,6 +54,8 @@ def __init__(self, rawPredictionCol="rawPrediction", labelCol="label",
metricName="areaUnderROC")
"""
super(BinaryClassificationEvaluator, self).__init__()
self._java_obj = self._new_java_obj(
"org.apache.spark.ml.evaluation.BinaryClassificationEvaluator", self.uid)
#: param for metric name in evaluation (areaUnderROC|areaUnderPR)
self.metricName = Param(self, "metricName",
"metric name in evaluation (areaUnderROC|areaUnderPR)")
Expand All @@ -68,7 +68,7 @@ def setMetricName(self, value):
"""
Sets the value of :py:attr:`metricName`.
"""
self.paramMap[self.metricName] = value
self._paramMap[self.metricName] = value
return self

def getMetricName(self):
Expand Down
Loading

0 comments on commit 9c7e802

Please sign in to comment.