Skip to content

Commit

Permalink
Fixed ParamMap#filter to avoid ClassCastException and SI-6654
Browse files Browse the repository at this point in the history
  • Loading branch information
sarutak committed Dec 18, 2015
1 parent 2bebaa3 commit 95e3ff2
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 2 deletions.
11 changes: 9 additions & 2 deletions mllib/src/main/scala/org/apache/spark/ml/param/params.scala
Original file line number Diff line number Diff line change
Expand Up @@ -859,8 +859,15 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any])
* Filters this param map for the given parent.
*/
def filter(parent: Params): ParamMap = {
val filtered = map.filterKeys(_.parent == parent)
new ParamMap(filtered.asInstanceOf[mutable.Map[Param[Any], Any]])
// Don't use filterKeys because mutable.Map#filterKeys
// returns the instance of collections.Map, not mutable.Map.
// Otherwise, we get ClassCastException.
// Not using filterKeys also avoid SI-6654
val filtered = map.filter {
case (k, _) =>
k.parent == parent.uid
}
new ParamMap(filtered)
}

/**
Expand Down
21 changes: 21 additions & 0 deletions mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.ml.param

import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.util.MyParams
import org.apache.spark.mllib.linalg.{Vector, Vectors}

class ParamsSuite extends SparkFunSuite {
Expand Down Expand Up @@ -349,6 +350,26 @@ class ParamsSuite extends SparkFunSuite {
val t3 = t.copy(ParamMap(t.maxIter -> 20))
assert(t3.isSet(t3.maxIter))
}

test("Filtering ParamMap") {
val params1 = new MyParams("my_params1")
val params2 = new MyParams("my_params2")
val paramMap = ParamMap(
params1.intParam -> 1,
params2.intParam -> 1,
params1.doubleParam -> 0.2,
params2.doubleParam -> 0.2)
val filteredParamMap = paramMap.filter(params1)

assert(filteredParamMap.size === 2)
filteredParamMap.toSeq.foreach {
case ParamPair(p, _) =>
assert(p.parent === params1.uid)
}

// Following assertion is to avoid SI-6654
assert(filteredParamMap.isInstanceOf[Serializable])
}
}

object ParamsSuite extends SparkFunSuite {
Expand Down

0 comments on commit 95e3ff2

Please sign in to comment.