Skip to content

Commit

Permalink
a working copy
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Jan 22, 2015
1 parent bce72f4 commit d0c5bb8
Show file tree
Hide file tree
Showing 7 changed files with 127 additions and 108 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from pyspark.ml.feature import HashingTF, Tokenizer
from pyspark.ml.classification import LogisticRegression


if __name__ == "__main__":
sc = SparkContext(appName="SimpleTextClassificationPipeline")
sqlCtx = SQLContext(sc)
Expand Down
75 changes: 67 additions & 8 deletions python/pyspark/ml/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
# limitations under the License.
#

from abc import ABCMeta, abstractmethod
from abc import ABCMeta, abstractmethod, abstractproperty

from pyspark import SparkContext
from pyspark.sql import inherit_doc # TODO: move inherit_doc to Spark Core
from pyspark.sql import SchemaRDD, inherit_doc # TODO: move inherit_doc to Spark Core
from pyspark.ml.param import Param, Params
from pyspark.ml.util import Identifiable

Expand Down Expand Up @@ -146,16 +146,17 @@ def getStages(self):
if self.stages in self.paramMap:
return self.paramMap[self.stages]

def fit(self, dataset):
def fit(self, dataset, params={}):
map = self._merge_params(params)
transformers = []
for stage in self.getStages():
if isinstance(stage, Transformer):
transformers.append(stage)
dataset = stage.transform(dataset)
dataset = stage.transform(dataset, map)
elif isinstance(stage, Estimator):
model = stage.fit(dataset)
model = stage.fit(dataset, map)
transformers.append(model)
dataset = model.transform(dataset)
dataset = model.transform(dataset, map)
else:
raise ValueError(
"Cannot recognize a pipeline stage of type %s." % type(stage).__name__)
Expand All @@ -169,7 +170,65 @@ def __init__(self, transformers):
super(PipelineModel, self).__init__()
self.transformers = transformers

def transform(self, dataset):
def transform(self, dataset, params={}):
map = self._merge_params(params)
for t in self.transformers:
dataset = t.transform(dataset)
dataset = t.transform(dataset, map)
return dataset


@inherit_doc
class JavaWrapper(object):

__metaclass__ = ABCMeta

def __init__(self):
super(JavaWrapper, self).__init__()

@abstractproperty
def _java_class(self):
raise NotImplementedError

def _create_java_obj(self):
java_obj = _jvm()
for name in self._java_class.split("."):
java_obj = getattr(java_obj, name)
return java_obj()


@inherit_doc
class JavaEstimator(Estimator, JavaWrapper):

__metaclass__ = ABCMeta

def __init__(self):
super(JavaEstimator, self).__init__()

@abstractmethod
def _create_model(self, java_model):
raise NotImplementedError

def _fit_java(self, dataset, params={}):
java_obj = self._create_java_obj()
self._transfer_params_to_java(params, java_obj)
return java_obj.fit(dataset._jschema_rdd, _jvm().org.apache.spark.ml.param.ParamMap())

def fit(self, dataset, params={}):
java_model = self._fit_java(dataset, params)
return self._create_model(java_model)


@inherit_doc
class JavaTransformer(Transformer, JavaWrapper):

__metaclass__ = ABCMeta

def __init__(self):
super(JavaTransformer, self).__init__()

def transform(self, dataset, params={}):
java_obj = self._create_java_obj()
self._transfer_params_to_java(params, java_obj)
return SchemaRDD(java_obj.transform(dataset._jschema_rdd,
_jvm().org.apache.spark.ml.param.ParamMap()),
dataset.sql_ctx)
28 changes: 13 additions & 15 deletions python/pyspark/ml/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,42 +16,40 @@
#

from pyspark.sql import SchemaRDD, inherit_doc
from pyspark.ml import Estimator, Transformer, _jvm
from pyspark.ml import JavaEstimator, Transformer, _jvm
from pyspark.ml.param.shared import HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,\
HasRegParam


@inherit_doc
class LogisticRegression(Estimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,
class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,
HasRegParam):
"""
Logistic regression.
"""

# _java_class = "org.apache.spark.ml.classification.LogisticRegression"

def __init__(self):
super(LogisticRegression, self).__init__()
self._java_obj = _jvm().org.apache.spark.ml.classification.LogisticRegression()

def fit(self, dataset, params=None):
"""
Fits a dataset with optional parameters.
"""
java_model = self._java_obj.fit(dataset._jschema_rdd,
_jvm().org.apache.spark.ml.param.ParamMap())

@property
def _java_class(self):
return "org.apache.spark.ml.classification.LogisticRegression"

def _create_model(self, java_model):
return LogisticRegressionModel(java_model)


@inherit_doc
class LogisticRegressionModel(Transformer):
"""
Model fitted by LogisticRegression.
"""

def __init__(self, _java_model):
self._java_model = _java_model
def __init__(self, java_model):
self._java_model = java_model

