Skip to content

Commit

Permalink
add base classes and docs
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Jan 19, 2015
1 parent a3015cf commit dadd84e
Show file tree
Hide file tree
Showing 7 changed files with 240 additions and 51 deletions.
13 changes: 13 additions & 0 deletions python/docs/pyspark.ml.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
pyspark.ml package
=====================

Submodules
----------

pyspark.ml module
-------------------------

.. automodule:: pyspark.ml
:members:
:undoc-members:
:show-inheritance:
144 changes: 112 additions & 32 deletions python/pyspark/ml/__init__.py
Original file line number Diff line number Diff line change
@@ -1,57 +1,149 @@
import inspect
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from abc import ABCMeta, abstractmethod

from pyspark import SparkContext
from pyspark.ml.param import Param
from pyspark.sql import inherit_doc
from pyspark.ml.param import Param, Params
from pyspark.ml.util import Identifiable

__all__ = ["Pipeline", "Transformer", "Estimator"]

# An implementation of PEP3102 for Python 2.
_keyword_only_secret = 70861589

def _jvm():
return SparkContext._jvm


def _assert_keyword_only_args():
@inherit_doc
class PipelineStage(Params):
"""
Checks whether the _keyword_only trick is applied and validates input arguments.
A stage in a pipeline, either an :py:class:`Estimator` or a
:py:class:`Transformer`.
"""
# Get the frame of the function that calls this function.
frame = inspect.currentframe().f_back
info = inspect.getargvalues(frame)
if "_keyword_only" not in info.args:
raise ValueError("Function does not have argument _keyword_only.")
if info.locals["_keyword_only"] != _keyword_only_secret:
raise ValueError("Must use keyword arguments instead of positional ones.")

def _jvm():
return SparkContext._jvm
def __init__(self):
super.__init__(self)


@inherit_doc
class Estimator(PipelineStage):
"""
Abstract class for estimators that fit models to data.
"""

__metaclass__ = ABCMeta

def __init__(self):
super.__init__(self)

@abstractmethod
def fit(self, dataset, params={}):
"""
Fits a model to the input dataset with optional parameters.
:param dataset: input dataset, which is an instance of
:py:class:`pyspark.sql.SchemaRDD`
:param params: an optional param map that overwrites embedded
params
:returns: fitted model
"""
raise NotImplementedError()


@inherit_doc
class Transformer(PipelineStage):
"""
Abstract class for transformers that transform one dataset into
another.
"""

__metaclass__ = ABCMeta

class Pipeline(object):
@abstractmethod
def transform(self, dataset, params={}):
"""
Transforms the input dataset with optional parameters.
:param dataset: input dataset, which is an instance of
:py:class:`pyspark.sql.SchemaRDD`
:param params: an optional param map that overwrites embedded
params
:returns: transformed dataset
"""
raise NotImplementedError()


@inherit_doc
class Pipeline(Estimator):
"""
A simple pipeline, which acts as an estimator. A Pipeline consists
of a sequence of stages, each of which is either an
:py:class:`Estimator` or a :py:class:`Transformer`. When
:py:meth:`Pipeline.fit` is called, the stages are executed in
order. If a stage is an :py:class:`Estimator`, its
:py:meth:`Estimator.fit` method will be called on the input
dataset to fit a model. Then the model, which is a transformer,
will be used to transform the dataset as the input to the next
stage. If a stage is a :py:class:`Transformer`, its
:py:meth:`Transformer.transform` method will be called to produce
the dataset for the next stage. The fitted model from a
:py:class:`Pipeline` is an :py:class:`PipelineModel`, which
consists of fitted models and transformers, corresponding to the
pipeline stages. If there are no stages, the pipeline acts as an
identity transformer.
"""

def __init__(self):
super.__init__(self)
#: Param for pipeline stages.
self.stages = Param(self, "stages", "pipeline stages")
self.paramMap = {}

def setStages(self, value):
"""
Set pipeline stages.
:param value: a list of transformers or estimators
:return: the pipeline instance
"""
self.paramMap[self.stages] = value
return self

