Skip to content

Commit

Permalink
gen numFeatures
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Jan 27, 2015
1 parent 46fa147 commit 036ca04
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 17 deletions.
17 changes: 2 additions & 15 deletions python/pyspark/ml/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@

from pyspark.sql import inherit_doc
from pyspark.ml import JavaTransformer
from pyspark.ml.param import Param
from pyspark.ml.param.shared import HasInputCol, HasOutputCol
from pyspark.ml.param.shared import HasInputCol, HasOutputCol, HasNumFeatures


@inherit_doc
Expand All @@ -33,23 +32,11 @@ def _java_class(self):


@inherit_doc
class HashingTF(JavaTransformer, HasInputCol, HasOutputCol):
class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures):

def __init__(self):
super(HashingTF, self).__init__()
#: param for number of features
self.numFeatures = Param(self, "numFeatures", "number of features", 1 << 18)

@property
def _java_class(self):
return "org.apache.spark.ml.feature.HashingTF"

def setNumFeatures(self, value):
self.paramMap[self.numFeatures] = value
return self

def getNumFeatures(self):
if self.numFeatures in self.paramMap:
return self.paramMap[self.numFeatures]
else:
return self.numFeatures.defaultValue
2 changes: 1 addition & 1 deletion python/pyspark/ml/param/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(self, parent, name, doc, defaultValue=None):
self.defaultValue = defaultValue

def __str__(self):
return str(self.parent) + "_" + self.name
return str(self.parent) + "-" + self.name

def __repr__(self):
return "Param(parent=%r, name=%r, doc=%r, defaultValue=%r)" % \
Expand Down
3 changes: 2 additions & 1 deletion python/pyspark/ml/param/_gen_shared_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ def get$Name(self):
("labelCol", "label column name", "'label'"),
("predictionCol", "prediction column name", "'prediction'"),
("inputCol", "input column name", "'input'"),
("outputCol", "output column name", "'output'")]
("outputCol", "output column name", "'output'"),
("numFeatures", "number of features", "1 << 18")]
code = []
for name, doc, defaultValue in shared:
code.append(_gen_param_code(name, doc, defaultValue))
Expand Down
30 changes: 30 additions & 0 deletions python/pyspark/ml/param/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,3 +228,33 @@ def getOutputCol(self):
return self.paramMap[self.outputCol]
else:
return self.outputCol.defaultValue


class HasNumFeatures(Params):
"""
Params with numFeatures.
"""

# a placeholder to make it appear in the generated doc
numFeatures = Param(Params._dummy(), "numFeatures", "number of features", 1 << 18)

def __init__(self):
super(HasNumFeatures, self).__init__()
#: param for number of features
self.numFeatures = Param(self, "numFeatures", "number of features", 1 << 18)

def setNumFeatures(self, value):
"""
Sets the value of :py:attr:`numFeatures`.
"""
self.paramMap[self.numFeatures] = value
return self

def getNumFeatures(self):
"""
Gets the value of numFeatures or its default value.
"""
if self.numFeatures in self.paramMap:
return self.paramMap[self.numFeatures]
else:
return self.numFeatures.defaultValue

0 comments on commit 036ca04

Please sign in to comment.