Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-12424][ML] The implementation of ParamMap#filter is wrong. #10381

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions mllib/src/main/scala/org/apache/spark/ml/param/params.scala
Original file line number Diff line number Diff line change
Expand Up @@ -859,8 +859,12 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any])
* Filters this param map for the given parent.
*/
def filter(parent: Params): ParamMap = {
val filtered = map.filterKeys(_.parent == parent)
new ParamMap(filtered.asInstanceOf[mutable.Map[Param[Any], Any]])
// Don't use filterKeys because mutable.Map#filterKeys
// returns the instance of collections.Map, not mutable.Map.
// Otherwise, we get ClassCastException.
// Not using filterKeys also avoid SI-6654
val filtered = map.filter { case (k, _) => k.parent == parent.uid }
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hasn't this changed the logic slightly? now you compare to parent.uid

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the original logic is wrong because the type of .parent is String (this is a member of Param) while the type of parameter parent is Params.

According to the implementation of Param, the member parent of Param is passed uid of Identifiable which is a trait of Params.

class Param[T](val parent: String, val name: String, val doc: String, val isValid: T => Boolean)
  extends Serializable {

  def this(parent: Identifiable, name: String, doc: String, isValid: T => Boolean) =
    this(parent.uid, name, doc, isValid)

new ParamMap(filtered)
}

/**
Expand Down
28 changes: 28 additions & 0 deletions mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@

package org.apache.spark.ml.param

import java.io.{ByteArrayOutputStream, NotSerializableException, ObjectOutputStream}

import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.util.MyParams
import org.apache.spark.mllib.linalg.{Vector, Vectors}

class ParamsSuite extends SparkFunSuite {
Expand Down Expand Up @@ -349,6 +352,31 @@ class ParamsSuite extends SparkFunSuite {
val t3 = t.copy(ParamMap(t.maxIter -> 20))
assert(t3.isSet(t3.maxIter))
}

test("Filtering ParamMap") {
val params1 = new MyParams("my_params1")
val params2 = new MyParams("my_params2")
val paramMap = ParamMap(
params1.intParam -> 1,
params2.intParam -> 1,
params1.doubleParam -> 0.2,
params2.doubleParam -> 0.2)
val filteredParamMap = paramMap.filter(params1)

assert(filteredParamMap.size === 2)
filteredParamMap.toSeq.foreach {
case ParamPair(p, _) =>
assert(p.parent === params1.uid)
}

// At the previous implementation of ParamMap#filter,
// mutable.Map#filterKeys was used internally but
// the return type of the method is not serializable (see SI-6654).
// Now mutable.Map#filter is used instead of filterKeys and the return type is serializable.
// So let's ensure serializability.
val objOut = new ObjectOutputStream(new ByteArrayOutputStream())
objOut.writeObject(filteredParamMap)
}
}

object ParamsSuite extends SparkFunSuite {
Expand Down