Skip to content

Commit

Permalink
update setSummary for other algos
Browse files Browse the repository at this point in the history
  • Loading branch information
sethah committed Nov 21, 2016
1 parent 428348d commit d6caa02
Show file tree
Hide file tree
Showing 8 changed files with 30 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -661,7 +661,7 @@ class LogisticRegression @Since("1.2.0") (
$(labelCol),
$(featuresCol),
objectiveHistory)
model.setSummary(logRegSummary)
model.setSummary(Some(logRegSummary))
} else {
model
}
Expand Down Expand Up @@ -803,9 +803,9 @@ class LogisticRegressionModel private[spark] (
}
}

private[classification] def setSummary(
summary: LogisticRegressionTrainingSummary): this.type = {
this.trainingSummary = Some(summary)
private[classification]
def setSummary(summary: Option[LogisticRegressionTrainingSummary]): this.type = {
this.trainingSummary = summary
this
}

Expand Down Expand Up @@ -900,8 +900,7 @@ class LogisticRegressionModel private[spark] (
override def copy(extra: ParamMap): LogisticRegressionModel = {
val newModel = copyValues(new LogisticRegressionModel(uid, coefficientMatrix, interceptVector,
numClasses, isMultinomial), extra)
if (trainingSummary.isDefined) newModel.setSummary(trainingSummary.get)
newModel.setParent(parent)
newModel.setSummary(trainingSummary).setParent(parent)
}

override protected def raw2prediction(rawPrediction: Vector): Double = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,7 @@ class KMeansModel private[ml] (
@Since("1.5.0")
override def copy(extra: ParamMap): KMeansModel = {
val copied = copyValues(new KMeansModel(uid, parentModel), extra)
if (trainingSummary.isDefined) copied.setSummary(trainingSummary.get)
copied.setParent(this.parent)
copied.setSummary(trainingSummary).setParent(this.parent)
}

/** @group setParam */
Expand Down Expand Up @@ -165,8 +164,8 @@ class KMeansModel private[ml] (

private var trainingSummary: Option[KMeansSummary] = None

private[clustering] def setSummary(summary: KMeansSummary): this.type = {
this.trainingSummary = Some(summary)
private[clustering] def setSummary(summary: Option[KMeansSummary]): this.type = {
this.trainingSummary = summary
this
}

Expand Down Expand Up @@ -325,7 +324,7 @@ class KMeans @Since("1.5.0") (
val model = copyValues(new KMeansModel(uid, parentModel).setParent(this))
val summary = new KMeansSummary(
model.transform(dataset), $(predictionCol), $(featuresCol), $(k))
model.setSummary(summary)
model.setSummary(Some(summary))
instr.logSuccess(model)
model
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
.setParent(this))
val trainingSummary = new GeneralizedLinearRegressionTrainingSummary(dataset, model,
wlsModel.diagInvAtWA.toArray, 1, getSolver)
return model.setSummary(trainingSummary)
return model.setSummary(Some(trainingSummary))
}

// Fit Generalized Linear Model by iteratively reweighted least squares (IRLS).
Expand All @@ -284,7 +284,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
.setParent(this))
val trainingSummary = new GeneralizedLinearRegressionTrainingSummary(dataset, model,
irlsModel.diagInvAtWA.toArray, irlsModel.numIterations, getSolver)
model.setSummary(trainingSummary)
model.setSummary(Some(trainingSummary))
}

@Since("2.0.0")
Expand Down Expand Up @@ -761,8 +761,8 @@ class GeneralizedLinearRegressionModel private[ml] (
def hasSummary: Boolean = trainingSummary.nonEmpty

private[regression]
def setSummary(summary: GeneralizedLinearRegressionTrainingSummary): this.type = {
this.trainingSummary = Some(summary)
def setSummary(summary: Option[GeneralizedLinearRegressionTrainingSummary]): this.type = {
this.trainingSummary = summary
this
}

Expand All @@ -778,8 +778,7 @@ class GeneralizedLinearRegressionModel private[ml] (
override def copy(extra: ParamMap): GeneralizedLinearRegressionModel = {
val copied = copyValues(new GeneralizedLinearRegressionModel(uid, coefficients, intercept),
extra)
if (trainingSummary.isDefined) copied.setSummary(trainingSummary.get)
copied.setParent(parent)
copied.setSummary(trainingSummary).setParent(parent)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
model.diagInvAtWA.toArray,
model.objectiveHistory)

return lrModel.setSummary(trainingSummary)
return lrModel.setSummary(Some(trainingSummary))
}

val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
Expand Down Expand Up @@ -276,7 +276,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
model,
Array(0D),
Array(0D))
return model.setSummary(trainingSummary)
return model.setSummary(Some(trainingSummary))
} else {
require($(regParam) == 0.0, "The standard deviation of the label is zero. " +
"Model cannot be regularized.")
Expand Down Expand Up @@ -398,7 +398,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
model,
Array(0D),
objectiveHistory)
model.setSummary(trainingSummary)
model.setSummary(Some(trainingSummary))
}

@Since("1.4.0")
Expand Down Expand Up @@ -444,8 +444,9 @@ class LinearRegressionModel private[ml] (
throw new SparkException("No training summary available for this LinearRegressionModel")
}

private[regression] def setSummary(summary: LinearRegressionTrainingSummary): this.type = {
this.trainingSummary = Some(summary)
private[regression]
def setSummary(summary: Option[LinearRegressionTrainingSummary]): this.type = {
this.trainingSummary = summary
this
}

Expand Down Expand Up @@ -488,8 +489,7 @@ class LinearRegressionModel private[ml] (
@Since("1.4.0")
override def copy(extra: ParamMap): LinearRegressionModel = {
val newModel = copyValues(new LinearRegressionModel(uid, coefficients, intercept), extra)
if (trainingSummary.isDefined) newModel.setSummary(trainingSummary.get)
newModel.setParent(parent)
newModel.setSummary(trainingSummary).setParent(parent)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ class LogisticRegressionSuite
assert(model.hasSummary)
val copiedModel = model.copy(ParamMap.empty)
assert(copiedModel.hasSummary)
model.setSummary(None)
assert(!model.hasSummary)
}

test("empty probabilityCol") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,9 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR
assert(clusterSizes.length === k)
assert(clusterSizes.sum === numRows)
assert(clusterSizes.forall(_ >= 0))

model.setSummary(None)
assert(!model.hasSummary)
}

test("KMeansModel transform with non-default feature and prediction cols") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,8 @@ class GeneralizedLinearRegressionSuite
assert(model.hasSummary)
val copiedModel = model.copy(ParamMap.empty)
assert(copiedModel.hasSummary)
model.setSummary(None)
assert(!model.hasSummary)

assert(model.getFeaturesCol === "features")
assert(model.getPredictionCol === "prediction")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,8 @@ class LinearRegressionSuite
assert(model.hasSummary)
val copiedModel = model.copy(ParamMap.empty)
assert(copiedModel.hasSummary)
model.setSummary(None)
assert(!model.hasSummary)

model.transform(datasetWithDenseFeature)
.select("label", "prediction")
Expand Down

0 comments on commit d6caa02

Please sign in to comment.