Skip to content

Commit

Permalink
[MLLIB] Add fit intercept api to ml logisticregression
Browse files Browse the repository at this point in the history
I have the fit intercept enabled by default for logistic regression, I
wonder what others think here. I understand that it enables allocation
by default which is undesirable, but one needs to have a very strong
reason for not having an intercept term enabled so it is the safer
default from a statistical sense.

Explicitly modeling the intercept by adding a column of all 1s does not
work. I believe the reason is that since the API for
LogisticRegressionWithLBFGS forces column normalization, and a column of all
1s has 0 variance so dividing by 0 kills it.
  • Loading branch information
Omede Firouz committed Mar 31, 2015
1 parent 0e2753f commit bd9663c
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import org.apache.spark.storage.StorageLevel
* Params for logistic regression.
*/
private[classification] trait LogisticRegressionParams extends ProbabilisticClassifierParams
with HasRegParam with HasMaxIter with HasThreshold
with HasRegParam with HasMaxIter with HasFitIntercept with HasThreshold


/**
Expand All @@ -46,6 +46,7 @@ class LogisticRegression
with LogisticRegressionParams {

setRegParam(0.1)
setFitIntercept(true)
setMaxIter(100)
setThreshold(0.5)

Expand All @@ -55,6 +56,9 @@ class LogisticRegression
/** @group setParam */
def setMaxIter(value: Int): this.type = set(maxIter, value)

/** @group setParam */
def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value)

/** @group setParam */
def setThreshold(value: Double): this.type = set(threshold, value)

Expand All @@ -71,6 +75,7 @@ class LogisticRegression
lr.optimizer
.setRegParam(paramMap(regParam))
.setNumIterations(paramMap(maxIter))
.addIntercept(paramMap(fitIntercept))
val oldModel = lr.run(oldDataset)
val lrm = new LogisticRegressionModel(this, paramMap, oldModel.weights, oldModel.intercept)

Expand Down
11 changes: 11 additions & 0 deletions mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,17 @@ private[ml] trait HasProbabilityCol extends Params {
def getProbabilityCol: String = get(probabilityCol)
}

private[ml] trait HasFitIntercept extends Params {
/**
* param for fitting the intercept term
* @group param
*/
val fitIntercept: BooleanParam = new BooleanParam(this, "fitIntercept", "fits the intercept term or not")

/** @group getParam */
def getFitIntercept: Boolean = get(fitIntercept)
}

private[ml] trait HasThreshold extends Params {
/**
* param for threshold in (binary) prediction
Expand Down

0 comments on commit bd9663c

Please sign in to comment.