From 95e3ff27d9dfd242bce322d6b3b635eb7fea2a8d Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Sat, 19 Dec 2015 00:33:44 +0900 Subject: [PATCH] Fixed ParamMap#filter to avoid ClassCastException and SI-6654 --- .../org/apache/spark/ml/param/params.scala | 11 ++++++++-- .../apache/spark/ml/param/ParamsSuite.scala | 21 +++++++++++++++++++ 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index ee7e89edd8798..750fd05b6684e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -859,8 +859,15 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) * Filters this param map for the given parent. */ def filter(parent: Params): ParamMap = { - val filtered = map.filterKeys(_.parent == parent) - new ParamMap(filtered.asInstanceOf[mutable.Map[Param[Any], Any]]) + // Don't use filterKeys because mutable.Map#filterKeys + // returns the instance of collections.Map, not mutable.Map. + // Otherwise, we get ClassCastException. + // Not using filterKeys also avoid SI-6654 + val filtered = map.filter { + case (k, _) => + k.parent == parent.uid + } + new ParamMap(filtered) } /** diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala index a1878be747ceb..ddcedc4121cc9 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.ml.param import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.util.MyParams import org.apache.spark.mllib.linalg.{Vector, Vectors} class ParamsSuite extends SparkFunSuite { @@ -349,6 +350,26 @@ class ParamsSuite extends SparkFunSuite { val t3 = t.copy(ParamMap(t.maxIter -> 20)) assert(t3.isSet(t3.maxIter)) } + + test("Filtering ParamMap") { + val params1 = new MyParams("my_params1") + val params2 = new MyParams("my_params2") + val paramMap = ParamMap( + params1.intParam -> 1, + params2.intParam -> 1, + params1.doubleParam -> 0.2, + params2.doubleParam -> 0.2) + val filteredParamMap = paramMap.filter(params1) + + assert(filteredParamMap.size === 2) + filteredParamMap.toSeq.foreach { + case ParamPair(p, _) => + assert(p.parent === params1.uid) + } + + // Following assertion is to avoid SI-6654 + assert(filteredParamMap.isInstanceOf[Serializable]) + } } object ParamsSuite extends SparkFunSuite {