From 7a5e0ec20176f43477e1d2ee21fb4432937ffc30 Mon Sep 17 00:00:00 2001 From: Vincenzo Selvaggio Date: Tue, 28 Apr 2015 06:53:59 +0100 Subject: [PATCH] [SPARK-1406] Binary classification for SVM and Logistic Regression --- ...BinaryClassificationPMMLModelExport.scala} | 31 ++++++++++------ .../pmml/export/PMMLModelExportFactory.scala | 13 ++++--- ...yClassificationPMMLModelExportSuite.scala} | 36 +++++++++++++++++-- ...eneralizedLinearPMMLModelExportSuite.scala | 25 ++----------- .../export/KMeansPMMLModelExportSuite.scala | 1 + .../export/PMMLModelExportFactorySuite.scala | 19 +++++----- 6 files changed, 77 insertions(+), 48 deletions(-) rename mllib/src/main/scala/org/apache/spark/mllib/pmml/export/{LogisticRegressionPMMLModelExport.scala => BinaryClassificationPMMLModelExport.scala} (71%) rename mllib/src/test/scala/org/apache/spark/mllib/pmml/export/{LogisticRegressionPMMLModelExportSuite.scala => BinaryClassificationPMMLModelExportSuite.scala} (56%) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/LogisticRegressionPMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala similarity index 71% rename from mllib/src/main/scala/org/apache/spark/mllib/pmml/export/LogisticRegressionPMMLModelExport.scala rename to mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala index 2bf4fa858b09b..58f4b52f1b497 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/LogisticRegressionPMMLModelExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala @@ -21,22 +21,24 @@ import scala.{Array => SArray} import org.dmg.pmml._ -import org.apache.spark.mllib.classification.LogisticRegressionModel +import org.apache.spark.mllib.regression.GeneralizedLinearModel /** - * PMML Model Export for LogisticRegressionModel class + * PMML Model Export for GeneralizedLinearModel class with binary ClassificationModel */ -private[mllib] class LogisticRegressionPMMLModelExport( - model : LogisticRegressionModel, - description : String) +private[mllib] class BinaryClassificationPMMLModelExport( + model : GeneralizedLinearModel, + description : String, + normalizationMethod : RegressionNormalizationMethodType, + threshold: Double) extends PMMLModelExport { - populateLogisticRegressionPMML(model) + populateBinaryClassificationPMML() /** - * Export the input LogisticRegressionModel model to PMML format + * Export the input LogisticRegressionModel or SVMModel to PMML format. */ - private def populateLogisticRegressionPMML(model : LogisticRegressionModel): Unit = { + private def populateBinaryClassificationPMML(): Unit = { pmml.getHeader.setDescription(description) if (model.weights.size > 0) { @@ -44,12 +46,21 @@ private[mllib] class LogisticRegressionPMMLModelExport( val dataDictionary = new DataDictionary val miningSchema = new MiningSchema val regressionTableYES = new RegressionTable(model.intercept).withTargetCategory("1") - val regressionTableNO = new RegressionTable(0.0).withTargetCategory("0") + var interceptNO = threshold + if (RegressionNormalizationMethodType.LOGIT == normalizationMethod) { + if (threshold <= 0) + interceptNO = 1000 + else if (threshold >= 1) + interceptNO = -1000 + else + interceptNO = -math.log(1/threshold -1) + } + val regressionTableNO = new RegressionTable(interceptNO).withTargetCategory("0") val regressionModel = new RegressionModel() .withFunctionName(MiningFunctionType.CLASSIFICATION) .withMiningSchema(miningSchema) .withModelName(description) - .withNormalizationMethod(RegressionNormalizationMethodType.LOGIT) + .withNormalizationMethod(normalizationMethod) .withRegressionTables(regressionTableYES, regressionTableNO) for (i <- 0 until model.weights.size) { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala index bd8c8f96a6e55..965e2785c3acc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala @@ -17,6 +17,8 @@ package org.apache.spark.mllib.pmml.export +import org.dmg.pmml.RegressionNormalizationMethodType + import org.apache.spark.mllib.classification.LogisticRegressionModel import org.apache.spark.mllib.classification.SVMModel import org.apache.spark.mllib.clustering.KMeansModel @@ -41,11 +43,14 @@ private[mllib] object PMMLModelExportFactory { case lasso: LassoModel => new GeneralizedLinearPMMLModelExport(lasso, "lasso regression") case svm: SVMModel => - new GeneralizedLinearPMMLModelExport(svm, - "linear SVM: if predicted value > 0, the outcome is positive, or negative otherwise") + new BinaryClassificationPMMLModelExport( + svm, "linear SVM", RegressionNormalizationMethodType.NONE, + svm.getThreshold.getOrElse(0.0)) case logistic: LogisticRegressionModel => - if(logistic.numClasses == 2) - new LogisticRegressionPMMLModelExport(logistic, "logistic regression") + if (logistic.numClasses == 2) + new BinaryClassificationPMMLModelExport( + logistic, "logistic regression", RegressionNormalizationMethodType.LOGIT, + logistic.getThreshold.getOrElse(0.5)) else throw new IllegalArgumentException( "PMML Export not supported for Multinomial Logistic Regression") diff --git a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/LogisticRegressionPMMLModelExportSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExportSuite.scala similarity index 56% rename from mllib/src/test/scala/org/apache/spark/mllib/pmml/export/LogisticRegressionPMMLModelExportSuite.scala rename to mllib/src/test/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExportSuite.scala index 696f95ed873bb..0b646cf1ce6c4 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/LogisticRegressionPMMLModelExportSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExportSuite.scala @@ -18,14 +18,16 @@ package org.apache.spark.mllib.pmml.export import org.dmg.pmml.RegressionModel +import org.dmg.pmml.RegressionNormalizationMethodType import org.scalatest.FunSuite import org.apache.spark.mllib.classification.LogisticRegressionModel +import org.apache.spark.mllib.classification.SVMModel import org.apache.spark.mllib.util.LinearDataGenerator -class LogisticRegressionPMMLModelExportSuite extends FunSuite { +class BinaryClassificationPMMLModelExportSuite extends FunSuite { - test("LogisticRegressionPMMLModelExport generate PMML format") { + test("logistic regression PMML export") { val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17) val logisticRegressionModel = new LogisticRegressionModel(linearInput(0).features, linearInput(0).label) @@ -48,5 +50,35 @@ class LogisticRegressionPMMLModelExportSuite extends FunSuite { // verify if there is a second table with target category 0 and no predictors assert(pmmlRegressionModel.getRegressionTables.get(1).getTargetCategory === "0") assert(pmmlRegressionModel.getRegressionTables.get(1).getNumericPredictors.size === 0) + // ensure logistic regression has normalization method set to LOGIT + assert(pmmlRegressionModel.getNormalizationMethod() == RegressionNormalizationMethodType.LOGIT) } + + test("linear SVM PMML export") { + val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17) + val svmModel = new SVMModel(linearInput(0).features, linearInput(0).label) + + val svmModelExport = PMMLModelExportFactory.createPMMLModelExport(svmModel) + + // assert that the PMML format is as expected + assert(svmModelExport.isInstanceOf[PMMLModelExport]) + val pmml = svmModelExport.getPmml + assert(pmml.getHeader.getDescription + === "linear SVM") + // check that the number of fields match the weights size + assert(pmml.getDataDictionary.getNumberOfFields === svmModel.weights.size + 1) + // This verify that there is a model attached to the pmml object and the model is a regression + // one. It also verifies that the pmml model has a regression table (for target category 1) + // with the same number of predictors of the model weights. + val pmmlRegressionModel = pmml.getModels.get(0).asInstanceOf[RegressionModel] + assert(pmmlRegressionModel.getRegressionTables.get(0).getTargetCategory === "1") + assert(pmmlRegressionModel.getRegressionTables.get(0).getNumericPredictors.size + === svmModel.weights.size) + // verify if there is a second table with target category 0 and no predictors + assert(pmmlRegressionModel.getRegressionTables.get(1).getTargetCategory === "0") + assert(pmmlRegressionModel.getRegressionTables.get(1).getNumericPredictors.size === 0) + // ensure linear SVM has normalization method set to NONE + assert(pmmlRegressionModel.getNormalizationMethod() == RegressionNormalizationMethodType.NONE) + } + } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExportSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExportSuite.scala index 417ea12ccfa0b..f9afbd888dfc5 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExportSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExportSuite.scala @@ -20,13 +20,12 @@ package org.apache.spark.mllib.pmml.export import org.dmg.pmml.RegressionModel import org.scalatest.FunSuite -import org.apache.spark.mllib.classification.SVMModel import org.apache.spark.mllib.regression.{LassoModel, LinearRegressionModel, RidgeRegressionModel} import org.apache.spark.mllib.util.LinearDataGenerator class GeneralizedLinearPMMLModelExportSuite extends FunSuite { - test("linear regression pmml export") { + test("linear regression PMML export") { val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17) val linearRegressionModel = new LinearRegressionModel(linearInput(0).features, linearInput(0).label) @@ -45,7 +44,7 @@ class GeneralizedLinearPMMLModelExportSuite extends FunSuite { === linearRegressionModel.weights.size) } - test("ridge regression pmml export") { + test("ridge regression PMML export") { val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17) val ridgeRegressionModel = new RidgeRegressionModel(linearInput(0).features, linearInput(0).label) @@ -64,7 +63,7 @@ class GeneralizedLinearPMMLModelExportSuite extends FunSuite { === ridgeRegressionModel.weights.size) } - test("lasso pmml export") { + test("lasso PMML export") { val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17) val lassoModel = new LassoModel(linearInput(0).features, linearInput(0).label) val lassoModelExport = PMMLModelExportFactory.createPMMLModelExport(lassoModel) @@ -82,22 +81,4 @@ class GeneralizedLinearPMMLModelExportSuite extends FunSuite { === lassoModel.weights.size) } - test("svm pmml export") { - val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17) - val svmModel = new SVMModel(linearInput(0).features, linearInput(0).label) - val svmModelExport = PMMLModelExportFactory.createPMMLModelExport(svmModel) - // assert that the PMML format is as expected - assert(svmModelExport.isInstanceOf[PMMLModelExport]) - val pmml = svmModelExport.getPmml - assert(pmml.getHeader.getDescription - === "linear SVM: if predicted value > 0, the outcome is positive, or negative otherwise") - // check that the number of fields match the weights size - assert(pmml.getDataDictionary.getNumberOfFields === svmModel.weights.size + 1) - // This verify that there is a model attached to the pmml object and the model is a regression - // one. It also verifies that the pmml model has a regression table with the same number of - // predictors of the model weights. - val pmmlRegressionModel = pmml.getModels.get(0).asInstanceOf[RegressionModel] - assert(pmmlRegressionModel.getRegressionTables.get(0).getNumericPredictors.size - === svmModel.weights.size) - } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExportSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExportSuite.scala index d3c1dd85fa3b1..b985d0446d7b0 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExportSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExportSuite.scala @@ -45,4 +45,5 @@ class KMeansPMMLModelExportSuite extends FunSuite { val pmmlClusteringModel = pmml.getModels.get(0).asInstanceOf[ClusteringModel] assert(pmmlClusteringModel.getNumberOfClusters === clusterCenters.length) } + } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala index b87e96e7032f3..f28a4ac8ad01f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala @@ -40,7 +40,7 @@ class PMMLModelExportFactorySuite extends FunSuite { } test("PMMLModelExportFactory create GeneralizedLinearPMMLModelExport when passing a " - + "LinearRegressionModel, RidgeRegressionModel, LassoModel or SVMModel") { + + "LinearRegressionModel, RidgeRegressionModel or LassoModel") { val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17) val linearRegressionModel = @@ -56,22 +56,21 @@ class PMMLModelExportFactorySuite extends FunSuite { val lassoModel = new LassoModel(linearInput(0).features, linearInput(0).label) val lassoModelExport = PMMLModelExportFactory.createPMMLModelExport(lassoModel) assert(lassoModelExport.isInstanceOf[GeneralizedLinearPMMLModelExport]) - - val svmModel = new SVMModel(linearInput(0).features, linearInput(0).label) - val svmModelExport = PMMLModelExportFactory.createPMMLModelExport(svmModel) - assert(svmModelExport.isInstanceOf[GeneralizedLinearPMMLModelExport]) } - test("PMMLModelExportFactory create LogisticRegressionPMMLModelExport " - + "when passing a LogisticRegressionModel") { + test("PMMLModelExportFactory create BinaryClassificationPMMLModelExport " + + "when passing a LogisticRegressionModel or SVMModel") { val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17) + val logisticRegressionModel = new LogisticRegressionModel(linearInput(0).features, linearInput(0).label) - val logisticRegressionModelExport = PMMLModelExportFactory.createPMMLModelExport(logisticRegressionModel) - - assert(logisticRegressionModelExport.isInstanceOf[LogisticRegressionPMMLModelExport]) + assert(logisticRegressionModelExport.isInstanceOf[BinaryClassificationPMMLModelExport]) + + val svmModel = new SVMModel(linearInput(0).features, linearInput(0).label) + val svmModelExport = PMMLModelExportFactory.createPMMLModelExport(svmModel) + assert(svmModelExport.isInstanceOf[BinaryClassificationPMMLModelExport]) } test("PMMLModelExportFactory throw IllegalArgumentException "