def getStages(self):
"""
Get pipeline stages.
"""
if self.stages in self.paramMap:
return self.paramMap[self.stages]

def fit(self, dataset):
transformers = []
for stage in self.getStages():
if hasattr(stage, "transform"):
if isinstance(stage, Transformer):
transformers.append(stage)
dataset = stage.transform(dataset)
elif hasattr(stage, "fit"):
elif isinstance(stage, Estimator):
model = stage.fit(dataset)
transformers.append(model)
dataset = model.transform(dataset)
return PipelineModel(transformers)


class PipelineModel(object):
@inherit_doc
class PipelineModel(Transformer):

def __init__(self, transformers):
self.transformers = transformers
Expand All @@ -60,15 +152,3 @@ def transform(self, dataset):
for t in self.transformers:
dataset = t.transform(dataset)
return dataset


class Estimator(object):

def fit(self, dataset, params={}):
raise NotImplementedError()


class Transformer(object):

def transform(self, dataset, paramMap={}):
raise NotImplementedError()
25 changes: 22 additions & 3 deletions python/pyspark/ml/classification.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,20 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from pyspark.sql import SchemaRDD
from pyspark.ml import Estimator, Transformer, _jvm
from pyspark.ml.param import Param
Expand Down Expand Up @@ -41,7 +58,8 @@ def fit(self, dataset, params=None):
"""
Fits a dataset with optional parameters.
"""
java_model = self._java_obj.fit(dataset._jschema_rdd, _jvm().org.apache.spark.ml.param.ParamMap())
java_model = self._java_obj.fit(dataset._jschema_rdd,
_jvm().org.apache.spark.ml.param.ParamMap())
return LogisticRegressionModel(java_model)


Expand All @@ -54,5 +72,6 @@ def __init__(self, _java_model):
self._java_model = _java_model

def transform(self, dataset):
return SchemaRDD(self._java_model.transform(dataset._jschema_rdd, _jvm().org.apache.spark.ml.param.ParamMap()), dataset.sql_ctx)

return SchemaRDD(self._java_model.transform(
dataset._jschema_rdd,
_jvm().org.apache.spark.ml.param.ParamMap()), dataset.sql_ctx)
17 changes: 17 additions & 0 deletions python/pyspark/ml/feature.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,20 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from pyspark.sql import SchemaRDD, ArrayType, StringType
from pyspark.ml import _jvm
from pyspark.ml.param import Param
Expand Down
50 changes: 49 additions & 1 deletion python/pyspark/ml/param.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,28 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from abc import ABCMeta, abstractmethod

from pyspark.ml.util import Identifiable


__all__ = ["Param"]


class Param(object):
"""
A param with self-contained documentation and optionally default value.
Expand All @@ -12,5 +37,28 @@ def __init__(self, parent, name, doc, defaultValue=None):
def __str__(self):
return self.parent + "_" + self.name

def __repr_(self):
def __repr__(self):
return self.parent + "_" + self.name


class Params(Identifiable):
"""
Components that take parameters. This also provides an internal
param map to store parameter values attached to the instance.
"""

__metaclass__ = ABCMeta

def __init__(self):
super.__init__(self)
#: Internal param map.
self.paramMap = {}

@abstractmethod
def params(self):
"""
Returns all params. The default implementation uses
:py:func:`dir` to get all attributes of type
:py:class:`Param`.
"""
return [attr for attr in dir(self) if isinstance(attr, Param)]
15 changes: 0 additions & 15 deletions python/pyspark/ml/test.py

This file was deleted.

27 changes: 27 additions & 0 deletions python/pyspark/ml/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


class Identifiable(object):
"""
Object with a unique ID.
"""

def __init__(self):
#: A unique id for the object. The default implementation
#: concatenates the class name, "-", and 8 random hex chars.
self.uid = type(self).__name__ + "-" + uuid.uuid4().hex[:8]

0 comments on commit dadd84e

Please sign in to comment.