diff --git a/examples/src/main/python/ml/simple_text_classification_pipeline.py b/examples/src/main/python/ml/simple_text_classification_pipeline.py index d303d4d97de79..eb80f44f4fa97 100644 --- a/examples/src/main/python/ml/simple_text_classification_pipeline.py +++ b/examples/src/main/python/ml/simple_text_classification_pipeline.py @@ -21,6 +21,7 @@ from pyspark.ml.feature import HashingTF, Tokenizer from pyspark.ml.classification import LogisticRegression + if __name__ == "__main__": sc = SparkContext(appName="SimpleTextClassificationPipeline") sqlCtx = SQLContext(sc) diff --git a/python/pyspark/ml/__init__.py b/python/pyspark/ml/__init__.py index 4666ce7bc2499..1cf9d3065f3d1 100644 --- a/python/pyspark/ml/__init__.py +++ b/python/pyspark/ml/__init__.py @@ -15,10 +15,10 @@ # limitations under the License. # -from abc import ABCMeta, abstractmethod +from abc import ABCMeta, abstractmethod, abstractproperty from pyspark import SparkContext -from pyspark.sql import inherit_doc # TODO: move inherit_doc to Spark Core +from pyspark.sql import SchemaRDD, inherit_doc # TODO: move inherit_doc to Spark Core from pyspark.ml.param import Param, Params from pyspark.ml.util import Identifiable @@ -146,16 +146,17 @@ def getStages(self): if self.stages in self.paramMap: return self.paramMap[self.stages] - def fit(self, dataset): + def fit(self, dataset, params={}): + map = self._merge_params(params) transformers = [] for stage in self.getStages(): if isinstance(stage, Transformer): transformers.append(stage) - dataset = stage.transform(dataset) + dataset = stage.transform(dataset, map) elif isinstance(stage, Estimator): - model = stage.fit(dataset) + model = stage.fit(dataset, map) transformers.append(model) - dataset = model.transform(dataset) + dataset = model.transform(dataset, map) else: raise ValueError( "Cannot recognize a pipeline stage of type %s." % type(stage).__name__) @@ -169,7 +170,65 @@ def __init__(self, transformers): super(PipelineModel, self).__init__() self.transformers = transformers - def transform(self, dataset): + def transform(self, dataset, params={}): + map = self._merge_params(params) for t in self.transformers: - dataset = t.transform(dataset) + dataset = t.transform(dataset, map) return dataset + + +@inherit_doc +class JavaWrapper(object): + + __metaclass__ = ABCMeta + + def __init__(self): + super(JavaWrapper, self).__init__() + + @abstractproperty + def _java_class(self): + raise NotImplementedError + + def _create_java_obj(self): + java_obj = _jvm() + for name in self._java_class.split("."): + java_obj = getattr(java_obj, name) + return java_obj() + + +@inherit_doc +class JavaEstimator(Estimator, JavaWrapper): + + __metaclass__ = ABCMeta + + def __init__(self): + super(JavaEstimator, self).__init__() + + @abstractmethod + def _create_model(self, java_model): + raise NotImplementedError + + def _fit_java(self, dataset, params={}): + java_obj = self._create_java_obj() + self._transfer_params_to_java(params, java_obj) + return java_obj.fit(dataset._jschema_rdd, _jvm().org.apache.spark.ml.param.ParamMap()) + + def fit(self, dataset, params={}): + java_model = self._fit_java(dataset, params) + return self._create_model(java_model) + + +@inherit_doc +class JavaTransformer(Transformer, JavaWrapper): + + __metaclass__ = ABCMeta + + def __init__(self): + super(JavaTransformer, self).__init__() + + def transform(self, dataset, params={}): + java_obj = self._create_java_obj() + self._transfer_params_to_java(params, java_obj) + return SchemaRDD(java_obj.transform(dataset._jschema_rdd, + _jvm().org.apache.spark.ml.param.ParamMap()), + dataset.sql_ctx) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index fd1fb906ca5c1..3a105f1a10a3d 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -16,42 +16,40 @@ # from pyspark.sql import SchemaRDD, inherit_doc -from pyspark.ml import Estimator, Transformer, _jvm +from pyspark.ml import JavaEstimator, Transformer, _jvm from pyspark.ml.param.shared import HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,\ HasRegParam @inherit_doc -class LogisticRegression(Estimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, +class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, HasRegParam): """ Logistic regression. """ - # _java_class = "org.apache.spark.ml.classification.LogisticRegression" - def __init__(self): super(LogisticRegression, self).__init__() - self._java_obj = _jvm().org.apache.spark.ml.classification.LogisticRegression() - - def fit(self, dataset, params=None): - """ - Fits a dataset with optional parameters. - """ - java_model = self._java_obj.fit(dataset._jschema_rdd, - _jvm().org.apache.spark.ml.param.ParamMap()) + + @property + def _java_class(self): + return "org.apache.spark.ml.classification.LogisticRegression" + + def _create_model(self, java_model): return LogisticRegressionModel(java_model) +@inherit_doc class LogisticRegressionModel(Transformer): """ Model fitted by LogisticRegression. """ - def __init__(self, _java_model): - self._java_model = _java_model + def __init__(self, java_model): + self._java_model = java_model - def transform(self, dataset): + def transform(self, dataset, params={}): + # TODO: handle params here. return SchemaRDD(self._java_model.transform( dataset._jschema_rdd, _jvm().org.apache.spark.ml.param.ParamMap()), dataset.sql_ctx) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 9e4b6574574a4..9f3b1c7a055cb 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -15,91 +15,41 @@ # limitations under the License. # -from pyspark.sql import SchemaRDD, ArrayType, StringType, inherit_doc -from pyspark.ml import Transformer, _jvm +from pyspark.sql import inherit_doc +from pyspark.ml import JavaTransformer from pyspark.ml.param import Param +from pyspark.ml.param.shared import HasInputCol, HasOutputCol + @inherit_doc -class Tokenizer(Transformer): +class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol): def __init__(self): super(Tokenizer, self).__init__() - self.inputCol = Param(self, "inputCol", "input column name", None) - self.outputCol = Param(self, "outputCol", "output column name", None) - self.paramMap = {} - - def setInputCol(self, value): - self.paramMap[self.inputCol] = value - return self - - def getInputCol(self): - if self.inputCol in self.paramMap: - return self.paramMap[self.inputCol] - def setOutputCol(self, value): - self.paramMap[self.outputCol] = value - return self - - def getOutputCol(self): - if self.outputCol in self.paramMap: - return self.paramMap[self.outputCol] - - def transform(self, dataset, params={}): - sqlCtx = dataset.sql_ctx - if isinstance(params, dict): - paramMap = self.paramMap.copy() - paramMap.update(params) - inputCol = paramMap[self.inputCol] - outputCol = paramMap[self.outputCol] - # TODO: make names unique - sqlCtx.registerFunction("tokenize", lambda text: text.split(), - ArrayType(StringType(), False)) - dataset.registerTempTable("dataset") - return sqlCtx.sql("SELECT *, tokenize(%s) AS %s FROM dataset" % (inputCol, outputCol)) - elif isinstance(params, list): - return [self.transform(dataset, paramMap) for paramMap in params] - else: - raise ValueError("The input params must be either a dict or a list.") + @property + def _java_class(self): + return "org.apache.spark.ml.feature.Tokenizer" @inherit_doc -class HashingTF(Transformer): +class HashingTF(JavaTransformer, HasInputCol, HasOutputCol): def __init__(self): super(HashingTF, self).__init__() - self._java_obj = _jvm().org.apache.spark.ml.feature.HashingTF() + #: param for number of features self.numFeatures = Param(self, "numFeatures", "number of features", 1 << 18) - self.inputCol = Param(self, "inputCol", "input column name") - self.outputCol = Param(self, "outputCol", "output column name") + + @property + def _java_class(self): + return "org.apache.spark.ml.feature.HashingTF" def setNumFeatures(self, value): - self._java_obj.setNumFeatures(value) + self.paramMap[self.numFeatures] = value return self def getNumFeatures(self): - return self._java_obj.getNumFeatures() - - def setInputCol(self, value): - self._java_obj.setInputCol(value) - return self - - def getInputCol(self): - return self._java_obj.getInputCol() - - def setOutputCol(self, value): - self._java_obj.setOutputCol(value) - return self - - def getOutputCol(self): - return self._java_obj.getOutputCol() - - def transform(self, dataset, paramMap={}): - if isinstance(paramMap, dict): - javaParamMap = _jvm().org.apache.spark.ml.param.ParamMap() - for k, v in paramMap.items(): - param = self._java_obj.getParam(k.name) - javaParamMap.put(param, v) - return SchemaRDD(self._java_obj.transform(dataset._jschema_rdd, javaParamMap), - dataset.sql_ctx) + if self.numFeatures in self.paramMap: + return self.paramMap[self.numFeatures] else: - raise ValueError("paramMap must be a dict.") + return self.numFeatures.defaultValue diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py index 89e5d732f7586..f81b836a2242f 100644 --- a/python/pyspark/ml/param/__init__.py +++ b/python/pyspark/ml/param/__init__.py @@ -63,3 +63,14 @@ def params(self): :py:class:`Param`. """ return filter(lambda x: isinstance(x, Param), map(lambda x: getattr(self, x), dir(self))) + + def _merge_params(self, params): + map = self.paramMap.copy() + map.update(params) + return map + + def _transfer_params_to_java(self, params, java_obj): + map = self._merge_params(params) + for param in self.params(): + if param in map: + java_obj.set(param.name, map[param]) diff --git a/python/pyspark/ml/param/_gen_shared_params.py b/python/pyspark/ml/param/_gen_shared_params.py index 8c3aa7eba9483..f40823b906221 100644 --- a/python/pyspark/ml/param/_gen_shared_params.py +++ b/python/pyspark/ml/param/_gen_shared_params.py @@ -54,13 +54,13 @@ def set%s(self, value): self.paramMap[self.%s] = value return self - def get%s(self, value): + def get%s(self): if self.%s in self.paramMap: return self.paramMap[self.%s] else: - return self.defaultValue""" % ( + return self.%s.defaultValue""" % ( upperCamelName, upperCamelName, doc, name, name, doc, defaultValue, upperCamelName, name, - upperCamelName, name, name) + upperCamelName, name, name, name) if __name__ == "__main__": print header diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py index 88afb5481f7b8..8680f389577b6 100644 --- a/python/pyspark/ml/param/shared.py +++ b/python/pyspark/ml/param/shared.py @@ -31,11 +31,11 @@ def setMaxIter(self, value): self.paramMap[self.maxIter] = value return self - def getMaxIter(self, value): + def getMaxIter(self): if self.maxIter in self.paramMap: return self.paramMap[self.maxIter] else: - return self.defaultValue + return self.maxIter.defaultValue class HasRegParam(Params): @@ -49,11 +49,11 @@ def setRegParam(self, value): self.paramMap[self.regParam] = value return self - def getRegParam(self, value): + def getRegParam(self): if self.regParam in self.paramMap: return self.paramMap[self.regParam] else: - return self.defaultValue + return self.regParam.defaultValue class HasFeaturesCol(Params): @@ -67,11 +67,11 @@ def setFeaturesCol(self, value): self.paramMap[self.featuresCol] = value return self - def getFeaturesCol(self, value): + def getFeaturesCol(self): if self.featuresCol in self.paramMap: return self.paramMap[self.featuresCol] else: - return self.defaultValue + return self.featuresCol.defaultValue class HasLabelCol(Params): @@ -85,11 +85,11 @@ def setLabelCol(self, value): self.paramMap[self.labelCol] = value return self - def getLabelCol(self, value): + def getLabelCol(self): if self.labelCol in self.paramMap: return self.paramMap[self.labelCol] else: - return self.defaultValue + return self.labelCol.defaultValue class HasPredictionCol(Params): @@ -103,11 +103,11 @@ def setPredictionCol(self, value): self.paramMap[self.predictionCol] = value return self - def getPredictionCol(self, value): + def getPredictionCol(self): if self.predictionCol in self.paramMap: return self.paramMap[self.predictionCol] else: - return self.defaultValue + return self.predictionCol.defaultValue class HasInputCol(Params): @@ -121,11 +121,11 @@ def setInputCol(self, value): self.paramMap[self.inputCol] = value return self - def getInputCol(self, value): + def getInputCol(self): if self.inputCol in self.paramMap: return self.paramMap[self.inputCol] else: - return self.defaultValue + return self.inputCol.defaultValue class HasOutputCol(Params): @@ -139,8 +139,8 @@ def setOutputCol(self, value): self.paramMap[self.outputCol] = value return self - def getOutputCol(self, value): + def getOutputCol(self): if self.outputCol in self.paramMap: return self.paramMap[self.outputCol] else: - return self.defaultValue + return self.outputCol.defaultValue