Skip to content

Commit

Permalink
overload StringArrayParam.w
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed May 7, 2015
1 parent c81072d commit c221db9
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 13 deletions.
22 changes: 13 additions & 9 deletions mllib/src/main/scala/org/apache/spark/ml/param/params.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import java.util.NoSuchElementException

import scala.annotation.varargs
import scala.collection.mutable
import scala.reflect.ClassTag
import scala.collection.JavaConverters._

import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.util.Identifiable
Expand Down Expand Up @@ -228,7 +228,8 @@ class StringArrayParam(parent: Params, name: String, doc: String, isValid: Array

override def w(value: Array[String]): ParamPair[Array[String]] = super.w(value)

private[param] def wCast(value: Seq[String]): ParamPair[Array[String]] = w(value.toArray)
/** Creates a param pair with a [[java.util.List]] of values (for Java and Python). */
def w(value: java.util.List[String]): ParamPair[Array[String]] = w(value.asScala.toArray)
}

/**
Expand Down Expand Up @@ -323,13 +324,7 @@ trait Params extends Identifiable with Serializable {
* Sets a parameter in the embedded param map.
*/
protected final def set[T](param: Param[T], value: T): this.type = {
shouldOwn(param)
if (param.isInstanceOf[StringArrayParam] && value.isInstanceOf[Seq[_]]) {
paramMap.put(param.asInstanceOf[StringArrayParam].wCast(value.asInstanceOf[Seq[String]]))
} else {
paramMap.put(param.w(value))
}
this
set(param -> value)
}

/**
Expand All @@ -339,6 +334,15 @@ trait Params extends Identifiable with Serializable {
set(getParam(param), value)
}

/**
* Sets a parameter in the embedded param map.
*/
protected final def set(paramPair: ParamPair[_]): this.type = {
shouldOwn(paramPair.param)
paramMap.put(paramPair)
this
}

/**
* Optionally returns the user-supplied value of a param.
*/
Expand Down
8 changes: 4 additions & 4 deletions python/pyspark/ml/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,8 @@ def _transfer_params_to_java(self, params, java_obj):
for param in self.params:
if param in paramMap:
value = paramMap[param]
if isinstance(value, list):
value = _jvm().PythonUtils.toSeq(value)
java_obj.set(param.name, value)
java_param = java_obj.getParam(param.name)
java_obj.set(java_param.w(value))

def _empty_java_param_map(self):
"""
Expand All @@ -82,7 +81,8 @@ def _create_java_param_map(self, params, java_obj):
paramMap = self._empty_java_param_map()
for param, value in params.items():
if param.parent is self:
paramMap.put(java_obj.getParam(param.name), value)
java_param = java_obj.getParam(param.name)
paramMap.put(java_param.w(value))
return paramMap


Expand Down

0 comments on commit c221db9

Please sign in to comment.