Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Davies Liu committed Jan 28, 2015
1 parent 78638df commit 54ca7df
Show file tree
Hide file tree
Showing 9 changed files with 200 additions and 163 deletions.
2 changes: 1 addition & 1 deletion python/pyspark/ml/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@
from pyspark.ml.param import *
from pyspark.ml.pipeline import *

__all__ = ["Pipeline", "Transformer", "Estimator"]
__all__ = ["Param", "Params", "Pipeline", "Transformer", "Estimator"]
7 changes: 4 additions & 3 deletions python/pyspark/ml/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
# limitations under the License.
#

from pyspark.ml.util import JavaEstimator, JavaModel, inherit_doc
from pyspark.ml.util import inherit_doc
from pyspark.ml.wrapper import JavaEstimator, JavaModel
from pyspark.ml.param.shared import HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,\
HasRegParam

Expand All @@ -39,10 +40,10 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
.setRegParam(0.01)
>>> model = lr.fit(dataset)
>>> test0 = sqlCtx.inferSchema(sc.parallelize([Row(features=Vectors.dense(-1.0))]))
>>> print model.transform(test0).first().prediction
>>> print model.transform(test0).head().prediction
0.0
>>> test1 = sqlCtx.inferSchema(sc.parallelize([Row(features=Vectors.sparse(1, [0], [1.0]))]))
>>> print model.transform(test1).first().prediction
>>> print model.transform(test1).head().prediction
1.0
"""
_java_class = "org.apache.spark.ml.classification.LogisticRegression"
Expand Down
12 changes: 6 additions & 6 deletions python/pyspark/ml/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
# limitations under the License.
#

from pyspark.ml.util import JavaTransformer, inherit_doc
from pyspark.ml.param.shared import HasInputCol, HasOutputCol, HasNumFeatures

from pyspark.ml.util import inherit_doc
from pyspark.ml.wrapper import JavaTransformer

__all__ = ['Tokenizer', 'HashingTF']

Expand All @@ -33,9 +33,9 @@ class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol):
>>> tokenizer = Tokenizer() \
.setInputCol("text") \
.setOutputCol("words")
>>> print tokenizer.transform(dataset).first()
>>> print tokenizer.transform(dataset).head()
Row(text=u'a b c', words=[u'a', u'b', u'c'])
>>> print tokenizer.transform(dataset, {tokenizer.outputCol: "tokens"}).first()
>>> print tokenizer.transform(dataset, {tokenizer.outputCol: "tokens"}).head()
Row(text=u'a b c', tokens=[u'a', u'b', u'c'])
"""

Expand All @@ -54,10 +54,10 @@ class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures):
.setNumFeatures(10) \
.setInputCol("words") \
.setOutputCol("features")
>>> print hashingTF.transform(dataset).first().features
>>> print hashingTF.transform(dataset).head().features
(10,[7,8,9],[1.0,1.0,1.0])
>>> params = {hashingTF.numFeatures: 5, hashingTF.outputCol: "vector"}
>>> print hashingTF.transform(dataset, params).first().vector
>>> print hashingTF.transform(dataset, params).head().vector
(5,[2,3,4],[1.0,1.0,1.0])
"""

Expand Down
17 changes: 15 additions & 2 deletions python/pyspark/ml/param/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,26 @@
# limitations under the License.
#

import uuid
from abc import ABCMeta

from pyspark.ml.util import Identifiable

__all__ = ['Param', 'Params']


class Identifiable(object):
"""
Object with a unique ID.
"""

def __init__(self):
#: A unique id for the object. The default implementation
#: concatenates the class name, "-", and 8 random hex chars.
self.uid = type(self).__name__ + "-" + uuid.uuid4().hex[:8]

def __repr__(self):
return self.uid


class Param(object):
"""
A param with self-contained documentation and optionally default value.
Expand Down
15 changes: 14 additions & 1 deletion python/pyspark/ml/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,25 @@

from abc import ABCMeta, abstractmethod

from pyspark.sql import inherit_doc # TODO: move inherit_doc to Spark Core
from pyspark.ml.param import Param, Params

__all__ = ['Estimator', 'Transformer', 'Pipeline', 'PipelineModel']


def inherit_doc(cls):
for name, func in vars(cls).items():
# only inherit docstring for public functions
if name.startswith("_"):
continue
if not func.__doc__:
for parent in cls.__bases__:
parent_func = getattr(parent, name, None)
if parent_func and getattr(parent_func, "__doc__", None):
func.__doc__ = parent_func.__doc__
break
return cls


