Skip to content

Commit

Permalink
[SPARK-1406] Binary classification for SVM and Logistic Regression
Browse files Browse the repository at this point in the history
  • Loading branch information
selvinsource committed Apr 28, 2015
1 parent cfcb596 commit 7a5e0ec
Show file tree
Hide file tree
Showing 6 changed files with 77 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,35 +21,46 @@ import scala.{Array => SArray}

import org.dmg.pmml._

import org.apache.spark.mllib.classification.LogisticRegressionModel
import org.apache.spark.mllib.regression.GeneralizedLinearModel

/**
* PMML Model Export for LogisticRegressionModel class
* PMML Model Export for GeneralizedLinearModel class with binary ClassificationModel
*/
private[mllib] class LogisticRegressionPMMLModelExport(
model : LogisticRegressionModel,
description : String)
private[mllib] class BinaryClassificationPMMLModelExport(
model : GeneralizedLinearModel,
description : String,
normalizationMethod : RegressionNormalizationMethodType,
threshold: Double)
extends PMMLModelExport {

populateLogisticRegressionPMML(model)
populateBinaryClassificationPMML()

/**
* Export the input LogisticRegressionModel model to PMML format
* Export the input LogisticRegressionModel or SVMModel to PMML format.
*/
private def populateLogisticRegressionPMML(model : LogisticRegressionModel): Unit = {
private def populateBinaryClassificationPMML(): Unit = {
pmml.getHeader.setDescription(description)

if (model.weights.size > 0) {
val fields = new SArray[FieldName](model.weights.size)
val dataDictionary = new DataDictionary
val miningSchema = new MiningSchema
val regressionTableYES = new RegressionTable(model.intercept).withTargetCategory("1")
val regressionTableNO = new RegressionTable(0.0).withTargetCategory("0")
var interceptNO = threshold
if (RegressionNormalizationMethodType.LOGIT == normalizationMethod) {
if (threshold <= 0)
interceptNO = 1000
else if (threshold >= 1)
interceptNO = -1000
else
interceptNO = -math.log(1/threshold -1)
}
val regressionTableNO = new RegressionTable(interceptNO).withTargetCategory("0")
val regressionModel = new RegressionModel()
.withFunctionName(MiningFunctionType.CLASSIFICATION)
.withMiningSchema(miningSchema)
.withModelName(description)
.withNormalizationMethod(RegressionNormalizationMethodType.LOGIT)
.withNormalizationMethod(normalizationMethod)
.withRegressionTables(regressionTableYES, regressionTableNO)

for (i <- 0 until model.weights.size) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.mllib.pmml.export

import org.dmg.pmml.RegressionNormalizationMethodType

import org.apache.spark.mllib.classification.LogisticRegressionModel
import org.apache.spark.mllib.classification.SVMModel
import org.apache.spark.mllib.clustering.KMeansModel
Expand All @@ -41,11 +43,14 @@ private[mllib] object PMMLModelExportFactory {
case lasso: LassoModel =>
new GeneralizedLinearPMMLModelExport(lasso, "lasso regression")
case svm: SVMModel =>
new GeneralizedLinearPMMLModelExport(svm,
"linear SVM: if predicted value > 0, the outcome is positive, or negative otherwise")
new BinaryClassificationPMMLModelExport(
svm, "linear SVM", RegressionNormalizationMethodType.NONE,
svm.getThreshold.getOrElse(0.0))
case logistic: LogisticRegressionModel =>
if(logistic.numClasses == 2)
new LogisticRegressionPMMLModelExport(logistic, "logistic regression")
if (logistic.numClasses == 2)
new BinaryClassificationPMMLModelExport(
logistic, "logistic regression", RegressionNormalizationMethodType.LOGIT,
logistic.getThreshold.getOrElse(0.5))
else
throw new IllegalArgumentException(
"PMML Export not supported for Multinomial Logistic Regression")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,16 @@
package org.apache.spark.mllib.pmml.export

import org.dmg.pmml.RegressionModel
import org.dmg.pmml.RegressionNormalizationMethodType
import org.scalatest.FunSuite

import org.apache.spark.mllib.classification.LogisticRegressionModel
import org.apache.spark.mllib.classification.SVMModel
import org.apache.spark.mllib.util.LinearDataGenerator

class LogisticRegressionPMMLModelExportSuite extends FunSuite {
class BinaryClassificationPMMLModelExportSuite extends FunSuite {

test("LogisticRegressionPMMLModelExport generate PMML format") {
test("logistic regression PMML export") {
val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17)
val logisticRegressionModel =
new LogisticRegressionModel(linearInput(0).features, linearInput(0).label)
Expand All @@ -48,5 +50,35 @@ class LogisticRegressionPMMLModelExportSuite extends FunSuite {
// verify if there is a second table with target category 0 and no predictors
assert(pmmlRegressionModel.getRegressionTables.get(1).getTargetCategory === "0")
assert(pmmlRegressionModel.getRegressionTables.get(1).getNumericPredictors.size === 0)
// ensure logistic regression has normalization method set to LOGIT
assert(pmmlRegressionModel.getNormalizationMethod() == RegressionNormalizationMethodType.LOGIT)
}

test("linear SVM PMML export") {
val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17)
val svmModel = new SVMModel(linearInput(0).features, linearInput(0).label)

val svmModelExport = PMMLModelExportFactory.createPMMLModelExport(svmModel)

// assert that the PMML format is as expected
assert(svmModelExport.isInstanceOf[PMMLModelExport])
val pmml = svmModelExport.getPmml
assert(pmml.getHeader.getDescription
=== "linear SVM")
// check that the number of fields match the weights size
assert(pmml.getDataDictionary.getNumberOfFields === svmModel.weights.size + 1)
// This verify that there is a model attached to the pmml object and the model is a regression
// one. It also verifies that the pmml model has a regression table (for target category 1)
// with the same number of predictors of the model weights.
val pmmlRegressionModel = pmml.getModels.get(0).asInstanceOf[RegressionModel]
assert(pmmlRegressionModel.getRegressionTables.get(0).getTargetCategory === "1")
assert(pmmlRegressionModel.getRegressionTables.get(0).getNumericPredictors.size
=== svmModel.weights.size)
// verify if there is a second table with target category 0 and no predictors
assert(pmmlRegressionModel.getRegressionTables.get(1).getTargetCategory === "0")
assert(pmmlRegressionModel.getRegressionTables.get(1).getNumericPredictors.size === 0)
// ensure linear SVM has normalization method set to NONE
assert(pmmlRegressionModel.getNormalizationMethod() == RegressionNormalizationMethodType.NONE)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,12 @@ package org.apache.spark.mllib.pmml.export
import org.dmg.pmml.RegressionModel
import org.scalatest.FunSuite

import org.apache.spark.mllib.classification.SVMModel
import org.apache.spark.mllib.regression.{LassoModel, LinearRegressionModel, RidgeRegressionModel}
import org.apache.spark.mllib.util.LinearDataGenerator

class GeneralizedLinearPMMLModelExportSuite extends FunSuite {

test("linear regression pmml export") {
test("linear regression PMML export") {
val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17)
val linearRegressionModel =
new LinearRegressionModel(linearInput(0).features, linearInput(0).label)
Expand All @@ -45,7 +44,7 @@ class GeneralizedLinearPMMLModelExportSuite extends FunSuite {
=== linearRegressionModel.weights.size)
}

test("ridge regression pmml export") {
test("ridge regression PMML export") {
val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17)
val ridgeRegressionModel =
new RidgeRegressionModel(linearInput(0).features, linearInput(0).label)
Expand All @@ -64,7 +63,7 @@ class GeneralizedLinearPMMLModelExportSuite extends FunSuite {
=== ridgeRegressionModel.weights.size)
}

test("lasso pmml export") {
test("lasso PMML export") {
val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17)
val lassoModel = new LassoModel(linearInput(0).features, linearInput(0).label)
val lassoModelExport = PMMLModelExportFactory.createPMMLModelExport(lassoModel)
Expand All @@ -82,22 +81,4 @@ class GeneralizedLinearPMMLModelExportSuite extends FunSuite {
=== lassoModel.weights.size)
}

test("svm pmml export") {
val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17)
val svmModel = new SVMModel(linearInput(0).features, linearInput(0).label)
val svmModelExport = PMMLModelExportFactory.createPMMLModelExport(svmModel)
// assert that the PMML format is as expected
assert(svmModelExport.isInstanceOf[PMMLModelExport])
val pmml = svmModelExport.getPmml
assert(pmml.getHeader.getDescription
=== "linear SVM: if predicted value > 0, the outcome is positive, or negative otherwise")
// check that the number of fields match the weights size
assert(pmml.getDataDictionary.getNumberOfFields === svmModel.weights.size + 1)
// This verify that there is a model attached to the pmml object and the model is a regression
// one. It also verifies that the pmml model has a regression table with the same number of
// predictors of the model weights.
val pmmlRegressionModel = pmml.getModels.get(0).asInstanceOf[RegressionModel]
assert(pmmlRegressionModel.getRegressionTables.get(0).getNumericPredictors.size
=== svmModel.weights.size)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,5 @@ class KMeansPMMLModelExportSuite extends FunSuite {
val pmmlClusteringModel = pmml.getModels.get(0).asInstanceOf[ClusteringModel]
assert(pmmlClusteringModel.getNumberOfClusters === clusterCenters.length)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class PMMLModelExportFactorySuite extends FunSuite {
}

test("PMMLModelExportFactory create GeneralizedLinearPMMLModelExport when passing a "
+ "LinearRegressionModel, RidgeRegressionModel, LassoModel or SVMModel") {
+ "LinearRegressionModel, RidgeRegressionModel or LassoModel") {
val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17)

val linearRegressionModel =
Expand All @@ -56,22 +56,21 @@ class PMMLModelExportFactorySuite extends FunSuite {
val lassoModel = new LassoModel(linearInput(0).features, linearInput(0).label)
val lassoModelExport = PMMLModelExportFactory.createPMMLModelExport(lassoModel)
assert(lassoModelExport.isInstanceOf[GeneralizedLinearPMMLModelExport])

val svmModel = new SVMModel(linearInput(0).features, linearInput(0).label)
val svmModelExport = PMMLModelExportFactory.createPMMLModelExport(svmModel)
assert(svmModelExport.isInstanceOf[GeneralizedLinearPMMLModelExport])
}

test("PMMLModelExportFactory create LogisticRegressionPMMLModelExport "
+ "when passing a LogisticRegressionModel") {
test("PMMLModelExportFactory create BinaryClassificationPMMLModelExport "
+ "when passing a LogisticRegressionModel or SVMModel") {
val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17)

val logisticRegressionModel =
new LogisticRegressionModel(linearInput(0).features, linearInput(0).label)

val logisticRegressionModelExport =
PMMLModelExportFactory.createPMMLModelExport(logisticRegressionModel)

assert(logisticRegressionModelExport.isInstanceOf[LogisticRegressionPMMLModelExport])
assert(logisticRegressionModelExport.isInstanceOf[BinaryClassificationPMMLModelExport])

val svmModel = new SVMModel(linearInput(0).features, linearInput(0).label)
val svmModelExport = PMMLModelExportFactory.createPMMLModelExport(svmModel)
assert(svmModelExport.isInstanceOf[BinaryClassificationPMMLModelExport])
}

test("PMMLModelExportFactory throw IllegalArgumentException "
Expand Down

0 comments on commit 7a5e0ec

Please sign in to comment.