From 78515ec9d8403d555b996163ca32409d91cb30f7 Mon Sep 17 00:00:00 2001 From: Vincenzo Selvaggio Date: Fri, 28 Nov 2014 07:39:57 +0000 Subject: [PATCH] [SPARK-1406] added pmml export for LinearRegressionModel, RidgeRegressionModel and LassoModel --- .../mllib/export/ModelExportFactory.scala | 13 ++- .../GeneralizedLinearPMMLModelExport.scala | 94 +++++++++++++++++++ .../export/pmml/KMeansPMMLModelExport.scala | 2 +- .../export/ModelExportFactorySuite.scala | 32 +++++++ ...eneralizedLinearPMMLModelExportSuite.scala | 87 +++++++++++++++++ 5 files changed, 226 insertions(+), 2 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/export/pmml/GeneralizedLinearPMMLModelExport.scala create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/export/pmml/GeneralizedLinearPMMLModelExportSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/mllib/export/ModelExportFactory.scala b/mllib/src/main/scala/org/apache/spark/mllib/export/ModelExportFactory.scala index 26bc2a499778e..079a5efc5c219 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/export/ModelExportFactory.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/export/ModelExportFactory.scala @@ -20,6 +20,10 @@ package org.apache.spark.mllib.export import org.apache.spark.mllib.clustering.KMeansModel import org.apache.spark.mllib.export.pmml.KMeansPMMLModelExport import org.apache.spark.mllib.export.ModelExportType._ +import org.apache.spark.mllib.regression.LinearRegressionModel +import org.apache.spark.mllib.export.pmml.GeneralizedLinearPMMLModelExport +import org.apache.spark.mllib.regression.RidgeRegressionModel +import org.apache.spark.mllib.regression.LassoModel private[mllib] object ModelExportFactory { @@ -31,7 +35,14 @@ private[mllib] object ModelExportFactory { def createModelExport(model: Any, exportType: ModelExportType): ModelExport = { return exportType match{ case PMML => model match{ - case kmeans: KMeansModel => new KMeansPMMLModelExport(kmeans) + case kmeans: KMeansModel => + new KMeansPMMLModelExport(kmeans) + case linearRegression: LinearRegressionModel => + new GeneralizedLinearPMMLModelExport(linearRegression, "linear regression") + case ridgeRegression: RidgeRegressionModel => + new GeneralizedLinearPMMLModelExport(ridgeRegression, "ridge regression") + case lassoRegression: LassoModel => + new GeneralizedLinearPMMLModelExport(lassoRegression, "lasso regression") case _ => throw new IllegalArgumentException("Export not supported for model: " + model.getClass) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/export/pmml/GeneralizedLinearPMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/export/pmml/GeneralizedLinearPMMLModelExport.scala new file mode 100644 index 0000000000000..edfacafa258ed --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/export/pmml/GeneralizedLinearPMMLModelExport.scala @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.export.pmml + +import org.dmg.pmml.Array.Type +import org.dmg.pmml.Cluster +import org.dmg.pmml.ClusteringField +import org.dmg.pmml.ClusteringModel +import org.dmg.pmml.ClusteringModel.ModelClass +import org.dmg.pmml.CompareFunctionType +import org.dmg.pmml.ComparisonMeasure +import org.dmg.pmml.ComparisonMeasure.Kind +import org.dmg.pmml.DataDictionary +import org.dmg.pmml.DataField +import org.dmg.pmml.DataType +import org.dmg.pmml.FieldName +import org.dmg.pmml.FieldUsageType +import org.dmg.pmml.MiningField +import org.dmg.pmml.MiningFunctionType +import org.dmg.pmml.MiningSchema +import org.dmg.pmml.OpType +import org.dmg.pmml.SquaredEuclidean +import org.apache.spark.mllib.clustering.KMeansModel +import org.apache.spark.mllib.regression.LinearRegressionModel +import org.apache.spark.mllib.regression.GeneralizedLinearModel +import org.dmg.pmml.RegressionModel +import org.dmg.pmml.RegressionTable +import org.dmg.pmml.NumericPredictor + +/** + * PMML Model Export for GeneralizedLinear abstract class + */ +private[mllib] class GeneralizedLinearPMMLModelExport( + model : GeneralizedLinearModel, + description : String) + extends PMMLModelExport{ + + /** + * Export the input GeneralizedLinearModel model to PMML format + */ + populateGeneralizedLinearPMML(model) + + private def populateGeneralizedLinearPMML(model : GeneralizedLinearModel): Unit = { + + pmml.getHeader().setDescription(description) + + if(model.weights.size > 0){ + + val fields = new Array[FieldName](model.weights.size) + + val dataDictionary = new DataDictionary() + + val miningSchema = new MiningSchema() + + val regressionTable = new RegressionTable(model.intercept) + + val regressionModel = new RegressionModel(miningSchema,MiningFunctionType.REGRESSION) + .withModelName(description).withRegressionTables(regressionTable) + + for ( i <- 0 until model.weights.size) { + fields(i) = FieldName.create("field_" + i) + dataDictionary + .withDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE)) + miningSchema + .withMiningFields(new MiningField(fields(i)) + .withUsageType(FieldUsageType.ACTIVE)) + regressionTable.withNumericPredictors(new NumericPredictor(fields(i), model.weights(i))) + } + + dataDictionary.withNumberOfFields((dataDictionary.getDataFields()).size()) + + pmml.setDataDictionary(dataDictionary) + pmml.withModels(regressionModel) + + } + + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/export/pmml/KMeansPMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/export/pmml/KMeansPMMLModelExport.scala index 909f172fc1f72..c10d48fb8eb4c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/export/pmml/KMeansPMMLModelExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/export/pmml/KMeansPMMLModelExport.scala @@ -71,7 +71,7 @@ private[mllib] class KMeansPMMLModelExport(model : KMeansModel) extends PMMLMode MiningFunctionType.CLUSTERING, ModelClass.CENTER_BASED, model.clusterCenters.length) .withModelName("k-means") - for ( i <- 0 to (clusterCenter.size - 1)) { + for ( i <- 0 until clusterCenter.size) { fields(i) = FieldName.create("field_" + i) dataDictionary .withDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE)) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/export/ModelExportFactorySuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/export/ModelExportFactorySuite.scala index eb0b7f18b1046..ec4b300d71e56 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/export/ModelExportFactorySuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/export/ModelExportFactorySuite.scala @@ -22,6 +22,11 @@ import org.apache.spark.mllib.clustering.KMeansModel import org.apache.spark.mllib.linalg.Vectors import org.scalatest.FunSuite import org.apache.spark.mllib.export.pmml.KMeansPMMLModelExport +import org.apache.spark.mllib.util.LinearDataGenerator +import org.apache.spark.mllib.regression.LinearRegressionModel +import org.apache.spark.mllib.export.pmml.GeneralizedLinearPMMLModelExport +import org.apache.spark.mllib.regression.LassoModel +import org.apache.spark.mllib.regression.RidgeRegressionModel class ModelExportFactorySuite extends FunSuite{ @@ -43,6 +48,33 @@ class ModelExportFactorySuite extends FunSuite{ } + test("ModelExportFactory create GeneralizedLinearPMMLModelExport when passing a" + +"LinearRegressionModel, RidgeRegressionModel or LassoModel") { + + //arrange + val linearInput = LinearDataGenerator.generateLinearInput( + 3.0, Array(10.0, 10.0), 1, 17) + val linearRegressionModel = new LinearRegressionModel(linearInput(0).features, linearInput(0).label); + val ridgeRegressionModel = new RidgeRegressionModel(linearInput(0).features, linearInput(0).label); + val lassoModel = new LassoModel(linearInput(0).features, linearInput(0).label); + + //act + val linearModelExport = ModelExportFactory.createModelExport(linearRegressionModel, ModelExportType.PMML) + //assert + assert(linearModelExport.isInstanceOf[GeneralizedLinearPMMLModelExport]) + + //act + val ridgeModelExport = ModelExportFactory.createModelExport(ridgeRegressionModel, ModelExportType.PMML) + //assert + assert(ridgeModelExport.isInstanceOf[GeneralizedLinearPMMLModelExport]) + + //act + val lassoModelExport = ModelExportFactory.createModelExport(lassoModel, ModelExportType.PMML) + //assert + assert(lassoModelExport.isInstanceOf[GeneralizedLinearPMMLModelExport]) + + } + test("ModelExportFactory throw IllegalArgumentException when passing an unsupported model") { //arrange diff --git a/mllib/src/test/scala/org/apache/spark/mllib/export/pmml/GeneralizedLinearPMMLModelExportSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/export/pmml/GeneralizedLinearPMMLModelExportSuite.scala new file mode 100644 index 0000000000000..20245c6917c5b --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/export/pmml/GeneralizedLinearPMMLModelExportSuite.scala @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.export.pmml + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.mllib.export.ModelExportFactory +import org.apache.spark.mllib.export.ModelExportType +import org.apache.spark.mllib.regression.LassoModel +import org.apache.spark.mllib.regression.LinearRegressionModel +import org.apache.spark.mllib.regression.RidgeRegressionModel +import org.apache.spark.mllib.util.LinearDataGenerator +import org.scalatest.FunSuite +import org.dmg.pmml.RegressionModel + +class GeneralizedLinearPMMLModelExportSuite extends FunSuite{ + + test("GeneralizedLinearPMMLModelExport generate PMML format") { + + //arrange models to test + val linearInput = LinearDataGenerator.generateLinearInput( + 3.0, Array(10.0, 10.0), 1, 17) + val linearRegressionModel = new LinearRegressionModel(linearInput(0).features, linearInput(0).label); + val ridgeRegressionModel = new RidgeRegressionModel(linearInput(0).features, linearInput(0).label); + val lassoModel = new LassoModel(linearInput(0).features, linearInput(0).label); + + //act by exporting the model to the PMML format + val linearModelExport = ModelExportFactory.createModelExport(linearRegressionModel, ModelExportType.PMML) + //assert that the PMML format is as expected + assert(linearModelExport.isInstanceOf[PMMLModelExport]) + var pmml = linearModelExport.asInstanceOf[PMMLModelExport].getPmml() + assert(pmml.getHeader().getDescription() === "linear regression") + //check that the number of fields match the weights size + assert(pmml.getDataDictionary().getNumberOfFields() === linearRegressionModel.weights.size) + //this verify that there is a model attached to the pmml object and the model is a regression one + //it also verifies that the pmml model has a regression table with the same number of predictors of the model weights + assert(pmml.getModels().get(0).asInstanceOf[RegressionModel] + .getRegressionTables().get(0).getNumericPredictors().size() === linearRegressionModel.weights.size) + + //act + val ridgeModelExport = ModelExportFactory.createModelExport(ridgeRegressionModel, ModelExportType.PMML) + //assert that the PMML format is as expected + assert(ridgeModelExport.isInstanceOf[PMMLModelExport]) + pmml = ridgeModelExport.asInstanceOf[PMMLModelExport].getPmml() + assert(pmml.getHeader().getDescription() === "ridge regression") + //check that the number of fields match the weights size + assert(pmml.getDataDictionary().getNumberOfFields() === ridgeRegressionModel.weights.size) + //this verify that there is a model attached to the pmml object and the model is a regression one + //it also verifies that the pmml model has a regression table with the same number of predictors of the model weights + assert(pmml.getModels().get(0).asInstanceOf[RegressionModel] + .getRegressionTables().get(0).getNumericPredictors().size() === ridgeRegressionModel.weights.size) + + //act + val lassoModelExport = ModelExportFactory.createModelExport(lassoModel, ModelExportType.PMML) + //assert that the PMML format is as expected + assert(lassoModelExport.isInstanceOf[PMMLModelExport]) + pmml = lassoModelExport.asInstanceOf[PMMLModelExport].getPmml() + assert(pmml.getHeader().getDescription() === "lasso regression") + //check that the number of fields match the weights size + assert(pmml.getDataDictionary().getNumberOfFields() === lassoModel.weights.size) + //this verify that there is a model attached to the pmml object and the model is a regression one + //it also verifies that the pmml model has a regression table with the same number of predictors of the model weights + assert(pmml.getModels().get(0).asInstanceOf[RegressionModel] + .getRegressionTables().get(0).getNumericPredictors().size() === lassoModel.weights.size) + + //manual checking + //ModelExporter.toPMML(linearRegressionModel,"/tmp/linearregression.xml") + //ModelExporter.toPMML(ridgeRegressionModel,"/tmp/ridgeregression.xml") + //ModelExporter.toPMML(lassoModel,"/tmp/lassoregression.xml") + + } + +}