Skip to content

Commit

Permalink
[SPARK-1406] Throw IllegalArgumentException when exporting a multinomial
Browse files Browse the repository at this point in the history
logistic regression
  • Loading branch information
selvinsource committed Apr 25, 2015
1 parent 25dce33 commit cfcb596
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,11 @@ private[mllib] object PMMLModelExportFactory {
new GeneralizedLinearPMMLModelExport(svm,
"linear SVM: if predicted value > 0, the outcome is positive, or negative otherwise")
case logistic: LogisticRegressionModel =>
new LogisticRegressionPMMLModelExport(logistic, "logistic regression")
if(logistic.numClasses == 2)
new LogisticRegressionPMMLModelExport(logistic, "logistic regression")
else
throw new IllegalArgumentException(
"PMML Export not supported for Multinomial Logistic Regression")
case _ =>
throw new IllegalArgumentException(
"PMML Export not supported for model: " + model.getClass.getName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,18 @@ class PMMLModelExportFactorySuite extends FunSuite {

assert(logisticRegressionModelExport.isInstanceOf[LogisticRegressionPMMLModelExport])
}

test("PMMLModelExportFactory throw IllegalArgumentException "
+ "when passing a Multinomial Logistic Regression") {
/** 3 classes, 2 features */
val multiclassLogisticRegressionModel = new LogisticRegressionModel(
weights = Vectors.dense(0.1, 0.2, 0.3, 0.4), intercept = 1.0,
numFeatures = 2, numClasses = 3)

intercept[IllegalArgumentException] {
PMMLModelExportFactory.createPMMLModelExport(multiclassLogisticRegressionModel)
}
}

test("PMMLModelExportFactory throw IllegalArgumentException when passing an unsupported model") {
val invalidModel = new Object
Expand Down

0 comments on commit cfcb596

Please sign in to comment.