From dadd84ee3d5f70b6a9b2af286cea9cac2057a764 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Mon, 19 Jan 2015 10:05:15 -0800 Subject: [PATCH] add base classes and docs --- python/docs/pyspark.ml.rst | 13 +++ python/pyspark/ml/__init__.py | 144 +++++++++++++++++++++------- python/pyspark/ml/classification.py | 25 ++++- python/pyspark/ml/feature.py | 17 ++++ python/pyspark/ml/param.py | 50 +++++++++- python/pyspark/ml/test.py | 15 --- python/pyspark/ml/util.py | 27 ++++++ 7 files changed, 240 insertions(+), 51 deletions(-) create mode 100644 python/docs/pyspark.ml.rst delete mode 100644 python/pyspark/ml/test.py create mode 100644 python/pyspark/ml/util.py diff --git a/python/docs/pyspark.ml.rst b/python/docs/pyspark.ml.rst new file mode 100644 index 0000000000000..9015a3c15edb2 --- /dev/null +++ b/python/docs/pyspark.ml.rst @@ -0,0 +1,13 @@ +pyspark.ml package +===================== + +Submodules +---------- + +pyspark.ml module +------------------------- + +.. automodule:: pyspark.ml + :members: + :undoc-members: + :show-inheritance: diff --git a/python/pyspark/ml/__init__.py b/python/pyspark/ml/__init__.py index 90cf56b97e093..b6606c76063db 100644 --- a/python/pyspark/ml/__init__.py +++ b/python/pyspark/ml/__init__.py @@ -1,57 +1,149 @@ -import inspect +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from abc import ABCMeta, abstractmethod from pyspark import SparkContext -from pyspark.ml.param import Param +from pyspark.sql import inherit_doc +from pyspark.ml.param import Param, Params +from pyspark.ml.util import Identifiable __all__ = ["Pipeline", "Transformer", "Estimator"] -# An implementation of PEP3102 for Python 2. -_keyword_only_secret = 70861589 + +def _jvm(): + return SparkContext._jvm -def _assert_keyword_only_args(): +@inherit_doc +class PipelineStage(Params): """ - Checks whether the _keyword_only trick is applied and validates input arguments. + A stage in a pipeline, either an :py:class:`Estimator` or a + :py:class:`Transformer`. """ - # Get the frame of the function that calls this function. - frame = inspect.currentframe().f_back - info = inspect.getargvalues(frame) - if "_keyword_only" not in info.args: - raise ValueError("Function does not have argument _keyword_only.") - if info.locals["_keyword_only"] != _keyword_only_secret: - raise ValueError("Must use keyword arguments instead of positional ones.") -def _jvm(): - return SparkContext._jvm + def __init__(self): + super.__init__(self) + + +@inherit_doc +class Estimator(PipelineStage): + """ + Abstract class for estimators that fit models to data. + """ + + __metaclass__ = ABCMeta + + def __init__(self): + super.__init__(self) + + @abstractmethod + def fit(self, dataset, params={}): + """ + Fits a model to the input dataset with optional parameters. + + :param dataset: input dataset, which is an instance of + :py:class:`pyspark.sql.SchemaRDD` + :param params: an optional param map that overwrites embedded + params + :returns: fitted model + """ + raise NotImplementedError() + + +@inherit_doc +class Transformer(PipelineStage): + """ + Abstract class for transformers that transform one dataset into + another. + """ + + __metaclass__ = ABCMeta -class Pipeline(object): + @abstractmethod + def transform(self, dataset, params={}): + """ + Transforms the input dataset with optional parameters. + + :param dataset: input dataset, which is an instance of + :py:class:`pyspark.sql.SchemaRDD` + :param params: an optional param map that overwrites embedded + params + :returns: transformed dataset + """ + raise NotImplementedError() + + +@inherit_doc +class Pipeline(Estimator): + """ + A simple pipeline, which acts as an estimator. A Pipeline consists + of a sequence of stages, each of which is either an + :py:class:`Estimator` or a :py:class:`Transformer`. When + :py:meth:`Pipeline.fit` is called, the stages are executed in + order. If a stage is an :py:class:`Estimator`, its + :py:meth:`Estimator.fit` method will be called on the input + dataset to fit a model. Then the model, which is a transformer, + will be used to transform the dataset as the input to the next + stage. If a stage is a :py:class:`Transformer`, its + :py:meth:`Transformer.transform` method will be called to produce + the dataset for the next stage. The fitted model from a + :py:class:`Pipeline` is an :py:class:`PipelineModel`, which + consists of fitted models and transformers, corresponding to the + pipeline stages. If there are no stages, the pipeline acts as an + identity transformer. + """ def __init__(self): + super.__init__(self) + #: Param for pipeline stages. self.stages = Param(self, "stages", "pipeline stages") - self.paramMap = {} def setStages(self, value): + """ + Set pipeline stages. + :param value: a list of transformers or estimators + :return: the pipeline instance + """ self.paramMap[self.stages] = value return self def getStages(self): + """ + Get pipeline stages. + """ if self.stages in self.paramMap: return self.paramMap[self.stages] def fit(self, dataset): transformers = [] for stage in self.getStages(): - if hasattr(stage, "transform"): + if isinstance(stage, Transformer): transformers.append(stage) dataset = stage.transform(dataset) - elif hasattr(stage, "fit"): + elif isinstance(stage, Estimator): model = stage.fit(dataset) transformers.append(model) dataset = model.transform(dataset) return PipelineModel(transformers) -class PipelineModel(object): +@inherit_doc +class PipelineModel(Transformer): def __init__(self, transformers): self.transformers = transformers @@ -60,15 +152,3 @@ def transform(self, dataset): for t in self.transformers: dataset = t.transform(dataset) return dataset - - -class Estimator(object): - - def fit(self, dataset, params={}): - raise NotImplementedError() - - -class Transformer(object): - - def transform(self, dataset, paramMap={}): - raise NotImplementedError() diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 957d93d45cbf7..2c9aaad03cedf 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -1,3 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + from pyspark.sql import SchemaRDD from pyspark.ml import Estimator, Transformer, _jvm from pyspark.ml.param import Param @@ -41,7 +58,8 @@ def fit(self, dataset, params=None): """ Fits a dataset with optional parameters. """ - java_model = self._java_obj.fit(dataset._jschema_rdd, _jvm().org.apache.spark.ml.param.ParamMap()) + java_model = self._java_obj.fit(dataset._jschema_rdd, + _jvm().org.apache.spark.ml.param.ParamMap()) return LogisticRegressionModel(java_model) @@ -54,5 +72,6 @@ def __init__(self, _java_model): self._java_model = _java_model def transform(self, dataset): - return SchemaRDD(self._java_model.transform(dataset._jschema_rdd, _jvm().org.apache.spark.ml.param.ParamMap()), dataset.sql_ctx) - + return SchemaRDD(self._java_model.transform( + dataset._jschema_rdd, + _jvm().org.apache.spark.ml.param.ParamMap()), dataset.sql_ctx) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 23923204b60a9..ce45105ba2b28 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -1,3 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + from pyspark.sql import SchemaRDD, ArrayType, StringType from pyspark.ml import _jvm from pyspark.ml.param import Param diff --git a/python/pyspark/ml/param.py b/python/pyspark/ml/param.py index 181a158cb94c8..427a70cc11d5c 100644 --- a/python/pyspark/ml/param.py +++ b/python/pyspark/ml/param.py @@ -1,3 +1,28 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from abc import ABCMeta, abstractmethod + +from pyspark.ml.util import Identifiable + + +__all__ = ["Param"] + + class Param(object): """ A param with self-contained documentation and optionally default value. @@ -12,5 +37,28 @@ def __init__(self, parent, name, doc, defaultValue=None): def __str__(self): return self.parent + "_" + self.name - def __repr_(self): + def __repr__(self): return self.parent + "_" + self.name + + +class Params(Identifiable): + """ + Components that take parameters. This also provides an internal + param map to store parameter values attached to the instance. + """ + + __metaclass__ = ABCMeta + + def __init__(self): + super.__init__(self) + #: Internal param map. + self.paramMap = {} + + @abstractmethod + def params(self): + """ + Returns all params. The default implementation uses + :py:func:`dir` to get all attributes of type + :py:class:`Param`. + """ + return [attr for attr in dir(self) if isinstance(attr, Param)] diff --git a/python/pyspark/ml/test.py b/python/pyspark/ml/test.py deleted file mode 100644 index aad7483488ad7..0000000000000 --- a/python/pyspark/ml/test.py +++ /dev/null @@ -1,15 +0,0 @@ -import subprocess - -def funcA(dataset, **kwargs): - """ - funcA - :param dataset: - :param kwargs: - - :return: - """ - pass - - -dataset = [] -funcA(dataset, ) diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py new file mode 100644 index 0000000000000..c6561a13a5d9d --- /dev/null +++ b/python/pyspark/ml/util.py @@ -0,0 +1,27 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +class Identifiable(object): + """ + Object with a unique ID. + """ + + def __init__(self): + #: A unique id for the object. The default implementation + #: concatenates the class name, "-", and 8 random hex chars. + self.uid = type(self).__name__ + "-" + uuid.uuid4().hex[:8]