forked from apache/spark
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
200 additions
and
35 deletions.
There are no files selected for viewing
33 changes: 33 additions & 0 deletions
33
examples/src/main/python/ml/simple_text_classification_pipeline.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
from pyspark import SparkContext | ||
from pyspark.sql import SQLContext, Row | ||
from pyspark.ml import Pipeline | ||
from pyspark.ml.feature import HashingTF, Tokenizer | ||
from pyspark.ml.classification import LogisticRegression | ||
|
||
if __name__ == "__main__": | ||
sc = SparkContext(appName="SimpleTextClassificationPipeline") | ||
sqlCtx = SQLContext(sc) | ||
training = sqlCtx.inferSchema( | ||
sc.parallelize([(0L, "a b c d e spark", 1.0), (1L, "b d", 0.0), (2L, "spark f g h", 1.0), (3L, "hadoop mapreduce", 0.0)]) \ | ||
.map(lambda x: Row(id=x[0], text=x[1], label=x[2]))) | ||
|
||
tokenizer = Tokenizer() \ | ||
.setInputCol("text") \ | ||
.setOutputCol("words") | ||
hashingTF = HashingTF() \ | ||
.setInputCol(tokenizer.getOutputCol()) \ | ||
.setOutputCol("features") | ||
lr = LogisticRegression() \ | ||
.setMaxIter(10) \ | ||
.setRegParam(0.01) | ||
pipeline = Pipeline() \ | ||
.setStages([tokenizer, hashingTF, lr]) | ||
|
||
model = pipeline.fit(training) | ||
|
||
test = sqlCtx.inferSchema( | ||
sc.parallelize([(4L, "spark i j k"), (5L, "l m n"), (6L, "mapreduce spark"), (7L, "apache hadoop")]) \ | ||
.map(lambda x: Row(id=x[0], text=x[1]))) | ||
|
||
for row in model.transform(test).collect(): | ||
print row |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
from pyspark.sql import SchemaRDD, ArrayType, StringType | ||
from pyspark.ml import _jvm | ||
from pyspark.ml.param import Param | ||
|
||
|
||
class Tokenizer(object): | ||
|
||
def __init__(self): | ||
self.inputCol = Param(self, "inputCol", "input column name", None) | ||
self.outputCol = Param(self, "outputCol", "output column name", None) | ||
self.paramMap = {} | ||
|
||
def setInputCol(self, value): | ||
self.paramMap[self.inputCol] = value | ||
return self | ||
|
||
def getInputCol(self): | ||
if self.inputCol in self.paramMap: | ||
return self.paramMap[self.inputCol] | ||
|
||
def setOutputCol(self, value): | ||
self.paramMap[self.outputCol] = value | ||
return self | ||
|
||
def getOutputCol(self): | ||
if self.outputCol in self.paramMap: | ||
return self.paramMap[self.outputCol] | ||
|
||
def transform(self, dataset, params={}): | ||
sqlCtx = dataset.sql_ctx | ||
if isinstance(params, dict): | ||
paramMap = self.paramMap.copy() | ||
paramMap.update(params) | ||
inputCol = paramMap[self.inputCol] | ||
outputCol = paramMap[self.outputCol] | ||
# TODO: make names unique | ||
sqlCtx.registerFunction("tokenize", lambda text: text.split(), | ||
ArrayType(StringType(), False)) | ||
dataset.registerTempTable("dataset") | ||
return sqlCtx.sql("SELECT *, tokenize(%s) AS %s FROM dataset" % (inputCol, outputCol)) | ||
elif isinstance(params, list): | ||
return [self.transform(dataset, paramMap) for paramMap in params] | ||
else: | ||
raise ValueError("The input params must be either a dict or a list.") | ||
|
||
|
||
class HashingTF(object): | ||
|
||
def __init__(self): | ||
self._java_obj = _jvm().org.apache.spark.ml.feature.HashingTF() | ||
self.numFeatures = Param(self, "numFeatures", "number of features", 1 << 18) | ||
self.inputCol = Param(self, "inputCol", "input column name") | ||
self.outputCol = Param(self, "outputCol", "output column name") | ||
|
||
def setNumFeatures(self, value): | ||
self._java_obj.setNumFeatures(value) | ||
return self | ||
|
||
def getNumFeatures(self): | ||
return self._java_obj.getNumFeatures() | ||
|
||
def setInputCol(self, value): | ||
self._java_obj.setInputCol(value) | ||
return self | ||
|
||
def getInputCol(self): | ||
return self._java_obj.getInputCol() | ||
|
||
def setOutputCol(self, value): | ||
self._java_obj.setOutputCol(value) | ||
return self | ||
|
||
def getOutputCol(self): | ||
return self._java_obj.getOutputCol() | ||
|
||
def transform(self, dataset, paramMap={}): | ||
if isinstance(paramMap, dict): | ||
javaParamMap = _jvm().org.apache.spark.ml.param.ParamMap() | ||
for k, v in paramMap.items(): | ||
param = self._java_obj.getParam(k.name) | ||
javaParamMap.put(param, v) | ||
return SchemaRDD(self._java_obj.transform(dataset._jschema_rdd, javaParamMap), | ||
dataset.sql_ctx) | ||
else: | ||
raise ValueError("paramMap must be a dict.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
import subprocess | ||
|
||
def funcA(dataset, **kwargs): | ||
""" | ||
funcA | ||
:param dataset: | ||
:param kwargs: | ||
:return: | ||
""" | ||
pass | ||
|
||
|
||
dataset = [] | ||
funcA(dataset, ) |