Skip to content

Commit

Permalink
added logistic regression
Browse files Browse the repository at this point in the history
  • Loading branch information
selvinsource committed Dec 9, 2014
1 parent da2ec11 commit 03bc3a5
Show file tree
Hide file tree
Showing 6 changed files with 196 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@

package org.apache.spark.mllib.export

import org.apache.spark.mllib.classification.LogisticRegressionModel
import org.apache.spark.mllib.classification.SVMModel
import org.apache.spark.mllib.clustering.KMeansModel
import org.apache.spark.mllib.export.ModelExportType.ModelExportType
import org.apache.spark.mllib.export.ModelExportType.PMML
import org.apache.spark.mllib.export.pmml.GeneralizedLinearPMMLModelExport
import org.apache.spark.mllib.export.pmml.KMeansPMMLModelExport
import org.apache.spark.mllib.export.pmml.LogisticRegressionPMMLModelExport
import org.apache.spark.mllib.regression.LassoModel
import org.apache.spark.mllib.regression.LinearRegressionModel
import org.apache.spark.mllib.regression.RidgeRegressionModel
Expand All @@ -46,7 +48,14 @@ private[mllib] object ModelExportFactory {
case lassoRegression: LassoModel =>
new GeneralizedLinearPMMLModelExport(lassoRegression, "lasso regression")
case svm: SVMModel =>
new GeneralizedLinearPMMLModelExport(svm, "linear SVM")
new GeneralizedLinearPMMLModelExport(
svm,
"linear SVM: if predicted value > 0, the outcome is positive, or negative otherwise")
case logisticRegression: LogisticRegressionModel =>
new LogisticRegressionPMMLModelExport(
logisticRegression,
"logistic regression: if predicted value > 0.5, "
+ "the outcome is positive, or negative otherwise")
case _ =>
throw new IllegalArgumentException("Export not supported for model: " + model.getClass)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import org.dmg.pmml.RegressionTable
import org.apache.spark.mllib.regression.GeneralizedLinearModel

/**
* PMML Model Export for GeneralizedLinear abstract class
* PMML Model Export for GeneralizedLinearModel abstract class
*/
private[mllib] class GeneralizedLinearPMMLModelExport(
model : GeneralizedLinearModel,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.export.pmml

import org.dmg.pmml.DataDictionary
import org.dmg.pmml.DataField
import org.dmg.pmml.DataType
import org.dmg.pmml.FieldName
import org.dmg.pmml.FieldUsageType
import org.dmg.pmml.MiningField
import org.dmg.pmml.MiningFunctionType
import org.dmg.pmml.MiningSchema
import org.dmg.pmml.NumericPredictor
import org.dmg.pmml.OpType
import org.dmg.pmml.RegressionModel
import org.dmg.pmml.RegressionTable
import org.apache.spark.mllib.classification.LogisticRegressionModel
import org.dmg.pmml.RegressionNormalizationMethodType

/**
* PMML Model Export for LogisticRegressionModel class
*/
private[mllib] class LogisticRegressionPMMLModelExport(
model : LogisticRegressionModel,
description : String)
extends PMMLModelExport{

/**
* Export the input LogisticRegressionModel model to PMML format
*/
populateLogisticRegressionPMML(model)

private def populateLogisticRegressionPMML(model : LogisticRegressionModel): Unit = {

pmml.getHeader().setDescription(description)

if(model.weights.size > 0){

val fields = new Array[FieldName](model.weights.size)

val dataDictionary = new DataDictionary()

val miningSchema = new MiningSchema()

val regressionTableYES = new RegressionTable(model.intercept)
.withTargetCategory("YES")

val regressionTableNO = new RegressionTable(0.0)
.withTargetCategory("NO")

val regressionModel = new RegressionModel(miningSchema,MiningFunctionType.CLASSIFICATION)
.withModelName(description)
.withNormalizationMethod(RegressionNormalizationMethodType.LOGIT)
.withRegressionTables(regressionTableYES, regressionTableNO)

for ( i <- 0 until model.weights.size) {
fields(i) = FieldName.create("field_" + i)
dataDictionary
.withDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE))
miningSchema
.withMiningFields(new MiningField(fields(i))
.withUsageType(FieldUsageType.ACTIVE))
regressionTableYES
.withNumericPredictors(new NumericPredictor(fields(i), model.weights(i)))
}

// add target field
val targetField = FieldName.create("target");
dataDictionary
.withDataFields(
new DataField(targetField, OpType.CONTINUOUS, DataType.DOUBLE)
)
miningSchema
.withMiningFields(new MiningField(targetField)
.withUsageType(FieldUsageType.TARGET))

dataDictionary.withNumberOfFields((dataDictionary.getDataFields()).size())

pmml.setDataDictionary(dataDictionary)
pmml.withModels(regressionModel)

}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.mllib.export

import org.scalatest.FunSuite

import org.apache.spark.mllib.classification.LogisticRegressionModel
import org.apache.spark.mllib.classification.SVMModel
import org.apache.spark.mllib.clustering.KMeansModel
import org.apache.spark.mllib.linalg.Vectors
Expand All @@ -28,6 +29,7 @@ import org.apache.spark.mllib.regression.RidgeRegressionModel
import org.apache.spark.mllib.util.LinearDataGenerator
import org.apache.spark.mllib.export.pmml.GeneralizedLinearPMMLModelExport
import org.apache.spark.mllib.export.pmml.KMeansPMMLModelExport
import org.apache.spark.mllib.export.pmml.LogisticRegressionPMMLModelExport

class ModelExportFactorySuite extends FunSuite{

Expand Down Expand Up @@ -55,10 +57,10 @@ class ModelExportFactorySuite extends FunSuite{
//arrange
val linearInput = LinearDataGenerator.generateLinearInput(
3.0, Array(10.0, 10.0), 1, 17)
val linearRegressionModel = new LinearRegressionModel(linearInput(0).features, linearInput(0).label);
val ridgeRegressionModel = new RidgeRegressionModel(linearInput(0).features, linearInput(0).label);
val lassoModel = new LassoModel(linearInput(0).features, linearInput(0).label);
val svmModel = new SVMModel(linearInput(0).features, linearInput(0).label);
val linearRegressionModel = new LinearRegressionModel(linearInput(0).features, linearInput(0).label)
val ridgeRegressionModel = new RidgeRegressionModel(linearInput(0).features, linearInput(0).label)
val lassoModel = new LassoModel(linearInput(0).features, linearInput(0).label)
val svmModel = new SVMModel(linearInput(0).features, linearInput(0).label)

//act
val linearModelExport = ModelExportFactory.createModelExport(linearRegressionModel, ModelExportType.PMML)
Expand All @@ -82,6 +84,20 @@ class ModelExportFactorySuite extends FunSuite{

}

test("ModelExportFactory create LogisticRegressionPMMLModelExport when passing a LogisticRegressionModel") {

//arrange
val linearInput = LinearDataGenerator.generateLinearInput(
3.0, Array(10.0, 10.0), 1, 17)
val logisticRegressionModel = new LogisticRegressionModel(linearInput(0).features, linearInput(0).label);

//act
val logisticRegressionModelExport = ModelExportFactory.createModelExport(logisticRegressionModel, ModelExportType.PMML)
//assert
assert(logisticRegressionModelExport.isInstanceOf[LogisticRegressionPMMLModelExport])

}

test("ModelExportFactory throw IllegalArgumentException when passing an unsupported model") {

//arrange
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class GeneralizedLinearPMMLModelExportSuite extends FunSuite{
//assert that the PMML format is as expected
assert(svmModelExport.isInstanceOf[PMMLModelExport])
pmml = svmModelExport.asInstanceOf[PMMLModelExport].getPmml()
assert(pmml.getHeader().getDescription() === "linear SVM")
assert(pmml.getHeader().getDescription() === "linear SVM: if predicted value > 0, the outcome is positive, or negative otherwise")
//check that the number of fields match the weights size
assert(pmml.getDataDictionary().getNumberOfFields() === svmModel.weights.size + 1)
//this verify that there is a model attached to the pmml object and the model is a regression one
Expand All @@ -96,7 +96,7 @@ class GeneralizedLinearPMMLModelExportSuite extends FunSuite{
//ModelExporter.toPMML(linearRegressionModel,"/tmp/linearregression.xml")
//ModelExporter.toPMML(ridgeRegressionModel,"/tmp/ridgeregression.xml")
//ModelExporter.toPMML(lassoModel,"/tmp/lassoregression.xml")
//ModelExporter.toPMML(svmModel,"/tmp/svm.xml")
//ModelExporter.toPMML(svmModel,"/tmp/linearsvm.xml")

}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.export.pmml

import org.dmg.pmml.RegressionModel
import org.scalatest.FunSuite

import org.apache.spark.mllib.classification.LogisticRegressionModel
import org.apache.spark.mllib.export.ModelExportFactory
import org.apache.spark.mllib.export.ModelExportType
import org.apache.spark.mllib.util.LinearDataGenerator

class LogisticRegressionPMMLModelExportSuite extends FunSuite{

test("LogisticRegressionPMMLModelExport generate PMML format") {

//arrange models to test
val linearInput = LinearDataGenerator.generateLinearInput(
3.0, Array(10.0, 10.0), 1, 17)
val logisticRegressionModel = new LogisticRegressionModel(linearInput(0).features, linearInput(0).label);

//act by exporting the model to the PMML format
val logisticModelExport = ModelExportFactory.createModelExport(logisticRegressionModel, ModelExportType.PMML)
//assert that the PMML format is as expected
assert(logisticModelExport.isInstanceOf[PMMLModelExport])
var pmml = logisticModelExport.asInstanceOf[PMMLModelExport].getPmml()
assert(pmml.getHeader().getDescription() === "logistic regression: if predicted value > 0.5, the outcome is positive, or negative otherwise")
//check that the number of fields match the weights size
assert(pmml.getDataDictionary().getNumberOfFields() === logisticRegressionModel.weights.size + 1)
//this verify that there is a model attached to the pmml object and the model is a regression one
//it also verifies that the pmml model has a regression table (for target category YES) with the same number of predictors of the model weights
assert(pmml.getModels().get(0).asInstanceOf[RegressionModel]
.getRegressionTables().get(0).getTargetCategory() === "YES")
assert(pmml.getModels().get(0).asInstanceOf[RegressionModel]
.getRegressionTables().get(0).getNumericPredictors().size() === logisticRegressionModel.weights.size)
//verify if there is a second table with target category NO and no predictors
assert(pmml.getModels().get(0).asInstanceOf[RegressionModel]
.getRegressionTables().get(1).getTargetCategory() === "NO")
assert(pmml.getModels().get(0).asInstanceOf[RegressionModel]
.getRegressionTables().get(1).getNumericPredictors().size() === 0)

//manual checking
//ModelExporter.toPMML(logisticRegressionModel,"/tmp/logisticregression.xml")

}

}

0 comments on commit 03bc3a5

Please sign in to comment.