diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/PMMLExportable.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/PMMLExportable.scala index 938a7998cdf5f..354e90f3eeaa6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/PMMLExportable.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/PMMLExportable.scala @@ -17,9 +17,7 @@ package org.apache.spark.mllib.pmml -import java.io.File -import java.io.OutputStream -import java.io.StringWriter +import java.io.{File, OutputStream, StringWriter} import javax.xml.transform.stream.StreamResult import org.jpmml.model.JAXBUtil @@ -33,7 +31,7 @@ import org.apache.spark.mllib.pmml.export.PMMLModelExportFactory * developed by the Data Mining Group (www.dmg.org). */ trait PMMLExportable { - + /** * Export the model to the stream result in PMML format */ @@ -41,14 +39,14 @@ trait PMMLExportable { val pmmlModelExport = PMMLModelExportFactory.createPMMLModelExport(this) JAXBUtil.marshalPMML(pmmlModelExport.getPmml, streamResult) } - + /** * Export the model to a local file in PMML format */ def toPMML(localPath: String): Unit = { toPMML(new StreamResult(new File(localPath))) } - + /** * Export the model to a directory on a distributed file system in PMML format */ @@ -56,14 +54,14 @@ trait PMMLExportable { val pmml = toPMML() sc.parallelize(Array(pmml), 1).saveAsTextFile(path) } - + /** * Export the model to the OutputStream in PMML format */ def toPMML(outputStream: OutputStream): Unit = { toPMML(new StreamResult(outputStream)) } - + /** * Export the model to a String in PMML format */ @@ -72,5 +70,5 @@ trait PMMLExportable { toPMML(new StreamResult(writer)) writer.toString } - + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExport.scala index baab1a2dbf963..8c079d5aec42c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExport.scala @@ -27,9 +27,9 @@ import org.apache.spark.mllib.regression.GeneralizedLinearModel * PMML Model Export for GeneralizedLinearModel abstract class */ private[mllib] class GeneralizedLinearPMMLModelExport( - model : GeneralizedLinearModel, - description : String) - extends PMMLModelExport{ + model: GeneralizedLinearModel, + description: String) + extends PMMLModelExport { populateGeneralizedLinearPMML(model) @@ -37,37 +37,37 @@ private[mllib] class GeneralizedLinearPMMLModelExport( * Export the input GeneralizedLinearModel model to PMML format. */ private def populateGeneralizedLinearPMML(model: GeneralizedLinearModel): Unit = { - pmml.getHeader.setDescription(description) + pmml.getHeader.setDescription(description) - if(model.weights.size > 0){ - val fields = new SArray[FieldName](model.weights.size) - val dataDictionary = new DataDictionary - val miningSchema = new MiningSchema - val regressionTable = new RegressionTable(model.intercept) - val regressionModel = new RegressionModel(miningSchema,MiningFunctionType.REGRESSION) - .withModelName(description) - .withRegressionTables(regressionTable) + if (model.weights.size > 0) { + val fields = new SArray[FieldName](model.weights.size) + val dataDictionary = new DataDictionary + val miningSchema = new MiningSchema + val regressionTable = new RegressionTable(model.intercept) + val regressionModel = new RegressionModel(miningSchema, MiningFunctionType.REGRESSION) + .withModelName(description) + .withRegressionTables(regressionTable) - for (i <- 0 until model.weights.size) { - fields(i) = FieldName.create("field_" + i) - dataDictionary.withDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE)) - miningSchema - .withMiningFields(new MiningField(fields(i)) - .withUsageType(FieldUsageType.ACTIVE)) - regressionTable.withNumericPredictors(new NumericPredictor(fields(i), model.weights(i))) - } - - // for completeness add target field - val targetField = FieldName.create("target") - dataDictionary.withDataFields(new DataField(targetField, OpType.CONTINUOUS, DataType.DOUBLE)) - miningSchema - .withMiningFields(new MiningField(targetField) - .withUsageType(FieldUsageType.TARGET)) + for (i <- 0 until model.weights.size) { + fields(i) = FieldName.create("field_" + i) + dataDictionary.withDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE)) + miningSchema + .withMiningFields(new MiningField(fields(i)) + .withUsageType(FieldUsageType.ACTIVE)) + regressionTable.withNumericPredictors(new NumericPredictor(fields(i), model.weights(i))) + } - dataDictionary.withNumberOfFields(dataDictionary.getDataFields.size) + // for completeness add target field + val targetField = FieldName.create("target") + dataDictionary.withDataFields(new DataField(targetField, OpType.CONTINUOUS, DataType.DOUBLE)) + miningSchema + .withMiningFields(new MiningField(targetField) + .withUsageType(FieldUsageType.TARGET)) - pmml.setDataDictionary(dataDictionary) - pmml.withModels(regressionModel) - } + dataDictionary.withNumberOfFields(dataDictionary.getDataFields.size) + + pmml.setDataDictionary(dataDictionary) + pmml.withModels(regressionModel) + } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/LogisticRegressionPMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/LogisticRegressionPMMLModelExport.scala index 75c28e1c03514..6e818c7709bda 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/LogisticRegressionPMMLModelExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/LogisticRegressionPMMLModelExport.scala @@ -29,7 +29,7 @@ import org.apache.spark.mllib.classification.LogisticRegressionModel private[mllib] class LogisticRegressionPMMLModelExport( model : LogisticRegressionModel, description : String) - extends PMMLModelExport{ + extends PMMLModelExport { populateLogisticRegressionPMML(model) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExportSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExportSuite.scala index f48d39f889cd3..417ea12ccfa0b 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExportSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExportSuite.scala @@ -21,9 +21,7 @@ import org.dmg.pmml.RegressionModel import org.scalatest.FunSuite import org.apache.spark.mllib.classification.SVMModel -import org.apache.spark.mllib.regression.LassoModel -import org.apache.spark.mllib.regression.LinearRegressionModel -import org.apache.spark.mllib.regression.RidgeRegressionModel +import org.apache.spark.mllib.regression.{LassoModel, LinearRegressionModel, RidgeRegressionModel} import org.apache.spark.mllib.util.LinearDataGenerator class GeneralizedLinearPMMLModelExportSuite extends FunSuite { @@ -87,7 +85,7 @@ class GeneralizedLinearPMMLModelExportSuite extends FunSuite { test("svm pmml export") { val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17) val svmModel = new SVMModel(linearInput(0).features, linearInput(0).label) - val svmModelExport = PMMLModelExportFactory.createPMMLModelExport(svmModel) + val svmModelExport = PMMLModelExportFactory.createPMMLModelExport(svmModel) // assert that the PMML format is as expected assert(svmModelExport.isInstanceOf[PMMLModelExport]) val pmml = svmModelExport.getPmml diff --git a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExportSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExportSuite.scala index f34e2a210a9fd..d3c1dd85fa3b1 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExportSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExportSuite.scala @@ -26,15 +26,14 @@ import org.apache.spark.mllib.linalg.Vectors class KMeansPMMLModelExportSuite extends FunSuite { test("KMeansPMMLModelExport generate PMML format") { - // arrange model to test val clusterCenters = Array( Vectors.dense(1.0, 2.0, 6.0), Vectors.dense(1.0, 3.0, 0.0), Vectors.dense(1.0, 4.0, 6.0)) val kmeansModel = new KMeansModel(clusterCenters) - + val modelExport = PMMLModelExportFactory.createPMMLModelExport(kmeansModel) - + // assert that the PMML format is as expected assert(modelExport.isInstanceOf[PMMLModelExport]) val pmml = modelExport.asInstanceOf[PMMLModelExport].getPmml diff --git a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/LogisticRegressionPMMLModelExportSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/LogisticRegressionPMMLModelExportSuite.scala index af642702ed942..696f95ed873bb 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/LogisticRegressionPMMLModelExportSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/LogisticRegressionPMMLModelExportSuite.scala @@ -23,13 +23,13 @@ import org.scalatest.FunSuite import org.apache.spark.mllib.classification.LogisticRegressionModel import org.apache.spark.mllib.util.LinearDataGenerator -class LogisticRegressionPMMLModelExportSuite extends FunSuite{ +class LogisticRegressionPMMLModelExportSuite extends FunSuite { test("LogisticRegressionPMMLModelExport generate PMML format") { val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17) val logisticRegressionModel = new LogisticRegressionModel(linearInput(0).features, linearInput(0).label) - + val logisticModelExport = PMMLModelExportFactory.createPMMLModelExport(logisticRegressionModel) // assert that the PMML format is as expected @@ -48,5 +48,5 @@ class LogisticRegressionPMMLModelExportSuite extends FunSuite{ // verify if there is a second table with target category 0 and no predictors assert(pmmlRegressionModel.getRegressionTables.get(1).getTargetCategory === "0") assert(pmmlRegressionModel.getRegressionTables.get(1).getNumericPredictors.size === 0) - } + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala index b466e08d09e6d..a94854e4c0f20 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala @@ -19,13 +19,10 @@ package org.apache.spark.mllib.pmml.export import org.scalatest.FunSuite -import org.apache.spark.mllib.classification.LogisticRegressionModel -import org.apache.spark.mllib.classification.SVMModel +import org.apache.spark.mllib.classification.{LogisticRegressionModel, SVMModel} import org.apache.spark.mllib.clustering.KMeansModel import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.mllib.regression.LassoModel -import org.apache.spark.mllib.regression.LinearRegressionModel -import org.apache.spark.mllib.regression.RidgeRegressionModel +import org.apache.spark.mllib.regression.{LassoModel, LinearRegressionModel, RidgeRegressionModel} import org.apache.spark.mllib.util.LinearDataGenerator class PMMLModelExportFactorySuite extends FunSuite { @@ -38,33 +35,32 @@ class PMMLModelExportFactorySuite extends FunSuite { val kmeansModel = new KMeansModel(clusterCenters) val modelExport = PMMLModelExportFactory.createPMMLModelExport(kmeansModel) - + assert(modelExport.isInstanceOf[KMeansPMMLModelExport]) - } - - test("PMMLModelExportFactory create GeneralizedLinearPMMLModelExport when passing a " - + "LinearRegressionModel, RidgeRegressionModel, LassoModel or SVMModel") { - val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17) + } - val linearRegressionModel = - new LinearRegressionModel(linearInput(0).features, linearInput(0).label) - val linearModelExport = PMMLModelExportFactory.createPMMLModelExport(linearRegressionModel) - assert(linearModelExport.isInstanceOf[GeneralizedLinearPMMLModelExport]) + test("PMMLModelExportFactory create GeneralizedLinearPMMLModelExport when passing a " + + "LinearRegressionModel, RidgeRegressionModel, LassoModel or SVMModel") { + val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17) - val ridgeRegressionModel = - new RidgeRegressionModel(linearInput(0).features, linearInput(0).label) - val ridgeModelExport = PMMLModelExportFactory.createPMMLModelExport(ridgeRegressionModel) - assert(ridgeModelExport.isInstanceOf[GeneralizedLinearPMMLModelExport]) + val linearRegressionModel = + new LinearRegressionModel(linearInput(0).features, linearInput(0).label) + val linearModelExport = PMMLModelExportFactory.createPMMLModelExport(linearRegressionModel) + assert(linearModelExport.isInstanceOf[GeneralizedLinearPMMLModelExport]) + val ridgeRegressionModel = + new RidgeRegressionModel(linearInput(0).features, linearInput(0).label) + val ridgeModelExport = PMMLModelExportFactory.createPMMLModelExport(ridgeRegressionModel) + assert(ridgeModelExport.isInstanceOf[GeneralizedLinearPMMLModelExport]) - val lassoModel = new LassoModel(linearInput(0).features, linearInput(0).label) - val lassoModelExport = PMMLModelExportFactory.createPMMLModelExport(lassoModel) - assert(lassoModelExport.isInstanceOf[GeneralizedLinearPMMLModelExport]) + val lassoModel = new LassoModel(linearInput(0).features, linearInput(0).label) + val lassoModelExport = PMMLModelExportFactory.createPMMLModelExport(lassoModel) + assert(lassoModelExport.isInstanceOf[GeneralizedLinearPMMLModelExport]) - val svmModel = new SVMModel(linearInput(0).features, linearInput(0).label) - val svmModelExport = PMMLModelExportFactory.createPMMLModelExport(svmModel) - assert(svmModelExport.isInstanceOf[GeneralizedLinearPMMLModelExport]) - } + val svmModel = new SVMModel(linearInput(0).features, linearInput(0).label) + val svmModelExport = PMMLModelExportFactory.createPMMLModelExport(svmModel) + assert(svmModelExport.isInstanceOf[GeneralizedLinearPMMLModelExport]) + } test("PMMLModelExportFactory create LogisticRegressionPMMLModelExport " + "when passing a LogisticRegressionModel") { @@ -76,14 +72,13 @@ class PMMLModelExportFactorySuite extends FunSuite { PMMLModelExportFactory.createPMMLModelExport(logisticRegressionModel) assert(logisticRegressionModelExport.isInstanceOf[LogisticRegressionPMMLModelExport]) - } - - test("PMMLModelExportFactory throw IllegalArgumentException " - + "when passing an unsupported model") { + } + + test("PMMLModelExportFactory throw IllegalArgumentException when passing an unsupported model") { val invalidModel = new Object - + intercept[IllegalArgumentException] { PMMLModelExportFactory.createPMMLModelExport(invalidModel) } - } + } }