def transform(self, dataset):
def transform(self, dataset, params={}):
# TODO: handle params here.
return SchemaRDD(self._java_model.transform(
dataset._jschema_rdd,
_jvm().org.apache.spark.ml.param.ParamMap()), dataset.sql_ctx)
86 changes: 18 additions & 68 deletions python/pyspark/ml/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,91 +15,41 @@
# limitations under the License.
#

from pyspark.sql import SchemaRDD, ArrayType, StringType, inherit_doc
from pyspark.ml import Transformer, _jvm
from pyspark.sql import inherit_doc
from pyspark.ml import JavaTransformer
from pyspark.ml.param import Param
from pyspark.ml.param.shared import HasInputCol, HasOutputCol


@inherit_doc
class Tokenizer(Transformer):
class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol):

def __init__(self):
super(Tokenizer, self).__init__()
self.inputCol = Param(self, "inputCol", "input column name", None)
self.outputCol = Param(self, "outputCol", "output column name", None)
self.paramMap = {}

def setInputCol(self, value):
self.paramMap[self.inputCol] = value
return self

def getInputCol(self):
if self.inputCol in self.paramMap:
return self.paramMap[self.inputCol]

def setOutputCol(self, value):
self.paramMap[self.outputCol] = value
return self

def getOutputCol(self):
if self.outputCol in self.paramMap:
return self.paramMap[self.outputCol]

def transform(self, dataset, params={}):
sqlCtx = dataset.sql_ctx
if isinstance(params, dict):
paramMap = self.paramMap.copy()
paramMap.update(params)
inputCol = paramMap[self.inputCol]
outputCol = paramMap[self.outputCol]
# TODO: make names unique
sqlCtx.registerFunction("tokenize", lambda text: text.split(),
ArrayType(StringType(), False))
dataset.registerTempTable("dataset")
return sqlCtx.sql("SELECT *, tokenize(%s) AS %s FROM dataset" % (inputCol, outputCol))
elif isinstance(params, list):
return [self.transform(dataset, paramMap) for paramMap in params]
else:
raise ValueError("The input params must be either a dict or a list.")
@property
def _java_class(self):
return "org.apache.spark.ml.feature.Tokenizer"


@inherit_doc
class HashingTF(Transformer):
class HashingTF(JavaTransformer, HasInputCol, HasOutputCol):

def __init__(self):
super(HashingTF, self).__init__()
self._java_obj = _jvm().org.apache.spark.ml.feature.HashingTF()
#: param for number of features
self.numFeatures = Param(self, "numFeatures", "number of features", 1 << 18)
self.inputCol = Param(self, "inputCol", "input column name")
self.outputCol = Param(self, "outputCol", "output column name")

@property
def _java_class(self):
return "org.apache.spark.ml.feature.HashingTF"

def setNumFeatures(self, value):
self._java_obj.setNumFeatures(value)
self.paramMap[self.numFeatures] = value
return self

def getNumFeatures(self):
return self._java_obj.getNumFeatures()

def setInputCol(self, value):
self._java_obj.setInputCol(value)
return self

def getInputCol(self):
return self._java_obj.getInputCol()

def setOutputCol(self, value):
self._java_obj.setOutputCol(value)
return self

def getOutputCol(self):
return self._java_obj.getOutputCol()

def transform(self, dataset, paramMap={}):
if isinstance(paramMap, dict):
javaParamMap = _jvm().org.apache.spark.ml.param.ParamMap()
for k, v in paramMap.items():
param = self._java_obj.getParam(k.name)
javaParamMap.put(param, v)
return SchemaRDD(self._java_obj.transform(dataset._jschema_rdd, javaParamMap),
dataset.sql_ctx)
if self.numFeatures in self.paramMap:
return self.paramMap[self.numFeatures]
else:
raise ValueError("paramMap must be a dict.")
return self.numFeatures.defaultValue
11 changes: 11 additions & 0 deletions python/pyspark/ml/param/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,14 @@ def params(self):
:py:class:`Param`.
"""
return filter(lambda x: isinstance(x, Param), map(lambda x: getattr(self, x), dir(self)))

def _merge_params(self, params):
map = self.paramMap.copy()
map.update(params)
return map

def _transfer_params_to_java(self, params, java_obj):
map = self._merge_params(params)
for param in self.params():
if param in map:
java_obj.set(param.name, map[param])
6 changes: 3 additions & 3 deletions python/pyspark/ml/param/_gen_shared_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,13 @@ def set%s(self, value):
self.paramMap[self.%s] = value
return self
def get%s(self, value):
def get%s(self):
if self.%s in self.paramMap:
return self.paramMap[self.%s]
else:
return self.defaultValue""" % (
return self.%s.defaultValue""" % (
upperCamelName, upperCamelName, doc, name, name, doc, defaultValue, upperCamelName, name,
upperCamelName, name, name)
upperCamelName, name, name, name)

if __name__ == "__main__":
print header
Expand Down
Loading

0 comments on commit d0c5bb8

Please sign in to comment.