Skip to content

Commit

Permalink
a pipeline in python
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Dec 31, 2014
1 parent 33b68e0 commit 46eea43
Show file tree
Hide file tree
Showing 7 changed files with 200 additions and 35 deletions.
33 changes: 33 additions & 0 deletions examples/src/main/python/ml/simple_text_classification_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from pyspark import SparkContext
from pyspark.sql import SQLContext, Row
from pyspark.ml import Pipeline
from pyspark.ml.feature import HashingTF, Tokenizer
from pyspark.ml.classification import LogisticRegression

if __name__ == "__main__":
sc = SparkContext(appName="SimpleTextClassificationPipeline")
sqlCtx = SQLContext(sc)
training = sqlCtx.inferSchema(
sc.parallelize([(0L, "a b c d e spark", 1.0), (1L, "b d", 0.0), (2L, "spark f g h", 1.0), (3L, "hadoop mapreduce", 0.0)]) \
.map(lambda x: Row(id=x[0], text=x[1], label=x[2])))

tokenizer = Tokenizer() \
.setInputCol("text") \
.setOutputCol("words")
hashingTF = HashingTF() \
.setInputCol(tokenizer.getOutputCol()) \
.setOutputCol("features")
lr = LogisticRegression() \
.setMaxIter(10) \
.setRegParam(0.01)
pipeline = Pipeline() \
.setStages([tokenizer, hashingTF, lr])

model = pipeline.fit(training)

test = sqlCtx.inferSchema(
sc.parallelize([(4L, "spark i j k"), (5L, "l m n"), (6L, "mapreduce spark"), (7L, "apache hadoop")]) \
.map(lambda x: Row(id=x[0], text=x[1])))

for row in model.transform(test).collect():
print row
7 changes: 7 additions & 0 deletions mllib/src/main/scala/org/apache/spark/ml/param/params.scala
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,13 @@ trait Params extends Identifiable with Serializable {
this
}

/**
* Sets a parameter (by name) in the embedded param map.
*/
private[ml] def set(param: String, value: Any): this.type = {
set(getParam(param), value)
}

/**
* Gets the value of a parameter in the embedded param map.
*/
Expand Down
40 changes: 40 additions & 0 deletions python/pyspark/ml/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import inspect

from pyspark import SparkContext
from pyspark.ml.param import Param

__all__ = ["Pipeline"]

# An implementation of PEP3102 for Python 2.
_keyword_only_secret = 70861589
Expand All @@ -20,3 +23,40 @@ def _assert_keyword_only_args():

def _jvm():
return SparkContext._jvm

class Pipeline(object):

def __init__(self):
self.stages = Param(self, "stages", "pipeline stages")
self.paramMap = {}

def setStages(self, value):
self.paramMap[self.stages] = value
return self

def getStages(self):
if self.stages in self.paramMap:
return self.paramMap[self.stages]

def fit(self, dataset):
transformers = []
for stage in self.getStages():
if hasattr(stage, "transform"):
transformers.append(stage)
dataset = stage.transform(dataset)
elif hasattr(stage, "fit"):
model = stage.fit(dataset)
transformers.append(model)
dataset = model.transform(dataset)
return PipelineModel(transformers)


class PipelineModel(object):

def __init__(self, transformers):
self.transformers = transformers

def transform(self, dataset):
for t in self.transformers:
dataset = t.transform(dataset)
return dataset
49 changes: 20 additions & 29 deletions python/pyspark/ml/classification.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pyspark.sql import SchemaRDD
from pyspark.ml import _keyword_only_secret, _assert_keyword_only_args, _jvm
from pyspark.ml import _jvm
from pyspark.ml.param import Param


