From 03bc3a51b39440ea5f050f2c0470b5e66da85cf1 Mon Sep 17 00:00:00 2001 From: Vincenzo Selvaggio Date: Tue, 9 Dec 2014 08:06:12 +0000 Subject: [PATCH] added logistic regression --- .../mllib/export/ModelExportFactory.scala | 11 +- .../GeneralizedLinearPMMLModelExport.scala | 2 +- .../LogisticRegressionPMMLModelExport.scala | 101 ++++++++++++++++++ .../export/ModelExportFactorySuite.scala | 24 ++++- ...eneralizedLinearPMMLModelExportSuite.scala | 4 +- ...gisticRegressionPMMLModelExportSuite.scala | 62 +++++++++++ 6 files changed, 196 insertions(+), 8 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/export/pmml/LogisticRegressionPMMLModelExport.scala create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/export/pmml/LogisticRegressionPMMLModelExportSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/mllib/export/ModelExportFactory.scala b/mllib/src/main/scala/org/apache/spark/mllib/export/ModelExportFactory.scala index 282a32ebc5ced..4889be8e3c7ec 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/export/ModelExportFactory.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/export/ModelExportFactory.scala @@ -17,12 +17,14 @@ package org.apache.spark.mllib.export +import org.apache.spark.mllib.classification.LogisticRegressionModel import org.apache.spark.mllib.classification.SVMModel import org.apache.spark.mllib.clustering.KMeansModel import org.apache.spark.mllib.export.ModelExportType.ModelExportType import org.apache.spark.mllib.export.ModelExportType.PMML import org.apache.spark.mllib.export.pmml.GeneralizedLinearPMMLModelExport import org.apache.spark.mllib.export.pmml.KMeansPMMLModelExport +import org.apache.spark.mllib.export.pmml.LogisticRegressionPMMLModelExport import org.apache.spark.mllib.regression.LassoModel import org.apache.spark.mllib.regression.LinearRegressionModel import org.apache.spark.mllib.regression.RidgeRegressionModel @@ -46,7 +48,14 @@ private[mllib] object ModelExportFactory { case lassoRegression: LassoModel => new GeneralizedLinearPMMLModelExport(lassoRegression, "lasso regression") case svm: SVMModel => - new GeneralizedLinearPMMLModelExport(svm, "linear SVM") + new GeneralizedLinearPMMLModelExport( + svm, + "linear SVM: if predicted value > 0, the outcome is positive, or negative otherwise") + case logisticRegression: LogisticRegressionModel => + new LogisticRegressionPMMLModelExport( + logisticRegression, + "logistic regression: if predicted value > 0.5, " + + "the outcome is positive, or negative otherwise") case _ => throw new IllegalArgumentException("Export not supported for model: " + model.getClass) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/export/pmml/GeneralizedLinearPMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/export/pmml/GeneralizedLinearPMMLModelExport.scala index 8b3a20602895f..8b3d1ce9e3e0f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/export/pmml/GeneralizedLinearPMMLModelExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/export/pmml/GeneralizedLinearPMMLModelExport.scala @@ -33,7 +33,7 @@ import org.dmg.pmml.RegressionTable import org.apache.spark.mllib.regression.GeneralizedLinearModel /** - * PMML Model Export for GeneralizedLinear abstract class + * PMML Model Export for GeneralizedLinearModel abstract class */ private[mllib] class GeneralizedLinearPMMLModelExport( model : GeneralizedLinearModel, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/export/pmml/LogisticRegressionPMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/export/pmml/LogisticRegressionPMMLModelExport.scala new file mode 100644 index 0000000000000..f0c6708af58c4 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/export/pmml/LogisticRegressionPMMLModelExport.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.export.pmml + +import org.dmg.pmml.DataDictionary +import org.dmg.pmml.DataField +import org.dmg.pmml.DataType +import org.dmg.pmml.FieldName +import org.dmg.pmml.FieldUsageType +import org.dmg.pmml.MiningField +import org.dmg.pmml.MiningFunctionType +import org.dmg.pmml.MiningSchema +import org.dmg.pmml.NumericPredictor +import org.dmg.pmml.OpType +import org.dmg.pmml.RegressionModel +import org.dmg.pmml.RegressionTable +import org.apache.spark.mllib.classification.LogisticRegressionModel +import org.dmg.pmml.RegressionNormalizationMethodType + +/** + * PMML Model Export for LogisticRegressionModel class + */ +private[mllib] class LogisticRegressionPMMLModelExport( + model : LogisticRegressionModel, + description : String) + extends PMMLModelExport{ + + /** + * Export the input LogisticRegressionModel model to PMML format + */ + populateLogisticRegressionPMML(model) + + private def populateLogisticRegressionPMML(model : LogisticRegressionModel): Unit = { + + pmml.getHeader().setDescription(description) + + if(model.weights.size > 0){ + + val fields = new Array[FieldName](model.weights.size) + + val dataDictionary = new DataDictionary() + + val miningSchema = new MiningSchema() + + val regressionTableYES = new RegressionTable(model.intercept) + .withTargetCategory("YES") + + val regressionTableNO = new RegressionTable(0.0) + .withTargetCategory("NO") + + val regressionModel = new RegressionModel(miningSchema,MiningFunctionType.CLASSIFICATION) + .withModelName(description) + .withNormalizationMethod(RegressionNormalizationMethodType.LOGIT) + .withRegressionTables(regressionTableYES, regressionTableNO) + + for ( i <- 0 until model.weights.size) { + fields(i) = FieldName.create("field_" + i) + dataDictionary + .withDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE)) + miningSchema + .withMiningFields(new MiningField(fields(i)) + .withUsageType(FieldUsageType.ACTIVE)) + regressionTableYES + .withNumericPredictors(new NumericPredictor(fields(i), model.weights(i))) + } + + // add target field + val targetField = FieldName.create("target"); + dataDictionary + .withDataFields( + new DataField(targetField, OpType.CONTINUOUS, DataType.DOUBLE) + ) + miningSchema + .withMiningFields(new MiningField(targetField) + .withUsageType(FieldUsageType.TARGET)) + + dataDictionary.withNumberOfFields((dataDictionary.getDataFields()).size()) + + pmml.setDataDictionary(dataDictionary) + pmml.withModels(regressionModel) + + } + + } + +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/export/ModelExportFactorySuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/export/ModelExportFactorySuite.scala index 6792e2d674bb4..d63a544ebdf97 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/export/ModelExportFactorySuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/export/ModelExportFactorySuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.mllib.export import org.scalatest.FunSuite +import org.apache.spark.mllib.classification.LogisticRegressionModel import org.apache.spark.mllib.classification.SVMModel import org.apache.spark.mllib.clustering.KMeansModel import org.apache.spark.mllib.linalg.Vectors @@ -28,6 +29,7 @@ import org.apache.spark.mllib.regression.RidgeRegressionModel import org.apache.spark.mllib.util.LinearDataGenerator import org.apache.spark.mllib.export.pmml.GeneralizedLinearPMMLModelExport import org.apache.spark.mllib.export.pmml.KMeansPMMLModelExport +import org.apache.spark.mllib.export.pmml.LogisticRegressionPMMLModelExport class ModelExportFactorySuite extends FunSuite{ @@ -55,10 +57,10 @@ class ModelExportFactorySuite extends FunSuite{ //arrange val linearInput = LinearDataGenerator.generateLinearInput( 3.0, Array(10.0, 10.0), 1, 17) - val linearRegressionModel = new LinearRegressionModel(linearInput(0).features, linearInput(0).label); - val ridgeRegressionModel = new RidgeRegressionModel(linearInput(0).features, linearInput(0).label); - val lassoModel = new LassoModel(linearInput(0).features, linearInput(0).label); - val svmModel = new SVMModel(linearInput(0).features, linearInput(0).label); + val linearRegressionModel = new LinearRegressionModel(linearInput(0).features, linearInput(0).label) + val ridgeRegressionModel = new RidgeRegressionModel(linearInput(0).features, linearInput(0).label) + val lassoModel = new LassoModel(linearInput(0).features, linearInput(0).label) + val svmModel = new SVMModel(linearInput(0).features, linearInput(0).label) //act val linearModelExport = ModelExportFactory.createModelExport(linearRegressionModel, ModelExportType.PMML) @@ -82,6 +84,20 @@ class ModelExportFactorySuite extends FunSuite{ } + test("ModelExportFactory create LogisticRegressionPMMLModelExport when passing a LogisticRegressionModel") { + + //arrange + val linearInput = LinearDataGenerator.generateLinearInput( + 3.0, Array(10.0, 10.0), 1, 17) + val logisticRegressionModel = new LogisticRegressionModel(linearInput(0).features, linearInput(0).label); + + //act + val logisticRegressionModelExport = ModelExportFactory.createModelExport(logisticRegressionModel, ModelExportType.PMML) + //assert + assert(logisticRegressionModelExport.isInstanceOf[LogisticRegressionPMMLModelExport]) + + } + test("ModelExportFactory throw IllegalArgumentException when passing an unsupported model") { //arrange diff --git a/mllib/src/test/scala/org/apache/spark/mllib/export/pmml/GeneralizedLinearPMMLModelExportSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/export/pmml/GeneralizedLinearPMMLModelExportSuite.scala index 402a84c2c8a47..e27c193a7f7f8 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/export/pmml/GeneralizedLinearPMMLModelExportSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/export/pmml/GeneralizedLinearPMMLModelExportSuite.scala @@ -84,7 +84,7 @@ class GeneralizedLinearPMMLModelExportSuite extends FunSuite{ //assert that the PMML format is as expected assert(svmModelExport.isInstanceOf[PMMLModelExport]) pmml = svmModelExport.asInstanceOf[PMMLModelExport].getPmml() - assert(pmml.getHeader().getDescription() === "linear SVM") + assert(pmml.getHeader().getDescription() === "linear SVM: if predicted value > 0, the outcome is positive, or negative otherwise") //check that the number of fields match the weights size assert(pmml.getDataDictionary().getNumberOfFields() === svmModel.weights.size + 1) //this verify that there is a model attached to the pmml object and the model is a regression one @@ -96,7 +96,7 @@ class GeneralizedLinearPMMLModelExportSuite extends FunSuite{ //ModelExporter.toPMML(linearRegressionModel,"/tmp/linearregression.xml") //ModelExporter.toPMML(ridgeRegressionModel,"/tmp/ridgeregression.xml") //ModelExporter.toPMML(lassoModel,"/tmp/lassoregression.xml") - //ModelExporter.toPMML(svmModel,"/tmp/svm.xml") + //ModelExporter.toPMML(svmModel,"/tmp/linearsvm.xml") } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/export/pmml/LogisticRegressionPMMLModelExportSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/export/pmml/LogisticRegressionPMMLModelExportSuite.scala new file mode 100644 index 0000000000000..27093938102ba --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/export/pmml/LogisticRegressionPMMLModelExportSuite.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.export.pmml + +import org.dmg.pmml.RegressionModel +import org.scalatest.FunSuite + +import org.apache.spark.mllib.classification.LogisticRegressionModel +import org.apache.spark.mllib.export.ModelExportFactory +import org.apache.spark.mllib.export.ModelExportType +import org.apache.spark.mllib.util.LinearDataGenerator + +class LogisticRegressionPMMLModelExportSuite extends FunSuite{ + + test("LogisticRegressionPMMLModelExport generate PMML format") { + + //arrange models to test + val linearInput = LinearDataGenerator.generateLinearInput( + 3.0, Array(10.0, 10.0), 1, 17) + val logisticRegressionModel = new LogisticRegressionModel(linearInput(0).features, linearInput(0).label); + + //act by exporting the model to the PMML format + val logisticModelExport = ModelExportFactory.createModelExport(logisticRegressionModel, ModelExportType.PMML) + //assert that the PMML format is as expected + assert(logisticModelExport.isInstanceOf[PMMLModelExport]) + var pmml = logisticModelExport.asInstanceOf[PMMLModelExport].getPmml() + assert(pmml.getHeader().getDescription() === "logistic regression: if predicted value > 0.5, the outcome is positive, or negative otherwise") + //check that the number of fields match the weights size + assert(pmml.getDataDictionary().getNumberOfFields() === logisticRegressionModel.weights.size + 1) + //this verify that there is a model attached to the pmml object and the model is a regression one + //it also verifies that the pmml model has a regression table (for target category YES) with the same number of predictors of the model weights + assert(pmml.getModels().get(0).asInstanceOf[RegressionModel] + .getRegressionTables().get(0).getTargetCategory() === "YES") + assert(pmml.getModels().get(0).asInstanceOf[RegressionModel] + .getRegressionTables().get(0).getNumericPredictors().size() === logisticRegressionModel.weights.size) + //verify if there is a second table with target category NO and no predictors + assert(pmml.getModels().get(0).asInstanceOf[RegressionModel] + .getRegressionTables().get(1).getTargetCategory() === "NO") + assert(pmml.getModels().get(0).asInstanceOf[RegressionModel] + .getRegressionTables().get(1).getNumericPredictors().size() === 0) + + //manual checking + //ModelExporter.toPMML(logisticRegressionModel,"/tmp/logisticregression.xml") + + } + +}