Skip to content

Commit

Permalink
make the example working
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Jan 19, 2015
1 parent dadd84e commit c18dca1
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 13 deletions.
27 changes: 25 additions & 2 deletions examples/src/main/python/ml/simple_text_classification_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,20 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from pyspark import SparkContext
from pyspark.sql import SQLContext, Row
from pyspark.ml import Pipeline
Expand All @@ -8,7 +25,10 @@
sc = SparkContext(appName="SimpleTextClassificationPipeline")
sqlCtx = SQLContext(sc)
training = sqlCtx.inferSchema(
sc.parallelize([(0L, "a b c d e spark", 1.0), (1L, "b d", 0.0), (2L, "spark f g h", 1.0), (3L, "hadoop mapreduce", 0.0)]) \
sc.parallelize([(0L, "a b c d e spark", 1.0),
(1L, "b d", 0.0),
(2L, "spark f g h", 1.0),
(3L, "hadoop mapreduce", 0.0)]) \
.map(lambda x: Row(id=x[0], text=x[1], label=x[2])))

tokenizer = Tokenizer() \
Expand All @@ -26,7 +46,10 @@
model = pipeline.fit(training)

test = sqlCtx.inferSchema(
sc.parallelize([(4L, "spark i j k"), (5L, "l m n"), (6L, "mapreduce spark"), (7L, "apache hadoop")]) \
sc.parallelize([(4L, "spark i j k"),
(5L, "l m n"),
(6L, "mapreduce spark"),
(7L, "apache hadoop")]) \
.map(lambda x: Row(id=x[0], text=x[1])))

for row in model.transform(test).collect():
Expand Down
15 changes: 11 additions & 4 deletions python/pyspark/ml/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from abc import ABCMeta, abstractmethod

from pyspark import SparkContext
from pyspark.sql import inherit_doc
from pyspark.sql import inherit_doc # TODO: move inherit_doc to Spark Core
from pyspark.ml.param import Param, Params
from pyspark.ml.util import Identifiable

Expand All @@ -37,7 +37,7 @@ class PipelineStage(Params):
"""

def __init__(self):
super.__init__(self)
super(PipelineStage, self).__init__()


@inherit_doc
Expand All @@ -49,7 +49,7 @@ class Estimator(PipelineStage):
__metaclass__ = ABCMeta

def __init__(self):
super.__init__(self)
super(Estimator, self).__init__()

@abstractmethod
def fit(self, dataset, params={}):
Expand All @@ -74,6 +74,9 @@ class Transformer(PipelineStage):

__metaclass__ = ABCMeta

def __init__(self):
super(Transformer, self).__init__()

@abstractmethod
def transform(self, dataset, params={}):
"""
Expand Down Expand Up @@ -109,7 +112,7 @@ class Pipeline(Estimator):
"""

def __init__(self):
super.__init__(self)
super(Pipeline, self).__init__()
#: Param for pipeline stages.
self.stages = Param(self, "stages", "pipeline stages")

Expand Down Expand Up @@ -139,13 +142,17 @@ def fit(self, dataset):
model = stage.fit(dataset)
transformers.append(model)
dataset = model.transform(dataset)
else:
raise ValueError(
"Cannot recognize a pipeline stage of type %s." % type(stage).__name__)
return PipelineModel(transformers)


@inherit_doc
class PipelineModel(Transformer):

def __init__(self, transformers):
super(PipelineModel, self).__init__()
self.transformers = transformers

def transform(self, dataset):
Expand Down
9 changes: 5 additions & 4 deletions python/pyspark/ml/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@
#

from pyspark.sql import SchemaRDD, ArrayType, StringType
from pyspark.ml import _jvm
from pyspark.ml import Transformer, _jvm
from pyspark.ml.param import Param


class Tokenizer(object):
class Tokenizer(Transformer):

def __init__(self):
super(Tokenizer, self).__init__()
self.inputCol = Param(self, "inputCol", "input column name", None)
self.outputCol = Param(self, "outputCol", "output column name", None)
self.paramMap = {}
Expand Down Expand Up @@ -61,9 +61,10 @@ def transform(self, dataset, params={}):
raise ValueError("The input params must be either a dict or a list.")


class HashingTF(object):
class HashingTF(Transformer):

def __init__(self):
super(HashingTF, self).__init__()
self._java_obj = _jvm().org.apache.spark.ml.feature.HashingTF()
self.numFeatures = Param(self, "numFeatures", "number of features", 1 << 18)
self.inputCol = Param(self, "inputCol", "input column name")
Expand Down
5 changes: 2 additions & 3 deletions python/pyspark/ml/param.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# limitations under the License.
#

from abc import ABCMeta, abstractmethod
from abc import ABCMeta

from pyspark.ml.util import Identifiable

Expand Down Expand Up @@ -50,11 +50,10 @@ class Params(Identifiable):
__metaclass__ = ABCMeta

def __init__(self):
super.__init__(self)
super(Params, self).__init__()
#: Internal param map.
self.paramMap = {}

@abstractmethod
def params(self):
"""
Returns all params. The default implementation uses
Expand Down
2 changes: 2 additions & 0 deletions python/pyspark/ml/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
# limitations under the License.
#

import uuid


class Identifiable(object):
"""
Expand Down

0 comments on commit c18dca1

Please sign in to comment.