Skip to content

Commit

Permalink
add na.replace in pyspark
Browse files Browse the repository at this point in the history
  • Loading branch information
adrian-wang committed May 12, 2015
1 parent 640f63b commit 63ac579
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,11 @@ private[spark] object PythonUtils {
def toSeq[T](cols: JList[T]): Seq[T] = {
cols.toList.toSeq
}

/**
* Convert java map of K, V into Map of K, V (for calling API with varargs)
*/
def toMap[K, V](jm: java.util.Map[K, V]): Map[K, V] = {
jm.toMap
}
}
85 changes: 85 additions & 0 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,10 @@ def _jseq(self, cols, converter=None):
"""Return a JVM Seq of Columns from a list of Column or names"""
return _to_seq(self.sql_ctx._sc, cols, converter)

def _jmap(self, jm):
"""Return a JVM Map from a dict"""
return _to_map(self.sql_ctx._sc, jm)

def _jcols(self, *cols):
"""Return a JVM Seq of Columns from a list of Column or column names
Expand Down Expand Up @@ -924,6 +928,77 @@ def fillna(self, value, subset=None):

return DataFrame(self._jdf.na().fill(value, self._jseq(subset)), self.sql_ctx)

def replacena(self, to_replace, value, subset=None):
"""Returns a new :class:`DataFrame` replacing a value with another value.
alias for ``na.replace()``.
:param to_replace: int, long, float, string, or list.
Value to be replaced.
The replacement value must be an int, long, float, or string.
:param value: int, long, float, string, or list.
Value to use to replace holes.
The replacement value must be an int, long, float, or string.
:param subset: optional list of column names to consider.
Columns specified in subset that do not have matching data type are ignored.
For example, if `value` is a string, and subset contains a non-string column,
then the non-string column is simply ignored.
>>> df4.na.replace(10, 20).show()
+----+------+-----+
| age|height| name|
+----+------+-----+
| 20| 80|Alice|
| 5| null| Bob|
|null| null| Tom|
|null| null| null|
+----+------+-----+
>>> df4.replacena(['Alice', 'Bob'], ['A', 'B'], 'name').show()
+----+------+----+
| age|height|name|
+----+------+----+
| 10| 80| A|
| 5| null| B|
|null| null| Tom|
|null| null|null|
+----+------+----+
"""
if not isinstance(to_replace, (float, int, long, basestring, list, tuple)):
raise ValueError("to_replace should be a float, int, long, string, list, or tuple")

if not isinstance(value, (float, int, long, basestring, list, tuple)):
raise ValueError("value should be a float, int, long, string, list, or tuple")

if isinstance(to_replace, dict) and not isinstance(value, (list, tuple, dict)):
raise TypeError("")

if isinstance(to_replace, (float, int, long, basestring)):
to_replace = [to_replace]

if isinstance(value, (float, int, long, basestring)):
value = [value]

if isinstance(to_replace, tuple):
to_replace = list(to_replace)

if isinstance(value, tuple):
value = list(value)

if isinstance(to_replace, list) and isinstance(value, list):
if len(to_replace) != len(value):
raise ValueError("to_replace and value lists should be of the same length")

rep_dict = dict(zip(to_replace, value))
if subset is None:
return DataFrame(self._jdf.na().replace('*', rep_dict), self.sql_ctx)
elif isinstance(subset, basestring):
subset = [subset]

if not isinstance(subset, (list, tuple)):
raise ValueError("subset should be a list or tuple of column names")

return DataFrame(
self._jdf.na().replace(self._jseq(subset), self._jmap(rep_dict)), self.sql_ctx)

def corr(self, col1, col2, method=None):
"""
Calculates the correlation of two columns of a DataFrame as a double value. Currently only
Expand Down Expand Up @@ -1225,6 +1300,11 @@ def _to_seq(sc, cols, converter=None):
cols = [converter(c) for c in cols]
return sc._jvm.PythonUtils.toSeq(cols)

def _to_map(sc, jm):
"""
Convert a dict into a JVM Map.
"""
return sc._jvm.PythonUtils.toMap(jm)

def _unary_op(name, doc="unary operator"):
""" Create a method for given unary operator """
Expand Down Expand Up @@ -1482,6 +1562,11 @@ def fill(self, value, subset=None):

fill.__doc__ = DataFrame.fillna.__doc__

def replace(self, to_replace, value, subset=None):
return self.df.replacena(to_replace=to_replace, value=value, subset=subset)

replace.__doc__ = DataFrame.replacena.__doc__


class DataFrameStatFunctions(object):
"""Functionality for statistic functions with :class:`DataFrame`.
Expand Down

0 comments on commit 63ac579

Please sign in to comment.