Skip to content

Commit

Permalink
Merge pull request #2 from mengxr/SPARK-1406
Browse files Browse the repository at this point in the history
more code style
  • Loading branch information
selvinsource committed Apr 21, 2015
2 parents e2313df + 3c22f79 commit a0a55f7
Show file tree
Hide file tree
Showing 7 changed files with 73 additions and 83 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@

package org.apache.spark.mllib.pmml

import java.io.File
import java.io.OutputStream
import java.io.StringWriter
import java.io.{File, OutputStream, StringWriter}
import javax.xml.transform.stream.StreamResult

import org.jpmml.model.JAXBUtil
Expand All @@ -33,37 +31,37 @@ import org.apache.spark.mllib.pmml.export.PMMLModelExportFactory
* developed by the Data Mining Group (www.dmg.org).
*/
trait PMMLExportable {

/**
* Export the model to the stream result in PMML format
*/
private def toPMML(streamResult: StreamResult): Unit = {
val pmmlModelExport = PMMLModelExportFactory.createPMMLModelExport(this)
JAXBUtil.marshalPMML(pmmlModelExport.getPmml, streamResult)
}

/**
* Export the model to a local file in PMML format
*/
def toPMML(localPath: String): Unit = {
toPMML(new StreamResult(new File(localPath)))
}

/**
* Export the model to a directory on a distributed file system in PMML format
*/
def toPMML(sc: SparkContext, path: String): Unit = {
val pmml = toPMML()
sc.parallelize(Array(pmml), 1).saveAsTextFile(path)
}

/**
* Export the model to the OutputStream in PMML format
*/
def toPMML(outputStream: OutputStream): Unit = {
toPMML(new StreamResult(outputStream))
}

/**
* Export the model to a String in PMML format
*/
Expand All @@ -72,5 +70,5 @@ trait PMMLExportable {
toPMML(new StreamResult(writer))
writer.toString
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -27,47 +27,47 @@ import org.apache.spark.mllib.regression.GeneralizedLinearModel
* PMML Model Export for GeneralizedLinearModel abstract class
*/
private[mllib] class GeneralizedLinearPMMLModelExport(
model : GeneralizedLinearModel,
description : String)
extends PMMLModelExport{
model: GeneralizedLinearModel,
description: String)
extends PMMLModelExport {

populateGeneralizedLinearPMML(model)

/**
* Export the input GeneralizedLinearModel model to PMML format.
*/
private def populateGeneralizedLinearPMML(model: GeneralizedLinearModel): Unit = {
pmml.getHeader.setDescription(description)
pmml.getHeader.setDescription(description)

if(model.weights.size > 0){
val fields = new SArray[FieldName](model.weights.size)
val dataDictionary = new DataDictionary
val miningSchema = new MiningSchema
val regressionTable = new RegressionTable(model.intercept)
val regressionModel = new RegressionModel(miningSchema,MiningFunctionType.REGRESSION)
.withModelName(description)
.withRegressionTables(regressionTable)
if (model.weights.size > 0) {
val fields = new SArray[FieldName](model.weights.size)
val dataDictionary = new DataDictionary
val miningSchema = new MiningSchema
val regressionTable = new RegressionTable(model.intercept)
val regressionModel = new RegressionModel(miningSchema, MiningFunctionType.REGRESSION)
.withModelName(description)
.withRegressionTables(regressionTable)

for (i <- 0 until model.weights.size) {
fields(i) = FieldName.create("field_" + i)
dataDictionary.withDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE))
miningSchema
.withMiningFields(new MiningField(fields(i))
.withUsageType(FieldUsageType.ACTIVE))
regressionTable.withNumericPredictors(new NumericPredictor(fields(i), model.weights(i)))
}

// for completeness add target field
val targetField = FieldName.create("target")
dataDictionary.withDataFields(new DataField(targetField, OpType.CONTINUOUS, DataType.DOUBLE))
miningSchema
.withMiningFields(new MiningField(targetField)
.withUsageType(FieldUsageType.TARGET))
for (i <- 0 until model.weights.size) {
fields(i) = FieldName.create("field_" + i)
dataDictionary.withDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE))
miningSchema
.withMiningFields(new MiningField(fields(i))
.withUsageType(FieldUsageType.ACTIVE))
regressionTable.withNumericPredictors(new NumericPredictor(fields(i), model.weights(i)))
}

dataDictionary.withNumberOfFields(dataDictionary.getDataFields.size)
// for completeness add target field
val targetField = FieldName.create("target")
dataDictionary.withDataFields(new DataField(targetField, OpType.CONTINUOUS, DataType.DOUBLE))
miningSchema
.withMiningFields(new MiningField(targetField)
.withUsageType(FieldUsageType.TARGET))

pmml.setDataDictionary(dataDictionary)
pmml.withModels(regressionModel)
}
dataDictionary.withNumberOfFields(dataDictionary.getDataFields.size)

pmml.setDataDictionary(dataDictionary)
pmml.withModels(regressionModel)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.apache.spark.mllib.classification.LogisticRegressionModel
private[mllib] class LogisticRegressionPMMLModelExport(
model : LogisticRegressionModel,
description : String)
extends PMMLModelExport{
extends PMMLModelExport {

populateLogisticRegressionPMML(model)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@ import org.dmg.pmml.RegressionModel
import org.scalatest.FunSuite

import org.apache.spark.mllib.classification.SVMModel
import org.apache.spark.mllib.regression.LassoModel
import org.apache.spark.mllib.regression.LinearRegressionModel
import org.apache.spark.mllib.regression.RidgeRegressionModel
import org.apache.spark.mllib.regression.{LassoModel, LinearRegressionModel, RidgeRegressionModel}
import org.apache.spark.mllib.util.LinearDataGenerator

class GeneralizedLinearPMMLModelExportSuite extends FunSuite {
Expand Down Expand Up @@ -87,7 +85,7 @@ class GeneralizedLinearPMMLModelExportSuite extends FunSuite {
test("svm pmml export") {
val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17)
val svmModel = new SVMModel(linearInput(0).features, linearInput(0).label)
val svmModelExport = PMMLModelExportFactory.createPMMLModelExport(svmModel)
val svmModelExport = PMMLModelExportFactory.createPMMLModelExport(svmModel)
// assert that the PMML format is as expected
assert(svmModelExport.isInstanceOf[PMMLModelExport])
val pmml = svmModelExport.getPmml
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,14 @@ import org.apache.spark.mllib.linalg.Vectors
class KMeansPMMLModelExportSuite extends FunSuite {

test("KMeansPMMLModelExport generate PMML format") {
// arrange model to test
val clusterCenters = Array(
Vectors.dense(1.0, 2.0, 6.0),
Vectors.dense(1.0, 3.0, 0.0),
Vectors.dense(1.0, 4.0, 6.0))
val kmeansModel = new KMeansModel(clusterCenters)

val modelExport = PMMLModelExportFactory.createPMMLModelExport(kmeansModel)

// assert that the PMML format is as expected
assert(modelExport.isInstanceOf[PMMLModelExport])
val pmml = modelExport.asInstanceOf[PMMLModelExport].getPmml
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@ import org.scalatest.FunSuite
import org.apache.spark.mllib.classification.LogisticRegressionModel
import org.apache.spark.mllib.util.LinearDataGenerator

class LogisticRegressionPMMLModelExportSuite extends FunSuite{
class LogisticRegressionPMMLModelExportSuite extends FunSuite {

test("LogisticRegressionPMMLModelExport generate PMML format") {
val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17)
val logisticRegressionModel =
new LogisticRegressionModel(linearInput(0).features, linearInput(0).label)

val logisticModelExport = PMMLModelExportFactory.createPMMLModelExport(logisticRegressionModel)

// assert that the PMML format is as expected
Expand All @@ -48,5 +48,5 @@ class LogisticRegressionPMMLModelExportSuite extends FunSuite{
// verify if there is a second table with target category 0 and no predictors
assert(pmmlRegressionModel.getRegressionTables.get(1).getTargetCategory === "0")
assert(pmmlRegressionModel.getRegressionTables.get(1).getNumericPredictors.size === 0)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,10 @@ package org.apache.spark.mllib.pmml.export

import org.scalatest.FunSuite

import org.apache.spark.mllib.classification.LogisticRegressionModel
import org.apache.spark.mllib.classification.SVMModel
import org.apache.spark.mllib.classification.{LogisticRegressionModel, SVMModel}
import org.apache.spark.mllib.clustering.KMeansModel
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LassoModel
import org.apache.spark.mllib.regression.LinearRegressionModel
import org.apache.spark.mllib.regression.RidgeRegressionModel
import org.apache.spark.mllib.regression.{LassoModel, LinearRegressionModel, RidgeRegressionModel}
import org.apache.spark.mllib.util.LinearDataGenerator

class PMMLModelExportFactorySuite extends FunSuite {
Expand All @@ -38,33 +35,32 @@ class PMMLModelExportFactorySuite extends FunSuite {
val kmeansModel = new KMeansModel(clusterCenters)

val modelExport = PMMLModelExportFactory.createPMMLModelExport(kmeansModel)

assert(modelExport.isInstanceOf[KMeansPMMLModelExport])
}

test("PMMLModelExportFactory create GeneralizedLinearPMMLModelExport when passing a "
+ "LinearRegressionModel, RidgeRegressionModel, LassoModel or SVMModel") {
val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17)
}

val linearRegressionModel =
new LinearRegressionModel(linearInput(0).features, linearInput(0).label)
val linearModelExport = PMMLModelExportFactory.createPMMLModelExport(linearRegressionModel)
assert(linearModelExport.isInstanceOf[GeneralizedLinearPMMLModelExport])
test("PMMLModelExportFactory create GeneralizedLinearPMMLModelExport when passing a "
+ "LinearRegressionModel, RidgeRegressionModel, LassoModel or SVMModel") {
val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17)

val ridgeRegressionModel =
new RidgeRegressionModel(linearInput(0).features, linearInput(0).label)
val ridgeModelExport = PMMLModelExportFactory.createPMMLModelExport(ridgeRegressionModel)
assert(ridgeModelExport.isInstanceOf[GeneralizedLinearPMMLModelExport])
val linearRegressionModel =
new LinearRegressionModel(linearInput(0).features, linearInput(0).label)
val linearModelExport = PMMLModelExportFactory.createPMMLModelExport(linearRegressionModel)
assert(linearModelExport.isInstanceOf[GeneralizedLinearPMMLModelExport])

val ridgeRegressionModel =
new RidgeRegressionModel(linearInput(0).features, linearInput(0).label)
val ridgeModelExport = PMMLModelExportFactory.createPMMLModelExport(ridgeRegressionModel)
assert(ridgeModelExport.isInstanceOf[GeneralizedLinearPMMLModelExport])

val lassoModel = new LassoModel(linearInput(0).features, linearInput(0).label)
val lassoModelExport = PMMLModelExportFactory.createPMMLModelExport(lassoModel)
assert(lassoModelExport.isInstanceOf[GeneralizedLinearPMMLModelExport])
val lassoModel = new LassoModel(linearInput(0).features, linearInput(0).label)
val lassoModelExport = PMMLModelExportFactory.createPMMLModelExport(lassoModel)
assert(lassoModelExport.isInstanceOf[GeneralizedLinearPMMLModelExport])

val svmModel = new SVMModel(linearInput(0).features, linearInput(0).label)
val svmModelExport = PMMLModelExportFactory.createPMMLModelExport(svmModel)
assert(svmModelExport.isInstanceOf[GeneralizedLinearPMMLModelExport])
}
val svmModel = new SVMModel(linearInput(0).features, linearInput(0).label)
val svmModelExport = PMMLModelExportFactory.createPMMLModelExport(svmModel)
assert(svmModelExport.isInstanceOf[GeneralizedLinearPMMLModelExport])
}

test("PMMLModelExportFactory create LogisticRegressionPMMLModelExport "
+ "when passing a LogisticRegressionModel") {
Expand All @@ -76,14 +72,13 @@ class PMMLModelExportFactorySuite extends FunSuite {
PMMLModelExportFactory.createPMMLModelExport(logisticRegressionModel)

assert(logisticRegressionModelExport.isInstanceOf[LogisticRegressionPMMLModelExport])
}

test("PMMLModelExportFactory throw IllegalArgumentException "
+ "when passing an unsupported model") {
}

test("PMMLModelExportFactory throw IllegalArgumentException when passing an unsupported model") {
val invalidModel = new Object

intercept[IllegalArgumentException] {
PMMLModelExportFactory.createPMMLModelExport(invalidModel)
}
}
}
}

0 comments on commit a0a55f7

Please sign in to comment.