Skip to content

Commit

Permalink
[SPARK-1406] Update code to latest pmml model
Browse files Browse the repository at this point in the history
  • Loading branch information
selvinsource committed Apr 21, 2015
1 parent dea98ca commit 25dce33
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ private[mllib] class GeneralizedLinearPMMLModelExport(
val dataDictionary = new DataDictionary
val miningSchema = new MiningSchema
val regressionTable = new RegressionTable(model.intercept)
val regressionModel = new RegressionModel(miningSchema, MiningFunctionType.REGRESSION)
val regressionModel = new RegressionModel()
.withFunctionName(MiningFunctionType.REGRESSION)
.withMiningSchema(miningSchema)
.withModelName(description)
.withRegressionTables(regressionTable)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,13 @@ private[mllib] class KMeansPMMLModelExport(model : KMeansModel) extends PMMLMode
val comparisonMeasure = new ComparisonMeasure()
.withKind(ComparisonMeasure.Kind.DISTANCE)
.withMeasure(new SquaredEuclidean())
val clusteringModel = new ClusteringModel(miningSchema, comparisonMeasure,
MiningFunctionType.CLUSTERING, ClusteringModel.ModelClass.CENTER_BASED,
model.clusterCenters.length)
val clusteringModel = new ClusteringModel()
.withModelName("k-means")
.withMiningSchema(miningSchema)
.withComparisonMeasure(comparisonMeasure)
.withFunctionName(MiningFunctionType.CLUSTERING)
.withModelClass(ClusteringModel.ModelClass.CENTER_BASED)
.withNumberOfClusters(model.clusterCenters.length)

for (i <- 0 until clusterCenter.size) {
fields(i) = FieldName.create("field_" + i)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ private[mllib] class LogisticRegressionPMMLModelExport(
val miningSchema = new MiningSchema
val regressionTableYES = new RegressionTable(model.intercept).withTargetCategory("1")
val regressionTableNO = new RegressionTable(0.0).withTargetCategory("0")
val regressionModel = new RegressionModel(miningSchema, MiningFunctionType.CLASSIFICATION)
val regressionModel = new RegressionModel()
.withFunctionName(MiningFunctionType.CLASSIFICATION)
.withMiningSchema(miningSchema)
.withModelName(description)
.withNormalizationMethod(RegressionNormalizationMethodType.LOGIT)
.withRegressionTables(regressionTableYES, regressionTableNO)
Expand Down

0 comments on commit 25dce33

Please sign in to comment.