Skip to content

Commit

Permalink
[SPARK-1406] added pmml export for LinearRegressionModel,
Browse files Browse the repository at this point in the history
RidgeRegressionModel and LassoModel
  • Loading branch information
selvinsource committed Nov 28, 2014
1 parent e29dfb9 commit 78515ec
Show file tree
Hide file tree
Showing 5 changed files with 226 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ package org.apache.spark.mllib.export
import org.apache.spark.mllib.clustering.KMeansModel
import org.apache.spark.mllib.export.pmml.KMeansPMMLModelExport
import org.apache.spark.mllib.export.ModelExportType._
import org.apache.spark.mllib.regression.LinearRegressionModel
import org.apache.spark.mllib.export.pmml.GeneralizedLinearPMMLModelExport
import org.apache.spark.mllib.regression.RidgeRegressionModel
import org.apache.spark.mllib.regression.LassoModel

private[mllib] object ModelExportFactory {

Expand All @@ -31,7 +35,14 @@ private[mllib] object ModelExportFactory {
def createModelExport(model: Any, exportType: ModelExportType): ModelExport = {
return exportType match{
case PMML => model match{
case kmeans: KMeansModel => new KMeansPMMLModelExport(kmeans)
case kmeans: KMeansModel =>
new KMeansPMMLModelExport(kmeans)
case linearRegression: LinearRegressionModel =>
new GeneralizedLinearPMMLModelExport(linearRegression, "linear regression")
case ridgeRegression: RidgeRegressionModel =>
new GeneralizedLinearPMMLModelExport(ridgeRegression, "ridge regression")
case lassoRegression: LassoModel =>
new GeneralizedLinearPMMLModelExport(lassoRegression, "lasso regression")
case _ =>
throw new IllegalArgumentException("Export not supported for model: " + model.getClass)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.export.pmml

import org.dmg.pmml.Array.Type
import org.dmg.pmml.Cluster
import org.dmg.pmml.ClusteringField
import org.dmg.pmml.ClusteringModel
import org.dmg.pmml.ClusteringModel.ModelClass
import org.dmg.pmml.CompareFunctionType
import org.dmg.pmml.ComparisonMeasure
import org.dmg.pmml.ComparisonMeasure.Kind
import org.dmg.pmml.DataDictionary
import org.dmg.pmml.DataField
import org.dmg.pmml.DataType
import org.dmg.pmml.FieldName
import org.dmg.pmml.FieldUsageType
import org.dmg.pmml.MiningField
import org.dmg.pmml.MiningFunctionType
import org.dmg.pmml.MiningSchema
import org.dmg.pmml.OpType
import org.dmg.pmml.SquaredEuclidean
import org.apache.spark.mllib.clustering.KMeansModel
import org.apache.spark.mllib.regression.LinearRegressionModel
import org.apache.spark.mllib.regression.GeneralizedLinearModel
import org.dmg.pmml.RegressionModel
import org.dmg.pmml.RegressionTable
import org.dmg.pmml.NumericPredictor

/**
* PMML Model Export for GeneralizedLinear abstract class
*/
private[mllib] class GeneralizedLinearPMMLModelExport(
model : GeneralizedLinearModel,
description : String)
extends PMMLModelExport{

/**
* Export the input GeneralizedLinearModel model to PMML format
*/
populateGeneralizedLinearPMML(model)

private def populateGeneralizedLinearPMML(model : GeneralizedLinearModel): Unit = {

pmml.getHeader().setDescription(description)

if(model.weights.size > 0){

val fields = new Array[FieldName](model.weights.size)

val dataDictionary = new DataDictionary()

val miningSchema = new MiningSchema()

val regressionTable = new RegressionTable(model.intercept)

val regressionModel = new RegressionModel(miningSchema,MiningFunctionType.REGRESSION)
.withModelName(description).withRegressionTables(regressionTable)

for ( i <- 0 until model.weights.size) {
fields(i) = FieldName.create("field_" + i)
dataDictionary
.withDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE))
miningSchema
.withMiningFields(new MiningField(fields(i))
.withUsageType(FieldUsageType.ACTIVE))
regressionTable.withNumericPredictors(new NumericPredictor(fields(i), model.weights(i)))
}

dataDictionary.withNumberOfFields((dataDictionary.getDataFields()).size())

pmml.setDataDictionary(dataDictionary)
pmml.withModels(regressionModel)

}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ private[mllib] class KMeansPMMLModelExport(model : KMeansModel) extends PMMLMode
MiningFunctionType.CLUSTERING, ModelClass.CENTER_BASED, model.clusterCenters.length)
.withModelName("k-means")

for ( i <- 0 to (clusterCenter.size - 1)) {
for ( i <- 0 until clusterCenter.size) {
fields(i) = FieldName.create("field_" + i)
dataDictionary
.withDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ import org.apache.spark.mllib.clustering.KMeansModel
import org.apache.spark.mllib.linalg.Vectors
import org.scalatest.FunSuite
import org.apache.spark.mllib.export.pmml.KMeansPMMLModelExport
import org.apache.spark.mllib.util.LinearDataGenerator
import org.apache.spark.mllib.regression.LinearRegressionModel
import org.apache.spark.mllib.export.pmml.GeneralizedLinearPMMLModelExport
import org.apache.spark.mllib.regression.LassoModel
import org.apache.spark.mllib.regression.RidgeRegressionModel

class ModelExportFactorySuite extends FunSuite{

Expand All @@ -43,6 +48,33 @@ class ModelExportFactorySuite extends FunSuite{

}

test("ModelExportFactory create GeneralizedLinearPMMLModelExport when passing a"
+"LinearRegressionModel, RidgeRegressionModel or LassoModel") {

//arrange
val linearInput = LinearDataGenerator.generateLinearInput(
3.0, Array(10.0, 10.0), 1, 17)
val linearRegressionModel = new LinearRegressionModel(linearInput(0).features, linearInput(0).label);
val ridgeRegressionModel = new RidgeRegressionModel(linearInput(0).features, linearInput(0).label);
val lassoModel = new LassoModel(linearInput(0).features, linearInput(0).label);

//act
val linearModelExport = ModelExportFactory.createModelExport(linearRegressionModel, ModelExportType.PMML)
//assert
assert(linearModelExport.isInstanceOf[GeneralizedLinearPMMLModelExport])

//act
val ridgeModelExport = ModelExportFactory.createModelExport(ridgeRegressionModel, ModelExportType.PMML)
//assert
assert(ridgeModelExport.isInstanceOf[GeneralizedLinearPMMLModelExport])

//act
val lassoModelExport = ModelExportFactory.createModelExport(lassoModel, ModelExportType.PMML)
//assert
assert(lassoModelExport.isInstanceOf[GeneralizedLinearPMMLModelExport])

}

test("ModelExportFactory throw IllegalArgumentException when passing an unsupported model") {

//arrange
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.export.pmml

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.mllib.export.ModelExportFactory
import org.apache.spark.mllib.export.ModelExportType
import org.apache.spark.mllib.regression.LassoModel
import org.apache.spark.mllib.regression.LinearRegressionModel
import org.apache.spark.mllib.regression.RidgeRegressionModel
import org.apache.spark.mllib.util.LinearDataGenerator
import org.scalatest.FunSuite
import org.dmg.pmml.RegressionModel

class GeneralizedLinearPMMLModelExportSuite extends FunSuite{

test("GeneralizedLinearPMMLModelExport generate PMML format") {

//arrange models to test
val linearInput = LinearDataGenerator.generateLinearInput(
3.0, Array(10.0, 10.0), 1, 17)
val linearRegressionModel = new LinearRegressionModel(linearInput(0).features, linearInput(0).label);
val ridgeRegressionModel = new RidgeRegressionModel(linearInput(0).features, linearInput(0).label);
val lassoModel = new LassoModel(linearInput(0).features, linearInput(0).label);

//act by exporting the model to the PMML format
val linearModelExport = ModelExportFactory.createModelExport(linearRegressionModel, ModelExportType.PMML)
//assert that the PMML format is as expected
assert(linearModelExport.isInstanceOf[PMMLModelExport])
var pmml = linearModelExport.asInstanceOf[PMMLModelExport].getPmml()
assert(pmml.getHeader().getDescription() === "linear regression")
//check that the number of fields match the weights size
assert(pmml.getDataDictionary().getNumberOfFields() === linearRegressionModel.weights.size)
//this verify that there is a model attached to the pmml object and the model is a regression one
//it also verifies that the pmml model has a regression table with the same number of predictors of the model weights
assert(pmml.getModels().get(0).asInstanceOf[RegressionModel]
.getRegressionTables().get(0).getNumericPredictors().size() === linearRegressionModel.weights.size)

//act
val ridgeModelExport = ModelExportFactory.createModelExport(ridgeRegressionModel, ModelExportType.PMML)
//assert that the PMML format is as expected
assert(ridgeModelExport.isInstanceOf[PMMLModelExport])
pmml = ridgeModelExport.asInstanceOf[PMMLModelExport].getPmml()
assert(pmml.getHeader().getDescription() === "ridge regression")
//check that the number of fields match the weights size
assert(pmml.getDataDictionary().getNumberOfFields() === ridgeRegressionModel.weights.size)
//this verify that there is a model attached to the pmml object and the model is a regression one
//it also verifies that the pmml model has a regression table with the same number of predictors of the model weights
assert(pmml.getModels().get(0).asInstanceOf[RegressionModel]
.getRegressionTables().get(0).getNumericPredictors().size() === ridgeRegressionModel.weights.size)

//act
val lassoModelExport = ModelExportFactory.createModelExport(lassoModel, ModelExportType.PMML)
//assert that the PMML format is as expected
assert(lassoModelExport.isInstanceOf[PMMLModelExport])
pmml = lassoModelExport.asInstanceOf[PMMLModelExport].getPmml()
assert(pmml.getHeader().getDescription() === "lasso regression")
//check that the number of fields match the weights size
assert(pmml.getDataDictionary().getNumberOfFields() === lassoModel.weights.size)
//this verify that there is a model attached to the pmml object and the model is a regression one
//it also verifies that the pmml model has a regression table with the same number of predictors of the model weights
assert(pmml.getModels().get(0).asInstanceOf[RegressionModel]
.getRegressionTables().get(0).getNumericPredictors().size() === lassoModel.weights.size)

//manual checking
//ModelExporter.toPMML(linearRegressionModel,"/tmp/linearregression.xml")
//ModelExporter.toPMML(ridgeRegressionModel,"/tmp/ridgeregression.xml")
//ModelExporter.toPMML(lassoModel,"/tmp/lassoregression.xml")

}

}

0 comments on commit 78515ec

Please sign in to comment.