Skip to content

Commit

Permalink
a working LR
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Dec 23, 2014
1 parent c233ab3 commit 33b68e0
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 0 deletions.
22 changes: 22 additions & 0 deletions python/pyspark/ml/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import inspect

from pyspark import SparkContext

# An implementation of PEP3102 for Python 2.
_keyword_only_secret = 70861589


def _assert_keyword_only_args():
"""
Checks whether the _keyword_only trick is applied and validates input arguments.
"""
# Get the frame of the function that calls this function.
frame = inspect.currentframe().f_back
info = inspect.getargvalues(frame)
if "_keyword_only" not in info.args:
raise ValueError("Function does not have argument _keyword_only.")
if info.locals["_keyword_only"] != _keyword_only_secret:
raise ValueError("Must use keyword arguments instead of positional ones.")

def _jvm():
return SparkContext._jvm
67 changes: 67 additions & 0 deletions python/pyspark/ml/classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from pyspark.sql import SchemaRDD
from pyspark.ml import _keyword_only_secret, _assert_keyword_only_args, _jvm
from pyspark.ml.param import Param


class LogisticRegression(object):
"""
Logistic regression.
"""

_java_class = "org.apache.spark.ml.classification.LogisticRegression"

def __init__(self):
self._java_obj = _jvm().org.apache.spark.ml.classification.LogisticRegression()
self.paramMap = {}
self.maxIter = Param(self, "maxIter", "max number of iterations", 100)
self.regParam = Param(self, "regParam", "regularization constant", 0.1)

def set(self, _keyword_only=_keyword_only_secret,
maxIter=None, regParam=None):
_assert_keyword_only_args()
if maxIter is not None:
self.paramMap[self.maxIter] = maxIter
if regParam is not None:
self.paramMap[self.regParam] = regParam
return self

# cannot chained
def setMaxIter(self, value):
self.paramMap[self.maxIter] = value
return self

def setRegParam(self, value):
self.paramMap[self.regParam] = value
return self

def getMaxIter(self):
if self.maxIter in self.paramMap:
return self.paramMap[self.maxIter]
else:
return self.maxIter.defaultValue

def getRegParam(self):
if self.regParam in self.paramMap:
return self.paramMap[self.regParam]
else:
return self.regParam.defaultValue

def fit(self, dataset):
java_model = self._java_obj.fit(dataset._jschema_rdd, _jvm().org.apache.spark.ml.param.ParamMap())
return LogisticRegressionModel(java_model)


class LogisticRegressionModel(object):
"""
Model fitted by LogisticRegression.
"""

def __init__(self, _java_model):
self._java_model = _java_model

def transform(self, dataset):
return SchemaRDD(self._java_model.transform(dataset._jschema_rdd, _jvm().org.apache.spark.ml.param.ParamMap()), dataset.sql_ctx)

lr = LogisticRegression()

lr.set(maxIter=10, regParam=0.1)
22 changes: 22 additions & 0 deletions python/pyspark/ml/param.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
class Param(object):
"""
A param with self-contained documentation and optionally default value.
"""

def __init__(self, parent, name, doc, defaultValue=None):
self.parent = parent
self.name = name
self.doc = doc
self.defaultValue = defaultValue

def __str__(self):
return self.parent + "_" + self.name

def __repr_(self):
return self.parent + "_" + self.name


class Params(object):
"""
Components that take parameters.
"""

0 comments on commit 33b68e0

Please sign in to comment.