diff --git a/mllib/src/main/scala/org/apache/spark/mllib/export/ModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/export/ModelExport.scala deleted file mode 100644 index a6daa6686cf5a..0000000000000 --- a/mllib/src/main/scala/org/apache/spark/mllib/export/ModelExport.scala +++ /dev/null @@ -1,22 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.mllib.export - -private[mllib] trait ModelExport { - -} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/export/ModelExportType.scala b/mllib/src/main/scala/org/apache/spark/mllib/export/ModelExportType.scala deleted file mode 100644 index f59040f3a3440..0000000000000 --- a/mllib/src/main/scala/org/apache/spark/mllib/export/ModelExportType.scala +++ /dev/null @@ -1,30 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.mllib.export - -/** - * Defines export types. - * - PMML exports the machine learning models in an XML-based file format - * called Predictive Model Markup Language developed by the Data Mining Group (www.dmg.org). - */ -private[mllib] object ModelExportType extends Enumeration{ - - type ModelExportType = Value - val PMML = Value - -} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/export/ModelExporter.scala b/mllib/src/main/scala/org/apache/spark/mllib/export/ModelExporter.scala deleted file mode 100644 index af35fee5eec06..0000000000000 --- a/mllib/src/main/scala/org/apache/spark/mllib/export/ModelExporter.scala +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.mllib.export - -import java.io.File -import javax.xml.transform.stream.StreamResult - -import org.jpmml.model.JAXBUtil - -import org.apache.spark.mllib.export.pmml.PMMLModelExport - -object ModelExporter { - - /** - * Export the input model to the stream result in PMML format - */ - def toPMML(inputModel: Any, streamResult: StreamResult): Unit = { - val modelExport = ModelExportFactory.createModelExport(inputModel, ModelExportType.PMML) - val pmml = modelExport.asInstanceOf[PMMLModelExport].getPmml() - JAXBUtil.marshalPMML(pmml, streamResult) - } - - /** - * Export the input model to a local path in PMML format - */ - def toPMML(inputModel: Any, localPath: String): Unit = { - toPMML(inputModel, new StreamResult(new File(localPath))) - } - -} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/PMMLExportable.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/PMMLExportable.scala new file mode 100644 index 0000000000000..b0ebc85b3719b --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/PMMLExportable.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.pmml + +import java.io.File +import java.io.OutputStream +import java.io.StringWriter +import javax.xml.transform.stream.StreamResult +import org.jpmml.model.JAXBUtil +import org.apache.spark.mllib.pmml.export.PMMLModelExport +import org.apache.spark.mllib.pmml.export.PMMLModelExportFactory + +/** + * Export model to the PMML format + * Predictive Model Markup Language (PMML) in an XML-based file format + * developed by the Data Mining Group (www.dmg.org). + */ +trait PMMLExportable { + + /** + * Export the model to the stream result in PMML format + */ + private def toPMML(streamResult: StreamResult): Unit = { + val pmmlModelExport = PMMLModelExportFactory.createPMMLModelExport(this) + JAXBUtil.marshalPMML(pmmlModelExport.getPmml(), streamResult) + } + + /** + * Export the model to a local File in PMML format + */ + def toPMML(localPath: String): Unit = { + toPMML(new StreamResult(new File(localPath))) + } + + /** + * Export the model to the Outputtream in PMML format + */ + def toPMML(outputStream: OutputStream): Unit = { + toPMML(new StreamResult(outputStream)) + } + + /** + * Export the model to a String in PMML format + */ + def toPMML(): String = { + var writer = new StringWriter(); + toPMML(new StreamResult(writer)) + return writer.toString(); + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/export/pmml/GeneralizedLinearPMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExport.scala similarity index 98% rename from mllib/src/main/scala/org/apache/spark/mllib/export/pmml/GeneralizedLinearPMMLModelExport.scala rename to mllib/src/main/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExport.scala index 8b3d1ce9e3e0f..94bbd705a9b69 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/export/pmml/GeneralizedLinearPMMLModelExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExport.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.mllib.export.pmml +package org.apache.spark.mllib.pmml.export import org.dmg.pmml.DataDictionary import org.dmg.pmml.DataField @@ -29,7 +29,6 @@ import org.dmg.pmml.NumericPredictor import org.dmg.pmml.OpType import org.dmg.pmml.RegressionModel import org.dmg.pmml.RegressionTable - import org.apache.spark.mllib.regression.GeneralizedLinearModel /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/export/pmml/KMeansPMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExport.scala similarity index 98% rename from mllib/src/main/scala/org/apache/spark/mllib/export/pmml/KMeansPMMLModelExport.scala rename to mllib/src/main/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExport.scala index c10d48fb8eb4c..901fbb6858a20 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/export/pmml/KMeansPMMLModelExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExport.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.mllib.export.pmml +package org.apache.spark.mllib.pmml.export import org.dmg.pmml.Array.Type import org.dmg.pmml.Cluster @@ -35,7 +35,6 @@ import org.dmg.pmml.MiningFunctionType import org.dmg.pmml.MiningSchema import org.dmg.pmml.OpType import org.dmg.pmml.SquaredEuclidean - import org.apache.spark.mllib.clustering.KMeansModel /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/export/pmml/LogisticRegressionPMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/LogisticRegressionPMMLModelExport.scala similarity index 98% rename from mllib/src/main/scala/org/apache/spark/mllib/export/pmml/LogisticRegressionPMMLModelExport.scala rename to mllib/src/main/scala/org/apache/spark/mllib/pmml/export/LogisticRegressionPMMLModelExport.scala index 0d65bc9ddc627..0b1d1d465b939 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/export/pmml/LogisticRegressionPMMLModelExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/LogisticRegressionPMMLModelExport.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.mllib.export.pmml +package org.apache.spark.mllib.pmml.export import org.dmg.pmml.DataDictionary import org.dmg.pmml.DataField @@ -29,8 +29,8 @@ import org.dmg.pmml.NumericPredictor import org.dmg.pmml.OpType import org.dmg.pmml.RegressionModel import org.dmg.pmml.RegressionTable -import org.apache.spark.mllib.classification.LogisticRegressionModel import org.dmg.pmml.RegressionNormalizationMethodType +import org.apache.spark.mllib.classification.LogisticRegressionModel /** * PMML Model Export for LogisticRegressionModel class diff --git a/mllib/src/main/scala/org/apache/spark/mllib/export/pmml/PMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala similarity index 91% rename from mllib/src/main/scala/org/apache/spark/mllib/export/pmml/PMMLModelExport.scala rename to mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala index cf9f993e88e59..14ab5e0d2c7b6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/export/pmml/PMMLModelExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala @@ -15,21 +15,17 @@ * limitations under the License. */ -package org.apache.spark.mllib.export.pmml +package org.apache.spark.mllib.pmml.export import java.text.SimpleDateFormat import java.util.Date - import scala.beans.BeanProperty - import org.dmg.pmml.Application import org.dmg.pmml.Header import org.dmg.pmml.PMML import org.dmg.pmml.Timestamp -import org.apache.spark.mllib.export.ModelExport - -private[mllib] trait PMMLModelExport extends ModelExport{ +private[mllib] trait PMMLModelExport { /** * Holder of the exported model in PMML format diff --git a/mllib/src/main/scala/org/apache/spark/mllib/export/ModelExportFactory.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala similarity index 67% rename from mllib/src/main/scala/org/apache/spark/mllib/export/ModelExportFactory.scala rename to mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala index 618fe79a7b14a..f97b1ace61ef9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/export/ModelExportFactory.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala @@ -15,30 +15,23 @@ * limitations under the License. */ -package org.apache.spark.mllib.export +package org.apache.spark.mllib.pmml.export import org.apache.spark.mllib.classification.LogisticRegressionModel import org.apache.spark.mllib.classification.SVMModel import org.apache.spark.mllib.clustering.KMeansModel -import org.apache.spark.mllib.export.ModelExportType.ModelExportType -import org.apache.spark.mllib.export.ModelExportType.PMML -import org.apache.spark.mllib.export.pmml.GeneralizedLinearPMMLModelExport -import org.apache.spark.mllib.export.pmml.KMeansPMMLModelExport -import org.apache.spark.mllib.export.pmml.LogisticRegressionPMMLModelExport import org.apache.spark.mllib.regression.LassoModel import org.apache.spark.mllib.regression.LinearRegressionModel import org.apache.spark.mllib.regression.RidgeRegressionModel -private[mllib] object ModelExportFactory { +private[mllib] object PMMLModelExportFactory { /** - * Factory object to help creating the necessary ModelExport implementation - * taking as input the ModelExportType (for example PMML) - * and the machine learning model (for example KMeansModel). + * Factory object to help creating the necessary PMMLModelExport implementation + * taking as input the machine learning model (for example KMeansModel). */ - def createModelExport(model: Any, exportType: ModelExportType): ModelExport = { - return exportType match{ - case PMML => model match{ + def createPMMLModelExport(model: Any): PMMLModelExport = { + return model match{ case kmeans: KMeansModel => new KMeansPMMLModelExport(kmeans) case linearRegression: LinearRegressionModel => @@ -54,10 +47,8 @@ private[mllib] object ModelExportFactory { case logisticRegression: LogisticRegressionModel => new LogisticRegressionPMMLModelExport(logisticRegression, "logistic regression") case _ => - throw new IllegalArgumentException("Export not supported for model: " + model.getClass) - } - case _ => throw new IllegalArgumentException("Export type not supported:" + exportType) - } + throw new IllegalArgumentException("PMML Export not supported for model: " + model.getClass) + } } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/export/pmml/GeneralizedLinearPMMLModelExportSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExportSuite.scala similarity index 85% rename from mllib/src/test/scala/org/apache/spark/mllib/export/pmml/GeneralizedLinearPMMLModelExportSuite.scala rename to mllib/src/test/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExportSuite.scala index e27c193a7f7f8..9b0c81b5b099f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/export/pmml/GeneralizedLinearPMMLModelExportSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExportSuite.scala @@ -15,14 +15,11 @@ * limitations under the License. */ -package org.apache.spark.mllib.export.pmml +package org.apache.spark.mllib.pmml.export import org.dmg.pmml.RegressionModel import org.scalatest.FunSuite - import org.apache.spark.mllib.classification.SVMModel -import org.apache.spark.mllib.export.ModelExportFactory -import org.apache.spark.mllib.export.ModelExportType import org.apache.spark.mllib.regression.LassoModel import org.apache.spark.mllib.regression.LinearRegressionModel import org.apache.spark.mllib.regression.RidgeRegressionModel @@ -41,7 +38,7 @@ class GeneralizedLinearPMMLModelExportSuite extends FunSuite{ val svmModel = new SVMModel(linearInput(0).features, linearInput(0).label); //act by exporting the model to the PMML format - val linearModelExport = ModelExportFactory.createModelExport(linearRegressionModel, ModelExportType.PMML) + val linearModelExport = PMMLModelExportFactory.createPMMLModelExport(linearRegressionModel) //assert that the PMML format is as expected assert(linearModelExport.isInstanceOf[PMMLModelExport]) var pmml = linearModelExport.asInstanceOf[PMMLModelExport].getPmml() @@ -54,7 +51,7 @@ class GeneralizedLinearPMMLModelExportSuite extends FunSuite{ .getRegressionTables().get(0).getNumericPredictors().size() === linearRegressionModel.weights.size) //act - val ridgeModelExport = ModelExportFactory.createModelExport(ridgeRegressionModel, ModelExportType.PMML) + val ridgeModelExport = PMMLModelExportFactory.createPMMLModelExport(ridgeRegressionModel) //assert that the PMML format is as expected assert(ridgeModelExport.isInstanceOf[PMMLModelExport]) pmml = ridgeModelExport.asInstanceOf[PMMLModelExport].getPmml() @@ -67,7 +64,7 @@ class GeneralizedLinearPMMLModelExportSuite extends FunSuite{ .getRegressionTables().get(0).getNumericPredictors().size() === ridgeRegressionModel.weights.size) //act - val lassoModelExport = ModelExportFactory.createModelExport(lassoModel, ModelExportType.PMML) + val lassoModelExport = PMMLModelExportFactory.createPMMLModelExport(lassoModel) //assert that the PMML format is as expected assert(lassoModelExport.isInstanceOf[PMMLModelExport]) pmml = lassoModelExport.asInstanceOf[PMMLModelExport].getPmml() @@ -80,7 +77,7 @@ class GeneralizedLinearPMMLModelExportSuite extends FunSuite{ .getRegressionTables().get(0).getNumericPredictors().size() === lassoModel.weights.size) //act - val svmModelExport = ModelExportFactory.createModelExport(svmModel, ModelExportType.PMML) + val svmModelExport = PMMLModelExportFactory.createPMMLModelExport(svmModel) //assert that the PMML format is as expected assert(svmModelExport.isInstanceOf[PMMLModelExport]) pmml = svmModelExport.asInstanceOf[PMMLModelExport].getPmml() @@ -93,10 +90,10 @@ class GeneralizedLinearPMMLModelExportSuite extends FunSuite{ .getRegressionTables().get(0).getNumericPredictors().size() === svmModel.weights.size) //manual checking - //ModelExporter.toPMML(linearRegressionModel,"/tmp/linearregression.xml") - //ModelExporter.toPMML(ridgeRegressionModel,"/tmp/ridgeregression.xml") - //ModelExporter.toPMML(lassoModel,"/tmp/lassoregression.xml") - //ModelExporter.toPMML(svmModel,"/tmp/linearsvm.xml") + //linearRegressionModel.toPMML("/tmp/linearregression.xml") + //ridgeRegressionModel.toPMML("/tmp/ridgeregression.xml") + //lassoModel.toPMML("/tmp/lassoregression.xml") + //svmModel.toPMML("/tmp/linearsvm.xml") } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/export/pmml/KMeansPMMLModelExportSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExportSuite.scala similarity index 84% rename from mllib/src/test/scala/org/apache/spark/mllib/export/pmml/KMeansPMMLModelExportSuite.scala rename to mllib/src/test/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExportSuite.scala index 471e311b53653..00682b0f78190 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/export/pmml/KMeansPMMLModelExportSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExportSuite.scala @@ -15,14 +15,11 @@ * limitations under the License. */ -package org.apache.spark.mllib.export.pmml +package org.apache.spark.mllib.pmml.export import org.dmg.pmml.ClusteringModel import org.scalatest.FunSuite - import org.apache.spark.mllib.clustering.KMeansModel -import org.apache.spark.mllib.export.ModelExportFactory -import org.apache.spark.mllib.export.ModelExportType import org.apache.spark.mllib.linalg.Vectors class KMeansPMMLModelExportSuite extends FunSuite{ @@ -38,7 +35,7 @@ class KMeansPMMLModelExportSuite extends FunSuite{ val kmeansModel = new KMeansModel(clusterCenters); //act by exporting the model to the PMML format - val modelExport = ModelExportFactory.createModelExport(kmeansModel, ModelExportType.PMML) + val modelExport = PMMLModelExportFactory.createPMMLModelExport(kmeansModel) //assert that the PMML format is as expected assert(modelExport.isInstanceOf[PMMLModelExport]) @@ -51,8 +48,9 @@ class KMeansPMMLModelExportSuite extends FunSuite{ assert(pmml.getModels().get(0).asInstanceOf[ClusteringModel].getNumberOfClusters() === clusterCenters.size) //manual checking - //ModelExporter.toPMML(kmeansModel,new StreamResult(System.out)) - //ModelExporter.toPMML(kmeansModel,"/tmp/kmeans.xml") + //kmeansModel.toPMML("/tmp/kmeans.xml") + //kmeansModel.toPMML(System.out) + //System.out.println(kmeansModel.toPMML()) } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/export/pmml/LogisticRegressionPMMLModelExportSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/LogisticRegressionPMMLModelExportSuite.scala similarity index 88% rename from mllib/src/test/scala/org/apache/spark/mllib/export/pmml/LogisticRegressionPMMLModelExportSuite.scala rename to mllib/src/test/scala/org/apache/spark/mllib/pmml/export/LogisticRegressionPMMLModelExportSuite.scala index 0bb6c9a60a485..b96194d47b882 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/export/pmml/LogisticRegressionPMMLModelExportSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/LogisticRegressionPMMLModelExportSuite.scala @@ -15,14 +15,11 @@ * limitations under the License. */ -package org.apache.spark.mllib.export.pmml +package org.apache.spark.mllib.pmml.export import org.dmg.pmml.RegressionModel import org.scalatest.FunSuite - import org.apache.spark.mllib.classification.LogisticRegressionModel -import org.apache.spark.mllib.export.ModelExportFactory -import org.apache.spark.mllib.export.ModelExportType import org.apache.spark.mllib.util.LinearDataGenerator class LogisticRegressionPMMLModelExportSuite extends FunSuite{ @@ -35,7 +32,7 @@ class LogisticRegressionPMMLModelExportSuite extends FunSuite{ val logisticRegressionModel = new LogisticRegressionModel(linearInput(0).features, linearInput(0).label); //act by exporting the model to the PMML format - val logisticModelExport = ModelExportFactory.createModelExport(logisticRegressionModel, ModelExportType.PMML) + val logisticModelExport = PMMLModelExportFactory.createPMMLModelExport(logisticRegressionModel) //assert that the PMML format is as expected assert(logisticModelExport.isInstanceOf[PMMLModelExport]) var pmml = logisticModelExport.asInstanceOf[PMMLModelExport].getPmml() @@ -55,7 +52,7 @@ class LogisticRegressionPMMLModelExportSuite extends FunSuite{ .getRegressionTables().get(1).getNumericPredictors().size() === 0) //manual checking - //ModelExporter.toPMML(logisticRegressionModel,"/tmp/logisticregression.xml") + //logisticRegressionModel.toPMML("/tmp/logisticregression.xml") } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/export/ModelExportFactorySuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala similarity index 68% rename from mllib/src/test/scala/org/apache/spark/mllib/export/ModelExportFactorySuite.scala rename to mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala index d63a544ebdf97..5b34e5a8329fb 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/export/ModelExportFactorySuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala @@ -15,10 +15,9 @@ * limitations under the License. */ -package org.apache.spark.mllib.export +package org.apache.spark.mllib.pmml.export import org.scalatest.FunSuite - import org.apache.spark.mllib.classification.LogisticRegressionModel import org.apache.spark.mllib.classification.SVMModel import org.apache.spark.mllib.clustering.KMeansModel @@ -27,13 +26,10 @@ import org.apache.spark.mllib.regression.LassoModel import org.apache.spark.mllib.regression.LinearRegressionModel import org.apache.spark.mllib.regression.RidgeRegressionModel import org.apache.spark.mllib.util.LinearDataGenerator -import org.apache.spark.mllib.export.pmml.GeneralizedLinearPMMLModelExport -import org.apache.spark.mllib.export.pmml.KMeansPMMLModelExport -import org.apache.spark.mllib.export.pmml.LogisticRegressionPMMLModelExport -class ModelExportFactorySuite extends FunSuite{ +class PMMLModelExportFactorySuite extends FunSuite{ - test("ModelExportFactory create KMeansPMMLModelExport when passing a KMeansModel") { + test("PMMLModelExportFactory create KMeansPMMLModelExport when passing a KMeansModel") { //arrange val clusterCenters = Array( @@ -44,14 +40,14 @@ class ModelExportFactorySuite extends FunSuite{ val kmeansModel = new KMeansModel(clusterCenters); //act - val modelExport = ModelExportFactory.createModelExport(kmeansModel, ModelExportType.PMML) + val modelExport = PMMLModelExportFactory.createPMMLModelExport(kmeansModel) //assert assert(modelExport.isInstanceOf[KMeansPMMLModelExport]) } - test("ModelExportFactory create GeneralizedLinearPMMLModelExport when passing a " + test("PMMLModelExportFactory create GeneralizedLinearPMMLModelExport when passing a " +"LinearRegressionModel, RidgeRegressionModel, LassoModel or SVMModel") { //arrange @@ -63,28 +59,28 @@ class ModelExportFactorySuite extends FunSuite{ val svmModel = new SVMModel(linearInput(0).features, linearInput(0).label) //act - val linearModelExport = ModelExportFactory.createModelExport(linearRegressionModel, ModelExportType.PMML) + val linearModelExport = PMMLModelExportFactory.createPMMLModelExport(linearRegressionModel) //assert assert(linearModelExport.isInstanceOf[GeneralizedLinearPMMLModelExport]) //act - val ridgeModelExport = ModelExportFactory.createModelExport(ridgeRegressionModel, ModelExportType.PMML) + val ridgeModelExport = PMMLModelExportFactory.createPMMLModelExport(ridgeRegressionModel) //assert assert(ridgeModelExport.isInstanceOf[GeneralizedLinearPMMLModelExport]) //act - val lassoModelExport = ModelExportFactory.createModelExport(lassoModel, ModelExportType.PMML) + val lassoModelExport = PMMLModelExportFactory.createPMMLModelExport(lassoModel) //assert assert(lassoModelExport.isInstanceOf[GeneralizedLinearPMMLModelExport]) //act - val svmModelExport = ModelExportFactory.createModelExport(svmModel, ModelExportType.PMML) + val svmModelExport = PMMLModelExportFactory.createPMMLModelExport(svmModel) //assert assert(svmModelExport.isInstanceOf[GeneralizedLinearPMMLModelExport]) } - test("ModelExportFactory create LogisticRegressionPMMLModelExport when passing a LogisticRegressionModel") { + test("PMMLModelExportFactory create LogisticRegressionPMMLModelExport when passing a LogisticRegressionModel") { //arrange val linearInput = LinearDataGenerator.generateLinearInput( @@ -92,13 +88,13 @@ class ModelExportFactorySuite extends FunSuite{ val logisticRegressionModel = new LogisticRegressionModel(linearInput(0).features, linearInput(0).label); //act - val logisticRegressionModelExport = ModelExportFactory.createModelExport(logisticRegressionModel, ModelExportType.PMML) + val logisticRegressionModelExport = PMMLModelExportFactory.createPMMLModelExport(logisticRegressionModel) //assert assert(logisticRegressionModelExport.isInstanceOf[LogisticRegressionPMMLModelExport]) } - test("ModelExportFactory throw IllegalArgumentException when passing an unsupported model") { + test("PMMLModelExportFactory throw IllegalArgumentException when passing an unsupported model") { //arrange val invalidModel = new Object; @@ -106,7 +102,7 @@ class ModelExportFactorySuite extends FunSuite{ //assert intercept[IllegalArgumentException] { //act - ModelExportFactory.createModelExport(invalidModel, ModelExportType.PMML) + PMMLModelExportFactory.createPMMLModelExport(invalidModel) } }