diff --git a/python/pyspark/ml/__init__.py b/python/pyspark/ml/__init__.py index 0a065b3155e1a..a4901622bf816 100644 --- a/python/pyspark/ml/__init__.py +++ b/python/pyspark/ml/__init__.py @@ -18,4 +18,4 @@ from pyspark.ml.param import * from pyspark.ml.pipeline import * -__all__ = ["Pipeline", "Transformer", "Estimator"] +__all__ = ["Param", "Params", "Pipeline", "Transformer", "Estimator"] diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 4b8f7fbe7d0ae..6bd2aa8e47837 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -15,7 +15,8 @@ # limitations under the License. # -from pyspark.ml.util import JavaEstimator, JavaModel, inherit_doc +from pyspark.ml.util import inherit_doc +from pyspark.ml.wrapper import JavaEstimator, JavaModel from pyspark.ml.param.shared import HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,\ HasRegParam @@ -39,10 +40,10 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti .setRegParam(0.01) >>> model = lr.fit(dataset) >>> test0 = sqlCtx.inferSchema(sc.parallelize([Row(features=Vectors.dense(-1.0))])) - >>> print model.transform(test0).first().prediction + >>> print model.transform(test0).head().prediction 0.0 >>> test1 = sqlCtx.inferSchema(sc.parallelize([Row(features=Vectors.sparse(1, [0], [1.0]))])) - >>> print model.transform(test1).first().prediction + >>> print model.transform(test1).head().prediction 1.0 """ _java_class = "org.apache.spark.ml.classification.LogisticRegression" diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 2cfe9a8dddc1f..e088acd0ca82d 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -15,9 +15,9 @@ # limitations under the License. # -from pyspark.ml.util import JavaTransformer, inherit_doc from pyspark.ml.param.shared import HasInputCol, HasOutputCol, HasNumFeatures - +from pyspark.ml.util import inherit_doc +from pyspark.ml.wrapper import JavaTransformer __all__ = ['Tokenizer', 'HashingTF'] @@ -33,9 +33,9 @@ class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol): >>> tokenizer = Tokenizer() \ .setInputCol("text") \ .setOutputCol("words") - >>> print tokenizer.transform(dataset).first() + >>> print tokenizer.transform(dataset).head() Row(text=u'a b c', words=[u'a', u'b', u'c']) - >>> print tokenizer.transform(dataset, {tokenizer.outputCol: "tokens"}).first() + >>> print tokenizer.transform(dataset, {tokenizer.outputCol: "tokens"}).head() Row(text=u'a b c', tokens=[u'a', u'b', u'c']) """ @@ -54,10 +54,10 @@ class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures): .setNumFeatures(10) \ .setInputCol("words") \ .setOutputCol("features") - >>> print hashingTF.transform(dataset).first().features + >>> print hashingTF.transform(dataset).head().features (10,[7,8,9],[1.0,1.0,1.0]) >>> params = {hashingTF.numFeatures: 5, hashingTF.outputCol: "vector"} - >>> print hashingTF.transform(dataset, params).first().vector + >>> print hashingTF.transform(dataset, params).head().vector (5,[2,3,4],[1.0,1.0,1.0]) """ diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py index 2396e20d23fa4..9d657acdd94f4 100644 --- a/python/pyspark/ml/param/__init__.py +++ b/python/pyspark/ml/param/__init__.py @@ -15,13 +15,26 @@ # limitations under the License. # +import uuid from abc import ABCMeta -from pyspark.ml.util import Identifiable - __all__ = ['Param', 'Params'] +class Identifiable(object): + """ + Object with a unique ID. + """ + + def __init__(self): + #: A unique id for the object. The default implementation + #: concatenates the class name, "-", and 8 random hex chars. + self.uid = type(self).__name__ + "-" + uuid.uuid4().hex[:8] + + def __repr__(self): + return self.uid + + class Param(object): """ A param with self-contained documentation and optionally default value. diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py index 3ef5a201a04b8..0c5ec86620a97 100644 --- a/python/pyspark/ml/pipeline.py +++ b/python/pyspark/ml/pipeline.py @@ -17,12 +17,25 @@ from abc import ABCMeta, abstractmethod -from pyspark.sql import inherit_doc # TODO: move inherit_doc to Spark Core from pyspark.ml.param import Param, Params __all__ = ['Estimator', 'Transformer', 'Pipeline', 'PipelineModel'] +def inherit_doc(cls): + for name, func in vars(cls).items(): + # only inherit docstring for public functions + if name.startswith("_"): + continue + if not func.__doc__: + for parent in cls.__bases__: + parent_func = getattr(parent, name, None) + if parent_func and getattr(parent_func, "__doc__", None): + func.__doc__ = parent_func.__doc__ + break + return cls + + @inherit_doc class Estimator(Params): """ diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index a154b51c11843..b627c2b4e930b 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -31,12 +31,12 @@ import unittest from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase -from pyspark.sql import SchemaRDD -from pyspark.ml import Transformer, Estimator, Pipeline +from pyspark.sql import DataFrame from pyspark.ml.param import Param +from pyspark.ml.pipeline import Transformer, Estimator, Pipeline -class MockDataset(SchemaRDD): +class MockDataset(DataFrame): def __init__(self): self.index = 0 diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index 82e1f9fa087e7..991330f78e983 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -15,28 +15,6 @@ # limitations under the License. # -import uuid -from abc import ABCMeta - -from pyspark import SparkContext -from pyspark.sql import DataFrame -from pyspark.ml.param import Params -from pyspark.ml.pipeline import Estimator, Transformer - - -class Identifiable(object): - """ - Object with a unique ID. - """ - - def __init__(self): - #: A unique id for the object. The default implementation - #: concatenates the class name, "-", and 8 random hex chars. - self.uid = type(self).__name__ + "-" + uuid.uuid4().hex[:8] - - def __repr__(self): - return self.uid - def inherit_doc(cls): for name, func in vars(cls).items(): @@ -50,128 +28,3 @@ def inherit_doc(cls): func.__doc__ = parent_func.__doc__ break return cls - - -def _jvm(): - """ - Returns the JVM view associated with SparkContext. Must be called - after SparkContext is initialized. - """ - jvm = SparkContext._jvm - if jvm: - return jvm - else: - raise AttributeError("Cannot load _jvm from SparkContext. Is SparkContext initialized?") - - -@inherit_doc -class JavaWrapper(Params): - """ - Utility class to help create wrapper classes from Java/Scala - implementations of pipeline components. - """ - - __metaclass__ = ABCMeta - - #: Fully-qualified class name of the wrapped Java component. - _java_class = None - - def _java_obj(self): - """ - Returns or creates a Java object. - """ - java_obj = _jvm() - for name in self._java_class.split("."): - java_obj = getattr(java_obj, name) - return java_obj() - - def _transfer_params_to_java(self, params, java_obj): - """ - Transforms the embedded params and additional params to the - input Java object. - :param params: additional params (overwriting embedded values) - :param java_obj: Java object to receive the params - """ - paramMap = self._merge_params(params) - for param in self.params: - if param in paramMap: - java_obj.set(param.name, paramMap[param]) - - def _empty_java_param_map(self): - """ - Returns an empty Java ParamMap reference. - """ - return _jvm().org.apache.spark.ml.param.ParamMap() - - def _create_java_param_map(self, params, java_obj): - paramMap = self._empty_java_param_map() - for param, value in params.items(): - if param.parent is self: - paramMap.put(java_obj.getParam(param.name), value) - return paramMap - - -@inherit_doc -class JavaEstimator(Estimator, JavaWrapper): - """ - Base class for :py:class:`Estimator`s that wrap Java/Scala - implementations. - """ - - __metaclass__ = ABCMeta - - def _create_model(self, java_model): - """ - Creates a model from the input Java model reference. - """ - return JavaModel(java_model) - - def _fit_java(self, dataset, params={}): - """ - Fits a Java model to the input dataset. - :param dataset: input dataset, which is an instance of - :py:class:`pyspark.sql.SchemaRDD` - :param params: additional params (overwriting embedded values) - :return: fitted Java model - """ - java_obj = self._java_obj() - self._transfer_params_to_java(params, java_obj) - return java_obj.fit(dataset._jschema_rdd, self._empty_java_param_map()) - - def fit(self, dataset, params={}): - java_model = self._fit_java(dataset, params) - return self._create_model(java_model) - - -@inherit_doc -class JavaTransformer(Transformer, JavaWrapper): - """ - Base class for :py:class:`Transformer`s that wrap Java/Scala - implementations. - """ - - __metaclass__ = ABCMeta - - def transform(self, dataset, params={}): - java_obj = self._java_obj() - self._transfer_params_to_java({}, java_obj) - java_param_map = self._create_java_param_map(params, java_obj) - return DataFrame(java_obj.transform(dataset._jschema_rdd, java_param_map), - dataset.sql_ctx) - - -@inherit_doc -class JavaModel(JavaTransformer): - """ - Base class for :py:class:`Model`s that wrap Java/Scala - implementations. - """ - - __metaclass__ = ABCMeta - - def __init__(self, java_model): - super(JavaTransformer, self).__init__() - self._java_model = java_model - - def _java_obj(self): - return self._java_model diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py new file mode 100644 index 0000000000000..9e12ddc3d9b8f --- /dev/null +++ b/python/pyspark/ml/wrapper.py @@ -0,0 +1,149 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from abc import ABCMeta + +from pyspark import SparkContext +from pyspark.sql import DataFrame +from pyspark.ml.param import Params +from pyspark.ml.pipeline import Estimator, Transformer +from pyspark.ml.util import inherit_doc + + +def _jvm(): + """ + Returns the JVM view associated with SparkContext. Must be called + after SparkContext is initialized. + """ + jvm = SparkContext._jvm + if jvm: + return jvm + else: + raise AttributeError("Cannot load _jvm from SparkContext. Is SparkContext initialized?") + + +@inherit_doc +class JavaWrapper(Params): + """ + Utility class to help create wrapper classes from Java/Scala + implementations of pipeline components. + """ + + __metaclass__ = ABCMeta + + #: Fully-qualified class name of the wrapped Java component. + _java_class = None + + def _java_obj(self): + """ + Returns or creates a Java object. + """ + java_obj = _jvm() + for name in self._java_class.split("."): + java_obj = getattr(java_obj, name) + return java_obj() + + def _transfer_params_to_java(self, params, java_obj): + """ + Transforms the embedded params and additional params to the + input Java object. + :param params: additional params (overwriting embedded values) + :param java_obj: Java object to receive the params + """ + paramMap = self._merge_params(params) + for param in self.params: + if param in paramMap: + java_obj.set(param.name, paramMap[param]) + + def _empty_java_param_map(self): + """ + Returns an empty Java ParamMap reference. + """ + return _jvm().org.apache.spark.ml.param.ParamMap() + + def _create_java_param_map(self, params, java_obj): + paramMap = self._empty_java_param_map() + for param, value in params.items(): + if param.parent is self: + paramMap.put(java_obj.getParam(param.name), value) + return paramMap + + +@inherit_doc +class JavaEstimator(Estimator, JavaWrapper): + """ + Base class for :py:class:`Estimator`s that wrap Java/Scala + implementations. + """ + + __metaclass__ = ABCMeta + + def _create_model(self, java_model): + """ + Creates a model from the input Java model reference. + """ + return JavaModel(java_model) + + def _fit_java(self, dataset, params={}): + """ + Fits a Java model to the input dataset. + :param dataset: input dataset, which is an instance of + :py:class:`pyspark.sql.SchemaRDD` + :param params: additional params (overwriting embedded values) + :return: fitted Java model + """ + java_obj = self._java_obj() + self._transfer_params_to_java(params, java_obj) + return java_obj.fit(dataset._jdf, self._empty_java_param_map()) + + def fit(self, dataset, params={}): + java_model = self._fit_java(dataset, params) + return self._create_model(java_model) + + +@inherit_doc +class JavaTransformer(Transformer, JavaWrapper): + """ + Base class for :py:class:`Transformer`s that wrap Java/Scala + implementations. + """ + + __metaclass__ = ABCMeta + + def transform(self, dataset, params={}): + java_obj = self._java_obj() + self._transfer_params_to_java({}, java_obj) + java_param_map = self._create_java_param_map(params, java_obj) + return DataFrame(java_obj.transform(dataset._jdf, java_param_map), + dataset.sql_ctx) + + +@inherit_doc +class JavaModel(JavaTransformer): + """ + Base class for :py:class:`Model`s that wrap Java/Scala + implementations. + """ + + __metaclass__ = ABCMeta + + def __init__(self, java_model): + super(JavaTransformer, self).__init__() + self._java_model = java_model + + def _java_obj(self): + return self._java_model diff --git a/python/run-tests b/python/run-tests index 9ee19ed6e6b26..57e58c1341c62 100755 --- a/python/run-tests +++ b/python/run-tests @@ -81,6 +81,13 @@ function run_mllib_tests() { run_test "pyspark/mllib/tests.py" } +function run_ml_tests() { + echo "Run ml tests ..." + run_test "pyspark/ml/feature.py" + run_test "pyspark/ml/classification.py" + run_test "pyspark/ml/tests.py" +} + function run_streaming_tests() { echo "Run streaming tests ..." run_test "pyspark/streaming/util.py" @@ -102,6 +109,7 @@ $PYSPARK_PYTHON --version run_core_tests run_sql_tests run_mllib_tests +run_ml_tests run_streaming_tests # Try to test with PyPy