Skip to content

Commit

Permalink
[SPARK-1406] Adjusted logistic regression export description and target
Browse files Browse the repository at this point in the history
categories
  • Loading branch information
selvinsource committed Dec 13, 2014
1 parent 03bc3a5 commit 8fe12bb
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,7 @@ private[mllib] object ModelExportFactory {
svm,
"linear SVM: if predicted value > 0, the outcome is positive, or negative otherwise")
case logisticRegression: LogisticRegressionModel =>
new LogisticRegressionPMMLModelExport(
logisticRegression,
"logistic regression: if predicted value > 0.5, "
+ "the outcome is positive, or negative otherwise")
new LogisticRegressionPMMLModelExport(logisticRegression, "logistic regression")
case _ =>
throw new IllegalArgumentException("Export not supported for model: " + model.getClass)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,10 @@ private[mllib] class LogisticRegressionPMMLModelExport(
val miningSchema = new MiningSchema()

val regressionTableYES = new RegressionTable(model.intercept)
.withTargetCategory("YES")
.withTargetCategory("1")

val regressionTableNO = new RegressionTable(0.0)
.withTargetCategory("NO")
.withTargetCategory("0")

val regressionModel = new RegressionModel(miningSchema,MiningFunctionType.CLASSIFICATION)
.withModelName(description)
Expand All @@ -83,7 +83,7 @@ private[mllib] class LogisticRegressionPMMLModelExport(
val targetField = FieldName.create("target");
dataDictionary
.withDataFields(
new DataField(targetField, OpType.CONTINUOUS, DataType.DOUBLE)
new DataField(targetField, OpType.CATEGORICAL, DataType.STRING)
)
miningSchema
.withMiningFields(new MiningField(targetField)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,18 @@ class LogisticRegressionPMMLModelExportSuite extends FunSuite{
//assert that the PMML format is as expected
assert(logisticModelExport.isInstanceOf[PMMLModelExport])
var pmml = logisticModelExport.asInstanceOf[PMMLModelExport].getPmml()
assert(pmml.getHeader().getDescription() === "logistic regression: if predicted value > 0.5, the outcome is positive, or negative otherwise")
assert(pmml.getHeader().getDescription() === "logistic regression")
//check that the number of fields match the weights size
assert(pmml.getDataDictionary().getNumberOfFields() === logisticRegressionModel.weights.size + 1)
//this verify that there is a model attached to the pmml object and the model is a regression one
//it also verifies that the pmml model has a regression table (for target category YES) with the same number of predictors of the model weights
//it also verifies that the pmml model has a regression table (for target category 1) with the same number of predictors of the model weights
assert(pmml.getModels().get(0).asInstanceOf[RegressionModel]
.getRegressionTables().get(0).getTargetCategory() === "YES")
.getRegressionTables().get(0).getTargetCategory() === "1")
assert(pmml.getModels().get(0).asInstanceOf[RegressionModel]
.getRegressionTables().get(0).getNumericPredictors().size() === logisticRegressionModel.weights.size)
//verify if there is a second table with target category NO and no predictors
//verify if there is a second table with target category 0 and no predictors
assert(pmml.getModels().get(0).asInstanceOf[RegressionModel]
.getRegressionTables().get(1).getTargetCategory() === "NO")
.getRegressionTables().get(1).getTargetCategory() === "0")
assert(pmml.getModels().get(0).asInstanceOf[RegressionModel]
.getRegressionTables().get(1).getNumericPredictors().size() === 0)

Expand Down

0 comments on commit 8fe12bb

Please sign in to comment.