@inherit_doc
class Estimator(Params):
"""
Expand Down
6 changes: 3 additions & 3 deletions python/pyspark/ml/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@
import unittest

from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase
from pyspark.sql import SchemaRDD
from pyspark.ml import Transformer, Estimator, Pipeline
from pyspark.sql import DataFrame
from pyspark.ml.param import Param
from pyspark.ml.pipeline import Transformer, Estimator, Pipeline


class MockDataset(SchemaRDD):
class MockDataset(DataFrame):

def __init__(self):
self.index = 0
Expand Down
147 changes: 0 additions & 147 deletions python/pyspark/ml/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,28 +15,6 @@
# limitations under the License.
#

import uuid
from abc import ABCMeta

from pyspark import SparkContext
from pyspark.sql import DataFrame
from pyspark.ml.param import Params
from pyspark.ml.pipeline import Estimator, Transformer


class Identifiable(object):
"""
Object with a unique ID.
"""

def __init__(self):
#: A unique id for the object. The default implementation
#: concatenates the class name, "-", and 8 random hex chars.
self.uid = type(self).__name__ + "-" + uuid.uuid4().hex[:8]

def __repr__(self):
return self.uid


def inherit_doc(cls):
for name, func in vars(cls).items():
Expand All @@ -50,128 +28,3 @@ def inherit_doc(cls):
func.__doc__ = parent_func.__doc__
break
return cls


def _jvm():
"""
Returns the JVM view associated with SparkContext. Must be called
after SparkContext is initialized.
"""
jvm = SparkContext._jvm
if jvm:
return jvm
else:
raise AttributeError("Cannot load _jvm from SparkContext. Is SparkContext initialized?")


@inherit_doc
class JavaWrapper(Params):
"""
Utility class to help create wrapper classes from Java/Scala
implementations of pipeline components.
"""

__metaclass__ = ABCMeta

#: Fully-qualified class name of the wrapped Java component.
_java_class = None

def _java_obj(self):
"""
Returns or creates a Java object.
"""
java_obj = _jvm()
for name in self._java_class.split("."):
java_obj = getattr(java_obj, name)
return java_obj()

def _transfer_params_to_java(self, params, java_obj):
"""
Transforms the embedded params and additional params to the
input Java object.
:param params: additional params (overwriting embedded values)
:param java_obj: Java object to receive the params
"""
paramMap = self._merge_params(params)
for param in self.params:
if param in paramMap:
java_obj.set(param.name, paramMap[param])

def _empty_java_param_map(self):
"""
Returns an empty Java ParamMap reference.
"""
return _jvm().org.apache.spark.ml.param.ParamMap()

def _create_java_param_map(self, params, java_obj):
paramMap = self._empty_java_param_map()
for param, value in params.items():
if param.parent is self:
paramMap.put(java_obj.getParam(param.name), value)
return paramMap


@inherit_doc
class JavaEstimator(Estimator, JavaWrapper):
"""
Base class for :py:class:`Estimator`s that wrap Java/Scala
implementations.
"""

__metaclass__ = ABCMeta

def _create_model(self, java_model):
"""
Creates a model from the input Java model reference.
"""
return JavaModel(java_model)

def _fit_java(self, dataset, params={}):
"""
Fits a Java model to the input dataset.
:param dataset: input dataset, which is an instance of
:py:class:`pyspark.sql.SchemaRDD`
:param params: additional params (overwriting embedded values)
:return: fitted Java model
"""
java_obj = self._java_obj()
self._transfer_params_to_java(params, java_obj)
return java_obj.fit(dataset._jschema_rdd, self._empty_java_param_map())

def fit(self, dataset, params={}):
java_model = self._fit_java(dataset, params)
return self._create_model(java_model)


@inherit_doc
class JavaTransformer(Transformer, JavaWrapper):
"""
Base class for :py:class:`Transformer`s that wrap Java/Scala
implementations.
"""

__metaclass__ = ABCMeta

def transform(self, dataset, params={}):
java_obj = self._java_obj()
self._transfer_params_to_java({}, java_obj)
java_param_map = self._create_java_param_map(params, java_obj)
return DataFrame(java_obj.transform(dataset._jschema_rdd, java_param_map),
dataset.sql_ctx)


@inherit_doc
class JavaModel(JavaTransformer):
"""
Base class for :py:class:`Model`s that wrap Java/Scala
implementations.
"""

__metaclass__ = ABCMeta

def __init__(self, java_model):
super(JavaTransformer, self).__init__()
self._java_model = java_model

def _java_obj(self):
return self._java_model
Loading

0 comments on commit 54ca7df

Please sign in to comment.