Expand All @@ -8,45 +8,39 @@ class LogisticRegression(object):
Logistic regression.
"""

_java_class = "org.apache.spark.ml.classification.LogisticRegression"
# _java_class = "org.apache.spark.ml.classification.LogisticRegression"

def __init__(self):
self._java_obj = _jvm().org.apache.spark.ml.classification.LogisticRegression()
self.paramMap = {}
self.maxIter = Param(self, "maxIter", "max number of iterations", 100)
self.regParam = Param(self, "regParam", "regularization constant", 0.1)
self.featuresCol = Param(self, "featuresCol", "features column name", "features")

def set(self, _keyword_only=_keyword_only_secret,
maxIter=None, regParam=None):
_assert_keyword_only_args()
if maxIter is not None:
self.paramMap[self.maxIter] = maxIter
if regParam is not None:
self.paramMap[self.regParam] = regParam
return self

# cannot chained
def setMaxIter(self, value):
self.paramMap[self.maxIter] = value
self._java_obj.setMaxIter(value)
return self

def getMaxIter(self):
return self._java_obj.getMaxIter()

def setRegParam(self, value):
self.paramMap[self.regParam] = value
self._java_obj.setRegParam(value)
return self

def getMaxIter(self):
if self.maxIter in self.paramMap:
return self.paramMap[self.maxIter]
else:
return self.maxIter.defaultValue

def getRegParam(self):
if self.regParam in self.paramMap:
return self.paramMap[self.regParam]
else:
return self.regParam.defaultValue
return self._java_obj.getRegParam()

def setFeaturesCol(self, value):
self._java_obj.setFeaturesCol(value)
return self

def fit(self, dataset):
def getFeaturesCol(self):
return self._java_obj.getFeaturesCol()

def fit(self, dataset, params=None):
"""
Fits a dataset with optional parameters.
"""
java_model = self._java_obj.fit(dataset._jschema_rdd, _jvm().org.apache.spark.ml.param.ParamMap())
return LogisticRegressionModel(java_model)

Expand All @@ -62,6 +56,3 @@ def __init__(self, _java_model):
def transform(self, dataset):
return SchemaRDD(self._java_model.transform(dataset._jschema_rdd, _jvm().org.apache.spark.ml.param.ParamMap()), dataset.sql_ctx)

lr = LogisticRegression()

lr.set(maxIter=10, regParam=0.1)
85 changes: 85 additions & 0 deletions python/pyspark/ml/feature.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from pyspark.sql import SchemaRDD, ArrayType, StringType
from pyspark.ml import _jvm
from pyspark.ml.param import Param


class Tokenizer(object):

def __init__(self):
self.inputCol = Param(self, "inputCol", "input column name", None)
self.outputCol = Param(self, "outputCol", "output column name", None)
self.paramMap = {}

def setInputCol(self, value):
self.paramMap[self.inputCol] = value
return self

def getInputCol(self):
if self.inputCol in self.paramMap:
return self.paramMap[self.inputCol]

def setOutputCol(self, value):
self.paramMap[self.outputCol] = value
return self

def getOutputCol(self):
if self.outputCol in self.paramMap:
return self.paramMap[self.outputCol]

def transform(self, dataset, params={}):
sqlCtx = dataset.sql_ctx
if isinstance(params, dict):
paramMap = self.paramMap.copy()
paramMap.update(params)
inputCol = paramMap[self.inputCol]
outputCol = paramMap[self.outputCol]
# TODO: make names unique
sqlCtx.registerFunction("tokenize", lambda text: text.split(),
ArrayType(StringType(), False))
dataset.registerTempTable("dataset")
return sqlCtx.sql("SELECT *, tokenize(%s) AS %s FROM dataset" % (inputCol, outputCol))
elif isinstance(params, list):
return [self.transform(dataset, paramMap) for paramMap in params]
else:
raise ValueError("The input params must be either a dict or a list.")


class HashingTF(object):

def __init__(self):
self._java_obj = _jvm().org.apache.spark.ml.feature.HashingTF()
self.numFeatures = Param(self, "numFeatures", "number of features", 1 << 18)
self.inputCol = Param(self, "inputCol", "input column name")
self.outputCol = Param(self, "outputCol", "output column name")

def setNumFeatures(self, value):
self._java_obj.setNumFeatures(value)
return self

def getNumFeatures(self):
return self._java_obj.getNumFeatures()

def setInputCol(self, value):
self._java_obj.setInputCol(value)
return self

def getInputCol(self):
return self._java_obj.getInputCol()

def setOutputCol(self, value):
self._java_obj.setOutputCol(value)
return self

def getOutputCol(self):
return self._java_obj.getOutputCol()

def transform(self, dataset, paramMap={}):
if isinstance(paramMap, dict):
javaParamMap = _jvm().org.apache.spark.ml.param.ParamMap()
for k, v in paramMap.items():
param = self._java_obj.getParam(k.name)
javaParamMap.put(param, v)
return SchemaRDD(self._java_obj.transform(dataset._jschema_rdd, javaParamMap),
dataset.sql_ctx)
else:
raise ValueError("paramMap must be a dict.")
6 changes: 0 additions & 6 deletions python/pyspark/ml/param.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,3 @@ def __str__(self):

def __repr_(self):
return self.parent + "_" + self.name


class Params(object):
"""
Components that take parameters.
"""
15 changes: 15 additions & 0 deletions python/pyspark/ml/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import subprocess

def funcA(dataset, **kwargs):
"""
funcA
:param dataset:
:param kwargs:
:return:
"""
pass


dataset = []
funcA(dataset, )

0 comments on commit 46eea43

Please sign in to comment.